1use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16 #[error("graph has no model config")]
18 MissingConfig,
19
20 #[error("I/O error: {0}")]
22 Io(#[from] std::io::Error),
23
24 #[error("format error: {0}")]
26 Fmt(#[from] std::fmt::Error),
27}
28
29pub fn generate_metal_project(
37 graph: &Graph,
38 output_dir: &Path,
39 model_name: &str,
40) -> Result<(), MetalCodegenError> {
41 let config = graph
42 .config
43 .as_ref()
44 .ok_or(MetalCodegenError::MissingConfig)?;
45
46 let src_dir = output_dir.join("src");
47 let shader_dir = output_dir.join("shaders");
48 fs::create_dir_all(&src_dir)?;
49 fs::create_dir_all(&shader_dir)?;
50
51 fs::write(
52 output_dir.join("Cargo.toml"),
53 generate_cargo_toml(model_name),
54 )?;
55
56 fs::write(
57 shader_dir.join("kernels.metal"),
58 generate_metal_shaders(config),
59 )?;
60
61 let model_rs = generate_model_rs(config)?;
62 fs::write(src_dir.join("model.rs"), model_rs)?;
63
64 let main_rs = generate_main_rs(model_name, config)?;
65 fs::write(src_dir.join("main.rs"), main_rs)?;
66
67 Ok(())
68}
69
70fn sanitize_name(name: &str) -> String {
75 name.to_lowercase()
76 .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77 .trim_matches('-')
78 .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82 let sanitized = sanitize_name(model_name);
83 format!(
84 r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108 )
109}
110
111fn generate_metal_shaders(config: &ModelConfig) -> String {
116 let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121 let attn_scores_size = config.max_seq_len.min(4096);
126 r#"//
127// Auto-generated by ForgeLLM Metal codegen.
128// Metal Shading Language compute kernels for transformer inference.
129//
130// Optimized with simdgroup cooperative reductions, shared memory vector
131// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
132// lane, and fast:: math intrinsics for Apple Silicon throughput.
133//
134
135#include <metal_stdlib>
136#include <metal_simdgroup_matrix>
137using namespace metal;
138
139// ── Constants ───────────────────────────────────────────────────────────
140// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
141// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
142// per shared memory load vs 4-row, improving ILP and reducing launches.
143constant constexpr uint SIMDGROUPS_PER_TG = 8;
144constant constexpr uint ROWS_PER_SIMDGROUP = 8;
145constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
146
147// ── matmul_vec ──────────────────────────────────────────────────────────
148// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
149// Uses simdgroup cooperative dot product with shared memory vector caching
150// and float4 vectorized loads. Each simdgroup processes 8 rows for better
151// shared memory reuse (8x vector reuse per load) and instruction-level
152// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
153kernel void matmul_vec(
154 device const float* matrix [[buffer(0)]],
155 device const float* vector [[buffer(1)]],
156 device float* output [[buffer(2)]],
157 constant uint& rows [[buffer(3)]],
158 constant uint& cols [[buffer(4)]],
159 uint tgid [[threadgroup_position_in_grid]],
160 uint tid [[thread_index_in_threadgroup]],
161 uint simd_lane [[thread_index_in_simdgroup]],
162 uint simd_id [[simdgroup_index_in_threadgroup]])
163{
164 // Cooperatively load vector into threadgroup shared memory
165 threadgroup float vec_tile[VEC_TILE_SIZE]; // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
166 for (uint i = tid; i < cols; i += 256) {
167 vec_tile[i] = vector[i];
168 }
169 threadgroup_barrier(mem_flags::mem_threadgroup);
170
171 // Each simdgroup handles 8 consecutive rows
172 uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
173 if (row_base >= rows) return;
174
175 uint base0 = row_base * cols;
176 uint base1 = (row_base + 1) * cols;
177 uint base2 = (row_base + 2) * cols;
178 uint base3 = (row_base + 3) * cols;
179 uint base4 = (row_base + 4) * cols;
180 uint base5 = (row_base + 5) * cols;
181 uint base6 = (row_base + 6) * cols;
182 uint base7 = (row_base + 7) * cols;
183
184 // float4 vectorized accumulation across 8 rows
185 uint cols_vec4 = cols & ~127u; // largest multiple of 128 <= cols
186 float4 sum4_0 = float4(0.0f);
187 float4 sum4_1 = float4(0.0f);
188 float4 sum4_2 = float4(0.0f);
189 float4 sum4_3 = float4(0.0f);
190 float4 sum4_4 = float4(0.0f);
191 float4 sum4_5 = float4(0.0f);
192 float4 sum4_6 = float4(0.0f);
193 float4 sum4_7 = float4(0.0f);
194
195 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
196 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
197 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
198 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
199 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
200 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
201 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
202 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
203 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
204 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
205 }
206
207 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
208 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
209 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
210 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
211 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
212 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
213 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
214 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
215
216 // Handle remaining elements (cols not divisible by 128)
217 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
218 float vv = vec_tile[j];
219 sum0 += matrix[base0 + j] * vv;
220 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
221 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
222 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
223 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
224 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
225 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
226 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
227 }
228
229 // Simdgroup hardware warp-level reduction
230 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
231 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
232 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
233 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
234
235 // Only first lane writes the results
236 if (simd_lane == 0) {
237 if (row_base < rows) output[row_base] = sum0;
238 if (row_base + 1 < rows) output[row_base + 1] = sum1;
239 if (row_base + 2 < rows) output[row_base + 2] = sum2;
240 if (row_base + 3 < rows) output[row_base + 3] = sum3;
241 if (row_base + 4 < rows) output[row_base + 4] = sum4;
242 if (row_base + 5 < rows) output[row_base + 5] = sum5;
243 if (row_base + 6 < rows) output[row_base + 6] = sum6;
244 if (row_base + 7 < rows) output[row_base + 7] = sum7;
245 }
246}
247
248// ── rms_norm ────────────────────────────────────────────────────────────
249// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
250// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
251// via shared memory for minimal synchronization overhead.
252kernel void rms_norm(
253 device const float* input [[buffer(0)]],
254 device const float* weight [[buffer(1)]],
255 device float* output [[buffer(2)]],
256 constant uint& n [[buffer(3)]],
257 constant float& eps [[buffer(4)]],
258 uint tid [[thread_index_in_threadgroup]])
259{
260 // Each thread accumulates partial sum-of-squares
261 float sum_sq = 0.0f;
262 for (uint i = tid; i < n; i += 256) {
263 float v = input[i];
264 sum_sq += v * v;
265 }
266
267 // Simdgroup-level reduction (hardware warp sum)
268 sum_sq = simd_sum(sum_sq);
269
270 // Cross-simdgroup reduction via shared memory
271 threadgroup float shared[8];
272 uint simd_id = tid / 32;
273 uint simd_lane = tid % 32;
274 if (simd_lane == 0) {
275 shared[simd_id] = sum_sq;
276 }
277 threadgroup_barrier(mem_flags::mem_threadgroup);
278
279 // First thread computes final inverse RMS
280 if (tid == 0) {
281 float total = 0.0f;
282 for (uint i = 0; i < 8; i++) {
283 total += shared[i];
284 }
285 shared[0] = fast::rsqrt(total / float(n) + eps);
286 }
287 threadgroup_barrier(mem_flags::mem_threadgroup);
288
289 float inv_rms = shared[0];
290
291 // Normalize
292 for (uint i = tid; i < n; i += 256) {
293 output[i] = input[i] * inv_rms * weight[i];
294 }
295}
296
297// ── rope ────────────────────────────────────────────────────────────────
298// Rotary Position Embedding applied in-place.
299// Each thread handles one (head, pair) combination.
300kernel void rope(
301 device float* data [[buffer(0)]],
302 constant uint& num_heads [[buffer(1)]],
303 constant uint& head_dim [[buffer(2)]],
304 constant uint& pos [[buffer(3)]],
305 constant float& theta [[buffer(4)]],
306 uint id [[thread_position_in_grid]])
307{
308 uint half_dim = head_dim / 2;
309 uint total_pairs = num_heads * half_dim;
310 if (id >= total_pairs) return;
311
312 uint h = id / half_dim;
313 uint i = id % half_dim;
314 uint off = h * head_dim;
315
316 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
317 float angle = float(pos) * freq;
318 float c = cos(angle);
319 float s = sin(angle);
320
321 float x0 = data[off + 2 * i];
322 float x1 = data[off + 2 * i + 1];
323 data[off + 2 * i] = x0 * c - x1 * s;
324 data[off + 2 * i + 1] = x0 * s + x1 * c;
325}
326
327// ── softmax ─────────────────────────────────────────────────────────────
328// Numerically stable softmax over a 1-D array.
329// Single-threadgroup kernel with cooperative reduction.
330kernel void softmax(
331 device float* data [[buffer(0)]],
332 constant uint& n [[buffer(1)]],
333 uint tid [[thread_index_in_threadgroup]],
334 uint tg_size [[threads_per_threadgroup]])
335{
336 threadgroup float shared_val[256];
337
338 // Pass 1: find max
339 float local_max = -INFINITY;
340 for (uint i = tid; i < n; i += tg_size) {
341 local_max = max(local_max, data[i]);
342 }
343 shared_val[tid] = local_max;
344 threadgroup_barrier(mem_flags::mem_threadgroup);
345
346 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
347 if (tid < stride) {
348 shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
349 }
350 threadgroup_barrier(mem_flags::mem_threadgroup);
351 }
352 float max_val = shared_val[0];
353 threadgroup_barrier(mem_flags::mem_threadgroup);
354
355 // Pass 2: exp and sum
356 float local_sum = 0.0f;
357 for (uint i = tid; i < n; i += tg_size) {
358 float e = fast::exp(data[i] - max_val);
359 data[i] = e;
360 local_sum += e;
361 }
362 shared_val[tid] = local_sum;
363 threadgroup_barrier(mem_flags::mem_threadgroup);
364
365 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
366 if (tid < stride) {
367 shared_val[tid] += shared_val[tid + stride];
368 }
369 threadgroup_barrier(mem_flags::mem_threadgroup);
370 }
371 float inv_sum = 1.0f / shared_val[0];
372 threadgroup_barrier(mem_flags::mem_threadgroup);
373
374 // Pass 3: normalize
375 for (uint i = tid; i < n; i += tg_size) {
376 data[i] *= inv_sum;
377 }
378}
379
380// ── silu_mul ────────────────────────────────────────────────────────────
381// Fused SiLU activation * element-wise multiply:
382// output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
383kernel void silu_mul(
384 device const float* gate [[buffer(0)]],
385 device const float* up [[buffer(1)]],
386 device float* output [[buffer(2)]],
387 constant uint& n [[buffer(3)]],
388 uint id [[thread_position_in_grid]])
389{
390 if (id >= n) return;
391 float g = gate[id];
392 output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
393}
394
395// ── silu_mul_fused ─────────────────────────────────────────────────────
396// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
397// gate = gate_up[0..n], up = gate_up[n..2*n]
398// output[i] = silu(gate_up[i]) * gate_up[n + i]
399kernel void silu_mul_fused(
400 device const float* gate_up [[buffer(0)]],
401 device float* output [[buffer(1)]],
402 constant uint& n [[buffer(2)]],
403 uint id [[thread_position_in_grid]])
404{
405 if (id >= n) return;
406 float g = gate_up[id];
407 float u = gate_up[n + id];
408 output[id] = (g / (1.0f + fast::exp(-g))) * u;
409}
410
411// ── elementwise_add ─────────────────────────────────────────────────────
412// Residual connection: output[i] = a[i] + b[i]
413kernel void elementwise_add(
414 device const float* a [[buffer(0)]],
415 device const float* b [[buffer(1)]],
416 device float* output [[buffer(2)]],
417 constant uint& n [[buffer(3)]],
418 uint id [[thread_position_in_grid]])
419{
420 if (id >= n) return;
421 output[id] = a[id] + b[id];
422}
423
424// ── copy_buffer ─────────────────────────────────────────────────────────
425// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
426// transitions. Used for KV cache updates and embedding lookup.
427kernel void copy_buffer(
428 device const float* src [[buffer(0)]],
429 device float* dst [[buffer(1)]],
430 constant uint& count [[buffer(2)]],
431 uint id [[thread_position_in_grid]])
432{
433 if (id < count) dst[id] = src[id];
434}
435
436// ── copy_offset ─────────────────────────────────────────────────────────
437// Copy with source offset (in floats). Used for embedding table lookup
438// where we need to copy a specific row from a large table.
439kernel void copy_offset(
440 device const float* src [[buffer(0)]],
441 device float* dst [[buffer(1)]],
442 constant uint& src_offset [[buffer(2)]], // in floats
443 constant uint& count [[buffer(3)]],
444 uint id [[thread_position_in_grid]])
445{
446 if (id < count) dst[id] = src[src_offset + id];
447}
448
449// ── add_inplace ─────────────────────────────────────────────────────────
450// In-place residual connection: a[i] += b[i]
451// Avoids a separate blit copy for residual add, reducing encoder overhead.
452kernel void add_inplace(
453 device float* a [[buffer(0)]],
454 device const float* b [[buffer(1)]],
455 constant uint& n [[buffer(2)]],
456 uint id [[thread_position_in_grid]])
457{
458 if (id >= n) return;
459 a[id] += b[id];
460}
461
462// ── matmul_vec_q8 ─────────────────────────────────────────────────────
463// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
464// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
465// Operates directly on quantized weights to halve memory bandwidth vs f32,
466// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
467//
468// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
469// because int8->float conversion doubles register demand. Fully unrolled
470// inner loop with float4 vector loads from shared memory eliminates loop
471// overhead and enables better instruction scheduling.
472// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
473constant constexpr uint Q8_ROWS_PER_SG = 4;
474constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
475
476// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
477// doubles ALU work, so the same register budget applies).
478constant constexpr uint Q4_ROWS_PER_SG = 4;
479constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
480
481kernel void matmul_vec_q8(
482 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes
483 device const float* vector [[buffer(1)]], // f32 input
484 device float* output [[buffer(2)]],
485 constant uint& rows [[buffer(3)]],
486 constant uint& cols [[buffer(4)]], // number of elements per row
487 uint tgid [[threadgroup_position_in_grid]],
488 uint tid [[thread_index_in_threadgroup]],
489 uint simd_lane [[thread_index_in_simdgroup]],
490 uint simd_id [[simdgroup_index_in_threadgroup]])
491{
492 // Load vector into shared memory
493 threadgroup float vec_tile[VEC_TILE_SIZE];
494 for (uint i = tid; i < cols; i += 256) {
495 vec_tile[i] = vector[i];
496 }
497 threadgroup_barrier(mem_flags::mem_threadgroup);
498
499 // Each simdgroup handles 4 consecutive rows (lower register pressure)
500 uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
501 if (row_base >= rows) return;
502
503 // Q8_0: each block is 34 bytes for 32 elements
504 uint blocks_per_row = cols / 32;
505 uint row_bytes = blocks_per_row * 34;
506
507 // Pointers to each row's Q8_0 data
508 device const uchar* r0 = matrix + row_base * row_bytes;
509 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
510 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
511 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
512
513 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
514
515 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
516 uint bb = blk * 34;
517 uint vb = blk * 32;
518
519 // Prefetch all 4 scales
520 float sc0 = float(*(device const half*)(r0 + bb));
521 float sc1 = float(*(device const half*)(r1 + bb));
522 float sc2 = float(*(device const half*)(r2 + bb));
523 float sc3 = float(*(device const half*)(r3 + bb));
524
525 // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
526 // Q8_0 block layout where the int8 data starts at offset +2 from a
527 // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
528 // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
529 // reduction in memory transactions. Metal's char16/packed_char16 are
530 // reserved types and packed_*int4 require >=4-byte alignment which
531 // this layout does not provide, so packed_short4 is the widest valid
532 // vectorized load.
533 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
534 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
535 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
536 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
537
538 // Load all 8 float4 vector values for this 32-element block from shared memory
539 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
540 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
541 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
542 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
543 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
544 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
545 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
546 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
547
548 // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
549 // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
550 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
551 short4 _s = short4(SHORT4); \
552 char2 _a = as_type<char2>(_s.x); \
553 char2 _b = as_type<char2>(_s.y); \
554 char2 _c = as_type<char2>(_s.z); \
555 char2 _d = as_type<char2>(_s.w); \
556 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
557 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
558 }
559
560 float4 f0, f1;
561 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
562
563 // Row 0: 4 short4 loads cover 32 int8 weights
564 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
565 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
566 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
567 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
568
569 // Row 1
570 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
571 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
572 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
573 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
574
575 // Row 2
576 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
577 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
578 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
579 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
580
581 // Row 3
582 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
583 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
584 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
585 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
586
587 #undef Q8_UNPACK8
588
589 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
590 }
591
592 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
593 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
594
595 if (simd_lane == 0) {
596 if (row_base < rows) output[row_base] = sum0;
597 if (row_base + 1 < rows) output[row_base + 1] = sum1;
598 if (row_base + 2 < rows) output[row_base + 2] = sum2;
599 if (row_base + 3 < rows) output[row_base + 3] = sum3;
600 }
601}
602
603// ── matmul_vec_q4 ─────────────────────────────────────────────────────
604// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
605// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
606// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
607// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
608//
609// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
610// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
611kernel void matmul_vec_q4(
612 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes
613 device const float* vector [[buffer(1)]], // f32 input
614 device float* output [[buffer(2)]],
615 constant uint& rows [[buffer(3)]],
616 constant uint& cols [[buffer(4)]], // number of elements per row
617 uint tgid [[threadgroup_position_in_grid]],
618 uint tid [[thread_index_in_threadgroup]],
619 uint simd_lane [[thread_index_in_simdgroup]],
620 uint simd_id [[simdgroup_index_in_threadgroup]])
621{
622 // Load vector into shared memory
623 threadgroup float vec_tile[VEC_TILE_SIZE];
624 for (uint i = tid; i < cols; i += 256) {
625 vec_tile[i] = vector[i];
626 }
627 threadgroup_barrier(mem_flags::mem_threadgroup);
628
629 // Each simdgroup handles 4 consecutive rows
630 uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
631 if (row_base >= rows) return;
632
633 // Q4_0: each block is 18 bytes for 32 elements
634 uint blocks_per_row = cols / 32;
635 uint row_bytes = blocks_per_row * 18;
636
637 // Pointers to each row's Q4_0 data
638 device const uchar* r0 = matrix + row_base * row_bytes;
639 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
640 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
641 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
642
643 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
644
645 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
646 uint bb = blk * 18;
647 uint vb = blk * 32;
648
649 // Prefetch all 4 scales
650 float sc0 = float(*(device const half*)(r0 + bb));
651 float sc1 = float(*(device const half*)(r1 + bb));
652 float sc2 = float(*(device const half*)(r2 + bb));
653 float sc3 = float(*(device const half*)(r3 + bb));
654
655 // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
656 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
657 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
658 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
659 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
660
661 // Load 8 float4 vector values for 32 elements from shared memory
662 // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
663 float4 v0 = *(threadgroup const float4*)(vec_tile + vb); // [0..3]
664 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4); // [4..7]
665 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8); // [8..11]
666 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12); // [12..15]
667 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16); // [16..19]
668 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20); // [20..23]
669 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24); // [24..27]
670 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28); // [28..31]
671
672 // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
673 // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
674 float bd0=0, bd1=0, bd2=0, bd3=0;
675 uchar4 b;
676
677 // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
678 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
679 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
680 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
681 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
682 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
683 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
684 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
685 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
686 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
687 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
688 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
689 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
690 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
691 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
692 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
693 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
694
695 // Row 1
696 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
697 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
698 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
699 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
700 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
701 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
702 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
703 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
704 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
705 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
706 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
707 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
708 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
709 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
710 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
711 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
712
713 // Row 2
714 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
715 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
716 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
717 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
718 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
719 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
720 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
721 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
722 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
723 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
724 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
725 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
726 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
727 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
728 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
729 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
730
731 // Row 3
732 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
733 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
734 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
735 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
736 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
737 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
738 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
739 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
740 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
741 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
742 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
743 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
744 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
745 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
746 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
747 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
748
749 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
750 }
751
752 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
753 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
754
755 if (simd_lane == 0) {
756 if (row_base < rows) output[row_base] = sum0;
757 if (row_base + 1 < rows) output[row_base + 1] = sum1;
758 if (row_base + 2 < rows) output[row_base + 2] = sum2;
759 if (row_base + 3 < rows) output[row_base + 3] = sum3;
760 }
761}
762
763// ── attention ───────────────────────────────────────────────────────────
764// Single-query attention with simdgroup cooperative reductions.
765// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
766// with simd_max/simd_sum reductions, then weighted sum of V.
767// Each threadgroup handles one head with 256 threads (8 simdgroups).
768//
769// Buffers:
770// q: [num_heads * head_dim] current query
771// k_cache: [max_seq_len * num_kv_heads * head_dim]
772// v_cache: [max_seq_len * num_kv_heads * head_dim]
773// output: [num_heads * head_dim]
774kernel void attention(
775 device const float* q [[buffer(0)]],
776 device const float* k_cache [[buffer(1)]],
777 device const float* v_cache [[buffer(2)]],
778 device float* output [[buffer(3)]],
779 constant uint& seq_len [[buffer(4)]],
780 constant uint& num_heads [[buffer(5)]],
781 constant uint& num_kv_heads [[buffer(6)]],
782 constant uint& head_dim [[buffer(7)]],
783 uint tgid [[threadgroup_position_in_grid]],
784 uint tid [[thread_index_in_threadgroup]],
785 uint simd_lane [[thread_index_in_simdgroup]],
786 uint simd_id [[simdgroup_index_in_threadgroup]])
787{
788 uint head = tgid;
789 if (head >= num_heads) return;
790 uint kv_head = head / (num_heads / num_kv_heads);
791
792 uint q_off = head * head_dim;
793
794 // Step 1: Compute attention scores Q·K^T with simdgroup reduction
795 // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
796 // vs 4096. For seq_len > 2048, generation-phase attention is rare;
797 // most generation steps have short effective context.
798 threadgroup float scores[ATTN_SCORES_SIZE]; // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
799
800 for (uint s = simd_id; s < seq_len; s += 8) {
801 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
802 float dot = 0.0;
803 for (uint d = simd_lane; d < head_dim; d += 32) {
804 dot += q[q_off + d] * k_cache[k_off + d];
805 }
806 dot = simd_sum(dot);
807 if (simd_lane == 0) {
808 scores[s] = dot * fast::rsqrt(float(head_dim));
809 }
810 }
811 threadgroup_barrier(mem_flags::mem_threadgroup);
812
813 // Step 2: Softmax over scores (cooperative)
814 // Find max
815 float local_max = -INFINITY;
816 for (uint s = tid; s < seq_len; s += 256) {
817 local_max = max(local_max, scores[s]);
818 }
819 local_max = simd_max(local_max);
820 threadgroup float shared_max[8];
821 if (simd_lane == 0) shared_max[simd_id] = local_max;
822 threadgroup_barrier(mem_flags::mem_threadgroup);
823 if (tid == 0) {
824 float m = shared_max[0];
825 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
826 shared_max[0] = m;
827 }
828 threadgroup_barrier(mem_flags::mem_threadgroup);
829 float max_val = shared_max[0];
830
831 // Exp and sum
832 float local_sum = 0.0;
833 for (uint s = tid; s < seq_len; s += 256) {
834 scores[s] = fast::exp(scores[s] - max_val);
835 local_sum += scores[s];
836 }
837 local_sum = simd_sum(local_sum);
838 threadgroup float shared_sum[8];
839 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
840 threadgroup_barrier(mem_flags::mem_threadgroup);
841 if (tid == 0) {
842 float total = 0.0;
843 for (uint i = 0; i < 8; i++) total += shared_sum[i];
844 shared_sum[0] = 1.0 / total;
845 }
846 threadgroup_barrier(mem_flags::mem_threadgroup);
847 float inv_sum = shared_sum[0];
848
849 for (uint s = tid; s < seq_len; s += 256) {
850 scores[s] *= inv_sum;
851 }
852 threadgroup_barrier(mem_flags::mem_threadgroup);
853
854 // Step 3: Weighted sum of V: output = scores · V
855 // Each thread handles a range of head_dim dimensions.
856 // Process 4 sequence positions at a time for better ILP and reduced
857 // loop overhead (float4 score gather, 4 V loads per iteration).
858 uint seq_len4 = seq_len & ~3u; // largest multiple of 4 <= seq_len
859 uint v_stride = num_kv_heads * head_dim;
860 for (uint d = tid; d < head_dim; d += 256) {
861 float acc = 0.0;
862 uint v_base = kv_head * head_dim + d;
863 for (uint s = 0; s < seq_len4; s += 4) {
864 float sc0 = scores[s];
865 float sc1 = scores[s + 1];
866 float sc2 = scores[s + 2];
867 float sc3 = scores[s + 3];
868 acc += sc0 * v_cache[s * v_stride + v_base]
869 + sc1 * v_cache[(s+1) * v_stride + v_base]
870 + sc2 * v_cache[(s+2) * v_stride + v_base]
871 + sc3 * v_cache[(s+3) * v_stride + v_base];
872 }
873 for (uint s = seq_len4; s < seq_len; s++) {
874 acc += scores[s] * v_cache[s * v_stride + v_base];
875 }
876 output[q_off + d] = acc;
877 }
878}
879
880// ── Batched prefill kernels ────────────────────────────────────────────
881// These kernels process M input vectors against the same weight matrix
882// in a single dispatch, converting mat-vec into mat-mat for better GPU
883// utilization during prompt prefill.
884
885// ── rms_norm_batch ─────────────────────────────────────────────────────
886// RMS normalization for a batch of vectors.
887// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
888// Grid: M threadgroups (one per token).
889kernel void rms_norm_batch(
890 device const float* input [[buffer(0)]], // [M, n]
891 device const float* weight [[buffer(1)]], // [n]
892 device float* output [[buffer(2)]], // [M, n]
893 constant uint& n [[buffer(3)]],
894 constant float& eps [[buffer(4)]],
895 constant uint& num_tokens [[buffer(5)]],
896 uint tgid [[threadgroup_position_in_grid]],
897 uint tid [[thread_index_in_threadgroup]])
898{
899 if (tgid >= num_tokens) return;
900
901 uint base = tgid * n;
902
903 float sum_sq = 0.0f;
904 for (uint i = tid; i < n; i += 256) {
905 float v = input[base + i];
906 sum_sq += v * v;
907 }
908
909 sum_sq = simd_sum(sum_sq);
910
911 threadgroup float shared[8];
912 uint simd_id = tid / 32;
913 uint simd_lane = tid % 32;
914 if (simd_lane == 0) {
915 shared[simd_id] = sum_sq;
916 }
917 threadgroup_barrier(mem_flags::mem_threadgroup);
918
919 if (tid == 0) {
920 float total = 0.0f;
921 for (uint i = 0; i < 8; i++) {
922 total += shared[i];
923 }
924 shared[0] = fast::rsqrt(total / float(n) + eps);
925 }
926 threadgroup_barrier(mem_flags::mem_threadgroup);
927
928 float inv_rms = shared[0];
929
930 for (uint i = tid; i < n; i += 256) {
931 output[base + i] = input[base + i] * inv_rms * weight[i];
932 }
933}
934
935// ── rope_batch ─────────────────────────────────────────────────────────
936// Rotary Position Embedding for a batch of vectors with different positions.
937// data layout: [M, num_heads * head_dim], positions: [M]
938// Each thread handles one (token, head, pair) combination.
939kernel void rope_batch(
940 device float* data [[buffer(0)]], // [M, num_heads * head_dim]
941 constant uint& num_heads [[buffer(1)]],
942 constant uint& head_dim [[buffer(2)]],
943 device const uint* positions [[buffer(3)]], // [M] position per token
944 constant float& theta [[buffer(4)]],
945 constant uint& num_tokens [[buffer(5)]],
946 uint id [[thread_position_in_grid]])
947{
948 uint half_dim = head_dim / 2;
949 uint pairs_per_token = num_heads * half_dim;
950 uint total = num_tokens * pairs_per_token;
951 if (id >= total) return;
952
953 uint token = id / pairs_per_token;
954 uint rem = id % pairs_per_token;
955 uint h = rem / half_dim;
956 uint i = rem % half_dim;
957 uint off = token * (num_heads * head_dim) + h * head_dim;
958
959 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
960 float angle = float(positions[token]) * freq;
961 float c = cos(angle);
962 float s = sin(angle);
963
964 float x0 = data[off + 2 * i];
965 float x1 = data[off + 2 * i + 1];
966 data[off + 2 * i] = x0 * c - x1 * s;
967 data[off + 2 * i + 1] = x0 * s + x1 * c;
968}
969
970// ── silu_mul_fused_batch ───────────────────────────────────────────────
971// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
972// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
973kernel void silu_mul_fused_batch(
974 device const float* gate_up [[buffer(0)]], // [M, 2*n]
975 device float* output [[buffer(1)]], // [M, n]
976 constant uint& n [[buffer(2)]],
977 constant uint& num_tokens [[buffer(3)]],
978 uint id [[thread_position_in_grid]])
979{
980 uint total = num_tokens * n;
981 if (id >= total) return;
982 uint token = id / n;
983 uint i = id % n;
984 uint gu_base = token * 2 * n;
985 float g = gate_up[gu_base + i];
986 float u = gate_up[gu_base + n + i];
987 output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
988}
989
990// ── add_inplace_batch ──────────────────────────────────────────────────
991// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
992kernel void add_inplace_batch(
993 device float* a [[buffer(0)]], // [M * n]
994 device const float* b [[buffer(1)]], // [M * n]
995 constant uint& total [[buffer(2)]], // M * n
996 uint id [[thread_position_in_grid]])
997{
998 if (id >= total) return;
999 a[id] += b[id];
1000}
1001
1002// ── copy_embedding_batch ───────────────────────────────────────────────
1003// Copy M embedding rows from embedding table to a contiguous batch buffer.
1004// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1005kernel void copy_embedding_batch(
1006 device const float* embed [[buffer(0)]], // [vocab_size, dim]
1007 device float* output [[buffer(1)]], // [M, dim]
1008 device const uint* tokens [[buffer(2)]], // [M]
1009 constant uint& dim [[buffer(3)]],
1010 constant uint& num_tokens [[buffer(4)]],
1011 uint id [[thread_position_in_grid]])
1012{
1013 uint total = num_tokens * dim;
1014 if (id >= total) return;
1015 uint token_idx = id / dim;
1016 uint d = id % dim;
1017 output[id] = embed[tokens[token_idx] * dim + d];
1018}
1019
1020// ── matmul_vec_batch ───────────────────────────────────────────────────
1021// Batched matrix-vector multiply: process M input vectors against the same
1022// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1023// Each threadgroup handles one (token, row_group) pair.
1024kernel void matmul_vec_batch(
1025 device const float* matrix [[buffer(0)]], // [rows, cols] weight
1026 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1027 device float* outputs [[buffer(2)]], // [M, rows] output batch
1028 constant uint& num_tokens [[buffer(3)]], // M
1029 constant uint& rows [[buffer(4)]],
1030 constant uint& cols [[buffer(5)]],
1031 uint tgid [[threadgroup_position_in_grid]],
1032 uint tid [[thread_index_in_threadgroup]],
1033 uint simd_lane [[thread_index_in_simdgroup]],
1034 uint simd_id [[simdgroup_index_in_threadgroup]])
1035{
1036 uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1037 uint token = tgid / row_tgs;
1038 uint tg_in_token = tgid % row_tgs;
1039 if (token >= num_tokens) return;
1040
1041 // Load this token's input vector into shared memory
1042 threadgroup float vec_tile[VEC_TILE_SIZE];
1043 device const float* input = inputs + token * cols;
1044 for (uint i = tid; i < cols; i += 256) {
1045 vec_tile[i] = input[i];
1046 }
1047 threadgroup_barrier(mem_flags::mem_threadgroup);
1048
1049 uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1050 if (row_base >= rows) return;
1051
1052 uint base0 = row_base * cols;
1053 uint base1 = (row_base + 1) * cols;
1054 uint base2 = (row_base + 2) * cols;
1055 uint base3 = (row_base + 3) * cols;
1056 uint base4 = (row_base + 4) * cols;
1057 uint base5 = (row_base + 5) * cols;
1058 uint base6 = (row_base + 6) * cols;
1059 uint base7 = (row_base + 7) * cols;
1060
1061 uint cols_vec4 = cols & ~127u;
1062 float4 sum4_0 = float4(0.0f);
1063 float4 sum4_1 = float4(0.0f);
1064 float4 sum4_2 = float4(0.0f);
1065 float4 sum4_3 = float4(0.0f);
1066 float4 sum4_4 = float4(0.0f);
1067 float4 sum4_5 = float4(0.0f);
1068 float4 sum4_6 = float4(0.0f);
1069 float4 sum4_7 = float4(0.0f);
1070
1071 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1072 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1073 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1074 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1075 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1076 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1077 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1078 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1079 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1080 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1081 }
1082
1083 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1084 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1085 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1086 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1087 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1088 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1089 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1090 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1091
1092 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1093 float vv = vec_tile[j];
1094 sum0 += matrix[base0 + j] * vv;
1095 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1096 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1097 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1098 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1099 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1100 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1101 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1102 }
1103
1104 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1105 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1106 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1107 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1108
1109 device float* output = outputs + token * rows;
1110 if (simd_lane == 0) {
1111 if (row_base < rows) output[row_base] = sum0;
1112 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1113 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1114 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1115 if (row_base + 4 < rows) output[row_base + 4] = sum4;
1116 if (row_base + 5 < rows) output[row_base + 5] = sum5;
1117 if (row_base + 6 < rows) output[row_base + 6] = sum6;
1118 if (row_base + 7 < rows) output[row_base + 7] = sum7;
1119 }
1120}
1121
1122// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1123// Batched Q8_0 matrix-vector multiply for M input vectors.
1124// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1125kernel void matmul_vec_q8_batch(
1126 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1127 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1128 device float* outputs [[buffer(2)]], // [M, rows] output batch
1129 constant uint& num_tokens [[buffer(3)]], // M
1130 constant uint& rows [[buffer(4)]],
1131 constant uint& cols [[buffer(5)]],
1132 uint tgid [[threadgroup_position_in_grid]],
1133 uint tid [[thread_index_in_threadgroup]],
1134 uint simd_lane [[thread_index_in_simdgroup]],
1135 uint simd_id [[simdgroup_index_in_threadgroup]])
1136{
1137 uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1138 uint token = tgid / row_tgs;
1139 uint tg_in_token = tgid % row_tgs;
1140 if (token >= num_tokens) return;
1141
1142 // Load this token's input vector into shared memory
1143 threadgroup float vec_tile[VEC_TILE_SIZE];
1144 device const float* input = inputs + token * cols;
1145 for (uint i = tid; i < cols; i += 256) {
1146 vec_tile[i] = input[i];
1147 }
1148 threadgroup_barrier(mem_flags::mem_threadgroup);
1149
1150 uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1151 if (row_base >= rows) return;
1152
1153 uint blocks_per_row = cols / 32;
1154 uint row_bytes = blocks_per_row * 34;
1155
1156 device const uchar* r0 = matrix + row_base * row_bytes;
1157 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1158 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1159 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1160
1161 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1162
1163 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1164 uint bb = blk * 34;
1165 uint vb = blk * 32;
1166
1167 float sc0 = float(*(device const half*)(r0 + bb));
1168 float sc1 = float(*(device const half*)(r1 + bb));
1169 float sc2 = float(*(device const half*)(r2 + bb));
1170 float sc3 = float(*(device const half*)(r3 + bb));
1171
1172 // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1173 // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1174 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1175 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1176 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1177 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1178
1179 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1180 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1181 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1182 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1183 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1184 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1185 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1186 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1187
1188 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1189 short4 _s = short4(SHORT4); \
1190 char2 _a = as_type<char2>(_s.x); \
1191 char2 _b = as_type<char2>(_s.y); \
1192 char2 _c = as_type<char2>(_s.z); \
1193 char2 _d = as_type<char2>(_s.w); \
1194 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1195 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1196 }
1197
1198 float4 f0, f1;
1199 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1200
1201 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1202 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1203 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1204 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1205
1206 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1207 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1208 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1209 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1210
1211 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1212 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1213 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1214 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1215
1216 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1217 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1218 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1219 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1220
1221 #undef Q8_UNPACK8
1222
1223 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1224 }
1225
1226 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1227 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1228
1229 device float* output = outputs + token * rows;
1230 if (simd_lane == 0) {
1231 if (row_base < rows) output[row_base] = sum0;
1232 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1233 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1234 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1235 }
1236}
1237
1238// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1239// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1240// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1241// the Q8_0 weight blocks are fetched once from device memory and reused for
1242// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1243// per-token dispatch).
1244//
1245// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1246// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1247// with simd_sum. Token vectors are read directly from device memory inside
1248// the block loop (not cached in shared memory) so intermediate_size up to
1249// 8192 fits without spilling threadgroup memory.
1250constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1251
1252kernel void matmul_q8_gemm_batch(
1253 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1254 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1255 device float* outputs [[buffer(2)]], // [M, rows] output batch
1256 constant uint& num_tokens [[buffer(3)]], // M
1257 constant uint& rows [[buffer(4)]],
1258 constant uint& cols [[buffer(5)]],
1259 uint2 tgid [[threadgroup_position_in_grid]],
1260 uint tid [[thread_index_in_threadgroup]],
1261 uint simd_lane [[thread_index_in_simdgroup]],
1262 uint simd_id [[simdgroup_index_in_threadgroup]])
1263{
1264 uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1265 uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1266 if (row_base >= rows || tok_base >= num_tokens) return;
1267
1268 // How many tokens in this tile are valid?
1269 uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1270
1271 uint blocks_per_row = cols / 32;
1272 uint row_bytes = blocks_per_row * 34;
1273
1274 device const uchar* r0 = matrix + row_base * row_bytes;
1275 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1276 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1277 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1278
1279 // Accumulators: 4 tokens × 4 rows per simdgroup.
1280 float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1281 float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1282 float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1283 float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1284
1285 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1286 uint bb = blk * 34;
1287 uint vb = blk * 32;
1288
1289 // ── Load weight data ONCE per block (reused across all tokens) ──
1290 float sc0 = float(*(device const half*)(r0 + bb));
1291 float sc1 = float(*(device const half*)(r1 + bb));
1292 float sc2 = float(*(device const half*)(r2 + bb));
1293 float sc3 = float(*(device const half*)(r3 + bb));
1294
1295 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1296 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1297 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1298 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1299
1300 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1301 short4 _s = short4(SHORT4); \
1302 char2 _a = as_type<char2>(_s.x); \
1303 char2 _b = as_type<char2>(_s.y); \
1304 char2 _c = as_type<char2>(_s.z); \
1305 char2 _d = as_type<char2>(_s.w); \
1306 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1307 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1308 }
1309
1310 // Unpack all 4 rows × 8 float4 weights (scaled). These live in
1311 // registers for the duration of the block and are dotted against
1312 // every token's vector tile.
1313 float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1314 float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1315 float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1316 float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1317
1318 Q8_UNPACK8(d0[0], w0_0, w0_1);
1319 Q8_UNPACK8(d0[1], w0_2, w0_3);
1320 Q8_UNPACK8(d0[2], w0_4, w0_5);
1321 Q8_UNPACK8(d0[3], w0_6, w0_7);
1322
1323 Q8_UNPACK8(d1[0], w1_0, w1_1);
1324 Q8_UNPACK8(d1[1], w1_2, w1_3);
1325 Q8_UNPACK8(d1[2], w1_4, w1_5);
1326 Q8_UNPACK8(d1[3], w1_6, w1_7);
1327
1328 Q8_UNPACK8(d2[0], w2_0, w2_1);
1329 Q8_UNPACK8(d2[1], w2_2, w2_3);
1330 Q8_UNPACK8(d2[2], w2_4, w2_5);
1331 Q8_UNPACK8(d2[3], w2_6, w2_7);
1332
1333 Q8_UNPACK8(d3[0], w3_0, w3_1);
1334 Q8_UNPACK8(d3[1], w3_2, w3_3);
1335 Q8_UNPACK8(d3[2], w3_4, w3_5);
1336 Q8_UNPACK8(d3[3], w3_6, w3_7);
1337
1338 #undef Q8_UNPACK8
1339
1340 // ── For each token, read vector and accumulate against shared weights ──
1341 // Token 0 (always valid: tok_count >= 1).
1342 {
1343 device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1344 float4 v0 = *(device const float4*)(a0);
1345 float4 v1 = *(device const float4*)(a0 + 4);
1346 float4 v2 = *(device const float4*)(a0 + 8);
1347 float4 v3 = *(device const float4*)(a0 + 12);
1348 float4 v4 = *(device const float4*)(a0 + 16);
1349 float4 v5 = *(device const float4*)(a0 + 20);
1350 float4 v6 = *(device const float4*)(a0 + 24);
1351 float4 v7 = *(device const float4*)(a0 + 28);
1352 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1353 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1354 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1355 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1356 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1357 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1358 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1359 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1360 s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1361 }
1362 // Token 1
1363 if (tok_count > 1) {
1364 device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1365 float4 v0 = *(device const float4*)(a1);
1366 float4 v1 = *(device const float4*)(a1 + 4);
1367 float4 v2 = *(device const float4*)(a1 + 8);
1368 float4 v3 = *(device const float4*)(a1 + 12);
1369 float4 v4 = *(device const float4*)(a1 + 16);
1370 float4 v5 = *(device const float4*)(a1 + 20);
1371 float4 v6 = *(device const float4*)(a1 + 24);
1372 float4 v7 = *(device const float4*)(a1 + 28);
1373 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1374 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1375 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1376 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1377 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1378 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1379 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1380 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1381 s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1382 }
1383 // Token 2
1384 if (tok_count > 2) {
1385 device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1386 float4 v0 = *(device const float4*)(a2);
1387 float4 v1 = *(device const float4*)(a2 + 4);
1388 float4 v2 = *(device const float4*)(a2 + 8);
1389 float4 v3 = *(device const float4*)(a2 + 12);
1390 float4 v4 = *(device const float4*)(a2 + 16);
1391 float4 v5 = *(device const float4*)(a2 + 20);
1392 float4 v6 = *(device const float4*)(a2 + 24);
1393 float4 v7 = *(device const float4*)(a2 + 28);
1394 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1395 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1396 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1397 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1398 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1399 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1400 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1401 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1402 s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1403 }
1404 // Token 3
1405 if (tok_count > 3) {
1406 device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1407 float4 v0 = *(device const float4*)(a3);
1408 float4 v1 = *(device const float4*)(a3 + 4);
1409 float4 v2 = *(device const float4*)(a3 + 8);
1410 float4 v3 = *(device const float4*)(a3 + 12);
1411 float4 v4 = *(device const float4*)(a3 + 16);
1412 float4 v5 = *(device const float4*)(a3 + 20);
1413 float4 v6 = *(device const float4*)(a3 + 24);
1414 float4 v7 = *(device const float4*)(a3 + 28);
1415 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1416 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1417 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1418 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1419 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1420 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1421 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1422 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1423 s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1424 }
1425 }
1426
1427 // simdgroup reduction
1428 s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1429 s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1430 s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1431 s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1432
1433 if (simd_lane == 0) {
1434 device float* o0 = outputs + (tok_base + 0) * rows;
1435 if (row_base < rows) o0[row_base] = s00;
1436 if (row_base + 1 < rows) o0[row_base + 1] = s01;
1437 if (row_base + 2 < rows) o0[row_base + 2] = s02;
1438 if (row_base + 3 < rows) o0[row_base + 3] = s03;
1439
1440 if (tok_count > 1) {
1441 device float* o1 = outputs + (tok_base + 1) * rows;
1442 if (row_base < rows) o1[row_base] = s10;
1443 if (row_base + 1 < rows) o1[row_base + 1] = s11;
1444 if (row_base + 2 < rows) o1[row_base + 2] = s12;
1445 if (row_base + 3 < rows) o1[row_base + 3] = s13;
1446 }
1447 if (tok_count > 2) {
1448 device float* o2 = outputs + (tok_base + 2) * rows;
1449 if (row_base < rows) o2[row_base] = s20;
1450 if (row_base + 1 < rows) o2[row_base + 1] = s21;
1451 if (row_base + 2 < rows) o2[row_base + 2] = s22;
1452 if (row_base + 3 < rows) o2[row_base + 3] = s23;
1453 }
1454 if (tok_count > 3) {
1455 device float* o3 = outputs + (tok_base + 3) * rows;
1456 if (row_base < rows) o3[row_base] = s30;
1457 if (row_base + 1 < rows) o3[row_base + 1] = s31;
1458 if (row_base + 2 < rows) o3[row_base + 2] = s32;
1459 if (row_base + 3 < rows) o3[row_base + 3] = s33;
1460 }
1461 }
1462}
1463
1464// ── matmul_q8_mma ──────────────────────────────────────────────────────
1465// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1466// simdgroup_matrix tiles (simdgroup_multiply_accumulate). This dispatches
1467// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1468// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1469//
1470// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1471// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1472// simdgroup_matrix<float, 8, 8> accumulator. Weight bytes are cooperatively
1473// dequantized into threadgroup memory once per block and reused by all
1474// simdgroups in the tile.
1475//
1476// Assumptions (verified in the dispatch helper, falls back otherwise):
1477// * cols % 32 == 0 (one Q8_0 block per K chunk)
1478// * rows % 16 == 0 (tile-aligned; true for all supported architectures)
1479// * num_tokens may be any value; partial row at the tile boundary is handled
1480// via a scratch copy path.
1481constant constexpr uint MMA_TOK_TILE = 16;
1482constant constexpr uint MMA_ROW_TILE = 16;
1483
1484kernel void matmul_q8_mma(
1485 device const uchar* matrix [[buffer(0)]], // Q8_0 [rows, cols/32 * 34]
1486 device const float* inputs [[buffer(1)]], // [M, cols]
1487 device float* outputs [[buffer(2)]], // [M, rows]
1488 constant uint& num_tokens [[buffer(3)]],
1489 constant uint& rows [[buffer(4)]],
1490 constant uint& cols [[buffer(5)]],
1491 uint2 tgid [[threadgroup_position_in_grid]],
1492 uint tid [[thread_index_in_threadgroup]],
1493 uint simd_id [[simdgroup_index_in_threadgroup]])
1494{
1495 uint row_base = tgid.x * MMA_ROW_TILE;
1496 uint tok_base = tgid.y * MMA_TOK_TILE;
1497 if (row_base >= rows || tok_base >= num_tokens) return;
1498
1499 // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1500 threadgroup float w_tile[MMA_ROW_TILE * 32];
1501 threadgroup float t_tile[MMA_TOK_TILE * 32];
1502
1503 // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1504 uint sg_tok_base = (simd_id / 2) * 8; // row within output tile (token dim)
1505 uint sg_row_base = (simd_id % 2) * 8; // col within output tile (row dim)
1506
1507 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1508
1509 uint blocks_per_row = cols / 32;
1510 uint row_bytes = blocks_per_row * 34;
1511
1512 for (uint blk = 0; blk < blocks_per_row; blk++) {
1513 // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1514 // 512 floats / 128 threads = 4 floats per thread.
1515 {
1516 uint base = tid * 4;
1517 for (uint ii = 0; ii < 4; ii++) {
1518 uint idx = base + ii;
1519 uint r = idx / 32;
1520 uint k = idx % 32;
1521 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1522 float sc = float(*(device const half*)rp);
1523 int ival = int(*(device const int8_t*)(rp + 2 + k));
1524 w_tile[r * 32 + k] = float(ival) * sc;
1525 }
1526 }
1527
1528 // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1529 {
1530 uint base = tid * 4;
1531 for (uint ii = 0; ii < 4; ii++) {
1532 uint idx = base + ii;
1533 uint m = idx / 32;
1534 uint k = idx % 32;
1535 uint tok = tok_base + m;
1536 t_tile[m * 32 + k] = (tok < num_tokens)
1537 ? inputs[tok * cols + blk * 32 + k]
1538 : 0.0f;
1539 }
1540 }
1541
1542 threadgroup_barrier(mem_flags::mem_threadgroup);
1543
1544 // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1545 // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k] (M×K, no transpose)
1546 // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k] (loaded transposed → K×R)
1547 // C[m, r] += A[m, k] * B[k, r]
1548 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1549 simdgroup_matrix<float, 8, 8> A, B;
1550 simdgroup_load(A,
1551 t_tile + sg_tok_base * 32 + k_sub * 8,
1552 32,
1553 ulong2(0, 0),
1554 false);
1555 simdgroup_load(B,
1556 w_tile + sg_row_base * 32 + k_sub * 8,
1557 32,
1558 ulong2(0, 0),
1559 true);
1560 simdgroup_multiply_accumulate(C, A, B, C);
1561 }
1562
1563 threadgroup_barrier(mem_flags::mem_threadgroup);
1564 }
1565
1566 // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1567 // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1568 uint out_tok = tok_base + sg_tok_base;
1569 uint out_row = row_base + sg_row_base;
1570 bool full_tok = (out_tok + 8 <= num_tokens);
1571 if (full_tok) {
1572 // Fast path: entire 8×8 sub-tile is in-bounds.
1573 simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1574 } else if (out_tok < num_tokens) {
1575 // Partial row at the last token tile: stage in per-simdgroup scratch
1576 // and scalar-copy the valid rows.
1577 threadgroup float scratch[4 * 64];
1578 simdgroup_store(C, scratch + simd_id * 64, 8);
1579 simdgroup_barrier(mem_flags::mem_threadgroup);
1580 uint lane = tid % 32;
1581 if (lane == 0) {
1582 uint valid = num_tokens - out_tok; // 1..7
1583 for (uint m = 0; m < valid; m++) {
1584 device float* dst = outputs + (out_tok + m) * rows + out_row;
1585 threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1586 dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1587 dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1588 }
1589 }
1590 }
1591}
1592
1593// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1594// Larger-tile variant of matmul_q8_mma for long-context prefill.
1595//
1596// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1597// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1598// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1599//
1600// simd_id = 2*sg_tok_idx + sg_row_half (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1601// output sub-tiles (tok, row):
1602// (sg_tok_idx*8, sg_row_half*16 + 0) -> C_a
1603// (sg_tok_idx*8, sg_row_half*16 + 8) -> C_b
1604//
1605// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1606// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1607// kernel — and halves the number of threadgroups vs the 16×16 tile.
1608//
1609// Assumptions (verified in dispatch helper, fallback otherwise):
1610// * cols % 32 == 0
1611// * rows % 32 == 0
1612constant constexpr uint MMA32_TOK_TILE = 32;
1613constant constexpr uint MMA32_ROW_TILE = 32;
1614
1615kernel void matmul_q8_mma32(
1616 device const uchar* matrix [[buffer(0)]],
1617 device const float* inputs [[buffer(1)]],
1618 device float* outputs [[buffer(2)]],
1619 constant uint& num_tokens [[buffer(3)]],
1620 constant uint& rows [[buffer(4)]],
1621 constant uint& cols [[buffer(5)]],
1622 uint2 tgid [[threadgroup_position_in_grid]],
1623 uint tid [[thread_index_in_threadgroup]],
1624 uint simd_id [[simdgroup_index_in_threadgroup]])
1625{
1626 uint row_base = tgid.x * MMA32_ROW_TILE;
1627 uint tok_base = tgid.y * MMA32_TOK_TILE;
1628 if (row_base >= rows || tok_base >= num_tokens) return;
1629
1630 // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1631 threadgroup float w_tile[MMA32_ROW_TILE * 32];
1632 threadgroup float t_tile[MMA32_TOK_TILE * 32];
1633
1634 uint sg_tok_idx = simd_id / 2; // 0..3
1635 uint sg_row_half = simd_id % 2; // 0..1
1636 uint sg_tok_base = sg_tok_idx * 8;
1637 uint sg_row_base_a = sg_row_half * 16 + 0;
1638 uint sg_row_base_b = sg_row_half * 16 + 8;
1639
1640 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1641 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1642
1643 uint blocks_per_row = cols / 32;
1644 uint row_bytes = blocks_per_row * 34;
1645
1646 for (uint blk = 0; blk < blocks_per_row; blk++) {
1647 // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1648 {
1649 uint base = tid * 4;
1650 for (uint ii = 0; ii < 4; ii++) {
1651 uint idx = base + ii;
1652 uint r = idx / 32;
1653 uint k = idx % 32;
1654 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1655 float sc = float(*(device const half*)rp);
1656 int ival = int(*(device const int8_t*)(rp + 2 + k));
1657 w_tile[r * 32 + k] = float(ival) * sc;
1658 }
1659 }
1660
1661 // Cooperative token tile load.
1662 {
1663 uint base = tid * 4;
1664 for (uint ii = 0; ii < 4; ii++) {
1665 uint idx = base + ii;
1666 uint m = idx / 32;
1667 uint k = idx % 32;
1668 uint tok = tok_base + m;
1669 t_tile[m * 32 + k] = (tok < num_tokens)
1670 ? inputs[tok * cols + blk * 32 + k]
1671 : 0.0f;
1672 }
1673 }
1674
1675 threadgroup_barrier(mem_flags::mem_threadgroup);
1676
1677 // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1678 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1679 simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1680 simdgroup_load(A,
1681 t_tile + sg_tok_base * 32 + k_sub * 8,
1682 32,
1683 ulong2(0, 0),
1684 false);
1685 simdgroup_load(B_a,
1686 w_tile + sg_row_base_a * 32 + k_sub * 8,
1687 32,
1688 ulong2(0, 0),
1689 true);
1690 simdgroup_load(B_b,
1691 w_tile + sg_row_base_b * 32 + k_sub * 8,
1692 32,
1693 ulong2(0, 0),
1694 true);
1695 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1696 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1697 }
1698
1699 threadgroup_barrier(mem_flags::mem_threadgroup);
1700 }
1701
1702 // Store both 8×8 accumulators. rows is always MMA32_ROW_TILE-aligned
1703 // (verified in dispatch), so full simdgroup_store is safe for the row
1704 // dimension; only the last token tile may be partial.
1705 uint out_tok = tok_base + sg_tok_base;
1706 uint out_row_a = row_base + sg_row_base_a;
1707 uint out_row_b = row_base + sg_row_base_b;
1708 bool full_tok = (out_tok + 8 <= num_tokens);
1709 if (full_tok) {
1710 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1711 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1712 } else if (out_tok < num_tokens) {
1713 threadgroup float scratch[8 * 2 * 64]; // 8 simdgroups × 2 accs × 64 floats
1714 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1715 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1716 simdgroup_barrier(mem_flags::mem_threadgroup);
1717 uint lane = tid % 32;
1718 if (lane == 0) {
1719 uint valid = num_tokens - out_tok; // 1..7
1720 for (uint m = 0; m < valid; m++) {
1721 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1722 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1723 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1724 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1725 for (uint j = 0; j < 8; j++) {
1726 dst_a[j] = src_a[j];
1727 dst_b[j] = src_b[j];
1728 }
1729 }
1730 }
1731 }
1732}
1733
1734// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1735// FP16 threadgroup-tile variant of matmul_q8_mma32.
1736//
1737// Stores dequantized weights and token inputs as `half` in threadgroup
1738// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1739// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1740// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1741// representation preserves the full quantized dynamic range. Token
1742// activations stay numerically safe because the subsequent
1743// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1744//
1745// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1746// accumulators each. Primary win vs mma32 is occupancy at moderate
1747// prefill lengths where the GPU is wave-starved.
1748kernel void matmul_q8_mma32_h(
1749 device const uchar* matrix [[buffer(0)]],
1750 device const float* inputs [[buffer(1)]],
1751 device float* outputs [[buffer(2)]],
1752 constant uint& num_tokens [[buffer(3)]],
1753 constant uint& rows [[buffer(4)]],
1754 constant uint& cols [[buffer(5)]],
1755 uint2 tgid [[threadgroup_position_in_grid]],
1756 uint tid [[thread_index_in_threadgroup]],
1757 uint simd_id [[simdgroup_index_in_threadgroup]])
1758{
1759 uint row_base = tgid.x * MMA32_ROW_TILE;
1760 uint tok_base = tgid.y * MMA32_TOK_TILE;
1761 if (row_base >= rows || tok_base >= num_tokens) return;
1762
1763 // 32×32 half tiles — 2 KB each, 4 KB total.
1764 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1765 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1766
1767 uint sg_tok_idx = simd_id / 2;
1768 uint sg_row_half = simd_id % 2;
1769 uint sg_tok_base = sg_tok_idx * 8;
1770 uint sg_row_base_a = sg_row_half * 16 + 0;
1771 uint sg_row_base_b = sg_row_half * 16 + 8;
1772
1773 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1774 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1775
1776 uint blocks_per_row = cols / 32;
1777 uint row_bytes = blocks_per_row * 34;
1778
1779 for (uint blk = 0; blk < blocks_per_row; blk++) {
1780 // Cooperative weight dequantization to FP16.
1781 {
1782 uint base = tid * 4;
1783 for (uint ii = 0; ii < 4; ii++) {
1784 uint idx = base + ii;
1785 uint r = idx / 32;
1786 uint k = idx % 32;
1787 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1788 float sc = float(*(device const half*)rp);
1789 int ival = int(*(device const int8_t*)(rp + 2 + k));
1790 w_tile[r * 32 + k] = half(float(ival) * sc);
1791 }
1792 }
1793
1794 // Cooperative token tile load (f32 → f16 narrowing).
1795 {
1796 uint base = tid * 4;
1797 for (uint ii = 0; ii < 4; ii++) {
1798 uint idx = base + ii;
1799 uint m = idx / 32;
1800 uint k = idx % 32;
1801 uint tok = tok_base + m;
1802 t_tile[m * 32 + k] = (tok < num_tokens)
1803 ? half(inputs[tok * cols + blk * 32 + k])
1804 : half(0);
1805 }
1806 }
1807
1808 threadgroup_barrier(mem_flags::mem_threadgroup);
1809
1810 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1811 simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1812 simdgroup_load(A,
1813 t_tile + sg_tok_base * 32 + k_sub * 8,
1814 32,
1815 ulong2(0, 0),
1816 false);
1817 simdgroup_load(B_a,
1818 w_tile + sg_row_base_a * 32 + k_sub * 8,
1819 32,
1820 ulong2(0, 0),
1821 true);
1822 simdgroup_load(B_b,
1823 w_tile + sg_row_base_b * 32 + k_sub * 8,
1824 32,
1825 ulong2(0, 0),
1826 true);
1827 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1828 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1829 }
1830
1831 threadgroup_barrier(mem_flags::mem_threadgroup);
1832 }
1833
1834 uint out_tok = tok_base + sg_tok_base;
1835 uint out_row_a = row_base + sg_row_base_a;
1836 uint out_row_b = row_base + sg_row_base_b;
1837 bool full_tok = (out_tok + 8 <= num_tokens);
1838 if (full_tok) {
1839 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1840 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1841 } else if (out_tok < num_tokens) {
1842 threadgroup float scratch[8 * 2 * 64];
1843 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1844 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1845 simdgroup_barrier(mem_flags::mem_threadgroup);
1846 uint lane = tid % 32;
1847 if (lane == 0) {
1848 uint valid = num_tokens - out_tok;
1849 for (uint m = 0; m < valid; m++) {
1850 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1851 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1852 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1853 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1854 for (uint j = 0; j < 8; j++) {
1855 dst_a[j] = src_a[j];
1856 dst_b[j] = src_b[j];
1857 }
1858 }
1859 }
1860 }
1861}
1862
1863// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1864// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1865//
1866// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1867// 4 simdgroups × **2×2 grid** of 8×8 accumulators each. Per simdgroup:
1868// C_00 (tok 0..8, row 0..8) C_01 (tok 0..8, row 8..16)
1869// C_10 (tok 8..16, row 0..8) C_11 (tok 8..16, row 8..16)
1870// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1871//
1872// Per K_sub iteration: load two A fragments and two B fragments, then run
1873// **four** MMA instructions reusing A_top with both B's and A_bot with
1874// both B's. That's double the FLOP-per-simdgroup-load compared to the
1875// 2-accumulator kernel and halves the thread count per threadgroup (128
1876// threads), which often improves occupancy on Apple GPUs where the
1877// concurrent-thread budget is the tighter limit than shared-memory size.
1878kernel void matmul_q8_mma32_h4(
1879 device const uchar* matrix [[buffer(0)]],
1880 device const float* inputs [[buffer(1)]],
1881 device float* outputs [[buffer(2)]],
1882 constant uint& num_tokens [[buffer(3)]],
1883 constant uint& rows [[buffer(4)]],
1884 constant uint& cols [[buffer(5)]],
1885 uint2 tgid [[threadgroup_position_in_grid]],
1886 uint tid [[thread_index_in_threadgroup]],
1887 uint simd_id [[simdgroup_index_in_threadgroup]])
1888{
1889 uint row_base = tgid.x * MMA32_ROW_TILE;
1890 uint tok_base = tgid.y * MMA32_TOK_TILE;
1891 if (row_base >= rows || tok_base >= num_tokens) return;
1892
1893 // 32×32 FP16 tiles, 4 KB total.
1894 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1895 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1896
1897 // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1898 uint sg_tok_q = simd_id / 2; // 0..1
1899 uint sg_row_q = simd_id % 2; // 0..1
1900 uint sg_tok_base = sg_tok_q * 16;
1901 uint sg_row_base = sg_row_q * 16;
1902
1903 simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1904 simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1905 simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1906 simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1907
1908 uint blocks_per_row = cols / 32;
1909 uint row_bytes = blocks_per_row * 34;
1910
1911 for (uint blk = 0; blk < blocks_per_row; blk++) {
1912 // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1913 {
1914 uint base = tid * 8;
1915 for (uint ii = 0; ii < 8; ii++) {
1916 uint idx = base + ii;
1917 uint r = idx / 32;
1918 uint k = idx % 32;
1919 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1920 float sc = float(*(device const half*)rp);
1921 int ival = int(*(device const int8_t*)(rp + 2 + k));
1922 w_tile[r * 32 + k] = half(float(ival) * sc);
1923 }
1924 }
1925
1926 // Cooperative token tile load.
1927 {
1928 uint base = tid * 8;
1929 for (uint ii = 0; ii < 8; ii++) {
1930 uint idx = base + ii;
1931 uint m = idx / 32;
1932 uint k = idx % 32;
1933 uint tok = tok_base + m;
1934 t_tile[m * 32 + k] = (tok < num_tokens)
1935 ? half(inputs[tok * cols + blk * 32 + k])
1936 : half(0);
1937 }
1938 }
1939
1940 threadgroup_barrier(mem_flags::mem_threadgroup);
1941
1942 // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1943 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1944 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1945 simdgroup_load(A_top,
1946 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1947 32,
1948 ulong2(0, 0),
1949 false);
1950 simdgroup_load(A_bot,
1951 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1952 32,
1953 ulong2(0, 0),
1954 false);
1955 simdgroup_load(B_lo,
1956 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1957 32,
1958 ulong2(0, 0),
1959 true);
1960 simdgroup_load(B_hi,
1961 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1962 32,
1963 ulong2(0, 0),
1964 true);
1965 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1966 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1967 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1968 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1969 }
1970
1971 threadgroup_barrier(mem_flags::mem_threadgroup);
1972 }
1973
1974 // Store 4 output tiles. Full-tile fast path assumes full 16×16 valid.
1975 uint out_tok_top = tok_base + sg_tok_base + 0;
1976 uint out_tok_bot = tok_base + sg_tok_base + 8;
1977 uint out_row_lo = row_base + sg_row_base + 0;
1978 uint out_row_hi = row_base + sg_row_base + 8;
1979 bool full = (out_tok_bot + 8 <= num_tokens);
1980 if (full) {
1981 simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1982 simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1983 simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1984 simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1985 } else {
1986 // Partial-token fallback via per-simdgroup scratch.
1987 threadgroup float scratch[4 * 4 * 64]; // 4 simdgroups × 4 accs × 64
1988 uint sg_base = simd_id * 256;
1989 simdgroup_store(C_00, scratch + sg_base + 0, 8);
1990 simdgroup_store(C_01, scratch + sg_base + 64, 8);
1991 simdgroup_store(C_10, scratch + sg_base + 128, 8);
1992 simdgroup_store(C_11, scratch + sg_base + 192, 8);
1993 simdgroup_barrier(mem_flags::mem_threadgroup);
1994 uint lane = tid % 32;
1995 if (lane == 0) {
1996 for (uint m = 0; m < 8; m++) {
1997 uint t_top = out_tok_top + m;
1998 if (t_top < num_tokens) {
1999 device float* dst0 = outputs + t_top * rows + out_row_lo;
2000 device float* dst1 = outputs + t_top * rows + out_row_hi;
2001 threadgroup const float* src0 = scratch + sg_base + 0 + m * 8;
2002 threadgroup const float* src1 = scratch + sg_base + 64 + m * 8;
2003 for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
2004 }
2005 uint t_bot = out_tok_bot + m;
2006 if (t_bot < num_tokens) {
2007 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2008 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2009 threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2010 threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2011 for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2012 }
2013 }
2014 }
2015 }
2016}
2017
2018// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2019// All-half MMA variant of matmul_q8_mma32_h4.
2020//
2021// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2022// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2023// rate (dual-issue FMA), so if Q8_0 precision holds through half
2024// accumulation this kernel can double the effective FLOP throughput on
2025// matmul-bound prefill.
2026//
2027// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2028// activations at each layer are bounded (post-RMSNorm ≈ O(1)). Summing
2029// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2030// precision on extreme values, but the inputs are already quantized so the
2031// per-product error floor is higher than the half-precision rounding error.
2032// We verify correctness on 135M / 1B / 3B before enabling.
2033kernel void matmul_q8_mma32_hh4(
2034 device const uchar* matrix [[buffer(0)]],
2035 device const float* inputs [[buffer(1)]],
2036 device float* outputs [[buffer(2)]],
2037 constant uint& num_tokens [[buffer(3)]],
2038 constant uint& rows [[buffer(4)]],
2039 constant uint& cols [[buffer(5)]],
2040 uint2 tgid [[threadgroup_position_in_grid]],
2041 uint tid [[thread_index_in_threadgroup]],
2042 uint simd_id [[simdgroup_index_in_threadgroup]])
2043{
2044 uint row_base = tgid.x * MMA32_ROW_TILE;
2045 uint tok_base = tgid.y * MMA32_TOK_TILE;
2046 if (row_base >= rows || tok_base >= num_tokens) return;
2047
2048 threadgroup half w_tile[MMA32_ROW_TILE * 32];
2049 threadgroup half t_tile[MMA32_TOK_TILE * 32];
2050
2051 uint sg_tok_q = simd_id / 2;
2052 uint sg_row_q = simd_id % 2;
2053 uint sg_tok_base = sg_tok_q * 16;
2054 uint sg_row_base = sg_row_q * 16;
2055
2056 simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2057 simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2058 simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2059 simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2060
2061 uint blocks_per_row = cols / 32;
2062 uint row_bytes = blocks_per_row * 34;
2063
2064 for (uint blk = 0; blk < blocks_per_row; blk++) {
2065 {
2066 uint base = tid * 8;
2067 for (uint ii = 0; ii < 8; ii++) {
2068 uint idx = base + ii;
2069 uint r = idx / 32;
2070 uint k = idx % 32;
2071 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2072 float sc = float(*(device const half*)rp);
2073 int ival = int(*(device const int8_t*)(rp + 2 + k));
2074 w_tile[r * 32 + k] = half(float(ival) * sc);
2075 }
2076 }
2077 {
2078 uint base = tid * 8;
2079 for (uint ii = 0; ii < 8; ii++) {
2080 uint idx = base + ii;
2081 uint m = idx / 32;
2082 uint k = idx % 32;
2083 uint tok = tok_base + m;
2084 t_tile[m * 32 + k] = (tok < num_tokens)
2085 ? half(inputs[tok * cols + blk * 32 + k])
2086 : half(0);
2087 }
2088 }
2089 threadgroup_barrier(mem_flags::mem_threadgroup);
2090
2091 for (uint k_sub = 0; k_sub < 4; k_sub++) {
2092 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2093 simdgroup_load(A_top,
2094 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2095 32, ulong2(0, 0), false);
2096 simdgroup_load(A_bot,
2097 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2098 32, ulong2(0, 0), false);
2099 simdgroup_load(B_lo,
2100 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2101 32, ulong2(0, 0), true);
2102 simdgroup_load(B_hi,
2103 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2104 32, ulong2(0, 0), true);
2105 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2106 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2107 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2108 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2109 }
2110
2111 threadgroup_barrier(mem_flags::mem_threadgroup);
2112 }
2113
2114 // Store half accumulators via scratch (must widen to f32 for device output).
2115 uint out_tok_top = tok_base + sg_tok_base + 0;
2116 uint out_tok_bot = tok_base + sg_tok_base + 8;
2117 uint out_row_lo = row_base + sg_row_base + 0;
2118 uint out_row_hi = row_base + sg_row_base + 8;
2119
2120 threadgroup half scratch[4 * 4 * 64];
2121 uint sg_base = simd_id * 256;
2122 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2123 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2124 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2125 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2126 simdgroup_barrier(mem_flags::mem_threadgroup);
2127 uint lane = tid % 32;
2128 if (lane == 0) {
2129 for (uint m = 0; m < 8; m++) {
2130 uint t_top = out_tok_top + m;
2131 if (t_top < num_tokens) {
2132 device float* dst0 = outputs + t_top * rows + out_row_lo;
2133 device float* dst1 = outputs + t_top * rows + out_row_hi;
2134 threadgroup const half* src0 = scratch + sg_base + 0 + m * 8;
2135 threadgroup const half* src1 = scratch + sg_base + 64 + m * 8;
2136 for (uint j = 0; j < 8; j++) {
2137 dst0[j] = float(src0[j]);
2138 dst1[j] = float(src1[j]);
2139 }
2140 }
2141 uint t_bot = out_tok_bot + m;
2142 if (t_bot < num_tokens) {
2143 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2144 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2145 threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2146 threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2147 for (uint j = 0; j < 8; j++) {
2148 dst2[j] = float(src2[j]);
2149 dst3[j] = float(src3[j]);
2150 }
2151 }
2152 }
2153 }
2154}
2155
2156// ── add_bias_batch ─────────────────────────────────────────────────────
2157// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2158// Used for Qwen2 QKV bias after the fused qkv matmul.
2159// out[token, i] += bias[i] for i in 0..rows, token in 0..num_tokens
2160kernel void add_bias_batch(
2161 device float* out [[buffer(0)]], // [num_tokens, rows]
2162 device const float* bias [[buffer(1)]], // [rows]
2163 constant uint& num_tokens [[buffer(2)]],
2164 constant uint& rows [[buffer(3)]],
2165 uint id [[thread_position_in_grid]])
2166{
2167 uint total = num_tokens * rows;
2168 if (id >= total) return;
2169 uint i = id % rows;
2170 out[id] += bias[i];
2171}
2172
2173// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2174// Batched Q4_0 matrix-vector multiply for M input vectors.
2175// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2176kernel void matmul_vec_q4_batch(
2177 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes [rows, cols]
2178 device const float* inputs [[buffer(1)]], // [M, cols] input batch
2179 device float* outputs [[buffer(2)]], // [M, rows] output batch
2180 constant uint& num_tokens [[buffer(3)]], // M
2181 constant uint& rows [[buffer(4)]],
2182 constant uint& cols [[buffer(5)]],
2183 uint tgid [[threadgroup_position_in_grid]],
2184 uint tid [[thread_index_in_threadgroup]],
2185 uint simd_lane [[thread_index_in_simdgroup]],
2186 uint simd_id [[simdgroup_index_in_threadgroup]])
2187{
2188 uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2189 uint token = tgid / row_tgs;
2190 uint tg_in_token = tgid % row_tgs;
2191 if (token >= num_tokens) return;
2192
2193 threadgroup float vec_tile[VEC_TILE_SIZE];
2194 device const float* input = inputs + token * cols;
2195 for (uint i = tid; i < cols; i += 256) {
2196 vec_tile[i] = input[i];
2197 }
2198 threadgroup_barrier(mem_flags::mem_threadgroup);
2199
2200 uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2201 if (row_base >= rows) return;
2202
2203 uint blocks_per_row = cols / 32;
2204 uint row_bytes = blocks_per_row * 18;
2205
2206 device const uchar* r0 = matrix + row_base * row_bytes;
2207 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2208 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2209 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2210
2211 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2212
2213 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2214 uint bb = blk * 18;
2215 uint vb = blk * 32;
2216
2217 float sc0 = float(*(device const half*)(r0 + bb));
2218 float sc1 = float(*(device const half*)(r1 + bb));
2219 float sc2 = float(*(device const half*)(r2 + bb));
2220 float sc3 = float(*(device const half*)(r3 + bb));
2221
2222 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2223 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2224 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2225 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2226
2227 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2228 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2229 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2230 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2231 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2232 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2233 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2234 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2235
2236 float bd0=0, bd1=0, bd2=0, bd3=0;
2237 uchar4 b;
2238
2239 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2240 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2241 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2242 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2243
2244 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2245 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2246 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2247 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2248
2249 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2250 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2251 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2252 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2253
2254 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2255 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2256 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2257 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2258
2259 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2260 }
2261
2262 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2263 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2264
2265 device float* output = outputs + token * rows;
2266 if (simd_lane == 0) {
2267 if (row_base < rows) output[row_base] = sum0;
2268 if (row_base + 1 < rows) output[row_base + 1] = sum1;
2269 if (row_base + 2 < rows) output[row_base + 2] = sum2;
2270 if (row_base + 3 < rows) output[row_base + 3] = sum3;
2271 }
2272}
2273
2274// ── copy_kv_batch ─────────────────────────────────────────────────────
2275// Copy K or V from a strided batch QKV buffer to the KV cache.
2276// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2277// dst layout: contiguous [max_seq, kv_dim] cache.
2278kernel void copy_kv_batch(
2279 device const float* src [[buffer(0)]], // batch QKV buffer
2280 device float* dst [[buffer(1)]], // KV cache
2281 constant uint& M [[buffer(2)]], // num tokens in batch
2282 constant uint& kv_dim [[buffer(3)]], // floats per KV vector
2283 constant uint& base_pos [[buffer(4)]], // starting position in cache
2284 constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2285 constant uint& src_offset [[buffer(6)]], // float offset within each src row
2286 uint id [[thread_position_in_grid]])
2287{
2288 uint total = M * kv_dim;
2289 if (id >= total) return;
2290 uint token = id / kv_dim;
2291 uint d = id % kv_dim;
2292 uint dst_off = (base_pos + token) * kv_dim + d;
2293 uint src_off = token * src_stride + src_offset + d;
2294 dst[dst_off] = src[src_off];
2295}
2296
2297// ── attention_batch ───────────────────────────────────────────────────
2298// Batched causal attention for prefill. Processes M tokens in one dispatch.
2299// Each threadgroup handles one (token, head) pair.
2300// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2301// Causal masking: token i can only attend to positions 0..base_pos+i.
2302kernel void attention_batch(
2303 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2304 device const float* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim]
2305 device const float* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim]
2306 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2307 constant uint& M [[buffer(4)]], // num tokens in batch
2308 constant uint& base_pos [[buffer(5)]], // starting position in KV cache
2309 constant uint& num_heads [[buffer(6)]],
2310 constant uint& num_kv_heads [[buffer(7)]],
2311 constant uint& head_dim [[buffer(8)]],
2312 constant uint& q_stride [[buffer(9)]], // floats per row in q_batch
2313 uint tgid [[threadgroup_position_in_grid]],
2314 uint tid [[thread_index_in_threadgroup]],
2315 uint simd_lane [[thread_index_in_simdgroup]],
2316 uint simd_id [[simdgroup_index_in_threadgroup]])
2317{
2318 // Grid: M * num_heads threadgroups
2319 uint token_idx = tgid / num_heads;
2320 uint head = tgid % num_heads;
2321 if (token_idx >= M) return;
2322
2323 uint kv_head = head / (num_heads / num_kv_heads);
2324 uint seq_len = base_pos + token_idx + 1; // causal: see positions 0..base_pos+token_idx
2325
2326 // Q offset uses strided layout (from batch QKV buffer)
2327 uint q_off = token_idx * q_stride + head * head_dim;
2328 // Output is contiguous [M, num_heads * head_dim]
2329 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2330
2331 // Shared memory for attention scores — sized to the effective max_seq_len
2332 // (4096 for all supported models) so long-context attention doesn't overflow.
2333 threadgroup float scores[ATTN_SCORES_SIZE];
2334
2335 // Step 1: Q * K^T with simdgroup reduction
2336 for (uint s = simd_id; s < seq_len; s += 8) {
2337 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2338 float dot = 0.0;
2339 for (uint d = simd_lane; d < head_dim; d += 32) {
2340 dot += q_batch[q_off + d] * k_cache[k_off + d];
2341 }
2342 dot = simd_sum(dot);
2343 if (simd_lane == 0) {
2344 scores[s] = dot * fast::rsqrt(float(head_dim));
2345 }
2346 }
2347 threadgroup_barrier(mem_flags::mem_threadgroup);
2348
2349 // Step 2: Softmax (cooperative)
2350 float local_max = -INFINITY;
2351 for (uint s = tid; s < seq_len; s += 256) {
2352 local_max = max(local_max, scores[s]);
2353 }
2354 local_max = simd_max(local_max);
2355 threadgroup float shared_max[8];
2356 if (simd_lane == 0) shared_max[simd_id] = local_max;
2357 threadgroup_barrier(mem_flags::mem_threadgroup);
2358 if (tid == 0) {
2359 float m = shared_max[0];
2360 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2361 shared_max[0] = m;
2362 }
2363 threadgroup_barrier(mem_flags::mem_threadgroup);
2364 float max_val = shared_max[0];
2365
2366 float local_sum = 0.0;
2367 for (uint s = tid; s < seq_len; s += 256) {
2368 scores[s] = fast::exp(scores[s] - max_val);
2369 local_sum += scores[s];
2370 }
2371 local_sum = simd_sum(local_sum);
2372 threadgroup float shared_sum[8];
2373 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2374 threadgroup_barrier(mem_flags::mem_threadgroup);
2375 if (tid == 0) {
2376 float total = 0.0;
2377 for (uint i = 0; i < 8; i++) total += shared_sum[i];
2378 shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2379 }
2380 threadgroup_barrier(mem_flags::mem_threadgroup);
2381 float inv_sum = shared_sum[0];
2382 for (uint s = tid; s < seq_len; s += 256) {
2383 scores[s] *= inv_sum;
2384 }
2385 threadgroup_barrier(mem_flags::mem_threadgroup);
2386
2387 // Step 3: scores * V using float4 vectorized loads
2388 // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2389 // This is much better than the scalar version where only 64 of 256 threads are active.
2390 uint v_stride = num_kv_heads * head_dim;
2391 uint head_dim4 = head_dim / 4;
2392 for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2393 uint d = d4 * 4;
2394 float4 acc = float4(0.0);
2395 uint v_base = kv_head * head_dim + d;
2396 uint seq_len4 = seq_len & ~3u;
2397 for (uint s = 0; s < seq_len4; s += 4) {
2398 float sc0 = scores[s];
2399 float sc1 = scores[s + 1];
2400 float sc2 = scores[s + 2];
2401 float sc3 = scores[s + 3];
2402 acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
2403 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
2404 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
2405 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
2406 }
2407 for (uint s = seq_len4; s < seq_len; s++) {
2408 acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
2409 }
2410 *(device float4*)(output_batch + out_off + d) = acc;
2411 }
2412 // Handle remaining dimensions not divisible by 4 (scalar fallback)
2413 for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2414 float acc = 0.0;
2415 uint v_base = kv_head * head_dim + d;
2416 for (uint s = 0; s < seq_len; s++) {
2417 acc += scores[s] * v_cache[s * v_stride + v_base];
2418 }
2419 output_batch[out_off + d] = acc;
2420 }
2421}
2422
2423// ── attention_flash_batch ─────────────────────────────────────────────
2424// Streaming attention with online softmax. Same grid as attention_batch
2425// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2426// matrix is never materialized. K/V positions are processed in a tile of
2427// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2428// the standard flash-attention recurrence:
2429//
2430// m_new = max(m_old, tile_max)
2431// alpha = exp(m_old - m_new)
2432// l_new = alpha * l_old + sum(exp(S - m_new))
2433// O_new = alpha * O_old + sum(exp(S - m_new) * V)
2434// O_final = O / l
2435//
2436// This removes the `scores[2048]` cap in attention_batch (which silently
2437// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2438// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2439//
2440// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2441constant constexpr uint FLASH_K_TILE = 32;
2442constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2443
2444kernel void attention_flash_batch(
2445 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2446 device const float* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim]
2447 device const float* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim]
2448 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2449 constant uint& M [[buffer(4)]],
2450 constant uint& base_pos [[buffer(5)]],
2451 constant uint& num_heads [[buffer(6)]],
2452 constant uint& num_kv_heads [[buffer(7)]],
2453 constant uint& head_dim [[buffer(8)]],
2454 constant uint& q_stride [[buffer(9)]],
2455 uint tgid [[threadgroup_position_in_grid]],
2456 uint tid [[thread_index_in_threadgroup]],
2457 uint simd_lane [[thread_index_in_simdgroup]],
2458 uint simd_id [[simdgroup_index_in_threadgroup]])
2459{
2460 uint token_idx = tgid / num_heads;
2461 uint head = tgid % num_heads;
2462 if (token_idx >= M) return;
2463
2464 uint kv_head = head / (num_heads / num_kv_heads);
2465 uint seq_len = base_pos + token_idx + 1; // causal: attend to [0, base_pos + token_idx]
2466 uint q_off = token_idx * q_stride + head * head_dim;
2467 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2468
2469 // Threadgroup state:
2470 // q_sh: Q vector for this (token, head), loaded once
2471 // o_sh: running output vector, updated each K tile
2472 // scores_sh: scores for the current K tile only
2473 // stats: [running max, running sum] (see flash-attention recurrence)
2474 threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2475 threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2476 threadgroup float scores_sh[FLASH_K_TILE];
2477 threadgroup float stats[2];
2478 threadgroup float sg_scratch[8]; // simdgroup-level reduction buffer
2479
2480 // --- Load Q (one row) and zero the running O ---
2481 for (uint d = tid; d < head_dim; d += 256) {
2482 q_sh[d] = q_batch[q_off + d];
2483 o_sh[d] = 0.0f;
2484 }
2485 if (tid == 0) {
2486 stats[0] = -INFINITY;
2487 stats[1] = 0.0f;
2488 }
2489 threadgroup_barrier(mem_flags::mem_threadgroup);
2490
2491 float scale = fast::rsqrt(float(head_dim));
2492 uint v_stride = num_kv_heads * head_dim;
2493 uint v_base = kv_head * head_dim;
2494
2495 // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2496 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2497 uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2498
2499 // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2500 // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2501 for (uint ti = simd_id; ti < tile_n; ti += 8) {
2502 uint k_off = (kv_base + ti) * v_stride + v_base; // same layout as V stride
2503 float dot = 0.0f;
2504 for (uint d = simd_lane; d < head_dim; d += 32) {
2505 dot += q_sh[d] * k_cache[k_off + d];
2506 }
2507 dot = simd_sum(dot);
2508 if (simd_lane == 0) {
2509 scores_sh[ti] = dot * scale;
2510 }
2511 }
2512 threadgroup_barrier(mem_flags::mem_threadgroup);
2513
2514 // [2] Tile max via cooperative reduction.
2515 float local_max = -INFINITY;
2516 for (uint s = tid; s < tile_n; s += 256) {
2517 local_max = max(local_max, scores_sh[s]);
2518 }
2519 local_max = simd_max(local_max);
2520 if (simd_lane == 0) {
2521 sg_scratch[simd_id] = local_max;
2522 }
2523 threadgroup_barrier(mem_flags::mem_threadgroup);
2524 // [3] Merge with running max, compute alpha, rescale running l.
2525 float m_new;
2526 float alpha;
2527 if (tid == 0) {
2528 float tile_max = sg_scratch[0];
2529 for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2530 float m_old = stats[0];
2531 m_new = max(m_old, tile_max);
2532 // First iteration: m_old = -inf → alpha = 0 (reset O).
2533 alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2534 stats[0] = m_new;
2535 stats[1] *= alpha;
2536 // Broadcast via sg_scratch.
2537 sg_scratch[0] = alpha;
2538 sg_scratch[1] = m_new;
2539 }
2540 threadgroup_barrier(mem_flags::mem_threadgroup);
2541 alpha = sg_scratch[0];
2542 m_new = sg_scratch[1];
2543
2544 // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2545 for (uint d = tid; d < head_dim; d += 256) {
2546 o_sh[d] *= alpha;
2547 }
2548 for (uint s = tid; s < tile_n; s += 256) {
2549 scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2550 }
2551 threadgroup_barrier(mem_flags::mem_threadgroup);
2552
2553 // [5] Tile sum → update running l.
2554 float local_sum = 0.0f;
2555 for (uint s = tid; s < tile_n; s += 256) {
2556 local_sum += scores_sh[s];
2557 }
2558 local_sum = simd_sum(local_sum);
2559 if (simd_lane == 0) {
2560 sg_scratch[simd_id] = local_sum;
2561 }
2562 threadgroup_barrier(mem_flags::mem_threadgroup);
2563 if (tid == 0) {
2564 float tile_sum = 0.0f;
2565 for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2566 stats[1] += tile_sum;
2567 }
2568 threadgroup_barrier(mem_flags::mem_threadgroup);
2569
2570 // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2571 for (uint d = tid; d < head_dim; d += 256) {
2572 float acc = 0.0f;
2573 for (uint s = 0; s < tile_n; s++) {
2574 acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2575 }
2576 o_sh[d] += acc;
2577 }
2578 threadgroup_barrier(mem_flags::mem_threadgroup);
2579 }
2580
2581 // --- Normalize and write output ---
2582 float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2583 for (uint d = tid; d < head_dim; d += 256) {
2584 output_batch[out_off + d] = o_sh[d] * inv_l;
2585 }
2586}
2587
2588// ── attention_mma_flash_batch ─────────────────────────────────────────
2589// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
2590// both Q·K^T and P·V. Processes Q_BLOCK=8 tokens of one head per
2591// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
2592// K/V loads across 8 Q rows and using hardware matrix-multiply for the
2593// arithmetic.
2594//
2595// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
2596// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
2597// to attention_batch / attention_flash_batch otherwise.
2598//
2599// Online softmax recurrence is identical to attention_flash_batch but
2600// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
2601constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
2602constant constexpr uint FLASH_MMA_K_BLOCK = 32;
2603constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
2604
2605kernel void attention_mma_flash_batch(
2606 device const float* q_batch [[buffer(0)]],
2607 device const float* k_cache [[buffer(1)]],
2608 device const float* v_cache [[buffer(2)]],
2609 device float* output_batch [[buffer(3)]],
2610 constant uint& M [[buffer(4)]],
2611 constant uint& base_pos [[buffer(5)]],
2612 constant uint& num_heads [[buffer(6)]],
2613 constant uint& num_kv_heads [[buffer(7)]],
2614 constant uint& head_dim [[buffer(8)]],
2615 constant uint& q_stride [[buffer(9)]],
2616 uint2 tgid [[threadgroup_position_in_grid]],
2617 uint tid [[thread_index_in_threadgroup]],
2618 uint simd_lane [[thread_index_in_simdgroup]],
2619 uint simd_id [[simdgroup_index_in_threadgroup]])
2620{
2621 uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
2622 uint head = tgid.y;
2623 if (q_block_start >= M) return;
2624 uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
2625
2626 uint kv_head = head / (num_heads / num_kv_heads);
2627 // Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
2628 // Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
2629 uint seq_len = base_pos + q_block_start + q_valid;
2630 float scale = fast::rsqrt(float(head_dim));
2631
2632 uint kv_stride = num_kv_heads * head_dim;
2633 uint kv_base_off = kv_head * head_dim;
2634
2635 // ── Threadgroup memory ──
2636 // q_sh: [Q_BLOCK, head_dim] half — Q tile, loaded once
2637 // k_sh: [K_BLOCK, head_dim] half — K tile, refreshed per kv_base iter
2638 // v_sh: [K_BLOCK, head_dim] half — V tile, refreshed per kv_base iter
2639 // s_sh: [Q_BLOCK, K_BLOCK] float — raw Q·K^T scores, then scaled+masked
2640 // p_sh: [Q_BLOCK, K_BLOCK] half — softmax probabilities (for P·V MMA)
2641 // o_sh: [Q_BLOCK, head_dim] float — running output accumulator
2642 // m_sh: [Q_BLOCK] float — running max per Q row
2643 // l_sh: [Q_BLOCK] float — running softmax denominator per Q row
2644 // scratch: 4*Q_BLOCK floats — per-row reduction scratch
2645 threadgroup half q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2646 threadgroup half k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2647 threadgroup half v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2648 threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2649 threadgroup half p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
2650 threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
2651 threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
2652 threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
2653 threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
2654
2655 // ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
2656 uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
2657 for (uint i = tid; i < qblock_elems; i += 128) {
2658 uint q = i / head_dim;
2659 uint d = i % head_dim;
2660 if (q < q_valid) {
2661 uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
2662 q_sh[q * head_dim + d] = half(q_batch[q_off]);
2663 } else {
2664 q_sh[q * head_dim + d] = half(0);
2665 }
2666 o_sh[q * head_dim + d] = 0.0f;
2667 }
2668 if (tid < FLASH_MMA_Q_BLOCK) {
2669 m_sh[tid] = -INFINITY;
2670 l_sh[tid] = 0.0f;
2671 }
2672 threadgroup_barrier(mem_flags::mem_threadgroup);
2673
2674 // ── Stream K/V in K_BLOCK chunks ──
2675 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
2676 uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
2677
2678 // Load K and V tile into TG memory (as half).
2679 uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
2680 for (uint i = tid; i < kv_tile_elems; i += 128) {
2681 uint k_pos = i / head_dim;
2682 uint d = i % head_dim;
2683 if (k_pos < tile_n) {
2684 uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
2685 k_sh[k_pos * head_dim + d] = half(k_cache[off]);
2686 v_sh[k_pos * head_dim + d] = half(v_cache[off]);
2687 } else {
2688 k_sh[k_pos * head_dim + d] = half(0);
2689 v_sh[k_pos * head_dim + d] = half(0);
2690 }
2691 }
2692 threadgroup_barrier(mem_flags::mem_threadgroup);
2693
2694 // ── Phase 1: S = Q @ K^T via MMA ──
2695 // 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
2696 // Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
2697 // Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
2698 {
2699 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
2700 uint dim_chunks = head_dim / 8;
2701 for (uint dc = 0; dc < dim_chunks; dc++) {
2702 simdgroup_matrix<half, 8, 8> A, B;
2703 // A = Q[0:8, dc*8 : dc*8+8] (rows of Q, no transpose)
2704 simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
2705 // B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
2706 // K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
2707 // K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
2708 // transpose=true, which places it in the register as K^T of that sub-block.
2709 simdgroup_load(B,
2710 k_sh + (simd_id * 8) * head_dim + dc * 8,
2711 head_dim, ulong2(0, 0), true);
2712 simdgroup_multiply_accumulate(C, A, B, C);
2713 }
2714 // Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
2715 simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
2716 }
2717 threadgroup_barrier(mem_flags::mem_threadgroup);
2718
2719 // ── Phase 2a: Apply scale + causal mask in place on s_sh ──
2720 // s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
2721 uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
2722 for (uint i = tid; i < s_elems; i += 128) {
2723 uint q = i / FLASH_MMA_K_BLOCK;
2724 uint k = i % FLASH_MMA_K_BLOCK;
2725 uint global_q = q_block_start + q;
2726 uint global_kv = kv_base + k;
2727 bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
2728 s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
2729 }
2730 threadgroup_barrier(mem_flags::mem_threadgroup);
2731
2732 // ── Phase 2b: per-row max via simdgroup reduction ──
2733 // 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
2734 // simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
2735 {
2736 uint row_base = simd_id * 2;
2737 for (uint qr = 0; qr < 2; qr++) {
2738 uint q = row_base + qr;
2739 float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
2740 float row_max = simd_max(my);
2741 if (simd_lane == 0) {
2742 scratch[q] = row_max; // tile_max[q]
2743 }
2744 }
2745 }
2746 threadgroup_barrier(mem_flags::mem_threadgroup);
2747
2748 // ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
2749 if (tid < FLASH_MMA_Q_BLOCK) {
2750 uint q = tid;
2751 float m_old = m_sh[q];
2752 float tile_max = scratch[q];
2753 float m_new = max(m_old, tile_max);
2754 float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2755 m_sh[q] = m_new;
2756 l_sh[q] = l_sh[q] * alpha;
2757 // scratch[q] = m_new (for phase 2d)
2758 // scratch[Q_BLOCK + q] = alpha (for phase 3)
2759 scratch[q] = m_new;
2760 scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
2761 }
2762 threadgroup_barrier(mem_flags::mem_threadgroup);
2763
2764 // ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
2765 {
2766 uint row_base = simd_id * 2;
2767 for (uint qr = 0; qr < 2; qr++) {
2768 uint q = row_base + qr;
2769 float m_new = scratch[q];
2770 float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
2771 p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
2772 float row_sum = simd_sum(p);
2773 if (simd_lane == 0) {
2774 scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
2775 }
2776 }
2777 }
2778 threadgroup_barrier(mem_flags::mem_threadgroup);
2779
2780 // ── Phase 2e: l_sh += tile_sum ──
2781 if (tid < FLASH_MMA_Q_BLOCK) {
2782 uint q = tid;
2783 l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
2784 }
2785
2786 // ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
2787 threadgroup_barrier(mem_flags::mem_threadgroup);
2788 for (uint i = tid; i < qblock_elems; i += 128) {
2789 uint q = i / head_dim;
2790 float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
2791 o_sh[i] *= alpha;
2792 }
2793 threadgroup_barrier(mem_flags::mem_threadgroup);
2794
2795 // ── Phase 4: O += P @ V via MMA ──
2796 // P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
2797 // Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
2798 // For head_dim=64: 16 dims/sg = 2 tiles. head_dim=128: 32 dims/sg = 4 tiles.
2799 {
2800 uint dims_per_sg = head_dim / 4; // 16 or 32
2801 uint tiles_per_sg = dims_per_sg / 8; // 2 or 4
2802 uint sg_d_base = simd_id * dims_per_sg;
2803 for (uint t = 0; t < tiles_per_sg; t++) {
2804 uint d_base = sg_d_base + t * 8;
2805 simdgroup_matrix<float, 8, 8> O_acc;
2806 simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
2807 uint k_chunks = FLASH_MMA_K_BLOCK / 8; // 4
2808 for (uint kc = 0; kc < k_chunks; kc++) {
2809 simdgroup_matrix<half, 8, 8> A, B;
2810 simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
2811 ulong2(0, 0), false);
2812 simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
2813 ulong2(0, 0), false);
2814 simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
2815 }
2816 simdgroup_store(O_acc, o_sh + d_base, head_dim);
2817 }
2818 }
2819 threadgroup_barrier(mem_flags::mem_threadgroup);
2820 }
2821
2822 // ── Finalize: O /= l, write to output ──
2823 for (uint i = tid; i < qblock_elems; i += 128) {
2824 uint q = i / head_dim;
2825 uint d = i % head_dim;
2826 if (q < q_valid) {
2827 float l = l_sh[q];
2828 float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
2829 uint token_idx = q_block_start + q;
2830 uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
2831 output_batch[out_off] = o_sh[i] * inv_l;
2832 }
2833 }
2834}
2835
2836// ── rope_qk_batch ─────────────────────────────────────────────────────
2837// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2838// launch + memory barrier per layer. Both Q and K live in the same
2839// qkv_data buffer at different offsets within each token's row.
2840// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2841kernel void rope_qk_batch(
2842 device float* qkv_data [[buffer(0)]], // [M, qkv_stride]
2843 constant uint& M [[buffer(1)]], // num tokens
2844 constant uint& base_pos [[buffer(2)]], // starting position
2845 constant uint& num_q_heads [[buffer(3)]],
2846 constant uint& num_kv_heads [[buffer(4)]],
2847 constant uint& head_dim [[buffer(5)]],
2848 constant uint& qkv_stride [[buffer(6)]], // floats per row
2849 constant float& theta [[buffer(7)]],
2850 uint id [[thread_position_in_grid]])
2851{
2852 uint half_dim = head_dim / 2;
2853 uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2854 uint token = id / total_pairs;
2855 uint pair = id % total_pairs;
2856 if (token >= M) return;
2857
2858 uint pos = base_pos + token;
2859 uint q_pairs = num_q_heads * half_dim;
2860
2861 uint h, i, offset;
2862 if (pair < q_pairs) {
2863 // Q head
2864 h = pair / half_dim;
2865 i = pair % half_dim;
2866 offset = token * qkv_stride + h * head_dim + i * 2;
2867 } else {
2868 // K head
2869 uint kp = pair - q_pairs;
2870 h = kp / half_dim;
2871 i = kp % half_dim;
2872 uint k_start = num_q_heads * head_dim;
2873 offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2874 }
2875
2876 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2877 float angle = float(pos) * freq;
2878 float cos_val = cos(angle);
2879 float sin_val = sin(angle);
2880
2881 float x0 = qkv_data[offset];
2882 float x1 = qkv_data[offset + 1];
2883 qkv_data[offset] = x0 * cos_val - x1 * sin_val;
2884 qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2885}
2886
2887// ── copy_kv_both_batch ────────────────────────────────────────────────
2888// Fused K+V cache copy in a single dispatch: copies both K and V from
2889// the strided batch QKV buffer to their respective KV cache buffers.
2890// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2891kernel void copy_kv_both_batch(
2892 device const float* src [[buffer(0)]], // batch QKV buffer [M, qkv_stride]
2893 device float* k_dst [[buffer(1)]], // K cache [max_seq, kv_dim]
2894 device float* v_dst [[buffer(2)]], // V cache [max_seq, kv_dim]
2895 constant uint& M [[buffer(3)]], // num tokens in batch
2896 constant uint& kv_dim [[buffer(4)]], // floats per KV vector
2897 constant uint& base_pos [[buffer(5)]], // starting position in cache
2898 constant uint& src_stride [[buffer(6)]], // floats per row in src (qkv_stride)
2899 constant uint& k_offset [[buffer(7)]], // float offset of K within each src row
2900 constant uint& v_offset [[buffer(8)]], // float offset of V within each src row
2901 uint id [[thread_position_in_grid]])
2902{
2903 // Total elements = M * kv_dim * 2 (K + V)
2904 uint total_kv = M * kv_dim;
2905 if (id >= total_kv * 2) return;
2906
2907 uint is_v = id / total_kv; // 0 = K, 1 = V
2908 uint local_id = id % total_kv;
2909 uint token = local_id / kv_dim;
2910 uint d = local_id % kv_dim;
2911
2912 uint dst_off = (base_pos + token) * kv_dim + d;
2913 uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2914
2915 if (is_v) {
2916 v_dst[dst_off] = src[src_off];
2917 } else {
2918 k_dst[dst_off] = src[src_off];
2919 }
2920}
2921"#
2922 .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2923 .replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
2924}
2925
2926fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2931 let mut code = String::with_capacity(48 * 1024);
2932 emit_model_header(&mut code, config)?;
2933 emit_metal_model_struct(&mut code, config)?;
2934 emit_layer_buffers_struct(&mut code, config)?;
2935 emit_metal_model_impl(&mut code, config)?;
2936 emit_helper_functions(&mut code)?;
2937 Ok(code)
2938}
2939
2940fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2941 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2942 writeln!(
2943 code,
2944 "//! Model: {} ({} layers, hidden={})",
2945 config.architecture, config.num_layers, config.hidden_size
2946 )?;
2947 writeln!(code, "//!")?;
2948 writeln!(
2949 code,
2950 "//! Uses native Metal compute pipelines via the metal crate."
2951 )?;
2952 writeln!(code)?;
2953 writeln!(code, "#![allow(dead_code)]")?;
2954 writeln!(code)?;
2955 writeln!(code, "use metal::*;")?;
2956 writeln!(code, "#[allow(unused_imports)]")?;
2957 writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2958 writeln!(code, "use std::mem;")?;
2959 writeln!(code)?;
2960
2961 writeln!(
2963 code,
2964 "// ── Model constants ──────────────────────────────────"
2965 )?;
2966 writeln!(
2967 code,
2968 "pub const HIDDEN_SIZE: usize = {};",
2969 config.hidden_size
2970 )?;
2971 writeln!(
2972 code,
2973 "pub const INTERMEDIATE_SIZE: usize = {};",
2974 config.intermediate_size
2975 )?;
2976 writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2977 writeln!(
2978 code,
2979 "pub const NUM_HEADS: usize = {};",
2980 config.num_attention_heads
2981 )?;
2982 writeln!(
2983 code,
2984 "pub const NUM_KV_HEADS: usize = {};",
2985 config.num_kv_heads
2986 )?;
2987 writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
2988 writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
2989 let effective_seq_len = config.max_seq_len.min(4096);
2990 writeln!(
2991 code,
2992 "pub const MAX_SEQ_LEN: usize = {}; // capped from model's {}",
2993 effective_seq_len, config.max_seq_len
2994 )?;
2995 writeln!(
2996 code,
2997 "pub const RMS_NORM_EPS: f32 = {:e};",
2998 config.rms_norm_eps
2999 )?;
3000 writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
3001 writeln!(
3002 code,
3003 "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
3004 )?;
3005 writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
3006 writeln!(code)?;
3007
3008 Ok(())
3009}
3010
3011fn emit_metal_model_struct(
3012 code: &mut String,
3013 config: &ModelConfig,
3014) -> Result<(), MetalCodegenError> {
3015 writeln!(
3016 code,
3017 "// ── MetalModel ──────────────────────────────────────────"
3018 )?;
3019 writeln!(code)?;
3020 writeln!(
3021 code,
3022 "/// Metal-accelerated transformer model for Apple Silicon."
3023 )?;
3024 writeln!(code, "///")?;
3025 writeln!(
3026 code,
3027 "/// Uses unified memory for zero-copy weight access and native Metal"
3028 )?;
3029 writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
3030 writeln!(code, "pub struct MetalModel {{")?;
3031 writeln!(code, " device: Device,")?;
3032 writeln!(code, " queue: CommandQueue,")?;
3033 writeln!(code)?;
3034 writeln!(code, " // ── Compute pipelines ──")?;
3035 writeln!(code, " matmul_pipeline: ComputePipelineState,")?;
3036 writeln!(code, " matmul_q8_pipeline: ComputePipelineState,")?;
3037 writeln!(code, " matmul_q4_pipeline: ComputePipelineState,")?;
3038 writeln!(code, " rms_norm_pipeline: ComputePipelineState,")?;
3039 writeln!(code, " rope_pipeline: ComputePipelineState,")?;
3040 writeln!(code, " softmax_pipeline: ComputePipelineState,")?;
3041 writeln!(code, " silu_mul_pipeline: ComputePipelineState,")?;
3042 writeln!(code, " silu_mul_fused_pipeline: ComputePipelineState,")?;
3043 writeln!(code, " add_pipeline: ComputePipelineState,")?;
3044 writeln!(code, " attention_pipeline: ComputePipelineState,")?;
3045 writeln!(code, " add_inplace_pipeline: ComputePipelineState,")?;
3046 writeln!(code, " copy_pipeline: ComputePipelineState,")?;
3047 writeln!(code, " copy_offset_pipeline: ComputePipelineState,")?;
3048 writeln!(code)?;
3049 writeln!(code, " // ── Batched prefill pipelines ──")?;
3050 writeln!(code, " matmul_batch_pipeline: ComputePipelineState,")?;
3051 writeln!(code, " matmul_q8_batch_pipeline: ComputePipelineState,")?;
3052 writeln!(
3053 code,
3054 " matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
3055 )?;
3056 writeln!(code, " matmul_q8_mma_pipeline: ComputePipelineState,")?;
3057 writeln!(code, " matmul_q8_mma32_pipeline: ComputePipelineState,")?;
3058 writeln!(
3059 code,
3060 " matmul_q8_mma32_h_pipeline: ComputePipelineState,"
3061 )?;
3062 writeln!(
3063 code,
3064 " matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
3065 )?;
3066 writeln!(
3067 code,
3068 " matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
3069 )?;
3070 if config.qkv_bias {
3071 writeln!(code, " add_bias_batch_pipeline: ComputePipelineState,")?;
3072 }
3073 writeln!(code, " matmul_q4_batch_pipeline: ComputePipelineState,")?;
3074 writeln!(code, " rms_norm_batch_pipeline: ComputePipelineState,")?;
3075 writeln!(code, " rope_batch_pipeline: ComputePipelineState,")?;
3076 writeln!(
3077 code,
3078 " silu_mul_fused_batch_pipeline: ComputePipelineState,"
3079 )?;
3080 writeln!(
3081 code,
3082 " add_inplace_batch_pipeline: ComputePipelineState,"
3083 )?;
3084 writeln!(
3085 code,
3086 " copy_embedding_batch_pipeline: ComputePipelineState,"
3087 )?;
3088 writeln!(code, " attention_batch_pipeline: ComputePipelineState,")?;
3089 writeln!(
3090 code,
3091 " attention_flash_batch_pipeline: ComputePipelineState,"
3092 )?;
3093 writeln!(
3094 code,
3095 " attention_mma_flash_batch_pipeline: ComputePipelineState,"
3096 )?;
3097 writeln!(code, " copy_kv_batch_pipeline: ComputePipelineState,")?;
3098 writeln!(code, " rope_qk_batch_pipeline: ComputePipelineState,")?;
3099 writeln!(
3100 code,
3101 " copy_kv_both_batch_pipeline: ComputePipelineState,"
3102 )?;
3103 writeln!(code)?;
3104 writeln!(code, " // ── Weight buffers (Metal shared memory) ──")?;
3105 writeln!(
3106 code,
3107 " /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
3108 )?;
3109 writeln!(code, " embed_buf: Buffer,")?;
3110 writeln!(code)?;
3111 writeln!(code, " /// Per-layer weight buffers")?;
3112 writeln!(code, " layers: Vec<LayerBuffers>,")?;
3113 writeln!(code)?;
3114 writeln!(code, " /// Final layer-norm weight [HIDDEN_SIZE]")?;
3115 writeln!(code, " norm_buf: Buffer,")?;
3116 writeln!(code)?;
3117 writeln!(
3118 code,
3119 " /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
3120 )?;
3121 writeln!(code, " lm_head_buf: Buffer,")?;
3122 writeln!(code)?;
3123 writeln!(
3124 code,
3125 " // ── Working buffers (pre-allocated, reused every forward pass) ──"
3126 )?;
3127 writeln!(code, " hidden_buf: Buffer,")?;
3128 writeln!(code, " residual_buf: Buffer,")?;
3129 writeln!(code, " normed_buf: Buffer,")?;
3130 writeln!(
3131 code,
3132 " /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
3133 )?;
3134 writeln!(code, " qkv_buf: Buffer,")?;
3135 writeln!(code, " attn_out_buf: Buffer,")?;
3136 writeln!(code, " attn_proj_buf: Buffer,")?;
3137 writeln!(
3138 code,
3139 " /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
3140 )?;
3141 writeln!(code, " gate_up_buf: Buffer,")?;
3142 writeln!(code, " ffn_hidden_buf: Buffer,")?;
3143 writeln!(code, " ffn_out_buf: Buffer,")?;
3144 writeln!(code, " add_tmp_buf: Buffer,")?;
3145 writeln!(code, " logits_buf: Buffer,")?;
3146 writeln!(code)?;
3147 writeln!(code, " // ── Batched prefill working buffers ──")?;
3148 writeln!(code, " /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
3149 writeln!(code, " batch_hidden_buf: Buffer,")?;
3150 writeln!(
3151 code,
3152 " /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
3153 )?;
3154 writeln!(code, " batch_residual_buf: Buffer,")?;
3155 writeln!(code, " /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
3156 writeln!(code, " batch_qkv_buf: Buffer,")?;
3157 writeln!(
3158 code,
3159 " /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
3160 )?;
3161 writeln!(code, " batch_attn_out_buf: Buffer,")?;
3162 writeln!(
3163 code,
3164 " /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
3165 )?;
3166 writeln!(code, " batch_attn_proj_buf: Buffer,")?;
3167 writeln!(
3168 code,
3169 " /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
3170 )?;
3171 writeln!(code, " batch_gate_up_buf: Buffer,")?;
3172 writeln!(
3173 code,
3174 " /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
3175 )?;
3176 writeln!(code, " batch_ffn_hidden_buf: Buffer,")?;
3177 writeln!(code, " /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
3178 writeln!(code, " batch_ffn_out_buf: Buffer,")?;
3179 writeln!(code, " /// Token IDs buffer for batch embedding lookup")?;
3180 writeln!(code, " batch_tokens_buf: Buffer,")?;
3181 writeln!(code, " /// Positions buffer for batch RoPE")?;
3182 writeln!(code, " batch_positions_buf: Buffer,")?;
3183 writeln!(code)?;
3184 writeln!(code, " // ── KV cache buffers (per-layer) ──")?;
3185 writeln!(code, " k_cache: Vec<Buffer>, // per-layer")?;
3186 writeln!(code, " v_cache: Vec<Buffer>, // per-layer")?;
3187 writeln!(code)?;
3188 writeln!(code, " // ── Inference state ──")?;
3189 writeln!(code, " pos: usize,")?;
3190 writeln!(code)?;
3191 writeln!(
3192 code,
3193 " /// Previous command buffer for double-buffered prefill."
3194 )?;
3195 writeln!(
3196 code,
3197 " /// While the GPU executes token N, the CPU can encode token N+1."
3198 )?;
3199 writeln!(code, " prev_cmd: Option<CommandBuffer>,")?;
3200 writeln!(code, "}}")?;
3201 writeln!(code)?;
3202
3203 Ok(())
3204}
3205
3206fn emit_layer_buffers_struct(
3207 code: &mut String,
3208 config: &ModelConfig,
3209) -> Result<(), MetalCodegenError> {
3210 writeln!(
3211 code,
3212 "/// Per-layer weight buffers for attention and FFN projections."
3213 )?;
3214 writeln!(code, "struct LayerBuffers {{")?;
3215 writeln!(code, " attn_norm: Buffer,")?;
3216 writeln!(
3217 code,
3218 " /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
3219 )?;
3220 writeln!(code, " qkv_weight: Buffer,")?;
3221 if config.qkv_bias {
3222 writeln!(
3223 code,
3224 " /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
3225 )?;
3226 writeln!(code, " qkv_bias: Buffer,")?;
3227 }
3228 writeln!(code, " o_weight: Buffer,")?;
3229 writeln!(code, " ffn_norm: Buffer,")?;
3230 writeln!(
3231 code,
3232 " /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
3233 )?;
3234 writeln!(code, " gate_up_weight: Buffer,")?;
3235 writeln!(code, " down_weight: Buffer,")?;
3236 writeln!(code, "}}")?;
3237 writeln!(code)?;
3238
3239 Ok(())
3240}
3241
3242fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
3243 let hidden = config.hidden_size;
3244 let intermediate = config.intermediate_size;
3245 let _num_layers = config.num_layers;
3246 let num_heads = config.num_attention_heads;
3247 let num_kv_heads = config.num_kv_heads;
3248 let head_dim = config.head_dim;
3249 let vocab = config.vocab_size;
3250 let effective_seq_len = config.max_seq_len.min(4096);
3251 let is_q8 = config.dtype == DType::Q8_0;
3252 let is_q4 = config.dtype == DType::Q4_0;
3253 let kv_dim = num_kv_heads * head_dim;
3254
3255 writeln!(code, "impl MetalModel {{")?;
3256
3257 writeln!(
3259 code,
3260 " /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3261 )?;
3262 writeln!(code, " ///")?;
3263 writeln!(
3264 code,
3265 " /// `weights` is the raw weight blob produced by `forge export-weights`."
3266 )?;
3267 writeln!(code, " pub fn new(weights: &[u8]) -> Self {{")?;
3268 writeln!(
3269 code,
3270 " let device = Device::system_default().expect(\"no Metal device found\");"
3271 )?;
3272 writeln!(code, " let queue = device.new_command_queue();")?;
3273 writeln!(code)?;
3274
3275 writeln!(
3277 code,
3278 " // Compile Metal shaders from embedded source"
3279 )?;
3280 writeln!(
3281 code,
3282 " let shader_source = include_str!(\"../shaders/kernels.metal\");"
3283 )?;
3284 writeln!(
3285 code,
3286 " let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3287 )?;
3288 writeln!(
3289 code,
3290 " .expect(\"failed to compile Metal shaders\");"
3291 )?;
3292 writeln!(code)?;
3293
3294 writeln!(code, " // Create compute pipelines")?;
3296 for (var, fn_name) in [
3297 ("matmul_pipeline", "matmul_vec"),
3298 ("matmul_q8_pipeline", "matmul_vec_q8"),
3299 ("matmul_q4_pipeline", "matmul_vec_q4"),
3300 ("rms_norm_pipeline", "rms_norm"),
3301 ("rope_pipeline", "rope"),
3302 ("softmax_pipeline", "softmax"),
3303 ("silu_mul_pipeline", "silu_mul"),
3304 ("silu_mul_fused_pipeline", "silu_mul_fused"),
3305 ("add_pipeline", "elementwise_add"),
3306 ("attention_pipeline", "attention"),
3307 ("add_inplace_pipeline", "add_inplace"),
3308 ("copy_pipeline", "copy_buffer"),
3309 ("copy_offset_pipeline", "copy_offset"),
3310 ("matmul_batch_pipeline", "matmul_vec_batch"),
3311 ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3312 ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3313 ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3314 ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3315 ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3316 ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3317 ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3318 ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3319 ("rms_norm_batch_pipeline", "rms_norm_batch"),
3320 ("rope_batch_pipeline", "rope_batch"),
3321 ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3322 ("add_inplace_batch_pipeline", "add_inplace_batch"),
3323 ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3324 ("attention_batch_pipeline", "attention_batch"),
3325 ("attention_flash_batch_pipeline", "attention_flash_batch"),
3326 (
3327 "attention_mma_flash_batch_pipeline",
3328 "attention_mma_flash_batch",
3329 ),
3330 ("copy_kv_batch_pipeline", "copy_kv_batch"),
3331 ("rope_qk_batch_pipeline", "rope_qk_batch"),
3332 ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3333 ] {
3334 writeln!(
3335 code,
3336 " let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3337 )?;
3338 }
3339 if config.qkv_bias {
3340 writeln!(
3341 code,
3342 " let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3343 )?;
3344 }
3345 writeln!(code)?;
3346
3347 writeln!(
3349 code,
3350 " // Load weights into Metal shared-memory buffers"
3351 )?;
3352 writeln!(code, " let f32_size = mem::size_of::<f32>();")?;
3353 writeln!(code, " let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3354 writeln!(code, " let hidden_elems = HIDDEN_SIZE;")?;
3355 writeln!(code)?;
3356 writeln!(
3357 code,
3358 " let cursor = std::cell::Cell::new(0usize); // byte cursor into `weights`"
3359 )?;
3360 writeln!(code)?;
3361 writeln!(
3362 code,
3363 " // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3364 )?;
3365 writeln!(
3366 code,
3367 " let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3368 )?;
3369 writeln!(code, " let byte_len = n * f32_size;")?;
3370 writeln!(code, " let cur = cursor.get();")?;
3371 writeln!(
3372 code,
3373 " let data = &weights[cur..cur + byte_len];"
3374 )?;
3375 writeln!(code, " cursor.set(cur + byte_len);")?;
3376 writeln!(code, " device.new_buffer_with_data(")?;
3377 writeln!(code, " data.as_ptr() as *const _,")?;
3378 writeln!(code, " byte_len as u64,")?;
3379 writeln!(
3380 code,
3381 " MTLResourceOptions::StorageModeShared,"
3382 )?;
3383 writeln!(code, " )")?;
3384 writeln!(code, " }};")?;
3385 writeln!(code)?;
3386
3387 if is_q8 {
3388 writeln!(
3393 code,
3394 " // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3395 )?;
3396 writeln!(
3397 code,
3398 " // as raw bytes into a Metal buffer (no dequantization)."
3399 )?;
3400 writeln!(
3401 code,
3402 " // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3403 )?;
3404 writeln!(
3405 code,
3406 " let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3407 )?;
3408 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3409 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3410 writeln!(code, " let total_raw = rows * row_bytes;")?;
3411 writeln!(code, " let cur = cursor.get();")?;
3412 writeln!(
3413 code,
3414 " let data = &weights[cur..cur + total_raw];"
3415 )?;
3416 writeln!(code, " cursor.set(cur + total_raw);")?;
3417 writeln!(code, " device.new_buffer_with_data(")?;
3418 writeln!(code, " data.as_ptr() as *const _,")?;
3419 writeln!(code, " total_raw as u64,")?;
3420 writeln!(
3421 code,
3422 " MTLResourceOptions::StorageModeShared,"
3423 )?;
3424 writeln!(code, " )")?;
3425 writeln!(code, " }};")?;
3426 writeln!(code)?;
3427 writeln!(
3428 code,
3429 " // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3430 )?;
3431 writeln!(
3432 code,
3433 " // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3434 )?;
3435 writeln!(
3436 code,
3437 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3438 )?;
3439 writeln!(
3440 code,
3441 " let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3442 )?;
3443 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3444 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3445 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3446 writeln!(code, " let cur = cursor.get();")?;
3447 writeln!(
3448 code,
3449 " let data = &weights[cur..cur + total_raw];"
3450 )?;
3451 writeln!(code, " cursor.set(cur + total_raw);")?;
3452 writeln!(code, " device.new_buffer_with_data(")?;
3453 writeln!(code, " data.as_ptr() as *const _,")?;
3454 writeln!(code, " total_raw as u64,")?;
3455 writeln!(
3456 code,
3457 " MTLResourceOptions::StorageModeShared,"
3458 )?;
3459 writeln!(code, " )")?;
3460 writeln!(code, " }};")?;
3461 writeln!(code)?;
3462 }
3463
3464 if is_q4 {
3465 writeln!(
3470 code,
3471 " // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3472 )?;
3473 writeln!(
3474 code,
3475 " // as raw bytes into a Metal buffer (no dequantization)."
3476 )?;
3477 writeln!(
3478 code,
3479 " // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3480 )?;
3481 writeln!(
3482 code,
3483 " let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3484 )?;
3485 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3486 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3487 writeln!(code, " let total_raw = rows * row_bytes;")?;
3488 writeln!(code, " let cur = cursor.get();")?;
3489 writeln!(
3490 code,
3491 " let data = &weights[cur..cur + total_raw];"
3492 )?;
3493 writeln!(code, " cursor.set(cur + total_raw);")?;
3494 writeln!(code, " device.new_buffer_with_data(")?;
3495 writeln!(code, " data.as_ptr() as *const _,")?;
3496 writeln!(code, " total_raw as u64,")?;
3497 writeln!(
3498 code,
3499 " MTLResourceOptions::StorageModeShared,"
3500 )?;
3501 writeln!(code, " )")?;
3502 writeln!(code, " }};")?;
3503 writeln!(code)?;
3504 writeln!(
3505 code,
3506 " // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3507 )?;
3508 writeln!(
3509 code,
3510 " // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3511 )?;
3512 writeln!(
3513 code,
3514 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3515 )?;
3516 writeln!(
3517 code,
3518 " let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3519 )?;
3520 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3521 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3522 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3523 writeln!(code, " let cur = cursor.get();")?;
3524 writeln!(
3525 code,
3526 " let data = &weights[cur..cur + total_raw];"
3527 )?;
3528 writeln!(code, " cursor.set(cur + total_raw);")?;
3529 writeln!(code, " device.new_buffer_with_data(")?;
3530 writeln!(code, " data.as_ptr() as *const _,")?;
3531 writeln!(code, " total_raw as u64,")?;
3532 writeln!(
3533 code,
3534 " MTLResourceOptions::StorageModeShared,"
3535 )?;
3536 writeln!(code, " )")?;
3537 writeln!(code, " }};")?;
3538 writeln!(code)?;
3539 }
3540
3541 writeln!(
3542 code,
3543 " let embed_buf = next_f32_buffer(&device, embed_elems);"
3544 )?;
3545 writeln!(code)?;
3546
3547 writeln!(
3549 code,
3550 " let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3551 )?;
3552 writeln!(code, " for _layer in 0..NUM_LAYERS {{")?;
3553
3554 writeln!(
3556 code,
3557 " let attn_norm = next_f32_buffer(&device, hidden_elems);"
3558 )?;
3559
3560 let qkv_rows = hidden + 2 * kv_dim;
3561 if is_q8 {
3562 writeln!(
3564 code,
3565 " let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3566 )?;
3567 if config.qkv_bias {
3568 writeln!(
3569 code,
3570 " // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3571 )?;
3572 writeln!(
3573 code,
3574 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3575 )?;
3576 }
3577 writeln!(
3578 code,
3579 " let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3580 )?;
3581 } else if is_q4 {
3582 writeln!(
3583 code,
3584 " let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3585 )?;
3586 if config.qkv_bias {
3587 writeln!(
3588 code,
3589 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3590 )?;
3591 }
3592 writeln!(
3593 code,
3594 " let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3595 )?;
3596 } else {
3597 writeln!(
3598 code,
3599 " let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3600 )?;
3601 if config.qkv_bias {
3602 writeln!(
3603 code,
3604 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3605 )?;
3606 }
3607 writeln!(
3608 code,
3609 " let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3610 )?;
3611 }
3612
3613 writeln!(
3615 code,
3616 " let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3617 )?;
3618
3619 let gate_up_rows = 2 * intermediate;
3620 if is_q8 {
3621 writeln!(
3623 code,
3624 " let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3625 )?;
3626 writeln!(
3627 code,
3628 " let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3629 )?;
3630 } else if is_q4 {
3631 writeln!(
3633 code,
3634 " let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3635 )?;
3636 writeln!(
3637 code,
3638 " let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3639 )?;
3640 } else {
3641 writeln!(
3643 code,
3644 " let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3645 )?;
3646 writeln!(
3647 code,
3648 " let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3649 )?;
3650 }
3651
3652 writeln!(code, " layers.push(LayerBuffers {{")?;
3653 writeln!(code, " attn_norm,")?;
3654 writeln!(code, " qkv_weight,")?;
3655 if config.qkv_bias {
3656 writeln!(code, " qkv_bias,")?;
3657 }
3658 writeln!(code, " o_weight,")?;
3659 writeln!(code, " ffn_norm,")?;
3660 writeln!(code, " gate_up_weight,")?;
3661 writeln!(code, " down_weight,")?;
3662 writeln!(code, " }});")?;
3663 writeln!(code, " }}")?;
3664 writeln!(code)?;
3665
3666 writeln!(
3668 code,
3669 " let norm_buf = next_f32_buffer(&device, hidden_elems);"
3670 )?;
3671 writeln!(code)?;
3672
3673 if is_q8 {
3675 writeln!(
3676 code,
3677 " let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3678 )?;
3679 } else if is_q4 {
3680 writeln!(
3681 code,
3682 " let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3683 )?;
3684 } else {
3685 writeln!(
3686 code,
3687 " let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3688 )?;
3689 }
3690 writeln!(code)?;
3691
3692 let hidden_bytes = hidden * 4;
3694 let _kv_dim_bytes = kv_dim * 4;
3695 let intermediate_bytes = intermediate * 4;
3696 let vocab_bytes = vocab * 4;
3697 let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
3698
3699 writeln!(
3700 code,
3701 " // Allocate working buffers (shared memory for zero-copy)"
3702 )?;
3703 writeln!(
3704 code,
3705 " let opts = MTLResourceOptions::StorageModeShared;"
3706 )?;
3707 writeln!(
3708 code,
3709 " let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3710 )?;
3711 writeln!(
3712 code,
3713 " let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3714 )?;
3715 let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3716 writeln!(
3717 code,
3718 " let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3719 )?;
3720 writeln!(
3721 code,
3722 " // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3723 )?;
3724 writeln!(
3725 code,
3726 " let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3727 )?;
3728 writeln!(
3729 code,
3730 " let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3731 )?;
3732 writeln!(
3733 code,
3734 " let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3735 )?;
3736 let gate_up_buf_bytes = 2 * intermediate * 4;
3737 writeln!(
3738 code,
3739 " // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3740 )?;
3741 writeln!(
3742 code,
3743 " let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3744 )?;
3745 writeln!(
3746 code,
3747 " let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3748 )?;
3749 writeln!(
3750 code,
3751 " let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3752 )?;
3753 writeln!(
3754 code,
3755 " let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3756 )?;
3757 writeln!(
3758 code,
3759 " let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3760 )?;
3761 writeln!(code)?;
3762
3763 let batch_hidden_bytes = hidden * 4; let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3766 let batch_gate_up_bytes = 2 * intermediate * 4;
3767 let batch_intermediate_bytes = intermediate * 4;
3768 writeln!(
3769 code,
3770 " // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3771 )?;
3772 writeln!(
3773 code,
3774 " let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3775 )?;
3776 writeln!(
3777 code,
3778 " let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3779 )?;
3780 writeln!(
3781 code,
3782 " let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3783 )?;
3784 writeln!(
3785 code,
3786 " let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3787 )?;
3788 writeln!(
3789 code,
3790 " let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3791 )?;
3792 writeln!(
3793 code,
3794 " let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3795 )?;
3796 writeln!(
3797 code,
3798 " let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3799 )?;
3800 writeln!(
3801 code,
3802 " let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3803 )?;
3804 writeln!(
3805 code,
3806 " let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3807 )?;
3808 writeln!(
3809 code,
3810 " let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3811 )?;
3812 writeln!(code)?;
3813
3814 writeln!(code, " // KV cache buffers (per-layer)")?;
3816 writeln!(
3817 code,
3818 " let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3819 )?;
3820 writeln!(
3821 code,
3822 " let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3823 )?;
3824 writeln!(code, " for _ in 0..NUM_LAYERS {{")?;
3825 writeln!(
3826 code,
3827 " k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3828 )?;
3829 writeln!(
3830 code,
3831 " v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3832 )?;
3833 writeln!(code, " }}")?;
3834 writeln!(code)?;
3835
3836 writeln!(code, " Self {{")?;
3837 writeln!(code, " device,")?;
3838 writeln!(code, " queue,")?;
3839 writeln!(code, " matmul_pipeline,")?;
3840 writeln!(code, " matmul_q8_pipeline,")?;
3841 writeln!(code, " matmul_q4_pipeline,")?;
3842 writeln!(code, " rms_norm_pipeline,")?;
3843 writeln!(code, " rope_pipeline,")?;
3844 writeln!(code, " softmax_pipeline,")?;
3845 writeln!(code, " silu_mul_pipeline,")?;
3846 writeln!(code, " silu_mul_fused_pipeline,")?;
3847 writeln!(code, " add_pipeline,")?;
3848 writeln!(code, " attention_pipeline,")?;
3849 writeln!(code, " add_inplace_pipeline,")?;
3850 writeln!(code, " copy_pipeline,")?;
3851 writeln!(code, " copy_offset_pipeline,")?;
3852 writeln!(code, " matmul_batch_pipeline,")?;
3853 writeln!(code, " matmul_q8_batch_pipeline,")?;
3854 writeln!(code, " matmul_q8_gemm_batch_pipeline,")?;
3855 writeln!(code, " matmul_q8_mma_pipeline,")?;
3856 writeln!(code, " matmul_q8_mma32_pipeline,")?;
3857 writeln!(code, " matmul_q8_mma32_h_pipeline,")?;
3858 writeln!(code, " matmul_q8_mma32_h4_pipeline,")?;
3859 writeln!(code, " matmul_q8_mma32_hh4_pipeline,")?;
3860 if config.qkv_bias {
3861 writeln!(code, " add_bias_batch_pipeline,")?;
3862 }
3863 writeln!(code, " matmul_q4_batch_pipeline,")?;
3864 writeln!(code, " rms_norm_batch_pipeline,")?;
3865 writeln!(code, " rope_batch_pipeline,")?;
3866 writeln!(code, " silu_mul_fused_batch_pipeline,")?;
3867 writeln!(code, " add_inplace_batch_pipeline,")?;
3868 writeln!(code, " copy_embedding_batch_pipeline,")?;
3869 writeln!(code, " attention_batch_pipeline,")?;
3870 writeln!(code, " attention_flash_batch_pipeline,")?;
3871 writeln!(code, " attention_mma_flash_batch_pipeline,")?;
3872 writeln!(code, " copy_kv_batch_pipeline,")?;
3873 writeln!(code, " rope_qk_batch_pipeline,")?;
3874 writeln!(code, " copy_kv_both_batch_pipeline,")?;
3875 writeln!(code, " embed_buf,")?;
3876 writeln!(code, " layers,")?;
3877 writeln!(code, " norm_buf,")?;
3878 writeln!(code, " lm_head_buf,")?;
3879 writeln!(code, " hidden_buf,")?;
3880 writeln!(code, " residual_buf,")?;
3881 writeln!(code, " normed_buf,")?;
3882 writeln!(code, " qkv_buf,")?;
3883 writeln!(code, " attn_out_buf,")?;
3884 writeln!(code, " attn_proj_buf,")?;
3885 writeln!(code, " gate_up_buf,")?;
3886 writeln!(code, " ffn_hidden_buf,")?;
3887 writeln!(code, " ffn_out_buf,")?;
3888 writeln!(code, " add_tmp_buf,")?;
3889 writeln!(code, " logits_buf,")?;
3890 writeln!(code, " batch_hidden_buf,")?;
3891 writeln!(code, " batch_residual_buf,")?;
3892 writeln!(code, " batch_qkv_buf,")?;
3893 writeln!(code, " batch_attn_out_buf,")?;
3894 writeln!(code, " batch_attn_proj_buf,")?;
3895 writeln!(code, " batch_gate_up_buf,")?;
3896 writeln!(code, " batch_ffn_hidden_buf,")?;
3897 writeln!(code, " batch_ffn_out_buf,")?;
3898 writeln!(code, " batch_tokens_buf,")?;
3899 writeln!(code, " batch_positions_buf,")?;
3900 writeln!(code, " k_cache,")?;
3901 writeln!(code, " v_cache,")?;
3902 writeln!(code, " pos: 0,")?;
3903 writeln!(code, " prev_cmd: None,")?;
3904 writeln!(code, " }}")?;
3905 writeln!(code, " }}")?;
3906 writeln!(code)?;
3907
3908 writeln!(
3910 code,
3911 " /// Run the forward pass for a single token at the current position."
3912 )?;
3913 writeln!(code, " ///")?;
3914 writeln!(
3915 code,
3916 " /// Returns logits over the vocabulary as a `Vec<f32>`."
3917 )?;
3918 writeln!(code, " ///")?;
3919 writeln!(
3920 code,
3921 " /// All GPU operations are encoded into a single command buffer and"
3922 )?;
3923 writeln!(
3924 code,
3925 " /// committed once at the end, avoiding per-operation synchronization."
3926 )?;
3927 writeln!(
3928 code,
3929 " pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3930 )?;
3931 writeln!(
3932 code,
3933 " // Wait for any pending prefill command buffer"
3934 )?;
3935 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
3936 writeln!(code, " prev.wait_until_completed();")?;
3937 writeln!(code, " }}")?;
3938 writeln!(code)?;
3939 writeln!(code, " let pos = self.pos;")?;
3940 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
3941 writeln!(code)?;
3942
3943 let matmul_fn = if is_q8 {
3946 "dispatch_matmul_q8"
3947 } else if is_q4 {
3948 "dispatch_matmul_q4"
3949 } else {
3950 "dispatch_matmul"
3951 };
3952
3953 writeln!(
3954 code,
3955 " // Single compute encoder for the entire forward pass (no blit transitions)"
3956 )?;
3957 writeln!(code, " {{")?;
3958 writeln!(
3959 code,
3960 " let enc = cmd.new_compute_command_encoder();"
3961 )?;
3962 writeln!(code)?;
3963
3964 writeln!(
3966 code,
3967 " // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3968 )?;
3969 writeln!(
3970 code,
3971 " // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3972 )?;
3973 writeln!(
3974 code,
3975 " // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3976 )?;
3977 writeln!(
3978 code,
3979 " // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
3980 hidden * 4,
3981 )?;
3982 writeln!(code, " unsafe {{")?;
3983 writeln!(
3984 code,
3985 " let embed_ptr = self.embed_buf.contents() as *const f32;"
3986 )?;
3987 writeln!(
3988 code,
3989 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3990 )?;
3991 writeln!(
3992 code,
3993 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
3994 )?;
3995 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
3996 writeln!(
3997 code,
3998 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3999 )?;
4000 writeln!(code, " hidden_ptr,")?;
4001 writeln!(code, " HIDDEN_SIZE,")?;
4002 writeln!(code, " );")?;
4003 writeln!(
4004 code,
4005 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4006 )?;
4007 writeln!(code, " }}")?;
4008 writeln!(code)?;
4009
4010 writeln!(code, " // 2. Transformer layers")?;
4012 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4013 writeln!(code)?;
4014 let q_byte_offset = 0usize;
4015 let k_byte_offset = hidden * 4;
4016 let v_byte_offset = (hidden + kv_dim) * 4;
4017
4018 writeln!(
4019 code,
4020 " // Pre-attention: rms_norm, fused QKV projection, RoPE"
4021 )?;
4022 writeln!(
4023 code,
4024 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4025 )?;
4026 writeln!(
4027 code,
4028 " // Fused Q+K+V matmul: single dispatch for all three projections"
4029 )?;
4030 writeln!(
4031 code,
4032 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4033 )?;
4034 if config.qkv_bias {
4035 writeln!(
4036 code,
4037 " // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
4038 )?;
4039 writeln!(
4040 code,
4041 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4042 )?;
4043 }
4044 writeln!(
4045 code,
4046 " // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
4047 )?;
4048 writeln!(
4049 code,
4050 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
4051 )?;
4052 writeln!(
4053 code,
4054 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4055 )?;
4056 writeln!(code)?;
4057 writeln!(
4058 code,
4059 " // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
4060 )?;
4061 writeln!(code, " let kv_offset = pos * {kv_dim};")?;
4062 writeln!(
4063 code,
4064 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4065 )?;
4066 writeln!(
4067 code,
4068 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4069 )?;
4070 writeln!(code)?;
4071 writeln!(
4072 code,
4073 " // Attention using Q from qkv_buf (offset 0)"
4074 )?;
4075 writeln!(
4076 code,
4077 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4078 )?;
4079 writeln!(
4080 code,
4081 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4082 )?;
4083 writeln!(
4084 code,
4085 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4086 )?;
4087 writeln!(
4088 code,
4089 " // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
4090 )?;
4091 writeln!(
4092 code,
4093 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4094 )?;
4095 writeln!(
4096 code,
4097 " // Fused gate+up matmul: single dispatch for both projections"
4098 )?;
4099 writeln!(
4100 code,
4101 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4102 )?;
4103 writeln!(
4104 code,
4105 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4106 )?;
4107 writeln!(
4108 code,
4109 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4110 )?;
4111 writeln!(
4112 code,
4113 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4114 )?;
4115 writeln!(code, " }}")?;
4116 writeln!(code)?;
4117
4118 writeln!(code, " // 3. Final RMS norm + logits projection")?;
4120 writeln!(
4121 code,
4122 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4123 )?;
4124 writeln!(
4125 code,
4126 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4127 )?;
4128 writeln!(code)?;
4129 writeln!(code, " enc.end_encoding();")?;
4130 writeln!(code, " }}")?;
4131 writeln!(code)?;
4132
4133 writeln!(
4135 code,
4136 " // 5. Commit all GPU work and wait for completion"
4137 )?;
4138 writeln!(code, " cmd.commit();")?;
4139 writeln!(code, " cmd.wait_until_completed();")?;
4140 writeln!(code)?;
4141 writeln!(code, " // 6. Read back logits from GPU")?;
4142 writeln!(code, " let logits = unsafe {{")?;
4143 writeln!(
4144 code,
4145 " let ptr = self.logits_buf.contents() as *const f32;"
4146 )?;
4147 writeln!(
4148 code,
4149 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4150 )?;
4151 writeln!(code, " }};")?;
4152 writeln!(code)?;
4153 writeln!(code, " self.pos += 1;")?;
4154 writeln!(code, " logits")?;
4155 writeln!(code, " }}")?;
4156 writeln!(code)?;
4157
4158 writeln!(
4160 code,
4161 " /// Profiling forward pass that prints per-stage GPU timing."
4162 )?;
4163 writeln!(code, " ///")?;
4164 writeln!(
4165 code,
4166 " /// Each stage is committed and waited on separately so that GPU timestamps"
4167 )?;
4168 writeln!(
4169 code,
4170 " /// accurately reflect per-operation cost. This is slower than `forward()` due"
4171 )?;
4172 writeln!(
4173 code,
4174 " /// to the per-stage synchronization, but useful for identifying bottlenecks."
4175 )?;
4176 writeln!(
4177 code,
4178 " pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
4179 )?;
4180 writeln!(code, " use std::time::Instant;")?;
4181 writeln!(code)?;
4182 writeln!(
4183 code,
4184 " // Wait for any pending prefill command buffer"
4185 )?;
4186 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4187 writeln!(code, " prev.wait_until_completed();")?;
4188 writeln!(code, " }}")?;
4189 writeln!(code)?;
4190 writeln!(code, " let pos = self.pos;")?;
4191 writeln!(code)?;
4192
4193 writeln!(
4195 code,
4196 " // ── Stage: Embedding lookup (CPU via unified memory) ──"
4197 )?;
4198 writeln!(code, " let t_embed = Instant::now();")?;
4199 writeln!(code, " unsafe {{")?;
4200 writeln!(
4201 code,
4202 " let embed_ptr = self.embed_buf.contents() as *const f32;"
4203 )?;
4204 writeln!(
4205 code,
4206 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
4207 )?;
4208 writeln!(
4209 code,
4210 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
4211 )?;
4212 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
4213 writeln!(
4214 code,
4215 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
4216 )?;
4217 writeln!(code, " hidden_ptr,")?;
4218 writeln!(code, " HIDDEN_SIZE,")?;
4219 writeln!(code, " );")?;
4220 writeln!(
4221 code,
4222 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
4223 )?;
4224 writeln!(code, " }}")?;
4225 writeln!(code, " let d_embed = t_embed.elapsed();")?;
4226 writeln!(code)?;
4227
4228 writeln!(code, " // ── Stage: Transformer layers (GPU) ──")?;
4230 writeln!(code, " let t_layers = Instant::now();")?;
4231 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4232 writeln!(code, " {{")?;
4233 writeln!(
4234 code,
4235 " let enc = cmd.new_compute_command_encoder();"
4236 )?;
4237 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4238 writeln!(
4239 code,
4240 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
4241 )?;
4242 writeln!(
4243 code,
4244 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
4245 )?;
4246 if config.qkv_bias {
4247 writeln!(
4248 code,
4249 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
4250 )?;
4251 }
4252 writeln!(
4253 code,
4254 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
4255 )?;
4256 writeln!(
4257 code,
4258 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4259 )?;
4260 writeln!(code, " let kv_offset = pos * {kv_dim};")?;
4261 writeln!(
4262 code,
4263 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4264 )?;
4265 writeln!(
4266 code,
4267 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4268 )?;
4269 writeln!(
4270 code,
4271 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4272 )?;
4273 writeln!(
4274 code,
4275 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4276 )?;
4277 writeln!(
4278 code,
4279 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4280 )?;
4281 writeln!(
4282 code,
4283 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4284 )?;
4285 writeln!(
4286 code,
4287 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4288 )?;
4289 writeln!(
4290 code,
4291 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4292 )?;
4293 writeln!(
4294 code,
4295 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4296 )?;
4297 writeln!(
4298 code,
4299 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4300 )?;
4301 writeln!(code, " }}")?;
4302 writeln!(code, " enc.end_encoding();")?;
4303 writeln!(code, " }}")?;
4304 writeln!(code, " cmd.commit();")?;
4305 writeln!(code, " cmd.wait_until_completed();")?;
4306 writeln!(code, " let d_layers = t_layers.elapsed();")?;
4307 writeln!(code)?;
4308
4309 writeln!(code, " // ── Stage: Final norm + logits (GPU) ──")?;
4311 writeln!(code, " let t_logits = Instant::now();")?;
4312 writeln!(code, " let cmd2 = self.queue.new_command_buffer();")?;
4313 writeln!(code, " {{")?;
4314 writeln!(
4315 code,
4316 " let enc = cmd2.new_compute_command_encoder();"
4317 )?;
4318 writeln!(
4319 code,
4320 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4321 )?;
4322 writeln!(
4323 code,
4324 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4325 )?;
4326 writeln!(code, " enc.end_encoding();")?;
4327 writeln!(code, " }}")?;
4328 writeln!(code, " cmd2.commit();")?;
4329 writeln!(code, " cmd2.wait_until_completed();")?;
4330 writeln!(code, " let d_logits = t_logits.elapsed();")?;
4331 writeln!(code)?;
4332
4333 writeln!(
4335 code,
4336 " eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4337 )?;
4338 writeln!(code, " d_embed.as_secs_f64() * 1000.0,")?;
4339 writeln!(code, " d_layers.as_secs_f64() * 1000.0,")?;
4340 writeln!(code, " d_logits.as_secs_f64() * 1000.0,")?;
4341 writeln!(
4342 code,
4343 " (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4344 )?;
4345 writeln!(code)?;
4346
4347 writeln!(code, " let logits = unsafe {{")?;
4349 writeln!(
4350 code,
4351 " let ptr = self.logits_buf.contents() as *const f32;"
4352 )?;
4353 writeln!(
4354 code,
4355 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4356 )?;
4357 writeln!(code, " }};")?;
4358 writeln!(code)?;
4359 writeln!(code, " self.pos += 1;")?;
4360 writeln!(code, " logits")?;
4361 writeln!(code, " }}")?;
4362 writeln!(code)?;
4363
4364 writeln!(
4366 code,
4367 " /// Asynchronous forward pass for a single prefill token (no logits readback)."
4368 )?;
4369 writeln!(code, " ///")?;
4370 writeln!(
4371 code,
4372 " /// Commits the command buffer without waiting, enabling double-buffered"
4373 )?;
4374 writeln!(
4375 code,
4376 " /// execution: GPU processes token N while CPU encodes token N+1."
4377 )?;
4378 writeln!(
4379 code,
4380 " pub fn forward_prefill(&mut self, token_id: u32) {{"
4381 )?;
4382 writeln!(code, " self.forward_prefill_batch(&[token_id]);")?;
4383 writeln!(code, " }}")?;
4384 writeln!(code)?;
4385
4386 let batch_matmul_fn = if is_q8 {
4389 "dispatch_matmul_q8_batch"
4390 } else if is_q4 {
4391 "dispatch_matmul_q4_batch"
4392 } else {
4393 "dispatch_matmul_batch"
4394 };
4395
4396 writeln!(
4397 code,
4398 " /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4399 )?;
4400 writeln!(code, " ///")?;
4401 writeln!(
4402 code,
4403 " /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4404 )?;
4405 writeln!(
4406 code,
4407 " /// of mat-vec), and batched causal attention with a single GPU dispatch."
4408 )?;
4409 writeln!(
4410 code,
4411 " /// This provides significant speedup during prompt prefill."
4412 )?;
4413 writeln!(
4414 code,
4415 " pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4416 )?;
4417 writeln!(code, " if tokens.is_empty() {{ return; }}")?;
4418 writeln!(
4419 code,
4420 " // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
4421 )?;
4422 writeln!(
4423 code,
4424 " // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
4425 )?;
4426 writeln!(
4427 code,
4428 " // longer than that must be processed iteratively. The KV cache"
4429 )?;
4430 writeln!(code, " // carries state across chunks via self.pos.")?;
4431 writeln!(
4432 code,
4433 " for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
4434 )?;
4435 writeln!(code, " let m = chunk.len();")?;
4436 writeln!(code, " if m == 0 {{ continue; }}")?;
4437 writeln!(code, " let start_pos = self.pos;")?;
4438 writeln!(code)?;
4439 writeln!(code, " // Wait for any pending command buffer")?;
4440 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4441 writeln!(code, " prev.wait_until_completed();")?;
4442 writeln!(code, " }}")?;
4443 writeln!(code)?;
4444
4445 writeln!(
4447 code,
4448 " // Upload token IDs and positions to GPU buffers"
4449 )?;
4450 writeln!(code, " unsafe {{")?;
4451 writeln!(
4452 code,
4453 " let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4454 )?;
4455 writeln!(
4456 code,
4457 " let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4458 )?;
4459 writeln!(code, " for i in 0..m {{")?;
4460 writeln!(code, " *tok_ptr.add(i) = chunk[i];")?;
4461 writeln!(
4462 code,
4463 " *pos_ptr.add(i) = (start_pos + i) as u32;"
4464 )?;
4465 writeln!(code, " }}")?;
4466 writeln!(code, " }}")?;
4467 writeln!(code)?;
4468
4469 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4470 writeln!(code, " {{")?;
4471 writeln!(
4472 code,
4473 " let enc = cmd.new_compute_command_encoder();"
4474 )?;
4475 writeln!(code)?;
4476
4477 writeln!(
4479 code,
4480 " // 1. Batch embedding lookup: copy all token embeddings at once"
4481 )?;
4482 writeln!(
4483 code,
4484 " self.dispatch_copy_embedding_batch(&enc, m);"
4485 )?;
4486 writeln!(
4488 code,
4489 " self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4490 )?;
4491 writeln!(code)?;
4492
4493 writeln!(code, " // 2. Transformer layers")?;
4495 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4496 writeln!(code)?;
4497
4498 writeln!(
4500 code,
4501 " // Batch RMS norm: batch_residual -> batch_hidden"
4502 )?;
4503 writeln!(
4504 code,
4505 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4506 )?;
4507
4508 writeln!(
4510 code,
4511 " // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4512 )?;
4513 writeln!(
4514 code,
4515 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4516 )?;
4517 if config.qkv_bias {
4518 writeln!(
4519 code,
4520 " // Qwen2: broadcast-add QKV bias across all M tokens."
4521 )?;
4522 writeln!(
4523 code,
4524 " self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4525 )?;
4526 }
4527 writeln!(code)?;
4528
4529 let k_float_offset = hidden;
4531 writeln!(
4532 code,
4533 " // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4534 )?;
4535 writeln!(
4536 code,
4537 " self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4538 )?;
4539 writeln!(code)?;
4540
4541 let v_float_offset = hidden + kv_dim;
4543 writeln!(
4544 code,
4545 " // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4546 )?;
4547 writeln!(
4548 code,
4549 " self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4550 )?;
4551 writeln!(code)?;
4552
4553 writeln!(
4555 code,
4556 " // Batched causal attention: one dispatch for all M tokens"
4557 )?;
4558 writeln!(
4559 code,
4560 " self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4561 )?;
4562 writeln!(code)?;
4563
4564 writeln!(code, " // Batched O projection")?;
4566 writeln!(
4567 code,
4568 " self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4569 )?;
4570 writeln!(code)?;
4571
4572 writeln!(
4574 code,
4575 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4576 )?;
4577 writeln!(code)?;
4578
4579 writeln!(
4581 code,
4582 " // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4583 )?;
4584 writeln!(
4585 code,
4586 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4587 )?;
4588 writeln!(
4589 code,
4590 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4591 )?;
4592 writeln!(
4593 code,
4594 " self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4595 )?;
4596 writeln!(
4597 code,
4598 " self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4599 )?;
4600 writeln!(
4601 code,
4602 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4603 )?;
4604 writeln!(code, " }}")?;
4605 writeln!(code)?;
4606
4607 writeln!(
4609 code,
4610 " // Copy last token's residual to single-token buffer for subsequent forward()"
4611 )?;
4612 writeln!(
4613 code,
4614 " self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4615 )?;
4616 writeln!(code)?;
4617 writeln!(code, " enc.end_encoding();")?;
4618 writeln!(code, " }}")?;
4619 writeln!(code)?;
4620
4621 writeln!(code, " cmd.commit();")?;
4622 writeln!(code, " self.prev_cmd = Some(cmd.to_owned());")?;
4623 writeln!(code, " self.pos += m;")?;
4624 writeln!(code, " }} // end for chunk")?;
4625 writeln!(code, " }}")?;
4626 writeln!(code)?;
4627
4628 writeln!(
4630 code,
4631 " /// Reset the model state for a new inference request."
4632 )?;
4633 writeln!(code, " pub fn reset(&mut self) {{")?;
4634 writeln!(code, " self.pos = 0;")?;
4635 writeln!(code, " self.prev_cmd = None;")?;
4636 writeln!(code, " }}")?;
4637 writeln!(code)?;
4638
4639 writeln!(
4641 code,
4642 " // ── Dispatch helpers (append to a shared compute command encoder) ──"
4643 )?;
4644 writeln!(
4645 code,
4646 " // These methods set pipeline state + buffers + dispatch on an existing"
4647 )?;
4648 writeln!(
4649 code,
4650 " // encoder, avoiding per-operation encoder creation overhead."
4651 )?;
4652 writeln!(code)?;
4653
4654 writeln!(
4656 code,
4657 " /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4658 )?;
4659 writeln!(
4660 code,
4661 " fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4662 )?;
4663 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
4664 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
4665 writeln!(
4666 code,
4667 " enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4668 )?;
4669 writeln!(
4670 code,
4671 " enc.set_buffer(0, Some(&self.residual_buf), 0);"
4672 )?;
4673 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
4674 writeln!(
4675 code,
4676 " enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4677 )?;
4678 writeln!(
4679 code,
4680 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4681 )?;
4682 writeln!(
4683 code,
4684 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4685 )?;
4686 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4687 writeln!(
4688 code,
4689 " let grid_size = MTLSize::new(1, 1, 1); // single threadgroup"
4690 )?;
4691 writeln!(
4692 code,
4693 " enc.dispatch_thread_groups(grid_size, tg_size);"
4694 )?;
4695 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4696 writeln!(code, " }}")?;
4697 writeln!(code)?;
4698
4699 writeln!(
4701 code,
4702 " /// Dispatch matrix-vector multiply: weight * input -> output."
4703 )?;
4704 writeln!(
4705 code,
4706 " fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4707 )?;
4708 writeln!(code, " let r: u32 = rows as u32;")?;
4709 writeln!(code, " let c: u32 = cols as u32;")?;
4710 writeln!(
4711 code,
4712 " enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4713 )?;
4714 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4715 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4716 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4717 writeln!(
4718 code,
4719 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4720 )?;
4721 writeln!(
4722 code,
4723 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4724 )?;
4725 writeln!(
4726 code,
4727 " // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4728 )?;
4729 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4730 writeln!(code, " let num_tg = ((rows + 63) / 64) as u64;")?;
4731 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4732 writeln!(
4733 code,
4734 " enc.dispatch_thread_groups(grid_size, tg_size);"
4735 )?;
4736 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4737 writeln!(code, " }}")?;
4738 writeln!(code)?;
4739
4740 writeln!(
4742 code,
4743 " /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4744 )?;
4745 writeln!(
4746 code,
4747 " /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4748 )?;
4749 writeln!(
4750 code,
4751 " fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4752 )?;
4753 writeln!(code, " let r: u32 = rows as u32;")?;
4754 writeln!(code, " let c: u32 = cols as u32;")?;
4755 writeln!(
4756 code,
4757 " enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4758 )?;
4759 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4760 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4761 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4762 writeln!(
4763 code,
4764 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4765 )?;
4766 writeln!(
4767 code,
4768 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4769 )?;
4770 writeln!(
4771 code,
4772 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4773 )?;
4774 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4775 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4776 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4777 writeln!(
4778 code,
4779 " enc.dispatch_thread_groups(grid_size, tg_size);"
4780 )?;
4781 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4782 writeln!(code, " }}")?;
4783 writeln!(code)?;
4784
4785 writeln!(
4787 code,
4788 " /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4789 )?;
4790 writeln!(
4791 code,
4792 " /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4793 )?;
4794 writeln!(
4795 code,
4796 " fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4797 )?;
4798 writeln!(code, " let r: u32 = rows as u32;")?;
4799 writeln!(code, " let c: u32 = cols as u32;")?;
4800 writeln!(
4801 code,
4802 " enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4803 )?;
4804 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4805 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4806 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4807 writeln!(
4808 code,
4809 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4810 )?;
4811 writeln!(
4812 code,
4813 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4814 )?;
4815 writeln!(
4816 code,
4817 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4818 )?;
4819 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4820 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4821 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4822 writeln!(
4823 code,
4824 " enc.dispatch_thread_groups(grid_size, tg_size);"
4825 )?;
4826 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4827 writeln!(code, " }}")?;
4828 writeln!(code)?;
4829
4830 writeln!(code, " /// Dispatch RoPE on a buffer in-place.")?;
4832 writeln!(
4833 code,
4834 " fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4835 )?;
4836 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4837 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4838 writeln!(code, " let p: u32 = pos as u32;")?;
4839 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4840 writeln!(
4841 code,
4842 " let total_pairs = num_heads * (head_dim / 2);"
4843 )?;
4844 writeln!(
4845 code,
4846 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4847 )?;
4848 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
4849 writeln!(
4850 code,
4851 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4852 )?;
4853 writeln!(
4854 code,
4855 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4856 )?;
4857 writeln!(
4858 code,
4859 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4860 )?;
4861 writeln!(
4862 code,
4863 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4864 )?;
4865 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4866 writeln!(
4867 code,
4868 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4869 )?;
4870 writeln!(
4871 code,
4872 " enc.dispatch_thread_groups(grid_size, tg_size);"
4873 )?;
4874 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4875 writeln!(code, " }}")?;
4876 writeln!(code)?;
4877
4878 writeln!(
4880 code,
4881 " /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4882 )?;
4883 writeln!(
4884 code,
4885 " fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4886 )?;
4887 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4888 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4889 writeln!(code, " let p: u32 = pos as u32;")?;
4890 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4891 writeln!(
4892 code,
4893 " let total_pairs = num_heads * (head_dim / 2);"
4894 )?;
4895 writeln!(
4896 code,
4897 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4898 )?;
4899 writeln!(
4900 code,
4901 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
4902 )?;
4903 writeln!(
4904 code,
4905 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4906 )?;
4907 writeln!(
4908 code,
4909 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4910 )?;
4911 writeln!(
4912 code,
4913 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4914 )?;
4915 writeln!(
4916 code,
4917 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4918 )?;
4919 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4920 writeln!(
4921 code,
4922 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4923 )?;
4924 writeln!(
4925 code,
4926 " enc.dispatch_thread_groups(grid_size, tg_size);"
4927 )?;
4928 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4929 writeln!(code, " }}")?;
4930 writeln!(code)?;
4931
4932 writeln!(
4934 code,
4935 " /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4936 )?;
4937 writeln!(
4938 code,
4939 " fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4940 )?;
4941 writeln!(code, " let sl: u32 = seq_len as u32;")?;
4942 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
4943 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
4944 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
4945 writeln!(
4946 code,
4947 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
4948 )?;
4949 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
4950 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
4951 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
4952 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
4953 writeln!(
4954 code,
4955 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4956 )?;
4957 writeln!(
4958 code,
4959 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4960 )?;
4961 writeln!(
4962 code,
4963 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4964 )?;
4965 writeln!(
4966 code,
4967 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4968 )?;
4969 writeln!(
4970 code,
4971 " // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4972 )?;
4973 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4974 writeln!(
4975 code,
4976 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4977 )?;
4978 writeln!(
4979 code,
4980 " enc.dispatch_thread_groups(grid_size, tg_size);"
4981 )?;
4982 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4983 writeln!(code, " }}")?;
4984 writeln!(code)?;
4985
4986 writeln!(
4988 code,
4989 " /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4990 )?;
4991 writeln!(
4992 code,
4993 " fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4994 )?;
4995 writeln!(code, " let sl: u32 = seq_len as u32;")?;
4996 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
4997 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
4998 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
4999 writeln!(
5000 code,
5001 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
5002 )?;
5003 writeln!(
5004 code,
5005 " enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
5006 )?;
5007 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
5008 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
5009 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
5010 writeln!(
5011 code,
5012 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
5013 )?;
5014 writeln!(
5015 code,
5016 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5017 )?;
5018 writeln!(
5019 code,
5020 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5021 )?;
5022 writeln!(
5023 code,
5024 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5025 )?;
5026 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5027 writeln!(
5028 code,
5029 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
5030 )?;
5031 writeln!(
5032 code,
5033 " enc.dispatch_thread_groups(grid_size, tg_size);"
5034 )?;
5035 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5036 writeln!(code, " }}")?;
5037 writeln!(code)?;
5038
5039 writeln!(code, " /// Dispatch fused SiLU-multiply kernel.")?;
5041 writeln!(
5042 code,
5043 " fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
5044 )?;
5045 writeln!(code, " let count: u32 = n as u32;")?;
5046 writeln!(
5047 code,
5048 " enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
5049 )?;
5050 writeln!(code, " enc.set_buffer(0, Some(gate), 0);")?;
5051 writeln!(code, " enc.set_buffer(1, Some(up), 0);")?;
5052 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5053 writeln!(
5054 code,
5055 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5056 )?;
5057 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5058 writeln!(
5059 code,
5060 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5061 )?;
5062 writeln!(
5063 code,
5064 " enc.dispatch_thread_groups(grid_size, tg_size);"
5065 )?;
5066 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5067 writeln!(code, " }}")?;
5068 writeln!(code)?;
5069
5070 writeln!(
5072 code,
5073 " /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
5074 )?;
5075 writeln!(
5076 code,
5077 " /// gate_up_buf contains [gate(n), up(n)] contiguously."
5078 )?;
5079 writeln!(
5080 code,
5081 " fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
5082 )?;
5083 writeln!(code, " let count: u32 = n as u32;")?;
5084 writeln!(
5085 code,
5086 " enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
5087 )?;
5088 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5089 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5090 writeln!(
5091 code,
5092 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5093 )?;
5094 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5095 writeln!(
5096 code,
5097 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5098 )?;
5099 writeln!(
5100 code,
5101 " enc.dispatch_thread_groups(grid_size, tg_size);"
5102 )?;
5103 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5104 writeln!(code, " }}")?;
5105 writeln!(code)?;
5106
5107 writeln!(
5109 code,
5110 " /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
5111 )?;
5112 writeln!(
5113 code,
5114 " fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5115 )?;
5116 writeln!(code, " let n: u32 = count as u32;")?;
5117 writeln!(
5118 code,
5119 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5120 )?;
5121 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5122 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5123 writeln!(
5124 code,
5125 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5126 )?;
5127 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5128 writeln!(
5129 code,
5130 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5131 )?;
5132 writeln!(
5133 code,
5134 " enc.dispatch_thread_groups(grid_size, tg_size);"
5135 )?;
5136 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5137 writeln!(code, " }}")?;
5138 writeln!(code)?;
5139
5140 writeln!(
5142 code,
5143 " /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
5144 )?;
5145 writeln!(
5146 code,
5147 " /// Used for embedding table lookup (copy a specific row)."
5148 )?;
5149 writeln!(
5150 code,
5151 " fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
5152 )?;
5153 writeln!(code, " let off: u32 = src_offset as u32;")?;
5154 writeln!(code, " let n: u32 = count as u32;")?;
5155 writeln!(
5156 code,
5157 " enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
5158 )?;
5159 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5160 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5161 writeln!(
5162 code,
5163 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
5164 )?;
5165 writeln!(
5166 code,
5167 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5168 )?;
5169 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5170 writeln!(
5171 code,
5172 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5173 )?;
5174 writeln!(
5175 code,
5176 " enc.dispatch_thread_groups(grid_size, tg_size);"
5177 )?;
5178 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5179 writeln!(code, " }}")?;
5180 writeln!(code)?;
5181
5182 writeln!(
5184 code,
5185 " /// Dispatch copy from source at byte offset to destination at float offset."
5186 )?;
5187 writeln!(
5188 code,
5189 " /// Used for KV cache updates from fused QKV buffer."
5190 )?;
5191 writeln!(
5192 code,
5193 " fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5194 )?;
5195 writeln!(code, " let n: u32 = count as u32;")?;
5196 writeln!(
5197 code,
5198 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5199 )?;
5200 writeln!(
5201 code,
5202 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5203 )?;
5204 writeln!(
5205 code,
5206 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5207 )?;
5208 writeln!(
5209 code,
5210 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5211 )?;
5212 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5213 writeln!(
5214 code,
5215 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5216 )?;
5217 writeln!(
5218 code,
5219 " enc.dispatch_thread_groups(grid_size, tg_size);"
5220 )?;
5221 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5222 writeln!(code, " }}")?;
5223 writeln!(code)?;
5224
5225 writeln!(
5227 code,
5228 " /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
5229 )?;
5230 writeln!(
5231 code,
5232 " /// Used for KV cache updates (write to a specific position in the cache)."
5233 )?;
5234 writeln!(
5235 code,
5236 " fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
5237 )?;
5238 writeln!(code, " let n: u32 = count as u32;")?;
5239 writeln!(
5240 code,
5241 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5242 )?;
5243 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5244 writeln!(
5245 code,
5246 " enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
5247 )?;
5248 writeln!(
5249 code,
5250 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5251 )?;
5252 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5253 writeln!(
5254 code,
5255 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5256 )?;
5257 writeln!(
5258 code,
5259 " enc.dispatch_thread_groups(grid_size, tg_size);"
5260 )?;
5261 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5262 writeln!(code, " }}")?;
5263 writeln!(code)?;
5264
5265 writeln!(
5267 code,
5268 " /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
5269 )?;
5270 writeln!(
5271 code,
5272 " fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
5273 )?;
5274 writeln!(code, " let count: u32 = n as u32;")?;
5275 writeln!(
5276 code,
5277 " enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5278 )?;
5279 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5280 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5281 writeln!(
5282 code,
5283 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5284 )?;
5285 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5286 writeln!(
5287 code,
5288 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5289 )?;
5290 writeln!(
5291 code,
5292 " enc.dispatch_thread_groups(grid_size, tg_size);"
5293 )?;
5294 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5295 writeln!(code, " }}")?;
5296 writeln!(code)?;
5297
5298 writeln!(code, " // ── Batched prefill dispatch helpers ──")?;
5300 writeln!(code)?;
5301
5302 writeln!(
5304 code,
5305 " /// Dispatch batched embedding lookup: copy M token embeddings at once."
5306 )?;
5307 writeln!(
5308 code,
5309 " fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5310 )?;
5311 writeln!(code, " let dim: u32 = HIDDEN_SIZE as u32;")?;
5312 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5313 writeln!(
5314 code,
5315 " enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5316 )?;
5317 writeln!(code, " enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5318 writeln!(
5319 code,
5320 " enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5321 )?;
5322 writeln!(
5323 code,
5324 " enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5325 )?;
5326 writeln!(
5327 code,
5328 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5329 )?;
5330 writeln!(
5331 code,
5332 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5333 )?;
5334 writeln!(code, " let total = num_tokens * HIDDEN_SIZE;")?;
5335 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5336 writeln!(
5337 code,
5338 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5339 )?;
5340 writeln!(
5341 code,
5342 " enc.dispatch_thread_groups(grid_size, tg_size);"
5343 )?;
5344 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5345 writeln!(code, " }}")?;
5346 writeln!(code)?;
5347
5348 writeln!(
5350 code,
5351 " /// Dispatch batched RMS norm: normalizes M vectors at once."
5352 )?;
5353 writeln!(
5354 code,
5355 " fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5356 )?;
5357 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
5358 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
5359 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5360 writeln!(
5361 code,
5362 " enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5363 )?;
5364 writeln!(code, " enc.set_buffer(0, Some(input), 0);")?;
5365 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
5366 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5367 writeln!(
5368 code,
5369 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5370 )?;
5371 writeln!(
5372 code,
5373 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5374 )?;
5375 writeln!(
5376 code,
5377 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5378 )?;
5379 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5380 writeln!(
5381 code,
5382 " let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5383 )?;
5384 writeln!(
5385 code,
5386 " enc.dispatch_thread_groups(grid_size, tg_size);"
5387 )?;
5388 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5389 writeln!(code, " }}")?;
5390 writeln!(code)?;
5391
5392 writeln!(
5394 code,
5395 " /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5396 )?;
5397 writeln!(
5398 code,
5399 " fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5400 )?;
5401 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5402 writeln!(code, " let r: u32 = rows as u32;")?;
5403 writeln!(code, " let c: u32 = cols as u32;")?;
5404 writeln!(
5405 code,
5406 " enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5407 )?;
5408 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5409 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5410 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5411 writeln!(
5412 code,
5413 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5414 )?;
5415 writeln!(
5416 code,
5417 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5418 )?;
5419 writeln!(
5420 code,
5421 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5422 )?;
5423 writeln!(
5424 code,
5425 " let row_tgs = (rows + 63) / 64; // 64 rows per threadgroup for f32"
5426 )?;
5427 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5428 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5429 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5430 writeln!(
5431 code,
5432 " enc.dispatch_thread_groups(grid_size, tg_size);"
5433 )?;
5434 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5435 writeln!(code, " }}")?;
5436 writeln!(code)?;
5437
5438 writeln!(
5440 code,
5441 " /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5442 )?;
5443 writeln!(code, " ///")?;
5444 writeln!(
5445 code,
5446 " /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5447 )?;
5448 writeln!(
5449 code,
5450 " /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5451 )?;
5452 writeln!(
5453 code,
5454 " fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5455 )?;
5456 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5457 writeln!(code, " let r: u32 = rows as u32;")?;
5458 writeln!(code, " let c: u32 = cols as u32;")?;
5459 writeln!(
5460 code,
5461 " // Tile sizes must match the Metal shader constants."
5462 )?;
5463 writeln!(code, " const TOKENS_PER_TG_Q8: usize = 4;")?;
5464 writeln!(code, " const MMA_TOK_TILE: usize = 16;")?;
5465 writeln!(code, " const MMA_ROW_TILE: usize = 16;")?;
5466 writeln!(code, " const MMA32_TOK_TILE: usize = 32;")?;
5467 writeln!(code, " const MMA32_ROW_TILE: usize = 32;")?;
5468 writeln!(
5469 code,
5470 " // Hardware matrix-multiply paths (simdgroup_matrix)."
5471 )?;
5472 writeln!(
5473 code,
5474 " // Prefer the large 32×32 tile when the problem supports it — halves"
5475 )?;
5476 writeln!(
5477 code,
5478 " // dispatch count and reuses each weight load across 32 tokens."
5479 )?;
5480 writeln!(
5481 code,
5482 " if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5483 )?;
5484 writeln!(
5485 code,
5486 " // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5487 )?;
5488 writeln!(
5489 code,
5490 " // It wins at moderate prefill lengths where the GPU is wave-starved,"
5491 )?;
5492 writeln!(
5493 code,
5494 " // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5495 )?;
5496 writeln!(
5497 code,
5498 " // case (135M / 360M). Switch at cols >= 2048 — a clean split that"
5499 )?;
5500 writeln!(
5501 code,
5502 " // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5503 )?;
5504 writeln!(
5505 code,
5506 " // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5507 )?;
5508 writeln!(
5509 code,
5510 " // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5511 )?;
5512 writeln!(
5513 code,
5514 " // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5515 )?;
5516 writeln!(code, " let use_h4 = cols >= 2048;")?;
5517 writeln!(code, " let pipe = if use_h4 {{")?;
5518 writeln!(code, " if num_tokens >= 256 {{")?;
5519 writeln!(
5520 code,
5521 " &self.matmul_q8_mma32_hh4_pipeline"
5522 )?;
5523 writeln!(code, " }} else {{")?;
5524 writeln!(
5525 code,
5526 " &self.matmul_q8_mma32_h4_pipeline"
5527 )?;
5528 writeln!(code, " }}")?;
5529 writeln!(code, " }} else {{")?;
5530 writeln!(code, " &self.matmul_q8_mma32_pipeline")?;
5531 writeln!(code, " }};")?;
5532 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
5533 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5534 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5535 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5536 writeln!(
5537 code,
5538 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5539 )?;
5540 writeln!(
5541 code,
5542 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5543 )?;
5544 writeln!(
5545 code,
5546 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5547 )?;
5548 writeln!(code, " let row_tgs = rows / MMA32_ROW_TILE;")?;
5549 writeln!(
5550 code,
5551 " let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5552 )?;
5553 writeln!(
5554 code,
5555 " let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5556 )?;
5557 writeln!(
5558 code,
5559 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5560 )?;
5561 writeln!(
5562 code,
5563 " enc.dispatch_thread_groups(grid_size, tg_size);"
5564 )?;
5565 writeln!(
5566 code,
5567 " }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5568 )?;
5569 writeln!(
5570 code,
5571 " enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5572 )?;
5573 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5574 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5575 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5576 writeln!(
5577 code,
5578 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5579 )?;
5580 writeln!(
5581 code,
5582 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5583 )?;
5584 writeln!(
5585 code,
5586 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5587 )?;
5588 writeln!(code, " let row_tgs = rows / MMA_ROW_TILE;")?;
5589 writeln!(
5590 code,
5591 " let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5592 )?;
5593 writeln!(
5594 code,
5595 " let tg_size = MTLSize::new(128, 1, 1); // 4 simdgroups × 32 lanes"
5596 )?;
5597 writeln!(
5598 code,
5599 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5600 )?;
5601 writeln!(
5602 code,
5603 " enc.dispatch_thread_groups(grid_size, tg_size);"
5604 )?;
5605 writeln!(code, " }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5606 writeln!(
5607 code,
5608 " enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5609 )?;
5610 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5611 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5612 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5613 writeln!(
5614 code,
5615 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5616 )?;
5617 writeln!(
5618 code,
5619 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5620 )?;
5621 writeln!(
5622 code,
5623 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5624 )?;
5625 writeln!(
5626 code,
5627 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5628 )?;
5629 writeln!(
5630 code,
5631 " let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5632 )?;
5633 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5634 writeln!(
5635 code,
5636 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5637 )?;
5638 writeln!(
5639 code,
5640 " enc.dispatch_thread_groups(grid_size, tg_size);"
5641 )?;
5642 writeln!(code, " }} else {{")?;
5643 writeln!(
5644 code,
5645 " enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5646 )?;
5647 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5648 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5649 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5650 writeln!(
5651 code,
5652 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5653 )?;
5654 writeln!(
5655 code,
5656 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5657 )?;
5658 writeln!(
5659 code,
5660 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5661 )?;
5662 writeln!(
5663 code,
5664 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5665 )?;
5666 writeln!(
5667 code,
5668 " let num_tg = (row_tgs * num_tokens) as u64;"
5669 )?;
5670 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5671 writeln!(
5672 code,
5673 " let grid_size = MTLSize::new(num_tg, 1, 1);"
5674 )?;
5675 writeln!(
5676 code,
5677 " enc.dispatch_thread_groups(grid_size, tg_size);"
5678 )?;
5679 writeln!(code, " }}")?;
5680 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5681 writeln!(code, " }}")?;
5682 writeln!(code)?;
5683
5684 writeln!(
5686 code,
5687 " /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5688 )?;
5689 writeln!(
5690 code,
5691 " fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5692 )?;
5693 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5694 writeln!(code, " let r: u32 = rows as u32;")?;
5695 writeln!(code, " let c: u32 = cols as u32;")?;
5696 writeln!(
5697 code,
5698 " enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5699 )?;
5700 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5701 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5702 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5703 writeln!(
5704 code,
5705 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5706 )?;
5707 writeln!(
5708 code,
5709 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5710 )?;
5711 writeln!(
5712 code,
5713 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5714 )?;
5715 writeln!(
5716 code,
5717 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q4"
5718 )?;
5719 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5720 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5721 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5722 writeln!(
5723 code,
5724 " enc.dispatch_thread_groups(grid_size, tg_size);"
5725 )?;
5726 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5727 writeln!(code, " }}")?;
5728 writeln!(code)?;
5729
5730 if config.qkv_bias {
5732 writeln!(
5733 code,
5734 " /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5735 )?;
5736 writeln!(
5737 code,
5738 " fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5739 )?;
5740 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5741 writeln!(code, " let r: u32 = rows as u32;")?;
5742 writeln!(
5743 code,
5744 " enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5745 )?;
5746 writeln!(code, " enc.set_buffer(0, Some(out), 0);")?;
5747 writeln!(code, " enc.set_buffer(1, Some(bias), 0);")?;
5748 writeln!(
5749 code,
5750 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5751 )?;
5752 writeln!(
5753 code,
5754 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5755 )?;
5756 writeln!(code, " let total = (num_tokens * rows) as u64;")?;
5757 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5758 writeln!(
5759 code,
5760 " let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5761 )?;
5762 writeln!(
5763 code,
5764 " enc.dispatch_thread_groups(grid_size, tg_size);"
5765 )?;
5766 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5767 writeln!(code, " }}")?;
5768 writeln!(code)?;
5769 }
5770
5771 writeln!(
5773 code,
5774 " /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5775 )?;
5776 writeln!(
5777 code,
5778 " /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5779 )?;
5780 writeln!(
5781 code,
5782 " /// `row_stride` is the number of floats per token row in the batch buffer."
5783 )?;
5784 writeln!(
5785 code,
5786 " fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5787 )?;
5788 writeln!(code, " let nh: u32 = num_heads as u32;")?;
5789 writeln!(code, " let hd: u32 = head_dim as u32;")?;
5790 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
5791 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5792 writeln!(
5793 code,
5794 " let pairs_per_token = num_heads * (head_dim / 2);"
5795 )?;
5796 writeln!(
5797 code,
5798 " let total_pairs = num_tokens * pairs_per_token;"
5799 )?;
5800 writeln!(
5813 code,
5814 " // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5815 )?;
5816 writeln!(code, " for t in 0..num_tokens {{")?;
5817 writeln!(
5818 code,
5819 " let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5820 )?;
5821 writeln!(
5822 code,
5823 " let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5824 )?;
5825 writeln!(
5826 code,
5827 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
5828 )?;
5829 writeln!(
5830 code,
5831 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
5832 )?;
5833 writeln!(
5834 code,
5835 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5836 )?;
5837 writeln!(
5838 code,
5839 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5840 )?;
5841 writeln!(
5842 code,
5843 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5844 )?;
5845 writeln!(
5846 code,
5847 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5848 )?;
5849 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5850 writeln!(
5851 code,
5852 " let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5853 )?;
5854 writeln!(
5855 code,
5856 " enc.dispatch_thread_groups(grid_size, tg_size);"
5857 )?;
5858 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5859 writeln!(code, " }}")?;
5860 writeln!(code, " }}")?;
5861 writeln!(code)?;
5862
5863 writeln!(
5865 code,
5866 " /// Dispatch batched fused SiLU-multiply for M tokens."
5867 )?;
5868 writeln!(
5869 code,
5870 " fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5871 )?;
5872 writeln!(code, " let count: u32 = n as u32;")?;
5873 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5874 writeln!(
5875 code,
5876 " enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5877 )?;
5878 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5879 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5880 writeln!(
5881 code,
5882 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5883 )?;
5884 writeln!(
5885 code,
5886 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5887 )?;
5888 writeln!(code, " let total = num_tokens * n;")?;
5889 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5890 writeln!(
5891 code,
5892 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5893 )?;
5894 writeln!(
5895 code,
5896 " enc.dispatch_thread_groups(grid_size, tg_size);"
5897 )?;
5898 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5899 writeln!(code, " }}")?;
5900 writeln!(code)?;
5901
5902 writeln!(
5904 code,
5905 " /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5906 )?;
5907 writeln!(
5908 code,
5909 " fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5910 )?;
5911 writeln!(code, " let count: u32 = total_n as u32;")?;
5912 writeln!(
5913 code,
5914 " enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5915 )?;
5916 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5917 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5918 writeln!(
5919 code,
5920 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5921 )?;
5922 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5923 writeln!(
5924 code,
5925 " let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5926 )?;
5927 writeln!(
5928 code,
5929 " enc.dispatch_thread_groups(grid_size, tg_size);"
5930 )?;
5931 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5932 writeln!(code, " }}")?;
5933 writeln!(code)?;
5934
5935 writeln!(
5937 code,
5938 " /// Copy src to dst using compute copy kernel (for batch residual init)."
5939 )?;
5940 writeln!(
5941 code,
5942 " fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5943 )?;
5944 writeln!(code, " let n: u32 = count as u32;")?;
5945 writeln!(
5946 code,
5947 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5948 )?;
5949 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5950 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5951 writeln!(
5952 code,
5953 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5954 )?;
5955 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5956 writeln!(
5957 code,
5958 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5959 )?;
5960 writeln!(
5961 code,
5962 " enc.dispatch_thread_groups(grid_size, tg_size);"
5963 )?;
5964 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5965 writeln!(code, " }}")?;
5966 writeln!(code)?;
5967
5968 writeln!(
5970 code,
5971 " /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
5972 )?;
5973 writeln!(
5974 code,
5975 " fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5976 )?;
5977 writeln!(code, " let n: u32 = count as u32;")?;
5978 writeln!(
5979 code,
5980 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5981 )?;
5982 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5983 writeln!(
5984 code,
5985 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5986 )?;
5987 writeln!(
5988 code,
5989 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5990 )?;
5991 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5992 writeln!(
5993 code,
5994 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5995 )?;
5996 writeln!(
5997 code,
5998 " enc.dispatch_thread_groups(grid_size, tg_size);"
5999 )?;
6000 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6001 writeln!(code, " }}")?;
6002 writeln!(code)?;
6003
6004 writeln!(
6006 code,
6007 " /// Copy from src at byte offset to dst at float offset."
6008 )?;
6009 writeln!(
6010 code,
6011 " fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
6012 )?;
6013 writeln!(code, " let n: u32 = count as u32;")?;
6014 writeln!(
6015 code,
6016 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
6017 )?;
6018 writeln!(
6019 code,
6020 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
6021 )?;
6022 writeln!(
6023 code,
6024 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
6025 )?;
6026 writeln!(
6027 code,
6028 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
6029 )?;
6030 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6031 writeln!(
6032 code,
6033 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
6034 )?;
6035 writeln!(
6036 code,
6037 " enc.dispatch_thread_groups(grid_size, tg_size);"
6038 )?;
6039 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6040 writeln!(code, " }}")?;
6041 writeln!(code)?;
6042
6043 writeln!(
6045 code,
6046 " /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
6047 )?;
6048 writeln!(
6049 code,
6050 " fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
6051 )?;
6052 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6053 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6054 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6055 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6056 writeln!(code, " let so: u32 = src_offset as u32;")?;
6057 writeln!(
6058 code,
6059 " enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
6060 )?;
6061 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6062 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
6063 writeln!(
6064 code,
6065 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6066 )?;
6067 writeln!(
6068 code,
6069 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6070 )?;
6071 writeln!(
6072 code,
6073 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6074 )?;
6075 writeln!(
6076 code,
6077 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6078 )?;
6079 writeln!(
6080 code,
6081 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
6082 )?;
6083 writeln!(code, " let total = num_tokens * kv_dim;")?;
6084 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6085 writeln!(
6086 code,
6087 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6088 )?;
6089 writeln!(
6090 code,
6091 " enc.dispatch_thread_groups(grid_size, tg_size);"
6092 )?;
6093 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6094 writeln!(code, " }}")?;
6095 writeln!(code)?;
6096
6097 writeln!(
6099 code,
6100 " /// Dispatch batched causal attention: one dispatch for all M tokens."
6101 )?;
6102 writeln!(
6103 code,
6104 " fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
6105 )?;
6106 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6107 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6108 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
6109 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6110 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6111 writeln!(code, " let qs: u32 = q_stride as u32;")?;
6112 writeln!(code, " let max_seq = base_pos + num_tokens;")?;
6122 writeln!(code, " let _ = max_seq;")?;
6123 writeln!(
6124 code,
6125 " let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
6126 )?;
6127 writeln!(code, " .map(|v| v == \"0\").unwrap_or(false);")?;
6128 writeln!(
6129 code,
6130 " let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
6131 )?;
6132 writeln!(code, " if use_mma_flash {{")?;
6133 writeln!(
6134 code,
6135 " let pipe = &self.attention_mma_flash_batch_pipeline;"
6136 )?;
6137 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6138 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6139 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6140 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6141 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6142 writeln!(
6143 code,
6144 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6145 )?;
6146 writeln!(
6147 code,
6148 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6149 )?;
6150 writeln!(
6151 code,
6152 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6153 )?;
6154 writeln!(
6155 code,
6156 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6157 )?;
6158 writeln!(
6159 code,
6160 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6161 )?;
6162 writeln!(
6163 code,
6164 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6165 )?;
6166 writeln!(
6167 code,
6168 " // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
6169 )?;
6170 writeln!(code, " let tg_size = MTLSize::new(128, 1, 1);")?;
6171 writeln!(
6172 code,
6173 " let q_blocks = ((num_tokens + 7) / 8) as u64;"
6174 )?;
6175 writeln!(
6176 code,
6177 " let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
6178 )?;
6179 writeln!(
6180 code,
6181 " enc.dispatch_thread_groups(grid_size, tg_size);"
6182 )?;
6183 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6184 writeln!(code, " return;")?;
6185 writeln!(code, " }}")?;
6186 writeln!(code, " let pipe = &self.attention_batch_pipeline;")?;
6187 writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
6188 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
6189 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
6190 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
6191 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
6192 writeln!(
6193 code,
6194 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6195 )?;
6196 writeln!(
6197 code,
6198 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6199 )?;
6200 writeln!(
6201 code,
6202 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
6203 )?;
6204 writeln!(
6205 code,
6206 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6207 )?;
6208 writeln!(
6209 code,
6210 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6211 )?;
6212 writeln!(
6213 code,
6214 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6215 )?;
6216 writeln!(
6217 code,
6218 " // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
6219 )?;
6220 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6221 writeln!(
6222 code,
6223 " let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
6224 )?;
6225 writeln!(
6226 code,
6227 " enc.dispatch_thread_groups(grid_size, tg_size);"
6228 )?;
6229 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6230 writeln!(code, " }}")?;
6231 writeln!(code)?;
6232
6233 writeln!(
6235 code,
6236 " /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
6237 )?;
6238 writeln!(
6239 code,
6240 " /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
6241 )?;
6242 writeln!(
6243 code,
6244 " fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
6245 )?;
6246 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6247 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6248 writeln!(code, " let nq: u32 = NUM_HEADS as u32;")?;
6249 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
6250 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
6251 writeln!(code, " let qs: u32 = qkv_stride as u32;")?;
6252 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
6253 writeln!(
6254 code,
6255 " enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
6256 )?;
6257 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
6258 writeln!(
6259 code,
6260 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6261 )?;
6262 writeln!(
6263 code,
6264 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6265 )?;
6266 writeln!(
6267 code,
6268 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
6269 )?;
6270 writeln!(
6271 code,
6272 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
6273 )?;
6274 writeln!(
6275 code,
6276 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
6277 )?;
6278 writeln!(
6279 code,
6280 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
6281 )?;
6282 writeln!(
6283 code,
6284 " enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
6285 )?;
6286 writeln!(
6287 code,
6288 " let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
6289 )?;
6290 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6291 writeln!(
6292 code,
6293 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
6294 )?;
6295 writeln!(
6296 code,
6297 " enc.dispatch_thread_groups(grid_size, tg_size);"
6298 )?;
6299 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6300 writeln!(code, " }}")?;
6301 writeln!(code)?;
6302
6303 writeln!(
6305 code,
6306 " /// Dispatch fused K+V cache copy in one kernel launch."
6307 )?;
6308 writeln!(
6309 code,
6310 " /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
6311 )?;
6312 writeln!(
6313 code,
6314 " fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
6315 )?;
6316 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
6317 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
6318 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6319 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6320 writeln!(code, " let ko: u32 = k_offset as u32;")?;
6321 writeln!(code, " let vo: u32 = v_offset as u32;")?;
6322 writeln!(
6323 code,
6324 " enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6325 )?;
6326 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6327 writeln!(code, " enc.set_buffer(1, Some(k_dst), 0);")?;
6328 writeln!(code, " enc.set_buffer(2, Some(v_dst), 0);")?;
6329 writeln!(
6330 code,
6331 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6332 )?;
6333 writeln!(
6334 code,
6335 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6336 )?;
6337 writeln!(
6338 code,
6339 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6340 )?;
6341 writeln!(
6342 code,
6343 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6344 )?;
6345 writeln!(
6346 code,
6347 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6348 )?;
6349 writeln!(
6350 code,
6351 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6352 )?;
6353 writeln!(
6354 code,
6355 " let total = num_tokens * kv_dim * 2; // K + V"
6356 )?;
6357 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6358 writeln!(
6359 code,
6360 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6361 )?;
6362 writeln!(
6363 code,
6364 " enc.dispatch_thread_groups(grid_size, tg_size);"
6365 )?;
6366 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6367 writeln!(code, " }}")?;
6368
6369 writeln!(code, "}}")?;
6370 writeln!(code)?;
6371
6372 Ok(())
6373}
6374
6375fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6376 writeln!(
6377 code,
6378 "// ── Helper functions ──────────────────────────────────"
6379 )?;
6380 writeln!(code)?;
6381 writeln!(
6382 code,
6383 "/// Create a compute pipeline from a named function in the Metal library."
6384 )?;
6385 writeln!(
6386 code,
6387 "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6388 )?;
6389 writeln!(
6390 code,
6391 " let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6392 )?;
6393 writeln!(
6394 code,
6395 " device.new_compute_pipeline_state_with_function(&func)"
6396 )?;
6397 writeln!(
6398 code,
6399 " .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6400 )?;
6401 writeln!(code, "}}")?;
6402 writeln!(code)?;
6403
6404 Ok(())
6405}
6406
6407fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6412 let _sanitized = sanitize_name(model_name);
6413 let _vocab = config.vocab_size;
6414
6415 let mut code = String::with_capacity(16 * 1024);
6416 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6417 writeln!(
6418 code,
6419 "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6420 )?;
6421 writeln!(code)?;
6422 writeln!(code, "mod model;")?;
6423 writeln!(code)?;
6424 writeln!(code, "use std::io::Write;")?;
6425 writeln!(code, "use std::time::Instant;")?;
6426 writeln!(code, "use serde::Deserialize;")?;
6427 writeln!(code)?;
6428
6429 writeln!(code, "fn main() {{")?;
6431 writeln!(
6432 code,
6433 " let args: Vec<String> = std::env::args().collect();"
6434 )?;
6435 writeln!(code)?;
6436 writeln!(
6437 code,
6438 " // Detect --serve mode (only requires weights + tokenizer)"
6439 )?;
6440 writeln!(
6441 code,
6442 " let serve_mode = args.iter().any(|a| a == \"--serve\");"
6443 )?;
6444 writeln!(code)?;
6445 writeln!(code, " if !serve_mode && args.len() < 4 {{")?;
6446 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6447 writeln!(code, " eprintln!(\" {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6448 writeln!(code, " std::process::exit(1);")?;
6449 writeln!(code, " }}")?;
6450 writeln!(code)?;
6451 writeln!(code, " if serve_mode && args.len() < 3 {{")?;
6452 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6453 writeln!(code, " std::process::exit(1);")?;
6454 writeln!(code, " }}")?;
6455 writeln!(code)?;
6456 writeln!(code, " let weights_path = &args[1];")?;
6457 writeln!(code, " let tokenizer_path = &args[2];")?;
6458 writeln!(code)?;
6459 writeln!(code, " // Parse optional flags")?;
6460 writeln!(code, " let mut max_tokens: usize = 128;")?;
6461 writeln!(code, " let mut port: u16 = 8080;")?;
6462 writeln!(
6463 code,
6464 " let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6465 )?;
6466 writeln!(
6467 code,
6468 " let profile = args.iter().any(|a| a == \"--profile\");"
6469 )?;
6470 writeln!(code, " let mut i = 3;")?;
6471 writeln!(code, " while i < args.len() {{")?;
6472 writeln!(
6473 code,
6474 " if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6475 )?;
6476 writeln!(
6477 code,
6478 " max_tokens = args[i + 1].parse().unwrap_or(128);"
6479 )?;
6480 writeln!(code, " i += 2;")?;
6481 writeln!(
6482 code,
6483 " }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6484 )?;
6485 writeln!(
6486 code,
6487 " port = args[i + 1].parse().unwrap_or(8080);"
6488 )?;
6489 writeln!(code, " i += 2;")?;
6490 writeln!(code, " }} else if args[i] == \"--serve\" {{")?;
6491 writeln!(code, " i += 1;")?;
6492 writeln!(code, " }} else if args[i] == \"--profile\" {{")?;
6493 writeln!(code, " i += 1;")?;
6494 writeln!(code, " }} else {{")?;
6495 writeln!(code, " i += 1;")?;
6496 writeln!(code, " }}")?;
6497 writeln!(code, " }}")?;
6498 writeln!(code)?;
6499
6500 writeln!(
6502 code,
6503 " // Memory-map weights for zero-copy loading on Apple Silicon"
6504 )?;
6505 writeln!(
6506 code,
6507 " let weights_file = std::fs::File::open(weights_path)"
6508 )?;
6509 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6510 writeln!(
6511 code,
6512 " let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6513 )?;
6514 writeln!(code)?;
6515 writeln!(code, " // Load tokenizer")?;
6516 writeln!(
6517 code,
6518 " let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6519 )?;
6520 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6521 writeln!(code)?;
6522 writeln!(code, " // Create Metal model")?;
6523 writeln!(code, " eprintln!(\"Loading model onto Metal GPU...\");")?;
6524 writeln!(
6525 code,
6526 " let mut model = model::MetalModel::new(&weights_mmap);"
6527 )?;
6528 writeln!(code)?;
6529
6530 writeln!(code, " if serve_mode {{")?;
6532 writeln!(code, " serve(model, tokenizer, port);")?;
6533 writeln!(code, " }} else {{")?;
6534 writeln!(code, " let prompt = &args[3];")?;
6535 writeln!(
6536 code,
6537 " cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6538 )?;
6539 writeln!(code, " }}")?;
6540 writeln!(code, "}}")?;
6541 writeln!(code)?;
6542
6543 writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6545 writeln!(code, " // Tokenize prompt")?;
6546 writeln!(code, " let encoding = tokenizer.encode(prompt, true)")?;
6547 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6548 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6549 writeln!(code)?;
6550 writeln!(
6551 code,
6552 " // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6553 )?;
6554 writeln!(
6555 code,
6556 " // Uses double-buffered batch dispatch for GPU-efficient matmul."
6557 )?;
6558 writeln!(
6559 code,
6560 " // The last token uses synchronous forward() to get logits."
6561 )?;
6562 writeln!(code, " let prompt_len = prompt_tokens.len();")?;
6563 writeln!(code, " let prefill_start = Instant::now();")?;
6564 writeln!(code, " let logits = if prompt_len > 1 {{")?;
6565 writeln!(
6566 code,
6567 " model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6568 )?;
6569 writeln!(code, " model.forward(prompt_tokens[prompt_len - 1])")?;
6570 writeln!(code, " }} else {{")?;
6571 writeln!(code, " model.forward(prompt_tokens[0])")?;
6572 writeln!(code, " }};")?;
6573 writeln!(
6574 code,
6575 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6576 )?;
6577 writeln!(code, " let prefill_tokens = prompt_tokens.len();")?;
6578 writeln!(
6579 code,
6580 " eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6581 )?;
6582 writeln!(
6583 code,
6584 " prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6585 )?;
6586 writeln!(code)?;
6587 writeln!(code, " // Generate tokens")?;
6588 writeln!(code, " let mut next_token = argmax(&logits);")?;
6589 writeln!(code, " let gen_start = Instant::now();")?;
6590 writeln!(code, " let mut generated_count: usize = 0;")?;
6591 writeln!(code)?;
6592 writeln!(code, " for _ in 0..max_tokens {{")?;
6593 writeln!(
6594 code,
6595 " if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6596 )?;
6597 writeln!(code, " if !quiet {{")?;
6598 writeln!(code, " print!(\"{{}}\", text);")?;
6599 writeln!(code, " std::io::stdout().flush().ok();")?;
6600 writeln!(code, " }}")?;
6601 writeln!(code, " }}")?;
6602 writeln!(code, " generated_count += 1;")?;
6603 writeln!(code)?;
6604 writeln!(
6605 code,
6606 " // Use profiling forward for first token when --profile is set"
6607 )?;
6608 writeln!(
6609 code,
6610 " let logits = if profile && generated_count == 1 {{"
6611 )?;
6612 writeln!(code, " model.forward_profile(next_token)")?;
6613 writeln!(code, " }} else {{")?;
6614 writeln!(code, " model.forward(next_token)")?;
6615 writeln!(code, " }};")?;
6616 writeln!(code, " next_token = argmax(&logits);")?;
6617 writeln!(code)?;
6618 writeln!(code, " // Stop on EOS (token 2 for most models)")?;
6619 writeln!(code, " if next_token == 2 {{")?;
6620 writeln!(code, " break;")?;
6621 writeln!(code, " }}")?;
6622 writeln!(code)?;
6623 writeln!(
6624 code,
6625 " // Yield between tokens to reduce sustained GPU thermal load."
6626 )?;
6627 writeln!(
6628 code,
6629 " // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6630 )?;
6631 writeln!(
6632 code,
6633 " // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6634 )?;
6635 writeln!(
6636 code,
6637 " // briefly, providing a micro-break that helps sustain peak throughput."
6638 )?;
6639 writeln!(code, " std::thread::yield_now();")?;
6640 writeln!(code, " }}")?;
6641 writeln!(code, " if !quiet {{")?;
6642 writeln!(code, " println!();")?;
6643 writeln!(code, " }}")?;
6644 writeln!(
6645 code,
6646 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6647 )?;
6648 writeln!(
6649 code,
6650 " eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6651 )?;
6652 writeln!(
6653 code,
6654 " generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6655 )?;
6656 writeln!(code, "}}")?;
6657 writeln!(code)?;
6658
6659 writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6661 writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6662 writeln!(code, " logits.iter()")?;
6663 writeln!(code, " .enumerate()")?;
6664 writeln!(
6665 code,
6666 " .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6667 )?;
6668 writeln!(code, " .map(|(i, _)| i as u32)")?;
6669 writeln!(code, " .unwrap_or(0)")?;
6670 writeln!(code, "}}")?;
6671 writeln!(code)?;
6672
6673 writeln!(
6675 code,
6676 "// -----------------------------------------------------------------------"
6677 )?;
6678 writeln!(code, "// OpenAI-compatible API server")?;
6679 writeln!(
6680 code,
6681 "// -----------------------------------------------------------------------"
6682 )?;
6683 writeln!(code)?;
6684 writeln!(code, "#[derive(Deserialize)]")?;
6685 writeln!(code, "struct ChatRequest {{")?;
6686 writeln!(code, " messages: Vec<ChatMessage>,")?;
6687 writeln!(code, " #[serde(default)]")?;
6688 writeln!(code, " stream: Option<bool>,")?;
6689 writeln!(code, " #[serde(default)]")?;
6690 writeln!(code, " max_tokens: Option<usize>,")?;
6691 writeln!(code, " #[serde(default)]")?;
6692 writeln!(code, " temperature: Option<f32>,")?;
6693 writeln!(code, " #[serde(default)]")?;
6694 writeln!(code, " model: Option<String>,")?;
6695 writeln!(code, "}}")?;
6696 writeln!(code)?;
6697 writeln!(code, "#[derive(Deserialize)]")?;
6698 writeln!(code, "struct ChatMessage {{")?;
6699 writeln!(code, " role: String,")?;
6700 writeln!(code, " content: String,")?;
6701 writeln!(code, "}}")?;
6702 writeln!(code)?;
6703
6704 writeln!(
6706 code,
6707 "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6708 )?;
6709 writeln!(code, " let mut prompt = String::new();")?;
6710 writeln!(code, " for msg in messages {{")?;
6711 writeln!(code, " prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6712 writeln!(code, " }}")?;
6713 writeln!(code, " prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6714 writeln!(code, " prompt")?;
6715 writeln!(code, "}}")?;
6716 writeln!(code)?;
6717
6718 writeln!(
6720 code,
6721 "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6722 )?;
6723 writeln!(code, " let len = tokens.len();")?;
6724 writeln!(code, " if len > 1 {{")?;
6725 writeln!(
6726 code,
6727 " model.forward_prefill_batch(&tokens[..len - 1]);"
6728 )?;
6729 writeln!(code, " }}")?;
6730 writeln!(code, " model.forward(tokens[len - 1])")?;
6731 writeln!(code, "}}")?;
6732 writeln!(code)?;
6733
6734 writeln!(
6736 code,
6737 "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6738 )?;
6739 writeln!(code, " let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6740 writeln!(code, " let server = tiny_http::Server::http(&addr)")?;
6741 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6742 writeln!(
6743 code,
6744 " eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6745 )?;
6746 writeln!(code, " eprintln!(\"Endpoints:\");")?;
6747 writeln!(code, " eprintln!(\" POST /v1/chat/completions\");")?;
6748 writeln!(code, " eprintln!(\" GET /v1/models\");")?;
6749 writeln!(code, " eprintln!(\" GET /health\");")?;
6750 writeln!(code)?;
6751 writeln!(code, " for request in server.incoming_requests() {{")?;
6752 writeln!(code, " let method = request.method().to_string();")?;
6753 writeln!(code, " let url = request.url().to_string();")?;
6754 writeln!(code)?;
6755 writeln!(code, " match (method.as_str(), url.as_str()) {{")?;
6756
6757 writeln!(
6759 code,
6760 " (\"POST\", \"/v1/chat/completions\") => {{"
6761 )?;
6762 writeln!(
6763 code,
6764 " handle_chat_completion(&mut model, &tokenizer, request);"
6765 )?;
6766 writeln!(code, " }}")?;
6767
6768 writeln!(code, " (\"GET\", \"/v1/models\") => {{")?;
6770 writeln!(code, " let body = serde_json::json!({{")?;
6771 writeln!(code, " \"object\": \"list\",")?;
6772 writeln!(code, " \"data\": [{{")?;
6773 writeln!(code, " \"id\": \"forgellm-metal\",")?;
6774 writeln!(code, " \"object\": \"model\",")?;
6775 writeln!(code, " \"owned_by\": \"forgellm\"")?;
6776 writeln!(code, " }}]")?;
6777 writeln!(code, " }});")?;
6778 writeln!(
6779 code,
6780 " let resp = tiny_http::Response::from_string(body.to_string())"
6781 )?;
6782 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6783 writeln!(code, " request.respond(resp).ok();")?;
6784 writeln!(code, " }}")?;
6785
6786 writeln!(code, " (\"GET\", \"/health\") => {{")?;
6788 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6789 writeln!(code, " request.respond(resp).ok();")?;
6790 writeln!(code, " }}")?;
6791
6792 writeln!(code, " _ => {{")?;
6794 writeln!(
6795 code,
6796 " let resp = tiny_http::Response::from_string(\"Not Found\")"
6797 )?;
6798 writeln!(code, " .with_status_code(404);")?;
6799 writeln!(code, " request.respond(resp).ok();")?;
6800 writeln!(code, " }}")?;
6801 writeln!(code, " }}")?;
6802 writeln!(code, " }}")?;
6803 writeln!(code, "}}")?;
6804 writeln!(code)?;
6805
6806 writeln!(code, "fn handle_chat_completion(")?;
6808 writeln!(code, " model: &mut model::MetalModel,")?;
6809 writeln!(code, " tokenizer: &tokenizers::Tokenizer,")?;
6810 writeln!(code, " mut request: tiny_http::Request,")?;
6811 writeln!(code, ") {{")?;
6812 writeln!(code, " // Read request body")?;
6813 writeln!(code, " let mut body = String::new();")?;
6814 writeln!(
6815 code,
6816 " if request.as_reader().read_to_string(&mut body).is_err() {{"
6817 )?;
6818 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6819 writeln!(code, " .with_status_code(400);")?;
6820 writeln!(code, " request.respond(resp).ok();")?;
6821 writeln!(code, " return;")?;
6822 writeln!(code, " }}")?;
6823 writeln!(code)?;
6824 writeln!(code, " // Parse JSON")?;
6825 writeln!(
6826 code,
6827 " let req: ChatRequest = match serde_json::from_str(&body) {{"
6828 )?;
6829 writeln!(code, " Ok(r) => r,")?;
6830 writeln!(code, " Err(e) => {{")?;
6831 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6832 writeln!(code, " .with_status_code(400);")?;
6833 writeln!(code, " request.respond(resp).ok();")?;
6834 writeln!(code, " return;")?;
6835 writeln!(code, " }}")?;
6836 writeln!(code, " }};")?;
6837 writeln!(code)?;
6838 writeln!(
6839 code,
6840 " let prompt = format_chat_messages(&req.messages);"
6841 )?;
6842 writeln!(
6843 code,
6844 " let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6845 )?;
6846 writeln!(code, " Ok(e) => e,")?;
6847 writeln!(code, " Err(e) => {{")?;
6848 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6849 writeln!(code, " .with_status_code(500);")?;
6850 writeln!(code, " request.respond(resp).ok();")?;
6851 writeln!(code, " return;")?;
6852 writeln!(code, " }}")?;
6853 writeln!(code, " }};")?;
6854 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6855 writeln!(code, " let stream = req.stream.unwrap_or(false);")?;
6856 writeln!(code, " let max_tokens = req.max_tokens.unwrap_or(256);")?;
6857 writeln!(
6858 code,
6859 " let _temperature = req.temperature.unwrap_or(1.0);"
6860 )?;
6861 writeln!(code)?;
6862
6863 writeln!(code, " model.reset();")?;
6865 writeln!(code)?;
6866
6867 writeln!(code, " let prefill_start = Instant::now();")?;
6869 writeln!(code, " let logits = prefill(model, prompt_tokens);")?;
6870 writeln!(
6871 code,
6872 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6873 )?;
6874 writeln!(code, " let prefill_count = prompt_tokens.len();")?;
6875 writeln!(code, " let mut next_token = argmax(&logits);")?;
6876 writeln!(code)?;
6877
6878 writeln!(code, " if stream {{")?;
6879
6880 writeln!(
6882 code,
6883 " // SSE streaming: generate tokens and build SSE body"
6884 )?;
6885 writeln!(code, " let gen_start = Instant::now();")?;
6886 writeln!(code, " let mut generated_count: usize = 0;")?;
6887 writeln!(code, " let mut sse_body = String::new();")?;
6888 writeln!(code, " for _ in 0..max_tokens {{")?;
6889 writeln!(code, " if next_token == 2 {{ break; }}")?;
6890 writeln!(
6891 code,
6892 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6893 )?;
6894 writeln!(
6895 code,
6896 " let escaped = serde_json::to_string(&text).unwrap_or_default();"
6897 )?;
6898 writeln!(
6899 code,
6900 " // escaped includes surrounding quotes, strip them"
6901 )?;
6902 writeln!(
6903 code,
6904 " let inner = &escaped[1..escaped.len()-1];"
6905 )?;
6906 writeln!(code, " sse_body.push_str(&format!(")?;
6907 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6908 writeln!(code, " inner")?;
6909 writeln!(code, " ));")?;
6910 writeln!(code, " }}")?;
6911 writeln!(code, " generated_count += 1;")?;
6912 writeln!(code, " let logits = model.forward(next_token);")?;
6913 writeln!(code, " next_token = argmax(&logits);")?;
6914 writeln!(code, " }}")?;
6915 writeln!(
6916 code,
6917 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6918 )?;
6919 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6920 writeln!(code, " let gen_time_ms = gen_elapsed * 1000.0;")?;
6921 writeln!(code)?;
6922 writeln!(
6923 code,
6924 " // Final chunk with finish_reason, timing, and DONE sentinel"
6925 )?;
6926 writeln!(code, " sse_body.push_str(&format!(")?;
6927 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6928 writeln!(code, " prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6929 writeln!(code, " ));")?;
6930 writeln!(code)?;
6931 writeln!(
6932 code,
6933 " let resp = tiny_http::Response::from_string(sse_body)"
6934 )?;
6935 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6936 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6937 writeln!(code, " request.respond(resp).ok();")?;
6938
6939 writeln!(code, " }} else {{")?;
6940
6941 writeln!(
6943 code,
6944 " // Non-streaming: generate all tokens, return JSON"
6945 )?;
6946 writeln!(code, " let gen_start = Instant::now();")?;
6947 writeln!(code, " let mut generated_count: usize = 0;")?;
6948 writeln!(code, " let mut generated = String::new();")?;
6949 writeln!(code, " for _ in 0..max_tokens {{")?;
6950 writeln!(code, " if next_token == 2 {{ break; }}")?;
6951 writeln!(
6952 code,
6953 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6954 )?;
6955 writeln!(code, " generated.push_str(&text);")?;
6956 writeln!(code, " }}")?;
6957 writeln!(code, " generated_count += 1;")?;
6958 writeln!(code, " let logits = model.forward(next_token);")?;
6959 writeln!(code, " next_token = argmax(&logits);")?;
6960 writeln!(code, " }}")?;
6961 writeln!(
6962 code,
6963 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6964 )?;
6965 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6966 writeln!(code)?;
6967 writeln!(code, " let resp_json = serde_json::json!({{")?;
6968 writeln!(code, " \"id\": \"chatcmpl-1\",")?;
6969 writeln!(code, " \"object\": \"chat.completion\",")?;
6970 writeln!(code, " \"choices\": [{{")?;
6971 writeln!(code, " \"index\": 0,")?;
6972 writeln!(code, " \"message\": {{")?;
6973 writeln!(code, " \"role\": \"assistant\",")?;
6974 writeln!(code, " \"content\": generated")?;
6975 writeln!(code, " }},")?;
6976 writeln!(code, " \"finish_reason\": \"stop\"")?;
6977 writeln!(code, " }}],")?;
6978 writeln!(code, " \"usage\": {{")?;
6979 writeln!(code, " \"prefill_tokens\": prefill_count,")?;
6980 writeln!(
6981 code,
6982 " \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
6983 )?;
6984 writeln!(
6985 code,
6986 " \"generation_tokens\": generated_count,"
6987 )?;
6988 writeln!(
6989 code,
6990 " \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
6991 )?;
6992 writeln!(code, " \"tokens_per_sec\": gen_tok_s")?;
6993 writeln!(code, " }}")?;
6994 writeln!(code, " }});")?;
6995 writeln!(
6996 code,
6997 " let resp = tiny_http::Response::from_string(resp_json.to_string())"
6998 )?;
6999 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
7000 writeln!(code, " request.respond(resp).ok();")?;
7001 writeln!(code, " }}")?;
7002 writeln!(code, "}}")?;
7003
7004 Ok(code)
7005}
7006
7007#[cfg(test)]
7012mod tests {
7013 use super::*;
7014 use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
7015
7016 fn minimal_config() -> ModelConfig {
7017 ModelConfig {
7018 architecture: Architecture::Llama,
7019 hidden_size: 64,
7020 intermediate_size: 128,
7021 num_layers: 2,
7022 num_attention_heads: 4,
7023 num_kv_heads: 4,
7024 head_dim: 16,
7025 vocab_size: 256,
7026 max_seq_len: 512,
7027 rms_norm_eps: 1e-5,
7028 rope_theta: 10000.0,
7029 dtype: DType::F32,
7030 sliding_window_size: None,
7031 qkv_bias: false,
7032 }
7033 }
7034
7035 fn minimal_graph() -> Graph {
7036 Graph::new("test-metal").with_config(minimal_config())
7037 }
7038
7039 #[test]
7040 fn generate_metal_project_creates_files() {
7041 let dir = tempfile::tempdir().unwrap();
7042 let graph = minimal_graph();
7043 generate_metal_project(&graph, dir.path(), "test-model").unwrap();
7044
7045 assert!(
7046 dir.path().join("Cargo.toml").exists(),
7047 "Cargo.toml should be created"
7048 );
7049 assert!(
7050 dir.path().join("src/model.rs").exists(),
7051 "src/model.rs should be created"
7052 );
7053 assert!(
7054 dir.path().join("src/main.rs").exists(),
7055 "src/main.rs should be created"
7056 );
7057 assert!(
7058 dir.path().join("shaders/kernels.metal").exists(),
7059 "shaders/kernels.metal should be created"
7060 );
7061 }
7062
7063 #[test]
7064 fn generated_cargo_toml_has_metal_dep() {
7065 let toml = generate_cargo_toml("my-model");
7066 assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
7067 assert!(
7068 toml.contains("tokenizers"),
7069 "Cargo.toml should depend on tokenizers"
7070 );
7071 assert!(
7072 toml.contains("memmap2"),
7073 "Cargo.toml should depend on memmap2"
7074 );
7075 assert!(toml.contains("half"), "Cargo.toml should depend on half");
7076 }
7077
7078 #[test]
7079 fn generated_model_rs_contains_metal_code() {
7080 let config = minimal_config();
7081 let model_rs = generate_model_rs(&config).unwrap();
7082
7083 assert!(
7084 model_rs.contains("pub struct MetalModel"),
7085 "model.rs should define MetalModel struct"
7086 );
7087 assert!(
7088 model_rs.contains("matmul_pipeline: ComputePipelineState"),
7089 "MetalModel should have matmul_pipeline field"
7090 );
7091 assert!(
7092 model_rs.contains("Device::system_default()"),
7093 "model.rs should use Metal device"
7094 );
7095 assert!(
7096 model_rs.contains("new_library_with_source"),
7097 "model.rs should compile Metal shaders"
7098 );
7099 assert!(
7100 model_rs.contains("fn new(weights: &[u8])"),
7101 "MetalModel should implement new()"
7102 );
7103 assert!(
7104 model_rs.contains("fn forward(&mut self, token_id: u32)"),
7105 "MetalModel should implement forward()"
7106 );
7107 }
7108
7109 #[test]
7110 fn generated_shaders_contain_kernel_names() {
7111 let shaders = generate_metal_shaders(&minimal_config());
7112
7113 assert!(
7114 shaders.contains("kernel void matmul_vec"),
7115 "shaders should contain matmul_vec kernel"
7116 );
7117 assert!(
7118 shaders.contains("kernel void rms_norm"),
7119 "shaders should contain rms_norm kernel"
7120 );
7121 assert!(
7122 shaders.contains("kernel void rope"),
7123 "shaders should contain rope kernel"
7124 );
7125 assert!(
7126 shaders.contains("kernel void softmax"),
7127 "shaders should contain softmax kernel"
7128 );
7129 assert!(
7130 shaders.contains("kernel void silu_mul("),
7131 "shaders should contain silu_mul kernel"
7132 );
7133 assert!(
7134 shaders.contains("kernel void silu_mul_fused"),
7135 "shaders should contain silu_mul_fused kernel"
7136 );
7137 assert!(
7138 shaders.contains("kernel void elementwise_add"),
7139 "shaders should contain elementwise_add kernel"
7140 );
7141 assert!(
7142 shaders.contains("kernel void attention"),
7143 "shaders should contain attention kernel"
7144 );
7145 assert!(
7146 shaders.contains("kernel void add_inplace"),
7147 "shaders should contain add_inplace kernel"
7148 );
7149 assert!(
7150 shaders.contains("kernel void copy_buffer"),
7151 "shaders should contain copy_buffer kernel"
7152 );
7153 assert!(
7154 shaders.contains("kernel void copy_offset"),
7155 "shaders should contain copy_offset kernel"
7156 );
7157 }
7158
7159 #[test]
7160 fn generated_shaders_use_simdgroup_features() {
7161 let shaders = generate_metal_shaders(&minimal_config());
7162
7163 assert!(
7164 shaders.contains("threadgroup_barrier"),
7165 "shaders should use threadgroup barriers"
7166 );
7167 assert!(
7168 shaders.contains("threadgroup float"),
7169 "shaders should use threadgroup shared memory"
7170 );
7171 assert!(
7172 shaders.contains("thread_index_in_threadgroup"),
7173 "shaders should use threadgroup indexing"
7174 );
7175 assert!(
7176 shaders.contains("simd_sum"),
7177 "shaders should use simd_sum for warp-level reduction"
7178 );
7179 assert!(
7180 shaders.contains("simd_max"),
7181 "attention kernel should use simd_max for cooperative softmax"
7182 );
7183 assert!(
7184 shaders.contains("thread_index_in_simdgroup"),
7185 "shaders should use simdgroup lane indexing"
7186 );
7187 assert!(
7188 shaders.contains("simdgroup_index_in_threadgroup"),
7189 "shaders should use simdgroup indexing within threadgroup"
7190 );
7191 assert!(
7192 shaders.contains("float4"),
7193 "matmul_vec should use float4 vectorized loads"
7194 );
7195 }
7196
7197 #[test]
7198 fn generated_main_rs_has_tokenizer_usage() {
7199 let config = minimal_config();
7200 let main_rs = generate_main_rs("test-model", &config).unwrap();
7201
7202 assert!(
7203 main_rs.contains("tokenizers::Tokenizer"),
7204 "main.rs should use tokenizers crate"
7205 );
7206 assert!(
7207 main_rs.contains("MetalModel::new"),
7208 "main.rs should call MetalModel::new"
7209 );
7210 assert!(
7211 main_rs.contains("model.forward"),
7212 "main.rs should call model.forward"
7213 );
7214 assert!(
7215 main_rs.contains("memmap2"),
7216 "main.rs should use memmap2 for zero-copy weight loading"
7217 );
7218 }
7219
7220 #[test]
7221 fn missing_config_returns_error() {
7222 let dir = tempfile::tempdir().unwrap();
7223 let graph = Graph::new("no-config");
7224 let result = generate_metal_project(&graph, dir.path(), "fail");
7225 assert!(
7226 matches!(result, Err(MetalCodegenError::MissingConfig)),
7227 "should fail with MissingConfig when graph has no config"
7228 );
7229 }
7230
7231 #[test]
7232 fn sanitize_name_works() {
7233 assert_eq!(sanitize_name("My Model!"), "my-model");
7234 assert_eq!(sanitize_name("test_model"), "test-model");
7235 assert_eq!(sanitize_name("simple"), "simple");
7236 }
7237
7238 #[test]
7239 fn generated_forward_uses_single_command_buffer() {
7240 let config = minimal_config();
7241 let model_rs = generate_model_rs(&config).unwrap();
7242
7243 let forward_start = model_rs
7246 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7247 .unwrap();
7248 let forward_body = &model_rs[forward_start..];
7249 let forward_end = forward_body
7251 .find("\n pub fn forward_profile")
7252 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7253 .or_else(|| forward_body.find("\n fn dispatch_"))
7254 .unwrap_or(forward_body.len());
7255 let forward_code = &forward_body[..forward_end];
7256
7257 let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
7259 assert_eq!(
7260 cmd_buf_count, 1,
7261 "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
7262 );
7263
7264 let commit_count = forward_code.matches("cmd.commit()").count();
7266 assert_eq!(
7267 commit_count, 1,
7268 "forward() should commit exactly once, found {commit_count}"
7269 );
7270
7271 let wait_count = forward_code.matches("wait_until_completed()").count();
7273 assert!(
7274 wait_count >= 1 && wait_count <= 2,
7275 "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
7276 );
7277 }
7278
7279 #[test]
7280 fn generated_model_has_preallocated_working_buffers() {
7281 let config = minimal_config();
7282 let model_rs = generate_model_rs(&config).unwrap();
7283
7284 for buf_name in &[
7285 "normed_buf",
7286 "qkv_buf",
7287 "attn_out_buf",
7288 "attn_proj_buf",
7289 "gate_up_buf",
7290 "ffn_hidden_buf",
7291 "ffn_out_buf",
7292 "add_tmp_buf",
7293 ] {
7294 assert!(
7295 model_rs.contains(&format!("{buf_name}: Buffer")),
7296 "MetalModel should have pre-allocated {buf_name} field"
7297 );
7298 }
7299 }
7300
7301 #[test]
7302 fn generated_dispatch_helpers_take_compute_encoder_ref() {
7303 let config = minimal_config();
7304 let model_rs = generate_model_rs(&config).unwrap();
7305
7306 for method in &[
7307 "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
7308 "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
7309 "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
7310 "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
7311 "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
7312 "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
7313 "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
7314 "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
7315 "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
7316 "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
7317 "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
7318 "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7319 "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7320 ] {
7321 assert!(
7322 model_rs.contains(method),
7323 "model.rs should contain dispatch helper: {method}"
7324 );
7325 }
7326 }
7327
7328 #[test]
7329 fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7330 let config = minimal_config();
7331 let model_rs = generate_model_rs(&config).unwrap();
7332
7333 let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7335 let helpers_code = &model_rs[helpers_start..];
7336
7337 assert!(
7339 !helpers_code.contains("self.queue.new_command_buffer()"),
7340 "dispatch helpers should not create their own command buffers"
7341 );
7342
7343 assert!(
7345 !helpers_code.contains("new_compute_command_encoder()"),
7346 "dispatch helpers should not create their own compute encoders"
7347 );
7348
7349 assert!(
7351 !helpers_code.contains("end_encoding()"),
7352 "dispatch helpers should not call end_encoding"
7353 );
7354
7355 assert!(
7357 !helpers_code.contains(".commit()"),
7358 "dispatch helpers should not commit command buffers"
7359 );
7360 assert!(
7361 !helpers_code.contains("wait_until_completed"),
7362 "dispatch helpers should not wait on command buffers"
7363 );
7364 }
7365
7366 #[test]
7367 fn generated_forward_batches_compute_encoders() {
7368 let config = minimal_config();
7369 let model_rs = generate_model_rs(&config).unwrap();
7370
7371 let forward_start = model_rs
7373 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7374 .unwrap();
7375 let forward_body = &model_rs[forward_start..];
7376 let forward_end = forward_body
7377 .find("\n pub fn forward_profile")
7378 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7379 .or_else(|| forward_body.find("\n fn dispatch_"))
7380 .unwrap_or(forward_body.len());
7381 let forward_code = &forward_body[..forward_end];
7382
7383 assert!(
7385 !forward_code.contains("device.new_buffer"),
7386 "forward() should not allocate new buffers per call"
7387 );
7388
7389 let compute_encoder_count = forward_code
7392 .matches("new_compute_command_encoder()")
7393 .count();
7394 let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7395
7396 assert_eq!(
7398 compute_encoder_count, 1,
7399 "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7400 );
7401 assert_eq!(
7402 blit_encoder_count, 0,
7403 "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7404 );
7405 }
7406
7407 #[test]
7408 fn generated_forward_uses_add_inplace() {
7409 let config = minimal_config();
7410 let model_rs = generate_model_rs(&config).unwrap();
7411
7412 assert!(
7414 model_rs.contains("dispatch_add_inplace"),
7415 "forward() should use dispatch_add_inplace for residual connections"
7416 );
7417 assert!(
7418 model_rs.contains("add_inplace_pipeline"),
7419 "MetalModel should have add_inplace_pipeline"
7420 );
7421 }
7422
7423 fn minimal_q8_config() -> ModelConfig {
7424 ModelConfig {
7425 architecture: Architecture::Llama,
7426 hidden_size: 64,
7427 intermediate_size: 128,
7428 num_layers: 2,
7429 num_attention_heads: 4,
7430 num_kv_heads: 4,
7431 head_dim: 16,
7432 vocab_size: 256,
7433 max_seq_len: 512,
7434 rms_norm_eps: 1e-5,
7435 rope_theta: 10000.0,
7436 dtype: DType::Q8_0,
7437 sliding_window_size: None,
7438 qkv_bias: false,
7439 }
7440 }
7441
7442 #[test]
7443 fn generated_shaders_contain_q8_kernel() {
7444 let shaders = generate_metal_shaders(&minimal_config());
7445
7446 assert!(
7447 shaders.contains("kernel void matmul_vec_q8"),
7448 "shaders should contain matmul_vec_q8 kernel"
7449 );
7450 assert!(
7451 shaders.contains("device const uchar* matrix"),
7452 "matmul_vec_q8 should accept raw Q8_0 bytes"
7453 );
7454 assert!(
7455 shaders.contains("packed_short4"),
7456 "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7457 );
7458 assert!(
7459 shaders.contains("as_type<char2>"),
7460 "matmul_vec_q8 should bitcast short lanes to char2"
7461 );
7462 assert!(
7463 shaders.contains("device const half*"),
7464 "matmul_vec_q8 should read f16 scale via half pointer"
7465 );
7466 }
7467
7468 #[test]
7469 fn generated_model_uses_fused_qkv_projections() {
7470 let config = minimal_config();
7471 let model_rs = generate_model_rs(&config).unwrap();
7472
7473 assert!(
7475 model_rs.contains("qkv_weight: Buffer"),
7476 "LayerBuffers should have fused qkv_weight field"
7477 );
7478 assert!(
7480 !model_rs.contains(" q_weight: Buffer"),
7481 "LayerBuffers should not have separate q_weight field"
7482 );
7483 assert!(
7484 !model_rs.contains(" k_weight: Buffer"),
7485 "LayerBuffers should not have separate k_weight field"
7486 );
7487 assert!(
7488 !model_rs.contains(" v_weight: Buffer"),
7489 "LayerBuffers should not have separate v_weight field"
7490 );
7491
7492 assert!(
7494 model_rs.contains("gate_up_weight: Buffer"),
7495 "LayerBuffers should have fused gate_up_weight field"
7496 );
7497 assert!(
7499 !model_rs.contains(" gate_weight: Buffer"),
7500 "LayerBuffers should not have separate gate_weight field"
7501 );
7502 assert!(
7503 !model_rs.contains(" up_weight: Buffer"),
7504 "LayerBuffers should not have separate up_weight field"
7505 );
7506
7507 assert!(
7509 model_rs.contains("qkv_buf: Buffer"),
7510 "MetalModel should have fused qkv_buf"
7511 );
7512 assert!(
7513 model_rs.contains("gate_up_buf: Buffer"),
7514 "MetalModel should have fused gate_up_buf"
7515 );
7516
7517 assert!(
7519 model_rs.contains("dispatch_silu_mul_fused"),
7520 "forward pass should use dispatch_silu_mul_fused"
7521 );
7522 assert!(
7523 model_rs.contains("dispatch_rope_offset"),
7524 "forward pass should use dispatch_rope_offset for fused QKV"
7525 );
7526 assert!(
7527 model_rs.contains("dispatch_attention_offset"),
7528 "forward pass should use dispatch_attention_offset for fused QKV"
7529 );
7530 }
7531
7532 #[test]
7533 fn q8_model_has_matmul_q8_pipeline() {
7534 let config = minimal_q8_config();
7535 let model_rs = generate_model_rs(&config).unwrap();
7536
7537 assert!(
7538 model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7539 "MetalModel should have matmul_q8_pipeline field"
7540 );
7541 assert!(
7542 model_rs.contains("matmul_q8_pipeline,"),
7543 "MetalModel Self should include matmul_q8_pipeline"
7544 );
7545 }
7546
7547 #[test]
7548 fn q8_model_uses_dispatch_matmul_q8() {
7549 let config = minimal_q8_config();
7550 let model_rs = generate_model_rs(&config).unwrap();
7551
7552 assert!(
7553 model_rs.contains("dispatch_matmul_q8"),
7554 "Q8_0 model should use dispatch_matmul_q8 for projections"
7555 );
7556 assert!(
7557 model_rs.contains("fn dispatch_matmul_q8"),
7558 "model.rs should define dispatch_matmul_q8 method"
7559 );
7560 }
7561
7562 #[test]
7563 fn q8_model_loads_raw_bytes_not_dequantized() {
7564 let config = minimal_q8_config();
7565 let model_rs = generate_model_rs(&config).unwrap();
7566
7567 assert!(
7569 !model_rs.contains("f16_to_f32"),
7570 "Q8_0 model should not dequantize weights to f32"
7571 );
7572 assert!(
7573 !model_rs.contains("f32_data"),
7574 "Q8_0 model should not create f32 weight data"
7575 );
7576
7577 assert!(
7579 model_rs.contains("total_raw as u64"),
7580 "Q8_0 model should load raw bytes into Metal buffer"
7581 );
7582 }
7583
7584 #[test]
7585 fn q8_model_norms_stay_f32() {
7586 let config = minimal_q8_config();
7587 let model_rs = generate_model_rs(&config).unwrap();
7588
7589 assert!(
7591 model_rs.contains("let attn_norm = next_f32_buffer"),
7592 "attn_norm should use f32 buffer even for Q8_0 models"
7593 );
7594 assert!(
7595 model_rs.contains("let ffn_norm = next_f32_buffer"),
7596 "ffn_norm should use f32 buffer even for Q8_0 models"
7597 );
7598 assert!(
7599 model_rs.contains("let norm_buf = next_f32_buffer"),
7600 "final norm should use f32 buffer even for Q8_0 models"
7601 );
7602 }
7603
7604 #[test]
7605 fn q8_model_uses_fused_weight_loading() {
7606 let config = minimal_q8_config();
7607 let model_rs = generate_model_rs(&config).unwrap();
7608
7609 assert!(
7611 model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7612 "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7613 );
7614 assert!(
7616 model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7617 "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7618 );
7619 assert!(
7621 model_rs.contains("let o_weight = next_q8_buffer"),
7622 "Q8_0 model should use next_q8_buffer for O weight"
7623 );
7624 assert!(
7625 model_rs.contains("let down_weight = next_q8_buffer"),
7626 "Q8_0 model should use next_q8_buffer for down weight"
7627 );
7628 }
7629
7630 #[test]
7631 fn f32_model_does_not_use_q8_dispatch() {
7632 let config = minimal_config();
7633 let model_rs = generate_model_rs(&config).unwrap();
7634
7635 let forward_start = model_rs
7637 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7638 .unwrap();
7639 let forward_body = &model_rs[forward_start..];
7640 let forward_end = forward_body
7641 .find("\n fn dispatch_")
7642 .unwrap_or(forward_body.len());
7643 let forward_code = &forward_body[..forward_end];
7644
7645 assert!(
7646 !forward_code.contains("dispatch_matmul_q8"),
7647 "f32 model forward should not use dispatch_matmul_q8"
7648 );
7649 }
7650
7651 #[test]
7652 fn q8_dispatch_helper_takes_compute_encoder_ref() {
7653 let config = minimal_q8_config();
7654 let model_rs = generate_model_rs(&config).unwrap();
7655
7656 assert!(
7657 model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7658 "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7659 );
7660 }
7661
7662 #[test]
7663 fn generated_model_has_double_buffered_prefill() {
7664 let config = minimal_config();
7665 let model_rs = generate_model_rs(&config).unwrap();
7666
7667 assert!(
7669 model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7670 "MetalModel should have prev_cmd field for double-buffered prefill"
7671 );
7672
7673 assert!(
7675 model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7676 "MetalModel should have forward_prefill method"
7677 );
7678
7679 assert!(
7681 model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7682 "forward() should drain prev_cmd from previous prefill"
7683 );
7684 }
7685
7686 #[test]
7687 fn generated_main_rs_uses_forward_prefill_for_prompt() {
7688 let config = minimal_config();
7689 let main_rs = generate_main_rs("test-model", &config).unwrap();
7690
7691 assert!(
7692 main_rs.contains("forward_prefill"),
7693 "main.rs should use forward_prefill for intermediate prompt tokens"
7694 );
7695 assert!(
7696 main_rs.contains("double-buffered"),
7697 "main.rs should document double-buffered prefill"
7698 );
7699 }
7700
7701 #[test]
7702 fn generated_shaders_q8_uses_wide_vectorized_loads() {
7703 let shaders = generate_metal_shaders(&minimal_config());
7704
7705 assert!(
7706 shaders.contains("packed_short4"),
7707 "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7708 );
7709 assert!(
7710 shaders.contains("d0[0]"),
7711 "matmul_vec_q8 should index the wide pointer for row 0"
7712 );
7713 assert!(
7714 shaders.contains("as_type<char2>"),
7715 "matmul_vec_q8 should bitcast short lanes to char2"
7716 );
7717 assert!(
7718 shaders.contains("dot("),
7719 "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7720 );
7721 }
7722
7723 fn minimal_q4_config() -> ModelConfig {
7726 ModelConfig {
7727 architecture: Architecture::Llama,
7728 hidden_size: 64,
7729 intermediate_size: 128,
7730 num_layers: 2,
7731 num_attention_heads: 4,
7732 num_kv_heads: 4,
7733 head_dim: 16,
7734 vocab_size: 256,
7735 max_seq_len: 512,
7736 rms_norm_eps: 1e-5,
7737 rope_theta: 10000.0,
7738 dtype: DType::Q4_0,
7739 sliding_window_size: None,
7740 qkv_bias: false,
7741 }
7742 }
7743
7744 #[test]
7745 fn generated_shaders_contain_q4_kernel() {
7746 let shaders = generate_metal_shaders(&minimal_config());
7747
7748 assert!(
7749 shaders.contains("kernel void matmul_vec_q4"),
7750 "shaders should contain matmul_vec_q4 kernel"
7751 );
7752 assert!(
7753 shaders.contains("Q4_ROWS_PER_TG"),
7754 "shaders should define Q4_ROWS_PER_TG constant"
7755 );
7756 assert!(
7757 shaders.contains("Q4_ROWS_PER_SG"),
7758 "shaders should define Q4_ROWS_PER_SG constant"
7759 );
7760 }
7761
7762 #[test]
7763 fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7764 let shaders = generate_metal_shaders(&minimal_config());
7765
7766 assert!(
7768 shaders.contains("uchar4"),
7769 "matmul_vec_q4 should use uchar4 for packed byte loads"
7770 );
7771 assert!(
7773 shaders.contains("&0xF"),
7774 "matmul_vec_q4 should extract low nibble with &0xF"
7775 );
7776 assert!(
7777 shaders.contains(">>4"),
7778 "matmul_vec_q4 should extract high nibble with >>4"
7779 );
7780 assert!(
7782 shaders.contains("-8)"),
7783 "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7784 );
7785 assert!(
7787 shaders.contains("blk * 18"),
7788 "matmul_vec_q4 should use 18-byte block stride"
7789 );
7790 }
7791
7792 #[test]
7793 fn q4_model_has_matmul_q4_pipeline() {
7794 let config = minimal_q4_config();
7795 let model_rs = generate_model_rs(&config).unwrap();
7796
7797 assert!(
7798 model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7799 "MetalModel should have matmul_q4_pipeline field"
7800 );
7801 assert!(
7802 model_rs.contains("matmul_q4_pipeline,"),
7803 "MetalModel Self should include matmul_q4_pipeline"
7804 );
7805 }
7806
7807 #[test]
7808 fn q4_model_uses_dispatch_matmul_q4() {
7809 let config = minimal_q4_config();
7810 let model_rs = generate_model_rs(&config).unwrap();
7811
7812 assert!(
7813 model_rs.contains("dispatch_matmul_q4"),
7814 "Q4_0 model should use dispatch_matmul_q4 for projections"
7815 );
7816 assert!(
7817 model_rs.contains("fn dispatch_matmul_q4"),
7818 "model.rs should define dispatch_matmul_q4 method"
7819 );
7820 }
7821
7822 #[test]
7823 fn q4_model_loads_raw_bytes_not_dequantized() {
7824 let config = minimal_q4_config();
7825 let model_rs = generate_model_rs(&config).unwrap();
7826
7827 assert!(
7829 !model_rs.contains("f16_to_f32"),
7830 "Q4_0 model should not dequantize weights to f32"
7831 );
7832
7833 assert!(
7835 model_rs.contains("total_raw as u64"),
7836 "Q4_0 model should load raw bytes into Metal buffer"
7837 );
7838 }
7839
7840 #[test]
7841 fn q4_model_norms_stay_f32() {
7842 let config = minimal_q4_config();
7843 let model_rs = generate_model_rs(&config).unwrap();
7844
7845 assert!(
7846 model_rs.contains("let attn_norm = next_f32_buffer"),
7847 "attn_norm should use f32 buffer even for Q4_0 models"
7848 );
7849 assert!(
7850 model_rs.contains("let ffn_norm = next_f32_buffer"),
7851 "ffn_norm should use f32 buffer even for Q4_0 models"
7852 );
7853 assert!(
7854 model_rs.contains("let norm_buf = next_f32_buffer"),
7855 "final norm should use f32 buffer even for Q4_0 models"
7856 );
7857 }
7858
7859 #[test]
7860 fn q4_model_uses_fused_weight_loading() {
7861 let config = minimal_q4_config();
7862 let model_rs = generate_model_rs(&config).unwrap();
7863
7864 assert!(
7865 model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7866 "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7867 );
7868 assert!(
7869 model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7870 "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7871 );
7872 assert!(
7873 model_rs.contains("let o_weight = next_q4_buffer"),
7874 "Q4_0 model should use next_q4_buffer for O weight"
7875 );
7876 assert!(
7877 model_rs.contains("let down_weight = next_q4_buffer"),
7878 "Q4_0 model should use next_q4_buffer for down weight"
7879 );
7880 }
7881
7882 #[test]
7883 fn attention_flash_batch_kernel_exists() {
7884 let config = minimal_config();
7888 let model_rs = generate_model_rs(&config).unwrap();
7889 let shaders = generate_metal_shaders(&config);
7890
7891 assert!(
7892 shaders.contains("kernel void attention_flash_batch"),
7893 "shaders.metal must still contain the attention_flash_batch kernel"
7894 );
7895 assert!(
7896 shaders.contains("FLASH_K_TILE"),
7897 "flash kernel must tile K/V with a FLASH_K_TILE constant"
7898 );
7899 assert!(
7900 model_rs.contains("attention_flash_batch_pipeline"),
7901 "MetalModel must register the flash attention pipeline"
7902 );
7903 }
7904
7905 #[test]
7906 fn attention_mma_flash_batch_kernel_wired() {
7907 let config = minimal_config();
7910 let model_rs = generate_model_rs(&config).unwrap();
7911 let shaders = generate_metal_shaders(&config);
7912
7913 assert!(
7914 shaders.contains("kernel void attention_mma_flash_batch"),
7915 "shaders.metal must contain the MMA flash kernel"
7916 );
7917 assert!(
7918 shaders.contains("FLASH_MMA_Q_BLOCK"),
7919 "MMA flash kernel must define Q_BLOCK tiling constant"
7920 );
7921 assert!(
7922 shaders.contains("simdgroup_multiply_accumulate"),
7923 "MMA flash kernel must use hardware MMA"
7924 );
7925 assert!(
7926 model_rs.contains("attention_mma_flash_batch_pipeline"),
7927 "MetalModel must register the MMA flash pipeline"
7928 );
7929 assert!(
7930 model_rs.contains("mma_opt_out"),
7931 "dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
7932 );
7933 assert!(
7934 model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
7935 "MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
7936 );
7937 }
7938
7939 #[test]
7940 fn forward_prefill_batch_chunks_by_max_batch_size() {
7941 let config = minimal_config();
7946 let model_rs = generate_model_rs(&config).unwrap();
7947 assert!(
7948 model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
7949 "forward_prefill_batch must chunk long prompts"
7950 );
7951 assert!(
7952 !model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
7953 "the old truncation path must be gone"
7954 );
7955 }
7956
7957 #[test]
7958 fn qwen2_qkv_bias_wired_through_metal_codegen() {
7959 let config = ModelConfig {
7963 architecture: Architecture::Qwen2,
7964 qkv_bias: true,
7965 ..minimal_config()
7966 };
7967 let model_rs = generate_model_rs(&config).unwrap();
7968
7969 assert!(
7970 model_rs.contains("qkv_bias: Buffer"),
7971 "Qwen2 LayerBuffers must declare qkv_bias field"
7972 );
7973 assert!(
7974 model_rs.contains("let qkv_bias = next_f32_buffer"),
7975 "Qwen2 layer init must load the bias from the weight blob"
7976 );
7977 assert!(
7978 model_rs.contains("add_bias_batch_pipeline"),
7979 "Qwen2 model struct must include the add_bias_batch_pipeline"
7980 );
7981 assert!(
7982 model_rs.contains("fn dispatch_add_bias_batch"),
7983 "Qwen2 codegen must emit dispatch_add_bias_batch helper"
7984 );
7985 assert!(
7986 model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
7987 "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
7988 );
7989 assert!(
7990 model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
7991 "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
7992 );
7993
7994 let shaders = generate_metal_shaders(&config);
7996 assert!(
7997 shaders.contains("kernel void add_bias_batch"),
7998 "shaders.metal must contain the add_bias_batch kernel"
7999 );
8000 }
8001
8002 #[test]
8003 fn llama_does_not_emit_qkv_bias_machinery() {
8004 let config = minimal_config();
8007 assert!(!config.qkv_bias);
8008 let model_rs = generate_model_rs(&config).unwrap();
8009 assert!(
8010 !model_rs.contains("qkv_bias: Buffer"),
8011 "Llama must not have qkv_bias field"
8012 );
8013 assert!(
8014 !model_rs.contains("add_bias_batch_pipeline"),
8015 "Llama must not pull in add_bias_batch_pipeline"
8016 );
8017 assert!(
8018 !model_rs.contains("dispatch_add_bias_batch"),
8019 "Llama must not call dispatch_add_bias_batch"
8020 );
8021 }
8022
8023 #[test]
8024 fn q4_dispatch_helper_takes_compute_encoder_ref() {
8025 let config = minimal_q4_config();
8026 let model_rs = generate_model_rs(&config).unwrap();
8027
8028 assert!(
8029 model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
8030 "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
8031 );
8032 }
8033
8034 #[test]
8035 fn f32_model_does_not_use_q4_dispatch() {
8036 let config = minimal_config();
8037 let model_rs = generate_model_rs(&config).unwrap();
8038
8039 let forward_start = model_rs
8040 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8041 .unwrap();
8042 let forward_body = &model_rs[forward_start..];
8043 let forward_end = forward_body
8044 .find("\n fn dispatch_")
8045 .unwrap_or(forward_body.len());
8046 let forward_code = &forward_body[..forward_end];
8047
8048 assert!(
8049 !forward_code.contains("dispatch_matmul_q4"),
8050 "f32 model forward should not use dispatch_matmul_q4"
8051 );
8052 }
8053
8054 #[test]
8055 fn q4_model_lm_head_uses_q4_buffer() {
8056 let config = minimal_q4_config();
8057 let model_rs = generate_model_rs(&config).unwrap();
8058
8059 assert!(
8060 model_rs.contains("let lm_head_buf = next_q4_buffer"),
8061 "Q4_0 model should use next_q4_buffer for lm_head"
8062 );
8063 }
8064
8065 #[test]
8066 fn vec_tile_size_matches_model_dimensions() {
8067 let small = minimal_config();
8069 let shaders_small = generate_metal_shaders(&small);
8070 assert!(
8071 shaders_small.contains("vec_tile[128]"),
8072 "vec_tile should be sized to max(hidden, intermediate) = 128"
8073 );
8074
8075 let mut large = minimal_config();
8077 large.hidden_size = 2048;
8078 large.intermediate_size = 8192;
8079 let shaders_large = generate_metal_shaders(&large);
8080 assert!(
8081 shaders_large.contains("vec_tile[8192]"),
8082 "vec_tile should be 8192 for models with intermediate=8192"
8083 );
8084 assert!(
8085 !shaders_large.contains("vec_tile[4096]"),
8086 "vec_tile should NOT be hardcoded to 4096"
8087 );
8088 }
8089
8090 #[test]
8091 fn generated_cargo_toml_has_server_deps() {
8092 let toml = generate_cargo_toml("my-model");
8093 assert!(
8094 toml.contains("tiny_http"),
8095 "Cargo.toml should depend on tiny_http"
8096 );
8097 assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
8098 assert!(
8099 toml.contains("serde_json"),
8100 "Cargo.toml should depend on serde_json"
8101 );
8102 }
8103
8104 #[test]
8105 fn generated_main_rs_has_serve_mode() {
8106 let config = minimal_config();
8107 let main_rs = generate_main_rs("test-model", &config).unwrap();
8108
8109 assert!(
8110 main_rs.contains("--serve"),
8111 "main.rs should parse --serve flag"
8112 );
8113 assert!(
8114 main_rs.contains("--port"),
8115 "main.rs should parse --port flag"
8116 );
8117 assert!(
8118 main_rs.contains("fn serve("),
8119 "main.rs should define serve function"
8120 );
8121 assert!(
8122 main_rs.contains("tiny_http::Server::http"),
8123 "main.rs should create tiny_http server"
8124 );
8125 }
8126
8127 #[test]
8128 fn generated_main_rs_has_chat_completions_endpoint() {
8129 let config = minimal_config();
8130 let main_rs = generate_main_rs("test-model", &config).unwrap();
8131
8132 assert!(
8133 main_rs.contains("/v1/chat/completions"),
8134 "main.rs should handle /v1/chat/completions endpoint"
8135 );
8136 assert!(
8137 main_rs.contains("/v1/models"),
8138 "main.rs should handle /v1/models endpoint"
8139 );
8140 assert!(
8141 main_rs.contains("/health"),
8142 "main.rs should handle /health endpoint"
8143 );
8144 }
8145
8146 #[test]
8147 fn generated_main_rs_has_sse_streaming() {
8148 let config = minimal_config();
8149 let main_rs = generate_main_rs("test-model", &config).unwrap();
8150
8151 assert!(
8152 main_rs.contains("text/event-stream"),
8153 "main.rs should set SSE content type for streaming"
8154 );
8155 assert!(
8156 main_rs.contains("chat.completion.chunk"),
8157 "main.rs should emit SSE chunks"
8158 );
8159 assert!(
8160 main_rs.contains("[DONE]"),
8161 "main.rs should emit [DONE] sentinel"
8162 );
8163 }
8164
8165 #[test]
8166 fn generated_main_rs_has_chat_message_formatting() {
8167 let config = minimal_config();
8168 let main_rs = generate_main_rs("test-model", &config).unwrap();
8169
8170 assert!(
8171 main_rs.contains("fn format_chat_messages"),
8172 "main.rs should define format_chat_messages function"
8173 );
8174 assert!(
8175 main_rs.contains("<|im_start|>"),
8176 "main.rs should use ChatML format"
8177 );
8178 assert!(
8179 main_rs.contains("<|im_end|>"),
8180 "main.rs should use ChatML format"
8181 );
8182 }
8183
8184 #[test]
8185 fn generated_main_rs_has_request_types() {
8186 let config = minimal_config();
8187 let main_rs = generate_main_rs("test-model", &config).unwrap();
8188
8189 assert!(
8190 main_rs.contains("struct ChatRequest"),
8191 "main.rs should define ChatRequest struct"
8192 );
8193 assert!(
8194 main_rs.contains("struct ChatMessage"),
8195 "main.rs should define ChatMessage struct"
8196 );
8197 assert!(
8198 main_rs.contains("Deserialize"),
8199 "main.rs should derive Deserialize for request types"
8200 );
8201 }
8202
8203 #[test]
8204 fn generated_model_has_reset_method() {
8205 let config = minimal_config();
8206 let model_rs = generate_model_rs(&config).unwrap();
8207
8208 assert!(
8209 model_rs.contains("pub fn reset(&mut self)"),
8210 "model.rs should have a reset() method for multi-request serving"
8211 );
8212 assert!(
8213 model_rs.contains("self.pos = 0"),
8214 "reset() should reset position to 0"
8215 );
8216 }
8217
8218 #[test]
8219 fn generated_main_rs_cli_mode_still_works() {
8220 let config = minimal_config();
8221 let main_rs = generate_main_rs("test-model", &config).unwrap();
8222
8223 assert!(
8225 main_rs.contains("fn cli_mode("),
8226 "main.rs should define cli_mode function"
8227 );
8228 assert!(
8229 main_rs.contains("model.forward"),
8230 "main.rs should call model.forward"
8231 );
8232 assert!(
8233 main_rs.contains("model.forward_prefill"),
8234 "main.rs should call model.forward_prefill"
8235 );
8236 }
8237
8238 #[test]
8241 fn generated_shaders_contain_batch_kernels() {
8242 let shaders = generate_metal_shaders(&minimal_config());
8243
8244 assert!(
8245 shaders.contains("kernel void matmul_vec_batch"),
8246 "shaders should contain matmul_vec_batch kernel"
8247 );
8248 assert!(
8249 shaders.contains("kernel void matmul_vec_q8_batch"),
8250 "shaders should contain matmul_vec_q8_batch kernel"
8251 );
8252 assert!(
8253 shaders.contains("kernel void matmul_q8_gemm_batch"),
8254 "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
8255 );
8256 assert!(
8257 shaders.contains("kernel void matmul_vec_q4_batch"),
8258 "shaders should contain matmul_vec_q4_batch kernel"
8259 );
8260 assert!(
8261 shaders.contains("kernel void rms_norm_batch"),
8262 "shaders should contain rms_norm_batch kernel"
8263 );
8264 assert!(
8265 shaders.contains("kernel void silu_mul_fused_batch"),
8266 "shaders should contain silu_mul_fused_batch kernel"
8267 );
8268 assert!(
8269 shaders.contains("kernel void add_inplace_batch"),
8270 "shaders should contain add_inplace_batch kernel"
8271 );
8272 assert!(
8273 shaders.contains("kernel void copy_embedding_batch"),
8274 "shaders should contain copy_embedding_batch kernel"
8275 );
8276 }
8277
8278 #[test]
8279 fn generated_model_has_batch_pipelines() {
8280 let config = minimal_config();
8281 let model_rs = generate_model_rs(&config).unwrap();
8282
8283 for pipeline in &[
8284 "matmul_batch_pipeline",
8285 "matmul_q8_batch_pipeline",
8286 "matmul_q8_gemm_batch_pipeline",
8287 "matmul_q4_batch_pipeline",
8288 "rms_norm_batch_pipeline",
8289 "rope_batch_pipeline",
8290 "silu_mul_fused_batch_pipeline",
8291 "add_inplace_batch_pipeline",
8292 "copy_embedding_batch_pipeline",
8293 "attention_batch_pipeline",
8294 "copy_kv_batch_pipeline",
8295 ] {
8296 assert!(
8297 model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
8298 "MetalModel should have {pipeline} field"
8299 );
8300 }
8301 }
8302
8303 #[test]
8304 fn generated_model_has_batch_buffers() {
8305 let config = minimal_config();
8306 let model_rs = generate_model_rs(&config).unwrap();
8307
8308 for buf in &[
8309 "batch_hidden_buf",
8310 "batch_residual_buf",
8311 "batch_qkv_buf",
8312 "batch_attn_out_buf",
8313 "batch_attn_proj_buf",
8314 "batch_gate_up_buf",
8315 "batch_ffn_hidden_buf",
8316 "batch_ffn_out_buf",
8317 "batch_tokens_buf",
8318 "batch_positions_buf",
8319 ] {
8320 assert!(
8321 model_rs.contains(&format!("{buf}: Buffer")),
8322 "MetalModel should have {buf} field"
8323 );
8324 }
8325 }
8326
8327 #[test]
8328 fn generated_model_has_forward_prefill_batch() {
8329 let config = minimal_config();
8330 let model_rs = generate_model_rs(&config).unwrap();
8331
8332 assert!(
8333 model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
8334 "MetalModel should have forward_prefill_batch method"
8335 );
8336
8337 assert!(
8339 model_rs.contains("self.forward_prefill_batch(&[token_id])"),
8340 "forward_prefill should delegate to forward_prefill_batch"
8341 );
8342 }
8343
8344 #[test]
8345 fn generated_model_has_max_batch_size_constant() {
8346 let config = minimal_config();
8347 let model_rs = generate_model_rs(&config).unwrap();
8348
8349 assert!(
8350 model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
8351 "model.rs should define MAX_BATCH_SIZE constant"
8352 );
8353 }
8354
8355 #[test]
8356 fn forward_prefill_batch_uses_batch_dispatch() {
8357 let config = minimal_config();
8358 let model_rs = generate_model_rs(&config).unwrap();
8359
8360 let batch_start = model_rs
8361 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8362 .unwrap();
8363 let batch_body = &model_rs[batch_start..];
8364 let batch_end = batch_body
8365 .find("\n pub fn reset")
8366 .unwrap_or(batch_body.len());
8367 let batch_code = &batch_body[..batch_end];
8368
8369 assert!(
8371 batch_code.contains("dispatch_rms_norm_batch"),
8372 "forward_prefill_batch should use dispatch_rms_norm_batch"
8373 );
8374 assert!(
8375 batch_code.contains("dispatch_copy_embedding_batch"),
8376 "forward_prefill_batch should use dispatch_copy_embedding_batch"
8377 );
8378 assert!(
8379 batch_code.contains("dispatch_silu_mul_fused_batch"),
8380 "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8381 );
8382 assert!(
8384 batch_code.contains("dispatch_attention_batch"),
8385 "forward_prefill_batch should use dispatch_attention_batch"
8386 );
8387 assert!(
8389 batch_code.contains("dispatch_copy_kv_both_batch"),
8390 "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8391 );
8392 assert!(
8394 batch_code.contains("dispatch_rope_qk_batch"),
8395 "forward_prefill_batch should use dispatch_rope_qk_batch"
8396 );
8397 }
8398
8399 #[test]
8400 fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8401 let config = minimal_q8_config();
8402 let model_rs = generate_model_rs(&config).unwrap();
8403
8404 let batch_start = model_rs
8405 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8406 .unwrap();
8407 let batch_body = &model_rs[batch_start..];
8408 let batch_end = batch_body
8409 .find("\n pub fn reset")
8410 .unwrap_or(batch_body.len());
8411 let batch_code = &batch_body[..batch_end];
8412
8413 assert!(
8414 batch_code.contains("dispatch_matmul_q8_batch"),
8415 "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8416 );
8417 }
8418
8419 #[test]
8420 fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8421 let config = minimal_q4_config();
8422 let model_rs = generate_model_rs(&config).unwrap();
8423
8424 let batch_start = model_rs
8425 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8426 .unwrap();
8427 let batch_body = &model_rs[batch_start..];
8428 let batch_end = batch_body
8429 .find("\n pub fn reset")
8430 .unwrap_or(batch_body.len());
8431 let batch_code = &batch_body[..batch_end];
8432
8433 assert!(
8434 batch_code.contains("dispatch_matmul_q4_batch"),
8435 "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8436 );
8437 }
8438
8439 #[test]
8440 fn generated_main_rs_uses_batched_prefill() {
8441 let config = minimal_config();
8442 let main_rs = generate_main_rs("test-model", &config).unwrap();
8443
8444 assert!(
8445 main_rs.contains("forward_prefill_batch"),
8446 "main.rs should use forward_prefill_batch for prompt tokens"
8447 );
8448 }
8449
8450 #[test]
8451 fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8452 let config = minimal_config();
8453 let model_rs = generate_model_rs(&config).unwrap();
8454
8455 let batch_start = model_rs
8456 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8457 .unwrap();
8458 let batch_body = &model_rs[batch_start..];
8459 let batch_end = batch_body
8460 .find("\n pub fn reset")
8461 .unwrap_or(batch_body.len());
8462 let batch_code = &batch_body[..batch_end];
8463
8464 assert!(
8465 batch_code.contains("dispatch_matmul_batch"),
8466 "f32 forward_prefill_batch should use dispatch_matmul_batch"
8467 );
8468 assert!(
8470 !batch_code.contains("dispatch_matmul_q8_batch"),
8471 "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8472 );
8473 assert!(
8474 !batch_code.contains("dispatch_matmul_q4_batch"),
8475 "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8476 );
8477 }
8478
8479 #[test]
8480 fn forward_uses_cpu_embedding_lookup() {
8481 let config = minimal_config();
8482 let model_rs = generate_model_rs(&config).unwrap();
8483
8484 let forward_start = model_rs
8486 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8487 .unwrap();
8488 let forward_body = &model_rs[forward_start..];
8489 let forward_end = forward_body
8490 .find("\n pub fn forward_profile")
8491 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8492 .unwrap_or(forward_body.len());
8493 let forward_code = &forward_body[..forward_end];
8494
8495 assert!(
8497 forward_code.contains("embed_buf.contents()"),
8498 "forward() should access embed_buf via CPU unified memory for embedding lookup"
8499 );
8500 assert!(
8501 forward_code.contains("copy_nonoverlapping"),
8502 "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8503 );
8504 assert!(
8506 !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8507 "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8508 );
8509 }
8510
8511 #[test]
8512 fn forward_profile_method_exists() {
8513 let config = minimal_config();
8514 let model_rs = generate_model_rs(&config).unwrap();
8515
8516 assert!(
8517 model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8518 "MetalModel should have forward_profile() method"
8519 );
8520 assert!(
8522 model_rs.contains("[profile]"),
8523 "forward_profile() should print timing with [profile] prefix"
8524 );
8525 assert!(
8526 model_rs.contains("d_embed"),
8527 "forward_profile() should measure embedding time"
8528 );
8529 assert!(
8530 model_rs.contains("d_layers"),
8531 "forward_profile() should measure layer time"
8532 );
8533 assert!(
8534 model_rs.contains("d_logits"),
8535 "forward_profile() should measure logits time"
8536 );
8537 }
8538
8539 #[test]
8540 fn generated_cli_has_profile_flag() {
8541 let config = minimal_config();
8542 let main_rs = generate_main_rs("test-model", &config).unwrap();
8543
8544 assert!(
8545 main_rs.contains("--profile"),
8546 "CLI should support --profile flag"
8547 );
8548 assert!(
8549 main_rs.contains("forward_profile"),
8550 "CLI should call forward_profile when --profile is set"
8551 );
8552 }
8553
8554 #[test]
8555 fn generated_cli_has_thermal_yield() {
8556 let config = minimal_config();
8557 let main_rs = generate_main_rs("test-model", &config).unwrap();
8558
8559 assert!(
8560 main_rs.contains("yield_now()"),
8561 "CLI generation loop should include thread::yield_now() for thermal management"
8562 );
8563 }
8564
8565 #[test]
8568 fn generated_forward_handles_single_token_prompt() {
8569 let config = minimal_config();
8573 let model_rs = generate_model_rs(&config).unwrap();
8574
8575 let forward_start = model_rs
8577 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8578 .expect("forward() must exist");
8579 let forward_body = &model_rs[forward_start..forward_start + 400];
8580
8581 assert!(
8583 !forward_body.contains("assert!(self.pos > 0"),
8584 "forward() must accept pos=0 (first token with no prefill)"
8585 );
8586
8587 assert!(
8589 model_rs.contains("self.pos"),
8590 "forward() should use self.pos to track sequence position"
8591 );
8592 }
8593
8594 #[test]
8595 fn generated_reset_clears_kv_cache_position() {
8596 let config = minimal_config();
8599 let model_rs = generate_model_rs(&config).unwrap();
8600
8601 let reset_start = model_rs
8602 .find("pub fn reset(&mut self)")
8603 .expect("reset() must exist");
8604 let reset_body = &model_rs[reset_start..reset_start + 200];
8605
8606 assert!(
8608 reset_body.contains("self.pos = 0"),
8609 "reset() must set self.pos = 0"
8610 );
8611
8612 assert!(
8614 reset_body.contains("self.prev_cmd = None"),
8615 "reset() should clear prev_cmd for clean command buffer state"
8616 );
8617 }
8618
8619 #[test]
8620 fn generated_serve_handles_empty_messages_gracefully() {
8621 let config = minimal_config();
8624 let main_rs = generate_main_rs("test-model", &config).unwrap();
8625
8626 let format_fn_start = main_rs
8628 .find("fn format_chat_messages")
8629 .expect("format_chat_messages must exist");
8630 let format_fn_body =
8631 &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8632
8633 assert!(
8635 format_fn_body.contains("for msg in messages"),
8636 "format_chat_messages should iterate over the messages slice"
8637 );
8638 assert!(
8640 format_fn_body.contains("<|im_start|>assistant"),
8641 "format_chat_messages should always append assistant prompt header"
8642 );
8643
8644 let serve_fn_start = main_rs
8646 .find("fn serve(")
8647 .expect("serve function must exist");
8648 let serve_fn_body = &main_rs[serve_fn_start..];
8649 assert!(
8650 serve_fn_body.contains("model.reset()"),
8651 "serve function should reset model between requests"
8652 );
8653 }
8654
8655 #[test]
8656 fn generated_model_forward_increments_pos() {
8657 let config = minimal_config();
8660 let model_rs = generate_model_rs(&config).unwrap();
8661
8662 let forward_start = model_rs
8663 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8664 .unwrap();
8665 let forward_body = &model_rs[forward_start..];
8666 let forward_end = forward_body
8667 .find("\n pub fn forward_profile")
8668 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8669 .or_else(|| forward_body.find("\n fn dispatch_"))
8670 .unwrap_or(forward_body.len());
8671 let forward_code = &forward_body[..forward_end];
8672
8673 assert!(
8674 forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8675 "forward() must increment self.pos after processing a token"
8676 );
8677 }
8678}