1use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16 #[error("graph has no model config")]
18 MissingConfig,
19
20 #[error("I/O error: {0}")]
22 Io(#[from] std::io::Error),
23
24 #[error("format error: {0}")]
26 Fmt(#[from] std::fmt::Error),
27}
28
29pub fn generate_metal_project(
37 graph: &Graph,
38 output_dir: &Path,
39 model_name: &str,
40) -> Result<(), MetalCodegenError> {
41 let config = graph
42 .config
43 .as_ref()
44 .ok_or(MetalCodegenError::MissingConfig)?;
45
46 let src_dir = output_dir.join("src");
47 let shader_dir = output_dir.join("shaders");
48 fs::create_dir_all(&src_dir)?;
49 fs::create_dir_all(&shader_dir)?;
50
51 fs::write(
52 output_dir.join("Cargo.toml"),
53 generate_cargo_toml(model_name),
54 )?;
55
56 fs::write(
57 shader_dir.join("kernels.metal"),
58 generate_metal_shaders(config),
59 )?;
60
61 let model_rs = generate_model_rs(config)?;
62 fs::write(src_dir.join("model.rs"), model_rs)?;
63
64 let main_rs = generate_main_rs(model_name, config)?;
65 fs::write(src_dir.join("main.rs"), main_rs)?;
66
67 Ok(())
68}
69
70fn sanitize_name(name: &str) -> String {
75 name.to_lowercase()
76 .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77 .trim_matches('-')
78 .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82 let sanitized = sanitize_name(model_name);
83 format!(
84 r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108 )
109}
110
111fn generate_metal_shaders(config: &ModelConfig) -> String {
116 let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121 r#"//
122// Auto-generated by ForgeLLM Metal codegen.
123// Metal Shading Language compute kernels for transformer inference.
124//
125// Optimized with simdgroup cooperative reductions, shared memory vector
126// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
127// lane, and fast:: math intrinsics for Apple Silicon throughput.
128//
129
130#include <metal_stdlib>
131#include <metal_simdgroup_matrix>
132using namespace metal;
133
134// ── Constants ───────────────────────────────────────────────────────────
135// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
136// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
137// per shared memory load vs 4-row, improving ILP and reducing launches.
138constant constexpr uint SIMDGROUPS_PER_TG = 8;
139constant constexpr uint ROWS_PER_SIMDGROUP = 8;
140constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
141
142// ── matmul_vec ──────────────────────────────────────────────────────────
143// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
144// Uses simdgroup cooperative dot product with shared memory vector caching
145// and float4 vectorized loads. Each simdgroup processes 8 rows for better
146// shared memory reuse (8x vector reuse per load) and instruction-level
147// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
148kernel void matmul_vec(
149 device const float* matrix [[buffer(0)]],
150 device const float* vector [[buffer(1)]],
151 device float* output [[buffer(2)]],
152 constant uint& rows [[buffer(3)]],
153 constant uint& cols [[buffer(4)]],
154 uint tgid [[threadgroup_position_in_grid]],
155 uint tid [[thread_index_in_threadgroup]],
156 uint simd_lane [[thread_index_in_simdgroup]],
157 uint simd_id [[simdgroup_index_in_threadgroup]])
158{
159 // Cooperatively load vector into threadgroup shared memory
160 threadgroup float vec_tile[VEC_TILE_SIZE]; // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
161 for (uint i = tid; i < cols; i += 256) {
162 vec_tile[i] = vector[i];
163 }
164 threadgroup_barrier(mem_flags::mem_threadgroup);
165
166 // Each simdgroup handles 8 consecutive rows
167 uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
168 if (row_base >= rows) return;
169
170 uint base0 = row_base * cols;
171 uint base1 = (row_base + 1) * cols;
172 uint base2 = (row_base + 2) * cols;
173 uint base3 = (row_base + 3) * cols;
174 uint base4 = (row_base + 4) * cols;
175 uint base5 = (row_base + 5) * cols;
176 uint base6 = (row_base + 6) * cols;
177 uint base7 = (row_base + 7) * cols;
178
179 // float4 vectorized accumulation across 8 rows
180 uint cols_vec4 = cols & ~127u; // largest multiple of 128 <= cols
181 float4 sum4_0 = float4(0.0f);
182 float4 sum4_1 = float4(0.0f);
183 float4 sum4_2 = float4(0.0f);
184 float4 sum4_3 = float4(0.0f);
185 float4 sum4_4 = float4(0.0f);
186 float4 sum4_5 = float4(0.0f);
187 float4 sum4_6 = float4(0.0f);
188 float4 sum4_7 = float4(0.0f);
189
190 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
191 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
192 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
193 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
194 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
195 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
196 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
197 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
198 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
199 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
200 }
201
202 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
203 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
204 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
205 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
206 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
207 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
208 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
209 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
210
211 // Handle remaining elements (cols not divisible by 128)
212 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
213 float vv = vec_tile[j];
214 sum0 += matrix[base0 + j] * vv;
215 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
216 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
217 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
218 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
219 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
220 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
221 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
222 }
223
224 // Simdgroup hardware warp-level reduction
225 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
226 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
227 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
228 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
229
230 // Only first lane writes the results
231 if (simd_lane == 0) {
232 if (row_base < rows) output[row_base] = sum0;
233 if (row_base + 1 < rows) output[row_base + 1] = sum1;
234 if (row_base + 2 < rows) output[row_base + 2] = sum2;
235 if (row_base + 3 < rows) output[row_base + 3] = sum3;
236 if (row_base + 4 < rows) output[row_base + 4] = sum4;
237 if (row_base + 5 < rows) output[row_base + 5] = sum5;
238 if (row_base + 6 < rows) output[row_base + 6] = sum6;
239 if (row_base + 7 < rows) output[row_base + 7] = sum7;
240 }
241}
242
243// ── rms_norm ────────────────────────────────────────────────────────────
244// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
245// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
246// via shared memory for minimal synchronization overhead.
247kernel void rms_norm(
248 device const float* input [[buffer(0)]],
249 device const float* weight [[buffer(1)]],
250 device float* output [[buffer(2)]],
251 constant uint& n [[buffer(3)]],
252 constant float& eps [[buffer(4)]],
253 uint tid [[thread_index_in_threadgroup]])
254{
255 // Each thread accumulates partial sum-of-squares
256 float sum_sq = 0.0f;
257 for (uint i = tid; i < n; i += 256) {
258 float v = input[i];
259 sum_sq += v * v;
260 }
261
262 // Simdgroup-level reduction (hardware warp sum)
263 sum_sq = simd_sum(sum_sq);
264
265 // Cross-simdgroup reduction via shared memory
266 threadgroup float shared[8];
267 uint simd_id = tid / 32;
268 uint simd_lane = tid % 32;
269 if (simd_lane == 0) {
270 shared[simd_id] = sum_sq;
271 }
272 threadgroup_barrier(mem_flags::mem_threadgroup);
273
274 // First thread computes final inverse RMS
275 if (tid == 0) {
276 float total = 0.0f;
277 for (uint i = 0; i < 8; i++) {
278 total += shared[i];
279 }
280 shared[0] = fast::rsqrt(total / float(n) + eps);
281 }
282 threadgroup_barrier(mem_flags::mem_threadgroup);
283
284 float inv_rms = shared[0];
285
286 // Normalize
287 for (uint i = tid; i < n; i += 256) {
288 output[i] = input[i] * inv_rms * weight[i];
289 }
290}
291
292// ── rope ────────────────────────────────────────────────────────────────
293// Rotary Position Embedding applied in-place.
294// Each thread handles one (head, pair) combination.
295kernel void rope(
296 device float* data [[buffer(0)]],
297 constant uint& num_heads [[buffer(1)]],
298 constant uint& head_dim [[buffer(2)]],
299 constant uint& pos [[buffer(3)]],
300 constant float& theta [[buffer(4)]],
301 uint id [[thread_position_in_grid]])
302{
303 uint half_dim = head_dim / 2;
304 uint total_pairs = num_heads * half_dim;
305 if (id >= total_pairs) return;
306
307 uint h = id / half_dim;
308 uint i = id % half_dim;
309 uint off = h * head_dim;
310
311 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
312 float angle = float(pos) * freq;
313 float c = cos(angle);
314 float s = sin(angle);
315
316 float x0 = data[off + 2 * i];
317 float x1 = data[off + 2 * i + 1];
318 data[off + 2 * i] = x0 * c - x1 * s;
319 data[off + 2 * i + 1] = x0 * s + x1 * c;
320}
321
322// ── softmax ─────────────────────────────────────────────────────────────
323// Numerically stable softmax over a 1-D array.
324// Single-threadgroup kernel with cooperative reduction.
325kernel void softmax(
326 device float* data [[buffer(0)]],
327 constant uint& n [[buffer(1)]],
328 uint tid [[thread_index_in_threadgroup]],
329 uint tg_size [[threads_per_threadgroup]])
330{
331 threadgroup float shared_val[256];
332
333 // Pass 1: find max
334 float local_max = -INFINITY;
335 for (uint i = tid; i < n; i += tg_size) {
336 local_max = max(local_max, data[i]);
337 }
338 shared_val[tid] = local_max;
339 threadgroup_barrier(mem_flags::mem_threadgroup);
340
341 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
342 if (tid < stride) {
343 shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
344 }
345 threadgroup_barrier(mem_flags::mem_threadgroup);
346 }
347 float max_val = shared_val[0];
348 threadgroup_barrier(mem_flags::mem_threadgroup);
349
350 // Pass 2: exp and sum
351 float local_sum = 0.0f;
352 for (uint i = tid; i < n; i += tg_size) {
353 float e = fast::exp(data[i] - max_val);
354 data[i] = e;
355 local_sum += e;
356 }
357 shared_val[tid] = local_sum;
358 threadgroup_barrier(mem_flags::mem_threadgroup);
359
360 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
361 if (tid < stride) {
362 shared_val[tid] += shared_val[tid + stride];
363 }
364 threadgroup_barrier(mem_flags::mem_threadgroup);
365 }
366 float inv_sum = 1.0f / shared_val[0];
367 threadgroup_barrier(mem_flags::mem_threadgroup);
368
369 // Pass 3: normalize
370 for (uint i = tid; i < n; i += tg_size) {
371 data[i] *= inv_sum;
372 }
373}
374
375// ── silu_mul ────────────────────────────────────────────────────────────
376// Fused SiLU activation * element-wise multiply:
377// output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
378kernel void silu_mul(
379 device const float* gate [[buffer(0)]],
380 device const float* up [[buffer(1)]],
381 device float* output [[buffer(2)]],
382 constant uint& n [[buffer(3)]],
383 uint id [[thread_position_in_grid]])
384{
385 if (id >= n) return;
386 float g = gate[id];
387 output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
388}
389
390// ── silu_mul_fused ─────────────────────────────────────────────────────
391// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
392// gate = gate_up[0..n], up = gate_up[n..2*n]
393// output[i] = silu(gate_up[i]) * gate_up[n + i]
394kernel void silu_mul_fused(
395 device const float* gate_up [[buffer(0)]],
396 device float* output [[buffer(1)]],
397 constant uint& n [[buffer(2)]],
398 uint id [[thread_position_in_grid]])
399{
400 if (id >= n) return;
401 float g = gate_up[id];
402 float u = gate_up[n + id];
403 output[id] = (g / (1.0f + fast::exp(-g))) * u;
404}
405
406// ── elementwise_add ─────────────────────────────────────────────────────
407// Residual connection: output[i] = a[i] + b[i]
408kernel void elementwise_add(
409 device const float* a [[buffer(0)]],
410 device const float* b [[buffer(1)]],
411 device float* output [[buffer(2)]],
412 constant uint& n [[buffer(3)]],
413 uint id [[thread_position_in_grid]])
414{
415 if (id >= n) return;
416 output[id] = a[id] + b[id];
417}
418
419// ── copy_buffer ─────────────────────────────────────────────────────────
420// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
421// transitions. Used for KV cache updates and embedding lookup.
422kernel void copy_buffer(
423 device const float* src [[buffer(0)]],
424 device float* dst [[buffer(1)]],
425 constant uint& count [[buffer(2)]],
426 uint id [[thread_position_in_grid]])
427{
428 if (id < count) dst[id] = src[id];
429}
430
431// ── copy_offset ─────────────────────────────────────────────────────────
432// Copy with source offset (in floats). Used for embedding table lookup
433// where we need to copy a specific row from a large table.
434kernel void copy_offset(
435 device const float* src [[buffer(0)]],
436 device float* dst [[buffer(1)]],
437 constant uint& src_offset [[buffer(2)]], // in floats
438 constant uint& count [[buffer(3)]],
439 uint id [[thread_position_in_grid]])
440{
441 if (id < count) dst[id] = src[src_offset + id];
442}
443
444// ── add_inplace ─────────────────────────────────────────────────────────
445// In-place residual connection: a[i] += b[i]
446// Avoids a separate blit copy for residual add, reducing encoder overhead.
447kernel void add_inplace(
448 device float* a [[buffer(0)]],
449 device const float* b [[buffer(1)]],
450 constant uint& n [[buffer(2)]],
451 uint id [[thread_position_in_grid]])
452{
453 if (id >= n) return;
454 a[id] += b[id];
455}
456
457// ── matmul_vec_q8 ─────────────────────────────────────────────────────
458// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
459// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
460// Operates directly on quantized weights to halve memory bandwidth vs f32,
461// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
462//
463// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
464// because int8->float conversion doubles register demand. Fully unrolled
465// inner loop with float4 vector loads from shared memory eliminates loop
466// overhead and enables better instruction scheduling.
467// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
468constant constexpr uint Q8_ROWS_PER_SG = 4;
469constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
470
471// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
472// doubles ALU work, so the same register budget applies).
473constant constexpr uint Q4_ROWS_PER_SG = 4;
474constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
475
476kernel void matmul_vec_q8(
477 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes
478 device const float* vector [[buffer(1)]], // f32 input
479 device float* output [[buffer(2)]],
480 constant uint& rows [[buffer(3)]],
481 constant uint& cols [[buffer(4)]], // number of elements per row
482 uint tgid [[threadgroup_position_in_grid]],
483 uint tid [[thread_index_in_threadgroup]],
484 uint simd_lane [[thread_index_in_simdgroup]],
485 uint simd_id [[simdgroup_index_in_threadgroup]])
486{
487 // Load vector into shared memory
488 threadgroup float vec_tile[VEC_TILE_SIZE];
489 for (uint i = tid; i < cols; i += 256) {
490 vec_tile[i] = vector[i];
491 }
492 threadgroup_barrier(mem_flags::mem_threadgroup);
493
494 // Each simdgroup handles 4 consecutive rows (lower register pressure)
495 uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
496 if (row_base >= rows) return;
497
498 // Q8_0: each block is 34 bytes for 32 elements
499 uint blocks_per_row = cols / 32;
500 uint row_bytes = blocks_per_row * 34;
501
502 // Pointers to each row's Q8_0 data
503 device const uchar* r0 = matrix + row_base * row_bytes;
504 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
505 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
506 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
507
508 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
509
510 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
511 uint bb = blk * 34;
512 uint vb = blk * 32;
513
514 // Prefetch all 4 scales
515 float sc0 = float(*(device const half*)(r0 + bb));
516 float sc1 = float(*(device const half*)(r1 + bb));
517 float sc2 = float(*(device const half*)(r2 + bb));
518 float sc3 = float(*(device const half*)(r3 + bb));
519
520 // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
521 // Q8_0 block layout where the int8 data starts at offset +2 from a
522 // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
523 // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
524 // reduction in memory transactions. Metal's char16/packed_char16 are
525 // reserved types and packed_*int4 require >=4-byte alignment which
526 // this layout does not provide, so packed_short4 is the widest valid
527 // vectorized load.
528 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
529 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
530 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
531 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
532
533 // Load all 8 float4 vector values for this 32-element block from shared memory
534 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
535 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
536 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
537 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
538 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
539 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
540 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
541 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
542
543 // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
544 // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
545 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
546 short4 _s = short4(SHORT4); \
547 char2 _a = as_type<char2>(_s.x); \
548 char2 _b = as_type<char2>(_s.y); \
549 char2 _c = as_type<char2>(_s.z); \
550 char2 _d = as_type<char2>(_s.w); \
551 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
552 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
553 }
554
555 float4 f0, f1;
556 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
557
558 // Row 0: 4 short4 loads cover 32 int8 weights
559 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
560 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
561 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
562 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
563
564 // Row 1
565 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
566 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
567 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
568 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
569
570 // Row 2
571 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
572 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
573 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
574 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
575
576 // Row 3
577 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
578 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
579 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
580 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
581
582 #undef Q8_UNPACK8
583
584 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
585 }
586
587 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
588 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
589
590 if (simd_lane == 0) {
591 if (row_base < rows) output[row_base] = sum0;
592 if (row_base + 1 < rows) output[row_base + 1] = sum1;
593 if (row_base + 2 < rows) output[row_base + 2] = sum2;
594 if (row_base + 3 < rows) output[row_base + 3] = sum3;
595 }
596}
597
598// ── matmul_vec_q4 ─────────────────────────────────────────────────────
599// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
600// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
601// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
602// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
603//
604// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
605// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
606kernel void matmul_vec_q4(
607 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes
608 device const float* vector [[buffer(1)]], // f32 input
609 device float* output [[buffer(2)]],
610 constant uint& rows [[buffer(3)]],
611 constant uint& cols [[buffer(4)]], // number of elements per row
612 uint tgid [[threadgroup_position_in_grid]],
613 uint tid [[thread_index_in_threadgroup]],
614 uint simd_lane [[thread_index_in_simdgroup]],
615 uint simd_id [[simdgroup_index_in_threadgroup]])
616{
617 // Load vector into shared memory
618 threadgroup float vec_tile[VEC_TILE_SIZE];
619 for (uint i = tid; i < cols; i += 256) {
620 vec_tile[i] = vector[i];
621 }
622 threadgroup_barrier(mem_flags::mem_threadgroup);
623
624 // Each simdgroup handles 4 consecutive rows
625 uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
626 if (row_base >= rows) return;
627
628 // Q4_0: each block is 18 bytes for 32 elements
629 uint blocks_per_row = cols / 32;
630 uint row_bytes = blocks_per_row * 18;
631
632 // Pointers to each row's Q4_0 data
633 device const uchar* r0 = matrix + row_base * row_bytes;
634 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
635 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
636 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
637
638 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
639
640 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
641 uint bb = blk * 18;
642 uint vb = blk * 32;
643
644 // Prefetch all 4 scales
645 float sc0 = float(*(device const half*)(r0 + bb));
646 float sc1 = float(*(device const half*)(r1 + bb));
647 float sc2 = float(*(device const half*)(r2 + bb));
648 float sc3 = float(*(device const half*)(r3 + bb));
649
650 // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
651 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
652 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
653 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
654 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
655
656 // Load 8 float4 vector values for 32 elements from shared memory
657 // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
658 float4 v0 = *(threadgroup const float4*)(vec_tile + vb); // [0..3]
659 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4); // [4..7]
660 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8); // [8..11]
661 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12); // [12..15]
662 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16); // [16..19]
663 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20); // [20..23]
664 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24); // [24..27]
665 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28); // [28..31]
666
667 // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
668 // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
669 float bd0=0, bd1=0, bd2=0, bd3=0;
670 uchar4 b;
671
672 // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
673 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
674 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
675 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
676 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
677 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
678 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
679 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
680 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
681 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
682 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
683 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
684 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
685 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
686 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
687 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
688 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
689
690 // Row 1
691 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
692 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
693 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
694 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
695 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
696 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
697 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
698 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
699 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
700 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
701 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
702 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
703 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
704 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
705 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
706 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
707
708 // Row 2
709 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
710 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
711 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
712 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
713 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
714 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
715 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
716 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
717 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
718 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
719 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
720 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
721 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
722 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
723 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
724 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
725
726 // Row 3
727 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
728 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
729 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
730 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
731 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
732 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
733 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
734 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
735 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
736 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
737 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
738 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
739 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
740 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
741 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
742 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
743
744 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
745 }
746
747 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
748 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
749
750 if (simd_lane == 0) {
751 if (row_base < rows) output[row_base] = sum0;
752 if (row_base + 1 < rows) output[row_base + 1] = sum1;
753 if (row_base + 2 < rows) output[row_base + 2] = sum2;
754 if (row_base + 3 < rows) output[row_base + 3] = sum3;
755 }
756}
757
758// ── attention ───────────────────────────────────────────────────────────
759// Single-query attention with simdgroup cooperative reductions.
760// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
761// with simd_max/simd_sum reductions, then weighted sum of V.
762// Each threadgroup handles one head with 256 threads (8 simdgroups).
763//
764// Buffers:
765// q: [num_heads * head_dim] current query
766// k_cache: [max_seq_len * num_kv_heads * head_dim]
767// v_cache: [max_seq_len * num_kv_heads * head_dim]
768// output: [num_heads * head_dim]
769kernel void attention(
770 device const float* q [[buffer(0)]],
771 device const float* k_cache [[buffer(1)]],
772 device const float* v_cache [[buffer(2)]],
773 device float* output [[buffer(3)]],
774 constant uint& seq_len [[buffer(4)]],
775 constant uint& num_heads [[buffer(5)]],
776 constant uint& num_kv_heads [[buffer(6)]],
777 constant uint& head_dim [[buffer(7)]],
778 uint tgid [[threadgroup_position_in_grid]],
779 uint tid [[thread_index_in_threadgroup]],
780 uint simd_lane [[thread_index_in_simdgroup]],
781 uint simd_id [[simdgroup_index_in_threadgroup]])
782{
783 uint head = tgid;
784 if (head >= num_heads) return;
785 uint kv_head = head / (num_heads / num_kv_heads);
786
787 uint q_off = head * head_dim;
788
789 // Step 1: Compute attention scores Q·K^T with simdgroup reduction
790 // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
791 // vs 4096. For seq_len > 2048, generation-phase attention is rare;
792 // most generation steps have short effective context.
793 threadgroup float scores[2048]; // max seq_len for generation phase
794
795 for (uint s = simd_id; s < seq_len; s += 8) {
796 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
797 float dot = 0.0;
798 for (uint d = simd_lane; d < head_dim; d += 32) {
799 dot += q[q_off + d] * k_cache[k_off + d];
800 }
801 dot = simd_sum(dot);
802 if (simd_lane == 0) {
803 scores[s] = dot * fast::rsqrt(float(head_dim));
804 }
805 }
806 threadgroup_barrier(mem_flags::mem_threadgroup);
807
808 // Step 2: Softmax over scores (cooperative)
809 // Find max
810 float local_max = -INFINITY;
811 for (uint s = tid; s < seq_len; s += 256) {
812 local_max = max(local_max, scores[s]);
813 }
814 local_max = simd_max(local_max);
815 threadgroup float shared_max[8];
816 if (simd_lane == 0) shared_max[simd_id] = local_max;
817 threadgroup_barrier(mem_flags::mem_threadgroup);
818 if (tid == 0) {
819 float m = shared_max[0];
820 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
821 shared_max[0] = m;
822 }
823 threadgroup_barrier(mem_flags::mem_threadgroup);
824 float max_val = shared_max[0];
825
826 // Exp and sum
827 float local_sum = 0.0;
828 for (uint s = tid; s < seq_len; s += 256) {
829 scores[s] = fast::exp(scores[s] - max_val);
830 local_sum += scores[s];
831 }
832 local_sum = simd_sum(local_sum);
833 threadgroup float shared_sum[8];
834 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
835 threadgroup_barrier(mem_flags::mem_threadgroup);
836 if (tid == 0) {
837 float total = 0.0;
838 for (uint i = 0; i < 8; i++) total += shared_sum[i];
839 shared_sum[0] = 1.0 / total;
840 }
841 threadgroup_barrier(mem_flags::mem_threadgroup);
842 float inv_sum = shared_sum[0];
843
844 for (uint s = tid; s < seq_len; s += 256) {
845 scores[s] *= inv_sum;
846 }
847 threadgroup_barrier(mem_flags::mem_threadgroup);
848
849 // Step 3: Weighted sum of V: output = scores · V
850 // Each thread handles a range of head_dim dimensions.
851 // Process 4 sequence positions at a time for better ILP and reduced
852 // loop overhead (float4 score gather, 4 V loads per iteration).
853 uint seq_len4 = seq_len & ~3u; // largest multiple of 4 <= seq_len
854 uint v_stride = num_kv_heads * head_dim;
855 for (uint d = tid; d < head_dim; d += 256) {
856 float acc = 0.0;
857 uint v_base = kv_head * head_dim + d;
858 for (uint s = 0; s < seq_len4; s += 4) {
859 float sc0 = scores[s];
860 float sc1 = scores[s + 1];
861 float sc2 = scores[s + 2];
862 float sc3 = scores[s + 3];
863 acc += sc0 * v_cache[s * v_stride + v_base]
864 + sc1 * v_cache[(s+1) * v_stride + v_base]
865 + sc2 * v_cache[(s+2) * v_stride + v_base]
866 + sc3 * v_cache[(s+3) * v_stride + v_base];
867 }
868 for (uint s = seq_len4; s < seq_len; s++) {
869 acc += scores[s] * v_cache[s * v_stride + v_base];
870 }
871 output[q_off + d] = acc;
872 }
873}
874
875// ── Batched prefill kernels ────────────────────────────────────────────
876// These kernels process M input vectors against the same weight matrix
877// in a single dispatch, converting mat-vec into mat-mat for better GPU
878// utilization during prompt prefill.
879
880// ── rms_norm_batch ─────────────────────────────────────────────────────
881// RMS normalization for a batch of vectors.
882// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
883// Grid: M threadgroups (one per token).
884kernel void rms_norm_batch(
885 device const float* input [[buffer(0)]], // [M, n]
886 device const float* weight [[buffer(1)]], // [n]
887 device float* output [[buffer(2)]], // [M, n]
888 constant uint& n [[buffer(3)]],
889 constant float& eps [[buffer(4)]],
890 constant uint& num_tokens [[buffer(5)]],
891 uint tgid [[threadgroup_position_in_grid]],
892 uint tid [[thread_index_in_threadgroup]])
893{
894 if (tgid >= num_tokens) return;
895
896 uint base = tgid * n;
897
898 float sum_sq = 0.0f;
899 for (uint i = tid; i < n; i += 256) {
900 float v = input[base + i];
901 sum_sq += v * v;
902 }
903
904 sum_sq = simd_sum(sum_sq);
905
906 threadgroup float shared[8];
907 uint simd_id = tid / 32;
908 uint simd_lane = tid % 32;
909 if (simd_lane == 0) {
910 shared[simd_id] = sum_sq;
911 }
912 threadgroup_barrier(mem_flags::mem_threadgroup);
913
914 if (tid == 0) {
915 float total = 0.0f;
916 for (uint i = 0; i < 8; i++) {
917 total += shared[i];
918 }
919 shared[0] = fast::rsqrt(total / float(n) + eps);
920 }
921 threadgroup_barrier(mem_flags::mem_threadgroup);
922
923 float inv_rms = shared[0];
924
925 for (uint i = tid; i < n; i += 256) {
926 output[base + i] = input[base + i] * inv_rms * weight[i];
927 }
928}
929
930// ── rope_batch ─────────────────────────────────────────────────────────
931// Rotary Position Embedding for a batch of vectors with different positions.
932// data layout: [M, num_heads * head_dim], positions: [M]
933// Each thread handles one (token, head, pair) combination.
934kernel void rope_batch(
935 device float* data [[buffer(0)]], // [M, num_heads * head_dim]
936 constant uint& num_heads [[buffer(1)]],
937 constant uint& head_dim [[buffer(2)]],
938 device const uint* positions [[buffer(3)]], // [M] position per token
939 constant float& theta [[buffer(4)]],
940 constant uint& num_tokens [[buffer(5)]],
941 uint id [[thread_position_in_grid]])
942{
943 uint half_dim = head_dim / 2;
944 uint pairs_per_token = num_heads * half_dim;
945 uint total = num_tokens * pairs_per_token;
946 if (id >= total) return;
947
948 uint token = id / pairs_per_token;
949 uint rem = id % pairs_per_token;
950 uint h = rem / half_dim;
951 uint i = rem % half_dim;
952 uint off = token * (num_heads * head_dim) + h * head_dim;
953
954 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
955 float angle = float(positions[token]) * freq;
956 float c = cos(angle);
957 float s = sin(angle);
958
959 float x0 = data[off + 2 * i];
960 float x1 = data[off + 2 * i + 1];
961 data[off + 2 * i] = x0 * c - x1 * s;
962 data[off + 2 * i + 1] = x0 * s + x1 * c;
963}
964
965// ── silu_mul_fused_batch ───────────────────────────────────────────────
966// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
967// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
968kernel void silu_mul_fused_batch(
969 device const float* gate_up [[buffer(0)]], // [M, 2*n]
970 device float* output [[buffer(1)]], // [M, n]
971 constant uint& n [[buffer(2)]],
972 constant uint& num_tokens [[buffer(3)]],
973 uint id [[thread_position_in_grid]])
974{
975 uint total = num_tokens * n;
976 if (id >= total) return;
977 uint token = id / n;
978 uint i = id % n;
979 uint gu_base = token * 2 * n;
980 float g = gate_up[gu_base + i];
981 float u = gate_up[gu_base + n + i];
982 output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
983}
984
985// ── add_inplace_batch ──────────────────────────────────────────────────
986// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
987kernel void add_inplace_batch(
988 device float* a [[buffer(0)]], // [M * n]
989 device const float* b [[buffer(1)]], // [M * n]
990 constant uint& total [[buffer(2)]], // M * n
991 uint id [[thread_position_in_grid]])
992{
993 if (id >= total) return;
994 a[id] += b[id];
995}
996
997// ── copy_embedding_batch ───────────────────────────────────────────────
998// Copy M embedding rows from embedding table to a contiguous batch buffer.
999// tokens: [M] array of token IDs, each selects a row of `dim` floats.
1000kernel void copy_embedding_batch(
1001 device const float* embed [[buffer(0)]], // [vocab_size, dim]
1002 device float* output [[buffer(1)]], // [M, dim]
1003 device const uint* tokens [[buffer(2)]], // [M]
1004 constant uint& dim [[buffer(3)]],
1005 constant uint& num_tokens [[buffer(4)]],
1006 uint id [[thread_position_in_grid]])
1007{
1008 uint total = num_tokens * dim;
1009 if (id >= total) return;
1010 uint token_idx = id / dim;
1011 uint d = id % dim;
1012 output[id] = embed[tokens[token_idx] * dim + d];
1013}
1014
1015// ── matmul_vec_batch ───────────────────────────────────────────────────
1016// Batched matrix-vector multiply: process M input vectors against the same
1017// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1018// Each threadgroup handles one (token, row_group) pair.
1019kernel void matmul_vec_batch(
1020 device const float* matrix [[buffer(0)]], // [rows, cols] weight
1021 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1022 device float* outputs [[buffer(2)]], // [M, rows] output batch
1023 constant uint& num_tokens [[buffer(3)]], // M
1024 constant uint& rows [[buffer(4)]],
1025 constant uint& cols [[buffer(5)]],
1026 uint tgid [[threadgroup_position_in_grid]],
1027 uint tid [[thread_index_in_threadgroup]],
1028 uint simd_lane [[thread_index_in_simdgroup]],
1029 uint simd_id [[simdgroup_index_in_threadgroup]])
1030{
1031 uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1032 uint token = tgid / row_tgs;
1033 uint tg_in_token = tgid % row_tgs;
1034 if (token >= num_tokens) return;
1035
1036 // Load this token's input vector into shared memory
1037 threadgroup float vec_tile[VEC_TILE_SIZE];
1038 device const float* input = inputs + token * cols;
1039 for (uint i = tid; i < cols; i += 256) {
1040 vec_tile[i] = input[i];
1041 }
1042 threadgroup_barrier(mem_flags::mem_threadgroup);
1043
1044 uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1045 if (row_base >= rows) return;
1046
1047 uint base0 = row_base * cols;
1048 uint base1 = (row_base + 1) * cols;
1049 uint base2 = (row_base + 2) * cols;
1050 uint base3 = (row_base + 3) * cols;
1051 uint base4 = (row_base + 4) * cols;
1052 uint base5 = (row_base + 5) * cols;
1053 uint base6 = (row_base + 6) * cols;
1054 uint base7 = (row_base + 7) * cols;
1055
1056 uint cols_vec4 = cols & ~127u;
1057 float4 sum4_0 = float4(0.0f);
1058 float4 sum4_1 = float4(0.0f);
1059 float4 sum4_2 = float4(0.0f);
1060 float4 sum4_3 = float4(0.0f);
1061 float4 sum4_4 = float4(0.0f);
1062 float4 sum4_5 = float4(0.0f);
1063 float4 sum4_6 = float4(0.0f);
1064 float4 sum4_7 = float4(0.0f);
1065
1066 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1067 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1068 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1069 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1070 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1071 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1072 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1073 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1074 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1075 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1076 }
1077
1078 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1079 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1080 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1081 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1082 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1083 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1084 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1085 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1086
1087 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1088 float vv = vec_tile[j];
1089 sum0 += matrix[base0 + j] * vv;
1090 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1091 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1092 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1093 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1094 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1095 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1096 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1097 }
1098
1099 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1100 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1101 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1102 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1103
1104 device float* output = outputs + token * rows;
1105 if (simd_lane == 0) {
1106 if (row_base < rows) output[row_base] = sum0;
1107 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1108 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1109 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1110 if (row_base + 4 < rows) output[row_base + 4] = sum4;
1111 if (row_base + 5 < rows) output[row_base + 5] = sum5;
1112 if (row_base + 6 < rows) output[row_base + 6] = sum6;
1113 if (row_base + 7 < rows) output[row_base + 7] = sum7;
1114 }
1115}
1116
1117// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1118// Batched Q8_0 matrix-vector multiply for M input vectors.
1119// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1120kernel void matmul_vec_q8_batch(
1121 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1122 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1123 device float* outputs [[buffer(2)]], // [M, rows] output batch
1124 constant uint& num_tokens [[buffer(3)]], // M
1125 constant uint& rows [[buffer(4)]],
1126 constant uint& cols [[buffer(5)]],
1127 uint tgid [[threadgroup_position_in_grid]],
1128 uint tid [[thread_index_in_threadgroup]],
1129 uint simd_lane [[thread_index_in_simdgroup]],
1130 uint simd_id [[simdgroup_index_in_threadgroup]])
1131{
1132 uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1133 uint token = tgid / row_tgs;
1134 uint tg_in_token = tgid % row_tgs;
1135 if (token >= num_tokens) return;
1136
1137 // Load this token's input vector into shared memory
1138 threadgroup float vec_tile[VEC_TILE_SIZE];
1139 device const float* input = inputs + token * cols;
1140 for (uint i = tid; i < cols; i += 256) {
1141 vec_tile[i] = input[i];
1142 }
1143 threadgroup_barrier(mem_flags::mem_threadgroup);
1144
1145 uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1146 if (row_base >= rows) return;
1147
1148 uint blocks_per_row = cols / 32;
1149 uint row_bytes = blocks_per_row * 34;
1150
1151 device const uchar* r0 = matrix + row_base * row_bytes;
1152 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1153 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1154 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1155
1156 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1157
1158 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1159 uint bb = blk * 34;
1160 uint vb = blk * 32;
1161
1162 float sc0 = float(*(device const half*)(r0 + bb));
1163 float sc1 = float(*(device const half*)(r1 + bb));
1164 float sc2 = float(*(device const half*)(r2 + bb));
1165 float sc3 = float(*(device const half*)(r3 + bb));
1166
1167 // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1168 // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1169 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1170 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1171 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1172 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1173
1174 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1175 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1176 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1177 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1178 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1179 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1180 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1181 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1182
1183 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1184 short4 _s = short4(SHORT4); \
1185 char2 _a = as_type<char2>(_s.x); \
1186 char2 _b = as_type<char2>(_s.y); \
1187 char2 _c = as_type<char2>(_s.z); \
1188 char2 _d = as_type<char2>(_s.w); \
1189 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1190 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1191 }
1192
1193 float4 f0, f1;
1194 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1195
1196 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1197 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1198 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1199 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1200
1201 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1202 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1203 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1204 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1205
1206 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1207 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1208 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1209 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1210
1211 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1212 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1213 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1214 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1215
1216 #undef Q8_UNPACK8
1217
1218 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1219 }
1220
1221 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1222 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1223
1224 device float* output = outputs + token * rows;
1225 if (simd_lane == 0) {
1226 if (row_base < rows) output[row_base] = sum0;
1227 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1228 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1229 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1230 }
1231}
1232
1233// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1234// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1235// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1236// the Q8_0 weight blocks are fetched once from device memory and reused for
1237// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1238// per-token dispatch).
1239//
1240// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1241// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1242// with simd_sum. Token vectors are read directly from device memory inside
1243// the block loop (not cached in shared memory) so intermediate_size up to
1244// 8192 fits without spilling threadgroup memory.
1245constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1246
1247kernel void matmul_q8_gemm_batch(
1248 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1249 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1250 device float* outputs [[buffer(2)]], // [M, rows] output batch
1251 constant uint& num_tokens [[buffer(3)]], // M
1252 constant uint& rows [[buffer(4)]],
1253 constant uint& cols [[buffer(5)]],
1254 uint2 tgid [[threadgroup_position_in_grid]],
1255 uint tid [[thread_index_in_threadgroup]],
1256 uint simd_lane [[thread_index_in_simdgroup]],
1257 uint simd_id [[simdgroup_index_in_threadgroup]])
1258{
1259 uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1260 uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1261 if (row_base >= rows || tok_base >= num_tokens) return;
1262
1263 // How many tokens in this tile are valid?
1264 uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1265
1266 uint blocks_per_row = cols / 32;
1267 uint row_bytes = blocks_per_row * 34;
1268
1269 device const uchar* r0 = matrix + row_base * row_bytes;
1270 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1271 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1272 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1273
1274 // Accumulators: 4 tokens × 4 rows per simdgroup.
1275 float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1276 float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1277 float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1278 float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1279
1280 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1281 uint bb = blk * 34;
1282 uint vb = blk * 32;
1283
1284 // ── Load weight data ONCE per block (reused across all tokens) ──
1285 float sc0 = float(*(device const half*)(r0 + bb));
1286 float sc1 = float(*(device const half*)(r1 + bb));
1287 float sc2 = float(*(device const half*)(r2 + bb));
1288 float sc3 = float(*(device const half*)(r3 + bb));
1289
1290 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1291 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1292 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1293 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1294
1295 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1296 short4 _s = short4(SHORT4); \
1297 char2 _a = as_type<char2>(_s.x); \
1298 char2 _b = as_type<char2>(_s.y); \
1299 char2 _c = as_type<char2>(_s.z); \
1300 char2 _d = as_type<char2>(_s.w); \
1301 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1302 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1303 }
1304
1305 // Unpack all 4 rows × 8 float4 weights (scaled). These live in
1306 // registers for the duration of the block and are dotted against
1307 // every token's vector tile.
1308 float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1309 float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1310 float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1311 float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1312
1313 Q8_UNPACK8(d0[0], w0_0, w0_1);
1314 Q8_UNPACK8(d0[1], w0_2, w0_3);
1315 Q8_UNPACK8(d0[2], w0_4, w0_5);
1316 Q8_UNPACK8(d0[3], w0_6, w0_7);
1317
1318 Q8_UNPACK8(d1[0], w1_0, w1_1);
1319 Q8_UNPACK8(d1[1], w1_2, w1_3);
1320 Q8_UNPACK8(d1[2], w1_4, w1_5);
1321 Q8_UNPACK8(d1[3], w1_6, w1_7);
1322
1323 Q8_UNPACK8(d2[0], w2_0, w2_1);
1324 Q8_UNPACK8(d2[1], w2_2, w2_3);
1325 Q8_UNPACK8(d2[2], w2_4, w2_5);
1326 Q8_UNPACK8(d2[3], w2_6, w2_7);
1327
1328 Q8_UNPACK8(d3[0], w3_0, w3_1);
1329 Q8_UNPACK8(d3[1], w3_2, w3_3);
1330 Q8_UNPACK8(d3[2], w3_4, w3_5);
1331 Q8_UNPACK8(d3[3], w3_6, w3_7);
1332
1333 #undef Q8_UNPACK8
1334
1335 // ── For each token, read vector and accumulate against shared weights ──
1336 // Token 0 (always valid: tok_count >= 1).
1337 {
1338 device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1339 float4 v0 = *(device const float4*)(a0);
1340 float4 v1 = *(device const float4*)(a0 + 4);
1341 float4 v2 = *(device const float4*)(a0 + 8);
1342 float4 v3 = *(device const float4*)(a0 + 12);
1343 float4 v4 = *(device const float4*)(a0 + 16);
1344 float4 v5 = *(device const float4*)(a0 + 20);
1345 float4 v6 = *(device const float4*)(a0 + 24);
1346 float4 v7 = *(device const float4*)(a0 + 28);
1347 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1348 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1349 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1350 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1351 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1352 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1353 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1354 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1355 s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1356 }
1357 // Token 1
1358 if (tok_count > 1) {
1359 device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1360 float4 v0 = *(device const float4*)(a1);
1361 float4 v1 = *(device const float4*)(a1 + 4);
1362 float4 v2 = *(device const float4*)(a1 + 8);
1363 float4 v3 = *(device const float4*)(a1 + 12);
1364 float4 v4 = *(device const float4*)(a1 + 16);
1365 float4 v5 = *(device const float4*)(a1 + 20);
1366 float4 v6 = *(device const float4*)(a1 + 24);
1367 float4 v7 = *(device const float4*)(a1 + 28);
1368 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1369 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1370 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1371 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1372 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1373 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1374 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1375 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1376 s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1377 }
1378 // Token 2
1379 if (tok_count > 2) {
1380 device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1381 float4 v0 = *(device const float4*)(a2);
1382 float4 v1 = *(device const float4*)(a2 + 4);
1383 float4 v2 = *(device const float4*)(a2 + 8);
1384 float4 v3 = *(device const float4*)(a2 + 12);
1385 float4 v4 = *(device const float4*)(a2 + 16);
1386 float4 v5 = *(device const float4*)(a2 + 20);
1387 float4 v6 = *(device const float4*)(a2 + 24);
1388 float4 v7 = *(device const float4*)(a2 + 28);
1389 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1390 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1391 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1392 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1393 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1394 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1395 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1396 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1397 s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1398 }
1399 // Token 3
1400 if (tok_count > 3) {
1401 device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1402 float4 v0 = *(device const float4*)(a3);
1403 float4 v1 = *(device const float4*)(a3 + 4);
1404 float4 v2 = *(device const float4*)(a3 + 8);
1405 float4 v3 = *(device const float4*)(a3 + 12);
1406 float4 v4 = *(device const float4*)(a3 + 16);
1407 float4 v5 = *(device const float4*)(a3 + 20);
1408 float4 v6 = *(device const float4*)(a3 + 24);
1409 float4 v7 = *(device const float4*)(a3 + 28);
1410 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1411 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1412 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1413 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1414 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1415 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1416 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1417 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1418 s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1419 }
1420 }
1421
1422 // simdgroup reduction
1423 s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1424 s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1425 s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1426 s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1427
1428 if (simd_lane == 0) {
1429 device float* o0 = outputs + (tok_base + 0) * rows;
1430 if (row_base < rows) o0[row_base] = s00;
1431 if (row_base + 1 < rows) o0[row_base + 1] = s01;
1432 if (row_base + 2 < rows) o0[row_base + 2] = s02;
1433 if (row_base + 3 < rows) o0[row_base + 3] = s03;
1434
1435 if (tok_count > 1) {
1436 device float* o1 = outputs + (tok_base + 1) * rows;
1437 if (row_base < rows) o1[row_base] = s10;
1438 if (row_base + 1 < rows) o1[row_base + 1] = s11;
1439 if (row_base + 2 < rows) o1[row_base + 2] = s12;
1440 if (row_base + 3 < rows) o1[row_base + 3] = s13;
1441 }
1442 if (tok_count > 2) {
1443 device float* o2 = outputs + (tok_base + 2) * rows;
1444 if (row_base < rows) o2[row_base] = s20;
1445 if (row_base + 1 < rows) o2[row_base + 1] = s21;
1446 if (row_base + 2 < rows) o2[row_base + 2] = s22;
1447 if (row_base + 3 < rows) o2[row_base + 3] = s23;
1448 }
1449 if (tok_count > 3) {
1450 device float* o3 = outputs + (tok_base + 3) * rows;
1451 if (row_base < rows) o3[row_base] = s30;
1452 if (row_base + 1 < rows) o3[row_base + 1] = s31;
1453 if (row_base + 2 < rows) o3[row_base + 2] = s32;
1454 if (row_base + 3 < rows) o3[row_base + 3] = s33;
1455 }
1456 }
1457}
1458
1459// ── matmul_q8_mma ──────────────────────────────────────────────────────
1460// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
1461// simdgroup_matrix tiles (simdgroup_multiply_accumulate). This dispatches
1462// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
1463// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
1464//
1465// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
1466// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
1467// simdgroup_matrix<float, 8, 8> accumulator. Weight bytes are cooperatively
1468// dequantized into threadgroup memory once per block and reused by all
1469// simdgroups in the tile.
1470//
1471// Assumptions (verified in the dispatch helper, falls back otherwise):
1472// * cols % 32 == 0 (one Q8_0 block per K chunk)
1473// * rows % 16 == 0 (tile-aligned; true for all supported architectures)
1474// * num_tokens may be any value; partial row at the tile boundary is handled
1475// via a scratch copy path.
1476constant constexpr uint MMA_TOK_TILE = 16;
1477constant constexpr uint MMA_ROW_TILE = 16;
1478
1479kernel void matmul_q8_mma(
1480 device const uchar* matrix [[buffer(0)]], // Q8_0 [rows, cols/32 * 34]
1481 device const float* inputs [[buffer(1)]], // [M, cols]
1482 device float* outputs [[buffer(2)]], // [M, rows]
1483 constant uint& num_tokens [[buffer(3)]],
1484 constant uint& rows [[buffer(4)]],
1485 constant uint& cols [[buffer(5)]],
1486 uint2 tgid [[threadgroup_position_in_grid]],
1487 uint tid [[thread_index_in_threadgroup]],
1488 uint simd_id [[simdgroup_index_in_threadgroup]])
1489{
1490 uint row_base = tgid.x * MMA_ROW_TILE;
1491 uint tok_base = tgid.y * MMA_TOK_TILE;
1492 if (row_base >= rows || tok_base >= num_tokens) return;
1493
1494 // Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
1495 threadgroup float w_tile[MMA_ROW_TILE * 32];
1496 threadgroup float t_tile[MMA_TOK_TILE * 32];
1497
1498 // 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
1499 uint sg_tok_base = (simd_id / 2) * 8; // row within output tile (token dim)
1500 uint sg_row_base = (simd_id % 2) * 8; // col within output tile (row dim)
1501
1502 simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
1503
1504 uint blocks_per_row = cols / 32;
1505 uint row_bytes = blocks_per_row * 34;
1506
1507 for (uint blk = 0; blk < blocks_per_row; blk++) {
1508 // ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
1509 // 512 floats / 128 threads = 4 floats per thread.
1510 {
1511 uint base = tid * 4;
1512 for (uint ii = 0; ii < 4; ii++) {
1513 uint idx = base + ii;
1514 uint r = idx / 32;
1515 uint k = idx % 32;
1516 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1517 float sc = float(*(device const half*)rp);
1518 int ival = int(*(device const int8_t*)(rp + 2 + k));
1519 w_tile[r * 32 + k] = float(ival) * sc;
1520 }
1521 }
1522
1523 // ── Cooperatively load 16 token vectors × 32 K into t_tile ──
1524 {
1525 uint base = tid * 4;
1526 for (uint ii = 0; ii < 4; ii++) {
1527 uint idx = base + ii;
1528 uint m = idx / 32;
1529 uint k = idx % 32;
1530 uint tok = tok_base + m;
1531 t_tile[m * 32 + k] = (tok < num_tokens)
1532 ? inputs[tok * cols + blk * 32 + k]
1533 : 0.0f;
1534 }
1535 }
1536
1537 threadgroup_barrier(mem_flags::mem_threadgroup);
1538
1539 // ── 4 × (8×8×8) MMA over the K=32 chunk ──
1540 // A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k] (M×K, no transpose)
1541 // B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k] (loaded transposed → K×R)
1542 // C[m, r] += A[m, k] * B[k, r]
1543 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1544 simdgroup_matrix<float, 8, 8> A, B;
1545 simdgroup_load(A,
1546 t_tile + sg_tok_base * 32 + k_sub * 8,
1547 32,
1548 ulong2(0, 0),
1549 false);
1550 simdgroup_load(B,
1551 w_tile + sg_row_base * 32 + k_sub * 8,
1552 32,
1553 ulong2(0, 0),
1554 true);
1555 simdgroup_multiply_accumulate(C, A, B, C);
1556 }
1557
1558 threadgroup_barrier(mem_flags::mem_threadgroup);
1559 }
1560
1561 // ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
1562 // Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
1563 uint out_tok = tok_base + sg_tok_base;
1564 uint out_row = row_base + sg_row_base;
1565 bool full_tok = (out_tok + 8 <= num_tokens);
1566 if (full_tok) {
1567 // Fast path: entire 8×8 sub-tile is in-bounds.
1568 simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
1569 } else if (out_tok < num_tokens) {
1570 // Partial row at the last token tile: stage in per-simdgroup scratch
1571 // and scalar-copy the valid rows.
1572 threadgroup float scratch[4 * 64];
1573 simdgroup_store(C, scratch + simd_id * 64, 8);
1574 simdgroup_barrier(mem_flags::mem_threadgroup);
1575 uint lane = tid % 32;
1576 if (lane == 0) {
1577 uint valid = num_tokens - out_tok; // 1..7
1578 for (uint m = 0; m < valid; m++) {
1579 device float* dst = outputs + (out_tok + m) * rows + out_row;
1580 threadgroup const float* src = scratch + simd_id * 64 + m * 8;
1581 dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
1582 dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
1583 }
1584 }
1585 }
1586}
1587
1588// ── matmul_q8_mma32 ────────────────────────────────────────────────────
1589// Larger-tile variant of matmul_q8_mma for long-context prefill.
1590//
1591// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
1592// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
1593// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
1594//
1595// simd_id = 2*sg_tok_idx + sg_row_half (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
1596// output sub-tiles (tok, row):
1597// (sg_tok_idx*8, sg_row_half*16 + 0) -> C_a
1598// (sg_tok_idx*8, sg_row_half*16 + 8) -> C_b
1599//
1600// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
1601// iteration — better FLOP/load ratio than the 16×16 single-accumulator
1602// kernel — and halves the number of threadgroups vs the 16×16 tile.
1603//
1604// Assumptions (verified in dispatch helper, fallback otherwise):
1605// * cols % 32 == 0
1606// * rows % 32 == 0
1607constant constexpr uint MMA32_TOK_TILE = 32;
1608constant constexpr uint MMA32_ROW_TILE = 32;
1609
1610kernel void matmul_q8_mma32(
1611 device const uchar* matrix [[buffer(0)]],
1612 device const float* inputs [[buffer(1)]],
1613 device float* outputs [[buffer(2)]],
1614 constant uint& num_tokens [[buffer(3)]],
1615 constant uint& rows [[buffer(4)]],
1616 constant uint& cols [[buffer(5)]],
1617 uint2 tgid [[threadgroup_position_in_grid]],
1618 uint tid [[thread_index_in_threadgroup]],
1619 uint simd_id [[simdgroup_index_in_threadgroup]])
1620{
1621 uint row_base = tgid.x * MMA32_ROW_TILE;
1622 uint tok_base = tgid.y * MMA32_TOK_TILE;
1623 if (row_base >= rows || tok_base >= num_tokens) return;
1624
1625 // 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
1626 threadgroup float w_tile[MMA32_ROW_TILE * 32];
1627 threadgroup float t_tile[MMA32_TOK_TILE * 32];
1628
1629 uint sg_tok_idx = simd_id / 2; // 0..3
1630 uint sg_row_half = simd_id % 2; // 0..1
1631 uint sg_tok_base = sg_tok_idx * 8;
1632 uint sg_row_base_a = sg_row_half * 16 + 0;
1633 uint sg_row_base_b = sg_row_half * 16 + 8;
1634
1635 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1636 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1637
1638 uint blocks_per_row = cols / 32;
1639 uint row_bytes = blocks_per_row * 34;
1640
1641 for (uint blk = 0; blk < blocks_per_row; blk++) {
1642 // Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
1643 {
1644 uint base = tid * 4;
1645 for (uint ii = 0; ii < 4; ii++) {
1646 uint idx = base + ii;
1647 uint r = idx / 32;
1648 uint k = idx % 32;
1649 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1650 float sc = float(*(device const half*)rp);
1651 int ival = int(*(device const int8_t*)(rp + 2 + k));
1652 w_tile[r * 32 + k] = float(ival) * sc;
1653 }
1654 }
1655
1656 // Cooperative token tile load.
1657 {
1658 uint base = tid * 4;
1659 for (uint ii = 0; ii < 4; ii++) {
1660 uint idx = base + ii;
1661 uint m = idx / 32;
1662 uint k = idx % 32;
1663 uint tok = tok_base + m;
1664 t_tile[m * 32 + k] = (tok < num_tokens)
1665 ? inputs[tok * cols + blk * 32 + k]
1666 : 0.0f;
1667 }
1668 }
1669
1670 threadgroup_barrier(mem_flags::mem_threadgroup);
1671
1672 // 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
1673 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1674 simdgroup_matrix<float, 8, 8> A, B_a, B_b;
1675 simdgroup_load(A,
1676 t_tile + sg_tok_base * 32 + k_sub * 8,
1677 32,
1678 ulong2(0, 0),
1679 false);
1680 simdgroup_load(B_a,
1681 w_tile + sg_row_base_a * 32 + k_sub * 8,
1682 32,
1683 ulong2(0, 0),
1684 true);
1685 simdgroup_load(B_b,
1686 w_tile + sg_row_base_b * 32 + k_sub * 8,
1687 32,
1688 ulong2(0, 0),
1689 true);
1690 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1691 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1692 }
1693
1694 threadgroup_barrier(mem_flags::mem_threadgroup);
1695 }
1696
1697 // Store both 8×8 accumulators. rows is always MMA32_ROW_TILE-aligned
1698 // (verified in dispatch), so full simdgroup_store is safe for the row
1699 // dimension; only the last token tile may be partial.
1700 uint out_tok = tok_base + sg_tok_base;
1701 uint out_row_a = row_base + sg_row_base_a;
1702 uint out_row_b = row_base + sg_row_base_b;
1703 bool full_tok = (out_tok + 8 <= num_tokens);
1704 if (full_tok) {
1705 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1706 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1707 } else if (out_tok < num_tokens) {
1708 threadgroup float scratch[8 * 2 * 64]; // 8 simdgroups × 2 accs × 64 floats
1709 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1710 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1711 simdgroup_barrier(mem_flags::mem_threadgroup);
1712 uint lane = tid % 32;
1713 if (lane == 0) {
1714 uint valid = num_tokens - out_tok; // 1..7
1715 for (uint m = 0; m < valid; m++) {
1716 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1717 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1718 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1719 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1720 for (uint j = 0; j < 8; j++) {
1721 dst_a[j] = src_a[j];
1722 dst_b[j] = src_b[j];
1723 }
1724 }
1725 }
1726 }
1727}
1728
1729// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
1730// FP16 threadgroup-tile variant of matmul_q8_mma32.
1731//
1732// Stores dequantized weights and token inputs as `half` in threadgroup
1733// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
1734// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
1735// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
1736// representation preserves the full quantized dynamic range. Token
1737// activations stay numerically safe because the subsequent
1738// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
1739//
1740// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
1741// accumulators each. Primary win vs mma32 is occupancy at moderate
1742// prefill lengths where the GPU is wave-starved.
1743kernel void matmul_q8_mma32_h(
1744 device const uchar* matrix [[buffer(0)]],
1745 device const float* inputs [[buffer(1)]],
1746 device float* outputs [[buffer(2)]],
1747 constant uint& num_tokens [[buffer(3)]],
1748 constant uint& rows [[buffer(4)]],
1749 constant uint& cols [[buffer(5)]],
1750 uint2 tgid [[threadgroup_position_in_grid]],
1751 uint tid [[thread_index_in_threadgroup]],
1752 uint simd_id [[simdgroup_index_in_threadgroup]])
1753{
1754 uint row_base = tgid.x * MMA32_ROW_TILE;
1755 uint tok_base = tgid.y * MMA32_TOK_TILE;
1756 if (row_base >= rows || tok_base >= num_tokens) return;
1757
1758 // 32×32 half tiles — 2 KB each, 4 KB total.
1759 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1760 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1761
1762 uint sg_tok_idx = simd_id / 2;
1763 uint sg_row_half = simd_id % 2;
1764 uint sg_tok_base = sg_tok_idx * 8;
1765 uint sg_row_base_a = sg_row_half * 16 + 0;
1766 uint sg_row_base_b = sg_row_half * 16 + 8;
1767
1768 simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
1769 simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
1770
1771 uint blocks_per_row = cols / 32;
1772 uint row_bytes = blocks_per_row * 34;
1773
1774 for (uint blk = 0; blk < blocks_per_row; blk++) {
1775 // Cooperative weight dequantization to FP16.
1776 {
1777 uint base = tid * 4;
1778 for (uint ii = 0; ii < 4; ii++) {
1779 uint idx = base + ii;
1780 uint r = idx / 32;
1781 uint k = idx % 32;
1782 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1783 float sc = float(*(device const half*)rp);
1784 int ival = int(*(device const int8_t*)(rp + 2 + k));
1785 w_tile[r * 32 + k] = half(float(ival) * sc);
1786 }
1787 }
1788
1789 // Cooperative token tile load (f32 → f16 narrowing).
1790 {
1791 uint base = tid * 4;
1792 for (uint ii = 0; ii < 4; ii++) {
1793 uint idx = base + ii;
1794 uint m = idx / 32;
1795 uint k = idx % 32;
1796 uint tok = tok_base + m;
1797 t_tile[m * 32 + k] = (tok < num_tokens)
1798 ? half(inputs[tok * cols + blk * 32 + k])
1799 : half(0);
1800 }
1801 }
1802
1803 threadgroup_barrier(mem_flags::mem_threadgroup);
1804
1805 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1806 simdgroup_matrix<half, 8, 8> A, B_a, B_b;
1807 simdgroup_load(A,
1808 t_tile + sg_tok_base * 32 + k_sub * 8,
1809 32,
1810 ulong2(0, 0),
1811 false);
1812 simdgroup_load(B_a,
1813 w_tile + sg_row_base_a * 32 + k_sub * 8,
1814 32,
1815 ulong2(0, 0),
1816 true);
1817 simdgroup_load(B_b,
1818 w_tile + sg_row_base_b * 32 + k_sub * 8,
1819 32,
1820 ulong2(0, 0),
1821 true);
1822 simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
1823 simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
1824 }
1825
1826 threadgroup_barrier(mem_flags::mem_threadgroup);
1827 }
1828
1829 uint out_tok = tok_base + sg_tok_base;
1830 uint out_row_a = row_base + sg_row_base_a;
1831 uint out_row_b = row_base + sg_row_base_b;
1832 bool full_tok = (out_tok + 8 <= num_tokens);
1833 if (full_tok) {
1834 simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
1835 simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
1836 } else if (out_tok < num_tokens) {
1837 threadgroup float scratch[8 * 2 * 64];
1838 simdgroup_store(C_a, scratch + simd_id * 128, 8);
1839 simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
1840 simdgroup_barrier(mem_flags::mem_threadgroup);
1841 uint lane = tid % 32;
1842 if (lane == 0) {
1843 uint valid = num_tokens - out_tok;
1844 for (uint m = 0; m < valid; m++) {
1845 device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
1846 device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
1847 threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
1848 threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
1849 for (uint j = 0; j < 8; j++) {
1850 dst_a[j] = src_a[j];
1851 dst_b[j] = src_b[j];
1852 }
1853 }
1854 }
1855 }
1856}
1857
1858// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
1859// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
1860//
1861// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
1862// 4 simdgroups × **2×2 grid** of 8×8 accumulators each. Per simdgroup:
1863// C_00 (tok 0..8, row 0..8) C_01 (tok 0..8, row 8..16)
1864// C_10 (tok 8..16, row 0..8) C_11 (tok 8..16, row 8..16)
1865// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
1866//
1867// Per K_sub iteration: load two A fragments and two B fragments, then run
1868// **four** MMA instructions reusing A_top with both B's and A_bot with
1869// both B's. That's double the FLOP-per-simdgroup-load compared to the
1870// 2-accumulator kernel and halves the thread count per threadgroup (128
1871// threads), which often improves occupancy on Apple GPUs where the
1872// concurrent-thread budget is the tighter limit than shared-memory size.
1873kernel void matmul_q8_mma32_h4(
1874 device const uchar* matrix [[buffer(0)]],
1875 device const float* inputs [[buffer(1)]],
1876 device float* outputs [[buffer(2)]],
1877 constant uint& num_tokens [[buffer(3)]],
1878 constant uint& rows [[buffer(4)]],
1879 constant uint& cols [[buffer(5)]],
1880 uint2 tgid [[threadgroup_position_in_grid]],
1881 uint tid [[thread_index_in_threadgroup]],
1882 uint simd_id [[simdgroup_index_in_threadgroup]])
1883{
1884 uint row_base = tgid.x * MMA32_ROW_TILE;
1885 uint tok_base = tgid.y * MMA32_TOK_TILE;
1886 if (row_base >= rows || tok_base >= num_tokens) return;
1887
1888 // 32×32 FP16 tiles, 4 KB total.
1889 threadgroup half w_tile[MMA32_ROW_TILE * 32];
1890 threadgroup half t_tile[MMA32_TOK_TILE * 32];
1891
1892 // 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
1893 uint sg_tok_q = simd_id / 2; // 0..1
1894 uint sg_row_q = simd_id % 2; // 0..1
1895 uint sg_tok_base = sg_tok_q * 16;
1896 uint sg_row_base = sg_row_q * 16;
1897
1898 simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
1899 simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
1900 simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
1901 simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
1902
1903 uint blocks_per_row = cols / 32;
1904 uint row_bytes = blocks_per_row * 34;
1905
1906 for (uint blk = 0; blk < blocks_per_row; blk++) {
1907 // Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
1908 {
1909 uint base = tid * 8;
1910 for (uint ii = 0; ii < 8; ii++) {
1911 uint idx = base + ii;
1912 uint r = idx / 32;
1913 uint k = idx % 32;
1914 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
1915 float sc = float(*(device const half*)rp);
1916 int ival = int(*(device const int8_t*)(rp + 2 + k));
1917 w_tile[r * 32 + k] = half(float(ival) * sc);
1918 }
1919 }
1920
1921 // Cooperative token tile load.
1922 {
1923 uint base = tid * 8;
1924 for (uint ii = 0; ii < 8; ii++) {
1925 uint idx = base + ii;
1926 uint m = idx / 32;
1927 uint k = idx % 32;
1928 uint tok = tok_base + m;
1929 t_tile[m * 32 + k] = (tok < num_tokens)
1930 ? half(inputs[tok * cols + blk * 32 + k])
1931 : half(0);
1932 }
1933 }
1934
1935 threadgroup_barrier(mem_flags::mem_threadgroup);
1936
1937 // 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
1938 for (uint k_sub = 0; k_sub < 4; k_sub++) {
1939 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
1940 simdgroup_load(A_top,
1941 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
1942 32,
1943 ulong2(0, 0),
1944 false);
1945 simdgroup_load(A_bot,
1946 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
1947 32,
1948 ulong2(0, 0),
1949 false);
1950 simdgroup_load(B_lo,
1951 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
1952 32,
1953 ulong2(0, 0),
1954 true);
1955 simdgroup_load(B_hi,
1956 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
1957 32,
1958 ulong2(0, 0),
1959 true);
1960 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
1961 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
1962 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
1963 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
1964 }
1965
1966 threadgroup_barrier(mem_flags::mem_threadgroup);
1967 }
1968
1969 // Store 4 output tiles. Full-tile fast path assumes full 16×16 valid.
1970 uint out_tok_top = tok_base + sg_tok_base + 0;
1971 uint out_tok_bot = tok_base + sg_tok_base + 8;
1972 uint out_row_lo = row_base + sg_row_base + 0;
1973 uint out_row_hi = row_base + sg_row_base + 8;
1974 bool full = (out_tok_bot + 8 <= num_tokens);
1975 if (full) {
1976 simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
1977 simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
1978 simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
1979 simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
1980 } else {
1981 // Partial-token fallback via per-simdgroup scratch.
1982 threadgroup float scratch[4 * 4 * 64]; // 4 simdgroups × 4 accs × 64
1983 uint sg_base = simd_id * 256;
1984 simdgroup_store(C_00, scratch + sg_base + 0, 8);
1985 simdgroup_store(C_01, scratch + sg_base + 64, 8);
1986 simdgroup_store(C_10, scratch + sg_base + 128, 8);
1987 simdgroup_store(C_11, scratch + sg_base + 192, 8);
1988 simdgroup_barrier(mem_flags::mem_threadgroup);
1989 uint lane = tid % 32;
1990 if (lane == 0) {
1991 for (uint m = 0; m < 8; m++) {
1992 uint t_top = out_tok_top + m;
1993 if (t_top < num_tokens) {
1994 device float* dst0 = outputs + t_top * rows + out_row_lo;
1995 device float* dst1 = outputs + t_top * rows + out_row_hi;
1996 threadgroup const float* src0 = scratch + sg_base + 0 + m * 8;
1997 threadgroup const float* src1 = scratch + sg_base + 64 + m * 8;
1998 for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
1999 }
2000 uint t_bot = out_tok_bot + m;
2001 if (t_bot < num_tokens) {
2002 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2003 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2004 threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
2005 threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
2006 for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
2007 }
2008 }
2009 }
2010 }
2011}
2012
2013// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
2014// All-half MMA variant of matmul_q8_mma32_h4.
2015//
2016// Both the input matrices and the accumulators are simdgroup_matrix<half>.
2017// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
2018// rate (dual-issue FMA), so if Q8_0 precision holds through half
2019// accumulation this kernel can double the effective FLOP throughput on
2020// matmul-bound prefill.
2021//
2022// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
2023// activations at each layer are bounded (post-RMSNorm ≈ O(1)). Summing
2024// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
2025// precision on extreme values, but the inputs are already quantized so the
2026// per-product error floor is higher than the half-precision rounding error.
2027// We verify correctness on 135M / 1B / 3B before enabling.
2028kernel void matmul_q8_mma32_hh4(
2029 device const uchar* matrix [[buffer(0)]],
2030 device const float* inputs [[buffer(1)]],
2031 device float* outputs [[buffer(2)]],
2032 constant uint& num_tokens [[buffer(3)]],
2033 constant uint& rows [[buffer(4)]],
2034 constant uint& cols [[buffer(5)]],
2035 uint2 tgid [[threadgroup_position_in_grid]],
2036 uint tid [[thread_index_in_threadgroup]],
2037 uint simd_id [[simdgroup_index_in_threadgroup]])
2038{
2039 uint row_base = tgid.x * MMA32_ROW_TILE;
2040 uint tok_base = tgid.y * MMA32_TOK_TILE;
2041 if (row_base >= rows || tok_base >= num_tokens) return;
2042
2043 threadgroup half w_tile[MMA32_ROW_TILE * 32];
2044 threadgroup half t_tile[MMA32_TOK_TILE * 32];
2045
2046 uint sg_tok_q = simd_id / 2;
2047 uint sg_row_q = simd_id % 2;
2048 uint sg_tok_base = sg_tok_q * 16;
2049 uint sg_row_base = sg_row_q * 16;
2050
2051 simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
2052 simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
2053 simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
2054 simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
2055
2056 uint blocks_per_row = cols / 32;
2057 uint row_bytes = blocks_per_row * 34;
2058
2059 for (uint blk = 0; blk < blocks_per_row; blk++) {
2060 {
2061 uint base = tid * 8;
2062 for (uint ii = 0; ii < 8; ii++) {
2063 uint idx = base + ii;
2064 uint r = idx / 32;
2065 uint k = idx % 32;
2066 device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
2067 float sc = float(*(device const half*)rp);
2068 int ival = int(*(device const int8_t*)(rp + 2 + k));
2069 w_tile[r * 32 + k] = half(float(ival) * sc);
2070 }
2071 }
2072 {
2073 uint base = tid * 8;
2074 for (uint ii = 0; ii < 8; ii++) {
2075 uint idx = base + ii;
2076 uint m = idx / 32;
2077 uint k = idx % 32;
2078 uint tok = tok_base + m;
2079 t_tile[m * 32 + k] = (tok < num_tokens)
2080 ? half(inputs[tok * cols + blk * 32 + k])
2081 : half(0);
2082 }
2083 }
2084 threadgroup_barrier(mem_flags::mem_threadgroup);
2085
2086 for (uint k_sub = 0; k_sub < 4; k_sub++) {
2087 simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
2088 simdgroup_load(A_top,
2089 t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
2090 32, ulong2(0, 0), false);
2091 simdgroup_load(A_bot,
2092 t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
2093 32, ulong2(0, 0), false);
2094 simdgroup_load(B_lo,
2095 w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
2096 32, ulong2(0, 0), true);
2097 simdgroup_load(B_hi,
2098 w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
2099 32, ulong2(0, 0), true);
2100 simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
2101 simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
2102 simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
2103 simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
2104 }
2105
2106 threadgroup_barrier(mem_flags::mem_threadgroup);
2107 }
2108
2109 // Store half accumulators via scratch (must widen to f32 for device output).
2110 uint out_tok_top = tok_base + sg_tok_base + 0;
2111 uint out_tok_bot = tok_base + sg_tok_base + 8;
2112 uint out_row_lo = row_base + sg_row_base + 0;
2113 uint out_row_hi = row_base + sg_row_base + 8;
2114
2115 threadgroup half scratch[4 * 4 * 64];
2116 uint sg_base = simd_id * 256;
2117 simdgroup_store(C_00, scratch + sg_base + 0, 8);
2118 simdgroup_store(C_01, scratch + sg_base + 64, 8);
2119 simdgroup_store(C_10, scratch + sg_base + 128, 8);
2120 simdgroup_store(C_11, scratch + sg_base + 192, 8);
2121 simdgroup_barrier(mem_flags::mem_threadgroup);
2122 uint lane = tid % 32;
2123 if (lane == 0) {
2124 for (uint m = 0; m < 8; m++) {
2125 uint t_top = out_tok_top + m;
2126 if (t_top < num_tokens) {
2127 device float* dst0 = outputs + t_top * rows + out_row_lo;
2128 device float* dst1 = outputs + t_top * rows + out_row_hi;
2129 threadgroup const half* src0 = scratch + sg_base + 0 + m * 8;
2130 threadgroup const half* src1 = scratch + sg_base + 64 + m * 8;
2131 for (uint j = 0; j < 8; j++) {
2132 dst0[j] = float(src0[j]);
2133 dst1[j] = float(src1[j]);
2134 }
2135 }
2136 uint t_bot = out_tok_bot + m;
2137 if (t_bot < num_tokens) {
2138 device float* dst2 = outputs + t_bot * rows + out_row_lo;
2139 device float* dst3 = outputs + t_bot * rows + out_row_hi;
2140 threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
2141 threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
2142 for (uint j = 0; j < 8; j++) {
2143 dst2[j] = float(src2[j]);
2144 dst3[j] = float(src3[j]);
2145 }
2146 }
2147 }
2148 }
2149}
2150
2151// ── add_bias_batch ─────────────────────────────────────────────────────
2152// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
2153// Used for Qwen2 QKV bias after the fused qkv matmul.
2154// out[token, i] += bias[i] for i in 0..rows, token in 0..num_tokens
2155kernel void add_bias_batch(
2156 device float* out [[buffer(0)]], // [num_tokens, rows]
2157 device const float* bias [[buffer(1)]], // [rows]
2158 constant uint& num_tokens [[buffer(2)]],
2159 constant uint& rows [[buffer(3)]],
2160 uint id [[thread_position_in_grid]])
2161{
2162 uint total = num_tokens * rows;
2163 if (id >= total) return;
2164 uint i = id % rows;
2165 out[id] += bias[i];
2166}
2167
2168// ── matmul_vec_q4_batch ────────────────────────────────────────────────
2169// Batched Q4_0 matrix-vector multiply for M input vectors.
2170// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
2171kernel void matmul_vec_q4_batch(
2172 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes [rows, cols]
2173 device const float* inputs [[buffer(1)]], // [M, cols] input batch
2174 device float* outputs [[buffer(2)]], // [M, rows] output batch
2175 constant uint& num_tokens [[buffer(3)]], // M
2176 constant uint& rows [[buffer(4)]],
2177 constant uint& cols [[buffer(5)]],
2178 uint tgid [[threadgroup_position_in_grid]],
2179 uint tid [[thread_index_in_threadgroup]],
2180 uint simd_lane [[thread_index_in_simdgroup]],
2181 uint simd_id [[simdgroup_index_in_threadgroup]])
2182{
2183 uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
2184 uint token = tgid / row_tgs;
2185 uint tg_in_token = tgid % row_tgs;
2186 if (token >= num_tokens) return;
2187
2188 threadgroup float vec_tile[VEC_TILE_SIZE];
2189 device const float* input = inputs + token * cols;
2190 for (uint i = tid; i < cols; i += 256) {
2191 vec_tile[i] = input[i];
2192 }
2193 threadgroup_barrier(mem_flags::mem_threadgroup);
2194
2195 uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
2196 if (row_base >= rows) return;
2197
2198 uint blocks_per_row = cols / 32;
2199 uint row_bytes = blocks_per_row * 18;
2200
2201 device const uchar* r0 = matrix + row_base * row_bytes;
2202 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
2203 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
2204 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
2205
2206 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
2207
2208 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
2209 uint bb = blk * 18;
2210 uint vb = blk * 32;
2211
2212 float sc0 = float(*(device const half*)(r0 + bb));
2213 float sc1 = float(*(device const half*)(r1 + bb));
2214 float sc2 = float(*(device const half*)(r2 + bb));
2215 float sc3 = float(*(device const half*)(r3 + bb));
2216
2217 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
2218 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
2219 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
2220 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
2221
2222 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
2223 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
2224 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
2225 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
2226 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
2227 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
2228 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
2229 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
2230
2231 float bd0=0, bd1=0, bd2=0, bd3=0;
2232 uchar4 b;
2233
2234 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2235 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2236 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2237 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2238
2239 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2240 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2241 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2242 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2243
2244 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2245 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2246 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2247 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2248
2249 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
2250 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
2251 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
2252 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
2253
2254 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
2255 }
2256
2257 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
2258 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
2259
2260 device float* output = outputs + token * rows;
2261 if (simd_lane == 0) {
2262 if (row_base < rows) output[row_base] = sum0;
2263 if (row_base + 1 < rows) output[row_base + 1] = sum1;
2264 if (row_base + 2 < rows) output[row_base + 2] = sum2;
2265 if (row_base + 3 < rows) output[row_base + 3] = sum3;
2266 }
2267}
2268
2269// ── copy_kv_batch ─────────────────────────────────────────────────────
2270// Copy K or V from a strided batch QKV buffer to the KV cache.
2271// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
2272// dst layout: contiguous [max_seq, kv_dim] cache.
2273kernel void copy_kv_batch(
2274 device const float* src [[buffer(0)]], // batch QKV buffer
2275 device float* dst [[buffer(1)]], // KV cache
2276 constant uint& M [[buffer(2)]], // num tokens in batch
2277 constant uint& kv_dim [[buffer(3)]], // floats per KV vector
2278 constant uint& base_pos [[buffer(4)]], // starting position in cache
2279 constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
2280 constant uint& src_offset [[buffer(6)]], // float offset within each src row
2281 uint id [[thread_position_in_grid]])
2282{
2283 uint total = M * kv_dim;
2284 if (id >= total) return;
2285 uint token = id / kv_dim;
2286 uint d = id % kv_dim;
2287 uint dst_off = (base_pos + token) * kv_dim + d;
2288 uint src_off = token * src_stride + src_offset + d;
2289 dst[dst_off] = src[src_off];
2290}
2291
2292// ── attention_batch ───────────────────────────────────────────────────
2293// Batched causal attention for prefill. Processes M tokens in one dispatch.
2294// Each threadgroup handles one (token, head) pair.
2295// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
2296// Causal masking: token i can only attend to positions 0..base_pos+i.
2297kernel void attention_batch(
2298 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2299 device const float* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim]
2300 device const float* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim]
2301 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2302 constant uint& M [[buffer(4)]], // num tokens in batch
2303 constant uint& base_pos [[buffer(5)]], // starting position in KV cache
2304 constant uint& num_heads [[buffer(6)]],
2305 constant uint& num_kv_heads [[buffer(7)]],
2306 constant uint& head_dim [[buffer(8)]],
2307 constant uint& q_stride [[buffer(9)]], // floats per row in q_batch
2308 uint tgid [[threadgroup_position_in_grid]],
2309 uint tid [[thread_index_in_threadgroup]],
2310 uint simd_lane [[thread_index_in_simdgroup]],
2311 uint simd_id [[simdgroup_index_in_threadgroup]])
2312{
2313 // Grid: M * num_heads threadgroups
2314 uint token_idx = tgid / num_heads;
2315 uint head = tgid % num_heads;
2316 if (token_idx >= M) return;
2317
2318 uint kv_head = head / (num_heads / num_kv_heads);
2319 uint seq_len = base_pos + token_idx + 1; // causal: see positions 0..base_pos+token_idx
2320
2321 // Q offset uses strided layout (from batch QKV buffer)
2322 uint q_off = token_idx * q_stride + head * head_dim;
2323 // Output is contiguous [M, num_heads * head_dim]
2324 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2325
2326 // Shared memory for attention scores
2327 threadgroup float scores[2048];
2328
2329 // Step 1: Q * K^T with simdgroup reduction
2330 for (uint s = simd_id; s < seq_len; s += 8) {
2331 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
2332 float dot = 0.0;
2333 for (uint d = simd_lane; d < head_dim; d += 32) {
2334 dot += q_batch[q_off + d] * k_cache[k_off + d];
2335 }
2336 dot = simd_sum(dot);
2337 if (simd_lane == 0) {
2338 scores[s] = dot * fast::rsqrt(float(head_dim));
2339 }
2340 }
2341 threadgroup_barrier(mem_flags::mem_threadgroup);
2342
2343 // Step 2: Softmax (cooperative)
2344 float local_max = -INFINITY;
2345 for (uint s = tid; s < seq_len; s += 256) {
2346 local_max = max(local_max, scores[s]);
2347 }
2348 local_max = simd_max(local_max);
2349 threadgroup float shared_max[8];
2350 if (simd_lane == 0) shared_max[simd_id] = local_max;
2351 threadgroup_barrier(mem_flags::mem_threadgroup);
2352 if (tid == 0) {
2353 float m = shared_max[0];
2354 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
2355 shared_max[0] = m;
2356 }
2357 threadgroup_barrier(mem_flags::mem_threadgroup);
2358 float max_val = shared_max[0];
2359
2360 float local_sum = 0.0;
2361 for (uint s = tid; s < seq_len; s += 256) {
2362 scores[s] = fast::exp(scores[s] - max_val);
2363 local_sum += scores[s];
2364 }
2365 local_sum = simd_sum(local_sum);
2366 threadgroup float shared_sum[8];
2367 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
2368 threadgroup_barrier(mem_flags::mem_threadgroup);
2369 if (tid == 0) {
2370 float total = 0.0;
2371 for (uint i = 0; i < 8; i++) total += shared_sum[i];
2372 shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
2373 }
2374 threadgroup_barrier(mem_flags::mem_threadgroup);
2375 float inv_sum = shared_sum[0];
2376 for (uint s = tid; s < seq_len; s += 256) {
2377 scores[s] *= inv_sum;
2378 }
2379 threadgroup_barrier(mem_flags::mem_threadgroup);
2380
2381 // Step 3: scores * V using float4 vectorized loads
2382 // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
2383 // This is much better than the scalar version where only 64 of 256 threads are active.
2384 uint v_stride = num_kv_heads * head_dim;
2385 uint head_dim4 = head_dim / 4;
2386 for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
2387 uint d = d4 * 4;
2388 float4 acc = float4(0.0);
2389 uint v_base = kv_head * head_dim + d;
2390 uint seq_len4 = seq_len & ~3u;
2391 for (uint s = 0; s < seq_len4; s += 4) {
2392 float sc0 = scores[s];
2393 float sc1 = scores[s + 1];
2394 float sc2 = scores[s + 2];
2395 float sc3 = scores[s + 3];
2396 acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
2397 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
2398 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
2399 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
2400 }
2401 for (uint s = seq_len4; s < seq_len; s++) {
2402 acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
2403 }
2404 *(device float4*)(output_batch + out_off + d) = acc;
2405 }
2406 // Handle remaining dimensions not divisible by 4 (scalar fallback)
2407 for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
2408 float acc = 0.0;
2409 uint v_base = kv_head * head_dim + d;
2410 for (uint s = 0; s < seq_len; s++) {
2411 acc += scores[s] * v_cache[s * v_stride + v_base];
2412 }
2413 output_batch[out_off + d] = acc;
2414 }
2415}
2416
2417// ── attention_flash_batch ─────────────────────────────────────────────
2418// Streaming attention with online softmax. Same grid as attention_batch
2419// (M × num_heads threadgroups, one per (token, head) pair) but the scores
2420// matrix is never materialized. K/V positions are processed in a tile of
2421// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
2422// the standard flash-attention recurrence:
2423//
2424// m_new = max(m_old, tile_max)
2425// alpha = exp(m_old - m_new)
2426// l_new = alpha * l_old + sum(exp(S - m_new))
2427// O_new = alpha * O_old + sum(exp(S - m_new) * V)
2428// O_final = O / l
2429//
2430// This removes the `scores[2048]` cap in attention_batch (which silently
2431// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
2432// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
2433//
2434// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
2435constant constexpr uint FLASH_K_TILE = 32;
2436constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
2437
2438kernel void attention_flash_batch(
2439 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
2440 device const float* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim]
2441 device const float* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim]
2442 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
2443 constant uint& M [[buffer(4)]],
2444 constant uint& base_pos [[buffer(5)]],
2445 constant uint& num_heads [[buffer(6)]],
2446 constant uint& num_kv_heads [[buffer(7)]],
2447 constant uint& head_dim [[buffer(8)]],
2448 constant uint& q_stride [[buffer(9)]],
2449 uint tgid [[threadgroup_position_in_grid]],
2450 uint tid [[thread_index_in_threadgroup]],
2451 uint simd_lane [[thread_index_in_simdgroup]],
2452 uint simd_id [[simdgroup_index_in_threadgroup]])
2453{
2454 uint token_idx = tgid / num_heads;
2455 uint head = tgid % num_heads;
2456 if (token_idx >= M) return;
2457
2458 uint kv_head = head / (num_heads / num_kv_heads);
2459 uint seq_len = base_pos + token_idx + 1; // causal: attend to [0, base_pos + token_idx]
2460 uint q_off = token_idx * q_stride + head * head_dim;
2461 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
2462
2463 // Threadgroup state:
2464 // q_sh: Q vector for this (token, head), loaded once
2465 // o_sh: running output vector, updated each K tile
2466 // scores_sh: scores for the current K tile only
2467 // stats: [running max, running sum] (see flash-attention recurrence)
2468 threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
2469 threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
2470 threadgroup float scores_sh[FLASH_K_TILE];
2471 threadgroup float stats[2];
2472 threadgroup float sg_scratch[8]; // simdgroup-level reduction buffer
2473
2474 // --- Load Q (one row) and zero the running O ---
2475 for (uint d = tid; d < head_dim; d += 256) {
2476 q_sh[d] = q_batch[q_off + d];
2477 o_sh[d] = 0.0f;
2478 }
2479 if (tid == 0) {
2480 stats[0] = -INFINITY;
2481 stats[1] = 0.0f;
2482 }
2483 threadgroup_barrier(mem_flags::mem_threadgroup);
2484
2485 float scale = fast::rsqrt(float(head_dim));
2486 uint v_stride = num_kv_heads * head_dim;
2487 uint v_base = kv_head * head_dim;
2488
2489 // --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
2490 for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
2491 uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
2492
2493 // [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
2494 // 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
2495 for (uint ti = simd_id; ti < tile_n; ti += 8) {
2496 uint k_off = (kv_base + ti) * v_stride + v_base; // same layout as V stride
2497 float dot = 0.0f;
2498 for (uint d = simd_lane; d < head_dim; d += 32) {
2499 dot += q_sh[d] * k_cache[k_off + d];
2500 }
2501 dot = simd_sum(dot);
2502 if (simd_lane == 0) {
2503 scores_sh[ti] = dot * scale;
2504 }
2505 }
2506 threadgroup_barrier(mem_flags::mem_threadgroup);
2507
2508 // [2] Tile max via cooperative reduction.
2509 float local_max = -INFINITY;
2510 for (uint s = tid; s < tile_n; s += 256) {
2511 local_max = max(local_max, scores_sh[s]);
2512 }
2513 local_max = simd_max(local_max);
2514 if (simd_lane == 0) {
2515 sg_scratch[simd_id] = local_max;
2516 }
2517 threadgroup_barrier(mem_flags::mem_threadgroup);
2518 // [3] Merge with running max, compute alpha, rescale running l.
2519 float m_new;
2520 float alpha;
2521 if (tid == 0) {
2522 float tile_max = sg_scratch[0];
2523 for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
2524 float m_old = stats[0];
2525 m_new = max(m_old, tile_max);
2526 // First iteration: m_old = -inf → alpha = 0 (reset O).
2527 alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
2528 stats[0] = m_new;
2529 stats[1] *= alpha;
2530 // Broadcast via sg_scratch.
2531 sg_scratch[0] = alpha;
2532 sg_scratch[1] = m_new;
2533 }
2534 threadgroup_barrier(mem_flags::mem_threadgroup);
2535 alpha = sg_scratch[0];
2536 m_new = sg_scratch[1];
2537
2538 // [4] Rescale running output by alpha, then compute exp(scores - m_new).
2539 for (uint d = tid; d < head_dim; d += 256) {
2540 o_sh[d] *= alpha;
2541 }
2542 for (uint s = tid; s < tile_n; s += 256) {
2543 scores_sh[s] = fast::exp(scores_sh[s] - m_new);
2544 }
2545 threadgroup_barrier(mem_flags::mem_threadgroup);
2546
2547 // [5] Tile sum → update running l.
2548 float local_sum = 0.0f;
2549 for (uint s = tid; s < tile_n; s += 256) {
2550 local_sum += scores_sh[s];
2551 }
2552 local_sum = simd_sum(local_sum);
2553 if (simd_lane == 0) {
2554 sg_scratch[simd_id] = local_sum;
2555 }
2556 threadgroup_barrier(mem_flags::mem_threadgroup);
2557 if (tid == 0) {
2558 float tile_sum = 0.0f;
2559 for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
2560 stats[1] += tile_sum;
2561 }
2562 threadgroup_barrier(mem_flags::mem_threadgroup);
2563
2564 // [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
2565 for (uint d = tid; d < head_dim; d += 256) {
2566 float acc = 0.0f;
2567 for (uint s = 0; s < tile_n; s++) {
2568 acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
2569 }
2570 o_sh[d] += acc;
2571 }
2572 threadgroup_barrier(mem_flags::mem_threadgroup);
2573 }
2574
2575 // --- Normalize and write output ---
2576 float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
2577 for (uint d = tid; d < head_dim; d += 256) {
2578 output_batch[out_off + d] = o_sh[d] * inv_l;
2579 }
2580}
2581
2582// ── rope_qk_batch ─────────────────────────────────────────────────────
2583// Fused RoPE for both Q and K in a single dispatch, saving one kernel
2584// launch + memory barrier per layer. Both Q and K live in the same
2585// qkv_data buffer at different offsets within each token's row.
2586// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
2587kernel void rope_qk_batch(
2588 device float* qkv_data [[buffer(0)]], // [M, qkv_stride]
2589 constant uint& M [[buffer(1)]], // num tokens
2590 constant uint& base_pos [[buffer(2)]], // starting position
2591 constant uint& num_q_heads [[buffer(3)]],
2592 constant uint& num_kv_heads [[buffer(4)]],
2593 constant uint& head_dim [[buffer(5)]],
2594 constant uint& qkv_stride [[buffer(6)]], // floats per row
2595 constant float& theta [[buffer(7)]],
2596 uint id [[thread_position_in_grid]])
2597{
2598 uint half_dim = head_dim / 2;
2599 uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
2600 uint token = id / total_pairs;
2601 uint pair = id % total_pairs;
2602 if (token >= M) return;
2603
2604 uint pos = base_pos + token;
2605 uint q_pairs = num_q_heads * half_dim;
2606
2607 uint h, i, offset;
2608 if (pair < q_pairs) {
2609 // Q head
2610 h = pair / half_dim;
2611 i = pair % half_dim;
2612 offset = token * qkv_stride + h * head_dim + i * 2;
2613 } else {
2614 // K head
2615 uint kp = pair - q_pairs;
2616 h = kp / half_dim;
2617 i = kp % half_dim;
2618 uint k_start = num_q_heads * head_dim;
2619 offset = token * qkv_stride + k_start + h * head_dim + i * 2;
2620 }
2621
2622 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
2623 float angle = float(pos) * freq;
2624 float cos_val = cos(angle);
2625 float sin_val = sin(angle);
2626
2627 float x0 = qkv_data[offset];
2628 float x1 = qkv_data[offset + 1];
2629 qkv_data[offset] = x0 * cos_val - x1 * sin_val;
2630 qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
2631}
2632
2633// ── copy_kv_both_batch ────────────────────────────────────────────────
2634// Fused K+V cache copy in a single dispatch: copies both K and V from
2635// the strided batch QKV buffer to their respective KV cache buffers.
2636// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
2637kernel void copy_kv_both_batch(
2638 device const float* src [[buffer(0)]], // batch QKV buffer [M, qkv_stride]
2639 device float* k_dst [[buffer(1)]], // K cache [max_seq, kv_dim]
2640 device float* v_dst [[buffer(2)]], // V cache [max_seq, kv_dim]
2641 constant uint& M [[buffer(3)]], // num tokens in batch
2642 constant uint& kv_dim [[buffer(4)]], // floats per KV vector
2643 constant uint& base_pos [[buffer(5)]], // starting position in cache
2644 constant uint& src_stride [[buffer(6)]], // floats per row in src (qkv_stride)
2645 constant uint& k_offset [[buffer(7)]], // float offset of K within each src row
2646 constant uint& v_offset [[buffer(8)]], // float offset of V within each src row
2647 uint id [[thread_position_in_grid]])
2648{
2649 // Total elements = M * kv_dim * 2 (K + V)
2650 uint total_kv = M * kv_dim;
2651 if (id >= total_kv * 2) return;
2652
2653 uint is_v = id / total_kv; // 0 = K, 1 = V
2654 uint local_id = id % total_kv;
2655 uint token = local_id / kv_dim;
2656 uint d = local_id % kv_dim;
2657
2658 uint dst_off = (base_pos + token) * kv_dim + d;
2659 uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
2660
2661 if (is_v) {
2662 v_dst[dst_off] = src[src_off];
2663 } else {
2664 k_dst[dst_off] = src[src_off];
2665 }
2666}
2667"#
2668 .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
2669}
2670
2671fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
2676 let mut code = String::with_capacity(48 * 1024);
2677 emit_model_header(&mut code, config)?;
2678 emit_metal_model_struct(&mut code, config)?;
2679 emit_layer_buffers_struct(&mut code, config)?;
2680 emit_metal_model_impl(&mut code, config)?;
2681 emit_helper_functions(&mut code)?;
2682 Ok(code)
2683}
2684
2685fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2686 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
2687 writeln!(
2688 code,
2689 "//! Model: {} ({} layers, hidden={})",
2690 config.architecture, config.num_layers, config.hidden_size
2691 )?;
2692 writeln!(code, "//!")?;
2693 writeln!(
2694 code,
2695 "//! Uses native Metal compute pipelines via the metal crate."
2696 )?;
2697 writeln!(code)?;
2698 writeln!(code, "#![allow(dead_code)]")?;
2699 writeln!(code)?;
2700 writeln!(code, "use metal::*;")?;
2701 writeln!(code, "#[allow(unused_imports)]")?;
2702 writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
2703 writeln!(code, "use std::mem;")?;
2704 writeln!(code)?;
2705
2706 writeln!(
2708 code,
2709 "// ── Model constants ──────────────────────────────────"
2710 )?;
2711 writeln!(
2712 code,
2713 "pub const HIDDEN_SIZE: usize = {};",
2714 config.hidden_size
2715 )?;
2716 writeln!(
2717 code,
2718 "pub const INTERMEDIATE_SIZE: usize = {};",
2719 config.intermediate_size
2720 )?;
2721 writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
2722 writeln!(
2723 code,
2724 "pub const NUM_HEADS: usize = {};",
2725 config.num_attention_heads
2726 )?;
2727 writeln!(
2728 code,
2729 "pub const NUM_KV_HEADS: usize = {};",
2730 config.num_kv_heads
2731 )?;
2732 writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
2733 writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
2734 let effective_seq_len = config.max_seq_len.min(4096);
2735 writeln!(
2736 code,
2737 "pub const MAX_SEQ_LEN: usize = {}; // capped from model's {}",
2738 effective_seq_len, config.max_seq_len
2739 )?;
2740 writeln!(
2741 code,
2742 "pub const RMS_NORM_EPS: f32 = {:e};",
2743 config.rms_norm_eps
2744 )?;
2745 writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
2746 writeln!(
2747 code,
2748 "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
2749 )?;
2750 writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
2751 writeln!(code)?;
2752
2753 Ok(())
2754}
2755
2756fn emit_metal_model_struct(
2757 code: &mut String,
2758 config: &ModelConfig,
2759) -> Result<(), MetalCodegenError> {
2760 writeln!(
2761 code,
2762 "// ── MetalModel ──────────────────────────────────────────"
2763 )?;
2764 writeln!(code)?;
2765 writeln!(
2766 code,
2767 "/// Metal-accelerated transformer model for Apple Silicon."
2768 )?;
2769 writeln!(code, "///")?;
2770 writeln!(
2771 code,
2772 "/// Uses unified memory for zero-copy weight access and native Metal"
2773 )?;
2774 writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
2775 writeln!(code, "pub struct MetalModel {{")?;
2776 writeln!(code, " device: Device,")?;
2777 writeln!(code, " queue: CommandQueue,")?;
2778 writeln!(code)?;
2779 writeln!(code, " // ── Compute pipelines ──")?;
2780 writeln!(code, " matmul_pipeline: ComputePipelineState,")?;
2781 writeln!(code, " matmul_q8_pipeline: ComputePipelineState,")?;
2782 writeln!(code, " matmul_q4_pipeline: ComputePipelineState,")?;
2783 writeln!(code, " rms_norm_pipeline: ComputePipelineState,")?;
2784 writeln!(code, " rope_pipeline: ComputePipelineState,")?;
2785 writeln!(code, " softmax_pipeline: ComputePipelineState,")?;
2786 writeln!(code, " silu_mul_pipeline: ComputePipelineState,")?;
2787 writeln!(code, " silu_mul_fused_pipeline: ComputePipelineState,")?;
2788 writeln!(code, " add_pipeline: ComputePipelineState,")?;
2789 writeln!(code, " attention_pipeline: ComputePipelineState,")?;
2790 writeln!(code, " add_inplace_pipeline: ComputePipelineState,")?;
2791 writeln!(code, " copy_pipeline: ComputePipelineState,")?;
2792 writeln!(code, " copy_offset_pipeline: ComputePipelineState,")?;
2793 writeln!(code)?;
2794 writeln!(code, " // ── Batched prefill pipelines ──")?;
2795 writeln!(code, " matmul_batch_pipeline: ComputePipelineState,")?;
2796 writeln!(code, " matmul_q8_batch_pipeline: ComputePipelineState,")?;
2797 writeln!(
2798 code,
2799 " matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
2800 )?;
2801 writeln!(
2802 code,
2803 " matmul_q8_mma_pipeline: ComputePipelineState,"
2804 )?;
2805 writeln!(
2806 code,
2807 " matmul_q8_mma32_pipeline: ComputePipelineState,"
2808 )?;
2809 writeln!(
2810 code,
2811 " matmul_q8_mma32_h_pipeline: ComputePipelineState,"
2812 )?;
2813 writeln!(
2814 code,
2815 " matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
2816 )?;
2817 writeln!(
2818 code,
2819 " matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
2820 )?;
2821 if config.qkv_bias {
2822 writeln!(
2823 code,
2824 " add_bias_batch_pipeline: ComputePipelineState,"
2825 )?;
2826 }
2827 writeln!(code, " matmul_q4_batch_pipeline: ComputePipelineState,")?;
2828 writeln!(code, " rms_norm_batch_pipeline: ComputePipelineState,")?;
2829 writeln!(code, " rope_batch_pipeline: ComputePipelineState,")?;
2830 writeln!(
2831 code,
2832 " silu_mul_fused_batch_pipeline: ComputePipelineState,"
2833 )?;
2834 writeln!(
2835 code,
2836 " add_inplace_batch_pipeline: ComputePipelineState,"
2837 )?;
2838 writeln!(
2839 code,
2840 " copy_embedding_batch_pipeline: ComputePipelineState,"
2841 )?;
2842 writeln!(code, " attention_batch_pipeline: ComputePipelineState,")?;
2843 writeln!(
2844 code,
2845 " attention_flash_batch_pipeline: ComputePipelineState,"
2846 )?;
2847 writeln!(code, " copy_kv_batch_pipeline: ComputePipelineState,")?;
2848 writeln!(code, " rope_qk_batch_pipeline: ComputePipelineState,")?;
2849 writeln!(
2850 code,
2851 " copy_kv_both_batch_pipeline: ComputePipelineState,"
2852 )?;
2853 writeln!(code)?;
2854 writeln!(code, " // ── Weight buffers (Metal shared memory) ──")?;
2855 writeln!(
2856 code,
2857 " /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
2858 )?;
2859 writeln!(code, " embed_buf: Buffer,")?;
2860 writeln!(code)?;
2861 writeln!(code, " /// Per-layer weight buffers")?;
2862 writeln!(code, " layers: Vec<LayerBuffers>,")?;
2863 writeln!(code)?;
2864 writeln!(code, " /// Final layer-norm weight [HIDDEN_SIZE]")?;
2865 writeln!(code, " norm_buf: Buffer,")?;
2866 writeln!(code)?;
2867 writeln!(
2868 code,
2869 " /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
2870 )?;
2871 writeln!(code, " lm_head_buf: Buffer,")?;
2872 writeln!(code)?;
2873 writeln!(
2874 code,
2875 " // ── Working buffers (pre-allocated, reused every forward pass) ──"
2876 )?;
2877 writeln!(code, " hidden_buf: Buffer,")?;
2878 writeln!(code, " residual_buf: Buffer,")?;
2879 writeln!(code, " normed_buf: Buffer,")?;
2880 writeln!(
2881 code,
2882 " /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
2883 )?;
2884 writeln!(code, " qkv_buf: Buffer,")?;
2885 writeln!(code, " attn_out_buf: Buffer,")?;
2886 writeln!(code, " attn_proj_buf: Buffer,")?;
2887 writeln!(
2888 code,
2889 " /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
2890 )?;
2891 writeln!(code, " gate_up_buf: Buffer,")?;
2892 writeln!(code, " ffn_hidden_buf: Buffer,")?;
2893 writeln!(code, " ffn_out_buf: Buffer,")?;
2894 writeln!(code, " add_tmp_buf: Buffer,")?;
2895 writeln!(code, " logits_buf: Buffer,")?;
2896 writeln!(code)?;
2897 writeln!(code, " // ── Batched prefill working buffers ──")?;
2898 writeln!(code, " /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
2899 writeln!(code, " batch_hidden_buf: Buffer,")?;
2900 writeln!(
2901 code,
2902 " /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
2903 )?;
2904 writeln!(code, " batch_residual_buf: Buffer,")?;
2905 writeln!(code, " /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
2906 writeln!(code, " batch_qkv_buf: Buffer,")?;
2907 writeln!(
2908 code,
2909 " /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
2910 )?;
2911 writeln!(code, " batch_attn_out_buf: Buffer,")?;
2912 writeln!(
2913 code,
2914 " /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
2915 )?;
2916 writeln!(code, " batch_attn_proj_buf: Buffer,")?;
2917 writeln!(
2918 code,
2919 " /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
2920 )?;
2921 writeln!(code, " batch_gate_up_buf: Buffer,")?;
2922 writeln!(
2923 code,
2924 " /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
2925 )?;
2926 writeln!(code, " batch_ffn_hidden_buf: Buffer,")?;
2927 writeln!(code, " /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
2928 writeln!(code, " batch_ffn_out_buf: Buffer,")?;
2929 writeln!(code, " /// Token IDs buffer for batch embedding lookup")?;
2930 writeln!(code, " batch_tokens_buf: Buffer,")?;
2931 writeln!(code, " /// Positions buffer for batch RoPE")?;
2932 writeln!(code, " batch_positions_buf: Buffer,")?;
2933 writeln!(code)?;
2934 writeln!(code, " // ── KV cache buffers (per-layer) ──")?;
2935 writeln!(code, " k_cache: Vec<Buffer>, // per-layer")?;
2936 writeln!(code, " v_cache: Vec<Buffer>, // per-layer")?;
2937 writeln!(code)?;
2938 writeln!(code, " // ── Inference state ──")?;
2939 writeln!(code, " pos: usize,")?;
2940 writeln!(code)?;
2941 writeln!(
2942 code,
2943 " /// Previous command buffer for double-buffered prefill."
2944 )?;
2945 writeln!(
2946 code,
2947 " /// While the GPU executes token N, the CPU can encode token N+1."
2948 )?;
2949 writeln!(code, " prev_cmd: Option<CommandBuffer>,")?;
2950 writeln!(code, "}}")?;
2951 writeln!(code)?;
2952
2953 Ok(())
2954}
2955
2956fn emit_layer_buffers_struct(
2957 code: &mut String,
2958 config: &ModelConfig,
2959) -> Result<(), MetalCodegenError> {
2960 writeln!(
2961 code,
2962 "/// Per-layer weight buffers for attention and FFN projections."
2963 )?;
2964 writeln!(code, "struct LayerBuffers {{")?;
2965 writeln!(code, " attn_norm: Buffer,")?;
2966 writeln!(
2967 code,
2968 " /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
2969 )?;
2970 writeln!(code, " qkv_weight: Buffer,")?;
2971 if config.qkv_bias {
2972 writeln!(
2973 code,
2974 " /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
2975 )?;
2976 writeln!(code, " qkv_bias: Buffer,")?;
2977 }
2978 writeln!(code, " o_weight: Buffer,")?;
2979 writeln!(code, " ffn_norm: Buffer,")?;
2980 writeln!(
2981 code,
2982 " /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
2983 )?;
2984 writeln!(code, " gate_up_weight: Buffer,")?;
2985 writeln!(code, " down_weight: Buffer,")?;
2986 writeln!(code, "}}")?;
2987 writeln!(code)?;
2988
2989 Ok(())
2990}
2991
2992fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2993 let hidden = config.hidden_size;
2994 let intermediate = config.intermediate_size;
2995 let _num_layers = config.num_layers;
2996 let num_heads = config.num_attention_heads;
2997 let num_kv_heads = config.num_kv_heads;
2998 let head_dim = config.head_dim;
2999 let vocab = config.vocab_size;
3000 let effective_seq_len = config.max_seq_len.min(4096);
3001 let is_q8 = config.dtype == DType::Q8_0;
3002 let is_q4 = config.dtype == DType::Q4_0;
3003 let kv_dim = num_kv_heads * head_dim;
3004
3005 writeln!(code, "impl MetalModel {{")?;
3006
3007 writeln!(
3009 code,
3010 " /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
3011 )?;
3012 writeln!(code, " ///")?;
3013 writeln!(
3014 code,
3015 " /// `weights` is the raw weight blob produced by `forge export-weights`."
3016 )?;
3017 writeln!(code, " pub fn new(weights: &[u8]) -> Self {{")?;
3018 writeln!(
3019 code,
3020 " let device = Device::system_default().expect(\"no Metal device found\");"
3021 )?;
3022 writeln!(code, " let queue = device.new_command_queue();")?;
3023 writeln!(code)?;
3024
3025 writeln!(
3027 code,
3028 " // Compile Metal shaders from embedded source"
3029 )?;
3030 writeln!(
3031 code,
3032 " let shader_source = include_str!(\"../shaders/kernels.metal\");"
3033 )?;
3034 writeln!(
3035 code,
3036 " let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
3037 )?;
3038 writeln!(
3039 code,
3040 " .expect(\"failed to compile Metal shaders\");"
3041 )?;
3042 writeln!(code)?;
3043
3044 writeln!(code, " // Create compute pipelines")?;
3046 for (var, fn_name) in [
3047 ("matmul_pipeline", "matmul_vec"),
3048 ("matmul_q8_pipeline", "matmul_vec_q8"),
3049 ("matmul_q4_pipeline", "matmul_vec_q4"),
3050 ("rms_norm_pipeline", "rms_norm"),
3051 ("rope_pipeline", "rope"),
3052 ("softmax_pipeline", "softmax"),
3053 ("silu_mul_pipeline", "silu_mul"),
3054 ("silu_mul_fused_pipeline", "silu_mul_fused"),
3055 ("add_pipeline", "elementwise_add"),
3056 ("attention_pipeline", "attention"),
3057 ("add_inplace_pipeline", "add_inplace"),
3058 ("copy_pipeline", "copy_buffer"),
3059 ("copy_offset_pipeline", "copy_offset"),
3060 ("matmul_batch_pipeline", "matmul_vec_batch"),
3061 ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
3062 ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
3063 ("matmul_q8_mma_pipeline", "matmul_q8_mma"),
3064 ("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
3065 ("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
3066 ("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
3067 ("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
3068 ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
3069 ("rms_norm_batch_pipeline", "rms_norm_batch"),
3070 ("rope_batch_pipeline", "rope_batch"),
3071 ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
3072 ("add_inplace_batch_pipeline", "add_inplace_batch"),
3073 ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
3074 ("attention_batch_pipeline", "attention_batch"),
3075 ("attention_flash_batch_pipeline", "attention_flash_batch"),
3076 ("copy_kv_batch_pipeline", "copy_kv_batch"),
3077 ("rope_qk_batch_pipeline", "rope_qk_batch"),
3078 ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
3079 ] {
3080 writeln!(
3081 code,
3082 " let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
3083 )?;
3084 }
3085 if config.qkv_bias {
3086 writeln!(
3087 code,
3088 " let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
3089 )?;
3090 }
3091 writeln!(code)?;
3092
3093 writeln!(
3095 code,
3096 " // Load weights into Metal shared-memory buffers"
3097 )?;
3098 writeln!(code, " let f32_size = mem::size_of::<f32>();")?;
3099 writeln!(code, " let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
3100 writeln!(code, " let hidden_elems = HIDDEN_SIZE;")?;
3101 writeln!(code)?;
3102 writeln!(
3103 code,
3104 " let cursor = std::cell::Cell::new(0usize); // byte cursor into `weights`"
3105 )?;
3106 writeln!(code)?;
3107 writeln!(
3108 code,
3109 " // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
3110 )?;
3111 writeln!(
3112 code,
3113 " let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
3114 )?;
3115 writeln!(code, " let byte_len = n * f32_size;")?;
3116 writeln!(code, " let cur = cursor.get();")?;
3117 writeln!(
3118 code,
3119 " let data = &weights[cur..cur + byte_len];"
3120 )?;
3121 writeln!(code, " cursor.set(cur + byte_len);")?;
3122 writeln!(code, " device.new_buffer_with_data(")?;
3123 writeln!(code, " data.as_ptr() as *const _,")?;
3124 writeln!(code, " byte_len as u64,")?;
3125 writeln!(
3126 code,
3127 " MTLResourceOptions::StorageModeShared,"
3128 )?;
3129 writeln!(code, " )")?;
3130 writeln!(code, " }};")?;
3131 writeln!(code)?;
3132
3133 if is_q8 {
3134 writeln!(
3139 code,
3140 " // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
3141 )?;
3142 writeln!(
3143 code,
3144 " // as raw bytes into a Metal buffer (no dequantization)."
3145 )?;
3146 writeln!(
3147 code,
3148 " // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
3149 )?;
3150 writeln!(
3151 code,
3152 " let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3153 )?;
3154 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3155 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3156 writeln!(code, " let total_raw = rows * row_bytes;")?;
3157 writeln!(code, " let cur = cursor.get();")?;
3158 writeln!(
3159 code,
3160 " let data = &weights[cur..cur + total_raw];"
3161 )?;
3162 writeln!(code, " cursor.set(cur + total_raw);")?;
3163 writeln!(code, " device.new_buffer_with_data(")?;
3164 writeln!(code, " data.as_ptr() as *const _,")?;
3165 writeln!(code, " total_raw as u64,")?;
3166 writeln!(
3167 code,
3168 " MTLResourceOptions::StorageModeShared,"
3169 )?;
3170 writeln!(code, " )")?;
3171 writeln!(code, " }};")?;
3172 writeln!(code)?;
3173 writeln!(
3174 code,
3175 " // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
3176 )?;
3177 writeln!(
3178 code,
3179 " // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
3180 )?;
3181 writeln!(
3182 code,
3183 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3184 )?;
3185 writeln!(
3186 code,
3187 " let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3188 )?;
3189 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3190 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
3191 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3192 writeln!(code, " let cur = cursor.get();")?;
3193 writeln!(
3194 code,
3195 " let data = &weights[cur..cur + total_raw];"
3196 )?;
3197 writeln!(code, " cursor.set(cur + total_raw);")?;
3198 writeln!(code, " device.new_buffer_with_data(")?;
3199 writeln!(code, " data.as_ptr() as *const _,")?;
3200 writeln!(code, " total_raw as u64,")?;
3201 writeln!(
3202 code,
3203 " MTLResourceOptions::StorageModeShared,"
3204 )?;
3205 writeln!(code, " )")?;
3206 writeln!(code, " }};")?;
3207 writeln!(code)?;
3208 }
3209
3210 if is_q4 {
3211 writeln!(
3216 code,
3217 " // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
3218 )?;
3219 writeln!(
3220 code,
3221 " // as raw bytes into a Metal buffer (no dequantization)."
3222 )?;
3223 writeln!(
3224 code,
3225 " // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
3226 )?;
3227 writeln!(
3228 code,
3229 " let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
3230 )?;
3231 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3232 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3233 writeln!(code, " let total_raw = rows * row_bytes;")?;
3234 writeln!(code, " let cur = cursor.get();")?;
3235 writeln!(
3236 code,
3237 " let data = &weights[cur..cur + total_raw];"
3238 )?;
3239 writeln!(code, " cursor.set(cur + total_raw);")?;
3240 writeln!(code, " device.new_buffer_with_data(")?;
3241 writeln!(code, " data.as_ptr() as *const _,")?;
3242 writeln!(code, " total_raw as u64,")?;
3243 writeln!(
3244 code,
3245 " MTLResourceOptions::StorageModeShared,"
3246 )?;
3247 writeln!(code, " )")?;
3248 writeln!(code, " }};")?;
3249 writeln!(code)?;
3250 writeln!(
3251 code,
3252 " // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
3253 )?;
3254 writeln!(
3255 code,
3256 " // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
3257 )?;
3258 writeln!(
3259 code,
3260 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
3261 )?;
3262 writeln!(
3263 code,
3264 " let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
3265 )?;
3266 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
3267 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
3268 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
3269 writeln!(code, " let cur = cursor.get();")?;
3270 writeln!(
3271 code,
3272 " let data = &weights[cur..cur + total_raw];"
3273 )?;
3274 writeln!(code, " cursor.set(cur + total_raw);")?;
3275 writeln!(code, " device.new_buffer_with_data(")?;
3276 writeln!(code, " data.as_ptr() as *const _,")?;
3277 writeln!(code, " total_raw as u64,")?;
3278 writeln!(
3279 code,
3280 " MTLResourceOptions::StorageModeShared,"
3281 )?;
3282 writeln!(code, " )")?;
3283 writeln!(code, " }};")?;
3284 writeln!(code)?;
3285 }
3286
3287 writeln!(
3288 code,
3289 " let embed_buf = next_f32_buffer(&device, embed_elems);"
3290 )?;
3291 writeln!(code)?;
3292
3293 writeln!(
3295 code,
3296 " let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
3297 )?;
3298 writeln!(code, " for _layer in 0..NUM_LAYERS {{")?;
3299
3300 writeln!(
3302 code,
3303 " let attn_norm = next_f32_buffer(&device, hidden_elems);"
3304 )?;
3305
3306 let qkv_rows = hidden + 2 * kv_dim;
3307 if is_q8 {
3308 writeln!(
3310 code,
3311 " let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
3312 )?;
3313 if config.qkv_bias {
3314 writeln!(
3315 code,
3316 " // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
3317 )?;
3318 writeln!(
3319 code,
3320 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3321 )?;
3322 }
3323 writeln!(
3324 code,
3325 " let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
3326 )?;
3327 } else if is_q4 {
3328 writeln!(
3329 code,
3330 " let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
3331 )?;
3332 if config.qkv_bias {
3333 writeln!(
3334 code,
3335 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3336 )?;
3337 }
3338 writeln!(
3339 code,
3340 " let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
3341 )?;
3342 } else {
3343 writeln!(
3344 code,
3345 " let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
3346 )?;
3347 if config.qkv_bias {
3348 writeln!(
3349 code,
3350 " let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
3351 )?;
3352 }
3353 writeln!(
3354 code,
3355 " let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
3356 )?;
3357 }
3358
3359 writeln!(
3361 code,
3362 " let ffn_norm = next_f32_buffer(&device, hidden_elems);"
3363 )?;
3364
3365 let gate_up_rows = 2 * intermediate;
3366 if is_q8 {
3367 writeln!(
3369 code,
3370 " let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
3371 )?;
3372 writeln!(
3373 code,
3374 " let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
3375 )?;
3376 } else if is_q4 {
3377 writeln!(
3379 code,
3380 " let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
3381 )?;
3382 writeln!(
3383 code,
3384 " let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
3385 )?;
3386 } else {
3387 writeln!(
3389 code,
3390 " let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
3391 )?;
3392 writeln!(
3393 code,
3394 " let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
3395 )?;
3396 }
3397
3398 writeln!(code, " layers.push(LayerBuffers {{")?;
3399 writeln!(code, " attn_norm,")?;
3400 writeln!(code, " qkv_weight,")?;
3401 if config.qkv_bias {
3402 writeln!(code, " qkv_bias,")?;
3403 }
3404 writeln!(code, " o_weight,")?;
3405 writeln!(code, " ffn_norm,")?;
3406 writeln!(code, " gate_up_weight,")?;
3407 writeln!(code, " down_weight,")?;
3408 writeln!(code, " }});")?;
3409 writeln!(code, " }}")?;
3410 writeln!(code)?;
3411
3412 writeln!(
3414 code,
3415 " let norm_buf = next_f32_buffer(&device, hidden_elems);"
3416 )?;
3417 writeln!(code)?;
3418
3419 if is_q8 {
3421 writeln!(
3422 code,
3423 " let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
3424 )?;
3425 } else if is_q4 {
3426 writeln!(
3427 code,
3428 " let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
3429 )?;
3430 } else {
3431 writeln!(
3432 code,
3433 " let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
3434 )?;
3435 }
3436 writeln!(code)?;
3437
3438 let hidden_bytes = hidden * 4;
3440 let _kv_dim_bytes = kv_dim * 4;
3441 let intermediate_bytes = intermediate * 4;
3442 let vocab_bytes = vocab * 4;
3443 let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
3444
3445 writeln!(
3446 code,
3447 " // Allocate working buffers (shared memory for zero-copy)"
3448 )?;
3449 writeln!(
3450 code,
3451 " let opts = MTLResourceOptions::StorageModeShared;"
3452 )?;
3453 writeln!(
3454 code,
3455 " let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3456 )?;
3457 writeln!(
3458 code,
3459 " let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3460 )?;
3461 let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
3462 writeln!(
3463 code,
3464 " let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3465 )?;
3466 writeln!(
3467 code,
3468 " // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
3469 )?;
3470 writeln!(
3471 code,
3472 " let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
3473 )?;
3474 writeln!(
3475 code,
3476 " let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3477 )?;
3478 writeln!(
3479 code,
3480 " let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3481 )?;
3482 let gate_up_buf_bytes = 2 * intermediate * 4;
3483 writeln!(
3484 code,
3485 " // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
3486 )?;
3487 writeln!(
3488 code,
3489 " let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
3490 )?;
3491 writeln!(
3492 code,
3493 " let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
3494 )?;
3495 writeln!(
3496 code,
3497 " let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3498 )?;
3499 writeln!(
3500 code,
3501 " let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
3502 )?;
3503 writeln!(
3504 code,
3505 " let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
3506 )?;
3507 writeln!(code)?;
3508
3509 let batch_hidden_bytes = hidden * 4; let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
3512 let batch_gate_up_bytes = 2 * intermediate * 4;
3513 let batch_intermediate_bytes = intermediate * 4;
3514 writeln!(
3515 code,
3516 " // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
3517 )?;
3518 writeln!(
3519 code,
3520 " let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3521 )?;
3522 writeln!(
3523 code,
3524 " let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3525 )?;
3526 writeln!(
3527 code,
3528 " let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
3529 )?;
3530 writeln!(
3531 code,
3532 " let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3533 )?;
3534 writeln!(
3535 code,
3536 " let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3537 )?;
3538 writeln!(
3539 code,
3540 " let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
3541 )?;
3542 writeln!(
3543 code,
3544 " let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
3545 )?;
3546 writeln!(
3547 code,
3548 " let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
3549 )?;
3550 writeln!(
3551 code,
3552 " let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3553 )?;
3554 writeln!(
3555 code,
3556 " let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
3557 )?;
3558 writeln!(code)?;
3559
3560 writeln!(code, " // KV cache buffers (per-layer)")?;
3562 writeln!(
3563 code,
3564 " let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3565 )?;
3566 writeln!(
3567 code,
3568 " let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
3569 )?;
3570 writeln!(code, " for _ in 0..NUM_LAYERS {{")?;
3571 writeln!(
3572 code,
3573 " k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3574 )?;
3575 writeln!(
3576 code,
3577 " v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
3578 )?;
3579 writeln!(code, " }}")?;
3580 writeln!(code)?;
3581
3582 writeln!(code, " Self {{")?;
3583 writeln!(code, " device,")?;
3584 writeln!(code, " queue,")?;
3585 writeln!(code, " matmul_pipeline,")?;
3586 writeln!(code, " matmul_q8_pipeline,")?;
3587 writeln!(code, " matmul_q4_pipeline,")?;
3588 writeln!(code, " rms_norm_pipeline,")?;
3589 writeln!(code, " rope_pipeline,")?;
3590 writeln!(code, " softmax_pipeline,")?;
3591 writeln!(code, " silu_mul_pipeline,")?;
3592 writeln!(code, " silu_mul_fused_pipeline,")?;
3593 writeln!(code, " add_pipeline,")?;
3594 writeln!(code, " attention_pipeline,")?;
3595 writeln!(code, " add_inplace_pipeline,")?;
3596 writeln!(code, " copy_pipeline,")?;
3597 writeln!(code, " copy_offset_pipeline,")?;
3598 writeln!(code, " matmul_batch_pipeline,")?;
3599 writeln!(code, " matmul_q8_batch_pipeline,")?;
3600 writeln!(code, " matmul_q8_gemm_batch_pipeline,")?;
3601 writeln!(code, " matmul_q8_mma_pipeline,")?;
3602 writeln!(code, " matmul_q8_mma32_pipeline,")?;
3603 writeln!(code, " matmul_q8_mma32_h_pipeline,")?;
3604 writeln!(code, " matmul_q8_mma32_h4_pipeline,")?;
3605 writeln!(code, " matmul_q8_mma32_hh4_pipeline,")?;
3606 if config.qkv_bias {
3607 writeln!(code, " add_bias_batch_pipeline,")?;
3608 }
3609 writeln!(code, " matmul_q4_batch_pipeline,")?;
3610 writeln!(code, " rms_norm_batch_pipeline,")?;
3611 writeln!(code, " rope_batch_pipeline,")?;
3612 writeln!(code, " silu_mul_fused_batch_pipeline,")?;
3613 writeln!(code, " add_inplace_batch_pipeline,")?;
3614 writeln!(code, " copy_embedding_batch_pipeline,")?;
3615 writeln!(code, " attention_batch_pipeline,")?;
3616 writeln!(code, " attention_flash_batch_pipeline,")?;
3617 writeln!(code, " copy_kv_batch_pipeline,")?;
3618 writeln!(code, " rope_qk_batch_pipeline,")?;
3619 writeln!(code, " copy_kv_both_batch_pipeline,")?;
3620 writeln!(code, " embed_buf,")?;
3621 writeln!(code, " layers,")?;
3622 writeln!(code, " norm_buf,")?;
3623 writeln!(code, " lm_head_buf,")?;
3624 writeln!(code, " hidden_buf,")?;
3625 writeln!(code, " residual_buf,")?;
3626 writeln!(code, " normed_buf,")?;
3627 writeln!(code, " qkv_buf,")?;
3628 writeln!(code, " attn_out_buf,")?;
3629 writeln!(code, " attn_proj_buf,")?;
3630 writeln!(code, " gate_up_buf,")?;
3631 writeln!(code, " ffn_hidden_buf,")?;
3632 writeln!(code, " ffn_out_buf,")?;
3633 writeln!(code, " add_tmp_buf,")?;
3634 writeln!(code, " logits_buf,")?;
3635 writeln!(code, " batch_hidden_buf,")?;
3636 writeln!(code, " batch_residual_buf,")?;
3637 writeln!(code, " batch_qkv_buf,")?;
3638 writeln!(code, " batch_attn_out_buf,")?;
3639 writeln!(code, " batch_attn_proj_buf,")?;
3640 writeln!(code, " batch_gate_up_buf,")?;
3641 writeln!(code, " batch_ffn_hidden_buf,")?;
3642 writeln!(code, " batch_ffn_out_buf,")?;
3643 writeln!(code, " batch_tokens_buf,")?;
3644 writeln!(code, " batch_positions_buf,")?;
3645 writeln!(code, " k_cache,")?;
3646 writeln!(code, " v_cache,")?;
3647 writeln!(code, " pos: 0,")?;
3648 writeln!(code, " prev_cmd: None,")?;
3649 writeln!(code, " }}")?;
3650 writeln!(code, " }}")?;
3651 writeln!(code)?;
3652
3653 writeln!(
3655 code,
3656 " /// Run the forward pass for a single token at the current position."
3657 )?;
3658 writeln!(code, " ///")?;
3659 writeln!(
3660 code,
3661 " /// Returns logits over the vocabulary as a `Vec<f32>`."
3662 )?;
3663 writeln!(code, " ///")?;
3664 writeln!(
3665 code,
3666 " /// All GPU operations are encoded into a single command buffer and"
3667 )?;
3668 writeln!(
3669 code,
3670 " /// committed once at the end, avoiding per-operation synchronization."
3671 )?;
3672 writeln!(
3673 code,
3674 " pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
3675 )?;
3676 writeln!(
3677 code,
3678 " // Wait for any pending prefill command buffer"
3679 )?;
3680 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
3681 writeln!(code, " prev.wait_until_completed();")?;
3682 writeln!(code, " }}")?;
3683 writeln!(code)?;
3684 writeln!(code, " let pos = self.pos;")?;
3685 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
3686 writeln!(code)?;
3687
3688 let matmul_fn = if is_q8 {
3691 "dispatch_matmul_q8"
3692 } else if is_q4 {
3693 "dispatch_matmul_q4"
3694 } else {
3695 "dispatch_matmul"
3696 };
3697
3698 writeln!(
3699 code,
3700 " // Single compute encoder for the entire forward pass (no blit transitions)"
3701 )?;
3702 writeln!(code, " {{")?;
3703 writeln!(
3704 code,
3705 " let enc = cmd.new_compute_command_encoder();"
3706 )?;
3707 writeln!(code)?;
3708
3709 writeln!(
3711 code,
3712 " // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
3713 )?;
3714 writeln!(
3715 code,
3716 " // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
3717 )?;
3718 writeln!(
3719 code,
3720 " // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
3721 )?;
3722 writeln!(
3723 code,
3724 " // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
3725 hidden * 4,
3726 )?;
3727 writeln!(code, " unsafe {{")?;
3728 writeln!(
3729 code,
3730 " let embed_ptr = self.embed_buf.contents() as *const f32;"
3731 )?;
3732 writeln!(
3733 code,
3734 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3735 )?;
3736 writeln!(
3737 code,
3738 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
3739 )?;
3740 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
3741 writeln!(
3742 code,
3743 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3744 )?;
3745 writeln!(code, " hidden_ptr,")?;
3746 writeln!(code, " HIDDEN_SIZE,")?;
3747 writeln!(code, " );")?;
3748 writeln!(
3749 code,
3750 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3751 )?;
3752 writeln!(code, " }}")?;
3753 writeln!(code)?;
3754
3755 writeln!(code, " // 2. Transformer layers")?;
3757 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
3758 writeln!(code)?;
3759 let q_byte_offset = 0usize;
3760 let k_byte_offset = hidden * 4;
3761 let v_byte_offset = (hidden + kv_dim) * 4;
3762
3763 writeln!(
3764 code,
3765 " // Pre-attention: rms_norm, fused QKV projection, RoPE"
3766 )?;
3767 writeln!(
3768 code,
3769 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3770 )?;
3771 writeln!(
3772 code,
3773 " // Fused Q+K+V matmul: single dispatch for all three projections"
3774 )?;
3775 writeln!(
3776 code,
3777 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3778 )?;
3779 if config.qkv_bias {
3780 writeln!(
3781 code,
3782 " // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
3783 )?;
3784 writeln!(
3785 code,
3786 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
3787 )?;
3788 }
3789 writeln!(
3790 code,
3791 " // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
3792 )?;
3793 writeln!(
3794 code,
3795 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3796 )?;
3797 writeln!(
3798 code,
3799 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3800 )?;
3801 writeln!(code)?;
3802 writeln!(
3803 code,
3804 " // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
3805 )?;
3806 writeln!(code, " let kv_offset = pos * {kv_dim};")?;
3807 writeln!(
3808 code,
3809 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3810 )?;
3811 writeln!(
3812 code,
3813 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3814 )?;
3815 writeln!(code)?;
3816 writeln!(
3817 code,
3818 " // Attention using Q from qkv_buf (offset 0)"
3819 )?;
3820 writeln!(
3821 code,
3822 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3823 )?;
3824 writeln!(
3825 code,
3826 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3827 )?;
3828 writeln!(
3829 code,
3830 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3831 )?;
3832 writeln!(
3833 code,
3834 " // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
3835 )?;
3836 writeln!(
3837 code,
3838 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3839 )?;
3840 writeln!(
3841 code,
3842 " // Fused gate+up matmul: single dispatch for both projections"
3843 )?;
3844 writeln!(
3845 code,
3846 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3847 )?;
3848 writeln!(
3849 code,
3850 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3851 )?;
3852 writeln!(
3853 code,
3854 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3855 )?;
3856 writeln!(
3857 code,
3858 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3859 )?;
3860 writeln!(code, " }}")?;
3861 writeln!(code)?;
3862
3863 writeln!(code, " // 3. Final RMS norm + logits projection")?;
3865 writeln!(
3866 code,
3867 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
3868 )?;
3869 writeln!(
3870 code,
3871 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3872 )?;
3873 writeln!(code)?;
3874 writeln!(code, " enc.end_encoding();")?;
3875 writeln!(code, " }}")?;
3876 writeln!(code)?;
3877
3878 writeln!(
3880 code,
3881 " // 5. Commit all GPU work and wait for completion"
3882 )?;
3883 writeln!(code, " cmd.commit();")?;
3884 writeln!(code, " cmd.wait_until_completed();")?;
3885 writeln!(code)?;
3886 writeln!(code, " // 6. Read back logits from GPU")?;
3887 writeln!(code, " let logits = unsafe {{")?;
3888 writeln!(
3889 code,
3890 " let ptr = self.logits_buf.contents() as *const f32;"
3891 )?;
3892 writeln!(
3893 code,
3894 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3895 )?;
3896 writeln!(code, " }};")?;
3897 writeln!(code)?;
3898 writeln!(code, " self.pos += 1;")?;
3899 writeln!(code, " logits")?;
3900 writeln!(code, " }}")?;
3901 writeln!(code)?;
3902
3903 writeln!(
3905 code,
3906 " /// Profiling forward pass that prints per-stage GPU timing."
3907 )?;
3908 writeln!(code, " ///")?;
3909 writeln!(
3910 code,
3911 " /// Each stage is committed and waited on separately so that GPU timestamps"
3912 )?;
3913 writeln!(
3914 code,
3915 " /// accurately reflect per-operation cost. This is slower than `forward()` due"
3916 )?;
3917 writeln!(
3918 code,
3919 " /// to the per-stage synchronization, but useful for identifying bottlenecks."
3920 )?;
3921 writeln!(
3922 code,
3923 " pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
3924 )?;
3925 writeln!(code, " use std::time::Instant;")?;
3926 writeln!(code)?;
3927 writeln!(
3928 code,
3929 " // Wait for any pending prefill command buffer"
3930 )?;
3931 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
3932 writeln!(code, " prev.wait_until_completed();")?;
3933 writeln!(code, " }}")?;
3934 writeln!(code)?;
3935 writeln!(code, " let pos = self.pos;")?;
3936 writeln!(code)?;
3937
3938 writeln!(
3940 code,
3941 " // ── Stage: Embedding lookup (CPU via unified memory) ──"
3942 )?;
3943 writeln!(code, " let t_embed = Instant::now();")?;
3944 writeln!(code, " unsafe {{")?;
3945 writeln!(
3946 code,
3947 " let embed_ptr = self.embed_buf.contents() as *const f32;"
3948 )?;
3949 writeln!(
3950 code,
3951 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
3952 )?;
3953 writeln!(
3954 code,
3955 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
3956 )?;
3957 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
3958 writeln!(
3959 code,
3960 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
3961 )?;
3962 writeln!(code, " hidden_ptr,")?;
3963 writeln!(code, " HIDDEN_SIZE,")?;
3964 writeln!(code, " );")?;
3965 writeln!(
3966 code,
3967 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
3968 )?;
3969 writeln!(code, " }}")?;
3970 writeln!(code, " let d_embed = t_embed.elapsed();")?;
3971 writeln!(code)?;
3972
3973 writeln!(code, " // ── Stage: Transformer layers (GPU) ──")?;
3975 writeln!(code, " let t_layers = Instant::now();")?;
3976 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
3977 writeln!(code, " {{")?;
3978 writeln!(
3979 code,
3980 " let enc = cmd.new_compute_command_encoder();"
3981 )?;
3982 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
3983 writeln!(
3984 code,
3985 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3986 )?;
3987 writeln!(
3988 code,
3989 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3990 )?;
3991 if config.qkv_bias {
3992 writeln!(
3993 code,
3994 " self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
3995 )?;
3996 }
3997 writeln!(
3998 code,
3999 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
4000 )?;
4001 writeln!(
4002 code,
4003 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
4004 )?;
4005 writeln!(code, " let kv_offset = pos * {kv_dim};")?;
4006 writeln!(
4007 code,
4008 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
4009 )?;
4010 writeln!(
4011 code,
4012 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
4013 )?;
4014 writeln!(
4015 code,
4016 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
4017 )?;
4018 writeln!(
4019 code,
4020 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
4021 )?;
4022 writeln!(
4023 code,
4024 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
4025 )?;
4026 writeln!(
4027 code,
4028 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
4029 )?;
4030 writeln!(
4031 code,
4032 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
4033 )?;
4034 writeln!(
4035 code,
4036 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
4037 )?;
4038 writeln!(
4039 code,
4040 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
4041 )?;
4042 writeln!(
4043 code,
4044 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
4045 )?;
4046 writeln!(code, " }}")?;
4047 writeln!(code, " enc.end_encoding();")?;
4048 writeln!(code, " }}")?;
4049 writeln!(code, " cmd.commit();")?;
4050 writeln!(code, " cmd.wait_until_completed();")?;
4051 writeln!(code, " let d_layers = t_layers.elapsed();")?;
4052 writeln!(code)?;
4053
4054 writeln!(code, " // ── Stage: Final norm + logits (GPU) ──")?;
4056 writeln!(code, " let t_logits = Instant::now();")?;
4057 writeln!(code, " let cmd2 = self.queue.new_command_buffer();")?;
4058 writeln!(code, " {{")?;
4059 writeln!(
4060 code,
4061 " let enc = cmd2.new_compute_command_encoder();"
4062 )?;
4063 writeln!(
4064 code,
4065 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
4066 )?;
4067 writeln!(
4068 code,
4069 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
4070 )?;
4071 writeln!(code, " enc.end_encoding();")?;
4072 writeln!(code, " }}")?;
4073 writeln!(code, " cmd2.commit();")?;
4074 writeln!(code, " cmd2.wait_until_completed();")?;
4075 writeln!(code, " let d_logits = t_logits.elapsed();")?;
4076 writeln!(code)?;
4077
4078 writeln!(
4080 code,
4081 " eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
4082 )?;
4083 writeln!(code, " d_embed.as_secs_f64() * 1000.0,")?;
4084 writeln!(code, " d_layers.as_secs_f64() * 1000.0,")?;
4085 writeln!(code, " d_logits.as_secs_f64() * 1000.0,")?;
4086 writeln!(
4087 code,
4088 " (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
4089 )?;
4090 writeln!(code)?;
4091
4092 writeln!(code, " let logits = unsafe {{")?;
4094 writeln!(
4095 code,
4096 " let ptr = self.logits_buf.contents() as *const f32;"
4097 )?;
4098 writeln!(
4099 code,
4100 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
4101 )?;
4102 writeln!(code, " }};")?;
4103 writeln!(code)?;
4104 writeln!(code, " self.pos += 1;")?;
4105 writeln!(code, " logits")?;
4106 writeln!(code, " }}")?;
4107 writeln!(code)?;
4108
4109 writeln!(
4111 code,
4112 " /// Asynchronous forward pass for a single prefill token (no logits readback)."
4113 )?;
4114 writeln!(code, " ///")?;
4115 writeln!(
4116 code,
4117 " /// Commits the command buffer without waiting, enabling double-buffered"
4118 )?;
4119 writeln!(
4120 code,
4121 " /// execution: GPU processes token N while CPU encodes token N+1."
4122 )?;
4123 writeln!(
4124 code,
4125 " pub fn forward_prefill(&mut self, token_id: u32) {{"
4126 )?;
4127 writeln!(code, " self.forward_prefill_batch(&[token_id]);")?;
4128 writeln!(code, " }}")?;
4129 writeln!(code)?;
4130
4131 let batch_matmul_fn = if is_q8 {
4134 "dispatch_matmul_q8_batch"
4135 } else if is_q4 {
4136 "dispatch_matmul_q4_batch"
4137 } else {
4138 "dispatch_matmul_batch"
4139 };
4140
4141 writeln!(
4142 code,
4143 " /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
4144 )?;
4145 writeln!(code, " ///")?;
4146 writeln!(
4147 code,
4148 " /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
4149 )?;
4150 writeln!(
4151 code,
4152 " /// of mat-vec), and batched causal attention with a single GPU dispatch."
4153 )?;
4154 writeln!(
4155 code,
4156 " /// This provides significant speedup during prompt prefill."
4157 )?;
4158 writeln!(
4159 code,
4160 " pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
4161 )?;
4162 writeln!(code, " let m = tokens.len().min(MAX_BATCH_SIZE);")?;
4163 writeln!(code, " if m == 0 {{ return; }}")?;
4164 writeln!(code, " let start_pos = self.pos;")?;
4165 writeln!(code)?;
4166 writeln!(code, " // Wait for any pending command buffer")?;
4167 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
4168 writeln!(code, " prev.wait_until_completed();")?;
4169 writeln!(code, " }}")?;
4170 writeln!(code)?;
4171
4172 writeln!(
4174 code,
4175 " // Upload token IDs and positions to GPU buffers"
4176 )?;
4177 writeln!(code, " unsafe {{")?;
4178 writeln!(
4179 code,
4180 " let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
4181 )?;
4182 writeln!(
4183 code,
4184 " let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
4185 )?;
4186 writeln!(code, " for i in 0..m {{")?;
4187 writeln!(code, " *tok_ptr.add(i) = tokens[i];")?;
4188 writeln!(
4189 code,
4190 " *pos_ptr.add(i) = (start_pos + i) as u32;"
4191 )?;
4192 writeln!(code, " }}")?;
4193 writeln!(code, " }}")?;
4194 writeln!(code)?;
4195
4196 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
4197 writeln!(code, " {{")?;
4198 writeln!(
4199 code,
4200 " let enc = cmd.new_compute_command_encoder();"
4201 )?;
4202 writeln!(code)?;
4203
4204 writeln!(
4206 code,
4207 " // 1. Batch embedding lookup: copy all token embeddings at once"
4208 )?;
4209 writeln!(
4210 code,
4211 " self.dispatch_copy_embedding_batch(&enc, m);"
4212 )?;
4213 writeln!(
4215 code,
4216 " self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
4217 )?;
4218 writeln!(code)?;
4219
4220 writeln!(code, " // 2. Transformer layers")?;
4222 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
4223 writeln!(code)?;
4224
4225 writeln!(
4227 code,
4228 " // Batch RMS norm: batch_residual -> batch_hidden"
4229 )?;
4230 writeln!(
4231 code,
4232 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
4233 )?;
4234
4235 writeln!(
4237 code,
4238 " // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
4239 )?;
4240 writeln!(
4241 code,
4242 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
4243 )?;
4244 if config.qkv_bias {
4245 writeln!(
4246 code,
4247 " // Qwen2: broadcast-add QKV bias across all M tokens."
4248 )?;
4249 writeln!(
4250 code,
4251 " self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
4252 )?;
4253 }
4254 writeln!(code)?;
4255
4256 let k_float_offset = hidden;
4258 writeln!(
4259 code,
4260 " // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
4261 )?;
4262 writeln!(
4263 code,
4264 " self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
4265 )?;
4266 writeln!(code)?;
4267
4268 let v_float_offset = hidden + kv_dim;
4270 writeln!(
4271 code,
4272 " // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
4273 )?;
4274 writeln!(
4275 code,
4276 " self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
4277 )?;
4278 writeln!(code)?;
4279
4280 writeln!(
4282 code,
4283 " // Batched causal attention: one dispatch for all M tokens"
4284 )?;
4285 writeln!(
4286 code,
4287 " self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
4288 )?;
4289 writeln!(code)?;
4290
4291 writeln!(code, " // Batched O projection")?;
4293 writeln!(
4294 code,
4295 " self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
4296 )?;
4297 writeln!(code)?;
4298
4299 writeln!(
4301 code,
4302 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
4303 )?;
4304 writeln!(code)?;
4305
4306 writeln!(
4308 code,
4309 " // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
4310 )?;
4311 writeln!(
4312 code,
4313 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
4314 )?;
4315 writeln!(
4316 code,
4317 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
4318 )?;
4319 writeln!(
4320 code,
4321 " self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
4322 )?;
4323 writeln!(
4324 code,
4325 " self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
4326 )?;
4327 writeln!(
4328 code,
4329 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
4330 )?;
4331 writeln!(code, " }}")?;
4332 writeln!(code)?;
4333
4334 writeln!(
4336 code,
4337 " // Copy last token's residual to single-token buffer for subsequent forward()"
4338 )?;
4339 writeln!(
4340 code,
4341 " self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
4342 )?;
4343 writeln!(code)?;
4344 writeln!(code, " enc.end_encoding();")?;
4345 writeln!(code, " }}")?;
4346 writeln!(code)?;
4347
4348 writeln!(code, " cmd.commit();")?;
4349 writeln!(code, " self.prev_cmd = Some(cmd.to_owned());")?;
4350 writeln!(code, " self.pos += m;")?;
4351 writeln!(code, " }}")?;
4352 writeln!(code)?;
4353
4354 writeln!(
4356 code,
4357 " /// Reset the model state for a new inference request."
4358 )?;
4359 writeln!(code, " pub fn reset(&mut self) {{")?;
4360 writeln!(code, " self.pos = 0;")?;
4361 writeln!(code, " self.prev_cmd = None;")?;
4362 writeln!(code, " }}")?;
4363 writeln!(code)?;
4364
4365 writeln!(
4367 code,
4368 " // ── Dispatch helpers (append to a shared compute command encoder) ──"
4369 )?;
4370 writeln!(
4371 code,
4372 " // These methods set pipeline state + buffers + dispatch on an existing"
4373 )?;
4374 writeln!(
4375 code,
4376 " // encoder, avoiding per-operation encoder creation overhead."
4377 )?;
4378 writeln!(code)?;
4379
4380 writeln!(
4382 code,
4383 " /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
4384 )?;
4385 writeln!(
4386 code,
4387 " fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
4388 )?;
4389 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
4390 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
4391 writeln!(
4392 code,
4393 " enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
4394 )?;
4395 writeln!(
4396 code,
4397 " enc.set_buffer(0, Some(&self.residual_buf), 0);"
4398 )?;
4399 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
4400 writeln!(
4401 code,
4402 " enc.set_buffer(2, Some(&self.hidden_buf), 0);"
4403 )?;
4404 writeln!(
4405 code,
4406 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4407 )?;
4408 writeln!(
4409 code,
4410 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4411 )?;
4412 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4413 writeln!(
4414 code,
4415 " let grid_size = MTLSize::new(1, 1, 1); // single threadgroup"
4416 )?;
4417 writeln!(
4418 code,
4419 " enc.dispatch_thread_groups(grid_size, tg_size);"
4420 )?;
4421 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4422 writeln!(code, " }}")?;
4423 writeln!(code)?;
4424
4425 writeln!(
4427 code,
4428 " /// Dispatch matrix-vector multiply: weight * input -> output."
4429 )?;
4430 writeln!(
4431 code,
4432 " fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4433 )?;
4434 writeln!(code, " let r: u32 = rows as u32;")?;
4435 writeln!(code, " let c: u32 = cols as u32;")?;
4436 writeln!(
4437 code,
4438 " enc.set_compute_pipeline_state(&self.matmul_pipeline);"
4439 )?;
4440 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4441 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4442 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4443 writeln!(
4444 code,
4445 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4446 )?;
4447 writeln!(
4448 code,
4449 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4450 )?;
4451 writeln!(
4452 code,
4453 " // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
4454 )?;
4455 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4456 writeln!(code, " let num_tg = ((rows + 63) / 64) as u64;")?;
4457 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4458 writeln!(
4459 code,
4460 " enc.dispatch_thread_groups(grid_size, tg_size);"
4461 )?;
4462 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4463 writeln!(code, " }}")?;
4464 writeln!(code)?;
4465
4466 writeln!(
4468 code,
4469 " /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
4470 )?;
4471 writeln!(
4472 code,
4473 " /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
4474 )?;
4475 writeln!(
4476 code,
4477 " fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4478 )?;
4479 writeln!(code, " let r: u32 = rows as u32;")?;
4480 writeln!(code, " let c: u32 = cols as u32;")?;
4481 writeln!(
4482 code,
4483 " enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
4484 )?;
4485 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4486 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4487 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4488 writeln!(
4489 code,
4490 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4491 )?;
4492 writeln!(
4493 code,
4494 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4495 )?;
4496 writeln!(
4497 code,
4498 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4499 )?;
4500 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4501 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4502 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4503 writeln!(
4504 code,
4505 " enc.dispatch_thread_groups(grid_size, tg_size);"
4506 )?;
4507 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4508 writeln!(code, " }}")?;
4509 writeln!(code)?;
4510
4511 writeln!(
4513 code,
4514 " /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
4515 )?;
4516 writeln!(
4517 code,
4518 " /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
4519 )?;
4520 writeln!(
4521 code,
4522 " fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
4523 )?;
4524 writeln!(code, " let r: u32 = rows as u32;")?;
4525 writeln!(code, " let c: u32 = cols as u32;")?;
4526 writeln!(
4527 code,
4528 " enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
4529 )?;
4530 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4531 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4532 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4533 writeln!(
4534 code,
4535 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4536 )?;
4537 writeln!(
4538 code,
4539 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4540 )?;
4541 writeln!(
4542 code,
4543 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
4544 )?;
4545 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4546 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
4547 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4548 writeln!(
4549 code,
4550 " enc.dispatch_thread_groups(grid_size, tg_size);"
4551 )?;
4552 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4553 writeln!(code, " }}")?;
4554 writeln!(code)?;
4555
4556 writeln!(code, " /// Dispatch RoPE on a buffer in-place.")?;
4558 writeln!(
4559 code,
4560 " fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
4561 )?;
4562 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4563 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4564 writeln!(code, " let p: u32 = pos as u32;")?;
4565 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4566 writeln!(
4567 code,
4568 " let total_pairs = num_heads * (head_dim / 2);"
4569 )?;
4570 writeln!(
4571 code,
4572 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4573 )?;
4574 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
4575 writeln!(
4576 code,
4577 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4578 )?;
4579 writeln!(
4580 code,
4581 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4582 )?;
4583 writeln!(
4584 code,
4585 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4586 )?;
4587 writeln!(
4588 code,
4589 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4590 )?;
4591 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4592 writeln!(
4593 code,
4594 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4595 )?;
4596 writeln!(
4597 code,
4598 " enc.dispatch_thread_groups(grid_size, tg_size);"
4599 )?;
4600 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4601 writeln!(code, " }}")?;
4602 writeln!(code)?;
4603
4604 writeln!(
4606 code,
4607 " /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
4608 )?;
4609 writeln!(
4610 code,
4611 " fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
4612 )?;
4613 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4614 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4615 writeln!(code, " let p: u32 = pos as u32;")?;
4616 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4617 writeln!(
4618 code,
4619 " let total_pairs = num_heads * (head_dim / 2);"
4620 )?;
4621 writeln!(
4622 code,
4623 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4624 )?;
4625 writeln!(
4626 code,
4627 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
4628 )?;
4629 writeln!(
4630 code,
4631 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4632 )?;
4633 writeln!(
4634 code,
4635 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4636 )?;
4637 writeln!(
4638 code,
4639 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4640 )?;
4641 writeln!(
4642 code,
4643 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4644 )?;
4645 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4646 writeln!(
4647 code,
4648 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4649 )?;
4650 writeln!(
4651 code,
4652 " enc.dispatch_thread_groups(grid_size, tg_size);"
4653 )?;
4654 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4655 writeln!(code, " }}")?;
4656 writeln!(code)?;
4657
4658 writeln!(
4660 code,
4661 " /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
4662 )?;
4663 writeln!(
4664 code,
4665 " fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4666 )?;
4667 writeln!(code, " let sl: u32 = seq_len as u32;")?;
4668 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
4669 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
4670 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
4671 writeln!(
4672 code,
4673 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
4674 )?;
4675 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
4676 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
4677 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
4678 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
4679 writeln!(
4680 code,
4681 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4682 )?;
4683 writeln!(
4684 code,
4685 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4686 )?;
4687 writeln!(
4688 code,
4689 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4690 )?;
4691 writeln!(
4692 code,
4693 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4694 )?;
4695 writeln!(
4696 code,
4697 " // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
4698 )?;
4699 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4700 writeln!(
4701 code,
4702 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4703 )?;
4704 writeln!(
4705 code,
4706 " enc.dispatch_thread_groups(grid_size, tg_size);"
4707 )?;
4708 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4709 writeln!(code, " }}")?;
4710 writeln!(code)?;
4711
4712 writeln!(
4714 code,
4715 " /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
4716 )?;
4717 writeln!(
4718 code,
4719 " fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
4720 )?;
4721 writeln!(code, " let sl: u32 = seq_len as u32;")?;
4722 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
4723 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
4724 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
4725 writeln!(
4726 code,
4727 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
4728 )?;
4729 writeln!(
4730 code,
4731 " enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
4732 )?;
4733 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
4734 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
4735 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
4736 writeln!(
4737 code,
4738 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
4739 )?;
4740 writeln!(
4741 code,
4742 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4743 )?;
4744 writeln!(
4745 code,
4746 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4747 )?;
4748 writeln!(
4749 code,
4750 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4751 )?;
4752 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4753 writeln!(
4754 code,
4755 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
4756 )?;
4757 writeln!(
4758 code,
4759 " enc.dispatch_thread_groups(grid_size, tg_size);"
4760 )?;
4761 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4762 writeln!(code, " }}")?;
4763 writeln!(code)?;
4764
4765 writeln!(code, " /// Dispatch fused SiLU-multiply kernel.")?;
4767 writeln!(
4768 code,
4769 " fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
4770 )?;
4771 writeln!(code, " let count: u32 = n as u32;")?;
4772 writeln!(
4773 code,
4774 " enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
4775 )?;
4776 writeln!(code, " enc.set_buffer(0, Some(gate), 0);")?;
4777 writeln!(code, " enc.set_buffer(1, Some(up), 0);")?;
4778 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4779 writeln!(
4780 code,
4781 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4782 )?;
4783 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4784 writeln!(
4785 code,
4786 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4787 )?;
4788 writeln!(
4789 code,
4790 " enc.dispatch_thread_groups(grid_size, tg_size);"
4791 )?;
4792 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4793 writeln!(code, " }}")?;
4794 writeln!(code)?;
4795
4796 writeln!(
4798 code,
4799 " /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
4800 )?;
4801 writeln!(
4802 code,
4803 " /// gate_up_buf contains [gate(n), up(n)] contiguously."
4804 )?;
4805 writeln!(
4806 code,
4807 " fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
4808 )?;
4809 writeln!(code, " let count: u32 = n as u32;")?;
4810 writeln!(
4811 code,
4812 " enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
4813 )?;
4814 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
4815 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
4816 writeln!(
4817 code,
4818 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4819 )?;
4820 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4821 writeln!(
4822 code,
4823 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4824 )?;
4825 writeln!(
4826 code,
4827 " enc.dispatch_thread_groups(grid_size, tg_size);"
4828 )?;
4829 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4830 writeln!(code, " }}")?;
4831 writeln!(code)?;
4832
4833 writeln!(
4835 code,
4836 " /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
4837 )?;
4838 writeln!(
4839 code,
4840 " fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
4841 )?;
4842 writeln!(code, " let n: u32 = count as u32;")?;
4843 writeln!(
4844 code,
4845 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
4846 )?;
4847 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
4848 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
4849 writeln!(
4850 code,
4851 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4852 )?;
4853 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4854 writeln!(
4855 code,
4856 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4857 )?;
4858 writeln!(
4859 code,
4860 " enc.dispatch_thread_groups(grid_size, tg_size);"
4861 )?;
4862 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4863 writeln!(code, " }}")?;
4864 writeln!(code)?;
4865
4866 writeln!(
4868 code,
4869 " /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
4870 )?;
4871 writeln!(
4872 code,
4873 " /// Used for embedding table lookup (copy a specific row)."
4874 )?;
4875 writeln!(
4876 code,
4877 " fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
4878 )?;
4879 writeln!(code, " let off: u32 = src_offset as u32;")?;
4880 writeln!(code, " let n: u32 = count as u32;")?;
4881 writeln!(
4882 code,
4883 " enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
4884 )?;
4885 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
4886 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
4887 writeln!(
4888 code,
4889 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
4890 )?;
4891 writeln!(
4892 code,
4893 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4894 )?;
4895 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4896 writeln!(
4897 code,
4898 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4899 )?;
4900 writeln!(
4901 code,
4902 " enc.dispatch_thread_groups(grid_size, tg_size);"
4903 )?;
4904 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4905 writeln!(code, " }}")?;
4906 writeln!(code)?;
4907
4908 writeln!(
4910 code,
4911 " /// Dispatch copy from source at byte offset to destination at float offset."
4912 )?;
4913 writeln!(
4914 code,
4915 " /// Used for KV cache updates from fused QKV buffer."
4916 )?;
4917 writeln!(
4918 code,
4919 " fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
4920 )?;
4921 writeln!(code, " let n: u32 = count as u32;")?;
4922 writeln!(
4923 code,
4924 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
4925 )?;
4926 writeln!(
4927 code,
4928 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
4929 )?;
4930 writeln!(
4931 code,
4932 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
4933 )?;
4934 writeln!(
4935 code,
4936 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4937 )?;
4938 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4939 writeln!(
4940 code,
4941 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4942 )?;
4943 writeln!(
4944 code,
4945 " enc.dispatch_thread_groups(grid_size, tg_size);"
4946 )?;
4947 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4948 writeln!(code, " }}")?;
4949 writeln!(code)?;
4950
4951 writeln!(
4953 code,
4954 " /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
4955 )?;
4956 writeln!(
4957 code,
4958 " /// Used for KV cache updates (write to a specific position in the cache)."
4959 )?;
4960 writeln!(
4961 code,
4962 " fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
4963 )?;
4964 writeln!(code, " let n: u32 = count as u32;")?;
4965 writeln!(
4966 code,
4967 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
4968 )?;
4969 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
4970 writeln!(
4971 code,
4972 " enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
4973 )?;
4974 writeln!(
4975 code,
4976 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4977 )?;
4978 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4979 writeln!(
4980 code,
4981 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4982 )?;
4983 writeln!(
4984 code,
4985 " enc.dispatch_thread_groups(grid_size, tg_size);"
4986 )?;
4987 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4988 writeln!(code, " }}")?;
4989 writeln!(code)?;
4990
4991 writeln!(
4993 code,
4994 " /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
4995 )?;
4996 writeln!(
4997 code,
4998 " fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
4999 )?;
5000 writeln!(code, " let count: u32 = n as u32;")?;
5001 writeln!(
5002 code,
5003 " enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
5004 )?;
5005 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5006 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5007 writeln!(
5008 code,
5009 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5010 )?;
5011 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5012 writeln!(
5013 code,
5014 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
5015 )?;
5016 writeln!(
5017 code,
5018 " enc.dispatch_thread_groups(grid_size, tg_size);"
5019 )?;
5020 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5021 writeln!(code, " }}")?;
5022 writeln!(code)?;
5023
5024 writeln!(code, " // ── Batched prefill dispatch helpers ──")?;
5026 writeln!(code)?;
5027
5028 writeln!(
5030 code,
5031 " /// Dispatch batched embedding lookup: copy M token embeddings at once."
5032 )?;
5033 writeln!(
5034 code,
5035 " fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
5036 )?;
5037 writeln!(code, " let dim: u32 = HIDDEN_SIZE as u32;")?;
5038 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5039 writeln!(
5040 code,
5041 " enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
5042 )?;
5043 writeln!(code, " enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
5044 writeln!(
5045 code,
5046 " enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
5047 )?;
5048 writeln!(
5049 code,
5050 " enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
5051 )?;
5052 writeln!(
5053 code,
5054 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
5055 )?;
5056 writeln!(
5057 code,
5058 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5059 )?;
5060 writeln!(code, " let total = num_tokens * HIDDEN_SIZE;")?;
5061 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5062 writeln!(
5063 code,
5064 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5065 )?;
5066 writeln!(
5067 code,
5068 " enc.dispatch_thread_groups(grid_size, tg_size);"
5069 )?;
5070 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5071 writeln!(code, " }}")?;
5072 writeln!(code)?;
5073
5074 writeln!(
5076 code,
5077 " /// Dispatch batched RMS norm: normalizes M vectors at once."
5078 )?;
5079 writeln!(
5080 code,
5081 " fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
5082 )?;
5083 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
5084 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
5085 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5086 writeln!(
5087 code,
5088 " enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
5089 )?;
5090 writeln!(code, " enc.set_buffer(0, Some(input), 0);")?;
5091 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
5092 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5093 writeln!(
5094 code,
5095 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5096 )?;
5097 writeln!(
5098 code,
5099 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
5100 )?;
5101 writeln!(
5102 code,
5103 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5104 )?;
5105 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5106 writeln!(
5107 code,
5108 " let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
5109 )?;
5110 writeln!(
5111 code,
5112 " enc.dispatch_thread_groups(grid_size, tg_size);"
5113 )?;
5114 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5115 writeln!(code, " }}")?;
5116 writeln!(code)?;
5117
5118 writeln!(
5120 code,
5121 " /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5122 )?;
5123 writeln!(
5124 code,
5125 " fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5126 )?;
5127 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5128 writeln!(code, " let r: u32 = rows as u32;")?;
5129 writeln!(code, " let c: u32 = cols as u32;")?;
5130 writeln!(
5131 code,
5132 " enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
5133 )?;
5134 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5135 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5136 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5137 writeln!(
5138 code,
5139 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5140 )?;
5141 writeln!(
5142 code,
5143 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5144 )?;
5145 writeln!(
5146 code,
5147 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5148 )?;
5149 writeln!(
5150 code,
5151 " let row_tgs = (rows + 63) / 64; // 64 rows per threadgroup for f32"
5152 )?;
5153 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5154 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5155 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5156 writeln!(
5157 code,
5158 " enc.dispatch_thread_groups(grid_size, tg_size);"
5159 )?;
5160 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5161 writeln!(code, " }}")?;
5162 writeln!(code)?;
5163
5164 writeln!(
5166 code,
5167 " /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5168 )?;
5169 writeln!(code, " ///")?;
5170 writeln!(
5171 code,
5172 " /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
5173 )?;
5174 writeln!(
5175 code,
5176 " /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
5177 )?;
5178 writeln!(
5179 code,
5180 " fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5181 )?;
5182 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5183 writeln!(code, " let r: u32 = rows as u32;")?;
5184 writeln!(code, " let c: u32 = cols as u32;")?;
5185 writeln!(
5186 code,
5187 " // Tile sizes must match the Metal shader constants."
5188 )?;
5189 writeln!(code, " const TOKENS_PER_TG_Q8: usize = 4;")?;
5190 writeln!(code, " const MMA_TOK_TILE: usize = 16;")?;
5191 writeln!(code, " const MMA_ROW_TILE: usize = 16;")?;
5192 writeln!(code, " const MMA32_TOK_TILE: usize = 32;")?;
5193 writeln!(code, " const MMA32_ROW_TILE: usize = 32;")?;
5194 writeln!(
5195 code,
5196 " // Hardware matrix-multiply paths (simdgroup_matrix)."
5197 )?;
5198 writeln!(
5199 code,
5200 " // Prefer the large 32×32 tile when the problem supports it — halves"
5201 )?;
5202 writeln!(
5203 code,
5204 " // dispatch count and reuses each weight load across 32 tokens."
5205 )?;
5206 writeln!(
5207 code,
5208 " if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
5209 )?;
5210 writeln!(
5211 code,
5212 " // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
5213 )?;
5214 writeln!(
5215 code,
5216 " // It wins at moderate prefill lengths where the GPU is wave-starved,"
5217 )?;
5218 writeln!(
5219 code,
5220 " // but the f32→f16 conversion overhead slightly hurts the small-hidden"
5221 )?;
5222 writeln!(
5223 code,
5224 " // case (135M / 360M). Switch at cols >= 2048 — a clean split that"
5225 )?;
5226 writeln!(
5227 code,
5228 " // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
5229 )?;
5230 writeln!(
5231 code,
5232 " // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
5233 )?;
5234 writeln!(
5235 code,
5236 " // little at low M but wins at higher M via ~2x FP16 MMA throughput."
5237 )?;
5238 writeln!(
5239 code,
5240 " // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
5241 )?;
5242 writeln!(code, " let use_h4 = cols >= 2048;")?;
5243 writeln!(code, " let pipe = if use_h4 {{")?;
5244 writeln!(code, " if num_tokens >= 256 {{")?;
5245 writeln!(code, " &self.matmul_q8_mma32_hh4_pipeline")?;
5246 writeln!(code, " }} else {{")?;
5247 writeln!(code, " &self.matmul_q8_mma32_h4_pipeline")?;
5248 writeln!(code, " }}")?;
5249 writeln!(code, " }} else {{")?;
5250 writeln!(code, " &self.matmul_q8_mma32_pipeline")?;
5251 writeln!(code, " }};")?;
5252 writeln!(
5253 code,
5254 " enc.set_compute_pipeline_state(pipe);"
5255 )?;
5256 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5257 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5258 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5259 writeln!(
5260 code,
5261 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5262 )?;
5263 writeln!(
5264 code,
5265 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5266 )?;
5267 writeln!(
5268 code,
5269 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5270 )?;
5271 writeln!(
5272 code,
5273 " let row_tgs = rows / MMA32_ROW_TILE;"
5274 )?;
5275 writeln!(
5276 code,
5277 " let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
5278 )?;
5279 writeln!(
5280 code,
5281 " let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
5282 )?;
5283 writeln!(
5284 code,
5285 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5286 )?;
5287 writeln!(
5288 code,
5289 " enc.dispatch_thread_groups(grid_size, tg_size);"
5290 )?;
5291 writeln!(
5292 code,
5293 " }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
5294 )?;
5295 writeln!(
5296 code,
5297 " enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
5298 )?;
5299 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5300 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5301 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5302 writeln!(
5303 code,
5304 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5305 )?;
5306 writeln!(
5307 code,
5308 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5309 )?;
5310 writeln!(
5311 code,
5312 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5313 )?;
5314 writeln!(
5315 code,
5316 " let row_tgs = rows / MMA_ROW_TILE;"
5317 )?;
5318 writeln!(
5319 code,
5320 " let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
5321 )?;
5322 writeln!(
5323 code,
5324 " let tg_size = MTLSize::new(128, 1, 1); // 4 simdgroups × 32 lanes"
5325 )?;
5326 writeln!(
5327 code,
5328 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5329 )?;
5330 writeln!(
5331 code,
5332 " enc.dispatch_thread_groups(grid_size, tg_size);"
5333 )?;
5334 writeln!(code, " }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
5335 writeln!(
5336 code,
5337 " enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
5338 )?;
5339 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5340 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5341 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5342 writeln!(
5343 code,
5344 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5345 )?;
5346 writeln!(
5347 code,
5348 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5349 )?;
5350 writeln!(
5351 code,
5352 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5353 )?;
5354 writeln!(
5355 code,
5356 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5357 )?;
5358 writeln!(
5359 code,
5360 " let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
5361 )?;
5362 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5363 writeln!(
5364 code,
5365 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
5366 )?;
5367 writeln!(
5368 code,
5369 " enc.dispatch_thread_groups(grid_size, tg_size);"
5370 )?;
5371 writeln!(code, " }} else {{")?;
5372 writeln!(
5373 code,
5374 " enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
5375 )?;
5376 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5377 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5378 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5379 writeln!(
5380 code,
5381 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5382 )?;
5383 writeln!(
5384 code,
5385 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5386 )?;
5387 writeln!(
5388 code,
5389 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5390 )?;
5391 writeln!(
5392 code,
5393 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
5394 )?;
5395 writeln!(
5396 code,
5397 " let num_tg = (row_tgs * num_tokens) as u64;"
5398 )?;
5399 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5400 writeln!(
5401 code,
5402 " let grid_size = MTLSize::new(num_tg, 1, 1);"
5403 )?;
5404 writeln!(
5405 code,
5406 " enc.dispatch_thread_groups(grid_size, tg_size);"
5407 )?;
5408 writeln!(code, " }}")?;
5409 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5410 writeln!(code, " }}")?;
5411 writeln!(code)?;
5412
5413 writeln!(
5415 code,
5416 " /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
5417 )?;
5418 writeln!(
5419 code,
5420 " fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
5421 )?;
5422 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5423 writeln!(code, " let r: u32 = rows as u32;")?;
5424 writeln!(code, " let c: u32 = cols as u32;")?;
5425 writeln!(
5426 code,
5427 " enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
5428 )?;
5429 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
5430 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
5431 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
5432 writeln!(
5433 code,
5434 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5435 )?;
5436 writeln!(
5437 code,
5438 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5439 )?;
5440 writeln!(
5441 code,
5442 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
5443 )?;
5444 writeln!(
5445 code,
5446 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q4"
5447 )?;
5448 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
5449 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5450 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
5451 writeln!(
5452 code,
5453 " enc.dispatch_thread_groups(grid_size, tg_size);"
5454 )?;
5455 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5456 writeln!(code, " }}")?;
5457 writeln!(code)?;
5458
5459 if config.qkv_bias {
5461 writeln!(
5462 code,
5463 " /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
5464 )?;
5465 writeln!(
5466 code,
5467 " fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
5468 )?;
5469 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5470 writeln!(code, " let r: u32 = rows as u32;")?;
5471 writeln!(
5472 code,
5473 " enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
5474 )?;
5475 writeln!(code, " enc.set_buffer(0, Some(out), 0);")?;
5476 writeln!(code, " enc.set_buffer(1, Some(bias), 0);")?;
5477 writeln!(
5478 code,
5479 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5480 )?;
5481 writeln!(
5482 code,
5483 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
5484 )?;
5485 writeln!(code, " let total = (num_tokens * rows) as u64;")?;
5486 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5487 writeln!(
5488 code,
5489 " let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
5490 )?;
5491 writeln!(
5492 code,
5493 " enc.dispatch_thread_groups(grid_size, tg_size);"
5494 )?;
5495 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5496 writeln!(code, " }}")?;
5497 writeln!(code)?;
5498 }
5499
5500 writeln!(
5502 code,
5503 " /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
5504 )?;
5505 writeln!(
5506 code,
5507 " /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
5508 )?;
5509 writeln!(
5510 code,
5511 " /// `row_stride` is the number of floats per token row in the batch buffer."
5512 )?;
5513 writeln!(
5514 code,
5515 " fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
5516 )?;
5517 writeln!(code, " let nh: u32 = num_heads as u32;")?;
5518 writeln!(code, " let hd: u32 = head_dim as u32;")?;
5519 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
5520 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5521 writeln!(
5522 code,
5523 " let pairs_per_token = num_heads * (head_dim / 2);"
5524 )?;
5525 writeln!(
5526 code,
5527 " let total_pairs = num_tokens * pairs_per_token;"
5528 )?;
5529 writeln!(
5542 code,
5543 " // Apply RoPE to each token individually (different positions, non-contiguous layout)"
5544 )?;
5545 writeln!(code, " for t in 0..num_tokens {{")?;
5546 writeln!(
5547 code,
5548 " let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
5549 )?;
5550 writeln!(
5551 code,
5552 " let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
5553 )?;
5554 writeln!(
5555 code,
5556 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
5557 )?;
5558 writeln!(
5559 code,
5560 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
5561 )?;
5562 writeln!(
5563 code,
5564 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5565 )?;
5566 writeln!(
5567 code,
5568 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5569 )?;
5570 writeln!(
5571 code,
5572 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
5573 )?;
5574 writeln!(
5575 code,
5576 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5577 )?;
5578 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5579 writeln!(
5580 code,
5581 " let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
5582 )?;
5583 writeln!(
5584 code,
5585 " enc.dispatch_thread_groups(grid_size, tg_size);"
5586 )?;
5587 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5588 writeln!(code, " }}")?;
5589 writeln!(code, " }}")?;
5590 writeln!(code)?;
5591
5592 writeln!(
5594 code,
5595 " /// Dispatch batched fused SiLU-multiply for M tokens."
5596 )?;
5597 writeln!(
5598 code,
5599 " fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
5600 )?;
5601 writeln!(code, " let count: u32 = n as u32;")?;
5602 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
5603 writeln!(
5604 code,
5605 " enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
5606 )?;
5607 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
5608 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
5609 writeln!(
5610 code,
5611 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5612 )?;
5613 writeln!(
5614 code,
5615 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
5616 )?;
5617 writeln!(code, " let total = num_tokens * n;")?;
5618 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5619 writeln!(
5620 code,
5621 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5622 )?;
5623 writeln!(
5624 code,
5625 " enc.dispatch_thread_groups(grid_size, tg_size);"
5626 )?;
5627 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5628 writeln!(code, " }}")?;
5629 writeln!(code)?;
5630
5631 writeln!(
5633 code,
5634 " /// Dispatch in-place add for total_n elements: a[i] += b[i]."
5635 )?;
5636 writeln!(
5637 code,
5638 " fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
5639 )?;
5640 writeln!(code, " let count: u32 = total_n as u32;")?;
5641 writeln!(
5642 code,
5643 " enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
5644 )?;
5645 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
5646 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
5647 writeln!(
5648 code,
5649 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
5650 )?;
5651 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5652 writeln!(
5653 code,
5654 " let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
5655 )?;
5656 writeln!(
5657 code,
5658 " enc.dispatch_thread_groups(grid_size, tg_size);"
5659 )?;
5660 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5661 writeln!(code, " }}")?;
5662 writeln!(code)?;
5663
5664 writeln!(
5666 code,
5667 " /// Copy src to dst using compute copy kernel (for batch residual init)."
5668 )?;
5669 writeln!(
5670 code,
5671 " fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
5672 )?;
5673 writeln!(code, " let n: u32 = count as u32;")?;
5674 writeln!(
5675 code,
5676 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5677 )?;
5678 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5679 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5680 writeln!(
5681 code,
5682 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5683 )?;
5684 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5685 writeln!(
5686 code,
5687 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5688 )?;
5689 writeln!(
5690 code,
5691 " enc.dispatch_thread_groups(grid_size, tg_size);"
5692 )?;
5693 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5694 writeln!(code, " }}")?;
5695 writeln!(code)?;
5696
5697 writeln!(
5699 code,
5700 " /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
5701 )?;
5702 writeln!(
5703 code,
5704 " fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5705 )?;
5706 writeln!(code, " let n: u32 = count as u32;")?;
5707 writeln!(
5708 code,
5709 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5710 )?;
5711 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5712 writeln!(
5713 code,
5714 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5715 )?;
5716 writeln!(
5717 code,
5718 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5719 )?;
5720 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5721 writeln!(
5722 code,
5723 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5724 )?;
5725 writeln!(
5726 code,
5727 " enc.dispatch_thread_groups(grid_size, tg_size);"
5728 )?;
5729 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5730 writeln!(code, " }}")?;
5731 writeln!(code)?;
5732
5733 writeln!(
5735 code,
5736 " /// Copy from src at byte offset to dst at float offset."
5737 )?;
5738 writeln!(
5739 code,
5740 " fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
5741 )?;
5742 writeln!(code, " let n: u32 = count as u32;")?;
5743 writeln!(
5744 code,
5745 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
5746 )?;
5747 writeln!(
5748 code,
5749 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
5750 )?;
5751 writeln!(
5752 code,
5753 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
5754 )?;
5755 writeln!(
5756 code,
5757 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
5758 )?;
5759 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5760 writeln!(
5761 code,
5762 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
5763 )?;
5764 writeln!(
5765 code,
5766 " enc.dispatch_thread_groups(grid_size, tg_size);"
5767 )?;
5768 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5769 writeln!(code, " }}")?;
5770 writeln!(code)?;
5771
5772 writeln!(
5774 code,
5775 " /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
5776 )?;
5777 writeln!(
5778 code,
5779 " fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
5780 )?;
5781 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
5782 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
5783 writeln!(code, " let bp: u32 = base_pos as u32;")?;
5784 writeln!(code, " let ss: u32 = src_stride as u32;")?;
5785 writeln!(code, " let so: u32 = src_offset as u32;")?;
5786 writeln!(
5787 code,
5788 " enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
5789 )?;
5790 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
5791 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
5792 writeln!(
5793 code,
5794 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5795 )?;
5796 writeln!(
5797 code,
5798 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
5799 )?;
5800 writeln!(
5801 code,
5802 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5803 )?;
5804 writeln!(
5805 code,
5806 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
5807 )?;
5808 writeln!(
5809 code,
5810 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
5811 )?;
5812 writeln!(code, " let total = num_tokens * kv_dim;")?;
5813 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5814 writeln!(
5815 code,
5816 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
5817 )?;
5818 writeln!(
5819 code,
5820 " enc.dispatch_thread_groups(grid_size, tg_size);"
5821 )?;
5822 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5823 writeln!(code, " }}")?;
5824 writeln!(code)?;
5825
5826 writeln!(
5828 code,
5829 " /// Dispatch batched causal attention: one dispatch for all M tokens."
5830 )?;
5831 writeln!(
5832 code,
5833 " fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
5834 )?;
5835 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
5836 writeln!(code, " let bp: u32 = base_pos as u32;")?;
5837 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
5838 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
5839 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
5840 writeln!(code, " let qs: u32 = q_stride as u32;")?;
5841 writeln!(
5848 code,
5849 " let max_seq = base_pos + num_tokens;"
5850 )?;
5851 writeln!(
5852 code,
5853 " let pipe = if max_seq > 2000 {{"
5854 )?;
5855 writeln!(
5856 code,
5857 " &self.attention_flash_batch_pipeline"
5858 )?;
5859 writeln!(code, " }} else {{")?;
5860 writeln!(
5861 code,
5862 " &self.attention_batch_pipeline"
5863 )?;
5864 writeln!(code, " }};")?;
5865 writeln!(
5866 code,
5867 " enc.set_compute_pipeline_state(pipe);"
5868 )?;
5869 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
5870 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
5871 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
5872 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
5873 writeln!(
5874 code,
5875 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5876 )?;
5877 writeln!(
5878 code,
5879 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5880 )?;
5881 writeln!(
5882 code,
5883 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
5884 )?;
5885 writeln!(
5886 code,
5887 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5888 )?;
5889 writeln!(
5890 code,
5891 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5892 )?;
5893 writeln!(
5894 code,
5895 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5896 )?;
5897 writeln!(
5898 code,
5899 " // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
5900 )?;
5901 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5902 writeln!(
5903 code,
5904 " let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
5905 )?;
5906 writeln!(
5907 code,
5908 " enc.dispatch_thread_groups(grid_size, tg_size);"
5909 )?;
5910 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5911 writeln!(code, " }}")?;
5912 writeln!(code)?;
5913
5914 writeln!(
5916 code,
5917 " /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
5918 )?;
5919 writeln!(
5920 code,
5921 " /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
5922 )?;
5923 writeln!(
5924 code,
5925 " fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
5926 )?;
5927 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
5928 writeln!(code, " let bp: u32 = base_pos as u32;")?;
5929 writeln!(code, " let nq: u32 = NUM_HEADS as u32;")?;
5930 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
5931 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
5932 writeln!(code, " let qs: u32 = qkv_stride as u32;")?;
5933 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
5934 writeln!(
5935 code,
5936 " enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
5937 )?;
5938 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
5939 writeln!(
5940 code,
5941 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
5942 )?;
5943 writeln!(
5944 code,
5945 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
5946 )?;
5947 writeln!(
5948 code,
5949 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
5950 )?;
5951 writeln!(
5952 code,
5953 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
5954 )?;
5955 writeln!(
5956 code,
5957 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
5958 )?;
5959 writeln!(
5960 code,
5961 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
5962 )?;
5963 writeln!(
5964 code,
5965 " enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
5966 )?;
5967 writeln!(
5968 code,
5969 " let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
5970 )?;
5971 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
5972 writeln!(
5973 code,
5974 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
5975 )?;
5976 writeln!(
5977 code,
5978 " enc.dispatch_thread_groups(grid_size, tg_size);"
5979 )?;
5980 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
5981 writeln!(code, " }}")?;
5982 writeln!(code)?;
5983
5984 writeln!(
5986 code,
5987 " /// Dispatch fused K+V cache copy in one kernel launch."
5988 )?;
5989 writeln!(
5990 code,
5991 " /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
5992 )?;
5993 writeln!(
5994 code,
5995 " fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
5996 )?;
5997 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
5998 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
5999 writeln!(code, " let bp: u32 = base_pos as u32;")?;
6000 writeln!(code, " let ss: u32 = src_stride as u32;")?;
6001 writeln!(code, " let ko: u32 = k_offset as u32;")?;
6002 writeln!(code, " let vo: u32 = v_offset as u32;")?;
6003 writeln!(
6004 code,
6005 " enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
6006 )?;
6007 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
6008 writeln!(code, " enc.set_buffer(1, Some(k_dst), 0);")?;
6009 writeln!(code, " enc.set_buffer(2, Some(v_dst), 0);")?;
6010 writeln!(
6011 code,
6012 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
6013 )?;
6014 writeln!(
6015 code,
6016 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
6017 )?;
6018 writeln!(
6019 code,
6020 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
6021 )?;
6022 writeln!(
6023 code,
6024 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
6025 )?;
6026 writeln!(
6027 code,
6028 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
6029 )?;
6030 writeln!(
6031 code,
6032 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
6033 )?;
6034 writeln!(
6035 code,
6036 " let total = num_tokens * kv_dim * 2; // K + V"
6037 )?;
6038 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
6039 writeln!(
6040 code,
6041 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
6042 )?;
6043 writeln!(
6044 code,
6045 " enc.dispatch_thread_groups(grid_size, tg_size);"
6046 )?;
6047 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
6048 writeln!(code, " }}")?;
6049
6050 writeln!(code, "}}")?;
6051 writeln!(code)?;
6052
6053 Ok(())
6054}
6055
6056fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
6057 writeln!(
6058 code,
6059 "// ── Helper functions ──────────────────────────────────"
6060 )?;
6061 writeln!(code)?;
6062 writeln!(
6063 code,
6064 "/// Create a compute pipeline from a named function in the Metal library."
6065 )?;
6066 writeln!(
6067 code,
6068 "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
6069 )?;
6070 writeln!(
6071 code,
6072 " let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
6073 )?;
6074 writeln!(
6075 code,
6076 " device.new_compute_pipeline_state_with_function(&func)"
6077 )?;
6078 writeln!(
6079 code,
6080 " .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
6081 )?;
6082 writeln!(code, "}}")?;
6083 writeln!(code)?;
6084
6085 Ok(())
6086}
6087
6088fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
6093 let _sanitized = sanitize_name(model_name);
6094 let _vocab = config.vocab_size;
6095
6096 let mut code = String::with_capacity(16 * 1024);
6097 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
6098 writeln!(
6099 code,
6100 "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
6101 )?;
6102 writeln!(code)?;
6103 writeln!(code, "mod model;")?;
6104 writeln!(code)?;
6105 writeln!(code, "use std::io::Write;")?;
6106 writeln!(code, "use std::time::Instant;")?;
6107 writeln!(code, "use serde::Deserialize;")?;
6108 writeln!(code)?;
6109
6110 writeln!(code, "fn main() {{")?;
6112 writeln!(
6113 code,
6114 " let args: Vec<String> = std::env::args().collect();"
6115 )?;
6116 writeln!(code)?;
6117 writeln!(
6118 code,
6119 " // Detect --serve mode (only requires weights + tokenizer)"
6120 )?;
6121 writeln!(
6122 code,
6123 " let serve_mode = args.iter().any(|a| a == \"--serve\");"
6124 )?;
6125 writeln!(code)?;
6126 writeln!(code, " if !serve_mode && args.len() < 4 {{")?;
6127 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
6128 writeln!(code, " eprintln!(\" {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6129 writeln!(code, " std::process::exit(1);")?;
6130 writeln!(code, " }}")?;
6131 writeln!(code)?;
6132 writeln!(code, " if serve_mode && args.len() < 3 {{")?;
6133 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
6134 writeln!(code, " std::process::exit(1);")?;
6135 writeln!(code, " }}")?;
6136 writeln!(code)?;
6137 writeln!(code, " let weights_path = &args[1];")?;
6138 writeln!(code, " let tokenizer_path = &args[2];")?;
6139 writeln!(code)?;
6140 writeln!(code, " // Parse optional flags")?;
6141 writeln!(code, " let mut max_tokens: usize = 128;")?;
6142 writeln!(code, " let mut port: u16 = 8080;")?;
6143 writeln!(
6144 code,
6145 " let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
6146 )?;
6147 writeln!(
6148 code,
6149 " let profile = args.iter().any(|a| a == \"--profile\");"
6150 )?;
6151 writeln!(code, " let mut i = 3;")?;
6152 writeln!(code, " while i < args.len() {{")?;
6153 writeln!(
6154 code,
6155 " if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
6156 )?;
6157 writeln!(
6158 code,
6159 " max_tokens = args[i + 1].parse().unwrap_or(128);"
6160 )?;
6161 writeln!(code, " i += 2;")?;
6162 writeln!(
6163 code,
6164 " }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
6165 )?;
6166 writeln!(
6167 code,
6168 " port = args[i + 1].parse().unwrap_or(8080);"
6169 )?;
6170 writeln!(code, " i += 2;")?;
6171 writeln!(code, " }} else if args[i] == \"--serve\" {{")?;
6172 writeln!(code, " i += 1;")?;
6173 writeln!(code, " }} else if args[i] == \"--profile\" {{")?;
6174 writeln!(code, " i += 1;")?;
6175 writeln!(code, " }} else {{")?;
6176 writeln!(code, " i += 1;")?;
6177 writeln!(code, " }}")?;
6178 writeln!(code, " }}")?;
6179 writeln!(code)?;
6180
6181 writeln!(
6183 code,
6184 " // Memory-map weights for zero-copy loading on Apple Silicon"
6185 )?;
6186 writeln!(
6187 code,
6188 " let weights_file = std::fs::File::open(weights_path)"
6189 )?;
6190 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
6191 writeln!(
6192 code,
6193 " let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
6194 )?;
6195 writeln!(code)?;
6196 writeln!(code, " // Load tokenizer")?;
6197 writeln!(
6198 code,
6199 " let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
6200 )?;
6201 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
6202 writeln!(code)?;
6203 writeln!(code, " // Create Metal model")?;
6204 writeln!(code, " eprintln!(\"Loading model onto Metal GPU...\");")?;
6205 writeln!(
6206 code,
6207 " let mut model = model::MetalModel::new(&weights_mmap);"
6208 )?;
6209 writeln!(code)?;
6210
6211 writeln!(code, " if serve_mode {{")?;
6213 writeln!(code, " serve(model, tokenizer, port);")?;
6214 writeln!(code, " }} else {{")?;
6215 writeln!(code, " let prompt = &args[3];")?;
6216 writeln!(
6217 code,
6218 " cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
6219 )?;
6220 writeln!(code, " }}")?;
6221 writeln!(code, "}}")?;
6222 writeln!(code)?;
6223
6224 writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
6226 writeln!(code, " // Tokenize prompt")?;
6227 writeln!(code, " let encoding = tokenizer.encode(prompt, true)")?;
6228 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
6229 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6230 writeln!(code)?;
6231 writeln!(
6232 code,
6233 " // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
6234 )?;
6235 writeln!(
6236 code,
6237 " // Uses double-buffered batch dispatch for GPU-efficient matmul."
6238 )?;
6239 writeln!(
6240 code,
6241 " // The last token uses synchronous forward() to get logits."
6242 )?;
6243 writeln!(code, " let prompt_len = prompt_tokens.len();")?;
6244 writeln!(code, " let prefill_start = Instant::now();")?;
6245 writeln!(code, " let logits = if prompt_len > 1 {{")?;
6246 writeln!(
6247 code,
6248 " model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
6249 )?;
6250 writeln!(code, " model.forward(prompt_tokens[prompt_len - 1])")?;
6251 writeln!(code, " }} else {{")?;
6252 writeln!(code, " model.forward(prompt_tokens[0])")?;
6253 writeln!(code, " }};")?;
6254 writeln!(
6255 code,
6256 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6257 )?;
6258 writeln!(code, " let prefill_tokens = prompt_tokens.len();")?;
6259 writeln!(
6260 code,
6261 " eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
6262 )?;
6263 writeln!(
6264 code,
6265 " prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
6266 )?;
6267 writeln!(code)?;
6268 writeln!(code, " // Generate tokens")?;
6269 writeln!(code, " let mut next_token = argmax(&logits);")?;
6270 writeln!(code, " let gen_start = Instant::now();")?;
6271 writeln!(code, " let mut generated_count: usize = 0;")?;
6272 writeln!(code)?;
6273 writeln!(code, " for _ in 0..max_tokens {{")?;
6274 writeln!(
6275 code,
6276 " if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
6277 )?;
6278 writeln!(code, " if !quiet {{")?;
6279 writeln!(code, " print!(\"{{}}\", text);")?;
6280 writeln!(code, " std::io::stdout().flush().ok();")?;
6281 writeln!(code, " }}")?;
6282 writeln!(code, " }}")?;
6283 writeln!(code, " generated_count += 1;")?;
6284 writeln!(code)?;
6285 writeln!(
6286 code,
6287 " // Use profiling forward for first token when --profile is set"
6288 )?;
6289 writeln!(
6290 code,
6291 " let logits = if profile && generated_count == 1 {{"
6292 )?;
6293 writeln!(code, " model.forward_profile(next_token)")?;
6294 writeln!(code, " }} else {{")?;
6295 writeln!(code, " model.forward(next_token)")?;
6296 writeln!(code, " }};")?;
6297 writeln!(code, " next_token = argmax(&logits);")?;
6298 writeln!(code)?;
6299 writeln!(code, " // Stop on EOS (token 2 for most models)")?;
6300 writeln!(code, " if next_token == 2 {{")?;
6301 writeln!(code, " break;")?;
6302 writeln!(code, " }}")?;
6303 writeln!(code)?;
6304 writeln!(
6305 code,
6306 " // Yield between tokens to reduce sustained GPU thermal load."
6307 )?;
6308 writeln!(
6309 code,
6310 " // On Apple Silicon, continuous GPU saturation causes thermal throttling"
6311 )?;
6312 writeln!(
6313 code,
6314 " // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
6315 )?;
6316 writeln!(
6317 code,
6318 " // briefly, providing a micro-break that helps sustain peak throughput."
6319 )?;
6320 writeln!(code, " std::thread::yield_now();")?;
6321 writeln!(code, " }}")?;
6322 writeln!(code, " if !quiet {{")?;
6323 writeln!(code, " println!();")?;
6324 writeln!(code, " }}")?;
6325 writeln!(
6326 code,
6327 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6328 )?;
6329 writeln!(
6330 code,
6331 " eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
6332 )?;
6333 writeln!(
6334 code,
6335 " generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
6336 )?;
6337 writeln!(code, "}}")?;
6338 writeln!(code)?;
6339
6340 writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
6342 writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
6343 writeln!(code, " logits.iter()")?;
6344 writeln!(code, " .enumerate()")?;
6345 writeln!(
6346 code,
6347 " .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
6348 )?;
6349 writeln!(code, " .map(|(i, _)| i as u32)")?;
6350 writeln!(code, " .unwrap_or(0)")?;
6351 writeln!(code, "}}")?;
6352 writeln!(code)?;
6353
6354 writeln!(
6356 code,
6357 "// -----------------------------------------------------------------------"
6358 )?;
6359 writeln!(code, "// OpenAI-compatible API server")?;
6360 writeln!(
6361 code,
6362 "// -----------------------------------------------------------------------"
6363 )?;
6364 writeln!(code)?;
6365 writeln!(code, "#[derive(Deserialize)]")?;
6366 writeln!(code, "struct ChatRequest {{")?;
6367 writeln!(code, " messages: Vec<ChatMessage>,")?;
6368 writeln!(code, " #[serde(default)]")?;
6369 writeln!(code, " stream: Option<bool>,")?;
6370 writeln!(code, " #[serde(default)]")?;
6371 writeln!(code, " max_tokens: Option<usize>,")?;
6372 writeln!(code, " #[serde(default)]")?;
6373 writeln!(code, " temperature: Option<f32>,")?;
6374 writeln!(code, " #[serde(default)]")?;
6375 writeln!(code, " model: Option<String>,")?;
6376 writeln!(code, "}}")?;
6377 writeln!(code)?;
6378 writeln!(code, "#[derive(Deserialize)]")?;
6379 writeln!(code, "struct ChatMessage {{")?;
6380 writeln!(code, " role: String,")?;
6381 writeln!(code, " content: String,")?;
6382 writeln!(code, "}}")?;
6383 writeln!(code)?;
6384
6385 writeln!(
6387 code,
6388 "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
6389 )?;
6390 writeln!(code, " let mut prompt = String::new();")?;
6391 writeln!(code, " for msg in messages {{")?;
6392 writeln!(code, " prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
6393 writeln!(code, " }}")?;
6394 writeln!(code, " prompt.push_str(\"<|im_start|>assistant\\n\");")?;
6395 writeln!(code, " prompt")?;
6396 writeln!(code, "}}")?;
6397 writeln!(code)?;
6398
6399 writeln!(
6401 code,
6402 "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
6403 )?;
6404 writeln!(code, " let len = tokens.len();")?;
6405 writeln!(code, " if len > 1 {{")?;
6406 writeln!(
6407 code,
6408 " model.forward_prefill_batch(&tokens[..len - 1]);"
6409 )?;
6410 writeln!(code, " }}")?;
6411 writeln!(code, " model.forward(tokens[len - 1])")?;
6412 writeln!(code, "}}")?;
6413 writeln!(code)?;
6414
6415 writeln!(
6417 code,
6418 "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
6419 )?;
6420 writeln!(code, " let addr = format!(\"0.0.0.0:{{}}\", port);")?;
6421 writeln!(code, " let server = tiny_http::Server::http(&addr)")?;
6422 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
6423 writeln!(
6424 code,
6425 " eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
6426 )?;
6427 writeln!(code, " eprintln!(\"Endpoints:\");")?;
6428 writeln!(code, " eprintln!(\" POST /v1/chat/completions\");")?;
6429 writeln!(code, " eprintln!(\" GET /v1/models\");")?;
6430 writeln!(code, " eprintln!(\" GET /health\");")?;
6431 writeln!(code)?;
6432 writeln!(code, " for request in server.incoming_requests() {{")?;
6433 writeln!(code, " let method = request.method().to_string();")?;
6434 writeln!(code, " let url = request.url().to_string();")?;
6435 writeln!(code)?;
6436 writeln!(code, " match (method.as_str(), url.as_str()) {{")?;
6437
6438 writeln!(
6440 code,
6441 " (\"POST\", \"/v1/chat/completions\") => {{"
6442 )?;
6443 writeln!(
6444 code,
6445 " handle_chat_completion(&mut model, &tokenizer, request);"
6446 )?;
6447 writeln!(code, " }}")?;
6448
6449 writeln!(code, " (\"GET\", \"/v1/models\") => {{")?;
6451 writeln!(code, " let body = serde_json::json!({{")?;
6452 writeln!(code, " \"object\": \"list\",")?;
6453 writeln!(code, " \"data\": [{{")?;
6454 writeln!(code, " \"id\": \"forgellm-metal\",")?;
6455 writeln!(code, " \"object\": \"model\",")?;
6456 writeln!(code, " \"owned_by\": \"forgellm\"")?;
6457 writeln!(code, " }}]")?;
6458 writeln!(code, " }});")?;
6459 writeln!(
6460 code,
6461 " let resp = tiny_http::Response::from_string(body.to_string())"
6462 )?;
6463 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6464 writeln!(code, " request.respond(resp).ok();")?;
6465 writeln!(code, " }}")?;
6466
6467 writeln!(code, " (\"GET\", \"/health\") => {{")?;
6469 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
6470 writeln!(code, " request.respond(resp).ok();")?;
6471 writeln!(code, " }}")?;
6472
6473 writeln!(code, " _ => {{")?;
6475 writeln!(
6476 code,
6477 " let resp = tiny_http::Response::from_string(\"Not Found\")"
6478 )?;
6479 writeln!(code, " .with_status_code(404);")?;
6480 writeln!(code, " request.respond(resp).ok();")?;
6481 writeln!(code, " }}")?;
6482 writeln!(code, " }}")?;
6483 writeln!(code, " }}")?;
6484 writeln!(code, "}}")?;
6485 writeln!(code)?;
6486
6487 writeln!(code, "fn handle_chat_completion(")?;
6489 writeln!(code, " model: &mut model::MetalModel,")?;
6490 writeln!(code, " tokenizer: &tokenizers::Tokenizer,")?;
6491 writeln!(code, " mut request: tiny_http::Request,")?;
6492 writeln!(code, ") {{")?;
6493 writeln!(code, " // Read request body")?;
6494 writeln!(code, " let mut body = String::new();")?;
6495 writeln!(
6496 code,
6497 " if request.as_reader().read_to_string(&mut body).is_err() {{"
6498 )?;
6499 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
6500 writeln!(code, " .with_status_code(400);")?;
6501 writeln!(code, " request.respond(resp).ok();")?;
6502 writeln!(code, " return;")?;
6503 writeln!(code, " }}")?;
6504 writeln!(code)?;
6505 writeln!(code, " // Parse JSON")?;
6506 writeln!(
6507 code,
6508 " let req: ChatRequest = match serde_json::from_str(&body) {{"
6509 )?;
6510 writeln!(code, " Ok(r) => r,")?;
6511 writeln!(code, " Err(e) => {{")?;
6512 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
6513 writeln!(code, " .with_status_code(400);")?;
6514 writeln!(code, " request.respond(resp).ok();")?;
6515 writeln!(code, " return;")?;
6516 writeln!(code, " }}")?;
6517 writeln!(code, " }};")?;
6518 writeln!(code)?;
6519 writeln!(
6520 code,
6521 " let prompt = format_chat_messages(&req.messages);"
6522 )?;
6523 writeln!(
6524 code,
6525 " let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
6526 )?;
6527 writeln!(code, " Ok(e) => e,")?;
6528 writeln!(code, " Err(e) => {{")?;
6529 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
6530 writeln!(code, " .with_status_code(500);")?;
6531 writeln!(code, " request.respond(resp).ok();")?;
6532 writeln!(code, " return;")?;
6533 writeln!(code, " }}")?;
6534 writeln!(code, " }};")?;
6535 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
6536 writeln!(code, " let stream = req.stream.unwrap_or(false);")?;
6537 writeln!(code, " let max_tokens = req.max_tokens.unwrap_or(256);")?;
6538 writeln!(
6539 code,
6540 " let _temperature = req.temperature.unwrap_or(1.0);"
6541 )?;
6542 writeln!(code)?;
6543
6544 writeln!(code, " model.reset();")?;
6546 writeln!(code)?;
6547
6548 writeln!(code, " let prefill_start = Instant::now();")?;
6550 writeln!(code, " let logits = prefill(model, prompt_tokens);")?;
6551 writeln!(
6552 code,
6553 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
6554 )?;
6555 writeln!(code, " let prefill_count = prompt_tokens.len();")?;
6556 writeln!(code, " let mut next_token = argmax(&logits);")?;
6557 writeln!(code)?;
6558
6559 writeln!(code, " if stream {{")?;
6560
6561 writeln!(
6563 code,
6564 " // SSE streaming: generate tokens and build SSE body"
6565 )?;
6566 writeln!(code, " let gen_start = Instant::now();")?;
6567 writeln!(code, " let mut generated_count: usize = 0;")?;
6568 writeln!(code, " let mut sse_body = String::new();")?;
6569 writeln!(code, " for _ in 0..max_tokens {{")?;
6570 writeln!(code, " if next_token == 2 {{ break; }}")?;
6571 writeln!(
6572 code,
6573 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6574 )?;
6575 writeln!(
6576 code,
6577 " let escaped = serde_json::to_string(&text).unwrap_or_default();"
6578 )?;
6579 writeln!(
6580 code,
6581 " // escaped includes surrounding quotes, strip them"
6582 )?;
6583 writeln!(
6584 code,
6585 " let inner = &escaped[1..escaped.len()-1];"
6586 )?;
6587 writeln!(code, " sse_body.push_str(&format!(")?;
6588 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
6589 writeln!(code, " inner")?;
6590 writeln!(code, " ));")?;
6591 writeln!(code, " }}")?;
6592 writeln!(code, " generated_count += 1;")?;
6593 writeln!(code, " let logits = model.forward(next_token);")?;
6594 writeln!(code, " next_token = argmax(&logits);")?;
6595 writeln!(code, " }}")?;
6596 writeln!(
6597 code,
6598 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6599 )?;
6600 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6601 writeln!(code, " let gen_time_ms = gen_elapsed * 1000.0;")?;
6602 writeln!(code)?;
6603 writeln!(
6604 code,
6605 " // Final chunk with finish_reason, timing, and DONE sentinel"
6606 )?;
6607 writeln!(code, " sse_body.push_str(&format!(")?;
6608 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
6609 writeln!(code, " prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
6610 writeln!(code, " ));")?;
6611 writeln!(code)?;
6612 writeln!(
6613 code,
6614 " let resp = tiny_http::Response::from_string(sse_body)"
6615 )?;
6616 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
6617 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
6618 writeln!(code, " request.respond(resp).ok();")?;
6619
6620 writeln!(code, " }} else {{")?;
6621
6622 writeln!(
6624 code,
6625 " // Non-streaming: generate all tokens, return JSON"
6626 )?;
6627 writeln!(code, " let gen_start = Instant::now();")?;
6628 writeln!(code, " let mut generated_count: usize = 0;")?;
6629 writeln!(code, " let mut generated = String::new();")?;
6630 writeln!(code, " for _ in 0..max_tokens {{")?;
6631 writeln!(code, " if next_token == 2 {{ break; }}")?;
6632 writeln!(
6633 code,
6634 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
6635 )?;
6636 writeln!(code, " generated.push_str(&text);")?;
6637 writeln!(code, " }}")?;
6638 writeln!(code, " generated_count += 1;")?;
6639 writeln!(code, " let logits = model.forward(next_token);")?;
6640 writeln!(code, " next_token = argmax(&logits);")?;
6641 writeln!(code, " }}")?;
6642 writeln!(
6643 code,
6644 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
6645 )?;
6646 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
6647 writeln!(code)?;
6648 writeln!(code, " let resp_json = serde_json::json!({{")?;
6649 writeln!(code, " \"id\": \"chatcmpl-1\",")?;
6650 writeln!(code, " \"object\": \"chat.completion\",")?;
6651 writeln!(code, " \"choices\": [{{")?;
6652 writeln!(code, " \"index\": 0,")?;
6653 writeln!(code, " \"message\": {{")?;
6654 writeln!(code, " \"role\": \"assistant\",")?;
6655 writeln!(code, " \"content\": generated")?;
6656 writeln!(code, " }},")?;
6657 writeln!(code, " \"finish_reason\": \"stop\"")?;
6658 writeln!(code, " }}],")?;
6659 writeln!(code, " \"usage\": {{")?;
6660 writeln!(code, " \"prefill_tokens\": prefill_count,")?;
6661 writeln!(
6662 code,
6663 " \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
6664 )?;
6665 writeln!(
6666 code,
6667 " \"generation_tokens\": generated_count,"
6668 )?;
6669 writeln!(
6670 code,
6671 " \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
6672 )?;
6673 writeln!(code, " \"tokens_per_sec\": gen_tok_s")?;
6674 writeln!(code, " }}")?;
6675 writeln!(code, " }});")?;
6676 writeln!(
6677 code,
6678 " let resp = tiny_http::Response::from_string(resp_json.to_string())"
6679 )?;
6680 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
6681 writeln!(code, " request.respond(resp).ok();")?;
6682 writeln!(code, " }}")?;
6683 writeln!(code, "}}")?;
6684
6685 Ok(code)
6686}
6687
6688#[cfg(test)]
6693mod tests {
6694 use super::*;
6695 use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
6696
6697 fn minimal_config() -> ModelConfig {
6698 ModelConfig {
6699 architecture: Architecture::Llama,
6700 hidden_size: 64,
6701 intermediate_size: 128,
6702 num_layers: 2,
6703 num_attention_heads: 4,
6704 num_kv_heads: 4,
6705 head_dim: 16,
6706 vocab_size: 256,
6707 max_seq_len: 512,
6708 rms_norm_eps: 1e-5,
6709 rope_theta: 10000.0,
6710 dtype: DType::F32,
6711 sliding_window_size: None,
6712 qkv_bias: false,
6713 }
6714 }
6715
6716 fn minimal_graph() -> Graph {
6717 Graph::new("test-metal").with_config(minimal_config())
6718 }
6719
6720 #[test]
6721 fn generate_metal_project_creates_files() {
6722 let dir = tempfile::tempdir().unwrap();
6723 let graph = minimal_graph();
6724 generate_metal_project(&graph, dir.path(), "test-model").unwrap();
6725
6726 assert!(
6727 dir.path().join("Cargo.toml").exists(),
6728 "Cargo.toml should be created"
6729 );
6730 assert!(
6731 dir.path().join("src/model.rs").exists(),
6732 "src/model.rs should be created"
6733 );
6734 assert!(
6735 dir.path().join("src/main.rs").exists(),
6736 "src/main.rs should be created"
6737 );
6738 assert!(
6739 dir.path().join("shaders/kernels.metal").exists(),
6740 "shaders/kernels.metal should be created"
6741 );
6742 }
6743
6744 #[test]
6745 fn generated_cargo_toml_has_metal_dep() {
6746 let toml = generate_cargo_toml("my-model");
6747 assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
6748 assert!(
6749 toml.contains("tokenizers"),
6750 "Cargo.toml should depend on tokenizers"
6751 );
6752 assert!(
6753 toml.contains("memmap2"),
6754 "Cargo.toml should depend on memmap2"
6755 );
6756 assert!(toml.contains("half"), "Cargo.toml should depend on half");
6757 }
6758
6759 #[test]
6760 fn generated_model_rs_contains_metal_code() {
6761 let config = minimal_config();
6762 let model_rs = generate_model_rs(&config).unwrap();
6763
6764 assert!(
6765 model_rs.contains("pub struct MetalModel"),
6766 "model.rs should define MetalModel struct"
6767 );
6768 assert!(
6769 model_rs.contains("matmul_pipeline: ComputePipelineState"),
6770 "MetalModel should have matmul_pipeline field"
6771 );
6772 assert!(
6773 model_rs.contains("Device::system_default()"),
6774 "model.rs should use Metal device"
6775 );
6776 assert!(
6777 model_rs.contains("new_library_with_source"),
6778 "model.rs should compile Metal shaders"
6779 );
6780 assert!(
6781 model_rs.contains("fn new(weights: &[u8])"),
6782 "MetalModel should implement new()"
6783 );
6784 assert!(
6785 model_rs.contains("fn forward(&mut self, token_id: u32)"),
6786 "MetalModel should implement forward()"
6787 );
6788 }
6789
6790 #[test]
6791 fn generated_shaders_contain_kernel_names() {
6792 let shaders = generate_metal_shaders(&minimal_config());
6793
6794 assert!(
6795 shaders.contains("kernel void matmul_vec"),
6796 "shaders should contain matmul_vec kernel"
6797 );
6798 assert!(
6799 shaders.contains("kernel void rms_norm"),
6800 "shaders should contain rms_norm kernel"
6801 );
6802 assert!(
6803 shaders.contains("kernel void rope"),
6804 "shaders should contain rope kernel"
6805 );
6806 assert!(
6807 shaders.contains("kernel void softmax"),
6808 "shaders should contain softmax kernel"
6809 );
6810 assert!(
6811 shaders.contains("kernel void silu_mul("),
6812 "shaders should contain silu_mul kernel"
6813 );
6814 assert!(
6815 shaders.contains("kernel void silu_mul_fused"),
6816 "shaders should contain silu_mul_fused kernel"
6817 );
6818 assert!(
6819 shaders.contains("kernel void elementwise_add"),
6820 "shaders should contain elementwise_add kernel"
6821 );
6822 assert!(
6823 shaders.contains("kernel void attention"),
6824 "shaders should contain attention kernel"
6825 );
6826 assert!(
6827 shaders.contains("kernel void add_inplace"),
6828 "shaders should contain add_inplace kernel"
6829 );
6830 assert!(
6831 shaders.contains("kernel void copy_buffer"),
6832 "shaders should contain copy_buffer kernel"
6833 );
6834 assert!(
6835 shaders.contains("kernel void copy_offset"),
6836 "shaders should contain copy_offset kernel"
6837 );
6838 }
6839
6840 #[test]
6841 fn generated_shaders_use_simdgroup_features() {
6842 let shaders = generate_metal_shaders(&minimal_config());
6843
6844 assert!(
6845 shaders.contains("threadgroup_barrier"),
6846 "shaders should use threadgroup barriers"
6847 );
6848 assert!(
6849 shaders.contains("threadgroup float"),
6850 "shaders should use threadgroup shared memory"
6851 );
6852 assert!(
6853 shaders.contains("thread_index_in_threadgroup"),
6854 "shaders should use threadgroup indexing"
6855 );
6856 assert!(
6857 shaders.contains("simd_sum"),
6858 "shaders should use simd_sum for warp-level reduction"
6859 );
6860 assert!(
6861 shaders.contains("simd_max"),
6862 "attention kernel should use simd_max for cooperative softmax"
6863 );
6864 assert!(
6865 shaders.contains("thread_index_in_simdgroup"),
6866 "shaders should use simdgroup lane indexing"
6867 );
6868 assert!(
6869 shaders.contains("simdgroup_index_in_threadgroup"),
6870 "shaders should use simdgroup indexing within threadgroup"
6871 );
6872 assert!(
6873 shaders.contains("float4"),
6874 "matmul_vec should use float4 vectorized loads"
6875 );
6876 }
6877
6878 #[test]
6879 fn generated_main_rs_has_tokenizer_usage() {
6880 let config = minimal_config();
6881 let main_rs = generate_main_rs("test-model", &config).unwrap();
6882
6883 assert!(
6884 main_rs.contains("tokenizers::Tokenizer"),
6885 "main.rs should use tokenizers crate"
6886 );
6887 assert!(
6888 main_rs.contains("MetalModel::new"),
6889 "main.rs should call MetalModel::new"
6890 );
6891 assert!(
6892 main_rs.contains("model.forward"),
6893 "main.rs should call model.forward"
6894 );
6895 assert!(
6896 main_rs.contains("memmap2"),
6897 "main.rs should use memmap2 for zero-copy weight loading"
6898 );
6899 }
6900
6901 #[test]
6902 fn missing_config_returns_error() {
6903 let dir = tempfile::tempdir().unwrap();
6904 let graph = Graph::new("no-config");
6905 let result = generate_metal_project(&graph, dir.path(), "fail");
6906 assert!(
6907 matches!(result, Err(MetalCodegenError::MissingConfig)),
6908 "should fail with MissingConfig when graph has no config"
6909 );
6910 }
6911
6912 #[test]
6913 fn sanitize_name_works() {
6914 assert_eq!(sanitize_name("My Model!"), "my-model");
6915 assert_eq!(sanitize_name("test_model"), "test-model");
6916 assert_eq!(sanitize_name("simple"), "simple");
6917 }
6918
6919 #[test]
6920 fn generated_forward_uses_single_command_buffer() {
6921 let config = minimal_config();
6922 let model_rs = generate_model_rs(&config).unwrap();
6923
6924 let forward_start = model_rs
6927 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6928 .unwrap();
6929 let forward_body = &model_rs[forward_start..];
6930 let forward_end = forward_body
6932 .find("\n pub fn forward_profile")
6933 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
6934 .or_else(|| forward_body.find("\n fn dispatch_"))
6935 .unwrap_or(forward_body.len());
6936 let forward_code = &forward_body[..forward_end];
6937
6938 let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
6940 assert_eq!(
6941 cmd_buf_count, 1,
6942 "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
6943 );
6944
6945 let commit_count = forward_code.matches("cmd.commit()").count();
6947 assert_eq!(
6948 commit_count, 1,
6949 "forward() should commit exactly once, found {commit_count}"
6950 );
6951
6952 let wait_count = forward_code.matches("wait_until_completed()").count();
6954 assert!(
6955 wait_count >= 1 && wait_count <= 2,
6956 "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
6957 );
6958 }
6959
6960 #[test]
6961 fn generated_model_has_preallocated_working_buffers() {
6962 let config = minimal_config();
6963 let model_rs = generate_model_rs(&config).unwrap();
6964
6965 for buf_name in &[
6966 "normed_buf",
6967 "qkv_buf",
6968 "attn_out_buf",
6969 "attn_proj_buf",
6970 "gate_up_buf",
6971 "ffn_hidden_buf",
6972 "ffn_out_buf",
6973 "add_tmp_buf",
6974 ] {
6975 assert!(
6976 model_rs.contains(&format!("{buf_name}: Buffer")),
6977 "MetalModel should have pre-allocated {buf_name} field"
6978 );
6979 }
6980 }
6981
6982 #[test]
6983 fn generated_dispatch_helpers_take_compute_encoder_ref() {
6984 let config = minimal_config();
6985 let model_rs = generate_model_rs(&config).unwrap();
6986
6987 for method in &[
6988 "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
6989 "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
6990 "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
6991 "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
6992 "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
6993 "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
6994 "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
6995 "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
6996 "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
6997 "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
6998 "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
6999 "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
7000 "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
7001 ] {
7002 assert!(
7003 model_rs.contains(method),
7004 "model.rs should contain dispatch helper: {method}"
7005 );
7006 }
7007 }
7008
7009 #[test]
7010 fn generated_helpers_do_not_create_command_buffers_or_encoders() {
7011 let config = minimal_config();
7012 let model_rs = generate_model_rs(&config).unwrap();
7013
7014 let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
7016 let helpers_code = &model_rs[helpers_start..];
7017
7018 assert!(
7020 !helpers_code.contains("self.queue.new_command_buffer()"),
7021 "dispatch helpers should not create their own command buffers"
7022 );
7023
7024 assert!(
7026 !helpers_code.contains("new_compute_command_encoder()"),
7027 "dispatch helpers should not create their own compute encoders"
7028 );
7029
7030 assert!(
7032 !helpers_code.contains("end_encoding()"),
7033 "dispatch helpers should not call end_encoding"
7034 );
7035
7036 assert!(
7038 !helpers_code.contains(".commit()"),
7039 "dispatch helpers should not commit command buffers"
7040 );
7041 assert!(
7042 !helpers_code.contains("wait_until_completed"),
7043 "dispatch helpers should not wait on command buffers"
7044 );
7045 }
7046
7047 #[test]
7048 fn generated_forward_batches_compute_encoders() {
7049 let config = minimal_config();
7050 let model_rs = generate_model_rs(&config).unwrap();
7051
7052 let forward_start = model_rs
7054 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7055 .unwrap();
7056 let forward_body = &model_rs[forward_start..];
7057 let forward_end = forward_body
7058 .find("\n pub fn forward_profile")
7059 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7060 .or_else(|| forward_body.find("\n fn dispatch_"))
7061 .unwrap_or(forward_body.len());
7062 let forward_code = &forward_body[..forward_end];
7063
7064 assert!(
7066 !forward_code.contains("device.new_buffer"),
7067 "forward() should not allocate new buffers per call"
7068 );
7069
7070 let compute_encoder_count = forward_code
7073 .matches("new_compute_command_encoder()")
7074 .count();
7075 let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
7076
7077 assert_eq!(
7079 compute_encoder_count, 1,
7080 "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
7081 );
7082 assert_eq!(
7083 blit_encoder_count, 0,
7084 "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
7085 );
7086 }
7087
7088 #[test]
7089 fn generated_forward_uses_add_inplace() {
7090 let config = minimal_config();
7091 let model_rs = generate_model_rs(&config).unwrap();
7092
7093 assert!(
7095 model_rs.contains("dispatch_add_inplace"),
7096 "forward() should use dispatch_add_inplace for residual connections"
7097 );
7098 assert!(
7099 model_rs.contains("add_inplace_pipeline"),
7100 "MetalModel should have add_inplace_pipeline"
7101 );
7102 }
7103
7104 fn minimal_q8_config() -> ModelConfig {
7105 ModelConfig {
7106 architecture: Architecture::Llama,
7107 hidden_size: 64,
7108 intermediate_size: 128,
7109 num_layers: 2,
7110 num_attention_heads: 4,
7111 num_kv_heads: 4,
7112 head_dim: 16,
7113 vocab_size: 256,
7114 max_seq_len: 512,
7115 rms_norm_eps: 1e-5,
7116 rope_theta: 10000.0,
7117 dtype: DType::Q8_0,
7118 sliding_window_size: None,
7119 qkv_bias: false,
7120 }
7121 }
7122
7123 #[test]
7124 fn generated_shaders_contain_q8_kernel() {
7125 let shaders = generate_metal_shaders(&minimal_config());
7126
7127 assert!(
7128 shaders.contains("kernel void matmul_vec_q8"),
7129 "shaders should contain matmul_vec_q8 kernel"
7130 );
7131 assert!(
7132 shaders.contains("device const uchar* matrix"),
7133 "matmul_vec_q8 should accept raw Q8_0 bytes"
7134 );
7135 assert!(
7136 shaders.contains("packed_short4"),
7137 "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
7138 );
7139 assert!(
7140 shaders.contains("as_type<char2>"),
7141 "matmul_vec_q8 should bitcast short lanes to char2"
7142 );
7143 assert!(
7144 shaders.contains("device const half*"),
7145 "matmul_vec_q8 should read f16 scale via half pointer"
7146 );
7147 }
7148
7149 #[test]
7150 fn generated_model_uses_fused_qkv_projections() {
7151 let config = minimal_config();
7152 let model_rs = generate_model_rs(&config).unwrap();
7153
7154 assert!(
7156 model_rs.contains("qkv_weight: Buffer"),
7157 "LayerBuffers should have fused qkv_weight field"
7158 );
7159 assert!(
7161 !model_rs.contains(" q_weight: Buffer"),
7162 "LayerBuffers should not have separate q_weight field"
7163 );
7164 assert!(
7165 !model_rs.contains(" k_weight: Buffer"),
7166 "LayerBuffers should not have separate k_weight field"
7167 );
7168 assert!(
7169 !model_rs.contains(" v_weight: Buffer"),
7170 "LayerBuffers should not have separate v_weight field"
7171 );
7172
7173 assert!(
7175 model_rs.contains("gate_up_weight: Buffer"),
7176 "LayerBuffers should have fused gate_up_weight field"
7177 );
7178 assert!(
7180 !model_rs.contains(" gate_weight: Buffer"),
7181 "LayerBuffers should not have separate gate_weight field"
7182 );
7183 assert!(
7184 !model_rs.contains(" up_weight: Buffer"),
7185 "LayerBuffers should not have separate up_weight field"
7186 );
7187
7188 assert!(
7190 model_rs.contains("qkv_buf: Buffer"),
7191 "MetalModel should have fused qkv_buf"
7192 );
7193 assert!(
7194 model_rs.contains("gate_up_buf: Buffer"),
7195 "MetalModel should have fused gate_up_buf"
7196 );
7197
7198 assert!(
7200 model_rs.contains("dispatch_silu_mul_fused"),
7201 "forward pass should use dispatch_silu_mul_fused"
7202 );
7203 assert!(
7204 model_rs.contains("dispatch_rope_offset"),
7205 "forward pass should use dispatch_rope_offset for fused QKV"
7206 );
7207 assert!(
7208 model_rs.contains("dispatch_attention_offset"),
7209 "forward pass should use dispatch_attention_offset for fused QKV"
7210 );
7211 }
7212
7213 #[test]
7214 fn q8_model_has_matmul_q8_pipeline() {
7215 let config = minimal_q8_config();
7216 let model_rs = generate_model_rs(&config).unwrap();
7217
7218 assert!(
7219 model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
7220 "MetalModel should have matmul_q8_pipeline field"
7221 );
7222 assert!(
7223 model_rs.contains("matmul_q8_pipeline,"),
7224 "MetalModel Self should include matmul_q8_pipeline"
7225 );
7226 }
7227
7228 #[test]
7229 fn q8_model_uses_dispatch_matmul_q8() {
7230 let config = minimal_q8_config();
7231 let model_rs = generate_model_rs(&config).unwrap();
7232
7233 assert!(
7234 model_rs.contains("dispatch_matmul_q8"),
7235 "Q8_0 model should use dispatch_matmul_q8 for projections"
7236 );
7237 assert!(
7238 model_rs.contains("fn dispatch_matmul_q8"),
7239 "model.rs should define dispatch_matmul_q8 method"
7240 );
7241 }
7242
7243 #[test]
7244 fn q8_model_loads_raw_bytes_not_dequantized() {
7245 let config = minimal_q8_config();
7246 let model_rs = generate_model_rs(&config).unwrap();
7247
7248 assert!(
7250 !model_rs.contains("f16_to_f32"),
7251 "Q8_0 model should not dequantize weights to f32"
7252 );
7253 assert!(
7254 !model_rs.contains("f32_data"),
7255 "Q8_0 model should not create f32 weight data"
7256 );
7257
7258 assert!(
7260 model_rs.contains("total_raw as u64"),
7261 "Q8_0 model should load raw bytes into Metal buffer"
7262 );
7263 }
7264
7265 #[test]
7266 fn q8_model_norms_stay_f32() {
7267 let config = minimal_q8_config();
7268 let model_rs = generate_model_rs(&config).unwrap();
7269
7270 assert!(
7272 model_rs.contains("let attn_norm = next_f32_buffer"),
7273 "attn_norm should use f32 buffer even for Q8_0 models"
7274 );
7275 assert!(
7276 model_rs.contains("let ffn_norm = next_f32_buffer"),
7277 "ffn_norm should use f32 buffer even for Q8_0 models"
7278 );
7279 assert!(
7280 model_rs.contains("let norm_buf = next_f32_buffer"),
7281 "final norm should use f32 buffer even for Q8_0 models"
7282 );
7283 }
7284
7285 #[test]
7286 fn q8_model_uses_fused_weight_loading() {
7287 let config = minimal_q8_config();
7288 let model_rs = generate_model_rs(&config).unwrap();
7289
7290 assert!(
7292 model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
7293 "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
7294 );
7295 assert!(
7297 model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
7298 "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
7299 );
7300 assert!(
7302 model_rs.contains("let o_weight = next_q8_buffer"),
7303 "Q8_0 model should use next_q8_buffer for O weight"
7304 );
7305 assert!(
7306 model_rs.contains("let down_weight = next_q8_buffer"),
7307 "Q8_0 model should use next_q8_buffer for down weight"
7308 );
7309 }
7310
7311 #[test]
7312 fn f32_model_does_not_use_q8_dispatch() {
7313 let config = minimal_config();
7314 let model_rs = generate_model_rs(&config).unwrap();
7315
7316 let forward_start = model_rs
7318 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7319 .unwrap();
7320 let forward_body = &model_rs[forward_start..];
7321 let forward_end = forward_body
7322 .find("\n fn dispatch_")
7323 .unwrap_or(forward_body.len());
7324 let forward_code = &forward_body[..forward_end];
7325
7326 assert!(
7327 !forward_code.contains("dispatch_matmul_q8"),
7328 "f32 model forward should not use dispatch_matmul_q8"
7329 );
7330 }
7331
7332 #[test]
7333 fn q8_dispatch_helper_takes_compute_encoder_ref() {
7334 let config = minimal_q8_config();
7335 let model_rs = generate_model_rs(&config).unwrap();
7336
7337 assert!(
7338 model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
7339 "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
7340 );
7341 }
7342
7343 #[test]
7344 fn generated_model_has_double_buffered_prefill() {
7345 let config = minimal_config();
7346 let model_rs = generate_model_rs(&config).unwrap();
7347
7348 assert!(
7350 model_rs.contains("prev_cmd: Option<CommandBuffer>"),
7351 "MetalModel should have prev_cmd field for double-buffered prefill"
7352 );
7353
7354 assert!(
7356 model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
7357 "MetalModel should have forward_prefill method"
7358 );
7359
7360 assert!(
7362 model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
7363 "forward() should drain prev_cmd from previous prefill"
7364 );
7365 }
7366
7367 #[test]
7368 fn generated_main_rs_uses_forward_prefill_for_prompt() {
7369 let config = minimal_config();
7370 let main_rs = generate_main_rs("test-model", &config).unwrap();
7371
7372 assert!(
7373 main_rs.contains("forward_prefill"),
7374 "main.rs should use forward_prefill for intermediate prompt tokens"
7375 );
7376 assert!(
7377 main_rs.contains("double-buffered"),
7378 "main.rs should document double-buffered prefill"
7379 );
7380 }
7381
7382 #[test]
7383 fn generated_shaders_q8_uses_wide_vectorized_loads() {
7384 let shaders = generate_metal_shaders(&minimal_config());
7385
7386 assert!(
7387 shaders.contains("packed_short4"),
7388 "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
7389 );
7390 assert!(
7391 shaders.contains("d0[0]"),
7392 "matmul_vec_q8 should index the wide pointer for row 0"
7393 );
7394 assert!(
7395 shaders.contains("as_type<char2>"),
7396 "matmul_vec_q8 should bitcast short lanes to char2"
7397 );
7398 assert!(
7399 shaders.contains("dot("),
7400 "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
7401 );
7402 }
7403
7404 fn minimal_q4_config() -> ModelConfig {
7407 ModelConfig {
7408 architecture: Architecture::Llama,
7409 hidden_size: 64,
7410 intermediate_size: 128,
7411 num_layers: 2,
7412 num_attention_heads: 4,
7413 num_kv_heads: 4,
7414 head_dim: 16,
7415 vocab_size: 256,
7416 max_seq_len: 512,
7417 rms_norm_eps: 1e-5,
7418 rope_theta: 10000.0,
7419 dtype: DType::Q4_0,
7420 sliding_window_size: None,
7421 qkv_bias: false,
7422 }
7423 }
7424
7425 #[test]
7426 fn generated_shaders_contain_q4_kernel() {
7427 let shaders = generate_metal_shaders(&minimal_config());
7428
7429 assert!(
7430 shaders.contains("kernel void matmul_vec_q4"),
7431 "shaders should contain matmul_vec_q4 kernel"
7432 );
7433 assert!(
7434 shaders.contains("Q4_ROWS_PER_TG"),
7435 "shaders should define Q4_ROWS_PER_TG constant"
7436 );
7437 assert!(
7438 shaders.contains("Q4_ROWS_PER_SG"),
7439 "shaders should define Q4_ROWS_PER_SG constant"
7440 );
7441 }
7442
7443 #[test]
7444 fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
7445 let shaders = generate_metal_shaders(&minimal_config());
7446
7447 assert!(
7449 shaders.contains("uchar4"),
7450 "matmul_vec_q4 should use uchar4 for packed byte loads"
7451 );
7452 assert!(
7454 shaders.contains("&0xF"),
7455 "matmul_vec_q4 should extract low nibble with &0xF"
7456 );
7457 assert!(
7458 shaders.contains(">>4"),
7459 "matmul_vec_q4 should extract high nibble with >>4"
7460 );
7461 assert!(
7463 shaders.contains("-8)"),
7464 "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
7465 );
7466 assert!(
7468 shaders.contains("blk * 18"),
7469 "matmul_vec_q4 should use 18-byte block stride"
7470 );
7471 }
7472
7473 #[test]
7474 fn q4_model_has_matmul_q4_pipeline() {
7475 let config = minimal_q4_config();
7476 let model_rs = generate_model_rs(&config).unwrap();
7477
7478 assert!(
7479 model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
7480 "MetalModel should have matmul_q4_pipeline field"
7481 );
7482 assert!(
7483 model_rs.contains("matmul_q4_pipeline,"),
7484 "MetalModel Self should include matmul_q4_pipeline"
7485 );
7486 }
7487
7488 #[test]
7489 fn q4_model_uses_dispatch_matmul_q4() {
7490 let config = minimal_q4_config();
7491 let model_rs = generate_model_rs(&config).unwrap();
7492
7493 assert!(
7494 model_rs.contains("dispatch_matmul_q4"),
7495 "Q4_0 model should use dispatch_matmul_q4 for projections"
7496 );
7497 assert!(
7498 model_rs.contains("fn dispatch_matmul_q4"),
7499 "model.rs should define dispatch_matmul_q4 method"
7500 );
7501 }
7502
7503 #[test]
7504 fn q4_model_loads_raw_bytes_not_dequantized() {
7505 let config = minimal_q4_config();
7506 let model_rs = generate_model_rs(&config).unwrap();
7507
7508 assert!(
7510 !model_rs.contains("f16_to_f32"),
7511 "Q4_0 model should not dequantize weights to f32"
7512 );
7513
7514 assert!(
7516 model_rs.contains("total_raw as u64"),
7517 "Q4_0 model should load raw bytes into Metal buffer"
7518 );
7519 }
7520
7521 #[test]
7522 fn q4_model_norms_stay_f32() {
7523 let config = minimal_q4_config();
7524 let model_rs = generate_model_rs(&config).unwrap();
7525
7526 assert!(
7527 model_rs.contains("let attn_norm = next_f32_buffer"),
7528 "attn_norm should use f32 buffer even for Q4_0 models"
7529 );
7530 assert!(
7531 model_rs.contains("let ffn_norm = next_f32_buffer"),
7532 "ffn_norm should use f32 buffer even for Q4_0 models"
7533 );
7534 assert!(
7535 model_rs.contains("let norm_buf = next_f32_buffer"),
7536 "final norm should use f32 buffer even for Q4_0 models"
7537 );
7538 }
7539
7540 #[test]
7541 fn q4_model_uses_fused_weight_loading() {
7542 let config = minimal_q4_config();
7543 let model_rs = generate_model_rs(&config).unwrap();
7544
7545 assert!(
7546 model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
7547 "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
7548 );
7549 assert!(
7550 model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
7551 "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
7552 );
7553 assert!(
7554 model_rs.contains("let o_weight = next_q4_buffer"),
7555 "Q4_0 model should use next_q4_buffer for O weight"
7556 );
7557 assert!(
7558 model_rs.contains("let down_weight = next_q4_buffer"),
7559 "Q4_0 model should use next_q4_buffer for down weight"
7560 );
7561 }
7562
7563 #[test]
7564 fn attention_flash_batch_kernel_wired() {
7565 let config = minimal_config();
7569 let model_rs = generate_model_rs(&config).unwrap();
7570 let shaders = generate_metal_shaders(&config);
7571
7572 assert!(
7573 shaders.contains("kernel void attention_flash_batch"),
7574 "shaders.metal must contain the attention_flash_batch kernel"
7575 );
7576 assert!(
7577 shaders.contains("FLASH_K_TILE"),
7578 "flash kernel must tile K/V with a FLASH_K_TILE constant"
7579 );
7580 assert!(
7581 model_rs.contains("attention_flash_batch_pipeline"),
7582 "MetalModel must register the flash attention pipeline"
7583 );
7584 assert!(
7586 model_rs.contains("max_seq > 2000"),
7587 "dispatch_attention_batch must branch on max_seq > 2000"
7588 );
7589 assert!(
7590 model_rs.contains("&self.attention_flash_batch_pipeline"),
7591 "dispatch must reference the flash pipeline"
7592 );
7593 }
7594
7595 #[test]
7596 fn qwen2_qkv_bias_wired_through_metal_codegen() {
7597 let config = ModelConfig {
7601 architecture: Architecture::Qwen2,
7602 qkv_bias: true,
7603 ..minimal_config()
7604 };
7605 let model_rs = generate_model_rs(&config).unwrap();
7606
7607 assert!(
7608 model_rs.contains("qkv_bias: Buffer"),
7609 "Qwen2 LayerBuffers must declare qkv_bias field"
7610 );
7611 assert!(
7612 model_rs.contains("let qkv_bias = next_f32_buffer"),
7613 "Qwen2 layer init must load the bias from the weight blob"
7614 );
7615 assert!(
7616 model_rs.contains("add_bias_batch_pipeline"),
7617 "Qwen2 model struct must include the add_bias_batch_pipeline"
7618 );
7619 assert!(
7620 model_rs.contains("fn dispatch_add_bias_batch"),
7621 "Qwen2 codegen must emit dispatch_add_bias_batch helper"
7622 );
7623 assert!(
7624 model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
7625 "forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
7626 );
7627 assert!(
7628 model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
7629 "forward must call dispatch_add_bias_batch on the single-token qkv_buf"
7630 );
7631
7632 let shaders = generate_metal_shaders(&config);
7634 assert!(
7635 shaders.contains("kernel void add_bias_batch"),
7636 "shaders.metal must contain the add_bias_batch kernel"
7637 );
7638 }
7639
7640 #[test]
7641 fn llama_does_not_emit_qkv_bias_machinery() {
7642 let config = minimal_config();
7645 assert!(!config.qkv_bias);
7646 let model_rs = generate_model_rs(&config).unwrap();
7647 assert!(
7648 !model_rs.contains("qkv_bias: Buffer"),
7649 "Llama must not have qkv_bias field"
7650 );
7651 assert!(
7652 !model_rs.contains("add_bias_batch_pipeline"),
7653 "Llama must not pull in add_bias_batch_pipeline"
7654 );
7655 assert!(
7656 !model_rs.contains("dispatch_add_bias_batch"),
7657 "Llama must not call dispatch_add_bias_batch"
7658 );
7659 }
7660
7661 #[test]
7662 fn q4_dispatch_helper_takes_compute_encoder_ref() {
7663 let config = minimal_q4_config();
7664 let model_rs = generate_model_rs(&config).unwrap();
7665
7666 assert!(
7667 model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
7668 "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
7669 );
7670 }
7671
7672 #[test]
7673 fn f32_model_does_not_use_q4_dispatch() {
7674 let config = minimal_config();
7675 let model_rs = generate_model_rs(&config).unwrap();
7676
7677 let forward_start = model_rs
7678 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7679 .unwrap();
7680 let forward_body = &model_rs[forward_start..];
7681 let forward_end = forward_body
7682 .find("\n fn dispatch_")
7683 .unwrap_or(forward_body.len());
7684 let forward_code = &forward_body[..forward_end];
7685
7686 assert!(
7687 !forward_code.contains("dispatch_matmul_q4"),
7688 "f32 model forward should not use dispatch_matmul_q4"
7689 );
7690 }
7691
7692 #[test]
7693 fn q4_model_lm_head_uses_q4_buffer() {
7694 let config = minimal_q4_config();
7695 let model_rs = generate_model_rs(&config).unwrap();
7696
7697 assert!(
7698 model_rs.contains("let lm_head_buf = next_q4_buffer"),
7699 "Q4_0 model should use next_q4_buffer for lm_head"
7700 );
7701 }
7702
7703 #[test]
7704 fn vec_tile_size_matches_model_dimensions() {
7705 let small = minimal_config();
7707 let shaders_small = generate_metal_shaders(&small);
7708 assert!(
7709 shaders_small.contains("vec_tile[128]"),
7710 "vec_tile should be sized to max(hidden, intermediate) = 128"
7711 );
7712
7713 let mut large = minimal_config();
7715 large.hidden_size = 2048;
7716 large.intermediate_size = 8192;
7717 let shaders_large = generate_metal_shaders(&large);
7718 assert!(
7719 shaders_large.contains("vec_tile[8192]"),
7720 "vec_tile should be 8192 for models with intermediate=8192"
7721 );
7722 assert!(
7723 !shaders_large.contains("vec_tile[4096]"),
7724 "vec_tile should NOT be hardcoded to 4096"
7725 );
7726 }
7727
7728 #[test]
7729 fn generated_cargo_toml_has_server_deps() {
7730 let toml = generate_cargo_toml("my-model");
7731 assert!(
7732 toml.contains("tiny_http"),
7733 "Cargo.toml should depend on tiny_http"
7734 );
7735 assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
7736 assert!(
7737 toml.contains("serde_json"),
7738 "Cargo.toml should depend on serde_json"
7739 );
7740 }
7741
7742 #[test]
7743 fn generated_main_rs_has_serve_mode() {
7744 let config = minimal_config();
7745 let main_rs = generate_main_rs("test-model", &config).unwrap();
7746
7747 assert!(
7748 main_rs.contains("--serve"),
7749 "main.rs should parse --serve flag"
7750 );
7751 assert!(
7752 main_rs.contains("--port"),
7753 "main.rs should parse --port flag"
7754 );
7755 assert!(
7756 main_rs.contains("fn serve("),
7757 "main.rs should define serve function"
7758 );
7759 assert!(
7760 main_rs.contains("tiny_http::Server::http"),
7761 "main.rs should create tiny_http server"
7762 );
7763 }
7764
7765 #[test]
7766 fn generated_main_rs_has_chat_completions_endpoint() {
7767 let config = minimal_config();
7768 let main_rs = generate_main_rs("test-model", &config).unwrap();
7769
7770 assert!(
7771 main_rs.contains("/v1/chat/completions"),
7772 "main.rs should handle /v1/chat/completions endpoint"
7773 );
7774 assert!(
7775 main_rs.contains("/v1/models"),
7776 "main.rs should handle /v1/models endpoint"
7777 );
7778 assert!(
7779 main_rs.contains("/health"),
7780 "main.rs should handle /health endpoint"
7781 );
7782 }
7783
7784 #[test]
7785 fn generated_main_rs_has_sse_streaming() {
7786 let config = minimal_config();
7787 let main_rs = generate_main_rs("test-model", &config).unwrap();
7788
7789 assert!(
7790 main_rs.contains("text/event-stream"),
7791 "main.rs should set SSE content type for streaming"
7792 );
7793 assert!(
7794 main_rs.contains("chat.completion.chunk"),
7795 "main.rs should emit SSE chunks"
7796 );
7797 assert!(
7798 main_rs.contains("[DONE]"),
7799 "main.rs should emit [DONE] sentinel"
7800 );
7801 }
7802
7803 #[test]
7804 fn generated_main_rs_has_chat_message_formatting() {
7805 let config = minimal_config();
7806 let main_rs = generate_main_rs("test-model", &config).unwrap();
7807
7808 assert!(
7809 main_rs.contains("fn format_chat_messages"),
7810 "main.rs should define format_chat_messages function"
7811 );
7812 assert!(
7813 main_rs.contains("<|im_start|>"),
7814 "main.rs should use ChatML format"
7815 );
7816 assert!(
7817 main_rs.contains("<|im_end|>"),
7818 "main.rs should use ChatML format"
7819 );
7820 }
7821
7822 #[test]
7823 fn generated_main_rs_has_request_types() {
7824 let config = minimal_config();
7825 let main_rs = generate_main_rs("test-model", &config).unwrap();
7826
7827 assert!(
7828 main_rs.contains("struct ChatRequest"),
7829 "main.rs should define ChatRequest struct"
7830 );
7831 assert!(
7832 main_rs.contains("struct ChatMessage"),
7833 "main.rs should define ChatMessage struct"
7834 );
7835 assert!(
7836 main_rs.contains("Deserialize"),
7837 "main.rs should derive Deserialize for request types"
7838 );
7839 }
7840
7841 #[test]
7842 fn generated_model_has_reset_method() {
7843 let config = minimal_config();
7844 let model_rs = generate_model_rs(&config).unwrap();
7845
7846 assert!(
7847 model_rs.contains("pub fn reset(&mut self)"),
7848 "model.rs should have a reset() method for multi-request serving"
7849 );
7850 assert!(
7851 model_rs.contains("self.pos = 0"),
7852 "reset() should reset position to 0"
7853 );
7854 }
7855
7856 #[test]
7857 fn generated_main_rs_cli_mode_still_works() {
7858 let config = minimal_config();
7859 let main_rs = generate_main_rs("test-model", &config).unwrap();
7860
7861 assert!(
7863 main_rs.contains("fn cli_mode("),
7864 "main.rs should define cli_mode function"
7865 );
7866 assert!(
7867 main_rs.contains("model.forward"),
7868 "main.rs should call model.forward"
7869 );
7870 assert!(
7871 main_rs.contains("model.forward_prefill"),
7872 "main.rs should call model.forward_prefill"
7873 );
7874 }
7875
7876 #[test]
7879 fn generated_shaders_contain_batch_kernels() {
7880 let shaders = generate_metal_shaders(&minimal_config());
7881
7882 assert!(
7883 shaders.contains("kernel void matmul_vec_batch"),
7884 "shaders should contain matmul_vec_batch kernel"
7885 );
7886 assert!(
7887 shaders.contains("kernel void matmul_vec_q8_batch"),
7888 "shaders should contain matmul_vec_q8_batch kernel"
7889 );
7890 assert!(
7891 shaders.contains("kernel void matmul_q8_gemm_batch"),
7892 "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
7893 );
7894 assert!(
7895 shaders.contains("kernel void matmul_vec_q4_batch"),
7896 "shaders should contain matmul_vec_q4_batch kernel"
7897 );
7898 assert!(
7899 shaders.contains("kernel void rms_norm_batch"),
7900 "shaders should contain rms_norm_batch kernel"
7901 );
7902 assert!(
7903 shaders.contains("kernel void silu_mul_fused_batch"),
7904 "shaders should contain silu_mul_fused_batch kernel"
7905 );
7906 assert!(
7907 shaders.contains("kernel void add_inplace_batch"),
7908 "shaders should contain add_inplace_batch kernel"
7909 );
7910 assert!(
7911 shaders.contains("kernel void copy_embedding_batch"),
7912 "shaders should contain copy_embedding_batch kernel"
7913 );
7914 }
7915
7916 #[test]
7917 fn generated_model_has_batch_pipelines() {
7918 let config = minimal_config();
7919 let model_rs = generate_model_rs(&config).unwrap();
7920
7921 for pipeline in &[
7922 "matmul_batch_pipeline",
7923 "matmul_q8_batch_pipeline",
7924 "matmul_q8_gemm_batch_pipeline",
7925 "matmul_q4_batch_pipeline",
7926 "rms_norm_batch_pipeline",
7927 "rope_batch_pipeline",
7928 "silu_mul_fused_batch_pipeline",
7929 "add_inplace_batch_pipeline",
7930 "copy_embedding_batch_pipeline",
7931 "attention_batch_pipeline",
7932 "copy_kv_batch_pipeline",
7933 ] {
7934 assert!(
7935 model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
7936 "MetalModel should have {pipeline} field"
7937 );
7938 }
7939 }
7940
7941 #[test]
7942 fn generated_model_has_batch_buffers() {
7943 let config = minimal_config();
7944 let model_rs = generate_model_rs(&config).unwrap();
7945
7946 for buf in &[
7947 "batch_hidden_buf",
7948 "batch_residual_buf",
7949 "batch_qkv_buf",
7950 "batch_attn_out_buf",
7951 "batch_attn_proj_buf",
7952 "batch_gate_up_buf",
7953 "batch_ffn_hidden_buf",
7954 "batch_ffn_out_buf",
7955 "batch_tokens_buf",
7956 "batch_positions_buf",
7957 ] {
7958 assert!(
7959 model_rs.contains(&format!("{buf}: Buffer")),
7960 "MetalModel should have {buf} field"
7961 );
7962 }
7963 }
7964
7965 #[test]
7966 fn generated_model_has_forward_prefill_batch() {
7967 let config = minimal_config();
7968 let model_rs = generate_model_rs(&config).unwrap();
7969
7970 assert!(
7971 model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
7972 "MetalModel should have forward_prefill_batch method"
7973 );
7974
7975 assert!(
7977 model_rs.contains("self.forward_prefill_batch(&[token_id])"),
7978 "forward_prefill should delegate to forward_prefill_batch"
7979 );
7980 }
7981
7982 #[test]
7983 fn generated_model_has_max_batch_size_constant() {
7984 let config = minimal_config();
7985 let model_rs = generate_model_rs(&config).unwrap();
7986
7987 assert!(
7988 model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
7989 "model.rs should define MAX_BATCH_SIZE constant"
7990 );
7991 }
7992
7993 #[test]
7994 fn forward_prefill_batch_uses_batch_dispatch() {
7995 let config = minimal_config();
7996 let model_rs = generate_model_rs(&config).unwrap();
7997
7998 let batch_start = model_rs
7999 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8000 .unwrap();
8001 let batch_body = &model_rs[batch_start..];
8002 let batch_end = batch_body
8003 .find("\n pub fn reset")
8004 .unwrap_or(batch_body.len());
8005 let batch_code = &batch_body[..batch_end];
8006
8007 assert!(
8009 batch_code.contains("dispatch_rms_norm_batch"),
8010 "forward_prefill_batch should use dispatch_rms_norm_batch"
8011 );
8012 assert!(
8013 batch_code.contains("dispatch_copy_embedding_batch"),
8014 "forward_prefill_batch should use dispatch_copy_embedding_batch"
8015 );
8016 assert!(
8017 batch_code.contains("dispatch_silu_mul_fused_batch"),
8018 "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
8019 );
8020 assert!(
8022 batch_code.contains("dispatch_attention_batch"),
8023 "forward_prefill_batch should use dispatch_attention_batch"
8024 );
8025 assert!(
8027 batch_code.contains("dispatch_copy_kv_both_batch"),
8028 "forward_prefill_batch should use dispatch_copy_kv_both_batch"
8029 );
8030 assert!(
8032 batch_code.contains("dispatch_rope_qk_batch"),
8033 "forward_prefill_batch should use dispatch_rope_qk_batch"
8034 );
8035 }
8036
8037 #[test]
8038 fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
8039 let config = minimal_q8_config();
8040 let model_rs = generate_model_rs(&config).unwrap();
8041
8042 let batch_start = model_rs
8043 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8044 .unwrap();
8045 let batch_body = &model_rs[batch_start..];
8046 let batch_end = batch_body
8047 .find("\n pub fn reset")
8048 .unwrap_or(batch_body.len());
8049 let batch_code = &batch_body[..batch_end];
8050
8051 assert!(
8052 batch_code.contains("dispatch_matmul_q8_batch"),
8053 "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
8054 );
8055 }
8056
8057 #[test]
8058 fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
8059 let config = minimal_q4_config();
8060 let model_rs = generate_model_rs(&config).unwrap();
8061
8062 let batch_start = model_rs
8063 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8064 .unwrap();
8065 let batch_body = &model_rs[batch_start..];
8066 let batch_end = batch_body
8067 .find("\n pub fn reset")
8068 .unwrap_or(batch_body.len());
8069 let batch_code = &batch_body[..batch_end];
8070
8071 assert!(
8072 batch_code.contains("dispatch_matmul_q4_batch"),
8073 "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
8074 );
8075 }
8076
8077 #[test]
8078 fn generated_main_rs_uses_batched_prefill() {
8079 let config = minimal_config();
8080 let main_rs = generate_main_rs("test-model", &config).unwrap();
8081
8082 assert!(
8083 main_rs.contains("forward_prefill_batch"),
8084 "main.rs should use forward_prefill_batch for prompt tokens"
8085 );
8086 }
8087
8088 #[test]
8089 fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
8090 let config = minimal_config();
8091 let model_rs = generate_model_rs(&config).unwrap();
8092
8093 let batch_start = model_rs
8094 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
8095 .unwrap();
8096 let batch_body = &model_rs[batch_start..];
8097 let batch_end = batch_body
8098 .find("\n pub fn reset")
8099 .unwrap_or(batch_body.len());
8100 let batch_code = &batch_body[..batch_end];
8101
8102 assert!(
8103 batch_code.contains("dispatch_matmul_batch"),
8104 "f32 forward_prefill_batch should use dispatch_matmul_batch"
8105 );
8106 assert!(
8108 !batch_code.contains("dispatch_matmul_q8_batch"),
8109 "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
8110 );
8111 assert!(
8112 !batch_code.contains("dispatch_matmul_q4_batch"),
8113 "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
8114 );
8115 }
8116
8117 #[test]
8118 fn forward_uses_cpu_embedding_lookup() {
8119 let config = minimal_config();
8120 let model_rs = generate_model_rs(&config).unwrap();
8121
8122 let forward_start = model_rs
8124 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8125 .unwrap();
8126 let forward_body = &model_rs[forward_start..];
8127 let forward_end = forward_body
8128 .find("\n pub fn forward_profile")
8129 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8130 .unwrap_or(forward_body.len());
8131 let forward_code = &forward_body[..forward_end];
8132
8133 assert!(
8135 forward_code.contains("embed_buf.contents()"),
8136 "forward() should access embed_buf via CPU unified memory for embedding lookup"
8137 );
8138 assert!(
8139 forward_code.contains("copy_nonoverlapping"),
8140 "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
8141 );
8142 assert!(
8144 !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
8145 "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
8146 );
8147 }
8148
8149 #[test]
8150 fn forward_profile_method_exists() {
8151 let config = minimal_config();
8152 let model_rs = generate_model_rs(&config).unwrap();
8153
8154 assert!(
8155 model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
8156 "MetalModel should have forward_profile() method"
8157 );
8158 assert!(
8160 model_rs.contains("[profile]"),
8161 "forward_profile() should print timing with [profile] prefix"
8162 );
8163 assert!(
8164 model_rs.contains("d_embed"),
8165 "forward_profile() should measure embedding time"
8166 );
8167 assert!(
8168 model_rs.contains("d_layers"),
8169 "forward_profile() should measure layer time"
8170 );
8171 assert!(
8172 model_rs.contains("d_logits"),
8173 "forward_profile() should measure logits time"
8174 );
8175 }
8176
8177 #[test]
8178 fn generated_cli_has_profile_flag() {
8179 let config = minimal_config();
8180 let main_rs = generate_main_rs("test-model", &config).unwrap();
8181
8182 assert!(
8183 main_rs.contains("--profile"),
8184 "CLI should support --profile flag"
8185 );
8186 assert!(
8187 main_rs.contains("forward_profile"),
8188 "CLI should call forward_profile when --profile is set"
8189 );
8190 }
8191
8192 #[test]
8193 fn generated_cli_has_thermal_yield() {
8194 let config = minimal_config();
8195 let main_rs = generate_main_rs("test-model", &config).unwrap();
8196
8197 assert!(
8198 main_rs.contains("yield_now()"),
8199 "CLI generation loop should include thread::yield_now() for thermal management"
8200 );
8201 }
8202
8203 #[test]
8206 fn generated_forward_handles_single_token_prompt() {
8207 let config = minimal_config();
8211 let model_rs = generate_model_rs(&config).unwrap();
8212
8213 let forward_start = model_rs
8215 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8216 .expect("forward() must exist");
8217 let forward_body = &model_rs[forward_start..forward_start + 400];
8218
8219 assert!(
8221 !forward_body.contains("assert!(self.pos > 0"),
8222 "forward() must accept pos=0 (first token with no prefill)"
8223 );
8224
8225 assert!(
8227 model_rs.contains("self.pos"),
8228 "forward() should use self.pos to track sequence position"
8229 );
8230 }
8231
8232 #[test]
8233 fn generated_reset_clears_kv_cache_position() {
8234 let config = minimal_config();
8237 let model_rs = generate_model_rs(&config).unwrap();
8238
8239 let reset_start = model_rs
8240 .find("pub fn reset(&mut self)")
8241 .expect("reset() must exist");
8242 let reset_body = &model_rs[reset_start..reset_start + 200];
8243
8244 assert!(
8246 reset_body.contains("self.pos = 0"),
8247 "reset() must set self.pos = 0"
8248 );
8249
8250 assert!(
8252 reset_body.contains("self.prev_cmd = None"),
8253 "reset() should clear prev_cmd for clean command buffer state"
8254 );
8255 }
8256
8257 #[test]
8258 fn generated_serve_handles_empty_messages_gracefully() {
8259 let config = minimal_config();
8262 let main_rs = generate_main_rs("test-model", &config).unwrap();
8263
8264 let format_fn_start = main_rs
8266 .find("fn format_chat_messages")
8267 .expect("format_chat_messages must exist");
8268 let format_fn_body =
8269 &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
8270
8271 assert!(
8273 format_fn_body.contains("for msg in messages"),
8274 "format_chat_messages should iterate over the messages slice"
8275 );
8276 assert!(
8278 format_fn_body.contains("<|im_start|>assistant"),
8279 "format_chat_messages should always append assistant prompt header"
8280 );
8281
8282 let serve_fn_start = main_rs
8284 .find("fn serve(")
8285 .expect("serve function must exist");
8286 let serve_fn_body = &main_rs[serve_fn_start..];
8287 assert!(
8288 serve_fn_body.contains("model.reset()"),
8289 "serve function should reset model between requests"
8290 );
8291 }
8292
8293 #[test]
8294 fn generated_model_forward_increments_pos() {
8295 let config = minimal_config();
8298 let model_rs = generate_model_rs(&config).unwrap();
8299
8300 let forward_start = model_rs
8301 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
8302 .unwrap();
8303 let forward_body = &model_rs[forward_start..];
8304 let forward_end = forward_body
8305 .find("\n pub fn forward_profile")
8306 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
8307 .or_else(|| forward_body.find("\n fn dispatch_"))
8308 .unwrap_or(forward_body.len());
8309 let forward_code = &forward_body[..forward_end];
8310
8311 assert!(
8312 forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
8313 "forward() must increment self.pos after processing a token"
8314 );
8315 }
8316}