1use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16 #[error("graph has no model config")]
18 MissingConfig,
19
20 #[error("I/O error: {0}")]
22 Io(#[from] std::io::Error),
23
24 #[error("format error: {0}")]
26 Fmt(#[from] std::fmt::Error),
27}
28
29pub fn generate_metal_project(
37 graph: &Graph,
38 output_dir: &Path,
39 model_name: &str,
40) -> Result<(), MetalCodegenError> {
41 let config = graph
42 .config
43 .as_ref()
44 .ok_or(MetalCodegenError::MissingConfig)?;
45
46 let src_dir = output_dir.join("src");
47 let shader_dir = output_dir.join("shaders");
48 fs::create_dir_all(&src_dir)?;
49 fs::create_dir_all(&shader_dir)?;
50
51 fs::write(
52 output_dir.join("Cargo.toml"),
53 generate_cargo_toml(model_name),
54 )?;
55
56 fs::write(
57 shader_dir.join("kernels.metal"),
58 generate_metal_shaders(config),
59 )?;
60
61 let model_rs = generate_model_rs(config)?;
62 fs::write(src_dir.join("model.rs"), model_rs)?;
63
64 let main_rs = generate_main_rs(model_name, config)?;
65 fs::write(src_dir.join("main.rs"), main_rs)?;
66
67 Ok(())
68}
69
70fn sanitize_name(name: &str) -> String {
75 name.to_lowercase()
76 .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77 .trim_matches('-')
78 .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82 let sanitized = sanitize_name(model_name);
83 format!(
84 r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108 )
109}
110
111fn generate_metal_shaders(config: &ModelConfig) -> String {
116 let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121 let attn_scores_size = config.max_seq_len.min(4096);
126 r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154 device const float* matrix [[buffer(0)]],
155 device const float* vector [[buffer(1)]],
156 device float* output [[buffer(2)]],
157 constant uint& rows [[buffer(3)]],
158 constant uint& cols [[buffer(4)]],
159 uint tgid [[threadgroup_position_in_grid]],
160 uint tid [[thread_index_in_threadgroup]],
161 uint simd_lane [[thread_index_in_simdgroup]],
162 uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164 // Cooperatively load vector into threadgroup shared memory
165 threadgroup float vec_tile[VEC_TILE_SIZE]; // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166 for (uint i = tid; i < cols; i += 256) {
167 vec_tile[i] = vector[i];
168 }
169 threadgroup_barrier(mem_flags::mem_threadgroup);
170
171 // Each simdgroup handles 8 consecutive rows
172 uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173 if (row_base >= rows) return;
174
175 uint base0 = row_base * cols;
176 uint base1 = (row_base + 1) * cols;
177 uint base2 = (row_base + 2) * cols;
178 uint base3 = (row_base + 3) * cols;
179 uint base4 = (row_base + 4) * cols;
180 uint base5 = (row_base + 5) * cols;
181 uint base6 = (row_base + 6) * cols;
182 uint base7 = (row_base + 7) * cols;
183
184 // float4 vectorized accumulation across 8 rows
185 uint cols_vec4 = cols & ~127u; // largest multiple of 128 <= cols
186 float4 sum4_0 = float4(0.0f);
187 float4 sum4_1 = float4(0.0f);
188 float4 sum4_2 = float4(0.0f);
189 float4 sum4_3 = float4(0.0f);
190 float4 sum4_4 = float4(0.0f);
191 float4 sum4_5 = float4(0.0f);
192 float4 sum4_6 = float4(0.0f);
193 float4 sum4_7 = float4(0.0f);
194
195 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205 }
206
207 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216 // Handle remaining elements (cols not divisible by 128)
217 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218 float vv = vec_tile[j];
219 sum0 += matrix[base0 + j] * vv;
220 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227 }
228
229 // Simdgroup hardware warp-level reduction
230 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235 // Only first lane writes the results
236 if (simd_lane == 0) {
237 if (row_base < rows) output[row_base] = sum0;
238 if (row_base + 1 < rows) output[row_base + 1] = sum1;
239 if (row_base + 2 < rows) output[row_base + 2] = sum2;
240 if (row_base + 3 < rows) output[row_base + 3] = sum3;
241 if (row_base + 4 < rows) output[row_base + 4] = sum4;
242 if (row_base + 5 < rows) output[row_base + 5] = sum5;
243 if (row_base + 6 < rows) output[row_base + 6] = sum6;
244 if (row_base + 7 < rows) output[row_base + 7] = sum7;
245 }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253 device const float* input [[buffer(0)]],
254 device const float* weight [[buffer(1)]],
255 device float* output [[buffer(2)]],
256 constant uint& n [[buffer(3)]],
257 constant float& eps [[buffer(4)]],
258 uint tid [[thread_index_in_threadgroup]])
259{
260 // Each thread accumulates partial sum-of-squares
261 float sum_sq = 0.0f;
262 for (uint i = tid; i < n; i += 256) {
263 float v = input[i];
264 sum_sq += v * v;
265 }
266
267 // Simdgroup-level reduction (hardware warp sum)
268 sum_sq = simd_sum(sum_sq);
269
270 // Cross-simdgroup reduction via shared memory
271 threadgroup float shared[8];
272 uint simd_id = tid / 32;
273 uint simd_lane = tid % 32;
274 if (simd_lane == 0) {
275 shared[simd_id] = sum_sq;
276 }
277 threadgroup_barrier(mem_flags::mem_threadgroup);
278
279 // First thread computes final inverse RMS
280 if (tid == 0) {
281 float total = 0.0f;
282 for (uint i = 0; i < 8; i++) {
283 total += shared[i];
284 }
285 shared[0] = fast::rsqrt(total / float(n) + eps);
286 }
287 threadgroup_barrier(mem_flags::mem_threadgroup);
288
289 float inv_rms = shared[0];
290
291 // Normalize
292 for (uint i = tid; i < n; i += 256) {
293 output[i] = input[i] * inv_rms * weight[i];
294 }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301 device float* data [[buffer(0)]],
302 constant uint& num_heads [[buffer(1)]],
303 constant uint& head_dim [[buffer(2)]],
304 constant uint& pos [[buffer(3)]],
305 constant float& theta [[buffer(4)]],
306 uint id [[thread_position_in_grid]])
307{
308 uint half_dim = head_dim / 2;
309 uint total_pairs = num_heads * half_dim;
310 if (id >= total_pairs) return;
311
312 uint h = id / half_dim;
313 uint i = id % half_dim;
314 uint off = h * head_dim;
315
316 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317 float angle = float(pos) * freq;
318 float c = cos(angle);
319 float s = sin(angle);
320
321 float x0 = data[off + 2 * i];
322 float x1 = data[off + 2 * i + 1];
323 data[off + 2 * i] = x0 * c - x1 * s;
324 data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331 device float* data [[buffer(0)]],
332 constant uint& n [[buffer(1)]],
333 uint tid [[thread_index_in_threadgroup]],
334 uint tg_size [[threads_per_threadgroup]])
335{
336 threadgroup float shared_val[256];
337
338 // Pass 1: find max
339 float local_max = -INFINITY;
340 for (uint i = tid; i < n; i += tg_size) {
341 local_max = max(local_max, data[i]);
342 }
343 shared_val[tid] = local_max;
344 threadgroup_barrier(mem_flags::mem_threadgroup);
345
346 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347 if (tid < stride) {
348 shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349 }
350 threadgroup_barrier(mem_flags::mem_threadgroup);
351 }
352 float max_val = shared_val[0];
353 threadgroup_barrier(mem_flags::mem_threadgroup);
354
355 // Pass 2: exp and sum
356 float local_sum = 0.0f;
357 for (uint i = tid; i < n; i += tg_size) {
358 float e = fast::exp(data[i] - max_val);
359 data[i] = e;
360 local_sum += e;
361 }
362 shared_val[tid] = local_sum;
363 threadgroup_barrier(mem_flags::mem_threadgroup);
364
365 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366 if (tid < stride) {
367 shared_val[tid] += shared_val[tid + stride];
368 }
369 threadgroup_barrier(mem_flags::mem_threadgroup);
370 }
371 float inv_sum = 1.0f / shared_val[0];
372 threadgroup_barrier(mem_flags::mem_threadgroup);
373
374 // Pass 3: normalize
375 for (uint i = tid; i < n; i += tg_size) {
376 data[i] *= inv_sum;
377 }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382// output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384 device const float* gate [[buffer(0)]],
385 device const float* up [[buffer(1)]],
386 device float* output [[buffer(2)]],
387 constant uint& n [[buffer(3)]],
388 uint id [[thread_position_in_grid]])
389{
390 if (id >= n) return;
391 float g = gate[id];
392 output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397// gate = gate_up[0..n], up = gate_up[n..2*n]
398// output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400 device const float* gate_up [[buffer(0)]],
401 device float* output [[buffer(1)]],
402 constant uint& n [[buffer(2)]],
403 uint id [[thread_position_in_grid]])
404{
405 if (id >= n) return;
406 float g = gate_up[id];
407 float u = gate_up[n + id];
408 output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414 device const float* a [[buffer(0)]],
415 device const float* b [[buffer(1)]],
416 device float* output [[buffer(2)]],
417 constant uint& n [[buffer(3)]],
418 uint id [[thread_position_in_grid]])
419{
420 if (id >= n) return;
421 output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428 device const float* src [[buffer(0)]],
429 device float* dst [[buffer(1)]],
430 constant uint& count [[buffer(2)]],
431 uint id [[thread_position_in_grid]])
432{
433 if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440 device const float* src [[buffer(0)]],
441 device float* dst [[buffer(1)]],
442 constant uint& src_offset [[buffer(2)]], // in floats
443 constant uint& count [[buffer(3)]],
444 uint id [[thread_position_in_grid]])
445{
446 if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── copy_f32_to_f16_offset ──────────────────────────────────────────────
450// Copy f32 elements from src into a half-typed dst, converting to half on
451// write. Used by the single-token decode path to append a new K/V vector
452// to the f16 KV cache. Byte offsets into src/dst are supplied via the
453// Metal buffer binding offsets — no in-kernel offsets needed.
454kernel void copy_f32_to_f16_offset(
455 device const float* src [[buffer(0)]],
456 device half* dst [[buffer(1)]],
457 constant uint& count [[buffer(2)]],
458 uint id [[thread_position_in_grid]])
459{
460 if (id < count) dst[id] = half(src[id]);
461}
462
463// ── add_inplace ─────────────────────────────────────────────────────────
464// In-place residual connection: a[i] += b[i]
465// Avoids a separate blit copy for residual add, reducing encoder overhead.
466kernel void add_inplace(
467 device float* a [[buffer(0)]],
468 device const float* b [[buffer(1)]],
469 constant uint& n [[buffer(2)]],
470 uint id [[thread_position_in_grid]])
471{
472 if (id >= n) return;
473 a[id] += b[id];
474}
475
476// ── matmul_vec_q8 ─────────────────────────────────────────────────────
477// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
478// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
479// Operates directly on quantized weights to halve memory bandwidth vs f32,
480// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
481//
482// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
483// because int8->float conversion doubles register demand. Fully unrolled
484// inner loop with float4 vector loads from shared memory eliminates loop
485// overhead and enables better instruction scheduling.
486// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
487constant constexpr uint Q8_ROWS_PER_SG = 4;
488constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
489
490// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
491// doubles ALU work, so the same register budget applies).
492constant constexpr uint Q4_ROWS_PER_SG = 4;
493constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
494
495kernel void matmul_vec_q8(
496 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes
497 device const float* vector [[buffer(1)]], // f32 input
498 device float* output [[buffer(2)]],
499 constant uint& rows [[buffer(3)]],
500 constant uint& cols [[buffer(4)]], // number of elements per row
501 uint tgid [[threadgroup_position_in_grid]],
502 uint tid [[thread_index_in_threadgroup]],
503 uint simd_lane [[thread_index_in_simdgroup]],
504 uint simd_id [[simdgroup_index_in_threadgroup]])
505{
506 // Load vector into shared memory
507 threadgroup float vec_tile[VEC_TILE_SIZE];
508 for (uint i = tid; i < cols; i += 256) {
509 vec_tile[i] = vector[i];
510 }
511 threadgroup_barrier(mem_flags::mem_threadgroup);
512
513 // Each simdgroup handles 4 consecutive rows (lower register pressure)
514 uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
515 if (row_base >= rows) return;
516
517 // Q8_0: each block is 34 bytes for 32 elements
518 uint blocks_per_row = cols / 32;
519 uint row_bytes = blocks_per_row * 34;
520
521 // Pointers to each row's Q8_0 data
522 device const uchar* r0 = matrix + row_base * row_bytes;
523 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
524 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
525 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
526
527 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
528
529 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
530 uint bb = blk * 34;
531 uint vb = blk * 32;
532
533 // Prefetch all 4 scales
534 float sc0 = float(*(device const half*)(r0 + bb));
535 float sc1 = float(*(device const half*)(r1 + bb));
536 float sc2 = float(*(device const half*)(r2 + bb));
537 float sc3 = float(*(device const half*)(r3 + bb));
538
539 // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
540 // Q8_0 block layout where the int8 data starts at offset +2 from a
541 // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
542 // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
543 // reduction in memory transactions. Metal's char16/packed_char16 are
544 // reserved types and packed_*int4 require >=4-byte alignment which
545 // this layout does not provide, so packed_short4 is the widest valid
546 // vectorized load.
547 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
548 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
549 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
550 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
551
552 // Load all 8 float4 vector values for this 32-element block from shared memory
553 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
554 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
555 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
556 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
557 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
558 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
559 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
560 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
561
562 // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
563 // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
564 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
565 short4 _s = short4(SHORT4); \
566 char2 _a = as_type<char2>(_s.x); \
567 char2 _b = as_type<char2>(_s.y); \
568 char2 _c = as_type<char2>(_s.z); \
569 char2 _d = as_type<char2>(_s.w); \
570 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
571 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
572 }
573
574 float4 f0, f1;
575 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
576
577 // Row 0: 4 short4 loads cover 32 int8 weights
578 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
579 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
580 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
581 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
582
583 // Row 1
584 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
585 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
586 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
587 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
588
589 // Row 2
590 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
591 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
592 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
593 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
594
595 // Row 3
596 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
597 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
598 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
599 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
600
601 #undef Q8_UNPACK8
602
603 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
604 }
605
606 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
607 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
608
609 if (simd_lane == 0) {
610 if (row_base < rows) output[row_base] = sum0;
611 if (row_base + 1 < rows) output[row_base + 1] = sum1;
612 if (row_base + 2 < rows) output[row_base + 2] = sum2;
613 if (row_base + 3 < rows) output[row_base + 3] = sum3;
614 }
615}
616
617// ── matmul_vec_q4 ─────────────────────────────────────────────────────
618// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
619// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
620// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
621// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
622//
623// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
624// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
625kernel void matmul_vec_q4(
626 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes
627 device const float* vector [[buffer(1)]], // f32 input
628 device float* output [[buffer(2)]],
629 constant uint& rows [[buffer(3)]],
630 constant uint& cols [[buffer(4)]], // number of elements per row
631 uint tgid [[threadgroup_position_in_grid]],
632 uint tid [[thread_index_in_threadgroup]],
633 uint simd_lane [[thread_index_in_simdgroup]],
634 uint simd_id [[simdgroup_index_in_threadgroup]])
635{
636 // Load vector into shared memory
637 threadgroup float vec_tile[VEC_TILE_SIZE];
638 for (uint i = tid; i < cols; i += 256) {
639 vec_tile[i] = vector[i];
640 }
641 threadgroup_barrier(mem_flags::mem_threadgroup);
642
643 // Each simdgroup handles 4 consecutive rows
644 uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
645 if (row_base >= rows) return;
646
647 // Q4_0: each block is 18 bytes for 32 elements
648 uint blocks_per_row = cols / 32;
649 uint row_bytes = blocks_per_row * 18;
650
651 // Pointers to each row's Q4_0 data
652 device const uchar* r0 = matrix + row_base * row_bytes;
653 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
654 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
655 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
656
657 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
658
659 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
660 uint bb = blk * 18;
661 uint vb = blk * 32;
662
663 // Prefetch all 4 scales
664 float sc0 = float(*(device const half*)(r0 + bb));
665 float sc1 = float(*(device const half*)(r1 + bb));
666 float sc2 = float(*(device const half*)(r2 + bb));
667 float sc3 = float(*(device const half*)(r3 + bb));
668
669 // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
670 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
671 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
672 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
673 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
674
675 // Load 8 float4 vector values for 32 elements from shared memory
676 // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
677 float4 v0 = *(threadgroup const float4*)(vec_tile + vb); // [0..3]
678 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4); // [4..7]
679 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8); // [8..11]
680 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12); // [12..15]
681 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16); // [16..19]
682 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20); // [20..23]
683 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24); // [24..27]
684 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28); // [28..31]
685
686 // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
687 // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
688 float bd0=0, bd1=0, bd2=0, bd3=0;
689 uchar4 b;
690
691 // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
692 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
693 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
694 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
695 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
696 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
697 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
698 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
699 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
700 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
701 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
702 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
703 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
704 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
705 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
706 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
707 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
708
709 // Row 1
710 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
711 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
712 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
713 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
714 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
715 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
716 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
717 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
718 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
719 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
720 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
721 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
722 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
723 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
724 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
725 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
726
727 // Row 2
728 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
729 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
730 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
731 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
732 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
733 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
734 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
735 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
736 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
737 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
738 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
739 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
740 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
741 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
742 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
743 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
744
745 // Row 3
746 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
747 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
748 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
749 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
750 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
751 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
752 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
753 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
754 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
755 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
756 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
757 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
758 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
759 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
760 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
761 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
762
763 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
764 }
765
766 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
767 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
768
769 if (simd_lane == 0) {
770 if (row_base < rows) output[row_base] = sum0;
771 if (row_base + 1 < rows) output[row_base + 1] = sum1;
772 if (row_base + 2 < rows) output[row_base + 2] = sum2;
773 if (row_base + 3 < rows) output[row_base + 3] = sum3;
774 }
775}
776
777// ── attention ───────────────────────────────────────────────────────────
778// Single-query attention with simdgroup cooperative reductions.
779// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
780// with simd_max/simd_sum reductions, then weighted sum of V.
781// Each threadgroup handles one head with 256 threads (8 simdgroups).
782//
783// Buffers:
784// q: [num_heads * head_dim] current query
785// k_cache: [max_seq_len * num_kv_heads * head_dim]
786// v_cache: [max_seq_len * num_kv_heads * head_dim]
787// output: [num_heads * head_dim]
788kernel void attention(
789 device const float* q [[buffer(0)]],
790 device const half* k_cache [[buffer(1)]],
791 device const half* v_cache [[buffer(2)]],
792 device float* output [[buffer(3)]],
793 constant uint& seq_len [[buffer(4)]],
794 constant uint& num_heads [[buffer(5)]],
795 constant uint& num_kv_heads [[buffer(6)]],
796 constant uint& head_dim [[buffer(7)]],
797 uint tgid [[threadgroup_position_in_grid]],
798 uint tid [[thread_index_in_threadgroup]],
799 uint simd_lane [[thread_index_in_simdgroup]],
800 uint simd_id [[simdgroup_index_in_threadgroup]])
801{
802 uint head = tgid;
803 if (head >= num_heads) return;
804 uint kv_head = head / (num_heads / num_kv_heads);
805
806 uint q_off = head * head_dim;
807
808 // Step 1: Compute attention scores Q·K^T with simdgroup reduction
809 // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
810 // vs 4096. For seq_len > 2048, generation-phase attention is rare;
811 // most generation steps have short effective context.
812 threadgroup float scores[ATTN_SCORES_SIZE]; // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
813
814 for (uint s = simd_id; s < seq_len; s += 8) {
815 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
816 float dot = 0.0;
817 for (uint d = simd_lane; d < head_dim; d += 32) {
818 dot += q[q_off + d] * k_cache[k_off + d];
819 }
820 dot = simd_sum(dot);
821 if (simd_lane == 0) {
822 scores[s] = dot * fast::rsqrt(float(head_dim));
823 }
824 }
825 threadgroup_barrier(mem_flags::mem_threadgroup);
826
827 // Step 2: Softmax over scores (cooperative)
828 // Find max
829 float local_max = -INFINITY;
830 for (uint s = tid; s < seq_len; s += 256) {
831 local_max = max(local_max, scores[s]);
832 }
833 local_max = simd_max(local_max);
834 threadgroup float shared_max[8];
835 if (simd_lane == 0) shared_max[simd_id] = local_max;
836 threadgroup_barrier(mem_flags::mem_threadgroup);
837 if (tid == 0) {
838 float m = shared_max[0];
839 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
840 shared_max[0] = m;
841 }
842 threadgroup_barrier(mem_flags::mem_threadgroup);
843 float max_val = shared_max[0];
844
845 // Exp and sum
846 float local_sum = 0.0;
847 for (uint s = tid; s < seq_len; s += 256) {
848 scores[s] = fast::exp(scores[s] - max_val);
849 local_sum += scores[s];
850 }
851 local_sum = simd_sum(local_sum);
852 threadgroup float shared_sum[8];
853 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
854 threadgroup_barrier(mem_flags::mem_threadgroup);
855 if (tid == 0) {
856 float total = 0.0;
857 for (uint i = 0; i < 8; i++) total += shared_sum[i];
858 shared_sum[0] = 1.0 / total;
859 }
860 threadgroup_barrier(mem_flags::mem_threadgroup);
861 float inv_sum = shared_sum[0];
862
863 for (uint s = tid; s < seq_len; s += 256) {
864 scores[s] *= inv_sum;
865 }
866 threadgroup_barrier(mem_flags::mem_threadgroup);
867
868 // Step 3: Weighted sum of V: output = scores · V
869 // Each thread handles a range of head_dim dimensions.
870 // Process 4 sequence positions at a time for better ILP and reduced
871 // loop overhead (float4 score gather, 4 V loads per iteration).
872 uint seq_len4 = seq_len & ~3u; // largest multiple of 4 <= seq_len
873 uint v_stride = num_kv_heads * head_dim;
874 for (uint d = tid; d < head_dim; d += 256) {
875 float acc = 0.0;
876 uint v_base = kv_head * head_dim + d;
877 for (uint s = 0; s < seq_len4; s += 4) {
878 float sc0 = scores[s];
879 float sc1 = scores[s + 1];
880 float sc2 = scores[s + 2];
881 float sc3 = scores[s + 3];
882 acc += sc0 * v_cache[s * v_stride + v_base]
883 + sc1 * v_cache[(s+1) * v_stride + v_base]
884 + sc2 * v_cache[(s+2) * v_stride + v_base]
885 + sc3 * v_cache[(s+3) * v_stride + v_base];
886 }
887 for (uint s = seq_len4; s < seq_len; s++) {
888 acc += scores[s] * v_cache[s * v_stride + v_base];
889 }
890 output[q_off + d] = acc;
891 }
892}
893
894// ── Batched prefill kernels ────────────────────────────────────────────
895// These kernels process M input vectors against the same weight matrix
896// in a single dispatch, converting mat-vec into mat-mat for better GPU
897// utilization during prompt prefill.
898
899// ── rms_norm_batch ─────────────────────────────────────────────────────
900// RMS normalization for a batch of vectors.
901// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
902// Grid: M threadgroups (one per token).
903kernel void rms_norm_batch(
904 device const float* input [[buffer(0)]], // [M, n]
905 device const float* weight [[buffer(1)]], // [n]
906 device float* output [[buffer(2)]], // [M, n]
907 constant uint& n [[buffer(3)]],
908 constant float& eps [[buffer(4)]],
909 constant uint& num_tokens [[buffer(5)]],
910 uint tgid [[threadgroup_position_in_grid]],
911 uint tid [[thread_index_in_threadgroup]])
912{
913 if (tgid >= num_tokens) return;
914
915 uint base = tgid * n;
916
917 float sum_sq = 0.0f;
918 for (uint i = tid; i < n; i += 256) {
919 float v = input[base + i];
920 sum_sq += v * v;
921 }
922
923 sum_sq = simd_sum(sum_sq);
924
925 threadgroup float shared[8];
926 uint simd_id = tid / 32;
927 uint simd_lane = tid % 32;
928 if (simd_lane == 0) {
929 shared[simd_id] = sum_sq;
930 }
931 threadgroup_barrier(mem_flags::mem_threadgroup);
932
933 if (tid == 0) {
934 float total = 0.0f;
935 for (uint i = 0; i < 8; i++) {
936 total += shared[i];
937 }
938 shared[0] = fast::rsqrt(total / float(n) + eps);
939 }
940 threadgroup_barrier(mem_flags::mem_threadgroup);
941
942 float inv_rms = shared[0];
943
944 for (uint i = tid; i < n; i += 256) {
945 output[base + i] = input[base + i] * inv_rms * weight[i];
946 }
947}
948
949// ── rope_batch ─────────────────────────────────────────────────────────
950// Rotary Position Embedding for a batch of vectors with different positions.
951// data layout: [M, num_heads * head_dim], positions: [M]
952// Each thread handles one (token, head, pair) combination.
953kernel void rope_batch(
954 device float* data [[buffer(0)]], // [M, num_heads * head_dim]
955 constant uint& num_heads [[buffer(1)]],
956 constant uint& head_dim [[buffer(2)]],
957 device const uint* positions [[buffer(3)]], // [M] position per token
958 constant float& theta [[buffer(4)]],
959 constant uint& num_tokens [[buffer(5)]],
960 uint id [[thread_position_in_grid]])
961{
962 uint half_dim = head_dim / 2;
963 uint pairs_per_token = num_heads * half_dim;
964 uint total = num_tokens * pairs_per_token;
965 if (id >= total) return;
966
967 uint token = id / pairs_per_token;
968 uint rem = id % pairs_per_token;
969 uint h = rem / half_dim;
970 uint i = rem % half_dim;
971 uint off = token * (num_heads * head_dim) + h * head_dim;
972
973 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
974 float angle = float(positions[token]) * freq;
975 float c = cos(angle);
976 float s = sin(angle);
977
978 float x0 = data[off + 2 * i];
979 float x1 = data[off + 2 * i + 1];
980 data[off + 2 * i] = x0 * c - x1 * s;
981 data[off + 2 * i + 1] = x0 * s + x1 * c;
982}
983
984// ── silu_mul_fused_batch ───────────────────────────────────────────────
985// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
986// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
987kernel void silu_mul_fused_batch(
988 device const float* gate_up [[buffer(0)]], // [M, 2*n]
989 device float* output [[buffer(1)]], // [M, n]
990 constant uint& n [[buffer(2)]],
991 constant uint& num_tokens [[buffer(3)]],
992 uint id [[thread_position_in_grid]])
993{
994 uint total = num_tokens * n;
995 if (id >= total) return;
996 uint token = id / n;
997 uint i = id % n;
998 uint gu_base = token * 2 * n;
999 float g = gate_up[gu_base + i];
1000 float u = gate_up[gu_base + n + i];
1001 output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
1002}
1003
1004// ── add_inplace_batch ──────────────────────────────────────────────────
1005// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
1006kernel void add_inplace_batch(
1007 device float* a [[buffer(0)]], // [M * n]
1008 device const float* b [[buffer(1)]], // [M * n]
1009 constant uint& total [[buffer(2)]], // M * n
1010 uint id [[thread_position_in_grid]])
1011{
1012 if (id >= total) return;
1013 a[id] += b[id];
1014}
1015
1016// ── copy_embedding_batch ───────────────────────────────────────────────
1017// Copy M embedding rows from embedding table to a contiguous batch buffer.
1018// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1019kernel void copy_embedding_batch(
1020 device const float* embed [[buffer(0)]], // [vocab_size, dim]
1021 device float* output [[buffer(1)]], // [M, dim]
1022 device const uint* tokens [[buffer(2)]], // [M]
1023 constant uint& dim [[buffer(3)]],
1024 constant uint& num_tokens [[buffer(4)]],
1025 uint id [[thread_position_in_grid]])
1026{
1027 uint total = num_tokens * dim;
1028 if (id >= total) return;
1029 uint token_idx = id / dim;
1030 uint d = id % dim;
1031 output[id] = embed[tokens[token_idx] * dim + d];
1032}
1033
1034// ── matmul_vec_batch ───────────────────────────────────────────────────
1035// Batched matrix-vector multiply: process M input vectors against the same
1036// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1037// Each threadgroup handles one (token, row_group) pair.
1038kernel void matmul_vec_batch(
1039 device const float* matrix [[buffer(0)]], // [rows, cols] weight
1040 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1041 device float* outputs [[buffer(2)]], // [M, rows] output batch
1042 constant uint& num_tokens [[buffer(3)]], // M
1043 constant uint& rows [[buffer(4)]],
1044 constant uint& cols [[buffer(5)]],
1045 uint tgid [[threadgroup_position_in_grid]],
1046 uint tid [[thread_index_in_threadgroup]],
1047 uint simd_lane [[thread_index_in_simdgroup]],
1048 uint simd_id [[simdgroup_index_in_threadgroup]])
1049{
1050 uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1051 uint token = tgid / row_tgs;
1052 uint tg_in_token = tgid % row_tgs;
1053 if (token >= num_tokens) return;
1054
1055 // Load this token's input vector into shared memory
1056 threadgroup float vec_tile[VEC_TILE_SIZE];
1057 device const float* input = inputs + token * cols;
1058 for (uint i = tid; i < cols; i += 256) {
1059 vec_tile[i] = input[i];
1060 }
1061 threadgroup_barrier(mem_flags::mem_threadgroup);
1062
1063 uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1064 if (row_base >= rows) return;
1065
1066 uint base0 = row_base * cols;
1067 uint base1 = (row_base + 1) * cols;
1068 uint base2 = (row_base + 2) * cols;
1069 uint base3 = (row_base + 3) * cols;
1070 uint base4 = (row_base + 4) * cols;
1071 uint base5 = (row_base + 5) * cols;
1072 uint base6 = (row_base + 6) * cols;
1073 uint base7 = (row_base + 7) * cols;
1074
1075 uint cols_vec4 = cols & ~127u;
1076 float4 sum4_0 = float4(0.0f);
1077 float4 sum4_1 = float4(0.0f);
1078 float4 sum4_2 = float4(0.0f);
1079 float4 sum4_3 = float4(0.0f);
1080 float4 sum4_4 = float4(0.0f);
1081 float4 sum4_5 = float4(0.0f);
1082 float4 sum4_6 = float4(0.0f);
1083 float4 sum4_7 = float4(0.0f);
1084
1085 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1086 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1087 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1088 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1089 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1090 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1091 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1092 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1093 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1094 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1095 }
1096
1097 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1098 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1099 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1100 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1101 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1102 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1103 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1104 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1105
1106 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1107 float vv = vec_tile[j];
1108 sum0 += matrix[base0 + j] * vv;
1109 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1110 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1111 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1112 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1113 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1114 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1115 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1116 }
1117
1118 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1119 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1120 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1121 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1122
1123 device float* output = outputs + token * rows;
1124 if (simd_lane == 0) {
1125 if (row_base < rows) output[row_base] = sum0;
1126 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1127 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1128 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1129 if (row_base + 4 < rows) output[row_base + 4] = sum4;
1130 if (row_base + 5 < rows) output[row_base + 5] = sum5;
1131 if (row_base + 6 < rows) output[row_base + 6] = sum6;
1132 if (row_base + 7 < rows) output[row_base + 7] = sum7;
1133 }
1134}
1135
1136// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1137// Batched Q8_0 matrix-vector multiply for M input vectors.
1138// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1139kernel void matmul_vec_q8_batch(
1140 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1141 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1142 device float* outputs [[buffer(2)]], // [M, rows] output batch
1143 constant uint& num_tokens [[buffer(3)]], // M
1144 constant uint& rows [[buffer(4)]],
1145 constant uint& cols [[buffer(5)]],
1146 uint tgid [[threadgroup_position_in_grid]],
1147 uint tid [[thread_index_in_threadgroup]],
1148 uint simd_lane [[thread_index_in_simdgroup]],
1149 uint simd_id [[simdgroup_index_in_threadgroup]])
1150{
1151 uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1152 uint token = tgid / row_tgs;
1153 uint tg_in_token = tgid % row_tgs;
1154 if (token >= num_tokens) return;
1155
1156 // Load this token's input vector into shared memory
1157 threadgroup float vec_tile[VEC_TILE_SIZE];
1158 device const float* input = inputs + token * cols;
1159 for (uint i = tid; i < cols; i += 256) {
1160 vec_tile[i] = input[i];
1161 }
1162 threadgroup_barrier(mem_flags::mem_threadgroup);
1163
1164 uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1165 if (row_base >= rows) return;
1166
1167 uint blocks_per_row = cols / 32;
1168 uint row_bytes = blocks_per_row * 34;
1169
1170 device const uchar* r0 = matrix + row_base * row_bytes;
1171 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1172 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1173 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1174
1175 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1176
1177 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1178 uint bb = blk * 34;
1179 uint vb = blk * 32;
1180
1181 float sc0 = float(*(device const half*)(r0 + bb));
1182 float sc1 = float(*(device const half*)(r1 + bb));
1183 float sc2 = float(*(device const half*)(r2 + bb));
1184 float sc3 = float(*(device const half*)(r3 + bb));
1185
1186 // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1187 // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1188 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1189 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1190 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1191 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1192
1193 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1194 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1195 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1196 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1197 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1198 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1199 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1200 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1201
1202 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1203 short4 _s = short4(SHORT4); \
1204 char2 _a = as_type<char2>(_s.x); \
1205 char2 _b = as_type<char2>(_s.y); \
1206 char2 _c = as_type<char2>(_s.z); \
1207 char2 _d = as_type<char2>(_s.w); \
1208 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1209 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1210 }
1211
1212 float4 f0, f1;
1213 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1214
1215 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1216 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1217 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1218 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1219
1220 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1221 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1222 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1223 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1224
1225 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1226 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1227 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1228 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1229
1230 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1231 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1232 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1233 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1234
1235 #undef Q8_UNPACK8
1236
1237 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1238 }
1239
1240 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1241 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1242
1243 device float* output = outputs + token * rows;
1244 if (simd_lane == 0) {
1245 if (row_base < rows) output[row_base] = sum0;
1246 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1247 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1248 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1249 }
1250}
1251
1252// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1253// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1254// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1255// the Q8_0 weight blocks are fetched once from device memory and reused for
1256// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1257// per-token dispatch).
1258//
1259// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1260// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1261// with simd_sum. Token vectors are read directly from device memory inside
1262// the block loop (not cached in shared memory) so intermediate_size up to
1263// 8192 fits without spilling threadgroup memory.
1264constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1265
1266kernel void matmul_q8_gemm_batch(
1267 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1268 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1269 device float* outputs [[buffer(2)]], // [M, rows] output batch
1270 constant uint& num_tokens [[buffer(3)]], // M
1271 constant uint& rows [[buffer(4)]],
1272 constant uint& cols [[buffer(5)]],
1273 uint2 tgid [[threadgroup_position_in_grid]],
1274 uint tid [[thread_index_in_threadgroup]],
1275 uint simd_lane [[thread_index_in_simdgroup]],
1276 uint simd_id [[simdgroup_index_in_threadgroup]])
1277{
1278 uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1279 uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1280 if (row_base >= rows || tok_base >= num_tokens) return;
1281
1282 // How many tokens in this tile are valid?
1283 uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1284
1285 uint blocks_per_row = cols / 32;
1286 uint row_bytes = blocks_per_row * 34;
1287
1288 device const uchar* r0 = matrix + row_base * row_bytes;
1289 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1290 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1291 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1292
1293 // Accumulators: 4 tokens × 4 rows per simdgroup.
1294 float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1295 float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1296 float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1297 float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1298
1299 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1300 uint bb = blk * 34;
1301 uint vb = blk * 32;
1302
1303 // ── Load weight data ONCE per block (reused across all tokens) ──
1304 float sc0 = float(*(device const half*)(r0 + bb));
1305 float sc1 = float(*(device const half*)(r1 + bb));
1306 float sc2 = float(*(device const half*)(r2 + bb));
1307 float sc3 = float(*(device const half*)(r3 + bb));
1308
1309 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1310 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1311 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1312 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1313
1314 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1315 short4 _s = short4(SHORT4); \
1316 char2 _a = as_type<char2>(_s.x); \
1317 char2 _b = as_type<char2>(_s.y); \
1318 char2 _c = as_type<char2>(_s.z); \
1319 char2 _d = as_type<char2>(_s.w); \
1320 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1321 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1322 }
1323
1324 // Unpack all 4 rows × 8 float4 weights (scaled). These live in
1325 // registers for the duration of the block and are dotted against
1326 // every token's vector tile.
1327 float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1328 float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1329 float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1330 float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1331
1332 Q8_UNPACK8(d0[0], w0_0, w0_1);
1333 Q8_UNPACK8(d0[1], w0_2, w0_3);
1334 Q8_UNPACK8(d0[2], w0_4, w0_5);
1335 Q8_UNPACK8(d0[3], w0_6, w0_7);
1336
1337 Q8_UNPACK8(d1[0], w1_0, w1_1);
1338 Q8_UNPACK8(d1[1], w1_2, w1_3);
1339 Q8_UNPACK8(d1[2], w1_4, w1_5);
1340 Q8_UNPACK8(d1[3], w1_6, w1_7);
1341
1342 Q8_UNPACK8(d2[0], w2_0, w2_1);
1343 Q8_UNPACK8(d2[1], w2_2, w2_3);
1344 Q8_UNPACK8(d2[2], w2_4, w2_5);
1345 Q8_UNPACK8(d2[3], w2_6, w2_7);
1346
1347 Q8_UNPACK8(d3[0], w3_0, w3_1);
1348 Q8_UNPACK8(d3[1], w3_2, w3_3);
1349 Q8_UNPACK8(d3[2], w3_4, w3_5);
1350 Q8_UNPACK8(d3[3], w3_6, w3_7);
1351
1352 #undef Q8_UNPACK8
1353
1354 // ── For each token, read vector and accumulate against shared weights ──
1355 // Token 0 (always valid: tok_count >= 1).
1356 {
1357 device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1358 float4 v0 = *(device const float4*)(a0);
1359 float4 v1 = *(device const float4*)(a0 + 4);
1360 float4 v2 = *(device const float4*)(a0 + 8);
1361 float4 v3 = *(device const float4*)(a0 + 12);
1362 float4 v4 = *(device const float4*)(a0 + 16);
1363 float4 v5 = *(device const float4*)(a0 + 20);
1364 float4 v6 = *(device const float4*)(a0 + 24);
1365 float4 v7 = *(device const float4*)(a0 + 28);
1366 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1367 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1368 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1369 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1370 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1371 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1372 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1373 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1374 s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1375 }
1376 // Token 1
1377 if (tok_count > 1) {
1378 device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1379 float4 v0 = *(device const float4*)(a1);
1380 float4 v1 = *(device const float4*)(a1 + 4);
1381 float4 v2 = *(device const float4*)(a1 + 8);
1382 float4 v3 = *(device const float4*)(a1 + 12);
1383 float4 v4 = *(device const float4*)(a1 + 16);
1384 float4 v5 = *(device const float4*)(a1 + 20);
1385 float4 v6 = *(device const float4*)(a1 + 24);
1386 float4 v7 = *(device const float4*)(a1 + 28);
1387 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1388 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1389 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1390 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1391 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1392 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1393 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1394 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1395 s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1396 }
1397 // Token 2
1398 if (tok_count > 2) {
1399 device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1400 float4 v0 = *(device const float4*)(a2);
1401 float4 v1 = *(device const float4*)(a2 + 4);
1402 float4 v2 = *(device const float4*)(a2 + 8);
1403 float4 v3 = *(device const float4*)(a2 + 12);
1404 float4 v4 = *(device const float4*)(a2 + 16);
1405 float4 v5 = *(device const float4*)(a2 + 20);
1406 float4 v6 = *(device const float4*)(a2 + 24);
1407 float4 v7 = *(device const float4*)(a2 + 28);
1408 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1409 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1410 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1411 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1412 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1413 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1414 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1415 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1416 s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1417 }
1418 // Token 3
1419 if (tok_count > 3) {
1420 device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1421 float4 v0 = *(device const float4*)(a3);
1422 float4 v1 = *(device const float4*)(a3 + 4);
1423 float4 v2 = *(device const float4*)(a3 + 8);
1424 float4 v3 = *(device const float4*)(a3 + 12);
1425 float4 v4 = *(device const float4*)(a3 + 16);
1426 float4 v5 = *(device const float4*)(a3 + 20);
1427 float4 v6 = *(device const float4*)(a3 + 24);
1428 float4 v7 = *(device const float4*)(a3 + 28);
1429 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1430 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1431 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1432 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1433 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1434 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1435 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1436 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1437 s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1438 }
1439 }
1440
1441 // simdgroup reduction
1442 s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1443 s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1444 s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1445 s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1446
1447 if (simd_lane == 0) {
1448 device float* o0 = outputs + (tok_base + 0) * rows;
1449 if (row_base < rows) o0[row_base] = s00;
1450 if (row_base + 1 < rows) o0[row_base + 1] = s01;
1451 if (row_base + 2 < rows) o0[row_base + 2] = s02;
1452 if (row_base + 3 < rows) o0[row_base + 3] = s03;
1453
1454 if (tok_count > 1) {
1455 device float* o1 = outputs + (tok_base + 1) * rows;
1456 if (row_base < rows) o1[row_base] = s10;
1457 if (row_base + 1 < rows) o1[row_base + 1] = s11;
1458 if (row_base + 2 < rows) o1[row_base + 2] = s12;
1459 if (row_base + 3 < rows) o1[row_base + 3] = s13;
1460 }
1461 if (tok_count > 2) {
1462 device float* o2 = outputs + (tok_base + 2) * rows;
1463 if (row_base < rows) o2[row_base] = s20;
1464 if (row_base + 1 < rows) o2[row_base + 1] = s21;
1465 if (row_base + 2 < rows) o2[row_base + 2] = s22;
1466 if (row_base + 3 < rows) o2[row_base + 3] = s23;
1467 }
1468 if (tok_count > 3) {
1469 device float* o3 = outputs + (tok_base + 3) * rows;
1470 if (row_base < rows) o3[row_base] = s30;
1471 if (row_base + 1 < rows) o3[row_base + 1] = s31;
1472 if (row_base + 2 < rows) o3[row_base + 2] = s32;
1473 if (row_base + 3 < rows) o3[row_base + 3] = s33;
1474 }
1475 }
1476}
1477
1478// ── matmul_q8_mma ──────────────────────────────────────────────────────
1479// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1480// simdgroup_matrix tiles (simdgroup_multiply_accumulate). This dispatches
1481// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1482// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1483//
1484// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1485// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1486// simdgroup_matrix<float, 8, 8> accumulator. Weight bytes are cooperatively
1487// dequantized into threadgroup memory once per block and reused by all
1488// simdgroups in the tile.
1489//
1490// Assumptions (verified in the dispatch helper, falls back otherwise):
1491// * cols % 32 == 0 (one Q8_0 block per K chunk)
1492// * rows % 16 == 0 (tile-aligned; true for all supported architectures)
1493// * num_tokens may be any value; partial row at the tile boundary is handled
1494// via a scratch copy path.
1495constant constexpr uint MMA_TOK_TILE = 16;
1496constant constexpr uint MMA_ROW_TILE = 16;
1497
1498kernel void matmul_q8_mma(
1499 device const uchar* matrix [[buffer(0)]], // Q8_0 [rows, cols/32 * 34]
1500 device const float* inputs [[buffer(1)]], // [M, cols]
1501 device float* outputs [[buffer(2)]], // [M, rows]
1502 constant uint& num_tokens [[buffer(3)]],
1503 constant uint& rows [[buffer(4)]],
1504 constant uint& cols [[buffer(5)]],
1505 uint2 tgid [[threadgroup_position_in_grid]],
1506 uint tid [[thread_index_in_threadgroup]],
1507 uint simd_id [[simdgroup_index_in_threadgroup]])
1508{
1509 uint row_base = tgid.x * MMA_ROW_TILE;
1510 uint tok_base = tgid.y * MMA_TOK_TILE;
1511 if (row_base >= rows || tok_base >= num_tokens) return;
1512
1513 // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1514 threadgroup float w_tile[MMA_ROW_TILE * 32];
1515 threadgroup float t_tile[MMA_TOK_TILE * 32];
1516
1517 // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1518 uint sg_tok_base = (simd_id / 2) * 8; // row within output tile (token dim)
1519 uint sg_row_base = (simd_id % 2) * 8; // col within output tile (row dim)
1520
1521 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1522
1523 uint blocks_per_row = cols / 32;
1524 uint row_bytes = blocks_per_row * 34;
1525
1526 for (uint blk = 0; blk < blocks_per_row; blk++) {
1527 // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1528 // 512 floats / 128 threads = 4 floats per thread.
1529 {
1530 uint base = tid * 4;
1531 for (uint ii = 0; ii < 4; ii++) {
1532 uint idx = base + ii;
1533 uint r = idx / 32;
1534 uint k = idx % 32;
1535 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1536 float sc = float(*(device const half*)rp);
1537 int ival = int(*(device const int8_t*)(rp + 2 + k));
1538 w_tile[r * 32 + k] = float(ival) * sc;
1539 }
1540 }
1541
1542 // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1543 {
1544 uint base = tid * 4;
1545 for (uint ii = 0; ii < 4; ii++) {
1546 uint idx = base + ii;
1547 uint m = idx / 32;
1548 uint k = idx % 32;
1549 uint tok = tok_base + m;
1550 t_tile[m * 32 + k] = (tok < num_tokens)
1551 ? inputs[tok * cols + blk * 32 + k]
1552 : 0.0f;
1553 }
1554 }
1555
1556 threadgroup_barrier(mem_flags::mem_threadgroup);
1557
1558 // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1559 // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k] (M×K, no transpose)
1560 // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k] (loaded transposed → K×R)
1561 // C[m, r] += A[m, k] * B[k, r]
1562 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1563 simdgroup_matrix<float, 8, 8> A, B;
1564 simdgroup_load(A,
1565 t_tile + sg_tok_base * 32 + k_sub * 8,
1566 32,
1567 ulong2(0, 0),
1568 false);
1569 simdgroup_load(B,
1570 w_tile + sg_row_base * 32 + k_sub * 8,
1571 32,
1572 ulong2(0, 0),
1573 true);
1574 simdgroup_multiply_accumulate(C, A, B, C);
1575 }
1576
1577 threadgroup_barrier(mem_flags::mem_threadgroup);
1578 }
1579
1580 // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1581 // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1582 uint out_tok = tok_base + sg_tok_base;
1583 uint out_row = row_base + sg_row_base;
1584 bool full_tok = (out_tok + 8 <= num_tokens);
1585 if (full_tok) {
1586 // Fast path: entire 8×8 sub-tile is in-bounds.
1587 simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1588 } else if (out_tok < num_tokens) {
1589 // Partial row at the last token tile: stage in per-simdgroup scratch
1590 // and scalar-copy the valid rows.
1591 threadgroup float scratch[4 * 64];
1592 simdgroup_store(C, scratch + simd_id * 64, 8);
1593 simdgroup_barrier(mem_flags::mem_threadgroup);
1594 uint lane = tid % 32;
1595 if (lane == 0) {
1596 uint valid = num_tokens - out_tok; // 1..7
1597 for (uint m = 0; m < valid; m++) {
1598 device float* dst = outputs + (out_tok + m) * rows + out_row;
1599 threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1600 dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1601 dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1602 }
1603 }
1604 }
1605}
1606
1607// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1608// Larger-tile variant of matmul_q8_mma for long-context prefill.
1609//
1610// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1611// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1612// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1613//
1614// simd_id = 2*sg_tok_idx + sg_row_half (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1615// output sub-tiles (tok, row):
1616// (sg_tok_idx*8, sg_row_half*16 + 0) -> C_a
1617// (sg_tok_idx*8, sg_row_half*16 + 8) -> C_b
1618//
1619// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1620// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1621// kernel — and halves the number of threadgroups vs the 16×16 tile.
1622//
1623// Assumptions (verified in dispatch helper, fallback otherwise):
1624// * cols % 32 == 0
1625// * rows % 32 == 0
1626constant constexpr uint MMA32_TOK_TILE = 32;
1627constant constexpr uint MMA32_ROW_TILE = 32;
1628
1629kernel void matmul_q8_mma32(
1630 device const uchar* matrix [[buffer(0)]],
1631 device const float* inputs [[buffer(1)]],
1632 device float* outputs [[buffer(2)]],
1633 constant uint& num_tokens [[buffer(3)]],
1634 constant uint& rows [[buffer(4)]],
1635 constant uint& cols [[buffer(5)]],
1636 uint2 tgid [[threadgroup_position_in_grid]],
1637 uint tid [[thread_index_in_threadgroup]],
1638 uint simd_id [[simdgroup_index_in_threadgroup]])
1639{
1640 uint row_base = tgid.x * MMA32_ROW_TILE;
1641 uint tok_base = tgid.y * MMA32_TOK_TILE;
1642 if (row_base >= rows || tok_base >= num_tokens) return;
1643
1644 // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1645 threadgroup float w_tile[MMA32_ROW_TILE * 32];
1646 threadgroup float t_tile[MMA32_TOK_TILE * 32];
1647
1648 uint sg_tok_idx = simd_id / 2; // 0..3
1649 uint sg_row_half = simd_id % 2; // 0..1
1650 uint sg_tok_base = sg_tok_idx * 8;
1651 uint sg_row_base_a = sg_row_half * 16 + 0;
1652 uint sg_row_base_b = sg_row_half * 16 + 8;
1653
1654 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1655 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1656
1657 uint blocks_per_row = cols / 32;
1658 uint row_bytes = blocks_per_row * 34;
1659
1660 for (uint blk = 0; blk < blocks_per_row; blk++) {
1661 // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1662 {
1663 uint base = tid * 4;
1664 for (uint ii = 0; ii < 4; ii++) {
1665 uint idx = base + ii;
1666 uint r = idx / 32;
1667 uint k = idx % 32;
1668 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1669 float sc = float(*(device const half*)rp);
1670 int ival = int(*(device const int8_t*)(rp + 2 + k));
1671 w_tile[r * 32 + k] = float(ival) * sc;
1672 }
1673 }
1674
1675 // Cooperative token tile load.
1676 {
1677 uint base = tid * 4;
1678 for (uint ii = 0; ii < 4; ii++) {
1679 uint idx = base + ii;
1680 uint m = idx / 32;
1681 uint k = idx % 32;
1682 uint tok = tok_base + m;
1683 t_tile[m * 32 + k] = (tok < num_tokens)
1684 ? inputs[tok * cols + blk * 32 + k]
1685 : 0.0f;
1686 }
1687 }
1688
1689 threadgroup_barrier(mem_flags::mem_threadgroup);
1690
1691 // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1692 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1693 simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1694 simdgroup_load(A,
1695 t_tile + sg_tok_base * 32 + k_sub * 8,
1696 32,
1697 ulong2(0, 0),
1698 false);
1699 simdgroup_load(B_a,
1700 w_tile + sg_row_base_a * 32 + k_sub * 8,
1701 32,
1702 ulong2(0, 0),
1703 true);
1704 simdgroup_load(B_b,
1705 w_tile + sg_row_base_b * 32 + k_sub * 8,
1706 32,
1707 ulong2(0, 0),
1708 true);
1709 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1710 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1711 }
1712
1713 threadgroup_barrier(mem_flags::mem_threadgroup);
1714 }
1715
1716 // Store both 8×8 accumulators. rows is always MMA32_ROW_TILE-aligned
1717 // (verified in dispatch), so full simdgroup_store is safe for the row
1718 // dimension; only the last token tile may be partial.
1719 uint out_tok = tok_base + sg_tok_base;
1720 uint out_row_a = row_base + sg_row_base_a;
1721 uint out_row_b = row_base + sg_row_base_b;
1722 bool full_tok = (out_tok + 8 <= num_tokens);
1723 if (full_tok) {
1724 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1725 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1726 } else if (out_tok < num_tokens) {
1727 threadgroup float scratch[8 * 2 * 64]; // 8 simdgroups × 2 accs × 64 floats
1728 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1729 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1730 simdgroup_barrier(mem_flags::mem_threadgroup);
1731 uint lane = tid % 32;
1732 if (lane == 0) {
1733 uint valid = num_tokens - out_tok; // 1..7
1734 for (uint m = 0; m < valid; m++) {
1735 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1736 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1737 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1738 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1739 for (uint j = 0; j < 8; j++) {
1740 dst_a[j] = src_a[j];
1741 dst_b[j] = src_b[j];
1742 }
1743 }
1744 }
1745 }
1746}
1747
1748// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1749// FP16 threadgroup-tile variant of matmul_q8_mma32.
1750//
1751// Stores dequantized weights and token inputs as `half` in threadgroup
1752// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1753// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1754// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1755// representation preserves the full quantized dynamic range. Token
1756// activations stay numerically safe because the subsequent
1757// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1758//
1759// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1760// accumulators each. Primary win vs mma32 is occupancy at moderate
1761// prefill lengths where the GPU is wave-starved.
1762kernel void matmul_q8_mma32_h(
1763 device const uchar* matrix [[buffer(0)]],
1764 device const float* inputs [[buffer(1)]],
1765 device float* outputs [[buffer(2)]],
1766 constant uint& num_tokens [[buffer(3)]],
1767 constant uint& rows [[buffer(4)]],
1768 constant uint& cols [[buffer(5)]],
1769 uint2 tgid [[threadgroup_position_in_grid]],
1770 uint tid [[thread_index_in_threadgroup]],
1771 uint simd_id [[simdgroup_index_in_threadgroup]])
1772{
1773 uint row_base = tgid.x * MMA32_ROW_TILE;
1774 uint tok_base = tgid.y * MMA32_TOK_TILE;
1775 if (row_base >= rows || tok_base >= num_tokens) return;
1776
1777 // 32×32 half tiles — 2 KB each, 4 KB total.
1778 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1779 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1780
1781 uint sg_tok_idx = simd_id / 2;
1782 uint sg_row_half = simd_id % 2;
1783 uint sg_tok_base = sg_tok_idx * 8;
1784 uint sg_row_base_a = sg_row_half * 16 + 0;
1785 uint sg_row_base_b = sg_row_half * 16 + 8;
1786
1787 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1788 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1789
1790 uint blocks_per_row = cols / 32;
1791 uint row_bytes = blocks_per_row * 34;
1792
1793 for (uint blk = 0; blk < blocks_per_row; blk++) {
1794 // Cooperative weight dequantization to FP16.
1795 {
1796 uint base = tid * 4;
1797 for (uint ii = 0; ii < 4; ii++) {
1798 uint idx = base + ii;
1799 uint r = idx / 32;
1800 uint k = idx % 32;
1801 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1802 float sc = float(*(device const half*)rp);
1803 int ival = int(*(device const int8_t*)(rp + 2 + k));
1804 w_tile[r * 32 + k] = half(float(ival) * sc);
1805 }
1806 }
1807
1808 // Cooperative token tile load (f32 → f16 narrowing).
1809 {
1810 uint base = tid * 4;
1811 for (uint ii = 0; ii < 4; ii++) {
1812 uint idx = base + ii;
1813 uint m = idx / 32;
1814 uint k = idx % 32;
1815 uint tok = tok_base + m;
1816 t_tile[m * 32 + k] = (tok < num_tokens)
1817 ? half(inputs[tok * cols + blk * 32 + k])
1818 : half(0);
1819 }
1820 }
1821
1822 threadgroup_barrier(mem_flags::mem_threadgroup);
1823
1824 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1825 simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1826 simdgroup_load(A,
1827 t_tile + sg_tok_base * 32 + k_sub * 8,
1828 32,
1829 ulong2(0, 0),
1830 false);
1831 simdgroup_load(B_a,
1832 w_tile + sg_row_base_a * 32 + k_sub * 8,
1833 32,
1834 ulong2(0, 0),
1835 true);
1836 simdgroup_load(B_b,
1837 w_tile + sg_row_base_b * 32 + k_sub * 8,
1838 32,
1839 ulong2(0, 0),
1840 true);
1841 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1842 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1843 }
1844
1845 threadgroup_barrier(mem_flags::mem_threadgroup);
1846 }
1847
1848 uint out_tok = tok_base + sg_tok_base;
1849 uint out_row_a = row_base + sg_row_base_a;
1850 uint out_row_b = row_base + sg_row_base_b;
1851 bool full_tok = (out_tok + 8 <= num_tokens);
1852 if (full_tok) {
1853 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1854 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1855 } else if (out_tok < num_tokens) {
1856 threadgroup float scratch[8 * 2 * 64];
1857 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1858 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1859 simdgroup_barrier(mem_flags::mem_threadgroup);
1860 uint lane = tid % 32;
1861 if (lane == 0) {
1862 uint valid = num_tokens - out_tok;
1863 for (uint m = 0; m < valid; m++) {
1864 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1865 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1866 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1867 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1868 for (uint j = 0; j < 8; j++) {
1869 dst_a[j] = src_a[j];
1870 dst_b[j] = src_b[j];
1871 }
1872 }
1873 }
1874 }
1875}
1876
1877// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1878// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1879//
1880// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1881// 4 simdgroups × **2×2 grid** of 8×8 accumulators each. Per simdgroup:
1882// C_00 (tok 0..8, row 0..8) C_01 (tok 0..8, row 8..16)
1883// C_10 (tok 8..16, row 0..8) C_11 (tok 8..16, row 8..16)
1884// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1885//
1886// Per K_sub iteration: load two A fragments and two B fragments, then run
1887// **four** MMA instructions reusing A_top with both B's and A_bot with
1888// both B's. That's double the FLOP-per-simdgroup-load compared to the
1889// 2-accumulator kernel and halves the thread count per threadgroup (128
1890// threads), which often improves occupancy on Apple GPUs where the
1891// concurrent-thread budget is the tighter limit than shared-memory size.
1892kernel void matmul_q8_mma32_h4(
1893 device const uchar* matrix [[buffer(0)]],
1894 device const float* inputs [[buffer(1)]],
1895 device float* outputs [[buffer(2)]],
1896 constant uint& num_tokens [[buffer(3)]],
1897 constant uint& rows [[buffer(4)]],
1898 constant uint& cols [[buffer(5)]],
1899 uint2 tgid [[threadgroup_position_in_grid]],
1900 uint tid [[thread_index_in_threadgroup]],
1901 uint simd_id [[simdgroup_index_in_threadgroup]])
1902{
1903 uint row_base = tgid.x * MMA32_ROW_TILE;
1904 uint tok_base = tgid.y * MMA32_TOK_TILE;
1905 if (row_base >= rows || tok_base >= num_tokens) return;
1906
1907 // 32×32 FP16 tiles, 4 KB total.
1908 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1909 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1910
1911 // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1912 uint sg_tok_q = simd_id / 2; // 0..1
1913 uint sg_row_q = simd_id % 2; // 0..1
1914 uint sg_tok_base = sg_tok_q * 16;
1915 uint sg_row_base = sg_row_q * 16;
1916
1917 simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1918 simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1919 simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1920 simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1921
1922 uint blocks_per_row = cols / 32;
1923 uint row_bytes = blocks_per_row * 34;
1924
1925 for (uint blk = 0; blk < blocks_per_row; blk++) {
1926 // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1927 {
1928 uint base = tid * 8;
1929 for (uint ii = 0; ii < 8; ii++) {
1930 uint idx = base + ii;
1931 uint r = idx / 32;
1932 uint k = idx % 32;
1933 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1934 float sc = float(*(device const half*)rp);
1935 int ival = int(*(device const int8_t*)(rp + 2 + k));
1936 w_tile[r * 32 + k] = half(float(ival) * sc);
1937 }
1938 }
1939
1940 // Cooperative token tile load.
1941 {
1942 uint base = tid * 8;
1943 for (uint ii = 0; ii < 8; ii++) {
1944 uint idx = base + ii;
1945 uint m = idx / 32;
1946 uint k = idx % 32;
1947 uint tok = tok_base + m;
1948 t_tile[m * 32 + k] = (tok < num_tokens)
1949 ? half(inputs[tok * cols + blk * 32 + k])
1950 : half(0);
1951 }
1952 }
1953
1954 threadgroup_barrier(mem_flags::mem_threadgroup);
1955
1956 // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1957 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1958 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1959 simdgroup_load(A_top,
1960 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1961 32,
1962 ulong2(0, 0),
1963 false);
1964 simdgroup_load(A_bot,
1965 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1966 32,
1967 ulong2(0, 0),
1968 false);
1969 simdgroup_load(B_lo,
1970 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1971 32,
1972 ulong2(0, 0),
1973 true);
1974 simdgroup_load(B_hi,
1975 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1976 32,
1977 ulong2(0, 0),
1978 true);
1979 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1980 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1981 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1982 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1983 }
1984
1985 threadgroup_barrier(mem_flags::mem_threadgroup);
1986 }
1987
1988 // Store 4 output tiles. Full-tile fast path assumes full 16×16 valid.
1989 uint out_tok_top = tok_base + sg_tok_base + 0;
1990 uint out_tok_bot = tok_base + sg_tok_base + 8;
1991 uint out_row_lo = row_base + sg_row_base + 0;
1992 uint out_row_hi = row_base + sg_row_base + 8;
1993 bool full = (out_tok_bot + 8 <= num_tokens);
1994 if (full) {
1995 simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1996 simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1997 simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1998 simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1999 } else {
2000 // Partial-token fallback via per-simdgroup scratch.
2001 threadgroup float scratch[4 * 4 * 64]; // 4 simdgroups × 4 accs × 64
2002 uint sg_base = simd_id * 256;
2003 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2004 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2005 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2006 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2007 simdgroup_barrier(mem_flags::mem_threadgroup);
2008 uint lane = tid % 32;
2009 if (lane == 0) {
2010 for (uint m = 0; m < 8; m++) {
2011 uint t_top = out_tok_top + m;
2012 if (t_top < num_tokens) {
2013 device float* dst0 = outputs + t_top * rows + out_row_lo;
2014 device float* dst1 = outputs + t_top * rows + out_row_hi;
2015 threadgroup const float* src0 = scratch + sg_base + 0 + m * 8;
2016 threadgroup const float* src1 = scratch + sg_base + 64 + m * 8;
2017 for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2018 }
2019 uint t_bot = out_tok_bot + m;
2020 if (t_bot < num_tokens) {
2021 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2022 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2023 threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2024 threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2025 for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2026 }
2027 }
2028 }
2029 }
2030}
2031
2032// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2033// All-half MMA variant of matmul_q8_mma32_h4.
2034//
2035// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2036// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2037// rate (dual-issue FMA), so if Q8_0 precision holds through half
2038// accumulation this kernel can double the effective FLOP throughput on
2039// matmul-bound prefill.
2040//
2041// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2042// activations at each layer are bounded (post-RMSNorm ≈ O(1)). Summing
2043// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2044// precision on extreme values, but the inputs are already quantized so the
2045// per-product error floor is higher than the half-precision rounding error.
2046// We verify correctness on 135M / 1B / 3B before enabling.
2047kernel void matmul_q8_mma32_hh4(
2048 device const uchar* matrix [[buffer(0)]],
2049 device const float* inputs [[buffer(1)]],
2050 device float* outputs [[buffer(2)]],
2051 constant uint& num_tokens [[buffer(3)]],
2052 constant uint& rows [[buffer(4)]],
2053 constant uint& cols [[buffer(5)]],
2054 uint2 tgid [[threadgroup_position_in_grid]],
2055 uint tid [[thread_index_in_threadgroup]],
2056 uint simd_id [[simdgroup_index_in_threadgroup]])
2057{
2058 uint row_base = tgid.x * MMA32_ROW_TILE;
2059 uint tok_base = tgid.y * MMA32_TOK_TILE;
2060 if (row_base >= rows || tok_base >= num_tokens) return;
2061
2062 threadgroup half w_tile[MMA32_ROW_TILE * 32];
2063 threadgroup half t_tile[MMA32_TOK_TILE * 32];
2064
2065 uint sg_tok_q = simd_id / 2;
2066 uint sg_row_q = simd_id % 2;
2067 uint sg_tok_base = sg_tok_q * 16;
2068 uint sg_row_base = sg_row_q * 16;
2069
2070 simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2071 simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2072 simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2073 simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2074
2075 uint blocks_per_row = cols / 32;
2076 uint row_bytes = blocks_per_row * 34;
2077
2078 for (uint blk = 0; blk < blocks_per_row; blk++) {
2079 {
2080 uint base = tid * 8;
2081 for (uint ii = 0; ii < 8; ii++) {
2082 uint idx = base + ii;
2083 uint r = idx / 32;
2084 uint k = idx % 32;
2085 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2086 float sc = float(*(device const half*)rp);
2087 int ival = int(*(device const int8_t*)(rp + 2 + k));
2088 w_tile[r * 32 + k] = half(float(ival) * sc);
2089 }
2090 }
2091 {
2092 uint base = tid * 8;
2093 for (uint ii = 0; ii < 8; ii++) {
2094 uint idx = base + ii;
2095 uint m = idx / 32;
2096 uint k = idx % 32;
2097 uint tok = tok_base + m;
2098 t_tile[m * 32 + k] = (tok < num_tokens)
2099 ? half(inputs[tok * cols + blk * 32 + k])
2100 : half(0);
2101 }
2102 }
2103 threadgroup_barrier(mem_flags::mem_threadgroup);
2104
2105 for (uint k_sub = 0; k_sub < 4; k_sub++) {
2106 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2107 simdgroup_load(A_top,
2108 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2109 32, ulong2(0, 0), false);
2110 simdgroup_load(A_bot,
2111 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2112 32, ulong2(0, 0), false);
2113 simdgroup_load(B_lo,
2114 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2115 32, ulong2(0, 0), true);
2116 simdgroup_load(B_hi,
2117 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2118 32, ulong2(0, 0), true);
2119 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2120 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2121 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2122 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2123 }
2124
2125 threadgroup_barrier(mem_flags::mem_threadgroup);
2126 }
2127
2128 // Store half accumulators via scratch (must widen to f32 for device output).
2129 uint out_tok_top = tok_base + sg_tok_base + 0;
2130 uint out_tok_bot = tok_base + sg_tok_base + 8;
2131 uint out_row_lo = row_base + sg_row_base + 0;
2132 uint out_row_hi = row_base + sg_row_base + 8;
2133
2134 threadgroup half scratch[4 * 4 * 64];
2135 uint sg_base = simd_id * 256;
2136 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2137 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2138 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2139 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2140 simdgroup_barrier(mem_flags::mem_threadgroup);
2141 uint lane = tid % 32;
2142 if (lane == 0) {
2143 for (uint m = 0; m < 8; m++) {
2144 uint t_top = out_tok_top + m;
2145 if (t_top < num_tokens) {
2146 device float* dst0 = outputs + t_top * rows + out_row_lo;
2147 device float* dst1 = outputs + t_top * rows + out_row_hi;
2148 threadgroup const half* src0 = scratch + sg_base + 0 + m * 8;
2149 threadgroup const half* src1 = scratch + sg_base + 64 + m * 8;
2150 for (uint j = 0; j < 8; j++) {
2151 dst0[j] = float(src0[j]);
2152 dst1[j] = float(src1[j]);
2153 }
2154 }
2155 uint t_bot = out_tok_bot + m;
2156 if (t_bot < num_tokens) {
2157 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2158 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2159 threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2160 threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2161 for (uint j = 0; j < 8; j++) {
2162 dst2[j] = float(src2[j]);
2163 dst3[j] = float(src3[j]);
2164 }
2165 }
2166 }
2167 }
2168}
2169
2170// ── add_bias_batch ─────────────────────────────────────────────────────
2171// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2172// Used for Qwen2 QKV bias after the fused qkv matmul.
2173// out[token, i] += bias[i] for i in 0..rows, token in 0..num_tokens
2174kernel void add_bias_batch(
2175 device float* out [[buffer(0)]], // [num_tokens, rows]
2176 device const float* bias [[buffer(1)]], // [rows]
2177 constant uint& num_tokens [[buffer(2)]],
2178 constant uint& rows [[buffer(3)]],
2179 uint id [[thread_position_in_grid]])
2180{
2181 uint total = num_tokens * rows;
2182 if (id >= total) return;
2183 uint i = id % rows;
2184 out[id] += bias[i];
2185}
2186
2187// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2188// Batched Q4_0 matrix-vector multiply for M input vectors.
2189// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2190kernel void matmul_vec_q4_batch(
2191 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes [rows, cols]
2192 device const float* inputs [[buffer(1)]], // [M, cols] input batch
2193 device float* outputs [[buffer(2)]], // [M, rows] output batch
2194 constant uint& num_tokens [[buffer(3)]], // M
2195 constant uint& rows [[buffer(4)]],
2196 constant uint& cols [[buffer(5)]],
2197 uint tgid [[threadgroup_position_in_grid]],
2198 uint tid [[thread_index_in_threadgroup]],
2199 uint simd_lane [[thread_index_in_simdgroup]],
2200 uint simd_id [[simdgroup_index_in_threadgroup]])
2201{
2202 uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2203 uint token = tgid / row_tgs;
2204 uint tg_in_token = tgid % row_tgs;
2205 if (token >= num_tokens) return;
2206
2207 threadgroup float vec_tile[VEC_TILE_SIZE];
2208 device const float* input = inputs + token * cols;
2209 for (uint i = tid; i < cols; i += 256) {
2210 vec_tile[i] = input[i];
2211 }
2212 threadgroup_barrier(mem_flags::mem_threadgroup);
2213
2214 uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2215 if (row_base >= rows) return;
2216
2217 uint blocks_per_row = cols / 32;
2218 uint row_bytes = blocks_per_row * 18;
2219
2220 device const uchar* r0 = matrix + row_base * row_bytes;
2221 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2222 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2223 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2224
2225 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2226
2227 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2228 uint bb = blk * 18;
2229 uint vb = blk * 32;
2230
2231 float sc0 = float(*(device const half*)(r0 + bb));
2232 float sc1 = float(*(device const half*)(r1 + bb));
2233 float sc2 = float(*(device const half*)(r2 + bb));
2234 float sc3 = float(*(device const half*)(r3 + bb));
2235
2236 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2237 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2238 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2239 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2240
2241 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2242 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2243 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2244 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2245 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2246 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2247 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2248 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2249
2250 float bd0=0, bd1=0, bd2=0, bd3=0;
2251 uchar4 b;
2252
2253 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2254 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2255 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2256 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2257
2258 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2259 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2260 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2261 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2262
2263 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2264 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2265 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2266 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2267
2268 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2269 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2270 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2271 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2272
2273 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2274 }
2275
2276 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2277 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2278
2279 device float* output = outputs + token * rows;
2280 if (simd_lane == 0) {
2281 if (row_base < rows) output[row_base] = sum0;
2282 if (row_base + 1 < rows) output[row_base + 1] = sum1;
2283 if (row_base + 2 < rows) output[row_base + 2] = sum2;
2284 if (row_base + 3 < rows) output[row_base + 3] = sum3;
2285 }
2286}
2287
2288// ── copy_kv_batch ─────────────────────────────────────────────────────
2289// Copy K or V from a strided batch QKV buffer to the KV cache.
2290// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2291// dst layout: contiguous [max_seq, kv_dim] cache.
2292kernel void copy_kv_batch(
2293 device const float* src [[buffer(0)]], // batch QKV buffer (f32)
2294 device half* dst [[buffer(1)]], // KV cache (f16)
2295 constant uint& M [[buffer(2)]], // num tokens in batch
2296 constant uint& kv_dim [[buffer(3)]], // elements per KV vector
2297 constant uint& base_pos [[buffer(4)]], // starting position in cache
2298 constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2299 constant uint& src_offset [[buffer(6)]], // float offset within each src row
2300 uint id [[thread_position_in_grid]])
2301{
2302 uint total = M * kv_dim;
2303 if (id >= total) return;
2304 uint token = id / kv_dim;
2305 uint d = id % kv_dim;
2306 uint dst_off = (base_pos + token) * kv_dim + d;
2307 uint src_off = token * src_stride + src_offset + d;
2308 dst[dst_off] = half(src[src_off]);
2309}
2310
2311// ── attention_batch ───────────────────────────────────────────────────
2312// Batched causal attention for prefill. Processes M tokens in one dispatch.
2313// Each threadgroup handles one (token, head) pair.
2314// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2315// Causal masking: token i can only attend to positions 0..base_pos+i.
2316kernel void attention_batch(
2317 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2318 device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
2319 device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
2320 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2321 constant uint& M [[buffer(4)]], // num tokens in batch
2322 constant uint& base_pos [[buffer(5)]], // starting position in KV cache
2323 constant uint& num_heads [[buffer(6)]],
2324 constant uint& num_kv_heads [[buffer(7)]],
2325 constant uint& head_dim [[buffer(8)]],
2326 constant uint& q_stride [[buffer(9)]], // floats per row in q_batch
2327 uint tgid [[threadgroup_position_in_grid]],
2328 uint tid [[thread_index_in_threadgroup]],
2329 uint simd_lane [[thread_index_in_simdgroup]],
2330 uint simd_id [[simdgroup_index_in_threadgroup]])
2331{
2332 // Grid: M * num_heads threadgroups
2333 uint token_idx = tgid / num_heads;
2334 uint head = tgid % num_heads;
2335 if (token_idx >= M) return;
2336
2337 uint kv_head = head / (num_heads / num_kv_heads);
2338 uint seq_len = base_pos + token_idx + 1; // causal: see positions 0..base_pos+token_idx
2339
2340 // Q offset uses strided layout (from batch QKV buffer)
2341 uint q_off = token_idx * q_stride + head * head_dim;
2342 // Output is contiguous [M, num_heads * head_dim]
2343 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2344
2345 // Shared memory for attention scores — sized to the effective max_seq_len
2346 // (4096 for all supported models) so long-context attention doesn't overflow.
2347 threadgroup float scores[ATTN_SCORES_SIZE];
2348
2349 // Step 1: Q * K^T with simdgroup reduction
2350 for (uint s = simd_id; s < seq_len; s += 8) {
2351 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2352 float dot = 0.0;
2353 for (uint d = simd_lane; d < head_dim; d += 32) {
2354 dot += q_batch[q_off + d] * k_cache[k_off + d];
2355 }
2356 dot = simd_sum(dot);
2357 if (simd_lane == 0) {
2358 scores[s] = dot * fast::rsqrt(float(head_dim));
2359 }
2360 }
2361 threadgroup_barrier(mem_flags::mem_threadgroup);
2362
2363 // Step 2: Softmax (cooperative)
2364 float local_max = -INFINITY;
2365 for (uint s = tid; s < seq_len; s += 256) {
2366 local_max = max(local_max, scores[s]);
2367 }
2368 local_max = simd_max(local_max);
2369 threadgroup float shared_max[8];
2370 if (simd_lane == 0) shared_max[simd_id] = local_max;
2371 threadgroup_barrier(mem_flags::mem_threadgroup);
2372 if (tid == 0) {
2373 float m = shared_max[0];
2374 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2375 shared_max[0] = m;
2376 }
2377 threadgroup_barrier(mem_flags::mem_threadgroup);
2378 float max_val = shared_max[0];
2379
2380 float local_sum = 0.0;
2381 for (uint s = tid; s < seq_len; s += 256) {
2382 scores[s] = fast::exp(scores[s] - max_val);
2383 local_sum += scores[s];
2384 }
2385 local_sum = simd_sum(local_sum);
2386 threadgroup float shared_sum[8];
2387 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2388 threadgroup_barrier(mem_flags::mem_threadgroup);
2389 if (tid == 0) {
2390 float total = 0.0;
2391 for (uint i = 0; i < 8; i++) total += shared_sum[i];
2392 shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2393 }
2394 threadgroup_barrier(mem_flags::mem_threadgroup);
2395 float inv_sum = shared_sum[0];
2396 for (uint s = tid; s < seq_len; s += 256) {
2397 scores[s] *= inv_sum;
2398 }
2399 threadgroup_barrier(mem_flags::mem_threadgroup);
2400
2401 // Step 3: scores * V using float4 vectorized loads
2402 // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2403 // This is much better than the scalar version where only 64 of 256 threads are active.
2404 uint v_stride = num_kv_heads * head_dim;
2405 uint head_dim4 = head_dim / 4;
2406 for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2407 uint d = d4 * 4;
2408 float4 acc = float4(0.0);
2409 uint v_base = kv_head * head_dim + d;
2410 uint seq_len4 = seq_len & ~3u;
2411 for (uint s = 0; s < seq_len4; s += 4) {
2412 float sc0 = scores[s];
2413 float sc1 = scores[s + 1];
2414 float sc2 = scores[s + 2];
2415 float sc3 = scores[s + 3];
2416 acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
2417 + sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
2418 + sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
2419 + sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
2420 }
2421 for (uint s = seq_len4; s < seq_len; s++) {
2422 acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
2423 }
2424 *(device float4*)(output_batch + out_off + d) = acc;
2425 }
2426 // Handle remaining dimensions not divisible by 4 (scalar fallback)
2427 for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2428 float acc = 0.0;
2429 uint v_base = kv_head * head_dim + d;
2430 for (uint s = 0; s < seq_len; s++) {
2431 acc += scores[s] * v_cache[s * v_stride + v_base];
2432 }
2433 output_batch[out_off + d] = acc;
2434 }
2435}
2436
2437// ── attention_flash_batch ─────────────────────────────────────────────
2438// Streaming attention with online softmax. Same grid as attention_batch
2439// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2440// matrix is never materialized. K/V positions are processed in a tile of
2441// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2442// the standard flash-attention recurrence:
2443//
2444// m_new = max(m_old, tile_max)
2445// alpha = exp(m_old - m_new)
2446// l_new = alpha * l_old + sum(exp(S - m_new))
2447// O_new = alpha * O_old + sum(exp(S - m_new) * V)
2448// O_final = O / l
2449//
2450// This removes the `scores[2048]` cap in attention_batch (which silently
2451// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2452// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2453//
2454// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2455constant constexpr uint FLASH_K_TILE = 32;
2456constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2457
2458kernel void attention_flash_batch(
2459 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2460 device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
2461 device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
2462 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2463 constant uint& M [[buffer(4)]],
2464 constant uint& base_pos [[buffer(5)]],
2465 constant uint& num_heads [[buffer(6)]],
2466 constant uint& num_kv_heads [[buffer(7)]],
2467 constant uint& head_dim [[buffer(8)]],
2468 constant uint& q_stride [[buffer(9)]],
2469 uint tgid [[threadgroup_position_in_grid]],
2470 uint tid [[thread_index_in_threadgroup]],
2471 uint simd_lane [[thread_index_in_simdgroup]],
2472 uint simd_id [[simdgroup_index_in_threadgroup]])
2473{
2474 uint token_idx = tgid / num_heads;
2475 uint head = tgid % num_heads;
2476 if (token_idx >= M) return;
2477
2478 uint kv_head = head / (num_heads / num_kv_heads);
2479 uint seq_len = base_pos + token_idx + 1; // causal: attend to [0, base_pos + token_idx]
2480 uint q_off = token_idx * q_stride + head * head_dim;
2481 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2482
2483 // Threadgroup state:
2484 // q_sh: Q vector for this (token, head), loaded once
2485 // o_sh: running output vector, updated each K tile
2486 // scores_sh: scores for the current K tile only
2487 // stats: [running max, running sum] (see flash-attention recurrence)
2488 threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2489 threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2490 threadgroup float scores_sh[FLASH_K_TILE];
2491 threadgroup float stats[2];
2492 threadgroup float sg_scratch[8]; // simdgroup-level reduction buffer
2493
2494 // --- Load Q (one row) and zero the running O ---
2495 for (uint d = tid; d < head_dim; d += 256) {
2496 q_sh[d] = q_batch[q_off + d];
2497 o_sh[d] = 0.0f;
2498 }
2499 if (tid == 0) {
2500 stats[0] = -INFINITY;
2501 stats[1] = 0.0f;
2502 }
2503 threadgroup_barrier(mem_flags::mem_threadgroup);
2504
2505 float scale = fast::rsqrt(float(head_dim));
2506 uint v_stride = num_kv_heads * head_dim;
2507 uint v_base = kv_head * head_dim;
2508
2509 // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2510 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2511 uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2512
2513 // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2514 // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2515 for (uint ti = simd_id; ti < tile_n; ti += 8) {
2516 uint k_off = (kv_base + ti) * v_stride + v_base; // same layout as V stride
2517 float dot = 0.0f;
2518 for (uint d = simd_lane; d < head_dim; d += 32) {
2519 dot += q_sh[d] * k_cache[k_off + d];
2520 }
2521 dot = simd_sum(dot);
2522 if (simd_lane == 0) {
2523 scores_sh[ti] = dot * scale;
2524 }
2525 }
2526 threadgroup_barrier(mem_flags::mem_threadgroup);
2527
2528 // [2] Tile max via cooperative reduction.
2529 float local_max = -INFINITY;
2530 for (uint s = tid; s < tile_n; s += 256) {
2531 local_max = max(local_max, scores_sh[s]);
2532 }
2533 local_max = simd_max(local_max);
2534 if (simd_lane == 0) {
2535 sg_scratch[simd_id] = local_max;
2536 }
2537 threadgroup_barrier(mem_flags::mem_threadgroup);
2538 // [3] Merge with running max, compute alpha, rescale running l.
2539 float m_new;
2540 float alpha;
2541 if (tid == 0) {
2542 float tile_max = sg_scratch[0];
2543 for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2544 float m_old = stats[0];
2545 m_new = max(m_old, tile_max);
2546 // First iteration: m_old = -inf → alpha = 0 (reset O).
2547 alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2548 stats[0] = m_new;
2549 stats[1] *= alpha;
2550 // Broadcast via sg_scratch.
2551 sg_scratch[0] = alpha;
2552 sg_scratch[1] = m_new;
2553 }
2554 threadgroup_barrier(mem_flags::mem_threadgroup);
2555 alpha = sg_scratch[0];
2556 m_new = sg_scratch[1];
2557
2558 // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2559 for (uint d = tid; d < head_dim; d += 256) {
2560 o_sh[d] *= alpha;
2561 }
2562 for (uint s = tid; s < tile_n; s += 256) {
2563 scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2564 }
2565 threadgroup_barrier(mem_flags::mem_threadgroup);
2566
2567 // [5] Tile sum → update running l.
2568 float local_sum = 0.0f;
2569 for (uint s = tid; s < tile_n; s += 256) {
2570 local_sum += scores_sh[s];
2571 }
2572 local_sum = simd_sum(local_sum);
2573 if (simd_lane == 0) {
2574 sg_scratch[simd_id] = local_sum;
2575 }
2576 threadgroup_barrier(mem_flags::mem_threadgroup);
2577 if (tid == 0) {
2578 float tile_sum = 0.0f;
2579 for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2580 stats[1] += tile_sum;
2581 }
2582 threadgroup_barrier(mem_flags::mem_threadgroup);
2583
2584 // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2585 for (uint d = tid; d < head_dim; d += 256) {
2586 float acc = 0.0f;
2587 for (uint s = 0; s < tile_n; s++) {
2588 acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2589 }
2590 o_sh[d] += acc;
2591 }
2592 threadgroup_barrier(mem_flags::mem_threadgroup);
2593 }
2594
2595 // --- Normalize and write output ---
2596 float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2597 for (uint d = tid; d < head_dim; d += 256) {
2598 output_batch[out_off + d] = o_sh[d] * inv_l;
2599 }
2600}
2601
2602// ── attention_mma_flash_batch ─────────────────────────────────────────
2603// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2604// both Q·K^T and P·V. Processes Q_BLOCK=8 tokens of one head per
2605// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2606// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2607// arithmetic.
2608//
2609// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2610// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2611// to attention_batch / attention_flash_batch otherwise.
2612//
2613// Online softmax recurrence is identical to attention_flash_batch but
2614// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2615constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2616constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2617constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2618
2619kernel void attention_mma_flash_batch(
2620 device const float* q_batch [[buffer(0)]],
2621 device const half* k_cache [[buffer(1)]],
2622 device const half* v_cache [[buffer(2)]],
2623 device float* output_batch [[buffer(3)]],
2624 constant uint& M [[buffer(4)]],
2625 constant uint& base_pos [[buffer(5)]],
2626 constant uint& num_heads [[buffer(6)]],
2627 constant uint& num_kv_heads [[buffer(7)]],
2628 constant uint& head_dim [[buffer(8)]],
2629 constant uint& q_stride [[buffer(9)]],
2630 uint2 tgid [[threadgroup_position_in_grid]],
2631 uint tid [[thread_index_in_threadgroup]],
2632 uint simd_lane [[thread_index_in_simdgroup]],
2633 uint simd_id [[simdgroup_index_in_threadgroup]])
2634{
2635 uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2636 uint head = tgid.y;
2637 if (q_block_start >= M) return;
2638 uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2639
2640 uint kv_head = head / (num_heads / num_kv_heads);
2641 // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2642 // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2643 uint seq_len = base_pos + q_block_start + q_valid;
2644 float scale = fast::rsqrt(float(head_dim));
2645
2646 uint kv_stride = num_kv_heads * head_dim;
2647 uint kv_base_off = kv_head * head_dim;
2648
2649 // ── Threadgroup memory ──
2650 // q_sh: [Q_BLOCK, head_dim] half — Q tile, loaded once
2651 // k_sh: [K_BLOCK, head_dim] half — K tile, refreshed per kv_base iter
2652 // v_sh: [K_BLOCK, head_dim] half — V tile, refreshed per kv_base iter
2653 // s_sh: [Q_BLOCK, K_BLOCK] float — raw Q·K^T scores, then scaled+masked
2654 // p_sh: [Q_BLOCK, K_BLOCK] half — softmax probabilities (for P·V MMA)
2655 // o_sh: [Q_BLOCK, head_dim] float — running output accumulator
2656 // m_sh: [Q_BLOCK] float — running max per Q row
2657 // l_sh: [Q_BLOCK] float — running softmax denominator per Q row
2658 // scratch: 4*Q_BLOCK floats — per-row reduction scratch
2659 threadgroup half q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2660 threadgroup half k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2661 threadgroup half v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2662 threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2663 threadgroup half p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2664 threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2665 threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2666 threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2667 threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2668
2669 // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2670 uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2671 for (uint i = tid; i < qblock_elems; i += 128) {
2672 uint q = i / head_dim;
2673 uint d = i % head_dim;
2674 if (q < q_valid) {
2675 uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2676 q_sh[q * head_dim + d] = half(q_batch[q_off]);
2677 } else {
2678 q_sh[q * head_dim + d] = half(0);
2679 }
2680 o_sh[q * head_dim + d] = 0.0f;
2681 }
2682 if (tid < FLASH_MMA_Q_BLOCK) {
2683 m_sh[tid] = -INFINITY;
2684 l_sh[tid] = 0.0f;
2685 }
2686 threadgroup_barrier(mem_flags::mem_threadgroup);
2687
2688 // ── Stream K/V in K_BLOCK chunks ──
2689 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2690 uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2691
2692 // Load K and V tile into TG memory (as half).
2693 uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2694 for (uint i = tid; i < kv_tile_elems; i += 128) {
2695 uint k_pos = i / head_dim;
2696 uint d = i % head_dim;
2697 if (k_pos < tile_n) {
2698 uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2699 k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2700 v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2701 } else {
2702 k_sh[k_pos * head_dim + d] = half(0);
2703 v_sh[k_pos * head_dim + d] = half(0);
2704 }
2705 }
2706 threadgroup_barrier(mem_flags::mem_threadgroup);
2707
2708 // ── Phase 1: S = Q @ K^T via MMA ──
2709 // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2710 // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2711 // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2712 {
2713 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2714 uint dim_chunks = head_dim / 8;
2715 for (uint dc = 0; dc < dim_chunks; dc++) {
2716 simdgroup_matrix<half, 8, 8> A, B;
2717 // A = Q[0:8, dc*8 : dc*8+8] (rows of Q, no transpose)
2718 simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2719 // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2720 // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2721 // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2722 // transpose=true, which places it in the register as K^T of that sub-block.
2723 simdgroup_load(B,
2724 k_sh + (simd_id * 8) * head_dim + dc * 8,
2725 head_dim, ulong2(0, 0), true);
2726 simdgroup_multiply_accumulate(C, A, B, C);
2727 }
2728 // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2729 simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2730 }
2731 threadgroup_barrier(mem_flags::mem_threadgroup);
2732
2733 // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2734 // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2735 uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2736 for (uint i = tid; i < s_elems; i += 128) {
2737 uint q = i / FLASH_MMA_K_BLOCK;
2738 uint k = i % FLASH_MMA_K_BLOCK;
2739 uint global_q = q_block_start + q;
2740 uint global_kv = kv_base + k;
2741 bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2742 s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2743 }
2744 threadgroup_barrier(mem_flags::mem_threadgroup);
2745
2746 // ── Phase 2b: per-row max via simdgroup reduction ──
2747 // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2748 // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2749 {
2750 uint row_base = simd_id * 2;
2751 for (uint qr = 0; qr < 2; qr++) {
2752 uint q = row_base + qr;
2753 float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2754 float row_max = simd_max(my);
2755 if (simd_lane == 0) {
2756 scratch[q] = row_max; // tile_max[q]
2757 }
2758 }
2759 }
2760 threadgroup_barrier(mem_flags::mem_threadgroup);
2761
2762 // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2763 if (tid < FLASH_MMA_Q_BLOCK) {
2764 uint q = tid;
2765 float m_old = m_sh[q];
2766 float tile_max = scratch[q];
2767 float m_new = max(m_old, tile_max);
2768 float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2769 m_sh[q] = m_new;
2770 l_sh[q] = l_sh[q] * alpha;
2771 // scratch[q] = m_new (for phase 2d)
2772 // scratch[Q_BLOCK + q] = alpha (for phase 3)
2773 scratch[q] = m_new;
2774 scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2775 }
2776 threadgroup_barrier(mem_flags::mem_threadgroup);
2777
2778 // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2779 {
2780 uint row_base = simd_id * 2;
2781 for (uint qr = 0; qr < 2; qr++) {
2782 uint q = row_base + qr;
2783 float m_new = scratch[q];
2784 float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2785 p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2786 float row_sum = simd_sum(p);
2787 if (simd_lane == 0) {
2788 scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2789 }
2790 }
2791 }
2792 threadgroup_barrier(mem_flags::mem_threadgroup);
2793
2794 // ── Phase 2e: l_sh += tile_sum ──
2795 if (tid < FLASH_MMA_Q_BLOCK) {
2796 uint q = tid;
2797 l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2798 }
2799
2800 // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2801 threadgroup_barrier(mem_flags::mem_threadgroup);
2802 for (uint i = tid; i < qblock_elems; i += 128) {
2803 uint q = i / head_dim;
2804 float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2805 o_sh[i] *= alpha;
2806 }
2807 threadgroup_barrier(mem_flags::mem_threadgroup);
2808
2809 // ── Phase 4: O += P @ V via MMA ──
2810 // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2811 // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2812 // For head_dim=64: 16 dims/sg = 2 tiles. head_dim=128: 32 dims/sg = 4 tiles.
2813 {
2814 uint dims_per_sg = head_dim / 4; // 16 or 32
2815 uint tiles_per_sg = dims_per_sg / 8; // 2 or 4
2816 uint sg_d_base = simd_id * dims_per_sg;
2817 for (uint t = 0; t < tiles_per_sg; t++) {
2818 uint d_base = sg_d_base + t * 8;
2819 simdgroup_matrix<float, 8, 8> O_acc;
2820 simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2821 uint k_chunks = FLASH_MMA_K_BLOCK / 8; // 4
2822 for (uint kc = 0; kc < k_chunks; kc++) {
2823 simdgroup_matrix<half, 8, 8> A, B;
2824 simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2825 ulong2(0, 0), false);
2826 simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2827 ulong2(0, 0), false);
2828 simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2829 }
2830 simdgroup_store(O_acc, o_sh + d_base, head_dim);
2831 }
2832 }
2833 threadgroup_barrier(mem_flags::mem_threadgroup);
2834 }
2835
2836 // ── Finalize: O /= l, write to output ──
2837 for (uint i = tid; i < qblock_elems; i += 128) {
2838 uint q = i / head_dim;
2839 uint d = i % head_dim;
2840 if (q < q_valid) {
2841 float l = l_sh[q];
2842 float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2843 uint token_idx = q_block_start + q;
2844 uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2845 output_batch[out_off] = o_sh[i] * inv_l;
2846 }
2847 }
2848}
2849
2850// ── rope_qk_batch ─────────────────────────────────────────────────────
2851// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2852// launch + memory barrier per layer. Both Q and K live in the same
2853// qkv_data buffer at different offsets within each token's row.
2854// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2855kernel void rope_qk_batch(
2856 device float* qkv_data [[buffer(0)]], // [M, qkv_stride]
2857 constant uint& M [[buffer(1)]], // num tokens
2858 constant uint& base_pos [[buffer(2)]], // starting position
2859 constant uint& num_q_heads [[buffer(3)]],
2860 constant uint& num_kv_heads [[buffer(4)]],
2861 constant uint& head_dim [[buffer(5)]],
2862 constant uint& qkv_stride [[buffer(6)]], // floats per row
2863 constant float& theta [[buffer(7)]],
2864 uint id [[thread_position_in_grid]])
2865{
2866 uint half_dim = head_dim / 2;
2867 uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2868 uint token = id / total_pairs;
2869 uint pair = id % total_pairs;
2870 if (token >= M) return;
2871
2872 uint pos = base_pos + token;
2873 uint q_pairs = num_q_heads * half_dim;
2874
2875 uint h, i, offset;
2876 if (pair < q_pairs) {
2877 // Q head
2878 h = pair / half_dim;
2879 i = pair % half_dim;
2880 offset = token * qkv_stride + h * head_dim + i * 2;
2881 } else {
2882 // K head
2883 uint kp = pair - q_pairs;
2884 h = kp / half_dim;
2885 i = kp % half_dim;
2886 uint k_start = num_q_heads * head_dim;
2887 offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2888 }
2889
2890 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2891 float angle = float(pos) * freq;
2892 float cos_val = cos(angle);
2893 float sin_val = sin(angle);
2894
2895 float x0 = qkv_data[offset];
2896 float x1 = qkv_data[offset + 1];
2897 qkv_data[offset] = x0 * cos_val - x1 * sin_val;
2898 qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2899}
2900
2901// ── copy_kv_both_batch ────────────────────────────────────────────────
2902// Fused K+V cache copy in a single dispatch: copies both K and V from
2903// the strided batch QKV buffer to their respective KV cache buffers.
2904// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2905kernel void copy_kv_both_batch(
2906 device const float* src [[buffer(0)]], // batch QKV buffer [M, qkv_stride] f32
2907 device half* k_dst [[buffer(1)]], // K cache [max_seq, kv_dim] f16
2908 device half* v_dst [[buffer(2)]], // V cache [max_seq, kv_dim] f16
2909 constant uint& M [[buffer(3)]], // num tokens in batch
2910 constant uint& kv_dim [[buffer(4)]], // elements per KV vector
2911 constant uint& base_pos [[buffer(5)]], // starting position in cache
2912 constant uint& src_stride [[buffer(6)]], // floats per row in src (qkv_stride)
2913 constant uint& k_offset [[buffer(7)]], // float offset of K within each src row
2914 constant uint& v_offset [[buffer(8)]], // float offset of V within each src row
2915 uint id [[thread_position_in_grid]])
2916{
2917 // Total elements = M * kv_dim * 2 (K + V)
2918 uint total_kv = M * kv_dim;
2919 if (id >= total_kv * 2) return;
2920
2921 uint is_v = id / total_kv; // 0 = K, 1 = V
2922 uint local_id = id % total_kv;
2923 uint token = local_id / kv_dim;
2924 uint d = local_id % kv_dim;
2925
2926 uint dst_off = (base_pos + token) * kv_dim + d;
2927 uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2928
2929 if (is_v) {
2930 v_dst[dst_off] = half(src[src_off]);
2931 } else {
2932 k_dst[dst_off] = half(src[src_off]);
2933 }
2934}
2935"#
2936 .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2937 .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2938}
2939
2940fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2945 let mut code = String::with_capacity(48 * 1024);
2946 emit_model_header(&mut code, config)?;
2947 emit_metal_model_struct(&mut code, config)?;
2948 emit_layer_buffers_struct(&mut code, config)?;
2949 emit_metal_model_impl(&mut code, config)?;
2950 emit_helper_functions(&mut code)?;
2951 Ok(code)
2952}
2953
2954fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2955 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2956 writeln!(
2957 code,
2958 "//! Model: {} ({} layers, hidden={})",
2959 config.architecture, config.num_layers, config.hidden_size
2960 )?;
2961 writeln!(code, "//!")?;
2962 writeln!(
2963 code,
2964 "//! Uses native Metal compute pipelines via the metal crate."
2965 )?;
2966 writeln!(code)?;
2967 writeln!(code, "#![allow(dead_code)]")?;
2968 writeln!(code)?;
2969 writeln!(code, "use metal::*;")?;
2970 writeln!(code, "#[allow(unused_imports)]")?;
2971 writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2972 writeln!(code, "use std::mem;")?;
2973 writeln!(code)?;
2974
2975 writeln!(
2977 code,
2978 "// ── Model constants ──────────────────────────────────"
2979 )?;
2980 writeln!(
2981 code,
2982 "pub const HIDDEN_SIZE: usize = {};",
2983 config.hidden_size
2984 )?;
2985 writeln!(
2986 code,
2987 "pub const INTERMEDIATE_SIZE: usize = {};",
2988 config.intermediate_size
2989 )?;
2990 writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2991 writeln!(
2992 code,
2993 "pub const NUM_HEADS: usize = {};",
2994 config.num_attention_heads
2995 )?;
2996 writeln!(
2997 code,
2998 "pub const NUM_KV_HEADS: usize = {};",
2999 config.num_kv_heads
3000 )?;
3001 writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
3002 writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
3003 let effective_seq_len = config.max_seq_len.min(4096);
3004 writeln!(
3005 code,
3006 "pub const MAX_SEQ_LEN: usize = {}; // capped from model's {}",
3007 effective_seq_len, config.max_seq_len
3008 )?;
3009 writeln!(
3010 code,
3011 "pub const RMS_NORM_EPS: f32 = {:e};",
3012 config.rms_norm_eps
3013 )?;
3014 writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3015 writeln!(
3016 code,
3017 "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3018 )?;
3019 writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3020 writeln!(code)?;
3021
3022 Ok(())
3023}
3024
3025fn emit_metal_model_struct(
3026 code: &mut String,
3027 config: &ModelConfig,
3028) -> Result<(), MetalCodegenError> {
3029 writeln!(
3030 code,
3031 "// ── MetalModel ──────────────────────────────────────────"
3032 )?;
3033 writeln!(code)?;
3034 writeln!(
3035 code,
3036 "/// Metal-accelerated transformer model for Apple Silicon."
3037 )?;
3038 writeln!(code, "///")?;
3039 writeln!(
3040 code,
3041 "/// Uses unified memory for zero-copy weight access and native Metal"
3042 )?;
3043 writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3044 writeln!(code, "pub struct MetalModel {{")?;
3045 writeln!(code, " device: Device,")?;
3046 writeln!(code, " queue: CommandQueue,")?;
3047 writeln!(code)?;
3048 writeln!(code, " // ── Compute pipelines ──")?;
3049 writeln!(code, " matmul_pipeline: ComputePipelineState,")?;
3050 writeln!(code, " matmul_q8_pipeline: ComputePipelineState,")?;
3051 writeln!(code, " matmul_q4_pipeline: ComputePipelineState,")?;
3052 writeln!(code, " rms_norm_pipeline: ComputePipelineState,")?;
3053 writeln!(code, " rope_pipeline: ComputePipelineState,")?;
3054 writeln!(code, " softmax_pipeline: ComputePipelineState,")?;
3055 writeln!(code, " silu_mul_pipeline: ComputePipelineState,")?;
3056 writeln!(code, " silu_mul_fused_pipeline: ComputePipelineState,")?;
3057 writeln!(code, " add_pipeline: ComputePipelineState,")?;
3058 writeln!(code, " attention_pipeline: ComputePipelineState,")?;
3059 writeln!(code, " add_inplace_pipeline: ComputePipelineState,")?;
3060 writeln!(code, " copy_pipeline: ComputePipelineState,")?;
3061 writeln!(code, " copy_offset_pipeline: ComputePipelineState,")?;
3062 writeln!(
3063 code,
3064 " copy_f32_to_f16_offset_pipeline: ComputePipelineState,"
3065 )?;
3066 writeln!(code)?;
3067 writeln!(code, " // ── Batched prefill pipelines ──")?;
3068 writeln!(code, " matmul_batch_pipeline: ComputePipelineState,")?;
3069 writeln!(code, " matmul_q8_batch_pipeline: ComputePipelineState,")?;
3070 writeln!(
3071 code,
3072 " matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3073 )?;
3074 writeln!(code, " matmul_q8_mma_pipeline: ComputePipelineState,")?;
3075 writeln!(code, " matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3076 writeln!(
3077 code,
3078 " matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3079 )?;
3080 writeln!(
3081 code,
3082 " matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3083 )?;
3084 writeln!(
3085 code,
3086 " matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3087 )?;
3088 if config.qkv_bias {
3089 writeln!(code, " add_bias_batch_pipeline: ComputePipelineState,")?;
3090 }
3091 writeln!(code, " matmul_q4_batch_pipeline: ComputePipelineState,")?;
3092 writeln!(code, " rms_norm_batch_pipeline: ComputePipelineState,")?;
3093 writeln!(code, " rope_batch_pipeline: ComputePipelineState,")?;
3094 writeln!(
3095 code,
3096 " silu_mul_fused_batch_pipeline: ComputePipelineState,"
3097 )?;
3098 writeln!(
3099 code,
3100 " add_inplace_batch_pipeline: ComputePipelineState,"
3101 )?;
3102 writeln!(
3103 code,
3104 " copy_embedding_batch_pipeline: ComputePipelineState,"
3105 )?;
3106 writeln!(code, " attention_batch_pipeline: ComputePipelineState,")?;
3107 writeln!(
3108 code,
3109 " attention_flash_batch_pipeline: ComputePipelineState,"
3110 )?;
3111 writeln!(
3112 code,
3113 " attention_mma_flash_batch_pipeline: ComputePipelineState,"
3114 )?;
3115 writeln!(code, " copy_kv_batch_pipeline: ComputePipelineState,")?;
3116 writeln!(code, " rope_qk_batch_pipeline: ComputePipelineState,")?;
3117 writeln!(
3118 code,
3119 " copy_kv_both_batch_pipeline: ComputePipelineState,"
3120 )?;
3121 writeln!(code)?;
3122 writeln!(code, " // ── Weight buffers (Metal shared memory) ──")?;
3123 writeln!(
3124 code,
3125 " /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3126 )?;
3127 writeln!(code, " embed_buf: Buffer,")?;
3128 writeln!(code)?;
3129 writeln!(code, " /// Per-layer weight buffers")?;
3130 writeln!(code, " layers: Vec<LayerBuffers>,")?;
3131 writeln!(code)?;
3132 writeln!(code, " /// Final layer-norm weight [HIDDEN_SIZE]")?;
3133 writeln!(code, " norm_buf: Buffer,")?;
3134 writeln!(code)?;
3135 writeln!(
3136 code,
3137 " /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3138 )?;
3139 writeln!(code, " lm_head_buf: Buffer,")?;
3140 writeln!(code)?;
3141 writeln!(
3142 code,
3143 " // ── Working buffers (pre-allocated, reused every forward pass) ──"
3144 )?;
3145 writeln!(code, " hidden_buf: Buffer,")?;
3146 writeln!(code, " residual_buf: Buffer,")?;
3147 writeln!(code, " normed_buf: Buffer,")?;
3148 writeln!(
3149 code,
3150 " /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3151 )?;
3152 writeln!(code, " qkv_buf: Buffer,")?;
3153 writeln!(code, " attn_out_buf: Buffer,")?;
3154 writeln!(code, " attn_proj_buf: Buffer,")?;
3155 writeln!(
3156 code,
3157 " /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3158 )?;
3159 writeln!(code, " gate_up_buf: Buffer,")?;
3160 writeln!(code, " ffn_hidden_buf: Buffer,")?;
3161 writeln!(code, " ffn_out_buf: Buffer,")?;
3162 writeln!(code, " add_tmp_buf: Buffer,")?;
3163 writeln!(code, " logits_buf: Buffer,")?;
3164 writeln!(code)?;
3165 writeln!(code, " // ── Batched prefill working buffers ──")?;
3166 writeln!(code, " /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3167 writeln!(code, " batch_hidden_buf: Buffer,")?;
3168 writeln!(
3169 code,
3170 " /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3171 )?;
3172 writeln!(code, " batch_residual_buf: Buffer,")?;
3173 writeln!(code, " /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3174 writeln!(code, " batch_qkv_buf: Buffer,")?;
3175 writeln!(
3176 code,
3177 " /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3178 )?;
3179 writeln!(code, " batch_attn_out_buf: Buffer,")?;
3180 writeln!(
3181 code,
3182 " /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3183 )?;
3184 writeln!(code, " batch_attn_proj_buf: Buffer,")?;
3185 writeln!(
3186 code,
3187 " /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3188 )?;
3189 writeln!(code, " batch_gate_up_buf: Buffer,")?;
3190 writeln!(
3191 code,
3192 " /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3193 )?;
3194 writeln!(code, " batch_ffn_hidden_buf: Buffer,")?;
3195 writeln!(code, " /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3196 writeln!(code, " batch_ffn_out_buf: Buffer,")?;
3197 writeln!(code, " /// Token IDs buffer for batch embedding lookup")?;
3198 writeln!(code, " batch_tokens_buf: Buffer,")?;
3199 writeln!(code, " /// Positions buffer for batch RoPE")?;
3200 writeln!(code, " batch_positions_buf: Buffer,")?;
3201 writeln!(code)?;
3202 writeln!(code, " // ── KV cache buffers (per-layer) ──")?;
3203 writeln!(code, " k_cache: Vec<Buffer>, // per-layer")?;
3204 writeln!(code, " v_cache: Vec<Buffer>, // per-layer")?;
3205 writeln!(code)?;
3206 writeln!(code, " // ── Inference state ──")?;
3207 writeln!(code, " pos: usize,")?;
3208 writeln!(code)?;
3209 writeln!(
3210 code,
3211 " /// Previous command buffer for double-buffered prefill."
3212 )?;
3213 writeln!(
3214 code,
3215 " /// While the GPU executes token N, the CPU can encode token N+1."
3216 )?;
3217 writeln!(code, " prev_cmd: Option<CommandBuffer>,")?;
3218 writeln!(code, "}}")?;
3219 writeln!(code)?;
3220
3221 Ok(())
3222}
3223
3224fn emit_layer_buffers_struct(
3225 code: &mut String,
3226 config: &ModelConfig,
3227) -> Result<(), MetalCodegenError> {
3228 writeln!(
3229 code,
3230 "/// Per-layer weight buffers for attention and FFN projections."
3231 )?;
3232 writeln!(code, "struct LayerBuffers {{")?;
3233 writeln!(code, " attn_norm: Buffer,")?;
3234 writeln!(
3235 code,
3236 " /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3237 )?;
3238 writeln!(code, " qkv_weight: Buffer,")?;
3239 if config.qkv_bias {
3240 writeln!(
3241 code,
3242 " /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3243 )?;
3244 writeln!(code, " qkv_bias: Buffer,")?;
3245 }
3246 writeln!(code, " o_weight: Buffer,")?;
3247 writeln!(code, " ffn_norm: Buffer,")?;
3248 writeln!(
3249 code,
3250 " /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3251 )?;
3252 writeln!(code, " gate_up_weight: Buffer,")?;
3253 writeln!(code, " down_weight: Buffer,")?;
3254 writeln!(code, "}}")?;
3255 writeln!(code)?;
3256
3257 Ok(())
3258}
3259
3260fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3261 let hidden = config.hidden_size;
3262 let intermediate = config.intermediate_size;
3263 let _num_layers = config.num_layers;
3264 let num_heads = config.num_attention_heads;
3265 let num_kv_heads = config.num_kv_heads;
3266 let head_dim = config.head_dim;
3267 let vocab = config.vocab_size;
3268 let effective_seq_len = config.max_seq_len.min(4096);
3269 let is_q8 = config.dtype == DType::Q8_0;
3270 let is_q4 = config.dtype == DType::Q4_0;
3271 let kv_dim = num_kv_heads * head_dim;
3272
3273 writeln!(code, "impl MetalModel {{")?;
3274
3275 writeln!(
3277 code,
3278 " /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3279 )?;
3280 writeln!(code, " ///")?;
3281 writeln!(
3282 code,
3283 " /// `weights` is the raw weight blob produced by `forge export-weights`."
3284 )?;
3285 writeln!(code, " pub fn new(weights: &[u8]) -> Self {{")?;
3286 writeln!(
3287 code,
3288 " let device = Device::system_default().expect(\"no Metal device found\");"
3289 )?;
3290 writeln!(code, " let queue = device.new_command_queue();")?;
3291 writeln!(code)?;
3292
3293 writeln!(
3295 code,
3296 " // Compile Metal shaders from embedded source"
3297 )?;
3298 writeln!(
3299 code,
3300 " let shader_source = include_str!(\"../shaders/kernels.metal\");"
3301 )?;
3302 writeln!(
3303 code,
3304 " let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3305 )?;
3306 writeln!(
3307 code,
3308 " .expect(\"failed to compile Metal shaders\");"
3309 )?;
3310 writeln!(code)?;
3311
3312 writeln!(code, " // Create compute pipelines")?;
3314 for (var, fn_name) in [
3315 ("matmul_pipeline", "matmul_vec"),
3316 ("matmul_q8_pipeline", "matmul_vec_q8"),
3317 ("matmul_q4_pipeline", "matmul_vec_q4"),
3318 ("rms_norm_pipeline", "rms_norm"),
3319 ("rope_pipeline", "rope"),
3320 ("softmax_pipeline", "softmax"),
3321 ("silu_mul_pipeline", "silu_mul"),
3322 ("silu_mul_fused_pipeline", "silu_mul_fused"),
3323 ("add_pipeline", "elementwise_add"),
3324 ("attention_pipeline", "attention"),
3325 ("add_inplace_pipeline", "add_inplace"),
3326 ("copy_pipeline", "copy_buffer"),
3327 ("copy_offset_pipeline", "copy_offset"),
3328 ("copy_f32_to_f16_offset_pipeline", "copy_f32_to_f16_offset"),
3329 ("matmul_batch_pipeline", "matmul_vec_batch"),
3330 ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3331 ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3332 ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3333 ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3334 ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3335 ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3336 ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3337 ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3338 ("rms_norm_batch_pipeline", "rms_norm_batch"),
3339 ("rope_batch_pipeline", "rope_batch"),
3340 ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3341 ("add_inplace_batch_pipeline", "add_inplace_batch"),
3342 ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3343 ("attention_batch_pipeline", "attention_batch"),
3344 ("attention_flash_batch_pipeline", "attention_flash_batch"),
3345 (
3346 "attention_mma_flash_batch_pipeline",
3347 "attention_mma_flash_batch",
3348 ),
3349 ("copy_kv_batch_pipeline", "copy_kv_batch"),
3350 ("rope_qk_batch_pipeline", "rope_qk_batch"),
3351 ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3352 ] {
3353 writeln!(
3354 code,
3355 " let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3356 )?;
3357 }
3358 if config.qkv_bias {
3359 writeln!(
3360 code,
3361 " let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3362 )?;
3363 }
3364 writeln!(code)?;
3365
3366 writeln!(
3368 code,
3369 " // Load weights into Metal shared-memory buffers"
3370 )?;
3371 writeln!(code, " let f32_size = mem::size_of::<f32>();")?;
3372 writeln!(code, " let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3373 writeln!(code, " let hidden_elems = HIDDEN_SIZE;")?;
3374 writeln!(code)?;
3375 writeln!(
3376 code,
3377 " let cursor = std::cell::Cell::new(0usize); // byte cursor into `weights`"
3378 )?;
3379 writeln!(code)?;
3380 writeln!(
3381 code,
3382 " // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3383 )?;
3384 writeln!(
3385 code,
3386 " let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3387 )?;
3388 writeln!(code, " let byte_len = n * f32_size;")?;
3389 writeln!(code, " let cur = cursor.get();")?;
3390 writeln!(
3391 code,
3392 " let data = &weights[cur..cur + byte_len];"
3393 )?;
3394 writeln!(code, " cursor.set(cur + byte_len);")?;
3395 writeln!(code, " device.new_buffer_with_data(")?;
3396 writeln!(code, " data.as_ptr() as *const _,")?;
3397 writeln!(code, " byte_len as u64,")?;
3398 writeln!(
3399 code,
3400 " MTLResourceOptions::StorageModeShared,"
3401 )?;
3402 writeln!(code, " )")?;
3403 writeln!(code, " }};")?;
3404 writeln!(code)?;
3405
3406 if is_q8 {
3407 writeln!(
3412 code,
3413 " // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3414 )?;
3415 writeln!(
3416 code,
3417 " // as raw bytes into a Metal buffer (no dequantization)."
3418 )?;
3419 writeln!(
3420 code,
3421 " // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3422 )?;
3423 writeln!(
3424 code,
3425 " let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3426 )?;
3427 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3428 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3429 writeln!(code, " let total_raw = rows * row_bytes;")?;
3430 writeln!(code, " let cur = cursor.get();")?;
3431 writeln!(
3432 code,
3433 " let data = &weights[cur..cur + total_raw];"
3434 )?;
3435 writeln!(code, " cursor.set(cur + total_raw);")?;
3436 writeln!(code, " device.new_buffer_with_data(")?;
3437 writeln!(code, " data.as_ptr() as *const _,")?;
3438 writeln!(code, " total_raw as u64,")?;
3439 writeln!(
3440 code,
3441 " MTLResourceOptions::StorageModeShared,"
3442 )?;
3443 writeln!(code, " )")?;
3444 writeln!(code, " }};")?;
3445 writeln!(code)?;
3446 writeln!(
3447 code,
3448 " // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3449 )?;
3450 writeln!(
3451 code,
3452 " // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3453 )?;
3454 writeln!(
3455 code,
3456 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3457 )?;
3458 writeln!(
3459 code,
3460 " let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3461 )?;
3462 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3463 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3464 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3465 writeln!(code, " let cur = cursor.get();")?;
3466 writeln!(
3467 code,
3468 " let data = &weights[cur..cur + total_raw];"
3469 )?;
3470 writeln!(code, " cursor.set(cur + total_raw);")?;
3471 writeln!(code, " device.new_buffer_with_data(")?;
3472 writeln!(code, " data.as_ptr() as *const _,")?;
3473 writeln!(code, " total_raw as u64,")?;
3474 writeln!(
3475 code,
3476 " MTLResourceOptions::StorageModeShared,"
3477 )?;
3478 writeln!(code, " )")?;
3479 writeln!(code, " }};")?;
3480 writeln!(code)?;
3481 }
3482
3483 if is_q4 {
3484 writeln!(
3489 code,
3490 " // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3491 )?;
3492 writeln!(
3493 code,
3494 " // as raw bytes into a Metal buffer (no dequantization)."
3495 )?;
3496 writeln!(
3497 code,
3498 " // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3499 )?;
3500 writeln!(
3501 code,
3502 " let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3503 )?;
3504 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3505 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3506 writeln!(code, " let total_raw = rows * row_bytes;")?;
3507 writeln!(code, " let cur = cursor.get();")?;
3508 writeln!(
3509 code,
3510 " let data = &weights[cur..cur + total_raw];"
3511 )?;
3512 writeln!(code, " cursor.set(cur + total_raw);")?;
3513 writeln!(code, " device.new_buffer_with_data(")?;
3514 writeln!(code, " data.as_ptr() as *const _,")?;
3515 writeln!(code, " total_raw as u64,")?;
3516 writeln!(
3517 code,
3518 " MTLResourceOptions::StorageModeShared,"
3519 )?;
3520 writeln!(code, " )")?;
3521 writeln!(code, " }};")?;
3522 writeln!(code)?;
3523 writeln!(
3524 code,
3525 " // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3526 )?;
3527 writeln!(
3528 code,
3529 " // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3530 )?;
3531 writeln!(
3532 code,
3533 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3534 )?;
3535 writeln!(
3536 code,
3537 " let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3538 )?;
3539 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3540 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3541 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3542 writeln!(code, " let cur = cursor.get();")?;
3543 writeln!(
3544 code,
3545 " let data = &weights[cur..cur + total_raw];"
3546 )?;
3547 writeln!(code, " cursor.set(cur + total_raw);")?;
3548 writeln!(code, " device.new_buffer_with_data(")?;
3549 writeln!(code, " data.as_ptr() as *const _,")?;
3550 writeln!(code, " total_raw as u64,")?;
3551 writeln!(
3552 code,
3553 " MTLResourceOptions::StorageModeShared,"
3554 )?;
3555 writeln!(code, " )")?;
3556 writeln!(code, " }};")?;
3557 writeln!(code)?;
3558 }
3559
3560 writeln!(
3561 code,
3562 " let embed_buf = next_f32_buffer(&device, embed_elems);"
3563 )?;
3564 writeln!(code)?;
3565
3566 writeln!(
3568 code,
3569 " let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3570 )?;
3571 writeln!(code, " for _layer in 0..NUM_LAYERS {{")?;
3572
3573 writeln!(
3575 code,
3576 " let attn_norm = next_f32_buffer(&device, hidden_elems);"
3577 )?;
3578
3579 let qkv_rows = hidden + 2 * kv_dim;
3580 if is_q8 {
3581 writeln!(
3583 code,
3584 " let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3585 )?;
3586 if config.qkv_bias {
3587 writeln!(
3588 code,
3589 " // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3590 )?;
3591 writeln!(
3592 code,
3593 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3594 )?;
3595 }
3596 writeln!(
3597 code,
3598 " let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3599 )?;
3600 } else if is_q4 {
3601 writeln!(
3602 code,
3603 " let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3604 )?;
3605 if config.qkv_bias {
3606 writeln!(
3607 code,
3608 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3609 )?;
3610 }
3611 writeln!(
3612 code,
3613 " let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3614 )?;
3615 } else {
3616 writeln!(
3617 code,
3618 " let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3619 )?;
3620 if config.qkv_bias {
3621 writeln!(
3622 code,
3623 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3624 )?;
3625 }
3626 writeln!(
3627 code,
3628 " let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3629 )?;
3630 }
3631
3632 writeln!(
3634 code,
3635 " let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3636 )?;
3637
3638 let gate_up_rows = 2 * intermediate;
3639 if is_q8 {
3640 writeln!(
3642 code,
3643 " let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3644 )?;
3645 writeln!(
3646 code,
3647 " let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3648 )?;
3649 } else if is_q4 {
3650 writeln!(
3652 code,
3653 " let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3654 )?;
3655 writeln!(
3656 code,
3657 " let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3658 )?;
3659 } else {
3660 writeln!(
3662 code,
3663 " let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3664 )?;
3665 writeln!(
3666 code,
3667 " let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3668 )?;
3669 }
3670
3671 writeln!(code, " layers.push(LayerBuffers {{")?;
3672 writeln!(code, " attn_norm,")?;
3673 writeln!(code, " qkv_weight,")?;
3674 if config.qkv_bias {
3675 writeln!(code, " qkv_bias,")?;
3676 }
3677 writeln!(code, " o_weight,")?;
3678 writeln!(code, " ffn_norm,")?;
3679 writeln!(code, " gate_up_weight,")?;
3680 writeln!(code, " down_weight,")?;
3681 writeln!(code, " }});")?;
3682 writeln!(code, " }}")?;
3683 writeln!(code)?;
3684
3685 writeln!(
3687 code,
3688 " let norm_buf = next_f32_buffer(&device, hidden_elems);"
3689 )?;
3690 writeln!(code)?;
3691
3692 if is_q8 {
3694 writeln!(
3695 code,
3696 " let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3697 )?;
3698 } else if is_q4 {
3699 writeln!(
3700 code,
3701 " let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3702 )?;
3703 } else {
3704 writeln!(
3705 code,
3706 " let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3707 )?;
3708 }
3709 writeln!(code)?;
3710
3711 let hidden_bytes = hidden * 4;
3713 let _kv_dim_bytes = kv_dim * 4;
3714 let intermediate_bytes = intermediate * 4;
3715 let vocab_bytes = vocab * 4;
3716 let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 2;
3719
3720 writeln!(
3721 code,
3722 " // Allocate working buffers (shared memory for zero-copy)"
3723 )?;
3724 writeln!(
3725 code,
3726 " let opts = MTLResourceOptions::StorageModeShared;"
3727 )?;
3728 writeln!(
3729 code,
3730 " let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3731 )?;
3732 writeln!(
3733 code,
3734 " let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3735 )?;
3736 let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3737 writeln!(
3738 code,
3739 " let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3740 )?;
3741 writeln!(
3742 code,
3743 " // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3744 )?;
3745 writeln!(
3746 code,
3747 " let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3748 )?;
3749 writeln!(
3750 code,
3751 " let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3752 )?;
3753 writeln!(
3754 code,
3755 " let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3756 )?;
3757 let gate_up_buf_bytes = 2 * intermediate * 4;
3758 writeln!(
3759 code,
3760 " // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3761 )?;
3762 writeln!(
3763 code,
3764 " let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3765 )?;
3766 writeln!(
3767 code,
3768 " let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3769 )?;
3770 writeln!(
3771 code,
3772 " let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3773 )?;
3774 writeln!(
3775 code,
3776 " let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3777 )?;
3778 writeln!(
3779 code,
3780 " let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3781 )?;
3782 writeln!(code)?;
3783
3784 let batch_hidden_bytes = hidden * 4; let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3787 let batch_gate_up_bytes = 2 * intermediate * 4;
3788 let batch_intermediate_bytes = intermediate * 4;
3789 writeln!(
3790 code,
3791 " // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3792 )?;
3793 writeln!(
3794 code,
3795 " let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3796 )?;
3797 writeln!(
3798 code,
3799 " let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3800 )?;
3801 writeln!(
3802 code,
3803 " let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3804 )?;
3805 writeln!(
3806 code,
3807 " let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3808 )?;
3809 writeln!(
3810 code,
3811 " let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3812 )?;
3813 writeln!(
3814 code,
3815 " let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3816 )?;
3817 writeln!(
3818 code,
3819 " let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3820 )?;
3821 writeln!(
3822 code,
3823 " let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3824 )?;
3825 writeln!(
3826 code,
3827 " let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3828 )?;
3829 writeln!(
3830 code,
3831 " let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3832 )?;
3833 writeln!(code)?;
3834
3835 writeln!(code, " // KV cache buffers (per-layer)")?;
3837 writeln!(
3838 code,
3839 " let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3840 )?;
3841 writeln!(
3842 code,
3843 " let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3844 )?;
3845 writeln!(code, " for _ in 0..NUM_LAYERS {{")?;
3846 writeln!(
3847 code,
3848 " k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3849 )?;
3850 writeln!(
3851 code,
3852 " v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3853 )?;
3854 writeln!(code, " }}")?;
3855 writeln!(code)?;
3856
3857 writeln!(code, " Self {{")?;
3858 writeln!(code, " device,")?;
3859 writeln!(code, " queue,")?;
3860 writeln!(code, " matmul_pipeline,")?;
3861 writeln!(code, " matmul_q8_pipeline,")?;
3862 writeln!(code, " matmul_q4_pipeline,")?;
3863 writeln!(code, " rms_norm_pipeline,")?;
3864 writeln!(code, " rope_pipeline,")?;
3865 writeln!(code, " softmax_pipeline,")?;
3866 writeln!(code, " silu_mul_pipeline,")?;
3867 writeln!(code, " silu_mul_fused_pipeline,")?;
3868 writeln!(code, " add_pipeline,")?;
3869 writeln!(code, " attention_pipeline,")?;
3870 writeln!(code, " add_inplace_pipeline,")?;
3871 writeln!(code, " copy_pipeline,")?;
3872 writeln!(code, " copy_offset_pipeline,")?;
3873 writeln!(code, " copy_f32_to_f16_offset_pipeline,")?;
3874 writeln!(code, " matmul_batch_pipeline,")?;
3875 writeln!(code, " matmul_q8_batch_pipeline,")?;
3876 writeln!(code, " matmul_q8_gemm_batch_pipeline,")?;
3877 writeln!(code, " matmul_q8_mma_pipeline,")?;
3878 writeln!(code, " matmul_q8_mma32_pipeline,")?;
3879 writeln!(code, " matmul_q8_mma32_h_pipeline,")?;
3880 writeln!(code, " matmul_q8_mma32_h4_pipeline,")?;
3881 writeln!(code, " matmul_q8_mma32_hh4_pipeline,")?;
3882 if config.qkv_bias {
3883 writeln!(code, " add_bias_batch_pipeline,")?;
3884 }
3885 writeln!(code, " matmul_q4_batch_pipeline,")?;
3886 writeln!(code, " rms_norm_batch_pipeline,")?;
3887 writeln!(code, " rope_batch_pipeline,")?;
3888 writeln!(code, " silu_mul_fused_batch_pipeline,")?;
3889 writeln!(code, " add_inplace_batch_pipeline,")?;
3890 writeln!(code, " copy_embedding_batch_pipeline,")?;
3891 writeln!(code, " attention_batch_pipeline,")?;
3892 writeln!(code, " attention_flash_batch_pipeline,")?;
3893 writeln!(code, " attention_mma_flash_batch_pipeline,")?;
3894 writeln!(code, " copy_kv_batch_pipeline,")?;
3895 writeln!(code, " rope_qk_batch_pipeline,")?;
3896 writeln!(code, " copy_kv_both_batch_pipeline,")?;
3897 writeln!(code, " embed_buf,")?;
3898 writeln!(code, " layers,")?;
3899 writeln!(code, " norm_buf,")?;
3900 writeln!(code, " lm_head_buf,")?;
3901 writeln!(code, " hidden_buf,")?;
3902 writeln!(code, " residual_buf,")?;
3903 writeln!(code, " normed_buf,")?;
3904 writeln!(code, " qkv_buf,")?;
3905 writeln!(code, " attn_out_buf,")?;
3906 writeln!(code, " attn_proj_buf,")?;
3907 writeln!(code, " gate_up_buf,")?;
3908 writeln!(code, " ffn_hidden_buf,")?;
3909 writeln!(code, " ffn_out_buf,")?;
3910 writeln!(code, " add_tmp_buf,")?;
3911 writeln!(code, " logits_buf,")?;
3912 writeln!(code, " batch_hidden_buf,")?;
3913 writeln!(code, " batch_residual_buf,")?;
3914 writeln!(code, " batch_qkv_buf,")?;
3915 writeln!(code, " batch_attn_out_buf,")?;
3916 writeln!(code, " batch_attn_proj_buf,")?;
3917 writeln!(code, " batch_gate_up_buf,")?;
3918 writeln!(code, " batch_ffn_hidden_buf,")?;
3919 writeln!(code, " batch_ffn_out_buf,")?;
3920 writeln!(code, " batch_tokens_buf,")?;
3921 writeln!(code, " batch_positions_buf,")?;
3922 writeln!(code, " k_cache,")?;
3923 writeln!(code, " v_cache,")?;
3924 writeln!(code, " pos: 0,")?;
3925 writeln!(code, " prev_cmd: None,")?;
3926 writeln!(code, " }}")?;
3927 writeln!(code, " }}")?;
3928 writeln!(code)?;
3929
3930 writeln!(
3932 code,
3933 " /// Run the forward pass for a single token at the current position."
3934 )?;
3935 writeln!(code, " ///")?;
3936 writeln!(
3937 code,
3938 " /// Returns logits over the vocabulary as a `Vec<f32>`."
3939 )?;
3940 writeln!(code, " ///")?;
3941 writeln!(
3942 code,
3943 " /// All GPU operations are encoded into a single command buffer and"
3944 )?;
3945 writeln!(
3946 code,
3947 " /// committed once at the end, avoiding per-operation synchronization."
3948 )?;
3949 writeln!(
3950 code,
3951 " pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3952 )?;
3953 writeln!(
3954 code,
3955 " // Wait for any pending prefill command buffer"
3956 )?;
3957 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
3958 writeln!(code, " prev.wait_until_completed();")?;
3959 writeln!(code, " }}")?;
3960 writeln!(code)?;
3961 writeln!(code, " let pos = self.pos;")?;
3962 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
3963 writeln!(code)?;
3964
3965 let matmul_fn = if is_q8 {
3968 "dispatch_matmul_q8"
3969 } else if is_q4 {
3970 "dispatch_matmul_q4"
3971 } else {
3972 "dispatch_matmul"
3973 };
3974
3975 writeln!(
3976 code,
3977 " // Single compute encoder for the entire forward pass (no blit transitions)"
3978 )?;
3979 writeln!(code, " {{")?;
3980 writeln!(
3981 code,
3982 " let enc = cmd.new_compute_command_encoder();"
3983 )?;
3984 writeln!(code)?;
3985
3986 writeln!(
3988 code,
3989 " // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3990 )?;
3991 writeln!(
3992 code,
3993 " // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3994 )?;
3995 writeln!(
3996 code,
3997 " // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3998 )?;
3999 writeln!(
4000 code,
4001 " // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
4002 hidden * 4,
4003 )?;
4004 writeln!(code, " unsafe {{")?;
4005 writeln!(
4006 code,
4007 " let embed_ptr = self.embed_buf.contents() as *const f32;"
4008 )?;
4009 writeln!(
4010 code,
4011 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4012 )?;
4013 writeln!(
4014 code,
4015 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
4016 )?;
4017 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
4018 writeln!(
4019 code,
4020 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4021 )?;
4022 writeln!(code, " hidden_ptr,")?;
4023 writeln!(code, " HIDDEN_SIZE,")?;
4024 writeln!(code, " );")?;
4025 writeln!(
4026 code,
4027 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4028 )?;
4029 writeln!(code, " }}")?;
4030 writeln!(code)?;
4031
4032 writeln!(code, " // 2. Transformer layers")?;
4034 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4035 writeln!(code)?;
4036 let q_byte_offset = 0usize;
4037 let k_byte_offset = hidden * 4;
4038 let v_byte_offset = (hidden + kv_dim) * 4;
4039
4040 writeln!(
4041 code,
4042 " // Pre-attention: rms_norm, fused QKV projection, RoPE"
4043 )?;
4044 writeln!(
4045 code,
4046 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4047 )?;
4048 writeln!(
4049 code,
4050 " // Fused Q+K+V matmul: single dispatch for all three projections"
4051 )?;
4052 writeln!(
4053 code,
4054 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4055 )?;
4056 if config.qkv_bias {
4057 writeln!(
4058 code,
4059 " // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4060 )?;
4061 writeln!(
4062 code,
4063 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4064 )?;
4065 }
4066 writeln!(
4067 code,
4068 " // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
4069 )?;
4070 writeln!(
4071 code,
4072 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
4073 )?;
4074 writeln!(
4075 code,
4076 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4077 )?;
4078 writeln!(code)?;
4079 writeln!(
4080 code,
4081 " // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
4082 )?;
4083 writeln!(code, " let kv_offset = pos * {kv_dim};")?;
4084 writeln!(
4085 code,
4086 " self.dispatch_copy_from_offset_f16(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4087 )?;
4088 writeln!(
4089 code,
4090 " self.dispatch_copy_from_offset_f16(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4091 )?;
4092 writeln!(code)?;
4093 writeln!(
4094 code,
4095 " // Attention using Q from qkv_buf (offset 0)"
4096 )?;
4097 writeln!(
4098 code,
4099 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4100 )?;
4101 writeln!(
4102 code,
4103 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4104 )?;
4105 writeln!(
4106 code,
4107 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4108 )?;
4109 writeln!(
4110 code,
4111 " // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4112 )?;
4113 writeln!(
4114 code,
4115 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4116 )?;
4117 writeln!(
4118 code,
4119 " // Fused gate+up matmul: single dispatch for both projections"
4120 )?;
4121 writeln!(
4122 code,
4123 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4124 )?;
4125 writeln!(
4126 code,
4127 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4128 )?;
4129 writeln!(
4130 code,
4131 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4132 )?;
4133 writeln!(
4134 code,
4135 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4136 )?;
4137 writeln!(code, " }}")?;
4138 writeln!(code)?;
4139
4140 writeln!(code, " // 3. Final RMS norm + logits projection")?;
4142 writeln!(
4143 code,
4144 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4145 )?;
4146 writeln!(
4147 code,
4148 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4149 )?;
4150 writeln!(code)?;
4151 writeln!(code, " enc.end_encoding();")?;
4152 writeln!(code, " }}")?;
4153 writeln!(code)?;
4154
4155 writeln!(
4157 code,
4158 " // 5. Commit all GPU work and wait for completion"
4159 )?;
4160 writeln!(code, " cmd.commit();")?;
4161 writeln!(code, " cmd.wait_until_completed();")?;
4162 writeln!(code)?;
4163 writeln!(code, " // 6. Read back logits from GPU")?;
4164 writeln!(code, " let logits = unsafe {{")?;
4165 writeln!(
4166 code,
4167 " let ptr = self.logits_buf.contents() as *const f32;"
4168 )?;
4169 writeln!(
4170 code,
4171 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4172 )?;
4173 writeln!(code, " }};")?;
4174 writeln!(code)?;
4175 writeln!(code, " self.pos += 1;")?;
4176 writeln!(code, " logits")?;
4177 writeln!(code, " }}")?;
4178 writeln!(code)?;
4179
4180 writeln!(
4182 code,
4183 " /// Profiling forward pass that prints per-stage GPU timing."
4184 )?;
4185 writeln!(code, " ///")?;
4186 writeln!(
4187 code,
4188 " /// Each stage is committed and waited on separately so that GPU timestamps"
4189 )?;
4190 writeln!(
4191 code,
4192 " /// accurately reflect per-operation cost. This is slower than `forward()` due"
4193 )?;
4194 writeln!(
4195 code,
4196 " /// to the per-stage synchronization, but useful for identifying bottlenecks."
4197 )?;
4198 writeln!(
4199 code,
4200 " pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4201 )?;
4202 writeln!(code, " use std::time::Instant;")?;
4203 writeln!(code)?;
4204 writeln!(
4205 code,
4206 " // Wait for any pending prefill command buffer"
4207 )?;
4208 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4209 writeln!(code, " prev.wait_until_completed();")?;
4210 writeln!(code, " }}")?;
4211 writeln!(code)?;
4212 writeln!(code, " let pos = self.pos;")?;
4213 writeln!(code)?;
4214
4215 writeln!(
4217 code,
4218 " // ── Stage: Embedding lookup (CPU via unified memory) ──"
4219 )?;
4220 writeln!(code, " let t_embed = Instant::now();")?;
4221 writeln!(code, " unsafe {{")?;
4222 writeln!(
4223 code,
4224 " let embed_ptr = self.embed_buf.contents() as *const f32;"
4225 )?;
4226 writeln!(
4227 code,
4228 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4229 )?;
4230 writeln!(
4231 code,
4232 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
4233 )?;
4234 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
4235 writeln!(
4236 code,
4237 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4238 )?;
4239 writeln!(code, " hidden_ptr,")?;
4240 writeln!(code, " HIDDEN_SIZE,")?;
4241 writeln!(code, " );")?;
4242 writeln!(
4243 code,
4244 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4245 )?;
4246 writeln!(code, " }}")?;
4247 writeln!(code, " let d_embed = t_embed.elapsed();")?;
4248 writeln!(code)?;
4249
4250 writeln!(code, " // ── Stage: Transformer layers (GPU) ──")?;
4252 writeln!(code, " let t_layers = Instant::now();")?;
4253 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4254 writeln!(code, " {{")?;
4255 writeln!(
4256 code,
4257 " let enc = cmd.new_compute_command_encoder();"
4258 )?;
4259 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4260 writeln!(
4261 code,
4262 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4263 )?;
4264 writeln!(
4265 code,
4266 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4267 )?;
4268 if config.qkv_bias {
4269 writeln!(
4270 code,
4271 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4272 )?;
4273 }
4274 writeln!(
4275 code,
4276 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
4277 )?;
4278 writeln!(
4279 code,
4280 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4281 )?;
4282 writeln!(code, " let kv_offset = pos * {kv_dim};")?;
4283 writeln!(
4284 code,
4285 " self.dispatch_copy_from_offset_f16(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4286 )?;
4287 writeln!(
4288 code,
4289 " self.dispatch_copy_from_offset_f16(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4290 )?;
4291 writeln!(
4292 code,
4293 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4294 )?;
4295 writeln!(
4296 code,
4297 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4298 )?;
4299 writeln!(
4300 code,
4301 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4302 )?;
4303 writeln!(
4304 code,
4305 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4306 )?;
4307 writeln!(
4308 code,
4309 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4310 )?;
4311 writeln!(
4312 code,
4313 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4314 )?;
4315 writeln!(
4316 code,
4317 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4318 )?;
4319 writeln!(
4320 code,
4321 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4322 )?;
4323 writeln!(code, " }}")?;
4324 writeln!(code, " enc.end_encoding();")?;
4325 writeln!(code, " }}")?;
4326 writeln!(code, " cmd.commit();")?;
4327 writeln!(code, " cmd.wait_until_completed();")?;
4328 writeln!(code, " let d_layers = t_layers.elapsed();")?;
4329 writeln!(code)?;
4330
4331 writeln!(code, " // ── Stage: Final norm + logits (GPU) ──")?;
4333 writeln!(code, " let t_logits = Instant::now();")?;
4334 writeln!(code, " let cmd2 = self.queue.new_command_buffer();")?;
4335 writeln!(code, " {{")?;
4336 writeln!(
4337 code,
4338 " let enc = cmd2.new_compute_command_encoder();"
4339 )?;
4340 writeln!(
4341 code,
4342 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4343 )?;
4344 writeln!(
4345 code,
4346 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4347 )?;
4348 writeln!(code, " enc.end_encoding();")?;
4349 writeln!(code, " }}")?;
4350 writeln!(code, " cmd2.commit();")?;
4351 writeln!(code, " cmd2.wait_until_completed();")?;
4352 writeln!(code, " let d_logits = t_logits.elapsed();")?;
4353 writeln!(code)?;
4354
4355 writeln!(
4357 code,
4358 " eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4359 )?;
4360 writeln!(code, " d_embed.as_secs_f64() * 1000.0,")?;
4361 writeln!(code, " d_layers.as_secs_f64() * 1000.0,")?;
4362 writeln!(code, " d_logits.as_secs_f64() * 1000.0,")?;
4363 writeln!(
4364 code,
4365 " (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4366 )?;
4367 writeln!(code)?;
4368
4369 writeln!(code, " let logits = unsafe {{")?;
4371 writeln!(
4372 code,
4373 " let ptr = self.logits_buf.contents() as *const f32;"
4374 )?;
4375 writeln!(
4376 code,
4377 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4378 )?;
4379 writeln!(code, " }};")?;
4380 writeln!(code)?;
4381 writeln!(code, " self.pos += 1;")?;
4382 writeln!(code, " logits")?;
4383 writeln!(code, " }}")?;
4384 writeln!(code)?;
4385
4386 writeln!(
4388 code,
4389 " /// Asynchronous forward pass for a single prefill token (no logits readback)."
4390 )?;
4391 writeln!(code, " ///")?;
4392 writeln!(
4393 code,
4394 " /// Commits the command buffer without waiting, enabling double-buffered"
4395 )?;
4396 writeln!(
4397 code,
4398 " /// execution: GPU processes token N while CPU encodes token N+1."
4399 )?;
4400 writeln!(
4401 code,
4402 " pub fn forward_prefill(&mut self, token_id: u32) {{"
4403 )?;
4404 writeln!(code, " self.forward_prefill_batch(&[token_id]);")?;
4405 writeln!(code, " }}")?;
4406 writeln!(code)?;
4407
4408 let batch_matmul_fn = if is_q8 {
4411 "dispatch_matmul_q8_batch"
4412 } else if is_q4 {
4413 "dispatch_matmul_q4_batch"
4414 } else {
4415 "dispatch_matmul_batch"
4416 };
4417
4418 writeln!(
4419 code,
4420 " /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4421 )?;
4422 writeln!(code, " ///")?;
4423 writeln!(
4424 code,
4425 " /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4426 )?;
4427 writeln!(
4428 code,
4429 " /// of mat-vec), and batched causal attention with a single GPU dispatch."
4430 )?;
4431 writeln!(
4432 code,
4433 " /// This provides significant speedup during prompt prefill."
4434 )?;
4435 writeln!(
4436 code,
4437 " pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4438 )?;
4439 writeln!(code, " if tokens.is_empty() {{ return; }}")?;
4440 writeln!(
4441 code,
4442 " // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4443 )?;
4444 writeln!(
4445 code,
4446 " // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4447 )?;
4448 writeln!(
4449 code,
4450 " // longer than that must be processed iteratively. The KV cache"
4451 )?;
4452 writeln!(code, " // carries state across chunks via self.pos.")?;
4453 writeln!(
4454 code,
4455 " for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4456 )?;
4457 writeln!(code, " let m = chunk.len();")?;
4458 writeln!(code, " if m == 0 {{ continue; }}")?;
4459 writeln!(code, " let start_pos = self.pos;")?;
4460 writeln!(code)?;
4461 writeln!(code, " // Wait for any pending command buffer")?;
4462 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4463 writeln!(code, " prev.wait_until_completed();")?;
4464 writeln!(code, " }}")?;
4465 writeln!(code)?;
4466
4467 writeln!(
4469 code,
4470 " // Upload token IDs and positions to GPU buffers"
4471 )?;
4472 writeln!(code, " unsafe {{")?;
4473 writeln!(
4474 code,
4475 " let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4476 )?;
4477 writeln!(
4478 code,
4479 " let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4480 )?;
4481 writeln!(code, " for i in 0..m {{")?;
4482 writeln!(code, " *tok_ptr.add(i) = chunk[i];")?;
4483 writeln!(
4484 code,
4485 " *pos_ptr.add(i) = (start_pos + i) as u32;"
4486 )?;
4487 writeln!(code, " }}")?;
4488 writeln!(code, " }}")?;
4489 writeln!(code)?;
4490
4491 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4492 writeln!(code, " {{")?;
4493 writeln!(
4494 code,
4495 " let enc = cmd.new_compute_command_encoder();"
4496 )?;
4497 writeln!(code)?;
4498
4499 writeln!(
4501 code,
4502 " // 1. Batch embedding lookup: copy all token embeddings at once"
4503 )?;
4504 writeln!(
4505 code,
4506 " self.dispatch_copy_embedding_batch(&enc, m);"
4507 )?;
4508 writeln!(
4510 code,
4511 " self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4512 )?;
4513 writeln!(code)?;
4514
4515 writeln!(code, " // 2. Transformer layers")?;
4517 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4518 writeln!(code)?;
4519
4520 writeln!(
4522 code,
4523 " // Batch RMS norm: batch_residual -> batch_hidden"
4524 )?;
4525 writeln!(
4526 code,
4527 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4528 )?;
4529
4530 writeln!(
4532 code,
4533 " // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4534 )?;
4535 writeln!(
4536 code,
4537 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4538 )?;
4539 if config.qkv_bias {
4540 writeln!(
4541 code,
4542 " // Qwen2: broadcast-add QKV bias across all M tokens."
4543 )?;
4544 writeln!(
4545 code,
4546 " self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4547 )?;
4548 }
4549 writeln!(code)?;
4550
4551 let k_float_offset = hidden;
4553 writeln!(
4554 code,
4555 " // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4556 )?;
4557 writeln!(
4558 code,
4559 " self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4560 )?;
4561 writeln!(code)?;
4562
4563 let v_float_offset = hidden + kv_dim;
4565 writeln!(
4566 code,
4567 " // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4568 )?;
4569 writeln!(
4570 code,
4571 " self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4572 )?;
4573 writeln!(code)?;
4574
4575 writeln!(
4577 code,
4578 " // Batched causal attention: one dispatch for all M tokens"
4579 )?;
4580 writeln!(
4581 code,
4582 " self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4583 )?;
4584 writeln!(code)?;
4585
4586 writeln!(code, " // Batched O projection")?;
4588 writeln!(
4589 code,
4590 " self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4591 )?;
4592 writeln!(code)?;
4593
4594 writeln!(
4596 code,
4597 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4598 )?;
4599 writeln!(code)?;
4600
4601 writeln!(
4603 code,
4604 " // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4605 )?;
4606 writeln!(
4607 code,
4608 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4609 )?;
4610 writeln!(
4611 code,
4612 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4613 )?;
4614 writeln!(
4615 code,
4616 " self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4617 )?;
4618 writeln!(
4619 code,
4620 " self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4621 )?;
4622 writeln!(
4623 code,
4624 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4625 )?;
4626 writeln!(code, " }}")?;
4627 writeln!(code)?;
4628
4629 writeln!(
4631 code,
4632 " // Copy last token's residual to single-token buffer for subsequent forward()"
4633 )?;
4634 writeln!(
4635 code,
4636 " self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4637 )?;
4638 writeln!(code)?;
4639 writeln!(code, " enc.end_encoding();")?;
4640 writeln!(code, " }}")?;
4641 writeln!(code)?;
4642
4643 writeln!(code, " cmd.commit();")?;
4644 writeln!(code, " self.prev_cmd = Some(cmd.to_owned());")?;
4645 writeln!(code, " self.pos += m;")?;
4646 writeln!(code, " }} // end for chunk")?;
4647 writeln!(code, " }}")?;
4648 writeln!(code)?;
4649
4650 writeln!(
4652 code,
4653 " /// Reset the model state for a new inference request."
4654 )?;
4655 writeln!(code, " pub fn reset(&mut self) {{")?;
4656 writeln!(code, " self.pos = 0;")?;
4657 writeln!(code, " self.prev_cmd = None;")?;
4658 writeln!(code, " }}")?;
4659 writeln!(code)?;
4660
4661 writeln!(
4663 code,
4664 " // ── Dispatch helpers (append to a shared compute command encoder) ──"
4665 )?;
4666 writeln!(
4667 code,
4668 " // These methods set pipeline state + buffers + dispatch on an existing"
4669 )?;
4670 writeln!(
4671 code,
4672 " // encoder, avoiding per-operation encoder creation overhead."
4673 )?;
4674 writeln!(code)?;
4675
4676 writeln!(
4678 code,
4679 " /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4680 )?;
4681 writeln!(
4682 code,
4683 " fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4684 )?;
4685 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
4686 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
4687 writeln!(
4688 code,
4689 " enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4690 )?;
4691 writeln!(
4692 code,
4693 " enc.set_buffer(0, Some(&self.residual_buf), 0);"
4694 )?;
4695 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
4696 writeln!(
4697 code,
4698 " enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4699 )?;
4700 writeln!(
4701 code,
4702 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4703 )?;
4704 writeln!(
4705 code,
4706 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4707 )?;
4708 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4709 writeln!(
4710 code,
4711 " let grid_size = MTLSize::new(1, 1, 1); // single threadgroup"
4712 )?;
4713 writeln!(
4714 code,
4715 " enc.dispatch_thread_groups(grid_size, tg_size);"
4716 )?;
4717 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4718 writeln!(code, " }}")?;
4719 writeln!(code)?;
4720
4721 writeln!(
4723 code,
4724 " /// Dispatch matrix-vector multiply: weight * input -> output."
4725 )?;
4726 writeln!(
4727 code,
4728 " fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4729 )?;
4730 writeln!(code, " let r: u32 = rows as u32;")?;
4731 writeln!(code, " let c: u32 = cols as u32;")?;
4732 writeln!(
4733 code,
4734 " enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4735 )?;
4736 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4737 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4738 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4739 writeln!(
4740 code,
4741 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4742 )?;
4743 writeln!(
4744 code,
4745 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4746 )?;
4747 writeln!(
4748 code,
4749 " // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4750 )?;
4751 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4752 writeln!(code, " let num_tg = ((rows + 63) / 64) as u64;")?;
4753 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4754 writeln!(
4755 code,
4756 " enc.dispatch_thread_groups(grid_size, tg_size);"
4757 )?;
4758 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4759 writeln!(code, " }}")?;
4760 writeln!(code)?;
4761
4762 writeln!(
4764 code,
4765 " /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4766 )?;
4767 writeln!(
4768 code,
4769 " /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4770 )?;
4771 writeln!(
4772 code,
4773 " fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4774 )?;
4775 writeln!(code, " let r: u32 = rows as u32;")?;
4776 writeln!(code, " let c: u32 = cols as u32;")?;
4777 writeln!(
4778 code,
4779 " enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4780 )?;
4781 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4782 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4783 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4784 writeln!(
4785 code,
4786 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4787 )?;
4788 writeln!(
4789 code,
4790 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4791 )?;
4792 writeln!(
4793 code,
4794 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4795 )?;
4796 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4797 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4798 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4799 writeln!(
4800 code,
4801 " enc.dispatch_thread_groups(grid_size, tg_size);"
4802 )?;
4803 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4804 writeln!(code, " }}")?;
4805 writeln!(code)?;
4806
4807 writeln!(
4809 code,
4810 " /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4811 )?;
4812 writeln!(
4813 code,
4814 " /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4815 )?;
4816 writeln!(
4817 code,
4818 " fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4819 )?;
4820 writeln!(code, " let r: u32 = rows as u32;")?;
4821 writeln!(code, " let c: u32 = cols as u32;")?;
4822 writeln!(
4823 code,
4824 " enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4825 )?;
4826 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4827 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4828 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4829 writeln!(
4830 code,
4831 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4832 )?;
4833 writeln!(
4834 code,
4835 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4836 )?;
4837 writeln!(
4838 code,
4839 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4840 )?;
4841 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4842 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4843 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4844 writeln!(
4845 code,
4846 " enc.dispatch_thread_groups(grid_size, tg_size);"
4847 )?;
4848 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4849 writeln!(code, " }}")?;
4850 writeln!(code)?;
4851
4852 writeln!(code, " /// Dispatch RoPE on a buffer in-place.")?;
4854 writeln!(
4855 code,
4856 " fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4857 )?;
4858 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4859 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4860 writeln!(code, " let p: u32 = pos as u32;")?;
4861 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4862 writeln!(
4863 code,
4864 " let total_pairs = num_heads * (head_dim / 2);"
4865 )?;
4866 writeln!(
4867 code,
4868 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4869 )?;
4870 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
4871 writeln!(
4872 code,
4873 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4874 )?;
4875 writeln!(
4876 code,
4877 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4878 )?;
4879 writeln!(
4880 code,
4881 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4882 )?;
4883 writeln!(
4884 code,
4885 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4886 )?;
4887 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4888 writeln!(
4889 code,
4890 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4891 )?;
4892 writeln!(
4893 code,
4894 " enc.dispatch_thread_groups(grid_size, tg_size);"
4895 )?;
4896 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4897 writeln!(code, " }}")?;
4898 writeln!(code)?;
4899
4900 writeln!(
4902 code,
4903 " /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4904 )?;
4905 writeln!(
4906 code,
4907 " fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4908 )?;
4909 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4910 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4911 writeln!(code, " let p: u32 = pos as u32;")?;
4912 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4913 writeln!(
4914 code,
4915 " let total_pairs = num_heads * (head_dim / 2);"
4916 )?;
4917 writeln!(
4918 code,
4919 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4920 )?;
4921 writeln!(
4922 code,
4923 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
4924 )?;
4925 writeln!(
4926 code,
4927 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4928 )?;
4929 writeln!(
4930 code,
4931 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4932 )?;
4933 writeln!(
4934 code,
4935 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4936 )?;
4937 writeln!(
4938 code,
4939 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4940 )?;
4941 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4942 writeln!(
4943 code,
4944 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4945 )?;
4946 writeln!(
4947 code,
4948 " enc.dispatch_thread_groups(grid_size, tg_size);"
4949 )?;
4950 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4951 writeln!(code, " }}")?;
4952 writeln!(code)?;
4953
4954 writeln!(
4956 code,
4957 " /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4958 )?;
4959 writeln!(
4960 code,
4961 " fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4962 )?;
4963 writeln!(code, " let sl: u32 = seq_len as u32;")?;
4964 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
4965 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
4966 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
4967 writeln!(
4968 code,
4969 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
4970 )?;
4971 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
4972 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
4973 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
4974 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
4975 writeln!(
4976 code,
4977 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4978 )?;
4979 writeln!(
4980 code,
4981 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4982 )?;
4983 writeln!(
4984 code,
4985 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4986 )?;
4987 writeln!(
4988 code,
4989 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4990 )?;
4991 writeln!(
4992 code,
4993 " // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4994 )?;
4995 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4996 writeln!(
4997 code,
4998 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4999 )?;
5000 writeln!(
5001 code,
5002 " enc.dispatch_thread_groups(grid_size, tg_size);"
5003 )?;
5004 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5005 writeln!(code, " }}")?;
5006 writeln!(code)?;
5007
5008 writeln!(
5010 code,
5011 " /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
5012 )?;
5013 writeln!(
5014 code,
5015 " fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
5016 )?;
5017 writeln!(code, " let sl: u32 = seq_len as u32;")?;
5018 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
5019 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
5020 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
5021 writeln!(
5022 code,
5023 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
5024 )?;
5025 writeln!(
5026 code,
5027 " enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5028 )?;
5029 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
5030 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
5031 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
5032 writeln!(
5033 code,
5034 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5035 )?;
5036 writeln!(
5037 code,
5038 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5039 )?;
5040 writeln!(
5041 code,
5042 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5043 )?;
5044 writeln!(
5045 code,
5046 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5047 )?;
5048 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5049 writeln!(
5050 code,
5051 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5052 )?;
5053 writeln!(
5054 code,
5055 " enc.dispatch_thread_groups(grid_size, tg_size);"
5056 )?;
5057 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5058 writeln!(code, " }}")?;
5059 writeln!(code)?;
5060
5061 writeln!(code, " /// Dispatch fused SiLU-multiply kernel.")?;
5063 writeln!(
5064 code,
5065 " fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5066 )?;
5067 writeln!(code, " let count: u32 = n as u32;")?;
5068 writeln!(
5069 code,
5070 " enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5071 )?;
5072 writeln!(code, " enc.set_buffer(0, Some(gate), 0);")?;
5073 writeln!(code, " enc.set_buffer(1, Some(up), 0);")?;
5074 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5075 writeln!(
5076 code,
5077 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5078 )?;
5079 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5080 writeln!(
5081 code,
5082 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5083 )?;
5084 writeln!(
5085 code,
5086 " enc.dispatch_thread_groups(grid_size, tg_size);"
5087 )?;
5088 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5089 writeln!(code, " }}")?;
5090 writeln!(code)?;
5091
5092 writeln!(
5094 code,
5095 " /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5096 )?;
5097 writeln!(
5098 code,
5099 " /// gate_up_buf contains [gate(n), up(n)] contiguously."
5100 )?;
5101 writeln!(
5102 code,
5103 " fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5104 )?;
5105 writeln!(code, " let count: u32 = n as u32;")?;
5106 writeln!(
5107 code,
5108 " enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5109 )?;
5110 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5111 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5112 writeln!(
5113 code,
5114 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5115 )?;
5116 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5117 writeln!(
5118 code,
5119 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5120 )?;
5121 writeln!(
5122 code,
5123 " enc.dispatch_thread_groups(grid_size, tg_size);"
5124 )?;
5125 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5126 writeln!(code, " }}")?;
5127 writeln!(code)?;
5128
5129 writeln!(
5131 code,
5132 " /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5133 )?;
5134 writeln!(
5135 code,
5136 " fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5137 )?;
5138 writeln!(code, " let n: u32 = count as u32;")?;
5139 writeln!(
5140 code,
5141 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5142 )?;
5143 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5144 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5145 writeln!(
5146 code,
5147 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5148 )?;
5149 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5150 writeln!(
5151 code,
5152 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5153 )?;
5154 writeln!(
5155 code,
5156 " enc.dispatch_thread_groups(grid_size, tg_size);"
5157 )?;
5158 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5159 writeln!(code, " }}")?;
5160 writeln!(code)?;
5161
5162 writeln!(
5164 code,
5165 " /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5166 )?;
5167 writeln!(
5168 code,
5169 " /// Used for embedding table lookup (copy a specific row)."
5170 )?;
5171 writeln!(
5172 code,
5173 " fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5174 )?;
5175 writeln!(code, " let off: u32 = src_offset as u32;")?;
5176 writeln!(code, " let n: u32 = count as u32;")?;
5177 writeln!(
5178 code,
5179 " enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5180 )?;
5181 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5182 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5183 writeln!(
5184 code,
5185 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5186 )?;
5187 writeln!(
5188 code,
5189 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5190 )?;
5191 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5192 writeln!(
5193 code,
5194 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5195 )?;
5196 writeln!(
5197 code,
5198 " enc.dispatch_thread_groups(grid_size, tg_size);"
5199 )?;
5200 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5201 writeln!(code, " }}")?;
5202 writeln!(code)?;
5203
5204 writeln!(
5206 code,
5207 " /// Dispatch copy from source at byte offset to destination at float offset."
5208 )?;
5209 writeln!(
5210 code,
5211 " /// Used for KV cache updates from fused QKV buffer."
5212 )?;
5213 writeln!(
5214 code,
5215 " fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5216 )?;
5217 writeln!(code, " let n: u32 = count as u32;")?;
5218 writeln!(
5219 code,
5220 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5221 )?;
5222 writeln!(
5223 code,
5224 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5225 )?;
5226 writeln!(
5227 code,
5228 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5229 )?;
5230 writeln!(
5231 code,
5232 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5233 )?;
5234 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5235 writeln!(
5236 code,
5237 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5238 )?;
5239 writeln!(
5240 code,
5241 " enc.dispatch_thread_groups(grid_size, tg_size);"
5242 )?;
5243 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5244 writeln!(code, " }}")?;
5245 writeln!(code)?;
5246
5247 writeln!(
5249 code,
5250 " /// Dispatch f32 -> f16 copy from src at byte offset to half-typed dst at element offset."
5251 )?;
5252 writeln!(
5253 code,
5254 " /// Used for single-token decode KV cache updates (f32 QKV buf -> f16 KV cache)."
5255 )?;
5256 writeln!(
5257 code,
5258 " fn dispatch_copy_from_offset_f16(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_elem_offset: usize, count: usize) {{"
5259 )?;
5260 writeln!(code, " let n: u32 = count as u32;")?;
5261 writeln!(
5262 code,
5263 " enc.set_compute_pipeline_state(&self.copy_f32_to_f16_offset_pipeline);"
5264 )?;
5265 writeln!(
5266 code,
5267 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5268 )?;
5269 writeln!(
5270 code,
5271 " enc.set_buffer(1, Some(dst), (dst_elem_offset * 2) as u64);"
5272 )?;
5273 writeln!(
5274 code,
5275 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5276 )?;
5277 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5278 writeln!(
5279 code,
5280 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5281 )?;
5282 writeln!(
5283 code,
5284 " enc.dispatch_thread_groups(grid_size, tg_size);"
5285 )?;
5286 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5287 writeln!(code, " }}")?;
5288 writeln!(code)?;
5289
5290 writeln!(
5292 code,
5293 " /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5294 )?;
5295 writeln!(
5296 code,
5297 " /// Used for KV cache updates (write to a specific position in the cache)."
5298 )?;
5299 writeln!(
5300 code,
5301 " fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5302 )?;
5303 writeln!(code, " let n: u32 = count as u32;")?;
5304 writeln!(
5305 code,
5306 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5307 )?;
5308 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5309 writeln!(
5310 code,
5311 " enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5312 )?;
5313 writeln!(
5314 code,
5315 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5316 )?;
5317 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5318 writeln!(
5319 code,
5320 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5321 )?;
5322 writeln!(
5323 code,
5324 " enc.dispatch_thread_groups(grid_size, tg_size);"
5325 )?;
5326 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5327 writeln!(code, " }}")?;
5328 writeln!(code)?;
5329
5330 writeln!(
5332 code,
5333 " /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5334 )?;
5335 writeln!(
5336 code,
5337 " fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5338 )?;
5339 writeln!(code, " let count: u32 = n as u32;")?;
5340 writeln!(
5341 code,
5342 " enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5343 )?;
5344 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5345 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5346 writeln!(
5347 code,
5348 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5349 )?;
5350 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5351 writeln!(
5352 code,
5353 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5354 )?;
5355 writeln!(
5356 code,
5357 " enc.dispatch_thread_groups(grid_size, tg_size);"
5358 )?;
5359 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5360 writeln!(code, " }}")?;
5361 writeln!(code)?;
5362
5363 writeln!(code, " // ── Batched prefill dispatch helpers ──")?;
5365 writeln!(code)?;
5366
5367 writeln!(
5369 code,
5370 " /// Dispatch batched embedding lookup: copy M token embeddings at once."
5371 )?;
5372 writeln!(
5373 code,
5374 " fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5375 )?;
5376 writeln!(code, " let dim: u32 = HIDDEN_SIZE as u32;")?;
5377 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5378 writeln!(
5379 code,
5380 " enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5381 )?;
5382 writeln!(code, " enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5383 writeln!(
5384 code,
5385 " enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5386 )?;
5387 writeln!(
5388 code,
5389 " enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5390 )?;
5391 writeln!(
5392 code,
5393 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5394 )?;
5395 writeln!(
5396 code,
5397 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5398 )?;
5399 writeln!(code, " let total = num_tokens * HIDDEN_SIZE;")?;
5400 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5401 writeln!(
5402 code,
5403 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5404 )?;
5405 writeln!(
5406 code,
5407 " enc.dispatch_thread_groups(grid_size, tg_size);"
5408 )?;
5409 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5410 writeln!(code, " }}")?;
5411 writeln!(code)?;
5412
5413 writeln!(
5415 code,
5416 " /// Dispatch batched RMS norm: normalizes M vectors at once."
5417 )?;
5418 writeln!(
5419 code,
5420 " fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5421 )?;
5422 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
5423 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
5424 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5425 writeln!(
5426 code,
5427 " enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5428 )?;
5429 writeln!(code, " enc.set_buffer(0, Some(input), 0);")?;
5430 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
5431 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5432 writeln!(
5433 code,
5434 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5435 )?;
5436 writeln!(
5437 code,
5438 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5439 )?;
5440 writeln!(
5441 code,
5442 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5443 )?;
5444 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5445 writeln!(
5446 code,
5447 " let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5448 )?;
5449 writeln!(
5450 code,
5451 " enc.dispatch_thread_groups(grid_size, tg_size);"
5452 )?;
5453 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5454 writeln!(code, " }}")?;
5455 writeln!(code)?;
5456
5457 writeln!(
5459 code,
5460 " /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5461 )?;
5462 writeln!(
5463 code,
5464 " fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5465 )?;
5466 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5467 writeln!(code, " let r: u32 = rows as u32;")?;
5468 writeln!(code, " let c: u32 = cols as u32;")?;
5469 writeln!(
5470 code,
5471 " enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5472 )?;
5473 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5474 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5475 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5476 writeln!(
5477 code,
5478 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5479 )?;
5480 writeln!(
5481 code,
5482 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5483 )?;
5484 writeln!(
5485 code,
5486 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5487 )?;
5488 writeln!(
5489 code,
5490 " let row_tgs = (rows + 63) / 64; // 64 rows per threadgroup for f32"
5491 )?;
5492 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5493 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5494 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5495 writeln!(
5496 code,
5497 " enc.dispatch_thread_groups(grid_size, tg_size);"
5498 )?;
5499 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5500 writeln!(code, " }}")?;
5501 writeln!(code)?;
5502
5503 writeln!(
5505 code,
5506 " /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5507 )?;
5508 writeln!(code, " ///")?;
5509 writeln!(
5510 code,
5511 " /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5512 )?;
5513 writeln!(
5514 code,
5515 " /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5516 )?;
5517 writeln!(
5518 code,
5519 " fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5520 )?;
5521 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5522 writeln!(code, " let r: u32 = rows as u32;")?;
5523 writeln!(code, " let c: u32 = cols as u32;")?;
5524 writeln!(
5525 code,
5526 " // Tile sizes must match the Metal shader constants."
5527 )?;
5528 writeln!(code, " const TOKENS_PER_TG_Q8: usize = 4;")?;
5529 writeln!(code, " const MMA_TOK_TILE: usize = 16;")?;
5530 writeln!(code, " const MMA_ROW_TILE: usize = 16;")?;
5531 writeln!(code, " const MMA32_TOK_TILE: usize = 32;")?;
5532 writeln!(code, " const MMA32_ROW_TILE: usize = 32;")?;
5533 writeln!(
5534 code,
5535 " // Hardware matrix-multiply paths (simdgroup_matrix)."
5536 )?;
5537 writeln!(
5538 code,
5539 " // Prefer the large 32×32 tile when the problem supports it — halves"
5540 )?;
5541 writeln!(
5542 code,
5543 " // dispatch count and reuses each weight load across 32 tokens."
5544 )?;
5545 writeln!(
5546 code,
5547 " if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5548 )?;
5549 writeln!(
5550 code,
5551 " // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5552 )?;
5553 writeln!(
5554 code,
5555 " // It wins at moderate prefill lengths where the GPU is wave-starved,"
5556 )?;
5557 writeln!(
5558 code,
5559 " // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5560 )?;
5561 writeln!(
5562 code,
5563 " // case (135M / 360M). Switch at cols >= 2048 — a clean split that"
5564 )?;
5565 writeln!(
5566 code,
5567 " // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5568 )?;
5569 writeln!(
5570 code,
5571 " // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5572 )?;
5573 writeln!(
5574 code,
5575 " // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5576 )?;
5577 writeln!(
5578 code,
5579 " // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5580 )?;
5581 writeln!(code, " let use_h4 = cols >= 2048;")?;
5582 writeln!(code, " let pipe = if use_h4 {{")?;
5583 writeln!(code, " if num_tokens >= 256 {{")?;
5584 writeln!(
5585 code,
5586 " &self.matmul_q8_mma32_hh4_pipeline"
5587 )?;
5588 writeln!(code, " }} else {{")?;
5589 writeln!(
5590 code,
5591 " &self.matmul_q8_mma32_h4_pipeline"
5592 )?;
5593 writeln!(code, " }}")?;
5594 writeln!(code, " }} else {{")?;
5595 writeln!(code, " &self.matmul_q8_mma32_pipeline")?;
5596 writeln!(code, " }};")?;
5597 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
5598 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5599 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5600 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5601 writeln!(
5602 code,
5603 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5604 )?;
5605 writeln!(
5606 code,
5607 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5608 )?;
5609 writeln!(
5610 code,
5611 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5612 )?;
5613 writeln!(code, " let row_tgs = rows / MMA32_ROW_TILE;")?;
5614 writeln!(
5615 code,
5616 " let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5617 )?;
5618 writeln!(
5619 code,
5620 " let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5621 )?;
5622 writeln!(
5623 code,
5624 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5625 )?;
5626 writeln!(
5627 code,
5628 " enc.dispatch_thread_groups(grid_size, tg_size);"
5629 )?;
5630 writeln!(
5631 code,
5632 " }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5633 )?;
5634 writeln!(
5635 code,
5636 " enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5637 )?;
5638 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5639 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5640 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5641 writeln!(
5642 code,
5643 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5644 )?;
5645 writeln!(
5646 code,
5647 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5648 )?;
5649 writeln!(
5650 code,
5651 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5652 )?;
5653 writeln!(code, " let row_tgs = rows / MMA_ROW_TILE;")?;
5654 writeln!(
5655 code,
5656 " let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5657 )?;
5658 writeln!(
5659 code,
5660 " let tg_size = MTLSize::new(128, 1, 1); // 4 simdgroups × 32 lanes"
5661 )?;
5662 writeln!(
5663 code,
5664 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5665 )?;
5666 writeln!(
5667 code,
5668 " enc.dispatch_thread_groups(grid_size, tg_size);"
5669 )?;
5670 writeln!(code, " }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5671 writeln!(
5672 code,
5673 " enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5674 )?;
5675 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5676 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5677 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5678 writeln!(
5679 code,
5680 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5681 )?;
5682 writeln!(
5683 code,
5684 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5685 )?;
5686 writeln!(
5687 code,
5688 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5689 )?;
5690 writeln!(
5691 code,
5692 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5693 )?;
5694 writeln!(
5695 code,
5696 " let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5697 )?;
5698 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5699 writeln!(
5700 code,
5701 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5702 )?;
5703 writeln!(
5704 code,
5705 " enc.dispatch_thread_groups(grid_size, tg_size);"
5706 )?;
5707 writeln!(code, " }} else {{")?;
5708 writeln!(
5709 code,
5710 " enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5711 )?;
5712 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5713 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5714 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5715 writeln!(
5716 code,
5717 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5718 )?;
5719 writeln!(
5720 code,
5721 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5722 )?;
5723 writeln!(
5724 code,
5725 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5726 )?;
5727 writeln!(
5728 code,
5729 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5730 )?;
5731 writeln!(
5732 code,
5733 " let num_tg = (row_tgs * num_tokens) as u64;"
5734 )?;
5735 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5736 writeln!(
5737 code,
5738 " let grid_size = MTLSize::new(num_tg, 1, 1);"
5739 )?;
5740 writeln!(
5741 code,
5742 " enc.dispatch_thread_groups(grid_size, tg_size);"
5743 )?;
5744 writeln!(code, " }}")?;
5745 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5746 writeln!(code, " }}")?;
5747 writeln!(code)?;
5748
5749 writeln!(
5751 code,
5752 " /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5753 )?;
5754 writeln!(
5755 code,
5756 " fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5757 )?;
5758 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5759 writeln!(code, " let r: u32 = rows as u32;")?;
5760 writeln!(code, " let c: u32 = cols as u32;")?;
5761 writeln!(
5762 code,
5763 " enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5764 )?;
5765 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5766 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5767 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5768 writeln!(
5769 code,
5770 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5771 )?;
5772 writeln!(
5773 code,
5774 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5775 )?;
5776 writeln!(
5777 code,
5778 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5779 )?;
5780 writeln!(
5781 code,
5782 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q4"
5783 )?;
5784 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5785 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5786 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5787 writeln!(
5788 code,
5789 " enc.dispatch_thread_groups(grid_size, tg_size);"
5790 )?;
5791 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5792 writeln!(code, " }}")?;
5793 writeln!(code)?;
5794
5795 if config.qkv_bias {
5797 writeln!(
5798 code,
5799 " /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5800 )?;
5801 writeln!(
5802 code,
5803 " fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5804 )?;
5805 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5806 writeln!(code, " let r: u32 = rows as u32;")?;
5807 writeln!(
5808 code,
5809 " enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5810 )?;
5811 writeln!(code, " enc.set_buffer(0, Some(out), 0);")?;
5812 writeln!(code, " enc.set_buffer(1, Some(bias), 0);")?;
5813 writeln!(
5814 code,
5815 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5816 )?;
5817 writeln!(
5818 code,
5819 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5820 )?;
5821 writeln!(code, " let total = (num_tokens * rows) as u64;")?;
5822 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5823 writeln!(
5824 code,
5825 " let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5826 )?;
5827 writeln!(
5828 code,
5829 " enc.dispatch_thread_groups(grid_size, tg_size);"
5830 )?;
5831 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5832 writeln!(code, " }}")?;
5833 writeln!(code)?;
5834 }
5835
5836 writeln!(
5838 code,
5839 " /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5840 )?;
5841 writeln!(
5842 code,
5843 " /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5844 )?;
5845 writeln!(
5846 code,
5847 " /// `row_stride` is the number of floats per token row in the batch buffer."
5848 )?;
5849 writeln!(
5850 code,
5851 " fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5852 )?;
5853 writeln!(code, " let nh: u32 = num_heads as u32;")?;
5854 writeln!(code, " let hd: u32 = head_dim as u32;")?;
5855 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
5856 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5857 writeln!(
5858 code,
5859 " let pairs_per_token = num_heads * (head_dim / 2);"
5860 )?;
5861 writeln!(
5862 code,
5863 " let total_pairs = num_tokens * pairs_per_token;"
5864 )?;
5865 writeln!(
5878 code,
5879 " // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5880 )?;
5881 writeln!(code, " for t in 0..num_tokens {{")?;
5882 writeln!(
5883 code,
5884 " let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5885 )?;
5886 writeln!(
5887 code,
5888 " let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5889 )?;
5890 writeln!(
5891 code,
5892 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
5893 )?;
5894 writeln!(
5895 code,
5896 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
5897 )?;
5898 writeln!(
5899 code,
5900 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5901 )?;
5902 writeln!(
5903 code,
5904 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5905 )?;
5906 writeln!(
5907 code,
5908 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5909 )?;
5910 writeln!(
5911 code,
5912 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5913 )?;
5914 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5915 writeln!(
5916 code,
5917 " let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5918 )?;
5919 writeln!(
5920 code,
5921 " enc.dispatch_thread_groups(grid_size, tg_size);"
5922 )?;
5923 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5924 writeln!(code, " }}")?;
5925 writeln!(code, " }}")?;
5926 writeln!(code)?;
5927
5928 writeln!(
5930 code,
5931 " /// Dispatch batched fused SiLU-multiply for M tokens."
5932 )?;
5933 writeln!(
5934 code,
5935 " fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5936 )?;
5937 writeln!(code, " let count: u32 = n as u32;")?;
5938 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5939 writeln!(
5940 code,
5941 " enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5942 )?;
5943 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5944 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5945 writeln!(
5946 code,
5947 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5948 )?;
5949 writeln!(
5950 code,
5951 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5952 )?;
5953 writeln!(code, " let total = num_tokens * n;")?;
5954 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5955 writeln!(
5956 code,
5957 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5958 )?;
5959 writeln!(
5960 code,
5961 " enc.dispatch_thread_groups(grid_size, tg_size);"
5962 )?;
5963 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5964 writeln!(code, " }}")?;
5965 writeln!(code)?;
5966
5967 writeln!(
5969 code,
5970 " /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5971 )?;
5972 writeln!(
5973 code,
5974 " fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5975 )?;
5976 writeln!(code, " let count: u32 = total_n as u32;")?;
5977 writeln!(
5978 code,
5979 " enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5980 )?;
5981 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5982 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5983 writeln!(
5984 code,
5985 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5986 )?;
5987 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5988 writeln!(
5989 code,
5990 " let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5991 )?;
5992 writeln!(
5993 code,
5994 " enc.dispatch_thread_groups(grid_size, tg_size);"
5995 )?;
5996 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5997 writeln!(code, " }}")?;
5998 writeln!(code)?;
5999
6000 writeln!(
6002 code,
6003 " /// Copy src to dst using compute copy kernel (for batch residual init)."
6004 )?;
6005 writeln!(
6006 code,
6007 " fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
6008 )?;
6009 writeln!(code, " let n: u32 = count as u32;")?;
6010 writeln!(
6011 code,
6012 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6013 )?;
6014 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6015 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
6016 writeln!(
6017 code,
6018 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6019 )?;
6020 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6021 writeln!(
6022 code,
6023 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6024 )?;
6025 writeln!(
6026 code,
6027 " enc.dispatch_thread_groups(grid_size, tg_size);"
6028 )?;
6029 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6030 writeln!(code, " }}")?;
6031 writeln!(code)?;
6032
6033 writeln!(
6035 code,
6036 " /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
6037 )?;
6038 writeln!(
6039 code,
6040 " fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6041 )?;
6042 writeln!(code, " let n: u32 = count as u32;")?;
6043 writeln!(
6044 code,
6045 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6046 )?;
6047 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6048 writeln!(
6049 code,
6050 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6051 )?;
6052 writeln!(
6053 code,
6054 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6055 )?;
6056 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6057 writeln!(
6058 code,
6059 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6060 )?;
6061 writeln!(
6062 code,
6063 " enc.dispatch_thread_groups(grid_size, tg_size);"
6064 )?;
6065 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6066 writeln!(code, " }}")?;
6067 writeln!(code)?;
6068
6069 writeln!(
6071 code,
6072 " /// Copy from src at byte offset to dst at float offset."
6073 )?;
6074 writeln!(
6075 code,
6076 " fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6077 )?;
6078 writeln!(code, " let n: u32 = count as u32;")?;
6079 writeln!(
6080 code,
6081 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6082 )?;
6083 writeln!(
6084 code,
6085 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6086 )?;
6087 writeln!(
6088 code,
6089 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6090 )?;
6091 writeln!(
6092 code,
6093 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6094 )?;
6095 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6096 writeln!(
6097 code,
6098 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6099 )?;
6100 writeln!(
6101 code,
6102 " enc.dispatch_thread_groups(grid_size, tg_size);"
6103 )?;
6104 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6105 writeln!(code, " }}")?;
6106 writeln!(code)?;
6107
6108 writeln!(
6110 code,
6111 " /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6112 )?;
6113 writeln!(
6114 code,
6115 " fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6116 )?;
6117 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6118 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6119 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6120 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6121 writeln!(code, " let so: u32 = src_offset as u32;")?;
6122 writeln!(
6123 code,
6124 " enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6125 )?;
6126 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6127 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
6128 writeln!(
6129 code,
6130 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6131 )?;
6132 writeln!(
6133 code,
6134 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6135 )?;
6136 writeln!(
6137 code,
6138 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6139 )?;
6140 writeln!(
6141 code,
6142 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6143 )?;
6144 writeln!(
6145 code,
6146 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6147 )?;
6148 writeln!(code, " let total = num_tokens * kv_dim;")?;
6149 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6150 writeln!(
6151 code,
6152 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6153 )?;
6154 writeln!(
6155 code,
6156 " enc.dispatch_thread_groups(grid_size, tg_size);"
6157 )?;
6158 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6159 writeln!(code, " }}")?;
6160 writeln!(code)?;
6161
6162 writeln!(
6164 code,
6165 " /// Dispatch batched causal attention: one dispatch for all M tokens."
6166 )?;
6167 writeln!(
6168 code,
6169 " fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6170 )?;
6171 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6172 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6173 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
6174 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6175 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6176 writeln!(code, " let qs: u32 = q_stride as u32;")?;
6177 writeln!(code, " let max_seq = base_pos + num_tokens;")?;
6187 writeln!(code, " let _ = max_seq;")?;
6188 writeln!(
6189 code,
6190 " let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
6191 )?;
6192 writeln!(code, " .map(|v| v == \"0\").unwrap_or(false);")?;
6193 writeln!(
6194 code,
6195 " let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
6196 )?;
6197 writeln!(code, " if use_mma_flash {{")?;
6198 writeln!(
6199 code,
6200 " let pipe = &self.attention_mma_flash_batch_pipeline;"
6201 )?;
6202 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6203 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6204 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6205 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6206 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6207 writeln!(
6208 code,
6209 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6210 )?;
6211 writeln!(
6212 code,
6213 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6214 )?;
6215 writeln!(
6216 code,
6217 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6218 )?;
6219 writeln!(
6220 code,
6221 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6222 )?;
6223 writeln!(
6224 code,
6225 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6226 )?;
6227 writeln!(
6228 code,
6229 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6230 )?;
6231 writeln!(
6232 code,
6233 " // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6234 )?;
6235 writeln!(code, " let tg_size = MTLSize::new(128, 1, 1);")?;
6236 writeln!(
6237 code,
6238 " let q_blocks = ((num_tokens + 7) / 8) as u64;"
6239 )?;
6240 writeln!(
6241 code,
6242 " let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6243 )?;
6244 writeln!(
6245 code,
6246 " enc.dispatch_thread_groups(grid_size, tg_size);"
6247 )?;
6248 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6249 writeln!(code, " return;")?;
6250 writeln!(code, " }}")?;
6251 writeln!(code, " let pipe = &self.attention_batch_pipeline;")?;
6252 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6253 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6254 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6255 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6256 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6257 writeln!(
6258 code,
6259 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6260 )?;
6261 writeln!(
6262 code,
6263 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6264 )?;
6265 writeln!(
6266 code,
6267 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6268 )?;
6269 writeln!(
6270 code,
6271 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6272 )?;
6273 writeln!(
6274 code,
6275 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6276 )?;
6277 writeln!(
6278 code,
6279 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6280 )?;
6281 writeln!(
6282 code,
6283 " // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6284 )?;
6285 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6286 writeln!(
6287 code,
6288 " let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6289 )?;
6290 writeln!(
6291 code,
6292 " enc.dispatch_thread_groups(grid_size, tg_size);"
6293 )?;
6294 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6295 writeln!(code, " }}")?;
6296 writeln!(code)?;
6297
6298 writeln!(
6300 code,
6301 " /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6302 )?;
6303 writeln!(
6304 code,
6305 " /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6306 )?;
6307 writeln!(
6308 code,
6309 " fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6310 )?;
6311 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6312 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6313 writeln!(code, " let nq: u32 = NUM_HEADS as u32;")?;
6314 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6315 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6316 writeln!(code, " let qs: u32 = qkv_stride as u32;")?;
6317 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
6318 writeln!(
6319 code,
6320 " enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6321 )?;
6322 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
6323 writeln!(
6324 code,
6325 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6326 )?;
6327 writeln!(
6328 code,
6329 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6330 )?;
6331 writeln!(
6332 code,
6333 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6334 )?;
6335 writeln!(
6336 code,
6337 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6338 )?;
6339 writeln!(
6340 code,
6341 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6342 )?;
6343 writeln!(
6344 code,
6345 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6346 )?;
6347 writeln!(
6348 code,
6349 " enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6350 )?;
6351 writeln!(
6352 code,
6353 " let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6354 )?;
6355 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6356 writeln!(
6357 code,
6358 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6359 )?;
6360 writeln!(
6361 code,
6362 " enc.dispatch_thread_groups(grid_size, tg_size);"
6363 )?;
6364 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6365 writeln!(code, " }}")?;
6366 writeln!(code)?;
6367
6368 writeln!(
6370 code,
6371 " /// Dispatch fused K+V cache copy in one kernel launch."
6372 )?;
6373 writeln!(
6374 code,
6375 " /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6376 )?;
6377 writeln!(
6378 code,
6379 " fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6380 )?;
6381 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6382 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6383 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6384 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6385 writeln!(code, " let ko: u32 = k_offset as u32;")?;
6386 writeln!(code, " let vo: u32 = v_offset as u32;")?;
6387 writeln!(
6388 code,
6389 " enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6390 )?;
6391 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6392 writeln!(code, " enc.set_buffer(1, Some(k_dst), 0);")?;
6393 writeln!(code, " enc.set_buffer(2, Some(v_dst), 0);")?;
6394 writeln!(
6395 code,
6396 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6397 )?;
6398 writeln!(
6399 code,
6400 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6401 )?;
6402 writeln!(
6403 code,
6404 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6405 )?;
6406 writeln!(
6407 code,
6408 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6409 )?;
6410 writeln!(
6411 code,
6412 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6413 )?;
6414 writeln!(
6415 code,
6416 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6417 )?;
6418 writeln!(
6419 code,
6420 " let total = num_tokens * kv_dim * 2; // K + V"
6421 )?;
6422 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6423 writeln!(
6424 code,
6425 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6426 )?;
6427 writeln!(
6428 code,
6429 " enc.dispatch_thread_groups(grid_size, tg_size);"
6430 )?;
6431 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6432 writeln!(code, " }}")?;
6433
6434 writeln!(code, "}}")?;
6435 writeln!(code)?;
6436
6437 Ok(())
6438}
6439
6440fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6441 writeln!(
6442 code,
6443 "// ── Helper functions ──────────────────────────────────"
6444 )?;
6445 writeln!(code)?;
6446 writeln!(
6447 code,
6448 "/// Create a compute pipeline from a named function in the Metal library."
6449 )?;
6450 writeln!(
6451 code,
6452 "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6453 )?;
6454 writeln!(
6455 code,
6456 " let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6457 )?;
6458 writeln!(
6459 code,
6460 " device.new_compute_pipeline_state_with_function(&func)"
6461 )?;
6462 writeln!(
6463 code,
6464 " .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6465 )?;
6466 writeln!(code, "}}")?;
6467 writeln!(code)?;
6468
6469 Ok(())
6470}
6471
6472fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6477 let _sanitized = sanitize_name(model_name);
6478 let _vocab = config.vocab_size;
6479
6480 let mut code = String::with_capacity(16 * 1024);
6481 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6482 writeln!(
6483 code,
6484 "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6485 )?;
6486 writeln!(code)?;
6487 writeln!(code, "mod model;")?;
6488 writeln!(code)?;
6489 writeln!(code, "use std::io::Write;")?;
6490 writeln!(code, "use std::time::Instant;")?;
6491 writeln!(code, "use serde::Deserialize;")?;
6492 writeln!(code)?;
6493
6494 writeln!(code, "fn main() {{")?;
6496 writeln!(
6497 code,
6498 " let args: Vec<String> = std::env::args().collect();"
6499 )?;
6500 writeln!(code)?;
6501 writeln!(
6502 code,
6503 " // Detect --serve mode (only requires weights + tokenizer)"
6504 )?;
6505 writeln!(
6506 code,
6507 " let serve_mode = args.iter().any(|a| a == \"--serve\");"
6508 )?;
6509 writeln!(code)?;
6510 writeln!(code, " if !serve_mode && args.len() < 4 {{")?;
6511 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6512 writeln!(code, " eprintln!(\" {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6513 writeln!(code, " std::process::exit(1);")?;
6514 writeln!(code, " }}")?;
6515 writeln!(code)?;
6516 writeln!(code, " if serve_mode && args.len() < 3 {{")?;
6517 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6518 writeln!(code, " std::process::exit(1);")?;
6519 writeln!(code, " }}")?;
6520 writeln!(code)?;
6521 writeln!(code, " let weights_path = &args[1];")?;
6522 writeln!(code, " let tokenizer_path = &args[2];")?;
6523 writeln!(code)?;
6524 writeln!(code, " // Parse optional flags")?;
6525 writeln!(code, " let mut max_tokens: usize = 128;")?;
6526 writeln!(code, " let mut port: u16 = 8080;")?;
6527 writeln!(
6528 code,
6529 " let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6530 )?;
6531 writeln!(
6532 code,
6533 " let profile = args.iter().any(|a| a == \"--profile\");"
6534 )?;
6535 writeln!(code, " let mut i = 3;")?;
6536 writeln!(code, " while i < args.len() {{")?;
6537 writeln!(
6538 code,
6539 " if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6540 )?;
6541 writeln!(
6542 code,
6543 " max_tokens = args[i + 1].parse().unwrap_or(128);"
6544 )?;
6545 writeln!(code, " i += 2;")?;
6546 writeln!(
6547 code,
6548 " }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6549 )?;
6550 writeln!(
6551 code,
6552 " port = args[i + 1].parse().unwrap_or(8080);"
6553 )?;
6554 writeln!(code, " i += 2;")?;
6555 writeln!(code, " }} else if args[i] == \"--serve\" {{")?;
6556 writeln!(code, " i += 1;")?;
6557 writeln!(code, " }} else if args[i] == \"--profile\" {{")?;
6558 writeln!(code, " i += 1;")?;
6559 writeln!(code, " }} else {{")?;
6560 writeln!(code, " i += 1;")?;
6561 writeln!(code, " }}")?;
6562 writeln!(code, " }}")?;
6563 writeln!(code)?;
6564
6565 writeln!(
6567 code,
6568 " // Memory-map weights for zero-copy loading on Apple Silicon"
6569 )?;
6570 writeln!(
6571 code,
6572 " let weights_file = std::fs::File::open(weights_path)"
6573 )?;
6574 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6575 writeln!(
6576 code,
6577 " let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6578 )?;
6579 writeln!(code)?;
6580 writeln!(code, " // Load tokenizer")?;
6581 writeln!(
6582 code,
6583 " let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6584 )?;
6585 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6586 writeln!(code)?;
6587 writeln!(code, " // Create Metal model")?;
6588 writeln!(code, " eprintln!(\"Loading model onto Metal GPU...\");")?;
6589 writeln!(
6590 code,
6591 " let mut model = model::MetalModel::new(&weights_mmap);"
6592 )?;
6593 writeln!(code)?;
6594
6595 writeln!(code, " if serve_mode {{")?;
6597 writeln!(code, " serve(model, tokenizer, port);")?;
6598 writeln!(code, " }} else {{")?;
6599 writeln!(code, " let prompt = &args[3];")?;
6600 writeln!(
6601 code,
6602 " cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6603 )?;
6604 writeln!(code, " }}")?;
6605 writeln!(code, "}}")?;
6606 writeln!(code)?;
6607
6608 writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6610 writeln!(code, " // Tokenize prompt")?;
6611 writeln!(code, " let encoding = tokenizer.encode(prompt, true)")?;
6612 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6613 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6614 writeln!(code)?;
6615 writeln!(
6616 code,
6617 " // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6618 )?;
6619 writeln!(
6620 code,
6621 " // Uses double-buffered batch dispatch for GPU-efficient matmul."
6622 )?;
6623 writeln!(
6624 code,
6625 " // The last token uses synchronous forward() to get logits."
6626 )?;
6627 writeln!(code, " let prompt_len = prompt_tokens.len();")?;
6628 writeln!(code, " let prefill_start = Instant::now();")?;
6629 writeln!(code, " let logits = if prompt_len > 1 {{")?;
6630 writeln!(
6631 code,
6632 " model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6633 )?;
6634 writeln!(code, " model.forward(prompt_tokens[prompt_len - 1])")?;
6635 writeln!(code, " }} else {{")?;
6636 writeln!(code, " model.forward(prompt_tokens[0])")?;
6637 writeln!(code, " }};")?;
6638 writeln!(
6639 code,
6640 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6641 )?;
6642 writeln!(code, " let prefill_tokens = prompt_tokens.len();")?;
6643 writeln!(
6644 code,
6645 " eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6646 )?;
6647 writeln!(
6648 code,
6649 " prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6650 )?;
6651 writeln!(code)?;
6652 writeln!(code, " // Generate tokens")?;
6653 writeln!(code, " let mut next_token = argmax(&logits);")?;
6654 writeln!(code, " let gen_start = Instant::now();")?;
6655 writeln!(code, " let mut generated_count: usize = 0;")?;
6656 writeln!(code)?;
6657 writeln!(code, " for _ in 0..max_tokens {{")?;
6658 writeln!(
6659 code,
6660 " if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6661 )?;
6662 writeln!(code, " if !quiet {{")?;
6663 writeln!(code, " print!(\"{{}}\", text);")?;
6664 writeln!(code, " std::io::stdout().flush().ok();")?;
6665 writeln!(code, " }}")?;
6666 writeln!(code, " }}")?;
6667 writeln!(code, " generated_count += 1;")?;
6668 writeln!(code)?;
6669 writeln!(
6670 code,
6671 " // Use profiling forward for first token when --profile is set"
6672 )?;
6673 writeln!(
6674 code,
6675 " let logits = if profile && generated_count == 1 {{"
6676 )?;
6677 writeln!(code, " model.forward_profile(next_token)")?;
6678 writeln!(code, " }} else {{")?;
6679 writeln!(code, " model.forward(next_token)")?;
6680 writeln!(code, " }};")?;
6681 writeln!(code, " next_token = argmax(&logits);")?;
6682 writeln!(code)?;
6683 writeln!(code, " // Stop on EOS (token 2 for most models)")?;
6684 writeln!(code, " if next_token == 2 {{")?;
6685 writeln!(code, " break;")?;
6686 writeln!(code, " }}")?;
6687 writeln!(code)?;
6688 writeln!(
6689 code,
6690 " // Yield between tokens to reduce sustained GPU thermal load."
6691 )?;
6692 writeln!(
6693 code,
6694 " // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6695 )?;
6696 writeln!(
6697 code,
6698 " // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6699 )?;
6700 writeln!(
6701 code,
6702 " // briefly, providing a micro-break that helps sustain peak throughput."
6703 )?;
6704 writeln!(code, " std::thread::yield_now();")?;
6705 writeln!(code, " }}")?;
6706 writeln!(code, " if !quiet {{")?;
6707 writeln!(code, " println!();")?;
6708 writeln!(code, " }}")?;
6709 writeln!(
6710 code,
6711 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6712 )?;
6713 writeln!(
6714 code,
6715 " eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6716 )?;
6717 writeln!(
6718 code,
6719 " generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6720 )?;
6721 writeln!(code, "}}")?;
6722 writeln!(code)?;
6723
6724 writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6726 writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6727 writeln!(code, " logits.iter()")?;
6728 writeln!(code, " .enumerate()")?;
6729 writeln!(
6730 code,
6731 " .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6732 )?;
6733 writeln!(code, " .map(|(i, _)| i as u32)")?;
6734 writeln!(code, " .unwrap_or(0)")?;
6735 writeln!(code, "}}")?;
6736 writeln!(code)?;
6737
6738 writeln!(
6740 code,
6741 "// -----------------------------------------------------------------------"
6742 )?;
6743 writeln!(code, "// OpenAI-compatible API server")?;
6744 writeln!(
6745 code,
6746 "// -----------------------------------------------------------------------"
6747 )?;
6748 writeln!(code)?;
6749 writeln!(code, "#[derive(Deserialize)]")?;
6750 writeln!(code, "struct ChatRequest {{")?;
6751 writeln!(code, " messages: Vec<ChatMessage>,")?;
6752 writeln!(code, " #[serde(default)]")?;
6753 writeln!(code, " stream: Option<bool>,")?;
6754 writeln!(code, " #[serde(default)]")?;
6755 writeln!(code, " max_tokens: Option<usize>,")?;
6756 writeln!(code, " #[serde(default)]")?;
6757 writeln!(code, " temperature: Option<f32>,")?;
6758 writeln!(code, " #[serde(default)]")?;
6759 writeln!(code, " model: Option<String>,")?;
6760 writeln!(code, "}}")?;
6761 writeln!(code)?;
6762 writeln!(code, "#[derive(Deserialize)]")?;
6763 writeln!(code, "struct ChatMessage {{")?;
6764 writeln!(code, " role: String,")?;
6765 writeln!(code, " content: String,")?;
6766 writeln!(code, "}}")?;
6767 writeln!(code)?;
6768
6769 writeln!(
6771 code,
6772 "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6773 )?;
6774 writeln!(code, " let mut prompt = String::new();")?;
6775 writeln!(code, " for msg in messages {{")?;
6776 writeln!(code, " prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6777 writeln!(code, " }}")?;
6778 writeln!(code, " prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6779 writeln!(code, " prompt")?;
6780 writeln!(code, "}}")?;
6781 writeln!(code)?;
6782
6783 writeln!(
6785 code,
6786 "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6787 )?;
6788 writeln!(code, " let len = tokens.len();")?;
6789 writeln!(code, " if len > 1 {{")?;
6790 writeln!(
6791 code,
6792 " model.forward_prefill_batch(&tokens[..len - 1]);"
6793 )?;
6794 writeln!(code, " }}")?;
6795 writeln!(code, " model.forward(tokens[len - 1])")?;
6796 writeln!(code, "}}")?;
6797 writeln!(code)?;
6798
6799 writeln!(
6801 code,
6802 "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6803 )?;
6804 writeln!(code, " let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6805 writeln!(code, " let server = tiny_http::Server::http(&addr)")?;
6806 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6807 writeln!(
6808 code,
6809 " eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6810 )?;
6811 writeln!(code, " eprintln!(\"Endpoints:\");")?;
6812 writeln!(code, " eprintln!(\" POST /v1/chat/completions\");")?;
6813 writeln!(code, " eprintln!(\" GET /v1/models\");")?;
6814 writeln!(code, " eprintln!(\" GET /health\");")?;
6815 writeln!(code)?;
6816 writeln!(code, " for request in server.incoming_requests() {{")?;
6817 writeln!(code, " let method = request.method().to_string();")?;
6818 writeln!(code, " let url = request.url().to_string();")?;
6819 writeln!(code)?;
6820 writeln!(code, " match (method.as_str(), url.as_str()) {{")?;
6821
6822 writeln!(
6824 code,
6825 " (\"POST\", \"/v1/chat/completions\") => {{"
6826 )?;
6827 writeln!(
6828 code,
6829 " handle_chat_completion(&mut model, &tokenizer, request);"
6830 )?;
6831 writeln!(code, " }}")?;
6832
6833 writeln!(code, " (\"GET\", \"/v1/models\") => {{")?;
6835 writeln!(code, " let body = serde_json::json!({{")?;
6836 writeln!(code, " \"object\": \"list\",")?;
6837 writeln!(code, " \"data\": [{{")?;
6838 writeln!(code, " \"id\": \"forgellm-metal\",")?;
6839 writeln!(code, " \"object\": \"model\",")?;
6840 writeln!(code, " \"owned_by\": \"forgellm\"")?;
6841 writeln!(code, " }}]")?;
6842 writeln!(code, " }});")?;
6843 writeln!(
6844 code,
6845 " let resp = tiny_http::Response::from_string(body.to_string())"
6846 )?;
6847 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6848 writeln!(code, " request.respond(resp).ok();")?;
6849 writeln!(code, " }}")?;
6850
6851 writeln!(code, " (\"GET\", \"/health\") => {{")?;
6853 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6854 writeln!(code, " request.respond(resp).ok();")?;
6855 writeln!(code, " }}")?;
6856
6857 writeln!(code, " _ => {{")?;
6859 writeln!(
6860 code,
6861 " let resp = tiny_http::Response::from_string(\"Not Found\")"
6862 )?;
6863 writeln!(code, " .with_status_code(404);")?;
6864 writeln!(code, " request.respond(resp).ok();")?;
6865 writeln!(code, " }}")?;
6866 writeln!(code, " }}")?;
6867 writeln!(code, " }}")?;
6868 writeln!(code, "}}")?;
6869 writeln!(code)?;
6870
6871 writeln!(code, "fn handle_chat_completion(")?;
6873 writeln!(code, " model: &mut model::MetalModel,")?;
6874 writeln!(code, " tokenizer: &tokenizers::Tokenizer,")?;
6875 writeln!(code, " mut request: tiny_http::Request,")?;
6876 writeln!(code, ") {{")?;
6877 writeln!(code, " // Read request body")?;
6878 writeln!(code, " let mut body = String::new();")?;
6879 writeln!(
6880 code,
6881 " if request.as_reader().read_to_string(&mut body).is_err() {{"
6882 )?;
6883 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6884 writeln!(code, " .with_status_code(400);")?;
6885 writeln!(code, " request.respond(resp).ok();")?;
6886 writeln!(code, " return;")?;
6887 writeln!(code, " }}")?;
6888 writeln!(code)?;
6889 writeln!(code, " // Parse JSON")?;
6890 writeln!(
6891 code,
6892 " let req: ChatRequest = match serde_json::from_str(&body) {{"
6893 )?;
6894 writeln!(code, " Ok(r) => r,")?;
6895 writeln!(code, " Err(e) => {{")?;
6896 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6897 writeln!(code, " .with_status_code(400);")?;
6898 writeln!(code, " request.respond(resp).ok();")?;
6899 writeln!(code, " return;")?;
6900 writeln!(code, " }}")?;
6901 writeln!(code, " }};")?;
6902 writeln!(code)?;
6903 writeln!(
6904 code,
6905 " let prompt = format_chat_messages(&req.messages);"
6906 )?;
6907 writeln!(
6908 code,
6909 " let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6910 )?;
6911 writeln!(code, " Ok(e) => e,")?;
6912 writeln!(code, " Err(e) => {{")?;
6913 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6914 writeln!(code, " .with_status_code(500);")?;
6915 writeln!(code, " request.respond(resp).ok();")?;
6916 writeln!(code, " return;")?;
6917 writeln!(code, " }}")?;
6918 writeln!(code, " }};")?;
6919 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6920 writeln!(code, " let stream = req.stream.unwrap_or(false);")?;
6921 writeln!(code, " let max_tokens = req.max_tokens.unwrap_or(256);")?;
6922 writeln!(
6923 code,
6924 " let _temperature = req.temperature.unwrap_or(1.0);"
6925 )?;
6926 writeln!(code)?;
6927
6928 writeln!(code, " model.reset();")?;
6930 writeln!(code)?;
6931
6932 writeln!(code, " let prefill_start = Instant::now();")?;
6934 writeln!(code, " let logits = prefill(model, prompt_tokens);")?;
6935 writeln!(
6936 code,
6937 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6938 )?;
6939 writeln!(code, " let prefill_count = prompt_tokens.len();")?;
6940 writeln!(code, " let mut next_token = argmax(&logits);")?;
6941 writeln!(code)?;
6942
6943 writeln!(code, " if stream {{")?;
6944
6945 writeln!(
6947 code,
6948 " // SSE streaming: generate tokens and build SSE body"
6949 )?;
6950 writeln!(code, " let gen_start = Instant::now();")?;
6951 writeln!(code, " let mut generated_count: usize = 0;")?;
6952 writeln!(code, " let mut sse_body = String::new();")?;
6953 writeln!(code, " for _ in 0..max_tokens {{")?;
6954 writeln!(code, " if next_token == 2 {{ break; }}")?;
6955 writeln!(
6956 code,
6957 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6958 )?;
6959 writeln!(
6960 code,
6961 " let escaped = serde_json::to_string(&text).unwrap_or_default();"
6962 )?;
6963 writeln!(
6964 code,
6965 " // escaped includes surrounding quotes, strip them"
6966 )?;
6967 writeln!(
6968 code,
6969 " let inner = &escaped[1..escaped.len()-1];"
6970 )?;
6971 writeln!(code, " sse_body.push_str(&format!(")?;
6972 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6973 writeln!(code, " inner")?;
6974 writeln!(code, " ));")?;
6975 writeln!(code, " }}")?;
6976 writeln!(code, " generated_count += 1;")?;
6977 writeln!(code, " let logits = model.forward(next_token);")?;
6978 writeln!(code, " next_token = argmax(&logits);")?;
6979 writeln!(code, " }}")?;
6980 writeln!(
6981 code,
6982 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6983 )?;
6984 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6985 writeln!(code, " let gen_time_ms = gen_elapsed * 1000.0;")?;
6986 writeln!(code)?;
6987 writeln!(
6988 code,
6989 " // Final chunk with finish_reason, timing, and DONE sentinel"
6990 )?;
6991 writeln!(code, " sse_body.push_str(&format!(")?;
6992 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6993 writeln!(code, " prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6994 writeln!(code, " ));")?;
6995 writeln!(code)?;
6996 writeln!(
6997 code,
6998 " let resp = tiny_http::Response::from_string(sse_body)"
6999 )?;
7000 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
7001 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
7002 writeln!(code, " request.respond(resp).ok();")?;
7003
7004 writeln!(code, " }} else {{")?;
7005
7006 writeln!(
7008 code,
7009 " // Non-streaming: generate all tokens, return JSON"
7010 )?;
7011 writeln!(code, " let gen_start = Instant::now();")?;
7012 writeln!(code, " let mut generated_count: usize = 0;")?;
7013 writeln!(code, " let mut generated = String::new();")?;
7014 writeln!(code, " for _ in 0..max_tokens {{")?;
7015 writeln!(code, " if next_token == 2 {{ break; }}")?;
7016 writeln!(
7017 code,
7018 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
7019 )?;
7020 writeln!(code, " generated.push_str(&text);")?;
7021 writeln!(code, " }}")?;
7022 writeln!(code, " generated_count += 1;")?;
7023 writeln!(code, " let logits = model.forward(next_token);")?;
7024 writeln!(code, " next_token = argmax(&logits);")?;
7025 writeln!(code, " }}")?;
7026 writeln!(
7027 code,
7028 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
7029 )?;
7030 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
7031 writeln!(code)?;
7032 writeln!(code, " let resp_json = serde_json::json!({{")?;
7033 writeln!(code, " \"id\": \"chatcmpl-1\",")?;
7034 writeln!(code, " \"object\": \"chat.completion\",")?;
7035 writeln!(code, " \"choices\": [{{")?;
7036 writeln!(code, " \"index\": 0,")?;
7037 writeln!(code, " \"message\": {{")?;
7038 writeln!(code, " \"role\": \"assistant\",")?;
7039 writeln!(code, " \"content\": generated")?;
7040 writeln!(code, " }},")?;
7041 writeln!(code, " \"finish_reason\": \"stop\"")?;
7042 writeln!(code, " }}],")?;
7043 writeln!(code, " \"usage\": {{")?;
7044 writeln!(code, " \"prefill_tokens\": prefill_count,")?;
7045 writeln!(
7046 code,
7047 " \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
7048 )?;
7049 writeln!(
7050 code,
7051 " \"generation_tokens\": generated_count,"
7052 )?;
7053 writeln!(
7054 code,
7055 " \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
7056 )?;
7057 writeln!(code, " \"tokens_per_sec\": gen_tok_s")?;
7058 writeln!(code, " }}")?;
7059 writeln!(code, " }});")?;
7060 writeln!(
7061 code,
7062 " let resp = tiny_http::Response::from_string(resp_json.to_string())"
7063 )?;
7064 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7065 writeln!(code, " request.respond(resp).ok();")?;
7066 writeln!(code, " }}")?;
7067 writeln!(code, "}}")?;
7068
7069 Ok(code)
7070}
7071
7072#[cfg(test)]
7077mod tests {
7078 use super::*;
7079 use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7080
7081 fn minimal_config() -> ModelConfig {
7082 ModelConfig {
7083 architecture: Architecture::Llama,
7084 hidden_size: 64,
7085 intermediate_size: 128,
7086 num_layers: 2,
7087 num_attention_heads: 4,
7088 num_kv_heads: 4,
7089 head_dim: 16,
7090 vocab_size: 256,
7091 max_seq_len: 512,
7092 rms_norm_eps: 1e-5,
7093 rope_theta: 10000.0,
7094 dtype: DType::F32,
7095 sliding_window_size: None,
7096 qkv_bias: false,
7097 }
7098 }
7099
7100 fn minimal_graph() -> Graph {
7101 Graph::new("test-metal").with_config(minimal_config())
7102 }
7103
7104 #[test]
7105 fn generate_metal_project_creates_files() {
7106 let dir = tempfile::tempdir().unwrap();
7107 let graph = minimal_graph();
7108 generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7109
7110 assert!(
7111 dir.path().join("Cargo.toml").exists(),
7112 "Cargo.toml should be created"
7113 );
7114 assert!(
7115 dir.path().join("src/model.rs").exists(),
7116 "src/model.rs should be created"
7117 );
7118 assert!(
7119 dir.path().join("src/main.rs").exists(),
7120 "src/main.rs should be created"
7121 );
7122 assert!(
7123 dir.path().join("shaders/kernels.metal").exists(),
7124 "shaders/kernels.metal should be created"
7125 );
7126 }
7127
7128 #[test]
7129 fn generated_cargo_toml_has_metal_dep() {
7130 let toml = generate_cargo_toml("my-model");
7131 assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7132 assert!(
7133 toml.contains("tokenizers"),
7134 "Cargo.toml should depend on tokenizers"
7135 );
7136 assert!(
7137 toml.contains("memmap2"),
7138 "Cargo.toml should depend on memmap2"
7139 );
7140 assert!(toml.contains("half"), "Cargo.toml should depend on half");
7141 }
7142
7143 #[test]
7144 fn generated_model_rs_contains_metal_code() {
7145 let config = minimal_config();
7146 let model_rs = generate_model_rs(&config).unwrap();
7147
7148 assert!(
7149 model_rs.contains("pub struct MetalModel"),
7150 "model.rs should define MetalModel struct"
7151 );
7152 assert!(
7153 model_rs.contains("matmul_pipeline: ComputePipelineState"),
7154 "MetalModel should have matmul_pipeline field"
7155 );
7156 assert!(
7157 model_rs.contains("Device::system_default()"),
7158 "model.rs should use Metal device"
7159 );
7160 assert!(
7161 model_rs.contains("new_library_with_source"),
7162 "model.rs should compile Metal shaders"
7163 );
7164 assert!(
7165 model_rs.contains("fn new(weights: &[u8])"),
7166 "MetalModel should implement new()"
7167 );
7168 assert!(
7169 model_rs.contains("fn forward(&mut self, token_id: u32)"),
7170 "MetalModel should implement forward()"
7171 );
7172 }
7173
7174 #[test]
7175 fn generated_shaders_contain_kernel_names() {
7176 let shaders = generate_metal_shaders(&minimal_config());
7177
7178 assert!(
7179 shaders.contains("kernel void matmul_vec"),
7180 "shaders should contain matmul_vec kernel"
7181 );
7182 assert!(
7183 shaders.contains("kernel void rms_norm"),
7184 "shaders should contain rms_norm kernel"
7185 );
7186 assert!(
7187 shaders.contains("kernel void rope"),
7188 "shaders should contain rope kernel"
7189 );
7190 assert!(
7191 shaders.contains("kernel void softmax"),
7192 "shaders should contain softmax kernel"
7193 );
7194 assert!(
7195 shaders.contains("kernel void silu_mul("),
7196 "shaders should contain silu_mul kernel"
7197 );
7198 assert!(
7199 shaders.contains("kernel void silu_mul_fused"),
7200 "shaders should contain silu_mul_fused kernel"
7201 );
7202 assert!(
7203 shaders.contains("kernel void elementwise_add"),
7204 "shaders should contain elementwise_add kernel"
7205 );
7206 assert!(
7207 shaders.contains("kernel void attention"),
7208 "shaders should contain attention kernel"
7209 );
7210 assert!(
7211 shaders.contains("kernel void add_inplace"),
7212 "shaders should contain add_inplace kernel"
7213 );
7214 assert!(
7215 shaders.contains("kernel void copy_buffer"),
7216 "shaders should contain copy_buffer kernel"
7217 );
7218 assert!(
7219 shaders.contains("kernel void copy_offset"),
7220 "shaders should contain copy_offset kernel"
7221 );
7222 }
7223
7224 #[test]
7225 fn generated_shaders_use_simdgroup_features() {
7226 let shaders = generate_metal_shaders(&minimal_config());
7227
7228 assert!(
7229 shaders.contains("threadgroup_barrier"),
7230 "shaders should use threadgroup barriers"
7231 );
7232 assert!(
7233 shaders.contains("threadgroup float"),
7234 "shaders should use threadgroup shared memory"
7235 );
7236 assert!(
7237 shaders.contains("thread_index_in_threadgroup"),
7238 "shaders should use threadgroup indexing"
7239 );
7240 assert!(
7241 shaders.contains("simd_sum"),
7242 "shaders should use simd_sum for warp-level reduction"
7243 );
7244 assert!(
7245 shaders.contains("simd_max"),
7246 "attention kernel should use simd_max for cooperative softmax"
7247 );
7248 assert!(
7249 shaders.contains("thread_index_in_simdgroup"),
7250 "shaders should use simdgroup lane indexing"
7251 );
7252 assert!(
7253 shaders.contains("simdgroup_index_in_threadgroup"),
7254 "shaders should use simdgroup indexing within threadgroup"
7255 );
7256 assert!(
7257 shaders.contains("float4"),
7258 "matmul_vec should use float4 vectorized loads"
7259 );
7260 }
7261
7262 #[test]
7263 fn generated_main_rs_has_tokenizer_usage() {
7264 let config = minimal_config();
7265 let main_rs = generate_main_rs("test-model", &config).unwrap();
7266
7267 assert!(
7268 main_rs.contains("tokenizers::Tokenizer"),
7269 "main.rs should use tokenizers crate"
7270 );
7271 assert!(
7272 main_rs.contains("MetalModel::new"),
7273 "main.rs should call MetalModel::new"
7274 );
7275 assert!(
7276 main_rs.contains("model.forward"),
7277 "main.rs should call model.forward"
7278 );
7279 assert!(
7280 main_rs.contains("memmap2"),
7281 "main.rs should use memmap2 for zero-copy weight loading"
7282 );
7283 }
7284
7285 #[test]
7286 fn missing_config_returns_error() {
7287 let dir = tempfile::tempdir().unwrap();
7288 let graph = Graph::new("no-config");
7289 let result = generate_metal_project(&graph, dir.path(), "fail");
7290 assert!(
7291 matches!(result, Err(MetalCodegenError::MissingConfig)),
7292 "should fail with MissingConfig when graph has no config"
7293 );
7294 }
7295
7296 #[test]
7297 fn sanitize_name_works() {
7298 assert_eq!(sanitize_name("My Model!"), "my-model");
7299 assert_eq!(sanitize_name("test_model"), "test-model");
7300 assert_eq!(sanitize_name("simple"), "simple");
7301 }
7302
7303 #[test]
7304 fn generated_forward_uses_single_command_buffer() {
7305 let config = minimal_config();
7306 let model_rs = generate_model_rs(&config).unwrap();
7307
7308 let forward_start = model_rs
7311 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7312 .unwrap();
7313 let forward_body = &model_rs[forward_start..];
7314 let forward_end = forward_body
7316 .find("\n pub fn forward_profile")
7317 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7318 .or_else(|| forward_body.find("\n fn dispatch_"))
7319 .unwrap_or(forward_body.len());
7320 let forward_code = &forward_body[..forward_end];
7321
7322 let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7324 assert_eq!(
7325 cmd_buf_count, 1,
7326 "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7327 );
7328
7329 let commit_count = forward_code.matches("cmd.commit()").count();
7331 assert_eq!(
7332 commit_count, 1,
7333 "forward() should commit exactly once, found {commit_count}"
7334 );
7335
7336 let wait_count = forward_code.matches("wait_until_completed()").count();
7338 assert!(
7339 wait_count >= 1 && wait_count <= 2,
7340 "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7341 );
7342 }
7343
7344 #[test]
7345 fn generated_model_has_preallocated_working_buffers() {
7346 let config = minimal_config();
7347 let model_rs = generate_model_rs(&config).unwrap();
7348
7349 for buf_name in &[
7350 "normed_buf",
7351 "qkv_buf",
7352 "attn_out_buf",
7353 "attn_proj_buf",
7354 "gate_up_buf",
7355 "ffn_hidden_buf",
7356 "ffn_out_buf",
7357 "add_tmp_buf",
7358 ] {
7359 assert!(
7360 model_rs.contains(&format!("{buf_name}: Buffer")),
7361 "MetalModel should have pre-allocated {buf_name} field"
7362 );
7363 }
7364 }
7365
7366 #[test]
7367 fn generated_dispatch_helpers_take_compute_encoder_ref() {
7368 let config = minimal_config();
7369 let model_rs = generate_model_rs(&config).unwrap();
7370
7371 for method in &[
7372 "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7373 "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7374 "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7375 "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7376 "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7377 "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7378 "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7379 "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7380 "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7381 "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7382 "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7383 "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7384 "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7385 ] {
7386 assert!(
7387 model_rs.contains(method),
7388 "model.rs should contain dispatch helper: {method}"
7389 );
7390 }
7391 }
7392
7393 #[test]
7394 fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7395 let config = minimal_config();
7396 let model_rs = generate_model_rs(&config).unwrap();
7397
7398 let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7400 let helpers_code = &model_rs[helpers_start..];
7401
7402 assert!(
7404 !helpers_code.contains("self.queue.new_command_buffer()"),
7405 "dispatch helpers should not create their own command buffers"
7406 );
7407
7408 assert!(
7410 !helpers_code.contains("new_compute_command_encoder()"),
7411 "dispatch helpers should not create their own compute encoders"
7412 );
7413
7414 assert!(
7416 !helpers_code.contains("end_encoding()"),
7417 "dispatch helpers should not call end_encoding"
7418 );
7419
7420 assert!(
7422 !helpers_code.contains(".commit()"),
7423 "dispatch helpers should not commit command buffers"
7424 );
7425 assert!(
7426 !helpers_code.contains("wait_until_completed"),
7427 "dispatch helpers should not wait on command buffers"
7428 );
7429 }
7430
7431 #[test]
7432 fn generated_forward_batches_compute_encoders() {
7433 let config = minimal_config();
7434 let model_rs = generate_model_rs(&config).unwrap();
7435
7436 let forward_start = model_rs
7438 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7439 .unwrap();
7440 let forward_body = &model_rs[forward_start..];
7441 let forward_end = forward_body
7442 .find("\n pub fn forward_profile")
7443 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7444 .or_else(|| forward_body.find("\n fn dispatch_"))
7445 .unwrap_or(forward_body.len());
7446 let forward_code = &forward_body[..forward_end];
7447
7448 assert!(
7450 !forward_code.contains("device.new_buffer"),
7451 "forward() should not allocate new buffers per call"
7452 );
7453
7454 let compute_encoder_count = forward_code
7457 .matches("new_compute_command_encoder()")
7458 .count();
7459 let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7460
7461 assert_eq!(
7463 compute_encoder_count, 1,
7464 "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7465 );
7466 assert_eq!(
7467 blit_encoder_count, 0,
7468 "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7469 );
7470 }
7471
7472 #[test]
7473 fn generated_forward_uses_add_inplace() {
7474 let config = minimal_config();
7475 let model_rs = generate_model_rs(&config).unwrap();
7476
7477 assert!(
7479 model_rs.contains("dispatch_add_inplace"),
7480 "forward() should use dispatch_add_inplace for residual connections"
7481 );
7482 assert!(
7483 model_rs.contains("add_inplace_pipeline"),
7484 "MetalModel should have add_inplace_pipeline"
7485 );
7486 }
7487
7488 fn minimal_q8_config() -> ModelConfig {
7489 ModelConfig {
7490 architecture: Architecture::Llama,
7491 hidden_size: 64,
7492 intermediate_size: 128,
7493 num_layers: 2,
7494 num_attention_heads: 4,
7495 num_kv_heads: 4,
7496 head_dim: 16,
7497 vocab_size: 256,
7498 max_seq_len: 512,
7499 rms_norm_eps: 1e-5,
7500 rope_theta: 10000.0,
7501 dtype: DType::Q8_0,
7502 sliding_window_size: None,
7503 qkv_bias: false,
7504 }
7505 }
7506
7507 #[test]
7508 fn generated_shaders_contain_q8_kernel() {
7509 let shaders = generate_metal_shaders(&minimal_config());
7510
7511 assert!(
7512 shaders.contains("kernel void matmul_vec_q8"),
7513 "shaders should contain matmul_vec_q8 kernel"
7514 );
7515 assert!(
7516 shaders.contains("device const uchar* matrix"),
7517 "matmul_vec_q8 should accept raw Q8_0 bytes"
7518 );
7519 assert!(
7520 shaders.contains("packed_short4"),
7521 "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7522 );
7523 assert!(
7524 shaders.contains("as_type<char2>"),
7525 "matmul_vec_q8 should bitcast short lanes to char2"
7526 );
7527 assert!(
7528 shaders.contains("device const half*"),
7529 "matmul_vec_q8 should read f16 scale via half pointer"
7530 );
7531 }
7532
7533 #[test]
7534 fn generated_model_uses_fused_qkv_projections() {
7535 let config = minimal_config();
7536 let model_rs = generate_model_rs(&config).unwrap();
7537
7538 assert!(
7540 model_rs.contains("qkv_weight: Buffer"),
7541 "LayerBuffers should have fused qkv_weight field"
7542 );
7543 assert!(
7545 !model_rs.contains(" q_weight: Buffer"),
7546 "LayerBuffers should not have separate q_weight field"
7547 );
7548 assert!(
7549 !model_rs.contains(" k_weight: Buffer"),
7550 "LayerBuffers should not have separate k_weight field"
7551 );
7552 assert!(
7553 !model_rs.contains(" v_weight: Buffer"),
7554 "LayerBuffers should not have separate v_weight field"
7555 );
7556
7557 assert!(
7559 model_rs.contains("gate_up_weight: Buffer"),
7560 "LayerBuffers should have fused gate_up_weight field"
7561 );
7562 assert!(
7564 !model_rs.contains(" gate_weight: Buffer"),
7565 "LayerBuffers should not have separate gate_weight field"
7566 );
7567 assert!(
7568 !model_rs.contains(" up_weight: Buffer"),
7569 "LayerBuffers should not have separate up_weight field"
7570 );
7571
7572 assert!(
7574 model_rs.contains("qkv_buf: Buffer"),
7575 "MetalModel should have fused qkv_buf"
7576 );
7577 assert!(
7578 model_rs.contains("gate_up_buf: Buffer"),
7579 "MetalModel should have fused gate_up_buf"
7580 );
7581
7582 assert!(
7584 model_rs.contains("dispatch_silu_mul_fused"),
7585 "forward pass should use dispatch_silu_mul_fused"
7586 );
7587 assert!(
7588 model_rs.contains("dispatch_rope_offset"),
7589 "forward pass should use dispatch_rope_offset for fused QKV"
7590 );
7591 assert!(
7592 model_rs.contains("dispatch_attention_offset"),
7593 "forward pass should use dispatch_attention_offset for fused QKV"
7594 );
7595 }
7596
7597 #[test]
7598 fn q8_model_has_matmul_q8_pipeline() {
7599 let config = minimal_q8_config();
7600 let model_rs = generate_model_rs(&config).unwrap();
7601
7602 assert!(
7603 model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7604 "MetalModel should have matmul_q8_pipeline field"
7605 );
7606 assert!(
7607 model_rs.contains("matmul_q8_pipeline,"),
7608 "MetalModel Self should include matmul_q8_pipeline"
7609 );
7610 }
7611
7612 #[test]
7613 fn q8_model_uses_dispatch_matmul_q8() {
7614 let config = minimal_q8_config();
7615 let model_rs = generate_model_rs(&config).unwrap();
7616
7617 assert!(
7618 model_rs.contains("dispatch_matmul_q8"),
7619 "Q8_0 model should use dispatch_matmul_q8 for projections"
7620 );
7621 assert!(
7622 model_rs.contains("fn dispatch_matmul_q8"),
7623 "model.rs should define dispatch_matmul_q8 method"
7624 );
7625 }
7626
7627 #[test]
7628 fn q8_model_loads_raw_bytes_not_dequantized() {
7629 let config = minimal_q8_config();
7630 let model_rs = generate_model_rs(&config).unwrap();
7631
7632 assert!(
7634 !model_rs.contains("f16_to_f32"),
7635 "Q8_0 model should not dequantize weights to f32"
7636 );
7637 assert!(
7638 !model_rs.contains("f32_data"),
7639 "Q8_0 model should not create f32 weight data"
7640 );
7641
7642 assert!(
7644 model_rs.contains("total_raw as u64"),
7645 "Q8_0 model should load raw bytes into Metal buffer"
7646 );
7647 }
7648
7649 #[test]
7650 fn q8_model_norms_stay_f32() {
7651 let config = minimal_q8_config();
7652 let model_rs = generate_model_rs(&config).unwrap();
7653
7654 assert!(
7656 model_rs.contains("let attn_norm = next_f32_buffer"),
7657 "attn_norm should use f32 buffer even for Q8_0 models"
7658 );
7659 assert!(
7660 model_rs.contains("let ffn_norm = next_f32_buffer"),
7661 "ffn_norm should use f32 buffer even for Q8_0 models"
7662 );
7663 assert!(
7664 model_rs.contains("let norm_buf = next_f32_buffer"),
7665 "final norm should use f32 buffer even for Q8_0 models"
7666 );
7667 }
7668
7669 #[test]
7670 fn q8_model_uses_fused_weight_loading() {
7671 let config = minimal_q8_config();
7672 let model_rs = generate_model_rs(&config).unwrap();
7673
7674 assert!(
7676 model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7677 "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7678 );
7679 assert!(
7681 model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7682 "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7683 );
7684 assert!(
7686 model_rs.contains("let o_weight = next_q8_buffer"),
7687 "Q8_0 model should use next_q8_buffer for O weight"
7688 );
7689 assert!(
7690 model_rs.contains("let down_weight = next_q8_buffer"),
7691 "Q8_0 model should use next_q8_buffer for down weight"
7692 );
7693 }
7694
7695 #[test]
7696 fn f32_model_does_not_use_q8_dispatch() {
7697 let config = minimal_config();
7698 let model_rs = generate_model_rs(&config).unwrap();
7699
7700 let forward_start = model_rs
7702 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7703 .unwrap();
7704 let forward_body = &model_rs[forward_start..];
7705 let forward_end = forward_body
7706 .find("\n fn dispatch_")
7707 .unwrap_or(forward_body.len());
7708 let forward_code = &forward_body[..forward_end];
7709
7710 assert!(
7711 !forward_code.contains("dispatch_matmul_q8"),
7712 "f32 model forward should not use dispatch_matmul_q8"
7713 );
7714 }
7715
7716 #[test]
7717 fn q8_dispatch_helper_takes_compute_encoder_ref() {
7718 let config = minimal_q8_config();
7719 let model_rs = generate_model_rs(&config).unwrap();
7720
7721 assert!(
7722 model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7723 "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7724 );
7725 }
7726
7727 #[test]
7728 fn generated_model_has_double_buffered_prefill() {
7729 let config = minimal_config();
7730 let model_rs = generate_model_rs(&config).unwrap();
7731
7732 assert!(
7734 model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7735 "MetalModel should have prev_cmd field for double-buffered prefill"
7736 );
7737
7738 assert!(
7740 model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7741 "MetalModel should have forward_prefill method"
7742 );
7743
7744 assert!(
7746 model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7747 "forward() should drain prev_cmd from previous prefill"
7748 );
7749 }
7750
7751 #[test]
7752 fn generated_main_rs_uses_forward_prefill_for_prompt() {
7753 let config = minimal_config();
7754 let main_rs = generate_main_rs("test-model", &config).unwrap();
7755
7756 assert!(
7757 main_rs.contains("forward_prefill"),
7758 "main.rs should use forward_prefill for intermediate prompt tokens"
7759 );
7760 assert!(
7761 main_rs.contains("double-buffered"),
7762 "main.rs should document double-buffered prefill"
7763 );
7764 }
7765
7766 #[test]
7767 fn generated_shaders_q8_uses_wide_vectorized_loads() {
7768 let shaders = generate_metal_shaders(&minimal_config());
7769
7770 assert!(
7771 shaders.contains("packed_short4"),
7772 "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7773 );
7774 assert!(
7775 shaders.contains("d0[0]"),
7776 "matmul_vec_q8 should index the wide pointer for row 0"
7777 );
7778 assert!(
7779 shaders.contains("as_type<char2>"),
7780 "matmul_vec_q8 should bitcast short lanes to char2"
7781 );
7782 assert!(
7783 shaders.contains("dot("),
7784 "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7785 );
7786 }
7787
7788 fn minimal_q4_config() -> ModelConfig {
7791 ModelConfig {
7792 architecture: Architecture::Llama,
7793 hidden_size: 64,
7794 intermediate_size: 128,
7795 num_layers: 2,
7796 num_attention_heads: 4,
7797 num_kv_heads: 4,
7798 head_dim: 16,
7799 vocab_size: 256,
7800 max_seq_len: 512,
7801 rms_norm_eps: 1e-5,
7802 rope_theta: 10000.0,
7803 dtype: DType::Q4_0,
7804 sliding_window_size: None,
7805 qkv_bias: false,
7806 }
7807 }
7808
7809 #[test]
7810 fn generated_shaders_contain_q4_kernel() {
7811 let shaders = generate_metal_shaders(&minimal_config());
7812
7813 assert!(
7814 shaders.contains("kernel void matmul_vec_q4"),
7815 "shaders should contain matmul_vec_q4 kernel"
7816 );
7817 assert!(
7818 shaders.contains("Q4_ROWS_PER_TG"),
7819 "shaders should define Q4_ROWS_PER_TG constant"
7820 );
7821 assert!(
7822 shaders.contains("Q4_ROWS_PER_SG"),
7823 "shaders should define Q4_ROWS_PER_SG constant"
7824 );
7825 }
7826
7827 #[test]
7828 fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7829 let shaders = generate_metal_shaders(&minimal_config());
7830
7831 assert!(
7833 shaders.contains("uchar4"),
7834 "matmul_vec_q4 should use uchar4 for packed byte loads"
7835 );
7836 assert!(
7838 shaders.contains("&0xF"),
7839 "matmul_vec_q4 should extract low nibble with &0xF"
7840 );
7841 assert!(
7842 shaders.contains(">>4"),
7843 "matmul_vec_q4 should extract high nibble with >>4"
7844 );
7845 assert!(
7847 shaders.contains("-8)"),
7848 "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7849 );
7850 assert!(
7852 shaders.contains("blk * 18"),
7853 "matmul_vec_q4 should use 18-byte block stride"
7854 );
7855 }
7856
7857 #[test]
7858 fn q4_model_has_matmul_q4_pipeline() {
7859 let config = minimal_q4_config();
7860 let model_rs = generate_model_rs(&config).unwrap();
7861
7862 assert!(
7863 model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7864 "MetalModel should have matmul_q4_pipeline field"
7865 );
7866 assert!(
7867 model_rs.contains("matmul_q4_pipeline,"),
7868 "MetalModel Self should include matmul_q4_pipeline"
7869 );
7870 }
7871
7872 #[test]
7873 fn q4_model_uses_dispatch_matmul_q4() {
7874 let config = minimal_q4_config();
7875 let model_rs = generate_model_rs(&config).unwrap();
7876
7877 assert!(
7878 model_rs.contains("dispatch_matmul_q4"),
7879 "Q4_0 model should use dispatch_matmul_q4 for projections"
7880 );
7881 assert!(
7882 model_rs.contains("fn dispatch_matmul_q4"),
7883 "model.rs should define dispatch_matmul_q4 method"
7884 );
7885 }
7886
7887 #[test]
7888 fn q4_model_loads_raw_bytes_not_dequantized() {
7889 let config = minimal_q4_config();
7890 let model_rs = generate_model_rs(&config).unwrap();
7891
7892 assert!(
7894 !model_rs.contains("f16_to_f32"),
7895 "Q4_0 model should not dequantize weights to f32"
7896 );
7897
7898 assert!(
7900 model_rs.contains("total_raw as u64"),
7901 "Q4_0 model should load raw bytes into Metal buffer"
7902 );
7903 }
7904
7905 #[test]
7906 fn q4_model_norms_stay_f32() {
7907 let config = minimal_q4_config();
7908 let model_rs = generate_model_rs(&config).unwrap();
7909
7910 assert!(
7911 model_rs.contains("let attn_norm = next_f32_buffer"),
7912 "attn_norm should use f32 buffer even for Q4_0 models"
7913 );
7914 assert!(
7915 model_rs.contains("let ffn_norm = next_f32_buffer"),
7916 "ffn_norm should use f32 buffer even for Q4_0 models"
7917 );
7918 assert!(
7919 model_rs.contains("let norm_buf = next_f32_buffer"),
7920 "final norm should use f32 buffer even for Q4_0 models"
7921 );
7922 }
7923
7924 #[test]
7925 fn q4_model_uses_fused_weight_loading() {
7926 let config = minimal_q4_config();
7927 let model_rs = generate_model_rs(&config).unwrap();
7928
7929 assert!(
7930 model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7931 "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7932 );
7933 assert!(
7934 model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7935 "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7936 );
7937 assert!(
7938 model_rs.contains("let o_weight = next_q4_buffer"),
7939 "Q4_0 model should use next_q4_buffer for O weight"
7940 );
7941 assert!(
7942 model_rs.contains("let down_weight = next_q4_buffer"),
7943 "Q4_0 model should use next_q4_buffer for down weight"
7944 );
7945 }
7946
7947 #[test]
7948 fn attention_flash_batch_kernel_exists() {
7949 let config = minimal_config();
7953 let model_rs = generate_model_rs(&config).unwrap();
7954 let shaders = generate_metal_shaders(&config);
7955
7956 assert!(
7957 shaders.contains("kernel void attention_flash_batch"),
7958 "shaders.metal must still contain the attention_flash_batch kernel"
7959 );
7960 assert!(
7961 shaders.contains("FLASH_K_TILE"),
7962 "flash kernel must tile K/V with a FLASH_K_TILE constant"
7963 );
7964 assert!(
7965 model_rs.contains("attention_flash_batch_pipeline"),
7966 "MetalModel must register the flash attention pipeline"
7967 );
7968 }
7969
7970 #[test]
7971 fn kv_cache_stored_as_f16() {
7972 let config = minimal_config();
7978 let shaders = generate_metal_shaders(&config);
7979 let model_rs = generate_model_rs(&config).unwrap();
7980
7981 for kernel in [
7982 "kernel void attention(",
7983 "kernel void attention_batch(",
7984 "kernel void attention_flash_batch(",
7985 "kernel void attention_mma_flash_batch(",
7986 ] {
7987 let start = shaders
7988 .find(kernel)
7989 .unwrap_or_else(|| panic!("kernel {kernel} missing"));
7990 let sig_end = shaders[start..]
7993 .find("){")
7994 .unwrap_or_else(|| shaders[start..].find(") {").unwrap());
7995 let sig = &shaders[start..start + sig_end];
7996 assert!(
7997 sig.contains("device const half*")
7998 && sig.contains("k_cache")
7999 && sig.contains("v_cache"),
8000 "{kernel} must read k_cache/v_cache as `device const half*`"
8001 );
8002 assert!(
8003 !sig.contains("device const float* k_cache")
8004 && !sig.contains("device const float* k_cache"),
8005 "{kernel} still reads k_cache as float"
8006 );
8007 }
8008
8009 assert!(
8010 shaders.contains("kernel void copy_f32_to_f16_offset"),
8011 "f32->f16 copy kernel must be present for single-token decode KV writes"
8012 );
8013 assert!(
8014 model_rs.contains("dispatch_copy_from_offset_f16"),
8015 "single-token decode must dispatch the f32->f16 KV copy"
8016 );
8017 }
8018
8019 #[test]
8020 fn attention_mma_flash_batch_kernel_wired() {
8021 let config = minimal_config();
8024 let model_rs = generate_model_rs(&config).unwrap();
8025 let shaders = generate_metal_shaders(&config);
8026
8027 assert!(
8028 shaders.contains("kernel void attention_mma_flash_batch"),
8029 "shaders.metal must contain the MMA flash kernel"
8030 );
8031 assert!(
8032 shaders.contains("FLASH_MMA_Q_BLOCK"),
8033 "MMA flash kernel must define Q_BLOCK tiling constant"
8034 );
8035 assert!(
8036 shaders.contains("simdgroup_multiply_accumulate"),
8037 "MMA flash kernel must use hardware MMA"
8038 );
8039 assert!(
8040 model_rs.contains("attention_mma_flash_batch_pipeline"),
8041 "MetalModel must register the MMA flash pipeline"
8042 );
8043 assert!(
8044 model_rs.contains("mma_opt_out"),
8045 "dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
8046 );
8047 assert!(
8048 model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
8049 "MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
8050 );
8051 }
8052
8053 #[test]
8054 fn forward_prefill_batch_chunks_by_max_batch_size() {
8055 let config = minimal_config();
8060 let model_rs = generate_model_rs(&config).unwrap();
8061 assert!(
8062 model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
8063 "forward_prefill_batch must chunk long prompts"
8064 );
8065 assert!(
8066 !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
8067 "the old truncation path must be gone"
8068 );
8069 }
8070
8071 #[test]
8072 fn qwen2_qkv_bias_wired_through_metal_codegen() {
8073 let config = ModelConfig {
8077 architecture: Architecture::Qwen2,
8078 qkv_bias: true,
8079 ..minimal_config()
8080 };
8081 let model_rs = generate_model_rs(&config).unwrap();
8082
8083 assert!(
8084 model_rs.contains("qkv_bias: Buffer"),
8085 "Qwen2 LayerBuffers must declare qkv_bias field"
8086 );
8087 assert!(
8088 model_rs.contains("let qkv_bias = next_f32_buffer"),
8089 "Qwen2 layer init must load the bias from the weight blob"
8090 );
8091 assert!(
8092 model_rs.contains("add_bias_batch_pipeline"),
8093 "Qwen2 model struct must include the add_bias_batch_pipeline"
8094 );
8095 assert!(
8096 model_rs.contains("fn dispatch_add_bias_batch"),
8097 "Qwen2 codegen must emit dispatch_add_bias_batch helper"
8098 );
8099 assert!(
8100 model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
8101 "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
8102 );
8103 assert!(
8104 model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
8105 "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
8106 );
8107
8108 let shaders = generate_metal_shaders(&config);
8110 assert!(
8111 shaders.contains("kernel void add_bias_batch"),
8112 "shaders.metal must contain the add_bias_batch kernel"
8113 );
8114 }
8115
8116 #[test]
8117 fn llama_does_not_emit_qkv_bias_machinery() {
8118 let config = minimal_config();
8121 assert!(!config.qkv_bias);
8122 let model_rs = generate_model_rs(&config).unwrap();
8123 assert!(
8124 !model_rs.contains("qkv_bias: Buffer"),
8125 "Llama must not have qkv_bias field"
8126 );
8127 assert!(
8128 !model_rs.contains("add_bias_batch_pipeline"),
8129 "Llama must not pull in add_bias_batch_pipeline"
8130 );
8131 assert!(
8132 !model_rs.contains("dispatch_add_bias_batch"),
8133 "Llama must not call dispatch_add_bias_batch"
8134 );
8135 }
8136
8137 #[test]
8138 fn q4_dispatch_helper_takes_compute_encoder_ref() {
8139 let config = minimal_q4_config();
8140 let model_rs = generate_model_rs(&config).unwrap();
8141
8142 assert!(
8143 model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8144 "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8145 );
8146 }
8147
8148 #[test]
8149 fn f32_model_does_not_use_q4_dispatch() {
8150 let config = minimal_config();
8151 let model_rs = generate_model_rs(&config).unwrap();
8152
8153 let forward_start = model_rs
8154 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8155 .unwrap();
8156 let forward_body = &model_rs[forward_start..];
8157 let forward_end = forward_body
8158 .find("\n fn dispatch_")
8159 .unwrap_or(forward_body.len());
8160 let forward_code = &forward_body[..forward_end];
8161
8162 assert!(
8163 !forward_code.contains("dispatch_matmul_q4"),
8164 "f32 model forward should not use dispatch_matmul_q4"
8165 );
8166 }
8167
8168 #[test]
8169 fn q4_model_lm_head_uses_q4_buffer() {
8170 let config = minimal_q4_config();
8171 let model_rs = generate_model_rs(&config).unwrap();
8172
8173 assert!(
8174 model_rs.contains("let lm_head_buf = next_q4_buffer"),
8175 "Q4_0 model should use next_q4_buffer for lm_head"
8176 );
8177 }
8178
8179 #[test]
8180 fn vec_tile_size_matches_model_dimensions() {
8181 let small = minimal_config();
8183 let shaders_small = generate_metal_shaders(&small);
8184 assert!(
8185 shaders_small.contains("vec_tile[128]"),
8186 "vec_tile should be sized to max(hidden, intermediate) = 128"
8187 );
8188
8189 let mut large = minimal_config();
8191 large.hidden_size = 2048;
8192 large.intermediate_size = 8192;
8193 let shaders_large = generate_metal_shaders(&large);
8194 assert!(
8195 shaders_large.contains("vec_tile[8192]"),
8196 "vec_tile should be 8192 for models with intermediate=8192"
8197 );
8198 assert!(
8199 !shaders_large.contains("vec_tile[4096]"),
8200 "vec_tile should NOT be hardcoded to 4096"
8201 );
8202 }
8203
8204 #[test]
8205 fn generated_cargo_toml_has_server_deps() {
8206 let toml = generate_cargo_toml("my-model");
8207 assert!(
8208 toml.contains("tiny_http"),
8209 "Cargo.toml should depend on tiny_http"
8210 );
8211 assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8212 assert!(
8213 toml.contains("serde_json"),
8214 "Cargo.toml should depend on serde_json"
8215 );
8216 }
8217
8218 #[test]
8219 fn generated_main_rs_has_serve_mode() {
8220 let config = minimal_config();
8221 let main_rs = generate_main_rs("test-model", &config).unwrap();
8222
8223 assert!(
8224 main_rs.contains("--serve"),
8225 "main.rs should parse --serve flag"
8226 );
8227 assert!(
8228 main_rs.contains("--port"),
8229 "main.rs should parse --port flag"
8230 );
8231 assert!(
8232 main_rs.contains("fn serve("),
8233 "main.rs should define serve function"
8234 );
8235 assert!(
8236 main_rs.contains("tiny_http::Server::http"),
8237 "main.rs should create tiny_http server"
8238 );
8239 }
8240
8241 #[test]
8242 fn generated_main_rs_has_chat_completions_endpoint() {
8243 let config = minimal_config();
8244 let main_rs = generate_main_rs("test-model", &config).unwrap();
8245
8246 assert!(
8247 main_rs.contains("/v1/chat/completions"),
8248 "main.rs should handle /v1/chat/completions endpoint"
8249 );
8250 assert!(
8251 main_rs.contains("/v1/models"),
8252 "main.rs should handle /v1/models endpoint"
8253 );
8254 assert!(
8255 main_rs.contains("/health"),
8256 "main.rs should handle /health endpoint"
8257 );
8258 }
8259
8260 #[test]
8261 fn generated_main_rs_has_sse_streaming() {
8262 let config = minimal_config();
8263 let main_rs = generate_main_rs("test-model", &config).unwrap();
8264
8265 assert!(
8266 main_rs.contains("text/event-stream"),
8267 "main.rs should set SSE content type for streaming"
8268 );
8269 assert!(
8270 main_rs.contains("chat.completion.chunk"),
8271 "main.rs should emit SSE chunks"
8272 );
8273 assert!(
8274 main_rs.contains("[DONE]"),
8275 "main.rs should emit [DONE] sentinel"
8276 );
8277 }
8278
8279 #[test]
8280 fn generated_main_rs_has_chat_message_formatting() {
8281 let config = minimal_config();
8282 let main_rs = generate_main_rs("test-model", &config).unwrap();
8283
8284 assert!(
8285 main_rs.contains("fn format_chat_messages"),
8286 "main.rs should define format_chat_messages function"
8287 );
8288 assert!(
8289 main_rs.contains("<|im_start|>"),
8290 "main.rs should use ChatML format"
8291 );
8292 assert!(
8293 main_rs.contains("<|im_end|>"),
8294 "main.rs should use ChatML format"
8295 );
8296 }
8297
8298 #[test]
8299 fn generated_main_rs_has_request_types() {
8300 let config = minimal_config();
8301 let main_rs = generate_main_rs("test-model", &config).unwrap();
8302
8303 assert!(
8304 main_rs.contains("struct ChatRequest"),
8305 "main.rs should define ChatRequest struct"
8306 );
8307 assert!(
8308 main_rs.contains("struct ChatMessage"),
8309 "main.rs should define ChatMessage struct"
8310 );
8311 assert!(
8312 main_rs.contains("Deserialize"),
8313 "main.rs should derive Deserialize for request types"
8314 );
8315 }
8316
8317 #[test]
8318 fn generated_model_has_reset_method() {
8319 let config = minimal_config();
8320 let model_rs = generate_model_rs(&config).unwrap();
8321
8322 assert!(
8323 model_rs.contains("pub fn reset(&mut self)"),
8324 "model.rs should have a reset() method for multi-request serving"
8325 );
8326 assert!(
8327 model_rs.contains("self.pos = 0"),
8328 "reset() should reset position to 0"
8329 );
8330 }
8331
8332 #[test]
8333 fn generated_main_rs_cli_mode_still_works() {
8334 let config = minimal_config();
8335 let main_rs = generate_main_rs("test-model", &config).unwrap();
8336
8337 assert!(
8339 main_rs.contains("fn cli_mode("),
8340 "main.rs should define cli_mode function"
8341 );
8342 assert!(
8343 main_rs.contains("model.forward"),
8344 "main.rs should call model.forward"
8345 );
8346 assert!(
8347 main_rs.contains("model.forward_prefill"),
8348 "main.rs should call model.forward_prefill"
8349 );
8350 }
8351
8352 #[test]
8355 fn generated_shaders_contain_batch_kernels() {
8356 let shaders = generate_metal_shaders(&minimal_config());
8357
8358 assert!(
8359 shaders.contains("kernel void matmul_vec_batch"),
8360 "shaders should contain matmul_vec_batch kernel"
8361 );
8362 assert!(
8363 shaders.contains("kernel void matmul_vec_q8_batch"),
8364 "shaders should contain matmul_vec_q8_batch kernel"
8365 );
8366 assert!(
8367 shaders.contains("kernel void matmul_q8_gemm_batch"),
8368 "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8369 );
8370 assert!(
8371 shaders.contains("kernel void matmul_vec_q4_batch"),
8372 "shaders should contain matmul_vec_q4_batch kernel"
8373 );
8374 assert!(
8375 shaders.contains("kernel void rms_norm_batch"),
8376 "shaders should contain rms_norm_batch kernel"
8377 );
8378 assert!(
8379 shaders.contains("kernel void silu_mul_fused_batch"),
8380 "shaders should contain silu_mul_fused_batch kernel"
8381 );
8382 assert!(
8383 shaders.contains("kernel void add_inplace_batch"),
8384 "shaders should contain add_inplace_batch kernel"
8385 );
8386 assert!(
8387 shaders.contains("kernel void copy_embedding_batch"),
8388 "shaders should contain copy_embedding_batch kernel"
8389 );
8390 }
8391
8392 #[test]
8393 fn generated_model_has_batch_pipelines() {
8394 let config = minimal_config();
8395 let model_rs = generate_model_rs(&config).unwrap();
8396
8397 for pipeline in &[
8398 "matmul_batch_pipeline",
8399 "matmul_q8_batch_pipeline",
8400 "matmul_q8_gemm_batch_pipeline",
8401 "matmul_q4_batch_pipeline",
8402 "rms_norm_batch_pipeline",
8403 "rope_batch_pipeline",
8404 "silu_mul_fused_batch_pipeline",
8405 "add_inplace_batch_pipeline",
8406 "copy_embedding_batch_pipeline",
8407 "attention_batch_pipeline",
8408 "copy_kv_batch_pipeline",
8409 ] {
8410 assert!(
8411 model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8412 "MetalModel should have {pipeline} field"
8413 );
8414 }
8415 }
8416
8417 #[test]
8418 fn generated_model_has_batch_buffers() {
8419 let config = minimal_config();
8420 let model_rs = generate_model_rs(&config).unwrap();
8421
8422 for buf in &[
8423 "batch_hidden_buf",
8424 "batch_residual_buf",
8425 "batch_qkv_buf",
8426 "batch_attn_out_buf",
8427 "batch_attn_proj_buf",
8428 "batch_gate_up_buf",
8429 "batch_ffn_hidden_buf",
8430 "batch_ffn_out_buf",
8431 "batch_tokens_buf",
8432 "batch_positions_buf",
8433 ] {
8434 assert!(
8435 model_rs.contains(&format!("{buf}: Buffer")),
8436 "MetalModel should have {buf} field"
8437 );
8438 }
8439 }
8440
8441 #[test]
8442 fn generated_model_has_forward_prefill_batch() {
8443 let config = minimal_config();
8444 let model_rs = generate_model_rs(&config).unwrap();
8445
8446 assert!(
8447 model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8448 "MetalModel should have forward_prefill_batch method"
8449 );
8450
8451 assert!(
8453 model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8454 "forward_prefill should delegate to forward_prefill_batch"
8455 );
8456 }
8457
8458 #[test]
8459 fn generated_model_has_max_batch_size_constant() {
8460 let config = minimal_config();
8461 let model_rs = generate_model_rs(&config).unwrap();
8462
8463 assert!(
8464 model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8465 "model.rs should define MAX_BATCH_SIZE constant"
8466 );
8467 }
8468
8469 #[test]
8470 fn forward_prefill_batch_uses_batch_dispatch() {
8471 let config = minimal_config();
8472 let model_rs = generate_model_rs(&config).unwrap();
8473
8474 let batch_start = model_rs
8475 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8476 .unwrap();
8477 let batch_body = &model_rs[batch_start..];
8478 let batch_end = batch_body
8479 .find("\n pub fn reset")
8480 .unwrap_or(batch_body.len());
8481 let batch_code = &batch_body[..batch_end];
8482
8483 assert!(
8485 batch_code.contains("dispatch_rms_norm_batch"),
8486 "forward_prefill_batch should use dispatch_rms_norm_batch"
8487 );
8488 assert!(
8489 batch_code.contains("dispatch_copy_embedding_batch"),
8490 "forward_prefill_batch should use dispatch_copy_embedding_batch"
8491 );
8492 assert!(
8493 batch_code.contains("dispatch_silu_mul_fused_batch"),
8494 "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8495 );
8496 assert!(
8498 batch_code.contains("dispatch_attention_batch"),
8499 "forward_prefill_batch should use dispatch_attention_batch"
8500 );
8501 assert!(
8503 batch_code.contains("dispatch_copy_kv_both_batch"),
8504 "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8505 );
8506 assert!(
8508 batch_code.contains("dispatch_rope_qk_batch"),
8509 "forward_prefill_batch should use dispatch_rope_qk_batch"
8510 );
8511 }
8512
8513 #[test]
8514 fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8515 let config = minimal_q8_config();
8516 let model_rs = generate_model_rs(&config).unwrap();
8517
8518 let batch_start = model_rs
8519 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8520 .unwrap();
8521 let batch_body = &model_rs[batch_start..];
8522 let batch_end = batch_body
8523 .find("\n pub fn reset")
8524 .unwrap_or(batch_body.len());
8525 let batch_code = &batch_body[..batch_end];
8526
8527 assert!(
8528 batch_code.contains("dispatch_matmul_q8_batch"),
8529 "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8530 );
8531 }
8532
8533 #[test]
8534 fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8535 let config = minimal_q4_config();
8536 let model_rs = generate_model_rs(&config).unwrap();
8537
8538 let batch_start = model_rs
8539 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8540 .unwrap();
8541 let batch_body = &model_rs[batch_start..];
8542 let batch_end = batch_body
8543 .find("\n pub fn reset")
8544 .unwrap_or(batch_body.len());
8545 let batch_code = &batch_body[..batch_end];
8546
8547 assert!(
8548 batch_code.contains("dispatch_matmul_q4_batch"),
8549 "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8550 );
8551 }
8552
8553 #[test]
8554 fn generated_main_rs_uses_batched_prefill() {
8555 let config = minimal_config();
8556 let main_rs = generate_main_rs("test-model", &config).unwrap();
8557
8558 assert!(
8559 main_rs.contains("forward_prefill_batch"),
8560 "main.rs should use forward_prefill_batch for prompt tokens"
8561 );
8562 }
8563
8564 #[test]
8565 fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8566 let config = minimal_config();
8567 let model_rs = generate_model_rs(&config).unwrap();
8568
8569 let batch_start = model_rs
8570 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8571 .unwrap();
8572 let batch_body = &model_rs[batch_start..];
8573 let batch_end = batch_body
8574 .find("\n pub fn reset")
8575 .unwrap_or(batch_body.len());
8576 let batch_code = &batch_body[..batch_end];
8577
8578 assert!(
8579 batch_code.contains("dispatch_matmul_batch"),
8580 "f32 forward_prefill_batch should use dispatch_matmul_batch"
8581 );
8582 assert!(
8584 !batch_code.contains("dispatch_matmul_q8_batch"),
8585 "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8586 );
8587 assert!(
8588 !batch_code.contains("dispatch_matmul_q4_batch"),
8589 "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8590 );
8591 }
8592
8593 #[test]
8594 fn forward_uses_cpu_embedding_lookup() {
8595 let config = minimal_config();
8596 let model_rs = generate_model_rs(&config).unwrap();
8597
8598 let forward_start = model_rs
8600 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8601 .unwrap();
8602 let forward_body = &model_rs[forward_start..];
8603 let forward_end = forward_body
8604 .find("\n pub fn forward_profile")
8605 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8606 .unwrap_or(forward_body.len());
8607 let forward_code = &forward_body[..forward_end];
8608
8609 assert!(
8611 forward_code.contains("embed_buf.contents()"),
8612 "forward() should access embed_buf via CPU unified memory for embedding lookup"
8613 );
8614 assert!(
8615 forward_code.contains("copy_nonoverlapping"),
8616 "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8617 );
8618 assert!(
8620 !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8621 "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8622 );
8623 }
8624
8625 #[test]
8626 fn forward_profile_method_exists() {
8627 let config = minimal_config();
8628 let model_rs = generate_model_rs(&config).unwrap();
8629
8630 assert!(
8631 model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8632 "MetalModel should have forward_profile() method"
8633 );
8634 assert!(
8636 model_rs.contains("[profile]"),
8637 "forward_profile() should print timing with [profile] prefix"
8638 );
8639 assert!(
8640 model_rs.contains("d_embed"),
8641 "forward_profile() should measure embedding time"
8642 );
8643 assert!(
8644 model_rs.contains("d_layers"),
8645 "forward_profile() should measure layer time"
8646 );
8647 assert!(
8648 model_rs.contains("d_logits"),
8649 "forward_profile() should measure logits time"
8650 );
8651 }
8652
8653 #[test]
8654 fn generated_cli_has_profile_flag() {
8655 let config = minimal_config();
8656 let main_rs = generate_main_rs("test-model", &config).unwrap();
8657
8658 assert!(
8659 main_rs.contains("--profile"),
8660 "CLI should support --profile flag"
8661 );
8662 assert!(
8663 main_rs.contains("forward_profile"),
8664 "CLI should call forward_profile when --profile is set"
8665 );
8666 }
8667
8668 #[test]
8669 fn generated_cli_has_thermal_yield() {
8670 let config = minimal_config();
8671 let main_rs = generate_main_rs("test-model", &config).unwrap();
8672
8673 assert!(
8674 main_rs.contains("yield_now()"),
8675 "CLI generation loop should include thread::yield_now() for thermal management"
8676 );
8677 }
8678
8679 #[test]
8682 fn generated_forward_handles_single_token_prompt() {
8683 let config = minimal_config();
8687 let model_rs = generate_model_rs(&config).unwrap();
8688
8689 let forward_start = model_rs
8691 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8692 .expect("forward() must exist");
8693 let forward_body = &model_rs[forward_start..forward_start + 400];
8694
8695 assert!(
8697 !forward_body.contains("assert!(self.pos > 0"),
8698 "forward() must accept pos=0 (first token with no prefill)"
8699 );
8700
8701 assert!(
8703 model_rs.contains("self.pos"),
8704 "forward() should use self.pos to track sequence position"
8705 );
8706 }
8707
8708 #[test]
8709 fn generated_reset_clears_kv_cache_position() {
8710 let config = minimal_config();
8713 let model_rs = generate_model_rs(&config).unwrap();
8714
8715 let reset_start = model_rs
8716 .find("pub fn reset(&mut self)")
8717 .expect("reset() must exist");
8718 let reset_body = &model_rs[reset_start..reset_start + 200];
8719
8720 assert!(
8722 reset_body.contains("self.pos = 0"),
8723 "reset() must set self.pos = 0"
8724 );
8725
8726 assert!(
8728 reset_body.contains("self.prev_cmd = None"),
8729 "reset() should clear prev_cmd for clean command buffer state"
8730 );
8731 }
8732
8733 #[test]
8734 fn generated_serve_handles_empty_messages_gracefully() {
8735 let config = minimal_config();
8738 let main_rs = generate_main_rs("test-model", &config).unwrap();
8739
8740 let format_fn_start = main_rs
8742 .find("fn format_chat_messages")
8743 .expect("format_chat_messages must exist");
8744 let format_fn_body =
8745 &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8746
8747 assert!(
8749 format_fn_body.contains("for msg in messages"),
8750 "format_chat_messages should iterate over the messages slice"
8751 );
8752 assert!(
8754 format_fn_body.contains("<|im_start|>assistant"),
8755 "format_chat_messages should always append assistant prompt header"
8756 );
8757
8758 let serve_fn_start = main_rs
8760 .find("fn serve(")
8761 .expect("serve function must exist");
8762 let serve_fn_body = &main_rs[serve_fn_start..];
8763 assert!(
8764 serve_fn_body.contains("model.reset()"),
8765 "serve function should reset model between requests"
8766 );
8767 }
8768
8769 #[test]
8770 fn generated_model_forward_increments_pos() {
8771 let config = minimal_config();
8774 let model_rs = generate_model_rs(&config).unwrap();
8775
8776 let forward_start = model_rs
8777 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8778 .unwrap();
8779 let forward_body = &model_rs[forward_start..];
8780 let forward_end = forward_body
8781 .find("\n pub fn forward_profile")
8782 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8783 .or_else(|| forward_body.find("\n fn dispatch_"))
8784 .unwrap_or(forward_body.len());
8785 let forward_code = &forward_body[..forward_end];
8786
8787 assert!(
8788 forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8789 "forward() must increment self.pos after processing a token"
8790 );
8791 }
8792}