1use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16 #[error("graph has no model config")]
18 MissingConfig,
19
20 #[error("I/O error: {0}")]
22 Io(#[from] std::io::Error),
23
24 #[error("format error: {0}")]
26 Fmt(#[from] std::fmt::Error),
27}
28
29pub fn generate_metal_project(
37 graph: &Graph,
38 output_dir: &Path,
39 model_name: &str,
40) -> Result<(), MetalCodegenError> {
41 let config = graph
42 .config
43 .as_ref()
44 .ok_or(MetalCodegenError::MissingConfig)?;
45
46 let src_dir = output_dir.join("src");
47 let shader_dir = output_dir.join("shaders");
48 fs::create_dir_all(&src_dir)?;
49 fs::create_dir_all(&shader_dir)?;
50
51 fs::write(
52 output_dir.join("Cargo.toml"),
53 generate_cargo_toml(model_name),
54 )?;
55
56 fs::write(
57 shader_dir.join("kernels.metal"),
58 generate_metal_shaders(config),
59 )?;
60
61 let model_rs = generate_model_rs(config)?;
62 fs::write(src_dir.join("model.rs"), model_rs)?;
63
64 let main_rs = generate_main_rs(model_name, config)?;
65 fs::write(src_dir.join("main.rs"), main_rs)?;
66
67 Ok(())
68}
69
70fn sanitize_name(name: &str) -> String {
75 name.to_lowercase()
76 .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77 .trim_matches('-')
78 .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82 let sanitized = sanitize_name(model_name);
83 format!(
84 r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108 )
109}
110
111fn generate_metal_shaders(config: &ModelConfig) -> String {
116 let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121 let attn_scores_size = config.max_seq_len.min(4096);
126 r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154 device const float* matrix [[buffer(0)]],
155 device const float* vector [[buffer(1)]],
156 device float* output [[buffer(2)]],
157 constant uint& rows [[buffer(3)]],
158 constant uint& cols [[buffer(4)]],
159 uint tgid [[threadgroup_position_in_grid]],
160 uint tid [[thread_index_in_threadgroup]],
161 uint simd_lane [[thread_index_in_simdgroup]],
162 uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164 // Cooperatively load vector into threadgroup shared memory
165 threadgroup float vec_tile[VEC_TILE_SIZE]; // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166 for (uint i = tid; i < cols; i += 256) {
167 vec_tile[i] = vector[i];
168 }
169 threadgroup_barrier(mem_flags::mem_threadgroup);
170
171 // Each simdgroup handles 8 consecutive rows
172 uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173 if (row_base >= rows) return;
174
175 uint base0 = row_base * cols;
176 uint base1 = (row_base + 1) * cols;
177 uint base2 = (row_base + 2) * cols;
178 uint base3 = (row_base + 3) * cols;
179 uint base4 = (row_base + 4) * cols;
180 uint base5 = (row_base + 5) * cols;
181 uint base6 = (row_base + 6) * cols;
182 uint base7 = (row_base + 7) * cols;
183
184 // float4 vectorized accumulation across 8 rows
185 uint cols_vec4 = cols & ~127u; // largest multiple of 128 <= cols
186 float4 sum4_0 = float4(0.0f);
187 float4 sum4_1 = float4(0.0f);
188 float4 sum4_2 = float4(0.0f);
189 float4 sum4_3 = float4(0.0f);
190 float4 sum4_4 = float4(0.0f);
191 float4 sum4_5 = float4(0.0f);
192 float4 sum4_6 = float4(0.0f);
193 float4 sum4_7 = float4(0.0f);
194
195 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205 }
206
207 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216 // Handle remaining elements (cols not divisible by 128)
217 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218 float vv = vec_tile[j];
219 sum0 += matrix[base0 + j] * vv;
220 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227 }
228
229 // Simdgroup hardware warp-level reduction
230 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235 // Only first lane writes the results
236 if (simd_lane == 0) {
237 if (row_base < rows) output[row_base] = sum0;
238 if (row_base + 1 < rows) output[row_base + 1] = sum1;
239 if (row_base + 2 < rows) output[row_base + 2] = sum2;
240 if (row_base + 3 < rows) output[row_base + 3] = sum3;
241 if (row_base + 4 < rows) output[row_base + 4] = sum4;
242 if (row_base + 5 < rows) output[row_base + 5] = sum5;
243 if (row_base + 6 < rows) output[row_base + 6] = sum6;
244 if (row_base + 7 < rows) output[row_base + 7] = sum7;
245 }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253 device const float* input [[buffer(0)]],
254 device const float* weight [[buffer(1)]],
255 device float* output [[buffer(2)]],
256 constant uint& n [[buffer(3)]],
257 constant float& eps [[buffer(4)]],
258 uint tid [[thread_index_in_threadgroup]])
259{
260 // Each thread accumulates partial sum-of-squares
261 float sum_sq = 0.0f;
262 for (uint i = tid; i < n; i += 256) {
263 float v = input[i];
264 sum_sq += v * v;
265 }
266
267 // Simdgroup-level reduction (hardware warp sum)
268 sum_sq = simd_sum(sum_sq);
269
270 // Cross-simdgroup reduction via shared memory
271 threadgroup float shared[8];
272 uint simd_id = tid / 32;
273 uint simd_lane = tid % 32;
274 if (simd_lane == 0) {
275 shared[simd_id] = sum_sq;
276 }
277 threadgroup_barrier(mem_flags::mem_threadgroup);
278
279 // First thread computes final inverse RMS
280 if (tid == 0) {
281 float total = 0.0f;
282 for (uint i = 0; i < 8; i++) {
283 total += shared[i];
284 }
285 shared[0] = fast::rsqrt(total / float(n) + eps);
286 }
287 threadgroup_barrier(mem_flags::mem_threadgroup);
288
289 float inv_rms = shared[0];
290
291 // Normalize
292 for (uint i = tid; i < n; i += 256) {
293 output[i] = input[i] * inv_rms * weight[i];
294 }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301 device float* data [[buffer(0)]],
302 constant uint& num_heads [[buffer(1)]],
303 constant uint& head_dim [[buffer(2)]],
304 constant uint& pos [[buffer(3)]],
305 constant float& theta [[buffer(4)]],
306 uint id [[thread_position_in_grid]])
307{
308 uint half_dim = head_dim / 2;
309 uint total_pairs = num_heads * half_dim;
310 if (id >= total_pairs) return;
311
312 uint h = id / half_dim;
313 uint i = id % half_dim;
314 uint off = h * head_dim;
315
316 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317 float angle = float(pos) * freq;
318 float c = cos(angle);
319 float s = sin(angle);
320
321 float x0 = data[off + 2 * i];
322 float x1 = data[off + 2 * i + 1];
323 data[off + 2 * i] = x0 * c - x1 * s;
324 data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331 device float* data [[buffer(0)]],
332 constant uint& n [[buffer(1)]],
333 uint tid [[thread_index_in_threadgroup]],
334 uint tg_size [[threads_per_threadgroup]])
335{
336 threadgroup float shared_val[256];
337
338 // Pass 1: find max
339 float local_max = -INFINITY;
340 for (uint i = tid; i < n; i += tg_size) {
341 local_max = max(local_max, data[i]);
342 }
343 shared_val[tid] = local_max;
344 threadgroup_barrier(mem_flags::mem_threadgroup);
345
346 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347 if (tid < stride) {
348 shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349 }
350 threadgroup_barrier(mem_flags::mem_threadgroup);
351 }
352 float max_val = shared_val[0];
353 threadgroup_barrier(mem_flags::mem_threadgroup);
354
355 // Pass 2: exp and sum
356 float local_sum = 0.0f;
357 for (uint i = tid; i < n; i += tg_size) {
358 float e = fast::exp(data[i] - max_val);
359 data[i] = e;
360 local_sum += e;
361 }
362 shared_val[tid] = local_sum;
363 threadgroup_barrier(mem_flags::mem_threadgroup);
364
365 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366 if (tid < stride) {
367 shared_val[tid] += shared_val[tid + stride];
368 }
369 threadgroup_barrier(mem_flags::mem_threadgroup);
370 }
371 float inv_sum = 1.0f / shared_val[0];
372 threadgroup_barrier(mem_flags::mem_threadgroup);
373
374 // Pass 3: normalize
375 for (uint i = tid; i < n; i += tg_size) {
376 data[i] *= inv_sum;
377 }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382// output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384 device const float* gate [[buffer(0)]],
385 device const float* up [[buffer(1)]],
386 device float* output [[buffer(2)]],
387 constant uint& n [[buffer(3)]],
388 uint id [[thread_position_in_grid]])
389{
390 if (id >= n) return;
391 float g = gate[id];
392 output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397// gate = gate_up[0..n], up = gate_up[n..2*n]
398// output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400 device const float* gate_up [[buffer(0)]],
401 device float* output [[buffer(1)]],
402 constant uint& n [[buffer(2)]],
403 uint id [[thread_position_in_grid]])
404{
405 if (id >= n) return;
406 float g = gate_up[id];
407 float u = gate_up[n + id];
408 output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414 device const float* a [[buffer(0)]],
415 device const float* b [[buffer(1)]],
416 device float* output [[buffer(2)]],
417 constant uint& n [[buffer(3)]],
418 uint id [[thread_position_in_grid]])
419{
420 if (id >= n) return;
421 output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428 device const float* src [[buffer(0)]],
429 device float* dst [[buffer(1)]],
430 constant uint& count [[buffer(2)]],
431 uint id [[thread_position_in_grid]])
432{
433 if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440 device const float* src [[buffer(0)]],
441 device float* dst [[buffer(1)]],
442 constant uint& src_offset [[buffer(2)]], // in floats
443 constant uint& count [[buffer(3)]],
444 uint id [[thread_position_in_grid]])
445{
446 if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── copy_f32_to_f16_offset ──────────────────────────────────────────────
450// Copy f32 elements from src into a half-typed dst, converting to half on
451// write. Used by the single-token decode path to append a new K/V vector
452// to the f16 KV cache. Byte offsets into src/dst are supplied via the
453// Metal buffer binding offsets — no in-kernel offsets needed.
454kernel void copy_f32_to_f16_offset(
455 device const float* src [[buffer(0)]],
456 device half* dst [[buffer(1)]],
457 constant uint& count [[buffer(2)]],
458 uint id [[thread_position_in_grid]])
459{
460 if (id < count) dst[id] = half(src[id]);
461}
462
463// ── add_inplace ─────────────────────────────────────────────────────────
464// In-place residual connection: a[i] += b[i]
465// Avoids a separate blit copy for residual add, reducing encoder overhead.
466kernel void add_inplace(
467 device float* a [[buffer(0)]],
468 device const float* b [[buffer(1)]],
469 constant uint& n [[buffer(2)]],
470 uint id [[thread_position_in_grid]])
471{
472 if (id >= n) return;
473 a[id] += b[id];
474}
475
476// ── matmul_vec_q8 ─────────────────────────────────────────────────────
477// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
478// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
479// Operates directly on quantized weights to halve memory bandwidth vs f32,
480// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
481//
482// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
483// because int8->float conversion doubles register demand. Fully unrolled
484// inner loop with float4 vector loads from shared memory eliminates loop
485// overhead and enables better instruction scheduling.
486// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
487constant constexpr uint Q8_ROWS_PER_SG = 4;
488constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
489
490// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
491// doubles ALU work, so the same register budget applies).
492constant constexpr uint Q4_ROWS_PER_SG = 4;
493constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
494
495kernel void matmul_vec_q8(
496 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes
497 device const float* vector [[buffer(1)]], // f32 input
498 device float* output [[buffer(2)]],
499 constant uint& rows [[buffer(3)]],
500 constant uint& cols [[buffer(4)]], // number of elements per row
501 uint tgid [[threadgroup_position_in_grid]],
502 uint tid [[thread_index_in_threadgroup]],
503 uint simd_lane [[thread_index_in_simdgroup]],
504 uint simd_id [[simdgroup_index_in_threadgroup]])
505{
506 // Load vector into shared memory
507 threadgroup float vec_tile[VEC_TILE_SIZE];
508 for (uint i = tid; i < cols; i += 256) {
509 vec_tile[i] = vector[i];
510 }
511 threadgroup_barrier(mem_flags::mem_threadgroup);
512
513 // Each simdgroup handles 4 consecutive rows (lower register pressure)
514 uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
515 if (row_base >= rows) return;
516
517 // Q8_0: each block is 34 bytes for 32 elements
518 uint blocks_per_row = cols / 32;
519 uint row_bytes = blocks_per_row * 34;
520
521 // Pointers to each row's Q8_0 data
522 device const uchar* r0 = matrix + row_base * row_bytes;
523 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
524 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
525 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
526
527 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
528
529 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
530 uint bb = blk * 34;
531 uint vb = blk * 32;
532
533 // Prefetch all 4 scales
534 float sc0 = float(*(device const half*)(r0 + bb));
535 float sc1 = float(*(device const half*)(r1 + bb));
536 float sc2 = float(*(device const half*)(r2 + bb));
537 float sc3 = float(*(device const half*)(r3 + bb));
538
539 // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
540 // Q8_0 block layout where the int8 data starts at offset +2 from a
541 // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
542 // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
543 // reduction in memory transactions. Metal's char16/packed_char16 are
544 // reserved types and packed_*int4 require >=4-byte alignment which
545 // this layout does not provide, so packed_short4 is the widest valid
546 // vectorized load.
547 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
548 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
549 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
550 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
551
552 // Load all 8 float4 vector values for this 32-element block from shared memory
553 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
554 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
555 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
556 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
557 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
558 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
559 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
560 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
561
562 // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
563 // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
564 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
565 short4 _s = short4(SHORT4); \
566 char2 _a = as_type<char2>(_s.x); \
567 char2 _b = as_type<char2>(_s.y); \
568 char2 _c = as_type<char2>(_s.z); \
569 char2 _d = as_type<char2>(_s.w); \
570 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
571 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
572 }
573
574 float4 f0, f1;
575 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
576
577 // Row 0: 4 short4 loads cover 32 int8 weights
578 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
579 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
580 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
581 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
582
583 // Row 1
584 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
585 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
586 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
587 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
588
589 // Row 2
590 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
591 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
592 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
593 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
594
595 // Row 3
596 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
597 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
598 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
599 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
600
601 #undef Q8_UNPACK8
602
603 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
604 }
605
606 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
607 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
608
609 if (simd_lane == 0) {
610 if (row_base < rows) output[row_base] = sum0;
611 if (row_base + 1 < rows) output[row_base + 1] = sum1;
612 if (row_base + 2 < rows) output[row_base + 2] = sum2;
613 if (row_base + 3 < rows) output[row_base + 3] = sum3;
614 }
615}
616
617// ── matmul_vec_q4 ─────────────────────────────────────────────────────
618// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
619// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
620// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
621// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
622//
623// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
624// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
625kernel void matmul_vec_q4(
626 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes
627 device const float* vector [[buffer(1)]], // f32 input
628 device float* output [[buffer(2)]],
629 constant uint& rows [[buffer(3)]],
630 constant uint& cols [[buffer(4)]], // number of elements per row
631 uint tgid [[threadgroup_position_in_grid]],
632 uint tid [[thread_index_in_threadgroup]],
633 uint simd_lane [[thread_index_in_simdgroup]],
634 uint simd_id [[simdgroup_index_in_threadgroup]])
635{
636 // Load vector into shared memory
637 threadgroup float vec_tile[VEC_TILE_SIZE];
638 for (uint i = tid; i < cols; i += 256) {
639 vec_tile[i] = vector[i];
640 }
641 threadgroup_barrier(mem_flags::mem_threadgroup);
642
643 // Each simdgroup handles 4 consecutive rows
644 uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
645 if (row_base >= rows) return;
646
647 // Q4_0: each block is 18 bytes for 32 elements
648 uint blocks_per_row = cols / 32;
649 uint row_bytes = blocks_per_row * 18;
650
651 // Pointers to each row's Q4_0 data
652 device const uchar* r0 = matrix + row_base * row_bytes;
653 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
654 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
655 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
656
657 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
658
659 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
660 uint bb = blk * 18;
661 uint vb = blk * 32;
662
663 // Prefetch all 4 scales
664 float sc0 = float(*(device const half*)(r0 + bb));
665 float sc1 = float(*(device const half*)(r1 + bb));
666 float sc2 = float(*(device const half*)(r2 + bb));
667 float sc3 = float(*(device const half*)(r3 + bb));
668
669 // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
670 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
671 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
672 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
673 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
674
675 // Load 8 float4 vector values for 32 elements from shared memory
676 // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
677 float4 v0 = *(threadgroup const float4*)(vec_tile + vb); // [0..3]
678 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4); // [4..7]
679 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8); // [8..11]
680 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12); // [12..15]
681 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16); // [16..19]
682 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20); // [20..23]
683 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24); // [24..27]
684 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28); // [28..31]
685
686 // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
687 // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
688 float bd0=0, bd1=0, bd2=0, bd3=0;
689 uchar4 b;
690
691 // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
692 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
693 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
694 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
695 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
696 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
697 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
698 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
699 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
700 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
701 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
702 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
703 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
704 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
705 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
706 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
707 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
708
709 // Row 1
710 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
711 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
712 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
713 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
714 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
715 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
716 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
717 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
718 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
719 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
720 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
721 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
722 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
723 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
724 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
725 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
726
727 // Row 2
728 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
729 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
730 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
731 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
732 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
733 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
734 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
735 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
736 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
737 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
738 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
739 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
740 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
741 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
742 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
743 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
744
745 // Row 3
746 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
747 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
748 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
749 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
750 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
751 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
752 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
753 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
754 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
755 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
756 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
757 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
758 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
759 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
760 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
761 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
762
763 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
764 }
765
766 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
767 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
768
769 if (simd_lane == 0) {
770 if (row_base < rows) output[row_base] = sum0;
771 if (row_base + 1 < rows) output[row_base + 1] = sum1;
772 if (row_base + 2 < rows) output[row_base + 2] = sum2;
773 if (row_base + 3 < rows) output[row_base + 3] = sum3;
774 }
775}
776
777// ── attention ───────────────────────────────────────────────────────────
778// Single-query attention with simdgroup cooperative reductions.
779// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
780// with simd_max/simd_sum reductions, then weighted sum of V.
781// Each threadgroup handles one head with 256 threads (8 simdgroups).
782//
783// Buffers:
784// q: [num_heads * head_dim] current query
785// k_cache: [max_seq_len * num_kv_heads * head_dim]
786// v_cache: [max_seq_len * num_kv_heads * head_dim]
787// output: [num_heads * head_dim]
788kernel void attention(
789 device const float* q [[buffer(0)]],
790 device const half* k_cache [[buffer(1)]],
791 device const half* v_cache [[buffer(2)]],
792 device float* output [[buffer(3)]],
793 constant uint& seq_len [[buffer(4)]],
794 constant uint& num_heads [[buffer(5)]],
795 constant uint& num_kv_heads [[buffer(6)]],
796 constant uint& head_dim [[buffer(7)]],
797 uint tgid [[threadgroup_position_in_grid]],
798 uint tid [[thread_index_in_threadgroup]],
799 uint simd_lane [[thread_index_in_simdgroup]],
800 uint simd_id [[simdgroup_index_in_threadgroup]])
801{
802 uint head = tgid;
803 if (head >= num_heads) return;
804 uint kv_head = head / (num_heads / num_kv_heads);
805
806 uint q_off = head * head_dim;
807
808 // Step 1: Compute attention scores Q·K^T with simdgroup reduction
809 // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
810 // vs 4096. For seq_len > 2048, generation-phase attention is rare;
811 // most generation steps have short effective context.
812 threadgroup float scores[ATTN_SCORES_SIZE]; // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
813
814 for (uint s = simd_id; s < seq_len; s += 8) {
815 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
816 float dot = 0.0;
817 for (uint d = simd_lane; d < head_dim; d += 32) {
818 dot += q[q_off + d] * k_cache[k_off + d];
819 }
820 dot = simd_sum(dot);
821 if (simd_lane == 0) {
822 scores[s] = dot * fast::rsqrt(float(head_dim));
823 }
824 }
825 threadgroup_barrier(mem_flags::mem_threadgroup);
826
827 // Step 2: Softmax over scores (cooperative)
828 // Find max
829 float local_max = -INFINITY;
830 for (uint s = tid; s < seq_len; s += 256) {
831 local_max = max(local_max, scores[s]);
832 }
833 local_max = simd_max(local_max);
834 threadgroup float shared_max[8];
835 if (simd_lane == 0) shared_max[simd_id] = local_max;
836 threadgroup_barrier(mem_flags::mem_threadgroup);
837 if (tid == 0) {
838 float m = shared_max[0];
839 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
840 shared_max[0] = m;
841 }
842 threadgroup_barrier(mem_flags::mem_threadgroup);
843 float max_val = shared_max[0];
844
845 // Exp and sum
846 float local_sum = 0.0;
847 for (uint s = tid; s < seq_len; s += 256) {
848 scores[s] = fast::exp(scores[s] - max_val);
849 local_sum += scores[s];
850 }
851 local_sum = simd_sum(local_sum);
852 threadgroup float shared_sum[8];
853 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
854 threadgroup_barrier(mem_flags::mem_threadgroup);
855 if (tid == 0) {
856 float total = 0.0;
857 for (uint i = 0; i < 8; i++) total += shared_sum[i];
858 shared_sum[0] = 1.0 / total;
859 }
860 threadgroup_barrier(mem_flags::mem_threadgroup);
861 float inv_sum = shared_sum[0];
862
863 for (uint s = tid; s < seq_len; s += 256) {
864 scores[s] *= inv_sum;
865 }
866 threadgroup_barrier(mem_flags::mem_threadgroup);
867
868 // Step 3: Weighted sum of V: output = scores · V
869 // Each thread handles a range of head_dim dimensions.
870 // Process 4 sequence positions at a time for better ILP and reduced
871 // loop overhead (float4 score gather, 4 V loads per iteration).
872 uint seq_len4 = seq_len & ~3u; // largest multiple of 4 <= seq_len
873 uint v_stride = num_kv_heads * head_dim;
874 for (uint d = tid; d < head_dim; d += 256) {
875 float acc = 0.0;
876 uint v_base = kv_head * head_dim + d;
877 for (uint s = 0; s < seq_len4; s += 4) {
878 float sc0 = scores[s];
879 float sc1 = scores[s + 1];
880 float sc2 = scores[s + 2];
881 float sc3 = scores[s + 3];
882 acc += sc0 * v_cache[s * v_stride + v_base]
883 + sc1 * v_cache[(s+1) * v_stride + v_base]
884 + sc2 * v_cache[(s+2) * v_stride + v_base]
885 + sc3 * v_cache[(s+3) * v_stride + v_base];
886 }
887 for (uint s = seq_len4; s < seq_len; s++) {
888 acc += scores[s] * v_cache[s * v_stride + v_base];
889 }
890 output[q_off + d] = acc;
891 }
892}
893
894// ── Batched prefill kernels ────────────────────────────────────────────
895// These kernels process M input vectors against the same weight matrix
896// in a single dispatch, converting mat-vec into mat-mat for better GPU
897// utilization during prompt prefill.
898
899// ── rms_norm_batch ─────────────────────────────────────────────────────
900// RMS normalization for a batch of vectors.
901// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
902// Grid: M threadgroups (one per token).
903kernel void rms_norm_batch(
904 device const float* input [[buffer(0)]], // [M, n]
905 device const float* weight [[buffer(1)]], // [n]
906 device float* output [[buffer(2)]], // [M, n]
907 constant uint& n [[buffer(3)]],
908 constant float& eps [[buffer(4)]],
909 constant uint& num_tokens [[buffer(5)]],
910 uint tgid [[threadgroup_position_in_grid]],
911 uint tid [[thread_index_in_threadgroup]])
912{
913 if (tgid >= num_tokens) return;
914
915 uint base = tgid * n;
916
917 float sum_sq = 0.0f;
918 for (uint i = tid; i < n; i += 256) {
919 float v = input[base + i];
920 sum_sq += v * v;
921 }
922
923 sum_sq = simd_sum(sum_sq);
924
925 threadgroup float shared[8];
926 uint simd_id = tid / 32;
927 uint simd_lane = tid % 32;
928 if (simd_lane == 0) {
929 shared[simd_id] = sum_sq;
930 }
931 threadgroup_barrier(mem_flags::mem_threadgroup);
932
933 if (tid == 0) {
934 float total = 0.0f;
935 for (uint i = 0; i < 8; i++) {
936 total += shared[i];
937 }
938 shared[0] = fast::rsqrt(total / float(n) + eps);
939 }
940 threadgroup_barrier(mem_flags::mem_threadgroup);
941
942 float inv_rms = shared[0];
943
944 for (uint i = tid; i < n; i += 256) {
945 output[base + i] = input[base + i] * inv_rms * weight[i];
946 }
947}
948
949// ── rope_batch ─────────────────────────────────────────────────────────
950// Rotary Position Embedding for a batch of vectors with different positions.
951// data layout: [M, num_heads * head_dim], positions: [M]
952// Each thread handles one (token, head, pair) combination.
953kernel void rope_batch(
954 device float* data [[buffer(0)]], // [M, num_heads * head_dim]
955 constant uint& num_heads [[buffer(1)]],
956 constant uint& head_dim [[buffer(2)]],
957 device const uint* positions [[buffer(3)]], // [M] position per token
958 constant float& theta [[buffer(4)]],
959 constant uint& num_tokens [[buffer(5)]],
960 uint id [[thread_position_in_grid]])
961{
962 uint half_dim = head_dim / 2;
963 uint pairs_per_token = num_heads * half_dim;
964 uint total = num_tokens * pairs_per_token;
965 if (id >= total) return;
966
967 uint token = id / pairs_per_token;
968 uint rem = id % pairs_per_token;
969 uint h = rem / half_dim;
970 uint i = rem % half_dim;
971 uint off = token * (num_heads * head_dim) + h * head_dim;
972
973 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
974 float angle = float(positions[token]) * freq;
975 float c = cos(angle);
976 float s = sin(angle);
977
978 float x0 = data[off + 2 * i];
979 float x1 = data[off + 2 * i + 1];
980 data[off + 2 * i] = x0 * c - x1 * s;
981 data[off + 2 * i + 1] = x0 * s + x1 * c;
982}
983
984// ── silu_mul_fused_batch ───────────────────────────────────────────────
985// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
986// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
987kernel void silu_mul_fused_batch(
988 device const float* gate_up [[buffer(0)]], // [M, 2*n]
989 device float* output [[buffer(1)]], // [M, n]
990 constant uint& n [[buffer(2)]],
991 constant uint& num_tokens [[buffer(3)]],
992 uint id [[thread_position_in_grid]])
993{
994 uint total = num_tokens * n;
995 if (id >= total) return;
996 uint token = id / n;
997 uint i = id % n;
998 uint gu_base = token * 2 * n;
999 float g = gate_up[gu_base + i];
1000 float u = gate_up[gu_base + n + i];
1001 output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
1002}
1003
1004// ── add_inplace_batch ──────────────────────────────────────────────────
1005// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
1006kernel void add_inplace_batch(
1007 device float* a [[buffer(0)]], // [M * n]
1008 device const float* b [[buffer(1)]], // [M * n]
1009 constant uint& total [[buffer(2)]], // M * n
1010 uint id [[thread_position_in_grid]])
1011{
1012 if (id >= total) return;
1013 a[id] += b[id];
1014}
1015
1016// ── copy_embedding_batch ───────────────────────────────────────────────
1017// Copy M embedding rows from embedding table to a contiguous batch buffer.
1018// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1019kernel void copy_embedding_batch(
1020 device const float* embed [[buffer(0)]], // [vocab_size, dim]
1021 device float* output [[buffer(1)]], // [M, dim]
1022 device const uint* tokens [[buffer(2)]], // [M]
1023 constant uint& dim [[buffer(3)]],
1024 constant uint& num_tokens [[buffer(4)]],
1025 uint id [[thread_position_in_grid]])
1026{
1027 uint total = num_tokens * dim;
1028 if (id >= total) return;
1029 uint token_idx = id / dim;
1030 uint d = id % dim;
1031 output[id] = embed[tokens[token_idx] * dim + d];
1032}
1033
1034// ── matmul_vec_batch ───────────────────────────────────────────────────
1035// Batched matrix-vector multiply: process M input vectors against the same
1036// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1037// Each threadgroup handles one (token, row_group) pair.
1038kernel void matmul_vec_batch(
1039 device const float* matrix [[buffer(0)]], // [rows, cols] weight
1040 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1041 device float* outputs [[buffer(2)]], // [M, rows] output batch
1042 constant uint& num_tokens [[buffer(3)]], // M
1043 constant uint& rows [[buffer(4)]],
1044 constant uint& cols [[buffer(5)]],
1045 uint tgid [[threadgroup_position_in_grid]],
1046 uint tid [[thread_index_in_threadgroup]],
1047 uint simd_lane [[thread_index_in_simdgroup]],
1048 uint simd_id [[simdgroup_index_in_threadgroup]])
1049{
1050 uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1051 uint token = tgid / row_tgs;
1052 uint tg_in_token = tgid % row_tgs;
1053 if (token >= num_tokens) return;
1054
1055 // Load this token's input vector into shared memory
1056 threadgroup float vec_tile[VEC_TILE_SIZE];
1057 device const float* input = inputs + token * cols;
1058 for (uint i = tid; i < cols; i += 256) {
1059 vec_tile[i] = input[i];
1060 }
1061 threadgroup_barrier(mem_flags::mem_threadgroup);
1062
1063 uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1064 if (row_base >= rows) return;
1065
1066 uint base0 = row_base * cols;
1067 uint base1 = (row_base + 1) * cols;
1068 uint base2 = (row_base + 2) * cols;
1069 uint base3 = (row_base + 3) * cols;
1070 uint base4 = (row_base + 4) * cols;
1071 uint base5 = (row_base + 5) * cols;
1072 uint base6 = (row_base + 6) * cols;
1073 uint base7 = (row_base + 7) * cols;
1074
1075 uint cols_vec4 = cols & ~127u;
1076 float4 sum4_0 = float4(0.0f);
1077 float4 sum4_1 = float4(0.0f);
1078 float4 sum4_2 = float4(0.0f);
1079 float4 sum4_3 = float4(0.0f);
1080 float4 sum4_4 = float4(0.0f);
1081 float4 sum4_5 = float4(0.0f);
1082 float4 sum4_6 = float4(0.0f);
1083 float4 sum4_7 = float4(0.0f);
1084
1085 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1086 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1087 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1088 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1089 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1090 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1091 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1092 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1093 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1094 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1095 }
1096
1097 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1098 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1099 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1100 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1101 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1102 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1103 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1104 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1105
1106 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1107 float vv = vec_tile[j];
1108 sum0 += matrix[base0 + j] * vv;
1109 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1110 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1111 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1112 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1113 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1114 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1115 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1116 }
1117
1118 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1119 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1120 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1121 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1122
1123 device float* output = outputs + token * rows;
1124 if (simd_lane == 0) {
1125 if (row_base < rows) output[row_base] = sum0;
1126 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1127 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1128 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1129 if (row_base + 4 < rows) output[row_base + 4] = sum4;
1130 if (row_base + 5 < rows) output[row_base + 5] = sum5;
1131 if (row_base + 6 < rows) output[row_base + 6] = sum6;
1132 if (row_base + 7 < rows) output[row_base + 7] = sum7;
1133 }
1134}
1135
1136// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1137// Batched Q8_0 matrix-vector multiply for M input vectors.
1138// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1139kernel void matmul_vec_q8_batch(
1140 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1141 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1142 device float* outputs [[buffer(2)]], // [M, rows] output batch
1143 constant uint& num_tokens [[buffer(3)]], // M
1144 constant uint& rows [[buffer(4)]],
1145 constant uint& cols [[buffer(5)]],
1146 uint tgid [[threadgroup_position_in_grid]],
1147 uint tid [[thread_index_in_threadgroup]],
1148 uint simd_lane [[thread_index_in_simdgroup]],
1149 uint simd_id [[simdgroup_index_in_threadgroup]])
1150{
1151 uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1152 uint token = tgid / row_tgs;
1153 uint tg_in_token = tgid % row_tgs;
1154 if (token >= num_tokens) return;
1155
1156 // Load this token's input vector into shared memory
1157 threadgroup float vec_tile[VEC_TILE_SIZE];
1158 device const float* input = inputs + token * cols;
1159 for (uint i = tid; i < cols; i += 256) {
1160 vec_tile[i] = input[i];
1161 }
1162 threadgroup_barrier(mem_flags::mem_threadgroup);
1163
1164 uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1165 if (row_base >= rows) return;
1166
1167 uint blocks_per_row = cols / 32;
1168 uint row_bytes = blocks_per_row * 34;
1169
1170 device const uchar* r0 = matrix + row_base * row_bytes;
1171 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1172 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1173 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1174
1175 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1176
1177 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1178 uint bb = blk * 34;
1179 uint vb = blk * 32;
1180
1181 float sc0 = float(*(device const half*)(r0 + bb));
1182 float sc1 = float(*(device const half*)(r1 + bb));
1183 float sc2 = float(*(device const half*)(r2 + bb));
1184 float sc3 = float(*(device const half*)(r3 + bb));
1185
1186 // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1187 // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1188 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1189 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1190 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1191 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1192
1193 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1194 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1195 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1196 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1197 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1198 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1199 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1200 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1201
1202 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1203 short4 _s = short4(SHORT4); \
1204 char2 _a = as_type<char2>(_s.x); \
1205 char2 _b = as_type<char2>(_s.y); \
1206 char2 _c = as_type<char2>(_s.z); \
1207 char2 _d = as_type<char2>(_s.w); \
1208 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1209 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1210 }
1211
1212 float4 f0, f1;
1213 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1214
1215 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1216 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1217 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1218 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1219
1220 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1221 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1222 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1223 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1224
1225 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1226 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1227 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1228 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1229
1230 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1231 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1232 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1233 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1234
1235 #undef Q8_UNPACK8
1236
1237 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1238 }
1239
1240 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1241 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1242
1243 device float* output = outputs + token * rows;
1244 if (simd_lane == 0) {
1245 if (row_base < rows) output[row_base] = sum0;
1246 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1247 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1248 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1249 }
1250}
1251
1252// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1253// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1254// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1255// the Q8_0 weight blocks are fetched once from device memory and reused for
1256// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1257// per-token dispatch).
1258//
1259// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1260// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1261// with simd_sum. Token vectors are read directly from device memory inside
1262// the block loop (not cached in shared memory) so intermediate_size up to
1263// 8192 fits without spilling threadgroup memory.
1264constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1265
1266kernel void matmul_q8_gemm_batch(
1267 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1268 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1269 device float* outputs [[buffer(2)]], // [M, rows] output batch
1270 constant uint& num_tokens [[buffer(3)]], // M
1271 constant uint& rows [[buffer(4)]],
1272 constant uint& cols [[buffer(5)]],
1273 uint2 tgid [[threadgroup_position_in_grid]],
1274 uint tid [[thread_index_in_threadgroup]],
1275 uint simd_lane [[thread_index_in_simdgroup]],
1276 uint simd_id [[simdgroup_index_in_threadgroup]])
1277{
1278 uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1279 uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1280 if (row_base >= rows || tok_base >= num_tokens) return;
1281
1282 // How many tokens in this tile are valid?
1283 uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1284
1285 uint blocks_per_row = cols / 32;
1286 uint row_bytes = blocks_per_row * 34;
1287
1288 device const uchar* r0 = matrix + row_base * row_bytes;
1289 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1290 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1291 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1292
1293 // Accumulators: 4 tokens × 4 rows per simdgroup.
1294 float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1295 float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1296 float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1297 float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1298
1299 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1300 uint bb = blk * 34;
1301 uint vb = blk * 32;
1302
1303 // ── Load weight data ONCE per block (reused across all tokens) ──
1304 float sc0 = float(*(device const half*)(r0 + bb));
1305 float sc1 = float(*(device const half*)(r1 + bb));
1306 float sc2 = float(*(device const half*)(r2 + bb));
1307 float sc3 = float(*(device const half*)(r3 + bb));
1308
1309 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1310 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1311 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1312 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1313
1314 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1315 short4 _s = short4(SHORT4); \
1316 char2 _a = as_type<char2>(_s.x); \
1317 char2 _b = as_type<char2>(_s.y); \
1318 char2 _c = as_type<char2>(_s.z); \
1319 char2 _d = as_type<char2>(_s.w); \
1320 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1321 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1322 }
1323
1324 // Unpack all 4 rows × 8 float4 weights (scaled). These live in
1325 // registers for the duration of the block and are dotted against
1326 // every token's vector tile.
1327 float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1328 float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1329 float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1330 float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1331
1332 Q8_UNPACK8(d0[0], w0_0, w0_1);
1333 Q8_UNPACK8(d0[1], w0_2, w0_3);
1334 Q8_UNPACK8(d0[2], w0_4, w0_5);
1335 Q8_UNPACK8(d0[3], w0_6, w0_7);
1336
1337 Q8_UNPACK8(d1[0], w1_0, w1_1);
1338 Q8_UNPACK8(d1[1], w1_2, w1_3);
1339 Q8_UNPACK8(d1[2], w1_4, w1_5);
1340 Q8_UNPACK8(d1[3], w1_6, w1_7);
1341
1342 Q8_UNPACK8(d2[0], w2_0, w2_1);
1343 Q8_UNPACK8(d2[1], w2_2, w2_3);
1344 Q8_UNPACK8(d2[2], w2_4, w2_5);
1345 Q8_UNPACK8(d2[3], w2_6, w2_7);
1346
1347 Q8_UNPACK8(d3[0], w3_0, w3_1);
1348 Q8_UNPACK8(d3[1], w3_2, w3_3);
1349 Q8_UNPACK8(d3[2], w3_4, w3_5);
1350 Q8_UNPACK8(d3[3], w3_6, w3_7);
1351
1352 #undef Q8_UNPACK8
1353
1354 // ── For each token, read vector and accumulate against shared weights ──
1355 // Token 0 (always valid: tok_count >= 1).
1356 {
1357 device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1358 float4 v0 = *(device const float4*)(a0);
1359 float4 v1 = *(device const float4*)(a0 + 4);
1360 float4 v2 = *(device const float4*)(a0 + 8);
1361 float4 v3 = *(device const float4*)(a0 + 12);
1362 float4 v4 = *(device const float4*)(a0 + 16);
1363 float4 v5 = *(device const float4*)(a0 + 20);
1364 float4 v6 = *(device const float4*)(a0 + 24);
1365 float4 v7 = *(device const float4*)(a0 + 28);
1366 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1367 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1368 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1369 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1370 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1371 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1372 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1373 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1374 s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1375 }
1376 // Token 1
1377 if (tok_count > 1) {
1378 device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1379 float4 v0 = *(device const float4*)(a1);
1380 float4 v1 = *(device const float4*)(a1 + 4);
1381 float4 v2 = *(device const float4*)(a1 + 8);
1382 float4 v3 = *(device const float4*)(a1 + 12);
1383 float4 v4 = *(device const float4*)(a1 + 16);
1384 float4 v5 = *(device const float4*)(a1 + 20);
1385 float4 v6 = *(device const float4*)(a1 + 24);
1386 float4 v7 = *(device const float4*)(a1 + 28);
1387 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1388 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1389 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1390 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1391 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1392 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1393 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1394 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1395 s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1396 }
1397 // Token 2
1398 if (tok_count > 2) {
1399 device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1400 float4 v0 = *(device const float4*)(a2);
1401 float4 v1 = *(device const float4*)(a2 + 4);
1402 float4 v2 = *(device const float4*)(a2 + 8);
1403 float4 v3 = *(device const float4*)(a2 + 12);
1404 float4 v4 = *(device const float4*)(a2 + 16);
1405 float4 v5 = *(device const float4*)(a2 + 20);
1406 float4 v6 = *(device const float4*)(a2 + 24);
1407 float4 v7 = *(device const float4*)(a2 + 28);
1408 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1409 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1410 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1411 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1412 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1413 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1414 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1415 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1416 s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1417 }
1418 // Token 3
1419 if (tok_count > 3) {
1420 device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1421 float4 v0 = *(device const float4*)(a3);
1422 float4 v1 = *(device const float4*)(a3 + 4);
1423 float4 v2 = *(device const float4*)(a3 + 8);
1424 float4 v3 = *(device const float4*)(a3 + 12);
1425 float4 v4 = *(device const float4*)(a3 + 16);
1426 float4 v5 = *(device const float4*)(a3 + 20);
1427 float4 v6 = *(device const float4*)(a3 + 24);
1428 float4 v7 = *(device const float4*)(a3 + 28);
1429 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1430 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1431 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1432 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1433 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1434 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1435 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1436 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1437 s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1438 }
1439 }
1440
1441 // simdgroup reduction
1442 s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1443 s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1444 s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1445 s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1446
1447 if (simd_lane == 0) {
1448 device float* o0 = outputs + (tok_base + 0) * rows;
1449 if (row_base < rows) o0[row_base] = s00;
1450 if (row_base + 1 < rows) o0[row_base + 1] = s01;
1451 if (row_base + 2 < rows) o0[row_base + 2] = s02;
1452 if (row_base + 3 < rows) o0[row_base + 3] = s03;
1453
1454 if (tok_count > 1) {
1455 device float* o1 = outputs + (tok_base + 1) * rows;
1456 if (row_base < rows) o1[row_base] = s10;
1457 if (row_base + 1 < rows) o1[row_base + 1] = s11;
1458 if (row_base + 2 < rows) o1[row_base + 2] = s12;
1459 if (row_base + 3 < rows) o1[row_base + 3] = s13;
1460 }
1461 if (tok_count > 2) {
1462 device float* o2 = outputs + (tok_base + 2) * rows;
1463 if (row_base < rows) o2[row_base] = s20;
1464 if (row_base + 1 < rows) o2[row_base + 1] = s21;
1465 if (row_base + 2 < rows) o2[row_base + 2] = s22;
1466 if (row_base + 3 < rows) o2[row_base + 3] = s23;
1467 }
1468 if (tok_count > 3) {
1469 device float* o3 = outputs + (tok_base + 3) * rows;
1470 if (row_base < rows) o3[row_base] = s30;
1471 if (row_base + 1 < rows) o3[row_base + 1] = s31;
1472 if (row_base + 2 < rows) o3[row_base + 2] = s32;
1473 if (row_base + 3 < rows) o3[row_base + 3] = s33;
1474 }
1475 }
1476}
1477
1478// ── matmul_q8_mma ──────────────────────────────────────────────────────
1479// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1480// simdgroup_matrix tiles (simdgroup_multiply_accumulate). This dispatches
1481// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1482// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1483//
1484// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1485// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1486// simdgroup_matrix<float, 8, 8> accumulator. Weight bytes are cooperatively
1487// dequantized into threadgroup memory once per block and reused by all
1488// simdgroups in the tile.
1489//
1490// Assumptions (verified in the dispatch helper, falls back otherwise):
1491// * cols % 32 == 0 (one Q8_0 block per K chunk)
1492// * rows % 16 == 0 (tile-aligned; true for all supported architectures)
1493// * num_tokens may be any value; partial row at the tile boundary is handled
1494// via a scratch copy path.
1495constant constexpr uint MMA_TOK_TILE = 16;
1496constant constexpr uint MMA_ROW_TILE = 16;
1497
1498kernel void matmul_q8_mma(
1499 device const uchar* matrix [[buffer(0)]], // Q8_0 [rows, cols/32 * 34]
1500 device const float* inputs [[buffer(1)]], // [M, cols]
1501 device float* outputs [[buffer(2)]], // [M, rows]
1502 constant uint& num_tokens [[buffer(3)]],
1503 constant uint& rows [[buffer(4)]],
1504 constant uint& cols [[buffer(5)]],
1505 uint2 tgid [[threadgroup_position_in_grid]],
1506 uint tid [[thread_index_in_threadgroup]],
1507 uint simd_id [[simdgroup_index_in_threadgroup]])
1508{
1509 uint row_base = tgid.x * MMA_ROW_TILE;
1510 uint tok_base = tgid.y * MMA_TOK_TILE;
1511 if (row_base >= rows || tok_base >= num_tokens) return;
1512
1513 // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1514 threadgroup float w_tile[MMA_ROW_TILE * 32];
1515 threadgroup float t_tile[MMA_TOK_TILE * 32];
1516
1517 // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1518 uint sg_tok_base = (simd_id / 2) * 8; // row within output tile (token dim)
1519 uint sg_row_base = (simd_id % 2) * 8; // col within output tile (row dim)
1520
1521 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1522
1523 uint blocks_per_row = cols / 32;
1524 uint row_bytes = blocks_per_row * 34;
1525
1526 for (uint blk = 0; blk < blocks_per_row; blk++) {
1527 // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1528 // 512 floats / 128 threads = 4 floats per thread.
1529 {
1530 uint base = tid * 4;
1531 for (uint ii = 0; ii < 4; ii++) {
1532 uint idx = base + ii;
1533 uint r = idx / 32;
1534 uint k = idx % 32;
1535 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1536 float sc = float(*(device const half*)rp);
1537 int ival = int(*(device const int8_t*)(rp + 2 + k));
1538 w_tile[r * 32 + k] = float(ival) * sc;
1539 }
1540 }
1541
1542 // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1543 {
1544 uint base = tid * 4;
1545 for (uint ii = 0; ii < 4; ii++) {
1546 uint idx = base + ii;
1547 uint m = idx / 32;
1548 uint k = idx % 32;
1549 uint tok = tok_base + m;
1550 t_tile[m * 32 + k] = (tok < num_tokens)
1551 ? inputs[tok * cols + blk * 32 + k]
1552 : 0.0f;
1553 }
1554 }
1555
1556 threadgroup_barrier(mem_flags::mem_threadgroup);
1557
1558 // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1559 // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k] (M×K, no transpose)
1560 // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k] (loaded transposed → K×R)
1561 // C[m, r] += A[m, k] * B[k, r]
1562 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1563 simdgroup_matrix<float, 8, 8> A, B;
1564 simdgroup_load(A,
1565 t_tile + sg_tok_base * 32 + k_sub * 8,
1566 32,
1567 ulong2(0, 0),
1568 false);
1569 simdgroup_load(B,
1570 w_tile + sg_row_base * 32 + k_sub * 8,
1571 32,
1572 ulong2(0, 0),
1573 true);
1574 simdgroup_multiply_accumulate(C, A, B, C);
1575 }
1576
1577 threadgroup_barrier(mem_flags::mem_threadgroup);
1578 }
1579
1580 // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1581 // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1582 uint out_tok = tok_base + sg_tok_base;
1583 uint out_row = row_base + sg_row_base;
1584 bool full_tok = (out_tok + 8 <= num_tokens);
1585 if (full_tok) {
1586 // Fast path: entire 8×8 sub-tile is in-bounds.
1587 simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1588 } else if (out_tok < num_tokens) {
1589 // Partial row at the last token tile: stage in per-simdgroup scratch
1590 // and scalar-copy the valid rows.
1591 threadgroup float scratch[4 * 64];
1592 simdgroup_store(C, scratch + simd_id * 64, 8);
1593 simdgroup_barrier(mem_flags::mem_threadgroup);
1594 uint lane = tid % 32;
1595 if (lane == 0) {
1596 uint valid = num_tokens - out_tok; // 1..7
1597 for (uint m = 0; m < valid; m++) {
1598 device float* dst = outputs + (out_tok + m) * rows + out_row;
1599 threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1600 dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1601 dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1602 }
1603 }
1604 }
1605}
1606
1607// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1608// Larger-tile variant of matmul_q8_mma for long-context prefill.
1609//
1610// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1611// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1612// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1613//
1614// simd_id = 2*sg_tok_idx + sg_row_half (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1615// output sub-tiles (tok, row):
1616// (sg_tok_idx*8, sg_row_half*16 + 0) -> C_a
1617// (sg_tok_idx*8, sg_row_half*16 + 8) -> C_b
1618//
1619// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1620// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1621// kernel — and halves the number of threadgroups vs the 16×16 tile.
1622//
1623// Assumptions (verified in dispatch helper, fallback otherwise):
1624// * cols % 32 == 0
1625// * rows % 32 == 0
1626constant constexpr uint MMA32_TOK_TILE = 32;
1627constant constexpr uint MMA32_ROW_TILE = 32;
1628
1629kernel void matmul_q8_mma32(
1630 device const uchar* matrix [[buffer(0)]],
1631 device const float* inputs [[buffer(1)]],
1632 device float* outputs [[buffer(2)]],
1633 constant uint& num_tokens [[buffer(3)]],
1634 constant uint& rows [[buffer(4)]],
1635 constant uint& cols [[buffer(5)]],
1636 uint2 tgid [[threadgroup_position_in_grid]],
1637 uint tid [[thread_index_in_threadgroup]],
1638 uint simd_id [[simdgroup_index_in_threadgroup]])
1639{
1640 uint row_base = tgid.x * MMA32_ROW_TILE;
1641 uint tok_base = tgid.y * MMA32_TOK_TILE;
1642 if (row_base >= rows || tok_base >= num_tokens) return;
1643
1644 // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1645 threadgroup float w_tile[MMA32_ROW_TILE * 32];
1646 threadgroup float t_tile[MMA32_TOK_TILE * 32];
1647
1648 uint sg_tok_idx = simd_id / 2; // 0..3
1649 uint sg_row_half = simd_id % 2; // 0..1
1650 uint sg_tok_base = sg_tok_idx * 8;
1651 uint sg_row_base_a = sg_row_half * 16 + 0;
1652 uint sg_row_base_b = sg_row_half * 16 + 8;
1653
1654 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1655 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1656
1657 uint blocks_per_row = cols / 32;
1658 uint row_bytes = blocks_per_row * 34;
1659
1660 for (uint blk = 0; blk < blocks_per_row; blk++) {
1661 // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1662 {
1663 uint base = tid * 4;
1664 for (uint ii = 0; ii < 4; ii++) {
1665 uint idx = base + ii;
1666 uint r = idx / 32;
1667 uint k = idx % 32;
1668 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1669 float sc = float(*(device const half*)rp);
1670 int ival = int(*(device const int8_t*)(rp + 2 + k));
1671 w_tile[r * 32 + k] = float(ival) * sc;
1672 }
1673 }
1674
1675 // Cooperative token tile load.
1676 {
1677 uint base = tid * 4;
1678 for (uint ii = 0; ii < 4; ii++) {
1679 uint idx = base + ii;
1680 uint m = idx / 32;
1681 uint k = idx % 32;
1682 uint tok = tok_base + m;
1683 t_tile[m * 32 + k] = (tok < num_tokens)
1684 ? inputs[tok * cols + blk * 32 + k]
1685 : 0.0f;
1686 }
1687 }
1688
1689 threadgroup_barrier(mem_flags::mem_threadgroup);
1690
1691 // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1692 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1693 simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1694 simdgroup_load(A,
1695 t_tile + sg_tok_base * 32 + k_sub * 8,
1696 32,
1697 ulong2(0, 0),
1698 false);
1699 simdgroup_load(B_a,
1700 w_tile + sg_row_base_a * 32 + k_sub * 8,
1701 32,
1702 ulong2(0, 0),
1703 true);
1704 simdgroup_load(B_b,
1705 w_tile + sg_row_base_b * 32 + k_sub * 8,
1706 32,
1707 ulong2(0, 0),
1708 true);
1709 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1710 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1711 }
1712
1713 threadgroup_barrier(mem_flags::mem_threadgroup);
1714 }
1715
1716 // Store both 8×8 accumulators. rows is always MMA32_ROW_TILE-aligned
1717 // (verified in dispatch), so full simdgroup_store is safe for the row
1718 // dimension; only the last token tile may be partial.
1719 uint out_tok = tok_base + sg_tok_base;
1720 uint out_row_a = row_base + sg_row_base_a;
1721 uint out_row_b = row_base + sg_row_base_b;
1722 bool full_tok = (out_tok + 8 <= num_tokens);
1723 if (full_tok) {
1724 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1725 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1726 } else if (out_tok < num_tokens) {
1727 threadgroup float scratch[8 * 2 * 64]; // 8 simdgroups × 2 accs × 64 floats
1728 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1729 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1730 simdgroup_barrier(mem_flags::mem_threadgroup);
1731 uint lane = tid % 32;
1732 if (lane == 0) {
1733 uint valid = num_tokens - out_tok; // 1..7
1734 for (uint m = 0; m < valid; m++) {
1735 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1736 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1737 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1738 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1739 for (uint j = 0; j < 8; j++) {
1740 dst_a[j] = src_a[j];
1741 dst_b[j] = src_b[j];
1742 }
1743 }
1744 }
1745 }
1746}
1747
1748// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1749// FP16 threadgroup-tile variant of matmul_q8_mma32.
1750//
1751// Stores dequantized weights and token inputs as `half` in threadgroup
1752// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1753// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1754// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1755// representation preserves the full quantized dynamic range. Token
1756// activations stay numerically safe because the subsequent
1757// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1758//
1759// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1760// accumulators each. Primary win vs mma32 is occupancy at moderate
1761// prefill lengths where the GPU is wave-starved.
1762kernel void matmul_q8_mma32_h(
1763 device const uchar* matrix [[buffer(0)]],
1764 device const float* inputs [[buffer(1)]],
1765 device float* outputs [[buffer(2)]],
1766 constant uint& num_tokens [[buffer(3)]],
1767 constant uint& rows [[buffer(4)]],
1768 constant uint& cols [[buffer(5)]],
1769 uint2 tgid [[threadgroup_position_in_grid]],
1770 uint tid [[thread_index_in_threadgroup]],
1771 uint simd_id [[simdgroup_index_in_threadgroup]])
1772{
1773 uint row_base = tgid.x * MMA32_ROW_TILE;
1774 uint tok_base = tgid.y * MMA32_TOK_TILE;
1775 if (row_base >= rows || tok_base >= num_tokens) return;
1776
1777 // 32×32 half tiles — 2 KB each, 4 KB total.
1778 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1779 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1780
1781 uint sg_tok_idx = simd_id / 2;
1782 uint sg_row_half = simd_id % 2;
1783 uint sg_tok_base = sg_tok_idx * 8;
1784 uint sg_row_base_a = sg_row_half * 16 + 0;
1785 uint sg_row_base_b = sg_row_half * 16 + 8;
1786
1787 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1788 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1789
1790 uint blocks_per_row = cols / 32;
1791 uint row_bytes = blocks_per_row * 34;
1792
1793 for (uint blk = 0; blk < blocks_per_row; blk++) {
1794 // Cooperative weight dequantization to FP16.
1795 {
1796 uint base = tid * 4;
1797 for (uint ii = 0; ii < 4; ii++) {
1798 uint idx = base + ii;
1799 uint r = idx / 32;
1800 uint k = idx % 32;
1801 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1802 float sc = float(*(device const half*)rp);
1803 int ival = int(*(device const int8_t*)(rp + 2 + k));
1804 w_tile[r * 32 + k] = half(float(ival) * sc);
1805 }
1806 }
1807
1808 // Cooperative token tile load (f32 → f16 narrowing).
1809 {
1810 uint base = tid * 4;
1811 for (uint ii = 0; ii < 4; ii++) {
1812 uint idx = base + ii;
1813 uint m = idx / 32;
1814 uint k = idx % 32;
1815 uint tok = tok_base + m;
1816 t_tile[m * 32 + k] = (tok < num_tokens)
1817 ? half(inputs[tok * cols + blk * 32 + k])
1818 : half(0);
1819 }
1820 }
1821
1822 threadgroup_barrier(mem_flags::mem_threadgroup);
1823
1824 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1825 simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1826 simdgroup_load(A,
1827 t_tile + sg_tok_base * 32 + k_sub * 8,
1828 32,
1829 ulong2(0, 0),
1830 false);
1831 simdgroup_load(B_a,
1832 w_tile + sg_row_base_a * 32 + k_sub * 8,
1833 32,
1834 ulong2(0, 0),
1835 true);
1836 simdgroup_load(B_b,
1837 w_tile + sg_row_base_b * 32 + k_sub * 8,
1838 32,
1839 ulong2(0, 0),
1840 true);
1841 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1842 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1843 }
1844
1845 threadgroup_barrier(mem_flags::mem_threadgroup);
1846 }
1847
1848 uint out_tok = tok_base + sg_tok_base;
1849 uint out_row_a = row_base + sg_row_base_a;
1850 uint out_row_b = row_base + sg_row_base_b;
1851 bool full_tok = (out_tok + 8 <= num_tokens);
1852 if (full_tok) {
1853 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1854 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1855 } else if (out_tok < num_tokens) {
1856 threadgroup float scratch[8 * 2 * 64];
1857 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1858 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1859 simdgroup_barrier(mem_flags::mem_threadgroup);
1860 uint lane = tid % 32;
1861 if (lane == 0) {
1862 uint valid = num_tokens - out_tok;
1863 for (uint m = 0; m < valid; m++) {
1864 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1865 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1866 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1867 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1868 for (uint j = 0; j < 8; j++) {
1869 dst_a[j] = src_a[j];
1870 dst_b[j] = src_b[j];
1871 }
1872 }
1873 }
1874 }
1875}
1876
1877// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1878// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1879//
1880// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1881// 4 simdgroups × **2×2 grid** of 8×8 accumulators each. Per simdgroup:
1882// C_00 (tok 0..8, row 0..8) C_01 (tok 0..8, row 8..16)
1883// C_10 (tok 8..16, row 0..8) C_11 (tok 8..16, row 8..16)
1884// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1885//
1886// Per K_sub iteration: load two A fragments and two B fragments, then run
1887// **four** MMA instructions reusing A_top with both B's and A_bot with
1888// both B's. That's double the FLOP-per-simdgroup-load compared to the
1889// 2-accumulator kernel and halves the thread count per threadgroup (128
1890// threads), which often improves occupancy on Apple GPUs where the
1891// concurrent-thread budget is the tighter limit than shared-memory size.
1892kernel void matmul_q8_mma32_h4(
1893 device const uchar* matrix [[buffer(0)]],
1894 device const float* inputs [[buffer(1)]],
1895 device float* outputs [[buffer(2)]],
1896 constant uint& num_tokens [[buffer(3)]],
1897 constant uint& rows [[buffer(4)]],
1898 constant uint& cols [[buffer(5)]],
1899 uint2 tgid [[threadgroup_position_in_grid]],
1900 uint tid [[thread_index_in_threadgroup]],
1901 uint simd_id [[simdgroup_index_in_threadgroup]])
1902{
1903 uint row_base = tgid.x * MMA32_ROW_TILE;
1904 uint tok_base = tgid.y * MMA32_TOK_TILE;
1905 if (row_base >= rows || tok_base >= num_tokens) return;
1906
1907 // 32×32 FP16 tiles, 4 KB total.
1908 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1909 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1910
1911 // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1912 uint sg_tok_q = simd_id / 2; // 0..1
1913 uint sg_row_q = simd_id % 2; // 0..1
1914 uint sg_tok_base = sg_tok_q * 16;
1915 uint sg_row_base = sg_row_q * 16;
1916
1917 simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1918 simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1919 simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1920 simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1921
1922 uint blocks_per_row = cols / 32;
1923 uint row_bytes = blocks_per_row * 34;
1924
1925 for (uint blk = 0; blk < blocks_per_row; blk++) {
1926 // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1927 {
1928 uint base = tid * 8;
1929 for (uint ii = 0; ii < 8; ii++) {
1930 uint idx = base + ii;
1931 uint r = idx / 32;
1932 uint k = idx % 32;
1933 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1934 float sc = float(*(device const half*)rp);
1935 int ival = int(*(device const int8_t*)(rp + 2 + k));
1936 w_tile[r * 32 + k] = half(float(ival) * sc);
1937 }
1938 }
1939
1940 // Cooperative token tile load.
1941 {
1942 uint base = tid * 8;
1943 for (uint ii = 0; ii < 8; ii++) {
1944 uint idx = base + ii;
1945 uint m = idx / 32;
1946 uint k = idx % 32;
1947 uint tok = tok_base + m;
1948 t_tile[m * 32 + k] = (tok < num_tokens)
1949 ? half(inputs[tok * cols + blk * 32 + k])
1950 : half(0);
1951 }
1952 }
1953
1954 threadgroup_barrier(mem_flags::mem_threadgroup);
1955
1956 // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1957 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1958 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1959 simdgroup_load(A_top,
1960 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1961 32,
1962 ulong2(0, 0),
1963 false);
1964 simdgroup_load(A_bot,
1965 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1966 32,
1967 ulong2(0, 0),
1968 false);
1969 simdgroup_load(B_lo,
1970 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1971 32,
1972 ulong2(0, 0),
1973 true);
1974 simdgroup_load(B_hi,
1975 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1976 32,
1977 ulong2(0, 0),
1978 true);
1979 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1980 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1981 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1982 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1983 }
1984
1985 threadgroup_barrier(mem_flags::mem_threadgroup);
1986 }
1987
1988 // Store 4 output tiles. Full-tile fast path assumes full 16×16 valid.
1989 uint out_tok_top = tok_base + sg_tok_base + 0;
1990 uint out_tok_bot = tok_base + sg_tok_base + 8;
1991 uint out_row_lo = row_base + sg_row_base + 0;
1992 uint out_row_hi = row_base + sg_row_base + 8;
1993 bool full = (out_tok_bot + 8 <= num_tokens);
1994 if (full) {
1995 simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1996 simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1997 simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1998 simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1999 } else {
2000 // Partial-token fallback via per-simdgroup scratch.
2001 threadgroup float scratch[4 * 4 * 64]; // 4 simdgroups × 4 accs × 64
2002 uint sg_base = simd_id * 256;
2003 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2004 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2005 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2006 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2007 simdgroup_barrier(mem_flags::mem_threadgroup);
2008 uint lane = tid % 32;
2009 if (lane == 0) {
2010 for (uint m = 0; m < 8; m++) {
2011 uint t_top = out_tok_top + m;
2012 if (t_top < num_tokens) {
2013 device float* dst0 = outputs + t_top * rows + out_row_lo;
2014 device float* dst1 = outputs + t_top * rows + out_row_hi;
2015 threadgroup const float* src0 = scratch + sg_base + 0 + m * 8;
2016 threadgroup const float* src1 = scratch + sg_base + 64 + m * 8;
2017 for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2018 }
2019 uint t_bot = out_tok_bot + m;
2020 if (t_bot < num_tokens) {
2021 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2022 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2023 threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2024 threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2025 for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2026 }
2027 }
2028 }
2029 }
2030}
2031
2032// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2033// All-half MMA variant of matmul_q8_mma32_h4.
2034//
2035// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2036// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2037// rate (dual-issue FMA), so if Q8_0 precision holds through half
2038// accumulation this kernel can double the effective FLOP throughput on
2039// matmul-bound prefill.
2040//
2041// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2042// activations at each layer are bounded (post-RMSNorm ≈ O(1)). Summing
2043// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2044// precision on extreme values, but the inputs are already quantized so the
2045// per-product error floor is higher than the half-precision rounding error.
2046// We verify correctness on 135M / 1B / 3B before enabling.
2047kernel void matmul_q8_mma32_hh4(
2048 device const uchar* matrix [[buffer(0)]],
2049 device const float* inputs [[buffer(1)]],
2050 device float* outputs [[buffer(2)]],
2051 constant uint& num_tokens [[buffer(3)]],
2052 constant uint& rows [[buffer(4)]],
2053 constant uint& cols [[buffer(5)]],
2054 uint2 tgid [[threadgroup_position_in_grid]],
2055 uint tid [[thread_index_in_threadgroup]],
2056 uint simd_id [[simdgroup_index_in_threadgroup]])
2057{
2058 uint row_base = tgid.x * MMA32_ROW_TILE;
2059 uint tok_base = tgid.y * MMA32_TOK_TILE;
2060 if (row_base >= rows || tok_base >= num_tokens) return;
2061
2062 threadgroup half w_tile[MMA32_ROW_TILE * 32];
2063 threadgroup half t_tile[MMA32_TOK_TILE * 32];
2064
2065 uint sg_tok_q = simd_id / 2;
2066 uint sg_row_q = simd_id % 2;
2067 uint sg_tok_base = sg_tok_q * 16;
2068 uint sg_row_base = sg_row_q * 16;
2069
2070 simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2071 simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2072 simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2073 simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2074
2075 uint blocks_per_row = cols / 32;
2076 uint row_bytes = blocks_per_row * 34;
2077
2078 for (uint blk = 0; blk < blocks_per_row; blk++) {
2079 {
2080 uint base = tid * 8;
2081 for (uint ii = 0; ii < 8; ii++) {
2082 uint idx = base + ii;
2083 uint r = idx / 32;
2084 uint k = idx % 32;
2085 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2086 float sc = float(*(device const half*)rp);
2087 int ival = int(*(device const int8_t*)(rp + 2 + k));
2088 w_tile[r * 32 + k] = half(float(ival) * sc);
2089 }
2090 }
2091 {
2092 uint base = tid * 8;
2093 for (uint ii = 0; ii < 8; ii++) {
2094 uint idx = base + ii;
2095 uint m = idx / 32;
2096 uint k = idx % 32;
2097 uint tok = tok_base + m;
2098 t_tile[m * 32 + k] = (tok < num_tokens)
2099 ? half(inputs[tok * cols + blk * 32 + k])
2100 : half(0);
2101 }
2102 }
2103 threadgroup_barrier(mem_flags::mem_threadgroup);
2104
2105 for (uint k_sub = 0; k_sub < 4; k_sub++) {
2106 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2107 simdgroup_load(A_top,
2108 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2109 32, ulong2(0, 0), false);
2110 simdgroup_load(A_bot,
2111 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2112 32, ulong2(0, 0), false);
2113 simdgroup_load(B_lo,
2114 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2115 32, ulong2(0, 0), true);
2116 simdgroup_load(B_hi,
2117 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2118 32, ulong2(0, 0), true);
2119 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2120 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2121 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2122 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2123 }
2124
2125 threadgroup_barrier(mem_flags::mem_threadgroup);
2126 }
2127
2128 // Store half accumulators via scratch (must widen to f32 for device output).
2129 uint out_tok_top = tok_base + sg_tok_base + 0;
2130 uint out_tok_bot = tok_base + sg_tok_base + 8;
2131 uint out_row_lo = row_base + sg_row_base + 0;
2132 uint out_row_hi = row_base + sg_row_base + 8;
2133
2134 threadgroup half scratch[4 * 4 * 64];
2135 uint sg_base = simd_id * 256;
2136 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2137 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2138 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2139 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2140 simdgroup_barrier(mem_flags::mem_threadgroup);
2141 uint lane = tid % 32;
2142 if (lane == 0) {
2143 for (uint m = 0; m < 8; m++) {
2144 uint t_top = out_tok_top + m;
2145 if (t_top < num_tokens) {
2146 device float* dst0 = outputs + t_top * rows + out_row_lo;
2147 device float* dst1 = outputs + t_top * rows + out_row_hi;
2148 threadgroup const half* src0 = scratch + sg_base + 0 + m * 8;
2149 threadgroup const half* src1 = scratch + sg_base + 64 + m * 8;
2150 for (uint j = 0; j < 8; j++) {
2151 dst0[j] = float(src0[j]);
2152 dst1[j] = float(src1[j]);
2153 }
2154 }
2155 uint t_bot = out_tok_bot + m;
2156 if (t_bot < num_tokens) {
2157 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2158 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2159 threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2160 threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2161 for (uint j = 0; j < 8; j++) {
2162 dst2[j] = float(src2[j]);
2163 dst3[j] = float(src3[j]);
2164 }
2165 }
2166 }
2167 }
2168}
2169
2170// ── add_bias_batch ─────────────────────────────────────────────────────
2171// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2172// Used for Qwen2 QKV bias after the fused qkv matmul.
2173// out[token, i] += bias[i] for i in 0..rows, token in 0..num_tokens
2174kernel void add_bias_batch(
2175 device float* out [[buffer(0)]], // [num_tokens, rows]
2176 device const float* bias [[buffer(1)]], // [rows]
2177 constant uint& num_tokens [[buffer(2)]],
2178 constant uint& rows [[buffer(3)]],
2179 uint id [[thread_position_in_grid]])
2180{
2181 uint total = num_tokens * rows;
2182 if (id >= total) return;
2183 uint i = id % rows;
2184 out[id] += bias[i];
2185}
2186
2187// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2188// Batched Q4_0 matrix-vector multiply for M input vectors.
2189// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2190kernel void matmul_vec_q4_batch(
2191 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes [rows, cols]
2192 device const float* inputs [[buffer(1)]], // [M, cols] input batch
2193 device float* outputs [[buffer(2)]], // [M, rows] output batch
2194 constant uint& num_tokens [[buffer(3)]], // M
2195 constant uint& rows [[buffer(4)]],
2196 constant uint& cols [[buffer(5)]],
2197 uint tgid [[threadgroup_position_in_grid]],
2198 uint tid [[thread_index_in_threadgroup]],
2199 uint simd_lane [[thread_index_in_simdgroup]],
2200 uint simd_id [[simdgroup_index_in_threadgroup]])
2201{
2202 uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2203 uint token = tgid / row_tgs;
2204 uint tg_in_token = tgid % row_tgs;
2205 if (token >= num_tokens) return;
2206
2207 threadgroup float vec_tile[VEC_TILE_SIZE];
2208 device const float* input = inputs + token * cols;
2209 for (uint i = tid; i < cols; i += 256) {
2210 vec_tile[i] = input[i];
2211 }
2212 threadgroup_barrier(mem_flags::mem_threadgroup);
2213
2214 uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2215 if (row_base >= rows) return;
2216
2217 uint blocks_per_row = cols / 32;
2218 uint row_bytes = blocks_per_row * 18;
2219
2220 device const uchar* r0 = matrix + row_base * row_bytes;
2221 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2222 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2223 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2224
2225 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2226
2227 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2228 uint bb = blk * 18;
2229 uint vb = blk * 32;
2230
2231 float sc0 = float(*(device const half*)(r0 + bb));
2232 float sc1 = float(*(device const half*)(r1 + bb));
2233 float sc2 = float(*(device const half*)(r2 + bb));
2234 float sc3 = float(*(device const half*)(r3 + bb));
2235
2236 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2237 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2238 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2239 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2240
2241 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2242 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2243 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2244 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2245 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2246 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2247 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2248 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2249
2250 float bd0=0, bd1=0, bd2=0, bd3=0;
2251 uchar4 b;
2252
2253 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2254 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2255 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2256 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2257
2258 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2259 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2260 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2261 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2262
2263 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2264 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2265 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2266 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2267
2268 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2269 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2270 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2271 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2272
2273 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2274 }
2275
2276 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2277 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2278
2279 device float* output = outputs + token * rows;
2280 if (simd_lane == 0) {
2281 if (row_base < rows) output[row_base] = sum0;
2282 if (row_base + 1 < rows) output[row_base + 1] = sum1;
2283 if (row_base + 2 < rows) output[row_base + 2] = sum2;
2284 if (row_base + 3 < rows) output[row_base + 3] = sum3;
2285 }
2286}
2287
2288// ── copy_kv_batch ─────────────────────────────────────────────────────
2289// Copy K or V from a strided batch QKV buffer to the KV cache.
2290// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2291// dst layout: contiguous [max_seq, kv_dim] cache.
2292kernel void copy_kv_batch(
2293 device const float* src [[buffer(0)]], // batch QKV buffer (f32)
2294 device half* dst [[buffer(1)]], // KV cache (f16)
2295 constant uint& M [[buffer(2)]], // num tokens in batch
2296 constant uint& kv_dim [[buffer(3)]], // elements per KV vector
2297 constant uint& base_pos [[buffer(4)]], // starting position in cache
2298 constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2299 constant uint& src_offset [[buffer(6)]], // float offset within each src row
2300 uint id [[thread_position_in_grid]])
2301{
2302 uint total = M * kv_dim;
2303 if (id >= total) return;
2304 uint token = id / kv_dim;
2305 uint d = id % kv_dim;
2306 uint dst_off = (base_pos + token) * kv_dim + d;
2307 uint src_off = token * src_stride + src_offset + d;
2308 dst[dst_off] = half(src[src_off]);
2309}
2310
2311// ── attention_batch ───────────────────────────────────────────────────
2312// Batched causal attention for prefill. Processes M tokens in one dispatch.
2313// Each threadgroup handles one (token, head) pair.
2314// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2315// Causal masking: token i can only attend to positions 0..base_pos+i.
2316kernel void attention_batch(
2317 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2318 device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
2319 device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
2320 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2321 constant uint& M [[buffer(4)]], // num tokens in batch
2322 constant uint& base_pos [[buffer(5)]], // starting position in KV cache
2323 constant uint& num_heads [[buffer(6)]],
2324 constant uint& num_kv_heads [[buffer(7)]],
2325 constant uint& head_dim [[buffer(8)]],
2326 constant uint& q_stride [[buffer(9)]], // floats per row in q_batch
2327 uint tgid [[threadgroup_position_in_grid]],
2328 uint tid [[thread_index_in_threadgroup]],
2329 uint simd_lane [[thread_index_in_simdgroup]],
2330 uint simd_id [[simdgroup_index_in_threadgroup]])
2331{
2332 // Grid: M * num_heads threadgroups
2333 uint token_idx = tgid / num_heads;
2334 uint head = tgid % num_heads;
2335 if (token_idx >= M) return;
2336
2337 uint kv_head = head / (num_heads / num_kv_heads);
2338 uint seq_len = base_pos + token_idx + 1; // causal: see positions 0..base_pos+token_idx
2339
2340 // Q offset uses strided layout (from batch QKV buffer)
2341 uint q_off = token_idx * q_stride + head * head_dim;
2342 // Output is contiguous [M, num_heads * head_dim]
2343 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2344
2345 // Shared memory for attention scores — sized to the effective max_seq_len
2346 // (4096 for all supported models) so long-context attention doesn't overflow.
2347 threadgroup float scores[ATTN_SCORES_SIZE];
2348
2349 // Step 1: Q * K^T with simdgroup reduction
2350 for (uint s = simd_id; s < seq_len; s += 8) {
2351 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2352 float dot = 0.0;
2353 for (uint d = simd_lane; d < head_dim; d += 32) {
2354 dot += q_batch[q_off + d] * k_cache[k_off + d];
2355 }
2356 dot = simd_sum(dot);
2357 if (simd_lane == 0) {
2358 scores[s] = dot * fast::rsqrt(float(head_dim));
2359 }
2360 }
2361 threadgroup_barrier(mem_flags::mem_threadgroup);
2362
2363 // Step 2: Softmax (cooperative)
2364 float local_max = -INFINITY;
2365 for (uint s = tid; s < seq_len; s += 256) {
2366 local_max = max(local_max, scores[s]);
2367 }
2368 local_max = simd_max(local_max);
2369 threadgroup float shared_max[8];
2370 if (simd_lane == 0) shared_max[simd_id] = local_max;
2371 threadgroup_barrier(mem_flags::mem_threadgroup);
2372 if (tid == 0) {
2373 float m = shared_max[0];
2374 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2375 shared_max[0] = m;
2376 }
2377 threadgroup_barrier(mem_flags::mem_threadgroup);
2378 float max_val = shared_max[0];
2379
2380 float local_sum = 0.0;
2381 for (uint s = tid; s < seq_len; s += 256) {
2382 scores[s] = fast::exp(scores[s] - max_val);
2383 local_sum += scores[s];
2384 }
2385 local_sum = simd_sum(local_sum);
2386 threadgroup float shared_sum[8];
2387 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2388 threadgroup_barrier(mem_flags::mem_threadgroup);
2389 if (tid == 0) {
2390 float total = 0.0;
2391 for (uint i = 0; i < 8; i++) total += shared_sum[i];
2392 shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2393 }
2394 threadgroup_barrier(mem_flags::mem_threadgroup);
2395 float inv_sum = shared_sum[0];
2396 for (uint s = tid; s < seq_len; s += 256) {
2397 scores[s] *= inv_sum;
2398 }
2399 threadgroup_barrier(mem_flags::mem_threadgroup);
2400
2401 // Step 3: scores * V using float4 vectorized loads
2402 // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2403 // This is much better than the scalar version where only 64 of 256 threads are active.
2404 uint v_stride = num_kv_heads * head_dim;
2405 uint head_dim4 = head_dim / 4;
2406 for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2407 uint d = d4 * 4;
2408 float4 acc = float4(0.0);
2409 uint v_base = kv_head * head_dim + d;
2410 uint seq_len4 = seq_len & ~3u;
2411 for (uint s = 0; s < seq_len4; s += 4) {
2412 float sc0 = scores[s];
2413 float sc1 = scores[s + 1];
2414 float sc2 = scores[s + 2];
2415 float sc3 = scores[s + 3];
2416 acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
2417 + sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
2418 + sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
2419 + sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
2420 }
2421 for (uint s = seq_len4; s < seq_len; s++) {
2422 acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
2423 }
2424 *(device float4*)(output_batch + out_off + d) = acc;
2425 }
2426 // Handle remaining dimensions not divisible by 4 (scalar fallback)
2427 for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2428 float acc = 0.0;
2429 uint v_base = kv_head * head_dim + d;
2430 for (uint s = 0; s < seq_len; s++) {
2431 acc += scores[s] * v_cache[s * v_stride + v_base];
2432 }
2433 output_batch[out_off + d] = acc;
2434 }
2435}
2436
2437// ── attention_flash_batch ─────────────────────────────────────────────
2438// Streaming attention with online softmax. Same grid as attention_batch
2439// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2440// matrix is never materialized. K/V positions are processed in a tile of
2441// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2442// the standard flash-attention recurrence:
2443//
2444// m_new = max(m_old, tile_max)
2445// alpha = exp(m_old - m_new)
2446// l_new = alpha * l_old + sum(exp(S - m_new))
2447// O_new = alpha * O_old + sum(exp(S - m_new) * V)
2448// O_final = O / l
2449//
2450// This removes the `scores[2048]` cap in attention_batch (which silently
2451// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2452// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2453//
2454// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2455constant constexpr uint FLASH_K_TILE = 32;
2456constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2457
2458kernel void attention_flash_batch(
2459 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2460 device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
2461 device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
2462 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2463 constant uint& M [[buffer(4)]],
2464 constant uint& base_pos [[buffer(5)]],
2465 constant uint& num_heads [[buffer(6)]],
2466 constant uint& num_kv_heads [[buffer(7)]],
2467 constant uint& head_dim [[buffer(8)]],
2468 constant uint& q_stride [[buffer(9)]],
2469 uint tgid [[threadgroup_position_in_grid]],
2470 uint tid [[thread_index_in_threadgroup]],
2471 uint simd_lane [[thread_index_in_simdgroup]],
2472 uint simd_id [[simdgroup_index_in_threadgroup]])
2473{
2474 uint token_idx = tgid / num_heads;
2475 uint head = tgid % num_heads;
2476 if (token_idx >= M) return;
2477
2478 uint kv_head = head / (num_heads / num_kv_heads);
2479 uint seq_len = base_pos + token_idx + 1; // causal: attend to [0, base_pos + token_idx]
2480 uint q_off = token_idx * q_stride + head * head_dim;
2481 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2482
2483 // Threadgroup state:
2484 // q_sh: Q vector for this (token, head), loaded once
2485 // o_sh: running output vector, updated each K tile
2486 // scores_sh: scores for the current K tile only
2487 // stats: [running max, running sum] (see flash-attention recurrence)
2488 threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2489 threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2490 threadgroup float scores_sh[FLASH_K_TILE];
2491 threadgroup float stats[2];
2492 threadgroup float sg_scratch[8]; // simdgroup-level reduction buffer
2493
2494 // --- Load Q (one row) and zero the running O ---
2495 for (uint d = tid; d < head_dim; d += 256) {
2496 q_sh[d] = q_batch[q_off + d];
2497 o_sh[d] = 0.0f;
2498 }
2499 if (tid == 0) {
2500 stats[0] = -INFINITY;
2501 stats[1] = 0.0f;
2502 }
2503 threadgroup_barrier(mem_flags::mem_threadgroup);
2504
2505 float scale = fast::rsqrt(float(head_dim));
2506 uint v_stride = num_kv_heads * head_dim;
2507 uint v_base = kv_head * head_dim;
2508
2509 // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2510 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2511 uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2512
2513 // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2514 // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2515 for (uint ti = simd_id; ti < tile_n; ti += 8) {
2516 uint k_off = (kv_base + ti) * v_stride + v_base; // same layout as V stride
2517 float dot = 0.0f;
2518 for (uint d = simd_lane; d < head_dim; d += 32) {
2519 dot += q_sh[d] * k_cache[k_off + d];
2520 }
2521 dot = simd_sum(dot);
2522 if (simd_lane == 0) {
2523 scores_sh[ti] = dot * scale;
2524 }
2525 }
2526 threadgroup_barrier(mem_flags::mem_threadgroup);
2527
2528 // [2] Tile max via cooperative reduction.
2529 float local_max = -INFINITY;
2530 for (uint s = tid; s < tile_n; s += 256) {
2531 local_max = max(local_max, scores_sh[s]);
2532 }
2533 local_max = simd_max(local_max);
2534 if (simd_lane == 0) {
2535 sg_scratch[simd_id] = local_max;
2536 }
2537 threadgroup_barrier(mem_flags::mem_threadgroup);
2538 // [3] Merge with running max, compute alpha, rescale running l.
2539 float m_new;
2540 float alpha;
2541 if (tid == 0) {
2542 float tile_max = sg_scratch[0];
2543 for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2544 float m_old = stats[0];
2545 m_new = max(m_old, tile_max);
2546 // First iteration: m_old = -inf → alpha = 0 (reset O).
2547 alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2548 stats[0] = m_new;
2549 stats[1] *= alpha;
2550 // Broadcast via sg_scratch.
2551 sg_scratch[0] = alpha;
2552 sg_scratch[1] = m_new;
2553 }
2554 threadgroup_barrier(mem_flags::mem_threadgroup);
2555 alpha = sg_scratch[0];
2556 m_new = sg_scratch[1];
2557
2558 // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2559 for (uint d = tid; d < head_dim; d += 256) {
2560 o_sh[d] *= alpha;
2561 }
2562 for (uint s = tid; s < tile_n; s += 256) {
2563 scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2564 }
2565 threadgroup_barrier(mem_flags::mem_threadgroup);
2566
2567 // [5] Tile sum → update running l.
2568 float local_sum = 0.0f;
2569 for (uint s = tid; s < tile_n; s += 256) {
2570 local_sum += scores_sh[s];
2571 }
2572 local_sum = simd_sum(local_sum);
2573 if (simd_lane == 0) {
2574 sg_scratch[simd_id] = local_sum;
2575 }
2576 threadgroup_barrier(mem_flags::mem_threadgroup);
2577 if (tid == 0) {
2578 float tile_sum = 0.0f;
2579 for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2580 stats[1] += tile_sum;
2581 }
2582 threadgroup_barrier(mem_flags::mem_threadgroup);
2583
2584 // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2585 for (uint d = tid; d < head_dim; d += 256) {
2586 float acc = 0.0f;
2587 for (uint s = 0; s < tile_n; s++) {
2588 acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2589 }
2590 o_sh[d] += acc;
2591 }
2592 threadgroup_barrier(mem_flags::mem_threadgroup);
2593 }
2594
2595 // --- Normalize and write output ---
2596 float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2597 for (uint d = tid; d < head_dim; d += 256) {
2598 output_batch[out_off + d] = o_sh[d] * inv_l;
2599 }
2600}
2601
2602// ── attention_mma_flash_batch ─────────────────────────────────────────
2603// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2604// both Q·K^T and P·V. Processes Q_BLOCK=8 tokens of one head per
2605// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2606// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2607// arithmetic.
2608//
2609// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2610// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2611// to attention_batch / attention_flash_batch otherwise.
2612//
2613// Online softmax recurrence is identical to attention_flash_batch but
2614// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2615constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2616constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2617constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2618
2619kernel void attention_mma_flash_batch(
2620 device const float* q_batch [[buffer(0)]],
2621 device const half* k_cache [[buffer(1)]],
2622 device const half* v_cache [[buffer(2)]],
2623 device float* output_batch [[buffer(3)]],
2624 constant uint& M [[buffer(4)]],
2625 constant uint& base_pos [[buffer(5)]],
2626 constant uint& num_heads [[buffer(6)]],
2627 constant uint& num_kv_heads [[buffer(7)]],
2628 constant uint& head_dim [[buffer(8)]],
2629 constant uint& q_stride [[buffer(9)]],
2630 uint2 tgid [[threadgroup_position_in_grid]],
2631 uint tid [[thread_index_in_threadgroup]],
2632 uint simd_lane [[thread_index_in_simdgroup]],
2633 uint simd_id [[simdgroup_index_in_threadgroup]])
2634{
2635 uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2636 uint head = tgid.y;
2637 if (q_block_start >= M) return;
2638 uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2639
2640 uint kv_head = head / (num_heads / num_kv_heads);
2641 // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2642 // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2643 uint seq_len = base_pos + q_block_start + q_valid;
2644 float scale = fast::rsqrt(float(head_dim));
2645
2646 uint kv_stride = num_kv_heads * head_dim;
2647 uint kv_base_off = kv_head * head_dim;
2648
2649 // ── Threadgroup memory ──
2650 // q_sh: [Q_BLOCK, head_dim] half — Q tile, loaded once
2651 // k_sh: [K_BLOCK, head_dim] half — K tile, refreshed per kv_base iter
2652 // v_sh: [K_BLOCK, head_dim] half — V tile, refreshed per kv_base iter
2653 // s_sh: [Q_BLOCK, K_BLOCK] float — raw Q·K^T scores, then scaled+masked
2654 // p_sh: [Q_BLOCK, K_BLOCK] half — softmax probabilities (for P·V MMA)
2655 // o_sh: [Q_BLOCK, head_dim] float — running output accumulator
2656 // m_sh: [Q_BLOCK] float — running max per Q row
2657 // l_sh: [Q_BLOCK] float — running softmax denominator per Q row
2658 // scratch: 4*Q_BLOCK floats — per-row reduction scratch
2659 threadgroup half q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2660 threadgroup half k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2661 threadgroup half v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2662 threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2663 threadgroup half p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2664 threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2665 threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2666 threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2667 threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2668
2669 // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2670 uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2671 for (uint i = tid; i < qblock_elems; i += 128) {
2672 uint q = i / head_dim;
2673 uint d = i % head_dim;
2674 if (q < q_valid) {
2675 uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2676 q_sh[q * head_dim + d] = half(q_batch[q_off]);
2677 } else {
2678 q_sh[q * head_dim + d] = half(0);
2679 }
2680 o_sh[q * head_dim + d] = 0.0f;
2681 }
2682 if (tid < FLASH_MMA_Q_BLOCK) {
2683 m_sh[tid] = -INFINITY;
2684 l_sh[tid] = 0.0f;
2685 }
2686 threadgroup_barrier(mem_flags::mem_threadgroup);
2687
2688 // ── Stream K/V in K_BLOCK chunks ──
2689 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2690 uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2691
2692 // Load K and V tile into TG memory (as half).
2693 uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2694 for (uint i = tid; i < kv_tile_elems; i += 128) {
2695 uint k_pos = i / head_dim;
2696 uint d = i % head_dim;
2697 if (k_pos < tile_n) {
2698 uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2699 k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2700 v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2701 } else {
2702 k_sh[k_pos * head_dim + d] = half(0);
2703 v_sh[k_pos * head_dim + d] = half(0);
2704 }
2705 }
2706 threadgroup_barrier(mem_flags::mem_threadgroup);
2707
2708 // ── Phase 1: S = Q @ K^T via MMA ──
2709 // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2710 // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2711 // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2712 {
2713 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2714 uint dim_chunks = head_dim / 8;
2715 for (uint dc = 0; dc < dim_chunks; dc++) {
2716 simdgroup_matrix<half, 8, 8> A, B;
2717 // A = Q[0:8, dc*8 : dc*8+8] (rows of Q, no transpose)
2718 simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2719 // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2720 // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2721 // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2722 // transpose=true, which places it in the register as K^T of that sub-block.
2723 simdgroup_load(B,
2724 k_sh + (simd_id * 8) * head_dim + dc * 8,
2725 head_dim, ulong2(0, 0), true);
2726 simdgroup_multiply_accumulate(C, A, B, C);
2727 }
2728 // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2729 simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2730 }
2731 threadgroup_barrier(mem_flags::mem_threadgroup);
2732
2733 // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2734 // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2735 uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2736 for (uint i = tid; i < s_elems; i += 128) {
2737 uint q = i / FLASH_MMA_K_BLOCK;
2738 uint k = i % FLASH_MMA_K_BLOCK;
2739 uint global_q = q_block_start + q;
2740 uint global_kv = kv_base + k;
2741 bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2742 s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2743 }
2744 threadgroup_barrier(mem_flags::mem_threadgroup);
2745
2746 // ── Phase 2b: per-row max via simdgroup reduction ──
2747 // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2748 // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2749 {
2750 uint row_base = simd_id * 2;
2751 for (uint qr = 0; qr < 2; qr++) {
2752 uint q = row_base + qr;
2753 float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2754 float row_max = simd_max(my);
2755 if (simd_lane == 0) {
2756 scratch[q] = row_max; // tile_max[q]
2757 }
2758 }
2759 }
2760 threadgroup_barrier(mem_flags::mem_threadgroup);
2761
2762 // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2763 if (tid < FLASH_MMA_Q_BLOCK) {
2764 uint q = tid;
2765 float m_old = m_sh[q];
2766 float tile_max = scratch[q];
2767 float m_new = max(m_old, tile_max);
2768 float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2769 m_sh[q] = m_new;
2770 l_sh[q] = l_sh[q] * alpha;
2771 // scratch[q] = m_new (for phase 2d)
2772 // scratch[Q_BLOCK + q] = alpha (for phase 3)
2773 scratch[q] = m_new;
2774 scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2775 }
2776 threadgroup_barrier(mem_flags::mem_threadgroup);
2777
2778 // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2779 {
2780 uint row_base = simd_id * 2;
2781 for (uint qr = 0; qr < 2; qr++) {
2782 uint q = row_base + qr;
2783 float m_new = scratch[q];
2784 float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2785 p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2786 float row_sum = simd_sum(p);
2787 if (simd_lane == 0) {
2788 scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2789 }
2790 }
2791 }
2792 threadgroup_barrier(mem_flags::mem_threadgroup);
2793
2794 // ── Phase 2e: l_sh += tile_sum ──
2795 if (tid < FLASH_MMA_Q_BLOCK) {
2796 uint q = tid;
2797 l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2798 }
2799
2800 // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2801 threadgroup_barrier(mem_flags::mem_threadgroup);
2802 for (uint i = tid; i < qblock_elems; i += 128) {
2803 uint q = i / head_dim;
2804 float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2805 o_sh[i] *= alpha;
2806 }
2807 threadgroup_barrier(mem_flags::mem_threadgroup);
2808
2809 // ── Phase 4: O += P @ V via MMA ──
2810 // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2811 // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2812 // For head_dim=64: 16 dims/sg = 2 tiles. head_dim=128: 32 dims/sg = 4 tiles.
2813 {
2814 uint dims_per_sg = head_dim / 4; // 16 or 32
2815 uint tiles_per_sg = dims_per_sg / 8; // 2 or 4
2816 uint sg_d_base = simd_id * dims_per_sg;
2817 for (uint t = 0; t < tiles_per_sg; t++) {
2818 uint d_base = sg_d_base + t * 8;
2819 simdgroup_matrix<float, 8, 8> O_acc;
2820 simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2821 uint k_chunks = FLASH_MMA_K_BLOCK / 8; // 4
2822 for (uint kc = 0; kc < k_chunks; kc++) {
2823 simdgroup_matrix<half, 8, 8> A, B;
2824 simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2825 ulong2(0, 0), false);
2826 simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2827 ulong2(0, 0), false);
2828 simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2829 }
2830 simdgroup_store(O_acc, o_sh + d_base, head_dim);
2831 }
2832 }
2833 threadgroup_barrier(mem_flags::mem_threadgroup);
2834 }
2835
2836 // ── Finalize: O /= l, write to output ──
2837 for (uint i = tid; i < qblock_elems; i += 128) {
2838 uint q = i / head_dim;
2839 uint d = i % head_dim;
2840 if (q < q_valid) {
2841 float l = l_sh[q];
2842 float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2843 uint token_idx = q_block_start + q;
2844 uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2845 output_batch[out_off] = o_sh[i] * inv_l;
2846 }
2847 }
2848}
2849
2850// ── rope_qk_batch ─────────────────────────────────────────────────────
2851// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2852// launch + memory barrier per layer. Both Q and K live in the same
2853// qkv_data buffer at different offsets within each token's row.
2854// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2855kernel void rope_qk_batch(
2856 device float* qkv_data [[buffer(0)]], // [M, qkv_stride]
2857 constant uint& M [[buffer(1)]], // num tokens
2858 constant uint& base_pos [[buffer(2)]], // starting position
2859 constant uint& num_q_heads [[buffer(3)]],
2860 constant uint& num_kv_heads [[buffer(4)]],
2861 constant uint& head_dim [[buffer(5)]],
2862 constant uint& qkv_stride [[buffer(6)]], // floats per row
2863 constant float& theta [[buffer(7)]],
2864 uint id [[thread_position_in_grid]])
2865{
2866 uint half_dim = head_dim / 2;
2867 uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2868 uint token = id / total_pairs;
2869 uint pair = id % total_pairs;
2870 if (token >= M) return;
2871
2872 uint pos = base_pos + token;
2873 uint q_pairs = num_q_heads * half_dim;
2874
2875 uint h, i, offset;
2876 if (pair < q_pairs) {
2877 // Q head
2878 h = pair / half_dim;
2879 i = pair % half_dim;
2880 offset = token * qkv_stride + h * head_dim + i * 2;
2881 } else {
2882 // K head
2883 uint kp = pair - q_pairs;
2884 h = kp / half_dim;
2885 i = kp % half_dim;
2886 uint k_start = num_q_heads * head_dim;
2887 offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2888 }
2889
2890 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2891 float angle = float(pos) * freq;
2892 float cos_val = cos(angle);
2893 float sin_val = sin(angle);
2894
2895 float x0 = qkv_data[offset];
2896 float x1 = qkv_data[offset + 1];
2897 qkv_data[offset] = x0 * cos_val - x1 * sin_val;
2898 qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2899}
2900
2901// ── copy_kv_both_batch ────────────────────────────────────────────────
2902// Fused K+V cache copy in a single dispatch: copies both K and V from
2903// the strided batch QKV buffer to their respective KV cache buffers.
2904// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2905kernel void copy_kv_both_batch(
2906 device const float* src [[buffer(0)]], // batch QKV buffer [M, qkv_stride] f32
2907 device half* k_dst [[buffer(1)]], // K cache [max_seq, kv_dim] f16
2908 device half* v_dst [[buffer(2)]], // V cache [max_seq, kv_dim] f16
2909 constant uint& M [[buffer(3)]], // num tokens in batch
2910 constant uint& kv_dim [[buffer(4)]], // elements per KV vector
2911 constant uint& base_pos [[buffer(5)]], // starting position in cache
2912 constant uint& src_stride [[buffer(6)]], // floats per row in src (qkv_stride)
2913 constant uint& k_offset [[buffer(7)]], // float offset of K within each src row
2914 constant uint& v_offset [[buffer(8)]], // float offset of V within each src row
2915 uint id [[thread_position_in_grid]])
2916{
2917 // Total elements = M * kv_dim * 2 (K + V)
2918 uint total_kv = M * kv_dim;
2919 if (id >= total_kv * 2) return;
2920
2921 uint is_v = id / total_kv; // 0 = K, 1 = V
2922 uint local_id = id % total_kv;
2923 uint token = local_id / kv_dim;
2924 uint d = local_id % kv_dim;
2925
2926 uint dst_off = (base_pos + token) * kv_dim + d;
2927 uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2928
2929 if (is_v) {
2930 v_dst[dst_off] = half(src[src_off]);
2931 } else {
2932 k_dst[dst_off] = half(src[src_off]);
2933 }
2934}
2935"#
2936 .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2937 .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2938}
2939
2940fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2945 let mut code = String::with_capacity(48 * 1024);
2946 emit_model_header(&mut code, config)?;
2947 emit_metal_model_struct(&mut code, config)?;
2948 emit_layer_buffers_struct(&mut code, config)?;
2949 emit_metal_model_impl(&mut code, config)?;
2950 emit_helper_functions(&mut code)?;
2951 Ok(code)
2952}
2953
2954fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2955 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2956 writeln!(
2957 code,
2958 "//! Model: {} ({} layers, hidden={})",
2959 config.architecture, config.num_layers, config.hidden_size
2960 )?;
2961 writeln!(code, "//!")?;
2962 writeln!(
2963 code,
2964 "//! Uses native Metal compute pipelines via the metal crate."
2965 )?;
2966 writeln!(code)?;
2967 writeln!(code, "#![allow(dead_code)]")?;
2968 writeln!(code)?;
2969 writeln!(code, "use metal::*;")?;
2970 writeln!(code, "#[allow(unused_imports)]")?;
2971 writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2972 writeln!(code, "use std::mem;")?;
2973 writeln!(code)?;
2974
2975 writeln!(
2977 code,
2978 "// ── Model constants ──────────────────────────────────"
2979 )?;
2980 writeln!(
2981 code,
2982 "pub const HIDDEN_SIZE: usize = {};",
2983 config.hidden_size
2984 )?;
2985 writeln!(
2986 code,
2987 "pub const INTERMEDIATE_SIZE: usize = {};",
2988 config.intermediate_size
2989 )?;
2990 writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2991 writeln!(
2992 code,
2993 "pub const NUM_HEADS: usize = {};",
2994 config.num_attention_heads
2995 )?;
2996 writeln!(
2997 code,
2998 "pub const NUM_KV_HEADS: usize = {};",
2999 config.num_kv_heads
3000 )?;
3001 writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
3002 writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
3003 let effective_seq_len = config.max_seq_len.min(4096);
3004 writeln!(
3005 code,
3006 "pub const MAX_SEQ_LEN: usize = {}; // capped from model's {}",
3007 effective_seq_len, config.max_seq_len
3008 )?;
3009 writeln!(
3010 code,
3011 "pub const RMS_NORM_EPS: f32 = {:e};",
3012 config.rms_norm_eps
3013 )?;
3014 writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3015 writeln!(
3016 code,
3017 "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3018 )?;
3019 writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3020 writeln!(code)?;
3021
3022 Ok(())
3023}
3024
3025fn emit_metal_model_struct(
3026 code: &mut String,
3027 config: &ModelConfig,
3028) -> Result<(), MetalCodegenError> {
3029 writeln!(
3030 code,
3031 "// ── MetalModel ──────────────────────────────────────────"
3032 )?;
3033 writeln!(code)?;
3034 writeln!(
3035 code,
3036 "/// Metal-accelerated transformer model for Apple Silicon."
3037 )?;
3038 writeln!(code, "///")?;
3039 writeln!(
3040 code,
3041 "/// Uses unified memory for zero-copy weight access and native Metal"
3042 )?;
3043 writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3044 writeln!(code, "pub struct MetalModel {{")?;
3045 writeln!(code, " device: Device,")?;
3046 writeln!(code, " queue: CommandQueue,")?;
3047 writeln!(code)?;
3048 writeln!(code, " // ── Compute pipelines ──")?;
3049 writeln!(code, " matmul_pipeline: ComputePipelineState,")?;
3050 writeln!(code, " matmul_q8_pipeline: ComputePipelineState,")?;
3051 writeln!(code, " matmul_q4_pipeline: ComputePipelineState,")?;
3052 writeln!(code, " rms_norm_pipeline: ComputePipelineState,")?;
3053 writeln!(code, " rope_pipeline: ComputePipelineState,")?;
3054 writeln!(code, " softmax_pipeline: ComputePipelineState,")?;
3055 writeln!(code, " silu_mul_pipeline: ComputePipelineState,")?;
3056 writeln!(code, " silu_mul_fused_pipeline: ComputePipelineState,")?;
3057 writeln!(code, " add_pipeline: ComputePipelineState,")?;
3058 writeln!(code, " attention_pipeline: ComputePipelineState,")?;
3059 writeln!(code, " add_inplace_pipeline: ComputePipelineState,")?;
3060 writeln!(code, " copy_pipeline: ComputePipelineState,")?;
3061 writeln!(code, " copy_offset_pipeline: ComputePipelineState,")?;
3062 writeln!(
3063 code,
3064 " copy_f32_to_f16_offset_pipeline: ComputePipelineState,"
3065 )?;
3066 writeln!(code)?;
3067 writeln!(code, " // ── Batched prefill pipelines ──")?;
3068 writeln!(code, " matmul_batch_pipeline: ComputePipelineState,")?;
3069 writeln!(code, " matmul_q8_batch_pipeline: ComputePipelineState,")?;
3070 writeln!(
3071 code,
3072 " matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3073 )?;
3074 writeln!(code, " matmul_q8_mma_pipeline: ComputePipelineState,")?;
3075 writeln!(code, " matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3076 writeln!(
3077 code,
3078 " matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3079 )?;
3080 writeln!(
3081 code,
3082 " matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3083 )?;
3084 writeln!(
3085 code,
3086 " matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3087 )?;
3088 if config.qkv_bias {
3089 writeln!(code, " add_bias_batch_pipeline: ComputePipelineState,")?;
3090 }
3091 writeln!(code, " matmul_q4_batch_pipeline: ComputePipelineState,")?;
3092 writeln!(code, " rms_norm_batch_pipeline: ComputePipelineState,")?;
3093 writeln!(code, " rope_batch_pipeline: ComputePipelineState,")?;
3094 writeln!(
3095 code,
3096 " silu_mul_fused_batch_pipeline: ComputePipelineState,"
3097 )?;
3098 writeln!(
3099 code,
3100 " add_inplace_batch_pipeline: ComputePipelineState,"
3101 )?;
3102 writeln!(
3103 code,
3104 " copy_embedding_batch_pipeline: ComputePipelineState,"
3105 )?;
3106 writeln!(code, " attention_batch_pipeline: ComputePipelineState,")?;
3107 writeln!(
3108 code,
3109 " attention_flash_batch_pipeline: ComputePipelineState,"
3110 )?;
3111 writeln!(
3112 code,
3113 " attention_mma_flash_batch_pipeline: ComputePipelineState,"
3114 )?;
3115 writeln!(code, " copy_kv_batch_pipeline: ComputePipelineState,")?;
3116 writeln!(code, " rope_qk_batch_pipeline: ComputePipelineState,")?;
3117 writeln!(
3118 code,
3119 " copy_kv_both_batch_pipeline: ComputePipelineState,"
3120 )?;
3121 writeln!(code)?;
3122 writeln!(code, " // ── Weight buffers (Metal shared memory) ──")?;
3123 writeln!(
3124 code,
3125 " /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3126 )?;
3127 writeln!(code, " embed_buf: Buffer,")?;
3128 writeln!(code)?;
3129 writeln!(code, " /// Per-layer weight buffers")?;
3130 writeln!(code, " layers: Vec<LayerBuffers>,")?;
3131 writeln!(code)?;
3132 writeln!(code, " /// Final layer-norm weight [HIDDEN_SIZE]")?;
3133 writeln!(code, " norm_buf: Buffer,")?;
3134 writeln!(code)?;
3135 writeln!(
3136 code,
3137 " /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3138 )?;
3139 writeln!(code, " lm_head_buf: Buffer,")?;
3140 writeln!(code)?;
3141 writeln!(
3142 code,
3143 " // ── Working buffers (pre-allocated, reused every forward pass) ──"
3144 )?;
3145 writeln!(code, " hidden_buf: Buffer,")?;
3146 writeln!(code, " residual_buf: Buffer,")?;
3147 writeln!(code, " normed_buf: Buffer,")?;
3148 writeln!(
3149 code,
3150 " /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3151 )?;
3152 writeln!(code, " qkv_buf: Buffer,")?;
3153 writeln!(code, " attn_out_buf: Buffer,")?;
3154 writeln!(code, " attn_proj_buf: Buffer,")?;
3155 writeln!(
3156 code,
3157 " /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3158 )?;
3159 writeln!(code, " gate_up_buf: Buffer,")?;
3160 writeln!(code, " ffn_hidden_buf: Buffer,")?;
3161 writeln!(code, " ffn_out_buf: Buffer,")?;
3162 writeln!(code, " add_tmp_buf: Buffer,")?;
3163 writeln!(code, " logits_buf: Buffer,")?;
3164 writeln!(code)?;
3165 writeln!(code, " // ── Batched prefill working buffers ──")?;
3166 writeln!(code, " /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3167 writeln!(code, " batch_hidden_buf: Buffer,")?;
3168 writeln!(
3169 code,
3170 " /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3171 )?;
3172 writeln!(code, " batch_residual_buf: Buffer,")?;
3173 writeln!(code, " /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3174 writeln!(code, " batch_qkv_buf: Buffer,")?;
3175 writeln!(
3176 code,
3177 " /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3178 )?;
3179 writeln!(code, " batch_attn_out_buf: Buffer,")?;
3180 writeln!(
3181 code,
3182 " /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3183 )?;
3184 writeln!(code, " batch_attn_proj_buf: Buffer,")?;
3185 writeln!(
3186 code,
3187 " /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3188 )?;
3189 writeln!(code, " batch_gate_up_buf: Buffer,")?;
3190 writeln!(
3191 code,
3192 " /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3193 )?;
3194 writeln!(code, " batch_ffn_hidden_buf: Buffer,")?;
3195 writeln!(code, " /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3196 writeln!(code, " batch_ffn_out_buf: Buffer,")?;
3197 writeln!(code, " /// Token IDs buffer for batch embedding lookup")?;
3198 writeln!(code, " batch_tokens_buf: Buffer,")?;
3199 writeln!(code, " /// Positions buffer for batch RoPE")?;
3200 writeln!(code, " batch_positions_buf: Buffer,")?;
3201 writeln!(code)?;
3202 writeln!(code, " // ── KV cache buffers (per-layer) ──")?;
3203 writeln!(code, " k_cache: Vec<Buffer>, // per-layer")?;
3204 writeln!(code, " v_cache: Vec<Buffer>, // per-layer")?;
3205 writeln!(code)?;
3206 writeln!(code, " // ── Inference state ──")?;
3207 writeln!(code, " pos: usize,")?;
3208 writeln!(code)?;
3209 writeln!(
3210 code,
3211 " /// Previous command buffer for double-buffered prefill."
3212 )?;
3213 writeln!(
3214 code,
3215 " /// While the GPU executes token N, the CPU can encode token N+1."
3216 )?;
3217 writeln!(code, " prev_cmd: Option<CommandBuffer>,")?;
3218 writeln!(code, "}}")?;
3219 writeln!(code)?;
3220
3221 Ok(())
3222}
3223
3224fn emit_layer_buffers_struct(
3225 code: &mut String,
3226 config: &ModelConfig,
3227) -> Result<(), MetalCodegenError> {
3228 writeln!(
3229 code,
3230 "/// Per-layer weight buffers for attention and FFN projections."
3231 )?;
3232 writeln!(code, "struct LayerBuffers {{")?;
3233 writeln!(code, " attn_norm: Buffer,")?;
3234 writeln!(
3235 code,
3236 " /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3237 )?;
3238 writeln!(code, " qkv_weight: Buffer,")?;
3239 if config.qkv_bias {
3240 writeln!(
3241 code,
3242 " /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3243 )?;
3244 writeln!(code, " qkv_bias: Buffer,")?;
3245 }
3246 writeln!(code, " o_weight: Buffer,")?;
3247 writeln!(code, " ffn_norm: Buffer,")?;
3248 writeln!(
3249 code,
3250 " /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3251 )?;
3252 writeln!(code, " gate_up_weight: Buffer,")?;
3253 writeln!(code, " down_weight: Buffer,")?;
3254 writeln!(code, "}}")?;
3255 writeln!(code)?;
3256
3257 Ok(())
3258}
3259
3260fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3261 let hidden = config.hidden_size;
3262 let intermediate = config.intermediate_size;
3263 let _num_layers = config.num_layers;
3264 let _num_heads = config.num_attention_heads;
3265 let num_kv_heads = config.num_kv_heads;
3266 let head_dim = config.head_dim;
3267 let vocab = config.vocab_size;
3268 let effective_seq_len = config.max_seq_len.min(4096);
3269 let is_q8 = config.dtype == DType::Q8_0;
3270 let is_q4 = config.dtype == DType::Q4_0;
3271 let kv_dim = num_kv_heads * head_dim;
3272
3273 writeln!(code, "impl MetalModel {{")?;
3274
3275 writeln!(
3277 code,
3278 " /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3279 )?;
3280 writeln!(code, " ///")?;
3281 writeln!(
3282 code,
3283 " /// `weights` is the raw weight blob produced by `forge export-weights`."
3284 )?;
3285 writeln!(code, " pub fn new(weights: &[u8]) -> Self {{")?;
3286 writeln!(
3287 code,
3288 " let device = Device::system_default().expect(\"no Metal device found\");"
3289 )?;
3290 writeln!(code, " let queue = device.new_command_queue();")?;
3291 writeln!(code)?;
3292
3293 writeln!(
3295 code,
3296 " // Compile Metal shaders from embedded source"
3297 )?;
3298 writeln!(
3299 code,
3300 " let shader_source = include_str!(\"../shaders/kernels.metal\");"
3301 )?;
3302 writeln!(
3303 code,
3304 " let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3305 )?;
3306 writeln!(
3307 code,
3308 " .expect(\"failed to compile Metal shaders\");"
3309 )?;
3310 writeln!(code)?;
3311
3312 writeln!(code, " // Create compute pipelines")?;
3314 for (var, fn_name) in [
3315 ("matmul_pipeline", "matmul_vec"),
3316 ("matmul_q8_pipeline", "matmul_vec_q8"),
3317 ("matmul_q4_pipeline", "matmul_vec_q4"),
3318 ("rms_norm_pipeline", "rms_norm"),
3319 ("rope_pipeline", "rope"),
3320 ("softmax_pipeline", "softmax"),
3321 ("silu_mul_pipeline", "silu_mul"),
3322 ("silu_mul_fused_pipeline", "silu_mul_fused"),
3323 ("add_pipeline", "elementwise_add"),
3324 ("attention_pipeline", "attention"),
3325 ("add_inplace_pipeline", "add_inplace"),
3326 ("copy_pipeline", "copy_buffer"),
3327 ("copy_offset_pipeline", "copy_offset"),
3328 ("copy_f32_to_f16_offset_pipeline", "copy_f32_to_f16_offset"),
3329 ("matmul_batch_pipeline", "matmul_vec_batch"),
3330 ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3331 ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3332 ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3333 ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3334 ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3335 ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3336 ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3337 ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3338 ("rms_norm_batch_pipeline", "rms_norm_batch"),
3339 ("rope_batch_pipeline", "rope_batch"),
3340 ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3341 ("add_inplace_batch_pipeline", "add_inplace_batch"),
3342 ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3343 ("attention_batch_pipeline", "attention_batch"),
3344 ("attention_flash_batch_pipeline", "attention_flash_batch"),
3345 (
3346 "attention_mma_flash_batch_pipeline",
3347 "attention_mma_flash_batch",
3348 ),
3349 ("copy_kv_batch_pipeline", "copy_kv_batch"),
3350 ("rope_qk_batch_pipeline", "rope_qk_batch"),
3351 ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3352 ] {
3353 writeln!(
3354 code,
3355 " let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3356 )?;
3357 }
3358 if config.qkv_bias {
3359 writeln!(
3360 code,
3361 " let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3362 )?;
3363 }
3364 writeln!(code)?;
3365
3366 writeln!(
3368 code,
3369 " // Load weights into Metal shared-memory buffers"
3370 )?;
3371 writeln!(code, " let f32_size = mem::size_of::<f32>();")?;
3372 writeln!(code, " let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3373 writeln!(code, " let hidden_elems = HIDDEN_SIZE;")?;
3374 writeln!(code)?;
3375 writeln!(
3376 code,
3377 " let cursor = std::cell::Cell::new(0usize); // byte cursor into `weights`"
3378 )?;
3379 writeln!(code)?;
3380 writeln!(
3381 code,
3382 " // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3383 )?;
3384 writeln!(
3385 code,
3386 " let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3387 )?;
3388 writeln!(code, " let byte_len = n * f32_size;")?;
3389 writeln!(code, " let cur = cursor.get();")?;
3390 writeln!(
3391 code,
3392 " let data = &weights[cur..cur + byte_len];"
3393 )?;
3394 writeln!(code, " cursor.set(cur + byte_len);")?;
3395 writeln!(code, " device.new_buffer_with_data(")?;
3396 writeln!(code, " data.as_ptr() as *const _,")?;
3397 writeln!(code, " byte_len as u64,")?;
3398 writeln!(
3399 code,
3400 " MTLResourceOptions::StorageModeShared,"
3401 )?;
3402 writeln!(code, " )")?;
3403 writeln!(code, " }};")?;
3404 writeln!(code)?;
3405
3406 if is_q8 {
3407 writeln!(
3412 code,
3413 " // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3414 )?;
3415 writeln!(
3416 code,
3417 " // as raw bytes into a Metal buffer (no dequantization)."
3418 )?;
3419 writeln!(
3420 code,
3421 " // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3422 )?;
3423 writeln!(
3424 code,
3425 " let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3426 )?;
3427 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3428 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3429 writeln!(code, " let total_raw = rows * row_bytes;")?;
3430 writeln!(code, " let cur = cursor.get();")?;
3431 writeln!(
3432 code,
3433 " let data = &weights[cur..cur + total_raw];"
3434 )?;
3435 writeln!(code, " cursor.set(cur + total_raw);")?;
3436 writeln!(code, " device.new_buffer_with_data(")?;
3437 writeln!(code, " data.as_ptr() as *const _,")?;
3438 writeln!(code, " total_raw as u64,")?;
3439 writeln!(
3440 code,
3441 " MTLResourceOptions::StorageModeShared,"
3442 )?;
3443 writeln!(code, " )")?;
3444 writeln!(code, " }};")?;
3445 writeln!(code)?;
3446 writeln!(
3447 code,
3448 " // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3449 )?;
3450 writeln!(
3451 code,
3452 " // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3453 )?;
3454 writeln!(
3455 code,
3456 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3457 )?;
3458 writeln!(
3459 code,
3460 " let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3461 )?;
3462 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3463 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3464 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3465 writeln!(code, " let cur = cursor.get();")?;
3466 writeln!(
3467 code,
3468 " let data = &weights[cur..cur + total_raw];"
3469 )?;
3470 writeln!(code, " cursor.set(cur + total_raw);")?;
3471 writeln!(code, " device.new_buffer_with_data(")?;
3472 writeln!(code, " data.as_ptr() as *const _,")?;
3473 writeln!(code, " total_raw as u64,")?;
3474 writeln!(
3475 code,
3476 " MTLResourceOptions::StorageModeShared,"
3477 )?;
3478 writeln!(code, " )")?;
3479 writeln!(code, " }};")?;
3480 writeln!(code)?;
3481 }
3482
3483 if is_q4 {
3484 writeln!(
3489 code,
3490 " // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3491 )?;
3492 writeln!(
3493 code,
3494 " // as raw bytes into a Metal buffer (no dequantization)."
3495 )?;
3496 writeln!(
3497 code,
3498 " // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3499 )?;
3500 writeln!(
3501 code,
3502 " let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3503 )?;
3504 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3505 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3506 writeln!(code, " let total_raw = rows * row_bytes;")?;
3507 writeln!(code, " let cur = cursor.get();")?;
3508 writeln!(
3509 code,
3510 " let data = &weights[cur..cur + total_raw];"
3511 )?;
3512 writeln!(code, " cursor.set(cur + total_raw);")?;
3513 writeln!(code, " device.new_buffer_with_data(")?;
3514 writeln!(code, " data.as_ptr() as *const _,")?;
3515 writeln!(code, " total_raw as u64,")?;
3516 writeln!(
3517 code,
3518 " MTLResourceOptions::StorageModeShared,"
3519 )?;
3520 writeln!(code, " )")?;
3521 writeln!(code, " }};")?;
3522 writeln!(code)?;
3523 writeln!(
3524 code,
3525 " // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3526 )?;
3527 writeln!(
3528 code,
3529 " // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3530 )?;
3531 writeln!(
3532 code,
3533 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3534 )?;
3535 writeln!(
3536 code,
3537 " let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3538 )?;
3539 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3540 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3541 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3542 writeln!(code, " let cur = cursor.get();")?;
3543 writeln!(
3544 code,
3545 " let data = &weights[cur..cur + total_raw];"
3546 )?;
3547 writeln!(code, " cursor.set(cur + total_raw);")?;
3548 writeln!(code, " device.new_buffer_with_data(")?;
3549 writeln!(code, " data.as_ptr() as *const _,")?;
3550 writeln!(code, " total_raw as u64,")?;
3551 writeln!(
3552 code,
3553 " MTLResourceOptions::StorageModeShared,"
3554 )?;
3555 writeln!(code, " )")?;
3556 writeln!(code, " }};")?;
3557 writeln!(code)?;
3558 }
3559
3560 writeln!(
3561 code,
3562 " let embed_buf = next_f32_buffer(&device, embed_elems);"
3563 )?;
3564 writeln!(code)?;
3565
3566 writeln!(
3568 code,
3569 " let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3570 )?;
3571 writeln!(code, " for _layer in 0..NUM_LAYERS {{")?;
3572
3573 writeln!(
3575 code,
3576 " let attn_norm = next_f32_buffer(&device, hidden_elems);"
3577 )?;
3578
3579 let qkv_rows = hidden + 2 * kv_dim;
3580 if is_q8 {
3581 writeln!(
3583 code,
3584 " let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3585 )?;
3586 if config.qkv_bias {
3587 writeln!(
3588 code,
3589 " // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3590 )?;
3591 writeln!(
3592 code,
3593 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3594 )?;
3595 }
3596 writeln!(
3597 code,
3598 " let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3599 )?;
3600 } else if is_q4 {
3601 writeln!(
3602 code,
3603 " let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3604 )?;
3605 if config.qkv_bias {
3606 writeln!(
3607 code,
3608 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3609 )?;
3610 }
3611 writeln!(
3612 code,
3613 " let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3614 )?;
3615 } else {
3616 writeln!(
3617 code,
3618 " let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3619 )?;
3620 if config.qkv_bias {
3621 writeln!(
3622 code,
3623 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3624 )?;
3625 }
3626 writeln!(
3627 code,
3628 " let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3629 )?;
3630 }
3631
3632 writeln!(
3634 code,
3635 " let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3636 )?;
3637
3638 let gate_up_rows = 2 * intermediate;
3639 if is_q8 {
3640 writeln!(
3642 code,
3643 " let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3644 )?;
3645 writeln!(
3646 code,
3647 " let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3648 )?;
3649 } else if is_q4 {
3650 writeln!(
3652 code,
3653 " let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3654 )?;
3655 writeln!(
3656 code,
3657 " let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3658 )?;
3659 } else {
3660 writeln!(
3662 code,
3663 " let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3664 )?;
3665 writeln!(
3666 code,
3667 " let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3668 )?;
3669 }
3670
3671 writeln!(code, " layers.push(LayerBuffers {{")?;
3672 writeln!(code, " attn_norm,")?;
3673 writeln!(code, " qkv_weight,")?;
3674 if config.qkv_bias {
3675 writeln!(code, " qkv_bias,")?;
3676 }
3677 writeln!(code, " o_weight,")?;
3678 writeln!(code, " ffn_norm,")?;
3679 writeln!(code, " gate_up_weight,")?;
3680 writeln!(code, " down_weight,")?;
3681 writeln!(code, " }});")?;
3682 writeln!(code, " }}")?;
3683 writeln!(code)?;
3684
3685 writeln!(
3687 code,
3688 " let norm_buf = next_f32_buffer(&device, hidden_elems);"
3689 )?;
3690 writeln!(code)?;
3691
3692 if is_q8 {
3694 writeln!(
3695 code,
3696 " let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3697 )?;
3698 } else if is_q4 {
3699 writeln!(
3700 code,
3701 " let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3702 )?;
3703 } else {
3704 writeln!(
3705 code,
3706 " let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3707 )?;
3708 }
3709 writeln!(code)?;
3710
3711 let hidden_bytes = hidden * 4;
3713 let _kv_dim_bytes = kv_dim * 4;
3714 let intermediate_bytes = intermediate * 4;
3715 let vocab_bytes = vocab * 4;
3716 let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 2;
3719
3720 writeln!(
3721 code,
3722 " // Allocate working buffers (shared memory for zero-copy)"
3723 )?;
3724 writeln!(
3725 code,
3726 " let opts = MTLResourceOptions::StorageModeShared;"
3727 )?;
3728 writeln!(
3729 code,
3730 " let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3731 )?;
3732 writeln!(
3733 code,
3734 " let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3735 )?;
3736 let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3737 writeln!(
3738 code,
3739 " let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3740 )?;
3741 writeln!(
3742 code,
3743 " // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3744 )?;
3745 writeln!(
3746 code,
3747 " let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3748 )?;
3749 writeln!(
3750 code,
3751 " let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3752 )?;
3753 writeln!(
3754 code,
3755 " let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3756 )?;
3757 let gate_up_buf_bytes = 2 * intermediate * 4;
3758 writeln!(
3759 code,
3760 " // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3761 )?;
3762 writeln!(
3763 code,
3764 " let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3765 )?;
3766 writeln!(
3767 code,
3768 " let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3769 )?;
3770 writeln!(
3771 code,
3772 " let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3773 )?;
3774 writeln!(
3775 code,
3776 " let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3777 )?;
3778 writeln!(
3779 code,
3780 " let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3781 )?;
3782 writeln!(code)?;
3783
3784 let batch_hidden_bytes = hidden * 4; let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3787 let batch_gate_up_bytes = 2 * intermediate * 4;
3788 let batch_intermediate_bytes = intermediate * 4;
3789 writeln!(
3790 code,
3791 " // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3792 )?;
3793 writeln!(
3794 code,
3795 " let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3796 )?;
3797 writeln!(
3798 code,
3799 " let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3800 )?;
3801 writeln!(
3802 code,
3803 " let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3804 )?;
3805 writeln!(
3806 code,
3807 " let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3808 )?;
3809 writeln!(
3810 code,
3811 " let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3812 )?;
3813 writeln!(
3814 code,
3815 " let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3816 )?;
3817 writeln!(
3818 code,
3819 " let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3820 )?;
3821 writeln!(
3822 code,
3823 " let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3824 )?;
3825 writeln!(
3826 code,
3827 " let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3828 )?;
3829 writeln!(
3830 code,
3831 " let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3832 )?;
3833 writeln!(code)?;
3834
3835 writeln!(code, " // KV cache buffers (per-layer)")?;
3837 writeln!(
3838 code,
3839 " let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3840 )?;
3841 writeln!(
3842 code,
3843 " let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3844 )?;
3845 writeln!(code, " for _ in 0..NUM_LAYERS {{")?;
3846 writeln!(
3847 code,
3848 " k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3849 )?;
3850 writeln!(
3851 code,
3852 " v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3853 )?;
3854 writeln!(code, " }}")?;
3855 writeln!(code)?;
3856
3857 writeln!(code, " Self {{")?;
3858 writeln!(code, " device,")?;
3859 writeln!(code, " queue,")?;
3860 writeln!(code, " matmul_pipeline,")?;
3861 writeln!(code, " matmul_q8_pipeline,")?;
3862 writeln!(code, " matmul_q4_pipeline,")?;
3863 writeln!(code, " rms_norm_pipeline,")?;
3864 writeln!(code, " rope_pipeline,")?;
3865 writeln!(code, " softmax_pipeline,")?;
3866 writeln!(code, " silu_mul_pipeline,")?;
3867 writeln!(code, " silu_mul_fused_pipeline,")?;
3868 writeln!(code, " add_pipeline,")?;
3869 writeln!(code, " attention_pipeline,")?;
3870 writeln!(code, " add_inplace_pipeline,")?;
3871 writeln!(code, " copy_pipeline,")?;
3872 writeln!(code, " copy_offset_pipeline,")?;
3873 writeln!(code, " copy_f32_to_f16_offset_pipeline,")?;
3874 writeln!(code, " matmul_batch_pipeline,")?;
3875 writeln!(code, " matmul_q8_batch_pipeline,")?;
3876 writeln!(code, " matmul_q8_gemm_batch_pipeline,")?;
3877 writeln!(code, " matmul_q8_mma_pipeline,")?;
3878 writeln!(code, " matmul_q8_mma32_pipeline,")?;
3879 writeln!(code, " matmul_q8_mma32_h_pipeline,")?;
3880 writeln!(code, " matmul_q8_mma32_h4_pipeline,")?;
3881 writeln!(code, " matmul_q8_mma32_hh4_pipeline,")?;
3882 if config.qkv_bias {
3883 writeln!(code, " add_bias_batch_pipeline,")?;
3884 }
3885 writeln!(code, " matmul_q4_batch_pipeline,")?;
3886 writeln!(code, " rms_norm_batch_pipeline,")?;
3887 writeln!(code, " rope_batch_pipeline,")?;
3888 writeln!(code, " silu_mul_fused_batch_pipeline,")?;
3889 writeln!(code, " add_inplace_batch_pipeline,")?;
3890 writeln!(code, " copy_embedding_batch_pipeline,")?;
3891 writeln!(code, " attention_batch_pipeline,")?;
3892 writeln!(code, " attention_flash_batch_pipeline,")?;
3893 writeln!(code, " attention_mma_flash_batch_pipeline,")?;
3894 writeln!(code, " copy_kv_batch_pipeline,")?;
3895 writeln!(code, " rope_qk_batch_pipeline,")?;
3896 writeln!(code, " copy_kv_both_batch_pipeline,")?;
3897 writeln!(code, " embed_buf,")?;
3898 writeln!(code, " layers,")?;
3899 writeln!(code, " norm_buf,")?;
3900 writeln!(code, " lm_head_buf,")?;
3901 writeln!(code, " hidden_buf,")?;
3902 writeln!(code, " residual_buf,")?;
3903 writeln!(code, " normed_buf,")?;
3904 writeln!(code, " qkv_buf,")?;
3905 writeln!(code, " attn_out_buf,")?;
3906 writeln!(code, " attn_proj_buf,")?;
3907 writeln!(code, " gate_up_buf,")?;
3908 writeln!(code, " ffn_hidden_buf,")?;
3909 writeln!(code, " ffn_out_buf,")?;
3910 writeln!(code, " add_tmp_buf,")?;
3911 writeln!(code, " logits_buf,")?;
3912 writeln!(code, " batch_hidden_buf,")?;
3913 writeln!(code, " batch_residual_buf,")?;
3914 writeln!(code, " batch_qkv_buf,")?;
3915 writeln!(code, " batch_attn_out_buf,")?;
3916 writeln!(code, " batch_attn_proj_buf,")?;
3917 writeln!(code, " batch_gate_up_buf,")?;
3918 writeln!(code, " batch_ffn_hidden_buf,")?;
3919 writeln!(code, " batch_ffn_out_buf,")?;
3920 writeln!(code, " batch_tokens_buf,")?;
3921 writeln!(code, " batch_positions_buf,")?;
3922 writeln!(code, " k_cache,")?;
3923 writeln!(code, " v_cache,")?;
3924 writeln!(code, " pos: 0,")?;
3925 writeln!(code, " prev_cmd: None,")?;
3926 writeln!(code, " }}")?;
3927 writeln!(code, " }}")?;
3928 writeln!(code)?;
3929
3930 writeln!(
3932 code,
3933 " /// Run the forward pass for a single token at the current position."
3934 )?;
3935 writeln!(code, " ///")?;
3936 writeln!(
3937 code,
3938 " /// Returns logits over the vocabulary as a `Vec<f32>`."
3939 )?;
3940 writeln!(code, " ///")?;
3941 writeln!(
3942 code,
3943 " /// All GPU operations are encoded into a single command buffer and"
3944 )?;
3945 writeln!(
3946 code,
3947 " /// committed once at the end, avoiding per-operation synchronization."
3948 )?;
3949 writeln!(
3950 code,
3951 " pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3952 )?;
3953 writeln!(
3954 code,
3955 " // Wait for any pending prefill command buffer"
3956 )?;
3957 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
3958 writeln!(code, " prev.wait_until_completed();")?;
3959 writeln!(code, " }}")?;
3960 writeln!(code)?;
3961 writeln!(code, " let pos = self.pos;")?;
3962 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
3963 writeln!(code)?;
3964
3965 let matmul_fn = if is_q8 {
3968 "dispatch_matmul_q8"
3969 } else if is_q4 {
3970 "dispatch_matmul_q4"
3971 } else {
3972 "dispatch_matmul"
3973 };
3974
3975 writeln!(
3976 code,
3977 " // Single compute encoder for the entire forward pass (no blit transitions)"
3978 )?;
3979 writeln!(code, " {{")?;
3980 writeln!(
3981 code,
3982 " let enc = cmd.new_compute_command_encoder();"
3983 )?;
3984 writeln!(code)?;
3985
3986 writeln!(
3988 code,
3989 " // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3990 )?;
3991 writeln!(
3992 code,
3993 " // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3994 )?;
3995 writeln!(
3996 code,
3997 " // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3998 )?;
3999 writeln!(
4000 code,
4001 " // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
4002 hidden * 4,
4003 )?;
4004 writeln!(code, " unsafe {{")?;
4005 writeln!(
4006 code,
4007 " let embed_ptr = self.embed_buf.contents() as *const f32;"
4008 )?;
4009 writeln!(
4010 code,
4011 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4012 )?;
4013 writeln!(
4014 code,
4015 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
4016 )?;
4017 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
4018 writeln!(
4019 code,
4020 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4021 )?;
4022 writeln!(code, " hidden_ptr,")?;
4023 writeln!(code, " HIDDEN_SIZE,")?;
4024 writeln!(code, " );")?;
4025 writeln!(
4026 code,
4027 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4028 )?;
4029 writeln!(code, " }}")?;
4030 writeln!(code)?;
4031
4032 writeln!(code, " // 2. Transformer layers")?;
4034 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4035 writeln!(code)?;
4036 let q_byte_offset = 0usize;
4037 let k_float_offset = hidden;
4038 let v_float_offset = hidden + kv_dim;
4039
4040 writeln!(
4041 code,
4042 " // Pre-attention: rms_norm, fused QKV projection, RoPE"
4043 )?;
4044 writeln!(
4045 code,
4046 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4047 )?;
4048 writeln!(
4049 code,
4050 " // Fused Q+K+V matmul: single dispatch for all three projections"
4051 )?;
4052 writeln!(
4053 code,
4054 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4055 )?;
4056 if config.qkv_bias {
4057 writeln!(
4058 code,
4059 " // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4060 )?;
4061 writeln!(
4062 code,
4063 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4064 )?;
4065 }
4066 writeln!(
4067 code,
4068 " // Fused Q+K RoPE in one dispatch (saves 1 dispatch + barrier vs separate Q and K rope)"
4069 )?;
4070 writeln!(
4071 code,
4072 " self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4073 )?;
4074 writeln!(code)?;
4075 writeln!(
4076 code,
4077 " // Fused K+V cache update in one dispatch (f32 qkv_buf -> f16 KV cache)"
4078 )?;
4079 writeln!(
4080 code,
4081 " self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4082 )?;
4083 writeln!(code)?;
4084 writeln!(
4085 code,
4086 " // Attention using Q from qkv_buf (offset 0)"
4087 )?;
4088 writeln!(
4089 code,
4090 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4091 )?;
4092 writeln!(
4093 code,
4094 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4095 )?;
4096 writeln!(
4097 code,
4098 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4099 )?;
4100 writeln!(
4101 code,
4102 " // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4103 )?;
4104 writeln!(
4105 code,
4106 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4107 )?;
4108 writeln!(
4109 code,
4110 " // Fused gate+up matmul: single dispatch for both projections"
4111 )?;
4112 writeln!(
4113 code,
4114 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4115 )?;
4116 writeln!(
4117 code,
4118 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4119 )?;
4120 writeln!(
4121 code,
4122 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4123 )?;
4124 writeln!(
4125 code,
4126 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4127 )?;
4128 writeln!(code, " }}")?;
4129 writeln!(code)?;
4130
4131 writeln!(code, " // 3. Final RMS norm + logits projection")?;
4133 writeln!(
4134 code,
4135 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4136 )?;
4137 writeln!(
4138 code,
4139 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4140 )?;
4141 writeln!(code)?;
4142 writeln!(code, " enc.end_encoding();")?;
4143 writeln!(code, " }}")?;
4144 writeln!(code)?;
4145
4146 writeln!(
4148 code,
4149 " // 5. Commit all GPU work and wait for completion"
4150 )?;
4151 writeln!(code, " cmd.commit();")?;
4152 writeln!(code, " cmd.wait_until_completed();")?;
4153 writeln!(code)?;
4154 writeln!(code, " // 6. Read back logits from GPU")?;
4155 writeln!(code, " let logits = unsafe {{")?;
4156 writeln!(
4157 code,
4158 " let ptr = self.logits_buf.contents() as *const f32;"
4159 )?;
4160 writeln!(
4161 code,
4162 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4163 )?;
4164 writeln!(code, " }};")?;
4165 writeln!(code)?;
4166 writeln!(code, " self.pos += 1;")?;
4167 writeln!(code, " logits")?;
4168 writeln!(code, " }}")?;
4169 writeln!(code)?;
4170
4171 writeln!(
4173 code,
4174 " /// Profiling forward pass that prints per-stage GPU timing."
4175 )?;
4176 writeln!(code, " ///")?;
4177 writeln!(
4178 code,
4179 " /// Each stage is committed and waited on separately so that GPU timestamps"
4180 )?;
4181 writeln!(
4182 code,
4183 " /// accurately reflect per-operation cost. This is slower than `forward()` due"
4184 )?;
4185 writeln!(
4186 code,
4187 " /// to the per-stage synchronization, but useful for identifying bottlenecks."
4188 )?;
4189 writeln!(
4190 code,
4191 " pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4192 )?;
4193 writeln!(code, " use std::time::Instant;")?;
4194 writeln!(code)?;
4195 writeln!(
4196 code,
4197 " // Wait for any pending prefill command buffer"
4198 )?;
4199 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4200 writeln!(code, " prev.wait_until_completed();")?;
4201 writeln!(code, " }}")?;
4202 writeln!(code)?;
4203 writeln!(code, " let pos = self.pos;")?;
4204 writeln!(code)?;
4205
4206 writeln!(
4208 code,
4209 " // ── Stage: Embedding lookup (CPU via unified memory) ──"
4210 )?;
4211 writeln!(code, " let t_embed = Instant::now();")?;
4212 writeln!(code, " unsafe {{")?;
4213 writeln!(
4214 code,
4215 " let embed_ptr = self.embed_buf.contents() as *const f32;"
4216 )?;
4217 writeln!(
4218 code,
4219 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4220 )?;
4221 writeln!(
4222 code,
4223 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
4224 )?;
4225 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
4226 writeln!(
4227 code,
4228 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4229 )?;
4230 writeln!(code, " hidden_ptr,")?;
4231 writeln!(code, " HIDDEN_SIZE,")?;
4232 writeln!(code, " );")?;
4233 writeln!(
4234 code,
4235 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4236 )?;
4237 writeln!(code, " }}")?;
4238 writeln!(code, " let d_embed = t_embed.elapsed();")?;
4239 writeln!(code)?;
4240
4241 writeln!(code, " // ── Stage: Transformer layers (GPU) ──")?;
4243 writeln!(code, " let t_layers = Instant::now();")?;
4244 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4245 writeln!(code, " {{")?;
4246 writeln!(
4247 code,
4248 " let enc = cmd.new_compute_command_encoder();"
4249 )?;
4250 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4251 writeln!(
4252 code,
4253 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4254 )?;
4255 writeln!(
4256 code,
4257 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4258 )?;
4259 if config.qkv_bias {
4260 writeln!(
4261 code,
4262 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4263 )?;
4264 }
4265 writeln!(
4266 code,
4267 " self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4268 )?;
4269 writeln!(
4270 code,
4271 " self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4272 )?;
4273 writeln!(
4274 code,
4275 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4276 )?;
4277 writeln!(
4278 code,
4279 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4280 )?;
4281 writeln!(
4282 code,
4283 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4284 )?;
4285 writeln!(
4286 code,
4287 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4288 )?;
4289 writeln!(
4290 code,
4291 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4292 )?;
4293 writeln!(
4294 code,
4295 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4296 )?;
4297 writeln!(
4298 code,
4299 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4300 )?;
4301 writeln!(
4302 code,
4303 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4304 )?;
4305 writeln!(code, " }}")?;
4306 writeln!(code, " enc.end_encoding();")?;
4307 writeln!(code, " }}")?;
4308 writeln!(code, " cmd.commit();")?;
4309 writeln!(code, " cmd.wait_until_completed();")?;
4310 writeln!(code, " let d_layers = t_layers.elapsed();")?;
4311 writeln!(code)?;
4312
4313 writeln!(code, " // ── Stage: Final norm + logits (GPU) ──")?;
4315 writeln!(code, " let t_logits = Instant::now();")?;
4316 writeln!(code, " let cmd2 = self.queue.new_command_buffer();")?;
4317 writeln!(code, " {{")?;
4318 writeln!(
4319 code,
4320 " let enc = cmd2.new_compute_command_encoder();"
4321 )?;
4322 writeln!(
4323 code,
4324 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4325 )?;
4326 writeln!(
4327 code,
4328 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4329 )?;
4330 writeln!(code, " enc.end_encoding();")?;
4331 writeln!(code, " }}")?;
4332 writeln!(code, " cmd2.commit();")?;
4333 writeln!(code, " cmd2.wait_until_completed();")?;
4334 writeln!(code, " let d_logits = t_logits.elapsed();")?;
4335 writeln!(code)?;
4336
4337 writeln!(
4339 code,
4340 " eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4341 )?;
4342 writeln!(code, " d_embed.as_secs_f64() * 1000.0,")?;
4343 writeln!(code, " d_layers.as_secs_f64() * 1000.0,")?;
4344 writeln!(code, " d_logits.as_secs_f64() * 1000.0,")?;
4345 writeln!(
4346 code,
4347 " (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4348 )?;
4349 writeln!(code)?;
4350
4351 writeln!(code, " let logits = unsafe {{")?;
4353 writeln!(
4354 code,
4355 " let ptr = self.logits_buf.contents() as *const f32;"
4356 )?;
4357 writeln!(
4358 code,
4359 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4360 )?;
4361 writeln!(code, " }};")?;
4362 writeln!(code)?;
4363 writeln!(code, " self.pos += 1;")?;
4364 writeln!(code, " logits")?;
4365 writeln!(code, " }}")?;
4366 writeln!(code)?;
4367
4368 writeln!(
4370 code,
4371 " /// Asynchronous forward pass for a single prefill token (no logits readback)."
4372 )?;
4373 writeln!(code, " ///")?;
4374 writeln!(
4375 code,
4376 " /// Commits the command buffer without waiting, enabling double-buffered"
4377 )?;
4378 writeln!(
4379 code,
4380 " /// execution: GPU processes token N while CPU encodes token N+1."
4381 )?;
4382 writeln!(
4383 code,
4384 " pub fn forward_prefill(&mut self, token_id: u32) {{"
4385 )?;
4386 writeln!(code, " self.forward_prefill_batch(&[token_id]);")?;
4387 writeln!(code, " }}")?;
4388 writeln!(code)?;
4389
4390 let batch_matmul_fn = if is_q8 {
4393 "dispatch_matmul_q8_batch"
4394 } else if is_q4 {
4395 "dispatch_matmul_q4_batch"
4396 } else {
4397 "dispatch_matmul_batch"
4398 };
4399
4400 writeln!(
4401 code,
4402 " /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4403 )?;
4404 writeln!(code, " ///")?;
4405 writeln!(
4406 code,
4407 " /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4408 )?;
4409 writeln!(
4410 code,
4411 " /// of mat-vec), and batched causal attention with a single GPU dispatch."
4412 )?;
4413 writeln!(
4414 code,
4415 " /// This provides significant speedup during prompt prefill."
4416 )?;
4417 writeln!(
4418 code,
4419 " pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4420 )?;
4421 writeln!(code, " if tokens.is_empty() {{ return; }}")?;
4422 writeln!(
4423 code,
4424 " // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4425 )?;
4426 writeln!(
4427 code,
4428 " // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4429 )?;
4430 writeln!(
4431 code,
4432 " // longer than that must be processed iteratively. The KV cache"
4433 )?;
4434 writeln!(code, " // carries state across chunks via self.pos.")?;
4435 writeln!(
4436 code,
4437 " for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4438 )?;
4439 writeln!(code, " let m = chunk.len();")?;
4440 writeln!(code, " if m == 0 {{ continue; }}")?;
4441 writeln!(code, " let start_pos = self.pos;")?;
4442 writeln!(code)?;
4443 writeln!(code, " // Wait for any pending command buffer")?;
4444 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4445 writeln!(code, " prev.wait_until_completed();")?;
4446 writeln!(code, " }}")?;
4447 writeln!(code)?;
4448
4449 writeln!(
4451 code,
4452 " // Upload token IDs and positions to GPU buffers"
4453 )?;
4454 writeln!(code, " unsafe {{")?;
4455 writeln!(
4456 code,
4457 " let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4458 )?;
4459 writeln!(
4460 code,
4461 " let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4462 )?;
4463 writeln!(code, " for i in 0..m {{")?;
4464 writeln!(code, " *tok_ptr.add(i) = chunk[i];")?;
4465 writeln!(
4466 code,
4467 " *pos_ptr.add(i) = (start_pos + i) as u32;"
4468 )?;
4469 writeln!(code, " }}")?;
4470 writeln!(code, " }}")?;
4471 writeln!(code)?;
4472
4473 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4474 writeln!(code, " {{")?;
4475 writeln!(
4476 code,
4477 " let enc = cmd.new_compute_command_encoder();"
4478 )?;
4479 writeln!(code)?;
4480
4481 writeln!(
4483 code,
4484 " // 1. Batch embedding lookup: copy all token embeddings at once"
4485 )?;
4486 writeln!(
4487 code,
4488 " self.dispatch_copy_embedding_batch(&enc, m);"
4489 )?;
4490 writeln!(
4492 code,
4493 " self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4494 )?;
4495 writeln!(code)?;
4496
4497 writeln!(code, " // 2. Transformer layers")?;
4499 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4500 writeln!(code)?;
4501
4502 writeln!(
4504 code,
4505 " // Batch RMS norm: batch_residual -> batch_hidden"
4506 )?;
4507 writeln!(
4508 code,
4509 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4510 )?;
4511
4512 writeln!(
4514 code,
4515 " // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4516 )?;
4517 writeln!(
4518 code,
4519 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4520 )?;
4521 if config.qkv_bias {
4522 writeln!(
4523 code,
4524 " // Qwen2: broadcast-add QKV bias across all M tokens."
4525 )?;
4526 writeln!(
4527 code,
4528 " self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4529 )?;
4530 }
4531 writeln!(code)?;
4532
4533 let k_float_offset = hidden;
4535 writeln!(
4536 code,
4537 " // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4538 )?;
4539 writeln!(
4540 code,
4541 " self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4542 )?;
4543 writeln!(code)?;
4544
4545 let v_float_offset = hidden + kv_dim;
4547 writeln!(
4548 code,
4549 " // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4550 )?;
4551 writeln!(
4552 code,
4553 " self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4554 )?;
4555 writeln!(code)?;
4556
4557 writeln!(
4559 code,
4560 " // Batched causal attention: one dispatch for all M tokens"
4561 )?;
4562 writeln!(
4563 code,
4564 " self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4565 )?;
4566 writeln!(code)?;
4567
4568 writeln!(code, " // Batched O projection")?;
4570 writeln!(
4571 code,
4572 " self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4573 )?;
4574 writeln!(code)?;
4575
4576 writeln!(
4578 code,
4579 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4580 )?;
4581 writeln!(code)?;
4582
4583 writeln!(
4585 code,
4586 " // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4587 )?;
4588 writeln!(
4589 code,
4590 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4591 )?;
4592 writeln!(
4593 code,
4594 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4595 )?;
4596 writeln!(
4597 code,
4598 " self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4599 )?;
4600 writeln!(
4601 code,
4602 " self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4603 )?;
4604 writeln!(
4605 code,
4606 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4607 )?;
4608 writeln!(code, " }}")?;
4609 writeln!(code)?;
4610
4611 writeln!(
4613 code,
4614 " // Copy last token's residual to single-token buffer for subsequent forward()"
4615 )?;
4616 writeln!(
4617 code,
4618 " self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4619 )?;
4620 writeln!(code)?;
4621 writeln!(code, " enc.end_encoding();")?;
4622 writeln!(code, " }}")?;
4623 writeln!(code)?;
4624
4625 writeln!(code, " cmd.commit();")?;
4626 writeln!(code, " self.prev_cmd = Some(cmd.to_owned());")?;
4627 writeln!(code, " self.pos += m;")?;
4628 writeln!(code, " }} // end for chunk")?;
4629 writeln!(code, " }}")?;
4630 writeln!(code)?;
4631
4632 writeln!(
4634 code,
4635 " /// Reset the model state for a new inference request."
4636 )?;
4637 writeln!(code, " pub fn reset(&mut self) {{")?;
4638 writeln!(code, " self.pos = 0;")?;
4639 writeln!(code, " self.prev_cmd = None;")?;
4640 writeln!(code, " }}")?;
4641 writeln!(code)?;
4642
4643 writeln!(
4645 code,
4646 " // ── Dispatch helpers (append to a shared compute command encoder) ──"
4647 )?;
4648 writeln!(
4649 code,
4650 " // These methods set pipeline state + buffers + dispatch on an existing"
4651 )?;
4652 writeln!(
4653 code,
4654 " // encoder, avoiding per-operation encoder creation overhead."
4655 )?;
4656 writeln!(code)?;
4657
4658 writeln!(
4660 code,
4661 " /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4662 )?;
4663 writeln!(
4664 code,
4665 " fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4666 )?;
4667 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
4668 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
4669 writeln!(
4670 code,
4671 " enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4672 )?;
4673 writeln!(
4674 code,
4675 " enc.set_buffer(0, Some(&self.residual_buf), 0);"
4676 )?;
4677 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
4678 writeln!(
4679 code,
4680 " enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4681 )?;
4682 writeln!(
4683 code,
4684 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4685 )?;
4686 writeln!(
4687 code,
4688 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4689 )?;
4690 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4691 writeln!(
4692 code,
4693 " let grid_size = MTLSize::new(1, 1, 1); // single threadgroup"
4694 )?;
4695 writeln!(
4696 code,
4697 " enc.dispatch_thread_groups(grid_size, tg_size);"
4698 )?;
4699 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4700 writeln!(code, " }}")?;
4701 writeln!(code)?;
4702
4703 writeln!(
4705 code,
4706 " /// Dispatch matrix-vector multiply: weight * input -> output."
4707 )?;
4708 writeln!(
4709 code,
4710 " fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4711 )?;
4712 writeln!(code, " let r: u32 = rows as u32;")?;
4713 writeln!(code, " let c: u32 = cols as u32;")?;
4714 writeln!(
4715 code,
4716 " enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4717 )?;
4718 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4719 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4720 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4721 writeln!(
4722 code,
4723 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4724 )?;
4725 writeln!(
4726 code,
4727 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4728 )?;
4729 writeln!(
4730 code,
4731 " // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4732 )?;
4733 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4734 writeln!(code, " let num_tg = ((rows + 63) / 64) as u64;")?;
4735 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4736 writeln!(
4737 code,
4738 " enc.dispatch_thread_groups(grid_size, tg_size);"
4739 )?;
4740 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4741 writeln!(code, " }}")?;
4742 writeln!(code)?;
4743
4744 writeln!(
4746 code,
4747 " /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4748 )?;
4749 writeln!(
4750 code,
4751 " /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4752 )?;
4753 writeln!(
4754 code,
4755 " fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4756 )?;
4757 writeln!(code, " let r: u32 = rows as u32;")?;
4758 writeln!(code, " let c: u32 = cols as u32;")?;
4759 writeln!(
4760 code,
4761 " enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4762 )?;
4763 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4764 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4765 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4766 writeln!(
4767 code,
4768 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4769 )?;
4770 writeln!(
4771 code,
4772 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4773 )?;
4774 writeln!(
4775 code,
4776 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4777 )?;
4778 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4779 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4780 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4781 writeln!(
4782 code,
4783 " enc.dispatch_thread_groups(grid_size, tg_size);"
4784 )?;
4785 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4786 writeln!(code, " }}")?;
4787 writeln!(code)?;
4788
4789 writeln!(
4791 code,
4792 " /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4793 )?;
4794 writeln!(
4795 code,
4796 " /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4797 )?;
4798 writeln!(
4799 code,
4800 " fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4801 )?;
4802 writeln!(code, " let r: u32 = rows as u32;")?;
4803 writeln!(code, " let c: u32 = cols as u32;")?;
4804 writeln!(
4805 code,
4806 " enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4807 )?;
4808 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4809 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4810 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4811 writeln!(
4812 code,
4813 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4814 )?;
4815 writeln!(
4816 code,
4817 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4818 )?;
4819 writeln!(
4820 code,
4821 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4822 )?;
4823 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4824 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4825 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4826 writeln!(
4827 code,
4828 " enc.dispatch_thread_groups(grid_size, tg_size);"
4829 )?;
4830 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4831 writeln!(code, " }}")?;
4832 writeln!(code)?;
4833
4834 writeln!(code, " /// Dispatch RoPE on a buffer in-place.")?;
4836 writeln!(
4837 code,
4838 " fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4839 )?;
4840 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4841 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4842 writeln!(code, " let p: u32 = pos as u32;")?;
4843 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4844 writeln!(
4845 code,
4846 " let total_pairs = num_heads * (head_dim / 2);"
4847 )?;
4848 writeln!(
4849 code,
4850 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4851 )?;
4852 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
4853 writeln!(
4854 code,
4855 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4856 )?;
4857 writeln!(
4858 code,
4859 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4860 )?;
4861 writeln!(
4862 code,
4863 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4864 )?;
4865 writeln!(
4866 code,
4867 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4868 )?;
4869 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4870 writeln!(
4871 code,
4872 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4873 )?;
4874 writeln!(
4875 code,
4876 " enc.dispatch_thread_groups(grid_size, tg_size);"
4877 )?;
4878 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4879 writeln!(code, " }}")?;
4880 writeln!(code)?;
4881
4882 writeln!(
4884 code,
4885 " /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4886 )?;
4887 writeln!(
4888 code,
4889 " fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4890 )?;
4891 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4892 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4893 writeln!(code, " let p: u32 = pos as u32;")?;
4894 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4895 writeln!(
4896 code,
4897 " let total_pairs = num_heads * (head_dim / 2);"
4898 )?;
4899 writeln!(
4900 code,
4901 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4902 )?;
4903 writeln!(
4904 code,
4905 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
4906 )?;
4907 writeln!(
4908 code,
4909 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4910 )?;
4911 writeln!(
4912 code,
4913 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4914 )?;
4915 writeln!(
4916 code,
4917 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4918 )?;
4919 writeln!(
4920 code,
4921 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4922 )?;
4923 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4924 writeln!(
4925 code,
4926 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4927 )?;
4928 writeln!(
4929 code,
4930 " enc.dispatch_thread_groups(grid_size, tg_size);"
4931 )?;
4932 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4933 writeln!(code, " }}")?;
4934 writeln!(code)?;
4935
4936 writeln!(
4938 code,
4939 " /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4940 )?;
4941 writeln!(
4942 code,
4943 " fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4944 )?;
4945 writeln!(code, " let sl: u32 = seq_len as u32;")?;
4946 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
4947 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
4948 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
4949 writeln!(
4950 code,
4951 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
4952 )?;
4953 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
4954 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
4955 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
4956 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
4957 writeln!(
4958 code,
4959 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4960 )?;
4961 writeln!(
4962 code,
4963 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4964 )?;
4965 writeln!(
4966 code,
4967 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4968 )?;
4969 writeln!(
4970 code,
4971 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4972 )?;
4973 writeln!(
4974 code,
4975 " // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4976 )?;
4977 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4978 writeln!(
4979 code,
4980 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4981 )?;
4982 writeln!(
4983 code,
4984 " enc.dispatch_thread_groups(grid_size, tg_size);"
4985 )?;
4986 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4987 writeln!(code, " }}")?;
4988 writeln!(code)?;
4989
4990 writeln!(
4992 code,
4993 " /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4994 )?;
4995 writeln!(
4996 code,
4997 " fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4998 )?;
4999 writeln!(code, " let sl: u32 = seq_len as u32;")?;
5000 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
5001 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
5002 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
5003 writeln!(
5004 code,
5005 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
5006 )?;
5007 writeln!(
5008 code,
5009 " enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5010 )?;
5011 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
5012 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
5013 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
5014 writeln!(
5015 code,
5016 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5017 )?;
5018 writeln!(
5019 code,
5020 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5021 )?;
5022 writeln!(
5023 code,
5024 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5025 )?;
5026 writeln!(
5027 code,
5028 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5029 )?;
5030 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5031 writeln!(
5032 code,
5033 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5034 )?;
5035 writeln!(
5036 code,
5037 " enc.dispatch_thread_groups(grid_size, tg_size);"
5038 )?;
5039 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5040 writeln!(code, " }}")?;
5041 writeln!(code)?;
5042
5043 writeln!(code, " /// Dispatch fused SiLU-multiply kernel.")?;
5045 writeln!(
5046 code,
5047 " fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5048 )?;
5049 writeln!(code, " let count: u32 = n as u32;")?;
5050 writeln!(
5051 code,
5052 " enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5053 )?;
5054 writeln!(code, " enc.set_buffer(0, Some(gate), 0);")?;
5055 writeln!(code, " enc.set_buffer(1, Some(up), 0);")?;
5056 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5057 writeln!(
5058 code,
5059 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5060 )?;
5061 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5062 writeln!(
5063 code,
5064 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5065 )?;
5066 writeln!(
5067 code,
5068 " enc.dispatch_thread_groups(grid_size, tg_size);"
5069 )?;
5070 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5071 writeln!(code, " }}")?;
5072 writeln!(code)?;
5073
5074 writeln!(
5076 code,
5077 " /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5078 )?;
5079 writeln!(
5080 code,
5081 " /// gate_up_buf contains [gate(n), up(n)] contiguously."
5082 )?;
5083 writeln!(
5084 code,
5085 " fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5086 )?;
5087 writeln!(code, " let count: u32 = n as u32;")?;
5088 writeln!(
5089 code,
5090 " enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5091 )?;
5092 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5093 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5094 writeln!(
5095 code,
5096 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5097 )?;
5098 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5099 writeln!(
5100 code,
5101 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5102 )?;
5103 writeln!(
5104 code,
5105 " enc.dispatch_thread_groups(grid_size, tg_size);"
5106 )?;
5107 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5108 writeln!(code, " }}")?;
5109 writeln!(code)?;
5110
5111 writeln!(
5113 code,
5114 " /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5115 )?;
5116 writeln!(
5117 code,
5118 " fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5119 )?;
5120 writeln!(code, " let n: u32 = count as u32;")?;
5121 writeln!(
5122 code,
5123 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5124 )?;
5125 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5126 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5127 writeln!(
5128 code,
5129 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5130 )?;
5131 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5132 writeln!(
5133 code,
5134 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5135 )?;
5136 writeln!(
5137 code,
5138 " enc.dispatch_thread_groups(grid_size, tg_size);"
5139 )?;
5140 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5141 writeln!(code, " }}")?;
5142 writeln!(code)?;
5143
5144 writeln!(
5146 code,
5147 " /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5148 )?;
5149 writeln!(
5150 code,
5151 " /// Used for embedding table lookup (copy a specific row)."
5152 )?;
5153 writeln!(
5154 code,
5155 " fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5156 )?;
5157 writeln!(code, " let off: u32 = src_offset as u32;")?;
5158 writeln!(code, " let n: u32 = count as u32;")?;
5159 writeln!(
5160 code,
5161 " enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5162 )?;
5163 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5164 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5165 writeln!(
5166 code,
5167 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5168 )?;
5169 writeln!(
5170 code,
5171 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5172 )?;
5173 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5174 writeln!(
5175 code,
5176 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5177 )?;
5178 writeln!(
5179 code,
5180 " enc.dispatch_thread_groups(grid_size, tg_size);"
5181 )?;
5182 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5183 writeln!(code, " }}")?;
5184 writeln!(code)?;
5185
5186 writeln!(
5188 code,
5189 " /// Dispatch copy from source at byte offset to destination at float offset."
5190 )?;
5191 writeln!(
5192 code,
5193 " /// Used for KV cache updates from fused QKV buffer."
5194 )?;
5195 writeln!(
5196 code,
5197 " fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5198 )?;
5199 writeln!(code, " let n: u32 = count as u32;")?;
5200 writeln!(
5201 code,
5202 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5203 )?;
5204 writeln!(
5205 code,
5206 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5207 )?;
5208 writeln!(
5209 code,
5210 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5211 )?;
5212 writeln!(
5213 code,
5214 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5215 )?;
5216 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5217 writeln!(
5218 code,
5219 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5220 )?;
5221 writeln!(
5222 code,
5223 " enc.dispatch_thread_groups(grid_size, tg_size);"
5224 )?;
5225 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5226 writeln!(code, " }}")?;
5227 writeln!(code)?;
5228
5229 writeln!(
5231 code,
5232 " /// Dispatch f32 -> f16 copy from src at byte offset to half-typed dst at element offset."
5233 )?;
5234 writeln!(
5235 code,
5236 " /// Used for single-token decode KV cache updates (f32 QKV buf -> f16 KV cache)."
5237 )?;
5238 writeln!(
5239 code,
5240 " fn dispatch_copy_from_offset_f16(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_elem_offset: usize, count: usize) {{"
5241 )?;
5242 writeln!(code, " let n: u32 = count as u32;")?;
5243 writeln!(
5244 code,
5245 " enc.set_compute_pipeline_state(&self.copy_f32_to_f16_offset_pipeline);"
5246 )?;
5247 writeln!(
5248 code,
5249 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5250 )?;
5251 writeln!(
5252 code,
5253 " enc.set_buffer(1, Some(dst), (dst_elem_offset * 2) as u64);"
5254 )?;
5255 writeln!(
5256 code,
5257 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5258 )?;
5259 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5260 writeln!(
5261 code,
5262 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5263 )?;
5264 writeln!(
5265 code,
5266 " enc.dispatch_thread_groups(grid_size, tg_size);"
5267 )?;
5268 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5269 writeln!(code, " }}")?;
5270 writeln!(code)?;
5271
5272 writeln!(
5274 code,
5275 " /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5276 )?;
5277 writeln!(
5278 code,
5279 " /// Used for KV cache updates (write to a specific position in the cache)."
5280 )?;
5281 writeln!(
5282 code,
5283 " fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5284 )?;
5285 writeln!(code, " let n: u32 = count as u32;")?;
5286 writeln!(
5287 code,
5288 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5289 )?;
5290 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5291 writeln!(
5292 code,
5293 " enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5294 )?;
5295 writeln!(
5296 code,
5297 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5298 )?;
5299 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5300 writeln!(
5301 code,
5302 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5303 )?;
5304 writeln!(
5305 code,
5306 " enc.dispatch_thread_groups(grid_size, tg_size);"
5307 )?;
5308 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5309 writeln!(code, " }}")?;
5310 writeln!(code)?;
5311
5312 writeln!(
5314 code,
5315 " /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5316 )?;
5317 writeln!(
5318 code,
5319 " fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5320 )?;
5321 writeln!(code, " let count: u32 = n as u32;")?;
5322 writeln!(
5323 code,
5324 " enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5325 )?;
5326 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5327 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5328 writeln!(
5329 code,
5330 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5331 )?;
5332 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5333 writeln!(
5334 code,
5335 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5336 )?;
5337 writeln!(
5338 code,
5339 " enc.dispatch_thread_groups(grid_size, tg_size);"
5340 )?;
5341 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5342 writeln!(code, " }}")?;
5343 writeln!(code)?;
5344
5345 writeln!(code, " // ── Batched prefill dispatch helpers ──")?;
5347 writeln!(code)?;
5348
5349 writeln!(
5351 code,
5352 " /// Dispatch batched embedding lookup: copy M token embeddings at once."
5353 )?;
5354 writeln!(
5355 code,
5356 " fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5357 )?;
5358 writeln!(code, " let dim: u32 = HIDDEN_SIZE as u32;")?;
5359 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5360 writeln!(
5361 code,
5362 " enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5363 )?;
5364 writeln!(code, " enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5365 writeln!(
5366 code,
5367 " enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5368 )?;
5369 writeln!(
5370 code,
5371 " enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5372 )?;
5373 writeln!(
5374 code,
5375 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5376 )?;
5377 writeln!(
5378 code,
5379 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5380 )?;
5381 writeln!(code, " let total = num_tokens * HIDDEN_SIZE;")?;
5382 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5383 writeln!(
5384 code,
5385 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5386 )?;
5387 writeln!(
5388 code,
5389 " enc.dispatch_thread_groups(grid_size, tg_size);"
5390 )?;
5391 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5392 writeln!(code, " }}")?;
5393 writeln!(code)?;
5394
5395 writeln!(
5397 code,
5398 " /// Dispatch batched RMS norm: normalizes M vectors at once."
5399 )?;
5400 writeln!(
5401 code,
5402 " fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5403 )?;
5404 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
5405 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
5406 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5407 writeln!(
5408 code,
5409 " enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5410 )?;
5411 writeln!(code, " enc.set_buffer(0, Some(input), 0);")?;
5412 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
5413 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5414 writeln!(
5415 code,
5416 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5417 )?;
5418 writeln!(
5419 code,
5420 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5421 )?;
5422 writeln!(
5423 code,
5424 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5425 )?;
5426 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5427 writeln!(
5428 code,
5429 " let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5430 )?;
5431 writeln!(
5432 code,
5433 " enc.dispatch_thread_groups(grid_size, tg_size);"
5434 )?;
5435 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5436 writeln!(code, " }}")?;
5437 writeln!(code)?;
5438
5439 writeln!(
5441 code,
5442 " /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5443 )?;
5444 writeln!(
5445 code,
5446 " fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5447 )?;
5448 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5449 writeln!(code, " let r: u32 = rows as u32;")?;
5450 writeln!(code, " let c: u32 = cols as u32;")?;
5451 writeln!(
5452 code,
5453 " enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5454 )?;
5455 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5456 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5457 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5458 writeln!(
5459 code,
5460 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5461 )?;
5462 writeln!(
5463 code,
5464 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5465 )?;
5466 writeln!(
5467 code,
5468 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5469 )?;
5470 writeln!(
5471 code,
5472 " let row_tgs = (rows + 63) / 64; // 64 rows per threadgroup for f32"
5473 )?;
5474 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5475 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5476 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5477 writeln!(
5478 code,
5479 " enc.dispatch_thread_groups(grid_size, tg_size);"
5480 )?;
5481 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5482 writeln!(code, " }}")?;
5483 writeln!(code)?;
5484
5485 writeln!(
5487 code,
5488 " /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5489 )?;
5490 writeln!(code, " ///")?;
5491 writeln!(
5492 code,
5493 " /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5494 )?;
5495 writeln!(
5496 code,
5497 " /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5498 )?;
5499 writeln!(
5500 code,
5501 " fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5502 )?;
5503 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5504 writeln!(code, " let r: u32 = rows as u32;")?;
5505 writeln!(code, " let c: u32 = cols as u32;")?;
5506 writeln!(
5507 code,
5508 " // Tile sizes must match the Metal shader constants."
5509 )?;
5510 writeln!(code, " const TOKENS_PER_TG_Q8: usize = 4;")?;
5511 writeln!(code, " const MMA_TOK_TILE: usize = 16;")?;
5512 writeln!(code, " const MMA_ROW_TILE: usize = 16;")?;
5513 writeln!(code, " const MMA32_TOK_TILE: usize = 32;")?;
5514 writeln!(code, " const MMA32_ROW_TILE: usize = 32;")?;
5515 writeln!(
5516 code,
5517 " // Hardware matrix-multiply paths (simdgroup_matrix)."
5518 )?;
5519 writeln!(
5520 code,
5521 " // Prefer the large 32×32 tile when the problem supports it — halves"
5522 )?;
5523 writeln!(
5524 code,
5525 " // dispatch count and reuses each weight load across 32 tokens."
5526 )?;
5527 writeln!(
5528 code,
5529 " if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5530 )?;
5531 writeln!(
5532 code,
5533 " // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5534 )?;
5535 writeln!(
5536 code,
5537 " // It wins at moderate prefill lengths where the GPU is wave-starved,"
5538 )?;
5539 writeln!(
5540 code,
5541 " // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5542 )?;
5543 writeln!(
5544 code,
5545 " // case (135M / 360M). Switch at cols >= 2048 — a clean split that"
5546 )?;
5547 writeln!(
5548 code,
5549 " // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5550 )?;
5551 writeln!(
5552 code,
5553 " // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5554 )?;
5555 writeln!(
5556 code,
5557 " // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5558 )?;
5559 writeln!(
5560 code,
5561 " // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5562 )?;
5563 writeln!(code, " let use_h4 = cols >= 2048;")?;
5564 writeln!(code, " let pipe = if use_h4 {{")?;
5565 writeln!(code, " if num_tokens >= 256 {{")?;
5566 writeln!(
5567 code,
5568 " &self.matmul_q8_mma32_hh4_pipeline"
5569 )?;
5570 writeln!(code, " }} else {{")?;
5571 writeln!(
5572 code,
5573 " &self.matmul_q8_mma32_h4_pipeline"
5574 )?;
5575 writeln!(code, " }}")?;
5576 writeln!(code, " }} else {{")?;
5577 writeln!(code, " &self.matmul_q8_mma32_pipeline")?;
5578 writeln!(code, " }};")?;
5579 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
5580 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5581 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5582 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5583 writeln!(
5584 code,
5585 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5586 )?;
5587 writeln!(
5588 code,
5589 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5590 )?;
5591 writeln!(
5592 code,
5593 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5594 )?;
5595 writeln!(code, " let row_tgs = rows / MMA32_ROW_TILE;")?;
5596 writeln!(
5597 code,
5598 " let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5599 )?;
5600 writeln!(
5601 code,
5602 " let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5603 )?;
5604 writeln!(
5605 code,
5606 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5607 )?;
5608 writeln!(
5609 code,
5610 " enc.dispatch_thread_groups(grid_size, tg_size);"
5611 )?;
5612 writeln!(
5613 code,
5614 " }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5615 )?;
5616 writeln!(
5617 code,
5618 " enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5619 )?;
5620 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5621 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5622 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5623 writeln!(
5624 code,
5625 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5626 )?;
5627 writeln!(
5628 code,
5629 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5630 )?;
5631 writeln!(
5632 code,
5633 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5634 )?;
5635 writeln!(code, " let row_tgs = rows / MMA_ROW_TILE;")?;
5636 writeln!(
5637 code,
5638 " let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5639 )?;
5640 writeln!(
5641 code,
5642 " let tg_size = MTLSize::new(128, 1, 1); // 4 simdgroups × 32 lanes"
5643 )?;
5644 writeln!(
5645 code,
5646 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5647 )?;
5648 writeln!(
5649 code,
5650 " enc.dispatch_thread_groups(grid_size, tg_size);"
5651 )?;
5652 writeln!(code, " }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5653 writeln!(
5654 code,
5655 " enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5656 )?;
5657 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5658 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5659 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5660 writeln!(
5661 code,
5662 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5663 )?;
5664 writeln!(
5665 code,
5666 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5667 )?;
5668 writeln!(
5669 code,
5670 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5671 )?;
5672 writeln!(
5673 code,
5674 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5675 )?;
5676 writeln!(
5677 code,
5678 " let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5679 )?;
5680 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5681 writeln!(
5682 code,
5683 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5684 )?;
5685 writeln!(
5686 code,
5687 " enc.dispatch_thread_groups(grid_size, tg_size);"
5688 )?;
5689 writeln!(code, " }} else {{")?;
5690 writeln!(
5691 code,
5692 " enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5693 )?;
5694 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5695 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5696 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5697 writeln!(
5698 code,
5699 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5700 )?;
5701 writeln!(
5702 code,
5703 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5704 )?;
5705 writeln!(
5706 code,
5707 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5708 )?;
5709 writeln!(
5710 code,
5711 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5712 )?;
5713 writeln!(
5714 code,
5715 " let num_tg = (row_tgs * num_tokens) as u64;"
5716 )?;
5717 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5718 writeln!(
5719 code,
5720 " let grid_size = MTLSize::new(num_tg, 1, 1);"
5721 )?;
5722 writeln!(
5723 code,
5724 " enc.dispatch_thread_groups(grid_size, tg_size);"
5725 )?;
5726 writeln!(code, " }}")?;
5727 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5728 writeln!(code, " }}")?;
5729 writeln!(code)?;
5730
5731 writeln!(
5733 code,
5734 " /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5735 )?;
5736 writeln!(
5737 code,
5738 " fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5739 )?;
5740 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5741 writeln!(code, " let r: u32 = rows as u32;")?;
5742 writeln!(code, " let c: u32 = cols as u32;")?;
5743 writeln!(
5744 code,
5745 " enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5746 )?;
5747 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5748 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5749 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5750 writeln!(
5751 code,
5752 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5753 )?;
5754 writeln!(
5755 code,
5756 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5757 )?;
5758 writeln!(
5759 code,
5760 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5761 )?;
5762 writeln!(
5763 code,
5764 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q4"
5765 )?;
5766 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5767 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5768 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5769 writeln!(
5770 code,
5771 " enc.dispatch_thread_groups(grid_size, tg_size);"
5772 )?;
5773 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5774 writeln!(code, " }}")?;
5775 writeln!(code)?;
5776
5777 if config.qkv_bias {
5779 writeln!(
5780 code,
5781 " /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5782 )?;
5783 writeln!(
5784 code,
5785 " fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5786 )?;
5787 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5788 writeln!(code, " let r: u32 = rows as u32;")?;
5789 writeln!(
5790 code,
5791 " enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5792 )?;
5793 writeln!(code, " enc.set_buffer(0, Some(out), 0);")?;
5794 writeln!(code, " enc.set_buffer(1, Some(bias), 0);")?;
5795 writeln!(
5796 code,
5797 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5798 )?;
5799 writeln!(
5800 code,
5801 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5802 )?;
5803 writeln!(code, " let total = (num_tokens * rows) as u64;")?;
5804 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5805 writeln!(
5806 code,
5807 " let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5808 )?;
5809 writeln!(
5810 code,
5811 " enc.dispatch_thread_groups(grid_size, tg_size);"
5812 )?;
5813 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5814 writeln!(code, " }}")?;
5815 writeln!(code)?;
5816 }
5817
5818 writeln!(
5820 code,
5821 " /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5822 )?;
5823 writeln!(
5824 code,
5825 " /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5826 )?;
5827 writeln!(
5828 code,
5829 " /// `row_stride` is the number of floats per token row in the batch buffer."
5830 )?;
5831 writeln!(
5832 code,
5833 " fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5834 )?;
5835 writeln!(code, " let nh: u32 = num_heads as u32;")?;
5836 writeln!(code, " let hd: u32 = head_dim as u32;")?;
5837 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
5838 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5839 writeln!(
5840 code,
5841 " let pairs_per_token = num_heads * (head_dim / 2);"
5842 )?;
5843 writeln!(
5844 code,
5845 " let total_pairs = num_tokens * pairs_per_token;"
5846 )?;
5847 writeln!(
5860 code,
5861 " // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5862 )?;
5863 writeln!(code, " for t in 0..num_tokens {{")?;
5864 writeln!(
5865 code,
5866 " let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5867 )?;
5868 writeln!(
5869 code,
5870 " let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5871 )?;
5872 writeln!(
5873 code,
5874 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
5875 )?;
5876 writeln!(
5877 code,
5878 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
5879 )?;
5880 writeln!(
5881 code,
5882 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5883 )?;
5884 writeln!(
5885 code,
5886 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5887 )?;
5888 writeln!(
5889 code,
5890 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5891 )?;
5892 writeln!(
5893 code,
5894 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5895 )?;
5896 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5897 writeln!(
5898 code,
5899 " let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5900 )?;
5901 writeln!(
5902 code,
5903 " enc.dispatch_thread_groups(grid_size, tg_size);"
5904 )?;
5905 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5906 writeln!(code, " }}")?;
5907 writeln!(code, " }}")?;
5908 writeln!(code)?;
5909
5910 writeln!(
5912 code,
5913 " /// Dispatch batched fused SiLU-multiply for M tokens."
5914 )?;
5915 writeln!(
5916 code,
5917 " fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5918 )?;
5919 writeln!(code, " let count: u32 = n as u32;")?;
5920 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5921 writeln!(
5922 code,
5923 " enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5924 )?;
5925 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5926 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5927 writeln!(
5928 code,
5929 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5930 )?;
5931 writeln!(
5932 code,
5933 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5934 )?;
5935 writeln!(code, " let total = num_tokens * n;")?;
5936 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5937 writeln!(
5938 code,
5939 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5940 )?;
5941 writeln!(
5942 code,
5943 " enc.dispatch_thread_groups(grid_size, tg_size);"
5944 )?;
5945 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5946 writeln!(code, " }}")?;
5947 writeln!(code)?;
5948
5949 writeln!(
5951 code,
5952 " /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5953 )?;
5954 writeln!(
5955 code,
5956 " fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5957 )?;
5958 writeln!(code, " let count: u32 = total_n as u32;")?;
5959 writeln!(
5960 code,
5961 " enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5962 )?;
5963 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5964 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5965 writeln!(
5966 code,
5967 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5968 )?;
5969 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5970 writeln!(
5971 code,
5972 " let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5973 )?;
5974 writeln!(
5975 code,
5976 " enc.dispatch_thread_groups(grid_size, tg_size);"
5977 )?;
5978 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5979 writeln!(code, " }}")?;
5980 writeln!(code)?;
5981
5982 writeln!(
5984 code,
5985 " /// Copy src to dst using compute copy kernel (for batch residual init)."
5986 )?;
5987 writeln!(
5988 code,
5989 " fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5990 )?;
5991 writeln!(code, " let n: u32 = count as u32;")?;
5992 writeln!(
5993 code,
5994 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5995 )?;
5996 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5997 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5998 writeln!(
5999 code,
6000 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6001 )?;
6002 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6003 writeln!(
6004 code,
6005 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6006 )?;
6007 writeln!(
6008 code,
6009 " enc.dispatch_thread_groups(grid_size, tg_size);"
6010 )?;
6011 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6012 writeln!(code, " }}")?;
6013 writeln!(code)?;
6014
6015 writeln!(
6017 code,
6018 " /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
6019 )?;
6020 writeln!(
6021 code,
6022 " fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6023 )?;
6024 writeln!(code, " let n: u32 = count as u32;")?;
6025 writeln!(
6026 code,
6027 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6028 )?;
6029 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6030 writeln!(
6031 code,
6032 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6033 )?;
6034 writeln!(
6035 code,
6036 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6037 )?;
6038 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6039 writeln!(
6040 code,
6041 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6042 )?;
6043 writeln!(
6044 code,
6045 " enc.dispatch_thread_groups(grid_size, tg_size);"
6046 )?;
6047 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6048 writeln!(code, " }}")?;
6049 writeln!(code)?;
6050
6051 writeln!(
6053 code,
6054 " /// Copy from src at byte offset to dst at float offset."
6055 )?;
6056 writeln!(
6057 code,
6058 " fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6059 )?;
6060 writeln!(code, " let n: u32 = count as u32;")?;
6061 writeln!(
6062 code,
6063 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6064 )?;
6065 writeln!(
6066 code,
6067 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6068 )?;
6069 writeln!(
6070 code,
6071 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6072 )?;
6073 writeln!(
6074 code,
6075 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6076 )?;
6077 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6078 writeln!(
6079 code,
6080 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6081 )?;
6082 writeln!(
6083 code,
6084 " enc.dispatch_thread_groups(grid_size, tg_size);"
6085 )?;
6086 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6087 writeln!(code, " }}")?;
6088 writeln!(code)?;
6089
6090 writeln!(
6092 code,
6093 " /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6094 )?;
6095 writeln!(
6096 code,
6097 " fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6098 )?;
6099 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6100 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6101 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6102 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6103 writeln!(code, " let so: u32 = src_offset as u32;")?;
6104 writeln!(
6105 code,
6106 " enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6107 )?;
6108 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6109 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
6110 writeln!(
6111 code,
6112 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6113 )?;
6114 writeln!(
6115 code,
6116 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6117 )?;
6118 writeln!(
6119 code,
6120 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6121 )?;
6122 writeln!(
6123 code,
6124 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6125 )?;
6126 writeln!(
6127 code,
6128 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6129 )?;
6130 writeln!(code, " let total = num_tokens * kv_dim;")?;
6131 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6132 writeln!(
6133 code,
6134 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6135 )?;
6136 writeln!(
6137 code,
6138 " enc.dispatch_thread_groups(grid_size, tg_size);"
6139 )?;
6140 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6141 writeln!(code, " }}")?;
6142 writeln!(code)?;
6143
6144 writeln!(
6146 code,
6147 " /// Dispatch batched causal attention: one dispatch for all M tokens."
6148 )?;
6149 writeln!(
6150 code,
6151 " fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6152 )?;
6153 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6154 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6155 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
6156 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6157 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6158 writeln!(code, " let qs: u32 = q_stride as u32;")?;
6159 writeln!(code, " let max_seq = base_pos + num_tokens;")?;
6169 writeln!(code, " let _ = max_seq;")?;
6170 writeln!(
6171 code,
6172 " let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
6173 )?;
6174 writeln!(code, " .map(|v| v == \"0\").unwrap_or(false);")?;
6175 writeln!(
6176 code,
6177 " let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
6178 )?;
6179 writeln!(code, " if use_mma_flash {{")?;
6180 writeln!(
6181 code,
6182 " let pipe = &self.attention_mma_flash_batch_pipeline;"
6183 )?;
6184 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6185 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6186 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6187 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6188 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6189 writeln!(
6190 code,
6191 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6192 )?;
6193 writeln!(
6194 code,
6195 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6196 )?;
6197 writeln!(
6198 code,
6199 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6200 )?;
6201 writeln!(
6202 code,
6203 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6204 )?;
6205 writeln!(
6206 code,
6207 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6208 )?;
6209 writeln!(
6210 code,
6211 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6212 )?;
6213 writeln!(
6214 code,
6215 " // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6216 )?;
6217 writeln!(code, " let tg_size = MTLSize::new(128, 1, 1);")?;
6218 writeln!(
6219 code,
6220 " let q_blocks = ((num_tokens + 7) / 8) as u64;"
6221 )?;
6222 writeln!(
6223 code,
6224 " let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6225 )?;
6226 writeln!(
6227 code,
6228 " enc.dispatch_thread_groups(grid_size, tg_size);"
6229 )?;
6230 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6231 writeln!(code, " return;")?;
6232 writeln!(code, " }}")?;
6233 writeln!(code, " let pipe = &self.attention_batch_pipeline;")?;
6234 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6235 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6236 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6237 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6238 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6239 writeln!(
6240 code,
6241 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6242 )?;
6243 writeln!(
6244 code,
6245 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6246 )?;
6247 writeln!(
6248 code,
6249 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6250 )?;
6251 writeln!(
6252 code,
6253 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6254 )?;
6255 writeln!(
6256 code,
6257 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6258 )?;
6259 writeln!(
6260 code,
6261 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6262 )?;
6263 writeln!(
6264 code,
6265 " // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6266 )?;
6267 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6268 writeln!(
6269 code,
6270 " let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6271 )?;
6272 writeln!(
6273 code,
6274 " enc.dispatch_thread_groups(grid_size, tg_size);"
6275 )?;
6276 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6277 writeln!(code, " }}")?;
6278 writeln!(code)?;
6279
6280 writeln!(
6282 code,
6283 " /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6284 )?;
6285 writeln!(
6286 code,
6287 " /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6288 )?;
6289 writeln!(
6290 code,
6291 " fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6292 )?;
6293 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6294 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6295 writeln!(code, " let nq: u32 = NUM_HEADS as u32;")?;
6296 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6297 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6298 writeln!(code, " let qs: u32 = qkv_stride as u32;")?;
6299 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
6300 writeln!(
6301 code,
6302 " enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6303 )?;
6304 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
6305 writeln!(
6306 code,
6307 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6308 )?;
6309 writeln!(
6310 code,
6311 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6312 )?;
6313 writeln!(
6314 code,
6315 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6316 )?;
6317 writeln!(
6318 code,
6319 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6320 )?;
6321 writeln!(
6322 code,
6323 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6324 )?;
6325 writeln!(
6326 code,
6327 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6328 )?;
6329 writeln!(
6330 code,
6331 " enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6332 )?;
6333 writeln!(
6334 code,
6335 " let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6336 )?;
6337 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6338 writeln!(
6339 code,
6340 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6341 )?;
6342 writeln!(
6343 code,
6344 " enc.dispatch_thread_groups(grid_size, tg_size);"
6345 )?;
6346 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6347 writeln!(code, " }}")?;
6348 writeln!(code)?;
6349
6350 writeln!(
6352 code,
6353 " /// Dispatch fused K+V cache copy in one kernel launch."
6354 )?;
6355 writeln!(
6356 code,
6357 " /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6358 )?;
6359 writeln!(
6360 code,
6361 " fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6362 )?;
6363 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6364 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6365 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6366 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6367 writeln!(code, " let ko: u32 = k_offset as u32;")?;
6368 writeln!(code, " let vo: u32 = v_offset as u32;")?;
6369 writeln!(
6370 code,
6371 " enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6372 )?;
6373 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6374 writeln!(code, " enc.set_buffer(1, Some(k_dst), 0);")?;
6375 writeln!(code, " enc.set_buffer(2, Some(v_dst), 0);")?;
6376 writeln!(
6377 code,
6378 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6379 )?;
6380 writeln!(
6381 code,
6382 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6383 )?;
6384 writeln!(
6385 code,
6386 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6387 )?;
6388 writeln!(
6389 code,
6390 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6391 )?;
6392 writeln!(
6393 code,
6394 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6395 )?;
6396 writeln!(
6397 code,
6398 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6399 )?;
6400 writeln!(
6401 code,
6402 " let total = num_tokens * kv_dim * 2; // K + V"
6403 )?;
6404 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6405 writeln!(
6406 code,
6407 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6408 )?;
6409 writeln!(
6410 code,
6411 " enc.dispatch_thread_groups(grid_size, tg_size);"
6412 )?;
6413 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6414 writeln!(code, " }}")?;
6415
6416 writeln!(code, "}}")?;
6417 writeln!(code)?;
6418
6419 Ok(())
6420}
6421
6422fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6423 writeln!(
6424 code,
6425 "// ── Helper functions ──────────────────────────────────"
6426 )?;
6427 writeln!(code)?;
6428 writeln!(
6429 code,
6430 "/// Create a compute pipeline from a named function in the Metal library."
6431 )?;
6432 writeln!(
6433 code,
6434 "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6435 )?;
6436 writeln!(
6437 code,
6438 " let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6439 )?;
6440 writeln!(
6441 code,
6442 " device.new_compute_pipeline_state_with_function(&func)"
6443 )?;
6444 writeln!(
6445 code,
6446 " .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6447 )?;
6448 writeln!(code, "}}")?;
6449 writeln!(code)?;
6450
6451 Ok(())
6452}
6453
6454fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6459 let _sanitized = sanitize_name(model_name);
6460 let _vocab = config.vocab_size;
6461
6462 let mut code = String::with_capacity(16 * 1024);
6463 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6464 writeln!(
6465 code,
6466 "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6467 )?;
6468 writeln!(code)?;
6469 writeln!(code, "mod model;")?;
6470 writeln!(code)?;
6471 writeln!(code, "use std::io::Write;")?;
6472 writeln!(code, "use std::time::Instant;")?;
6473 writeln!(code, "use serde::Deserialize;")?;
6474 writeln!(code)?;
6475
6476 writeln!(code, "fn main() {{")?;
6478 writeln!(
6479 code,
6480 " let args: Vec<String> = std::env::args().collect();"
6481 )?;
6482 writeln!(code)?;
6483 writeln!(
6484 code,
6485 " // Detect --serve mode (only requires weights + tokenizer)"
6486 )?;
6487 writeln!(
6488 code,
6489 " let serve_mode = args.iter().any(|a| a == \"--serve\");"
6490 )?;
6491 writeln!(code)?;
6492 writeln!(code, " if !serve_mode && args.len() < 4 {{")?;
6493 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6494 writeln!(code, " eprintln!(\" {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6495 writeln!(code, " std::process::exit(1);")?;
6496 writeln!(code, " }}")?;
6497 writeln!(code)?;
6498 writeln!(code, " if serve_mode && args.len() < 3 {{")?;
6499 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6500 writeln!(code, " std::process::exit(1);")?;
6501 writeln!(code, " }}")?;
6502 writeln!(code)?;
6503 writeln!(code, " let weights_path = &args[1];")?;
6504 writeln!(code, " let tokenizer_path = &args[2];")?;
6505 writeln!(code)?;
6506 writeln!(code, " // Parse optional flags")?;
6507 writeln!(code, " let mut max_tokens: usize = 128;")?;
6508 writeln!(code, " let mut port: u16 = 8080;")?;
6509 writeln!(
6510 code,
6511 " let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6512 )?;
6513 writeln!(
6514 code,
6515 " let profile = args.iter().any(|a| a == \"--profile\");"
6516 )?;
6517 writeln!(code, " let mut i = 3;")?;
6518 writeln!(code, " while i < args.len() {{")?;
6519 writeln!(
6520 code,
6521 " if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6522 )?;
6523 writeln!(
6524 code,
6525 " max_tokens = args[i + 1].parse().unwrap_or(128);"
6526 )?;
6527 writeln!(code, " i += 2;")?;
6528 writeln!(
6529 code,
6530 " }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6531 )?;
6532 writeln!(
6533 code,
6534 " port = args[i + 1].parse().unwrap_or(8080);"
6535 )?;
6536 writeln!(code, " i += 2;")?;
6537 writeln!(code, " }} else if args[i] == \"--serve\" {{")?;
6538 writeln!(code, " i += 1;")?;
6539 writeln!(code, " }} else if args[i] == \"--profile\" {{")?;
6540 writeln!(code, " i += 1;")?;
6541 writeln!(code, " }} else {{")?;
6542 writeln!(code, " i += 1;")?;
6543 writeln!(code, " }}")?;
6544 writeln!(code, " }}")?;
6545 writeln!(code)?;
6546
6547 writeln!(
6549 code,
6550 " // Memory-map weights for zero-copy loading on Apple Silicon"
6551 )?;
6552 writeln!(
6553 code,
6554 " let weights_file = std::fs::File::open(weights_path)"
6555 )?;
6556 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6557 writeln!(
6558 code,
6559 " let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6560 )?;
6561 writeln!(code)?;
6562 writeln!(code, " // Load tokenizer")?;
6563 writeln!(
6564 code,
6565 " let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6566 )?;
6567 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6568 writeln!(code)?;
6569 writeln!(code, " // Create Metal model")?;
6570 writeln!(code, " eprintln!(\"Loading model onto Metal GPU...\");")?;
6571 writeln!(
6572 code,
6573 " let mut model = model::MetalModel::new(&weights_mmap);"
6574 )?;
6575 writeln!(code)?;
6576
6577 writeln!(code, " if serve_mode {{")?;
6579 writeln!(code, " serve(model, tokenizer, port);")?;
6580 writeln!(code, " }} else {{")?;
6581 writeln!(code, " let prompt = &args[3];")?;
6582 writeln!(
6583 code,
6584 " cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6585 )?;
6586 writeln!(code, " }}")?;
6587 writeln!(code, "}}")?;
6588 writeln!(code)?;
6589
6590 writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6592 writeln!(code, " // Tokenize prompt")?;
6593 writeln!(code, " let encoding = tokenizer.encode(prompt, true)")?;
6594 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6595 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6596 writeln!(code)?;
6597 writeln!(
6598 code,
6599 " // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6600 )?;
6601 writeln!(
6602 code,
6603 " // Uses double-buffered batch dispatch for GPU-efficient matmul."
6604 )?;
6605 writeln!(
6606 code,
6607 " // The last token uses synchronous forward() to get logits."
6608 )?;
6609 writeln!(code, " let prompt_len = prompt_tokens.len();")?;
6610 writeln!(code, " let prefill_start = Instant::now();")?;
6611 writeln!(code, " let logits = if prompt_len > 1 {{")?;
6612 writeln!(
6613 code,
6614 " model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6615 )?;
6616 writeln!(code, " model.forward(prompt_tokens[prompt_len - 1])")?;
6617 writeln!(code, " }} else {{")?;
6618 writeln!(code, " model.forward(prompt_tokens[0])")?;
6619 writeln!(code, " }};")?;
6620 writeln!(
6621 code,
6622 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6623 )?;
6624 writeln!(code, " let prefill_tokens = prompt_tokens.len();")?;
6625 writeln!(
6626 code,
6627 " eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6628 )?;
6629 writeln!(
6630 code,
6631 " prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6632 )?;
6633 writeln!(code)?;
6634 writeln!(code, " // Generate tokens")?;
6635 writeln!(code, " let mut next_token = argmax(&logits);")?;
6636 writeln!(code, " let gen_start = Instant::now();")?;
6637 writeln!(code, " let mut generated_count: usize = 0;")?;
6638 writeln!(code)?;
6639 writeln!(code, " for _ in 0..max_tokens {{")?;
6640 writeln!(
6641 code,
6642 " if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6643 )?;
6644 writeln!(code, " if !quiet {{")?;
6645 writeln!(code, " print!(\"{{}}\", text);")?;
6646 writeln!(code, " std::io::stdout().flush().ok();")?;
6647 writeln!(code, " }}")?;
6648 writeln!(code, " }}")?;
6649 writeln!(code, " generated_count += 1;")?;
6650 writeln!(code)?;
6651 writeln!(
6652 code,
6653 " // Use profiling forward for first token when --profile is set"
6654 )?;
6655 writeln!(
6656 code,
6657 " let logits = if profile && generated_count == 1 {{"
6658 )?;
6659 writeln!(code, " model.forward_profile(next_token)")?;
6660 writeln!(code, " }} else {{")?;
6661 writeln!(code, " model.forward(next_token)")?;
6662 writeln!(code, " }};")?;
6663 writeln!(code, " next_token = argmax(&logits);")?;
6664 writeln!(code)?;
6665 writeln!(code, " // Stop on EOS (token 2 for most models)")?;
6666 writeln!(code, " if next_token == 2 {{")?;
6667 writeln!(code, " break;")?;
6668 writeln!(code, " }}")?;
6669 writeln!(code)?;
6670 writeln!(
6671 code,
6672 " // Yield between tokens to reduce sustained GPU thermal load."
6673 )?;
6674 writeln!(
6675 code,
6676 " // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6677 )?;
6678 writeln!(
6679 code,
6680 " // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6681 )?;
6682 writeln!(
6683 code,
6684 " // briefly, providing a micro-break that helps sustain peak throughput."
6685 )?;
6686 writeln!(code, " std::thread::yield_now();")?;
6687 writeln!(code, " }}")?;
6688 writeln!(code, " if !quiet {{")?;
6689 writeln!(code, " println!();")?;
6690 writeln!(code, " }}")?;
6691 writeln!(
6692 code,
6693 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6694 )?;
6695 writeln!(
6696 code,
6697 " eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6698 )?;
6699 writeln!(
6700 code,
6701 " generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6702 )?;
6703 writeln!(code, "}}")?;
6704 writeln!(code)?;
6705
6706 writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6708 writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6709 writeln!(code, " logits.iter()")?;
6710 writeln!(code, " .enumerate()")?;
6711 writeln!(
6712 code,
6713 " .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6714 )?;
6715 writeln!(code, " .map(|(i, _)| i as u32)")?;
6716 writeln!(code, " .unwrap_or(0)")?;
6717 writeln!(code, "}}")?;
6718 writeln!(code)?;
6719
6720 writeln!(
6722 code,
6723 "// -----------------------------------------------------------------------"
6724 )?;
6725 writeln!(code, "// OpenAI-compatible API server")?;
6726 writeln!(
6727 code,
6728 "// -----------------------------------------------------------------------"
6729 )?;
6730 writeln!(code)?;
6731 writeln!(code, "#[derive(Deserialize)]")?;
6732 writeln!(code, "struct ChatRequest {{")?;
6733 writeln!(code, " messages: Vec<ChatMessage>,")?;
6734 writeln!(code, " #[serde(default)]")?;
6735 writeln!(code, " stream: Option<bool>,")?;
6736 writeln!(code, " #[serde(default)]")?;
6737 writeln!(code, " max_tokens: Option<usize>,")?;
6738 writeln!(code, " #[serde(default)]")?;
6739 writeln!(code, " temperature: Option<f32>,")?;
6740 writeln!(code, " #[serde(default)]")?;
6741 writeln!(code, " model: Option<String>,")?;
6742 writeln!(code, "}}")?;
6743 writeln!(code)?;
6744 writeln!(code, "#[derive(Deserialize)]")?;
6745 writeln!(code, "struct ChatMessage {{")?;
6746 writeln!(code, " role: String,")?;
6747 writeln!(code, " content: String,")?;
6748 writeln!(code, "}}")?;
6749 writeln!(code)?;
6750
6751 writeln!(
6753 code,
6754 "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6755 )?;
6756 writeln!(code, " let mut prompt = String::new();")?;
6757 writeln!(code, " for msg in messages {{")?;
6758 writeln!(code, " prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6759 writeln!(code, " }}")?;
6760 writeln!(code, " prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6761 writeln!(code, " prompt")?;
6762 writeln!(code, "}}")?;
6763 writeln!(code)?;
6764
6765 writeln!(
6767 code,
6768 "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6769 )?;
6770 writeln!(code, " let len = tokens.len();")?;
6771 writeln!(code, " if len > 1 {{")?;
6772 writeln!(
6773 code,
6774 " model.forward_prefill_batch(&tokens[..len - 1]);"
6775 )?;
6776 writeln!(code, " }}")?;
6777 writeln!(code, " model.forward(tokens[len - 1])")?;
6778 writeln!(code, "}}")?;
6779 writeln!(code)?;
6780
6781 writeln!(
6783 code,
6784 "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6785 )?;
6786 writeln!(code, " let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6787 writeln!(code, " let server = tiny_http::Server::http(&addr)")?;
6788 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6789 writeln!(
6790 code,
6791 " eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6792 )?;
6793 writeln!(code, " eprintln!(\"Endpoints:\");")?;
6794 writeln!(code, " eprintln!(\" POST /v1/chat/completions\");")?;
6795 writeln!(code, " eprintln!(\" GET /v1/models\");")?;
6796 writeln!(code, " eprintln!(\" GET /health\");")?;
6797 writeln!(code)?;
6798 writeln!(code, " for request in server.incoming_requests() {{")?;
6799 writeln!(code, " let method = request.method().to_string();")?;
6800 writeln!(code, " let url = request.url().to_string();")?;
6801 writeln!(code)?;
6802 writeln!(code, " match (method.as_str(), url.as_str()) {{")?;
6803
6804 writeln!(
6806 code,
6807 " (\"POST\", \"/v1/chat/completions\") => {{"
6808 )?;
6809 writeln!(
6810 code,
6811 " handle_chat_completion(&mut model, &tokenizer, request);"
6812 )?;
6813 writeln!(code, " }}")?;
6814
6815 writeln!(code, " (\"GET\", \"/v1/models\") => {{")?;
6817 writeln!(code, " let body = serde_json::json!({{")?;
6818 writeln!(code, " \"object\": \"list\",")?;
6819 writeln!(code, " \"data\": [{{")?;
6820 writeln!(code, " \"id\": \"forgellm-metal\",")?;
6821 writeln!(code, " \"object\": \"model\",")?;
6822 writeln!(code, " \"owned_by\": \"forgellm\"")?;
6823 writeln!(code, " }}]")?;
6824 writeln!(code, " }});")?;
6825 writeln!(
6826 code,
6827 " let resp = tiny_http::Response::from_string(body.to_string())"
6828 )?;
6829 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6830 writeln!(code, " request.respond(resp).ok();")?;
6831 writeln!(code, " }}")?;
6832
6833 writeln!(code, " (\"GET\", \"/health\") => {{")?;
6835 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6836 writeln!(code, " request.respond(resp).ok();")?;
6837 writeln!(code, " }}")?;
6838
6839 writeln!(code, " _ => {{")?;
6841 writeln!(
6842 code,
6843 " let resp = tiny_http::Response::from_string(\"Not Found\")"
6844 )?;
6845 writeln!(code, " .with_status_code(404);")?;
6846 writeln!(code, " request.respond(resp).ok();")?;
6847 writeln!(code, " }}")?;
6848 writeln!(code, " }}")?;
6849 writeln!(code, " }}")?;
6850 writeln!(code, "}}")?;
6851 writeln!(code)?;
6852
6853 writeln!(code, "fn handle_chat_completion(")?;
6855 writeln!(code, " model: &mut model::MetalModel,")?;
6856 writeln!(code, " tokenizer: &tokenizers::Tokenizer,")?;
6857 writeln!(code, " mut request: tiny_http::Request,")?;
6858 writeln!(code, ") {{")?;
6859 writeln!(code, " // Read request body")?;
6860 writeln!(code, " let mut body = String::new();")?;
6861 writeln!(
6862 code,
6863 " if request.as_reader().read_to_string(&mut body).is_err() {{"
6864 )?;
6865 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6866 writeln!(code, " .with_status_code(400);")?;
6867 writeln!(code, " request.respond(resp).ok();")?;
6868 writeln!(code, " return;")?;
6869 writeln!(code, " }}")?;
6870 writeln!(code)?;
6871 writeln!(code, " // Parse JSON")?;
6872 writeln!(
6873 code,
6874 " let req: ChatRequest = match serde_json::from_str(&body) {{"
6875 )?;
6876 writeln!(code, " Ok(r) => r,")?;
6877 writeln!(code, " Err(e) => {{")?;
6878 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6879 writeln!(code, " .with_status_code(400);")?;
6880 writeln!(code, " request.respond(resp).ok();")?;
6881 writeln!(code, " return;")?;
6882 writeln!(code, " }}")?;
6883 writeln!(code, " }};")?;
6884 writeln!(code)?;
6885 writeln!(
6886 code,
6887 " let prompt = format_chat_messages(&req.messages);"
6888 )?;
6889 writeln!(
6890 code,
6891 " let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6892 )?;
6893 writeln!(code, " Ok(e) => e,")?;
6894 writeln!(code, " Err(e) => {{")?;
6895 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6896 writeln!(code, " .with_status_code(500);")?;
6897 writeln!(code, " request.respond(resp).ok();")?;
6898 writeln!(code, " return;")?;
6899 writeln!(code, " }}")?;
6900 writeln!(code, " }};")?;
6901 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6902 writeln!(code, " let stream = req.stream.unwrap_or(false);")?;
6903 writeln!(code, " let max_tokens = req.max_tokens.unwrap_or(256);")?;
6904 writeln!(
6905 code,
6906 " let _temperature = req.temperature.unwrap_or(1.0);"
6907 )?;
6908 writeln!(code)?;
6909
6910 writeln!(code, " model.reset();")?;
6912 writeln!(code)?;
6913
6914 writeln!(code, " let prefill_start = Instant::now();")?;
6916 writeln!(code, " let logits = prefill(model, prompt_tokens);")?;
6917 writeln!(
6918 code,
6919 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6920 )?;
6921 writeln!(code, " let prefill_count = prompt_tokens.len();")?;
6922 writeln!(code, " let mut next_token = argmax(&logits);")?;
6923 writeln!(code)?;
6924
6925 writeln!(code, " if stream {{")?;
6926
6927 writeln!(
6929 code,
6930 " // SSE streaming: generate tokens and build SSE body"
6931 )?;
6932 writeln!(code, " let gen_start = Instant::now();")?;
6933 writeln!(code, " let mut generated_count: usize = 0;")?;
6934 writeln!(code, " let mut sse_body = String::new();")?;
6935 writeln!(code, " for _ in 0..max_tokens {{")?;
6936 writeln!(code, " if next_token == 2 {{ break; }}")?;
6937 writeln!(
6938 code,
6939 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6940 )?;
6941 writeln!(
6942 code,
6943 " let escaped = serde_json::to_string(&text).unwrap_or_default();"
6944 )?;
6945 writeln!(
6946 code,
6947 " // escaped includes surrounding quotes, strip them"
6948 )?;
6949 writeln!(
6950 code,
6951 " let inner = &escaped[1..escaped.len()-1];"
6952 )?;
6953 writeln!(code, " sse_body.push_str(&format!(")?;
6954 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6955 writeln!(code, " inner")?;
6956 writeln!(code, " ));")?;
6957 writeln!(code, " }}")?;
6958 writeln!(code, " generated_count += 1;")?;
6959 writeln!(code, " let logits = model.forward(next_token);")?;
6960 writeln!(code, " next_token = argmax(&logits);")?;
6961 writeln!(code, " }}")?;
6962 writeln!(
6963 code,
6964 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6965 )?;
6966 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6967 writeln!(code, " let gen_time_ms = gen_elapsed * 1000.0;")?;
6968 writeln!(code)?;
6969 writeln!(
6970 code,
6971 " // Final chunk with finish_reason, timing, and DONE sentinel"
6972 )?;
6973 writeln!(code, " sse_body.push_str(&format!(")?;
6974 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6975 writeln!(code, " prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6976 writeln!(code, " ));")?;
6977 writeln!(code)?;
6978 writeln!(
6979 code,
6980 " let resp = tiny_http::Response::from_string(sse_body)"
6981 )?;
6982 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6983 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6984 writeln!(code, " request.respond(resp).ok();")?;
6985
6986 writeln!(code, " }} else {{")?;
6987
6988 writeln!(
6990 code,
6991 " // Non-streaming: generate all tokens, return JSON"
6992 )?;
6993 writeln!(code, " let gen_start = Instant::now();")?;
6994 writeln!(code, " let mut generated_count: usize = 0;")?;
6995 writeln!(code, " let mut generated = String::new();")?;
6996 writeln!(code, " for _ in 0..max_tokens {{")?;
6997 writeln!(code, " if next_token == 2 {{ break; }}")?;
6998 writeln!(
6999 code,
7000 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
7001 )?;
7002 writeln!(code, " generated.push_str(&text);")?;
7003 writeln!(code, " }}")?;
7004 writeln!(code, " generated_count += 1;")?;
7005 writeln!(code, " let logits = model.forward(next_token);")?;
7006 writeln!(code, " next_token = argmax(&logits);")?;
7007 writeln!(code, " }}")?;
7008 writeln!(
7009 code,
7010 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
7011 )?;
7012 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7013 writeln!(code)?;
7014 writeln!(code, " let resp_json = serde_json::json!({{")?;
7015 writeln!(code, " \"id\": \"chatcmpl-1\",")?;
7016 writeln!(code, " \"object\": \"chat.completion\",")?;
7017 writeln!(code, " \"choices\": [{{")?;
7018 writeln!(code, " \"index\": 0,")?;
7019 writeln!(code, " \"message\": {{")?;
7020 writeln!(code, " \"role\": \"assistant\",")?;
7021 writeln!(code, " \"content\": generated")?;
7022 writeln!(code, " }},")?;
7023 writeln!(code, " \"finish_reason\": \"stop\"")?;
7024 writeln!(code, " }}],")?;
7025 writeln!(code, " \"usage\": {{")?;
7026 writeln!(code, " \"prefill_tokens\": prefill_count,")?;
7027 writeln!(
7028 code,
7029 " \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
7030 )?;
7031 writeln!(
7032 code,
7033 " \"generation_tokens\": generated_count,"
7034 )?;
7035 writeln!(
7036 code,
7037 " \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
7038 )?;
7039 writeln!(code, " \"tokens_per_sec\": gen_tok_s")?;
7040 writeln!(code, " }}")?;
7041 writeln!(code, " }});")?;
7042 writeln!(
7043 code,
7044 " let resp = tiny_http::Response::from_string(resp_json.to_string())"
7045 )?;
7046 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7047 writeln!(code, " request.respond(resp).ok();")?;
7048 writeln!(code, " }}")?;
7049 writeln!(code, "}}")?;
7050
7051 Ok(code)
7052}
7053
7054#[cfg(test)]
7059mod tests {
7060 use super::*;
7061 use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7062
7063 fn minimal_config() -> ModelConfig {
7064 ModelConfig {
7065 architecture: Architecture::Llama,
7066 hidden_size: 64,
7067 intermediate_size: 128,
7068 num_layers: 2,
7069 num_attention_heads: 4,
7070 num_kv_heads: 4,
7071 head_dim: 16,
7072 vocab_size: 256,
7073 max_seq_len: 512,
7074 rms_norm_eps: 1e-5,
7075 rope_theta: 10000.0,
7076 dtype: DType::F32,
7077 sliding_window_size: None,
7078 qkv_bias: false,
7079 }
7080 }
7081
7082 fn minimal_graph() -> Graph {
7083 Graph::new("test-metal").with_config(minimal_config())
7084 }
7085
7086 #[test]
7087 fn generate_metal_project_creates_files() {
7088 let dir = tempfile::tempdir().unwrap();
7089 let graph = minimal_graph();
7090 generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7091
7092 assert!(
7093 dir.path().join("Cargo.toml").exists(),
7094 "Cargo.toml should be created"
7095 );
7096 assert!(
7097 dir.path().join("src/model.rs").exists(),
7098 "src/model.rs should be created"
7099 );
7100 assert!(
7101 dir.path().join("src/main.rs").exists(),
7102 "src/main.rs should be created"
7103 );
7104 assert!(
7105 dir.path().join("shaders/kernels.metal").exists(),
7106 "shaders/kernels.metal should be created"
7107 );
7108 }
7109
7110 #[test]
7111 fn generated_cargo_toml_has_metal_dep() {
7112 let toml = generate_cargo_toml("my-model");
7113 assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7114 assert!(
7115 toml.contains("tokenizers"),
7116 "Cargo.toml should depend on tokenizers"
7117 );
7118 assert!(
7119 toml.contains("memmap2"),
7120 "Cargo.toml should depend on memmap2"
7121 );
7122 assert!(toml.contains("half"), "Cargo.toml should depend on half");
7123 }
7124
7125 #[test]
7126 fn generated_model_rs_contains_metal_code() {
7127 let config = minimal_config();
7128 let model_rs = generate_model_rs(&config).unwrap();
7129
7130 assert!(
7131 model_rs.contains("pub struct MetalModel"),
7132 "model.rs should define MetalModel struct"
7133 );
7134 assert!(
7135 model_rs.contains("matmul_pipeline: ComputePipelineState"),
7136 "MetalModel should have matmul_pipeline field"
7137 );
7138 assert!(
7139 model_rs.contains("Device::system_default()"),
7140 "model.rs should use Metal device"
7141 );
7142 assert!(
7143 model_rs.contains("new_library_with_source"),
7144 "model.rs should compile Metal shaders"
7145 );
7146 assert!(
7147 model_rs.contains("fn new(weights: &[u8])"),
7148 "MetalModel should implement new()"
7149 );
7150 assert!(
7151 model_rs.contains("fn forward(&mut self, token_id: u32)"),
7152 "MetalModel should implement forward()"
7153 );
7154 }
7155
7156 #[test]
7157 fn generated_shaders_contain_kernel_names() {
7158 let shaders = generate_metal_shaders(&minimal_config());
7159
7160 assert!(
7161 shaders.contains("kernel void matmul_vec"),
7162 "shaders should contain matmul_vec kernel"
7163 );
7164 assert!(
7165 shaders.contains("kernel void rms_norm"),
7166 "shaders should contain rms_norm kernel"
7167 );
7168 assert!(
7169 shaders.contains("kernel void rope"),
7170 "shaders should contain rope kernel"
7171 );
7172 assert!(
7173 shaders.contains("kernel void softmax"),
7174 "shaders should contain softmax kernel"
7175 );
7176 assert!(
7177 shaders.contains("kernel void silu_mul("),
7178 "shaders should contain silu_mul kernel"
7179 );
7180 assert!(
7181 shaders.contains("kernel void silu_mul_fused"),
7182 "shaders should contain silu_mul_fused kernel"
7183 );
7184 assert!(
7185 shaders.contains("kernel void elementwise_add"),
7186 "shaders should contain elementwise_add kernel"
7187 );
7188 assert!(
7189 shaders.contains("kernel void attention"),
7190 "shaders should contain attention kernel"
7191 );
7192 assert!(
7193 shaders.contains("kernel void add_inplace"),
7194 "shaders should contain add_inplace kernel"
7195 );
7196 assert!(
7197 shaders.contains("kernel void copy_buffer"),
7198 "shaders should contain copy_buffer kernel"
7199 );
7200 assert!(
7201 shaders.contains("kernel void copy_offset"),
7202 "shaders should contain copy_offset kernel"
7203 );
7204 }
7205
7206 #[test]
7207 fn generated_shaders_use_simdgroup_features() {
7208 let shaders = generate_metal_shaders(&minimal_config());
7209
7210 assert!(
7211 shaders.contains("threadgroup_barrier"),
7212 "shaders should use threadgroup barriers"
7213 );
7214 assert!(
7215 shaders.contains("threadgroup float"),
7216 "shaders should use threadgroup shared memory"
7217 );
7218 assert!(
7219 shaders.contains("thread_index_in_threadgroup"),
7220 "shaders should use threadgroup indexing"
7221 );
7222 assert!(
7223 shaders.contains("simd_sum"),
7224 "shaders should use simd_sum for warp-level reduction"
7225 );
7226 assert!(
7227 shaders.contains("simd_max"),
7228 "attention kernel should use simd_max for cooperative softmax"
7229 );
7230 assert!(
7231 shaders.contains("thread_index_in_simdgroup"),
7232 "shaders should use simdgroup lane indexing"
7233 );
7234 assert!(
7235 shaders.contains("simdgroup_index_in_threadgroup"),
7236 "shaders should use simdgroup indexing within threadgroup"
7237 );
7238 assert!(
7239 shaders.contains("float4"),
7240 "matmul_vec should use float4 vectorized loads"
7241 );
7242 }
7243
7244 #[test]
7245 fn generated_main_rs_has_tokenizer_usage() {
7246 let config = minimal_config();
7247 let main_rs = generate_main_rs("test-model", &config).unwrap();
7248
7249 assert!(
7250 main_rs.contains("tokenizers::Tokenizer"),
7251 "main.rs should use tokenizers crate"
7252 );
7253 assert!(
7254 main_rs.contains("MetalModel::new"),
7255 "main.rs should call MetalModel::new"
7256 );
7257 assert!(
7258 main_rs.contains("model.forward"),
7259 "main.rs should call model.forward"
7260 );
7261 assert!(
7262 main_rs.contains("memmap2"),
7263 "main.rs should use memmap2 for zero-copy weight loading"
7264 );
7265 }
7266
7267 #[test]
7268 fn missing_config_returns_error() {
7269 let dir = tempfile::tempdir().unwrap();
7270 let graph = Graph::new("no-config");
7271 let result = generate_metal_project(&graph, dir.path(), "fail");
7272 assert!(
7273 matches!(result, Err(MetalCodegenError::MissingConfig)),
7274 "should fail with MissingConfig when graph has no config"
7275 );
7276 }
7277
7278 #[test]
7279 fn sanitize_name_works() {
7280 assert_eq!(sanitize_name("My Model!"), "my-model");
7281 assert_eq!(sanitize_name("test_model"), "test-model");
7282 assert_eq!(sanitize_name("simple"), "simple");
7283 }
7284
7285 #[test]
7286 fn generated_forward_uses_single_command_buffer() {
7287 let config = minimal_config();
7288 let model_rs = generate_model_rs(&config).unwrap();
7289
7290 let forward_start = model_rs
7293 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7294 .unwrap();
7295 let forward_body = &model_rs[forward_start..];
7296 let forward_end = forward_body
7298 .find("\n pub fn forward_profile")
7299 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7300 .or_else(|| forward_body.find("\n fn dispatch_"))
7301 .unwrap_or(forward_body.len());
7302 let forward_code = &forward_body[..forward_end];
7303
7304 let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7306 assert_eq!(
7307 cmd_buf_count, 1,
7308 "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7309 );
7310
7311 let commit_count = forward_code.matches("cmd.commit()").count();
7313 assert_eq!(
7314 commit_count, 1,
7315 "forward() should commit exactly once, found {commit_count}"
7316 );
7317
7318 let wait_count = forward_code.matches("wait_until_completed()").count();
7320 assert!(
7321 wait_count >= 1 && wait_count <= 2,
7322 "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7323 );
7324 }
7325
7326 #[test]
7327 fn generated_model_has_preallocated_working_buffers() {
7328 let config = minimal_config();
7329 let model_rs = generate_model_rs(&config).unwrap();
7330
7331 for buf_name in &[
7332 "normed_buf",
7333 "qkv_buf",
7334 "attn_out_buf",
7335 "attn_proj_buf",
7336 "gate_up_buf",
7337 "ffn_hidden_buf",
7338 "ffn_out_buf",
7339 "add_tmp_buf",
7340 ] {
7341 assert!(
7342 model_rs.contains(&format!("{buf_name}: Buffer")),
7343 "MetalModel should have pre-allocated {buf_name} field"
7344 );
7345 }
7346 }
7347
7348 #[test]
7349 fn generated_dispatch_helpers_take_compute_encoder_ref() {
7350 let config = minimal_config();
7351 let model_rs = generate_model_rs(&config).unwrap();
7352
7353 for method in &[
7354 "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7355 "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7356 "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7357 "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7358 "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7359 "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7360 "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7361 "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7362 "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7363 "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7364 "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7365 "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7366 "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7367 ] {
7368 assert!(
7369 model_rs.contains(method),
7370 "model.rs should contain dispatch helper: {method}"
7371 );
7372 }
7373 }
7374
7375 #[test]
7376 fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7377 let config = minimal_config();
7378 let model_rs = generate_model_rs(&config).unwrap();
7379
7380 let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7382 let helpers_code = &model_rs[helpers_start..];
7383
7384 assert!(
7386 !helpers_code.contains("self.queue.new_command_buffer()"),
7387 "dispatch helpers should not create their own command buffers"
7388 );
7389
7390 assert!(
7392 !helpers_code.contains("new_compute_command_encoder()"),
7393 "dispatch helpers should not create their own compute encoders"
7394 );
7395
7396 assert!(
7398 !helpers_code.contains("end_encoding()"),
7399 "dispatch helpers should not call end_encoding"
7400 );
7401
7402 assert!(
7404 !helpers_code.contains(".commit()"),
7405 "dispatch helpers should not commit command buffers"
7406 );
7407 assert!(
7408 !helpers_code.contains("wait_until_completed"),
7409 "dispatch helpers should not wait on command buffers"
7410 );
7411 }
7412
7413 #[test]
7414 fn generated_forward_batches_compute_encoders() {
7415 let config = minimal_config();
7416 let model_rs = generate_model_rs(&config).unwrap();
7417
7418 let forward_start = model_rs
7420 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7421 .unwrap();
7422 let forward_body = &model_rs[forward_start..];
7423 let forward_end = forward_body
7424 .find("\n pub fn forward_profile")
7425 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7426 .or_else(|| forward_body.find("\n fn dispatch_"))
7427 .unwrap_or(forward_body.len());
7428 let forward_code = &forward_body[..forward_end];
7429
7430 assert!(
7432 !forward_code.contains("device.new_buffer"),
7433 "forward() should not allocate new buffers per call"
7434 );
7435
7436 let compute_encoder_count = forward_code
7439 .matches("new_compute_command_encoder()")
7440 .count();
7441 let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7442
7443 assert_eq!(
7445 compute_encoder_count, 1,
7446 "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7447 );
7448 assert_eq!(
7449 blit_encoder_count, 0,
7450 "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7451 );
7452 }
7453
7454 #[test]
7455 fn generated_forward_uses_add_inplace() {
7456 let config = minimal_config();
7457 let model_rs = generate_model_rs(&config).unwrap();
7458
7459 assert!(
7461 model_rs.contains("dispatch_add_inplace"),
7462 "forward() should use dispatch_add_inplace for residual connections"
7463 );
7464 assert!(
7465 model_rs.contains("add_inplace_pipeline"),
7466 "MetalModel should have add_inplace_pipeline"
7467 );
7468 }
7469
7470 fn minimal_q8_config() -> ModelConfig {
7471 ModelConfig {
7472 architecture: Architecture::Llama,
7473 hidden_size: 64,
7474 intermediate_size: 128,
7475 num_layers: 2,
7476 num_attention_heads: 4,
7477 num_kv_heads: 4,
7478 head_dim: 16,
7479 vocab_size: 256,
7480 max_seq_len: 512,
7481 rms_norm_eps: 1e-5,
7482 rope_theta: 10000.0,
7483 dtype: DType::Q8_0,
7484 sliding_window_size: None,
7485 qkv_bias: false,
7486 }
7487 }
7488
7489 #[test]
7490 fn generated_shaders_contain_q8_kernel() {
7491 let shaders = generate_metal_shaders(&minimal_config());
7492
7493 assert!(
7494 shaders.contains("kernel void matmul_vec_q8"),
7495 "shaders should contain matmul_vec_q8 kernel"
7496 );
7497 assert!(
7498 shaders.contains("device const uchar* matrix"),
7499 "matmul_vec_q8 should accept raw Q8_0 bytes"
7500 );
7501 assert!(
7502 shaders.contains("packed_short4"),
7503 "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7504 );
7505 assert!(
7506 shaders.contains("as_type<char2>"),
7507 "matmul_vec_q8 should bitcast short lanes to char2"
7508 );
7509 assert!(
7510 shaders.contains("device const half*"),
7511 "matmul_vec_q8 should read f16 scale via half pointer"
7512 );
7513 }
7514
7515 #[test]
7516 fn generated_model_uses_fused_qkv_projections() {
7517 let config = minimal_config();
7518 let model_rs = generate_model_rs(&config).unwrap();
7519
7520 assert!(
7522 model_rs.contains("qkv_weight: Buffer"),
7523 "LayerBuffers should have fused qkv_weight field"
7524 );
7525 assert!(
7527 !model_rs.contains(" q_weight: Buffer"),
7528 "LayerBuffers should not have separate q_weight field"
7529 );
7530 assert!(
7531 !model_rs.contains(" k_weight: Buffer"),
7532 "LayerBuffers should not have separate k_weight field"
7533 );
7534 assert!(
7535 !model_rs.contains(" v_weight: Buffer"),
7536 "LayerBuffers should not have separate v_weight field"
7537 );
7538
7539 assert!(
7541 model_rs.contains("gate_up_weight: Buffer"),
7542 "LayerBuffers should have fused gate_up_weight field"
7543 );
7544 assert!(
7546 !model_rs.contains(" gate_weight: Buffer"),
7547 "LayerBuffers should not have separate gate_weight field"
7548 );
7549 assert!(
7550 !model_rs.contains(" up_weight: Buffer"),
7551 "LayerBuffers should not have separate up_weight field"
7552 );
7553
7554 assert!(
7556 model_rs.contains("qkv_buf: Buffer"),
7557 "MetalModel should have fused qkv_buf"
7558 );
7559 assert!(
7560 model_rs.contains("gate_up_buf: Buffer"),
7561 "MetalModel should have fused gate_up_buf"
7562 );
7563
7564 assert!(
7566 model_rs.contains("dispatch_silu_mul_fused"),
7567 "forward pass should use dispatch_silu_mul_fused"
7568 );
7569 assert!(
7570 model_rs.contains("dispatch_rope_offset"),
7571 "forward pass should use dispatch_rope_offset for fused QKV"
7572 );
7573 assert!(
7574 model_rs.contains("dispatch_attention_offset"),
7575 "forward pass should use dispatch_attention_offset for fused QKV"
7576 );
7577 }
7578
7579 #[test]
7580 fn q8_model_has_matmul_q8_pipeline() {
7581 let config = minimal_q8_config();
7582 let model_rs = generate_model_rs(&config).unwrap();
7583
7584 assert!(
7585 model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7586 "MetalModel should have matmul_q8_pipeline field"
7587 );
7588 assert!(
7589 model_rs.contains("matmul_q8_pipeline,"),
7590 "MetalModel Self should include matmul_q8_pipeline"
7591 );
7592 }
7593
7594 #[test]
7595 fn q8_model_uses_dispatch_matmul_q8() {
7596 let config = minimal_q8_config();
7597 let model_rs = generate_model_rs(&config).unwrap();
7598
7599 assert!(
7600 model_rs.contains("dispatch_matmul_q8"),
7601 "Q8_0 model should use dispatch_matmul_q8 for projections"
7602 );
7603 assert!(
7604 model_rs.contains("fn dispatch_matmul_q8"),
7605 "model.rs should define dispatch_matmul_q8 method"
7606 );
7607 }
7608
7609 #[test]
7610 fn q8_model_loads_raw_bytes_not_dequantized() {
7611 let config = minimal_q8_config();
7612 let model_rs = generate_model_rs(&config).unwrap();
7613
7614 assert!(
7616 !model_rs.contains("f16_to_f32"),
7617 "Q8_0 model should not dequantize weights to f32"
7618 );
7619 assert!(
7620 !model_rs.contains("f32_data"),
7621 "Q8_0 model should not create f32 weight data"
7622 );
7623
7624 assert!(
7626 model_rs.contains("total_raw as u64"),
7627 "Q8_0 model should load raw bytes into Metal buffer"
7628 );
7629 }
7630
7631 #[test]
7632 fn q8_model_norms_stay_f32() {
7633 let config = minimal_q8_config();
7634 let model_rs = generate_model_rs(&config).unwrap();
7635
7636 assert!(
7638 model_rs.contains("let attn_norm = next_f32_buffer"),
7639 "attn_norm should use f32 buffer even for Q8_0 models"
7640 );
7641 assert!(
7642 model_rs.contains("let ffn_norm = next_f32_buffer"),
7643 "ffn_norm should use f32 buffer even for Q8_0 models"
7644 );
7645 assert!(
7646 model_rs.contains("let norm_buf = next_f32_buffer"),
7647 "final norm should use f32 buffer even for Q8_0 models"
7648 );
7649 }
7650
7651 #[test]
7652 fn q8_model_uses_fused_weight_loading() {
7653 let config = minimal_q8_config();
7654 let model_rs = generate_model_rs(&config).unwrap();
7655
7656 assert!(
7658 model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7659 "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7660 );
7661 assert!(
7663 model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7664 "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7665 );
7666 assert!(
7668 model_rs.contains("let o_weight = next_q8_buffer"),
7669 "Q8_0 model should use next_q8_buffer for O weight"
7670 );
7671 assert!(
7672 model_rs.contains("let down_weight = next_q8_buffer"),
7673 "Q8_0 model should use next_q8_buffer for down weight"
7674 );
7675 }
7676
7677 #[test]
7678 fn f32_model_does_not_use_q8_dispatch() {
7679 let config = minimal_config();
7680 let model_rs = generate_model_rs(&config).unwrap();
7681
7682 let forward_start = model_rs
7684 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7685 .unwrap();
7686 let forward_body = &model_rs[forward_start..];
7687 let forward_end = forward_body
7688 .find("\n fn dispatch_")
7689 .unwrap_or(forward_body.len());
7690 let forward_code = &forward_body[..forward_end];
7691
7692 assert!(
7693 !forward_code.contains("dispatch_matmul_q8"),
7694 "f32 model forward should not use dispatch_matmul_q8"
7695 );
7696 }
7697
7698 #[test]
7699 fn q8_dispatch_helper_takes_compute_encoder_ref() {
7700 let config = minimal_q8_config();
7701 let model_rs = generate_model_rs(&config).unwrap();
7702
7703 assert!(
7704 model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7705 "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7706 );
7707 }
7708
7709 #[test]
7710 fn generated_model_has_double_buffered_prefill() {
7711 let config = minimal_config();
7712 let model_rs = generate_model_rs(&config).unwrap();
7713
7714 assert!(
7716 model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7717 "MetalModel should have prev_cmd field for double-buffered prefill"
7718 );
7719
7720 assert!(
7722 model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7723 "MetalModel should have forward_prefill method"
7724 );
7725
7726 assert!(
7728 model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7729 "forward() should drain prev_cmd from previous prefill"
7730 );
7731 }
7732
7733 #[test]
7734 fn generated_main_rs_uses_forward_prefill_for_prompt() {
7735 let config = minimal_config();
7736 let main_rs = generate_main_rs("test-model", &config).unwrap();
7737
7738 assert!(
7739 main_rs.contains("forward_prefill"),
7740 "main.rs should use forward_prefill for intermediate prompt tokens"
7741 );
7742 assert!(
7743 main_rs.contains("double-buffered"),
7744 "main.rs should document double-buffered prefill"
7745 );
7746 }
7747
7748 #[test]
7749 fn generated_shaders_q8_uses_wide_vectorized_loads() {
7750 let shaders = generate_metal_shaders(&minimal_config());
7751
7752 assert!(
7753 shaders.contains("packed_short4"),
7754 "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7755 );
7756 assert!(
7757 shaders.contains("d0[0]"),
7758 "matmul_vec_q8 should index the wide pointer for row 0"
7759 );
7760 assert!(
7761 shaders.contains("as_type<char2>"),
7762 "matmul_vec_q8 should bitcast short lanes to char2"
7763 );
7764 assert!(
7765 shaders.contains("dot("),
7766 "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7767 );
7768 }
7769
7770 fn minimal_q4_config() -> ModelConfig {
7773 ModelConfig {
7774 architecture: Architecture::Llama,
7775 hidden_size: 64,
7776 intermediate_size: 128,
7777 num_layers: 2,
7778 num_attention_heads: 4,
7779 num_kv_heads: 4,
7780 head_dim: 16,
7781 vocab_size: 256,
7782 max_seq_len: 512,
7783 rms_norm_eps: 1e-5,
7784 rope_theta: 10000.0,
7785 dtype: DType::Q4_0,
7786 sliding_window_size: None,
7787 qkv_bias: false,
7788 }
7789 }
7790
7791 #[test]
7792 fn generated_shaders_contain_q4_kernel() {
7793 let shaders = generate_metal_shaders(&minimal_config());
7794
7795 assert!(
7796 shaders.contains("kernel void matmul_vec_q4"),
7797 "shaders should contain matmul_vec_q4 kernel"
7798 );
7799 assert!(
7800 shaders.contains("Q4_ROWS_PER_TG"),
7801 "shaders should define Q4_ROWS_PER_TG constant"
7802 );
7803 assert!(
7804 shaders.contains("Q4_ROWS_PER_SG"),
7805 "shaders should define Q4_ROWS_PER_SG constant"
7806 );
7807 }
7808
7809 #[test]
7810 fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7811 let shaders = generate_metal_shaders(&minimal_config());
7812
7813 assert!(
7815 shaders.contains("uchar4"),
7816 "matmul_vec_q4 should use uchar4 for packed byte loads"
7817 );
7818 assert!(
7820 shaders.contains("&0xF"),
7821 "matmul_vec_q4 should extract low nibble with &0xF"
7822 );
7823 assert!(
7824 shaders.contains(">>4"),
7825 "matmul_vec_q4 should extract high nibble with >>4"
7826 );
7827 assert!(
7829 shaders.contains("-8)"),
7830 "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7831 );
7832 assert!(
7834 shaders.contains("blk * 18"),
7835 "matmul_vec_q4 should use 18-byte block stride"
7836 );
7837 }
7838
7839 #[test]
7840 fn q4_model_has_matmul_q4_pipeline() {
7841 let config = minimal_q4_config();
7842 let model_rs = generate_model_rs(&config).unwrap();
7843
7844 assert!(
7845 model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7846 "MetalModel should have matmul_q4_pipeline field"
7847 );
7848 assert!(
7849 model_rs.contains("matmul_q4_pipeline,"),
7850 "MetalModel Self should include matmul_q4_pipeline"
7851 );
7852 }
7853
7854 #[test]
7855 fn q4_model_uses_dispatch_matmul_q4() {
7856 let config = minimal_q4_config();
7857 let model_rs = generate_model_rs(&config).unwrap();
7858
7859 assert!(
7860 model_rs.contains("dispatch_matmul_q4"),
7861 "Q4_0 model should use dispatch_matmul_q4 for projections"
7862 );
7863 assert!(
7864 model_rs.contains("fn dispatch_matmul_q4"),
7865 "model.rs should define dispatch_matmul_q4 method"
7866 );
7867 }
7868
7869 #[test]
7870 fn q4_model_loads_raw_bytes_not_dequantized() {
7871 let config = minimal_q4_config();
7872 let model_rs = generate_model_rs(&config).unwrap();
7873
7874 assert!(
7876 !model_rs.contains("f16_to_f32"),
7877 "Q4_0 model should not dequantize weights to f32"
7878 );
7879
7880 assert!(
7882 model_rs.contains("total_raw as u64"),
7883 "Q4_0 model should load raw bytes into Metal buffer"
7884 );
7885 }
7886
7887 #[test]
7888 fn q4_model_norms_stay_f32() {
7889 let config = minimal_q4_config();
7890 let model_rs = generate_model_rs(&config).unwrap();
7891
7892 assert!(
7893 model_rs.contains("let attn_norm = next_f32_buffer"),
7894 "attn_norm should use f32 buffer even for Q4_0 models"
7895 );
7896 assert!(
7897 model_rs.contains("let ffn_norm = next_f32_buffer"),
7898 "ffn_norm should use f32 buffer even for Q4_0 models"
7899 );
7900 assert!(
7901 model_rs.contains("let norm_buf = next_f32_buffer"),
7902 "final norm should use f32 buffer even for Q4_0 models"
7903 );
7904 }
7905
7906 #[test]
7907 fn q4_model_uses_fused_weight_loading() {
7908 let config = minimal_q4_config();
7909 let model_rs = generate_model_rs(&config).unwrap();
7910
7911 assert!(
7912 model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7913 "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7914 );
7915 assert!(
7916 model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7917 "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7918 );
7919 assert!(
7920 model_rs.contains("let o_weight = next_q4_buffer"),
7921 "Q4_0 model should use next_q4_buffer for O weight"
7922 );
7923 assert!(
7924 model_rs.contains("let down_weight = next_q4_buffer"),
7925 "Q4_0 model should use next_q4_buffer for down weight"
7926 );
7927 }
7928
7929 #[test]
7930 fn attention_flash_batch_kernel_exists() {
7931 let config = minimal_config();
7935 let model_rs = generate_model_rs(&config).unwrap();
7936 let shaders = generate_metal_shaders(&config);
7937
7938 assert!(
7939 shaders.contains("kernel void attention_flash_batch"),
7940 "shaders.metal must still contain the attention_flash_batch kernel"
7941 );
7942 assert!(
7943 shaders.contains("FLASH_K_TILE"),
7944 "flash kernel must tile K/V with a FLASH_K_TILE constant"
7945 );
7946 assert!(
7947 model_rs.contains("attention_flash_batch_pipeline"),
7948 "MetalModel must register the flash attention pipeline"
7949 );
7950 }
7951
7952 #[test]
7953 fn decode_uses_fused_rope_and_kv_copy() {
7954 let config = minimal_config();
7959 let model_rs = generate_model_rs(&config).unwrap();
7960
7961 let forward_start = model_rs.find("pub fn forward(").expect("forward missing");
7966 let forward_end = model_rs[forward_start..]
7967 .find("pub fn forward_prefill(")
7968 .expect("forward_prefill missing");
7969 let forward_body = &model_rs[forward_start..forward_start + forward_end];
7970
7971 assert!(
7972 forward_body.contains("dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1"),
7973 "single-token forward must use fused rope_qk_batch with M=1"
7974 );
7975 assert!(
7976 forward_body.contains("dispatch_copy_kv_both_batch(&enc, &self.qkv_buf,"),
7977 "single-token forward must use fused copy_kv_both_batch"
7978 );
7979 assert!(
7980 !forward_body.contains("dispatch_copy_from_offset_f16"),
7981 "decode path should no longer call the per-K/per-V copy"
7982 );
7983 }
7984
7985 #[test]
7986 fn kv_cache_stored_as_f16() {
7987 let config = minimal_config();
7993 let shaders = generate_metal_shaders(&config);
7994 let model_rs = generate_model_rs(&config).unwrap();
7995
7996 for kernel in [
7997 "kernel void attention(",
7998 "kernel void attention_batch(",
7999 "kernel void attention_flash_batch(",
8000 "kernel void attention_mma_flash_batch(",
8001 ] {
8002 let start = shaders
8003 .find(kernel)
8004 .unwrap_or_else(|| panic!("kernel {kernel} missing"));
8005 let sig_end = shaders[start..]
8008 .find("){")
8009 .unwrap_or_else(|| shaders[start..].find(") {").unwrap());
8010 let sig = &shaders[start..start + sig_end];
8011 assert!(
8012 sig.contains("device const half*")
8013 && sig.contains("k_cache")
8014 && sig.contains("v_cache"),
8015 "{kernel} must read k_cache/v_cache as `device const half*`"
8016 );
8017 assert!(
8018 !sig.contains("device const float* k_cache")
8019 && !sig.contains("device const float* k_cache"),
8020 "{kernel} still reads k_cache as float"
8021 );
8022 }
8023
8024 assert!(
8025 shaders.contains("kernel void copy_f32_to_f16_offset"),
8026 "f32->f16 copy kernel must be present for single-token decode KV writes"
8027 );
8028 assert!(
8029 model_rs.contains("dispatch_copy_from_offset_f16"),
8030 "single-token decode must dispatch the f32->f16 KV copy"
8031 );
8032 }
8033
8034 #[test]
8035 fn attention_mma_flash_batch_kernel_wired() {
8036 let config = minimal_config();
8039 let model_rs = generate_model_rs(&config).unwrap();
8040 let shaders = generate_metal_shaders(&config);
8041
8042 assert!(
8043 shaders.contains("kernel void attention_mma_flash_batch"),
8044 "shaders.metal must contain the MMA flash kernel"
8045 );
8046 assert!(
8047 shaders.contains("FLASH_MMA_Q_BLOCK"),
8048 "MMA flash kernel must define Q_BLOCK tiling constant"
8049 );
8050 assert!(
8051 shaders.contains("simdgroup_multiply_accumulate"),
8052 "MMA flash kernel must use hardware MMA"
8053 );
8054 assert!(
8055 model_rs.contains("attention_mma_flash_batch_pipeline"),
8056 "MetalModel must register the MMA flash pipeline"
8057 );
8058 assert!(
8059 model_rs.contains("mma_opt_out"),
8060 "dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
8061 );
8062 assert!(
8063 model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
8064 "MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
8065 );
8066 }
8067
8068 #[test]
8069 fn forward_prefill_batch_chunks_by_max_batch_size() {
8070 let config = minimal_config();
8075 let model_rs = generate_model_rs(&config).unwrap();
8076 assert!(
8077 model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
8078 "forward_prefill_batch must chunk long prompts"
8079 );
8080 assert!(
8081 !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
8082 "the old truncation path must be gone"
8083 );
8084 }
8085
8086 #[test]
8087 fn qwen2_qkv_bias_wired_through_metal_codegen() {
8088 let config = ModelConfig {
8092 architecture: Architecture::Qwen2,
8093 qkv_bias: true,
8094 ..minimal_config()
8095 };
8096 let model_rs = generate_model_rs(&config).unwrap();
8097
8098 assert!(
8099 model_rs.contains("qkv_bias: Buffer"),
8100 "Qwen2 LayerBuffers must declare qkv_bias field"
8101 );
8102 assert!(
8103 model_rs.contains("let qkv_bias = next_f32_buffer"),
8104 "Qwen2 layer init must load the bias from the weight blob"
8105 );
8106 assert!(
8107 model_rs.contains("add_bias_batch_pipeline"),
8108 "Qwen2 model struct must include the add_bias_batch_pipeline"
8109 );
8110 assert!(
8111 model_rs.contains("fn dispatch_add_bias_batch"),
8112 "Qwen2 codegen must emit dispatch_add_bias_batch helper"
8113 );
8114 assert!(
8115 model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
8116 "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
8117 );
8118 assert!(
8119 model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
8120 "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
8121 );
8122
8123 let shaders = generate_metal_shaders(&config);
8125 assert!(
8126 shaders.contains("kernel void add_bias_batch"),
8127 "shaders.metal must contain the add_bias_batch kernel"
8128 );
8129 }
8130
8131 #[test]
8132 fn llama_does_not_emit_qkv_bias_machinery() {
8133 let config = minimal_config();
8136 assert!(!config.qkv_bias);
8137 let model_rs = generate_model_rs(&config).unwrap();
8138 assert!(
8139 !model_rs.contains("qkv_bias: Buffer"),
8140 "Llama must not have qkv_bias field"
8141 );
8142 assert!(
8143 !model_rs.contains("add_bias_batch_pipeline"),
8144 "Llama must not pull in add_bias_batch_pipeline"
8145 );
8146 assert!(
8147 !model_rs.contains("dispatch_add_bias_batch"),
8148 "Llama must not call dispatch_add_bias_batch"
8149 );
8150 }
8151
8152 #[test]
8153 fn q4_dispatch_helper_takes_compute_encoder_ref() {
8154 let config = minimal_q4_config();
8155 let model_rs = generate_model_rs(&config).unwrap();
8156
8157 assert!(
8158 model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8159 "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8160 );
8161 }
8162
8163 #[test]
8164 fn f32_model_does_not_use_q4_dispatch() {
8165 let config = minimal_config();
8166 let model_rs = generate_model_rs(&config).unwrap();
8167
8168 let forward_start = model_rs
8169 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8170 .unwrap();
8171 let forward_body = &model_rs[forward_start..];
8172 let forward_end = forward_body
8173 .find("\n fn dispatch_")
8174 .unwrap_or(forward_body.len());
8175 let forward_code = &forward_body[..forward_end];
8176
8177 assert!(
8178 !forward_code.contains("dispatch_matmul_q4"),
8179 "f32 model forward should not use dispatch_matmul_q4"
8180 );
8181 }
8182
8183 #[test]
8184 fn q4_model_lm_head_uses_q4_buffer() {
8185 let config = minimal_q4_config();
8186 let model_rs = generate_model_rs(&config).unwrap();
8187
8188 assert!(
8189 model_rs.contains("let lm_head_buf = next_q4_buffer"),
8190 "Q4_0 model should use next_q4_buffer for lm_head"
8191 );
8192 }
8193
8194 #[test]
8195 fn vec_tile_size_matches_model_dimensions() {
8196 let small = minimal_config();
8198 let shaders_small = generate_metal_shaders(&small);
8199 assert!(
8200 shaders_small.contains("vec_tile[128]"),
8201 "vec_tile should be sized to max(hidden, intermediate) = 128"
8202 );
8203
8204 let mut large = minimal_config();
8206 large.hidden_size = 2048;
8207 large.intermediate_size = 8192;
8208 let shaders_large = generate_metal_shaders(&large);
8209 assert!(
8210 shaders_large.contains("vec_tile[8192]"),
8211 "vec_tile should be 8192 for models with intermediate=8192"
8212 );
8213 assert!(
8214 !shaders_large.contains("vec_tile[4096]"),
8215 "vec_tile should NOT be hardcoded to 4096"
8216 );
8217 }
8218
8219 #[test]
8220 fn generated_cargo_toml_has_server_deps() {
8221 let toml = generate_cargo_toml("my-model");
8222 assert!(
8223 toml.contains("tiny_http"),
8224 "Cargo.toml should depend on tiny_http"
8225 );
8226 assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8227 assert!(
8228 toml.contains("serde_json"),
8229 "Cargo.toml should depend on serde_json"
8230 );
8231 }
8232
8233 #[test]
8234 fn generated_main_rs_has_serve_mode() {
8235 let config = minimal_config();
8236 let main_rs = generate_main_rs("test-model", &config).unwrap();
8237
8238 assert!(
8239 main_rs.contains("--serve"),
8240 "main.rs should parse --serve flag"
8241 );
8242 assert!(
8243 main_rs.contains("--port"),
8244 "main.rs should parse --port flag"
8245 );
8246 assert!(
8247 main_rs.contains("fn serve("),
8248 "main.rs should define serve function"
8249 );
8250 assert!(
8251 main_rs.contains("tiny_http::Server::http"),
8252 "main.rs should create tiny_http server"
8253 );
8254 }
8255
8256 #[test]
8257 fn generated_main_rs_has_chat_completions_endpoint() {
8258 let config = minimal_config();
8259 let main_rs = generate_main_rs("test-model", &config).unwrap();
8260
8261 assert!(
8262 main_rs.contains("/v1/chat/completions"),
8263 "main.rs should handle /v1/chat/completions endpoint"
8264 );
8265 assert!(
8266 main_rs.contains("/v1/models"),
8267 "main.rs should handle /v1/models endpoint"
8268 );
8269 assert!(
8270 main_rs.contains("/health"),
8271 "main.rs should handle /health endpoint"
8272 );
8273 }
8274
8275 #[test]
8276 fn generated_main_rs_has_sse_streaming() {
8277 let config = minimal_config();
8278 let main_rs = generate_main_rs("test-model", &config).unwrap();
8279
8280 assert!(
8281 main_rs.contains("text/event-stream"),
8282 "main.rs should set SSE content type for streaming"
8283 );
8284 assert!(
8285 main_rs.contains("chat.completion.chunk"),
8286 "main.rs should emit SSE chunks"
8287 );
8288 assert!(
8289 main_rs.contains("[DONE]"),
8290 "main.rs should emit [DONE] sentinel"
8291 );
8292 }
8293
8294 #[test]
8295 fn generated_main_rs_has_chat_message_formatting() {
8296 let config = minimal_config();
8297 let main_rs = generate_main_rs("test-model", &config).unwrap();
8298
8299 assert!(
8300 main_rs.contains("fn format_chat_messages"),
8301 "main.rs should define format_chat_messages function"
8302 );
8303 assert!(
8304 main_rs.contains("<|im_start|>"),
8305 "main.rs should use ChatML format"
8306 );
8307 assert!(
8308 main_rs.contains("<|im_end|>"),
8309 "main.rs should use ChatML format"
8310 );
8311 }
8312
8313 #[test]
8314 fn generated_main_rs_has_request_types() {
8315 let config = minimal_config();
8316 let main_rs = generate_main_rs("test-model", &config).unwrap();
8317
8318 assert!(
8319 main_rs.contains("struct ChatRequest"),
8320 "main.rs should define ChatRequest struct"
8321 );
8322 assert!(
8323 main_rs.contains("struct ChatMessage"),
8324 "main.rs should define ChatMessage struct"
8325 );
8326 assert!(
8327 main_rs.contains("Deserialize"),
8328 "main.rs should derive Deserialize for request types"
8329 );
8330 }
8331
8332 #[test]
8333 fn generated_model_has_reset_method() {
8334 let config = minimal_config();
8335 let model_rs = generate_model_rs(&config).unwrap();
8336
8337 assert!(
8338 model_rs.contains("pub fn reset(&mut self)"),
8339 "model.rs should have a reset() method for multi-request serving"
8340 );
8341 assert!(
8342 model_rs.contains("self.pos = 0"),
8343 "reset() should reset position to 0"
8344 );
8345 }
8346
8347 #[test]
8348 fn generated_main_rs_cli_mode_still_works() {
8349 let config = minimal_config();
8350 let main_rs = generate_main_rs("test-model", &config).unwrap();
8351
8352 assert!(
8354 main_rs.contains("fn cli_mode("),
8355 "main.rs should define cli_mode function"
8356 );
8357 assert!(
8358 main_rs.contains("model.forward"),
8359 "main.rs should call model.forward"
8360 );
8361 assert!(
8362 main_rs.contains("model.forward_prefill"),
8363 "main.rs should call model.forward_prefill"
8364 );
8365 }
8366
8367 #[test]
8370 fn generated_shaders_contain_batch_kernels() {
8371 let shaders = generate_metal_shaders(&minimal_config());
8372
8373 assert!(
8374 shaders.contains("kernel void matmul_vec_batch"),
8375 "shaders should contain matmul_vec_batch kernel"
8376 );
8377 assert!(
8378 shaders.contains("kernel void matmul_vec_q8_batch"),
8379 "shaders should contain matmul_vec_q8_batch kernel"
8380 );
8381 assert!(
8382 shaders.contains("kernel void matmul_q8_gemm_batch"),
8383 "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8384 );
8385 assert!(
8386 shaders.contains("kernel void matmul_vec_q4_batch"),
8387 "shaders should contain matmul_vec_q4_batch kernel"
8388 );
8389 assert!(
8390 shaders.contains("kernel void rms_norm_batch"),
8391 "shaders should contain rms_norm_batch kernel"
8392 );
8393 assert!(
8394 shaders.contains("kernel void silu_mul_fused_batch"),
8395 "shaders should contain silu_mul_fused_batch kernel"
8396 );
8397 assert!(
8398 shaders.contains("kernel void add_inplace_batch"),
8399 "shaders should contain add_inplace_batch kernel"
8400 );
8401 assert!(
8402 shaders.contains("kernel void copy_embedding_batch"),
8403 "shaders should contain copy_embedding_batch kernel"
8404 );
8405 }
8406
8407 #[test]
8408 fn generated_model_has_batch_pipelines() {
8409 let config = minimal_config();
8410 let model_rs = generate_model_rs(&config).unwrap();
8411
8412 for pipeline in &[
8413 "matmul_batch_pipeline",
8414 "matmul_q8_batch_pipeline",
8415 "matmul_q8_gemm_batch_pipeline",
8416 "matmul_q4_batch_pipeline",
8417 "rms_norm_batch_pipeline",
8418 "rope_batch_pipeline",
8419 "silu_mul_fused_batch_pipeline",
8420 "add_inplace_batch_pipeline",
8421 "copy_embedding_batch_pipeline",
8422 "attention_batch_pipeline",
8423 "copy_kv_batch_pipeline",
8424 ] {
8425 assert!(
8426 model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8427 "MetalModel should have {pipeline} field"
8428 );
8429 }
8430 }
8431
8432 #[test]
8433 fn generated_model_has_batch_buffers() {
8434 let config = minimal_config();
8435 let model_rs = generate_model_rs(&config).unwrap();
8436
8437 for buf in &[
8438 "batch_hidden_buf",
8439 "batch_residual_buf",
8440 "batch_qkv_buf",
8441 "batch_attn_out_buf",
8442 "batch_attn_proj_buf",
8443 "batch_gate_up_buf",
8444 "batch_ffn_hidden_buf",
8445 "batch_ffn_out_buf",
8446 "batch_tokens_buf",
8447 "batch_positions_buf",
8448 ] {
8449 assert!(
8450 model_rs.contains(&format!("{buf}: Buffer")),
8451 "MetalModel should have {buf} field"
8452 );
8453 }
8454 }
8455
8456 #[test]
8457 fn generated_model_has_forward_prefill_batch() {
8458 let config = minimal_config();
8459 let model_rs = generate_model_rs(&config).unwrap();
8460
8461 assert!(
8462 model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8463 "MetalModel should have forward_prefill_batch method"
8464 );
8465
8466 assert!(
8468 model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8469 "forward_prefill should delegate to forward_prefill_batch"
8470 );
8471 }
8472
8473 #[test]
8474 fn generated_model_has_max_batch_size_constant() {
8475 let config = minimal_config();
8476 let model_rs = generate_model_rs(&config).unwrap();
8477
8478 assert!(
8479 model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8480 "model.rs should define MAX_BATCH_SIZE constant"
8481 );
8482 }
8483
8484 #[test]
8485 fn forward_prefill_batch_uses_batch_dispatch() {
8486 let config = minimal_config();
8487 let model_rs = generate_model_rs(&config).unwrap();
8488
8489 let batch_start = model_rs
8490 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8491 .unwrap();
8492 let batch_body = &model_rs[batch_start..];
8493 let batch_end = batch_body
8494 .find("\n pub fn reset")
8495 .unwrap_or(batch_body.len());
8496 let batch_code = &batch_body[..batch_end];
8497
8498 assert!(
8500 batch_code.contains("dispatch_rms_norm_batch"),
8501 "forward_prefill_batch should use dispatch_rms_norm_batch"
8502 );
8503 assert!(
8504 batch_code.contains("dispatch_copy_embedding_batch"),
8505 "forward_prefill_batch should use dispatch_copy_embedding_batch"
8506 );
8507 assert!(
8508 batch_code.contains("dispatch_silu_mul_fused_batch"),
8509 "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8510 );
8511 assert!(
8513 batch_code.contains("dispatch_attention_batch"),
8514 "forward_prefill_batch should use dispatch_attention_batch"
8515 );
8516 assert!(
8518 batch_code.contains("dispatch_copy_kv_both_batch"),
8519 "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8520 );
8521 assert!(
8523 batch_code.contains("dispatch_rope_qk_batch"),
8524 "forward_prefill_batch should use dispatch_rope_qk_batch"
8525 );
8526 }
8527
8528 #[test]
8529 fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8530 let config = minimal_q8_config();
8531 let model_rs = generate_model_rs(&config).unwrap();
8532
8533 let batch_start = model_rs
8534 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8535 .unwrap();
8536 let batch_body = &model_rs[batch_start..];
8537 let batch_end = batch_body
8538 .find("\n pub fn reset")
8539 .unwrap_or(batch_body.len());
8540 let batch_code = &batch_body[..batch_end];
8541
8542 assert!(
8543 batch_code.contains("dispatch_matmul_q8_batch"),
8544 "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8545 );
8546 }
8547
8548 #[test]
8549 fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8550 let config = minimal_q4_config();
8551 let model_rs = generate_model_rs(&config).unwrap();
8552
8553 let batch_start = model_rs
8554 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8555 .unwrap();
8556 let batch_body = &model_rs[batch_start..];
8557 let batch_end = batch_body
8558 .find("\n pub fn reset")
8559 .unwrap_or(batch_body.len());
8560 let batch_code = &batch_body[..batch_end];
8561
8562 assert!(
8563 batch_code.contains("dispatch_matmul_q4_batch"),
8564 "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8565 );
8566 }
8567
8568 #[test]
8569 fn generated_main_rs_uses_batched_prefill() {
8570 let config = minimal_config();
8571 let main_rs = generate_main_rs("test-model", &config).unwrap();
8572
8573 assert!(
8574 main_rs.contains("forward_prefill_batch"),
8575 "main.rs should use forward_prefill_batch for prompt tokens"
8576 );
8577 }
8578
8579 #[test]
8580 fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8581 let config = minimal_config();
8582 let model_rs = generate_model_rs(&config).unwrap();
8583
8584 let batch_start = model_rs
8585 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8586 .unwrap();
8587 let batch_body = &model_rs[batch_start..];
8588 let batch_end = batch_body
8589 .find("\n pub fn reset")
8590 .unwrap_or(batch_body.len());
8591 let batch_code = &batch_body[..batch_end];
8592
8593 assert!(
8594 batch_code.contains("dispatch_matmul_batch"),
8595 "f32 forward_prefill_batch should use dispatch_matmul_batch"
8596 );
8597 assert!(
8599 !batch_code.contains("dispatch_matmul_q8_batch"),
8600 "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8601 );
8602 assert!(
8603 !batch_code.contains("dispatch_matmul_q4_batch"),
8604 "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8605 );
8606 }
8607
8608 #[test]
8609 fn forward_uses_cpu_embedding_lookup() {
8610 let config = minimal_config();
8611 let model_rs = generate_model_rs(&config).unwrap();
8612
8613 let forward_start = model_rs
8615 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8616 .unwrap();
8617 let forward_body = &model_rs[forward_start..];
8618 let forward_end = forward_body
8619 .find("\n pub fn forward_profile")
8620 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8621 .unwrap_or(forward_body.len());
8622 let forward_code = &forward_body[..forward_end];
8623
8624 assert!(
8626 forward_code.contains("embed_buf.contents()"),
8627 "forward() should access embed_buf via CPU unified memory for embedding lookup"
8628 );
8629 assert!(
8630 forward_code.contains("copy_nonoverlapping"),
8631 "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8632 );
8633 assert!(
8635 !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8636 "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8637 );
8638 }
8639
8640 #[test]
8641 fn forward_profile_method_exists() {
8642 let config = minimal_config();
8643 let model_rs = generate_model_rs(&config).unwrap();
8644
8645 assert!(
8646 model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8647 "MetalModel should have forward_profile() method"
8648 );
8649 assert!(
8651 model_rs.contains("[profile]"),
8652 "forward_profile() should print timing with [profile] prefix"
8653 );
8654 assert!(
8655 model_rs.contains("d_embed"),
8656 "forward_profile() should measure embedding time"
8657 );
8658 assert!(
8659 model_rs.contains("d_layers"),
8660 "forward_profile() should measure layer time"
8661 );
8662 assert!(
8663 model_rs.contains("d_logits"),
8664 "forward_profile() should measure logits time"
8665 );
8666 }
8667
8668 #[test]
8669 fn generated_cli_has_profile_flag() {
8670 let config = minimal_config();
8671 let main_rs = generate_main_rs("test-model", &config).unwrap();
8672
8673 assert!(
8674 main_rs.contains("--profile"),
8675 "CLI should support --profile flag"
8676 );
8677 assert!(
8678 main_rs.contains("forward_profile"),
8679 "CLI should call forward_profile when --profile is set"
8680 );
8681 }
8682
8683 #[test]
8684 fn generated_cli_has_thermal_yield() {
8685 let config = minimal_config();
8686 let main_rs = generate_main_rs("test-model", &config).unwrap();
8687
8688 assert!(
8689 main_rs.contains("yield_now()"),
8690 "CLI generation loop should include thread::yield_now() for thermal management"
8691 );
8692 }
8693
8694 #[test]
8697 fn generated_forward_handles_single_token_prompt() {
8698 let config = minimal_config();
8702 let model_rs = generate_model_rs(&config).unwrap();
8703
8704 let forward_start = model_rs
8706 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8707 .expect("forward() must exist");
8708 let forward_body = &model_rs[forward_start..forward_start + 400];
8709
8710 assert!(
8712 !forward_body.contains("assert!(self.pos > 0"),
8713 "forward() must accept pos=0 (first token with no prefill)"
8714 );
8715
8716 assert!(
8718 model_rs.contains("self.pos"),
8719 "forward() should use self.pos to track sequence position"
8720 );
8721 }
8722
8723 #[test]
8724 fn generated_reset_clears_kv_cache_position() {
8725 let config = minimal_config();
8728 let model_rs = generate_model_rs(&config).unwrap();
8729
8730 let reset_start = model_rs
8731 .find("pub fn reset(&mut self)")
8732 .expect("reset() must exist");
8733 let reset_body = &model_rs[reset_start..reset_start + 200];
8734
8735 assert!(
8737 reset_body.contains("self.pos = 0"),
8738 "reset() must set self.pos = 0"
8739 );
8740
8741 assert!(
8743 reset_body.contains("self.prev_cmd = None"),
8744 "reset() should clear prev_cmd for clean command buffer state"
8745 );
8746 }
8747
8748 #[test]
8749 fn generated_serve_handles_empty_messages_gracefully() {
8750 let config = minimal_config();
8753 let main_rs = generate_main_rs("test-model", &config).unwrap();
8754
8755 let format_fn_start = main_rs
8757 .find("fn format_chat_messages")
8758 .expect("format_chat_messages must exist");
8759 let format_fn_body =
8760 &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8761
8762 assert!(
8764 format_fn_body.contains("for msg in messages"),
8765 "format_chat_messages should iterate over the messages slice"
8766 );
8767 assert!(
8769 format_fn_body.contains("<|im_start|>assistant"),
8770 "format_chat_messages should always append assistant prompt header"
8771 );
8772
8773 let serve_fn_start = main_rs
8775 .find("fn serve(")
8776 .expect("serve function must exist");
8777 let serve_fn_body = &main_rs[serve_fn_start..];
8778 assert!(
8779 serve_fn_body.contains("model.reset()"),
8780 "serve function should reset model between requests"
8781 );
8782 }
8783
8784 #[test]
8785 fn generated_model_forward_increments_pos() {
8786 let config = minimal_config();
8789 let model_rs = generate_model_rs(&config).unwrap();
8790
8791 let forward_start = model_rs
8792 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8793 .unwrap();
8794 let forward_body = &model_rs[forward_start..];
8795 let forward_end = forward_body
8796 .find("\n pub fn forward_profile")
8797 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8798 .or_else(|| forward_body.find("\n fn dispatch_"))
8799 .unwrap_or(forward_body.len());
8800 let forward_code = &forward_body[..forward_end];
8801
8802 assert!(
8803 forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8804 "forward() must increment self.pos after processing a token"
8805 );
8806 }
8807}