1use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16 #[error("graph has no model config")]
18 MissingConfig,
19
20 #[error("I/O error: {0}")]
22 Io(#[from] std::io::Error),
23
24 #[error("format error: {0}")]
26 Fmt(#[from] std::fmt::Error),
27}
28
29pub fn generate_metal_project(
37 graph: &Graph,
38 output_dir: &Path,
39 model_name: &str,
40) -> Result<(), MetalCodegenError> {
41 let config = graph
42 .config
43 .as_ref()
44 .ok_or(MetalCodegenError::MissingConfig)?;
45
46 let src_dir = output_dir.join("src");
47 let shader_dir = output_dir.join("shaders");
48 fs::create_dir_all(&src_dir)?;
49 fs::create_dir_all(&shader_dir)?;
50
51 fs::write(
52 output_dir.join("Cargo.toml"),
53 generate_cargo_toml(model_name),
54 )?;
55
56 fs::write(
57 shader_dir.join("kernels.metal"),
58 generate_metal_shaders(config),
59 )?;
60
61 let model_rs = generate_model_rs(config)?;
62 fs::write(src_dir.join("model.rs"), model_rs)?;
63
64 let main_rs = generate_main_rs(model_name, config)?;
65 fs::write(src_dir.join("main.rs"), main_rs)?;
66
67 Ok(())
68}
69
70fn sanitize_name(name: &str) -> String {
75 name.to_lowercase()
76 .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77 .trim_matches('-')
78 .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82 let sanitized = sanitize_name(model_name);
83 format!(
84 r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108 )
109}
110
111fn generate_metal_shaders(config: &ModelConfig) -> String {
116 let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121 let attn_scores_size = config.max_seq_len.min(4096);
126 r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154 device const float* matrix [[buffer(0)]],
155 device const float* vector [[buffer(1)]],
156 device float* output [[buffer(2)]],
157 constant uint& rows [[buffer(3)]],
158 constant uint& cols [[buffer(4)]],
159 uint tgid [[threadgroup_position_in_grid]],
160 uint tid [[thread_index_in_threadgroup]],
161 uint simd_lane [[thread_index_in_simdgroup]],
162 uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164 // Cooperatively load vector into threadgroup shared memory
165 threadgroup float vec_tile[VEC_TILE_SIZE]; // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166 for (uint i = tid; i < cols; i += 256) {
167 vec_tile[i] = vector[i];
168 }
169 threadgroup_barrier(mem_flags::mem_threadgroup);
170
171 // Each simdgroup handles 8 consecutive rows
172 uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173 if (row_base >= rows) return;
174
175 uint base0 = row_base * cols;
176 uint base1 = (row_base + 1) * cols;
177 uint base2 = (row_base + 2) * cols;
178 uint base3 = (row_base + 3) * cols;
179 uint base4 = (row_base + 4) * cols;
180 uint base5 = (row_base + 5) * cols;
181 uint base6 = (row_base + 6) * cols;
182 uint base7 = (row_base + 7) * cols;
183
184 // float4 vectorized accumulation across 8 rows
185 uint cols_vec4 = cols & ~127u; // largest multiple of 128 <= cols
186 float4 sum4_0 = float4(0.0f);
187 float4 sum4_1 = float4(0.0f);
188 float4 sum4_2 = float4(0.0f);
189 float4 sum4_3 = float4(0.0f);
190 float4 sum4_4 = float4(0.0f);
191 float4 sum4_5 = float4(0.0f);
192 float4 sum4_6 = float4(0.0f);
193 float4 sum4_7 = float4(0.0f);
194
195 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205 }
206
207 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216 // Handle remaining elements (cols not divisible by 128)
217 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218 float vv = vec_tile[j];
219 sum0 += matrix[base0 + j] * vv;
220 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227 }
228
229 // Simdgroup hardware warp-level reduction
230 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235 // Only first lane writes the results
236 if (simd_lane == 0) {
237 if (row_base < rows) output[row_base] = sum0;
238 if (row_base + 1 < rows) output[row_base + 1] = sum1;
239 if (row_base + 2 < rows) output[row_base + 2] = sum2;
240 if (row_base + 3 < rows) output[row_base + 3] = sum3;
241 if (row_base + 4 < rows) output[row_base + 4] = sum4;
242 if (row_base + 5 < rows) output[row_base + 5] = sum5;
243 if (row_base + 6 < rows) output[row_base + 6] = sum6;
244 if (row_base + 7 < rows) output[row_base + 7] = sum7;
245 }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253 device const float* input [[buffer(0)]],
254 device const float* weight [[buffer(1)]],
255 device float* output [[buffer(2)]],
256 constant uint& n [[buffer(3)]],
257 constant float& eps [[buffer(4)]],
258 uint tid [[thread_index_in_threadgroup]])
259{
260 // Each thread accumulates partial sum-of-squares
261 float sum_sq = 0.0f;
262 for (uint i = tid; i < n; i += 256) {
263 float v = input[i];
264 sum_sq += v * v;
265 }
266
267 // Simdgroup-level reduction (hardware warp sum)
268 sum_sq = simd_sum(sum_sq);
269
270 // Cross-simdgroup reduction via shared memory
271 threadgroup float shared[8];
272 uint simd_id = tid / 32;
273 uint simd_lane = tid % 32;
274 if (simd_lane == 0) {
275 shared[simd_id] = sum_sq;
276 }
277 threadgroup_barrier(mem_flags::mem_threadgroup);
278
279 // First thread computes final inverse RMS
280 if (tid == 0) {
281 float total = 0.0f;
282 for (uint i = 0; i < 8; i++) {
283 total += shared[i];
284 }
285 shared[0] = fast::rsqrt(total / float(n) + eps);
286 }
287 threadgroup_barrier(mem_flags::mem_threadgroup);
288
289 float inv_rms = shared[0];
290
291 // Normalize
292 for (uint i = tid; i < n; i += 256) {
293 output[i] = input[i] * inv_rms * weight[i];
294 }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301 device float* data [[buffer(0)]],
302 constant uint& num_heads [[buffer(1)]],
303 constant uint& head_dim [[buffer(2)]],
304 constant uint& pos [[buffer(3)]],
305 constant float& theta [[buffer(4)]],
306 uint id [[thread_position_in_grid]])
307{
308 uint half_dim = head_dim / 2;
309 uint total_pairs = num_heads * half_dim;
310 if (id >= total_pairs) return;
311
312 uint h = id / half_dim;
313 uint i = id % half_dim;
314 uint off = h * head_dim;
315
316 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317 float angle = float(pos) * freq;
318 float c = cos(angle);
319 float s = sin(angle);
320
321 float x0 = data[off + 2 * i];
322 float x1 = data[off + 2 * i + 1];
323 data[off + 2 * i] = x0 * c - x1 * s;
324 data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331 device float* data [[buffer(0)]],
332 constant uint& n [[buffer(1)]],
333 uint tid [[thread_index_in_threadgroup]],
334 uint tg_size [[threads_per_threadgroup]])
335{
336 threadgroup float shared_val[256];
337
338 // Pass 1: find max
339 float local_max = -INFINITY;
340 for (uint i = tid; i < n; i += tg_size) {
341 local_max = max(local_max, data[i]);
342 }
343 shared_val[tid] = local_max;
344 threadgroup_barrier(mem_flags::mem_threadgroup);
345
346 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347 if (tid < stride) {
348 shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349 }
350 threadgroup_barrier(mem_flags::mem_threadgroup);
351 }
352 float max_val = shared_val[0];
353 threadgroup_barrier(mem_flags::mem_threadgroup);
354
355 // Pass 2: exp and sum
356 float local_sum = 0.0f;
357 for (uint i = tid; i < n; i += tg_size) {
358 float e = fast::exp(data[i] - max_val);
359 data[i] = e;
360 local_sum += e;
361 }
362 shared_val[tid] = local_sum;
363 threadgroup_barrier(mem_flags::mem_threadgroup);
364
365 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366 if (tid < stride) {
367 shared_val[tid] += shared_val[tid + stride];
368 }
369 threadgroup_barrier(mem_flags::mem_threadgroup);
370 }
371 float inv_sum = 1.0f / shared_val[0];
372 threadgroup_barrier(mem_flags::mem_threadgroup);
373
374 // Pass 3: normalize
375 for (uint i = tid; i < n; i += tg_size) {
376 data[i] *= inv_sum;
377 }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382// output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384 device const float* gate [[buffer(0)]],
385 device const float* up [[buffer(1)]],
386 device float* output [[buffer(2)]],
387 constant uint& n [[buffer(3)]],
388 uint id [[thread_position_in_grid]])
389{
390 if (id >= n) return;
391 float g = gate[id];
392 output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397// gate = gate_up[0..n], up = gate_up[n..2*n]
398// output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400 device const float* gate_up [[buffer(0)]],
401 device float* output [[buffer(1)]],
402 constant uint& n [[buffer(2)]],
403 uint id [[thread_position_in_grid]])
404{
405 if (id >= n) return;
406 float g = gate_up[id];
407 float u = gate_up[n + id];
408 output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414 device const float* a [[buffer(0)]],
415 device const float* b [[buffer(1)]],
416 device float* output [[buffer(2)]],
417 constant uint& n [[buffer(3)]],
418 uint id [[thread_position_in_grid]])
419{
420 if (id >= n) return;
421 output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428 device const float* src [[buffer(0)]],
429 device float* dst [[buffer(1)]],
430 constant uint& count [[buffer(2)]],
431 uint id [[thread_position_in_grid]])
432{
433 if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440 device const float* src [[buffer(0)]],
441 device float* dst [[buffer(1)]],
442 constant uint& src_offset [[buffer(2)]], // in floats
443 constant uint& count [[buffer(3)]],
444 uint id [[thread_position_in_grid]])
445{
446 if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── copy_f32_to_f16_offset ──────────────────────────────────────────────
450// Copy f32 elements from src into a half-typed dst, converting to half on
451// write. Used by the single-token decode path to append a new K/V vector
452// to the f16 KV cache. Byte offsets into src/dst are supplied via the
453// Metal buffer binding offsets — no in-kernel offsets needed.
454kernel void copy_f32_to_f16_offset(
455 device const float* src [[buffer(0)]],
456 device half* dst [[buffer(1)]],
457 constant uint& count [[buffer(2)]],
458 uint id [[thread_position_in_grid]])
459{
460 if (id < count) dst[id] = half(src[id]);
461}
462
463// ── add_inplace ─────────────────────────────────────────────────────────
464// In-place residual connection: a[i] += b[i]
465// Avoids a separate blit copy for residual add, reducing encoder overhead.
466kernel void add_inplace(
467 device float* a [[buffer(0)]],
468 device const float* b [[buffer(1)]],
469 constant uint& n [[buffer(2)]],
470 uint id [[thread_position_in_grid]])
471{
472 if (id >= n) return;
473 a[id] += b[id];
474}
475
476// ── matmul_vec_q8 ─────────────────────────────────────────────────────
477// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
478// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
479// Operates directly on quantized weights to halve memory bandwidth vs f32,
480// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
481//
482// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
483// because int8->float conversion doubles register demand. Fully unrolled
484// inner loop with float4 vector loads from shared memory eliminates loop
485// overhead and enables better instruction scheduling.
486// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
487constant constexpr uint Q8_ROWS_PER_SG = 4;
488constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
489
490// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
491// doubles ALU work, so the same register budget applies).
492constant constexpr uint Q4_ROWS_PER_SG = 4;
493constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
494
495kernel void matmul_vec_q8(
496 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes
497 device const float* vector [[buffer(1)]], // f32 input
498 device float* output [[buffer(2)]],
499 constant uint& rows [[buffer(3)]],
500 constant uint& cols [[buffer(4)]], // number of elements per row
501 uint tgid [[threadgroup_position_in_grid]],
502 uint tid [[thread_index_in_threadgroup]],
503 uint simd_lane [[thread_index_in_simdgroup]],
504 uint simd_id [[simdgroup_index_in_threadgroup]])
505{
506 // Load vector into shared memory
507 threadgroup float vec_tile[VEC_TILE_SIZE];
508 for (uint i = tid; i < cols; i += 256) {
509 vec_tile[i] = vector[i];
510 }
511 threadgroup_barrier(mem_flags::mem_threadgroup);
512
513 // Each simdgroup handles 4 consecutive rows (lower register pressure)
514 uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
515 if (row_base >= rows) return;
516
517 // Q8_0: each block is 34 bytes for 32 elements
518 uint blocks_per_row = cols / 32;
519 uint row_bytes = blocks_per_row * 34;
520
521 // Pointers to each row's Q8_0 data
522 device const uchar* r0 = matrix + row_base * row_bytes;
523 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
524 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
525 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
526
527 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
528
529 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
530 uint bb = blk * 34;
531 uint vb = blk * 32;
532
533 // Prefetch all 4 scales
534 float sc0 = float(*(device const half*)(r0 + bb));
535 float sc1 = float(*(device const half*)(r1 + bb));
536 float sc2 = float(*(device const half*)(r2 + bb));
537 float sc3 = float(*(device const half*)(r3 + bb));
538
539 // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
540 // Q8_0 block layout where the int8 data starts at offset +2 from a
541 // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
542 // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
543 // reduction in memory transactions. Metal's char16/packed_char16 are
544 // reserved types and packed_*int4 require >=4-byte alignment which
545 // this layout does not provide, so packed_short4 is the widest valid
546 // vectorized load.
547 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
548 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
549 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
550 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
551
552 // Load all 8 float4 vector values for this 32-element block from shared memory
553 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
554 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
555 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
556 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
557 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
558 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
559 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
560 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
561
562 // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
563 // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
564 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
565 short4 _s = short4(SHORT4); \
566 char2 _a = as_type<char2>(_s.x); \
567 char2 _b = as_type<char2>(_s.y); \
568 char2 _c = as_type<char2>(_s.z); \
569 char2 _d = as_type<char2>(_s.w); \
570 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
571 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
572 }
573
574 float4 f0, f1;
575 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
576
577 // Row 0: 4 short4 loads cover 32 int8 weights
578 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
579 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
580 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
581 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
582
583 // Row 1
584 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
585 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
586 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
587 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
588
589 // Row 2
590 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
591 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
592 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
593 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
594
595 // Row 3
596 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
597 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
598 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
599 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
600
601 #undef Q8_UNPACK8
602
603 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
604 }
605
606 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
607 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
608
609 if (simd_lane == 0) {
610 if (row_base < rows) output[row_base] = sum0;
611 if (row_base + 1 < rows) output[row_base + 1] = sum1;
612 if (row_base + 2 < rows) output[row_base + 2] = sum2;
613 if (row_base + 3 < rows) output[row_base + 3] = sum3;
614 }
615}
616
617// ── matmul_vec_q4 ─────────────────────────────────────────────────────
618// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
619// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
620// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
621// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
622//
623// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
624// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
625kernel void matmul_vec_q4(
626 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes
627 device const float* vector [[buffer(1)]], // f32 input
628 device float* output [[buffer(2)]],
629 constant uint& rows [[buffer(3)]],
630 constant uint& cols [[buffer(4)]], // number of elements per row
631 uint tgid [[threadgroup_position_in_grid]],
632 uint tid [[thread_index_in_threadgroup]],
633 uint simd_lane [[thread_index_in_simdgroup]],
634 uint simd_id [[simdgroup_index_in_threadgroup]])
635{
636 // Load vector into shared memory
637 threadgroup float vec_tile[VEC_TILE_SIZE];
638 for (uint i = tid; i < cols; i += 256) {
639 vec_tile[i] = vector[i];
640 }
641 threadgroup_barrier(mem_flags::mem_threadgroup);
642
643 // Each simdgroup handles 4 consecutive rows
644 uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
645 if (row_base >= rows) return;
646
647 // Q4_0: each block is 18 bytes for 32 elements
648 uint blocks_per_row = cols / 32;
649 uint row_bytes = blocks_per_row * 18;
650
651 // Pointers to each row's Q4_0 data
652 device const uchar* r0 = matrix + row_base * row_bytes;
653 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
654 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
655 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
656
657 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
658
659 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
660 uint bb = blk * 18;
661 uint vb = blk * 32;
662
663 // Prefetch all 4 scales
664 float sc0 = float(*(device const half*)(r0 + bb));
665 float sc1 = float(*(device const half*)(r1 + bb));
666 float sc2 = float(*(device const half*)(r2 + bb));
667 float sc3 = float(*(device const half*)(r3 + bb));
668
669 // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
670 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
671 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
672 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
673 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
674
675 // Load 8 float4 vector values for 32 elements from shared memory
676 // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
677 float4 v0 = *(threadgroup const float4*)(vec_tile + vb); // [0..3]
678 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4); // [4..7]
679 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8); // [8..11]
680 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12); // [12..15]
681 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16); // [16..19]
682 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20); // [20..23]
683 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24); // [24..27]
684 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28); // [28..31]
685
686 // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
687 // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
688 float bd0=0, bd1=0, bd2=0, bd3=0;
689 uchar4 b;
690
691 // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
692 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
693 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
694 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
695 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
696 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
697 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
698 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
699 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
700 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
701 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
702 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
703 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
704 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
705 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
706 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
707 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
708
709 // Row 1
710 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
711 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
712 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
713 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
714 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
715 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
716 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
717 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
718 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
719 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
720 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
721 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
722 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
723 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
724 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
725 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
726
727 // Row 2
728 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
729 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
730 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
731 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
732 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
733 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
734 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
735 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
736 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
737 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
738 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
739 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
740 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
741 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
742 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
743 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
744
745 // Row 3
746 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
747 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
748 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
749 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
750 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
751 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
752 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
753 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
754 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
755 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
756 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
757 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
758 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
759 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
760 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
761 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
762
763 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
764 }
765
766 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
767 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
768
769 if (simd_lane == 0) {
770 if (row_base < rows) output[row_base] = sum0;
771 if (row_base + 1 < rows) output[row_base + 1] = sum1;
772 if (row_base + 2 < rows) output[row_base + 2] = sum2;
773 if (row_base + 3 < rows) output[row_base + 3] = sum3;
774 }
775}
776
777// ── attention ───────────────────────────────────────────────────────────
778// Single-query attention with simdgroup cooperative reductions.
779// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
780// with simd_max/simd_sum reductions, then weighted sum of V.
781// Each threadgroup handles one head with 256 threads (8 simdgroups).
782//
783// Buffers:
784// q: [num_heads * head_dim] current query
785// k_cache: [max_seq_len * num_kv_heads * head_dim]
786// v_cache: [max_seq_len * num_kv_heads * head_dim]
787// output: [num_heads * head_dim]
788kernel void attention(
789 device const float* q [[buffer(0)]],
790 device const half* k_cache [[buffer(1)]],
791 device const half* v_cache [[buffer(2)]],
792 device float* output [[buffer(3)]],
793 constant uint& seq_len [[buffer(4)]],
794 constant uint& num_heads [[buffer(5)]],
795 constant uint& num_kv_heads [[buffer(6)]],
796 constant uint& head_dim [[buffer(7)]],
797 uint tgid [[threadgroup_position_in_grid]],
798 uint tid [[thread_index_in_threadgroup]],
799 uint simd_lane [[thread_index_in_simdgroup]],
800 uint simd_id [[simdgroup_index_in_threadgroup]])
801{
802 uint head = tgid;
803 if (head >= num_heads) return;
804 uint kv_head = head / (num_heads / num_kv_heads);
805
806 uint q_off = head * head_dim;
807
808 // Step 1: Compute attention scores Q·K^T with simdgroup reduction
809 // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
810 // vs 4096. For seq_len > 2048, generation-phase attention is rare;
811 // most generation steps have short effective context.
812 threadgroup float scores[ATTN_SCORES_SIZE]; // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
813
814 // Q·K^T with half4/float4 vectorized loads.
815 // Each simdgroup handles one s; 32 lanes cover head_dim in chunks of 4.
816 // For head_dim=128, every lane does exactly one half4 load (no loop).
817 // For head_dim=64, 16 lanes active; for head_dim=96, 24 lanes.
818 uint head_dim4 = head_dim / 4;
819 for (uint s = simd_id; s < seq_len; s += 8) {
820 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
821 float dot = 0.0;
822 for (uint d4 = simd_lane; d4 < head_dim4; d4 += 32) {
823 uint d = d4 * 4;
824 float4 q4 = *(device const float4*)(q + q_off + d);
825 float4 k4 = float4(*(device const half4*)(k_cache + k_off + d));
826 dot += q4.x * k4.x + q4.y * k4.y + q4.z * k4.z + q4.w * k4.w;
827 }
828 // Scalar fallback for head_dim not divisible by 4 (unused for all
829 // current models: head_dim ∈ {64, 96, 128} are all multiples of 4).
830 for (uint d = head_dim4 * 4 + simd_lane; d < head_dim; d += 32) {
831 dot += q[q_off + d] * float(k_cache[k_off + d]);
832 }
833 dot = simd_sum(dot);
834 if (simd_lane == 0) {
835 scores[s] = dot * fast::rsqrt(float(head_dim));
836 }
837 }
838 threadgroup_barrier(mem_flags::mem_threadgroup);
839
840 // Step 2: Softmax over scores (cooperative)
841 // Find max
842 float local_max = -INFINITY;
843 for (uint s = tid; s < seq_len; s += 256) {
844 local_max = max(local_max, scores[s]);
845 }
846 local_max = simd_max(local_max);
847 threadgroup float shared_max[8];
848 if (simd_lane == 0) shared_max[simd_id] = local_max;
849 threadgroup_barrier(mem_flags::mem_threadgroup);
850 if (tid == 0) {
851 float m = shared_max[0];
852 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
853 shared_max[0] = m;
854 }
855 threadgroup_barrier(mem_flags::mem_threadgroup);
856 float max_val = shared_max[0];
857
858 // Exp and sum
859 float local_sum = 0.0;
860 for (uint s = tid; s < seq_len; s += 256) {
861 scores[s] = fast::exp(scores[s] - max_val);
862 local_sum += scores[s];
863 }
864 local_sum = simd_sum(local_sum);
865 threadgroup float shared_sum[8];
866 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
867 threadgroup_barrier(mem_flags::mem_threadgroup);
868 if (tid == 0) {
869 float total = 0.0;
870 for (uint i = 0; i < 8; i++) total += shared_sum[i];
871 shared_sum[0] = 1.0 / total;
872 }
873 threadgroup_barrier(mem_flags::mem_threadgroup);
874 float inv_sum = shared_sum[0];
875
876 for (uint s = tid; s < seq_len; s += 256) {
877 scores[s] *= inv_sum;
878 }
879 threadgroup_barrier(mem_flags::mem_threadgroup);
880
881 // Step 3: Weighted sum of V: output = scores · V.
882 //
883 // v0.7.4 restructure: one simdgroup per d4 chunk. The 32 lanes in a
884 // simdgroup partition the seq_len dimension (lane handles positions
885 // s = simd_lane, simd_lane+32, simd_lane+64, ...) and reduce their
886 // partial sums via simd_sum. This keeps all 256 threads productive
887 // for head_dim ∈ {64, 96, 128}; the old "one thread per d4" layout
888 // kept only 16/256 threads productive for head_dim=64.
889 //
890 // 8 simdgroups handle 8 d4s per outer iteration:
891 // head_dim=64 (head_dim4=16): 2 iters
892 // head_dim=96 (head_dim4=24): 3 iters
893 // head_dim=128 (head_dim4=32): 4 iters
894 uint v_stride = num_kv_heads * head_dim;
895 for (uint d4 = simd_id; d4 < head_dim4; d4 += 8) {
896 uint d = d4 * 4;
897 uint v_base = kv_head * head_dim + d;
898 float4 partial = float4(0.0);
899 for (uint s = simd_lane; s < seq_len; s += 32) {
900 half4 v = *(device const half4*)(v_cache + s * v_stride + v_base);
901 partial += scores[s] * float4(v);
902 }
903 // Reduce the 32 lanes of this simdgroup to lane 0 per component.
904 partial.x = simd_sum(partial.x);
905 partial.y = simd_sum(partial.y);
906 partial.z = simd_sum(partial.z);
907 partial.w = simd_sum(partial.w);
908 if (simd_lane == 0) {
909 *(device float4*)(output + q_off + d) = partial;
910 }
911 }
912 // Scalar fallback for dims not divisible by 4 (unused for current models —
913 // all supported head_dims are in {64, 96, 128}). Same simd-per-d layout.
914 for (uint d = head_dim4 * 4 + simd_id; d < head_dim; d += 8) {
915 uint v_base = kv_head * head_dim + d;
916 float partial = 0.0;
917 for (uint s = simd_lane; s < seq_len; s += 32) {
918 partial += scores[s] * float(v_cache[s * v_stride + v_base]);
919 }
920 partial = simd_sum(partial);
921 if (simd_lane == 0) {
922 output[q_off + d] = partial;
923 }
924 }
925}
926
927// ── Batched prefill kernels ────────────────────────────────────────────
928// These kernels process M input vectors against the same weight matrix
929// in a single dispatch, converting mat-vec into mat-mat for better GPU
930// utilization during prompt prefill.
931
932// ── rms_norm_batch ─────────────────────────────────────────────────────
933// RMS normalization for a batch of vectors.
934// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
935// Grid: M threadgroups (one per token).
936kernel void rms_norm_batch(
937 device const float* input [[buffer(0)]], // [M, n]
938 device const float* weight [[buffer(1)]], // [n]
939 device float* output [[buffer(2)]], // [M, n]
940 constant uint& n [[buffer(3)]],
941 constant float& eps [[buffer(4)]],
942 constant uint& num_tokens [[buffer(5)]],
943 uint tgid [[threadgroup_position_in_grid]],
944 uint tid [[thread_index_in_threadgroup]])
945{
946 if (tgid >= num_tokens) return;
947
948 uint base = tgid * n;
949
950 float sum_sq = 0.0f;
951 for (uint i = tid; i < n; i += 256) {
952 float v = input[base + i];
953 sum_sq += v * v;
954 }
955
956 sum_sq = simd_sum(sum_sq);
957
958 threadgroup float shared[8];
959 uint simd_id = tid / 32;
960 uint simd_lane = tid % 32;
961 if (simd_lane == 0) {
962 shared[simd_id] = sum_sq;
963 }
964 threadgroup_barrier(mem_flags::mem_threadgroup);
965
966 if (tid == 0) {
967 float total = 0.0f;
968 for (uint i = 0; i < 8; i++) {
969 total += shared[i];
970 }
971 shared[0] = fast::rsqrt(total / float(n) + eps);
972 }
973 threadgroup_barrier(mem_flags::mem_threadgroup);
974
975 float inv_rms = shared[0];
976
977 for (uint i = tid; i < n; i += 256) {
978 output[base + i] = input[base + i] * inv_rms * weight[i];
979 }
980}
981
982// ── rope_batch ─────────────────────────────────────────────────────────
983// Rotary Position Embedding for a batch of vectors with different positions.
984// data layout: [M, num_heads * head_dim], positions: [M]
985// Each thread handles one (token, head, pair) combination.
986kernel void rope_batch(
987 device float* data [[buffer(0)]], // [M, num_heads * head_dim]
988 constant uint& num_heads [[buffer(1)]],
989 constant uint& head_dim [[buffer(2)]],
990 device const uint* positions [[buffer(3)]], // [M] position per token
991 constant float& theta [[buffer(4)]],
992 constant uint& num_tokens [[buffer(5)]],
993 uint id [[thread_position_in_grid]])
994{
995 uint half_dim = head_dim / 2;
996 uint pairs_per_token = num_heads * half_dim;
997 uint total = num_tokens * pairs_per_token;
998 if (id >= total) return;
999
1000 uint token = id / pairs_per_token;
1001 uint rem = id % pairs_per_token;
1002 uint h = rem / half_dim;
1003 uint i = rem % half_dim;
1004 uint off = token * (num_heads * head_dim) + h * head_dim;
1005
1006 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
1007 float angle = float(positions[token]) * freq;
1008 float c = cos(angle);
1009 float s = sin(angle);
1010
1011 float x0 = data[off + 2 * i];
1012 float x1 = data[off + 2 * i + 1];
1013 data[off + 2 * i] = x0 * c - x1 * s;
1014 data[off + 2 * i + 1] = x0 * s + x1 * c;
1015}
1016
1017// ── silu_mul_fused_batch ───────────────────────────────────────────────
1018// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
1019// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
1020kernel void silu_mul_fused_batch(
1021 device const float* gate_up [[buffer(0)]], // [M, 2*n]
1022 device float* output [[buffer(1)]], // [M, n]
1023 constant uint& n [[buffer(2)]],
1024 constant uint& num_tokens [[buffer(3)]],
1025 uint id [[thread_position_in_grid]])
1026{
1027 uint total = num_tokens * n;
1028 if (id >= total) return;
1029 uint token = id / n;
1030 uint i = id % n;
1031 uint gu_base = token * 2 * n;
1032 float g = gate_up[gu_base + i];
1033 float u = gate_up[gu_base + n + i];
1034 output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
1035}
1036
1037// ── add_inplace_batch ──────────────────────────────────────────────────
1038// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
1039kernel void add_inplace_batch(
1040 device float* a [[buffer(0)]], // [M * n]
1041 device const float* b [[buffer(1)]], // [M * n]
1042 constant uint& total [[buffer(2)]], // M * n
1043 uint id [[thread_position_in_grid]])
1044{
1045 if (id >= total) return;
1046 a[id] += b[id];
1047}
1048
1049// ── copy_embedding_batch ───────────────────────────────────────────────
1050// Copy M embedding rows from embedding table to a contiguous batch buffer.
1051// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1052kernel void copy_embedding_batch(
1053 device const float* embed [[buffer(0)]], // [vocab_size, dim]
1054 device float* output [[buffer(1)]], // [M, dim]
1055 device const uint* tokens [[buffer(2)]], // [M]
1056 constant uint& dim [[buffer(3)]],
1057 constant uint& num_tokens [[buffer(4)]],
1058 uint id [[thread_position_in_grid]])
1059{
1060 uint total = num_tokens * dim;
1061 if (id >= total) return;
1062 uint token_idx = id / dim;
1063 uint d = id % dim;
1064 output[id] = embed[tokens[token_idx] * dim + d];
1065}
1066
1067// ── matmul_vec_batch ───────────────────────────────────────────────────
1068// Batched matrix-vector multiply: process M input vectors against the same
1069// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1070// Each threadgroup handles one (token, row_group) pair.
1071kernel void matmul_vec_batch(
1072 device const float* matrix [[buffer(0)]], // [rows, cols] weight
1073 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1074 device float* outputs [[buffer(2)]], // [M, rows] output batch
1075 constant uint& num_tokens [[buffer(3)]], // M
1076 constant uint& rows [[buffer(4)]],
1077 constant uint& cols [[buffer(5)]],
1078 uint tgid [[threadgroup_position_in_grid]],
1079 uint tid [[thread_index_in_threadgroup]],
1080 uint simd_lane [[thread_index_in_simdgroup]],
1081 uint simd_id [[simdgroup_index_in_threadgroup]])
1082{
1083 uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1084 uint token = tgid / row_tgs;
1085 uint tg_in_token = tgid % row_tgs;
1086 if (token >= num_tokens) return;
1087
1088 // Load this token's input vector into shared memory
1089 threadgroup float vec_tile[VEC_TILE_SIZE];
1090 device const float* input = inputs + token * cols;
1091 for (uint i = tid; i < cols; i += 256) {
1092 vec_tile[i] = input[i];
1093 }
1094 threadgroup_barrier(mem_flags::mem_threadgroup);
1095
1096 uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1097 if (row_base >= rows) return;
1098
1099 uint base0 = row_base * cols;
1100 uint base1 = (row_base + 1) * cols;
1101 uint base2 = (row_base + 2) * cols;
1102 uint base3 = (row_base + 3) * cols;
1103 uint base4 = (row_base + 4) * cols;
1104 uint base5 = (row_base + 5) * cols;
1105 uint base6 = (row_base + 6) * cols;
1106 uint base7 = (row_base + 7) * cols;
1107
1108 uint cols_vec4 = cols & ~127u;
1109 float4 sum4_0 = float4(0.0f);
1110 float4 sum4_1 = float4(0.0f);
1111 float4 sum4_2 = float4(0.0f);
1112 float4 sum4_3 = float4(0.0f);
1113 float4 sum4_4 = float4(0.0f);
1114 float4 sum4_5 = float4(0.0f);
1115 float4 sum4_6 = float4(0.0f);
1116 float4 sum4_7 = float4(0.0f);
1117
1118 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1119 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1120 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1121 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1122 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1123 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1124 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1125 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1126 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1127 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1128 }
1129
1130 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1131 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1132 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1133 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1134 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1135 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1136 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1137 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1138
1139 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1140 float vv = vec_tile[j];
1141 sum0 += matrix[base0 + j] * vv;
1142 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1143 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1144 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1145 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1146 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1147 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1148 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1149 }
1150
1151 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1152 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1153 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1154 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1155
1156 device float* output = outputs + token * rows;
1157 if (simd_lane == 0) {
1158 if (row_base < rows) output[row_base] = sum0;
1159 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1160 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1161 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1162 if (row_base + 4 < rows) output[row_base + 4] = sum4;
1163 if (row_base + 5 < rows) output[row_base + 5] = sum5;
1164 if (row_base + 6 < rows) output[row_base + 6] = sum6;
1165 if (row_base + 7 < rows) output[row_base + 7] = sum7;
1166 }
1167}
1168
1169// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1170// Batched Q8_0 matrix-vector multiply for M input vectors.
1171// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1172kernel void matmul_vec_q8_batch(
1173 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1174 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1175 device float* outputs [[buffer(2)]], // [M, rows] output batch
1176 constant uint& num_tokens [[buffer(3)]], // M
1177 constant uint& rows [[buffer(4)]],
1178 constant uint& cols [[buffer(5)]],
1179 uint tgid [[threadgroup_position_in_grid]],
1180 uint tid [[thread_index_in_threadgroup]],
1181 uint simd_lane [[thread_index_in_simdgroup]],
1182 uint simd_id [[simdgroup_index_in_threadgroup]])
1183{
1184 uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1185 uint token = tgid / row_tgs;
1186 uint tg_in_token = tgid % row_tgs;
1187 if (token >= num_tokens) return;
1188
1189 // Load this token's input vector into shared memory
1190 threadgroup float vec_tile[VEC_TILE_SIZE];
1191 device const float* input = inputs + token * cols;
1192 for (uint i = tid; i < cols; i += 256) {
1193 vec_tile[i] = input[i];
1194 }
1195 threadgroup_barrier(mem_flags::mem_threadgroup);
1196
1197 uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1198 if (row_base >= rows) return;
1199
1200 uint blocks_per_row = cols / 32;
1201 uint row_bytes = blocks_per_row * 34;
1202
1203 device const uchar* r0 = matrix + row_base * row_bytes;
1204 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1205 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1206 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1207
1208 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1209
1210 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1211 uint bb = blk * 34;
1212 uint vb = blk * 32;
1213
1214 float sc0 = float(*(device const half*)(r0 + bb));
1215 float sc1 = float(*(device const half*)(r1 + bb));
1216 float sc2 = float(*(device const half*)(r2 + bb));
1217 float sc3 = float(*(device const half*)(r3 + bb));
1218
1219 // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1220 // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1221 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1222 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1223 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1224 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1225
1226 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1227 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1228 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1229 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1230 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1231 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1232 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1233 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1234
1235 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1236 short4 _s = short4(SHORT4); \
1237 char2 _a = as_type<char2>(_s.x); \
1238 char2 _b = as_type<char2>(_s.y); \
1239 char2 _c = as_type<char2>(_s.z); \
1240 char2 _d = as_type<char2>(_s.w); \
1241 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1242 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1243 }
1244
1245 float4 f0, f1;
1246 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1247
1248 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1249 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1250 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1251 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1252
1253 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1254 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1255 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1256 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1257
1258 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1259 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1260 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1261 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1262
1263 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1264 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1265 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1266 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1267
1268 #undef Q8_UNPACK8
1269
1270 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1271 }
1272
1273 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1274 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1275
1276 device float* output = outputs + token * rows;
1277 if (simd_lane == 0) {
1278 if (row_base < rows) output[row_base] = sum0;
1279 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1280 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1281 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1282 }
1283}
1284
1285// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1286// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1287// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1288// the Q8_0 weight blocks are fetched once from device memory and reused for
1289// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1290// per-token dispatch).
1291//
1292// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1293// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1294// with simd_sum. Token vectors are read directly from device memory inside
1295// the block loop (not cached in shared memory) so intermediate_size up to
1296// 8192 fits without spilling threadgroup memory.
1297constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1298
1299kernel void matmul_q8_gemm_batch(
1300 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1301 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1302 device float* outputs [[buffer(2)]], // [M, rows] output batch
1303 constant uint& num_tokens [[buffer(3)]], // M
1304 constant uint& rows [[buffer(4)]],
1305 constant uint& cols [[buffer(5)]],
1306 uint2 tgid [[threadgroup_position_in_grid]],
1307 uint tid [[thread_index_in_threadgroup]],
1308 uint simd_lane [[thread_index_in_simdgroup]],
1309 uint simd_id [[simdgroup_index_in_threadgroup]])
1310{
1311 uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1312 uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1313 if (row_base >= rows || tok_base >= num_tokens) return;
1314
1315 // How many tokens in this tile are valid?
1316 uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1317
1318 uint blocks_per_row = cols / 32;
1319 uint row_bytes = blocks_per_row * 34;
1320
1321 device const uchar* r0 = matrix + row_base * row_bytes;
1322 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1323 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1324 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1325
1326 // Accumulators: 4 tokens × 4 rows per simdgroup.
1327 float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1328 float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1329 float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1330 float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1331
1332 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1333 uint bb = blk * 34;
1334 uint vb = blk * 32;
1335
1336 // ── Load weight data ONCE per block (reused across all tokens) ──
1337 float sc0 = float(*(device const half*)(r0 + bb));
1338 float sc1 = float(*(device const half*)(r1 + bb));
1339 float sc2 = float(*(device const half*)(r2 + bb));
1340 float sc3 = float(*(device const half*)(r3 + bb));
1341
1342 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1343 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1344 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1345 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1346
1347 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1348 short4 _s = short4(SHORT4); \
1349 char2 _a = as_type<char2>(_s.x); \
1350 char2 _b = as_type<char2>(_s.y); \
1351 char2 _c = as_type<char2>(_s.z); \
1352 char2 _d = as_type<char2>(_s.w); \
1353 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1354 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1355 }
1356
1357 // Unpack all 4 rows × 8 float4 weights (scaled). These live in
1358 // registers for the duration of the block and are dotted against
1359 // every token's vector tile.
1360 float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1361 float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1362 float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1363 float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1364
1365 Q8_UNPACK8(d0[0], w0_0, w0_1);
1366 Q8_UNPACK8(d0[1], w0_2, w0_3);
1367 Q8_UNPACK8(d0[2], w0_4, w0_5);
1368 Q8_UNPACK8(d0[3], w0_6, w0_7);
1369
1370 Q8_UNPACK8(d1[0], w1_0, w1_1);
1371 Q8_UNPACK8(d1[1], w1_2, w1_3);
1372 Q8_UNPACK8(d1[2], w1_4, w1_5);
1373 Q8_UNPACK8(d1[3], w1_6, w1_7);
1374
1375 Q8_UNPACK8(d2[0], w2_0, w2_1);
1376 Q8_UNPACK8(d2[1], w2_2, w2_3);
1377 Q8_UNPACK8(d2[2], w2_4, w2_5);
1378 Q8_UNPACK8(d2[3], w2_6, w2_7);
1379
1380 Q8_UNPACK8(d3[0], w3_0, w3_1);
1381 Q8_UNPACK8(d3[1], w3_2, w3_3);
1382 Q8_UNPACK8(d3[2], w3_4, w3_5);
1383 Q8_UNPACK8(d3[3], w3_6, w3_7);
1384
1385 #undef Q8_UNPACK8
1386
1387 // ── For each token, read vector and accumulate against shared weights ──
1388 // Token 0 (always valid: tok_count >= 1).
1389 {
1390 device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1391 float4 v0 = *(device const float4*)(a0);
1392 float4 v1 = *(device const float4*)(a0 + 4);
1393 float4 v2 = *(device const float4*)(a0 + 8);
1394 float4 v3 = *(device const float4*)(a0 + 12);
1395 float4 v4 = *(device const float4*)(a0 + 16);
1396 float4 v5 = *(device const float4*)(a0 + 20);
1397 float4 v6 = *(device const float4*)(a0 + 24);
1398 float4 v7 = *(device const float4*)(a0 + 28);
1399 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1400 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1401 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1402 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1403 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1404 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1405 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1406 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1407 s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1408 }
1409 // Token 1
1410 if (tok_count > 1) {
1411 device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1412 float4 v0 = *(device const float4*)(a1);
1413 float4 v1 = *(device const float4*)(a1 + 4);
1414 float4 v2 = *(device const float4*)(a1 + 8);
1415 float4 v3 = *(device const float4*)(a1 + 12);
1416 float4 v4 = *(device const float4*)(a1 + 16);
1417 float4 v5 = *(device const float4*)(a1 + 20);
1418 float4 v6 = *(device const float4*)(a1 + 24);
1419 float4 v7 = *(device const float4*)(a1 + 28);
1420 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1421 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1422 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1423 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1424 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1425 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1426 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1427 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1428 s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1429 }
1430 // Token 2
1431 if (tok_count > 2) {
1432 device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1433 float4 v0 = *(device const float4*)(a2);
1434 float4 v1 = *(device const float4*)(a2 + 4);
1435 float4 v2 = *(device const float4*)(a2 + 8);
1436 float4 v3 = *(device const float4*)(a2 + 12);
1437 float4 v4 = *(device const float4*)(a2 + 16);
1438 float4 v5 = *(device const float4*)(a2 + 20);
1439 float4 v6 = *(device const float4*)(a2 + 24);
1440 float4 v7 = *(device const float4*)(a2 + 28);
1441 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1442 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1443 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1444 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1445 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1446 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1447 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1448 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1449 s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1450 }
1451 // Token 3
1452 if (tok_count > 3) {
1453 device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1454 float4 v0 = *(device const float4*)(a3);
1455 float4 v1 = *(device const float4*)(a3 + 4);
1456 float4 v2 = *(device const float4*)(a3 + 8);
1457 float4 v3 = *(device const float4*)(a3 + 12);
1458 float4 v4 = *(device const float4*)(a3 + 16);
1459 float4 v5 = *(device const float4*)(a3 + 20);
1460 float4 v6 = *(device const float4*)(a3 + 24);
1461 float4 v7 = *(device const float4*)(a3 + 28);
1462 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1463 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1464 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1465 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1466 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1467 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1468 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1469 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1470 s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1471 }
1472 }
1473
1474 // simdgroup reduction
1475 s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1476 s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1477 s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1478 s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1479
1480 if (simd_lane == 0) {
1481 device float* o0 = outputs + (tok_base + 0) * rows;
1482 if (row_base < rows) o0[row_base] = s00;
1483 if (row_base + 1 < rows) o0[row_base + 1] = s01;
1484 if (row_base + 2 < rows) o0[row_base + 2] = s02;
1485 if (row_base + 3 < rows) o0[row_base + 3] = s03;
1486
1487 if (tok_count > 1) {
1488 device float* o1 = outputs + (tok_base + 1) * rows;
1489 if (row_base < rows) o1[row_base] = s10;
1490 if (row_base + 1 < rows) o1[row_base + 1] = s11;
1491 if (row_base + 2 < rows) o1[row_base + 2] = s12;
1492 if (row_base + 3 < rows) o1[row_base + 3] = s13;
1493 }
1494 if (tok_count > 2) {
1495 device float* o2 = outputs + (tok_base + 2) * rows;
1496 if (row_base < rows) o2[row_base] = s20;
1497 if (row_base + 1 < rows) o2[row_base + 1] = s21;
1498 if (row_base + 2 < rows) o2[row_base + 2] = s22;
1499 if (row_base + 3 < rows) o2[row_base + 3] = s23;
1500 }
1501 if (tok_count > 3) {
1502 device float* o3 = outputs + (tok_base + 3) * rows;
1503 if (row_base < rows) o3[row_base] = s30;
1504 if (row_base + 1 < rows) o3[row_base + 1] = s31;
1505 if (row_base + 2 < rows) o3[row_base + 2] = s32;
1506 if (row_base + 3 < rows) o3[row_base + 3] = s33;
1507 }
1508 }
1509}
1510
1511// ── matmul_q8_mma ──────────────────────────────────────────────────────
1512// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1513// simdgroup_matrix tiles (simdgroup_multiply_accumulate). This dispatches
1514// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1515// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1516//
1517// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1518// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1519// simdgroup_matrix<float, 8, 8> accumulator. Weight bytes are cooperatively
1520// dequantized into threadgroup memory once per block and reused by all
1521// simdgroups in the tile.
1522//
1523// Assumptions (verified in the dispatch helper, falls back otherwise):
1524// * cols % 32 == 0 (one Q8_0 block per K chunk)
1525// * rows % 16 == 0 (tile-aligned; true for all supported architectures)
1526// * num_tokens may be any value; partial row at the tile boundary is handled
1527// via a scratch copy path.
1528constant constexpr uint MMA_TOK_TILE = 16;
1529constant constexpr uint MMA_ROW_TILE = 16;
1530
1531kernel void matmul_q8_mma(
1532 device const uchar* matrix [[buffer(0)]], // Q8_0 [rows, cols/32 * 34]
1533 device const float* inputs [[buffer(1)]], // [M, cols]
1534 device float* outputs [[buffer(2)]], // [M, rows]
1535 constant uint& num_tokens [[buffer(3)]],
1536 constant uint& rows [[buffer(4)]],
1537 constant uint& cols [[buffer(5)]],
1538 uint2 tgid [[threadgroup_position_in_grid]],
1539 uint tid [[thread_index_in_threadgroup]],
1540 uint simd_id [[simdgroup_index_in_threadgroup]])
1541{
1542 uint row_base = tgid.x * MMA_ROW_TILE;
1543 uint tok_base = tgid.y * MMA_TOK_TILE;
1544 if (row_base >= rows || tok_base >= num_tokens) return;
1545
1546 // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1547 threadgroup float w_tile[MMA_ROW_TILE * 32];
1548 threadgroup float t_tile[MMA_TOK_TILE * 32];
1549
1550 // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1551 uint sg_tok_base = (simd_id / 2) * 8; // row within output tile (token dim)
1552 uint sg_row_base = (simd_id % 2) * 8; // col within output tile (row dim)
1553
1554 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1555
1556 uint blocks_per_row = cols / 32;
1557 uint row_bytes = blocks_per_row * 34;
1558
1559 for (uint blk = 0; blk < blocks_per_row; blk++) {
1560 // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1561 // 512 floats / 128 threads = 4 floats per thread.
1562 {
1563 uint base = tid * 4;
1564 for (uint ii = 0; ii < 4; ii++) {
1565 uint idx = base + ii;
1566 uint r = idx / 32;
1567 uint k = idx % 32;
1568 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1569 float sc = float(*(device const half*)rp);
1570 int ival = int(*(device const int8_t*)(rp + 2 + k));
1571 w_tile[r * 32 + k] = float(ival) * sc;
1572 }
1573 }
1574
1575 // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1576 {
1577 uint base = tid * 4;
1578 for (uint ii = 0; ii < 4; ii++) {
1579 uint idx = base + ii;
1580 uint m = idx / 32;
1581 uint k = idx % 32;
1582 uint tok = tok_base + m;
1583 t_tile[m * 32 + k] = (tok < num_tokens)
1584 ? inputs[tok * cols + blk * 32 + k]
1585 : 0.0f;
1586 }
1587 }
1588
1589 threadgroup_barrier(mem_flags::mem_threadgroup);
1590
1591 // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1592 // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k] (M×K, no transpose)
1593 // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k] (loaded transposed → K×R)
1594 // C[m, r] += A[m, k] * B[k, r]
1595 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1596 simdgroup_matrix<float, 8, 8> A, B;
1597 simdgroup_load(A,
1598 t_tile + sg_tok_base * 32 + k_sub * 8,
1599 32,
1600 ulong2(0, 0),
1601 false);
1602 simdgroup_load(B,
1603 w_tile + sg_row_base * 32 + k_sub * 8,
1604 32,
1605 ulong2(0, 0),
1606 true);
1607 simdgroup_multiply_accumulate(C, A, B, C);
1608 }
1609
1610 threadgroup_barrier(mem_flags::mem_threadgroup);
1611 }
1612
1613 // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1614 // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1615 uint out_tok = tok_base + sg_tok_base;
1616 uint out_row = row_base + sg_row_base;
1617 bool full_tok = (out_tok + 8 <= num_tokens);
1618 if (full_tok) {
1619 // Fast path: entire 8×8 sub-tile is in-bounds.
1620 simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1621 } else if (out_tok < num_tokens) {
1622 // Partial row at the last token tile: stage in per-simdgroup scratch
1623 // and scalar-copy the valid rows.
1624 threadgroup float scratch[4 * 64];
1625 simdgroup_store(C, scratch + simd_id * 64, 8);
1626 simdgroup_barrier(mem_flags::mem_threadgroup);
1627 uint lane = tid % 32;
1628 if (lane == 0) {
1629 uint valid = num_tokens - out_tok; // 1..7
1630 for (uint m = 0; m < valid; m++) {
1631 device float* dst = outputs + (out_tok + m) * rows + out_row;
1632 threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1633 dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1634 dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1635 }
1636 }
1637 }
1638}
1639
1640// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1641// Larger-tile variant of matmul_q8_mma for long-context prefill.
1642//
1643// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1644// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1645// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1646//
1647// simd_id = 2*sg_tok_idx + sg_row_half (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1648// output sub-tiles (tok, row):
1649// (sg_tok_idx*8, sg_row_half*16 + 0) -> C_a
1650// (sg_tok_idx*8, sg_row_half*16 + 8) -> C_b
1651//
1652// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1653// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1654// kernel — and halves the number of threadgroups vs the 16×16 tile.
1655//
1656// Assumptions (verified in dispatch helper, fallback otherwise):
1657// * cols % 32 == 0
1658// * rows % 32 == 0
1659constant constexpr uint MMA32_TOK_TILE = 32;
1660constant constexpr uint MMA32_ROW_TILE = 32;
1661
1662kernel void matmul_q8_mma32(
1663 device const uchar* matrix [[buffer(0)]],
1664 device const float* inputs [[buffer(1)]],
1665 device float* outputs [[buffer(2)]],
1666 constant uint& num_tokens [[buffer(3)]],
1667 constant uint& rows [[buffer(4)]],
1668 constant uint& cols [[buffer(5)]],
1669 uint2 tgid [[threadgroup_position_in_grid]],
1670 uint tid [[thread_index_in_threadgroup]],
1671 uint simd_id [[simdgroup_index_in_threadgroup]])
1672{
1673 uint row_base = tgid.x * MMA32_ROW_TILE;
1674 uint tok_base = tgid.y * MMA32_TOK_TILE;
1675 if (row_base >= rows || tok_base >= num_tokens) return;
1676
1677 // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1678 threadgroup float w_tile[MMA32_ROW_TILE * 32];
1679 threadgroup float t_tile[MMA32_TOK_TILE * 32];
1680
1681 uint sg_tok_idx = simd_id / 2; // 0..3
1682 uint sg_row_half = simd_id % 2; // 0..1
1683 uint sg_tok_base = sg_tok_idx * 8;
1684 uint sg_row_base_a = sg_row_half * 16 + 0;
1685 uint sg_row_base_b = sg_row_half * 16 + 8;
1686
1687 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1688 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1689
1690 uint blocks_per_row = cols / 32;
1691 uint row_bytes = blocks_per_row * 34;
1692
1693 for (uint blk = 0; blk < blocks_per_row; blk++) {
1694 // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1695 {
1696 uint base = tid * 4;
1697 for (uint ii = 0; ii < 4; ii++) {
1698 uint idx = base + ii;
1699 uint r = idx / 32;
1700 uint k = idx % 32;
1701 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1702 float sc = float(*(device const half*)rp);
1703 int ival = int(*(device const int8_t*)(rp + 2 + k));
1704 w_tile[r * 32 + k] = float(ival) * sc;
1705 }
1706 }
1707
1708 // Cooperative token tile load.
1709 {
1710 uint base = tid * 4;
1711 for (uint ii = 0; ii < 4; ii++) {
1712 uint idx = base + ii;
1713 uint m = idx / 32;
1714 uint k = idx % 32;
1715 uint tok = tok_base + m;
1716 t_tile[m * 32 + k] = (tok < num_tokens)
1717 ? inputs[tok * cols + blk * 32 + k]
1718 : 0.0f;
1719 }
1720 }
1721
1722 threadgroup_barrier(mem_flags::mem_threadgroup);
1723
1724 // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1725 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1726 simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1727 simdgroup_load(A,
1728 t_tile + sg_tok_base * 32 + k_sub * 8,
1729 32,
1730 ulong2(0, 0),
1731 false);
1732 simdgroup_load(B_a,
1733 w_tile + sg_row_base_a * 32 + k_sub * 8,
1734 32,
1735 ulong2(0, 0),
1736 true);
1737 simdgroup_load(B_b,
1738 w_tile + sg_row_base_b * 32 + k_sub * 8,
1739 32,
1740 ulong2(0, 0),
1741 true);
1742 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1743 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1744 }
1745
1746 threadgroup_barrier(mem_flags::mem_threadgroup);
1747 }
1748
1749 // Store both 8×8 accumulators. rows is always MMA32_ROW_TILE-aligned
1750 // (verified in dispatch), so full simdgroup_store is safe for the row
1751 // dimension; only the last token tile may be partial.
1752 uint out_tok = tok_base + sg_tok_base;
1753 uint out_row_a = row_base + sg_row_base_a;
1754 uint out_row_b = row_base + sg_row_base_b;
1755 bool full_tok = (out_tok + 8 <= num_tokens);
1756 if (full_tok) {
1757 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1758 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1759 } else if (out_tok < num_tokens) {
1760 threadgroup float scratch[8 * 2 * 64]; // 8 simdgroups × 2 accs × 64 floats
1761 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1762 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1763 simdgroup_barrier(mem_flags::mem_threadgroup);
1764 uint lane = tid % 32;
1765 if (lane == 0) {
1766 uint valid = num_tokens - out_tok; // 1..7
1767 for (uint m = 0; m < valid; m++) {
1768 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1769 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1770 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1771 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1772 for (uint j = 0; j < 8; j++) {
1773 dst_a[j] = src_a[j];
1774 dst_b[j] = src_b[j];
1775 }
1776 }
1777 }
1778 }
1779}
1780
1781// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1782// FP16 threadgroup-tile variant of matmul_q8_mma32.
1783//
1784// Stores dequantized weights and token inputs as `half` in threadgroup
1785// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1786// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1787// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1788// representation preserves the full quantized dynamic range. Token
1789// activations stay numerically safe because the subsequent
1790// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1791//
1792// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1793// accumulators each. Primary win vs mma32 is occupancy at moderate
1794// prefill lengths where the GPU is wave-starved.
1795kernel void matmul_q8_mma32_h(
1796 device const uchar* matrix [[buffer(0)]],
1797 device const float* inputs [[buffer(1)]],
1798 device float* outputs [[buffer(2)]],
1799 constant uint& num_tokens [[buffer(3)]],
1800 constant uint& rows [[buffer(4)]],
1801 constant uint& cols [[buffer(5)]],
1802 uint2 tgid [[threadgroup_position_in_grid]],
1803 uint tid [[thread_index_in_threadgroup]],
1804 uint simd_id [[simdgroup_index_in_threadgroup]])
1805{
1806 uint row_base = tgid.x * MMA32_ROW_TILE;
1807 uint tok_base = tgid.y * MMA32_TOK_TILE;
1808 if (row_base >= rows || tok_base >= num_tokens) return;
1809
1810 // 32×32 half tiles — 2 KB each, 4 KB total.
1811 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1812 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1813
1814 uint sg_tok_idx = simd_id / 2;
1815 uint sg_row_half = simd_id % 2;
1816 uint sg_tok_base = sg_tok_idx * 8;
1817 uint sg_row_base_a = sg_row_half * 16 + 0;
1818 uint sg_row_base_b = sg_row_half * 16 + 8;
1819
1820 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1821 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1822
1823 uint blocks_per_row = cols / 32;
1824 uint row_bytes = blocks_per_row * 34;
1825
1826 for (uint blk = 0; blk < blocks_per_row; blk++) {
1827 // Cooperative weight dequantization to FP16.
1828 {
1829 uint base = tid * 4;
1830 for (uint ii = 0; ii < 4; ii++) {
1831 uint idx = base + ii;
1832 uint r = idx / 32;
1833 uint k = idx % 32;
1834 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1835 float sc = float(*(device const half*)rp);
1836 int ival = int(*(device const int8_t*)(rp + 2 + k));
1837 w_tile[r * 32 + k] = half(float(ival) * sc);
1838 }
1839 }
1840
1841 // Cooperative token tile load (f32 → f16 narrowing).
1842 {
1843 uint base = tid * 4;
1844 for (uint ii = 0; ii < 4; ii++) {
1845 uint idx = base + ii;
1846 uint m = idx / 32;
1847 uint k = idx % 32;
1848 uint tok = tok_base + m;
1849 t_tile[m * 32 + k] = (tok < num_tokens)
1850 ? half(inputs[tok * cols + blk * 32 + k])
1851 : half(0);
1852 }
1853 }
1854
1855 threadgroup_barrier(mem_flags::mem_threadgroup);
1856
1857 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1858 simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1859 simdgroup_load(A,
1860 t_tile + sg_tok_base * 32 + k_sub * 8,
1861 32,
1862 ulong2(0, 0),
1863 false);
1864 simdgroup_load(B_a,
1865 w_tile + sg_row_base_a * 32 + k_sub * 8,
1866 32,
1867 ulong2(0, 0),
1868 true);
1869 simdgroup_load(B_b,
1870 w_tile + sg_row_base_b * 32 + k_sub * 8,
1871 32,
1872 ulong2(0, 0),
1873 true);
1874 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1875 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1876 }
1877
1878 threadgroup_barrier(mem_flags::mem_threadgroup);
1879 }
1880
1881 uint out_tok = tok_base + sg_tok_base;
1882 uint out_row_a = row_base + sg_row_base_a;
1883 uint out_row_b = row_base + sg_row_base_b;
1884 bool full_tok = (out_tok + 8 <= num_tokens);
1885 if (full_tok) {
1886 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1887 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1888 } else if (out_tok < num_tokens) {
1889 threadgroup float scratch[8 * 2 * 64];
1890 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1891 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1892 simdgroup_barrier(mem_flags::mem_threadgroup);
1893 uint lane = tid % 32;
1894 if (lane == 0) {
1895 uint valid = num_tokens - out_tok;
1896 for (uint m = 0; m < valid; m++) {
1897 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1898 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1899 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1900 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1901 for (uint j = 0; j < 8; j++) {
1902 dst_a[j] = src_a[j];
1903 dst_b[j] = src_b[j];
1904 }
1905 }
1906 }
1907 }
1908}
1909
1910// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1911// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1912//
1913// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1914// 4 simdgroups × **2×2 grid** of 8×8 accumulators each. Per simdgroup:
1915// C_00 (tok 0..8, row 0..8) C_01 (tok 0..8, row 8..16)
1916// C_10 (tok 8..16, row 0..8) C_11 (tok 8..16, row 8..16)
1917// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1918//
1919// Per K_sub iteration: load two A fragments and two B fragments, then run
1920// **four** MMA instructions reusing A_top with both B's and A_bot with
1921// both B's. That's double the FLOP-per-simdgroup-load compared to the
1922// 2-accumulator kernel and halves the thread count per threadgroup (128
1923// threads), which often improves occupancy on Apple GPUs where the
1924// concurrent-thread budget is the tighter limit than shared-memory size.
1925kernel void matmul_q8_mma32_h4(
1926 device const uchar* matrix [[buffer(0)]],
1927 device const float* inputs [[buffer(1)]],
1928 device float* outputs [[buffer(2)]],
1929 constant uint& num_tokens [[buffer(3)]],
1930 constant uint& rows [[buffer(4)]],
1931 constant uint& cols [[buffer(5)]],
1932 uint2 tgid [[threadgroup_position_in_grid]],
1933 uint tid [[thread_index_in_threadgroup]],
1934 uint simd_id [[simdgroup_index_in_threadgroup]])
1935{
1936 uint row_base = tgid.x * MMA32_ROW_TILE;
1937 uint tok_base = tgid.y * MMA32_TOK_TILE;
1938 if (row_base >= rows || tok_base >= num_tokens) return;
1939
1940 // 32×32 FP16 tiles, 4 KB total.
1941 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1942 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1943
1944 // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1945 uint sg_tok_q = simd_id / 2; // 0..1
1946 uint sg_row_q = simd_id % 2; // 0..1
1947 uint sg_tok_base = sg_tok_q * 16;
1948 uint sg_row_base = sg_row_q * 16;
1949
1950 simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1951 simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1952 simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1953 simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1954
1955 uint blocks_per_row = cols / 32;
1956 uint row_bytes = blocks_per_row * 34;
1957
1958 for (uint blk = 0; blk < blocks_per_row; blk++) {
1959 // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1960 {
1961 uint base = tid * 8;
1962 for (uint ii = 0; ii < 8; ii++) {
1963 uint idx = base + ii;
1964 uint r = idx / 32;
1965 uint k = idx % 32;
1966 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1967 float sc = float(*(device const half*)rp);
1968 int ival = int(*(device const int8_t*)(rp + 2 + k));
1969 w_tile[r * 32 + k] = half(float(ival) * sc);
1970 }
1971 }
1972
1973 // Cooperative token tile load.
1974 {
1975 uint base = tid * 8;
1976 for (uint ii = 0; ii < 8; ii++) {
1977 uint idx = base + ii;
1978 uint m = idx / 32;
1979 uint k = idx % 32;
1980 uint tok = tok_base + m;
1981 t_tile[m * 32 + k] = (tok < num_tokens)
1982 ? half(inputs[tok * cols + blk * 32 + k])
1983 : half(0);
1984 }
1985 }
1986
1987 threadgroup_barrier(mem_flags::mem_threadgroup);
1988
1989 // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1990 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1991 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1992 simdgroup_load(A_top,
1993 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1994 32,
1995 ulong2(0, 0),
1996 false);
1997 simdgroup_load(A_bot,
1998 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1999 32,
2000 ulong2(0, 0),
2001 false);
2002 simdgroup_load(B_lo,
2003 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2004 32,
2005 ulong2(0, 0),
2006 true);
2007 simdgroup_load(B_hi,
2008 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2009 32,
2010 ulong2(0, 0),
2011 true);
2012 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2013 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2014 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2015 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2016 }
2017
2018 threadgroup_barrier(mem_flags::mem_threadgroup);
2019 }
2020
2021 // Store 4 output tiles. Full-tile fast path assumes full 16×16 valid.
2022 uint out_tok_top = tok_base + sg_tok_base + 0;
2023 uint out_tok_bot = tok_base + sg_tok_base + 8;
2024 uint out_row_lo = row_base + sg_row_base + 0;
2025 uint out_row_hi = row_base + sg_row_base + 8;
2026 bool full = (out_tok_bot + 8 <= num_tokens);
2027 if (full) {
2028 simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
2029 simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
2030 simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
2031 simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
2032 } else {
2033 // Partial-token fallback via per-simdgroup scratch.
2034 threadgroup float scratch[4 * 4 * 64]; // 4 simdgroups × 4 accs × 64
2035 uint sg_base = simd_id * 256;
2036 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2037 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2038 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2039 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2040 simdgroup_barrier(mem_flags::mem_threadgroup);
2041 uint lane = tid % 32;
2042 if (lane == 0) {
2043 for (uint m = 0; m < 8; m++) {
2044 uint t_top = out_tok_top + m;
2045 if (t_top < num_tokens) {
2046 device float* dst0 = outputs + t_top * rows + out_row_lo;
2047 device float* dst1 = outputs + t_top * rows + out_row_hi;
2048 threadgroup const float* src0 = scratch + sg_base + 0 + m * 8;
2049 threadgroup const float* src1 = scratch + sg_base + 64 + m * 8;
2050 for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2051 }
2052 uint t_bot = out_tok_bot + m;
2053 if (t_bot < num_tokens) {
2054 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2055 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2056 threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2057 threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2058 for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2059 }
2060 }
2061 }
2062 }
2063}
2064
2065// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2066// All-half MMA variant of matmul_q8_mma32_h4.
2067//
2068// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2069// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2070// rate (dual-issue FMA), so if Q8_0 precision holds through half
2071// accumulation this kernel can double the effective FLOP throughput on
2072// matmul-bound prefill.
2073//
2074// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2075// activations at each layer are bounded (post-RMSNorm ≈ O(1)). Summing
2076// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2077// precision on extreme values, but the inputs are already quantized so the
2078// per-product error floor is higher than the half-precision rounding error.
2079// We verify correctness on 135M / 1B / 3B before enabling.
2080kernel void matmul_q8_mma32_hh4(
2081 device const uchar* matrix [[buffer(0)]],
2082 device const float* inputs [[buffer(1)]],
2083 device float* outputs [[buffer(2)]],
2084 constant uint& num_tokens [[buffer(3)]],
2085 constant uint& rows [[buffer(4)]],
2086 constant uint& cols [[buffer(5)]],
2087 uint2 tgid [[threadgroup_position_in_grid]],
2088 uint tid [[thread_index_in_threadgroup]],
2089 uint simd_id [[simdgroup_index_in_threadgroup]])
2090{
2091 uint row_base = tgid.x * MMA32_ROW_TILE;
2092 uint tok_base = tgid.y * MMA32_TOK_TILE;
2093 if (row_base >= rows || tok_base >= num_tokens) return;
2094
2095 threadgroup half w_tile[MMA32_ROW_TILE * 32];
2096 threadgroup half t_tile[MMA32_TOK_TILE * 32];
2097
2098 uint sg_tok_q = simd_id / 2;
2099 uint sg_row_q = simd_id % 2;
2100 uint sg_tok_base = sg_tok_q * 16;
2101 uint sg_row_base = sg_row_q * 16;
2102
2103 simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2104 simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2105 simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2106 simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2107
2108 uint blocks_per_row = cols / 32;
2109 uint row_bytes = blocks_per_row * 34;
2110
2111 for (uint blk = 0; blk < blocks_per_row; blk++) {
2112 {
2113 uint base = tid * 8;
2114 for (uint ii = 0; ii < 8; ii++) {
2115 uint idx = base + ii;
2116 uint r = idx / 32;
2117 uint k = idx % 32;
2118 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2119 float sc = float(*(device const half*)rp);
2120 int ival = int(*(device const int8_t*)(rp + 2 + k));
2121 w_tile[r * 32 + k] = half(float(ival) * sc);
2122 }
2123 }
2124 {
2125 uint base = tid * 8;
2126 for (uint ii = 0; ii < 8; ii++) {
2127 uint idx = base + ii;
2128 uint m = idx / 32;
2129 uint k = idx % 32;
2130 uint tok = tok_base + m;
2131 t_tile[m * 32 + k] = (tok < num_tokens)
2132 ? half(inputs[tok * cols + blk * 32 + k])
2133 : half(0);
2134 }
2135 }
2136 threadgroup_barrier(mem_flags::mem_threadgroup);
2137
2138 for (uint k_sub = 0; k_sub < 4; k_sub++) {
2139 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2140 simdgroup_load(A_top,
2141 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2142 32, ulong2(0, 0), false);
2143 simdgroup_load(A_bot,
2144 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2145 32, ulong2(0, 0), false);
2146 simdgroup_load(B_lo,
2147 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2148 32, ulong2(0, 0), true);
2149 simdgroup_load(B_hi,
2150 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2151 32, ulong2(0, 0), true);
2152 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2153 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2154 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2155 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2156 }
2157
2158 threadgroup_barrier(mem_flags::mem_threadgroup);
2159 }
2160
2161 // Store half accumulators via scratch (must widen to f32 for device output).
2162 uint out_tok_top = tok_base + sg_tok_base + 0;
2163 uint out_tok_bot = tok_base + sg_tok_base + 8;
2164 uint out_row_lo = row_base + sg_row_base + 0;
2165 uint out_row_hi = row_base + sg_row_base + 8;
2166
2167 threadgroup half scratch[4 * 4 * 64];
2168 uint sg_base = simd_id * 256;
2169 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2170 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2171 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2172 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2173 simdgroup_barrier(mem_flags::mem_threadgroup);
2174 uint lane = tid % 32;
2175 if (lane == 0) {
2176 for (uint m = 0; m < 8; m++) {
2177 uint t_top = out_tok_top + m;
2178 if (t_top < num_tokens) {
2179 device float* dst0 = outputs + t_top * rows + out_row_lo;
2180 device float* dst1 = outputs + t_top * rows + out_row_hi;
2181 threadgroup const half* src0 = scratch + sg_base + 0 + m * 8;
2182 threadgroup const half* src1 = scratch + sg_base + 64 + m * 8;
2183 for (uint j = 0; j < 8; j++) {
2184 dst0[j] = float(src0[j]);
2185 dst1[j] = float(src1[j]);
2186 }
2187 }
2188 uint t_bot = out_tok_bot + m;
2189 if (t_bot < num_tokens) {
2190 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2191 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2192 threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2193 threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2194 for (uint j = 0; j < 8; j++) {
2195 dst2[j] = float(src2[j]);
2196 dst3[j] = float(src3[j]);
2197 }
2198 }
2199 }
2200 }
2201}
2202
2203// ── add_bias_batch ─────────────────────────────────────────────────────
2204// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2205// Used for Qwen2 QKV bias after the fused qkv matmul.
2206// out[token, i] += bias[i] for i in 0..rows, token in 0..num_tokens
2207kernel void add_bias_batch(
2208 device float* out [[buffer(0)]], // [num_tokens, rows]
2209 device const float* bias [[buffer(1)]], // [rows]
2210 constant uint& num_tokens [[buffer(2)]],
2211 constant uint& rows [[buffer(3)]],
2212 uint id [[thread_position_in_grid]])
2213{
2214 uint total = num_tokens * rows;
2215 if (id >= total) return;
2216 uint i = id % rows;
2217 out[id] += bias[i];
2218}
2219
2220// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2221// Batched Q4_0 matrix-vector multiply for M input vectors.
2222// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2223kernel void matmul_vec_q4_batch(
2224 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes [rows, cols]
2225 device const float* inputs [[buffer(1)]], // [M, cols] input batch
2226 device float* outputs [[buffer(2)]], // [M, rows] output batch
2227 constant uint& num_tokens [[buffer(3)]], // M
2228 constant uint& rows [[buffer(4)]],
2229 constant uint& cols [[buffer(5)]],
2230 uint tgid [[threadgroup_position_in_grid]],
2231 uint tid [[thread_index_in_threadgroup]],
2232 uint simd_lane [[thread_index_in_simdgroup]],
2233 uint simd_id [[simdgroup_index_in_threadgroup]])
2234{
2235 uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2236 uint token = tgid / row_tgs;
2237 uint tg_in_token = tgid % row_tgs;
2238 if (token >= num_tokens) return;
2239
2240 threadgroup float vec_tile[VEC_TILE_SIZE];
2241 device const float* input = inputs + token * cols;
2242 for (uint i = tid; i < cols; i += 256) {
2243 vec_tile[i] = input[i];
2244 }
2245 threadgroup_barrier(mem_flags::mem_threadgroup);
2246
2247 uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2248 if (row_base >= rows) return;
2249
2250 uint blocks_per_row = cols / 32;
2251 uint row_bytes = blocks_per_row * 18;
2252
2253 device const uchar* r0 = matrix + row_base * row_bytes;
2254 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2255 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2256 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2257
2258 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2259
2260 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2261 uint bb = blk * 18;
2262 uint vb = blk * 32;
2263
2264 float sc0 = float(*(device const half*)(r0 + bb));
2265 float sc1 = float(*(device const half*)(r1 + bb));
2266 float sc2 = float(*(device const half*)(r2 + bb));
2267 float sc3 = float(*(device const half*)(r3 + bb));
2268
2269 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2270 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2271 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2272 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2273
2274 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2275 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2276 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2277 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2278 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2279 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2280 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2281 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2282
2283 float bd0=0, bd1=0, bd2=0, bd3=0;
2284 uchar4 b;
2285
2286 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2287 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2288 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2289 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2290
2291 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2292 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2293 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2294 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2295
2296 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2297 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2298 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2299 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2300
2301 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2302 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2303 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2304 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2305
2306 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2307 }
2308
2309 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2310 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2311
2312 device float* output = outputs + token * rows;
2313 if (simd_lane == 0) {
2314 if (row_base < rows) output[row_base] = sum0;
2315 if (row_base + 1 < rows) output[row_base + 1] = sum1;
2316 if (row_base + 2 < rows) output[row_base + 2] = sum2;
2317 if (row_base + 3 < rows) output[row_base + 3] = sum3;
2318 }
2319}
2320
2321// ── copy_kv_batch ─────────────────────────────────────────────────────
2322// Copy K or V from a strided batch QKV buffer to the KV cache.
2323// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2324// dst layout: contiguous [max_seq, kv_dim] cache.
2325kernel void copy_kv_batch(
2326 device const float* src [[buffer(0)]], // batch QKV buffer (f32)
2327 device half* dst [[buffer(1)]], // KV cache (f16)
2328 constant uint& M [[buffer(2)]], // num tokens in batch
2329 constant uint& kv_dim [[buffer(3)]], // elements per KV vector
2330 constant uint& base_pos [[buffer(4)]], // starting position in cache
2331 constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2332 constant uint& src_offset [[buffer(6)]], // float offset within each src row
2333 uint id [[thread_position_in_grid]])
2334{
2335 uint total = M * kv_dim;
2336 if (id >= total) return;
2337 uint token = id / kv_dim;
2338 uint d = id % kv_dim;
2339 uint dst_off = (base_pos + token) * kv_dim + d;
2340 uint src_off = token * src_stride + src_offset + d;
2341 dst[dst_off] = half(src[src_off]);
2342}
2343
2344// ── attention_batch ───────────────────────────────────────────────────
2345// Batched causal attention for prefill. Processes M tokens in one dispatch.
2346// Each threadgroup handles one (token, head) pair.
2347// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2348// Causal masking: token i can only attend to positions 0..base_pos+i.
2349kernel void attention_batch(
2350 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2351 device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
2352 device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
2353 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2354 constant uint& M [[buffer(4)]], // num tokens in batch
2355 constant uint& base_pos [[buffer(5)]], // starting position in KV cache
2356 constant uint& num_heads [[buffer(6)]],
2357 constant uint& num_kv_heads [[buffer(7)]],
2358 constant uint& head_dim [[buffer(8)]],
2359 constant uint& q_stride [[buffer(9)]], // floats per row in q_batch
2360 uint tgid [[threadgroup_position_in_grid]],
2361 uint tid [[thread_index_in_threadgroup]],
2362 uint simd_lane [[thread_index_in_simdgroup]],
2363 uint simd_id [[simdgroup_index_in_threadgroup]])
2364{
2365 // Grid: M * num_heads threadgroups
2366 uint token_idx = tgid / num_heads;
2367 uint head = tgid % num_heads;
2368 if (token_idx >= M) return;
2369
2370 uint kv_head = head / (num_heads / num_kv_heads);
2371 uint seq_len = base_pos + token_idx + 1; // causal: see positions 0..base_pos+token_idx
2372
2373 // Q offset uses strided layout (from batch QKV buffer)
2374 uint q_off = token_idx * q_stride + head * head_dim;
2375 // Output is contiguous [M, num_heads * head_dim]
2376 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2377
2378 // Shared memory for attention scores — sized to the effective max_seq_len
2379 // (4096 for all supported models) so long-context attention doesn't overflow.
2380 threadgroup float scores[ATTN_SCORES_SIZE];
2381
2382 // Step 1: Q * K^T with simdgroup reduction
2383 for (uint s = simd_id; s < seq_len; s += 8) {
2384 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2385 float dot = 0.0;
2386 for (uint d = simd_lane; d < head_dim; d += 32) {
2387 dot += q_batch[q_off + d] * k_cache[k_off + d];
2388 }
2389 dot = simd_sum(dot);
2390 if (simd_lane == 0) {
2391 scores[s] = dot * fast::rsqrt(float(head_dim));
2392 }
2393 }
2394 threadgroup_barrier(mem_flags::mem_threadgroup);
2395
2396 // Step 2: Softmax (cooperative)
2397 float local_max = -INFINITY;
2398 for (uint s = tid; s < seq_len; s += 256) {
2399 local_max = max(local_max, scores[s]);
2400 }
2401 local_max = simd_max(local_max);
2402 threadgroup float shared_max[8];
2403 if (simd_lane == 0) shared_max[simd_id] = local_max;
2404 threadgroup_barrier(mem_flags::mem_threadgroup);
2405 if (tid == 0) {
2406 float m = shared_max[0];
2407 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2408 shared_max[0] = m;
2409 }
2410 threadgroup_barrier(mem_flags::mem_threadgroup);
2411 float max_val = shared_max[0];
2412
2413 float local_sum = 0.0;
2414 for (uint s = tid; s < seq_len; s += 256) {
2415 scores[s] = fast::exp(scores[s] - max_val);
2416 local_sum += scores[s];
2417 }
2418 local_sum = simd_sum(local_sum);
2419 threadgroup float shared_sum[8];
2420 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2421 threadgroup_barrier(mem_flags::mem_threadgroup);
2422 if (tid == 0) {
2423 float total = 0.0;
2424 for (uint i = 0; i < 8; i++) total += shared_sum[i];
2425 shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2426 }
2427 threadgroup_barrier(mem_flags::mem_threadgroup);
2428 float inv_sum = shared_sum[0];
2429 for (uint s = tid; s < seq_len; s += 256) {
2430 scores[s] *= inv_sum;
2431 }
2432 threadgroup_barrier(mem_flags::mem_threadgroup);
2433
2434 // Step 3: scores * V using float4 vectorized loads
2435 // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2436 // This is much better than the scalar version where only 64 of 256 threads are active.
2437 uint v_stride = num_kv_heads * head_dim;
2438 uint head_dim4 = head_dim / 4;
2439 for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2440 uint d = d4 * 4;
2441 float4 acc = float4(0.0);
2442 uint v_base = kv_head * head_dim + d;
2443 uint seq_len4 = seq_len & ~3u;
2444 for (uint s = 0; s < seq_len4; s += 4) {
2445 float sc0 = scores[s];
2446 float sc1 = scores[s + 1];
2447 float sc2 = scores[s + 2];
2448 float sc3 = scores[s + 3];
2449 acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
2450 + sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
2451 + sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
2452 + sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
2453 }
2454 for (uint s = seq_len4; s < seq_len; s++) {
2455 acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
2456 }
2457 *(device float4*)(output_batch + out_off + d) = acc;
2458 }
2459 // Handle remaining dimensions not divisible by 4 (scalar fallback)
2460 for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2461 float acc = 0.0;
2462 uint v_base = kv_head * head_dim + d;
2463 for (uint s = 0; s < seq_len; s++) {
2464 acc += scores[s] * v_cache[s * v_stride + v_base];
2465 }
2466 output_batch[out_off + d] = acc;
2467 }
2468}
2469
2470// ── attention_flash_batch ─────────────────────────────────────────────
2471// Streaming attention with online softmax. Same grid as attention_batch
2472// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2473// matrix is never materialized. K/V positions are processed in a tile of
2474// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2475// the standard flash-attention recurrence:
2476//
2477// m_new = max(m_old, tile_max)
2478// alpha = exp(m_old - m_new)
2479// l_new = alpha * l_old + sum(exp(S - m_new))
2480// O_new = alpha * O_old + sum(exp(S - m_new) * V)
2481// O_final = O / l
2482//
2483// This removes the `scores[2048]` cap in attention_batch (which silently
2484// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2485// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2486//
2487// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2488constant constexpr uint FLASH_K_TILE = 32;
2489constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2490
2491kernel void attention_flash_batch(
2492 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2493 device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
2494 device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
2495 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2496 constant uint& M [[buffer(4)]],
2497 constant uint& base_pos [[buffer(5)]],
2498 constant uint& num_heads [[buffer(6)]],
2499 constant uint& num_kv_heads [[buffer(7)]],
2500 constant uint& head_dim [[buffer(8)]],
2501 constant uint& q_stride [[buffer(9)]],
2502 uint tgid [[threadgroup_position_in_grid]],
2503 uint tid [[thread_index_in_threadgroup]],
2504 uint simd_lane [[thread_index_in_simdgroup]],
2505 uint simd_id [[simdgroup_index_in_threadgroup]])
2506{
2507 uint token_idx = tgid / num_heads;
2508 uint head = tgid % num_heads;
2509 if (token_idx >= M) return;
2510
2511 uint kv_head = head / (num_heads / num_kv_heads);
2512 uint seq_len = base_pos + token_idx + 1; // causal: attend to [0, base_pos + token_idx]
2513 uint q_off = token_idx * q_stride + head * head_dim;
2514 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2515
2516 // Threadgroup state:
2517 // q_sh: Q vector for this (token, head), loaded once
2518 // o_sh: running output vector, updated each K tile
2519 // scores_sh: scores for the current K tile only
2520 // stats: [running max, running sum] (see flash-attention recurrence)
2521 threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2522 threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2523 threadgroup float scores_sh[FLASH_K_TILE];
2524 threadgroup float stats[2];
2525 threadgroup float sg_scratch[8]; // simdgroup-level reduction buffer
2526
2527 // --- Load Q (one row) and zero the running O ---
2528 for (uint d = tid; d < head_dim; d += 256) {
2529 q_sh[d] = q_batch[q_off + d];
2530 o_sh[d] = 0.0f;
2531 }
2532 if (tid == 0) {
2533 stats[0] = -INFINITY;
2534 stats[1] = 0.0f;
2535 }
2536 threadgroup_barrier(mem_flags::mem_threadgroup);
2537
2538 float scale = fast::rsqrt(float(head_dim));
2539 uint v_stride = num_kv_heads * head_dim;
2540 uint v_base = kv_head * head_dim;
2541
2542 // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2543 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2544 uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2545
2546 // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2547 // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2548 for (uint ti = simd_id; ti < tile_n; ti += 8) {
2549 uint k_off = (kv_base + ti) * v_stride + v_base; // same layout as V stride
2550 float dot = 0.0f;
2551 for (uint d = simd_lane; d < head_dim; d += 32) {
2552 dot += q_sh[d] * k_cache[k_off + d];
2553 }
2554 dot = simd_sum(dot);
2555 if (simd_lane == 0) {
2556 scores_sh[ti] = dot * scale;
2557 }
2558 }
2559 threadgroup_barrier(mem_flags::mem_threadgroup);
2560
2561 // [2] Tile max via cooperative reduction.
2562 float local_max = -INFINITY;
2563 for (uint s = tid; s < tile_n; s += 256) {
2564 local_max = max(local_max, scores_sh[s]);
2565 }
2566 local_max = simd_max(local_max);
2567 if (simd_lane == 0) {
2568 sg_scratch[simd_id] = local_max;
2569 }
2570 threadgroup_barrier(mem_flags::mem_threadgroup);
2571 // [3] Merge with running max, compute alpha, rescale running l.
2572 float m_new;
2573 float alpha;
2574 if (tid == 0) {
2575 float tile_max = sg_scratch[0];
2576 for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2577 float m_old = stats[0];
2578 m_new = max(m_old, tile_max);
2579 // First iteration: m_old = -inf → alpha = 0 (reset O).
2580 alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2581 stats[0] = m_new;
2582 stats[1] *= alpha;
2583 // Broadcast via sg_scratch.
2584 sg_scratch[0] = alpha;
2585 sg_scratch[1] = m_new;
2586 }
2587 threadgroup_barrier(mem_flags::mem_threadgroup);
2588 alpha = sg_scratch[0];
2589 m_new = sg_scratch[1];
2590
2591 // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2592 for (uint d = tid; d < head_dim; d += 256) {
2593 o_sh[d] *= alpha;
2594 }
2595 for (uint s = tid; s < tile_n; s += 256) {
2596 scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2597 }
2598 threadgroup_barrier(mem_flags::mem_threadgroup);
2599
2600 // [5] Tile sum → update running l.
2601 float local_sum = 0.0f;
2602 for (uint s = tid; s < tile_n; s += 256) {
2603 local_sum += scores_sh[s];
2604 }
2605 local_sum = simd_sum(local_sum);
2606 if (simd_lane == 0) {
2607 sg_scratch[simd_id] = local_sum;
2608 }
2609 threadgroup_barrier(mem_flags::mem_threadgroup);
2610 if (tid == 0) {
2611 float tile_sum = 0.0f;
2612 for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2613 stats[1] += tile_sum;
2614 }
2615 threadgroup_barrier(mem_flags::mem_threadgroup);
2616
2617 // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2618 for (uint d = tid; d < head_dim; d += 256) {
2619 float acc = 0.0f;
2620 for (uint s = 0; s < tile_n; s++) {
2621 acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2622 }
2623 o_sh[d] += acc;
2624 }
2625 threadgroup_barrier(mem_flags::mem_threadgroup);
2626 }
2627
2628 // --- Normalize and write output ---
2629 float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2630 for (uint d = tid; d < head_dim; d += 256) {
2631 output_batch[out_off + d] = o_sh[d] * inv_l;
2632 }
2633}
2634
2635// ── attention_mma_flash_batch ─────────────────────────────────────────
2636// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2637// both Q·K^T and P·V. Processes Q_BLOCK=8 tokens of one head per
2638// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2639// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2640// arithmetic.
2641//
2642// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2643// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2644// to attention_batch / attention_flash_batch otherwise.
2645//
2646// Online softmax recurrence is identical to attention_flash_batch but
2647// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2648constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2649constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2650constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2651
2652kernel void attention_mma_flash_batch(
2653 device const float* q_batch [[buffer(0)]],
2654 device const half* k_cache [[buffer(1)]],
2655 device const half* v_cache [[buffer(2)]],
2656 device float* output_batch [[buffer(3)]],
2657 constant uint& M [[buffer(4)]],
2658 constant uint& base_pos [[buffer(5)]],
2659 constant uint& num_heads [[buffer(6)]],
2660 constant uint& num_kv_heads [[buffer(7)]],
2661 constant uint& head_dim [[buffer(8)]],
2662 constant uint& q_stride [[buffer(9)]],
2663 uint2 tgid [[threadgroup_position_in_grid]],
2664 uint tid [[thread_index_in_threadgroup]],
2665 uint simd_lane [[thread_index_in_simdgroup]],
2666 uint simd_id [[simdgroup_index_in_threadgroup]])
2667{
2668 uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2669 uint head = tgid.y;
2670 if (q_block_start >= M) return;
2671 uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2672
2673 uint kv_head = head / (num_heads / num_kv_heads);
2674 // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2675 // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2676 uint seq_len = base_pos + q_block_start + q_valid;
2677 float scale = fast::rsqrt(float(head_dim));
2678
2679 uint kv_stride = num_kv_heads * head_dim;
2680 uint kv_base_off = kv_head * head_dim;
2681
2682 // ── Threadgroup memory ──
2683 // q_sh: [Q_BLOCK, head_dim] half — Q tile, loaded once
2684 // k_sh: [K_BLOCK, head_dim] half — K tile, refreshed per kv_base iter
2685 // v_sh: [K_BLOCK, head_dim] half — V tile, refreshed per kv_base iter
2686 // s_sh: [Q_BLOCK, K_BLOCK] float — raw Q·K^T scores, then scaled+masked
2687 // p_sh: [Q_BLOCK, K_BLOCK] half — softmax probabilities (for P·V MMA)
2688 // o_sh: [Q_BLOCK, head_dim] float — running output accumulator
2689 // m_sh: [Q_BLOCK] float — running max per Q row
2690 // l_sh: [Q_BLOCK] float — running softmax denominator per Q row
2691 // scratch: 4*Q_BLOCK floats — per-row reduction scratch
2692 threadgroup half q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2693 threadgroup half k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2694 threadgroup half v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2695 threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2696 threadgroup half p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2697 threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2698 threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2699 threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2700 threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2701
2702 // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2703 uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2704 for (uint i = tid; i < qblock_elems; i += 128) {
2705 uint q = i / head_dim;
2706 uint d = i % head_dim;
2707 if (q < q_valid) {
2708 uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2709 q_sh[q * head_dim + d] = half(q_batch[q_off]);
2710 } else {
2711 q_sh[q * head_dim + d] = half(0);
2712 }
2713 o_sh[q * head_dim + d] = 0.0f;
2714 }
2715 if (tid < FLASH_MMA_Q_BLOCK) {
2716 m_sh[tid] = -INFINITY;
2717 l_sh[tid] = 0.0f;
2718 }
2719 threadgroup_barrier(mem_flags::mem_threadgroup);
2720
2721 // ── Stream K/V in K_BLOCK chunks ──
2722 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2723 uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2724
2725 // Load K and V tile into TG memory (as half).
2726 uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2727 for (uint i = tid; i < kv_tile_elems; i += 128) {
2728 uint k_pos = i / head_dim;
2729 uint d = i % head_dim;
2730 if (k_pos < tile_n) {
2731 uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2732 k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2733 v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2734 } else {
2735 k_sh[k_pos * head_dim + d] = half(0);
2736 v_sh[k_pos * head_dim + d] = half(0);
2737 }
2738 }
2739 threadgroup_barrier(mem_flags::mem_threadgroup);
2740
2741 // ── Phase 1: S = Q @ K^T via MMA ──
2742 // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2743 // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2744 // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2745 {
2746 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2747 uint dim_chunks = head_dim / 8;
2748 for (uint dc = 0; dc < dim_chunks; dc++) {
2749 simdgroup_matrix<half, 8, 8> A, B;
2750 // A = Q[0:8, dc*8 : dc*8+8] (rows of Q, no transpose)
2751 simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2752 // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2753 // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2754 // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2755 // transpose=true, which places it in the register as K^T of that sub-block.
2756 simdgroup_load(B,
2757 k_sh + (simd_id * 8) * head_dim + dc * 8,
2758 head_dim, ulong2(0, 0), true);
2759 simdgroup_multiply_accumulate(C, A, B, C);
2760 }
2761 // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2762 simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2763 }
2764 threadgroup_barrier(mem_flags::mem_threadgroup);
2765
2766 // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2767 // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2768 uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2769 for (uint i = tid; i < s_elems; i += 128) {
2770 uint q = i / FLASH_MMA_K_BLOCK;
2771 uint k = i % FLASH_MMA_K_BLOCK;
2772 uint global_q = q_block_start + q;
2773 uint global_kv = kv_base + k;
2774 bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2775 s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2776 }
2777 threadgroup_barrier(mem_flags::mem_threadgroup);
2778
2779 // ── Phase 2b: per-row max via simdgroup reduction ──
2780 // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2781 // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2782 {
2783 uint row_base = simd_id * 2;
2784 for (uint qr = 0; qr < 2; qr++) {
2785 uint q = row_base + qr;
2786 float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2787 float row_max = simd_max(my);
2788 if (simd_lane == 0) {
2789 scratch[q] = row_max; // tile_max[q]
2790 }
2791 }
2792 }
2793 threadgroup_barrier(mem_flags::mem_threadgroup);
2794
2795 // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2796 if (tid < FLASH_MMA_Q_BLOCK) {
2797 uint q = tid;
2798 float m_old = m_sh[q];
2799 float tile_max = scratch[q];
2800 float m_new = max(m_old, tile_max);
2801 float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2802 m_sh[q] = m_new;
2803 l_sh[q] = l_sh[q] * alpha;
2804 // scratch[q] = m_new (for phase 2d)
2805 // scratch[Q_BLOCK + q] = alpha (for phase 3)
2806 scratch[q] = m_new;
2807 scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2808 }
2809 threadgroup_barrier(mem_flags::mem_threadgroup);
2810
2811 // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2812 {
2813 uint row_base = simd_id * 2;
2814 for (uint qr = 0; qr < 2; qr++) {
2815 uint q = row_base + qr;
2816 float m_new = scratch[q];
2817 float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2818 p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2819 float row_sum = simd_sum(p);
2820 if (simd_lane == 0) {
2821 scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2822 }
2823 }
2824 }
2825 threadgroup_barrier(mem_flags::mem_threadgroup);
2826
2827 // ── Phase 2e: l_sh += tile_sum ──
2828 if (tid < FLASH_MMA_Q_BLOCK) {
2829 uint q = tid;
2830 l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2831 }
2832
2833 // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2834 threadgroup_barrier(mem_flags::mem_threadgroup);
2835 for (uint i = tid; i < qblock_elems; i += 128) {
2836 uint q = i / head_dim;
2837 float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2838 o_sh[i] *= alpha;
2839 }
2840 threadgroup_barrier(mem_flags::mem_threadgroup);
2841
2842 // ── Phase 4: O += P @ V via MMA ──
2843 // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2844 // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2845 // For head_dim=64: 16 dims/sg = 2 tiles. head_dim=128: 32 dims/sg = 4 tiles.
2846 {
2847 uint dims_per_sg = head_dim / 4; // 16 or 32
2848 uint tiles_per_sg = dims_per_sg / 8; // 2 or 4
2849 uint sg_d_base = simd_id * dims_per_sg;
2850 for (uint t = 0; t < tiles_per_sg; t++) {
2851 uint d_base = sg_d_base + t * 8;
2852 simdgroup_matrix<float, 8, 8> O_acc;
2853 simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2854 uint k_chunks = FLASH_MMA_K_BLOCK / 8; // 4
2855 for (uint kc = 0; kc < k_chunks; kc++) {
2856 simdgroup_matrix<half, 8, 8> A, B;
2857 simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2858 ulong2(0, 0), false);
2859 simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2860 ulong2(0, 0), false);
2861 simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2862 }
2863 simdgroup_store(O_acc, o_sh + d_base, head_dim);
2864 }
2865 }
2866 threadgroup_barrier(mem_flags::mem_threadgroup);
2867 }
2868
2869 // ── Finalize: O /= l, write to output ──
2870 for (uint i = tid; i < qblock_elems; i += 128) {
2871 uint q = i / head_dim;
2872 uint d = i % head_dim;
2873 if (q < q_valid) {
2874 float l = l_sh[q];
2875 float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2876 uint token_idx = q_block_start + q;
2877 uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2878 output_batch[out_off] = o_sh[i] * inv_l;
2879 }
2880 }
2881}
2882
2883// ── rope_qk_batch ─────────────────────────────────────────────────────
2884// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2885// launch + memory barrier per layer. Both Q and K live in the same
2886// qkv_data buffer at different offsets within each token's row.
2887// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2888kernel void rope_qk_batch(
2889 device float* qkv_data [[buffer(0)]], // [M, qkv_stride]
2890 constant uint& M [[buffer(1)]], // num tokens
2891 constant uint& base_pos [[buffer(2)]], // starting position
2892 constant uint& num_q_heads [[buffer(3)]],
2893 constant uint& num_kv_heads [[buffer(4)]],
2894 constant uint& head_dim [[buffer(5)]],
2895 constant uint& qkv_stride [[buffer(6)]], // floats per row
2896 constant float& theta [[buffer(7)]],
2897 uint id [[thread_position_in_grid]])
2898{
2899 uint half_dim = head_dim / 2;
2900 uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2901 uint token = id / total_pairs;
2902 uint pair = id % total_pairs;
2903 if (token >= M) return;
2904
2905 uint pos = base_pos + token;
2906 uint q_pairs = num_q_heads * half_dim;
2907
2908 uint h, i, offset;
2909 if (pair < q_pairs) {
2910 // Q head
2911 h = pair / half_dim;
2912 i = pair % half_dim;
2913 offset = token * qkv_stride + h * head_dim + i * 2;
2914 } else {
2915 // K head
2916 uint kp = pair - q_pairs;
2917 h = kp / half_dim;
2918 i = kp % half_dim;
2919 uint k_start = num_q_heads * head_dim;
2920 offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2921 }
2922
2923 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2924 float angle = float(pos) * freq;
2925 float cos_val = cos(angle);
2926 float sin_val = sin(angle);
2927
2928 float x0 = qkv_data[offset];
2929 float x1 = qkv_data[offset + 1];
2930 qkv_data[offset] = x0 * cos_val - x1 * sin_val;
2931 qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2932}
2933
2934// ── copy_kv_both_batch ────────────────────────────────────────────────
2935// Fused K+V cache copy in a single dispatch: copies both K and V from
2936// the strided batch QKV buffer to their respective KV cache buffers.
2937// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2938kernel void copy_kv_both_batch(
2939 device const float* src [[buffer(0)]], // batch QKV buffer [M, qkv_stride] f32
2940 device half* k_dst [[buffer(1)]], // K cache [max_seq, kv_dim] f16
2941 device half* v_dst [[buffer(2)]], // V cache [max_seq, kv_dim] f16
2942 constant uint& M [[buffer(3)]], // num tokens in batch
2943 constant uint& kv_dim [[buffer(4)]], // elements per KV vector
2944 constant uint& base_pos [[buffer(5)]], // starting position in cache
2945 constant uint& src_stride [[buffer(6)]], // floats per row in src (qkv_stride)
2946 constant uint& k_offset [[buffer(7)]], // float offset of K within each src row
2947 constant uint& v_offset [[buffer(8)]], // float offset of V within each src row
2948 uint id [[thread_position_in_grid]])
2949{
2950 // Total elements = M * kv_dim * 2 (K + V)
2951 uint total_kv = M * kv_dim;
2952 if (id >= total_kv * 2) return;
2953
2954 uint is_v = id / total_kv; // 0 = K, 1 = V
2955 uint local_id = id % total_kv;
2956 uint token = local_id / kv_dim;
2957 uint d = local_id % kv_dim;
2958
2959 uint dst_off = (base_pos + token) * kv_dim + d;
2960 uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2961
2962 if (is_v) {
2963 v_dst[dst_off] = half(src[src_off]);
2964 } else {
2965 k_dst[dst_off] = half(src[src_off]);
2966 }
2967}
2968"#
2969 .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2970 .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2971}
2972
2973fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2978 let mut code = String::with_capacity(48 * 1024);
2979 emit_model_header(&mut code, config)?;
2980 emit_metal_model_struct(&mut code, config)?;
2981 emit_layer_buffers_struct(&mut code, config)?;
2982 emit_metal_model_impl(&mut code, config)?;
2983 emit_helper_functions(&mut code)?;
2984 Ok(code)
2985}
2986
2987fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2988 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2989 writeln!(
2990 code,
2991 "//! Model: {} ({} layers, hidden={})",
2992 config.architecture, config.num_layers, config.hidden_size
2993 )?;
2994 writeln!(code, "//!")?;
2995 writeln!(
2996 code,
2997 "//! Uses native Metal compute pipelines via the metal crate."
2998 )?;
2999 writeln!(code)?;
3000 writeln!(code, "#![allow(dead_code)]")?;
3001 writeln!(code)?;
3002 writeln!(code, "use metal::*;")?;
3003 writeln!(code, "#[allow(unused_imports)]")?;
3004 writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
3005 writeln!(code, "use std::mem;")?;
3006 writeln!(code)?;
3007
3008 writeln!(
3010 code,
3011 "// ── Model constants ──────────────────────────────────"
3012 )?;
3013 writeln!(
3014 code,
3015 "pub const HIDDEN_SIZE: usize = {};",
3016 config.hidden_size
3017 )?;
3018 writeln!(
3019 code,
3020 "pub const INTERMEDIATE_SIZE: usize = {};",
3021 config.intermediate_size
3022 )?;
3023 writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
3024 writeln!(
3025 code,
3026 "pub const NUM_HEADS: usize = {};",
3027 config.num_attention_heads
3028 )?;
3029 writeln!(
3030 code,
3031 "pub const NUM_KV_HEADS: usize = {};",
3032 config.num_kv_heads
3033 )?;
3034 writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
3035 writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
3036 let effective_seq_len = config.max_seq_len.min(4096);
3037 writeln!(
3038 code,
3039 "pub const MAX_SEQ_LEN: usize = {}; // capped from model's {}",
3040 effective_seq_len, config.max_seq_len
3041 )?;
3042 writeln!(
3043 code,
3044 "pub const RMS_NORM_EPS: f32 = {:e};",
3045 config.rms_norm_eps
3046 )?;
3047 writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3048 writeln!(
3049 code,
3050 "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3051 )?;
3052 writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3053 writeln!(code)?;
3054
3055 Ok(())
3056}
3057
3058fn emit_metal_model_struct(
3059 code: &mut String,
3060 config: &ModelConfig,
3061) -> Result<(), MetalCodegenError> {
3062 writeln!(
3063 code,
3064 "// ── MetalModel ──────────────────────────────────────────"
3065 )?;
3066 writeln!(code)?;
3067 writeln!(
3068 code,
3069 "/// Metal-accelerated transformer model for Apple Silicon."
3070 )?;
3071 writeln!(code, "///")?;
3072 writeln!(
3073 code,
3074 "/// Uses unified memory for zero-copy weight access and native Metal"
3075 )?;
3076 writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3077 writeln!(code, "pub struct MetalModel {{")?;
3078 writeln!(code, " device: Device,")?;
3079 writeln!(code, " queue: CommandQueue,")?;
3080 writeln!(code)?;
3081 writeln!(code, " // ── Compute pipelines ──")?;
3082 writeln!(code, " matmul_pipeline: ComputePipelineState,")?;
3083 writeln!(code, " matmul_q8_pipeline: ComputePipelineState,")?;
3084 writeln!(code, " matmul_q4_pipeline: ComputePipelineState,")?;
3085 writeln!(code, " rms_norm_pipeline: ComputePipelineState,")?;
3086 writeln!(code, " rope_pipeline: ComputePipelineState,")?;
3087 writeln!(code, " softmax_pipeline: ComputePipelineState,")?;
3088 writeln!(code, " silu_mul_pipeline: ComputePipelineState,")?;
3089 writeln!(code, " silu_mul_fused_pipeline: ComputePipelineState,")?;
3090 writeln!(code, " add_pipeline: ComputePipelineState,")?;
3091 writeln!(code, " attention_pipeline: ComputePipelineState,")?;
3092 writeln!(code, " add_inplace_pipeline: ComputePipelineState,")?;
3093 writeln!(code, " copy_pipeline: ComputePipelineState,")?;
3094 writeln!(code, " copy_offset_pipeline: ComputePipelineState,")?;
3095 writeln!(
3096 code,
3097 " copy_f32_to_f16_offset_pipeline: ComputePipelineState,"
3098 )?;
3099 writeln!(code)?;
3100 writeln!(code, " // ── Batched prefill pipelines ──")?;
3101 writeln!(code, " matmul_batch_pipeline: ComputePipelineState,")?;
3102 writeln!(code, " matmul_q8_batch_pipeline: ComputePipelineState,")?;
3103 writeln!(
3104 code,
3105 " matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3106 )?;
3107 writeln!(code, " matmul_q8_mma_pipeline: ComputePipelineState,")?;
3108 writeln!(code, " matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3109 writeln!(
3110 code,
3111 " matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3112 )?;
3113 writeln!(
3114 code,
3115 " matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3116 )?;
3117 writeln!(
3118 code,
3119 " matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3120 )?;
3121 if config.qkv_bias {
3122 writeln!(code, " add_bias_batch_pipeline: ComputePipelineState,")?;
3123 }
3124 writeln!(code, " matmul_q4_batch_pipeline: ComputePipelineState,")?;
3125 writeln!(code, " rms_norm_batch_pipeline: ComputePipelineState,")?;
3126 writeln!(code, " rope_batch_pipeline: ComputePipelineState,")?;
3127 writeln!(
3128 code,
3129 " silu_mul_fused_batch_pipeline: ComputePipelineState,"
3130 )?;
3131 writeln!(
3132 code,
3133 " add_inplace_batch_pipeline: ComputePipelineState,"
3134 )?;
3135 writeln!(
3136 code,
3137 " copy_embedding_batch_pipeline: ComputePipelineState,"
3138 )?;
3139 writeln!(code, " attention_batch_pipeline: ComputePipelineState,")?;
3140 writeln!(
3141 code,
3142 " attention_flash_batch_pipeline: ComputePipelineState,"
3143 )?;
3144 writeln!(
3145 code,
3146 " attention_mma_flash_batch_pipeline: ComputePipelineState,"
3147 )?;
3148 writeln!(code, " copy_kv_batch_pipeline: ComputePipelineState,")?;
3149 writeln!(code, " rope_qk_batch_pipeline: ComputePipelineState,")?;
3150 writeln!(
3151 code,
3152 " copy_kv_both_batch_pipeline: ComputePipelineState,"
3153 )?;
3154 writeln!(code)?;
3155 writeln!(code, " // ── Weight buffers (Metal shared memory) ──")?;
3156 writeln!(
3157 code,
3158 " /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3159 )?;
3160 writeln!(code, " embed_buf: Buffer,")?;
3161 writeln!(code)?;
3162 writeln!(code, " /// Per-layer weight buffers")?;
3163 writeln!(code, " layers: Vec<LayerBuffers>,")?;
3164 writeln!(code)?;
3165 writeln!(code, " /// Final layer-norm weight [HIDDEN_SIZE]")?;
3166 writeln!(code, " norm_buf: Buffer,")?;
3167 writeln!(code)?;
3168 writeln!(
3169 code,
3170 " /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3171 )?;
3172 writeln!(code, " lm_head_buf: Buffer,")?;
3173 writeln!(code)?;
3174 writeln!(
3175 code,
3176 " // ── Working buffers (pre-allocated, reused every forward pass) ──"
3177 )?;
3178 writeln!(code, " hidden_buf: Buffer,")?;
3179 writeln!(code, " residual_buf: Buffer,")?;
3180 writeln!(code, " normed_buf: Buffer,")?;
3181 writeln!(
3182 code,
3183 " /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3184 )?;
3185 writeln!(code, " qkv_buf: Buffer,")?;
3186 writeln!(code, " attn_out_buf: Buffer,")?;
3187 writeln!(code, " attn_proj_buf: Buffer,")?;
3188 writeln!(
3189 code,
3190 " /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3191 )?;
3192 writeln!(code, " gate_up_buf: Buffer,")?;
3193 writeln!(code, " ffn_hidden_buf: Buffer,")?;
3194 writeln!(code, " ffn_out_buf: Buffer,")?;
3195 writeln!(code, " add_tmp_buf: Buffer,")?;
3196 writeln!(code, " logits_buf: Buffer,")?;
3197 writeln!(code)?;
3198 writeln!(code, " // ── Batched prefill working buffers ──")?;
3199 writeln!(code, " /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3200 writeln!(code, " batch_hidden_buf: Buffer,")?;
3201 writeln!(
3202 code,
3203 " /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3204 )?;
3205 writeln!(code, " batch_residual_buf: Buffer,")?;
3206 writeln!(code, " /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3207 writeln!(code, " batch_qkv_buf: Buffer,")?;
3208 writeln!(
3209 code,
3210 " /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3211 )?;
3212 writeln!(code, " batch_attn_out_buf: Buffer,")?;
3213 writeln!(
3214 code,
3215 " /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3216 )?;
3217 writeln!(code, " batch_attn_proj_buf: Buffer,")?;
3218 writeln!(
3219 code,
3220 " /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3221 )?;
3222 writeln!(code, " batch_gate_up_buf: Buffer,")?;
3223 writeln!(
3224 code,
3225 " /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3226 )?;
3227 writeln!(code, " batch_ffn_hidden_buf: Buffer,")?;
3228 writeln!(code, " /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3229 writeln!(code, " batch_ffn_out_buf: Buffer,")?;
3230 writeln!(code, " /// Token IDs buffer for batch embedding lookup")?;
3231 writeln!(code, " batch_tokens_buf: Buffer,")?;
3232 writeln!(code, " /// Positions buffer for batch RoPE")?;
3233 writeln!(code, " batch_positions_buf: Buffer,")?;
3234 writeln!(code)?;
3235 writeln!(code, " // ── KV cache buffers (per-layer) ──")?;
3236 writeln!(code, " k_cache: Vec<Buffer>, // per-layer")?;
3237 writeln!(code, " v_cache: Vec<Buffer>, // per-layer")?;
3238 writeln!(code)?;
3239 writeln!(code, " // ── Inference state ──")?;
3240 writeln!(code, " pos: usize,")?;
3241 writeln!(code)?;
3242 writeln!(
3243 code,
3244 " /// Previous command buffer for double-buffered prefill."
3245 )?;
3246 writeln!(
3247 code,
3248 " /// While the GPU executes token N, the CPU can encode token N+1."
3249 )?;
3250 writeln!(code, " prev_cmd: Option<CommandBuffer>,")?;
3251 writeln!(code, "}}")?;
3252 writeln!(code)?;
3253
3254 Ok(())
3255}
3256
3257fn emit_layer_buffers_struct(
3258 code: &mut String,
3259 config: &ModelConfig,
3260) -> Result<(), MetalCodegenError> {
3261 writeln!(
3262 code,
3263 "/// Per-layer weight buffers for attention and FFN projections."
3264 )?;
3265 writeln!(code, "struct LayerBuffers {{")?;
3266 writeln!(code, " attn_norm: Buffer,")?;
3267 writeln!(
3268 code,
3269 " /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3270 )?;
3271 writeln!(code, " qkv_weight: Buffer,")?;
3272 if config.qkv_bias {
3273 writeln!(
3274 code,
3275 " /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3276 )?;
3277 writeln!(code, " qkv_bias: Buffer,")?;
3278 }
3279 writeln!(code, " o_weight: Buffer,")?;
3280 writeln!(code, " ffn_norm: Buffer,")?;
3281 writeln!(
3282 code,
3283 " /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3284 )?;
3285 writeln!(code, " gate_up_weight: Buffer,")?;
3286 writeln!(code, " down_weight: Buffer,")?;
3287 writeln!(code, "}}")?;
3288 writeln!(code)?;
3289
3290 Ok(())
3291}
3292
3293fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3294 let hidden = config.hidden_size;
3295 let intermediate = config.intermediate_size;
3296 let _num_layers = config.num_layers;
3297 let _num_heads = config.num_attention_heads;
3298 let num_kv_heads = config.num_kv_heads;
3299 let head_dim = config.head_dim;
3300 let vocab = config.vocab_size;
3301 let effective_seq_len = config.max_seq_len.min(4096);
3302 let is_q8 = config.dtype == DType::Q8_0;
3303 let is_q4 = config.dtype == DType::Q4_0;
3304 let kv_dim = num_kv_heads * head_dim;
3305
3306 writeln!(code, "impl MetalModel {{")?;
3307
3308 writeln!(
3310 code,
3311 " /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3312 )?;
3313 writeln!(code, " ///")?;
3314 writeln!(
3315 code,
3316 " /// `weights` is the raw weight blob produced by `forge export-weights`."
3317 )?;
3318 writeln!(code, " pub fn new(weights: &[u8]) -> Self {{")?;
3319 writeln!(
3320 code,
3321 " let device = Device::system_default().expect(\"no Metal device found\");"
3322 )?;
3323 writeln!(code, " let queue = device.new_command_queue();")?;
3324 writeln!(code)?;
3325
3326 writeln!(
3328 code,
3329 " // Compile Metal shaders from embedded source"
3330 )?;
3331 writeln!(
3332 code,
3333 " let shader_source = include_str!(\"../shaders/kernels.metal\");"
3334 )?;
3335 writeln!(
3336 code,
3337 " let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3338 )?;
3339 writeln!(
3340 code,
3341 " .expect(\"failed to compile Metal shaders\");"
3342 )?;
3343 writeln!(code)?;
3344
3345 writeln!(code, " // Create compute pipelines")?;
3347 for (var, fn_name) in [
3348 ("matmul_pipeline", "matmul_vec"),
3349 ("matmul_q8_pipeline", "matmul_vec_q8"),
3350 ("matmul_q4_pipeline", "matmul_vec_q4"),
3351 ("rms_norm_pipeline", "rms_norm"),
3352 ("rope_pipeline", "rope"),
3353 ("softmax_pipeline", "softmax"),
3354 ("silu_mul_pipeline", "silu_mul"),
3355 ("silu_mul_fused_pipeline", "silu_mul_fused"),
3356 ("add_pipeline", "elementwise_add"),
3357 ("attention_pipeline", "attention"),
3358 ("add_inplace_pipeline", "add_inplace"),
3359 ("copy_pipeline", "copy_buffer"),
3360 ("copy_offset_pipeline", "copy_offset"),
3361 ("copy_f32_to_f16_offset_pipeline", "copy_f32_to_f16_offset"),
3362 ("matmul_batch_pipeline", "matmul_vec_batch"),
3363 ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3364 ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3365 ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3366 ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3367 ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3368 ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3369 ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3370 ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3371 ("rms_norm_batch_pipeline", "rms_norm_batch"),
3372 ("rope_batch_pipeline", "rope_batch"),
3373 ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3374 ("add_inplace_batch_pipeline", "add_inplace_batch"),
3375 ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3376 ("attention_batch_pipeline", "attention_batch"),
3377 ("attention_flash_batch_pipeline", "attention_flash_batch"),
3378 (
3379 "attention_mma_flash_batch_pipeline",
3380 "attention_mma_flash_batch",
3381 ),
3382 ("copy_kv_batch_pipeline", "copy_kv_batch"),
3383 ("rope_qk_batch_pipeline", "rope_qk_batch"),
3384 ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3385 ] {
3386 writeln!(
3387 code,
3388 " let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3389 )?;
3390 }
3391 if config.qkv_bias {
3392 writeln!(
3393 code,
3394 " let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3395 )?;
3396 }
3397 writeln!(code)?;
3398
3399 writeln!(
3401 code,
3402 " // Load weights into Metal shared-memory buffers"
3403 )?;
3404 writeln!(code, " let f32_size = mem::size_of::<f32>();")?;
3405 writeln!(code, " let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3406 writeln!(code, " let hidden_elems = HIDDEN_SIZE;")?;
3407 writeln!(code)?;
3408 writeln!(
3409 code,
3410 " let cursor = std::cell::Cell::new(0usize); // byte cursor into `weights`"
3411 )?;
3412 writeln!(code)?;
3413 writeln!(
3414 code,
3415 " // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3416 )?;
3417 writeln!(
3418 code,
3419 " let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3420 )?;
3421 writeln!(code, " let byte_len = n * f32_size;")?;
3422 writeln!(code, " let cur = cursor.get();")?;
3423 writeln!(
3424 code,
3425 " let data = &weights[cur..cur + byte_len];"
3426 )?;
3427 writeln!(code, " cursor.set(cur + byte_len);")?;
3428 writeln!(code, " device.new_buffer_with_data(")?;
3429 writeln!(code, " data.as_ptr() as *const _,")?;
3430 writeln!(code, " byte_len as u64,")?;
3431 writeln!(
3432 code,
3433 " MTLResourceOptions::StorageModeShared,"
3434 )?;
3435 writeln!(code, " )")?;
3436 writeln!(code, " }};")?;
3437 writeln!(code)?;
3438
3439 if is_q8 {
3440 writeln!(
3445 code,
3446 " // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3447 )?;
3448 writeln!(
3449 code,
3450 " // as raw bytes into a Metal buffer (no dequantization)."
3451 )?;
3452 writeln!(
3453 code,
3454 " // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3455 )?;
3456 writeln!(
3457 code,
3458 " let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3459 )?;
3460 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3461 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3462 writeln!(code, " let total_raw = rows * row_bytes;")?;
3463 writeln!(code, " let cur = cursor.get();")?;
3464 writeln!(
3465 code,
3466 " let data = &weights[cur..cur + total_raw];"
3467 )?;
3468 writeln!(code, " cursor.set(cur + total_raw);")?;
3469 writeln!(code, " device.new_buffer_with_data(")?;
3470 writeln!(code, " data.as_ptr() as *const _,")?;
3471 writeln!(code, " total_raw as u64,")?;
3472 writeln!(
3473 code,
3474 " MTLResourceOptions::StorageModeShared,"
3475 )?;
3476 writeln!(code, " )")?;
3477 writeln!(code, " }};")?;
3478 writeln!(code)?;
3479 writeln!(
3480 code,
3481 " // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3482 )?;
3483 writeln!(
3484 code,
3485 " // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3486 )?;
3487 writeln!(
3488 code,
3489 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3490 )?;
3491 writeln!(
3492 code,
3493 " let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3494 )?;
3495 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3496 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3497 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3498 writeln!(code, " let cur = cursor.get();")?;
3499 writeln!(
3500 code,
3501 " let data = &weights[cur..cur + total_raw];"
3502 )?;
3503 writeln!(code, " cursor.set(cur + total_raw);")?;
3504 writeln!(code, " device.new_buffer_with_data(")?;
3505 writeln!(code, " data.as_ptr() as *const _,")?;
3506 writeln!(code, " total_raw as u64,")?;
3507 writeln!(
3508 code,
3509 " MTLResourceOptions::StorageModeShared,"
3510 )?;
3511 writeln!(code, " )")?;
3512 writeln!(code, " }};")?;
3513 writeln!(code)?;
3514 }
3515
3516 if is_q4 {
3517 writeln!(
3522 code,
3523 " // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3524 )?;
3525 writeln!(
3526 code,
3527 " // as raw bytes into a Metal buffer (no dequantization)."
3528 )?;
3529 writeln!(
3530 code,
3531 " // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3532 )?;
3533 writeln!(
3534 code,
3535 " let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3536 )?;
3537 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3538 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3539 writeln!(code, " let total_raw = rows * row_bytes;")?;
3540 writeln!(code, " let cur = cursor.get();")?;
3541 writeln!(
3542 code,
3543 " let data = &weights[cur..cur + total_raw];"
3544 )?;
3545 writeln!(code, " cursor.set(cur + total_raw);")?;
3546 writeln!(code, " device.new_buffer_with_data(")?;
3547 writeln!(code, " data.as_ptr() as *const _,")?;
3548 writeln!(code, " total_raw as u64,")?;
3549 writeln!(
3550 code,
3551 " MTLResourceOptions::StorageModeShared,"
3552 )?;
3553 writeln!(code, " )")?;
3554 writeln!(code, " }};")?;
3555 writeln!(code)?;
3556 writeln!(
3557 code,
3558 " // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3559 )?;
3560 writeln!(
3561 code,
3562 " // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3563 )?;
3564 writeln!(
3565 code,
3566 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3567 )?;
3568 writeln!(
3569 code,
3570 " let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3571 )?;
3572 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3573 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3574 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3575 writeln!(code, " let cur = cursor.get();")?;
3576 writeln!(
3577 code,
3578 " let data = &weights[cur..cur + total_raw];"
3579 )?;
3580 writeln!(code, " cursor.set(cur + total_raw);")?;
3581 writeln!(code, " device.new_buffer_with_data(")?;
3582 writeln!(code, " data.as_ptr() as *const _,")?;
3583 writeln!(code, " total_raw as u64,")?;
3584 writeln!(
3585 code,
3586 " MTLResourceOptions::StorageModeShared,"
3587 )?;
3588 writeln!(code, " )")?;
3589 writeln!(code, " }};")?;
3590 writeln!(code)?;
3591 }
3592
3593 writeln!(
3594 code,
3595 " let embed_buf = next_f32_buffer(&device, embed_elems);"
3596 )?;
3597 writeln!(code)?;
3598
3599 writeln!(
3601 code,
3602 " let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3603 )?;
3604 writeln!(code, " for _layer in 0..NUM_LAYERS {{")?;
3605
3606 writeln!(
3608 code,
3609 " let attn_norm = next_f32_buffer(&device, hidden_elems);"
3610 )?;
3611
3612 let qkv_rows = hidden + 2 * kv_dim;
3613 if is_q8 {
3614 writeln!(
3616 code,
3617 " let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3618 )?;
3619 if config.qkv_bias {
3620 writeln!(
3621 code,
3622 " // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3623 )?;
3624 writeln!(
3625 code,
3626 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3627 )?;
3628 }
3629 writeln!(
3630 code,
3631 " let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3632 )?;
3633 } else if is_q4 {
3634 writeln!(
3635 code,
3636 " let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3637 )?;
3638 if config.qkv_bias {
3639 writeln!(
3640 code,
3641 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3642 )?;
3643 }
3644 writeln!(
3645 code,
3646 " let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3647 )?;
3648 } else {
3649 writeln!(
3650 code,
3651 " let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3652 )?;
3653 if config.qkv_bias {
3654 writeln!(
3655 code,
3656 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3657 )?;
3658 }
3659 writeln!(
3660 code,
3661 " let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3662 )?;
3663 }
3664
3665 writeln!(
3667 code,
3668 " let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3669 )?;
3670
3671 let gate_up_rows = 2 * intermediate;
3672 if is_q8 {
3673 writeln!(
3675 code,
3676 " let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3677 )?;
3678 writeln!(
3679 code,
3680 " let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3681 )?;
3682 } else if is_q4 {
3683 writeln!(
3685 code,
3686 " let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3687 )?;
3688 writeln!(
3689 code,
3690 " let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3691 )?;
3692 } else {
3693 writeln!(
3695 code,
3696 " let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3697 )?;
3698 writeln!(
3699 code,
3700 " let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3701 )?;
3702 }
3703
3704 writeln!(code, " layers.push(LayerBuffers {{")?;
3705 writeln!(code, " attn_norm,")?;
3706 writeln!(code, " qkv_weight,")?;
3707 if config.qkv_bias {
3708 writeln!(code, " qkv_bias,")?;
3709 }
3710 writeln!(code, " o_weight,")?;
3711 writeln!(code, " ffn_norm,")?;
3712 writeln!(code, " gate_up_weight,")?;
3713 writeln!(code, " down_weight,")?;
3714 writeln!(code, " }});")?;
3715 writeln!(code, " }}")?;
3716 writeln!(code)?;
3717
3718 writeln!(
3720 code,
3721 " let norm_buf = next_f32_buffer(&device, hidden_elems);"
3722 )?;
3723 writeln!(code)?;
3724
3725 if is_q8 {
3727 writeln!(
3728 code,
3729 " let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3730 )?;
3731 } else if is_q4 {
3732 writeln!(
3733 code,
3734 " let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3735 )?;
3736 } else {
3737 writeln!(
3738 code,
3739 " let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3740 )?;
3741 }
3742 writeln!(code)?;
3743
3744 let hidden_bytes = hidden * 4;
3746 let _kv_dim_bytes = kv_dim * 4;
3747 let intermediate_bytes = intermediate * 4;
3748 let vocab_bytes = vocab * 4;
3749 let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 2;
3752
3753 writeln!(
3754 code,
3755 " // Allocate working buffers (shared memory for zero-copy)"
3756 )?;
3757 writeln!(
3758 code,
3759 " let opts = MTLResourceOptions::StorageModeShared;"
3760 )?;
3761 writeln!(
3762 code,
3763 " let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3764 )?;
3765 writeln!(
3766 code,
3767 " let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3768 )?;
3769 let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3770 writeln!(
3771 code,
3772 " let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3773 )?;
3774 writeln!(
3775 code,
3776 " // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3777 )?;
3778 writeln!(
3779 code,
3780 " let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3781 )?;
3782 writeln!(
3783 code,
3784 " let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3785 )?;
3786 writeln!(
3787 code,
3788 " let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3789 )?;
3790 let gate_up_buf_bytes = 2 * intermediate * 4;
3791 writeln!(
3792 code,
3793 " // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3794 )?;
3795 writeln!(
3796 code,
3797 " let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3798 )?;
3799 writeln!(
3800 code,
3801 " let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3802 )?;
3803 writeln!(
3804 code,
3805 " let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3806 )?;
3807 writeln!(
3808 code,
3809 " let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3810 )?;
3811 writeln!(
3812 code,
3813 " let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3814 )?;
3815 writeln!(code)?;
3816
3817 let batch_hidden_bytes = hidden * 4; let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3820 let batch_gate_up_bytes = 2 * intermediate * 4;
3821 let batch_intermediate_bytes = intermediate * 4;
3822 writeln!(
3823 code,
3824 " // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3825 )?;
3826 writeln!(
3827 code,
3828 " let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3829 )?;
3830 writeln!(
3831 code,
3832 " let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3833 )?;
3834 writeln!(
3835 code,
3836 " let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3837 )?;
3838 writeln!(
3839 code,
3840 " let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3841 )?;
3842 writeln!(
3843 code,
3844 " let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3845 )?;
3846 writeln!(
3847 code,
3848 " let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3849 )?;
3850 writeln!(
3851 code,
3852 " let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3853 )?;
3854 writeln!(
3855 code,
3856 " let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3857 )?;
3858 writeln!(
3859 code,
3860 " let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3861 )?;
3862 writeln!(
3863 code,
3864 " let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3865 )?;
3866 writeln!(code)?;
3867
3868 writeln!(code, " // KV cache buffers (per-layer)")?;
3870 writeln!(
3871 code,
3872 " let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3873 )?;
3874 writeln!(
3875 code,
3876 " let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3877 )?;
3878 writeln!(code, " for _ in 0..NUM_LAYERS {{")?;
3879 writeln!(
3880 code,
3881 " k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3882 )?;
3883 writeln!(
3884 code,
3885 " v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3886 )?;
3887 writeln!(code, " }}")?;
3888 writeln!(code)?;
3889
3890 writeln!(code, " Self {{")?;
3891 writeln!(code, " device,")?;
3892 writeln!(code, " queue,")?;
3893 writeln!(code, " matmul_pipeline,")?;
3894 writeln!(code, " matmul_q8_pipeline,")?;
3895 writeln!(code, " matmul_q4_pipeline,")?;
3896 writeln!(code, " rms_norm_pipeline,")?;
3897 writeln!(code, " rope_pipeline,")?;
3898 writeln!(code, " softmax_pipeline,")?;
3899 writeln!(code, " silu_mul_pipeline,")?;
3900 writeln!(code, " silu_mul_fused_pipeline,")?;
3901 writeln!(code, " add_pipeline,")?;
3902 writeln!(code, " attention_pipeline,")?;
3903 writeln!(code, " add_inplace_pipeline,")?;
3904 writeln!(code, " copy_pipeline,")?;
3905 writeln!(code, " copy_offset_pipeline,")?;
3906 writeln!(code, " copy_f32_to_f16_offset_pipeline,")?;
3907 writeln!(code, " matmul_batch_pipeline,")?;
3908 writeln!(code, " matmul_q8_batch_pipeline,")?;
3909 writeln!(code, " matmul_q8_gemm_batch_pipeline,")?;
3910 writeln!(code, " matmul_q8_mma_pipeline,")?;
3911 writeln!(code, " matmul_q8_mma32_pipeline,")?;
3912 writeln!(code, " matmul_q8_mma32_h_pipeline,")?;
3913 writeln!(code, " matmul_q8_mma32_h4_pipeline,")?;
3914 writeln!(code, " matmul_q8_mma32_hh4_pipeline,")?;
3915 if config.qkv_bias {
3916 writeln!(code, " add_bias_batch_pipeline,")?;
3917 }
3918 writeln!(code, " matmul_q4_batch_pipeline,")?;
3919 writeln!(code, " rms_norm_batch_pipeline,")?;
3920 writeln!(code, " rope_batch_pipeline,")?;
3921 writeln!(code, " silu_mul_fused_batch_pipeline,")?;
3922 writeln!(code, " add_inplace_batch_pipeline,")?;
3923 writeln!(code, " copy_embedding_batch_pipeline,")?;
3924 writeln!(code, " attention_batch_pipeline,")?;
3925 writeln!(code, " attention_flash_batch_pipeline,")?;
3926 writeln!(code, " attention_mma_flash_batch_pipeline,")?;
3927 writeln!(code, " copy_kv_batch_pipeline,")?;
3928 writeln!(code, " rope_qk_batch_pipeline,")?;
3929 writeln!(code, " copy_kv_both_batch_pipeline,")?;
3930 writeln!(code, " embed_buf,")?;
3931 writeln!(code, " layers,")?;
3932 writeln!(code, " norm_buf,")?;
3933 writeln!(code, " lm_head_buf,")?;
3934 writeln!(code, " hidden_buf,")?;
3935 writeln!(code, " residual_buf,")?;
3936 writeln!(code, " normed_buf,")?;
3937 writeln!(code, " qkv_buf,")?;
3938 writeln!(code, " attn_out_buf,")?;
3939 writeln!(code, " attn_proj_buf,")?;
3940 writeln!(code, " gate_up_buf,")?;
3941 writeln!(code, " ffn_hidden_buf,")?;
3942 writeln!(code, " ffn_out_buf,")?;
3943 writeln!(code, " add_tmp_buf,")?;
3944 writeln!(code, " logits_buf,")?;
3945 writeln!(code, " batch_hidden_buf,")?;
3946 writeln!(code, " batch_residual_buf,")?;
3947 writeln!(code, " batch_qkv_buf,")?;
3948 writeln!(code, " batch_attn_out_buf,")?;
3949 writeln!(code, " batch_attn_proj_buf,")?;
3950 writeln!(code, " batch_gate_up_buf,")?;
3951 writeln!(code, " batch_ffn_hidden_buf,")?;
3952 writeln!(code, " batch_ffn_out_buf,")?;
3953 writeln!(code, " batch_tokens_buf,")?;
3954 writeln!(code, " batch_positions_buf,")?;
3955 writeln!(code, " k_cache,")?;
3956 writeln!(code, " v_cache,")?;
3957 writeln!(code, " pos: 0,")?;
3958 writeln!(code, " prev_cmd: None,")?;
3959 writeln!(code, " }}")?;
3960 writeln!(code, " }}")?;
3961 writeln!(code)?;
3962
3963 writeln!(
3965 code,
3966 " /// Run the forward pass for a single token at the current position."
3967 )?;
3968 writeln!(code, " ///")?;
3969 writeln!(
3970 code,
3971 " /// Returns logits over the vocabulary as a `Vec<f32>`."
3972 )?;
3973 writeln!(code, " ///")?;
3974 writeln!(
3975 code,
3976 " /// All GPU operations are encoded into a single command buffer and"
3977 )?;
3978 writeln!(
3979 code,
3980 " /// committed once at the end, avoiding per-operation synchronization."
3981 )?;
3982 writeln!(
3983 code,
3984 " pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3985 )?;
3986 writeln!(
3987 code,
3988 " // Wait for any pending prefill command buffer"
3989 )?;
3990 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
3991 writeln!(code, " prev.wait_until_completed();")?;
3992 writeln!(code, " }}")?;
3993 writeln!(code)?;
3994 writeln!(code, " let pos = self.pos;")?;
3995 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
3996 writeln!(code)?;
3997
3998 let matmul_fn = if is_q8 {
4001 "dispatch_matmul_q8"
4002 } else if is_q4 {
4003 "dispatch_matmul_q4"
4004 } else {
4005 "dispatch_matmul"
4006 };
4007
4008 writeln!(
4009 code,
4010 " // Single compute encoder for the entire forward pass (no blit transitions)"
4011 )?;
4012 writeln!(code, " {{")?;
4013 writeln!(
4014 code,
4015 " let enc = cmd.new_compute_command_encoder();"
4016 )?;
4017 writeln!(code)?;
4018
4019 writeln!(
4021 code,
4022 " // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
4023 )?;
4024 writeln!(
4025 code,
4026 " // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
4027 )?;
4028 writeln!(
4029 code,
4030 " // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
4031 )?;
4032 writeln!(
4033 code,
4034 " // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
4035 hidden * 4,
4036 )?;
4037 writeln!(code, " unsafe {{")?;
4038 writeln!(
4039 code,
4040 " let embed_ptr = self.embed_buf.contents() as *const f32;"
4041 )?;
4042 writeln!(
4043 code,
4044 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4045 )?;
4046 writeln!(
4047 code,
4048 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
4049 )?;
4050 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
4051 writeln!(
4052 code,
4053 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4054 )?;
4055 writeln!(code, " hidden_ptr,")?;
4056 writeln!(code, " HIDDEN_SIZE,")?;
4057 writeln!(code, " );")?;
4058 writeln!(
4059 code,
4060 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4061 )?;
4062 writeln!(code, " }}")?;
4063 writeln!(code)?;
4064
4065 writeln!(code, " // 2. Transformer layers")?;
4067 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4068 writeln!(code)?;
4069 let q_byte_offset = 0usize;
4070 let k_float_offset = hidden;
4071 let v_float_offset = hidden + kv_dim;
4072
4073 writeln!(
4074 code,
4075 " // Pre-attention: rms_norm, fused QKV projection, RoPE"
4076 )?;
4077 writeln!(
4078 code,
4079 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4080 )?;
4081 writeln!(
4082 code,
4083 " // Fused Q+K+V matmul: single dispatch for all three projections"
4084 )?;
4085 writeln!(
4086 code,
4087 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4088 )?;
4089 if config.qkv_bias {
4090 writeln!(
4091 code,
4092 " // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4093 )?;
4094 writeln!(
4095 code,
4096 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4097 )?;
4098 }
4099 writeln!(
4100 code,
4101 " // Fused Q+K RoPE in one dispatch (saves 1 dispatch + barrier vs separate Q and K rope)"
4102 )?;
4103 writeln!(
4104 code,
4105 " self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4106 )?;
4107 writeln!(code)?;
4108 writeln!(
4109 code,
4110 " // Fused K+V cache update in one dispatch (f32 qkv_buf -> f16 KV cache)"
4111 )?;
4112 writeln!(
4113 code,
4114 " self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4115 )?;
4116 writeln!(code)?;
4117 writeln!(
4118 code,
4119 " // Attention using Q from qkv_buf (offset 0)"
4120 )?;
4121 writeln!(
4122 code,
4123 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4124 )?;
4125 writeln!(
4126 code,
4127 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4128 )?;
4129 writeln!(
4130 code,
4131 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4132 )?;
4133 writeln!(
4134 code,
4135 " // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4136 )?;
4137 writeln!(
4138 code,
4139 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4140 )?;
4141 writeln!(
4142 code,
4143 " // Fused gate+up matmul: single dispatch for both projections"
4144 )?;
4145 writeln!(
4146 code,
4147 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4148 )?;
4149 writeln!(
4150 code,
4151 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4152 )?;
4153 writeln!(
4154 code,
4155 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4156 )?;
4157 writeln!(
4158 code,
4159 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4160 )?;
4161 writeln!(code, " }}")?;
4162 writeln!(code)?;
4163
4164 writeln!(code, " // 3. Final RMS norm + logits projection")?;
4166 writeln!(
4167 code,
4168 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4169 )?;
4170 writeln!(
4171 code,
4172 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4173 )?;
4174 writeln!(code)?;
4175 writeln!(code, " enc.end_encoding();")?;
4176 writeln!(code, " }}")?;
4177 writeln!(code)?;
4178
4179 writeln!(
4181 code,
4182 " // 5. Commit all GPU work and wait for completion"
4183 )?;
4184 writeln!(code, " cmd.commit();")?;
4185 writeln!(code, " cmd.wait_until_completed();")?;
4186 writeln!(code)?;
4187 writeln!(code, " // 6. Read back logits from GPU")?;
4188 writeln!(code, " let logits = unsafe {{")?;
4189 writeln!(
4190 code,
4191 " let ptr = self.logits_buf.contents() as *const f32;"
4192 )?;
4193 writeln!(
4194 code,
4195 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4196 )?;
4197 writeln!(code, " }};")?;
4198 writeln!(code)?;
4199 writeln!(code, " self.pos += 1;")?;
4200 writeln!(code, " logits")?;
4201 writeln!(code, " }}")?;
4202 writeln!(code)?;
4203
4204 writeln!(
4206 code,
4207 " /// Profiling forward pass that prints per-stage GPU timing."
4208 )?;
4209 writeln!(code, " ///")?;
4210 writeln!(
4211 code,
4212 " /// Each stage is committed and waited on separately so that GPU timestamps"
4213 )?;
4214 writeln!(
4215 code,
4216 " /// accurately reflect per-operation cost. This is slower than `forward()` due"
4217 )?;
4218 writeln!(
4219 code,
4220 " /// to the per-stage synchronization, but useful for identifying bottlenecks."
4221 )?;
4222 writeln!(
4223 code,
4224 " pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4225 )?;
4226 writeln!(code, " use std::time::Instant;")?;
4227 writeln!(code)?;
4228 writeln!(
4229 code,
4230 " // Wait for any pending prefill command buffer"
4231 )?;
4232 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4233 writeln!(code, " prev.wait_until_completed();")?;
4234 writeln!(code, " }}")?;
4235 writeln!(code)?;
4236 writeln!(code, " let pos = self.pos;")?;
4237 writeln!(code)?;
4238
4239 writeln!(
4241 code,
4242 " // ── Stage: Embedding lookup (CPU via unified memory) ──"
4243 )?;
4244 writeln!(code, " let t_embed = Instant::now();")?;
4245 writeln!(code, " unsafe {{")?;
4246 writeln!(
4247 code,
4248 " let embed_ptr = self.embed_buf.contents() as *const f32;"
4249 )?;
4250 writeln!(
4251 code,
4252 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4253 )?;
4254 writeln!(
4255 code,
4256 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
4257 )?;
4258 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
4259 writeln!(
4260 code,
4261 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4262 )?;
4263 writeln!(code, " hidden_ptr,")?;
4264 writeln!(code, " HIDDEN_SIZE,")?;
4265 writeln!(code, " );")?;
4266 writeln!(
4267 code,
4268 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4269 )?;
4270 writeln!(code, " }}")?;
4271 writeln!(code, " let d_embed = t_embed.elapsed();")?;
4272 writeln!(code)?;
4273
4274 writeln!(code, " // ── Stage: Transformer layers (GPU) ──")?;
4276 writeln!(code, " let t_layers = Instant::now();")?;
4277 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4278 writeln!(code, " {{")?;
4279 writeln!(
4280 code,
4281 " let enc = cmd.new_compute_command_encoder();"
4282 )?;
4283 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4284 writeln!(
4285 code,
4286 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4287 )?;
4288 writeln!(
4289 code,
4290 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4291 )?;
4292 if config.qkv_bias {
4293 writeln!(
4294 code,
4295 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4296 )?;
4297 }
4298 writeln!(
4299 code,
4300 " self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
4301 )?;
4302 writeln!(
4303 code,
4304 " self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4305 )?;
4306 writeln!(
4307 code,
4308 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4309 )?;
4310 writeln!(
4311 code,
4312 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4313 )?;
4314 writeln!(
4315 code,
4316 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4317 )?;
4318 writeln!(
4319 code,
4320 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4321 )?;
4322 writeln!(
4323 code,
4324 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4325 )?;
4326 writeln!(
4327 code,
4328 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4329 )?;
4330 writeln!(
4331 code,
4332 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4333 )?;
4334 writeln!(
4335 code,
4336 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4337 )?;
4338 writeln!(code, " }}")?;
4339 writeln!(code, " enc.end_encoding();")?;
4340 writeln!(code, " }}")?;
4341 writeln!(code, " cmd.commit();")?;
4342 writeln!(code, " cmd.wait_until_completed();")?;
4343 writeln!(code, " let d_layers = t_layers.elapsed();")?;
4344 writeln!(code)?;
4345
4346 writeln!(code, " // ── Stage: Final norm + logits (GPU) ──")?;
4348 writeln!(code, " let t_logits = Instant::now();")?;
4349 writeln!(code, " let cmd2 = self.queue.new_command_buffer();")?;
4350 writeln!(code, " {{")?;
4351 writeln!(
4352 code,
4353 " let enc = cmd2.new_compute_command_encoder();"
4354 )?;
4355 writeln!(
4356 code,
4357 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4358 )?;
4359 writeln!(
4360 code,
4361 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4362 )?;
4363 writeln!(code, " enc.end_encoding();")?;
4364 writeln!(code, " }}")?;
4365 writeln!(code, " cmd2.commit();")?;
4366 writeln!(code, " cmd2.wait_until_completed();")?;
4367 writeln!(code, " let d_logits = t_logits.elapsed();")?;
4368 writeln!(code)?;
4369
4370 writeln!(
4372 code,
4373 " eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4374 )?;
4375 writeln!(code, " d_embed.as_secs_f64() * 1000.0,")?;
4376 writeln!(code, " d_layers.as_secs_f64() * 1000.0,")?;
4377 writeln!(code, " d_logits.as_secs_f64() * 1000.0,")?;
4378 writeln!(
4379 code,
4380 " (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4381 )?;
4382 writeln!(code)?;
4383
4384 writeln!(code, " let logits = unsafe {{")?;
4386 writeln!(
4387 code,
4388 " let ptr = self.logits_buf.contents() as *const f32;"
4389 )?;
4390 writeln!(
4391 code,
4392 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4393 )?;
4394 writeln!(code, " }};")?;
4395 writeln!(code)?;
4396 writeln!(code, " self.pos += 1;")?;
4397 writeln!(code, " logits")?;
4398 writeln!(code, " }}")?;
4399 writeln!(code)?;
4400
4401 writeln!(
4403 code,
4404 " /// Asynchronous forward pass for a single prefill token (no logits readback)."
4405 )?;
4406 writeln!(code, " ///")?;
4407 writeln!(
4408 code,
4409 " /// Commits the command buffer without waiting, enabling double-buffered"
4410 )?;
4411 writeln!(
4412 code,
4413 " /// execution: GPU processes token N while CPU encodes token N+1."
4414 )?;
4415 writeln!(
4416 code,
4417 " pub fn forward_prefill(&mut self, token_id: u32) {{"
4418 )?;
4419 writeln!(code, " self.forward_prefill_batch(&[token_id]);")?;
4420 writeln!(code, " }}")?;
4421 writeln!(code)?;
4422
4423 let batch_matmul_fn = if is_q8 {
4426 "dispatch_matmul_q8_batch"
4427 } else if is_q4 {
4428 "dispatch_matmul_q4_batch"
4429 } else {
4430 "dispatch_matmul_batch"
4431 };
4432
4433 writeln!(
4434 code,
4435 " /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4436 )?;
4437 writeln!(code, " ///")?;
4438 writeln!(
4439 code,
4440 " /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4441 )?;
4442 writeln!(
4443 code,
4444 " /// of mat-vec), and batched causal attention with a single GPU dispatch."
4445 )?;
4446 writeln!(
4447 code,
4448 " /// This provides significant speedup during prompt prefill."
4449 )?;
4450 writeln!(
4451 code,
4452 " pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4453 )?;
4454 writeln!(code, " if tokens.is_empty() {{ return; }}")?;
4455 writeln!(
4456 code,
4457 " // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4458 )?;
4459 writeln!(
4460 code,
4461 " // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4462 )?;
4463 writeln!(
4464 code,
4465 " // longer than that must be processed iteratively. The KV cache"
4466 )?;
4467 writeln!(code, " // carries state across chunks via self.pos.")?;
4468 writeln!(
4469 code,
4470 " for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4471 )?;
4472 writeln!(code, " let m = chunk.len();")?;
4473 writeln!(code, " if m == 0 {{ continue; }}")?;
4474 writeln!(code, " let start_pos = self.pos;")?;
4475 writeln!(code)?;
4476 writeln!(code, " // Wait for any pending command buffer")?;
4477 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4478 writeln!(code, " prev.wait_until_completed();")?;
4479 writeln!(code, " }}")?;
4480 writeln!(code)?;
4481
4482 writeln!(
4484 code,
4485 " // Upload token IDs and positions to GPU buffers"
4486 )?;
4487 writeln!(code, " unsafe {{")?;
4488 writeln!(
4489 code,
4490 " let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4491 )?;
4492 writeln!(
4493 code,
4494 " let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4495 )?;
4496 writeln!(code, " for i in 0..m {{")?;
4497 writeln!(code, " *tok_ptr.add(i) = chunk[i];")?;
4498 writeln!(
4499 code,
4500 " *pos_ptr.add(i) = (start_pos + i) as u32;"
4501 )?;
4502 writeln!(code, " }}")?;
4503 writeln!(code, " }}")?;
4504 writeln!(code)?;
4505
4506 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4507 writeln!(code, " {{")?;
4508 writeln!(
4509 code,
4510 " let enc = cmd.new_compute_command_encoder();"
4511 )?;
4512 writeln!(code)?;
4513
4514 writeln!(
4516 code,
4517 " // 1. Batch embedding lookup: copy all token embeddings at once"
4518 )?;
4519 writeln!(
4520 code,
4521 " self.dispatch_copy_embedding_batch(&enc, m);"
4522 )?;
4523 writeln!(
4525 code,
4526 " self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4527 )?;
4528 writeln!(code)?;
4529
4530 writeln!(code, " // 2. Transformer layers")?;
4532 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4533 writeln!(code)?;
4534
4535 writeln!(
4537 code,
4538 " // Batch RMS norm: batch_residual -> batch_hidden"
4539 )?;
4540 writeln!(
4541 code,
4542 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4543 )?;
4544
4545 writeln!(
4547 code,
4548 " // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4549 )?;
4550 writeln!(
4551 code,
4552 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4553 )?;
4554 if config.qkv_bias {
4555 writeln!(
4556 code,
4557 " // Qwen2: broadcast-add QKV bias across all M tokens."
4558 )?;
4559 writeln!(
4560 code,
4561 " self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4562 )?;
4563 }
4564 writeln!(code)?;
4565
4566 let k_float_offset = hidden;
4568 writeln!(
4569 code,
4570 " // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4571 )?;
4572 writeln!(
4573 code,
4574 " self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4575 )?;
4576 writeln!(code)?;
4577
4578 let v_float_offset = hidden + kv_dim;
4580 writeln!(
4581 code,
4582 " // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4583 )?;
4584 writeln!(
4585 code,
4586 " self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4587 )?;
4588 writeln!(code)?;
4589
4590 writeln!(
4592 code,
4593 " // Batched causal attention: one dispatch for all M tokens"
4594 )?;
4595 writeln!(
4596 code,
4597 " self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4598 )?;
4599 writeln!(code)?;
4600
4601 writeln!(code, " // Batched O projection")?;
4603 writeln!(
4604 code,
4605 " self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4606 )?;
4607 writeln!(code)?;
4608
4609 writeln!(
4611 code,
4612 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4613 )?;
4614 writeln!(code)?;
4615
4616 writeln!(
4618 code,
4619 " // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4620 )?;
4621 writeln!(
4622 code,
4623 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4624 )?;
4625 writeln!(
4626 code,
4627 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4628 )?;
4629 writeln!(
4630 code,
4631 " self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4632 )?;
4633 writeln!(
4634 code,
4635 " self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4636 )?;
4637 writeln!(
4638 code,
4639 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4640 )?;
4641 writeln!(code, " }}")?;
4642 writeln!(code)?;
4643
4644 writeln!(
4646 code,
4647 " // Copy last token's residual to single-token buffer for subsequent forward()"
4648 )?;
4649 writeln!(
4650 code,
4651 " self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4652 )?;
4653 writeln!(code)?;
4654 writeln!(code, " enc.end_encoding();")?;
4655 writeln!(code, " }}")?;
4656 writeln!(code)?;
4657
4658 writeln!(code, " cmd.commit();")?;
4659 writeln!(code, " self.prev_cmd = Some(cmd.to_owned());")?;
4660 writeln!(code, " self.pos += m;")?;
4661 writeln!(code, " }} // end for chunk")?;
4662 writeln!(code, " }}")?;
4663 writeln!(code)?;
4664
4665 writeln!(
4667 code,
4668 " /// Reset the model state for a new inference request."
4669 )?;
4670 writeln!(code, " pub fn reset(&mut self) {{")?;
4671 writeln!(code, " self.pos = 0;")?;
4672 writeln!(code, " self.prev_cmd = None;")?;
4673 writeln!(code, " }}")?;
4674 writeln!(code)?;
4675
4676 writeln!(
4678 code,
4679 " // ── Dispatch helpers (append to a shared compute command encoder) ──"
4680 )?;
4681 writeln!(
4682 code,
4683 " // These methods set pipeline state + buffers + dispatch on an existing"
4684 )?;
4685 writeln!(
4686 code,
4687 " // encoder, avoiding per-operation encoder creation overhead."
4688 )?;
4689 writeln!(code)?;
4690
4691 writeln!(
4693 code,
4694 " /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4695 )?;
4696 writeln!(
4697 code,
4698 " fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4699 )?;
4700 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
4701 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
4702 writeln!(
4703 code,
4704 " enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4705 )?;
4706 writeln!(
4707 code,
4708 " enc.set_buffer(0, Some(&self.residual_buf), 0);"
4709 )?;
4710 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
4711 writeln!(
4712 code,
4713 " enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4714 )?;
4715 writeln!(
4716 code,
4717 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4718 )?;
4719 writeln!(
4720 code,
4721 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4722 )?;
4723 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4724 writeln!(
4725 code,
4726 " let grid_size = MTLSize::new(1, 1, 1); // single threadgroup"
4727 )?;
4728 writeln!(
4729 code,
4730 " enc.dispatch_thread_groups(grid_size, tg_size);"
4731 )?;
4732 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4733 writeln!(code, " }}")?;
4734 writeln!(code)?;
4735
4736 writeln!(
4738 code,
4739 " /// Dispatch matrix-vector multiply: weight * input -> output."
4740 )?;
4741 writeln!(
4742 code,
4743 " fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4744 )?;
4745 writeln!(code, " let r: u32 = rows as u32;")?;
4746 writeln!(code, " let c: u32 = cols as u32;")?;
4747 writeln!(
4748 code,
4749 " enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4750 )?;
4751 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4752 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4753 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4754 writeln!(
4755 code,
4756 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4757 )?;
4758 writeln!(
4759 code,
4760 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4761 )?;
4762 writeln!(
4763 code,
4764 " // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4765 )?;
4766 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4767 writeln!(code, " let num_tg = ((rows + 63) / 64) as u64;")?;
4768 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4769 writeln!(
4770 code,
4771 " enc.dispatch_thread_groups(grid_size, tg_size);"
4772 )?;
4773 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4774 writeln!(code, " }}")?;
4775 writeln!(code)?;
4776
4777 writeln!(
4779 code,
4780 " /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4781 )?;
4782 writeln!(
4783 code,
4784 " /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4785 )?;
4786 writeln!(
4787 code,
4788 " fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4789 )?;
4790 writeln!(code, " let r: u32 = rows as u32;")?;
4791 writeln!(code, " let c: u32 = cols as u32;")?;
4792 writeln!(
4793 code,
4794 " enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4795 )?;
4796 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4797 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4798 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4799 writeln!(
4800 code,
4801 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4802 )?;
4803 writeln!(
4804 code,
4805 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4806 )?;
4807 writeln!(
4808 code,
4809 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4810 )?;
4811 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4812 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4813 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4814 writeln!(
4815 code,
4816 " enc.dispatch_thread_groups(grid_size, tg_size);"
4817 )?;
4818 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4819 writeln!(code, " }}")?;
4820 writeln!(code)?;
4821
4822 writeln!(
4824 code,
4825 " /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4826 )?;
4827 writeln!(
4828 code,
4829 " /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4830 )?;
4831 writeln!(
4832 code,
4833 " fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4834 )?;
4835 writeln!(code, " let r: u32 = rows as u32;")?;
4836 writeln!(code, " let c: u32 = cols as u32;")?;
4837 writeln!(
4838 code,
4839 " enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4840 )?;
4841 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4842 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4843 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4844 writeln!(
4845 code,
4846 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4847 )?;
4848 writeln!(
4849 code,
4850 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4851 )?;
4852 writeln!(
4853 code,
4854 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4855 )?;
4856 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4857 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4858 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4859 writeln!(
4860 code,
4861 " enc.dispatch_thread_groups(grid_size, tg_size);"
4862 )?;
4863 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4864 writeln!(code, " }}")?;
4865 writeln!(code)?;
4866
4867 writeln!(code, " /// Dispatch RoPE on a buffer in-place.")?;
4869 writeln!(
4870 code,
4871 " fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4872 )?;
4873 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4874 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4875 writeln!(code, " let p: u32 = pos as u32;")?;
4876 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4877 writeln!(
4878 code,
4879 " let total_pairs = num_heads * (head_dim / 2);"
4880 )?;
4881 writeln!(
4882 code,
4883 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4884 )?;
4885 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
4886 writeln!(
4887 code,
4888 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4889 )?;
4890 writeln!(
4891 code,
4892 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4893 )?;
4894 writeln!(
4895 code,
4896 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4897 )?;
4898 writeln!(
4899 code,
4900 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4901 )?;
4902 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4903 writeln!(
4904 code,
4905 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4906 )?;
4907 writeln!(
4908 code,
4909 " enc.dispatch_thread_groups(grid_size, tg_size);"
4910 )?;
4911 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4912 writeln!(code, " }}")?;
4913 writeln!(code)?;
4914
4915 writeln!(
4917 code,
4918 " /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4919 )?;
4920 writeln!(
4921 code,
4922 " fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4923 )?;
4924 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4925 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4926 writeln!(code, " let p: u32 = pos as u32;")?;
4927 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4928 writeln!(
4929 code,
4930 " let total_pairs = num_heads * (head_dim / 2);"
4931 )?;
4932 writeln!(
4933 code,
4934 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4935 )?;
4936 writeln!(
4937 code,
4938 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
4939 )?;
4940 writeln!(
4941 code,
4942 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4943 )?;
4944 writeln!(
4945 code,
4946 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4947 )?;
4948 writeln!(
4949 code,
4950 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4951 )?;
4952 writeln!(
4953 code,
4954 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4955 )?;
4956 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4957 writeln!(
4958 code,
4959 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4960 )?;
4961 writeln!(
4962 code,
4963 " enc.dispatch_thread_groups(grid_size, tg_size);"
4964 )?;
4965 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4966 writeln!(code, " }}")?;
4967 writeln!(code)?;
4968
4969 writeln!(
4971 code,
4972 " /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4973 )?;
4974 writeln!(
4975 code,
4976 " fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4977 )?;
4978 writeln!(code, " let sl: u32 = seq_len as u32;")?;
4979 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
4980 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
4981 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
4982 writeln!(
4983 code,
4984 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
4985 )?;
4986 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
4987 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
4988 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
4989 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
4990 writeln!(
4991 code,
4992 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4993 )?;
4994 writeln!(
4995 code,
4996 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4997 )?;
4998 writeln!(
4999 code,
5000 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5001 )?;
5002 writeln!(
5003 code,
5004 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5005 )?;
5006 writeln!(
5007 code,
5008 " // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
5009 )?;
5010 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5011 writeln!(
5012 code,
5013 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5014 )?;
5015 writeln!(
5016 code,
5017 " enc.dispatch_thread_groups(grid_size, tg_size);"
5018 )?;
5019 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5020 writeln!(code, " }}")?;
5021 writeln!(code)?;
5022
5023 writeln!(
5025 code,
5026 " /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
5027 )?;
5028 writeln!(
5029 code,
5030 " fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
5031 )?;
5032 writeln!(code, " let sl: u32 = seq_len as u32;")?;
5033 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
5034 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
5035 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
5036 writeln!(
5037 code,
5038 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
5039 )?;
5040 writeln!(
5041 code,
5042 " enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5043 )?;
5044 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
5045 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
5046 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
5047 writeln!(
5048 code,
5049 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5050 )?;
5051 writeln!(
5052 code,
5053 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5054 )?;
5055 writeln!(
5056 code,
5057 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5058 )?;
5059 writeln!(
5060 code,
5061 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5062 )?;
5063 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5064 writeln!(
5065 code,
5066 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5067 )?;
5068 writeln!(
5069 code,
5070 " enc.dispatch_thread_groups(grid_size, tg_size);"
5071 )?;
5072 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5073 writeln!(code, " }}")?;
5074 writeln!(code)?;
5075
5076 writeln!(code, " /// Dispatch fused SiLU-multiply kernel.")?;
5078 writeln!(
5079 code,
5080 " fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5081 )?;
5082 writeln!(code, " let count: u32 = n as u32;")?;
5083 writeln!(
5084 code,
5085 " enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5086 )?;
5087 writeln!(code, " enc.set_buffer(0, Some(gate), 0);")?;
5088 writeln!(code, " enc.set_buffer(1, Some(up), 0);")?;
5089 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5090 writeln!(
5091 code,
5092 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5093 )?;
5094 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5095 writeln!(
5096 code,
5097 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5098 )?;
5099 writeln!(
5100 code,
5101 " enc.dispatch_thread_groups(grid_size, tg_size);"
5102 )?;
5103 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5104 writeln!(code, " }}")?;
5105 writeln!(code)?;
5106
5107 writeln!(
5109 code,
5110 " /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5111 )?;
5112 writeln!(
5113 code,
5114 " /// gate_up_buf contains [gate(n), up(n)] contiguously."
5115 )?;
5116 writeln!(
5117 code,
5118 " fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5119 )?;
5120 writeln!(code, " let count: u32 = n as u32;")?;
5121 writeln!(
5122 code,
5123 " enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5124 )?;
5125 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5126 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5127 writeln!(
5128 code,
5129 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5130 )?;
5131 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5132 writeln!(
5133 code,
5134 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5135 )?;
5136 writeln!(
5137 code,
5138 " enc.dispatch_thread_groups(grid_size, tg_size);"
5139 )?;
5140 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5141 writeln!(code, " }}")?;
5142 writeln!(code)?;
5143
5144 writeln!(
5146 code,
5147 " /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5148 )?;
5149 writeln!(
5150 code,
5151 " fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5152 )?;
5153 writeln!(code, " let n: u32 = count as u32;")?;
5154 writeln!(
5155 code,
5156 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5157 )?;
5158 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5159 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5160 writeln!(
5161 code,
5162 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5163 )?;
5164 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5165 writeln!(
5166 code,
5167 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5168 )?;
5169 writeln!(
5170 code,
5171 " enc.dispatch_thread_groups(grid_size, tg_size);"
5172 )?;
5173 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5174 writeln!(code, " }}")?;
5175 writeln!(code)?;
5176
5177 writeln!(
5179 code,
5180 " /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5181 )?;
5182 writeln!(
5183 code,
5184 " /// Used for embedding table lookup (copy a specific row)."
5185 )?;
5186 writeln!(
5187 code,
5188 " fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5189 )?;
5190 writeln!(code, " let off: u32 = src_offset as u32;")?;
5191 writeln!(code, " let n: u32 = count as u32;")?;
5192 writeln!(
5193 code,
5194 " enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5195 )?;
5196 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5197 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5198 writeln!(
5199 code,
5200 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5201 )?;
5202 writeln!(
5203 code,
5204 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5205 )?;
5206 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5207 writeln!(
5208 code,
5209 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5210 )?;
5211 writeln!(
5212 code,
5213 " enc.dispatch_thread_groups(grid_size, tg_size);"
5214 )?;
5215 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5216 writeln!(code, " }}")?;
5217 writeln!(code)?;
5218
5219 writeln!(
5221 code,
5222 " /// Dispatch copy from source at byte offset to destination at float offset."
5223 )?;
5224 writeln!(
5225 code,
5226 " /// Used for KV cache updates from fused QKV buffer."
5227 )?;
5228 writeln!(
5229 code,
5230 " fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5231 )?;
5232 writeln!(code, " let n: u32 = count as u32;")?;
5233 writeln!(
5234 code,
5235 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5236 )?;
5237 writeln!(
5238 code,
5239 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5240 )?;
5241 writeln!(
5242 code,
5243 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5244 )?;
5245 writeln!(
5246 code,
5247 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5248 )?;
5249 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5250 writeln!(
5251 code,
5252 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5253 )?;
5254 writeln!(
5255 code,
5256 " enc.dispatch_thread_groups(grid_size, tg_size);"
5257 )?;
5258 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5259 writeln!(code, " }}")?;
5260 writeln!(code)?;
5261
5262 writeln!(
5264 code,
5265 " /// Dispatch f32 -> f16 copy from src at byte offset to half-typed dst at element offset."
5266 )?;
5267 writeln!(
5268 code,
5269 " /// Used for single-token decode KV cache updates (f32 QKV buf -> f16 KV cache)."
5270 )?;
5271 writeln!(
5272 code,
5273 " fn dispatch_copy_from_offset_f16(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_elem_offset: usize, count: usize) {{"
5274 )?;
5275 writeln!(code, " let n: u32 = count as u32;")?;
5276 writeln!(
5277 code,
5278 " enc.set_compute_pipeline_state(&self.copy_f32_to_f16_offset_pipeline);"
5279 )?;
5280 writeln!(
5281 code,
5282 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5283 )?;
5284 writeln!(
5285 code,
5286 " enc.set_buffer(1, Some(dst), (dst_elem_offset * 2) as u64);"
5287 )?;
5288 writeln!(
5289 code,
5290 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5291 )?;
5292 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5293 writeln!(
5294 code,
5295 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5296 )?;
5297 writeln!(
5298 code,
5299 " enc.dispatch_thread_groups(grid_size, tg_size);"
5300 )?;
5301 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5302 writeln!(code, " }}")?;
5303 writeln!(code)?;
5304
5305 writeln!(
5307 code,
5308 " /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5309 )?;
5310 writeln!(
5311 code,
5312 " /// Used for KV cache updates (write to a specific position in the cache)."
5313 )?;
5314 writeln!(
5315 code,
5316 " fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5317 )?;
5318 writeln!(code, " let n: u32 = count as u32;")?;
5319 writeln!(
5320 code,
5321 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5322 )?;
5323 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5324 writeln!(
5325 code,
5326 " enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5327 )?;
5328 writeln!(
5329 code,
5330 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5331 )?;
5332 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5333 writeln!(
5334 code,
5335 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5336 )?;
5337 writeln!(
5338 code,
5339 " enc.dispatch_thread_groups(grid_size, tg_size);"
5340 )?;
5341 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5342 writeln!(code, " }}")?;
5343 writeln!(code)?;
5344
5345 writeln!(
5347 code,
5348 " /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5349 )?;
5350 writeln!(
5351 code,
5352 " fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5353 )?;
5354 writeln!(code, " let count: u32 = n as u32;")?;
5355 writeln!(
5356 code,
5357 " enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5358 )?;
5359 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5360 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5361 writeln!(
5362 code,
5363 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5364 )?;
5365 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5366 writeln!(
5367 code,
5368 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5369 )?;
5370 writeln!(
5371 code,
5372 " enc.dispatch_thread_groups(grid_size, tg_size);"
5373 )?;
5374 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5375 writeln!(code, " }}")?;
5376 writeln!(code)?;
5377
5378 writeln!(code, " // ── Batched prefill dispatch helpers ──")?;
5380 writeln!(code)?;
5381
5382 writeln!(
5384 code,
5385 " /// Dispatch batched embedding lookup: copy M token embeddings at once."
5386 )?;
5387 writeln!(
5388 code,
5389 " fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5390 )?;
5391 writeln!(code, " let dim: u32 = HIDDEN_SIZE as u32;")?;
5392 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5393 writeln!(
5394 code,
5395 " enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5396 )?;
5397 writeln!(code, " enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5398 writeln!(
5399 code,
5400 " enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5401 )?;
5402 writeln!(
5403 code,
5404 " enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5405 )?;
5406 writeln!(
5407 code,
5408 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5409 )?;
5410 writeln!(
5411 code,
5412 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5413 )?;
5414 writeln!(code, " let total = num_tokens * HIDDEN_SIZE;")?;
5415 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5416 writeln!(
5417 code,
5418 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5419 )?;
5420 writeln!(
5421 code,
5422 " enc.dispatch_thread_groups(grid_size, tg_size);"
5423 )?;
5424 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5425 writeln!(code, " }}")?;
5426 writeln!(code)?;
5427
5428 writeln!(
5430 code,
5431 " /// Dispatch batched RMS norm: normalizes M vectors at once."
5432 )?;
5433 writeln!(
5434 code,
5435 " fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5436 )?;
5437 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
5438 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
5439 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5440 writeln!(
5441 code,
5442 " enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5443 )?;
5444 writeln!(code, " enc.set_buffer(0, Some(input), 0);")?;
5445 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
5446 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5447 writeln!(
5448 code,
5449 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5450 )?;
5451 writeln!(
5452 code,
5453 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5454 )?;
5455 writeln!(
5456 code,
5457 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5458 )?;
5459 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5460 writeln!(
5461 code,
5462 " let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5463 )?;
5464 writeln!(
5465 code,
5466 " enc.dispatch_thread_groups(grid_size, tg_size);"
5467 )?;
5468 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5469 writeln!(code, " }}")?;
5470 writeln!(code)?;
5471
5472 writeln!(
5474 code,
5475 " /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5476 )?;
5477 writeln!(
5478 code,
5479 " fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5480 )?;
5481 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5482 writeln!(code, " let r: u32 = rows as u32;")?;
5483 writeln!(code, " let c: u32 = cols as u32;")?;
5484 writeln!(
5485 code,
5486 " enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5487 )?;
5488 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5489 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5490 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5491 writeln!(
5492 code,
5493 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5494 )?;
5495 writeln!(
5496 code,
5497 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5498 )?;
5499 writeln!(
5500 code,
5501 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5502 )?;
5503 writeln!(
5504 code,
5505 " let row_tgs = (rows + 63) / 64; // 64 rows per threadgroup for f32"
5506 )?;
5507 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5508 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5509 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5510 writeln!(
5511 code,
5512 " enc.dispatch_thread_groups(grid_size, tg_size);"
5513 )?;
5514 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5515 writeln!(code, " }}")?;
5516 writeln!(code)?;
5517
5518 writeln!(
5520 code,
5521 " /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5522 )?;
5523 writeln!(code, " ///")?;
5524 writeln!(
5525 code,
5526 " /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5527 )?;
5528 writeln!(
5529 code,
5530 " /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5531 )?;
5532 writeln!(
5533 code,
5534 " fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5535 )?;
5536 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5537 writeln!(code, " let r: u32 = rows as u32;")?;
5538 writeln!(code, " let c: u32 = cols as u32;")?;
5539 writeln!(
5540 code,
5541 " // Tile sizes must match the Metal shader constants."
5542 )?;
5543 writeln!(code, " const TOKENS_PER_TG_Q8: usize = 4;")?;
5544 writeln!(code, " const MMA_TOK_TILE: usize = 16;")?;
5545 writeln!(code, " const MMA_ROW_TILE: usize = 16;")?;
5546 writeln!(code, " const MMA32_TOK_TILE: usize = 32;")?;
5547 writeln!(code, " const MMA32_ROW_TILE: usize = 32;")?;
5548 writeln!(
5549 code,
5550 " // Hardware matrix-multiply paths (simdgroup_matrix)."
5551 )?;
5552 writeln!(
5553 code,
5554 " // Prefer the large 32×32 tile when the problem supports it — halves"
5555 )?;
5556 writeln!(
5557 code,
5558 " // dispatch count and reuses each weight load across 32 tokens."
5559 )?;
5560 writeln!(
5561 code,
5562 " if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5563 )?;
5564 writeln!(
5565 code,
5566 " // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5567 )?;
5568 writeln!(
5569 code,
5570 " // It wins at moderate prefill lengths where the GPU is wave-starved,"
5571 )?;
5572 writeln!(
5573 code,
5574 " // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5575 )?;
5576 writeln!(
5577 code,
5578 " // case (135M / 360M). Switch at cols >= 2048 — a clean split that"
5579 )?;
5580 writeln!(
5581 code,
5582 " // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5583 )?;
5584 writeln!(
5585 code,
5586 " // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5587 )?;
5588 writeln!(
5589 code,
5590 " // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5591 )?;
5592 writeln!(
5593 code,
5594 " // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5595 )?;
5596 writeln!(code, " let use_h4 = cols >= 2048;")?;
5597 writeln!(code, " let pipe = if use_h4 {{")?;
5598 writeln!(code, " if num_tokens >= 256 {{")?;
5599 writeln!(
5600 code,
5601 " &self.matmul_q8_mma32_hh4_pipeline"
5602 )?;
5603 writeln!(code, " }} else {{")?;
5604 writeln!(
5605 code,
5606 " &self.matmul_q8_mma32_h4_pipeline"
5607 )?;
5608 writeln!(code, " }}")?;
5609 writeln!(code, " }} else {{")?;
5610 writeln!(code, " &self.matmul_q8_mma32_pipeline")?;
5611 writeln!(code, " }};")?;
5612 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
5613 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5614 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5615 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5616 writeln!(
5617 code,
5618 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5619 )?;
5620 writeln!(
5621 code,
5622 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5623 )?;
5624 writeln!(
5625 code,
5626 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5627 )?;
5628 writeln!(code, " let row_tgs = rows / MMA32_ROW_TILE;")?;
5629 writeln!(
5630 code,
5631 " let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5632 )?;
5633 writeln!(
5634 code,
5635 " let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5636 )?;
5637 writeln!(
5638 code,
5639 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5640 )?;
5641 writeln!(
5642 code,
5643 " enc.dispatch_thread_groups(grid_size, tg_size);"
5644 )?;
5645 writeln!(
5646 code,
5647 " }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5648 )?;
5649 writeln!(
5650 code,
5651 " enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5652 )?;
5653 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5654 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5655 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5656 writeln!(
5657 code,
5658 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5659 )?;
5660 writeln!(
5661 code,
5662 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5663 )?;
5664 writeln!(
5665 code,
5666 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5667 )?;
5668 writeln!(code, " let row_tgs = rows / MMA_ROW_TILE;")?;
5669 writeln!(
5670 code,
5671 " let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5672 )?;
5673 writeln!(
5674 code,
5675 " let tg_size = MTLSize::new(128, 1, 1); // 4 simdgroups × 32 lanes"
5676 )?;
5677 writeln!(
5678 code,
5679 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5680 )?;
5681 writeln!(
5682 code,
5683 " enc.dispatch_thread_groups(grid_size, tg_size);"
5684 )?;
5685 writeln!(code, " }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5686 writeln!(
5687 code,
5688 " enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5689 )?;
5690 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5691 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5692 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5693 writeln!(
5694 code,
5695 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5696 )?;
5697 writeln!(
5698 code,
5699 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5700 )?;
5701 writeln!(
5702 code,
5703 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5704 )?;
5705 writeln!(
5706 code,
5707 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5708 )?;
5709 writeln!(
5710 code,
5711 " let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5712 )?;
5713 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5714 writeln!(
5715 code,
5716 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5717 )?;
5718 writeln!(
5719 code,
5720 " enc.dispatch_thread_groups(grid_size, tg_size);"
5721 )?;
5722 writeln!(code, " }} else {{")?;
5723 writeln!(
5724 code,
5725 " enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5726 )?;
5727 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5728 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5729 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5730 writeln!(
5731 code,
5732 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5733 )?;
5734 writeln!(
5735 code,
5736 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5737 )?;
5738 writeln!(
5739 code,
5740 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5741 )?;
5742 writeln!(
5743 code,
5744 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5745 )?;
5746 writeln!(
5747 code,
5748 " let num_tg = (row_tgs * num_tokens) as u64;"
5749 )?;
5750 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5751 writeln!(
5752 code,
5753 " let grid_size = MTLSize::new(num_tg, 1, 1);"
5754 )?;
5755 writeln!(
5756 code,
5757 " enc.dispatch_thread_groups(grid_size, tg_size);"
5758 )?;
5759 writeln!(code, " }}")?;
5760 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5761 writeln!(code, " }}")?;
5762 writeln!(code)?;
5763
5764 writeln!(
5766 code,
5767 " /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5768 )?;
5769 writeln!(
5770 code,
5771 " fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5772 )?;
5773 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5774 writeln!(code, " let r: u32 = rows as u32;")?;
5775 writeln!(code, " let c: u32 = cols as u32;")?;
5776 writeln!(
5777 code,
5778 " enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5779 )?;
5780 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5781 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5782 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5783 writeln!(
5784 code,
5785 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5786 )?;
5787 writeln!(
5788 code,
5789 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5790 )?;
5791 writeln!(
5792 code,
5793 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5794 )?;
5795 writeln!(
5796 code,
5797 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q4"
5798 )?;
5799 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5800 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5801 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5802 writeln!(
5803 code,
5804 " enc.dispatch_thread_groups(grid_size, tg_size);"
5805 )?;
5806 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5807 writeln!(code, " }}")?;
5808 writeln!(code)?;
5809
5810 if config.qkv_bias {
5812 writeln!(
5813 code,
5814 " /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5815 )?;
5816 writeln!(
5817 code,
5818 " fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5819 )?;
5820 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5821 writeln!(code, " let r: u32 = rows as u32;")?;
5822 writeln!(
5823 code,
5824 " enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5825 )?;
5826 writeln!(code, " enc.set_buffer(0, Some(out), 0);")?;
5827 writeln!(code, " enc.set_buffer(1, Some(bias), 0);")?;
5828 writeln!(
5829 code,
5830 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5831 )?;
5832 writeln!(
5833 code,
5834 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5835 )?;
5836 writeln!(code, " let total = (num_tokens * rows) as u64;")?;
5837 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5838 writeln!(
5839 code,
5840 " let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5841 )?;
5842 writeln!(
5843 code,
5844 " enc.dispatch_thread_groups(grid_size, tg_size);"
5845 )?;
5846 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5847 writeln!(code, " }}")?;
5848 writeln!(code)?;
5849 }
5850
5851 writeln!(
5853 code,
5854 " /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5855 )?;
5856 writeln!(
5857 code,
5858 " /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5859 )?;
5860 writeln!(
5861 code,
5862 " /// `row_stride` is the number of floats per token row in the batch buffer."
5863 )?;
5864 writeln!(
5865 code,
5866 " fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5867 )?;
5868 writeln!(code, " let nh: u32 = num_heads as u32;")?;
5869 writeln!(code, " let hd: u32 = head_dim as u32;")?;
5870 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
5871 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5872 writeln!(
5873 code,
5874 " let pairs_per_token = num_heads * (head_dim / 2);"
5875 )?;
5876 writeln!(
5877 code,
5878 " let total_pairs = num_tokens * pairs_per_token;"
5879 )?;
5880 writeln!(
5893 code,
5894 " // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5895 )?;
5896 writeln!(code, " for t in 0..num_tokens {{")?;
5897 writeln!(
5898 code,
5899 " let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5900 )?;
5901 writeln!(
5902 code,
5903 " let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5904 )?;
5905 writeln!(
5906 code,
5907 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
5908 )?;
5909 writeln!(
5910 code,
5911 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
5912 )?;
5913 writeln!(
5914 code,
5915 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5916 )?;
5917 writeln!(
5918 code,
5919 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5920 )?;
5921 writeln!(
5922 code,
5923 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5924 )?;
5925 writeln!(
5926 code,
5927 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5928 )?;
5929 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5930 writeln!(
5931 code,
5932 " let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5933 )?;
5934 writeln!(
5935 code,
5936 " enc.dispatch_thread_groups(grid_size, tg_size);"
5937 )?;
5938 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5939 writeln!(code, " }}")?;
5940 writeln!(code, " }}")?;
5941 writeln!(code)?;
5942
5943 writeln!(
5945 code,
5946 " /// Dispatch batched fused SiLU-multiply for M tokens."
5947 )?;
5948 writeln!(
5949 code,
5950 " fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5951 )?;
5952 writeln!(code, " let count: u32 = n as u32;")?;
5953 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5954 writeln!(
5955 code,
5956 " enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5957 )?;
5958 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5959 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5960 writeln!(
5961 code,
5962 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5963 )?;
5964 writeln!(
5965 code,
5966 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5967 )?;
5968 writeln!(code, " let total = num_tokens * n;")?;
5969 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5970 writeln!(
5971 code,
5972 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5973 )?;
5974 writeln!(
5975 code,
5976 " enc.dispatch_thread_groups(grid_size, tg_size);"
5977 )?;
5978 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5979 writeln!(code, " }}")?;
5980 writeln!(code)?;
5981
5982 writeln!(
5984 code,
5985 " /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5986 )?;
5987 writeln!(
5988 code,
5989 " fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5990 )?;
5991 writeln!(code, " let count: u32 = total_n as u32;")?;
5992 writeln!(
5993 code,
5994 " enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5995 )?;
5996 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5997 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5998 writeln!(
5999 code,
6000 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
6001 )?;
6002 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6003 writeln!(
6004 code,
6005 " let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
6006 )?;
6007 writeln!(
6008 code,
6009 " enc.dispatch_thread_groups(grid_size, tg_size);"
6010 )?;
6011 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6012 writeln!(code, " }}")?;
6013 writeln!(code)?;
6014
6015 writeln!(
6017 code,
6018 " /// Copy src to dst using compute copy kernel (for batch residual init)."
6019 )?;
6020 writeln!(
6021 code,
6022 " fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
6023 )?;
6024 writeln!(code, " let n: u32 = count as u32;")?;
6025 writeln!(
6026 code,
6027 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6028 )?;
6029 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6030 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
6031 writeln!(
6032 code,
6033 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6034 )?;
6035 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6036 writeln!(
6037 code,
6038 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6039 )?;
6040 writeln!(
6041 code,
6042 " enc.dispatch_thread_groups(grid_size, tg_size);"
6043 )?;
6044 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6045 writeln!(code, " }}")?;
6046 writeln!(code)?;
6047
6048 writeln!(
6050 code,
6051 " /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
6052 )?;
6053 writeln!(
6054 code,
6055 " fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6056 )?;
6057 writeln!(code, " let n: u32 = count as u32;")?;
6058 writeln!(
6059 code,
6060 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6061 )?;
6062 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6063 writeln!(
6064 code,
6065 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6066 )?;
6067 writeln!(
6068 code,
6069 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6070 )?;
6071 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6072 writeln!(
6073 code,
6074 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6075 )?;
6076 writeln!(
6077 code,
6078 " enc.dispatch_thread_groups(grid_size, tg_size);"
6079 )?;
6080 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6081 writeln!(code, " }}")?;
6082 writeln!(code)?;
6083
6084 writeln!(
6086 code,
6087 " /// Copy from src at byte offset to dst at float offset."
6088 )?;
6089 writeln!(
6090 code,
6091 " fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6092 )?;
6093 writeln!(code, " let n: u32 = count as u32;")?;
6094 writeln!(
6095 code,
6096 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6097 )?;
6098 writeln!(
6099 code,
6100 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6101 )?;
6102 writeln!(
6103 code,
6104 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6105 )?;
6106 writeln!(
6107 code,
6108 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6109 )?;
6110 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6111 writeln!(
6112 code,
6113 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6114 )?;
6115 writeln!(
6116 code,
6117 " enc.dispatch_thread_groups(grid_size, tg_size);"
6118 )?;
6119 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6120 writeln!(code, " }}")?;
6121 writeln!(code)?;
6122
6123 writeln!(
6125 code,
6126 " /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6127 )?;
6128 writeln!(
6129 code,
6130 " fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6131 )?;
6132 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6133 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6134 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6135 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6136 writeln!(code, " let so: u32 = src_offset as u32;")?;
6137 writeln!(
6138 code,
6139 " enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6140 )?;
6141 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6142 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
6143 writeln!(
6144 code,
6145 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6146 )?;
6147 writeln!(
6148 code,
6149 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6150 )?;
6151 writeln!(
6152 code,
6153 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6154 )?;
6155 writeln!(
6156 code,
6157 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6158 )?;
6159 writeln!(
6160 code,
6161 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6162 )?;
6163 writeln!(code, " let total = num_tokens * kv_dim;")?;
6164 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6165 writeln!(
6166 code,
6167 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6168 )?;
6169 writeln!(
6170 code,
6171 " enc.dispatch_thread_groups(grid_size, tg_size);"
6172 )?;
6173 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6174 writeln!(code, " }}")?;
6175 writeln!(code)?;
6176
6177 writeln!(
6179 code,
6180 " /// Dispatch batched causal attention: one dispatch for all M tokens."
6181 )?;
6182 writeln!(
6183 code,
6184 " fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6185 )?;
6186 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6187 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6188 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
6189 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6190 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6191 writeln!(code, " let qs: u32 = q_stride as u32;")?;
6192 writeln!(code, " let max_seq = base_pos + num_tokens;")?;
6202 writeln!(code, " let _ = max_seq;")?;
6203 writeln!(
6204 code,
6205 " let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
6206 )?;
6207 writeln!(code, " .map(|v| v == \"0\").unwrap_or(false);")?;
6208 writeln!(
6209 code,
6210 " let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
6211 )?;
6212 writeln!(code, " if use_mma_flash {{")?;
6213 writeln!(
6214 code,
6215 " let pipe = &self.attention_mma_flash_batch_pipeline;"
6216 )?;
6217 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6218 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6219 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6220 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6221 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6222 writeln!(
6223 code,
6224 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6225 )?;
6226 writeln!(
6227 code,
6228 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6229 )?;
6230 writeln!(
6231 code,
6232 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6233 )?;
6234 writeln!(
6235 code,
6236 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6237 )?;
6238 writeln!(
6239 code,
6240 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6241 )?;
6242 writeln!(
6243 code,
6244 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6245 )?;
6246 writeln!(
6247 code,
6248 " // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6249 )?;
6250 writeln!(code, " let tg_size = MTLSize::new(128, 1, 1);")?;
6251 writeln!(
6252 code,
6253 " let q_blocks = ((num_tokens + 7) / 8) as u64;"
6254 )?;
6255 writeln!(
6256 code,
6257 " let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6258 )?;
6259 writeln!(
6260 code,
6261 " enc.dispatch_thread_groups(grid_size, tg_size);"
6262 )?;
6263 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6264 writeln!(code, " return;")?;
6265 writeln!(code, " }}")?;
6266 writeln!(code, " let pipe = &self.attention_batch_pipeline;")?;
6267 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6268 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6269 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6270 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6271 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6272 writeln!(
6273 code,
6274 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6275 )?;
6276 writeln!(
6277 code,
6278 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6279 )?;
6280 writeln!(
6281 code,
6282 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6283 )?;
6284 writeln!(
6285 code,
6286 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6287 )?;
6288 writeln!(
6289 code,
6290 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6291 )?;
6292 writeln!(
6293 code,
6294 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6295 )?;
6296 writeln!(
6297 code,
6298 " // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6299 )?;
6300 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6301 writeln!(
6302 code,
6303 " let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6304 )?;
6305 writeln!(
6306 code,
6307 " enc.dispatch_thread_groups(grid_size, tg_size);"
6308 )?;
6309 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6310 writeln!(code, " }}")?;
6311 writeln!(code)?;
6312
6313 writeln!(
6315 code,
6316 " /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6317 )?;
6318 writeln!(
6319 code,
6320 " /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6321 )?;
6322 writeln!(
6323 code,
6324 " fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6325 )?;
6326 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6327 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6328 writeln!(code, " let nq: u32 = NUM_HEADS as u32;")?;
6329 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6330 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6331 writeln!(code, " let qs: u32 = qkv_stride as u32;")?;
6332 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
6333 writeln!(
6334 code,
6335 " enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6336 )?;
6337 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
6338 writeln!(
6339 code,
6340 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6341 )?;
6342 writeln!(
6343 code,
6344 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6345 )?;
6346 writeln!(
6347 code,
6348 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6349 )?;
6350 writeln!(
6351 code,
6352 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6353 )?;
6354 writeln!(
6355 code,
6356 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6357 )?;
6358 writeln!(
6359 code,
6360 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6361 )?;
6362 writeln!(
6363 code,
6364 " enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6365 )?;
6366 writeln!(
6367 code,
6368 " let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6369 )?;
6370 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6371 writeln!(
6372 code,
6373 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6374 )?;
6375 writeln!(
6376 code,
6377 " enc.dispatch_thread_groups(grid_size, tg_size);"
6378 )?;
6379 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6380 writeln!(code, " }}")?;
6381 writeln!(code)?;
6382
6383 writeln!(
6385 code,
6386 " /// Dispatch fused K+V cache copy in one kernel launch."
6387 )?;
6388 writeln!(
6389 code,
6390 " /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6391 )?;
6392 writeln!(
6393 code,
6394 " fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6395 )?;
6396 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6397 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6398 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6399 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6400 writeln!(code, " let ko: u32 = k_offset as u32;")?;
6401 writeln!(code, " let vo: u32 = v_offset as u32;")?;
6402 writeln!(
6403 code,
6404 " enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6405 )?;
6406 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6407 writeln!(code, " enc.set_buffer(1, Some(k_dst), 0);")?;
6408 writeln!(code, " enc.set_buffer(2, Some(v_dst), 0);")?;
6409 writeln!(
6410 code,
6411 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6412 )?;
6413 writeln!(
6414 code,
6415 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6416 )?;
6417 writeln!(
6418 code,
6419 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6420 )?;
6421 writeln!(
6422 code,
6423 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6424 )?;
6425 writeln!(
6426 code,
6427 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6428 )?;
6429 writeln!(
6430 code,
6431 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6432 )?;
6433 writeln!(
6434 code,
6435 " let total = num_tokens * kv_dim * 2; // K + V"
6436 )?;
6437 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6438 writeln!(
6439 code,
6440 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6441 )?;
6442 writeln!(
6443 code,
6444 " enc.dispatch_thread_groups(grid_size, tg_size);"
6445 )?;
6446 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6447 writeln!(code, " }}")?;
6448
6449 writeln!(code, "}}")?;
6450 writeln!(code)?;
6451
6452 Ok(())
6453}
6454
6455fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6456 writeln!(
6457 code,
6458 "// ── Helper functions ──────────────────────────────────"
6459 )?;
6460 writeln!(code)?;
6461 writeln!(
6462 code,
6463 "/// Create a compute pipeline from a named function in the Metal library."
6464 )?;
6465 writeln!(
6466 code,
6467 "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6468 )?;
6469 writeln!(
6470 code,
6471 " let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6472 )?;
6473 writeln!(
6474 code,
6475 " device.new_compute_pipeline_state_with_function(&func)"
6476 )?;
6477 writeln!(
6478 code,
6479 " .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6480 )?;
6481 writeln!(code, "}}")?;
6482 writeln!(code)?;
6483
6484 Ok(())
6485}
6486
6487fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6492 let _sanitized = sanitize_name(model_name);
6493 let _vocab = config.vocab_size;
6494
6495 let mut code = String::with_capacity(16 * 1024);
6496 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6497 writeln!(
6498 code,
6499 "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6500 )?;
6501 writeln!(code)?;
6502 writeln!(code, "mod model;")?;
6503 writeln!(code)?;
6504 writeln!(code, "use std::io::Write;")?;
6505 writeln!(code, "use std::time::Instant;")?;
6506 writeln!(code, "use serde::Deserialize;")?;
6507 writeln!(code)?;
6508
6509 writeln!(code, "fn main() {{")?;
6511 writeln!(
6512 code,
6513 " let args: Vec<String> = std::env::args().collect();"
6514 )?;
6515 writeln!(code)?;
6516 writeln!(
6517 code,
6518 " // Detect --serve mode (only requires weights + tokenizer)"
6519 )?;
6520 writeln!(
6521 code,
6522 " let serve_mode = args.iter().any(|a| a == \"--serve\");"
6523 )?;
6524 writeln!(code)?;
6525 writeln!(code, " if !serve_mode && args.len() < 4 {{")?;
6526 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6527 writeln!(code, " eprintln!(\" {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6528 writeln!(code, " std::process::exit(1);")?;
6529 writeln!(code, " }}")?;
6530 writeln!(code)?;
6531 writeln!(code, " if serve_mode && args.len() < 3 {{")?;
6532 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6533 writeln!(code, " std::process::exit(1);")?;
6534 writeln!(code, " }}")?;
6535 writeln!(code)?;
6536 writeln!(code, " let weights_path = &args[1];")?;
6537 writeln!(code, " let tokenizer_path = &args[2];")?;
6538 writeln!(code)?;
6539 writeln!(code, " // Parse optional flags")?;
6540 writeln!(code, " let mut max_tokens: usize = 128;")?;
6541 writeln!(code, " let mut port: u16 = 8080;")?;
6542 writeln!(
6543 code,
6544 " let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6545 )?;
6546 writeln!(
6547 code,
6548 " let profile = args.iter().any(|a| a == \"--profile\");"
6549 )?;
6550 writeln!(code, " let mut i = 3;")?;
6551 writeln!(code, " while i < args.len() {{")?;
6552 writeln!(
6553 code,
6554 " if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6555 )?;
6556 writeln!(
6557 code,
6558 " max_tokens = args[i + 1].parse().unwrap_or(128);"
6559 )?;
6560 writeln!(code, " i += 2;")?;
6561 writeln!(
6562 code,
6563 " }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6564 )?;
6565 writeln!(
6566 code,
6567 " port = args[i + 1].parse().unwrap_or(8080);"
6568 )?;
6569 writeln!(code, " i += 2;")?;
6570 writeln!(code, " }} else if args[i] == \"--serve\" {{")?;
6571 writeln!(code, " i += 1;")?;
6572 writeln!(code, " }} else if args[i] == \"--profile\" {{")?;
6573 writeln!(code, " i += 1;")?;
6574 writeln!(code, " }} else {{")?;
6575 writeln!(code, " i += 1;")?;
6576 writeln!(code, " }}")?;
6577 writeln!(code, " }}")?;
6578 writeln!(code)?;
6579
6580 writeln!(
6582 code,
6583 " // Memory-map weights for zero-copy loading on Apple Silicon"
6584 )?;
6585 writeln!(
6586 code,
6587 " let weights_file = std::fs::File::open(weights_path)"
6588 )?;
6589 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6590 writeln!(
6591 code,
6592 " let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6593 )?;
6594 writeln!(code)?;
6595 writeln!(code, " // Load tokenizer")?;
6596 writeln!(
6597 code,
6598 " let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6599 )?;
6600 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6601 writeln!(code)?;
6602 writeln!(code, " // Create Metal model")?;
6603 writeln!(code, " eprintln!(\"Loading model onto Metal GPU...\");")?;
6604 writeln!(
6605 code,
6606 " let mut model = model::MetalModel::new(&weights_mmap);"
6607 )?;
6608 writeln!(code)?;
6609
6610 writeln!(code, " if serve_mode {{")?;
6612 writeln!(code, " serve(model, tokenizer, port);")?;
6613 writeln!(code, " }} else {{")?;
6614 writeln!(code, " let prompt = &args[3];")?;
6615 writeln!(
6616 code,
6617 " cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6618 )?;
6619 writeln!(code, " }}")?;
6620 writeln!(code, "}}")?;
6621 writeln!(code)?;
6622
6623 writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6625 writeln!(code, " // Tokenize prompt")?;
6626 writeln!(code, " let encoding = tokenizer.encode(prompt, true)")?;
6627 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6628 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6629 writeln!(code)?;
6630 writeln!(
6631 code,
6632 " // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6633 )?;
6634 writeln!(
6635 code,
6636 " // Uses double-buffered batch dispatch for GPU-efficient matmul."
6637 )?;
6638 writeln!(
6639 code,
6640 " // The last token uses synchronous forward() to get logits."
6641 )?;
6642 writeln!(code, " let prompt_len = prompt_tokens.len();")?;
6643 writeln!(code, " let prefill_start = Instant::now();")?;
6644 writeln!(code, " let logits = if prompt_len > 1 {{")?;
6645 writeln!(
6646 code,
6647 " model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6648 )?;
6649 writeln!(code, " model.forward(prompt_tokens[prompt_len - 1])")?;
6650 writeln!(code, " }} else {{")?;
6651 writeln!(code, " model.forward(prompt_tokens[0])")?;
6652 writeln!(code, " }};")?;
6653 writeln!(
6654 code,
6655 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6656 )?;
6657 writeln!(code, " let prefill_tokens = prompt_tokens.len();")?;
6658 writeln!(
6659 code,
6660 " eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6661 )?;
6662 writeln!(
6663 code,
6664 " prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6665 )?;
6666 writeln!(code)?;
6667 writeln!(code, " // Generate tokens")?;
6668 writeln!(code, " let mut next_token = argmax(&logits);")?;
6669 writeln!(code, " let gen_start = Instant::now();")?;
6670 writeln!(code, " let mut generated_count: usize = 0;")?;
6671 writeln!(code)?;
6672 writeln!(code, " for _ in 0..max_tokens {{")?;
6673 writeln!(
6674 code,
6675 " if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6676 )?;
6677 writeln!(code, " if !quiet {{")?;
6678 writeln!(code, " print!(\"{{}}\", text);")?;
6679 writeln!(code, " std::io::stdout().flush().ok();")?;
6680 writeln!(code, " }}")?;
6681 writeln!(code, " }}")?;
6682 writeln!(code, " generated_count += 1;")?;
6683 writeln!(code)?;
6684 writeln!(
6685 code,
6686 " // Use profiling forward for first token when --profile is set"
6687 )?;
6688 writeln!(
6689 code,
6690 " let logits = if profile && generated_count == 1 {{"
6691 )?;
6692 writeln!(code, " model.forward_profile(next_token)")?;
6693 writeln!(code, " }} else {{")?;
6694 writeln!(code, " model.forward(next_token)")?;
6695 writeln!(code, " }};")?;
6696 writeln!(code, " next_token = argmax(&logits);")?;
6697 writeln!(code)?;
6698 writeln!(code, " // Stop on EOS (token 2 for most models)")?;
6699 writeln!(code, " if next_token == 2 {{")?;
6700 writeln!(code, " break;")?;
6701 writeln!(code, " }}")?;
6702 writeln!(code)?;
6703 writeln!(
6704 code,
6705 " // Yield between tokens to reduce sustained GPU thermal load."
6706 )?;
6707 writeln!(
6708 code,
6709 " // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6710 )?;
6711 writeln!(
6712 code,
6713 " // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6714 )?;
6715 writeln!(
6716 code,
6717 " // briefly, providing a micro-break that helps sustain peak throughput."
6718 )?;
6719 writeln!(code, " std::thread::yield_now();")?;
6720 writeln!(code, " }}")?;
6721 writeln!(code, " if !quiet {{")?;
6722 writeln!(code, " println!();")?;
6723 writeln!(code, " }}")?;
6724 writeln!(
6725 code,
6726 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6727 )?;
6728 writeln!(
6729 code,
6730 " eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6731 )?;
6732 writeln!(
6733 code,
6734 " generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6735 )?;
6736 writeln!(code, "}}")?;
6737 writeln!(code)?;
6738
6739 writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6741 writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6742 writeln!(code, " logits.iter()")?;
6743 writeln!(code, " .enumerate()")?;
6744 writeln!(
6745 code,
6746 " .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6747 )?;
6748 writeln!(code, " .map(|(i, _)| i as u32)")?;
6749 writeln!(code, " .unwrap_or(0)")?;
6750 writeln!(code, "}}")?;
6751 writeln!(code)?;
6752
6753 writeln!(
6755 code,
6756 "// -----------------------------------------------------------------------"
6757 )?;
6758 writeln!(code, "// OpenAI-compatible API server")?;
6759 writeln!(
6760 code,
6761 "// -----------------------------------------------------------------------"
6762 )?;
6763 writeln!(code)?;
6764 writeln!(code, "#[derive(Deserialize)]")?;
6765 writeln!(code, "struct ChatRequest {{")?;
6766 writeln!(code, " messages: Vec<ChatMessage>,")?;
6767 writeln!(code, " #[serde(default)]")?;
6768 writeln!(code, " stream: Option<bool>,")?;
6769 writeln!(code, " #[serde(default)]")?;
6770 writeln!(code, " max_tokens: Option<usize>,")?;
6771 writeln!(code, " #[serde(default)]")?;
6772 writeln!(code, " temperature: Option<f32>,")?;
6773 writeln!(code, " #[serde(default)]")?;
6774 writeln!(code, " model: Option<String>,")?;
6775 writeln!(code, "}}")?;
6776 writeln!(code)?;
6777 writeln!(code, "#[derive(Deserialize)]")?;
6778 writeln!(code, "struct ChatMessage {{")?;
6779 writeln!(code, " role: String,")?;
6780 writeln!(code, " content: String,")?;
6781 writeln!(code, "}}")?;
6782 writeln!(code)?;
6783
6784 writeln!(
6786 code,
6787 "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6788 )?;
6789 writeln!(code, " let mut prompt = String::new();")?;
6790 writeln!(code, " for msg in messages {{")?;
6791 writeln!(code, " prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6792 writeln!(code, " }}")?;
6793 writeln!(code, " prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6794 writeln!(code, " prompt")?;
6795 writeln!(code, "}}")?;
6796 writeln!(code)?;
6797
6798 writeln!(
6800 code,
6801 "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6802 )?;
6803 writeln!(code, " let len = tokens.len();")?;
6804 writeln!(code, " if len > 1 {{")?;
6805 writeln!(
6806 code,
6807 " model.forward_prefill_batch(&tokens[..len - 1]);"
6808 )?;
6809 writeln!(code, " }}")?;
6810 writeln!(code, " model.forward(tokens[len - 1])")?;
6811 writeln!(code, "}}")?;
6812 writeln!(code)?;
6813
6814 writeln!(
6816 code,
6817 "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6818 )?;
6819 writeln!(code, " let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6820 writeln!(code, " let server = tiny_http::Server::http(&addr)")?;
6821 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6822 writeln!(
6823 code,
6824 " eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6825 )?;
6826 writeln!(code, " eprintln!(\"Endpoints:\");")?;
6827 writeln!(code, " eprintln!(\" POST /v1/chat/completions\");")?;
6828 writeln!(code, " eprintln!(\" GET /v1/models\");")?;
6829 writeln!(code, " eprintln!(\" GET /health\");")?;
6830 writeln!(code)?;
6831 writeln!(code, " for request in server.incoming_requests() {{")?;
6832 writeln!(code, " let method = request.method().to_string();")?;
6833 writeln!(code, " let url = request.url().to_string();")?;
6834 writeln!(code)?;
6835 writeln!(code, " match (method.as_str(), url.as_str()) {{")?;
6836
6837 writeln!(
6839 code,
6840 " (\"POST\", \"/v1/chat/completions\") => {{"
6841 )?;
6842 writeln!(
6843 code,
6844 " handle_chat_completion(&mut model, &tokenizer, request);"
6845 )?;
6846 writeln!(code, " }}")?;
6847
6848 writeln!(code, " (\"GET\", \"/v1/models\") => {{")?;
6850 writeln!(code, " let body = serde_json::json!({{")?;
6851 writeln!(code, " \"object\": \"list\",")?;
6852 writeln!(code, " \"data\": [{{")?;
6853 writeln!(code, " \"id\": \"forgellm-metal\",")?;
6854 writeln!(code, " \"object\": \"model\",")?;
6855 writeln!(code, " \"owned_by\": \"forgellm\"")?;
6856 writeln!(code, " }}]")?;
6857 writeln!(code, " }});")?;
6858 writeln!(
6859 code,
6860 " let resp = tiny_http::Response::from_string(body.to_string())"
6861 )?;
6862 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6863 writeln!(code, " request.respond(resp).ok();")?;
6864 writeln!(code, " }}")?;
6865
6866 writeln!(code, " (\"GET\", \"/health\") => {{")?;
6868 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6869 writeln!(code, " request.respond(resp).ok();")?;
6870 writeln!(code, " }}")?;
6871
6872 writeln!(code, " _ => {{")?;
6874 writeln!(
6875 code,
6876 " let resp = tiny_http::Response::from_string(\"Not Found\")"
6877 )?;
6878 writeln!(code, " .with_status_code(404);")?;
6879 writeln!(code, " request.respond(resp).ok();")?;
6880 writeln!(code, " }}")?;
6881 writeln!(code, " }}")?;
6882 writeln!(code, " }}")?;
6883 writeln!(code, "}}")?;
6884 writeln!(code)?;
6885
6886 writeln!(code, "fn handle_chat_completion(")?;
6888 writeln!(code, " model: &mut model::MetalModel,")?;
6889 writeln!(code, " tokenizer: &tokenizers::Tokenizer,")?;
6890 writeln!(code, " mut request: tiny_http::Request,")?;
6891 writeln!(code, ") {{")?;
6892 writeln!(code, " // Read request body")?;
6893 writeln!(code, " let mut body = String::new();")?;
6894 writeln!(
6895 code,
6896 " if request.as_reader().read_to_string(&mut body).is_err() {{"
6897 )?;
6898 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6899 writeln!(code, " .with_status_code(400);")?;
6900 writeln!(code, " request.respond(resp).ok();")?;
6901 writeln!(code, " return;")?;
6902 writeln!(code, " }}")?;
6903 writeln!(code)?;
6904 writeln!(code, " // Parse JSON")?;
6905 writeln!(
6906 code,
6907 " let req: ChatRequest = match serde_json::from_str(&body) {{"
6908 )?;
6909 writeln!(code, " Ok(r) => r,")?;
6910 writeln!(code, " Err(e) => {{")?;
6911 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6912 writeln!(code, " .with_status_code(400);")?;
6913 writeln!(code, " request.respond(resp).ok();")?;
6914 writeln!(code, " return;")?;
6915 writeln!(code, " }}")?;
6916 writeln!(code, " }};")?;
6917 writeln!(code)?;
6918 writeln!(
6919 code,
6920 " let prompt = format_chat_messages(&req.messages);"
6921 )?;
6922 writeln!(
6923 code,
6924 " let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6925 )?;
6926 writeln!(code, " Ok(e) => e,")?;
6927 writeln!(code, " Err(e) => {{")?;
6928 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6929 writeln!(code, " .with_status_code(500);")?;
6930 writeln!(code, " request.respond(resp).ok();")?;
6931 writeln!(code, " return;")?;
6932 writeln!(code, " }}")?;
6933 writeln!(code, " }};")?;
6934 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6935 writeln!(code, " let stream = req.stream.unwrap_or(false);")?;
6936 writeln!(code, " let max_tokens = req.max_tokens.unwrap_or(256);")?;
6937 writeln!(
6938 code,
6939 " let _temperature = req.temperature.unwrap_or(1.0);"
6940 )?;
6941 writeln!(code)?;
6942
6943 writeln!(code, " model.reset();")?;
6945 writeln!(code)?;
6946
6947 writeln!(code, " let prefill_start = Instant::now();")?;
6949 writeln!(code, " let logits = prefill(model, prompt_tokens);")?;
6950 writeln!(
6951 code,
6952 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6953 )?;
6954 writeln!(code, " let prefill_count = prompt_tokens.len();")?;
6955 writeln!(code, " let mut next_token = argmax(&logits);")?;
6956 writeln!(code)?;
6957
6958 writeln!(code, " if stream {{")?;
6959
6960 writeln!(
6962 code,
6963 " // SSE streaming: generate tokens and build SSE body"
6964 )?;
6965 writeln!(code, " let gen_start = Instant::now();")?;
6966 writeln!(code, " let mut generated_count: usize = 0;")?;
6967 writeln!(code, " let mut sse_body = String::new();")?;
6968 writeln!(code, " for _ in 0..max_tokens {{")?;
6969 writeln!(code, " if next_token == 2 {{ break; }}")?;
6970 writeln!(
6971 code,
6972 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6973 )?;
6974 writeln!(
6975 code,
6976 " let escaped = serde_json::to_string(&text).unwrap_or_default();"
6977 )?;
6978 writeln!(
6979 code,
6980 " // escaped includes surrounding quotes, strip them"
6981 )?;
6982 writeln!(
6983 code,
6984 " let inner = &escaped[1..escaped.len()-1];"
6985 )?;
6986 writeln!(code, " sse_body.push_str(&format!(")?;
6987 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6988 writeln!(code, " inner")?;
6989 writeln!(code, " ));")?;
6990 writeln!(code, " }}")?;
6991 writeln!(code, " generated_count += 1;")?;
6992 writeln!(code, " let logits = model.forward(next_token);")?;
6993 writeln!(code, " next_token = argmax(&logits);")?;
6994 writeln!(code, " }}")?;
6995 writeln!(
6996 code,
6997 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6998 )?;
6999 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7000 writeln!(code, " let gen_time_ms = gen_elapsed * 1000.0;")?;
7001 writeln!(code)?;
7002 writeln!(
7003 code,
7004 " // Final chunk with finish_reason, timing, and DONE sentinel"
7005 )?;
7006 writeln!(code, " sse_body.push_str(&format!(")?;
7007 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
7008 writeln!(code, " prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
7009 writeln!(code, " ));")?;
7010 writeln!(code)?;
7011 writeln!(
7012 code,
7013 " let resp = tiny_http::Response::from_string(sse_body)"
7014 )?;
7015 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
7016 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
7017 writeln!(code, " request.respond(resp).ok();")?;
7018
7019 writeln!(code, " }} else {{")?;
7020
7021 writeln!(
7023 code,
7024 " // Non-streaming: generate all tokens, return JSON"
7025 )?;
7026 writeln!(code, " let gen_start = Instant::now();")?;
7027 writeln!(code, " let mut generated_count: usize = 0;")?;
7028 writeln!(code, " let mut generated = String::new();")?;
7029 writeln!(code, " for _ in 0..max_tokens {{")?;
7030 writeln!(code, " if next_token == 2 {{ break; }}")?;
7031 writeln!(
7032 code,
7033 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
7034 )?;
7035 writeln!(code, " generated.push_str(&text);")?;
7036 writeln!(code, " }}")?;
7037 writeln!(code, " generated_count += 1;")?;
7038 writeln!(code, " let logits = model.forward(next_token);")?;
7039 writeln!(code, " next_token = argmax(&logits);")?;
7040 writeln!(code, " }}")?;
7041 writeln!(
7042 code,
7043 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
7044 )?;
7045 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7046 writeln!(code)?;
7047 writeln!(code, " let resp_json = serde_json::json!({{")?;
7048 writeln!(code, " \"id\": \"chatcmpl-1\",")?;
7049 writeln!(code, " \"object\": \"chat.completion\",")?;
7050 writeln!(code, " \"choices\": [{{")?;
7051 writeln!(code, " \"index\": 0,")?;
7052 writeln!(code, " \"message\": {{")?;
7053 writeln!(code, " \"role\": \"assistant\",")?;
7054 writeln!(code, " \"content\": generated")?;
7055 writeln!(code, " }},")?;
7056 writeln!(code, " \"finish_reason\": \"stop\"")?;
7057 writeln!(code, " }}],")?;
7058 writeln!(code, " \"usage\": {{")?;
7059 writeln!(code, " \"prefill_tokens\": prefill_count,")?;
7060 writeln!(
7061 code,
7062 " \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
7063 )?;
7064 writeln!(
7065 code,
7066 " \"generation_tokens\": generated_count,"
7067 )?;
7068 writeln!(
7069 code,
7070 " \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
7071 )?;
7072 writeln!(code, " \"tokens_per_sec\": gen_tok_s")?;
7073 writeln!(code, " }}")?;
7074 writeln!(code, " }});")?;
7075 writeln!(
7076 code,
7077 " let resp = tiny_http::Response::from_string(resp_json.to_string())"
7078 )?;
7079 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7080 writeln!(code, " request.respond(resp).ok();")?;
7081 writeln!(code, " }}")?;
7082 writeln!(code, "}}")?;
7083
7084 Ok(code)
7085}
7086
7087#[cfg(test)]
7092mod tests {
7093 use super::*;
7094 use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7095
7096 fn minimal_config() -> ModelConfig {
7097 ModelConfig {
7098 architecture: Architecture::Llama,
7099 hidden_size: 64,
7100 intermediate_size: 128,
7101 num_layers: 2,
7102 num_attention_heads: 4,
7103 num_kv_heads: 4,
7104 head_dim: 16,
7105 vocab_size: 256,
7106 max_seq_len: 512,
7107 rms_norm_eps: 1e-5,
7108 rope_theta: 10000.0,
7109 dtype: DType::F32,
7110 sliding_window_size: None,
7111 qkv_bias: false,
7112 hidden_activation: HiddenActivation::SiLU,
7113 }
7114 }
7115
7116 fn minimal_graph() -> Graph {
7117 Graph::new("test-metal").with_config(minimal_config())
7118 }
7119
7120 #[test]
7121 fn generate_metal_project_creates_files() {
7122 let dir = tempfile::tempdir().unwrap();
7123 let graph = minimal_graph();
7124 generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7125
7126 assert!(
7127 dir.path().join("Cargo.toml").exists(),
7128 "Cargo.toml should be created"
7129 );
7130 assert!(
7131 dir.path().join("src/model.rs").exists(),
7132 "src/model.rs should be created"
7133 );
7134 assert!(
7135 dir.path().join("src/main.rs").exists(),
7136 "src/main.rs should be created"
7137 );
7138 assert!(
7139 dir.path().join("shaders/kernels.metal").exists(),
7140 "shaders/kernels.metal should be created"
7141 );
7142 }
7143
7144 #[test]
7145 fn generated_cargo_toml_has_metal_dep() {
7146 let toml = generate_cargo_toml("my-model");
7147 assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7148 assert!(
7149 toml.contains("tokenizers"),
7150 "Cargo.toml should depend on tokenizers"
7151 );
7152 assert!(
7153 toml.contains("memmap2"),
7154 "Cargo.toml should depend on memmap2"
7155 );
7156 assert!(toml.contains("half"), "Cargo.toml should depend on half");
7157 }
7158
7159 #[test]
7160 fn generated_model_rs_contains_metal_code() {
7161 let config = minimal_config();
7162 let model_rs = generate_model_rs(&config).unwrap();
7163
7164 assert!(
7165 model_rs.contains("pub struct MetalModel"),
7166 "model.rs should define MetalModel struct"
7167 );
7168 assert!(
7169 model_rs.contains("matmul_pipeline: ComputePipelineState"),
7170 "MetalModel should have matmul_pipeline field"
7171 );
7172 assert!(
7173 model_rs.contains("Device::system_default()"),
7174 "model.rs should use Metal device"
7175 );
7176 assert!(
7177 model_rs.contains("new_library_with_source"),
7178 "model.rs should compile Metal shaders"
7179 );
7180 assert!(
7181 model_rs.contains("fn new(weights: &[u8])"),
7182 "MetalModel should implement new()"
7183 );
7184 assert!(
7185 model_rs.contains("fn forward(&mut self, token_id: u32)"),
7186 "MetalModel should implement forward()"
7187 );
7188 }
7189
7190 #[test]
7191 fn generated_shaders_contain_kernel_names() {
7192 let shaders = generate_metal_shaders(&minimal_config());
7193
7194 assert!(
7195 shaders.contains("kernel void matmul_vec"),
7196 "shaders should contain matmul_vec kernel"
7197 );
7198 assert!(
7199 shaders.contains("kernel void rms_norm"),
7200 "shaders should contain rms_norm kernel"
7201 );
7202 assert!(
7203 shaders.contains("kernel void rope"),
7204 "shaders should contain rope kernel"
7205 );
7206 assert!(
7207 shaders.contains("kernel void softmax"),
7208 "shaders should contain softmax kernel"
7209 );
7210 assert!(
7211 shaders.contains("kernel void silu_mul("),
7212 "shaders should contain silu_mul kernel"
7213 );
7214 assert!(
7215 shaders.contains("kernel void silu_mul_fused"),
7216 "shaders should contain silu_mul_fused kernel"
7217 );
7218 assert!(
7219 shaders.contains("kernel void elementwise_add"),
7220 "shaders should contain elementwise_add kernel"
7221 );
7222 assert!(
7223 shaders.contains("kernel void attention"),
7224 "shaders should contain attention kernel"
7225 );
7226 assert!(
7227 shaders.contains("kernel void add_inplace"),
7228 "shaders should contain add_inplace kernel"
7229 );
7230 assert!(
7231 shaders.contains("kernel void copy_buffer"),
7232 "shaders should contain copy_buffer kernel"
7233 );
7234 assert!(
7235 shaders.contains("kernel void copy_offset"),
7236 "shaders should contain copy_offset kernel"
7237 );
7238 }
7239
7240 #[test]
7241 fn generated_shaders_use_simdgroup_features() {
7242 let shaders = generate_metal_shaders(&minimal_config());
7243
7244 assert!(
7245 shaders.contains("threadgroup_barrier"),
7246 "shaders should use threadgroup barriers"
7247 );
7248 assert!(
7249 shaders.contains("threadgroup float"),
7250 "shaders should use threadgroup shared memory"
7251 );
7252 assert!(
7253 shaders.contains("thread_index_in_threadgroup"),
7254 "shaders should use threadgroup indexing"
7255 );
7256 assert!(
7257 shaders.contains("simd_sum"),
7258 "shaders should use simd_sum for warp-level reduction"
7259 );
7260 assert!(
7261 shaders.contains("simd_max"),
7262 "attention kernel should use simd_max for cooperative softmax"
7263 );
7264 assert!(
7265 shaders.contains("thread_index_in_simdgroup"),
7266 "shaders should use simdgroup lane indexing"
7267 );
7268 assert!(
7269 shaders.contains("simdgroup_index_in_threadgroup"),
7270 "shaders should use simdgroup indexing within threadgroup"
7271 );
7272 assert!(
7273 shaders.contains("float4"),
7274 "matmul_vec should use float4 vectorized loads"
7275 );
7276 }
7277
7278 #[test]
7279 fn generated_main_rs_has_tokenizer_usage() {
7280 let config = minimal_config();
7281 let main_rs = generate_main_rs("test-model", &config).unwrap();
7282
7283 assert!(
7284 main_rs.contains("tokenizers::Tokenizer"),
7285 "main.rs should use tokenizers crate"
7286 );
7287 assert!(
7288 main_rs.contains("MetalModel::new"),
7289 "main.rs should call MetalModel::new"
7290 );
7291 assert!(
7292 main_rs.contains("model.forward"),
7293 "main.rs should call model.forward"
7294 );
7295 assert!(
7296 main_rs.contains("memmap2"),
7297 "main.rs should use memmap2 for zero-copy weight loading"
7298 );
7299 }
7300
7301 #[test]
7302 fn missing_config_returns_error() {
7303 let dir = tempfile::tempdir().unwrap();
7304 let graph = Graph::new("no-config");
7305 let result = generate_metal_project(&graph, dir.path(), "fail");
7306 assert!(
7307 matches!(result, Err(MetalCodegenError::MissingConfig)),
7308 "should fail with MissingConfig when graph has no config"
7309 );
7310 }
7311
7312 #[test]
7313 fn sanitize_name_works() {
7314 assert_eq!(sanitize_name("My Model!"), "my-model");
7315 assert_eq!(sanitize_name("test_model"), "test-model");
7316 assert_eq!(sanitize_name("simple"), "simple");
7317 }
7318
7319 #[test]
7320 fn generated_forward_uses_single_command_buffer() {
7321 let config = minimal_config();
7322 let model_rs = generate_model_rs(&config).unwrap();
7323
7324 let forward_start = model_rs
7327 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7328 .unwrap();
7329 let forward_body = &model_rs[forward_start..];
7330 let forward_end = forward_body
7332 .find("\n pub fn forward_profile")
7333 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7334 .or_else(|| forward_body.find("\n fn dispatch_"))
7335 .unwrap_or(forward_body.len());
7336 let forward_code = &forward_body[..forward_end];
7337
7338 let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7340 assert_eq!(
7341 cmd_buf_count, 1,
7342 "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7343 );
7344
7345 let commit_count = forward_code.matches("cmd.commit()").count();
7347 assert_eq!(
7348 commit_count, 1,
7349 "forward() should commit exactly once, found {commit_count}"
7350 );
7351
7352 let wait_count = forward_code.matches("wait_until_completed()").count();
7354 assert!(
7355 wait_count >= 1 && wait_count <= 2,
7356 "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7357 );
7358 }
7359
7360 #[test]
7361 fn generated_model_has_preallocated_working_buffers() {
7362 let config = minimal_config();
7363 let model_rs = generate_model_rs(&config).unwrap();
7364
7365 for buf_name in &[
7366 "normed_buf",
7367 "qkv_buf",
7368 "attn_out_buf",
7369 "attn_proj_buf",
7370 "gate_up_buf",
7371 "ffn_hidden_buf",
7372 "ffn_out_buf",
7373 "add_tmp_buf",
7374 ] {
7375 assert!(
7376 model_rs.contains(&format!("{buf_name}: Buffer")),
7377 "MetalModel should have pre-allocated {buf_name} field"
7378 );
7379 }
7380 }
7381
7382 #[test]
7383 fn generated_dispatch_helpers_take_compute_encoder_ref() {
7384 let config = minimal_config();
7385 let model_rs = generate_model_rs(&config).unwrap();
7386
7387 for method in &[
7388 "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7389 "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7390 "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7391 "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7392 "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7393 "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7394 "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7395 "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7396 "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7397 "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7398 "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7399 "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7400 "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7401 ] {
7402 assert!(
7403 model_rs.contains(method),
7404 "model.rs should contain dispatch helper: {method}"
7405 );
7406 }
7407 }
7408
7409 #[test]
7410 fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7411 let config = minimal_config();
7412 let model_rs = generate_model_rs(&config).unwrap();
7413
7414 let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7416 let helpers_code = &model_rs[helpers_start..];
7417
7418 assert!(
7420 !helpers_code.contains("self.queue.new_command_buffer()"),
7421 "dispatch helpers should not create their own command buffers"
7422 );
7423
7424 assert!(
7426 !helpers_code.contains("new_compute_command_encoder()"),
7427 "dispatch helpers should not create their own compute encoders"
7428 );
7429
7430 assert!(
7432 !helpers_code.contains("end_encoding()"),
7433 "dispatch helpers should not call end_encoding"
7434 );
7435
7436 assert!(
7438 !helpers_code.contains(".commit()"),
7439 "dispatch helpers should not commit command buffers"
7440 );
7441 assert!(
7442 !helpers_code.contains("wait_until_completed"),
7443 "dispatch helpers should not wait on command buffers"
7444 );
7445 }
7446
7447 #[test]
7448 fn generated_forward_batches_compute_encoders() {
7449 let config = minimal_config();
7450 let model_rs = generate_model_rs(&config).unwrap();
7451
7452 let forward_start = model_rs
7454 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7455 .unwrap();
7456 let forward_body = &model_rs[forward_start..];
7457 let forward_end = forward_body
7458 .find("\n pub fn forward_profile")
7459 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7460 .or_else(|| forward_body.find("\n fn dispatch_"))
7461 .unwrap_or(forward_body.len());
7462 let forward_code = &forward_body[..forward_end];
7463
7464 assert!(
7466 !forward_code.contains("device.new_buffer"),
7467 "forward() should not allocate new buffers per call"
7468 );
7469
7470 let compute_encoder_count = forward_code
7473 .matches("new_compute_command_encoder()")
7474 .count();
7475 let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7476
7477 assert_eq!(
7479 compute_encoder_count, 1,
7480 "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7481 );
7482 assert_eq!(
7483 blit_encoder_count, 0,
7484 "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7485 );
7486 }
7487
7488 #[test]
7489 fn generated_forward_uses_add_inplace() {
7490 let config = minimal_config();
7491 let model_rs = generate_model_rs(&config).unwrap();
7492
7493 assert!(
7495 model_rs.contains("dispatch_add_inplace"),
7496 "forward() should use dispatch_add_inplace for residual connections"
7497 );
7498 assert!(
7499 model_rs.contains("add_inplace_pipeline"),
7500 "MetalModel should have add_inplace_pipeline"
7501 );
7502 }
7503
7504 fn minimal_q8_config() -> ModelConfig {
7505 ModelConfig {
7506 architecture: Architecture::Llama,
7507 hidden_size: 64,
7508 intermediate_size: 128,
7509 num_layers: 2,
7510 num_attention_heads: 4,
7511 num_kv_heads: 4,
7512 head_dim: 16,
7513 vocab_size: 256,
7514 max_seq_len: 512,
7515 rms_norm_eps: 1e-5,
7516 rope_theta: 10000.0,
7517 dtype: DType::Q8_0,
7518 sliding_window_size: None,
7519 qkv_bias: false,
7520 hidden_activation: HiddenActivation::SiLU,
7521 }
7522 }
7523
7524 #[test]
7525 fn generated_shaders_contain_q8_kernel() {
7526 let shaders = generate_metal_shaders(&minimal_config());
7527
7528 assert!(
7529 shaders.contains("kernel void matmul_vec_q8"),
7530 "shaders should contain matmul_vec_q8 kernel"
7531 );
7532 assert!(
7533 shaders.contains("device const uchar* matrix"),
7534 "matmul_vec_q8 should accept raw Q8_0 bytes"
7535 );
7536 assert!(
7537 shaders.contains("packed_short4"),
7538 "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7539 );
7540 assert!(
7541 shaders.contains("as_type<char2>"),
7542 "matmul_vec_q8 should bitcast short lanes to char2"
7543 );
7544 assert!(
7545 shaders.contains("device const half*"),
7546 "matmul_vec_q8 should read f16 scale via half pointer"
7547 );
7548 }
7549
7550 #[test]
7551 fn generated_model_uses_fused_qkv_projections() {
7552 let config = minimal_config();
7553 let model_rs = generate_model_rs(&config).unwrap();
7554
7555 assert!(
7557 model_rs.contains("qkv_weight: Buffer"),
7558 "LayerBuffers should have fused qkv_weight field"
7559 );
7560 assert!(
7562 !model_rs.contains(" q_weight: Buffer"),
7563 "LayerBuffers should not have separate q_weight field"
7564 );
7565 assert!(
7566 !model_rs.contains(" k_weight: Buffer"),
7567 "LayerBuffers should not have separate k_weight field"
7568 );
7569 assert!(
7570 !model_rs.contains(" v_weight: Buffer"),
7571 "LayerBuffers should not have separate v_weight field"
7572 );
7573
7574 assert!(
7576 model_rs.contains("gate_up_weight: Buffer"),
7577 "LayerBuffers should have fused gate_up_weight field"
7578 );
7579 assert!(
7581 !model_rs.contains(" gate_weight: Buffer"),
7582 "LayerBuffers should not have separate gate_weight field"
7583 );
7584 assert!(
7585 !model_rs.contains(" up_weight: Buffer"),
7586 "LayerBuffers should not have separate up_weight field"
7587 );
7588
7589 assert!(
7591 model_rs.contains("qkv_buf: Buffer"),
7592 "MetalModel should have fused qkv_buf"
7593 );
7594 assert!(
7595 model_rs.contains("gate_up_buf: Buffer"),
7596 "MetalModel should have fused gate_up_buf"
7597 );
7598
7599 assert!(
7601 model_rs.contains("dispatch_silu_mul_fused"),
7602 "forward pass should use dispatch_silu_mul_fused"
7603 );
7604 assert!(
7605 model_rs.contains("dispatch_rope_offset"),
7606 "forward pass should use dispatch_rope_offset for fused QKV"
7607 );
7608 assert!(
7609 model_rs.contains("dispatch_attention_offset"),
7610 "forward pass should use dispatch_attention_offset for fused QKV"
7611 );
7612 }
7613
7614 #[test]
7615 fn q8_model_has_matmul_q8_pipeline() {
7616 let config = minimal_q8_config();
7617 let model_rs = generate_model_rs(&config).unwrap();
7618
7619 assert!(
7620 model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7621 "MetalModel should have matmul_q8_pipeline field"
7622 );
7623 assert!(
7624 model_rs.contains("matmul_q8_pipeline,"),
7625 "MetalModel Self should include matmul_q8_pipeline"
7626 );
7627 }
7628
7629 #[test]
7630 fn q8_model_uses_dispatch_matmul_q8() {
7631 let config = minimal_q8_config();
7632 let model_rs = generate_model_rs(&config).unwrap();
7633
7634 assert!(
7635 model_rs.contains("dispatch_matmul_q8"),
7636 "Q8_0 model should use dispatch_matmul_q8 for projections"
7637 );
7638 assert!(
7639 model_rs.contains("fn dispatch_matmul_q8"),
7640 "model.rs should define dispatch_matmul_q8 method"
7641 );
7642 }
7643
7644 #[test]
7645 fn q8_model_loads_raw_bytes_not_dequantized() {
7646 let config = minimal_q8_config();
7647 let model_rs = generate_model_rs(&config).unwrap();
7648
7649 assert!(
7651 !model_rs.contains("f16_to_f32"),
7652 "Q8_0 model should not dequantize weights to f32"
7653 );
7654 assert!(
7655 !model_rs.contains("f32_data"),
7656 "Q8_0 model should not create f32 weight data"
7657 );
7658
7659 assert!(
7661 model_rs.contains("total_raw as u64"),
7662 "Q8_0 model should load raw bytes into Metal buffer"
7663 );
7664 }
7665
7666 #[test]
7667 fn q8_model_norms_stay_f32() {
7668 let config = minimal_q8_config();
7669 let model_rs = generate_model_rs(&config).unwrap();
7670
7671 assert!(
7673 model_rs.contains("let attn_norm = next_f32_buffer"),
7674 "attn_norm should use f32 buffer even for Q8_0 models"
7675 );
7676 assert!(
7677 model_rs.contains("let ffn_norm = next_f32_buffer"),
7678 "ffn_norm should use f32 buffer even for Q8_0 models"
7679 );
7680 assert!(
7681 model_rs.contains("let norm_buf = next_f32_buffer"),
7682 "final norm should use f32 buffer even for Q8_0 models"
7683 );
7684 }
7685
7686 #[test]
7687 fn q8_model_uses_fused_weight_loading() {
7688 let config = minimal_q8_config();
7689 let model_rs = generate_model_rs(&config).unwrap();
7690
7691 assert!(
7693 model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7694 "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7695 );
7696 assert!(
7698 model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7699 "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7700 );
7701 assert!(
7703 model_rs.contains("let o_weight = next_q8_buffer"),
7704 "Q8_0 model should use next_q8_buffer for O weight"
7705 );
7706 assert!(
7707 model_rs.contains("let down_weight = next_q8_buffer"),
7708 "Q8_0 model should use next_q8_buffer for down weight"
7709 );
7710 }
7711
7712 #[test]
7713 fn f32_model_does_not_use_q8_dispatch() {
7714 let config = minimal_config();
7715 let model_rs = generate_model_rs(&config).unwrap();
7716
7717 let forward_start = model_rs
7719 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7720 .unwrap();
7721 let forward_body = &model_rs[forward_start..];
7722 let forward_end = forward_body
7723 .find("\n fn dispatch_")
7724 .unwrap_or(forward_body.len());
7725 let forward_code = &forward_body[..forward_end];
7726
7727 assert!(
7728 !forward_code.contains("dispatch_matmul_q8"),
7729 "f32 model forward should not use dispatch_matmul_q8"
7730 );
7731 }
7732
7733 #[test]
7734 fn q8_dispatch_helper_takes_compute_encoder_ref() {
7735 let config = minimal_q8_config();
7736 let model_rs = generate_model_rs(&config).unwrap();
7737
7738 assert!(
7739 model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7740 "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7741 );
7742 }
7743
7744 #[test]
7745 fn generated_model_has_double_buffered_prefill() {
7746 let config = minimal_config();
7747 let model_rs = generate_model_rs(&config).unwrap();
7748
7749 assert!(
7751 model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7752 "MetalModel should have prev_cmd field for double-buffered prefill"
7753 );
7754
7755 assert!(
7757 model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7758 "MetalModel should have forward_prefill method"
7759 );
7760
7761 assert!(
7763 model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7764 "forward() should drain prev_cmd from previous prefill"
7765 );
7766 }
7767
7768 #[test]
7769 fn generated_main_rs_uses_forward_prefill_for_prompt() {
7770 let config = minimal_config();
7771 let main_rs = generate_main_rs("test-model", &config).unwrap();
7772
7773 assert!(
7774 main_rs.contains("forward_prefill"),
7775 "main.rs should use forward_prefill for intermediate prompt tokens"
7776 );
7777 assert!(
7778 main_rs.contains("double-buffered"),
7779 "main.rs should document double-buffered prefill"
7780 );
7781 }
7782
7783 #[test]
7784 fn generated_shaders_q8_uses_wide_vectorized_loads() {
7785 let shaders = generate_metal_shaders(&minimal_config());
7786
7787 assert!(
7788 shaders.contains("packed_short4"),
7789 "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7790 );
7791 assert!(
7792 shaders.contains("d0[0]"),
7793 "matmul_vec_q8 should index the wide pointer for row 0"
7794 );
7795 assert!(
7796 shaders.contains("as_type<char2>"),
7797 "matmul_vec_q8 should bitcast short lanes to char2"
7798 );
7799 assert!(
7800 shaders.contains("dot("),
7801 "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7802 );
7803 }
7804
7805 fn minimal_q4_config() -> ModelConfig {
7808 ModelConfig {
7809 architecture: Architecture::Llama,
7810 hidden_size: 64,
7811 intermediate_size: 128,
7812 num_layers: 2,
7813 num_attention_heads: 4,
7814 num_kv_heads: 4,
7815 head_dim: 16,
7816 vocab_size: 256,
7817 max_seq_len: 512,
7818 rms_norm_eps: 1e-5,
7819 rope_theta: 10000.0,
7820 dtype: DType::Q4_0,
7821 sliding_window_size: None,
7822 qkv_bias: false,
7823 hidden_activation: HiddenActivation::SiLU,
7824 }
7825 }
7826
7827 #[test]
7828 fn generated_shaders_contain_q4_kernel() {
7829 let shaders = generate_metal_shaders(&minimal_config());
7830
7831 assert!(
7832 shaders.contains("kernel void matmul_vec_q4"),
7833 "shaders should contain matmul_vec_q4 kernel"
7834 );
7835 assert!(
7836 shaders.contains("Q4_ROWS_PER_TG"),
7837 "shaders should define Q4_ROWS_PER_TG constant"
7838 );
7839 assert!(
7840 shaders.contains("Q4_ROWS_PER_SG"),
7841 "shaders should define Q4_ROWS_PER_SG constant"
7842 );
7843 }
7844
7845 #[test]
7846 fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7847 let shaders = generate_metal_shaders(&minimal_config());
7848
7849 assert!(
7851 shaders.contains("uchar4"),
7852 "matmul_vec_q4 should use uchar4 for packed byte loads"
7853 );
7854 assert!(
7856 shaders.contains("&0xF"),
7857 "matmul_vec_q4 should extract low nibble with &0xF"
7858 );
7859 assert!(
7860 shaders.contains(">>4"),
7861 "matmul_vec_q4 should extract high nibble with >>4"
7862 );
7863 assert!(
7865 shaders.contains("-8)"),
7866 "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7867 );
7868 assert!(
7870 shaders.contains("blk * 18"),
7871 "matmul_vec_q4 should use 18-byte block stride"
7872 );
7873 }
7874
7875 #[test]
7876 fn q4_model_has_matmul_q4_pipeline() {
7877 let config = minimal_q4_config();
7878 let model_rs = generate_model_rs(&config).unwrap();
7879
7880 assert!(
7881 model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7882 "MetalModel should have matmul_q4_pipeline field"
7883 );
7884 assert!(
7885 model_rs.contains("matmul_q4_pipeline,"),
7886 "MetalModel Self should include matmul_q4_pipeline"
7887 );
7888 }
7889
7890 #[test]
7891 fn q4_model_uses_dispatch_matmul_q4() {
7892 let config = minimal_q4_config();
7893 let model_rs = generate_model_rs(&config).unwrap();
7894
7895 assert!(
7896 model_rs.contains("dispatch_matmul_q4"),
7897 "Q4_0 model should use dispatch_matmul_q4 for projections"
7898 );
7899 assert!(
7900 model_rs.contains("fn dispatch_matmul_q4"),
7901 "model.rs should define dispatch_matmul_q4 method"
7902 );
7903 }
7904
7905 #[test]
7906 fn q4_model_loads_raw_bytes_not_dequantized() {
7907 let config = minimal_q4_config();
7908 let model_rs = generate_model_rs(&config).unwrap();
7909
7910 assert!(
7912 !model_rs.contains("f16_to_f32"),
7913 "Q4_0 model should not dequantize weights to f32"
7914 );
7915
7916 assert!(
7918 model_rs.contains("total_raw as u64"),
7919 "Q4_0 model should load raw bytes into Metal buffer"
7920 );
7921 }
7922
7923 #[test]
7924 fn q4_model_norms_stay_f32() {
7925 let config = minimal_q4_config();
7926 let model_rs = generate_model_rs(&config).unwrap();
7927
7928 assert!(
7929 model_rs.contains("let attn_norm = next_f32_buffer"),
7930 "attn_norm should use f32 buffer even for Q4_0 models"
7931 );
7932 assert!(
7933 model_rs.contains("let ffn_norm = next_f32_buffer"),
7934 "ffn_norm should use f32 buffer even for Q4_0 models"
7935 );
7936 assert!(
7937 model_rs.contains("let norm_buf = next_f32_buffer"),
7938 "final norm should use f32 buffer even for Q4_0 models"
7939 );
7940 }
7941
7942 #[test]
7943 fn q4_model_uses_fused_weight_loading() {
7944 let config = minimal_q4_config();
7945 let model_rs = generate_model_rs(&config).unwrap();
7946
7947 assert!(
7948 model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7949 "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7950 );
7951 assert!(
7952 model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7953 "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7954 );
7955 assert!(
7956 model_rs.contains("let o_weight = next_q4_buffer"),
7957 "Q4_0 model should use next_q4_buffer for O weight"
7958 );
7959 assert!(
7960 model_rs.contains("let down_weight = next_q4_buffer"),
7961 "Q4_0 model should use next_q4_buffer for down weight"
7962 );
7963 }
7964
7965 #[test]
7966 fn attention_flash_batch_kernel_exists() {
7967 let config = minimal_config();
7971 let model_rs = generate_model_rs(&config).unwrap();
7972 let shaders = generate_metal_shaders(&config);
7973
7974 assert!(
7975 shaders.contains("kernel void attention_flash_batch"),
7976 "shaders.metal must still contain the attention_flash_batch kernel"
7977 );
7978 assert!(
7979 shaders.contains("FLASH_K_TILE"),
7980 "flash kernel must tile K/V with a FLASH_K_TILE constant"
7981 );
7982 assert!(
7983 model_rs.contains("attention_flash_batch_pipeline"),
7984 "MetalModel must register the flash attention pipeline"
7985 );
7986 }
7987
7988 #[test]
7989 fn decode_uses_fused_rope_and_kv_copy() {
7990 let config = minimal_config();
7995 let model_rs = generate_model_rs(&config).unwrap();
7996
7997 let forward_start = model_rs.find("pub fn forward(").expect("forward missing");
8002 let forward_end = model_rs[forward_start..]
8003 .find("pub fn forward_prefill(")
8004 .expect("forward_prefill missing");
8005 let forward_body = &model_rs[forward_start..forward_start + forward_end];
8006
8007 assert!(
8008 forward_body.contains("dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1"),
8009 "single-token forward must use fused rope_qk_batch with M=1"
8010 );
8011 assert!(
8012 forward_body.contains("dispatch_copy_kv_both_batch(&enc, &self.qkv_buf,"),
8013 "single-token forward must use fused copy_kv_both_batch"
8014 );
8015 assert!(
8016 !forward_body.contains("dispatch_copy_from_offset_f16"),
8017 "decode path should no longer call the per-K/per-V copy"
8018 );
8019 }
8020
8021 #[test]
8022 fn kv_cache_stored_as_f16() {
8023 let config = minimal_config();
8029 let shaders = generate_metal_shaders(&config);
8030 let model_rs = generate_model_rs(&config).unwrap();
8031
8032 for kernel in [
8033 "kernel void attention(",
8034 "kernel void attention_batch(",
8035 "kernel void attention_flash_batch(",
8036 "kernel void attention_mma_flash_batch(",
8037 ] {
8038 let start = shaders
8039 .find(kernel)
8040 .unwrap_or_else(|| panic!("kernel {kernel} missing"));
8041 let sig_end = shaders[start..]
8044 .find("){")
8045 .unwrap_or_else(|| shaders[start..].find(") {").unwrap());
8046 let sig = &shaders[start..start + sig_end];
8047 assert!(
8048 sig.contains("device const half*")
8049 && sig.contains("k_cache")
8050 && sig.contains("v_cache"),
8051 "{kernel} must read k_cache/v_cache as `device const half*`"
8052 );
8053 assert!(
8054 !sig.contains("device const float* k_cache")
8055 && !sig.contains("device const float* k_cache"),
8056 "{kernel} still reads k_cache as float"
8057 );
8058 }
8059
8060 assert!(
8061 shaders.contains("kernel void copy_f32_to_f16_offset"),
8062 "f32->f16 copy kernel must be present for single-token decode KV writes"
8063 );
8064 assert!(
8065 model_rs.contains("dispatch_copy_from_offset_f16"),
8066 "single-token decode must dispatch the f32->f16 KV copy"
8067 );
8068 }
8069
8070 #[test]
8071 fn decode_attention_uses_half4_vectorized_loads() {
8072 let config = minimal_config();
8077 let shaders = generate_metal_shaders(&config);
8078
8079 let start = shaders
8080 .find("kernel void attention(")
8081 .expect("decode attention kernel missing");
8082 let end_rel = shaders[start + 1..]
8084 .find("kernel void ")
8085 .expect("next kernel missing");
8086 let body = &shaders[start..start + 1 + end_rel];
8087
8088 assert!(
8089 body.contains("device const half4*"),
8090 "decode attention must half4-load K/V"
8091 );
8092 assert!(
8093 body.contains("device const float4*"),
8094 "decode attention must float4-load Q"
8095 );
8096 assert!(
8097 body.contains("device float4*"),
8098 "decode attention must float4-store output"
8099 );
8100 assert!(
8101 body.contains("head_dim4"),
8102 "decode attention must iterate head_dim in chunks of 4"
8103 );
8104 assert!(
8106 body.contains("simd_sum(partial.x)")
8107 && body.contains("simd_sum(partial.y)")
8108 && body.contains("simd_sum(partial.z)")
8109 && body.contains("simd_sum(partial.w)"),
8110 "decode attention V-step must reduce float4 partials via simd_sum"
8111 );
8112 assert!(
8114 body.contains("for (uint d4 = simd_id; d4 < head_dim4; d4 += 8)"),
8115 "decode attention V-step must partition d4 across simdgroups"
8116 );
8117 }
8118
8119 #[test]
8120 fn attention_mma_flash_batch_kernel_wired() {
8121 let config = minimal_config();
8124 let model_rs = generate_model_rs(&config).unwrap();
8125 let shaders = generate_metal_shaders(&config);
8126
8127 assert!(
8128 shaders.contains("kernel void attention_mma_flash_batch"),
8129 "shaders.metal must contain the MMA flash kernel"
8130 );
8131 assert!(
8132 shaders.contains("FLASH_MMA_Q_BLOCK"),
8133 "MMA flash kernel must define Q_BLOCK tiling constant"
8134 );
8135 assert!(
8136 shaders.contains("simdgroup_multiply_accumulate"),
8137 "MMA flash kernel must use hardware MMA"
8138 );
8139 assert!(
8140 model_rs.contains("attention_mma_flash_batch_pipeline"),
8141 "MetalModel must register the MMA flash pipeline"
8142 );
8143 assert!(
8144 model_rs.contains("mma_opt_out"),
8145 "dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
8146 );
8147 assert!(
8148 model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
8149 "MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
8150 );
8151 }
8152
8153 #[test]
8154 fn forward_prefill_batch_chunks_by_max_batch_size() {
8155 let config = minimal_config();
8160 let model_rs = generate_model_rs(&config).unwrap();
8161 assert!(
8162 model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
8163 "forward_prefill_batch must chunk long prompts"
8164 );
8165 assert!(
8166 !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
8167 "the old truncation path must be gone"
8168 );
8169 }
8170
8171 #[test]
8172 fn qwen2_qkv_bias_wired_through_metal_codegen() {
8173 let config = ModelConfig {
8177 architecture: Architecture::Qwen2,
8178 qkv_bias: true,
8179 ..minimal_config()
8180 };
8181 let model_rs = generate_model_rs(&config).unwrap();
8182
8183 assert!(
8184 model_rs.contains("qkv_bias: Buffer"),
8185 "Qwen2 LayerBuffers must declare qkv_bias field"
8186 );
8187 assert!(
8188 model_rs.contains("let qkv_bias = next_f32_buffer"),
8189 "Qwen2 layer init must load the bias from the weight blob"
8190 );
8191 assert!(
8192 model_rs.contains("add_bias_batch_pipeline"),
8193 "Qwen2 model struct must include the add_bias_batch_pipeline"
8194 );
8195 assert!(
8196 model_rs.contains("fn dispatch_add_bias_batch"),
8197 "Qwen2 codegen must emit dispatch_add_bias_batch helper"
8198 );
8199 assert!(
8200 model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
8201 "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
8202 );
8203 assert!(
8204 model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
8205 "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
8206 );
8207
8208 let shaders = generate_metal_shaders(&config);
8210 assert!(
8211 shaders.contains("kernel void add_bias_batch"),
8212 "shaders.metal must contain the add_bias_batch kernel"
8213 );
8214 }
8215
8216 #[test]
8217 fn llama_does_not_emit_qkv_bias_machinery() {
8218 let config = minimal_config();
8221 assert!(!config.qkv_bias);
8222 let model_rs = generate_model_rs(&config).unwrap();
8223 assert!(
8224 !model_rs.contains("qkv_bias: Buffer"),
8225 "Llama must not have qkv_bias field"
8226 );
8227 assert!(
8228 !model_rs.contains("add_bias_batch_pipeline"),
8229 "Llama must not pull in add_bias_batch_pipeline"
8230 );
8231 assert!(
8232 !model_rs.contains("dispatch_add_bias_batch"),
8233 "Llama must not call dispatch_add_bias_batch"
8234 );
8235 }
8236
8237 #[test]
8238 fn q4_dispatch_helper_takes_compute_encoder_ref() {
8239 let config = minimal_q4_config();
8240 let model_rs = generate_model_rs(&config).unwrap();
8241
8242 assert!(
8243 model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8244 "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8245 );
8246 }
8247
8248 #[test]
8249 fn f32_model_does_not_use_q4_dispatch() {
8250 let config = minimal_config();
8251 let model_rs = generate_model_rs(&config).unwrap();
8252
8253 let forward_start = model_rs
8254 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8255 .unwrap();
8256 let forward_body = &model_rs[forward_start..];
8257 let forward_end = forward_body
8258 .find("\n fn dispatch_")
8259 .unwrap_or(forward_body.len());
8260 let forward_code = &forward_body[..forward_end];
8261
8262 assert!(
8263 !forward_code.contains("dispatch_matmul_q4"),
8264 "f32 model forward should not use dispatch_matmul_q4"
8265 );
8266 }
8267
8268 #[test]
8269 fn q4_model_lm_head_uses_q4_buffer() {
8270 let config = minimal_q4_config();
8271 let model_rs = generate_model_rs(&config).unwrap();
8272
8273 assert!(
8274 model_rs.contains("let lm_head_buf = next_q4_buffer"),
8275 "Q4_0 model should use next_q4_buffer for lm_head"
8276 );
8277 }
8278
8279 #[test]
8280 fn vec_tile_size_matches_model_dimensions() {
8281 let small = minimal_config();
8283 let shaders_small = generate_metal_shaders(&small);
8284 assert!(
8285 shaders_small.contains("vec_tile[128]"),
8286 "vec_tile should be sized to max(hidden, intermediate) = 128"
8287 );
8288
8289 let mut large = minimal_config();
8291 large.hidden_size = 2048;
8292 large.intermediate_size = 8192;
8293 let shaders_large = generate_metal_shaders(&large);
8294 assert!(
8295 shaders_large.contains("vec_tile[8192]"),
8296 "vec_tile should be 8192 for models with intermediate=8192"
8297 );
8298 assert!(
8299 !shaders_large.contains("vec_tile[4096]"),
8300 "vec_tile should NOT be hardcoded to 4096"
8301 );
8302 }
8303
8304 #[test]
8305 fn generated_cargo_toml_has_server_deps() {
8306 let toml = generate_cargo_toml("my-model");
8307 assert!(
8308 toml.contains("tiny_http"),
8309 "Cargo.toml should depend on tiny_http"
8310 );
8311 assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8312 assert!(
8313 toml.contains("serde_json"),
8314 "Cargo.toml should depend on serde_json"
8315 );
8316 }
8317
8318 #[test]
8319 fn generated_main_rs_has_serve_mode() {
8320 let config = minimal_config();
8321 let main_rs = generate_main_rs("test-model", &config).unwrap();
8322
8323 assert!(
8324 main_rs.contains("--serve"),
8325 "main.rs should parse --serve flag"
8326 );
8327 assert!(
8328 main_rs.contains("--port"),
8329 "main.rs should parse --port flag"
8330 );
8331 assert!(
8332 main_rs.contains("fn serve("),
8333 "main.rs should define serve function"
8334 );
8335 assert!(
8336 main_rs.contains("tiny_http::Server::http"),
8337 "main.rs should create tiny_http server"
8338 );
8339 }
8340
8341 #[test]
8342 fn generated_main_rs_has_chat_completions_endpoint() {
8343 let config = minimal_config();
8344 let main_rs = generate_main_rs("test-model", &config).unwrap();
8345
8346 assert!(
8347 main_rs.contains("/v1/chat/completions"),
8348 "main.rs should handle /v1/chat/completions endpoint"
8349 );
8350 assert!(
8351 main_rs.contains("/v1/models"),
8352 "main.rs should handle /v1/models endpoint"
8353 );
8354 assert!(
8355 main_rs.contains("/health"),
8356 "main.rs should handle /health endpoint"
8357 );
8358 }
8359
8360 #[test]
8361 fn generated_main_rs_has_sse_streaming() {
8362 let config = minimal_config();
8363 let main_rs = generate_main_rs("test-model", &config).unwrap();
8364
8365 assert!(
8366 main_rs.contains("text/event-stream"),
8367 "main.rs should set SSE content type for streaming"
8368 );
8369 assert!(
8370 main_rs.contains("chat.completion.chunk"),
8371 "main.rs should emit SSE chunks"
8372 );
8373 assert!(
8374 main_rs.contains("[DONE]"),
8375 "main.rs should emit [DONE] sentinel"
8376 );
8377 }
8378
8379 #[test]
8380 fn generated_main_rs_has_chat_message_formatting() {
8381 let config = minimal_config();
8382 let main_rs = generate_main_rs("test-model", &config).unwrap();
8383
8384 assert!(
8385 main_rs.contains("fn format_chat_messages"),
8386 "main.rs should define format_chat_messages function"
8387 );
8388 assert!(
8389 main_rs.contains("<|im_start|>"),
8390 "main.rs should use ChatML format"
8391 );
8392 assert!(
8393 main_rs.contains("<|im_end|>"),
8394 "main.rs should use ChatML format"
8395 );
8396 }
8397
8398 #[test]
8399 fn generated_main_rs_has_request_types() {
8400 let config = minimal_config();
8401 let main_rs = generate_main_rs("test-model", &config).unwrap();
8402
8403 assert!(
8404 main_rs.contains("struct ChatRequest"),
8405 "main.rs should define ChatRequest struct"
8406 );
8407 assert!(
8408 main_rs.contains("struct ChatMessage"),
8409 "main.rs should define ChatMessage struct"
8410 );
8411 assert!(
8412 main_rs.contains("Deserialize"),
8413 "main.rs should derive Deserialize for request types"
8414 );
8415 }
8416
8417 #[test]
8418 fn generated_model_has_reset_method() {
8419 let config = minimal_config();
8420 let model_rs = generate_model_rs(&config).unwrap();
8421
8422 assert!(
8423 model_rs.contains("pub fn reset(&mut self)"),
8424 "model.rs should have a reset() method for multi-request serving"
8425 );
8426 assert!(
8427 model_rs.contains("self.pos = 0"),
8428 "reset() should reset position to 0"
8429 );
8430 }
8431
8432 #[test]
8433 fn generated_main_rs_cli_mode_still_works() {
8434 let config = minimal_config();
8435 let main_rs = generate_main_rs("test-model", &config).unwrap();
8436
8437 assert!(
8439 main_rs.contains("fn cli_mode("),
8440 "main.rs should define cli_mode function"
8441 );
8442 assert!(
8443 main_rs.contains("model.forward"),
8444 "main.rs should call model.forward"
8445 );
8446 assert!(
8447 main_rs.contains("model.forward_prefill"),
8448 "main.rs should call model.forward_prefill"
8449 );
8450 }
8451
8452 #[test]
8455 fn generated_shaders_contain_batch_kernels() {
8456 let shaders = generate_metal_shaders(&minimal_config());
8457
8458 assert!(
8459 shaders.contains("kernel void matmul_vec_batch"),
8460 "shaders should contain matmul_vec_batch kernel"
8461 );
8462 assert!(
8463 shaders.contains("kernel void matmul_vec_q8_batch"),
8464 "shaders should contain matmul_vec_q8_batch kernel"
8465 );
8466 assert!(
8467 shaders.contains("kernel void matmul_q8_gemm_batch"),
8468 "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8469 );
8470 assert!(
8471 shaders.contains("kernel void matmul_vec_q4_batch"),
8472 "shaders should contain matmul_vec_q4_batch kernel"
8473 );
8474 assert!(
8475 shaders.contains("kernel void rms_norm_batch"),
8476 "shaders should contain rms_norm_batch kernel"
8477 );
8478 assert!(
8479 shaders.contains("kernel void silu_mul_fused_batch"),
8480 "shaders should contain silu_mul_fused_batch kernel"
8481 );
8482 assert!(
8483 shaders.contains("kernel void add_inplace_batch"),
8484 "shaders should contain add_inplace_batch kernel"
8485 );
8486 assert!(
8487 shaders.contains("kernel void copy_embedding_batch"),
8488 "shaders should contain copy_embedding_batch kernel"
8489 );
8490 }
8491
8492 #[test]
8493 fn generated_model_has_batch_pipelines() {
8494 let config = minimal_config();
8495 let model_rs = generate_model_rs(&config).unwrap();
8496
8497 for pipeline in &[
8498 "matmul_batch_pipeline",
8499 "matmul_q8_batch_pipeline",
8500 "matmul_q8_gemm_batch_pipeline",
8501 "matmul_q4_batch_pipeline",
8502 "rms_norm_batch_pipeline",
8503 "rope_batch_pipeline",
8504 "silu_mul_fused_batch_pipeline",
8505 "add_inplace_batch_pipeline",
8506 "copy_embedding_batch_pipeline",
8507 "attention_batch_pipeline",
8508 "copy_kv_batch_pipeline",
8509 ] {
8510 assert!(
8511 model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8512 "MetalModel should have {pipeline} field"
8513 );
8514 }
8515 }
8516
8517 #[test]
8518 fn generated_model_has_batch_buffers() {
8519 let config = minimal_config();
8520 let model_rs = generate_model_rs(&config).unwrap();
8521
8522 for buf in &[
8523 "batch_hidden_buf",
8524 "batch_residual_buf",
8525 "batch_qkv_buf",
8526 "batch_attn_out_buf",
8527 "batch_attn_proj_buf",
8528 "batch_gate_up_buf",
8529 "batch_ffn_hidden_buf",
8530 "batch_ffn_out_buf",
8531 "batch_tokens_buf",
8532 "batch_positions_buf",
8533 ] {
8534 assert!(
8535 model_rs.contains(&format!("{buf}: Buffer")),
8536 "MetalModel should have {buf} field"
8537 );
8538 }
8539 }
8540
8541 #[test]
8542 fn generated_model_has_forward_prefill_batch() {
8543 let config = minimal_config();
8544 let model_rs = generate_model_rs(&config).unwrap();
8545
8546 assert!(
8547 model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8548 "MetalModel should have forward_prefill_batch method"
8549 );
8550
8551 assert!(
8553 model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8554 "forward_prefill should delegate to forward_prefill_batch"
8555 );
8556 }
8557
8558 #[test]
8559 fn generated_model_has_max_batch_size_constant() {
8560 let config = minimal_config();
8561 let model_rs = generate_model_rs(&config).unwrap();
8562
8563 assert!(
8564 model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8565 "model.rs should define MAX_BATCH_SIZE constant"
8566 );
8567 }
8568
8569 #[test]
8570 fn forward_prefill_batch_uses_batch_dispatch() {
8571 let config = minimal_config();
8572 let model_rs = generate_model_rs(&config).unwrap();
8573
8574 let batch_start = model_rs
8575 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8576 .unwrap();
8577 let batch_body = &model_rs[batch_start..];
8578 let batch_end = batch_body
8579 .find("\n pub fn reset")
8580 .unwrap_or(batch_body.len());
8581 let batch_code = &batch_body[..batch_end];
8582
8583 assert!(
8585 batch_code.contains("dispatch_rms_norm_batch"),
8586 "forward_prefill_batch should use dispatch_rms_norm_batch"
8587 );
8588 assert!(
8589 batch_code.contains("dispatch_copy_embedding_batch"),
8590 "forward_prefill_batch should use dispatch_copy_embedding_batch"
8591 );
8592 assert!(
8593 batch_code.contains("dispatch_silu_mul_fused_batch"),
8594 "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8595 );
8596 assert!(
8598 batch_code.contains("dispatch_attention_batch"),
8599 "forward_prefill_batch should use dispatch_attention_batch"
8600 );
8601 assert!(
8603 batch_code.contains("dispatch_copy_kv_both_batch"),
8604 "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8605 );
8606 assert!(
8608 batch_code.contains("dispatch_rope_qk_batch"),
8609 "forward_prefill_batch should use dispatch_rope_qk_batch"
8610 );
8611 }
8612
8613 #[test]
8614 fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8615 let config = minimal_q8_config();
8616 let model_rs = generate_model_rs(&config).unwrap();
8617
8618 let batch_start = model_rs
8619 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8620 .unwrap();
8621 let batch_body = &model_rs[batch_start..];
8622 let batch_end = batch_body
8623 .find("\n pub fn reset")
8624 .unwrap_or(batch_body.len());
8625 let batch_code = &batch_body[..batch_end];
8626
8627 assert!(
8628 batch_code.contains("dispatch_matmul_q8_batch"),
8629 "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8630 );
8631 }
8632
8633 #[test]
8634 fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8635 let config = minimal_q4_config();
8636 let model_rs = generate_model_rs(&config).unwrap();
8637
8638 let batch_start = model_rs
8639 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8640 .unwrap();
8641 let batch_body = &model_rs[batch_start..];
8642 let batch_end = batch_body
8643 .find("\n pub fn reset")
8644 .unwrap_or(batch_body.len());
8645 let batch_code = &batch_body[..batch_end];
8646
8647 assert!(
8648 batch_code.contains("dispatch_matmul_q4_batch"),
8649 "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8650 );
8651 }
8652
8653 #[test]
8654 fn generated_main_rs_uses_batched_prefill() {
8655 let config = minimal_config();
8656 let main_rs = generate_main_rs("test-model", &config).unwrap();
8657
8658 assert!(
8659 main_rs.contains("forward_prefill_batch"),
8660 "main.rs should use forward_prefill_batch for prompt tokens"
8661 );
8662 }
8663
8664 #[test]
8665 fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8666 let config = minimal_config();
8667 let model_rs = generate_model_rs(&config).unwrap();
8668
8669 let batch_start = model_rs
8670 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8671 .unwrap();
8672 let batch_body = &model_rs[batch_start..];
8673 let batch_end = batch_body
8674 .find("\n pub fn reset")
8675 .unwrap_or(batch_body.len());
8676 let batch_code = &batch_body[..batch_end];
8677
8678 assert!(
8679 batch_code.contains("dispatch_matmul_batch"),
8680 "f32 forward_prefill_batch should use dispatch_matmul_batch"
8681 );
8682 assert!(
8684 !batch_code.contains("dispatch_matmul_q8_batch"),
8685 "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8686 );
8687 assert!(
8688 !batch_code.contains("dispatch_matmul_q4_batch"),
8689 "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8690 );
8691 }
8692
8693 #[test]
8694 fn forward_uses_cpu_embedding_lookup() {
8695 let config = minimal_config();
8696 let model_rs = generate_model_rs(&config).unwrap();
8697
8698 let forward_start = model_rs
8700 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8701 .unwrap();
8702 let forward_body = &model_rs[forward_start..];
8703 let forward_end = forward_body
8704 .find("\n pub fn forward_profile")
8705 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8706 .unwrap_or(forward_body.len());
8707 let forward_code = &forward_body[..forward_end];
8708
8709 assert!(
8711 forward_code.contains("embed_buf.contents()"),
8712 "forward() should access embed_buf via CPU unified memory for embedding lookup"
8713 );
8714 assert!(
8715 forward_code.contains("copy_nonoverlapping"),
8716 "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8717 );
8718 assert!(
8720 !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8721 "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8722 );
8723 }
8724
8725 #[test]
8726 fn forward_profile_method_exists() {
8727 let config = minimal_config();
8728 let model_rs = generate_model_rs(&config).unwrap();
8729
8730 assert!(
8731 model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8732 "MetalModel should have forward_profile() method"
8733 );
8734 assert!(
8736 model_rs.contains("[profile]"),
8737 "forward_profile() should print timing with [profile] prefix"
8738 );
8739 assert!(
8740 model_rs.contains("d_embed"),
8741 "forward_profile() should measure embedding time"
8742 );
8743 assert!(
8744 model_rs.contains("d_layers"),
8745 "forward_profile() should measure layer time"
8746 );
8747 assert!(
8748 model_rs.contains("d_logits"),
8749 "forward_profile() should measure logits time"
8750 );
8751 }
8752
8753 #[test]
8754 fn generated_cli_has_profile_flag() {
8755 let config = minimal_config();
8756 let main_rs = generate_main_rs("test-model", &config).unwrap();
8757
8758 assert!(
8759 main_rs.contains("--profile"),
8760 "CLI should support --profile flag"
8761 );
8762 assert!(
8763 main_rs.contains("forward_profile"),
8764 "CLI should call forward_profile when --profile is set"
8765 );
8766 }
8767
8768 #[test]
8769 fn generated_cli_has_thermal_yield() {
8770 let config = minimal_config();
8771 let main_rs = generate_main_rs("test-model", &config).unwrap();
8772
8773 assert!(
8774 main_rs.contains("yield_now()"),
8775 "CLI generation loop should include thread::yield_now() for thermal management"
8776 );
8777 }
8778
8779 #[test]
8782 fn generated_forward_handles_single_token_prompt() {
8783 let config = minimal_config();
8787 let model_rs = generate_model_rs(&config).unwrap();
8788
8789 let forward_start = model_rs
8791 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8792 .expect("forward() must exist");
8793 let forward_body = &model_rs[forward_start..forward_start + 400];
8794
8795 assert!(
8797 !forward_body.contains("assert!(self.pos > 0"),
8798 "forward() must accept pos=0 (first token with no prefill)"
8799 );
8800
8801 assert!(
8803 model_rs.contains("self.pos"),
8804 "forward() should use self.pos to track sequence position"
8805 );
8806 }
8807
8808 #[test]
8809 fn generated_reset_clears_kv_cache_position() {
8810 let config = minimal_config();
8813 let model_rs = generate_model_rs(&config).unwrap();
8814
8815 let reset_start = model_rs
8816 .find("pub fn reset(&mut self)")
8817 .expect("reset() must exist");
8818 let reset_body = &model_rs[reset_start..reset_start + 200];
8819
8820 assert!(
8822 reset_body.contains("self.pos = 0"),
8823 "reset() must set self.pos = 0"
8824 );
8825
8826 assert!(
8828 reset_body.contains("self.prev_cmd = None"),
8829 "reset() should clear prev_cmd for clean command buffer state"
8830 );
8831 }
8832
8833 #[test]
8834 fn generated_serve_handles_empty_messages_gracefully() {
8835 let config = minimal_config();
8838 let main_rs = generate_main_rs("test-model", &config).unwrap();
8839
8840 let format_fn_start = main_rs
8842 .find("fn format_chat_messages")
8843 .expect("format_chat_messages must exist");
8844 let format_fn_body =
8845 &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8846
8847 assert!(
8849 format_fn_body.contains("for msg in messages"),
8850 "format_chat_messages should iterate over the messages slice"
8851 );
8852 assert!(
8854 format_fn_body.contains("<|im_start|>assistant"),
8855 "format_chat_messages should always append assistant prompt header"
8856 );
8857
8858 let serve_fn_start = main_rs
8860 .find("fn serve(")
8861 .expect("serve function must exist");
8862 let serve_fn_body = &main_rs[serve_fn_start..];
8863 assert!(
8864 serve_fn_body.contains("model.reset()"),
8865 "serve function should reset model between requests"
8866 );
8867 }
8868
8869 #[test]
8870 fn generated_model_forward_increments_pos() {
8871 let config = minimal_config();
8874 let model_rs = generate_model_rs(&config).unwrap();
8875
8876 let forward_start = model_rs
8877 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8878 .unwrap();
8879 let forward_body = &model_rs[forward_start..];
8880 let forward_end = forward_body
8881 .find("\n pub fn forward_profile")
8882 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8883 .or_else(|| forward_body.find("\n fn dispatch_"))
8884 .unwrap_or(forward_body.len());
8885 let forward_code = &forward_body[..forward_end];
8886
8887 assert!(
8888 forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8889 "forward() must increment self.pos after processing a token"
8890 );
8891 }
8892}