1use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16 #[error("graph has no model config")]
18 MissingConfig,
19
20 #[error("I/O error: {0}")]
22 Io(#[from] std::io::Error),
23
24 #[error("format error: {0}")]
26 Fmt(#[from] std::fmt::Error),
27}
28
29pub fn generate_metal_project(
37 graph: &Graph,
38 output_dir: &Path,
39 model_name: &str,
40) -> Result<(), MetalCodegenError> {
41 let config = graph
42 .config
43 .as_ref()
44 .ok_or(MetalCodegenError::MissingConfig)?;
45
46 let src_dir = output_dir.join("src");
47 let shader_dir = output_dir.join("shaders");
48 fs::create_dir_all(&src_dir)?;
49 fs::create_dir_all(&shader_dir)?;
50
51 fs::write(
52 output_dir.join("Cargo.toml"),
53 generate_cargo_toml(model_name),
54 )?;
55
56 fs::write(
57 shader_dir.join("kernels.metal"),
58 generate_metal_shaders(config),
59 )?;
60
61 let model_rs = generate_model_rs(config)?;
62 fs::write(src_dir.join("model.rs"), model_rs)?;
63
64 let main_rs = generate_main_rs(model_name, config)?;
65 fs::write(src_dir.join("main.rs"), main_rs)?;
66
67 Ok(())
68}
69
70fn sanitize_name(name: &str) -> String {
75 name.to_lowercase()
76 .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77 .trim_matches('-')
78 .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82 let sanitized = sanitize_name(model_name);
83 format!(
84 r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108 )
109}
110
111fn generate_metal_shaders(config: &ModelConfig) -> String {
116 let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
125 let attn_scores_size = config.max_seq_len.min(4096);
130 r#"//
131// Auto-generated by ForgeLLM Metal codegen.
132// Metal Shading Language compute kernels for transformer inference.
133//
134// Optimized with simdgroup cooperative reductions, shared memory vector
135// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
136// lane, and fast:: math intrinsics for Apple Silicon throughput.
137//
138
139#include <metal_stdlib>
140#include <metal_simdgroup_matrix>
141using namespace metal;
142
143// ── Constants ───────────────────────────────────────────────────────────
144// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
145// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
146// per shared memory load vs 4-row, improving ILP and reducing launches.
147constant constexpr uint SIMDGROUPS_PER_TG = 8;
148constant constexpr uint ROWS_PER_SIMDGROUP = 8;
149constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
150
151// ── matmul_vec ──────────────────────────────────────────────────────────
152// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
153// Uses simdgroup cooperative dot product with shared memory vector caching
154// and float4 vectorized loads. Each simdgroup processes 8 rows for better
155// shared memory reuse (8x vector reuse per load) and instruction-level
156// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
157kernel void matmul_vec(
158 device const float* matrix [[buffer(0)]],
159 device const float* vector [[buffer(1)]],
160 device float* output [[buffer(2)]],
161 constant uint& rows [[buffer(3)]],
162 constant uint& cols [[buffer(4)]],
163 uint tgid [[threadgroup_position_in_grid]],
164 uint tid [[thread_index_in_threadgroup]],
165 uint simd_lane [[thread_index_in_simdgroup]],
166 uint simd_id [[simdgroup_index_in_threadgroup]])
167{
168 // Cooperatively load vector into threadgroup shared memory
169 threadgroup float vec_tile[VEC_TILE_SIZE]; // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
170 for (uint i = tid; i < cols; i += 256) {
171 vec_tile[i] = vector[i];
172 }
173 threadgroup_barrier(mem_flags::mem_threadgroup);
174
175 // Each simdgroup handles 8 consecutive rows
176 uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
177 if (row_base >= rows) return;
178
179 uint base0 = row_base * cols;
180 uint base1 = (row_base + 1) * cols;
181 uint base2 = (row_base + 2) * cols;
182 uint base3 = (row_base + 3) * cols;
183 uint base4 = (row_base + 4) * cols;
184 uint base5 = (row_base + 5) * cols;
185 uint base6 = (row_base + 6) * cols;
186 uint base7 = (row_base + 7) * cols;
187
188 // float4 vectorized accumulation across 8 rows
189 uint cols_vec4 = cols & ~127u; // largest multiple of 128 <= cols
190 float4 sum4_0 = float4(0.0f);
191 float4 sum4_1 = float4(0.0f);
192 float4 sum4_2 = float4(0.0f);
193 float4 sum4_3 = float4(0.0f);
194 float4 sum4_4 = float4(0.0f);
195 float4 sum4_5 = float4(0.0f);
196 float4 sum4_6 = float4(0.0f);
197 float4 sum4_7 = float4(0.0f);
198
199 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
200 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
201 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
202 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
203 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
204 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
205 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
206 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
207 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
208 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
209 }
210
211 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
212 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
213 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
214 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
215 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
216 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
217 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
218 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
219
220 // Handle remaining elements (cols not divisible by 128)
221 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
222 float vv = vec_tile[j];
223 sum0 += matrix[base0 + j] * vv;
224 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
225 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
226 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
227 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
228 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
229 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
230 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
231 }
232
233 // Simdgroup hardware warp-level reduction
234 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
235 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
236 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
237 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
238
239 // Only first lane writes the results
240 if (simd_lane == 0) {
241 if (row_base < rows) output[row_base] = sum0;
242 if (row_base + 1 < rows) output[row_base + 1] = sum1;
243 if (row_base + 2 < rows) output[row_base + 2] = sum2;
244 if (row_base + 3 < rows) output[row_base + 3] = sum3;
245 if (row_base + 4 < rows) output[row_base + 4] = sum4;
246 if (row_base + 5 < rows) output[row_base + 5] = sum5;
247 if (row_base + 6 < rows) output[row_base + 6] = sum6;
248 if (row_base + 7 < rows) output[row_base + 7] = sum7;
249 }
250}
251
252// ── rms_norm ────────────────────────────────────────────────────────────
253// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
254// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
255// via shared memory for minimal synchronization overhead.
256kernel void rms_norm(
257 device const float* input [[buffer(0)]],
258 device const float* weight [[buffer(1)]],
259 device float* output [[buffer(2)]],
260 constant uint& n [[buffer(3)]],
261 constant float& eps [[buffer(4)]],
262 uint tid [[thread_index_in_threadgroup]])
263{
264 // Each thread accumulates partial sum-of-squares
265 float sum_sq = 0.0f;
266 for (uint i = tid; i < n; i += 256) {
267 float v = input[i];
268 sum_sq += v * v;
269 }
270
271 // Simdgroup-level reduction (hardware warp sum)
272 sum_sq = simd_sum(sum_sq);
273
274 // Cross-simdgroup reduction via shared memory
275 threadgroup float shared[8];
276 uint simd_id = tid / 32;
277 uint simd_lane = tid % 32;
278 if (simd_lane == 0) {
279 shared[simd_id] = sum_sq;
280 }
281 threadgroup_barrier(mem_flags::mem_threadgroup);
282
283 // First thread computes final inverse RMS
284 if (tid == 0) {
285 float total = 0.0f;
286 for (uint i = 0; i < 8; i++) {
287 total += shared[i];
288 }
289 shared[0] = fast::rsqrt(total / float(n) + eps);
290 }
291 threadgroup_barrier(mem_flags::mem_threadgroup);
292
293 float inv_rms = shared[0];
294
295 // Normalize
296 for (uint i = tid; i < n; i += 256) {
297 output[i] = input[i] * inv_rms * weight[i];
298 }
299}
300
301// ── rope ────────────────────────────────────────────────────────────────
302// Rotary Position Embedding applied in-place.
303// Each thread handles one (head, pair) combination.
304kernel void rope(
305 device float* data [[buffer(0)]],
306 constant uint& num_heads [[buffer(1)]],
307 constant uint& head_dim [[buffer(2)]],
308 constant uint& pos [[buffer(3)]],
309 constant float& theta [[buffer(4)]],
310 uint id [[thread_position_in_grid]])
311{
312 uint half_dim = head_dim / 2;
313 uint total_pairs = num_heads * half_dim;
314 if (id >= total_pairs) return;
315
316 uint h = id / half_dim;
317 uint i = id % half_dim;
318 uint off = h * head_dim;
319
320 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
321 float angle = float(pos) * freq;
322 float c = cos(angle);
323 float s = sin(angle);
324
325 float x0 = data[off + 2 * i];
326 float x1 = data[off + 2 * i + 1];
327 data[off + 2 * i] = x0 * c - x1 * s;
328 data[off + 2 * i + 1] = x0 * s + x1 * c;
329}
330
331// ── softmax ─────────────────────────────────────────────────────────────
332// Numerically stable softmax over a 1-D array.
333// Single-threadgroup kernel with cooperative reduction.
334kernel void softmax(
335 device float* data [[buffer(0)]],
336 constant uint& n [[buffer(1)]],
337 uint tid [[thread_index_in_threadgroup]],
338 uint tg_size [[threads_per_threadgroup]])
339{
340 threadgroup float shared_val[256];
341
342 // Pass 1: find max
343 float local_max = -INFINITY;
344 for (uint i = tid; i < n; i += tg_size) {
345 local_max = max(local_max, data[i]);
346 }
347 shared_val[tid] = local_max;
348 threadgroup_barrier(mem_flags::mem_threadgroup);
349
350 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
351 if (tid < stride) {
352 shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
353 }
354 threadgroup_barrier(mem_flags::mem_threadgroup);
355 }
356 float max_val = shared_val[0];
357 threadgroup_barrier(mem_flags::mem_threadgroup);
358
359 // Pass 2: exp and sum
360 float local_sum = 0.0f;
361 for (uint i = tid; i < n; i += tg_size) {
362 float e = fast::exp(data[i] - max_val);
363 data[i] = e;
364 local_sum += e;
365 }
366 shared_val[tid] = local_sum;
367 threadgroup_barrier(mem_flags::mem_threadgroup);
368
369 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
370 if (tid < stride) {
371 shared_val[tid] += shared_val[tid + stride];
372 }
373 threadgroup_barrier(mem_flags::mem_threadgroup);
374 }
375 float inv_sum = 1.0f / shared_val[0];
376 threadgroup_barrier(mem_flags::mem_threadgroup);
377
378 // Pass 3: normalize
379 for (uint i = tid; i < n; i += tg_size) {
380 data[i] *= inv_sum;
381 }
382}
383
384// ── silu_mul ────────────────────────────────────────────────────────────
385// Fused SiLU activation * element-wise multiply:
386// output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
387kernel void silu_mul(
388 device const float* gate [[buffer(0)]],
389 device const float* up [[buffer(1)]],
390 device float* output [[buffer(2)]],
391 constant uint& n [[buffer(3)]],
392 uint id [[thread_position_in_grid]])
393{
394 if (id >= n) return;
395 float g = gate[id];
396 output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
397}
398
399// ── silu_mul_fused ─────────────────────────────────────────────────────
400// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
401// gate = gate_up[0..n], up = gate_up[n..2*n]
402// output[i] = silu(gate_up[i]) * gate_up[n + i]
403kernel void silu_mul_fused(
404 device const float* gate_up [[buffer(0)]],
405 device float* output [[buffer(1)]],
406 constant uint& n [[buffer(2)]],
407 uint id [[thread_position_in_grid]])
408{
409 if (id >= n) return;
410 float g = gate_up[id];
411 float u = gate_up[n + id];
412 output[id] = (g / (1.0f + fast::exp(-g))) * u;
413}
414
415// ── gelu_mul_fused ─────────────────────────────────────────────────────
416// Fused GELU-tanh-approx × multiply. Same layout as silu_mul_fused but
417// uses the tanh-approximate GELU (matches HF's `gelu_pytorch_tanh`, which
418// is what Gemma-1 uses):
419// GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
420// output[i] = GELU(gate_up[i]) * gate_up[n + i]
421kernel void gelu_mul_fused(
422 device const float* gate_up [[buffer(0)]],
423 device float* output [[buffer(1)]],
424 constant uint& n [[buffer(2)]],
425 uint id [[thread_position_in_grid]])
426{
427 if (id >= n) return;
428 constexpr float SQRT_2_OVER_PI = 0.7978845608f;
429 float g = gate_up[id];
430 float u = gate_up[n + id];
431 float inner = SQRT_2_OVER_PI * (g + 0.044715f * g * g * g);
432 float gelu = 0.5f * g * (1.0f + precise::tanh(inner));
433 output[id] = gelu * u;
434}
435
436// ── elementwise_add ─────────────────────────────────────────────────────
437// Residual connection: output[i] = a[i] + b[i]
438kernel void elementwise_add(
439 device const float* a [[buffer(0)]],
440 device const float* b [[buffer(1)]],
441 device float* output [[buffer(2)]],
442 constant uint& n [[buffer(3)]],
443 uint id [[thread_position_in_grid]])
444{
445 if (id >= n) return;
446 output[id] = a[id] + b[id];
447}
448
449// ── copy_buffer ─────────────────────────────────────────────────────────
450// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
451// transitions. Used for KV cache updates and embedding lookup.
452kernel void copy_buffer(
453 device const float* src [[buffer(0)]],
454 device float* dst [[buffer(1)]],
455 constant uint& count [[buffer(2)]],
456 uint id [[thread_position_in_grid]])
457{
458 if (id < count) dst[id] = src[id];
459}
460
461// ── copy_offset ─────────────────────────────────────────────────────────
462// Copy with source offset (in floats). Used for embedding table lookup
463// where we need to copy a specific row from a large table.
464kernel void copy_offset(
465 device const float* src [[buffer(0)]],
466 device float* dst [[buffer(1)]],
467 constant uint& src_offset [[buffer(2)]], // in floats
468 constant uint& count [[buffer(3)]],
469 uint id [[thread_position_in_grid]])
470{
471 if (id < count) dst[id] = src[src_offset + id];
472}
473
474// ── copy_f32_to_f16_offset ──────────────────────────────────────────────
475// Copy f32 elements from src into a half-typed dst, converting to half on
476// write. Used by the single-token decode path to append a new K/V vector
477// to the f16 KV cache. Byte offsets into src/dst are supplied via the
478// Metal buffer binding offsets — no in-kernel offsets needed.
479kernel void copy_f32_to_f16_offset(
480 device const float* src [[buffer(0)]],
481 device half* dst [[buffer(1)]],
482 constant uint& count [[buffer(2)]],
483 uint id [[thread_position_in_grid]])
484{
485 if (id < count) dst[id] = half(src[id]);
486}
487
488// ── add_inplace ─────────────────────────────────────────────────────────
489// In-place residual connection: a[i] += b[i]
490// Avoids a separate blit copy for residual add, reducing encoder overhead.
491kernel void add_inplace(
492 device float* a [[buffer(0)]],
493 device const float* b [[buffer(1)]],
494 constant uint& n [[buffer(2)]],
495 uint id [[thread_position_in_grid]])
496{
497 if (id >= n) return;
498 a[id] += b[id];
499}
500
501// ── matmul_vec_q8 ─────────────────────────────────────────────────────
502// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
503// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
504// Operates directly on quantized weights to halve memory bandwidth vs f32,
505// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
506//
507// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
508// because int8->float conversion doubles register demand. Fully unrolled
509// inner loop with float4 vector loads from shared memory eliminates loop
510// overhead and enables better instruction scheduling.
511// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
512constant constexpr uint Q8_ROWS_PER_SG = 4;
513constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
514
515// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
516// doubles ALU work, so the same register budget applies).
517constant constexpr uint Q4_ROWS_PER_SG = 4;
518constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
519
520kernel void matmul_vec_q8(
521 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes
522 device const float* vector [[buffer(1)]], // f32 input
523 device float* output [[buffer(2)]],
524 constant uint& rows [[buffer(3)]],
525 constant uint& cols [[buffer(4)]], // number of elements per row
526 uint tgid [[threadgroup_position_in_grid]],
527 uint tid [[thread_index_in_threadgroup]],
528 uint simd_lane [[thread_index_in_simdgroup]],
529 uint simd_id [[simdgroup_index_in_threadgroup]])
530{
531 // Load vector into shared memory
532 threadgroup float vec_tile[VEC_TILE_SIZE];
533 for (uint i = tid; i < cols; i += 256) {
534 vec_tile[i] = vector[i];
535 }
536 threadgroup_barrier(mem_flags::mem_threadgroup);
537
538 // Each simdgroup handles 4 consecutive rows (lower register pressure)
539 uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
540 if (row_base >= rows) return;
541
542 // Q8_0: each block is 34 bytes for 32 elements
543 uint blocks_per_row = cols / 32;
544 uint row_bytes = blocks_per_row * 34;
545
546 // Pointers to each row's Q8_0 data
547 device const uchar* r0 = matrix + row_base * row_bytes;
548 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
549 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
550 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
551
552 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
553
554 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
555 uint bb = blk * 34;
556 uint vb = blk * 32;
557
558 // Prefetch all 4 scales
559 float sc0 = float(*(device const half*)(r0 + bb));
560 float sc1 = float(*(device const half*)(r1 + bb));
561 float sc2 = float(*(device const half*)(r2 + bb));
562 float sc3 = float(*(device const half*)(r3 + bb));
563
564 // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
565 // Q8_0 block layout where the int8 data starts at offset +2 from a
566 // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
567 // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
568 // reduction in memory transactions. Metal's char16/packed_char16 are
569 // reserved types and packed_*int4 require >=4-byte alignment which
570 // this layout does not provide, so packed_short4 is the widest valid
571 // vectorized load.
572 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
573 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
574 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
575 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
576
577 // Load all 8 float4 vector values for this 32-element block from shared memory
578 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
579 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
580 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
581 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
582 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
583 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
584 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
585 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
586
587 // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
588 // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
589 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
590 short4 _s = short4(SHORT4); \
591 char2 _a = as_type<char2>(_s.x); \
592 char2 _b = as_type<char2>(_s.y); \
593 char2 _c = as_type<char2>(_s.z); \
594 char2 _d = as_type<char2>(_s.w); \
595 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
596 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
597 }
598
599 float4 f0, f1;
600 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
601
602 // Row 0: 4 short4 loads cover 32 int8 weights
603 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
604 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
605 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
606 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
607
608 // Row 1
609 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
610 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
611 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
612 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
613
614 // Row 2
615 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
616 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
617 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
618 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
619
620 // Row 3
621 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
622 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
623 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
624 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
625
626 #undef Q8_UNPACK8
627
628 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
629 }
630
631 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
632 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
633
634 if (simd_lane == 0) {
635 if (row_base < rows) output[row_base] = sum0;
636 if (row_base + 1 < rows) output[row_base + 1] = sum1;
637 if (row_base + 2 < rows) output[row_base + 2] = sum2;
638 if (row_base + 3 < rows) output[row_base + 3] = sum3;
639 }
640}
641
642// ── matmul_vec_q4 ─────────────────────────────────────────────────────
643// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
644// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
645// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
646// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
647//
648// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
649// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
650kernel void matmul_vec_q4(
651 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes
652 device const float* vector [[buffer(1)]], // f32 input
653 device float* output [[buffer(2)]],
654 constant uint& rows [[buffer(3)]],
655 constant uint& cols [[buffer(4)]], // number of elements per row
656 uint tgid [[threadgroup_position_in_grid]],
657 uint tid [[thread_index_in_threadgroup]],
658 uint simd_lane [[thread_index_in_simdgroup]],
659 uint simd_id [[simdgroup_index_in_threadgroup]])
660{
661 // Load vector into shared memory
662 threadgroup float vec_tile[VEC_TILE_SIZE];
663 for (uint i = tid; i < cols; i += 256) {
664 vec_tile[i] = vector[i];
665 }
666 threadgroup_barrier(mem_flags::mem_threadgroup);
667
668 // Each simdgroup handles 4 consecutive rows
669 uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
670 if (row_base >= rows) return;
671
672 // Q4_0: each block is 18 bytes for 32 elements
673 uint blocks_per_row = cols / 32;
674 uint row_bytes = blocks_per_row * 18;
675
676 // Pointers to each row's Q4_0 data
677 device const uchar* r0 = matrix + row_base * row_bytes;
678 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
679 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
680 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
681
682 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
683
684 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
685 uint bb = blk * 18;
686 uint vb = blk * 32;
687
688 // Prefetch all 4 scales
689 float sc0 = float(*(device const half*)(r0 + bb));
690 float sc1 = float(*(device const half*)(r1 + bb));
691 float sc2 = float(*(device const half*)(r2 + bb));
692 float sc3 = float(*(device const half*)(r3 + bb));
693
694 // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
695 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
696 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
697 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
698 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
699
700 // Load 8 float4 vector values for 32 elements from shared memory
701 // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
702 float4 v0 = *(threadgroup const float4*)(vec_tile + vb); // [0..3]
703 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4); // [4..7]
704 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8); // [8..11]
705 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12); // [12..15]
706 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16); // [16..19]
707 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20); // [20..23]
708 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24); // [24..27]
709 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28); // [28..31]
710
711 // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
712 // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
713 float bd0=0, bd1=0, bd2=0, bd3=0;
714 uchar4 b;
715
716 // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
717 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
718 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
719 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
720 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
721 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
722 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
723 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
724 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
725 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
726 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
727 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
728 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
729 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
730 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
731 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
732 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
733
734 // Row 1
735 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
736 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
737 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
738 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
739 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
740 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
741 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
742 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
743 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
744 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
745 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
746 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
747 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
748 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
749 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
750 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
751
752 // Row 2
753 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
754 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
755 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
756 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
757 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
758 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
759 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
760 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
761 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
762 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
763 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
764 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
765 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
766 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
767 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
768 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
769
770 // Row 3
771 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
772 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
773 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
774 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
775 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
776 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
777 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
778 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
779 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
780 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
781 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
782 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
783 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
784 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
785 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
786 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
787
788 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
789 }
790
791 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
792 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
793
794 if (simd_lane == 0) {
795 if (row_base < rows) output[row_base] = sum0;
796 if (row_base + 1 < rows) output[row_base + 1] = sum1;
797 if (row_base + 2 < rows) output[row_base + 2] = sum2;
798 if (row_base + 3 < rows) output[row_base + 3] = sum3;
799 }
800}
801
802// ── attention ───────────────────────────────────────────────────────────
803// Single-query attention with simdgroup cooperative reductions.
804// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
805// with simd_max/simd_sum reductions, then weighted sum of V.
806// Each threadgroup handles one head with 256 threads (8 simdgroups).
807//
808// Buffers:
809// q: [num_heads * head_dim] current query
810// k_cache: [max_seq_len * num_kv_heads * head_dim]
811// v_cache: [max_seq_len * num_kv_heads * head_dim]
812// output: [num_heads * head_dim]
813kernel void attention(
814 device const float* q [[buffer(0)]],
815 device const half* k_cache [[buffer(1)]],
816 device const half* v_cache [[buffer(2)]],
817 device float* output [[buffer(3)]],
818 constant uint& seq_len [[buffer(4)]],
819 constant uint& num_heads [[buffer(5)]],
820 constant uint& num_kv_heads [[buffer(6)]],
821 constant uint& head_dim [[buffer(7)]],
822 uint tgid [[threadgroup_position_in_grid]],
823 uint tid [[thread_index_in_threadgroup]],
824 uint simd_lane [[thread_index_in_simdgroup]],
825 uint simd_id [[simdgroup_index_in_threadgroup]])
826{
827 uint head = tgid;
828 if (head >= num_heads) return;
829 uint kv_head = head / (num_heads / num_kv_heads);
830
831 uint q_off = head * head_dim;
832
833 // Step 1: Compute attention scores Q·K^T with simdgroup reduction
834 // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
835 // vs 4096. For seq_len > 2048, generation-phase attention is rare;
836 // most generation steps have short effective context.
837 threadgroup float scores[ATTN_SCORES_SIZE]; // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
838
839 // Q·K^T with half4/float4 vectorized loads.
840 // Each simdgroup handles one s; 32 lanes cover head_dim in chunks of 4.
841 // For head_dim=128, every lane does exactly one half4 load (no loop).
842 // For head_dim=64, 16 lanes active; for head_dim=96, 24 lanes.
843 uint head_dim4 = head_dim / 4;
844 for (uint s = simd_id; s < seq_len; s += 8) {
845 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
846 float dot = 0.0;
847 for (uint d4 = simd_lane; d4 < head_dim4; d4 += 32) {
848 uint d = d4 * 4;
849 float4 q4 = *(device const float4*)(q + q_off + d);
850 float4 k4 = float4(*(device const half4*)(k_cache + k_off + d));
851 dot += q4.x * k4.x + q4.y * k4.y + q4.z * k4.z + q4.w * k4.w;
852 }
853 // Scalar fallback for head_dim not divisible by 4 (unused for all
854 // current models: head_dim ∈ {64, 96, 128} are all multiples of 4).
855 for (uint d = head_dim4 * 4 + simd_lane; d < head_dim; d += 32) {
856 dot += q[q_off + d] * float(k_cache[k_off + d]);
857 }
858 dot = simd_sum(dot);
859 if (simd_lane == 0) {
860 scores[s] = dot * fast::rsqrt(float(head_dim));
861 }
862 }
863 threadgroup_barrier(mem_flags::mem_threadgroup);
864
865 // Step 2: Softmax over scores (cooperative)
866 // Find max
867 float local_max = -INFINITY;
868 for (uint s = tid; s < seq_len; s += 256) {
869 local_max = max(local_max, scores[s]);
870 }
871 local_max = simd_max(local_max);
872 threadgroup float shared_max[8];
873 if (simd_lane == 0) shared_max[simd_id] = local_max;
874 threadgroup_barrier(mem_flags::mem_threadgroup);
875 if (tid == 0) {
876 float m = shared_max[0];
877 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
878 shared_max[0] = m;
879 }
880 threadgroup_barrier(mem_flags::mem_threadgroup);
881 float max_val = shared_max[0];
882
883 // Exp and sum
884 float local_sum = 0.0;
885 for (uint s = tid; s < seq_len; s += 256) {
886 scores[s] = fast::exp(scores[s] - max_val);
887 local_sum += scores[s];
888 }
889 local_sum = simd_sum(local_sum);
890 threadgroup float shared_sum[8];
891 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
892 threadgroup_barrier(mem_flags::mem_threadgroup);
893 if (tid == 0) {
894 float total = 0.0;
895 for (uint i = 0; i < 8; i++) total += shared_sum[i];
896 shared_sum[0] = 1.0 / total;
897 }
898 threadgroup_barrier(mem_flags::mem_threadgroup);
899 float inv_sum = shared_sum[0];
900
901 for (uint s = tid; s < seq_len; s += 256) {
902 scores[s] *= inv_sum;
903 }
904 threadgroup_barrier(mem_flags::mem_threadgroup);
905
906 // Step 3: Weighted sum of V: output = scores · V.
907 //
908 // v0.7.4 restructure: one simdgroup per d4 chunk. The 32 lanes in a
909 // simdgroup partition the seq_len dimension (lane handles positions
910 // s = simd_lane, simd_lane+32, simd_lane+64, ...) and reduce their
911 // partial sums via simd_sum. This keeps all 256 threads productive
912 // for head_dim ∈ {64, 96, 128}; the old "one thread per d4" layout
913 // kept only 16/256 threads productive for head_dim=64.
914 //
915 // 8 simdgroups handle 8 d4s per outer iteration:
916 // head_dim=64 (head_dim4=16): 2 iters
917 // head_dim=96 (head_dim4=24): 3 iters
918 // head_dim=128 (head_dim4=32): 4 iters
919 uint v_stride = num_kv_heads * head_dim;
920 for (uint d4 = simd_id; d4 < head_dim4; d4 += 8) {
921 uint d = d4 * 4;
922 uint v_base = kv_head * head_dim + d;
923 float4 partial = float4(0.0);
924 for (uint s = simd_lane; s < seq_len; s += 32) {
925 half4 v = *(device const half4*)(v_cache + s * v_stride + v_base);
926 partial += scores[s] * float4(v);
927 }
928 // Reduce the 32 lanes of this simdgroup to lane 0 per component.
929 partial.x = simd_sum(partial.x);
930 partial.y = simd_sum(partial.y);
931 partial.z = simd_sum(partial.z);
932 partial.w = simd_sum(partial.w);
933 if (simd_lane == 0) {
934 *(device float4*)(output + q_off + d) = partial;
935 }
936 }
937 // Scalar fallback for dims not divisible by 4 (unused for current models —
938 // all supported head_dims are in {64, 96, 128}). Same simd-per-d layout.
939 for (uint d = head_dim4 * 4 + simd_id; d < head_dim; d += 8) {
940 uint v_base = kv_head * head_dim + d;
941 float partial = 0.0;
942 for (uint s = simd_lane; s < seq_len; s += 32) {
943 partial += scores[s] * float(v_cache[s * v_stride + v_base]);
944 }
945 partial = simd_sum(partial);
946 if (simd_lane == 0) {
947 output[q_off + d] = partial;
948 }
949 }
950}
951
952// ── Batched prefill kernels ────────────────────────────────────────────
953// These kernels process M input vectors against the same weight matrix
954// in a single dispatch, converting mat-vec into mat-mat for better GPU
955// utilization during prompt prefill.
956
957// ── rms_norm_batch ─────────────────────────────────────────────────────
958// RMS normalization for a batch of vectors.
959// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
960// Grid: M threadgroups (one per token).
961kernel void rms_norm_batch(
962 device const float* input [[buffer(0)]], // [M, n]
963 device const float* weight [[buffer(1)]], // [n]
964 device float* output [[buffer(2)]], // [M, n]
965 constant uint& n [[buffer(3)]],
966 constant float& eps [[buffer(4)]],
967 constant uint& num_tokens [[buffer(5)]],
968 uint tgid [[threadgroup_position_in_grid]],
969 uint tid [[thread_index_in_threadgroup]])
970{
971 if (tgid >= num_tokens) return;
972
973 uint base = tgid * n;
974
975 float sum_sq = 0.0f;
976 for (uint i = tid; i < n; i += 256) {
977 float v = input[base + i];
978 sum_sq += v * v;
979 }
980
981 sum_sq = simd_sum(sum_sq);
982
983 threadgroup float shared[8];
984 uint simd_id = tid / 32;
985 uint simd_lane = tid % 32;
986 if (simd_lane == 0) {
987 shared[simd_id] = sum_sq;
988 }
989 threadgroup_barrier(mem_flags::mem_threadgroup);
990
991 if (tid == 0) {
992 float total = 0.0f;
993 for (uint i = 0; i < 8; i++) {
994 total += shared[i];
995 }
996 shared[0] = fast::rsqrt(total / float(n) + eps);
997 }
998 threadgroup_barrier(mem_flags::mem_threadgroup);
999
1000 float inv_rms = shared[0];
1001
1002 for (uint i = tid; i < n; i += 256) {
1003 output[base + i] = input[base + i] * inv_rms * weight[i];
1004 }
1005}
1006
1007// ── rope_batch ─────────────────────────────────────────────────────────
1008// Rotary Position Embedding for a batch of vectors with different positions.
1009// data layout: [M, num_heads * head_dim], positions: [M]
1010// Each thread handles one (token, head, pair) combination.
1011kernel void rope_batch(
1012 device float* data [[buffer(0)]], // [M, num_heads * head_dim]
1013 constant uint& num_heads [[buffer(1)]],
1014 constant uint& head_dim [[buffer(2)]],
1015 device const uint* positions [[buffer(3)]], // [M] position per token
1016 constant float& theta [[buffer(4)]],
1017 constant uint& num_tokens [[buffer(5)]],
1018 uint id [[thread_position_in_grid]])
1019{
1020 uint half_dim = head_dim / 2;
1021 uint pairs_per_token = num_heads * half_dim;
1022 uint total = num_tokens * pairs_per_token;
1023 if (id >= total) return;
1024
1025 uint token = id / pairs_per_token;
1026 uint rem = id % pairs_per_token;
1027 uint h = rem / half_dim;
1028 uint i = rem % half_dim;
1029 uint off = token * (num_heads * head_dim) + h * head_dim;
1030
1031 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
1032 float angle = float(positions[token]) * freq;
1033 float c = cos(angle);
1034 float s = sin(angle);
1035
1036 float x0 = data[off + 2 * i];
1037 float x1 = data[off + 2 * i + 1];
1038 data[off + 2 * i] = x0 * c - x1 * s;
1039 data[off + 2 * i + 1] = x0 * s + x1 * c;
1040}
1041
1042// ── silu_mul_fused_batch ───────────────────────────────────────────────
1043// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
1044// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
1045kernel void silu_mul_fused_batch(
1046 device const float* gate_up [[buffer(0)]], // [M, 2*n]
1047 device float* output [[buffer(1)]], // [M, n]
1048 constant uint& n [[buffer(2)]],
1049 constant uint& num_tokens [[buffer(3)]],
1050 uint id [[thread_position_in_grid]])
1051{
1052 uint total = num_tokens * n;
1053 if (id >= total) return;
1054 uint token = id / n;
1055 uint i = id % n;
1056 uint gu_base = token * 2 * n;
1057 float g = gate_up[gu_base + i];
1058 float u = gate_up[gu_base + n + i];
1059 output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
1060}
1061
1062// ── gelu_mul_fused_batch ───────────────────────────────────────────────
1063// Fused GELU-tanh-approx × multiply for a batch (Gemma-1 FFN).
1064// Same layout as silu_mul_fused_batch but with `gelu_pytorch_tanh`.
1065kernel void gelu_mul_fused_batch(
1066 device const float* gate_up [[buffer(0)]], // [M, 2*n]
1067 device float* output [[buffer(1)]], // [M, n]
1068 constant uint& n [[buffer(2)]],
1069 constant uint& num_tokens [[buffer(3)]],
1070 uint id [[thread_position_in_grid]])
1071{
1072 uint total = num_tokens * n;
1073 if (id >= total) return;
1074 uint token = id / n;
1075 uint i = id % n;
1076 uint gu_base = token * 2 * n;
1077 constexpr float SQRT_2_OVER_PI = 0.7978845608f;
1078 float g = gate_up[gu_base + i];
1079 float u = gate_up[gu_base + n + i];
1080 float inner = SQRT_2_OVER_PI * (g + 0.044715f * g * g * g);
1081 float gelu = 0.5f * g * (1.0f + precise::tanh(inner));
1082 output[token * n + i] = gelu * u;
1083}
1084
1085// ── add_inplace_batch ──────────────────────────────────────────────────
1086// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
1087kernel void add_inplace_batch(
1088 device float* a [[buffer(0)]], // [M * n]
1089 device const float* b [[buffer(1)]], // [M * n]
1090 constant uint& total [[buffer(2)]], // M * n
1091 uint id [[thread_position_in_grid]])
1092{
1093 if (id >= total) return;
1094 a[id] += b[id];
1095}
1096
1097// ── scale_buffer ───────────────────────────────────────────────────────
1098// Multiply every element of a buffer by a scalar in place. Used by the
1099// Gemma-1 prefill path to scale embeddings by `sqrt(hidden_size)` after
1100// the batched embed lookup (Gemma-1 scales embeddings right after lookup
1101// in the HF reference forward pass).
1102kernel void scale_buffer(
1103 device float* data [[buffer(0)]],
1104 constant float& scale [[buffer(1)]],
1105 constant uint& count [[buffer(2)]],
1106 uint id [[thread_position_in_grid]])
1107{
1108 if (id >= count) return;
1109 data[id] *= scale;
1110}
1111
1112// ── copy_embedding_batch ───────────────────────────────────────────────
1113// Copy M embedding rows from embedding table to a contiguous batch buffer.
1114// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1115kernel void copy_embedding_batch(
1116 device const float* embed [[buffer(0)]], // [vocab_size, dim]
1117 device float* output [[buffer(1)]], // [M, dim]
1118 device const uint* tokens [[buffer(2)]], // [M]
1119 constant uint& dim [[buffer(3)]],
1120 constant uint& num_tokens [[buffer(4)]],
1121 uint id [[thread_position_in_grid]])
1122{
1123 uint total = num_tokens * dim;
1124 if (id >= total) return;
1125 uint token_idx = id / dim;
1126 uint d = id % dim;
1127 output[id] = embed[tokens[token_idx] * dim + d];
1128}
1129
1130// ── matmul_vec_batch ───────────────────────────────────────────────────
1131// Batched matrix-vector multiply: process M input vectors against the same
1132// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1133// Each threadgroup handles one (token, row_group) pair.
1134kernel void matmul_vec_batch(
1135 device const float* matrix [[buffer(0)]], // [rows, cols] weight
1136 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1137 device float* outputs [[buffer(2)]], // [M, rows] output batch
1138 constant uint& num_tokens [[buffer(3)]], // M
1139 constant uint& rows [[buffer(4)]],
1140 constant uint& cols [[buffer(5)]],
1141 uint tgid [[threadgroup_position_in_grid]],
1142 uint tid [[thread_index_in_threadgroup]],
1143 uint simd_lane [[thread_index_in_simdgroup]],
1144 uint simd_id [[simdgroup_index_in_threadgroup]])
1145{
1146 uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1147 uint token = tgid / row_tgs;
1148 uint tg_in_token = tgid % row_tgs;
1149 if (token >= num_tokens) return;
1150
1151 // Load this token's input vector into shared memory
1152 threadgroup float vec_tile[VEC_TILE_SIZE];
1153 device const float* input = inputs + token * cols;
1154 for (uint i = tid; i < cols; i += 256) {
1155 vec_tile[i] = input[i];
1156 }
1157 threadgroup_barrier(mem_flags::mem_threadgroup);
1158
1159 uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1160 if (row_base >= rows) return;
1161
1162 uint base0 = row_base * cols;
1163 uint base1 = (row_base + 1) * cols;
1164 uint base2 = (row_base + 2) * cols;
1165 uint base3 = (row_base + 3) * cols;
1166 uint base4 = (row_base + 4) * cols;
1167 uint base5 = (row_base + 5) * cols;
1168 uint base6 = (row_base + 6) * cols;
1169 uint base7 = (row_base + 7) * cols;
1170
1171 uint cols_vec4 = cols & ~127u;
1172 float4 sum4_0 = float4(0.0f);
1173 float4 sum4_1 = float4(0.0f);
1174 float4 sum4_2 = float4(0.0f);
1175 float4 sum4_3 = float4(0.0f);
1176 float4 sum4_4 = float4(0.0f);
1177 float4 sum4_5 = float4(0.0f);
1178 float4 sum4_6 = float4(0.0f);
1179 float4 sum4_7 = float4(0.0f);
1180
1181 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1182 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1183 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1184 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1185 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1186 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1187 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1188 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1189 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1190 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1191 }
1192
1193 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1194 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1195 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1196 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1197 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1198 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1199 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1200 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1201
1202 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1203 float vv = vec_tile[j];
1204 sum0 += matrix[base0 + j] * vv;
1205 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1206 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1207 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1208 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1209 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1210 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1211 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1212 }
1213
1214 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1215 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1216 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1217 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1218
1219 device float* output = outputs + token * rows;
1220 if (simd_lane == 0) {
1221 if (row_base < rows) output[row_base] = sum0;
1222 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1223 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1224 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1225 if (row_base + 4 < rows) output[row_base + 4] = sum4;
1226 if (row_base + 5 < rows) output[row_base + 5] = sum5;
1227 if (row_base + 6 < rows) output[row_base + 6] = sum6;
1228 if (row_base + 7 < rows) output[row_base + 7] = sum7;
1229 }
1230}
1231
1232// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1233// Batched Q8_0 matrix-vector multiply for M input vectors.
1234// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1235kernel void matmul_vec_q8_batch(
1236 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1237 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1238 device float* outputs [[buffer(2)]], // [M, rows] output batch
1239 constant uint& num_tokens [[buffer(3)]], // M
1240 constant uint& rows [[buffer(4)]],
1241 constant uint& cols [[buffer(5)]],
1242 uint tgid [[threadgroup_position_in_grid]],
1243 uint tid [[thread_index_in_threadgroup]],
1244 uint simd_lane [[thread_index_in_simdgroup]],
1245 uint simd_id [[simdgroup_index_in_threadgroup]])
1246{
1247 uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1248 uint token = tgid / row_tgs;
1249 uint tg_in_token = tgid % row_tgs;
1250 if (token >= num_tokens) return;
1251
1252 // Load this token's input vector into shared memory
1253 threadgroup float vec_tile[VEC_TILE_SIZE];
1254 device const float* input = inputs + token * cols;
1255 for (uint i = tid; i < cols; i += 256) {
1256 vec_tile[i] = input[i];
1257 }
1258 threadgroup_barrier(mem_flags::mem_threadgroup);
1259
1260 uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1261 if (row_base >= rows) return;
1262
1263 uint blocks_per_row = cols / 32;
1264 uint row_bytes = blocks_per_row * 34;
1265
1266 device const uchar* r0 = matrix + row_base * row_bytes;
1267 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1268 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1269 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1270
1271 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1272
1273 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1274 uint bb = blk * 34;
1275 uint vb = blk * 32;
1276
1277 float sc0 = float(*(device const half*)(r0 + bb));
1278 float sc1 = float(*(device const half*)(r1 + bb));
1279 float sc2 = float(*(device const half*)(r2 + bb));
1280 float sc3 = float(*(device const half*)(r3 + bb));
1281
1282 // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1283 // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1284 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1285 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1286 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1287 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1288
1289 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1290 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1291 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1292 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1293 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1294 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1295 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1296 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1297
1298 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1299 short4 _s = short4(SHORT4); \
1300 char2 _a = as_type<char2>(_s.x); \
1301 char2 _b = as_type<char2>(_s.y); \
1302 char2 _c = as_type<char2>(_s.z); \
1303 char2 _d = as_type<char2>(_s.w); \
1304 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1305 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1306 }
1307
1308 float4 f0, f1;
1309 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1310
1311 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1312 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1313 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1314 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1315
1316 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1317 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1318 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1319 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1320
1321 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1322 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1323 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1324 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1325
1326 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1327 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1328 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1329 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1330
1331 #undef Q8_UNPACK8
1332
1333 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1334 }
1335
1336 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1337 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1338
1339 device float* output = outputs + token * rows;
1340 if (simd_lane == 0) {
1341 if (row_base < rows) output[row_base] = sum0;
1342 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1343 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1344 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1345 }
1346}
1347
1348// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1349// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1350// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1351// the Q8_0 weight blocks are fetched once from device memory and reused for
1352// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1353// per-token dispatch).
1354//
1355// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1356// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1357// with simd_sum. Token vectors are read directly from device memory inside
1358// the block loop (not cached in shared memory) so intermediate_size up to
1359// 8192 fits without spilling threadgroup memory.
1360constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1361
1362kernel void matmul_q8_gemm_batch(
1363 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1364 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1365 device float* outputs [[buffer(2)]], // [M, rows] output batch
1366 constant uint& num_tokens [[buffer(3)]], // M
1367 constant uint& rows [[buffer(4)]],
1368 constant uint& cols [[buffer(5)]],
1369 uint2 tgid [[threadgroup_position_in_grid]],
1370 uint tid [[thread_index_in_threadgroup]],
1371 uint simd_lane [[thread_index_in_simdgroup]],
1372 uint simd_id [[simdgroup_index_in_threadgroup]])
1373{
1374 uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1375 uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1376 if (row_base >= rows || tok_base >= num_tokens) return;
1377
1378 // How many tokens in this tile are valid?
1379 uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1380
1381 uint blocks_per_row = cols / 32;
1382 uint row_bytes = blocks_per_row * 34;
1383
1384 device const uchar* r0 = matrix + row_base * row_bytes;
1385 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1386 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1387 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1388
1389 // Accumulators: 4 tokens × 4 rows per simdgroup.
1390 float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1391 float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1392 float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1393 float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1394
1395 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1396 uint bb = blk * 34;
1397 uint vb = blk * 32;
1398
1399 // ── Load weight data ONCE per block (reused across all tokens) ──
1400 float sc0 = float(*(device const half*)(r0 + bb));
1401 float sc1 = float(*(device const half*)(r1 + bb));
1402 float sc2 = float(*(device const half*)(r2 + bb));
1403 float sc3 = float(*(device const half*)(r3 + bb));
1404
1405 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1406 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1407 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1408 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1409
1410 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1411 short4 _s = short4(SHORT4); \
1412 char2 _a = as_type<char2>(_s.x); \
1413 char2 _b = as_type<char2>(_s.y); \
1414 char2 _c = as_type<char2>(_s.z); \
1415 char2 _d = as_type<char2>(_s.w); \
1416 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1417 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1418 }
1419
1420 // Unpack all 4 rows × 8 float4 weights (scaled). These live in
1421 // registers for the duration of the block and are dotted against
1422 // every token's vector tile.
1423 float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1424 float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1425 float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1426 float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1427
1428 Q8_UNPACK8(d0[0], w0_0, w0_1);
1429 Q8_UNPACK8(d0[1], w0_2, w0_3);
1430 Q8_UNPACK8(d0[2], w0_4, w0_5);
1431 Q8_UNPACK8(d0[3], w0_6, w0_7);
1432
1433 Q8_UNPACK8(d1[0], w1_0, w1_1);
1434 Q8_UNPACK8(d1[1], w1_2, w1_3);
1435 Q8_UNPACK8(d1[2], w1_4, w1_5);
1436 Q8_UNPACK8(d1[3], w1_6, w1_7);
1437
1438 Q8_UNPACK8(d2[0], w2_0, w2_1);
1439 Q8_UNPACK8(d2[1], w2_2, w2_3);
1440 Q8_UNPACK8(d2[2], w2_4, w2_5);
1441 Q8_UNPACK8(d2[3], w2_6, w2_7);
1442
1443 Q8_UNPACK8(d3[0], w3_0, w3_1);
1444 Q8_UNPACK8(d3[1], w3_2, w3_3);
1445 Q8_UNPACK8(d3[2], w3_4, w3_5);
1446 Q8_UNPACK8(d3[3], w3_6, w3_7);
1447
1448 #undef Q8_UNPACK8
1449
1450 // ── For each token, read vector and accumulate against shared weights ──
1451 // Token 0 (always valid: tok_count >= 1).
1452 {
1453 device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1454 float4 v0 = *(device const float4*)(a0);
1455 float4 v1 = *(device const float4*)(a0 + 4);
1456 float4 v2 = *(device const float4*)(a0 + 8);
1457 float4 v3 = *(device const float4*)(a0 + 12);
1458 float4 v4 = *(device const float4*)(a0 + 16);
1459 float4 v5 = *(device const float4*)(a0 + 20);
1460 float4 v6 = *(device const float4*)(a0 + 24);
1461 float4 v7 = *(device const float4*)(a0 + 28);
1462 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1463 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1464 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1465 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1466 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1467 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1468 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1469 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1470 s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1471 }
1472 // Token 1
1473 if (tok_count > 1) {
1474 device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1475 float4 v0 = *(device const float4*)(a1);
1476 float4 v1 = *(device const float4*)(a1 + 4);
1477 float4 v2 = *(device const float4*)(a1 + 8);
1478 float4 v3 = *(device const float4*)(a1 + 12);
1479 float4 v4 = *(device const float4*)(a1 + 16);
1480 float4 v5 = *(device const float4*)(a1 + 20);
1481 float4 v6 = *(device const float4*)(a1 + 24);
1482 float4 v7 = *(device const float4*)(a1 + 28);
1483 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1484 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1485 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1486 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1487 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1488 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1489 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1490 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1491 s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1492 }
1493 // Token 2
1494 if (tok_count > 2) {
1495 device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1496 float4 v0 = *(device const float4*)(a2);
1497 float4 v1 = *(device const float4*)(a2 + 4);
1498 float4 v2 = *(device const float4*)(a2 + 8);
1499 float4 v3 = *(device const float4*)(a2 + 12);
1500 float4 v4 = *(device const float4*)(a2 + 16);
1501 float4 v5 = *(device const float4*)(a2 + 20);
1502 float4 v6 = *(device const float4*)(a2 + 24);
1503 float4 v7 = *(device const float4*)(a2 + 28);
1504 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1505 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1506 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1507 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1508 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1509 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1510 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1511 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1512 s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1513 }
1514 // Token 3
1515 if (tok_count > 3) {
1516 device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1517 float4 v0 = *(device const float4*)(a3);
1518 float4 v1 = *(device const float4*)(a3 + 4);
1519 float4 v2 = *(device const float4*)(a3 + 8);
1520 float4 v3 = *(device const float4*)(a3 + 12);
1521 float4 v4 = *(device const float4*)(a3 + 16);
1522 float4 v5 = *(device const float4*)(a3 + 20);
1523 float4 v6 = *(device const float4*)(a3 + 24);
1524 float4 v7 = *(device const float4*)(a3 + 28);
1525 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1526 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1527 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1528 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1529 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1530 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1531 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1532 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1533 s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1534 }
1535 }
1536
1537 // simdgroup reduction
1538 s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1539 s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1540 s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1541 s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1542
1543 if (simd_lane == 0) {
1544 device float* o0 = outputs + (tok_base + 0) * rows;
1545 if (row_base < rows) o0[row_base] = s00;
1546 if (row_base + 1 < rows) o0[row_base + 1] = s01;
1547 if (row_base + 2 < rows) o0[row_base + 2] = s02;
1548 if (row_base + 3 < rows) o0[row_base + 3] = s03;
1549
1550 if (tok_count > 1) {
1551 device float* o1 = outputs + (tok_base + 1) * rows;
1552 if (row_base < rows) o1[row_base] = s10;
1553 if (row_base + 1 < rows) o1[row_base + 1] = s11;
1554 if (row_base + 2 < rows) o1[row_base + 2] = s12;
1555 if (row_base + 3 < rows) o1[row_base + 3] = s13;
1556 }
1557 if (tok_count > 2) {
1558 device float* o2 = outputs + (tok_base + 2) * rows;
1559 if (row_base < rows) o2[row_base] = s20;
1560 if (row_base + 1 < rows) o2[row_base + 1] = s21;
1561 if (row_base + 2 < rows) o2[row_base + 2] = s22;
1562 if (row_base + 3 < rows) o2[row_base + 3] = s23;
1563 }
1564 if (tok_count > 3) {
1565 device float* o3 = outputs + (tok_base + 3) * rows;
1566 if (row_base < rows) o3[row_base] = s30;
1567 if (row_base + 1 < rows) o3[row_base + 1] = s31;
1568 if (row_base + 2 < rows) o3[row_base + 2] = s32;
1569 if (row_base + 3 < rows) o3[row_base + 3] = s33;
1570 }
1571 }
1572}
1573
1574// ── matmul_q8_mma ──────────────────────────────────────────────────────
1575// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1576// simdgroup_matrix tiles (simdgroup_multiply_accumulate). This dispatches
1577// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1578// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1579//
1580// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1581// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1582// simdgroup_matrix<float, 8, 8> accumulator. Weight bytes are cooperatively
1583// dequantized into threadgroup memory once per block and reused by all
1584// simdgroups in the tile.
1585//
1586// Assumptions (verified in the dispatch helper, falls back otherwise):
1587// * cols % 32 == 0 (one Q8_0 block per K chunk)
1588// * rows % 16 == 0 (tile-aligned; true for all supported architectures)
1589// * num_tokens may be any value; partial row at the tile boundary is handled
1590// via a scratch copy path.
1591constant constexpr uint MMA_TOK_TILE = 16;
1592constant constexpr uint MMA_ROW_TILE = 16;
1593
1594kernel void matmul_q8_mma(
1595 device const uchar* matrix [[buffer(0)]], // Q8_0 [rows, cols/32 * 34]
1596 device const float* inputs [[buffer(1)]], // [M, cols]
1597 device float* outputs [[buffer(2)]], // [M, rows]
1598 constant uint& num_tokens [[buffer(3)]],
1599 constant uint& rows [[buffer(4)]],
1600 constant uint& cols [[buffer(5)]],
1601 uint2 tgid [[threadgroup_position_in_grid]],
1602 uint tid [[thread_index_in_threadgroup]],
1603 uint simd_id [[simdgroup_index_in_threadgroup]])
1604{
1605 uint row_base = tgid.x * MMA_ROW_TILE;
1606 uint tok_base = tgid.y * MMA_TOK_TILE;
1607 if (row_base >= rows || tok_base >= num_tokens) return;
1608
1609 // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1610 threadgroup float w_tile[MMA_ROW_TILE * 32];
1611 threadgroup float t_tile[MMA_TOK_TILE * 32];
1612
1613 // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1614 uint sg_tok_base = (simd_id / 2) * 8; // row within output tile (token dim)
1615 uint sg_row_base = (simd_id % 2) * 8; // col within output tile (row dim)
1616
1617 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1618
1619 uint blocks_per_row = cols / 32;
1620 uint row_bytes = blocks_per_row * 34;
1621
1622 for (uint blk = 0; blk < blocks_per_row; blk++) {
1623 // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1624 // 512 floats / 128 threads = 4 floats per thread.
1625 {
1626 uint base = tid * 4;
1627 for (uint ii = 0; ii < 4; ii++) {
1628 uint idx = base + ii;
1629 uint r = idx / 32;
1630 uint k = idx % 32;
1631 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1632 float sc = float(*(device const half*)rp);
1633 int ival = int(*(device const int8_t*)(rp + 2 + k));
1634 w_tile[r * 32 + k] = float(ival) * sc;
1635 }
1636 }
1637
1638 // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1639 {
1640 uint base = tid * 4;
1641 for (uint ii = 0; ii < 4; ii++) {
1642 uint idx = base + ii;
1643 uint m = idx / 32;
1644 uint k = idx % 32;
1645 uint tok = tok_base + m;
1646 t_tile[m * 32 + k] = (tok < num_tokens)
1647 ? inputs[tok * cols + blk * 32 + k]
1648 : 0.0f;
1649 }
1650 }
1651
1652 threadgroup_barrier(mem_flags::mem_threadgroup);
1653
1654 // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1655 // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k] (M×K, no transpose)
1656 // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k] (loaded transposed → K×R)
1657 // C[m, r] += A[m, k] * B[k, r]
1658 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1659 simdgroup_matrix<float, 8, 8> A, B;
1660 simdgroup_load(A,
1661 t_tile + sg_tok_base * 32 + k_sub * 8,
1662 32,
1663 ulong2(0, 0),
1664 false);
1665 simdgroup_load(B,
1666 w_tile + sg_row_base * 32 + k_sub * 8,
1667 32,
1668 ulong2(0, 0),
1669 true);
1670 simdgroup_multiply_accumulate(C, A, B, C);
1671 }
1672
1673 threadgroup_barrier(mem_flags::mem_threadgroup);
1674 }
1675
1676 // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1677 // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1678 uint out_tok = tok_base + sg_tok_base;
1679 uint out_row = row_base + sg_row_base;
1680 bool full_tok = (out_tok + 8 <= num_tokens);
1681 if (full_tok) {
1682 // Fast path: entire 8×8 sub-tile is in-bounds.
1683 simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1684 } else if (out_tok < num_tokens) {
1685 // Partial row at the last token tile: stage in per-simdgroup scratch
1686 // and scalar-copy the valid rows.
1687 threadgroup float scratch[4 * 64];
1688 simdgroup_store(C, scratch + simd_id * 64, 8);
1689 simdgroup_barrier(mem_flags::mem_threadgroup);
1690 uint lane = tid % 32;
1691 if (lane == 0) {
1692 uint valid = num_tokens - out_tok; // 1..7
1693 for (uint m = 0; m < valid; m++) {
1694 device float* dst = outputs + (out_tok + m) * rows + out_row;
1695 threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1696 dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1697 dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1698 }
1699 }
1700 }
1701}
1702
1703// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1704// Larger-tile variant of matmul_q8_mma for long-context prefill.
1705//
1706// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1707// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1708// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1709//
1710// simd_id = 2*sg_tok_idx + sg_row_half (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1711// output sub-tiles (tok, row):
1712// (sg_tok_idx*8, sg_row_half*16 + 0) -> C_a
1713// (sg_tok_idx*8, sg_row_half*16 + 8) -> C_b
1714//
1715// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1716// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1717// kernel — and halves the number of threadgroups vs the 16×16 tile.
1718//
1719// Assumptions (verified in dispatch helper, fallback otherwise):
1720// * cols % 32 == 0
1721// * rows % 32 == 0
1722constant constexpr uint MMA32_TOK_TILE = 32;
1723constant constexpr uint MMA32_ROW_TILE = 32;
1724
1725kernel void matmul_q8_mma32(
1726 device const uchar* matrix [[buffer(0)]],
1727 device const float* inputs [[buffer(1)]],
1728 device float* outputs [[buffer(2)]],
1729 constant uint& num_tokens [[buffer(3)]],
1730 constant uint& rows [[buffer(4)]],
1731 constant uint& cols [[buffer(5)]],
1732 uint2 tgid [[threadgroup_position_in_grid]],
1733 uint tid [[thread_index_in_threadgroup]],
1734 uint simd_id [[simdgroup_index_in_threadgroup]])
1735{
1736 uint row_base = tgid.x * MMA32_ROW_TILE;
1737 uint tok_base = tgid.y * MMA32_TOK_TILE;
1738 if (row_base >= rows || tok_base >= num_tokens) return;
1739
1740 // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1741 threadgroup float w_tile[MMA32_ROW_TILE * 32];
1742 threadgroup float t_tile[MMA32_TOK_TILE * 32];
1743
1744 uint sg_tok_idx = simd_id / 2; // 0..3
1745 uint sg_row_half = simd_id % 2; // 0..1
1746 uint sg_tok_base = sg_tok_idx * 8;
1747 uint sg_row_base_a = sg_row_half * 16 + 0;
1748 uint sg_row_base_b = sg_row_half * 16 + 8;
1749
1750 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1751 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1752
1753 uint blocks_per_row = cols / 32;
1754 uint row_bytes = blocks_per_row * 34;
1755
1756 for (uint blk = 0; blk < blocks_per_row; blk++) {
1757 // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1758 {
1759 uint base = tid * 4;
1760 for (uint ii = 0; ii < 4; ii++) {
1761 uint idx = base + ii;
1762 uint r = idx / 32;
1763 uint k = idx % 32;
1764 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1765 float sc = float(*(device const half*)rp);
1766 int ival = int(*(device const int8_t*)(rp + 2 + k));
1767 w_tile[r * 32 + k] = float(ival) * sc;
1768 }
1769 }
1770
1771 // Cooperative token tile load.
1772 {
1773 uint base = tid * 4;
1774 for (uint ii = 0; ii < 4; ii++) {
1775 uint idx = base + ii;
1776 uint m = idx / 32;
1777 uint k = idx % 32;
1778 uint tok = tok_base + m;
1779 t_tile[m * 32 + k] = (tok < num_tokens)
1780 ? inputs[tok * cols + blk * 32 + k]
1781 : 0.0f;
1782 }
1783 }
1784
1785 threadgroup_barrier(mem_flags::mem_threadgroup);
1786
1787 // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1788 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1789 simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1790 simdgroup_load(A,
1791 t_tile + sg_tok_base * 32 + k_sub * 8,
1792 32,
1793 ulong2(0, 0),
1794 false);
1795 simdgroup_load(B_a,
1796 w_tile + sg_row_base_a * 32 + k_sub * 8,
1797 32,
1798 ulong2(0, 0),
1799 true);
1800 simdgroup_load(B_b,
1801 w_tile + sg_row_base_b * 32 + k_sub * 8,
1802 32,
1803 ulong2(0, 0),
1804 true);
1805 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1806 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1807 }
1808
1809 threadgroup_barrier(mem_flags::mem_threadgroup);
1810 }
1811
1812 // Store both 8×8 accumulators. rows is always MMA32_ROW_TILE-aligned
1813 // (verified in dispatch), so full simdgroup_store is safe for the row
1814 // dimension; only the last token tile may be partial.
1815 uint out_tok = tok_base + sg_tok_base;
1816 uint out_row_a = row_base + sg_row_base_a;
1817 uint out_row_b = row_base + sg_row_base_b;
1818 bool full_tok = (out_tok + 8 <= num_tokens);
1819 if (full_tok) {
1820 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1821 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1822 } else if (out_tok < num_tokens) {
1823 threadgroup float scratch[8 * 2 * 64]; // 8 simdgroups × 2 accs × 64 floats
1824 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1825 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1826 simdgroup_barrier(mem_flags::mem_threadgroup);
1827 uint lane = tid % 32;
1828 if (lane == 0) {
1829 uint valid = num_tokens - out_tok; // 1..7
1830 for (uint m = 0; m < valid; m++) {
1831 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1832 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1833 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1834 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1835 for (uint j = 0; j < 8; j++) {
1836 dst_a[j] = src_a[j];
1837 dst_b[j] = src_b[j];
1838 }
1839 }
1840 }
1841 }
1842}
1843
1844// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1845// FP16 threadgroup-tile variant of matmul_q8_mma32.
1846//
1847// Stores dequantized weights and token inputs as `half` in threadgroup
1848// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1849// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1850// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1851// representation preserves the full quantized dynamic range. Token
1852// activations stay numerically safe because the subsequent
1853// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1854//
1855// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1856// accumulators each. Primary win vs mma32 is occupancy at moderate
1857// prefill lengths where the GPU is wave-starved.
1858kernel void matmul_q8_mma32_h(
1859 device const uchar* matrix [[buffer(0)]],
1860 device const float* inputs [[buffer(1)]],
1861 device float* outputs [[buffer(2)]],
1862 constant uint& num_tokens [[buffer(3)]],
1863 constant uint& rows [[buffer(4)]],
1864 constant uint& cols [[buffer(5)]],
1865 uint2 tgid [[threadgroup_position_in_grid]],
1866 uint tid [[thread_index_in_threadgroup]],
1867 uint simd_id [[simdgroup_index_in_threadgroup]])
1868{
1869 uint row_base = tgid.x * MMA32_ROW_TILE;
1870 uint tok_base = tgid.y * MMA32_TOK_TILE;
1871 if (row_base >= rows || tok_base >= num_tokens) return;
1872
1873 // 32×32 half tiles — 2 KB each, 4 KB total.
1874 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1875 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1876
1877 uint sg_tok_idx = simd_id / 2;
1878 uint sg_row_half = simd_id % 2;
1879 uint sg_tok_base = sg_tok_idx * 8;
1880 uint sg_row_base_a = sg_row_half * 16 + 0;
1881 uint sg_row_base_b = sg_row_half * 16 + 8;
1882
1883 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1884 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1885
1886 uint blocks_per_row = cols / 32;
1887 uint row_bytes = blocks_per_row * 34;
1888
1889 for (uint blk = 0; blk < blocks_per_row; blk++) {
1890 // Cooperative weight dequantization to FP16.
1891 {
1892 uint base = tid * 4;
1893 for (uint ii = 0; ii < 4; ii++) {
1894 uint idx = base + ii;
1895 uint r = idx / 32;
1896 uint k = idx % 32;
1897 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1898 float sc = float(*(device const half*)rp);
1899 int ival = int(*(device const int8_t*)(rp + 2 + k));
1900 w_tile[r * 32 + k] = half(float(ival) * sc);
1901 }
1902 }
1903
1904 // Cooperative token tile load (f32 → f16 narrowing).
1905 {
1906 uint base = tid * 4;
1907 for (uint ii = 0; ii < 4; ii++) {
1908 uint idx = base + ii;
1909 uint m = idx / 32;
1910 uint k = idx % 32;
1911 uint tok = tok_base + m;
1912 t_tile[m * 32 + k] = (tok < num_tokens)
1913 ? half(inputs[tok * cols + blk * 32 + k])
1914 : half(0);
1915 }
1916 }
1917
1918 threadgroup_barrier(mem_flags::mem_threadgroup);
1919
1920 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1921 simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1922 simdgroup_load(A,
1923 t_tile + sg_tok_base * 32 + k_sub * 8,
1924 32,
1925 ulong2(0, 0),
1926 false);
1927 simdgroup_load(B_a,
1928 w_tile + sg_row_base_a * 32 + k_sub * 8,
1929 32,
1930 ulong2(0, 0),
1931 true);
1932 simdgroup_load(B_b,
1933 w_tile + sg_row_base_b * 32 + k_sub * 8,
1934 32,
1935 ulong2(0, 0),
1936 true);
1937 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1938 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1939 }
1940
1941 threadgroup_barrier(mem_flags::mem_threadgroup);
1942 }
1943
1944 uint out_tok = tok_base + sg_tok_base;
1945 uint out_row_a = row_base + sg_row_base_a;
1946 uint out_row_b = row_base + sg_row_base_b;
1947 bool full_tok = (out_tok + 8 <= num_tokens);
1948 if (full_tok) {
1949 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1950 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1951 } else if (out_tok < num_tokens) {
1952 threadgroup float scratch[8 * 2 * 64];
1953 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1954 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1955 simdgroup_barrier(mem_flags::mem_threadgroup);
1956 uint lane = tid % 32;
1957 if (lane == 0) {
1958 uint valid = num_tokens - out_tok;
1959 for (uint m = 0; m < valid; m++) {
1960 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1961 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1962 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1963 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1964 for (uint j = 0; j < 8; j++) {
1965 dst_a[j] = src_a[j];
1966 dst_b[j] = src_b[j];
1967 }
1968 }
1969 }
1970 }
1971}
1972
1973// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1974// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1975//
1976// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1977// 4 simdgroups × **2×2 grid** of 8×8 accumulators each. Per simdgroup:
1978// C_00 (tok 0..8, row 0..8) C_01 (tok 0..8, row 8..16)
1979// C_10 (tok 8..16, row 0..8) C_11 (tok 8..16, row 8..16)
1980// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1981//
1982// Per K_sub iteration: load two A fragments and two B fragments, then run
1983// **four** MMA instructions reusing A_top with both B's and A_bot with
1984// both B's. That's double the FLOP-per-simdgroup-load compared to the
1985// 2-accumulator kernel and halves the thread count per threadgroup (128
1986// threads), which often improves occupancy on Apple GPUs where the
1987// concurrent-thread budget is the tighter limit than shared-memory size.
1988kernel void matmul_q8_mma32_h4(
1989 device const uchar* matrix [[buffer(0)]],
1990 device const float* inputs [[buffer(1)]],
1991 device float* outputs [[buffer(2)]],
1992 constant uint& num_tokens [[buffer(3)]],
1993 constant uint& rows [[buffer(4)]],
1994 constant uint& cols [[buffer(5)]],
1995 uint2 tgid [[threadgroup_position_in_grid]],
1996 uint tid [[thread_index_in_threadgroup]],
1997 uint simd_id [[simdgroup_index_in_threadgroup]])
1998{
1999 uint row_base = tgid.x * MMA32_ROW_TILE;
2000 uint tok_base = tgid.y * MMA32_TOK_TILE;
2001 if (row_base >= rows || tok_base >= num_tokens) return;
2002
2003 // 32×32 FP16 tiles, 4 KB total.
2004 threadgroup half w_tile[MMA32_ROW_TILE * 32];
2005 threadgroup half t_tile[MMA32_TOK_TILE * 32];
2006
2007 // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
2008 uint sg_tok_q = simd_id / 2; // 0..1
2009 uint sg_row_q = simd_id % 2; // 0..1
2010 uint sg_tok_base = sg_tok_q * 16;
2011 uint sg_row_base = sg_row_q * 16;
2012
2013 simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
2014 simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
2015 simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
2016 simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
2017
2018 uint blocks_per_row = cols / 32;
2019 uint row_bytes = blocks_per_row * 34;
2020
2021 for (uint blk = 0; blk < blocks_per_row; blk++) {
2022 // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
2023 {
2024 uint base = tid * 8;
2025 for (uint ii = 0; ii < 8; ii++) {
2026 uint idx = base + ii;
2027 uint r = idx / 32;
2028 uint k = idx % 32;
2029 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2030 float sc = float(*(device const half*)rp);
2031 int ival = int(*(device const int8_t*)(rp + 2 + k));
2032 w_tile[r * 32 + k] = half(float(ival) * sc);
2033 }
2034 }
2035
2036 // Cooperative token tile load.
2037 {
2038 uint base = tid * 8;
2039 for (uint ii = 0; ii < 8; ii++) {
2040 uint idx = base + ii;
2041 uint m = idx / 32;
2042 uint k = idx % 32;
2043 uint tok = tok_base + m;
2044 t_tile[m * 32 + k] = (tok < num_tokens)
2045 ? half(inputs[tok * cols + blk * 32 + k])
2046 : half(0);
2047 }
2048 }
2049
2050 threadgroup_barrier(mem_flags::mem_threadgroup);
2051
2052 // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
2053 for (uint k_sub = 0; k_sub < 4; k_sub++) {
2054 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2055 simdgroup_load(A_top,
2056 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2057 32,
2058 ulong2(0, 0),
2059 false);
2060 simdgroup_load(A_bot,
2061 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2062 32,
2063 ulong2(0, 0),
2064 false);
2065 simdgroup_load(B_lo,
2066 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2067 32,
2068 ulong2(0, 0),
2069 true);
2070 simdgroup_load(B_hi,
2071 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2072 32,
2073 ulong2(0, 0),
2074 true);
2075 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2076 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2077 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2078 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2079 }
2080
2081 threadgroup_barrier(mem_flags::mem_threadgroup);
2082 }
2083
2084 // Store 4 output tiles. Full-tile fast path assumes full 16×16 valid.
2085 uint out_tok_top = tok_base + sg_tok_base + 0;
2086 uint out_tok_bot = tok_base + sg_tok_base + 8;
2087 uint out_row_lo = row_base + sg_row_base + 0;
2088 uint out_row_hi = row_base + sg_row_base + 8;
2089 bool full = (out_tok_bot + 8 <= num_tokens);
2090 if (full) {
2091 simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
2092 simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
2093 simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
2094 simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
2095 } else {
2096 // Partial-token fallback via per-simdgroup scratch.
2097 threadgroup float scratch[4 * 4 * 64]; // 4 simdgroups × 4 accs × 64
2098 uint sg_base = simd_id * 256;
2099 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2100 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2101 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2102 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2103 simdgroup_barrier(mem_flags::mem_threadgroup);
2104 uint lane = tid % 32;
2105 if (lane == 0) {
2106 for (uint m = 0; m < 8; m++) {
2107 uint t_top = out_tok_top + m;
2108 if (t_top < num_tokens) {
2109 device float* dst0 = outputs + t_top * rows + out_row_lo;
2110 device float* dst1 = outputs + t_top * rows + out_row_hi;
2111 threadgroup const float* src0 = scratch + sg_base + 0 + m * 8;
2112 threadgroup const float* src1 = scratch + sg_base + 64 + m * 8;
2113 for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2114 }
2115 uint t_bot = out_tok_bot + m;
2116 if (t_bot < num_tokens) {
2117 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2118 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2119 threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2120 threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2121 for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2122 }
2123 }
2124 }
2125 }
2126}
2127
2128// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2129// All-half MMA variant of matmul_q8_mma32_h4.
2130//
2131// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2132// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2133// rate (dual-issue FMA), so if Q8_0 precision holds through half
2134// accumulation this kernel can double the effective FLOP throughput on
2135// matmul-bound prefill.
2136//
2137// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2138// activations at each layer are bounded (post-RMSNorm ≈ O(1)). Summing
2139// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2140// precision on extreme values, but the inputs are already quantized so the
2141// per-product error floor is higher than the half-precision rounding error.
2142// We verify correctness on 135M / 1B / 3B before enabling.
2143kernel void matmul_q8_mma32_hh4(
2144 device const uchar* matrix [[buffer(0)]],
2145 device const float* inputs [[buffer(1)]],
2146 device float* outputs [[buffer(2)]],
2147 constant uint& num_tokens [[buffer(3)]],
2148 constant uint& rows [[buffer(4)]],
2149 constant uint& cols [[buffer(5)]],
2150 uint2 tgid [[threadgroup_position_in_grid]],
2151 uint tid [[thread_index_in_threadgroup]],
2152 uint simd_id [[simdgroup_index_in_threadgroup]])
2153{
2154 uint row_base = tgid.x * MMA32_ROW_TILE;
2155 uint tok_base = tgid.y * MMA32_TOK_TILE;
2156 if (row_base >= rows || tok_base >= num_tokens) return;
2157
2158 threadgroup half w_tile[MMA32_ROW_TILE * 32];
2159 threadgroup half t_tile[MMA32_TOK_TILE * 32];
2160
2161 uint sg_tok_q = simd_id / 2;
2162 uint sg_row_q = simd_id % 2;
2163 uint sg_tok_base = sg_tok_q * 16;
2164 uint sg_row_base = sg_row_q * 16;
2165
2166 simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2167 simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2168 simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2169 simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2170
2171 uint blocks_per_row = cols / 32;
2172 uint row_bytes = blocks_per_row * 34;
2173
2174 for (uint blk = 0; blk < blocks_per_row; blk++) {
2175 {
2176 uint base = tid * 8;
2177 for (uint ii = 0; ii < 8; ii++) {
2178 uint idx = base + ii;
2179 uint r = idx / 32;
2180 uint k = idx % 32;
2181 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2182 float sc = float(*(device const half*)rp);
2183 int ival = int(*(device const int8_t*)(rp + 2 + k));
2184 w_tile[r * 32 + k] = half(float(ival) * sc);
2185 }
2186 }
2187 {
2188 uint base = tid * 8;
2189 for (uint ii = 0; ii < 8; ii++) {
2190 uint idx = base + ii;
2191 uint m = idx / 32;
2192 uint k = idx % 32;
2193 uint tok = tok_base + m;
2194 t_tile[m * 32 + k] = (tok < num_tokens)
2195 ? half(inputs[tok * cols + blk * 32 + k])
2196 : half(0);
2197 }
2198 }
2199 threadgroup_barrier(mem_flags::mem_threadgroup);
2200
2201 for (uint k_sub = 0; k_sub < 4; k_sub++) {
2202 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2203 simdgroup_load(A_top,
2204 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2205 32, ulong2(0, 0), false);
2206 simdgroup_load(A_bot,
2207 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2208 32, ulong2(0, 0), false);
2209 simdgroup_load(B_lo,
2210 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2211 32, ulong2(0, 0), true);
2212 simdgroup_load(B_hi,
2213 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2214 32, ulong2(0, 0), true);
2215 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2216 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2217 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2218 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2219 }
2220
2221 threadgroup_barrier(mem_flags::mem_threadgroup);
2222 }
2223
2224 // Store half accumulators via scratch (must widen to f32 for device output).
2225 uint out_tok_top = tok_base + sg_tok_base + 0;
2226 uint out_tok_bot = tok_base + sg_tok_base + 8;
2227 uint out_row_lo = row_base + sg_row_base + 0;
2228 uint out_row_hi = row_base + sg_row_base + 8;
2229
2230 threadgroup half scratch[4 * 4 * 64];
2231 uint sg_base = simd_id * 256;
2232 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2233 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2234 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2235 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2236 simdgroup_barrier(mem_flags::mem_threadgroup);
2237 uint lane = tid % 32;
2238 if (lane == 0) {
2239 for (uint m = 0; m < 8; m++) {
2240 uint t_top = out_tok_top + m;
2241 if (t_top < num_tokens) {
2242 device float* dst0 = outputs + t_top * rows + out_row_lo;
2243 device float* dst1 = outputs + t_top * rows + out_row_hi;
2244 threadgroup const half* src0 = scratch + sg_base + 0 + m * 8;
2245 threadgroup const half* src1 = scratch + sg_base + 64 + m * 8;
2246 for (uint j = 0; j < 8; j++) {
2247 dst0[j] = float(src0[j]);
2248 dst1[j] = float(src1[j]);
2249 }
2250 }
2251 uint t_bot = out_tok_bot + m;
2252 if (t_bot < num_tokens) {
2253 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2254 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2255 threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2256 threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2257 for (uint j = 0; j < 8; j++) {
2258 dst2[j] = float(src2[j]);
2259 dst3[j] = float(src3[j]);
2260 }
2261 }
2262 }
2263 }
2264}
2265
2266// ── add_bias_batch ─────────────────────────────────────────────────────
2267// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2268// Used for Qwen2 QKV bias after the fused qkv matmul.
2269// out[token, i] += bias[i] for i in 0..rows, token in 0..num_tokens
2270kernel void add_bias_batch(
2271 device float* out [[buffer(0)]], // [num_tokens, rows]
2272 device const float* bias [[buffer(1)]], // [rows]
2273 constant uint& num_tokens [[buffer(2)]],
2274 constant uint& rows [[buffer(3)]],
2275 uint id [[thread_position_in_grid]])
2276{
2277 uint total = num_tokens * rows;
2278 if (id >= total) return;
2279 uint i = id % rows;
2280 out[id] += bias[i];
2281}
2282
2283// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2284// Batched Q4_0 matrix-vector multiply for M input vectors.
2285// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2286kernel void matmul_vec_q4_batch(
2287 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes [rows, cols]
2288 device const float* inputs [[buffer(1)]], // [M, cols] input batch
2289 device float* outputs [[buffer(2)]], // [M, rows] output batch
2290 constant uint& num_tokens [[buffer(3)]], // M
2291 constant uint& rows [[buffer(4)]],
2292 constant uint& cols [[buffer(5)]],
2293 uint tgid [[threadgroup_position_in_grid]],
2294 uint tid [[thread_index_in_threadgroup]],
2295 uint simd_lane [[thread_index_in_simdgroup]],
2296 uint simd_id [[simdgroup_index_in_threadgroup]])
2297{
2298 uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2299 uint token = tgid / row_tgs;
2300 uint tg_in_token = tgid % row_tgs;
2301 if (token >= num_tokens) return;
2302
2303 threadgroup float vec_tile[VEC_TILE_SIZE];
2304 device const float* input = inputs + token * cols;
2305 for (uint i = tid; i < cols; i += 256) {
2306 vec_tile[i] = input[i];
2307 }
2308 threadgroup_barrier(mem_flags::mem_threadgroup);
2309
2310 uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2311 if (row_base >= rows) return;
2312
2313 uint blocks_per_row = cols / 32;
2314 uint row_bytes = blocks_per_row * 18;
2315
2316 device const uchar* r0 = matrix + row_base * row_bytes;
2317 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2318 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2319 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2320
2321 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2322
2323 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2324 uint bb = blk * 18;
2325 uint vb = blk * 32;
2326
2327 float sc0 = float(*(device const half*)(r0 + bb));
2328 float sc1 = float(*(device const half*)(r1 + bb));
2329 float sc2 = float(*(device const half*)(r2 + bb));
2330 float sc3 = float(*(device const half*)(r3 + bb));
2331
2332 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2333 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2334 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2335 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2336
2337 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2338 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2339 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2340 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2341 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2342 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2343 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2344 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2345
2346 float bd0=0, bd1=0, bd2=0, bd3=0;
2347 uchar4 b;
2348
2349 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2350 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2351 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2352 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2353
2354 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2355 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2356 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2357 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2358
2359 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2360 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2361 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2362 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2363
2364 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2365 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2366 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2367 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2368
2369 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2370 }
2371
2372 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2373 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2374
2375 device float* output = outputs + token * rows;
2376 if (simd_lane == 0) {
2377 if (row_base < rows) output[row_base] = sum0;
2378 if (row_base + 1 < rows) output[row_base + 1] = sum1;
2379 if (row_base + 2 < rows) output[row_base + 2] = sum2;
2380 if (row_base + 3 < rows) output[row_base + 3] = sum3;
2381 }
2382}
2383
2384// ── copy_kv_batch ─────────────────────────────────────────────────────
2385// Copy K or V from a strided batch QKV buffer to the KV cache.
2386// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2387// dst layout: contiguous [max_seq, kv_dim] cache.
2388kernel void copy_kv_batch(
2389 device const float* src [[buffer(0)]], // batch QKV buffer (f32)
2390 device half* dst [[buffer(1)]], // KV cache (f16)
2391 constant uint& M [[buffer(2)]], // num tokens in batch
2392 constant uint& kv_dim [[buffer(3)]], // elements per KV vector
2393 constant uint& base_pos [[buffer(4)]], // starting position in cache
2394 constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2395 constant uint& src_offset [[buffer(6)]], // float offset within each src row
2396 uint id [[thread_position_in_grid]])
2397{
2398 uint total = M * kv_dim;
2399 if (id >= total) return;
2400 uint token = id / kv_dim;
2401 uint d = id % kv_dim;
2402 uint dst_off = (base_pos + token) * kv_dim + d;
2403 uint src_off = token * src_stride + src_offset + d;
2404 dst[dst_off] = half(src[src_off]);
2405}
2406
2407// ── attention_batch ───────────────────────────────────────────────────
2408// Batched causal attention for prefill. Processes M tokens in one dispatch.
2409// Each threadgroup handles one (token, head) pair.
2410// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2411// Causal masking: token i can only attend to positions 0..base_pos+i.
2412kernel void attention_batch(
2413 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2414 device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
2415 device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
2416 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2417 constant uint& M [[buffer(4)]], // num tokens in batch
2418 constant uint& base_pos [[buffer(5)]], // starting position in KV cache
2419 constant uint& num_heads [[buffer(6)]],
2420 constant uint& num_kv_heads [[buffer(7)]],
2421 constant uint& head_dim [[buffer(8)]],
2422 constant uint& q_stride [[buffer(9)]], // floats per row in q_batch
2423 uint tgid [[threadgroup_position_in_grid]],
2424 uint tid [[thread_index_in_threadgroup]],
2425 uint simd_lane [[thread_index_in_simdgroup]],
2426 uint simd_id [[simdgroup_index_in_threadgroup]])
2427{
2428 // Grid: M * num_heads threadgroups
2429 uint token_idx = tgid / num_heads;
2430 uint head = tgid % num_heads;
2431 if (token_idx >= M) return;
2432
2433 uint kv_head = head / (num_heads / num_kv_heads);
2434 uint seq_len = base_pos + token_idx + 1; // causal: see positions 0..base_pos+token_idx
2435
2436 // Q offset uses strided layout (from batch QKV buffer)
2437 uint q_off = token_idx * q_stride + head * head_dim;
2438 // Output is contiguous [M, num_heads * head_dim]
2439 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2440
2441 // Shared memory for attention scores — sized to the effective max_seq_len
2442 // (4096 for all supported models) so long-context attention doesn't overflow.
2443 threadgroup float scores[ATTN_SCORES_SIZE];
2444
2445 // Step 1: Q * K^T with simdgroup reduction
2446 for (uint s = simd_id; s < seq_len; s += 8) {
2447 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2448 float dot = 0.0;
2449 for (uint d = simd_lane; d < head_dim; d += 32) {
2450 dot += q_batch[q_off + d] * k_cache[k_off + d];
2451 }
2452 dot = simd_sum(dot);
2453 if (simd_lane == 0) {
2454 scores[s] = dot * fast::rsqrt(float(head_dim));
2455 }
2456 }
2457 threadgroup_barrier(mem_flags::mem_threadgroup);
2458
2459 // Step 2: Softmax (cooperative)
2460 float local_max = -INFINITY;
2461 for (uint s = tid; s < seq_len; s += 256) {
2462 local_max = max(local_max, scores[s]);
2463 }
2464 local_max = simd_max(local_max);
2465 threadgroup float shared_max[8];
2466 if (simd_lane == 0) shared_max[simd_id] = local_max;
2467 threadgroup_barrier(mem_flags::mem_threadgroup);
2468 if (tid == 0) {
2469 float m = shared_max[0];
2470 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2471 shared_max[0] = m;
2472 }
2473 threadgroup_barrier(mem_flags::mem_threadgroup);
2474 float max_val = shared_max[0];
2475
2476 float local_sum = 0.0;
2477 for (uint s = tid; s < seq_len; s += 256) {
2478 scores[s] = fast::exp(scores[s] - max_val);
2479 local_sum += scores[s];
2480 }
2481 local_sum = simd_sum(local_sum);
2482 threadgroup float shared_sum[8];
2483 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2484 threadgroup_barrier(mem_flags::mem_threadgroup);
2485 if (tid == 0) {
2486 float total = 0.0;
2487 for (uint i = 0; i < 8; i++) total += shared_sum[i];
2488 shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2489 }
2490 threadgroup_barrier(mem_flags::mem_threadgroup);
2491 float inv_sum = shared_sum[0];
2492 for (uint s = tid; s < seq_len; s += 256) {
2493 scores[s] *= inv_sum;
2494 }
2495 threadgroup_barrier(mem_flags::mem_threadgroup);
2496
2497 // Step 3: scores * V using float4 vectorized loads
2498 // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2499 // This is much better than the scalar version where only 64 of 256 threads are active.
2500 uint v_stride = num_kv_heads * head_dim;
2501 uint head_dim4 = head_dim / 4;
2502 for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2503 uint d = d4 * 4;
2504 float4 acc = float4(0.0);
2505 uint v_base = kv_head * head_dim + d;
2506 uint seq_len4 = seq_len & ~3u;
2507 for (uint s = 0; s < seq_len4; s += 4) {
2508 float sc0 = scores[s];
2509 float sc1 = scores[s + 1];
2510 float sc2 = scores[s + 2];
2511 float sc3 = scores[s + 3];
2512 acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
2513 + sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
2514 + sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
2515 + sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
2516 }
2517 for (uint s = seq_len4; s < seq_len; s++) {
2518 acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
2519 }
2520 *(device float4*)(output_batch + out_off + d) = acc;
2521 }
2522 // Handle remaining dimensions not divisible by 4 (scalar fallback)
2523 for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2524 float acc = 0.0;
2525 uint v_base = kv_head * head_dim + d;
2526 for (uint s = 0; s < seq_len; s++) {
2527 acc += scores[s] * v_cache[s * v_stride + v_base];
2528 }
2529 output_batch[out_off + d] = acc;
2530 }
2531}
2532
2533// ── attention_flash_batch ─────────────────────────────────────────────
2534// Streaming attention with online softmax. Same grid as attention_batch
2535// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2536// matrix is never materialized. K/V positions are processed in a tile of
2537// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2538// the standard flash-attention recurrence:
2539//
2540// m_new = max(m_old, tile_max)
2541// alpha = exp(m_old - m_new)
2542// l_new = alpha * l_old + sum(exp(S - m_new))
2543// O_new = alpha * O_old + sum(exp(S - m_new) * V)
2544// O_final = O / l
2545//
2546// This removes the `scores[2048]` cap in attention_batch (which silently
2547// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2548// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2549//
2550// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2551constant constexpr uint FLASH_K_TILE = 32;
2552constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2553
2554kernel void attention_flash_batch(
2555 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2556 device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
2557 device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
2558 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2559 constant uint& M [[buffer(4)]],
2560 constant uint& base_pos [[buffer(5)]],
2561 constant uint& num_heads [[buffer(6)]],
2562 constant uint& num_kv_heads [[buffer(7)]],
2563 constant uint& head_dim [[buffer(8)]],
2564 constant uint& q_stride [[buffer(9)]],
2565 uint tgid [[threadgroup_position_in_grid]],
2566 uint tid [[thread_index_in_threadgroup]],
2567 uint simd_lane [[thread_index_in_simdgroup]],
2568 uint simd_id [[simdgroup_index_in_threadgroup]])
2569{
2570 uint token_idx = tgid / num_heads;
2571 uint head = tgid % num_heads;
2572 if (token_idx >= M) return;
2573
2574 uint kv_head = head / (num_heads / num_kv_heads);
2575 uint seq_len = base_pos + token_idx + 1; // causal: attend to [0, base_pos + token_idx]
2576 uint q_off = token_idx * q_stride + head * head_dim;
2577 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2578
2579 // Threadgroup state:
2580 // q_sh: Q vector for this (token, head), loaded once
2581 // o_sh: running output vector, updated each K tile
2582 // scores_sh: scores for the current K tile only
2583 // stats: [running max, running sum] (see flash-attention recurrence)
2584 threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2585 threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2586 threadgroup float scores_sh[FLASH_K_TILE];
2587 threadgroup float stats[2];
2588 threadgroup float sg_scratch[8]; // simdgroup-level reduction buffer
2589
2590 // --- Load Q (one row) and zero the running O ---
2591 for (uint d = tid; d < head_dim; d += 256) {
2592 q_sh[d] = q_batch[q_off + d];
2593 o_sh[d] = 0.0f;
2594 }
2595 if (tid == 0) {
2596 stats[0] = -INFINITY;
2597 stats[1] = 0.0f;
2598 }
2599 threadgroup_barrier(mem_flags::mem_threadgroup);
2600
2601 float scale = fast::rsqrt(float(head_dim));
2602 uint v_stride = num_kv_heads * head_dim;
2603 uint v_base = kv_head * head_dim;
2604
2605 // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2606 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2607 uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2608
2609 // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2610 // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2611 for (uint ti = simd_id; ti < tile_n; ti += 8) {
2612 uint k_off = (kv_base + ti) * v_stride + v_base; // same layout as V stride
2613 float dot = 0.0f;
2614 for (uint d = simd_lane; d < head_dim; d += 32) {
2615 dot += q_sh[d] * k_cache[k_off + d];
2616 }
2617 dot = simd_sum(dot);
2618 if (simd_lane == 0) {
2619 scores_sh[ti] = dot * scale;
2620 }
2621 }
2622 threadgroup_barrier(mem_flags::mem_threadgroup);
2623
2624 // [2] Tile max via cooperative reduction.
2625 float local_max = -INFINITY;
2626 for (uint s = tid; s < tile_n; s += 256) {
2627 local_max = max(local_max, scores_sh[s]);
2628 }
2629 local_max = simd_max(local_max);
2630 if (simd_lane == 0) {
2631 sg_scratch[simd_id] = local_max;
2632 }
2633 threadgroup_barrier(mem_flags::mem_threadgroup);
2634 // [3] Merge with running max, compute alpha, rescale running l.
2635 float m_new;
2636 float alpha;
2637 if (tid == 0) {
2638 float tile_max = sg_scratch[0];
2639 for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2640 float m_old = stats[0];
2641 m_new = max(m_old, tile_max);
2642 // First iteration: m_old = -inf → alpha = 0 (reset O).
2643 alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2644 stats[0] = m_new;
2645 stats[1] *= alpha;
2646 // Broadcast via sg_scratch.
2647 sg_scratch[0] = alpha;
2648 sg_scratch[1] = m_new;
2649 }
2650 threadgroup_barrier(mem_flags::mem_threadgroup);
2651 alpha = sg_scratch[0];
2652 m_new = sg_scratch[1];
2653
2654 // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2655 for (uint d = tid; d < head_dim; d += 256) {
2656 o_sh[d] *= alpha;
2657 }
2658 for (uint s = tid; s < tile_n; s += 256) {
2659 scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2660 }
2661 threadgroup_barrier(mem_flags::mem_threadgroup);
2662
2663 // [5] Tile sum → update running l.
2664 float local_sum = 0.0f;
2665 for (uint s = tid; s < tile_n; s += 256) {
2666 local_sum += scores_sh[s];
2667 }
2668 local_sum = simd_sum(local_sum);
2669 if (simd_lane == 0) {
2670 sg_scratch[simd_id] = local_sum;
2671 }
2672 threadgroup_barrier(mem_flags::mem_threadgroup);
2673 if (tid == 0) {
2674 float tile_sum = 0.0f;
2675 for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2676 stats[1] += tile_sum;
2677 }
2678 threadgroup_barrier(mem_flags::mem_threadgroup);
2679
2680 // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2681 for (uint d = tid; d < head_dim; d += 256) {
2682 float acc = 0.0f;
2683 for (uint s = 0; s < tile_n; s++) {
2684 acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2685 }
2686 o_sh[d] += acc;
2687 }
2688 threadgroup_barrier(mem_flags::mem_threadgroup);
2689 }
2690
2691 // --- Normalize and write output ---
2692 float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2693 for (uint d = tid; d < head_dim; d += 256) {
2694 output_batch[out_off + d] = o_sh[d] * inv_l;
2695 }
2696}
2697
2698// ── attention_mma_flash_batch ─────────────────────────────────────────
2699// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2700// both Q·K^T and P·V. Processes Q_BLOCK=8 tokens of one head per
2701// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2702// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2703// arithmetic.
2704//
2705// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2706// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2707// to attention_batch / attention_flash_batch otherwise.
2708//
2709// Online softmax recurrence is identical to attention_flash_batch but
2710// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2711constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2712constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2713constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2714
2715kernel void attention_mma_flash_batch(
2716 device const float* q_batch [[buffer(0)]],
2717 device const half* k_cache [[buffer(1)]],
2718 device const half* v_cache [[buffer(2)]],
2719 device float* output_batch [[buffer(3)]],
2720 constant uint& M [[buffer(4)]],
2721 constant uint& base_pos [[buffer(5)]],
2722 constant uint& num_heads [[buffer(6)]],
2723 constant uint& num_kv_heads [[buffer(7)]],
2724 constant uint& head_dim [[buffer(8)]],
2725 constant uint& q_stride [[buffer(9)]],
2726 uint2 tgid [[threadgroup_position_in_grid]],
2727 uint tid [[thread_index_in_threadgroup]],
2728 uint simd_lane [[thread_index_in_simdgroup]],
2729 uint simd_id [[simdgroup_index_in_threadgroup]])
2730{
2731 uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2732 uint head = tgid.y;
2733 if (q_block_start >= M) return;
2734 uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2735
2736 uint kv_head = head / (num_heads / num_kv_heads);
2737 // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2738 // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2739 uint seq_len = base_pos + q_block_start + q_valid;
2740 float scale = fast::rsqrt(float(head_dim));
2741
2742 uint kv_stride = num_kv_heads * head_dim;
2743 uint kv_base_off = kv_head * head_dim;
2744
2745 // ── Threadgroup memory ──
2746 // q_sh: [Q_BLOCK, head_dim] half — Q tile, loaded once
2747 // k_sh: [K_BLOCK, head_dim] half — K tile, refreshed per kv_base iter
2748 // v_sh: [K_BLOCK, head_dim] half — V tile, refreshed per kv_base iter
2749 // s_sh: [Q_BLOCK, K_BLOCK] float — raw Q·K^T scores, then scaled+masked
2750 // p_sh: [Q_BLOCK, K_BLOCK] half — softmax probabilities (for P·V MMA)
2751 // o_sh: [Q_BLOCK, head_dim] float — running output accumulator
2752 // m_sh: [Q_BLOCK] float — running max per Q row
2753 // l_sh: [Q_BLOCK] float — running softmax denominator per Q row
2754 // scratch: 4*Q_BLOCK floats — per-row reduction scratch
2755 threadgroup half q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2756 threadgroup half k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2757 threadgroup half v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2758 threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2759 threadgroup half p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2760 threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2761 threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2762 threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2763 threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2764
2765 // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2766 uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2767 for (uint i = tid; i < qblock_elems; i += 128) {
2768 uint q = i / head_dim;
2769 uint d = i % head_dim;
2770 if (q < q_valid) {
2771 uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2772 q_sh[q * head_dim + d] = half(q_batch[q_off]);
2773 } else {
2774 q_sh[q * head_dim + d] = half(0);
2775 }
2776 o_sh[q * head_dim + d] = 0.0f;
2777 }
2778 if (tid < FLASH_MMA_Q_BLOCK) {
2779 m_sh[tid] = -INFINITY;
2780 l_sh[tid] = 0.0f;
2781 }
2782 threadgroup_barrier(mem_flags::mem_threadgroup);
2783
2784 // ── Stream K/V in K_BLOCK chunks ──
2785 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2786 uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2787
2788 // Load K and V tile into TG memory (as half).
2789 uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2790 for (uint i = tid; i < kv_tile_elems; i += 128) {
2791 uint k_pos = i / head_dim;
2792 uint d = i % head_dim;
2793 if (k_pos < tile_n) {
2794 uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2795 k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2796 v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2797 } else {
2798 k_sh[k_pos * head_dim + d] = half(0);
2799 v_sh[k_pos * head_dim + d] = half(0);
2800 }
2801 }
2802 threadgroup_barrier(mem_flags::mem_threadgroup);
2803
2804 // ── Phase 1: S = Q @ K^T via MMA ──
2805 // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2806 // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2807 // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2808 {
2809 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2810 uint dim_chunks = head_dim / 8;
2811 for (uint dc = 0; dc < dim_chunks; dc++) {
2812 simdgroup_matrix<half, 8, 8> A, B;
2813 // A = Q[0:8, dc*8 : dc*8+8] (rows of Q, no transpose)
2814 simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2815 // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2816 // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2817 // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2818 // transpose=true, which places it in the register as K^T of that sub-block.
2819 simdgroup_load(B,
2820 k_sh + (simd_id * 8) * head_dim + dc * 8,
2821 head_dim, ulong2(0, 0), true);
2822 simdgroup_multiply_accumulate(C, A, B, C);
2823 }
2824 // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2825 simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2826 }
2827 threadgroup_barrier(mem_flags::mem_threadgroup);
2828
2829 // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2830 // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2831 uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2832 for (uint i = tid; i < s_elems; i += 128) {
2833 uint q = i / FLASH_MMA_K_BLOCK;
2834 uint k = i % FLASH_MMA_K_BLOCK;
2835 uint global_q = q_block_start + q;
2836 uint global_kv = kv_base + k;
2837 bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2838 s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2839 }
2840 threadgroup_barrier(mem_flags::mem_threadgroup);
2841
2842 // ── Phase 2b: per-row max via simdgroup reduction ──
2843 // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2844 // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2845 {
2846 uint row_base = simd_id * 2;
2847 for (uint qr = 0; qr < 2; qr++) {
2848 uint q = row_base + qr;
2849 float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2850 float row_max = simd_max(my);
2851 if (simd_lane == 0) {
2852 scratch[q] = row_max; // tile_max[q]
2853 }
2854 }
2855 }
2856 threadgroup_barrier(mem_flags::mem_threadgroup);
2857
2858 // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2859 if (tid < FLASH_MMA_Q_BLOCK) {
2860 uint q = tid;
2861 float m_old = m_sh[q];
2862 float tile_max = scratch[q];
2863 float m_new = max(m_old, tile_max);
2864 float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2865 m_sh[q] = m_new;
2866 l_sh[q] = l_sh[q] * alpha;
2867 // scratch[q] = m_new (for phase 2d)
2868 // scratch[Q_BLOCK + q] = alpha (for phase 3)
2869 scratch[q] = m_new;
2870 scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2871 }
2872 threadgroup_barrier(mem_flags::mem_threadgroup);
2873
2874 // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2875 {
2876 uint row_base = simd_id * 2;
2877 for (uint qr = 0; qr < 2; qr++) {
2878 uint q = row_base + qr;
2879 float m_new = scratch[q];
2880 float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2881 p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2882 float row_sum = simd_sum(p);
2883 if (simd_lane == 0) {
2884 scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2885 }
2886 }
2887 }
2888 threadgroup_barrier(mem_flags::mem_threadgroup);
2889
2890 // ── Phase 2e: l_sh += tile_sum ──
2891 if (tid < FLASH_MMA_Q_BLOCK) {
2892 uint q = tid;
2893 l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2894 }
2895
2896 // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2897 threadgroup_barrier(mem_flags::mem_threadgroup);
2898 for (uint i = tid; i < qblock_elems; i += 128) {
2899 uint q = i / head_dim;
2900 float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2901 o_sh[i] *= alpha;
2902 }
2903 threadgroup_barrier(mem_flags::mem_threadgroup);
2904
2905 // ── Phase 4: O += P @ V via MMA ──
2906 // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2907 // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2908 // For head_dim=64: 16 dims/sg = 2 tiles. head_dim=128: 32 dims/sg = 4 tiles.
2909 {
2910 uint dims_per_sg = head_dim / 4; // 16 or 32
2911 uint tiles_per_sg = dims_per_sg / 8; // 2 or 4
2912 uint sg_d_base = simd_id * dims_per_sg;
2913 for (uint t = 0; t < tiles_per_sg; t++) {
2914 uint d_base = sg_d_base + t * 8;
2915 simdgroup_matrix<float, 8, 8> O_acc;
2916 simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2917 uint k_chunks = FLASH_MMA_K_BLOCK / 8; // 4
2918 for (uint kc = 0; kc < k_chunks; kc++) {
2919 simdgroup_matrix<half, 8, 8> A, B;
2920 simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2921 ulong2(0, 0), false);
2922 simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2923 ulong2(0, 0), false);
2924 simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2925 }
2926 simdgroup_store(O_acc, o_sh + d_base, head_dim);
2927 }
2928 }
2929 threadgroup_barrier(mem_flags::mem_threadgroup);
2930 }
2931
2932 // ── Finalize: O /= l, write to output ──
2933 for (uint i = tid; i < qblock_elems; i += 128) {
2934 uint q = i / head_dim;
2935 uint d = i % head_dim;
2936 if (q < q_valid) {
2937 float l = l_sh[q];
2938 float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2939 uint token_idx = q_block_start + q;
2940 uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2941 output_batch[out_off] = o_sh[i] * inv_l;
2942 }
2943 }
2944}
2945
2946// ── rope_qk_batch ─────────────────────────────────────────────────────
2947// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2948// launch + memory barrier per layer. Both Q and K live in the same
2949// qkv_data buffer at different offsets within each token's row.
2950// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2951kernel void rope_qk_batch(
2952 device float* qkv_data [[buffer(0)]], // [M, qkv_stride]
2953 constant uint& M [[buffer(1)]], // num tokens
2954 constant uint& base_pos [[buffer(2)]], // starting position
2955 constant uint& num_q_heads [[buffer(3)]],
2956 constant uint& num_kv_heads [[buffer(4)]],
2957 constant uint& head_dim [[buffer(5)]],
2958 constant uint& qkv_stride [[buffer(6)]], // floats per row
2959 constant float& theta [[buffer(7)]],
2960 uint id [[thread_position_in_grid]])
2961{
2962 uint half_dim = head_dim / 2;
2963 uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2964 uint token = id / total_pairs;
2965 uint pair = id % total_pairs;
2966 if (token >= M) return;
2967
2968 uint pos = base_pos + token;
2969 uint q_pairs = num_q_heads * half_dim;
2970
2971 uint h, i, offset;
2972 if (pair < q_pairs) {
2973 // Q head
2974 h = pair / half_dim;
2975 i = pair % half_dim;
2976 offset = token * qkv_stride + h * head_dim + i * 2;
2977 } else {
2978 // K head
2979 uint kp = pair - q_pairs;
2980 h = kp / half_dim;
2981 i = kp % half_dim;
2982 uint k_start = num_q_heads * head_dim;
2983 offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2984 }
2985
2986 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2987 float angle = float(pos) * freq;
2988 float cos_val = cos(angle);
2989 float sin_val = sin(angle);
2990
2991 float x0 = qkv_data[offset];
2992 float x1 = qkv_data[offset + 1];
2993 qkv_data[offset] = x0 * cos_val - x1 * sin_val;
2994 qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2995}
2996
2997// ── copy_kv_both_batch ────────────────────────────────────────────────
2998// Fused K+V cache copy in a single dispatch: copies both K and V from
2999// the strided batch QKV buffer to their respective KV cache buffers.
3000// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
3001kernel void copy_kv_both_batch(
3002 device const float* src [[buffer(0)]], // batch QKV buffer [M, qkv_stride] f32
3003 device half* k_dst [[buffer(1)]], // K cache [max_seq, kv_dim] f16
3004 device half* v_dst [[buffer(2)]], // V cache [max_seq, kv_dim] f16
3005 constant uint& M [[buffer(3)]], // num tokens in batch
3006 constant uint& kv_dim [[buffer(4)]], // elements per KV vector
3007 constant uint& base_pos [[buffer(5)]], // starting position in cache
3008 constant uint& src_stride [[buffer(6)]], // floats per row in src (qkv_stride)
3009 constant uint& k_offset [[buffer(7)]], // float offset of K within each src row
3010 constant uint& v_offset [[buffer(8)]], // float offset of V within each src row
3011 uint id [[thread_position_in_grid]])
3012{
3013 // Total elements = M * kv_dim * 2 (K + V)
3014 uint total_kv = M * kv_dim;
3015 if (id >= total_kv * 2) return;
3016
3017 uint is_v = id / total_kv; // 0 = K, 1 = V
3018 uint local_id = id % total_kv;
3019 uint token = local_id / kv_dim;
3020 uint d = local_id % kv_dim;
3021
3022 uint dst_off = (base_pos + token) * kv_dim + d;
3023 uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
3024
3025 if (is_v) {
3026 v_dst[dst_off] = half(src[src_off]);
3027 } else {
3028 k_dst[dst_off] = half(src[src_off]);
3029 }
3030}
3031"#
3032 .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
3033 .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
3034}
3035
3036fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
3041 let mut code = String::with_capacity(48 * 1024);
3042 emit_model_header(&mut code, config)?;
3043 emit_metal_model_struct(&mut code, config)?;
3044 emit_layer_buffers_struct(&mut code, config)?;
3045 emit_metal_model_impl(&mut code, config)?;
3046 emit_helper_functions(&mut code)?;
3047 Ok(code)
3048}
3049
3050fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3051 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
3052 writeln!(
3053 code,
3054 "//! Model: {} ({} layers, hidden={})",
3055 config.architecture, config.num_layers, config.hidden_size
3056 )?;
3057 writeln!(code, "//!")?;
3058 writeln!(
3059 code,
3060 "//! Uses native Metal compute pipelines via the metal crate."
3061 )?;
3062 writeln!(code)?;
3063 writeln!(code, "#![allow(dead_code)]")?;
3064 writeln!(code)?;
3065 writeln!(code, "use metal::*;")?;
3066 writeln!(code, "#[allow(unused_imports)]")?;
3067 writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
3068 writeln!(code, "use std::mem;")?;
3069 writeln!(code)?;
3070
3071 writeln!(
3073 code,
3074 "// ── Model constants ──────────────────────────────────"
3075 )?;
3076 writeln!(
3077 code,
3078 "pub const HIDDEN_SIZE: usize = {};",
3079 config.hidden_size
3080 )?;
3081 writeln!(
3082 code,
3083 "pub const INTERMEDIATE_SIZE: usize = {};",
3084 config.intermediate_size
3085 )?;
3086 writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
3087 writeln!(
3088 code,
3089 "pub const NUM_HEADS: usize = {};",
3090 config.num_attention_heads
3091 )?;
3092 writeln!(
3093 code,
3094 "pub const NUM_KV_HEADS: usize = {};",
3095 config.num_kv_heads
3096 )?;
3097 writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
3098 writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
3099 let effective_seq_len = config.max_seq_len.min(4096);
3100 writeln!(
3101 code,
3102 "pub const MAX_SEQ_LEN: usize = {}; // capped from model's {}",
3103 effective_seq_len, config.max_seq_len
3104 )?;
3105 writeln!(
3106 code,
3107 "pub const RMS_NORM_EPS: f32 = {:e};",
3108 config.rms_norm_eps
3109 )?;
3110 writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3111 writeln!(
3112 code,
3113 "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3114 )?;
3115 writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3116 writeln!(code)?;
3117
3118 Ok(())
3119}
3120
3121fn emit_metal_model_struct(
3122 code: &mut String,
3123 config: &ModelConfig,
3124) -> Result<(), MetalCodegenError> {
3125 writeln!(
3126 code,
3127 "// ── MetalModel ──────────────────────────────────────────"
3128 )?;
3129 writeln!(code)?;
3130 writeln!(
3131 code,
3132 "/// Metal-accelerated transformer model for Apple Silicon."
3133 )?;
3134 writeln!(code, "///")?;
3135 writeln!(
3136 code,
3137 "/// Uses unified memory for zero-copy weight access and native Metal"
3138 )?;
3139 writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3140 writeln!(code, "pub struct MetalModel {{")?;
3141 writeln!(code, " device: Device,")?;
3142 writeln!(code, " queue: CommandQueue,")?;
3143 writeln!(code)?;
3144 writeln!(code, " // ── Compute pipelines ──")?;
3145 writeln!(code, " matmul_pipeline: ComputePipelineState,")?;
3146 writeln!(code, " matmul_q8_pipeline: ComputePipelineState,")?;
3147 writeln!(code, " matmul_q4_pipeline: ComputePipelineState,")?;
3148 writeln!(code, " rms_norm_pipeline: ComputePipelineState,")?;
3149 writeln!(code, " rope_pipeline: ComputePipelineState,")?;
3150 writeln!(code, " softmax_pipeline: ComputePipelineState,")?;
3151 writeln!(code, " silu_mul_pipeline: ComputePipelineState,")?;
3152 writeln!(code, " silu_mul_fused_pipeline: ComputePipelineState,")?;
3153 writeln!(code, " add_pipeline: ComputePipelineState,")?;
3154 writeln!(code, " attention_pipeline: ComputePipelineState,")?;
3155 writeln!(code, " add_inplace_pipeline: ComputePipelineState,")?;
3156 writeln!(code, " copy_pipeline: ComputePipelineState,")?;
3157 writeln!(code, " copy_offset_pipeline: ComputePipelineState,")?;
3158 writeln!(
3159 code,
3160 " copy_f32_to_f16_offset_pipeline: ComputePipelineState,"
3161 )?;
3162 writeln!(code)?;
3163 writeln!(code, " // ── Batched prefill pipelines ──")?;
3164 writeln!(code, " matmul_batch_pipeline: ComputePipelineState,")?;
3165 writeln!(code, " matmul_q8_batch_pipeline: ComputePipelineState,")?;
3166 writeln!(
3167 code,
3168 " matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3169 )?;
3170 writeln!(code, " matmul_q8_mma_pipeline: ComputePipelineState,")?;
3171 writeln!(code, " matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3172 writeln!(
3173 code,
3174 " matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3175 )?;
3176 writeln!(
3177 code,
3178 " matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3179 )?;
3180 writeln!(
3181 code,
3182 " matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3183 )?;
3184 if config.qkv_bias {
3185 writeln!(code, " add_bias_batch_pipeline: ComputePipelineState,")?;
3186 }
3187 writeln!(code, " matmul_q4_batch_pipeline: ComputePipelineState,")?;
3188 writeln!(code, " rms_norm_batch_pipeline: ComputePipelineState,")?;
3189 writeln!(code, " rope_batch_pipeline: ComputePipelineState,")?;
3190 writeln!(
3191 code,
3192 " silu_mul_fused_batch_pipeline: ComputePipelineState,"
3193 )?;
3194 writeln!(
3195 code,
3196 " add_inplace_batch_pipeline: ComputePipelineState,"
3197 )?;
3198 writeln!(
3199 code,
3200 " copy_embedding_batch_pipeline: ComputePipelineState,\n scale_buffer_pipeline: ComputePipelineState,"
3201 )?;
3202 writeln!(code, " attention_batch_pipeline: ComputePipelineState,")?;
3203 writeln!(
3204 code,
3205 " attention_flash_batch_pipeline: ComputePipelineState,"
3206 )?;
3207 writeln!(
3208 code,
3209 " attention_mma_flash_batch_pipeline: ComputePipelineState,"
3210 )?;
3211 writeln!(code, " copy_kv_batch_pipeline: ComputePipelineState,")?;
3212 writeln!(code, " rope_qk_batch_pipeline: ComputePipelineState,")?;
3213 writeln!(
3214 code,
3215 " copy_kv_both_batch_pipeline: ComputePipelineState,"
3216 )?;
3217 writeln!(code)?;
3218 writeln!(code, " // ── Weight buffers (Metal shared memory) ──")?;
3219 writeln!(
3220 code,
3221 " /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3222 )?;
3223 writeln!(code, " embed_buf: Buffer,")?;
3224 writeln!(code)?;
3225 writeln!(code, " /// Per-layer weight buffers")?;
3226 writeln!(code, " layers: Vec<LayerBuffers>,")?;
3227 writeln!(code)?;
3228 writeln!(code, " /// Final layer-norm weight [HIDDEN_SIZE]")?;
3229 writeln!(code, " norm_buf: Buffer,")?;
3230 writeln!(code)?;
3231 writeln!(
3232 code,
3233 " /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3234 )?;
3235 writeln!(code, " lm_head_buf: Buffer,")?;
3236 writeln!(code)?;
3237 writeln!(
3238 code,
3239 " // ── Working buffers (pre-allocated, reused every forward pass) ──"
3240 )?;
3241 writeln!(code, " hidden_buf: Buffer,")?;
3242 writeln!(code, " residual_buf: Buffer,")?;
3243 writeln!(code, " normed_buf: Buffer,")?;
3244 writeln!(
3245 code,
3246 " /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3247 )?;
3248 writeln!(code, " qkv_buf: Buffer,")?;
3249 writeln!(code, " attn_out_buf: Buffer,")?;
3250 writeln!(code, " attn_proj_buf: Buffer,")?;
3251 writeln!(
3252 code,
3253 " /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3254 )?;
3255 writeln!(code, " gate_up_buf: Buffer,")?;
3256 writeln!(code, " ffn_hidden_buf: Buffer,")?;
3257 writeln!(code, " ffn_out_buf: Buffer,")?;
3258 writeln!(code, " add_tmp_buf: Buffer,")?;
3259 writeln!(code, " logits_buf: Buffer,")?;
3260 writeln!(code)?;
3261 writeln!(code, " // ── Batched prefill working buffers ──")?;
3262 writeln!(code, " /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3263 writeln!(code, " batch_hidden_buf: Buffer,")?;
3264 writeln!(
3265 code,
3266 " /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3267 )?;
3268 writeln!(code, " batch_residual_buf: Buffer,")?;
3269 writeln!(code, " /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3270 writeln!(code, " batch_qkv_buf: Buffer,")?;
3271 writeln!(
3272 code,
3273 " /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3274 )?;
3275 writeln!(code, " batch_attn_out_buf: Buffer,")?;
3276 writeln!(
3277 code,
3278 " /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3279 )?;
3280 writeln!(code, " batch_attn_proj_buf: Buffer,")?;
3281 writeln!(
3282 code,
3283 " /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3284 )?;
3285 writeln!(code, " batch_gate_up_buf: Buffer,")?;
3286 writeln!(
3287 code,
3288 " /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3289 )?;
3290 writeln!(code, " batch_ffn_hidden_buf: Buffer,")?;
3291 writeln!(code, " /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3292 writeln!(code, " batch_ffn_out_buf: Buffer,")?;
3293 writeln!(code, " /// Token IDs buffer for batch embedding lookup")?;
3294 writeln!(code, " batch_tokens_buf: Buffer,")?;
3295 writeln!(code, " /// Positions buffer for batch RoPE")?;
3296 writeln!(code, " batch_positions_buf: Buffer,")?;
3297 writeln!(code)?;
3298 writeln!(code, " // ── KV cache buffers (per-layer) ──")?;
3299 writeln!(code, " k_cache: Vec<Buffer>, // per-layer")?;
3300 writeln!(code, " v_cache: Vec<Buffer>, // per-layer")?;
3301 writeln!(code)?;
3302 writeln!(code, " // ── Inference state ──")?;
3303 writeln!(code, " pos: usize,")?;
3304 writeln!(code)?;
3305 writeln!(
3306 code,
3307 " /// Previous command buffer for double-buffered prefill."
3308 )?;
3309 writeln!(
3310 code,
3311 " /// While the GPU executes token N, the CPU can encode token N+1."
3312 )?;
3313 writeln!(code, " prev_cmd: Option<CommandBuffer>,")?;
3314 writeln!(code, "}}")?;
3315 writeln!(code)?;
3316
3317 Ok(())
3318}
3319
3320fn emit_layer_buffers_struct(
3321 code: &mut String,
3322 config: &ModelConfig,
3323) -> Result<(), MetalCodegenError> {
3324 writeln!(
3325 code,
3326 "/// Per-layer weight buffers for attention and FFN projections."
3327 )?;
3328 writeln!(code, "struct LayerBuffers {{")?;
3329 writeln!(code, " attn_norm: Buffer,")?;
3330 writeln!(
3331 code,
3332 " /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3333 )?;
3334 writeln!(code, " qkv_weight: Buffer,")?;
3335 if config.qkv_bias {
3336 writeln!(
3337 code,
3338 " /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3339 )?;
3340 writeln!(code, " qkv_bias: Buffer,")?;
3341 }
3342 writeln!(code, " o_weight: Buffer,")?;
3343 writeln!(code, " ffn_norm: Buffer,")?;
3344 writeln!(
3345 code,
3346 " /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3347 )?;
3348 writeln!(code, " gate_up_weight: Buffer,")?;
3349 writeln!(code, " down_weight: Buffer,")?;
3350 writeln!(code, "}}")?;
3351 writeln!(code)?;
3352
3353 Ok(())
3354}
3355
3356fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3357 let hidden = config.hidden_size;
3358 let intermediate = config.intermediate_size;
3359 let _num_layers = config.num_layers;
3360 let _num_heads = config.num_attention_heads;
3361 let num_kv_heads = config.num_kv_heads;
3362 let head_dim = config.head_dim;
3363 let vocab = config.vocab_size;
3364 let effective_seq_len = config.max_seq_len.min(4096);
3365 let is_q8 = config.dtype == DType::Q8_0;
3366 let is_q4 = config.dtype == DType::Q4_0;
3367 let kv_dim = num_kv_heads * head_dim;
3368
3369 writeln!(code, "impl MetalModel {{")?;
3370
3371 writeln!(
3373 code,
3374 " /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3375 )?;
3376 writeln!(code, " ///")?;
3377 writeln!(
3378 code,
3379 " /// `weights` is the raw weight blob produced by `forge export-weights`."
3380 )?;
3381 writeln!(code, " pub fn new(weights: &[u8]) -> Self {{")?;
3382 writeln!(
3383 code,
3384 " let device = Device::system_default().expect(\"no Metal device found\");"
3385 )?;
3386 writeln!(code, " let queue = device.new_command_queue();")?;
3387 writeln!(code)?;
3388
3389 writeln!(
3391 code,
3392 " // Compile Metal shaders from embedded source"
3393 )?;
3394 writeln!(
3395 code,
3396 " let shader_source = include_str!(\"../shaders/kernels.metal\");"
3397 )?;
3398 writeln!(
3399 code,
3400 " let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3401 )?;
3402 writeln!(
3403 code,
3404 " .expect(\"failed to compile Metal shaders\");"
3405 )?;
3406 writeln!(code)?;
3407
3408 let activation_kernel = match config.hidden_activation {
3414 HiddenActivation::SiLU => "silu_mul_fused",
3415 HiddenActivation::GeluApprox => "gelu_mul_fused",
3416 };
3417 let activation_kernel_batch = match config.hidden_activation {
3418 HiddenActivation::SiLU => "silu_mul_fused_batch",
3419 HiddenActivation::GeluApprox => "gelu_mul_fused_batch",
3420 };
3421 writeln!(code, " // Create compute pipelines")?;
3422 for (var, fn_name) in [
3423 ("matmul_pipeline", "matmul_vec"),
3424 ("matmul_q8_pipeline", "matmul_vec_q8"),
3425 ("matmul_q4_pipeline", "matmul_vec_q4"),
3426 ("rms_norm_pipeline", "rms_norm"),
3427 ("rope_pipeline", "rope"),
3428 ("softmax_pipeline", "softmax"),
3429 ("silu_mul_pipeline", "silu_mul"),
3430 ("silu_mul_fused_pipeline", activation_kernel),
3431 ("add_pipeline", "elementwise_add"),
3432 ("attention_pipeline", "attention"),
3433 ("add_inplace_pipeline", "add_inplace"),
3434 ("copy_pipeline", "copy_buffer"),
3435 ("copy_offset_pipeline", "copy_offset"),
3436 ("copy_f32_to_f16_offset_pipeline", "copy_f32_to_f16_offset"),
3437 ("matmul_batch_pipeline", "matmul_vec_batch"),
3438 ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3439 ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3440 ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3441 ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3442 ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3443 ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3444 ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3445 ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3446 ("rms_norm_batch_pipeline", "rms_norm_batch"),
3447 ("rope_batch_pipeline", "rope_batch"),
3448 ("silu_mul_fused_batch_pipeline", activation_kernel_batch),
3449 ("add_inplace_batch_pipeline", "add_inplace_batch"),
3450 ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3451 ("scale_buffer_pipeline", "scale_buffer"),
3452 ("attention_batch_pipeline", "attention_batch"),
3453 ("attention_flash_batch_pipeline", "attention_flash_batch"),
3454 (
3455 "attention_mma_flash_batch_pipeline",
3456 "attention_mma_flash_batch",
3457 ),
3458 ("copy_kv_batch_pipeline", "copy_kv_batch"),
3459 ("rope_qk_batch_pipeline", "rope_qk_batch"),
3460 ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3461 ] {
3462 writeln!(
3463 code,
3464 " let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3465 )?;
3466 }
3467 if config.qkv_bias {
3468 writeln!(
3469 code,
3470 " let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3471 )?;
3472 }
3473 writeln!(code)?;
3474
3475 writeln!(
3477 code,
3478 " // Load weights into Metal shared-memory buffers"
3479 )?;
3480 writeln!(code, " let f32_size = mem::size_of::<f32>();")?;
3481 writeln!(code, " let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3482 writeln!(code, " let hidden_elems = HIDDEN_SIZE;")?;
3483 writeln!(code)?;
3484 writeln!(
3485 code,
3486 " let cursor = std::cell::Cell::new(0usize); // byte cursor into `weights`"
3487 )?;
3488 writeln!(code)?;
3489 writeln!(
3490 code,
3491 " // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3492 )?;
3493 writeln!(
3494 code,
3495 " let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3496 )?;
3497 writeln!(code, " let byte_len = n * f32_size;")?;
3498 writeln!(code, " let cur = cursor.get();")?;
3499 writeln!(
3500 code,
3501 " let data = &weights[cur..cur + byte_len];"
3502 )?;
3503 writeln!(code, " cursor.set(cur + byte_len);")?;
3504 writeln!(code, " device.new_buffer_with_data(")?;
3505 writeln!(code, " data.as_ptr() as *const _,")?;
3506 writeln!(code, " byte_len as u64,")?;
3507 writeln!(
3508 code,
3509 " MTLResourceOptions::StorageModeShared,"
3510 )?;
3511 writeln!(code, " )")?;
3512 writeln!(code, " }};")?;
3513 writeln!(code)?;
3514
3515 if is_q8 {
3516 writeln!(
3521 code,
3522 " // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3523 )?;
3524 writeln!(
3525 code,
3526 " // as raw bytes into a Metal buffer (no dequantization)."
3527 )?;
3528 writeln!(
3529 code,
3530 " // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3531 )?;
3532 writeln!(
3533 code,
3534 " let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3535 )?;
3536 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3537 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3538 writeln!(code, " let total_raw = rows * row_bytes;")?;
3539 writeln!(code, " let cur = cursor.get();")?;
3540 writeln!(
3541 code,
3542 " let data = &weights[cur..cur + total_raw];"
3543 )?;
3544 writeln!(code, " cursor.set(cur + total_raw);")?;
3545 writeln!(code, " device.new_buffer_with_data(")?;
3546 writeln!(code, " data.as_ptr() as *const _,")?;
3547 writeln!(code, " total_raw as u64,")?;
3548 writeln!(
3549 code,
3550 " MTLResourceOptions::StorageModeShared,"
3551 )?;
3552 writeln!(code, " )")?;
3553 writeln!(code, " }};")?;
3554 writeln!(code)?;
3555 writeln!(
3556 code,
3557 " // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3558 )?;
3559 writeln!(
3560 code,
3561 " // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3562 )?;
3563 writeln!(
3564 code,
3565 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3566 )?;
3567 writeln!(
3568 code,
3569 " let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3570 )?;
3571 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3572 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3573 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3574 writeln!(code, " let cur = cursor.get();")?;
3575 writeln!(
3576 code,
3577 " let data = &weights[cur..cur + total_raw];"
3578 )?;
3579 writeln!(code, " cursor.set(cur + total_raw);")?;
3580 writeln!(code, " device.new_buffer_with_data(")?;
3581 writeln!(code, " data.as_ptr() as *const _,")?;
3582 writeln!(code, " total_raw as u64,")?;
3583 writeln!(
3584 code,
3585 " MTLResourceOptions::StorageModeShared,"
3586 )?;
3587 writeln!(code, " )")?;
3588 writeln!(code, " }};")?;
3589 writeln!(code)?;
3590 }
3591
3592 if is_q4 {
3593 writeln!(
3598 code,
3599 " // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3600 )?;
3601 writeln!(
3602 code,
3603 " // as raw bytes into a Metal buffer (no dequantization)."
3604 )?;
3605 writeln!(
3606 code,
3607 " // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3608 )?;
3609 writeln!(
3610 code,
3611 " let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3612 )?;
3613 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3614 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3615 writeln!(code, " let total_raw = rows * row_bytes;")?;
3616 writeln!(code, " let cur = cursor.get();")?;
3617 writeln!(
3618 code,
3619 " let data = &weights[cur..cur + total_raw];"
3620 )?;
3621 writeln!(code, " cursor.set(cur + total_raw);")?;
3622 writeln!(code, " device.new_buffer_with_data(")?;
3623 writeln!(code, " data.as_ptr() as *const _,")?;
3624 writeln!(code, " total_raw as u64,")?;
3625 writeln!(
3626 code,
3627 " MTLResourceOptions::StorageModeShared,"
3628 )?;
3629 writeln!(code, " )")?;
3630 writeln!(code, " }};")?;
3631 writeln!(code)?;
3632 writeln!(
3633 code,
3634 " // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3635 )?;
3636 writeln!(
3637 code,
3638 " // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3639 )?;
3640 writeln!(
3641 code,
3642 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3643 )?;
3644 writeln!(
3645 code,
3646 " let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3647 )?;
3648 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3649 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3650 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3651 writeln!(code, " let cur = cursor.get();")?;
3652 writeln!(
3653 code,
3654 " let data = &weights[cur..cur + total_raw];"
3655 )?;
3656 writeln!(code, " cursor.set(cur + total_raw);")?;
3657 writeln!(code, " device.new_buffer_with_data(")?;
3658 writeln!(code, " data.as_ptr() as *const _,")?;
3659 writeln!(code, " total_raw as u64,")?;
3660 writeln!(
3661 code,
3662 " MTLResourceOptions::StorageModeShared,"
3663 )?;
3664 writeln!(code, " )")?;
3665 writeln!(code, " }};")?;
3666 writeln!(code)?;
3667 }
3668
3669 writeln!(
3670 code,
3671 " let embed_buf = next_f32_buffer(&device, embed_elems);"
3672 )?;
3673 writeln!(code)?;
3674
3675 writeln!(
3677 code,
3678 " let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3679 )?;
3680 writeln!(code, " for _layer in 0..NUM_LAYERS {{")?;
3681
3682 writeln!(
3684 code,
3685 " let attn_norm = next_f32_buffer(&device, hidden_elems);"
3686 )?;
3687
3688 let qkv_rows = hidden + 2 * kv_dim;
3689 if is_q8 {
3690 writeln!(
3692 code,
3693 " let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3694 )?;
3695 if config.qkv_bias {
3696 writeln!(
3697 code,
3698 " // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3699 )?;
3700 writeln!(
3701 code,
3702 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3703 )?;
3704 }
3705 writeln!(
3706 code,
3707 " let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3708 )?;
3709 } else if is_q4 {
3710 writeln!(
3711 code,
3712 " let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3713 )?;
3714 if config.qkv_bias {
3715 writeln!(
3716 code,
3717 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3718 )?;
3719 }
3720 writeln!(
3721 code,
3722 " let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3723 )?;
3724 } else {
3725 writeln!(
3726 code,
3727 " let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3728 )?;
3729 if config.qkv_bias {
3730 writeln!(
3731 code,
3732 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3733 )?;
3734 }
3735 writeln!(
3736 code,
3737 " let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3738 )?;
3739 }
3740
3741 writeln!(
3743 code,
3744 " let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3745 )?;
3746
3747 let gate_up_rows = 2 * intermediate;
3748 if is_q8 {
3749 writeln!(
3751 code,
3752 " let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3753 )?;
3754 writeln!(
3755 code,
3756 " let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3757 )?;
3758 } else if is_q4 {
3759 writeln!(
3761 code,
3762 " let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3763 )?;
3764 writeln!(
3765 code,
3766 " let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3767 )?;
3768 } else {
3769 writeln!(
3771 code,
3772 " let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3773 )?;
3774 writeln!(
3775 code,
3776 " let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3777 )?;
3778 }
3779
3780 writeln!(code, " layers.push(LayerBuffers {{")?;
3781 writeln!(code, " attn_norm,")?;
3782 writeln!(code, " qkv_weight,")?;
3783 if config.qkv_bias {
3784 writeln!(code, " qkv_bias,")?;
3785 }
3786 writeln!(code, " o_weight,")?;
3787 writeln!(code, " ffn_norm,")?;
3788 writeln!(code, " gate_up_weight,")?;
3789 writeln!(code, " down_weight,")?;
3790 writeln!(code, " }});")?;
3791 writeln!(code, " }}")?;
3792 writeln!(code)?;
3793
3794 writeln!(
3796 code,
3797 " let norm_buf = next_f32_buffer(&device, hidden_elems);"
3798 )?;
3799 writeln!(code)?;
3800
3801 if is_q8 {
3803 writeln!(
3804 code,
3805 " let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3806 )?;
3807 } else if is_q4 {
3808 writeln!(
3809 code,
3810 " let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3811 )?;
3812 } else {
3813 writeln!(
3814 code,
3815 " let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3816 )?;
3817 }
3818 writeln!(code)?;
3819
3820 let hidden_bytes = hidden * 4;
3822 let _kv_dim_bytes = kv_dim * 4;
3823 let intermediate_bytes = intermediate * 4;
3824 let vocab_bytes = vocab * 4;
3825 let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 2;
3828
3829 writeln!(
3830 code,
3831 " // Allocate working buffers (shared memory for zero-copy)"
3832 )?;
3833 writeln!(
3834 code,
3835 " let opts = MTLResourceOptions::StorageModeShared;"
3836 )?;
3837 writeln!(
3838 code,
3839 " let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3840 )?;
3841 writeln!(
3842 code,
3843 " let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3844 )?;
3845 let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3846 writeln!(
3847 code,
3848 " let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3849 )?;
3850 writeln!(
3851 code,
3852 " // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3853 )?;
3854 writeln!(
3855 code,
3856 " let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3857 )?;
3858 writeln!(
3859 code,
3860 " let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3861 )?;
3862 writeln!(
3863 code,
3864 " let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3865 )?;
3866 let gate_up_buf_bytes = 2 * intermediate * 4;
3867 writeln!(
3868 code,
3869 " // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3870 )?;
3871 writeln!(
3872 code,
3873 " let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3874 )?;
3875 writeln!(
3876 code,
3877 " let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3878 )?;
3879 writeln!(
3880 code,
3881 " let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3882 )?;
3883 writeln!(
3884 code,
3885 " let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3886 )?;
3887 writeln!(
3888 code,
3889 " let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3890 )?;
3891 writeln!(code)?;
3892
3893 let batch_hidden_bytes = hidden * 4; let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3896 let batch_gate_up_bytes = 2 * intermediate * 4;
3897 let batch_intermediate_bytes = intermediate * 4;
3898 writeln!(
3899 code,
3900 " // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3901 )?;
3902 writeln!(
3903 code,
3904 " let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3905 )?;
3906 writeln!(
3907 code,
3908 " let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3909 )?;
3910 writeln!(
3911 code,
3912 " let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3913 )?;
3914 writeln!(
3915 code,
3916 " let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3917 )?;
3918 writeln!(
3919 code,
3920 " let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3921 )?;
3922 writeln!(
3923 code,
3924 " let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3925 )?;
3926 writeln!(
3927 code,
3928 " let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3929 )?;
3930 writeln!(
3931 code,
3932 " let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3933 )?;
3934 writeln!(
3935 code,
3936 " let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3937 )?;
3938 writeln!(
3939 code,
3940 " let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3941 )?;
3942 writeln!(code)?;
3943
3944 writeln!(code, " // KV cache buffers (per-layer)")?;
3946 writeln!(
3947 code,
3948 " let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3949 )?;
3950 writeln!(
3951 code,
3952 " let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3953 )?;
3954 writeln!(code, " for _ in 0..NUM_LAYERS {{")?;
3955 writeln!(
3956 code,
3957 " k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3958 )?;
3959 writeln!(
3960 code,
3961 " v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3962 )?;
3963 writeln!(code, " }}")?;
3964 writeln!(code)?;
3965
3966 writeln!(code, " Self {{")?;
3967 writeln!(code, " device,")?;
3968 writeln!(code, " queue,")?;
3969 writeln!(code, " matmul_pipeline,")?;
3970 writeln!(code, " matmul_q8_pipeline,")?;
3971 writeln!(code, " matmul_q4_pipeline,")?;
3972 writeln!(code, " rms_norm_pipeline,")?;
3973 writeln!(code, " rope_pipeline,")?;
3974 writeln!(code, " softmax_pipeline,")?;
3975 writeln!(code, " silu_mul_pipeline,")?;
3976 writeln!(code, " silu_mul_fused_pipeline,")?;
3977 writeln!(code, " add_pipeline,")?;
3978 writeln!(code, " attention_pipeline,")?;
3979 writeln!(code, " add_inplace_pipeline,")?;
3980 writeln!(code, " copy_pipeline,")?;
3981 writeln!(code, " copy_offset_pipeline,")?;
3982 writeln!(code, " copy_f32_to_f16_offset_pipeline,")?;
3983 writeln!(code, " matmul_batch_pipeline,")?;
3984 writeln!(code, " matmul_q8_batch_pipeline,")?;
3985 writeln!(code, " matmul_q8_gemm_batch_pipeline,")?;
3986 writeln!(code, " matmul_q8_mma_pipeline,")?;
3987 writeln!(code, " matmul_q8_mma32_pipeline,")?;
3988 writeln!(code, " matmul_q8_mma32_h_pipeline,")?;
3989 writeln!(code, " matmul_q8_mma32_h4_pipeline,")?;
3990 writeln!(code, " matmul_q8_mma32_hh4_pipeline,")?;
3991 if config.qkv_bias {
3992 writeln!(code, " add_bias_batch_pipeline,")?;
3993 }
3994 writeln!(code, " matmul_q4_batch_pipeline,")?;
3995 writeln!(code, " rms_norm_batch_pipeline,")?;
3996 writeln!(code, " rope_batch_pipeline,")?;
3997 writeln!(code, " silu_mul_fused_batch_pipeline,")?;
3998 writeln!(code, " add_inplace_batch_pipeline,")?;
3999 writeln!(code, " copy_embedding_batch_pipeline,")?;
4000 writeln!(code, " scale_buffer_pipeline,")?;
4001 writeln!(code, " attention_batch_pipeline,")?;
4002 writeln!(code, " attention_flash_batch_pipeline,")?;
4003 writeln!(code, " attention_mma_flash_batch_pipeline,")?;
4004 writeln!(code, " copy_kv_batch_pipeline,")?;
4005 writeln!(code, " rope_qk_batch_pipeline,")?;
4006 writeln!(code, " copy_kv_both_batch_pipeline,")?;
4007 writeln!(code, " embed_buf,")?;
4008 writeln!(code, " layers,")?;
4009 writeln!(code, " norm_buf,")?;
4010 writeln!(code, " lm_head_buf,")?;
4011 writeln!(code, " hidden_buf,")?;
4012 writeln!(code, " residual_buf,")?;
4013 writeln!(code, " normed_buf,")?;
4014 writeln!(code, " qkv_buf,")?;
4015 writeln!(code, " attn_out_buf,")?;
4016 writeln!(code, " attn_proj_buf,")?;
4017 writeln!(code, " gate_up_buf,")?;
4018 writeln!(code, " ffn_hidden_buf,")?;
4019 writeln!(code, " ffn_out_buf,")?;
4020 writeln!(code, " add_tmp_buf,")?;
4021 writeln!(code, " logits_buf,")?;
4022 writeln!(code, " batch_hidden_buf,")?;
4023 writeln!(code, " batch_residual_buf,")?;
4024 writeln!(code, " batch_qkv_buf,")?;
4025 writeln!(code, " batch_attn_out_buf,")?;
4026 writeln!(code, " batch_attn_proj_buf,")?;
4027 writeln!(code, " batch_gate_up_buf,")?;
4028 writeln!(code, " batch_ffn_hidden_buf,")?;
4029 writeln!(code, " batch_ffn_out_buf,")?;
4030 writeln!(code, " batch_tokens_buf,")?;
4031 writeln!(code, " batch_positions_buf,")?;
4032 writeln!(code, " k_cache,")?;
4033 writeln!(code, " v_cache,")?;
4034 writeln!(code, " pos: 0,")?;
4035 writeln!(code, " prev_cmd: None,")?;
4036 writeln!(code, " }}")?;
4037 writeln!(code, " }}")?;
4038 writeln!(code)?;
4039
4040 writeln!(
4042 code,
4043 " /// Run the forward pass for a single token at the current position."
4044 )?;
4045 writeln!(code, " ///")?;
4046 writeln!(
4047 code,
4048 " /// Returns logits over the vocabulary as a `Vec<f32>`."
4049 )?;
4050 writeln!(code, " ///")?;
4051 writeln!(
4052 code,
4053 " /// All GPU operations are encoded into a single command buffer and"
4054 )?;
4055 writeln!(
4056 code,
4057 " /// committed once at the end, avoiding per-operation synchronization."
4058 )?;
4059 writeln!(
4060 code,
4061 " pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
4062 )?;
4063 writeln!(
4064 code,
4065 " // Wait for any pending prefill command buffer"
4066 )?;
4067 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4068 writeln!(code, " prev.wait_until_completed();")?;
4069 writeln!(code, " }}")?;
4070 writeln!(code)?;
4071 writeln!(code, " let pos = self.pos;")?;
4072 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4073 writeln!(code)?;
4074
4075 let matmul_fn = if is_q8 {
4078 "dispatch_matmul_q8"
4079 } else if is_q4 {
4080 "dispatch_matmul_q4"
4081 } else {
4082 "dispatch_matmul"
4083 };
4084
4085 writeln!(
4086 code,
4087 " // Single compute encoder for the entire forward pass (no blit transitions)"
4088 )?;
4089 writeln!(code, " {{")?;
4090 writeln!(
4091 code,
4092 " let enc = cmd.new_compute_command_encoder();"
4093 )?;
4094 writeln!(code)?;
4095
4096 writeln!(
4098 code,
4099 " // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
4100 )?;
4101 writeln!(
4102 code,
4103 " // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
4104 )?;
4105 writeln!(
4106 code,
4107 " // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
4108 )?;
4109 writeln!(
4110 code,
4111 " // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
4112 hidden * 4,
4113 )?;
4114 writeln!(code, " unsafe {{")?;
4115 writeln!(
4116 code,
4117 " let embed_ptr = self.embed_buf.contents() as *const f32;"
4118 )?;
4119 writeln!(
4120 code,
4121 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4122 )?;
4123 writeln!(
4124 code,
4125 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
4126 )?;
4127 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
4128 writeln!(
4129 code,
4130 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4131 )?;
4132 writeln!(code, " hidden_ptr,")?;
4133 writeln!(code, " HIDDEN_SIZE,")?;
4134 writeln!(code, " );")?;
4135 if matches!(config.hidden_activation, HiddenActivation::GeluApprox) {
4136 writeln!(
4141 code,
4142 " const GEMMA_EMBED_SCALE: f32 = {scale:e}_f32;",
4143 scale = (hidden as f32).sqrt(),
4144 )?;
4145 writeln!(code, " for i in 0..HIDDEN_SIZE {{")?;
4146 writeln!(
4147 code,
4148 " *hidden_ptr.add(i) *= GEMMA_EMBED_SCALE;"
4149 )?;
4150 writeln!(code, " }}")?;
4151 }
4152 writeln!(
4153 code,
4154 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4155 )?;
4156 writeln!(code, " }}")?;
4157 writeln!(code)?;
4158
4159 writeln!(code, " // 2. Transformer layers")?;
4161 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4162 writeln!(code)?;
4163 let q_byte_offset = 0usize;
4164 let k_float_offset = hidden;
4165 let v_float_offset = hidden + kv_dim;
4166
4167 writeln!(
4168 code,
4169 " // Pre-attention: rms_norm, fused QKV projection, RoPE"
4170 )?;
4171 writeln!(
4172 code,
4173 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4174 )?;
4175 writeln!(
4176 code,
4177 " // Fused Q+K+V matmul: single dispatch for all three projections"
4178 )?;
4179 writeln!(
4180 code,
4181 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4182 )?;
4183 if config.qkv_bias {
4184 writeln!(
4185 code,
4186 " // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4187 )?;
4188 writeln!(
4189 code,
4190 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4191 )?;
4192 }
4193 writeln!(
4194 code,
4195 " // Fused Q+K RoPE in one dispatch (saves 1 dispatch + barrier vs separate Q and K rope)"
4196 )?;
4197 writeln!(
4198 code,
4199 " self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4200 )?;
4201 writeln!(code)?;
4202 writeln!(
4203 code,
4204 " // Fused K+V cache update in one dispatch (f32 qkv_buf -> f16 KV cache)"
4205 )?;
4206 writeln!(
4207 code,
4208 " self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4209 )?;
4210 writeln!(code)?;
4211 writeln!(
4212 code,
4213 " // Attention using Q from qkv_buf (offset 0)"
4214 )?;
4215 writeln!(
4216 code,
4217 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4218 )?;
4219 writeln!(
4220 code,
4221 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4222 )?;
4223 writeln!(
4224 code,
4225 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4226 )?;
4227 writeln!(
4228 code,
4229 " // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4230 )?;
4231 writeln!(
4232 code,
4233 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4234 )?;
4235 writeln!(
4236 code,
4237 " // Fused gate+up matmul: single dispatch for both projections"
4238 )?;
4239 writeln!(
4240 code,
4241 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4242 )?;
4243 writeln!(
4244 code,
4245 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4246 )?;
4247 writeln!(
4248 code,
4249 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4250 )?;
4251 writeln!(
4252 code,
4253 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4254 )?;
4255 writeln!(code, " }}")?;
4256 writeln!(code)?;
4257
4258 writeln!(code, " // 3. Final RMS norm + logits projection")?;
4260 writeln!(
4261 code,
4262 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4263 )?;
4264 writeln!(
4265 code,
4266 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4267 )?;
4268 writeln!(code)?;
4269 writeln!(code, " enc.end_encoding();")?;
4270 writeln!(code, " }}")?;
4271 writeln!(code)?;
4272
4273 writeln!(
4275 code,
4276 " // 5. Commit all GPU work and wait for completion"
4277 )?;
4278 writeln!(code, " cmd.commit();")?;
4279 writeln!(code, " cmd.wait_until_completed();")?;
4280 writeln!(code)?;
4281 writeln!(code, " // 6. Read back logits from GPU")?;
4282 writeln!(code, " let logits = unsafe {{")?;
4283 writeln!(
4284 code,
4285 " let ptr = self.logits_buf.contents() as *const f32;"
4286 )?;
4287 writeln!(
4288 code,
4289 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4290 )?;
4291 writeln!(code, " }};")?;
4292 writeln!(code)?;
4293 writeln!(code, " self.pos += 1;")?;
4294 writeln!(code, " logits")?;
4295 writeln!(code, " }}")?;
4296 writeln!(code)?;
4297
4298 writeln!(
4300 code,
4301 " /// Profiling forward pass that prints per-stage GPU timing."
4302 )?;
4303 writeln!(code, " ///")?;
4304 writeln!(
4305 code,
4306 " /// Each stage is committed and waited on separately so that GPU timestamps"
4307 )?;
4308 writeln!(
4309 code,
4310 " /// accurately reflect per-operation cost. This is slower than `forward()` due"
4311 )?;
4312 writeln!(
4313 code,
4314 " /// to the per-stage synchronization, but useful for identifying bottlenecks."
4315 )?;
4316 writeln!(
4317 code,
4318 " pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4319 )?;
4320 writeln!(code, " use std::time::Instant;")?;
4321 writeln!(code)?;
4322 writeln!(
4323 code,
4324 " // Wait for any pending prefill command buffer"
4325 )?;
4326 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4327 writeln!(code, " prev.wait_until_completed();")?;
4328 writeln!(code, " }}")?;
4329 writeln!(code)?;
4330 writeln!(code, " let pos = self.pos;")?;
4331 writeln!(code)?;
4332
4333 writeln!(
4335 code,
4336 " // ── Stage: Embedding lookup (CPU via unified memory) ──"
4337 )?;
4338 writeln!(code, " let t_embed = Instant::now();")?;
4339 writeln!(code, " unsafe {{")?;
4340 writeln!(
4341 code,
4342 " let embed_ptr = self.embed_buf.contents() as *const f32;"
4343 )?;
4344 writeln!(
4345 code,
4346 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4347 )?;
4348 writeln!(
4349 code,
4350 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
4351 )?;
4352 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
4353 writeln!(
4354 code,
4355 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4356 )?;
4357 writeln!(code, " hidden_ptr,")?;
4358 writeln!(code, " HIDDEN_SIZE,")?;
4359 writeln!(code, " );")?;
4360 if matches!(config.hidden_activation, HiddenActivation::GeluApprox) {
4361 writeln!(
4362 code,
4363 " const GEMMA_EMBED_SCALE: f32 = {scale:e}_f32;",
4364 scale = (hidden as f32).sqrt(),
4365 )?;
4366 writeln!(code, " for i in 0..HIDDEN_SIZE {{")?;
4367 writeln!(
4368 code,
4369 " *hidden_ptr.add(i) *= GEMMA_EMBED_SCALE;"
4370 )?;
4371 writeln!(code, " }}")?;
4372 }
4373 writeln!(
4374 code,
4375 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4376 )?;
4377 writeln!(code, " }}")?;
4378 writeln!(code, " let d_embed = t_embed.elapsed();")?;
4379 writeln!(code)?;
4380
4381 writeln!(code, " // ── Stage: Transformer layers (GPU) ──")?;
4383 writeln!(code, " let t_layers = Instant::now();")?;
4384 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4385 writeln!(code, " {{")?;
4386 writeln!(
4387 code,
4388 " let enc = cmd.new_compute_command_encoder();"
4389 )?;
4390 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4391 writeln!(
4392 code,
4393 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4394 )?;
4395 writeln!(
4396 code,
4397 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4398 )?;
4399 if config.qkv_bias {
4400 writeln!(
4401 code,
4402 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4403 )?;
4404 }
4405 writeln!(
4406 code,
4407 " self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4408 )?;
4409 writeln!(
4410 code,
4411 " self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4412 )?;
4413 writeln!(
4414 code,
4415 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4416 )?;
4417 writeln!(
4418 code,
4419 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4420 )?;
4421 writeln!(
4422 code,
4423 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4424 )?;
4425 writeln!(
4426 code,
4427 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4428 )?;
4429 writeln!(
4430 code,
4431 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4432 )?;
4433 writeln!(
4434 code,
4435 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4436 )?;
4437 writeln!(
4438 code,
4439 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4440 )?;
4441 writeln!(
4442 code,
4443 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4444 )?;
4445 writeln!(code, " }}")?;
4446 writeln!(code, " enc.end_encoding();")?;
4447 writeln!(code, " }}")?;
4448 writeln!(code, " cmd.commit();")?;
4449 writeln!(code, " cmd.wait_until_completed();")?;
4450 writeln!(code, " let d_layers = t_layers.elapsed();")?;
4451 writeln!(code)?;
4452
4453 writeln!(code, " // ── Stage: Final norm + logits (GPU) ──")?;
4455 writeln!(code, " let t_logits = Instant::now();")?;
4456 writeln!(code, " let cmd2 = self.queue.new_command_buffer();")?;
4457 writeln!(code, " {{")?;
4458 writeln!(
4459 code,
4460 " let enc = cmd2.new_compute_command_encoder();"
4461 )?;
4462 writeln!(
4463 code,
4464 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4465 )?;
4466 writeln!(
4467 code,
4468 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4469 )?;
4470 writeln!(code, " enc.end_encoding();")?;
4471 writeln!(code, " }}")?;
4472 writeln!(code, " cmd2.commit();")?;
4473 writeln!(code, " cmd2.wait_until_completed();")?;
4474 writeln!(code, " let d_logits = t_logits.elapsed();")?;
4475 writeln!(code)?;
4476
4477 writeln!(
4479 code,
4480 " eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4481 )?;
4482 writeln!(code, " d_embed.as_secs_f64() * 1000.0,")?;
4483 writeln!(code, " d_layers.as_secs_f64() * 1000.0,")?;
4484 writeln!(code, " d_logits.as_secs_f64() * 1000.0,")?;
4485 writeln!(
4486 code,
4487 " (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4488 )?;
4489 writeln!(code)?;
4490
4491 writeln!(code, " let logits = unsafe {{")?;
4493 writeln!(
4494 code,
4495 " let ptr = self.logits_buf.contents() as *const f32;"
4496 )?;
4497 writeln!(
4498 code,
4499 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4500 )?;
4501 writeln!(code, " }};")?;
4502 writeln!(code)?;
4503 writeln!(code, " self.pos += 1;")?;
4504 writeln!(code, " logits")?;
4505 writeln!(code, " }}")?;
4506 writeln!(code)?;
4507
4508 writeln!(
4510 code,
4511 " /// Asynchronous forward pass for a single prefill token (no logits readback)."
4512 )?;
4513 writeln!(code, " ///")?;
4514 writeln!(
4515 code,
4516 " /// Commits the command buffer without waiting, enabling double-buffered"
4517 )?;
4518 writeln!(
4519 code,
4520 " /// execution: GPU processes token N while CPU encodes token N+1."
4521 )?;
4522 writeln!(
4523 code,
4524 " pub fn forward_prefill(&mut self, token_id: u32) {{"
4525 )?;
4526 writeln!(code, " self.forward_prefill_batch(&[token_id]);")?;
4527 writeln!(code, " }}")?;
4528 writeln!(code)?;
4529
4530 let batch_matmul_fn = if is_q8 {
4533 "dispatch_matmul_q8_batch"
4534 } else if is_q4 {
4535 "dispatch_matmul_q4_batch"
4536 } else {
4537 "dispatch_matmul_batch"
4538 };
4539
4540 writeln!(
4541 code,
4542 " /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4543 )?;
4544 writeln!(code, " ///")?;
4545 writeln!(
4546 code,
4547 " /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4548 )?;
4549 writeln!(
4550 code,
4551 " /// of mat-vec), and batched causal attention with a single GPU dispatch."
4552 )?;
4553 writeln!(
4554 code,
4555 " /// This provides significant speedup during prompt prefill."
4556 )?;
4557 writeln!(
4558 code,
4559 " pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4560 )?;
4561 writeln!(code, " if tokens.is_empty() {{ return; }}")?;
4562 writeln!(
4563 code,
4564 " // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4565 )?;
4566 writeln!(
4567 code,
4568 " // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4569 )?;
4570 writeln!(
4571 code,
4572 " // longer than that must be processed iteratively. The KV cache"
4573 )?;
4574 writeln!(code, " // carries state across chunks via self.pos.")?;
4575 writeln!(
4576 code,
4577 " for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4578 )?;
4579 writeln!(code, " let m = chunk.len();")?;
4580 writeln!(code, " if m == 0 {{ continue; }}")?;
4581 writeln!(code, " let start_pos = self.pos;")?;
4582 writeln!(code)?;
4583 writeln!(code, " // Wait for any pending command buffer")?;
4584 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4585 writeln!(code, " prev.wait_until_completed();")?;
4586 writeln!(code, " }}")?;
4587 writeln!(code)?;
4588
4589 writeln!(
4591 code,
4592 " // Upload token IDs and positions to GPU buffers"
4593 )?;
4594 writeln!(code, " unsafe {{")?;
4595 writeln!(
4596 code,
4597 " let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4598 )?;
4599 writeln!(
4600 code,
4601 " let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4602 )?;
4603 writeln!(code, " for i in 0..m {{")?;
4604 writeln!(code, " *tok_ptr.add(i) = chunk[i];")?;
4605 writeln!(
4606 code,
4607 " *pos_ptr.add(i) = (start_pos + i) as u32;"
4608 )?;
4609 writeln!(code, " }}")?;
4610 writeln!(code, " }}")?;
4611 writeln!(code)?;
4612
4613 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4614 writeln!(code, " {{")?;
4615 writeln!(
4616 code,
4617 " let enc = cmd.new_compute_command_encoder();"
4618 )?;
4619 writeln!(code)?;
4620
4621 writeln!(
4623 code,
4624 " // 1. Batch embedding lookup: copy all token embeddings at once"
4625 )?;
4626 writeln!(
4627 code,
4628 " self.dispatch_copy_embedding_batch(&enc, m);"
4629 )?;
4630 if matches!(config.hidden_activation, HiddenActivation::GeluApprox) {
4631 writeln!(
4633 code,
4634 " // Gemma-1 embedding scale by sqrt(hidden_size)"
4635 )?;
4636 writeln!(
4637 code,
4638 " self.dispatch_scale_buffer(&enc, &self.batch_hidden_buf, {scale:e}_f32, m * {hidden});",
4639 scale = (hidden as f32).sqrt(),
4640 )?;
4641 }
4642 writeln!(
4644 code,
4645 " self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4646 )?;
4647 writeln!(code)?;
4648
4649 writeln!(code, " // 2. Transformer layers")?;
4651 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4652 writeln!(code)?;
4653
4654 writeln!(
4656 code,
4657 " // Batch RMS norm: batch_residual -> batch_hidden"
4658 )?;
4659 writeln!(
4660 code,
4661 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4662 )?;
4663
4664 writeln!(
4666 code,
4667 " // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4668 )?;
4669 writeln!(
4670 code,
4671 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4672 )?;
4673 if config.qkv_bias {
4674 writeln!(
4675 code,
4676 " // Qwen2: broadcast-add QKV bias across all M tokens."
4677 )?;
4678 writeln!(
4679 code,
4680 " self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4681 )?;
4682 }
4683 writeln!(code)?;
4684
4685 let k_float_offset = hidden;
4687 writeln!(
4688 code,
4689 " // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4690 )?;
4691 writeln!(
4692 code,
4693 " self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4694 )?;
4695 writeln!(code)?;
4696
4697 let v_float_offset = hidden + kv_dim;
4699 writeln!(
4700 code,
4701 " // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4702 )?;
4703 writeln!(
4704 code,
4705 " self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4706 )?;
4707 writeln!(code)?;
4708
4709 writeln!(
4711 code,
4712 " // Batched causal attention: one dispatch for all M tokens"
4713 )?;
4714 writeln!(
4715 code,
4716 " self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4717 )?;
4718 writeln!(code)?;
4719
4720 writeln!(code, " // Batched O projection")?;
4722 writeln!(
4723 code,
4724 " self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4725 )?;
4726 writeln!(code)?;
4727
4728 writeln!(
4730 code,
4731 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4732 )?;
4733 writeln!(code)?;
4734
4735 writeln!(
4737 code,
4738 " // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4739 )?;
4740 writeln!(
4741 code,
4742 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4743 )?;
4744 writeln!(
4745 code,
4746 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4747 )?;
4748 writeln!(
4749 code,
4750 " self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4751 )?;
4752 writeln!(
4753 code,
4754 " self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4755 )?;
4756 writeln!(
4757 code,
4758 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4759 )?;
4760 writeln!(code, " }}")?;
4761 writeln!(code)?;
4762
4763 writeln!(
4765 code,
4766 " // Copy last token's residual to single-token buffer for subsequent forward()"
4767 )?;
4768 writeln!(
4769 code,
4770 " self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4771 )?;
4772 writeln!(code)?;
4773 writeln!(code, " enc.end_encoding();")?;
4774 writeln!(code, " }}")?;
4775 writeln!(code)?;
4776
4777 writeln!(code, " cmd.commit();")?;
4778 writeln!(code, " self.prev_cmd = Some(cmd.to_owned());")?;
4779 writeln!(code, " self.pos += m;")?;
4780 writeln!(code, " }} // end for chunk")?;
4781 writeln!(code, " }}")?;
4782 writeln!(code)?;
4783
4784 writeln!(
4786 code,
4787 " /// Reset the model state for a new inference request."
4788 )?;
4789 writeln!(code, " pub fn reset(&mut self) {{")?;
4790 writeln!(code, " self.pos = 0;")?;
4791 writeln!(code, " self.prev_cmd = None;")?;
4792 writeln!(code, " }}")?;
4793 writeln!(code)?;
4794
4795 writeln!(
4797 code,
4798 " // ── Dispatch helpers (append to a shared compute command encoder) ──"
4799 )?;
4800 writeln!(
4801 code,
4802 " // These methods set pipeline state + buffers + dispatch on an existing"
4803 )?;
4804 writeln!(
4805 code,
4806 " // encoder, avoiding per-operation encoder creation overhead."
4807 )?;
4808 writeln!(code)?;
4809
4810 writeln!(
4812 code,
4813 " /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4814 )?;
4815 writeln!(
4816 code,
4817 " fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4818 )?;
4819 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
4820 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
4821 writeln!(
4822 code,
4823 " enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4824 )?;
4825 writeln!(
4826 code,
4827 " enc.set_buffer(0, Some(&self.residual_buf), 0);"
4828 )?;
4829 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
4830 writeln!(
4831 code,
4832 " enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4833 )?;
4834 writeln!(
4835 code,
4836 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4837 )?;
4838 writeln!(
4839 code,
4840 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4841 )?;
4842 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4843 writeln!(
4844 code,
4845 " let grid_size = MTLSize::new(1, 1, 1); // single threadgroup"
4846 )?;
4847 writeln!(
4848 code,
4849 " enc.dispatch_thread_groups(grid_size, tg_size);"
4850 )?;
4851 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4852 writeln!(code, " }}")?;
4853 writeln!(code)?;
4854
4855 writeln!(
4857 code,
4858 " /// Dispatch matrix-vector multiply: weight * input -> output."
4859 )?;
4860 writeln!(
4861 code,
4862 " fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4863 )?;
4864 writeln!(code, " let r: u32 = rows as u32;")?;
4865 writeln!(code, " let c: u32 = cols as u32;")?;
4866 writeln!(
4867 code,
4868 " enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4869 )?;
4870 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4871 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4872 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4873 writeln!(
4874 code,
4875 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4876 )?;
4877 writeln!(
4878 code,
4879 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4880 )?;
4881 writeln!(
4882 code,
4883 " // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4884 )?;
4885 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4886 writeln!(code, " let num_tg = ((rows + 63) / 64) as u64;")?;
4887 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4888 writeln!(
4889 code,
4890 " enc.dispatch_thread_groups(grid_size, tg_size);"
4891 )?;
4892 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4893 writeln!(code, " }}")?;
4894 writeln!(code)?;
4895
4896 writeln!(
4898 code,
4899 " /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4900 )?;
4901 writeln!(
4902 code,
4903 " /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4904 )?;
4905 writeln!(
4906 code,
4907 " fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4908 )?;
4909 writeln!(code, " let r: u32 = rows as u32;")?;
4910 writeln!(code, " let c: u32 = cols as u32;")?;
4911 writeln!(code, " if cols > 8192 {{")?;
4916 writeln!(code, " let nt: u32 = 1u32;")?;
4917 writeln!(
4918 code,
4919 " enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
4920 )?;
4921 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4922 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4923 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4924 writeln!(
4925 code,
4926 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4927 )?;
4928 writeln!(
4929 code,
4930 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4931 )?;
4932 writeln!(
4933 code,
4934 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4935 )?;
4936 writeln!(code, " let row_tgs = ((rows + 31) / 32) as u64;")?;
4937 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4938 writeln!(
4939 code,
4940 " let grid_size = MTLSize::new(row_tgs, 1, 1);"
4941 )?;
4942 writeln!(
4943 code,
4944 " enc.dispatch_thread_groups(grid_size, tg_size);"
4945 )?;
4946 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4947 writeln!(code, " return;")?;
4948 writeln!(code, " }}")?;
4949 writeln!(
4950 code,
4951 " enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4952 )?;
4953 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4954 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4955 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4956 writeln!(
4957 code,
4958 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4959 )?;
4960 writeln!(
4961 code,
4962 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4963 )?;
4964 writeln!(
4965 code,
4966 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4967 )?;
4968 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4969 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4970 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4971 writeln!(
4972 code,
4973 " enc.dispatch_thread_groups(grid_size, tg_size);"
4974 )?;
4975 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4976 writeln!(code, " }}")?;
4977 writeln!(code)?;
4978
4979 writeln!(
4981 code,
4982 " /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4983 )?;
4984 writeln!(
4985 code,
4986 " /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4987 )?;
4988 writeln!(
4989 code,
4990 " fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4991 )?;
4992 writeln!(code, " let r: u32 = rows as u32;")?;
4993 writeln!(code, " let c: u32 = cols as u32;")?;
4994 writeln!(
4995 code,
4996 " enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4997 )?;
4998 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4999 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5000 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5001 writeln!(
5002 code,
5003 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5004 )?;
5005 writeln!(
5006 code,
5007 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5008 )?;
5009 writeln!(
5010 code,
5011 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
5012 )?;
5013 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5014 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
5015 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5016 writeln!(
5017 code,
5018 " enc.dispatch_thread_groups(grid_size, tg_size);"
5019 )?;
5020 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5021 writeln!(code, " }}")?;
5022 writeln!(code)?;
5023
5024 writeln!(code, " /// Dispatch RoPE on a buffer in-place.")?;
5026 writeln!(
5027 code,
5028 " fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
5029 )?;
5030 writeln!(code, " let nh: u32 = num_heads as u32;")?;
5031 writeln!(code, " let hd: u32 = head_dim as u32;")?;
5032 writeln!(code, " let p: u32 = pos as u32;")?;
5033 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
5034 writeln!(
5035 code,
5036 " let total_pairs = num_heads * (head_dim / 2);"
5037 )?;
5038 writeln!(
5039 code,
5040 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
5041 )?;
5042 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
5043 writeln!(
5044 code,
5045 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5046 )?;
5047 writeln!(
5048 code,
5049 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5050 )?;
5051 writeln!(
5052 code,
5053 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5054 )?;
5055 writeln!(
5056 code,
5057 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5058 )?;
5059 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5060 writeln!(
5061 code,
5062 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
5063 )?;
5064 writeln!(
5065 code,
5066 " enc.dispatch_thread_groups(grid_size, tg_size);"
5067 )?;
5068 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5069 writeln!(code, " }}")?;
5070 writeln!(code)?;
5071
5072 writeln!(
5074 code,
5075 " /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
5076 )?;
5077 writeln!(
5078 code,
5079 " fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
5080 )?;
5081 writeln!(code, " let nh: u32 = num_heads as u32;")?;
5082 writeln!(code, " let hd: u32 = head_dim as u32;")?;
5083 writeln!(code, " let p: u32 = pos as u32;")?;
5084 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
5085 writeln!(
5086 code,
5087 " let total_pairs = num_heads * (head_dim / 2);"
5088 )?;
5089 writeln!(
5090 code,
5091 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
5092 )?;
5093 writeln!(
5094 code,
5095 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
5096 )?;
5097 writeln!(
5098 code,
5099 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5100 )?;
5101 writeln!(
5102 code,
5103 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5104 )?;
5105 writeln!(
5106 code,
5107 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5108 )?;
5109 writeln!(
5110 code,
5111 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5112 )?;
5113 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5114 writeln!(
5115 code,
5116 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
5117 )?;
5118 writeln!(
5119 code,
5120 " enc.dispatch_thread_groups(grid_size, tg_size);"
5121 )?;
5122 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5123 writeln!(code, " }}")?;
5124 writeln!(code)?;
5125
5126 writeln!(
5128 code,
5129 " /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
5130 )?;
5131 writeln!(
5132 code,
5133 " fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
5134 )?;
5135 writeln!(code, " let sl: u32 = seq_len as u32;")?;
5136 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
5137 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
5138 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
5139 writeln!(
5140 code,
5141 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
5142 )?;
5143 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
5144 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
5145 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
5146 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
5147 writeln!(
5148 code,
5149 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5150 )?;
5151 writeln!(
5152 code,
5153 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5154 )?;
5155 writeln!(
5156 code,
5157 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5158 )?;
5159 writeln!(
5160 code,
5161 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5162 )?;
5163 writeln!(
5164 code,
5165 " // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
5166 )?;
5167 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5168 writeln!(
5169 code,
5170 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5171 )?;
5172 writeln!(
5173 code,
5174 " enc.dispatch_thread_groups(grid_size, tg_size);"
5175 )?;
5176 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5177 writeln!(code, " }}")?;
5178 writeln!(code)?;
5179
5180 writeln!(
5182 code,
5183 " /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
5184 )?;
5185 writeln!(
5186 code,
5187 " fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
5188 )?;
5189 writeln!(code, " let sl: u32 = seq_len as u32;")?;
5190 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
5191 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
5192 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
5193 writeln!(
5194 code,
5195 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
5196 )?;
5197 writeln!(
5198 code,
5199 " enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5200 )?;
5201 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
5202 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
5203 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
5204 writeln!(
5205 code,
5206 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5207 )?;
5208 writeln!(
5209 code,
5210 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5211 )?;
5212 writeln!(
5213 code,
5214 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5215 )?;
5216 writeln!(
5217 code,
5218 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5219 )?;
5220 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5221 writeln!(
5222 code,
5223 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5224 )?;
5225 writeln!(
5226 code,
5227 " enc.dispatch_thread_groups(grid_size, tg_size);"
5228 )?;
5229 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5230 writeln!(code, " }}")?;
5231 writeln!(code)?;
5232
5233 writeln!(code, " /// Dispatch fused SiLU-multiply kernel.")?;
5235 writeln!(
5236 code,
5237 " fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5238 )?;
5239 writeln!(code, " let count: u32 = n as u32;")?;
5240 writeln!(
5241 code,
5242 " enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5243 )?;
5244 writeln!(code, " enc.set_buffer(0, Some(gate), 0);")?;
5245 writeln!(code, " enc.set_buffer(1, Some(up), 0);")?;
5246 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5247 writeln!(
5248 code,
5249 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5250 )?;
5251 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5252 writeln!(
5253 code,
5254 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5255 )?;
5256 writeln!(
5257 code,
5258 " enc.dispatch_thread_groups(grid_size, tg_size);"
5259 )?;
5260 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5261 writeln!(code, " }}")?;
5262 writeln!(code)?;
5263
5264 writeln!(
5266 code,
5267 " /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5268 )?;
5269 writeln!(
5270 code,
5271 " /// gate_up_buf contains [gate(n), up(n)] contiguously."
5272 )?;
5273 writeln!(
5274 code,
5275 " fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5276 )?;
5277 writeln!(code, " let count: u32 = n as u32;")?;
5278 writeln!(
5279 code,
5280 " enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5281 )?;
5282 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5283 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5284 writeln!(
5285 code,
5286 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5287 )?;
5288 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5289 writeln!(
5290 code,
5291 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5292 )?;
5293 writeln!(
5294 code,
5295 " enc.dispatch_thread_groups(grid_size, tg_size);"
5296 )?;
5297 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5298 writeln!(code, " }}")?;
5299 writeln!(code)?;
5300
5301 writeln!(
5303 code,
5304 " /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5305 )?;
5306 writeln!(
5307 code,
5308 " fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5309 )?;
5310 writeln!(code, " let n: u32 = count as u32;")?;
5311 writeln!(
5312 code,
5313 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5314 )?;
5315 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5316 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5317 writeln!(
5318 code,
5319 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5320 )?;
5321 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5322 writeln!(
5323 code,
5324 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5325 )?;
5326 writeln!(
5327 code,
5328 " enc.dispatch_thread_groups(grid_size, tg_size);"
5329 )?;
5330 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5331 writeln!(code, " }}")?;
5332 writeln!(code)?;
5333
5334 writeln!(
5336 code,
5337 " /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5338 )?;
5339 writeln!(
5340 code,
5341 " /// Used for embedding table lookup (copy a specific row)."
5342 )?;
5343 writeln!(
5344 code,
5345 " fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5346 )?;
5347 writeln!(code, " let off: u32 = src_offset as u32;")?;
5348 writeln!(code, " let n: u32 = count as u32;")?;
5349 writeln!(
5350 code,
5351 " enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5352 )?;
5353 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5354 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5355 writeln!(
5356 code,
5357 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5358 )?;
5359 writeln!(
5360 code,
5361 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5362 )?;
5363 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5364 writeln!(
5365 code,
5366 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5367 )?;
5368 writeln!(
5369 code,
5370 " enc.dispatch_thread_groups(grid_size, tg_size);"
5371 )?;
5372 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5373 writeln!(code, " }}")?;
5374 writeln!(code)?;
5375
5376 writeln!(
5378 code,
5379 " /// Dispatch copy from source at byte offset to destination at float offset."
5380 )?;
5381 writeln!(
5382 code,
5383 " /// Used for KV cache updates from fused QKV buffer."
5384 )?;
5385 writeln!(
5386 code,
5387 " fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5388 )?;
5389 writeln!(code, " let n: u32 = count as u32;")?;
5390 writeln!(
5391 code,
5392 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5393 )?;
5394 writeln!(
5395 code,
5396 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5397 )?;
5398 writeln!(
5399 code,
5400 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5401 )?;
5402 writeln!(
5403 code,
5404 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5405 )?;
5406 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5407 writeln!(
5408 code,
5409 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5410 )?;
5411 writeln!(
5412 code,
5413 " enc.dispatch_thread_groups(grid_size, tg_size);"
5414 )?;
5415 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5416 writeln!(code, " }}")?;
5417 writeln!(code)?;
5418
5419 writeln!(
5421 code,
5422 " /// Dispatch f32 -> f16 copy from src at byte offset to half-typed dst at element offset."
5423 )?;
5424 writeln!(
5425 code,
5426 " /// Used for single-token decode KV cache updates (f32 QKV buf -> f16 KV cache)."
5427 )?;
5428 writeln!(
5429 code,
5430 " fn dispatch_copy_from_offset_f16(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_elem_offset: usize, count: usize) {{"
5431 )?;
5432 writeln!(code, " let n: u32 = count as u32;")?;
5433 writeln!(
5434 code,
5435 " enc.set_compute_pipeline_state(&self.copy_f32_to_f16_offset_pipeline);"
5436 )?;
5437 writeln!(
5438 code,
5439 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5440 )?;
5441 writeln!(
5442 code,
5443 " enc.set_buffer(1, Some(dst), (dst_elem_offset * 2) as u64);"
5444 )?;
5445 writeln!(
5446 code,
5447 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5448 )?;
5449 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5450 writeln!(
5451 code,
5452 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5453 )?;
5454 writeln!(
5455 code,
5456 " enc.dispatch_thread_groups(grid_size, tg_size);"
5457 )?;
5458 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5459 writeln!(code, " }}")?;
5460 writeln!(code)?;
5461
5462 writeln!(
5464 code,
5465 " /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5466 )?;
5467 writeln!(
5468 code,
5469 " /// Used for KV cache updates (write to a specific position in the cache)."
5470 )?;
5471 writeln!(
5472 code,
5473 " fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5474 )?;
5475 writeln!(code, " let n: u32 = count as u32;")?;
5476 writeln!(
5477 code,
5478 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5479 )?;
5480 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5481 writeln!(
5482 code,
5483 " enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5484 )?;
5485 writeln!(
5486 code,
5487 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5488 )?;
5489 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5490 writeln!(
5491 code,
5492 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5493 )?;
5494 writeln!(
5495 code,
5496 " enc.dispatch_thread_groups(grid_size, tg_size);"
5497 )?;
5498 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5499 writeln!(code, " }}")?;
5500 writeln!(code)?;
5501
5502 writeln!(
5504 code,
5505 " /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5506 )?;
5507 writeln!(
5508 code,
5509 " fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5510 )?;
5511 writeln!(code, " let count: u32 = n as u32;")?;
5512 writeln!(
5513 code,
5514 " enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5515 )?;
5516 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5517 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5518 writeln!(
5519 code,
5520 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5521 )?;
5522 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5523 writeln!(
5524 code,
5525 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5526 )?;
5527 writeln!(
5528 code,
5529 " enc.dispatch_thread_groups(grid_size, tg_size);"
5530 )?;
5531 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5532 writeln!(code, " }}")?;
5533 writeln!(code)?;
5534
5535 writeln!(code, " // ── Batched prefill dispatch helpers ──")?;
5537 writeln!(code)?;
5538
5539 writeln!(
5541 code,
5542 " /// Dispatch batched embedding lookup: copy M token embeddings at once."
5543 )?;
5544 writeln!(
5545 code,
5546 " fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5547 )?;
5548 writeln!(code, " let dim: u32 = HIDDEN_SIZE as u32;")?;
5549 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5550 writeln!(
5551 code,
5552 " enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5553 )?;
5554 writeln!(code, " enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5555 writeln!(
5556 code,
5557 " enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5558 )?;
5559 writeln!(
5560 code,
5561 " enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5562 )?;
5563 writeln!(
5564 code,
5565 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5566 )?;
5567 writeln!(
5568 code,
5569 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5570 )?;
5571 writeln!(code, " let total = num_tokens * HIDDEN_SIZE;")?;
5572 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5573 writeln!(
5574 code,
5575 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5576 )?;
5577 writeln!(
5578 code,
5579 " enc.dispatch_thread_groups(grid_size, tg_size);"
5580 )?;
5581 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5582 writeln!(code, " }}")?;
5583 writeln!(code)?;
5584
5585 writeln!(
5587 code,
5588 " /// Multiply every element of `data` by `scale` in place."
5589 )?;
5590 writeln!(
5591 code,
5592 " fn dispatch_scale_buffer(&self, enc: &ComputeCommandEncoderRef, data: &Buffer, scale: f32, count: usize) {{"
5593 )?;
5594 writeln!(code, " let n: u32 = count as u32;")?;
5595 writeln!(
5596 code,
5597 " enc.set_compute_pipeline_state(&self.scale_buffer_pipeline);"
5598 )?;
5599 writeln!(code, " enc.set_buffer(0, Some(data), 0);")?;
5600 writeln!(
5601 code,
5602 " enc.set_bytes(1, mem::size_of::<f32>() as u64, &scale as *const f32 as *const _);"
5603 )?;
5604 writeln!(
5605 code,
5606 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5607 )?;
5608 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5609 writeln!(
5610 code,
5611 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5612 )?;
5613 writeln!(
5614 code,
5615 " enc.dispatch_thread_groups(grid_size, tg_size);"
5616 )?;
5617 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5618 writeln!(code, " }}")?;
5619 writeln!(code)?;
5620
5621 writeln!(
5623 code,
5624 " /// Dispatch batched RMS norm: normalizes M vectors at once."
5625 )?;
5626 writeln!(
5627 code,
5628 " fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5629 )?;
5630 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
5631 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
5632 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5633 writeln!(
5634 code,
5635 " enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5636 )?;
5637 writeln!(code, " enc.set_buffer(0, Some(input), 0);")?;
5638 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
5639 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5640 writeln!(
5641 code,
5642 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5643 )?;
5644 writeln!(
5645 code,
5646 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5647 )?;
5648 writeln!(
5649 code,
5650 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5651 )?;
5652 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5653 writeln!(
5654 code,
5655 " let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5656 )?;
5657 writeln!(
5658 code,
5659 " enc.dispatch_thread_groups(grid_size, tg_size);"
5660 )?;
5661 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5662 writeln!(code, " }}")?;
5663 writeln!(code)?;
5664
5665 writeln!(
5667 code,
5668 " /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5669 )?;
5670 writeln!(
5671 code,
5672 " fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5673 )?;
5674 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5675 writeln!(code, " let r: u32 = rows as u32;")?;
5676 writeln!(code, " let c: u32 = cols as u32;")?;
5677 writeln!(
5678 code,
5679 " enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5680 )?;
5681 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5682 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5683 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5684 writeln!(
5685 code,
5686 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5687 )?;
5688 writeln!(
5689 code,
5690 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5691 )?;
5692 writeln!(
5693 code,
5694 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5695 )?;
5696 writeln!(
5697 code,
5698 " let row_tgs = (rows + 63) / 64; // 64 rows per threadgroup for f32"
5699 )?;
5700 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5701 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5702 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5703 writeln!(
5704 code,
5705 " enc.dispatch_thread_groups(grid_size, tg_size);"
5706 )?;
5707 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5708 writeln!(code, " }}")?;
5709 writeln!(code)?;
5710
5711 writeln!(
5713 code,
5714 " /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5715 )?;
5716 writeln!(code, " ///")?;
5717 writeln!(
5718 code,
5719 " /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5720 )?;
5721 writeln!(
5722 code,
5723 " /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5724 )?;
5725 writeln!(
5726 code,
5727 " fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5728 )?;
5729 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5730 writeln!(code, " let r: u32 = rows as u32;")?;
5731 writeln!(code, " let c: u32 = cols as u32;")?;
5732 writeln!(
5733 code,
5734 " // Tile sizes must match the Metal shader constants."
5735 )?;
5736 writeln!(code, " const TOKENS_PER_TG_Q8: usize = 4;")?;
5737 writeln!(code, " const MMA_TOK_TILE: usize = 16;")?;
5738 writeln!(code, " const MMA_ROW_TILE: usize = 16;")?;
5739 writeln!(code, " const MMA32_TOK_TILE: usize = 32;")?;
5740 writeln!(code, " const MMA32_ROW_TILE: usize = 32;")?;
5741 writeln!(
5742 code,
5743 " // Hardware matrix-multiply paths (simdgroup_matrix)."
5744 )?;
5745 writeln!(
5746 code,
5747 " // Prefer the large 32×32 tile when the problem supports it — halves"
5748 )?;
5749 writeln!(
5750 code,
5751 " // dispatch count and reuses each weight load across 32 tokens."
5752 )?;
5753 writeln!(
5754 code,
5755 " if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5756 )?;
5757 writeln!(
5758 code,
5759 " // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5760 )?;
5761 writeln!(
5762 code,
5763 " // It wins at moderate prefill lengths where the GPU is wave-starved,"
5764 )?;
5765 writeln!(
5766 code,
5767 " // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5768 )?;
5769 writeln!(
5770 code,
5771 " // case (135M / 360M). Switch at cols >= 2048 — a clean split that"
5772 )?;
5773 writeln!(
5774 code,
5775 " // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5776 )?;
5777 writeln!(
5778 code,
5779 " // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5780 )?;
5781 writeln!(
5782 code,
5783 " // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5784 )?;
5785 writeln!(
5786 code,
5787 " // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5788 )?;
5789 writeln!(code, " let use_h4 = cols >= 2048;")?;
5790 writeln!(code, " let pipe = if use_h4 {{")?;
5791 writeln!(code, " if num_tokens >= 256 {{")?;
5792 writeln!(
5793 code,
5794 " &self.matmul_q8_mma32_hh4_pipeline"
5795 )?;
5796 writeln!(code, " }} else {{")?;
5797 writeln!(
5798 code,
5799 " &self.matmul_q8_mma32_h4_pipeline"
5800 )?;
5801 writeln!(code, " }}")?;
5802 writeln!(code, " }} else {{")?;
5803 writeln!(code, " &self.matmul_q8_mma32_pipeline")?;
5804 writeln!(code, " }};")?;
5805 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
5806 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5807 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5808 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5809 writeln!(
5810 code,
5811 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5812 )?;
5813 writeln!(
5814 code,
5815 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5816 )?;
5817 writeln!(
5818 code,
5819 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5820 )?;
5821 writeln!(code, " let row_tgs = rows / MMA32_ROW_TILE;")?;
5822 writeln!(
5823 code,
5824 " let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5825 )?;
5826 writeln!(
5827 code,
5828 " let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5829 )?;
5830 writeln!(
5831 code,
5832 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5833 )?;
5834 writeln!(
5835 code,
5836 " enc.dispatch_thread_groups(grid_size, tg_size);"
5837 )?;
5838 writeln!(
5839 code,
5840 " }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5841 )?;
5842 writeln!(
5843 code,
5844 " enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5845 )?;
5846 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5847 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5848 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5849 writeln!(
5850 code,
5851 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5852 )?;
5853 writeln!(
5854 code,
5855 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5856 )?;
5857 writeln!(
5858 code,
5859 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5860 )?;
5861 writeln!(code, " let row_tgs = rows / MMA_ROW_TILE;")?;
5862 writeln!(
5863 code,
5864 " let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5865 )?;
5866 writeln!(
5867 code,
5868 " let tg_size = MTLSize::new(128, 1, 1); // 4 simdgroups × 32 lanes"
5869 )?;
5870 writeln!(
5871 code,
5872 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5873 )?;
5874 writeln!(
5875 code,
5876 " enc.dispatch_thread_groups(grid_size, tg_size);"
5877 )?;
5878 writeln!(
5884 code,
5885 " }} else if num_tokens >= TOKENS_PER_TG_Q8 || cols > 8192 {{"
5886 )?;
5887 writeln!(
5888 code,
5889 " enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5890 )?;
5891 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5892 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5893 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5894 writeln!(
5895 code,
5896 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5897 )?;
5898 writeln!(
5899 code,
5900 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5901 )?;
5902 writeln!(
5903 code,
5904 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5905 )?;
5906 writeln!(
5907 code,
5908 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5909 )?;
5910 writeln!(
5911 code,
5912 " let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5913 )?;
5914 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5915 writeln!(
5916 code,
5917 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5918 )?;
5919 writeln!(
5920 code,
5921 " enc.dispatch_thread_groups(grid_size, tg_size);"
5922 )?;
5923 writeln!(code, " }} else {{")?;
5924 writeln!(
5925 code,
5926 " enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5927 )?;
5928 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5929 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5930 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5931 writeln!(
5932 code,
5933 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5934 )?;
5935 writeln!(
5936 code,
5937 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5938 )?;
5939 writeln!(
5940 code,
5941 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5942 )?;
5943 writeln!(
5944 code,
5945 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5946 )?;
5947 writeln!(
5948 code,
5949 " let num_tg = (row_tgs * num_tokens) as u64;"
5950 )?;
5951 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5952 writeln!(
5953 code,
5954 " let grid_size = MTLSize::new(num_tg, 1, 1);"
5955 )?;
5956 writeln!(
5957 code,
5958 " enc.dispatch_thread_groups(grid_size, tg_size);"
5959 )?;
5960 writeln!(code, " }}")?;
5961 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5962 writeln!(code, " }}")?;
5963 writeln!(code)?;
5964
5965 writeln!(
5967 code,
5968 " /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5969 )?;
5970 writeln!(
5971 code,
5972 " fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5973 )?;
5974 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5975 writeln!(code, " let r: u32 = rows as u32;")?;
5976 writeln!(code, " let c: u32 = cols as u32;")?;
5977 writeln!(
5978 code,
5979 " enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5980 )?;
5981 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5982 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5983 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5984 writeln!(
5985 code,
5986 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5987 )?;
5988 writeln!(
5989 code,
5990 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5991 )?;
5992 writeln!(
5993 code,
5994 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5995 )?;
5996 writeln!(
5997 code,
5998 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q4"
5999 )?;
6000 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
6001 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6002 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
6003 writeln!(
6004 code,
6005 " enc.dispatch_thread_groups(grid_size, tg_size);"
6006 )?;
6007 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6008 writeln!(code, " }}")?;
6009 writeln!(code)?;
6010
6011 if config.qkv_bias {
6013 writeln!(
6014 code,
6015 " /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
6016 )?;
6017 writeln!(
6018 code,
6019 " fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
6020 )?;
6021 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
6022 writeln!(code, " let r: u32 = rows as u32;")?;
6023 writeln!(
6024 code,
6025 " enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
6026 )?;
6027 writeln!(code, " enc.set_buffer(0, Some(out), 0);")?;
6028 writeln!(code, " enc.set_buffer(1, Some(bias), 0);")?;
6029 writeln!(
6030 code,
6031 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
6032 )?;
6033 writeln!(
6034 code,
6035 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
6036 )?;
6037 writeln!(code, " let total = (num_tokens * rows) as u64;")?;
6038 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6039 writeln!(
6040 code,
6041 " let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
6042 )?;
6043 writeln!(
6044 code,
6045 " enc.dispatch_thread_groups(grid_size, tg_size);"
6046 )?;
6047 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6048 writeln!(code, " }}")?;
6049 writeln!(code)?;
6050 }
6051
6052 writeln!(
6054 code,
6055 " /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
6056 )?;
6057 writeln!(
6058 code,
6059 " /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
6060 )?;
6061 writeln!(
6062 code,
6063 " /// `row_stride` is the number of floats per token row in the batch buffer."
6064 )?;
6065 writeln!(
6066 code,
6067 " fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
6068 )?;
6069 writeln!(code, " let nh: u32 = num_heads as u32;")?;
6070 writeln!(code, " let hd: u32 = head_dim as u32;")?;
6071 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
6072 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
6073 writeln!(
6074 code,
6075 " let pairs_per_token = num_heads * (head_dim / 2);"
6076 )?;
6077 writeln!(
6078 code,
6079 " let total_pairs = num_tokens * pairs_per_token;"
6080 )?;
6081 writeln!(
6094 code,
6095 " // Apply RoPE to each token individually (different positions, non-contiguous layout)"
6096 )?;
6097 writeln!(code, " for t in 0..num_tokens {{")?;
6098 writeln!(
6099 code,
6100 " let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
6101 )?;
6102 writeln!(
6103 code,
6104 " let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
6105 )?;
6106 writeln!(
6107 code,
6108 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
6109 )?;
6110 writeln!(
6111 code,
6112 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
6113 )?;
6114 writeln!(
6115 code,
6116 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6117 )?;
6118 writeln!(
6119 code,
6120 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6121 )?;
6122 writeln!(
6123 code,
6124 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
6125 )?;
6126 writeln!(
6127 code,
6128 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6129 )?;
6130 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6131 writeln!(
6132 code,
6133 " let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
6134 )?;
6135 writeln!(
6136 code,
6137 " enc.dispatch_thread_groups(grid_size, tg_size);"
6138 )?;
6139 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6140 writeln!(code, " }}")?;
6141 writeln!(code, " }}")?;
6142 writeln!(code)?;
6143
6144 writeln!(
6146 code,
6147 " /// Dispatch batched fused SiLU-multiply for M tokens."
6148 )?;
6149 writeln!(
6150 code,
6151 " fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
6152 )?;
6153 writeln!(code, " let count: u32 = n as u32;")?;
6154 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
6155 writeln!(
6156 code,
6157 " enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
6158 )?;
6159 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
6160 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
6161 writeln!(
6162 code,
6163 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
6164 )?;
6165 writeln!(
6166 code,
6167 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
6168 )?;
6169 writeln!(code, " let total = num_tokens * n;")?;
6170 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6171 writeln!(
6172 code,
6173 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6174 )?;
6175 writeln!(
6176 code,
6177 " enc.dispatch_thread_groups(grid_size, tg_size);"
6178 )?;
6179 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6180 writeln!(code, " }}")?;
6181 writeln!(code)?;
6182
6183 writeln!(
6185 code,
6186 " /// Dispatch in-place add for total_n elements: a[i] += b[i]."
6187 )?;
6188 writeln!(
6189 code,
6190 " fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
6191 )?;
6192 writeln!(code, " let count: u32 = total_n as u32;")?;
6193 writeln!(
6194 code,
6195 " enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
6196 )?;
6197 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
6198 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
6199 writeln!(
6200 code,
6201 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
6202 )?;
6203 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6204 writeln!(
6205 code,
6206 " let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
6207 )?;
6208 writeln!(
6209 code,
6210 " enc.dispatch_thread_groups(grid_size, tg_size);"
6211 )?;
6212 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6213 writeln!(code, " }}")?;
6214 writeln!(code)?;
6215
6216 writeln!(
6218 code,
6219 " /// Copy src to dst using compute copy kernel (for batch residual init)."
6220 )?;
6221 writeln!(
6222 code,
6223 " fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
6224 )?;
6225 writeln!(code, " let n: u32 = count as u32;")?;
6226 writeln!(
6227 code,
6228 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6229 )?;
6230 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6231 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
6232 writeln!(
6233 code,
6234 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6235 )?;
6236 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6237 writeln!(
6238 code,
6239 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6240 )?;
6241 writeln!(
6242 code,
6243 " enc.dispatch_thread_groups(grid_size, tg_size);"
6244 )?;
6245 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6246 writeln!(code, " }}")?;
6247 writeln!(code)?;
6248
6249 writeln!(
6251 code,
6252 " /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
6253 )?;
6254 writeln!(
6255 code,
6256 " fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6257 )?;
6258 writeln!(code, " let n: u32 = count as u32;")?;
6259 writeln!(
6260 code,
6261 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6262 )?;
6263 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6264 writeln!(
6265 code,
6266 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6267 )?;
6268 writeln!(
6269 code,
6270 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6271 )?;
6272 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6273 writeln!(
6274 code,
6275 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6276 )?;
6277 writeln!(
6278 code,
6279 " enc.dispatch_thread_groups(grid_size, tg_size);"
6280 )?;
6281 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6282 writeln!(code, " }}")?;
6283 writeln!(code)?;
6284
6285 writeln!(
6287 code,
6288 " /// Copy from src at byte offset to dst at float offset."
6289 )?;
6290 writeln!(
6291 code,
6292 " fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6293 )?;
6294 writeln!(code, " let n: u32 = count as u32;")?;
6295 writeln!(
6296 code,
6297 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6298 )?;
6299 writeln!(
6300 code,
6301 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6302 )?;
6303 writeln!(
6304 code,
6305 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6306 )?;
6307 writeln!(
6308 code,
6309 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6310 )?;
6311 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6312 writeln!(
6313 code,
6314 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6315 )?;
6316 writeln!(
6317 code,
6318 " enc.dispatch_thread_groups(grid_size, tg_size);"
6319 )?;
6320 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6321 writeln!(code, " }}")?;
6322 writeln!(code)?;
6323
6324 writeln!(
6326 code,
6327 " /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6328 )?;
6329 writeln!(
6330 code,
6331 " fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6332 )?;
6333 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6334 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6335 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6336 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6337 writeln!(code, " let so: u32 = src_offset as u32;")?;
6338 writeln!(
6339 code,
6340 " enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6341 )?;
6342 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6343 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
6344 writeln!(
6345 code,
6346 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6347 )?;
6348 writeln!(
6349 code,
6350 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6351 )?;
6352 writeln!(
6353 code,
6354 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6355 )?;
6356 writeln!(
6357 code,
6358 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6359 )?;
6360 writeln!(
6361 code,
6362 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6363 )?;
6364 writeln!(code, " let total = num_tokens * kv_dim;")?;
6365 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6366 writeln!(
6367 code,
6368 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6369 )?;
6370 writeln!(
6371 code,
6372 " enc.dispatch_thread_groups(grid_size, tg_size);"
6373 )?;
6374 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6375 writeln!(code, " }}")?;
6376 writeln!(code)?;
6377
6378 writeln!(
6380 code,
6381 " /// Dispatch batched causal attention: one dispatch for all M tokens."
6382 )?;
6383 writeln!(
6384 code,
6385 " fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6386 )?;
6387 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6388 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6389 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
6390 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6391 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6392 writeln!(code, " let qs: u32 = q_stride as u32;")?;
6393 writeln!(code, " let max_seq = base_pos + num_tokens;")?;
6403 writeln!(code, " let _ = max_seq;")?;
6404 writeln!(
6405 code,
6406 " let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
6407 )?;
6408 writeln!(code, " .map(|v| v == \"0\").unwrap_or(false);")?;
6409 writeln!(
6410 code,
6411 " let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
6412 )?;
6413 writeln!(code, " if use_mma_flash {{")?;
6414 writeln!(
6415 code,
6416 " let pipe = &self.attention_mma_flash_batch_pipeline;"
6417 )?;
6418 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6419 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6420 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6421 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6422 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6423 writeln!(
6424 code,
6425 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6426 )?;
6427 writeln!(
6428 code,
6429 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6430 )?;
6431 writeln!(
6432 code,
6433 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6434 )?;
6435 writeln!(
6436 code,
6437 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6438 )?;
6439 writeln!(
6440 code,
6441 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6442 )?;
6443 writeln!(
6444 code,
6445 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6446 )?;
6447 writeln!(
6448 code,
6449 " // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6450 )?;
6451 writeln!(code, " let tg_size = MTLSize::new(128, 1, 1);")?;
6452 writeln!(
6453 code,
6454 " let q_blocks = ((num_tokens + 7) / 8) as u64;"
6455 )?;
6456 writeln!(
6457 code,
6458 " let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6459 )?;
6460 writeln!(
6461 code,
6462 " enc.dispatch_thread_groups(grid_size, tg_size);"
6463 )?;
6464 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6465 writeln!(code, " return;")?;
6466 writeln!(code, " }}")?;
6467 writeln!(code, " let pipe = &self.attention_batch_pipeline;")?;
6468 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6469 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6470 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6471 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6472 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6473 writeln!(
6474 code,
6475 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6476 )?;
6477 writeln!(
6478 code,
6479 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6480 )?;
6481 writeln!(
6482 code,
6483 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6484 )?;
6485 writeln!(
6486 code,
6487 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6488 )?;
6489 writeln!(
6490 code,
6491 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6492 )?;
6493 writeln!(
6494 code,
6495 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6496 )?;
6497 writeln!(
6498 code,
6499 " // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6500 )?;
6501 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6502 writeln!(
6503 code,
6504 " let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6505 )?;
6506 writeln!(
6507 code,
6508 " enc.dispatch_thread_groups(grid_size, tg_size);"
6509 )?;
6510 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6511 writeln!(code, " }}")?;
6512 writeln!(code)?;
6513
6514 writeln!(
6516 code,
6517 " /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6518 )?;
6519 writeln!(
6520 code,
6521 " /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6522 )?;
6523 writeln!(
6524 code,
6525 " fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6526 )?;
6527 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6528 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6529 writeln!(code, " let nq: u32 = NUM_HEADS as u32;")?;
6530 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6531 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6532 writeln!(code, " let qs: u32 = qkv_stride as u32;")?;
6533 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
6534 writeln!(
6535 code,
6536 " enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6537 )?;
6538 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
6539 writeln!(
6540 code,
6541 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6542 )?;
6543 writeln!(
6544 code,
6545 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6546 )?;
6547 writeln!(
6548 code,
6549 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6550 )?;
6551 writeln!(
6552 code,
6553 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6554 )?;
6555 writeln!(
6556 code,
6557 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6558 )?;
6559 writeln!(
6560 code,
6561 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6562 )?;
6563 writeln!(
6564 code,
6565 " enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6566 )?;
6567 writeln!(
6568 code,
6569 " let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6570 )?;
6571 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6572 writeln!(
6573 code,
6574 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6575 )?;
6576 writeln!(
6577 code,
6578 " enc.dispatch_thread_groups(grid_size, tg_size);"
6579 )?;
6580 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6581 writeln!(code, " }}")?;
6582 writeln!(code)?;
6583
6584 writeln!(
6586 code,
6587 " /// Dispatch fused K+V cache copy in one kernel launch."
6588 )?;
6589 writeln!(
6590 code,
6591 " /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6592 )?;
6593 writeln!(
6594 code,
6595 " fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6596 )?;
6597 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6598 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6599 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6600 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6601 writeln!(code, " let ko: u32 = k_offset as u32;")?;
6602 writeln!(code, " let vo: u32 = v_offset as u32;")?;
6603 writeln!(
6604 code,
6605 " enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6606 )?;
6607 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6608 writeln!(code, " enc.set_buffer(1, Some(k_dst), 0);")?;
6609 writeln!(code, " enc.set_buffer(2, Some(v_dst), 0);")?;
6610 writeln!(
6611 code,
6612 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6613 )?;
6614 writeln!(
6615 code,
6616 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6617 )?;
6618 writeln!(
6619 code,
6620 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6621 )?;
6622 writeln!(
6623 code,
6624 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6625 )?;
6626 writeln!(
6627 code,
6628 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6629 )?;
6630 writeln!(
6631 code,
6632 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6633 )?;
6634 writeln!(
6635 code,
6636 " let total = num_tokens * kv_dim * 2; // K + V"
6637 )?;
6638 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6639 writeln!(
6640 code,
6641 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6642 )?;
6643 writeln!(
6644 code,
6645 " enc.dispatch_thread_groups(grid_size, tg_size);"
6646 )?;
6647 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6648 writeln!(code, " }}")?;
6649
6650 writeln!(code, "}}")?;
6651 writeln!(code)?;
6652
6653 Ok(())
6654}
6655
6656fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6657 writeln!(
6658 code,
6659 "// ── Helper functions ──────────────────────────────────"
6660 )?;
6661 writeln!(code)?;
6662 writeln!(
6663 code,
6664 "/// Create a compute pipeline from a named function in the Metal library."
6665 )?;
6666 writeln!(
6667 code,
6668 "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6669 )?;
6670 writeln!(
6671 code,
6672 " let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6673 )?;
6674 writeln!(
6675 code,
6676 " device.new_compute_pipeline_state_with_function(&func)"
6677 )?;
6678 writeln!(
6679 code,
6680 " .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6681 )?;
6682 writeln!(code, "}}")?;
6683 writeln!(code)?;
6684
6685 Ok(())
6686}
6687
6688fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6693 let _sanitized = sanitize_name(model_name);
6694 let _vocab = config.vocab_size;
6695
6696 let mut code = String::with_capacity(16 * 1024);
6697 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6698 writeln!(
6699 code,
6700 "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6701 )?;
6702 writeln!(code)?;
6703 writeln!(code, "mod model;")?;
6704 writeln!(code)?;
6705 writeln!(code, "use std::io::Write;")?;
6706 writeln!(code, "use std::time::Instant;")?;
6707 writeln!(code, "use serde::Deserialize;")?;
6708 writeln!(code)?;
6709
6710 writeln!(code, "fn main() {{")?;
6712 writeln!(
6713 code,
6714 " let args: Vec<String> = std::env::args().collect();"
6715 )?;
6716 writeln!(code)?;
6717 writeln!(
6718 code,
6719 " // Detect --serve mode (only requires weights + tokenizer)"
6720 )?;
6721 writeln!(
6722 code,
6723 " let serve_mode = args.iter().any(|a| a == \"--serve\");"
6724 )?;
6725 writeln!(code)?;
6726 writeln!(code, " if !serve_mode && args.len() < 4 {{")?;
6727 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6728 writeln!(code, " eprintln!(\" {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6729 writeln!(code, " std::process::exit(1);")?;
6730 writeln!(code, " }}")?;
6731 writeln!(code)?;
6732 writeln!(code, " if serve_mode && args.len() < 3 {{")?;
6733 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6734 writeln!(code, " std::process::exit(1);")?;
6735 writeln!(code, " }}")?;
6736 writeln!(code)?;
6737 writeln!(code, " let weights_path = &args[1];")?;
6738 writeln!(code, " let tokenizer_path = &args[2];")?;
6739 writeln!(code)?;
6740 writeln!(code, " // Parse optional flags")?;
6741 writeln!(code, " let mut max_tokens: usize = 128;")?;
6742 writeln!(code, " let mut port: u16 = 8080;")?;
6743 writeln!(
6744 code,
6745 " let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6746 )?;
6747 writeln!(
6748 code,
6749 " let profile = args.iter().any(|a| a == \"--profile\");"
6750 )?;
6751 writeln!(code, " let mut i = 3;")?;
6752 writeln!(code, " while i < args.len() {{")?;
6753 writeln!(
6754 code,
6755 " if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6756 )?;
6757 writeln!(
6758 code,
6759 " max_tokens = args[i + 1].parse().unwrap_or(128);"
6760 )?;
6761 writeln!(code, " i += 2;")?;
6762 writeln!(
6763 code,
6764 " }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6765 )?;
6766 writeln!(
6767 code,
6768 " port = args[i + 1].parse().unwrap_or(8080);"
6769 )?;
6770 writeln!(code, " i += 2;")?;
6771 writeln!(code, " }} else if args[i] == \"--serve\" {{")?;
6772 writeln!(code, " i += 1;")?;
6773 writeln!(code, " }} else if args[i] == \"--profile\" {{")?;
6774 writeln!(code, " i += 1;")?;
6775 writeln!(code, " }} else {{")?;
6776 writeln!(code, " i += 1;")?;
6777 writeln!(code, " }}")?;
6778 writeln!(code, " }}")?;
6779 writeln!(code)?;
6780
6781 writeln!(
6783 code,
6784 " // Memory-map weights for zero-copy loading on Apple Silicon"
6785 )?;
6786 writeln!(
6787 code,
6788 " let weights_file = std::fs::File::open(weights_path)"
6789 )?;
6790 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6791 writeln!(
6792 code,
6793 " let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6794 )?;
6795 writeln!(code)?;
6796 writeln!(code, " // Load tokenizer")?;
6797 writeln!(
6798 code,
6799 " let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6800 )?;
6801 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6802 writeln!(code)?;
6803 writeln!(code, " // Create Metal model")?;
6804 writeln!(code, " eprintln!(\"Loading model onto Metal GPU...\");")?;
6805 writeln!(
6806 code,
6807 " let mut model = model::MetalModel::new(&weights_mmap);"
6808 )?;
6809 writeln!(code)?;
6810
6811 writeln!(code, " if serve_mode {{")?;
6813 writeln!(code, " serve(model, tokenizer, port);")?;
6814 writeln!(code, " }} else {{")?;
6815 writeln!(code, " let prompt = &args[3];")?;
6816 writeln!(
6817 code,
6818 " cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6819 )?;
6820 writeln!(code, " }}")?;
6821 writeln!(code, "}}")?;
6822 writeln!(code)?;
6823
6824 writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6826 writeln!(code, " // Tokenize prompt")?;
6827 writeln!(code, " let encoding = tokenizer.encode(prompt, true)")?;
6828 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6829 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6830 writeln!(code)?;
6831 writeln!(
6832 code,
6833 " // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6834 )?;
6835 writeln!(
6836 code,
6837 " // Uses double-buffered batch dispatch for GPU-efficient matmul."
6838 )?;
6839 writeln!(
6840 code,
6841 " // The last token uses synchronous forward() to get logits."
6842 )?;
6843 writeln!(code, " let prompt_len = prompt_tokens.len();")?;
6844 writeln!(code, " let prefill_start = Instant::now();")?;
6845 writeln!(code, " let logits = if prompt_len > 1 {{")?;
6846 writeln!(
6847 code,
6848 " model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6849 )?;
6850 writeln!(code, " model.forward(prompt_tokens[prompt_len - 1])")?;
6851 writeln!(code, " }} else {{")?;
6852 writeln!(code, " model.forward(prompt_tokens[0])")?;
6853 writeln!(code, " }};")?;
6854 writeln!(
6855 code,
6856 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6857 )?;
6858 writeln!(code, " let prefill_tokens = prompt_tokens.len();")?;
6859 writeln!(
6860 code,
6861 " eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6862 )?;
6863 writeln!(
6864 code,
6865 " prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6866 )?;
6867 writeln!(code)?;
6868 writeln!(code, " // Generate tokens")?;
6869 writeln!(code, " let mut next_token = argmax(&logits);")?;
6870 writeln!(code, " let gen_start = Instant::now();")?;
6871 writeln!(code, " let mut generated_count: usize = 0;")?;
6872 writeln!(code)?;
6873 writeln!(code, " for _ in 0..max_tokens {{")?;
6874 writeln!(
6875 code,
6876 " if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6877 )?;
6878 writeln!(code, " if !quiet {{")?;
6879 writeln!(code, " print!(\"{{}}\", text);")?;
6880 writeln!(code, " std::io::stdout().flush().ok();")?;
6881 writeln!(code, " }}")?;
6882 writeln!(code, " }}")?;
6883 writeln!(code, " generated_count += 1;")?;
6884 writeln!(code)?;
6885 writeln!(
6886 code,
6887 " // Use profiling forward for first token when --profile is set"
6888 )?;
6889 writeln!(
6890 code,
6891 " let logits = if profile && generated_count == 1 {{"
6892 )?;
6893 writeln!(code, " model.forward_profile(next_token)")?;
6894 writeln!(code, " }} else {{")?;
6895 writeln!(code, " model.forward(next_token)")?;
6896 writeln!(code, " }};")?;
6897 writeln!(code, " next_token = argmax(&logits);")?;
6898 writeln!(code)?;
6899 writeln!(code, " // Stop on EOS (token 2 for most models)")?;
6900 writeln!(code, " if next_token == 2 {{")?;
6901 writeln!(code, " break;")?;
6902 writeln!(code, " }}")?;
6903 writeln!(code)?;
6904 writeln!(
6905 code,
6906 " // Yield between tokens to reduce sustained GPU thermal load."
6907 )?;
6908 writeln!(
6909 code,
6910 " // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6911 )?;
6912 writeln!(
6913 code,
6914 " // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6915 )?;
6916 writeln!(
6917 code,
6918 " // briefly, providing a micro-break that helps sustain peak throughput."
6919 )?;
6920 writeln!(code, " std::thread::yield_now();")?;
6921 writeln!(code, " }}")?;
6922 writeln!(code, " if !quiet {{")?;
6923 writeln!(code, " println!();")?;
6924 writeln!(code, " }}")?;
6925 writeln!(
6926 code,
6927 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6928 )?;
6929 writeln!(
6930 code,
6931 " eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6932 )?;
6933 writeln!(
6934 code,
6935 " generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6936 )?;
6937 writeln!(code, "}}")?;
6938 writeln!(code)?;
6939
6940 writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6942 writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6943 writeln!(code, " logits.iter()")?;
6944 writeln!(code, " .enumerate()")?;
6945 writeln!(
6946 code,
6947 " .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6948 )?;
6949 writeln!(code, " .map(|(i, _)| i as u32)")?;
6950 writeln!(code, " .unwrap_or(0)")?;
6951 writeln!(code, "}}")?;
6952 writeln!(code)?;
6953
6954 writeln!(
6956 code,
6957 "// -----------------------------------------------------------------------"
6958 )?;
6959 writeln!(code, "// OpenAI-compatible API server")?;
6960 writeln!(
6961 code,
6962 "// -----------------------------------------------------------------------"
6963 )?;
6964 writeln!(code)?;
6965 writeln!(code, "#[derive(Deserialize)]")?;
6966 writeln!(code, "struct ChatRequest {{")?;
6967 writeln!(code, " messages: Vec<ChatMessage>,")?;
6968 writeln!(code, " #[serde(default)]")?;
6969 writeln!(code, " stream: Option<bool>,")?;
6970 writeln!(code, " #[serde(default)]")?;
6971 writeln!(code, " max_tokens: Option<usize>,")?;
6972 writeln!(code, " #[serde(default)]")?;
6973 writeln!(code, " temperature: Option<f32>,")?;
6974 writeln!(code, " #[serde(default)]")?;
6975 writeln!(code, " model: Option<String>,")?;
6976 writeln!(code, "}}")?;
6977 writeln!(code)?;
6978 writeln!(code, "#[derive(Deserialize)]")?;
6979 writeln!(code, "struct ChatMessage {{")?;
6980 writeln!(code, " role: String,")?;
6981 writeln!(code, " content: String,")?;
6982 writeln!(code, "}}")?;
6983 writeln!(code)?;
6984
6985 writeln!(
6987 code,
6988 "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6989 )?;
6990 writeln!(code, " let mut prompt = String::new();")?;
6991 writeln!(code, " for msg in messages {{")?;
6992 writeln!(code, " prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6993 writeln!(code, " }}")?;
6994 writeln!(code, " prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6995 writeln!(code, " prompt")?;
6996 writeln!(code, "}}")?;
6997 writeln!(code)?;
6998
6999 writeln!(
7001 code,
7002 "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
7003 )?;
7004 writeln!(code, " let len = tokens.len();")?;
7005 writeln!(code, " if len > 1 {{")?;
7006 writeln!(
7007 code,
7008 " model.forward_prefill_batch(&tokens[..len - 1]);"
7009 )?;
7010 writeln!(code, " }}")?;
7011 writeln!(code, " model.forward(tokens[len - 1])")?;
7012 writeln!(code, "}}")?;
7013 writeln!(code)?;
7014
7015 writeln!(
7017 code,
7018 "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
7019 )?;
7020 writeln!(code, " let addr = format!(\"0.0.0.0:{{}}\", port);")?;
7021 writeln!(code, " let server = tiny_http::Server::http(&addr)")?;
7022 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
7023 writeln!(
7024 code,
7025 " eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
7026 )?;
7027 writeln!(code, " eprintln!(\"Endpoints:\");")?;
7028 writeln!(code, " eprintln!(\" POST /v1/chat/completions\");")?;
7029 writeln!(code, " eprintln!(\" GET /v1/models\");")?;
7030 writeln!(code, " eprintln!(\" GET /health\");")?;
7031 writeln!(code)?;
7032 writeln!(code, " for request in server.incoming_requests() {{")?;
7033 writeln!(code, " let method = request.method().to_string();")?;
7034 writeln!(code, " let url = request.url().to_string();")?;
7035 writeln!(code)?;
7036 writeln!(code, " match (method.as_str(), url.as_str()) {{")?;
7037
7038 writeln!(
7040 code,
7041 " (\"POST\", \"/v1/chat/completions\") => {{"
7042 )?;
7043 writeln!(
7044 code,
7045 " handle_chat_completion(&mut model, &tokenizer, request);"
7046 )?;
7047 writeln!(code, " }}")?;
7048
7049 writeln!(code, " (\"GET\", \"/v1/models\") => {{")?;
7051 writeln!(code, " let body = serde_json::json!({{")?;
7052 writeln!(code, " \"object\": \"list\",")?;
7053 writeln!(code, " \"data\": [{{")?;
7054 writeln!(code, " \"id\": \"forgellm-metal\",")?;
7055 writeln!(code, " \"object\": \"model\",")?;
7056 writeln!(code, " \"owned_by\": \"forgellm\"")?;
7057 writeln!(code, " }}]")?;
7058 writeln!(code, " }});")?;
7059 writeln!(
7060 code,
7061 " let resp = tiny_http::Response::from_string(body.to_string())"
7062 )?;
7063 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7064 writeln!(code, " request.respond(resp).ok();")?;
7065 writeln!(code, " }}")?;
7066
7067 writeln!(code, " (\"GET\", \"/health\") => {{")?;
7069 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
7070 writeln!(code, " request.respond(resp).ok();")?;
7071 writeln!(code, " }}")?;
7072
7073 writeln!(code, " _ => {{")?;
7075 writeln!(
7076 code,
7077 " let resp = tiny_http::Response::from_string(\"Not Found\")"
7078 )?;
7079 writeln!(code, " .with_status_code(404);")?;
7080 writeln!(code, " request.respond(resp).ok();")?;
7081 writeln!(code, " }}")?;
7082 writeln!(code, " }}")?;
7083 writeln!(code, " }}")?;
7084 writeln!(code, "}}")?;
7085 writeln!(code)?;
7086
7087 writeln!(code, "fn handle_chat_completion(")?;
7089 writeln!(code, " model: &mut model::MetalModel,")?;
7090 writeln!(code, " tokenizer: &tokenizers::Tokenizer,")?;
7091 writeln!(code, " mut request: tiny_http::Request,")?;
7092 writeln!(code, ") {{")?;
7093 writeln!(code, " // Read request body")?;
7094 writeln!(code, " let mut body = String::new();")?;
7095 writeln!(
7096 code,
7097 " if request.as_reader().read_to_string(&mut body).is_err() {{"
7098 )?;
7099 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
7100 writeln!(code, " .with_status_code(400);")?;
7101 writeln!(code, " request.respond(resp).ok();")?;
7102 writeln!(code, " return;")?;
7103 writeln!(code, " }}")?;
7104 writeln!(code)?;
7105 writeln!(code, " // Parse JSON")?;
7106 writeln!(
7107 code,
7108 " let req: ChatRequest = match serde_json::from_str(&body) {{"
7109 )?;
7110 writeln!(code, " Ok(r) => r,")?;
7111 writeln!(code, " Err(e) => {{")?;
7112 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
7113 writeln!(code, " .with_status_code(400);")?;
7114 writeln!(code, " request.respond(resp).ok();")?;
7115 writeln!(code, " return;")?;
7116 writeln!(code, " }}")?;
7117 writeln!(code, " }};")?;
7118 writeln!(code)?;
7119 writeln!(
7120 code,
7121 " let prompt = format_chat_messages(&req.messages);"
7122 )?;
7123 writeln!(
7124 code,
7125 " let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
7126 )?;
7127 writeln!(code, " Ok(e) => e,")?;
7128 writeln!(code, " Err(e) => {{")?;
7129 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
7130 writeln!(code, " .with_status_code(500);")?;
7131 writeln!(code, " request.respond(resp).ok();")?;
7132 writeln!(code, " return;")?;
7133 writeln!(code, " }}")?;
7134 writeln!(code, " }};")?;
7135 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
7136 writeln!(code, " let stream = req.stream.unwrap_or(false);")?;
7137 writeln!(code, " let max_tokens = req.max_tokens.unwrap_or(256);")?;
7138 writeln!(
7139 code,
7140 " let _temperature = req.temperature.unwrap_or(1.0);"
7141 )?;
7142 writeln!(code)?;
7143
7144 writeln!(code, " model.reset();")?;
7146 writeln!(code)?;
7147
7148 writeln!(code, " let prefill_start = Instant::now();")?;
7150 writeln!(code, " let logits = prefill(model, prompt_tokens);")?;
7151 writeln!(
7152 code,
7153 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
7154 )?;
7155 writeln!(code, " let prefill_count = prompt_tokens.len();")?;
7156 writeln!(code, " let mut next_token = argmax(&logits);")?;
7157 writeln!(code)?;
7158
7159 writeln!(code, " if stream {{")?;
7160
7161 writeln!(
7163 code,
7164 " // SSE streaming: generate tokens and build SSE body"
7165 )?;
7166 writeln!(code, " let gen_start = Instant::now();")?;
7167 writeln!(code, " let mut generated_count: usize = 0;")?;
7168 writeln!(code, " let mut sse_body = String::new();")?;
7169 writeln!(code, " for _ in 0..max_tokens {{")?;
7170 writeln!(code, " if next_token == 2 {{ break; }}")?;
7171 writeln!(
7172 code,
7173 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
7174 )?;
7175 writeln!(
7176 code,
7177 " let escaped = serde_json::to_string(&text).unwrap_or_default();"
7178 )?;
7179 writeln!(
7180 code,
7181 " // escaped includes surrounding quotes, strip them"
7182 )?;
7183 writeln!(
7184 code,
7185 " let inner = &escaped[1..escaped.len()-1];"
7186 )?;
7187 writeln!(code, " sse_body.push_str(&format!(")?;
7188 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
7189 writeln!(code, " inner")?;
7190 writeln!(code, " ));")?;
7191 writeln!(code, " }}")?;
7192 writeln!(code, " generated_count += 1;")?;
7193 writeln!(code, " let logits = model.forward(next_token);")?;
7194 writeln!(code, " next_token = argmax(&logits);")?;
7195 writeln!(code, " }}")?;
7196 writeln!(
7197 code,
7198 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
7199 )?;
7200 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7201 writeln!(code, " let gen_time_ms = gen_elapsed * 1000.0;")?;
7202 writeln!(code)?;
7203 writeln!(
7204 code,
7205 " // Final chunk with finish_reason, timing, and DONE sentinel"
7206 )?;
7207 writeln!(code, " sse_body.push_str(&format!(")?;
7208 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
7209 writeln!(code, " prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
7210 writeln!(code, " ));")?;
7211 writeln!(code)?;
7212 writeln!(
7213 code,
7214 " let resp = tiny_http::Response::from_string(sse_body)"
7215 )?;
7216 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
7217 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
7218 writeln!(code, " request.respond(resp).ok();")?;
7219
7220 writeln!(code, " }} else {{")?;
7221
7222 writeln!(
7224 code,
7225 " // Non-streaming: generate all tokens, return JSON"
7226 )?;
7227 writeln!(code, " let gen_start = Instant::now();")?;
7228 writeln!(code, " let mut generated_count: usize = 0;")?;
7229 writeln!(code, " let mut generated = String::new();")?;
7230 writeln!(code, " for _ in 0..max_tokens {{")?;
7231 writeln!(code, " if next_token == 2 {{ break; }}")?;
7232 writeln!(
7233 code,
7234 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
7235 )?;
7236 writeln!(code, " generated.push_str(&text);")?;
7237 writeln!(code, " }}")?;
7238 writeln!(code, " generated_count += 1;")?;
7239 writeln!(code, " let logits = model.forward(next_token);")?;
7240 writeln!(code, " next_token = argmax(&logits);")?;
7241 writeln!(code, " }}")?;
7242 writeln!(
7243 code,
7244 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
7245 )?;
7246 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7247 writeln!(code)?;
7248 writeln!(code, " let resp_json = serde_json::json!({{")?;
7249 writeln!(code, " \"id\": \"chatcmpl-1\",")?;
7250 writeln!(code, " \"object\": \"chat.completion\",")?;
7251 writeln!(code, " \"choices\": [{{")?;
7252 writeln!(code, " \"index\": 0,")?;
7253 writeln!(code, " \"message\": {{")?;
7254 writeln!(code, " \"role\": \"assistant\",")?;
7255 writeln!(code, " \"content\": generated")?;
7256 writeln!(code, " }},")?;
7257 writeln!(code, " \"finish_reason\": \"stop\"")?;
7258 writeln!(code, " }}],")?;
7259 writeln!(code, " \"usage\": {{")?;
7260 writeln!(code, " \"prefill_tokens\": prefill_count,")?;
7261 writeln!(
7262 code,
7263 " \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
7264 )?;
7265 writeln!(
7266 code,
7267 " \"generation_tokens\": generated_count,"
7268 )?;
7269 writeln!(
7270 code,
7271 " \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
7272 )?;
7273 writeln!(code, " \"tokens_per_sec\": gen_tok_s")?;
7274 writeln!(code, " }}")?;
7275 writeln!(code, " }});")?;
7276 writeln!(
7277 code,
7278 " let resp = tiny_http::Response::from_string(resp_json.to_string())"
7279 )?;
7280 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7281 writeln!(code, " request.respond(resp).ok();")?;
7282 writeln!(code, " }}")?;
7283 writeln!(code, "}}")?;
7284
7285 Ok(code)
7286}
7287
7288#[cfg(test)]
7293mod tests {
7294 use super::*;
7295 use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7296
7297 fn minimal_config() -> ModelConfig {
7298 ModelConfig {
7299 architecture: Architecture::Llama,
7300 hidden_size: 64,
7301 intermediate_size: 128,
7302 num_layers: 2,
7303 num_attention_heads: 4,
7304 num_kv_heads: 4,
7305 head_dim: 16,
7306 vocab_size: 256,
7307 max_seq_len: 512,
7308 rms_norm_eps: 1e-5,
7309 rope_theta: 10000.0,
7310 dtype: DType::F32,
7311 sliding_window_size: None,
7312 qkv_bias: false,
7313 hidden_activation: HiddenActivation::SiLU,
7314 }
7315 }
7316
7317 fn minimal_graph() -> Graph {
7318 Graph::new("test-metal").with_config(minimal_config())
7319 }
7320
7321 #[test]
7322 fn generate_metal_project_creates_files() {
7323 let dir = tempfile::tempdir().unwrap();
7324 let graph = minimal_graph();
7325 generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7326
7327 assert!(
7328 dir.path().join("Cargo.toml").exists(),
7329 "Cargo.toml should be created"
7330 );
7331 assert!(
7332 dir.path().join("src/model.rs").exists(),
7333 "src/model.rs should be created"
7334 );
7335 assert!(
7336 dir.path().join("src/main.rs").exists(),
7337 "src/main.rs should be created"
7338 );
7339 assert!(
7340 dir.path().join("shaders/kernels.metal").exists(),
7341 "shaders/kernels.metal should be created"
7342 );
7343 }
7344
7345 #[test]
7346 fn generated_cargo_toml_has_metal_dep() {
7347 let toml = generate_cargo_toml("my-model");
7348 assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7349 assert!(
7350 toml.contains("tokenizers"),
7351 "Cargo.toml should depend on tokenizers"
7352 );
7353 assert!(
7354 toml.contains("memmap2"),
7355 "Cargo.toml should depend on memmap2"
7356 );
7357 assert!(toml.contains("half"), "Cargo.toml should depend on half");
7358 }
7359
7360 #[test]
7361 fn generated_model_rs_contains_metal_code() {
7362 let config = minimal_config();
7363 let model_rs = generate_model_rs(&config).unwrap();
7364
7365 assert!(
7366 model_rs.contains("pub struct MetalModel"),
7367 "model.rs should define MetalModel struct"
7368 );
7369 assert!(
7370 model_rs.contains("matmul_pipeline: ComputePipelineState"),
7371 "MetalModel should have matmul_pipeline field"
7372 );
7373 assert!(
7374 model_rs.contains("Device::system_default()"),
7375 "model.rs should use Metal device"
7376 );
7377 assert!(
7378 model_rs.contains("new_library_with_source"),
7379 "model.rs should compile Metal shaders"
7380 );
7381 assert!(
7382 model_rs.contains("fn new(weights: &[u8])"),
7383 "MetalModel should implement new()"
7384 );
7385 assert!(
7386 model_rs.contains("fn forward(&mut self, token_id: u32)"),
7387 "MetalModel should implement forward()"
7388 );
7389 }
7390
7391 #[test]
7392 fn generated_shaders_contain_kernel_names() {
7393 let shaders = generate_metal_shaders(&minimal_config());
7394
7395 assert!(
7396 shaders.contains("kernel void matmul_vec"),
7397 "shaders should contain matmul_vec kernel"
7398 );
7399 assert!(
7400 shaders.contains("kernel void rms_norm"),
7401 "shaders should contain rms_norm kernel"
7402 );
7403 assert!(
7404 shaders.contains("kernel void rope"),
7405 "shaders should contain rope kernel"
7406 );
7407 assert!(
7408 shaders.contains("kernel void softmax"),
7409 "shaders should contain softmax kernel"
7410 );
7411 assert!(
7412 shaders.contains("kernel void silu_mul("),
7413 "shaders should contain silu_mul kernel"
7414 );
7415 assert!(
7416 shaders.contains("kernel void silu_mul_fused"),
7417 "shaders should contain silu_mul_fused kernel"
7418 );
7419 assert!(
7420 shaders.contains("kernel void elementwise_add"),
7421 "shaders should contain elementwise_add kernel"
7422 );
7423 assert!(
7424 shaders.contains("kernel void attention"),
7425 "shaders should contain attention kernel"
7426 );
7427 assert!(
7428 shaders.contains("kernel void add_inplace"),
7429 "shaders should contain add_inplace kernel"
7430 );
7431 assert!(
7432 shaders.contains("kernel void copy_buffer"),
7433 "shaders should contain copy_buffer kernel"
7434 );
7435 assert!(
7436 shaders.contains("kernel void copy_offset"),
7437 "shaders should contain copy_offset kernel"
7438 );
7439 }
7440
7441 #[test]
7442 fn generated_shaders_use_simdgroup_features() {
7443 let shaders = generate_metal_shaders(&minimal_config());
7444
7445 assert!(
7446 shaders.contains("threadgroup_barrier"),
7447 "shaders should use threadgroup barriers"
7448 );
7449 assert!(
7450 shaders.contains("threadgroup float"),
7451 "shaders should use threadgroup shared memory"
7452 );
7453 assert!(
7454 shaders.contains("thread_index_in_threadgroup"),
7455 "shaders should use threadgroup indexing"
7456 );
7457 assert!(
7458 shaders.contains("simd_sum"),
7459 "shaders should use simd_sum for warp-level reduction"
7460 );
7461 assert!(
7462 shaders.contains("simd_max"),
7463 "attention kernel should use simd_max for cooperative softmax"
7464 );
7465 assert!(
7466 shaders.contains("thread_index_in_simdgroup"),
7467 "shaders should use simdgroup lane indexing"
7468 );
7469 assert!(
7470 shaders.contains("simdgroup_index_in_threadgroup"),
7471 "shaders should use simdgroup indexing within threadgroup"
7472 );
7473 assert!(
7474 shaders.contains("float4"),
7475 "matmul_vec should use float4 vectorized loads"
7476 );
7477 }
7478
7479 #[test]
7480 fn generated_main_rs_has_tokenizer_usage() {
7481 let config = minimal_config();
7482 let main_rs = generate_main_rs("test-model", &config).unwrap();
7483
7484 assert!(
7485 main_rs.contains("tokenizers::Tokenizer"),
7486 "main.rs should use tokenizers crate"
7487 );
7488 assert!(
7489 main_rs.contains("MetalModel::new"),
7490 "main.rs should call MetalModel::new"
7491 );
7492 assert!(
7493 main_rs.contains("model.forward"),
7494 "main.rs should call model.forward"
7495 );
7496 assert!(
7497 main_rs.contains("memmap2"),
7498 "main.rs should use memmap2 for zero-copy weight loading"
7499 );
7500 }
7501
7502 #[test]
7503 fn missing_config_returns_error() {
7504 let dir = tempfile::tempdir().unwrap();
7505 let graph = Graph::new("no-config");
7506 let result = generate_metal_project(&graph, dir.path(), "fail");
7507 assert!(
7508 matches!(result, Err(MetalCodegenError::MissingConfig)),
7509 "should fail with MissingConfig when graph has no config"
7510 );
7511 }
7512
7513 #[test]
7514 fn sanitize_name_works() {
7515 assert_eq!(sanitize_name("My Model!"), "my-model");
7516 assert_eq!(sanitize_name("test_model"), "test-model");
7517 assert_eq!(sanitize_name("simple"), "simple");
7518 }
7519
7520 #[test]
7521 fn generated_forward_uses_single_command_buffer() {
7522 let config = minimal_config();
7523 let model_rs = generate_model_rs(&config).unwrap();
7524
7525 let forward_start = model_rs
7528 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7529 .unwrap();
7530 let forward_body = &model_rs[forward_start..];
7531 let forward_end = forward_body
7533 .find("\n pub fn forward_profile")
7534 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7535 .or_else(|| forward_body.find("\n fn dispatch_"))
7536 .unwrap_or(forward_body.len());
7537 let forward_code = &forward_body[..forward_end];
7538
7539 let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7541 assert_eq!(
7542 cmd_buf_count, 1,
7543 "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7544 );
7545
7546 let commit_count = forward_code.matches("cmd.commit()").count();
7548 assert_eq!(
7549 commit_count, 1,
7550 "forward() should commit exactly once, found {commit_count}"
7551 );
7552
7553 let wait_count = forward_code.matches("wait_until_completed()").count();
7555 assert!(
7556 wait_count >= 1 && wait_count <= 2,
7557 "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7558 );
7559 }
7560
7561 #[test]
7562 fn generated_model_has_preallocated_working_buffers() {
7563 let config = minimal_config();
7564 let model_rs = generate_model_rs(&config).unwrap();
7565
7566 for buf_name in &[
7567 "normed_buf",
7568 "qkv_buf",
7569 "attn_out_buf",
7570 "attn_proj_buf",
7571 "gate_up_buf",
7572 "ffn_hidden_buf",
7573 "ffn_out_buf",
7574 "add_tmp_buf",
7575 ] {
7576 assert!(
7577 model_rs.contains(&format!("{buf_name}: Buffer")),
7578 "MetalModel should have pre-allocated {buf_name} field"
7579 );
7580 }
7581 }
7582
7583 #[test]
7584 fn generated_dispatch_helpers_take_compute_encoder_ref() {
7585 let config = minimal_config();
7586 let model_rs = generate_model_rs(&config).unwrap();
7587
7588 for method in &[
7589 "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7590 "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7591 "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7592 "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7593 "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7594 "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7595 "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7596 "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7597 "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7598 "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7599 "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7600 "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7601 "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7602 ] {
7603 assert!(
7604 model_rs.contains(method),
7605 "model.rs should contain dispatch helper: {method}"
7606 );
7607 }
7608 }
7609
7610 #[test]
7611 fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7612 let config = minimal_config();
7613 let model_rs = generate_model_rs(&config).unwrap();
7614
7615 let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7617 let helpers_code = &model_rs[helpers_start..];
7618
7619 assert!(
7621 !helpers_code.contains("self.queue.new_command_buffer()"),
7622 "dispatch helpers should not create their own command buffers"
7623 );
7624
7625 assert!(
7627 !helpers_code.contains("new_compute_command_encoder()"),
7628 "dispatch helpers should not create their own compute encoders"
7629 );
7630
7631 assert!(
7633 !helpers_code.contains("end_encoding()"),
7634 "dispatch helpers should not call end_encoding"
7635 );
7636
7637 assert!(
7639 !helpers_code.contains(".commit()"),
7640 "dispatch helpers should not commit command buffers"
7641 );
7642 assert!(
7643 !helpers_code.contains("wait_until_completed"),
7644 "dispatch helpers should not wait on command buffers"
7645 );
7646 }
7647
7648 #[test]
7649 fn generated_forward_batches_compute_encoders() {
7650 let config = minimal_config();
7651 let model_rs = generate_model_rs(&config).unwrap();
7652
7653 let forward_start = model_rs
7655 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7656 .unwrap();
7657 let forward_body = &model_rs[forward_start..];
7658 let forward_end = forward_body
7659 .find("\n pub fn forward_profile")
7660 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7661 .or_else(|| forward_body.find("\n fn dispatch_"))
7662 .unwrap_or(forward_body.len());
7663 let forward_code = &forward_body[..forward_end];
7664
7665 assert!(
7667 !forward_code.contains("device.new_buffer"),
7668 "forward() should not allocate new buffers per call"
7669 );
7670
7671 let compute_encoder_count = forward_code
7674 .matches("new_compute_command_encoder()")
7675 .count();
7676 let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7677
7678 assert_eq!(
7680 compute_encoder_count, 1,
7681 "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7682 );
7683 assert_eq!(
7684 blit_encoder_count, 0,
7685 "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7686 );
7687 }
7688
7689 #[test]
7690 fn generated_forward_uses_add_inplace() {
7691 let config = minimal_config();
7692 let model_rs = generate_model_rs(&config).unwrap();
7693
7694 assert!(
7696 model_rs.contains("dispatch_add_inplace"),
7697 "forward() should use dispatch_add_inplace for residual connections"
7698 );
7699 assert!(
7700 model_rs.contains("add_inplace_pipeline"),
7701 "MetalModel should have add_inplace_pipeline"
7702 );
7703 }
7704
7705 fn minimal_q8_config() -> ModelConfig {
7706 ModelConfig {
7707 architecture: Architecture::Llama,
7708 hidden_size: 64,
7709 intermediate_size: 128,
7710 num_layers: 2,
7711 num_attention_heads: 4,
7712 num_kv_heads: 4,
7713 head_dim: 16,
7714 vocab_size: 256,
7715 max_seq_len: 512,
7716 rms_norm_eps: 1e-5,
7717 rope_theta: 10000.0,
7718 dtype: DType::Q8_0,
7719 sliding_window_size: None,
7720 qkv_bias: false,
7721 hidden_activation: HiddenActivation::SiLU,
7722 }
7723 }
7724
7725 #[test]
7726 fn generated_shaders_contain_q8_kernel() {
7727 let shaders = generate_metal_shaders(&minimal_config());
7728
7729 assert!(
7730 shaders.contains("kernel void matmul_vec_q8"),
7731 "shaders should contain matmul_vec_q8 kernel"
7732 );
7733 assert!(
7734 shaders.contains("device const uchar* matrix"),
7735 "matmul_vec_q8 should accept raw Q8_0 bytes"
7736 );
7737 assert!(
7738 shaders.contains("packed_short4"),
7739 "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7740 );
7741 assert!(
7742 shaders.contains("as_type<char2>"),
7743 "matmul_vec_q8 should bitcast short lanes to char2"
7744 );
7745 assert!(
7746 shaders.contains("device const half*"),
7747 "matmul_vec_q8 should read f16 scale via half pointer"
7748 );
7749 }
7750
7751 #[test]
7752 fn generated_model_uses_fused_qkv_projections() {
7753 let config = minimal_config();
7754 let model_rs = generate_model_rs(&config).unwrap();
7755
7756 assert!(
7758 model_rs.contains("qkv_weight: Buffer"),
7759 "LayerBuffers should have fused qkv_weight field"
7760 );
7761 assert!(
7763 !model_rs.contains(" q_weight: Buffer"),
7764 "LayerBuffers should not have separate q_weight field"
7765 );
7766 assert!(
7767 !model_rs.contains(" k_weight: Buffer"),
7768 "LayerBuffers should not have separate k_weight field"
7769 );
7770 assert!(
7771 !model_rs.contains(" v_weight: Buffer"),
7772 "LayerBuffers should not have separate v_weight field"
7773 );
7774
7775 assert!(
7777 model_rs.contains("gate_up_weight: Buffer"),
7778 "LayerBuffers should have fused gate_up_weight field"
7779 );
7780 assert!(
7782 !model_rs.contains(" gate_weight: Buffer"),
7783 "LayerBuffers should not have separate gate_weight field"
7784 );
7785 assert!(
7786 !model_rs.contains(" up_weight: Buffer"),
7787 "LayerBuffers should not have separate up_weight field"
7788 );
7789
7790 assert!(
7792 model_rs.contains("qkv_buf: Buffer"),
7793 "MetalModel should have fused qkv_buf"
7794 );
7795 assert!(
7796 model_rs.contains("gate_up_buf: Buffer"),
7797 "MetalModel should have fused gate_up_buf"
7798 );
7799
7800 assert!(
7802 model_rs.contains("dispatch_silu_mul_fused"),
7803 "forward pass should use dispatch_silu_mul_fused"
7804 );
7805 assert!(
7806 model_rs.contains("dispatch_rope_offset"),
7807 "forward pass should use dispatch_rope_offset for fused QKV"
7808 );
7809 assert!(
7810 model_rs.contains("dispatch_attention_offset"),
7811 "forward pass should use dispatch_attention_offset for fused QKV"
7812 );
7813 }
7814
7815 #[test]
7816 fn q8_model_has_matmul_q8_pipeline() {
7817 let config = minimal_q8_config();
7818 let model_rs = generate_model_rs(&config).unwrap();
7819
7820 assert!(
7821 model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7822 "MetalModel should have matmul_q8_pipeline field"
7823 );
7824 assert!(
7825 model_rs.contains("matmul_q8_pipeline,"),
7826 "MetalModel Self should include matmul_q8_pipeline"
7827 );
7828 }
7829
7830 #[test]
7831 fn q8_model_uses_dispatch_matmul_q8() {
7832 let config = minimal_q8_config();
7833 let model_rs = generate_model_rs(&config).unwrap();
7834
7835 assert!(
7836 model_rs.contains("dispatch_matmul_q8"),
7837 "Q8_0 model should use dispatch_matmul_q8 for projections"
7838 );
7839 assert!(
7840 model_rs.contains("fn dispatch_matmul_q8"),
7841 "model.rs should define dispatch_matmul_q8 method"
7842 );
7843 }
7844
7845 #[test]
7846 fn q8_model_loads_raw_bytes_not_dequantized() {
7847 let config = minimal_q8_config();
7848 let model_rs = generate_model_rs(&config).unwrap();
7849
7850 assert!(
7852 !model_rs.contains("f16_to_f32"),
7853 "Q8_0 model should not dequantize weights to f32"
7854 );
7855 assert!(
7856 !model_rs.contains("f32_data"),
7857 "Q8_0 model should not create f32 weight data"
7858 );
7859
7860 assert!(
7862 model_rs.contains("total_raw as u64"),
7863 "Q8_0 model should load raw bytes into Metal buffer"
7864 );
7865 }
7866
7867 #[test]
7868 fn q8_model_norms_stay_f32() {
7869 let config = minimal_q8_config();
7870 let model_rs = generate_model_rs(&config).unwrap();
7871
7872 assert!(
7874 model_rs.contains("let attn_norm = next_f32_buffer"),
7875 "attn_norm should use f32 buffer even for Q8_0 models"
7876 );
7877 assert!(
7878 model_rs.contains("let ffn_norm = next_f32_buffer"),
7879 "ffn_norm should use f32 buffer even for Q8_0 models"
7880 );
7881 assert!(
7882 model_rs.contains("let norm_buf = next_f32_buffer"),
7883 "final norm should use f32 buffer even for Q8_0 models"
7884 );
7885 }
7886
7887 #[test]
7888 fn q8_model_uses_fused_weight_loading() {
7889 let config = minimal_q8_config();
7890 let model_rs = generate_model_rs(&config).unwrap();
7891
7892 assert!(
7894 model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7895 "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7896 );
7897 assert!(
7899 model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7900 "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7901 );
7902 assert!(
7904 model_rs.contains("let o_weight = next_q8_buffer"),
7905 "Q8_0 model should use next_q8_buffer for O weight"
7906 );
7907 assert!(
7908 model_rs.contains("let down_weight = next_q8_buffer"),
7909 "Q8_0 model should use next_q8_buffer for down weight"
7910 );
7911 }
7912
7913 #[test]
7914 fn f32_model_does_not_use_q8_dispatch() {
7915 let config = minimal_config();
7916 let model_rs = generate_model_rs(&config).unwrap();
7917
7918 let forward_start = model_rs
7920 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7921 .unwrap();
7922 let forward_body = &model_rs[forward_start..];
7923 let forward_end = forward_body
7924 .find("\n fn dispatch_")
7925 .unwrap_or(forward_body.len());
7926 let forward_code = &forward_body[..forward_end];
7927
7928 assert!(
7929 !forward_code.contains("dispatch_matmul_q8"),
7930 "f32 model forward should not use dispatch_matmul_q8"
7931 );
7932 }
7933
7934 #[test]
7935 fn q8_dispatch_helper_takes_compute_encoder_ref() {
7936 let config = minimal_q8_config();
7937 let model_rs = generate_model_rs(&config).unwrap();
7938
7939 assert!(
7940 model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7941 "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7942 );
7943 }
7944
7945 #[test]
7946 fn generated_model_has_double_buffered_prefill() {
7947 let config = minimal_config();
7948 let model_rs = generate_model_rs(&config).unwrap();
7949
7950 assert!(
7952 model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7953 "MetalModel should have prev_cmd field for double-buffered prefill"
7954 );
7955
7956 assert!(
7958 model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7959 "MetalModel should have forward_prefill method"
7960 );
7961
7962 assert!(
7964 model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7965 "forward() should drain prev_cmd from previous prefill"
7966 );
7967 }
7968
7969 #[test]
7970 fn generated_main_rs_uses_forward_prefill_for_prompt() {
7971 let config = minimal_config();
7972 let main_rs = generate_main_rs("test-model", &config).unwrap();
7973
7974 assert!(
7975 main_rs.contains("forward_prefill"),
7976 "main.rs should use forward_prefill for intermediate prompt tokens"
7977 );
7978 assert!(
7979 main_rs.contains("double-buffered"),
7980 "main.rs should document double-buffered prefill"
7981 );
7982 }
7983
7984 #[test]
7985 fn generated_shaders_q8_uses_wide_vectorized_loads() {
7986 let shaders = generate_metal_shaders(&minimal_config());
7987
7988 assert!(
7989 shaders.contains("packed_short4"),
7990 "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7991 );
7992 assert!(
7993 shaders.contains("d0[0]"),
7994 "matmul_vec_q8 should index the wide pointer for row 0"
7995 );
7996 assert!(
7997 shaders.contains("as_type<char2>"),
7998 "matmul_vec_q8 should bitcast short lanes to char2"
7999 );
8000 assert!(
8001 shaders.contains("dot("),
8002 "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
8003 );
8004 }
8005
8006 fn minimal_q4_config() -> ModelConfig {
8009 ModelConfig {
8010 architecture: Architecture::Llama,
8011 hidden_size: 64,
8012 intermediate_size: 128,
8013 num_layers: 2,
8014 num_attention_heads: 4,
8015 num_kv_heads: 4,
8016 head_dim: 16,
8017 vocab_size: 256,
8018 max_seq_len: 512,
8019 rms_norm_eps: 1e-5,
8020 rope_theta: 10000.0,
8021 dtype: DType::Q4_0,
8022 sliding_window_size: None,
8023 qkv_bias: false,
8024 hidden_activation: HiddenActivation::SiLU,
8025 }
8026 }
8027
8028 #[test]
8029 fn generated_shaders_contain_q4_kernel() {
8030 let shaders = generate_metal_shaders(&minimal_config());
8031
8032 assert!(
8033 shaders.contains("kernel void matmul_vec_q4"),
8034 "shaders should contain matmul_vec_q4 kernel"
8035 );
8036 assert!(
8037 shaders.contains("Q4_ROWS_PER_TG"),
8038 "shaders should define Q4_ROWS_PER_TG constant"
8039 );
8040 assert!(
8041 shaders.contains("Q4_ROWS_PER_SG"),
8042 "shaders should define Q4_ROWS_PER_SG constant"
8043 );
8044 }
8045
8046 #[test]
8047 fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
8048 let shaders = generate_metal_shaders(&minimal_config());
8049
8050 assert!(
8052 shaders.contains("uchar4"),
8053 "matmul_vec_q4 should use uchar4 for packed byte loads"
8054 );
8055 assert!(
8057 shaders.contains("&0xF"),
8058 "matmul_vec_q4 should extract low nibble with &0xF"
8059 );
8060 assert!(
8061 shaders.contains(">>4"),
8062 "matmul_vec_q4 should extract high nibble with >>4"
8063 );
8064 assert!(
8066 shaders.contains("-8)"),
8067 "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
8068 );
8069 assert!(
8071 shaders.contains("blk * 18"),
8072 "matmul_vec_q4 should use 18-byte block stride"
8073 );
8074 }
8075
8076 #[test]
8077 fn q4_model_has_matmul_q4_pipeline() {
8078 let config = minimal_q4_config();
8079 let model_rs = generate_model_rs(&config).unwrap();
8080
8081 assert!(
8082 model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
8083 "MetalModel should have matmul_q4_pipeline field"
8084 );
8085 assert!(
8086 model_rs.contains("matmul_q4_pipeline,"),
8087 "MetalModel Self should include matmul_q4_pipeline"
8088 );
8089 }
8090
8091 #[test]
8092 fn q4_model_uses_dispatch_matmul_q4() {
8093 let config = minimal_q4_config();
8094 let model_rs = generate_model_rs(&config).unwrap();
8095
8096 assert!(
8097 model_rs.contains("dispatch_matmul_q4"),
8098 "Q4_0 model should use dispatch_matmul_q4 for projections"
8099 );
8100 assert!(
8101 model_rs.contains("fn dispatch_matmul_q4"),
8102 "model.rs should define dispatch_matmul_q4 method"
8103 );
8104 }
8105
8106 #[test]
8107 fn q4_model_loads_raw_bytes_not_dequantized() {
8108 let config = minimal_q4_config();
8109 let model_rs = generate_model_rs(&config).unwrap();
8110
8111 assert!(
8113 !model_rs.contains("f16_to_f32"),
8114 "Q4_0 model should not dequantize weights to f32"
8115 );
8116
8117 assert!(
8119 model_rs.contains("total_raw as u64"),
8120 "Q4_0 model should load raw bytes into Metal buffer"
8121 );
8122 }
8123
8124 #[test]
8125 fn q4_model_norms_stay_f32() {
8126 let config = minimal_q4_config();
8127 let model_rs = generate_model_rs(&config).unwrap();
8128
8129 assert!(
8130 model_rs.contains("let attn_norm = next_f32_buffer"),
8131 "attn_norm should use f32 buffer even for Q4_0 models"
8132 );
8133 assert!(
8134 model_rs.contains("let ffn_norm = next_f32_buffer"),
8135 "ffn_norm should use f32 buffer even for Q4_0 models"
8136 );
8137 assert!(
8138 model_rs.contains("let norm_buf = next_f32_buffer"),
8139 "final norm should use f32 buffer even for Q4_0 models"
8140 );
8141 }
8142
8143 #[test]
8144 fn q4_model_uses_fused_weight_loading() {
8145 let config = minimal_q4_config();
8146 let model_rs = generate_model_rs(&config).unwrap();
8147
8148 assert!(
8149 model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
8150 "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
8151 );
8152 assert!(
8153 model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
8154 "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
8155 );
8156 assert!(
8157 model_rs.contains("let o_weight = next_q4_buffer"),
8158 "Q4_0 model should use next_q4_buffer for O weight"
8159 );
8160 assert!(
8161 model_rs.contains("let down_weight = next_q4_buffer"),
8162 "Q4_0 model should use next_q4_buffer for down weight"
8163 );
8164 }
8165
8166 #[test]
8167 fn attention_flash_batch_kernel_exists() {
8168 let config = minimal_config();
8172 let model_rs = generate_model_rs(&config).unwrap();
8173 let shaders = generate_metal_shaders(&config);
8174
8175 assert!(
8176 shaders.contains("kernel void attention_flash_batch"),
8177 "shaders.metal must still contain the attention_flash_batch kernel"
8178 );
8179 assert!(
8180 shaders.contains("FLASH_K_TILE"),
8181 "flash kernel must tile K/V with a FLASH_K_TILE constant"
8182 );
8183 assert!(
8184 model_rs.contains("attention_flash_batch_pipeline"),
8185 "MetalModel must register the flash attention pipeline"
8186 );
8187 }
8188
8189 #[test]
8190 fn decode_uses_fused_rope_and_kv_copy() {
8191 let config = minimal_config();
8196 let model_rs = generate_model_rs(&config).unwrap();
8197
8198 let forward_start = model_rs.find("pub fn forward(").expect("forward missing");
8203 let forward_end = model_rs[forward_start..]
8204 .find("pub fn forward_prefill(")
8205 .expect("forward_prefill missing");
8206 let forward_body = &model_rs[forward_start..forward_start + forward_end];
8207
8208 assert!(
8209 forward_body.contains("dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1"),
8210 "single-token forward must use fused rope_qk_batch with M=1"
8211 );
8212 assert!(
8213 forward_body.contains("dispatch_copy_kv_both_batch(&enc, &self.qkv_buf,"),
8214 "single-token forward must use fused copy_kv_both_batch"
8215 );
8216 assert!(
8217 !forward_body.contains("dispatch_copy_from_offset_f16"),
8218 "decode path should no longer call the per-K/per-V copy"
8219 );
8220 }
8221
8222 #[test]
8223 fn kv_cache_stored_as_f16() {
8224 let config = minimal_config();
8230 let shaders = generate_metal_shaders(&config);
8231 let model_rs = generate_model_rs(&config).unwrap();
8232
8233 for kernel in [
8234 "kernel void attention(",
8235 "kernel void attention_batch(",
8236 "kernel void attention_flash_batch(",
8237 "kernel void attention_mma_flash_batch(",
8238 ] {
8239 let start = shaders
8240 .find(kernel)
8241 .unwrap_or_else(|| panic!("kernel {kernel} missing"));
8242 let sig_end = shaders[start..]
8245 .find("){")
8246 .unwrap_or_else(|| shaders[start..].find(") {").unwrap());
8247 let sig = &shaders[start..start + sig_end];
8248 assert!(
8249 sig.contains("device const half*")
8250 && sig.contains("k_cache")
8251 && sig.contains("v_cache"),
8252 "{kernel} must read k_cache/v_cache as `device const half*`"
8253 );
8254 assert!(
8255 !sig.contains("device const float* k_cache")
8256 && !sig.contains("device const float* k_cache"),
8257 "{kernel} still reads k_cache as float"
8258 );
8259 }
8260
8261 assert!(
8262 shaders.contains("kernel void copy_f32_to_f16_offset"),
8263 "f32->f16 copy kernel must be present for single-token decode KV writes"
8264 );
8265 assert!(
8266 model_rs.contains("dispatch_copy_from_offset_f16"),
8267 "single-token decode must dispatch the f32->f16 KV copy"
8268 );
8269 }
8270
8271 #[test]
8272 fn decode_attention_uses_half4_vectorized_loads() {
8273 let config = minimal_config();
8278 let shaders = generate_metal_shaders(&config);
8279
8280 let start = shaders
8281 .find("kernel void attention(")
8282 .expect("decode attention kernel missing");
8283 let end_rel = shaders[start + 1..]
8285 .find("kernel void ")
8286 .expect("next kernel missing");
8287 let body = &shaders[start..start + 1 + end_rel];
8288
8289 assert!(
8290 body.contains("device const half4*"),
8291 "decode attention must half4-load K/V"
8292 );
8293 assert!(
8294 body.contains("device const float4*"),
8295 "decode attention must float4-load Q"
8296 );
8297 assert!(
8298 body.contains("device float4*"),
8299 "decode attention must float4-store output"
8300 );
8301 assert!(
8302 body.contains("head_dim4"),
8303 "decode attention must iterate head_dim in chunks of 4"
8304 );
8305 assert!(
8307 body.contains("simd_sum(partial.x)")
8308 && body.contains("simd_sum(partial.y)")
8309 && body.contains("simd_sum(partial.z)")
8310 && body.contains("simd_sum(partial.w)"),
8311 "decode attention V-step must reduce float4 partials via simd_sum"
8312 );
8313 assert!(
8315 body.contains("for (uint d4 = simd_id; d4 < head_dim4; d4 += 8)"),
8316 "decode attention V-step must partition d4 across simdgroups"
8317 );
8318 }
8319
8320 #[test]
8321 fn attention_mma_flash_batch_kernel_wired() {
8322 let config = minimal_config();
8325 let model_rs = generate_model_rs(&config).unwrap();
8326 let shaders = generate_metal_shaders(&config);
8327
8328 assert!(
8329 shaders.contains("kernel void attention_mma_flash_batch"),
8330 "shaders.metal must contain the MMA flash kernel"
8331 );
8332 assert!(
8333 shaders.contains("FLASH_MMA_Q_BLOCK"),
8334 "MMA flash kernel must define Q_BLOCK tiling constant"
8335 );
8336 assert!(
8337 shaders.contains("simdgroup_multiply_accumulate"),
8338 "MMA flash kernel must use hardware MMA"
8339 );
8340 assert!(
8341 model_rs.contains("attention_mma_flash_batch_pipeline"),
8342 "MetalModel must register the MMA flash pipeline"
8343 );
8344 assert!(
8345 model_rs.contains("mma_opt_out"),
8346 "dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
8347 );
8348 assert!(
8349 model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
8350 "MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
8351 );
8352 }
8353
8354 #[test]
8355 fn forward_prefill_batch_chunks_by_max_batch_size() {
8356 let config = minimal_config();
8361 let model_rs = generate_model_rs(&config).unwrap();
8362 assert!(
8363 model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
8364 "forward_prefill_batch must chunk long prompts"
8365 );
8366 assert!(
8367 !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
8368 "the old truncation path must be gone"
8369 );
8370 }
8371
8372 #[test]
8373 fn qwen2_qkv_bias_wired_through_metal_codegen() {
8374 let config = ModelConfig {
8378 architecture: Architecture::Qwen2,
8379 qkv_bias: true,
8380 ..minimal_config()
8381 };
8382 let model_rs = generate_model_rs(&config).unwrap();
8383
8384 assert!(
8385 model_rs.contains("qkv_bias: Buffer"),
8386 "Qwen2 LayerBuffers must declare qkv_bias field"
8387 );
8388 assert!(
8389 model_rs.contains("let qkv_bias = next_f32_buffer"),
8390 "Qwen2 layer init must load the bias from the weight blob"
8391 );
8392 assert!(
8393 model_rs.contains("add_bias_batch_pipeline"),
8394 "Qwen2 model struct must include the add_bias_batch_pipeline"
8395 );
8396 assert!(
8397 model_rs.contains("fn dispatch_add_bias_batch"),
8398 "Qwen2 codegen must emit dispatch_add_bias_batch helper"
8399 );
8400 assert!(
8401 model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
8402 "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
8403 );
8404 assert!(
8405 model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
8406 "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
8407 );
8408
8409 let shaders = generate_metal_shaders(&config);
8411 assert!(
8412 shaders.contains("kernel void add_bias_batch"),
8413 "shaders.metal must contain the add_bias_batch kernel"
8414 );
8415 }
8416
8417 #[test]
8418 fn llama_does_not_emit_qkv_bias_machinery() {
8419 let config = minimal_config();
8422 assert!(!config.qkv_bias);
8423 let model_rs = generate_model_rs(&config).unwrap();
8424 assert!(
8425 !model_rs.contains("qkv_bias: Buffer"),
8426 "Llama must not have qkv_bias field"
8427 );
8428 assert!(
8429 !model_rs.contains("add_bias_batch_pipeline"),
8430 "Llama must not pull in add_bias_batch_pipeline"
8431 );
8432 assert!(
8433 !model_rs.contains("dispatch_add_bias_batch"),
8434 "Llama must not call dispatch_add_bias_batch"
8435 );
8436 }
8437
8438 #[test]
8439 fn q4_dispatch_helper_takes_compute_encoder_ref() {
8440 let config = minimal_q4_config();
8441 let model_rs = generate_model_rs(&config).unwrap();
8442
8443 assert!(
8444 model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8445 "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8446 );
8447 }
8448
8449 #[test]
8450 fn f32_model_does_not_use_q4_dispatch() {
8451 let config = minimal_config();
8452 let model_rs = generate_model_rs(&config).unwrap();
8453
8454 let forward_start = model_rs
8455 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8456 .unwrap();
8457 let forward_body = &model_rs[forward_start..];
8458 let forward_end = forward_body
8459 .find("\n fn dispatch_")
8460 .unwrap_or(forward_body.len());
8461 let forward_code = &forward_body[..forward_end];
8462
8463 assert!(
8464 !forward_code.contains("dispatch_matmul_q4"),
8465 "f32 model forward should not use dispatch_matmul_q4"
8466 );
8467 }
8468
8469 #[test]
8470 fn q4_model_lm_head_uses_q4_buffer() {
8471 let config = minimal_q4_config();
8472 let model_rs = generate_model_rs(&config).unwrap();
8473
8474 assert!(
8475 model_rs.contains("let lm_head_buf = next_q4_buffer"),
8476 "Q4_0 model should use next_q4_buffer for lm_head"
8477 );
8478 }
8479
8480 #[test]
8481 fn vec_tile_size_matches_model_dimensions() {
8482 let small = minimal_config();
8484 let shaders_small = generate_metal_shaders(&small);
8485 assert!(
8486 shaders_small.contains("vec_tile[128]"),
8487 "vec_tile should be sized to max(hidden, intermediate) = 128"
8488 );
8489
8490 let mut large = minimal_config();
8492 large.hidden_size = 2048;
8493 large.intermediate_size = 8192;
8494 let shaders_large = generate_metal_shaders(&large);
8495 assert!(
8496 shaders_large.contains("vec_tile[8192]"),
8497 "vec_tile should be 8192 for models with intermediate=8192"
8498 );
8499 assert!(
8500 !shaders_large.contains("vec_tile[4096]"),
8501 "vec_tile should NOT be hardcoded to 4096"
8502 );
8503 }
8504
8505 #[test]
8506 fn generated_cargo_toml_has_server_deps() {
8507 let toml = generate_cargo_toml("my-model");
8508 assert!(
8509 toml.contains("tiny_http"),
8510 "Cargo.toml should depend on tiny_http"
8511 );
8512 assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8513 assert!(
8514 toml.contains("serde_json"),
8515 "Cargo.toml should depend on serde_json"
8516 );
8517 }
8518
8519 #[test]
8520 fn generated_main_rs_has_serve_mode() {
8521 let config = minimal_config();
8522 let main_rs = generate_main_rs("test-model", &config).unwrap();
8523
8524 assert!(
8525 main_rs.contains("--serve"),
8526 "main.rs should parse --serve flag"
8527 );
8528 assert!(
8529 main_rs.contains("--port"),
8530 "main.rs should parse --port flag"
8531 );
8532 assert!(
8533 main_rs.contains("fn serve("),
8534 "main.rs should define serve function"
8535 );
8536 assert!(
8537 main_rs.contains("tiny_http::Server::http"),
8538 "main.rs should create tiny_http server"
8539 );
8540 }
8541
8542 #[test]
8543 fn generated_main_rs_has_chat_completions_endpoint() {
8544 let config = minimal_config();
8545 let main_rs = generate_main_rs("test-model", &config).unwrap();
8546
8547 assert!(
8548 main_rs.contains("/v1/chat/completions"),
8549 "main.rs should handle /v1/chat/completions endpoint"
8550 );
8551 assert!(
8552 main_rs.contains("/v1/models"),
8553 "main.rs should handle /v1/models endpoint"
8554 );
8555 assert!(
8556 main_rs.contains("/health"),
8557 "main.rs should handle /health endpoint"
8558 );
8559 }
8560
8561 #[test]
8562 fn generated_main_rs_has_sse_streaming() {
8563 let config = minimal_config();
8564 let main_rs = generate_main_rs("test-model", &config).unwrap();
8565
8566 assert!(
8567 main_rs.contains("text/event-stream"),
8568 "main.rs should set SSE content type for streaming"
8569 );
8570 assert!(
8571 main_rs.contains("chat.completion.chunk"),
8572 "main.rs should emit SSE chunks"
8573 );
8574 assert!(
8575 main_rs.contains("[DONE]"),
8576 "main.rs should emit [DONE] sentinel"
8577 );
8578 }
8579
8580 #[test]
8581 fn generated_main_rs_has_chat_message_formatting() {
8582 let config = minimal_config();
8583 let main_rs = generate_main_rs("test-model", &config).unwrap();
8584
8585 assert!(
8586 main_rs.contains("fn format_chat_messages"),
8587 "main.rs should define format_chat_messages function"
8588 );
8589 assert!(
8590 main_rs.contains("<|im_start|>"),
8591 "main.rs should use ChatML format"
8592 );
8593 assert!(
8594 main_rs.contains("<|im_end|>"),
8595 "main.rs should use ChatML format"
8596 );
8597 }
8598
8599 #[test]
8600 fn generated_main_rs_has_request_types() {
8601 let config = minimal_config();
8602 let main_rs = generate_main_rs("test-model", &config).unwrap();
8603
8604 assert!(
8605 main_rs.contains("struct ChatRequest"),
8606 "main.rs should define ChatRequest struct"
8607 );
8608 assert!(
8609 main_rs.contains("struct ChatMessage"),
8610 "main.rs should define ChatMessage struct"
8611 );
8612 assert!(
8613 main_rs.contains("Deserialize"),
8614 "main.rs should derive Deserialize for request types"
8615 );
8616 }
8617
8618 #[test]
8619 fn generated_model_has_reset_method() {
8620 let config = minimal_config();
8621 let model_rs = generate_model_rs(&config).unwrap();
8622
8623 assert!(
8624 model_rs.contains("pub fn reset(&mut self)"),
8625 "model.rs should have a reset() method for multi-request serving"
8626 );
8627 assert!(
8628 model_rs.contains("self.pos = 0"),
8629 "reset() should reset position to 0"
8630 );
8631 }
8632
8633 #[test]
8634 fn generated_main_rs_cli_mode_still_works() {
8635 let config = minimal_config();
8636 let main_rs = generate_main_rs("test-model", &config).unwrap();
8637
8638 assert!(
8640 main_rs.contains("fn cli_mode("),
8641 "main.rs should define cli_mode function"
8642 );
8643 assert!(
8644 main_rs.contains("model.forward"),
8645 "main.rs should call model.forward"
8646 );
8647 assert!(
8648 main_rs.contains("model.forward_prefill"),
8649 "main.rs should call model.forward_prefill"
8650 );
8651 }
8652
8653 #[test]
8656 fn generated_shaders_contain_batch_kernels() {
8657 let shaders = generate_metal_shaders(&minimal_config());
8658
8659 assert!(
8660 shaders.contains("kernel void matmul_vec_batch"),
8661 "shaders should contain matmul_vec_batch kernel"
8662 );
8663 assert!(
8664 shaders.contains("kernel void matmul_vec_q8_batch"),
8665 "shaders should contain matmul_vec_q8_batch kernel"
8666 );
8667 assert!(
8668 shaders.contains("kernel void matmul_q8_gemm_batch"),
8669 "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8670 );
8671 assert!(
8672 shaders.contains("kernel void matmul_vec_q4_batch"),
8673 "shaders should contain matmul_vec_q4_batch kernel"
8674 );
8675 assert!(
8676 shaders.contains("kernel void rms_norm_batch"),
8677 "shaders should contain rms_norm_batch kernel"
8678 );
8679 assert!(
8680 shaders.contains("kernel void silu_mul_fused_batch"),
8681 "shaders should contain silu_mul_fused_batch kernel"
8682 );
8683 assert!(
8684 shaders.contains("kernel void add_inplace_batch"),
8685 "shaders should contain add_inplace_batch kernel"
8686 );
8687 assert!(
8688 shaders.contains("kernel void copy_embedding_batch"),
8689 "shaders should contain copy_embedding_batch kernel"
8690 );
8691 }
8692
8693 #[test]
8694 fn generated_model_has_batch_pipelines() {
8695 let config = minimal_config();
8696 let model_rs = generate_model_rs(&config).unwrap();
8697
8698 for pipeline in &[
8699 "matmul_batch_pipeline",
8700 "matmul_q8_batch_pipeline",
8701 "matmul_q8_gemm_batch_pipeline",
8702 "matmul_q4_batch_pipeline",
8703 "rms_norm_batch_pipeline",
8704 "rope_batch_pipeline",
8705 "silu_mul_fused_batch_pipeline",
8706 "add_inplace_batch_pipeline",
8707 "copy_embedding_batch_pipeline",
8708 "attention_batch_pipeline",
8709 "copy_kv_batch_pipeline",
8710 ] {
8711 assert!(
8712 model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8713 "MetalModel should have {pipeline} field"
8714 );
8715 }
8716 }
8717
8718 #[test]
8719 fn generated_model_has_batch_buffers() {
8720 let config = minimal_config();
8721 let model_rs = generate_model_rs(&config).unwrap();
8722
8723 for buf in &[
8724 "batch_hidden_buf",
8725 "batch_residual_buf",
8726 "batch_qkv_buf",
8727 "batch_attn_out_buf",
8728 "batch_attn_proj_buf",
8729 "batch_gate_up_buf",
8730 "batch_ffn_hidden_buf",
8731 "batch_ffn_out_buf",
8732 "batch_tokens_buf",
8733 "batch_positions_buf",
8734 ] {
8735 assert!(
8736 model_rs.contains(&format!("{buf}: Buffer")),
8737 "MetalModel should have {buf} field"
8738 );
8739 }
8740 }
8741
8742 #[test]
8743 fn generated_model_has_forward_prefill_batch() {
8744 let config = minimal_config();
8745 let model_rs = generate_model_rs(&config).unwrap();
8746
8747 assert!(
8748 model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8749 "MetalModel should have forward_prefill_batch method"
8750 );
8751
8752 assert!(
8754 model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8755 "forward_prefill should delegate to forward_prefill_batch"
8756 );
8757 }
8758
8759 #[test]
8760 fn generated_model_has_max_batch_size_constant() {
8761 let config = minimal_config();
8762 let model_rs = generate_model_rs(&config).unwrap();
8763
8764 assert!(
8765 model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8766 "model.rs should define MAX_BATCH_SIZE constant"
8767 );
8768 }
8769
8770 #[test]
8771 fn forward_prefill_batch_uses_batch_dispatch() {
8772 let config = minimal_config();
8773 let model_rs = generate_model_rs(&config).unwrap();
8774
8775 let batch_start = model_rs
8776 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8777 .unwrap();
8778 let batch_body = &model_rs[batch_start..];
8779 let batch_end = batch_body
8780 .find("\n pub fn reset")
8781 .unwrap_or(batch_body.len());
8782 let batch_code = &batch_body[..batch_end];
8783
8784 assert!(
8786 batch_code.contains("dispatch_rms_norm_batch"),
8787 "forward_prefill_batch should use dispatch_rms_norm_batch"
8788 );
8789 assert!(
8790 batch_code.contains("dispatch_copy_embedding_batch"),
8791 "forward_prefill_batch should use dispatch_copy_embedding_batch"
8792 );
8793 assert!(
8794 batch_code.contains("dispatch_silu_mul_fused_batch"),
8795 "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8796 );
8797 assert!(
8799 batch_code.contains("dispatch_attention_batch"),
8800 "forward_prefill_batch should use dispatch_attention_batch"
8801 );
8802 assert!(
8804 batch_code.contains("dispatch_copy_kv_both_batch"),
8805 "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8806 );
8807 assert!(
8809 batch_code.contains("dispatch_rope_qk_batch"),
8810 "forward_prefill_batch should use dispatch_rope_qk_batch"
8811 );
8812 }
8813
8814 #[test]
8815 fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8816 let config = minimal_q8_config();
8817 let model_rs = generate_model_rs(&config).unwrap();
8818
8819 let batch_start = model_rs
8820 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8821 .unwrap();
8822 let batch_body = &model_rs[batch_start..];
8823 let batch_end = batch_body
8824 .find("\n pub fn reset")
8825 .unwrap_or(batch_body.len());
8826 let batch_code = &batch_body[..batch_end];
8827
8828 assert!(
8829 batch_code.contains("dispatch_matmul_q8_batch"),
8830 "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8831 );
8832 }
8833
8834 #[test]
8835 fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8836 let config = minimal_q4_config();
8837 let model_rs = generate_model_rs(&config).unwrap();
8838
8839 let batch_start = model_rs
8840 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8841 .unwrap();
8842 let batch_body = &model_rs[batch_start..];
8843 let batch_end = batch_body
8844 .find("\n pub fn reset")
8845 .unwrap_or(batch_body.len());
8846 let batch_code = &batch_body[..batch_end];
8847
8848 assert!(
8849 batch_code.contains("dispatch_matmul_q4_batch"),
8850 "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8851 );
8852 }
8853
8854 #[test]
8855 fn generated_main_rs_uses_batched_prefill() {
8856 let config = minimal_config();
8857 let main_rs = generate_main_rs("test-model", &config).unwrap();
8858
8859 assert!(
8860 main_rs.contains("forward_prefill_batch"),
8861 "main.rs should use forward_prefill_batch for prompt tokens"
8862 );
8863 }
8864
8865 #[test]
8866 fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8867 let config = minimal_config();
8868 let model_rs = generate_model_rs(&config).unwrap();
8869
8870 let batch_start = model_rs
8871 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8872 .unwrap();
8873 let batch_body = &model_rs[batch_start..];
8874 let batch_end = batch_body
8875 .find("\n pub fn reset")
8876 .unwrap_or(batch_body.len());
8877 let batch_code = &batch_body[..batch_end];
8878
8879 assert!(
8880 batch_code.contains("dispatch_matmul_batch"),
8881 "f32 forward_prefill_batch should use dispatch_matmul_batch"
8882 );
8883 assert!(
8885 !batch_code.contains("dispatch_matmul_q8_batch"),
8886 "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8887 );
8888 assert!(
8889 !batch_code.contains("dispatch_matmul_q4_batch"),
8890 "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8891 );
8892 }
8893
8894 #[test]
8895 fn forward_uses_cpu_embedding_lookup() {
8896 let config = minimal_config();
8897 let model_rs = generate_model_rs(&config).unwrap();
8898
8899 let forward_start = model_rs
8901 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8902 .unwrap();
8903 let forward_body = &model_rs[forward_start..];
8904 let forward_end = forward_body
8905 .find("\n pub fn forward_profile")
8906 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8907 .unwrap_or(forward_body.len());
8908 let forward_code = &forward_body[..forward_end];
8909
8910 assert!(
8912 forward_code.contains("embed_buf.contents()"),
8913 "forward() should access embed_buf via CPU unified memory for embedding lookup"
8914 );
8915 assert!(
8916 forward_code.contains("copy_nonoverlapping"),
8917 "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8918 );
8919 assert!(
8921 !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8922 "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8923 );
8924 }
8925
8926 #[test]
8927 fn forward_profile_method_exists() {
8928 let config = minimal_config();
8929 let model_rs = generate_model_rs(&config).unwrap();
8930
8931 assert!(
8932 model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8933 "MetalModel should have forward_profile() method"
8934 );
8935 assert!(
8937 model_rs.contains("[profile]"),
8938 "forward_profile() should print timing with [profile] prefix"
8939 );
8940 assert!(
8941 model_rs.contains("d_embed"),
8942 "forward_profile() should measure embedding time"
8943 );
8944 assert!(
8945 model_rs.contains("d_layers"),
8946 "forward_profile() should measure layer time"
8947 );
8948 assert!(
8949 model_rs.contains("d_logits"),
8950 "forward_profile() should measure logits time"
8951 );
8952 }
8953
8954 #[test]
8955 fn generated_cli_has_profile_flag() {
8956 let config = minimal_config();
8957 let main_rs = generate_main_rs("test-model", &config).unwrap();
8958
8959 assert!(
8960 main_rs.contains("--profile"),
8961 "CLI should support --profile flag"
8962 );
8963 assert!(
8964 main_rs.contains("forward_profile"),
8965 "CLI should call forward_profile when --profile is set"
8966 );
8967 }
8968
8969 #[test]
8970 fn generated_cli_has_thermal_yield() {
8971 let config = minimal_config();
8972 let main_rs = generate_main_rs("test-model", &config).unwrap();
8973
8974 assert!(
8975 main_rs.contains("yield_now()"),
8976 "CLI generation loop should include thread::yield_now() for thermal management"
8977 );
8978 }
8979
8980 #[test]
8983 fn generated_forward_handles_single_token_prompt() {
8984 let config = minimal_config();
8988 let model_rs = generate_model_rs(&config).unwrap();
8989
8990 let forward_start = model_rs
8992 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8993 .expect("forward() must exist");
8994 let forward_body = &model_rs[forward_start..forward_start + 400];
8995
8996 assert!(
8998 !forward_body.contains("assert!(self.pos > 0"),
8999 "forward() must accept pos=0 (first token with no prefill)"
9000 );
9001
9002 assert!(
9004 model_rs.contains("self.pos"),
9005 "forward() should use self.pos to track sequence position"
9006 );
9007 }
9008
9009 #[test]
9010 fn generated_reset_clears_kv_cache_position() {
9011 let config = minimal_config();
9014 let model_rs = generate_model_rs(&config).unwrap();
9015
9016 let reset_start = model_rs
9017 .find("pub fn reset(&mut self)")
9018 .expect("reset() must exist");
9019 let reset_body = &model_rs[reset_start..reset_start + 200];
9020
9021 assert!(
9023 reset_body.contains("self.pos = 0"),
9024 "reset() must set self.pos = 0"
9025 );
9026
9027 assert!(
9029 reset_body.contains("self.prev_cmd = None"),
9030 "reset() should clear prev_cmd for clean command buffer state"
9031 );
9032 }
9033
9034 #[test]
9035 fn generated_serve_handles_empty_messages_gracefully() {
9036 let config = minimal_config();
9039 let main_rs = generate_main_rs("test-model", &config).unwrap();
9040
9041 let format_fn_start = main_rs
9043 .find("fn format_chat_messages")
9044 .expect("format_chat_messages must exist");
9045 let format_fn_body =
9046 &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
9047
9048 assert!(
9050 format_fn_body.contains("for msg in messages"),
9051 "format_chat_messages should iterate over the messages slice"
9052 );
9053 assert!(
9055 format_fn_body.contains("<|im_start|>assistant"),
9056 "format_chat_messages should always append assistant prompt header"
9057 );
9058
9059 let serve_fn_start = main_rs
9061 .find("fn serve(")
9062 .expect("serve function must exist");
9063 let serve_fn_body = &main_rs[serve_fn_start..];
9064 assert!(
9065 serve_fn_body.contains("model.reset()"),
9066 "serve function should reset model between requests"
9067 );
9068 }
9069
9070 #[test]
9071 fn generated_model_forward_increments_pos() {
9072 let config = minimal_config();
9075 let model_rs = generate_model_rs(&config).unwrap();
9076
9077 let forward_start = model_rs
9078 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
9079 .unwrap();
9080 let forward_body = &model_rs[forward_start..];
9081 let forward_end = forward_body
9082 .find("\n pub fn forward_profile")
9083 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
9084 .or_else(|| forward_body.find("\n fn dispatch_"))
9085 .unwrap_or(forward_body.len());
9086 let forward_code = &forward_body[..forward_end];
9087
9088 assert!(
9089 forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
9090 "forward() must increment self.pos after processing a token"
9091 );
9092 }
9093}