1use std::fmt::Write as FmtWrite;
8use std::fs;
9use std::path::Path;
10
11use forgellm_frontend::ir::*;
12
13#[derive(Debug, thiserror::Error)]
15pub enum MetalCodegenError {
16 #[error("graph has no model config")]
18 MissingConfig,
19
20 #[error("I/O error: {0}")]
22 Io(#[from] std::io::Error),
23
24 #[error("format error: {0}")]
26 Fmt(#[from] std::fmt::Error),
27}
28
29pub fn generate_metal_project(
37 graph: &Graph,
38 output_dir: &Path,
39 model_name: &str,
40) -> Result<(), MetalCodegenError> {
41 let config = graph
42 .config
43 .as_ref()
44 .ok_or(MetalCodegenError::MissingConfig)?;
45
46 let src_dir = output_dir.join("src");
47 let shader_dir = output_dir.join("shaders");
48 fs::create_dir_all(&src_dir)?;
49 fs::create_dir_all(&shader_dir)?;
50
51 fs::write(
52 output_dir.join("Cargo.toml"),
53 generate_cargo_toml(model_name),
54 )?;
55
56 fs::write(
57 shader_dir.join("kernels.metal"),
58 generate_metal_shaders(config),
59 )?;
60
61 let model_rs = generate_model_rs(config)?;
62 fs::write(src_dir.join("model.rs"), model_rs)?;
63
64 let main_rs = generate_main_rs(model_name, config)?;
65 fs::write(src_dir.join("main.rs"), main_rs)?;
66
67 Ok(())
68}
69
70fn sanitize_name(name: &str) -> String {
75 name.to_lowercase()
76 .replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
77 .trim_matches('-')
78 .to_string()
79}
80
81fn generate_cargo_toml(model_name: &str) -> String {
82 let sanitized = sanitize_name(model_name);
83 format!(
84 r#"[package]
85name = "{sanitized}"
86version = "0.1.0"
87edition = "2021"
88
89[[bin]]
90name = "{sanitized}"
91path = "src/main.rs"
92
93[dependencies]
94metal = "0.29"
95objc = "0.2"
96half = "2"
97tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
98memmap2 = "0.9"
99tiny_http = "0.12"
100serde = {{ version = "1", features = ["derive"] }}
101serde_json = "1"
102
103[profile.release]
104opt-level = 3
105lto = "fat"
106codegen-units = 1
107"#
108 )
109}
110
111fn generate_metal_shaders(config: &ModelConfig) -> String {
116 let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
121 r#"//
122// Auto-generated by ForgeLLM Metal codegen.
123// Metal Shading Language compute kernels for transformer inference.
124//
125// Optimized with simdgroup cooperative reductions, shared memory vector
126// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
127// lane, and fast:: math intrinsics for Apple Silicon throughput.
128//
129
130#include <metal_stdlib>
131using namespace metal;
132
133// ── Constants ───────────────────────────────────────────────────────────
134// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
135// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
136// per shared memory load vs 4-row, improving ILP and reducing launches.
137constant constexpr uint SIMDGROUPS_PER_TG = 8;
138constant constexpr uint ROWS_PER_SIMDGROUP = 8;
139constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
140
141// ── matmul_vec ──────────────────────────────────────────────────────────
142// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
143// Uses simdgroup cooperative dot product with shared memory vector caching
144// and float4 vectorized loads. Each simdgroup processes 8 rows for better
145// shared memory reuse (8x vector reuse per load) and instruction-level
146// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
147kernel void matmul_vec(
148 device const float* matrix [[buffer(0)]],
149 device const float* vector [[buffer(1)]],
150 device float* output [[buffer(2)]],
151 constant uint& rows [[buffer(3)]],
152 constant uint& cols [[buffer(4)]],
153 uint tgid [[threadgroup_position_in_grid]],
154 uint tid [[thread_index_in_threadgroup]],
155 uint simd_lane [[thread_index_in_simdgroup]],
156 uint simd_id [[simdgroup_index_in_threadgroup]])
157{
158 // Cooperatively load vector into threadgroup shared memory
159 threadgroup float vec_tile[VEC_TILE_SIZE]; // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
160 for (uint i = tid; i < cols; i += 256) {
161 vec_tile[i] = vector[i];
162 }
163 threadgroup_barrier(mem_flags::mem_threadgroup);
164
165 // Each simdgroup handles 8 consecutive rows
166 uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
167 if (row_base >= rows) return;
168
169 uint base0 = row_base * cols;
170 uint base1 = (row_base + 1) * cols;
171 uint base2 = (row_base + 2) * cols;
172 uint base3 = (row_base + 3) * cols;
173 uint base4 = (row_base + 4) * cols;
174 uint base5 = (row_base + 5) * cols;
175 uint base6 = (row_base + 6) * cols;
176 uint base7 = (row_base + 7) * cols;
177
178 // float4 vectorized accumulation across 8 rows
179 uint cols_vec4 = cols & ~127u; // largest multiple of 128 <= cols
180 float4 sum4_0 = float4(0.0f);
181 float4 sum4_1 = float4(0.0f);
182 float4 sum4_2 = float4(0.0f);
183 float4 sum4_3 = float4(0.0f);
184 float4 sum4_4 = float4(0.0f);
185 float4 sum4_5 = float4(0.0f);
186 float4 sum4_6 = float4(0.0f);
187 float4 sum4_7 = float4(0.0f);
188
189 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
190 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
191 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
192 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
193 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
194 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
195 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
196 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
197 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
198 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
199 }
200
201 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
202 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
203 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
204 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
205 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
206 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
207 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
208 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
209
210 // Handle remaining elements (cols not divisible by 128)
211 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
212 float vv = vec_tile[j];
213 sum0 += matrix[base0 + j] * vv;
214 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
215 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
216 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
217 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
218 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
219 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
220 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
221 }
222
223 // Simdgroup hardware warp-level reduction
224 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
225 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
226 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
227 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
228
229 // Only first lane writes the results
230 if (simd_lane == 0) {
231 if (row_base < rows) output[row_base] = sum0;
232 if (row_base + 1 < rows) output[row_base + 1] = sum1;
233 if (row_base + 2 < rows) output[row_base + 2] = sum2;
234 if (row_base + 3 < rows) output[row_base + 3] = sum3;
235 if (row_base + 4 < rows) output[row_base + 4] = sum4;
236 if (row_base + 5 < rows) output[row_base + 5] = sum5;
237 if (row_base + 6 < rows) output[row_base + 6] = sum6;
238 if (row_base + 7 < rows) output[row_base + 7] = sum7;
239 }
240}
241
242// ── rms_norm ────────────────────────────────────────────────────────────
243// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
244// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
245// via shared memory for minimal synchronization overhead.
246kernel void rms_norm(
247 device const float* input [[buffer(0)]],
248 device const float* weight [[buffer(1)]],
249 device float* output [[buffer(2)]],
250 constant uint& n [[buffer(3)]],
251 constant float& eps [[buffer(4)]],
252 uint tid [[thread_index_in_threadgroup]])
253{
254 // Each thread accumulates partial sum-of-squares
255 float sum_sq = 0.0f;
256 for (uint i = tid; i < n; i += 256) {
257 float v = input[i];
258 sum_sq += v * v;
259 }
260
261 // Simdgroup-level reduction (hardware warp sum)
262 sum_sq = simd_sum(sum_sq);
263
264 // Cross-simdgroup reduction via shared memory
265 threadgroup float shared[8];
266 uint simd_id = tid / 32;
267 uint simd_lane = tid % 32;
268 if (simd_lane == 0) {
269 shared[simd_id] = sum_sq;
270 }
271 threadgroup_barrier(mem_flags::mem_threadgroup);
272
273 // First thread computes final inverse RMS
274 if (tid == 0) {
275 float total = 0.0f;
276 for (uint i = 0; i < 8; i++) {
277 total += shared[i];
278 }
279 shared[0] = fast::rsqrt(total / float(n) + eps);
280 }
281 threadgroup_barrier(mem_flags::mem_threadgroup);
282
283 float inv_rms = shared[0];
284
285 // Normalize
286 for (uint i = tid; i < n; i += 256) {
287 output[i] = input[i] * inv_rms * weight[i];
288 }
289}
290
291// ── rope ────────────────────────────────────────────────────────────────
292// Rotary Position Embedding applied in-place.
293// Each thread handles one (head, pair) combination.
294kernel void rope(
295 device float* data [[buffer(0)]],
296 constant uint& num_heads [[buffer(1)]],
297 constant uint& head_dim [[buffer(2)]],
298 constant uint& pos [[buffer(3)]],
299 constant float& theta [[buffer(4)]],
300 uint id [[thread_position_in_grid]])
301{
302 uint half_dim = head_dim / 2;
303 uint total_pairs = num_heads * half_dim;
304 if (id >= total_pairs) return;
305
306 uint h = id / half_dim;
307 uint i = id % half_dim;
308 uint off = h * head_dim;
309
310 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
311 float angle = float(pos) * freq;
312 float c = cos(angle);
313 float s = sin(angle);
314
315 float x0 = data[off + 2 * i];
316 float x1 = data[off + 2 * i + 1];
317 data[off + 2 * i] = x0 * c - x1 * s;
318 data[off + 2 * i + 1] = x0 * s + x1 * c;
319}
320
321// ── softmax ─────────────────────────────────────────────────────────────
322// Numerically stable softmax over a 1-D array.
323// Single-threadgroup kernel with cooperative reduction.
324kernel void softmax(
325 device float* data [[buffer(0)]],
326 constant uint& n [[buffer(1)]],
327 uint tid [[thread_index_in_threadgroup]],
328 uint tg_size [[threads_per_threadgroup]])
329{
330 threadgroup float shared_val[256];
331
332 // Pass 1: find max
333 float local_max = -INFINITY;
334 for (uint i = tid; i < n; i += tg_size) {
335 local_max = max(local_max, data[i]);
336 }
337 shared_val[tid] = local_max;
338 threadgroup_barrier(mem_flags::mem_threadgroup);
339
340 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
341 if (tid < stride) {
342 shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
343 }
344 threadgroup_barrier(mem_flags::mem_threadgroup);
345 }
346 float max_val = shared_val[0];
347 threadgroup_barrier(mem_flags::mem_threadgroup);
348
349 // Pass 2: exp and sum
350 float local_sum = 0.0f;
351 for (uint i = tid; i < n; i += tg_size) {
352 float e = fast::exp(data[i] - max_val);
353 data[i] = e;
354 local_sum += e;
355 }
356 shared_val[tid] = local_sum;
357 threadgroup_barrier(mem_flags::mem_threadgroup);
358
359 for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
360 if (tid < stride) {
361 shared_val[tid] += shared_val[tid + stride];
362 }
363 threadgroup_barrier(mem_flags::mem_threadgroup);
364 }
365 float inv_sum = 1.0f / shared_val[0];
366 threadgroup_barrier(mem_flags::mem_threadgroup);
367
368 // Pass 3: normalize
369 for (uint i = tid; i < n; i += tg_size) {
370 data[i] *= inv_sum;
371 }
372}
373
374// ── silu_mul ────────────────────────────────────────────────────────────
375// Fused SiLU activation * element-wise multiply:
376// output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
377kernel void silu_mul(
378 device const float* gate [[buffer(0)]],
379 device const float* up [[buffer(1)]],
380 device float* output [[buffer(2)]],
381 constant uint& n [[buffer(3)]],
382 uint id [[thread_position_in_grid]])
383{
384 if (id >= n) return;
385 float g = gate[id];
386 output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
387}
388
389// ── silu_mul_fused ─────────────────────────────────────────────────────
390// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
391// gate = gate_up[0..n], up = gate_up[n..2*n]
392// output[i] = silu(gate_up[i]) * gate_up[n + i]
393kernel void silu_mul_fused(
394 device const float* gate_up [[buffer(0)]],
395 device float* output [[buffer(1)]],
396 constant uint& n [[buffer(2)]],
397 uint id [[thread_position_in_grid]])
398{
399 if (id >= n) return;
400 float g = gate_up[id];
401 float u = gate_up[n + id];
402 output[id] = (g / (1.0f + fast::exp(-g))) * u;
403}
404
405// ── elementwise_add ─────────────────────────────────────────────────────
406// Residual connection: output[i] = a[i] + b[i]
407kernel void elementwise_add(
408 device const float* a [[buffer(0)]],
409 device const float* b [[buffer(1)]],
410 device float* output [[buffer(2)]],
411 constant uint& n [[buffer(3)]],
412 uint id [[thread_position_in_grid]])
413{
414 if (id >= n) return;
415 output[id] = a[id] + b[id];
416}
417
418// ── copy_buffer ─────────────────────────────────────────────────────────
419// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
420// transitions. Used for KV cache updates and embedding lookup.
421kernel void copy_buffer(
422 device const float* src [[buffer(0)]],
423 device float* dst [[buffer(1)]],
424 constant uint& count [[buffer(2)]],
425 uint id [[thread_position_in_grid]])
426{
427 if (id < count) dst[id] = src[id];
428}
429
430// ── copy_offset ─────────────────────────────────────────────────────────
431// Copy with source offset (in floats). Used for embedding table lookup
432// where we need to copy a specific row from a large table.
433kernel void copy_offset(
434 device const float* src [[buffer(0)]],
435 device float* dst [[buffer(1)]],
436 constant uint& src_offset [[buffer(2)]], // in floats
437 constant uint& count [[buffer(3)]],
438 uint id [[thread_position_in_grid]])
439{
440 if (id < count) dst[id] = src[src_offset + id];
441}
442
443// ── add_inplace ─────────────────────────────────────────────────────────
444// In-place residual connection: a[i] += b[i]
445// Avoids a separate blit copy for residual add, reducing encoder overhead.
446kernel void add_inplace(
447 device float* a [[buffer(0)]],
448 device const float* b [[buffer(1)]],
449 constant uint& n [[buffer(2)]],
450 uint id [[thread_position_in_grid]])
451{
452 if (id >= n) return;
453 a[id] += b[id];
454}
455
456// ── matmul_vec_q8 ─────────────────────────────────────────────────────
457// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
458// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
459// Operates directly on quantized weights to halve memory bandwidth vs f32,
460// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
461//
462// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
463// because int8->float conversion doubles register demand. Fully unrolled
464// inner loop with float4 vector loads from shared memory eliminates loop
465// overhead and enables better instruction scheduling.
466// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
467constant constexpr uint Q8_ROWS_PER_SG = 4;
468constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
469
470// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
471// doubles ALU work, so the same register budget applies).
472constant constexpr uint Q4_ROWS_PER_SG = 4;
473constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
474
475kernel void matmul_vec_q8(
476 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes
477 device const float* vector [[buffer(1)]], // f32 input
478 device float* output [[buffer(2)]],
479 constant uint& rows [[buffer(3)]],
480 constant uint& cols [[buffer(4)]], // number of elements per row
481 uint tgid [[threadgroup_position_in_grid]],
482 uint tid [[thread_index_in_threadgroup]],
483 uint simd_lane [[thread_index_in_simdgroup]],
484 uint simd_id [[simdgroup_index_in_threadgroup]])
485{
486 // Load vector into shared memory
487 threadgroup float vec_tile[VEC_TILE_SIZE];
488 for (uint i = tid; i < cols; i += 256) {
489 vec_tile[i] = vector[i];
490 }
491 threadgroup_barrier(mem_flags::mem_threadgroup);
492
493 // Each simdgroup handles 4 consecutive rows (lower register pressure)
494 uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
495 if (row_base >= rows) return;
496
497 // Q8_0: each block is 34 bytes for 32 elements
498 uint blocks_per_row = cols / 32;
499 uint row_bytes = blocks_per_row * 34;
500
501 // Pointers to each row's Q8_0 data
502 device const uchar* r0 = matrix + row_base * row_bytes;
503 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
504 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
505 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
506
507 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
508
509 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
510 uint bb = blk * 34;
511 uint vb = blk * 32;
512
513 // Prefetch all 4 scales
514 float sc0 = float(*(device const half*)(r0 + bb));
515 float sc1 = float(*(device const half*)(r1 + bb));
516 float sc2 = float(*(device const half*)(r2 + bb));
517 float sc3 = float(*(device const half*)(r3 + bb));
518
519 // Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
520 // Q8_0 block layout where the int8 data starts at offset +2 from a
521 // 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
522 // so 4 loads per row per block vs the previous 8 char4 loads — a 2x
523 // reduction in memory transactions. Metal's char16/packed_char16 are
524 // reserved types and packed_*int4 require >=4-byte alignment which
525 // this layout does not provide, so packed_short4 is the widest valid
526 // vectorized load.
527 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
528 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
529 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
530 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
531
532 // Load all 8 float4 vector values for this 32-element block from shared memory
533 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
534 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
535 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
536 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
537 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
538 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
539 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
540 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
541
542 // Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
543 // char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
544 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
545 short4 _s = short4(SHORT4); \
546 char2 _a = as_type<char2>(_s.x); \
547 char2 _b = as_type<char2>(_s.y); \
548 char2 _c = as_type<char2>(_s.z); \
549 char2 _d = as_type<char2>(_s.w); \
550 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
551 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
552 }
553
554 float4 f0, f1;
555 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
556
557 // Row 0: 4 short4 loads cover 32 int8 weights
558 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
559 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
560 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
561 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
562
563 // Row 1
564 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
565 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
566 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
567 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
568
569 // Row 2
570 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
571 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
572 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
573 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
574
575 // Row 3
576 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
577 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
578 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
579 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
580
581 #undef Q8_UNPACK8
582
583 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
584 }
585
586 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
587 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
588
589 if (simd_lane == 0) {
590 if (row_base < rows) output[row_base] = sum0;
591 if (row_base + 1 < rows) output[row_base + 1] = sum1;
592 if (row_base + 2 < rows) output[row_base + 2] = sum2;
593 if (row_base + 3 < rows) output[row_base + 3] = sum3;
594 }
595}
596
597// ── matmul_vec_q4 ─────────────────────────────────────────────────────
598// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
599// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
600// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
601// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
602//
603// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
604// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
605kernel void matmul_vec_q4(
606 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes
607 device const float* vector [[buffer(1)]], // f32 input
608 device float* output [[buffer(2)]],
609 constant uint& rows [[buffer(3)]],
610 constant uint& cols [[buffer(4)]], // number of elements per row
611 uint tgid [[threadgroup_position_in_grid]],
612 uint tid [[thread_index_in_threadgroup]],
613 uint simd_lane [[thread_index_in_simdgroup]],
614 uint simd_id [[simdgroup_index_in_threadgroup]])
615{
616 // Load vector into shared memory
617 threadgroup float vec_tile[VEC_TILE_SIZE];
618 for (uint i = tid; i < cols; i += 256) {
619 vec_tile[i] = vector[i];
620 }
621 threadgroup_barrier(mem_flags::mem_threadgroup);
622
623 // Each simdgroup handles 4 consecutive rows
624 uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
625 if (row_base >= rows) return;
626
627 // Q4_0: each block is 18 bytes for 32 elements
628 uint blocks_per_row = cols / 32;
629 uint row_bytes = blocks_per_row * 18;
630
631 // Pointers to each row's Q4_0 data
632 device const uchar* r0 = matrix + row_base * row_bytes;
633 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
634 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
635 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
636
637 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
638
639 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
640 uint bb = blk * 18;
641 uint vb = blk * 32;
642
643 // Prefetch all 4 scales
644 float sc0 = float(*(device const half*)(r0 + bb));
645 float sc1 = float(*(device const half*)(r1 + bb));
646 float sc2 = float(*(device const half*)(r2 + bb));
647 float sc3 = float(*(device const half*)(r3 + bb));
648
649 // Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
650 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
651 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
652 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
653 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
654
655 // Load 8 float4 vector values for 32 elements from shared memory
656 // Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
657 float4 v0 = *(threadgroup const float4*)(vec_tile + vb); // [0..3]
658 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4); // [4..7]
659 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8); // [8..11]
660 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12); // [12..15]
661 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16); // [16..19]
662 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20); // [20..23]
663 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24); // [24..27]
664 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28); // [28..31]
665
666 // Fully unrolled block dot products — 4 rows x 4 uchar4 reads
667 // Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
668 float bd0=0, bd1=0, bd2=0, bd3=0;
669 uchar4 b;
670
671 // Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
672 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
673 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
674 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
675 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
676 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
677 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
678 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
679 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
680 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
681 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
682 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
683 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
684 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
685 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
686 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
687 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
688
689 // Row 1
690 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
691 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
692 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
693 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
694 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
695 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
696 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
697 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
698 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
699 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
700 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
701 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
702 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
703 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
704 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
705 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
706
707 // Row 2
708 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
709 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
710 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
711 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
712 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
713 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
714 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
715 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
716 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
717 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
718 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
719 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
720 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
721 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
722 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
723 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
724
725 // Row 3
726 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
727 +float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
728 +float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
729 +float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
730 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
731 +float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
732 +float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
733 +float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
734 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
735 +float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
736 +float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
737 +float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
738 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
739 +float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
740 +float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
741 +float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
742
743 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
744 }
745
746 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
747 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
748
749 if (simd_lane == 0) {
750 if (row_base < rows) output[row_base] = sum0;
751 if (row_base + 1 < rows) output[row_base + 1] = sum1;
752 if (row_base + 2 < rows) output[row_base + 2] = sum2;
753 if (row_base + 3 < rows) output[row_base + 3] = sum3;
754 }
755}
756
757// ── attention ───────────────────────────────────────────────────────────
758// Single-query attention with simdgroup cooperative reductions.
759// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
760// with simd_max/simd_sum reductions, then weighted sum of V.
761// Each threadgroup handles one head with 256 threads (8 simdgroups).
762//
763// Buffers:
764// q: [num_heads * head_dim] current query
765// k_cache: [max_seq_len * num_kv_heads * head_dim]
766// v_cache: [max_seq_len * num_kv_heads * head_dim]
767// output: [num_heads * head_dim]
768kernel void attention(
769 device const float* q [[buffer(0)]],
770 device const float* k_cache [[buffer(1)]],
771 device const float* v_cache [[buffer(2)]],
772 device float* output [[buffer(3)]],
773 constant uint& seq_len [[buffer(4)]],
774 constant uint& num_heads [[buffer(5)]],
775 constant uint& num_kv_heads [[buffer(6)]],
776 constant uint& head_dim [[buffer(7)]],
777 uint tgid [[threadgroup_position_in_grid]],
778 uint tid [[thread_index_in_threadgroup]],
779 uint simd_lane [[thread_index_in_simdgroup]],
780 uint simd_id [[simdgroup_index_in_threadgroup]])
781{
782 uint head = tgid;
783 if (head >= num_heads) return;
784 uint kv_head = head / (num_heads / num_kv_heads);
785
786 uint q_off = head * head_dim;
787
788 // Step 1: Compute attention scores Q·K^T with simdgroup reduction
789 // Use shared memory for scores — 2048 entries (8 KB) saves TG memory
790 // vs 4096. For seq_len > 2048, generation-phase attention is rare;
791 // most generation steps have short effective context.
792 threadgroup float scores[2048]; // max seq_len for generation phase
793
794 for (uint s = simd_id; s < seq_len; s += 8) {
795 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
796 float dot = 0.0;
797 for (uint d = simd_lane; d < head_dim; d += 32) {
798 dot += q[q_off + d] * k_cache[k_off + d];
799 }
800 dot = simd_sum(dot);
801 if (simd_lane == 0) {
802 scores[s] = dot * fast::rsqrt(float(head_dim));
803 }
804 }
805 threadgroup_barrier(mem_flags::mem_threadgroup);
806
807 // Step 2: Softmax over scores (cooperative)
808 // Find max
809 float local_max = -INFINITY;
810 for (uint s = tid; s < seq_len; s += 256) {
811 local_max = max(local_max, scores[s]);
812 }
813 local_max = simd_max(local_max);
814 threadgroup float shared_max[8];
815 if (simd_lane == 0) shared_max[simd_id] = local_max;
816 threadgroup_barrier(mem_flags::mem_threadgroup);
817 if (tid == 0) {
818 float m = shared_max[0];
819 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
820 shared_max[0] = m;
821 }
822 threadgroup_barrier(mem_flags::mem_threadgroup);
823 float max_val = shared_max[0];
824
825 // Exp and sum
826 float local_sum = 0.0;
827 for (uint s = tid; s < seq_len; s += 256) {
828 scores[s] = fast::exp(scores[s] - max_val);
829 local_sum += scores[s];
830 }
831 local_sum = simd_sum(local_sum);
832 threadgroup float shared_sum[8];
833 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
834 threadgroup_barrier(mem_flags::mem_threadgroup);
835 if (tid == 0) {
836 float total = 0.0;
837 for (uint i = 0; i < 8; i++) total += shared_sum[i];
838 shared_sum[0] = 1.0 / total;
839 }
840 threadgroup_barrier(mem_flags::mem_threadgroup);
841 float inv_sum = shared_sum[0];
842
843 for (uint s = tid; s < seq_len; s += 256) {
844 scores[s] *= inv_sum;
845 }
846 threadgroup_barrier(mem_flags::mem_threadgroup);
847
848 // Step 3: Weighted sum of V: output = scores · V
849 // Each thread handles a range of head_dim dimensions.
850 // Process 4 sequence positions at a time for better ILP and reduced
851 // loop overhead (float4 score gather, 4 V loads per iteration).
852 uint seq_len4 = seq_len & ~3u; // largest multiple of 4 <= seq_len
853 uint v_stride = num_kv_heads * head_dim;
854 for (uint d = tid; d < head_dim; d += 256) {
855 float acc = 0.0;
856 uint v_base = kv_head * head_dim + d;
857 for (uint s = 0; s < seq_len4; s += 4) {
858 float sc0 = scores[s];
859 float sc1 = scores[s + 1];
860 float sc2 = scores[s + 2];
861 float sc3 = scores[s + 3];
862 acc += sc0 * v_cache[s * v_stride + v_base]
863 + sc1 * v_cache[(s+1) * v_stride + v_base]
864 + sc2 * v_cache[(s+2) * v_stride + v_base]
865 + sc3 * v_cache[(s+3) * v_stride + v_base];
866 }
867 for (uint s = seq_len4; s < seq_len; s++) {
868 acc += scores[s] * v_cache[s * v_stride + v_base];
869 }
870 output[q_off + d] = acc;
871 }
872}
873
874// ── Batched prefill kernels ────────────────────────────────────────────
875// These kernels process M input vectors against the same weight matrix
876// in a single dispatch, converting mat-vec into mat-mat for better GPU
877// utilization during prompt prefill.
878
879// ── rms_norm_batch ─────────────────────────────────────────────────────
880// RMS normalization for a batch of vectors.
881// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
882// Grid: M threadgroups (one per token).
883kernel void rms_norm_batch(
884 device const float* input [[buffer(0)]], // [M, n]
885 device const float* weight [[buffer(1)]], // [n]
886 device float* output [[buffer(2)]], // [M, n]
887 constant uint& n [[buffer(3)]],
888 constant float& eps [[buffer(4)]],
889 constant uint& num_tokens [[buffer(5)]],
890 uint tgid [[threadgroup_position_in_grid]],
891 uint tid [[thread_index_in_threadgroup]])
892{
893 if (tgid >= num_tokens) return;
894
895 uint base = tgid * n;
896
897 float sum_sq = 0.0f;
898 for (uint i = tid; i < n; i += 256) {
899 float v = input[base + i];
900 sum_sq += v * v;
901 }
902
903 sum_sq = simd_sum(sum_sq);
904
905 threadgroup float shared[8];
906 uint simd_id = tid / 32;
907 uint simd_lane = tid % 32;
908 if (simd_lane == 0) {
909 shared[simd_id] = sum_sq;
910 }
911 threadgroup_barrier(mem_flags::mem_threadgroup);
912
913 if (tid == 0) {
914 float total = 0.0f;
915 for (uint i = 0; i < 8; i++) {
916 total += shared[i];
917 }
918 shared[0] = fast::rsqrt(total / float(n) + eps);
919 }
920 threadgroup_barrier(mem_flags::mem_threadgroup);
921
922 float inv_rms = shared[0];
923
924 for (uint i = tid; i < n; i += 256) {
925 output[base + i] = input[base + i] * inv_rms * weight[i];
926 }
927}
928
929// ── rope_batch ─────────────────────────────────────────────────────────
930// Rotary Position Embedding for a batch of vectors with different positions.
931// data layout: [M, num_heads * head_dim], positions: [M]
932// Each thread handles one (token, head, pair) combination.
933kernel void rope_batch(
934 device float* data [[buffer(0)]], // [M, num_heads * head_dim]
935 constant uint& num_heads [[buffer(1)]],
936 constant uint& head_dim [[buffer(2)]],
937 device const uint* positions [[buffer(3)]], // [M] position per token
938 constant float& theta [[buffer(4)]],
939 constant uint& num_tokens [[buffer(5)]],
940 uint id [[thread_position_in_grid]])
941{
942 uint half_dim = head_dim / 2;
943 uint pairs_per_token = num_heads * half_dim;
944 uint total = num_tokens * pairs_per_token;
945 if (id >= total) return;
946
947 uint token = id / pairs_per_token;
948 uint rem = id % pairs_per_token;
949 uint h = rem / half_dim;
950 uint i = rem % half_dim;
951 uint off = token * (num_heads * head_dim) + h * head_dim;
952
953 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
954 float angle = float(positions[token]) * freq;
955 float c = cos(angle);
956 float s = sin(angle);
957
958 float x0 = data[off + 2 * i];
959 float x1 = data[off + 2 * i + 1];
960 data[off + 2 * i] = x0 * c - x1 * s;
961 data[off + 2 * i + 1] = x0 * s + x1 * c;
962}
963
964// ── silu_mul_fused_batch ───────────────────────────────────────────────
965// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
966// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
967kernel void silu_mul_fused_batch(
968 device const float* gate_up [[buffer(0)]], // [M, 2*n]
969 device float* output [[buffer(1)]], // [M, n]
970 constant uint& n [[buffer(2)]],
971 constant uint& num_tokens [[buffer(3)]],
972 uint id [[thread_position_in_grid]])
973{
974 uint total = num_tokens * n;
975 if (id >= total) return;
976 uint token = id / n;
977 uint i = id % n;
978 uint gu_base = token * 2 * n;
979 float g = gate_up[gu_base + i];
980 float u = gate_up[gu_base + n + i];
981 output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
982}
983
984// ── add_inplace_batch ──────────────────────────────────────────────────
985// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
986kernel void add_inplace_batch(
987 device float* a [[buffer(0)]], // [M * n]
988 device const float* b [[buffer(1)]], // [M * n]
989 constant uint& total [[buffer(2)]], // M * n
990 uint id [[thread_position_in_grid]])
991{
992 if (id >= total) return;
993 a[id] += b[id];
994}
995
996// ── copy_embedding_batch ───────────────────────────────────────────────
997// Copy M embedding rows from embedding table to a contiguous batch buffer.
998// tokens: [M] array of token IDs, each selects a row of `dim` floats.
999kernel void copy_embedding_batch(
1000 device const float* embed [[buffer(0)]], // [vocab_size, dim]
1001 device float* output [[buffer(1)]], // [M, dim]
1002 device const uint* tokens [[buffer(2)]], // [M]
1003 constant uint& dim [[buffer(3)]],
1004 constant uint& num_tokens [[buffer(4)]],
1005 uint id [[thread_position_in_grid]])
1006{
1007 uint total = num_tokens * dim;
1008 if (id >= total) return;
1009 uint token_idx = id / dim;
1010 uint d = id % dim;
1011 output[id] = embed[tokens[token_idx] * dim + d];
1012}
1013
1014// ── matmul_vec_batch ───────────────────────────────────────────────────
1015// Batched matrix-vector multiply: process M input vectors against the same
1016// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
1017// Each threadgroup handles one (token, row_group) pair.
1018kernel void matmul_vec_batch(
1019 device const float* matrix [[buffer(0)]], // [rows, cols] weight
1020 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1021 device float* outputs [[buffer(2)]], // [M, rows] output batch
1022 constant uint& num_tokens [[buffer(3)]], // M
1023 constant uint& rows [[buffer(4)]],
1024 constant uint& cols [[buffer(5)]],
1025 uint tgid [[threadgroup_position_in_grid]],
1026 uint tid [[thread_index_in_threadgroup]],
1027 uint simd_lane [[thread_index_in_simdgroup]],
1028 uint simd_id [[simdgroup_index_in_threadgroup]])
1029{
1030 uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
1031 uint token = tgid / row_tgs;
1032 uint tg_in_token = tgid % row_tgs;
1033 if (token >= num_tokens) return;
1034
1035 // Load this token's input vector into shared memory
1036 threadgroup float vec_tile[VEC_TILE_SIZE];
1037 device const float* input = inputs + token * cols;
1038 for (uint i = tid; i < cols; i += 256) {
1039 vec_tile[i] = input[i];
1040 }
1041 threadgroup_barrier(mem_flags::mem_threadgroup);
1042
1043 uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
1044 if (row_base >= rows) return;
1045
1046 uint base0 = row_base * cols;
1047 uint base1 = (row_base + 1) * cols;
1048 uint base2 = (row_base + 2) * cols;
1049 uint base3 = (row_base + 3) * cols;
1050 uint base4 = (row_base + 4) * cols;
1051 uint base5 = (row_base + 5) * cols;
1052 uint base6 = (row_base + 6) * cols;
1053 uint base7 = (row_base + 7) * cols;
1054
1055 uint cols_vec4 = cols & ~127u;
1056 float4 sum4_0 = float4(0.0f);
1057 float4 sum4_1 = float4(0.0f);
1058 float4 sum4_2 = float4(0.0f);
1059 float4 sum4_3 = float4(0.0f);
1060 float4 sum4_4 = float4(0.0f);
1061 float4 sum4_5 = float4(0.0f);
1062 float4 sum4_6 = float4(0.0f);
1063 float4 sum4_7 = float4(0.0f);
1064
1065 for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
1066 float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
1067 sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
1068 if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
1069 if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
1070 if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
1071 if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
1072 if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
1073 if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
1074 if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
1075 }
1076
1077 float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
1078 float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
1079 float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
1080 float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
1081 float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
1082 float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
1083 float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
1084 float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
1085
1086 for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
1087 float vv = vec_tile[j];
1088 sum0 += matrix[base0 + j] * vv;
1089 if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
1090 if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
1091 if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
1092 if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
1093 if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
1094 if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
1095 if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
1096 }
1097
1098 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1099 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1100 sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
1101 sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
1102
1103 device float* output = outputs + token * rows;
1104 if (simd_lane == 0) {
1105 if (row_base < rows) output[row_base] = sum0;
1106 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1107 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1108 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1109 if (row_base + 4 < rows) output[row_base + 4] = sum4;
1110 if (row_base + 5 < rows) output[row_base + 5] = sum5;
1111 if (row_base + 6 < rows) output[row_base + 6] = sum6;
1112 if (row_base + 7 < rows) output[row_base + 7] = sum7;
1113 }
1114}
1115
1116// ── matmul_vec_q8_batch ────────────────────────────────────────────────
1117// Batched Q8_0 matrix-vector multiply for M input vectors.
1118// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
1119kernel void matmul_vec_q8_batch(
1120 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1121 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1122 device float* outputs [[buffer(2)]], // [M, rows] output batch
1123 constant uint& num_tokens [[buffer(3)]], // M
1124 constant uint& rows [[buffer(4)]],
1125 constant uint& cols [[buffer(5)]],
1126 uint tgid [[threadgroup_position_in_grid]],
1127 uint tid [[thread_index_in_threadgroup]],
1128 uint simd_lane [[thread_index_in_simdgroup]],
1129 uint simd_id [[simdgroup_index_in_threadgroup]])
1130{
1131 uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
1132 uint token = tgid / row_tgs;
1133 uint tg_in_token = tgid % row_tgs;
1134 if (token >= num_tokens) return;
1135
1136 // Load this token's input vector into shared memory
1137 threadgroup float vec_tile[VEC_TILE_SIZE];
1138 device const float* input = inputs + token * cols;
1139 for (uint i = tid; i < cols; i += 256) {
1140 vec_tile[i] = input[i];
1141 }
1142 threadgroup_barrier(mem_flags::mem_threadgroup);
1143
1144 uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1145 if (row_base >= rows) return;
1146
1147 uint blocks_per_row = cols / 32;
1148 uint row_bytes = blocks_per_row * 34;
1149
1150 device const uchar* r0 = matrix + row_base * row_bytes;
1151 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1152 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1153 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1154
1155 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1156
1157 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1158 uint bb = blk * 34;
1159 uint vb = blk * 32;
1160
1161 float sc0 = float(*(device const half*)(r0 + bb));
1162 float sc1 = float(*(device const half*)(r1 + bb));
1163 float sc2 = float(*(device const half*)(r2 + bb));
1164 float sc3 = float(*(device const half*)(r3 + bb));
1165
1166 // Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
1167 // row per block vs 8 char4 loads — 2x reduction in memory transactions.
1168 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1169 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1170 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1171 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1172
1173 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1174 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1175 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1176 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1177 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1178 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1179 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1180 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1181
1182 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1183 short4 _s = short4(SHORT4); \
1184 char2 _a = as_type<char2>(_s.x); \
1185 char2 _b = as_type<char2>(_s.y); \
1186 char2 _c = as_type<char2>(_s.z); \
1187 char2 _d = as_type<char2>(_s.w); \
1188 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1189 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1190 }
1191
1192 float4 f0, f1;
1193 float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
1194
1195 Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
1196 Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
1197 Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
1198 Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
1199
1200 Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
1201 Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
1202 Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
1203 Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
1204
1205 Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
1206 Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
1207 Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
1208 Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
1209
1210 Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
1211 Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
1212 Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
1213 Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
1214
1215 #undef Q8_UNPACK8
1216
1217 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1218 }
1219
1220 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1221 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1222
1223 device float* output = outputs + token * rows;
1224 if (simd_lane == 0) {
1225 if (row_base < rows) output[row_base] = sum0;
1226 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1227 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1228 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1229 }
1230}
1231
1232// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
1233// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
1234// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
1235// the Q8_0 weight blocks are fetched once from device memory and reused for
1236// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
1237// per-token dispatch).
1238//
1239// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
1240// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
1241// with simd_sum. Token vectors are read directly from device memory inside
1242// the block loop (not cached in shared memory) so intermediate_size up to
1243// 8192 fits without spilling threadgroup memory.
1244constant constexpr uint TOKENS_PER_TG_Q8 = 4;
1245
1246kernel void matmul_q8_gemm_batch(
1247 device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
1248 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1249 device float* outputs [[buffer(2)]], // [M, rows] output batch
1250 constant uint& num_tokens [[buffer(3)]], // M
1251 constant uint& rows [[buffer(4)]],
1252 constant uint& cols [[buffer(5)]],
1253 uint2 tgid [[threadgroup_position_in_grid]],
1254 uint tid [[thread_index_in_threadgroup]],
1255 uint simd_lane [[thread_index_in_simdgroup]],
1256 uint simd_id [[simdgroup_index_in_threadgroup]])
1257{
1258 uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
1259 uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
1260 if (row_base >= rows || tok_base >= num_tokens) return;
1261
1262 // How many tokens in this tile are valid?
1263 uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
1264
1265 uint blocks_per_row = cols / 32;
1266 uint row_bytes = blocks_per_row * 34;
1267
1268 device const uchar* r0 = matrix + row_base * row_bytes;
1269 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1270 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1271 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1272
1273 // Accumulators: 4 tokens × 4 rows per simdgroup.
1274 float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
1275 float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
1276 float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
1277 float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
1278
1279 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1280 uint bb = blk * 34;
1281 uint vb = blk * 32;
1282
1283 // ── Load weight data ONCE per block (reused across all tokens) ──
1284 float sc0 = float(*(device const half*)(r0 + bb));
1285 float sc1 = float(*(device const half*)(r1 + bb));
1286 float sc2 = float(*(device const half*)(r2 + bb));
1287 float sc3 = float(*(device const half*)(r3 + bb));
1288
1289 device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
1290 device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
1291 device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
1292 device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
1293
1294 #define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
1295 short4 _s = short4(SHORT4); \
1296 char2 _a = as_type<char2>(_s.x); \
1297 char2 _b = as_type<char2>(_s.y); \
1298 char2 _c = as_type<char2>(_s.z); \
1299 char2 _d = as_type<char2>(_s.w); \
1300 (OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
1301 (OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
1302 }
1303
1304 // Unpack all 4 rows × 8 float4 weights (scaled). These live in
1305 // registers for the duration of the block and are dotted against
1306 // every token's vector tile.
1307 float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
1308 float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
1309 float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
1310 float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
1311
1312 Q8_UNPACK8(d0[0], w0_0, w0_1);
1313 Q8_UNPACK8(d0[1], w0_2, w0_3);
1314 Q8_UNPACK8(d0[2], w0_4, w0_5);
1315 Q8_UNPACK8(d0[3], w0_6, w0_7);
1316
1317 Q8_UNPACK8(d1[0], w1_0, w1_1);
1318 Q8_UNPACK8(d1[1], w1_2, w1_3);
1319 Q8_UNPACK8(d1[2], w1_4, w1_5);
1320 Q8_UNPACK8(d1[3], w1_6, w1_7);
1321
1322 Q8_UNPACK8(d2[0], w2_0, w2_1);
1323 Q8_UNPACK8(d2[1], w2_2, w2_3);
1324 Q8_UNPACK8(d2[2], w2_4, w2_5);
1325 Q8_UNPACK8(d2[3], w2_6, w2_7);
1326
1327 Q8_UNPACK8(d3[0], w3_0, w3_1);
1328 Q8_UNPACK8(d3[1], w3_2, w3_3);
1329 Q8_UNPACK8(d3[2], w3_4, w3_5);
1330 Q8_UNPACK8(d3[3], w3_6, w3_7);
1331
1332 #undef Q8_UNPACK8
1333
1334 // ── For each token, read vector and accumulate against shared weights ──
1335 // Token 0 (always valid: tok_count >= 1).
1336 {
1337 device const float* a0 = inputs + (tok_base + 0) * cols + vb;
1338 float4 v0 = *(device const float4*)(a0);
1339 float4 v1 = *(device const float4*)(a0 + 4);
1340 float4 v2 = *(device const float4*)(a0 + 8);
1341 float4 v3 = *(device const float4*)(a0 + 12);
1342 float4 v4 = *(device const float4*)(a0 + 16);
1343 float4 v5 = *(device const float4*)(a0 + 20);
1344 float4 v6 = *(device const float4*)(a0 + 24);
1345 float4 v7 = *(device const float4*)(a0 + 28);
1346 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1347 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1348 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1349 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1350 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1351 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1352 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1353 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1354 s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
1355 }
1356 // Token 1
1357 if (tok_count > 1) {
1358 device const float* a1 = inputs + (tok_base + 1) * cols + vb;
1359 float4 v0 = *(device const float4*)(a1);
1360 float4 v1 = *(device const float4*)(a1 + 4);
1361 float4 v2 = *(device const float4*)(a1 + 8);
1362 float4 v3 = *(device const float4*)(a1 + 12);
1363 float4 v4 = *(device const float4*)(a1 + 16);
1364 float4 v5 = *(device const float4*)(a1 + 20);
1365 float4 v6 = *(device const float4*)(a1 + 24);
1366 float4 v7 = *(device const float4*)(a1 + 28);
1367 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1368 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1369 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1370 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1371 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1372 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1373 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1374 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1375 s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
1376 }
1377 // Token 2
1378 if (tok_count > 2) {
1379 device const float* a2 = inputs + (tok_base + 2) * cols + vb;
1380 float4 v0 = *(device const float4*)(a2);
1381 float4 v1 = *(device const float4*)(a2 + 4);
1382 float4 v2 = *(device const float4*)(a2 + 8);
1383 float4 v3 = *(device const float4*)(a2 + 12);
1384 float4 v4 = *(device const float4*)(a2 + 16);
1385 float4 v5 = *(device const float4*)(a2 + 20);
1386 float4 v6 = *(device const float4*)(a2 + 24);
1387 float4 v7 = *(device const float4*)(a2 + 28);
1388 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1389 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1390 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1391 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1392 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1393 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1394 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1395 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1396 s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
1397 }
1398 // Token 3
1399 if (tok_count > 3) {
1400 device const float* a3 = inputs + (tok_base + 3) * cols + vb;
1401 float4 v0 = *(device const float4*)(a3);
1402 float4 v1 = *(device const float4*)(a3 + 4);
1403 float4 v2 = *(device const float4*)(a3 + 8);
1404 float4 v3 = *(device const float4*)(a3 + 12);
1405 float4 v4 = *(device const float4*)(a3 + 16);
1406 float4 v5 = *(device const float4*)(a3 + 20);
1407 float4 v6 = *(device const float4*)(a3 + 24);
1408 float4 v7 = *(device const float4*)(a3 + 28);
1409 float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
1410 + dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
1411 float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
1412 + dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
1413 float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
1414 + dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
1415 float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
1416 + dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
1417 s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
1418 }
1419 }
1420
1421 // simdgroup reduction
1422 s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
1423 s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
1424 s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
1425 s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
1426
1427 if (simd_lane == 0) {
1428 device float* o0 = outputs + (tok_base + 0) * rows;
1429 if (row_base < rows) o0[row_base] = s00;
1430 if (row_base + 1 < rows) o0[row_base + 1] = s01;
1431 if (row_base + 2 < rows) o0[row_base + 2] = s02;
1432 if (row_base + 3 < rows) o0[row_base + 3] = s03;
1433
1434 if (tok_count > 1) {
1435 device float* o1 = outputs + (tok_base + 1) * rows;
1436 if (row_base < rows) o1[row_base] = s10;
1437 if (row_base + 1 < rows) o1[row_base + 1] = s11;
1438 if (row_base + 2 < rows) o1[row_base + 2] = s12;
1439 if (row_base + 3 < rows) o1[row_base + 3] = s13;
1440 }
1441 if (tok_count > 2) {
1442 device float* o2 = outputs + (tok_base + 2) * rows;
1443 if (row_base < rows) o2[row_base] = s20;
1444 if (row_base + 1 < rows) o2[row_base + 1] = s21;
1445 if (row_base + 2 < rows) o2[row_base + 2] = s22;
1446 if (row_base + 3 < rows) o2[row_base + 3] = s23;
1447 }
1448 if (tok_count > 3) {
1449 device float* o3 = outputs + (tok_base + 3) * rows;
1450 if (row_base < rows) o3[row_base] = s30;
1451 if (row_base + 1 < rows) o3[row_base + 1] = s31;
1452 if (row_base + 2 < rows) o3[row_base + 2] = s32;
1453 if (row_base + 3 < rows) o3[row_base + 3] = s33;
1454 }
1455 }
1456}
1457
1458// ── matmul_vec_q4_batch ────────────────────────────────────────────────
1459// Batched Q4_0 matrix-vector multiply for M input vectors.
1460// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
1461kernel void matmul_vec_q4_batch(
1462 device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes [rows, cols]
1463 device const float* inputs [[buffer(1)]], // [M, cols] input batch
1464 device float* outputs [[buffer(2)]], // [M, rows] output batch
1465 constant uint& num_tokens [[buffer(3)]], // M
1466 constant uint& rows [[buffer(4)]],
1467 constant uint& cols [[buffer(5)]],
1468 uint tgid [[threadgroup_position_in_grid]],
1469 uint tid [[thread_index_in_threadgroup]],
1470 uint simd_lane [[thread_index_in_simdgroup]],
1471 uint simd_id [[simdgroup_index_in_threadgroup]])
1472{
1473 uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
1474 uint token = tgid / row_tgs;
1475 uint tg_in_token = tgid % row_tgs;
1476 if (token >= num_tokens) return;
1477
1478 threadgroup float vec_tile[VEC_TILE_SIZE];
1479 device const float* input = inputs + token * cols;
1480 for (uint i = tid; i < cols; i += 256) {
1481 vec_tile[i] = input[i];
1482 }
1483 threadgroup_barrier(mem_flags::mem_threadgroup);
1484
1485 uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
1486 if (row_base >= rows) return;
1487
1488 uint blocks_per_row = cols / 32;
1489 uint row_bytes = blocks_per_row * 18;
1490
1491 device const uchar* r0 = matrix + row_base * row_bytes;
1492 device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
1493 device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
1494 device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
1495
1496 float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
1497
1498 for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
1499 uint bb = blk * 18;
1500 uint vb = blk * 32;
1501
1502 float sc0 = float(*(device const half*)(r0 + bb));
1503 float sc1 = float(*(device const half*)(r1 + bb));
1504 float sc2 = float(*(device const half*)(r2 + bb));
1505 float sc3 = float(*(device const half*)(r3 + bb));
1506
1507 device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
1508 device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
1509 device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
1510 device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
1511
1512 float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
1513 float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
1514 float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
1515 float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
1516 float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
1517 float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
1518 float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
1519 float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
1520
1521 float bd0=0, bd1=0, bd2=0, bd3=0;
1522 uchar4 b;
1523
1524 b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
1525 b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
1526 b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
1527 b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
1528
1529 b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
1530 b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
1531 b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
1532 b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
1533
1534 b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
1535 b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
1536 b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
1537 b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
1538
1539 b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
1540 b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
1541 b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
1542 b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
1543
1544 sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
1545 }
1546
1547 sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
1548 sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
1549
1550 device float* output = outputs + token * rows;
1551 if (simd_lane == 0) {
1552 if (row_base < rows) output[row_base] = sum0;
1553 if (row_base + 1 < rows) output[row_base + 1] = sum1;
1554 if (row_base + 2 < rows) output[row_base + 2] = sum2;
1555 if (row_base + 3 < rows) output[row_base + 3] = sum3;
1556 }
1557}
1558
1559// ── copy_kv_batch ─────────────────────────────────────────────────────
1560// Copy K or V from a strided batch QKV buffer to the KV cache.
1561// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
1562// dst layout: contiguous [max_seq, kv_dim] cache.
1563kernel void copy_kv_batch(
1564 device const float* src [[buffer(0)]], // batch QKV buffer
1565 device float* dst [[buffer(1)]], // KV cache
1566 constant uint& M [[buffer(2)]], // num tokens in batch
1567 constant uint& kv_dim [[buffer(3)]], // floats per KV vector
1568 constant uint& base_pos [[buffer(4)]], // starting position in cache
1569 constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
1570 constant uint& src_offset [[buffer(6)]], // float offset within each src row
1571 uint id [[thread_position_in_grid]])
1572{
1573 uint total = M * kv_dim;
1574 if (id >= total) return;
1575 uint token = id / kv_dim;
1576 uint d = id % kv_dim;
1577 uint dst_off = (base_pos + token) * kv_dim + d;
1578 uint src_off = token * src_stride + src_offset + d;
1579 dst[dst_off] = src[src_off];
1580}
1581
1582// ── attention_batch ───────────────────────────────────────────────────
1583// Batched causal attention for prefill. Processes M tokens in one dispatch.
1584// Each threadgroup handles one (token, head) pair.
1585// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
1586// Causal masking: token i can only attend to positions 0..base_pos+i.
1587kernel void attention_batch(
1588 device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
1589 device const float* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim]
1590 device const float* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim]
1591 device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
1592 constant uint& M [[buffer(4)]], // num tokens in batch
1593 constant uint& base_pos [[buffer(5)]], // starting position in KV cache
1594 constant uint& num_heads [[buffer(6)]],
1595 constant uint& num_kv_heads [[buffer(7)]],
1596 constant uint& head_dim [[buffer(8)]],
1597 constant uint& q_stride [[buffer(9)]], // floats per row in q_batch
1598 uint tgid [[threadgroup_position_in_grid]],
1599 uint tid [[thread_index_in_threadgroup]],
1600 uint simd_lane [[thread_index_in_simdgroup]],
1601 uint simd_id [[simdgroup_index_in_threadgroup]])
1602{
1603 // Grid: M * num_heads threadgroups
1604 uint token_idx = tgid / num_heads;
1605 uint head = tgid % num_heads;
1606 if (token_idx >= M) return;
1607
1608 uint kv_head = head / (num_heads / num_kv_heads);
1609 uint seq_len = base_pos + token_idx + 1; // causal: see positions 0..base_pos+token_idx
1610
1611 // Q offset uses strided layout (from batch QKV buffer)
1612 uint q_off = token_idx * q_stride + head * head_dim;
1613 // Output is contiguous [M, num_heads * head_dim]
1614 uint out_off = token_idx * num_heads * head_dim + head * head_dim;
1615
1616 // Shared memory for attention scores
1617 threadgroup float scores[2048];
1618
1619 // Step 1: Q * K^T with simdgroup reduction
1620 for (uint s = simd_id; s < seq_len; s += 8) {
1621 uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
1622 float dot = 0.0;
1623 for (uint d = simd_lane; d < head_dim; d += 32) {
1624 dot += q_batch[q_off + d] * k_cache[k_off + d];
1625 }
1626 dot = simd_sum(dot);
1627 if (simd_lane == 0) {
1628 scores[s] = dot * fast::rsqrt(float(head_dim));
1629 }
1630 }
1631 threadgroup_barrier(mem_flags::mem_threadgroup);
1632
1633 // Step 2: Softmax (cooperative)
1634 float local_max = -INFINITY;
1635 for (uint s = tid; s < seq_len; s += 256) {
1636 local_max = max(local_max, scores[s]);
1637 }
1638 local_max = simd_max(local_max);
1639 threadgroup float shared_max[8];
1640 if (simd_lane == 0) shared_max[simd_id] = local_max;
1641 threadgroup_barrier(mem_flags::mem_threadgroup);
1642 if (tid == 0) {
1643 float m = shared_max[0];
1644 for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
1645 shared_max[0] = m;
1646 }
1647 threadgroup_barrier(mem_flags::mem_threadgroup);
1648 float max_val = shared_max[0];
1649
1650 float local_sum = 0.0;
1651 for (uint s = tid; s < seq_len; s += 256) {
1652 scores[s] = fast::exp(scores[s] - max_val);
1653 local_sum += scores[s];
1654 }
1655 local_sum = simd_sum(local_sum);
1656 threadgroup float shared_sum[8];
1657 if (simd_lane == 0) shared_sum[simd_id] = local_sum;
1658 threadgroup_barrier(mem_flags::mem_threadgroup);
1659 if (tid == 0) {
1660 float total = 0.0;
1661 for (uint i = 0; i < 8; i++) total += shared_sum[i];
1662 shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
1663 }
1664 threadgroup_barrier(mem_flags::mem_threadgroup);
1665 float inv_sum = shared_sum[0];
1666 for (uint s = tid; s < seq_len; s += 256) {
1667 scores[s] *= inv_sum;
1668 }
1669 threadgroup_barrier(mem_flags::mem_threadgroup);
1670
1671 // Step 3: scores * V using float4 vectorized loads
1672 // With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
1673 // This is much better than the scalar version where only 64 of 256 threads are active.
1674 uint v_stride = num_kv_heads * head_dim;
1675 uint head_dim4 = head_dim / 4;
1676 for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
1677 uint d = d4 * 4;
1678 float4 acc = float4(0.0);
1679 uint v_base = kv_head * head_dim + d;
1680 uint seq_len4 = seq_len & ~3u;
1681 for (uint s = 0; s < seq_len4; s += 4) {
1682 float sc0 = scores[s];
1683 float sc1 = scores[s + 1];
1684 float sc2 = scores[s + 2];
1685 float sc3 = scores[s + 3];
1686 acc += sc0 * *(device const float4*)(v_cache + s * v_stride + v_base)
1687 + sc1 * *(device const float4*)(v_cache + (s+1) * v_stride + v_base)
1688 + sc2 * *(device const float4*)(v_cache + (s+2) * v_stride + v_base)
1689 + sc3 * *(device const float4*)(v_cache + (s+3) * v_stride + v_base);
1690 }
1691 for (uint s = seq_len4; s < seq_len; s++) {
1692 acc += scores[s] * *(device const float4*)(v_cache + s * v_stride + v_base);
1693 }
1694 *(device float4*)(output_batch + out_off + d) = acc;
1695 }
1696 // Handle remaining dimensions not divisible by 4 (scalar fallback)
1697 for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
1698 float acc = 0.0;
1699 uint v_base = kv_head * head_dim + d;
1700 for (uint s = 0; s < seq_len; s++) {
1701 acc += scores[s] * v_cache[s * v_stride + v_base];
1702 }
1703 output_batch[out_off + d] = acc;
1704 }
1705}
1706
1707// ── rope_qk_batch ─────────────────────────────────────────────────────
1708// Fused RoPE for both Q and K in a single dispatch, saving one kernel
1709// launch + memory barrier per layer. Both Q and K live in the same
1710// qkv_data buffer at different offsets within each token's row.
1711// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
1712kernel void rope_qk_batch(
1713 device float* qkv_data [[buffer(0)]], // [M, qkv_stride]
1714 constant uint& M [[buffer(1)]], // num tokens
1715 constant uint& base_pos [[buffer(2)]], // starting position
1716 constant uint& num_q_heads [[buffer(3)]],
1717 constant uint& num_kv_heads [[buffer(4)]],
1718 constant uint& head_dim [[buffer(5)]],
1719 constant uint& qkv_stride [[buffer(6)]], // floats per row
1720 constant float& theta [[buffer(7)]],
1721 uint id [[thread_position_in_grid]])
1722{
1723 uint half_dim = head_dim / 2;
1724 uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
1725 uint token = id / total_pairs;
1726 uint pair = id % total_pairs;
1727 if (token >= M) return;
1728
1729 uint pos = base_pos + token;
1730 uint q_pairs = num_q_heads * half_dim;
1731
1732 uint h, i, offset;
1733 if (pair < q_pairs) {
1734 // Q head
1735 h = pair / half_dim;
1736 i = pair % half_dim;
1737 offset = token * qkv_stride + h * head_dim + i * 2;
1738 } else {
1739 // K head
1740 uint kp = pair - q_pairs;
1741 h = kp / half_dim;
1742 i = kp % half_dim;
1743 uint k_start = num_q_heads * head_dim;
1744 offset = token * qkv_stride + k_start + h * head_dim + i * 2;
1745 }
1746
1747 float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
1748 float angle = float(pos) * freq;
1749 float cos_val = cos(angle);
1750 float sin_val = sin(angle);
1751
1752 float x0 = qkv_data[offset];
1753 float x1 = qkv_data[offset + 1];
1754 qkv_data[offset] = x0 * cos_val - x1 * sin_val;
1755 qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
1756}
1757
1758// ── copy_kv_both_batch ────────────────────────────────────────────────
1759// Fused K+V cache copy in a single dispatch: copies both K and V from
1760// the strided batch QKV buffer to their respective KV cache buffers.
1761// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
1762kernel void copy_kv_both_batch(
1763 device const float* src [[buffer(0)]], // batch QKV buffer [M, qkv_stride]
1764 device float* k_dst [[buffer(1)]], // K cache [max_seq, kv_dim]
1765 device float* v_dst [[buffer(2)]], // V cache [max_seq, kv_dim]
1766 constant uint& M [[buffer(3)]], // num tokens in batch
1767 constant uint& kv_dim [[buffer(4)]], // floats per KV vector
1768 constant uint& base_pos [[buffer(5)]], // starting position in cache
1769 constant uint& src_stride [[buffer(6)]], // floats per row in src (qkv_stride)
1770 constant uint& k_offset [[buffer(7)]], // float offset of K within each src row
1771 constant uint& v_offset [[buffer(8)]], // float offset of V within each src row
1772 uint id [[thread_position_in_grid]])
1773{
1774 // Total elements = M * kv_dim * 2 (K + V)
1775 uint total_kv = M * kv_dim;
1776 if (id >= total_kv * 2) return;
1777
1778 uint is_v = id / total_kv; // 0 = K, 1 = V
1779 uint local_id = id % total_kv;
1780 uint token = local_id / kv_dim;
1781 uint d = local_id % kv_dim;
1782
1783 uint dst_off = (base_pos + token) * kv_dim + d;
1784 uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
1785
1786 if (is_v) {
1787 v_dst[dst_off] = src[src_off];
1788 } else {
1789 k_dst[dst_off] = src[src_off];
1790 }
1791}
1792"#
1793 .replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
1794}
1795
1796fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
1801 let mut code = String::with_capacity(48 * 1024);
1802 emit_model_header(&mut code, config)?;
1803 emit_metal_model_struct(&mut code, config)?;
1804 emit_layer_buffers_struct(&mut code)?;
1805 emit_metal_model_impl(&mut code, config)?;
1806 emit_helper_functions(&mut code)?;
1807 Ok(code)
1808}
1809
1810fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
1811 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
1812 writeln!(
1813 code,
1814 "//! Model: {} ({} layers, hidden={})",
1815 config.architecture, config.num_layers, config.hidden_size
1816 )?;
1817 writeln!(code, "//!")?;
1818 writeln!(
1819 code,
1820 "//! Uses native Metal compute pipelines via the metal crate."
1821 )?;
1822 writeln!(code)?;
1823 writeln!(code, "#![allow(dead_code)]")?;
1824 writeln!(code)?;
1825 writeln!(code, "use metal::*;")?;
1826 writeln!(code, "#[allow(unused_imports)]")?;
1827 writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
1828 writeln!(code, "use std::mem;")?;
1829 writeln!(code)?;
1830
1831 writeln!(
1833 code,
1834 "// ── Model constants ──────────────────────────────────"
1835 )?;
1836 writeln!(
1837 code,
1838 "pub const HIDDEN_SIZE: usize = {};",
1839 config.hidden_size
1840 )?;
1841 writeln!(
1842 code,
1843 "pub const INTERMEDIATE_SIZE: usize = {};",
1844 config.intermediate_size
1845 )?;
1846 writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
1847 writeln!(
1848 code,
1849 "pub const NUM_HEADS: usize = {};",
1850 config.num_attention_heads
1851 )?;
1852 writeln!(
1853 code,
1854 "pub const NUM_KV_HEADS: usize = {};",
1855 config.num_kv_heads
1856 )?;
1857 writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
1858 writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
1859 let effective_seq_len = config.max_seq_len.min(4096);
1860 writeln!(
1861 code,
1862 "pub const MAX_SEQ_LEN: usize = {}; // capped from model's {}",
1863 effective_seq_len, config.max_seq_len
1864 )?;
1865 writeln!(
1866 code,
1867 "pub const RMS_NORM_EPS: f32 = {:e};",
1868 config.rms_norm_eps
1869 )?;
1870 writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
1871 writeln!(
1872 code,
1873 "/// Maximum batch size for batched prefill (prompt tokens processed at once)."
1874 )?;
1875 writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
1876 writeln!(code)?;
1877
1878 Ok(())
1879}
1880
1881fn emit_metal_model_struct(
1882 code: &mut String,
1883 _config: &ModelConfig,
1884) -> Result<(), MetalCodegenError> {
1885 writeln!(
1886 code,
1887 "// ── MetalModel ──────────────────────────────────────────"
1888 )?;
1889 writeln!(code)?;
1890 writeln!(
1891 code,
1892 "/// Metal-accelerated transformer model for Apple Silicon."
1893 )?;
1894 writeln!(code, "///")?;
1895 writeln!(
1896 code,
1897 "/// Uses unified memory for zero-copy weight access and native Metal"
1898 )?;
1899 writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
1900 writeln!(code, "pub struct MetalModel {{")?;
1901 writeln!(code, " device: Device,")?;
1902 writeln!(code, " queue: CommandQueue,")?;
1903 writeln!(code)?;
1904 writeln!(code, " // ── Compute pipelines ──")?;
1905 writeln!(code, " matmul_pipeline: ComputePipelineState,")?;
1906 writeln!(code, " matmul_q8_pipeline: ComputePipelineState,")?;
1907 writeln!(code, " matmul_q4_pipeline: ComputePipelineState,")?;
1908 writeln!(code, " rms_norm_pipeline: ComputePipelineState,")?;
1909 writeln!(code, " rope_pipeline: ComputePipelineState,")?;
1910 writeln!(code, " softmax_pipeline: ComputePipelineState,")?;
1911 writeln!(code, " silu_mul_pipeline: ComputePipelineState,")?;
1912 writeln!(code, " silu_mul_fused_pipeline: ComputePipelineState,")?;
1913 writeln!(code, " add_pipeline: ComputePipelineState,")?;
1914 writeln!(code, " attention_pipeline: ComputePipelineState,")?;
1915 writeln!(code, " add_inplace_pipeline: ComputePipelineState,")?;
1916 writeln!(code, " copy_pipeline: ComputePipelineState,")?;
1917 writeln!(code, " copy_offset_pipeline: ComputePipelineState,")?;
1918 writeln!(code)?;
1919 writeln!(code, " // ── Batched prefill pipelines ──")?;
1920 writeln!(code, " matmul_batch_pipeline: ComputePipelineState,")?;
1921 writeln!(code, " matmul_q8_batch_pipeline: ComputePipelineState,")?;
1922 writeln!(
1923 code,
1924 " matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
1925 )?;
1926 writeln!(code, " matmul_q4_batch_pipeline: ComputePipelineState,")?;
1927 writeln!(code, " rms_norm_batch_pipeline: ComputePipelineState,")?;
1928 writeln!(code, " rope_batch_pipeline: ComputePipelineState,")?;
1929 writeln!(
1930 code,
1931 " silu_mul_fused_batch_pipeline: ComputePipelineState,"
1932 )?;
1933 writeln!(
1934 code,
1935 " add_inplace_batch_pipeline: ComputePipelineState,"
1936 )?;
1937 writeln!(
1938 code,
1939 " copy_embedding_batch_pipeline: ComputePipelineState,"
1940 )?;
1941 writeln!(code, " attention_batch_pipeline: ComputePipelineState,")?;
1942 writeln!(code, " copy_kv_batch_pipeline: ComputePipelineState,")?;
1943 writeln!(code, " rope_qk_batch_pipeline: ComputePipelineState,")?;
1944 writeln!(
1945 code,
1946 " copy_kv_both_batch_pipeline: ComputePipelineState,"
1947 )?;
1948 writeln!(code)?;
1949 writeln!(code, " // ── Weight buffers (Metal shared memory) ──")?;
1950 writeln!(
1951 code,
1952 " /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
1953 )?;
1954 writeln!(code, " embed_buf: Buffer,")?;
1955 writeln!(code)?;
1956 writeln!(code, " /// Per-layer weight buffers")?;
1957 writeln!(code, " layers: Vec<LayerBuffers>,")?;
1958 writeln!(code)?;
1959 writeln!(code, " /// Final layer-norm weight [HIDDEN_SIZE]")?;
1960 writeln!(code, " norm_buf: Buffer,")?;
1961 writeln!(code)?;
1962 writeln!(
1963 code,
1964 " /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
1965 )?;
1966 writeln!(code, " lm_head_buf: Buffer,")?;
1967 writeln!(code)?;
1968 writeln!(
1969 code,
1970 " // ── Working buffers (pre-allocated, reused every forward pass) ──"
1971 )?;
1972 writeln!(code, " hidden_buf: Buffer,")?;
1973 writeln!(code, " residual_buf: Buffer,")?;
1974 writeln!(code, " normed_buf: Buffer,")?;
1975 writeln!(
1976 code,
1977 " /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
1978 )?;
1979 writeln!(code, " qkv_buf: Buffer,")?;
1980 writeln!(code, " attn_out_buf: Buffer,")?;
1981 writeln!(code, " attn_proj_buf: Buffer,")?;
1982 writeln!(
1983 code,
1984 " /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
1985 )?;
1986 writeln!(code, " gate_up_buf: Buffer,")?;
1987 writeln!(code, " ffn_hidden_buf: Buffer,")?;
1988 writeln!(code, " ffn_out_buf: Buffer,")?;
1989 writeln!(code, " add_tmp_buf: Buffer,")?;
1990 writeln!(code, " logits_buf: Buffer,")?;
1991 writeln!(code)?;
1992 writeln!(code, " // ── Batched prefill working buffers ──")?;
1993 writeln!(code, " /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
1994 writeln!(code, " batch_hidden_buf: Buffer,")?;
1995 writeln!(
1996 code,
1997 " /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
1998 )?;
1999 writeln!(code, " batch_residual_buf: Buffer,")?;
2000 writeln!(code, " /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
2001 writeln!(code, " batch_qkv_buf: Buffer,")?;
2002 writeln!(
2003 code,
2004 " /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
2005 )?;
2006 writeln!(code, " batch_attn_out_buf: Buffer,")?;
2007 writeln!(
2008 code,
2009 " /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
2010 )?;
2011 writeln!(code, " batch_attn_proj_buf: Buffer,")?;
2012 writeln!(
2013 code,
2014 " /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
2015 )?;
2016 writeln!(code, " batch_gate_up_buf: Buffer,")?;
2017 writeln!(
2018 code,
2019 " /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
2020 )?;
2021 writeln!(code, " batch_ffn_hidden_buf: Buffer,")?;
2022 writeln!(code, " /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
2023 writeln!(code, " batch_ffn_out_buf: Buffer,")?;
2024 writeln!(code, " /// Token IDs buffer for batch embedding lookup")?;
2025 writeln!(code, " batch_tokens_buf: Buffer,")?;
2026 writeln!(code, " /// Positions buffer for batch RoPE")?;
2027 writeln!(code, " batch_positions_buf: Buffer,")?;
2028 writeln!(code)?;
2029 writeln!(code, " // ── KV cache buffers (per-layer) ──")?;
2030 writeln!(code, " k_cache: Vec<Buffer>, // per-layer")?;
2031 writeln!(code, " v_cache: Vec<Buffer>, // per-layer")?;
2032 writeln!(code)?;
2033 writeln!(code, " // ── Inference state ──")?;
2034 writeln!(code, " pos: usize,")?;
2035 writeln!(code)?;
2036 writeln!(
2037 code,
2038 " /// Previous command buffer for double-buffered prefill."
2039 )?;
2040 writeln!(
2041 code,
2042 " /// While the GPU executes token N, the CPU can encode token N+1."
2043 )?;
2044 writeln!(code, " prev_cmd: Option<CommandBuffer>,")?;
2045 writeln!(code, "}}")?;
2046 writeln!(code)?;
2047
2048 Ok(())
2049}
2050
2051fn emit_layer_buffers_struct(code: &mut String) -> Result<(), MetalCodegenError> {
2052 writeln!(
2053 code,
2054 "/// Per-layer weight buffers for attention and FFN projections."
2055 )?;
2056 writeln!(code, "struct LayerBuffers {{")?;
2057 writeln!(code, " attn_norm: Buffer,")?;
2058 writeln!(
2059 code,
2060 " /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
2061 )?;
2062 writeln!(code, " qkv_weight: Buffer,")?;
2063 writeln!(code, " o_weight: Buffer,")?;
2064 writeln!(code, " ffn_norm: Buffer,")?;
2065 writeln!(
2066 code,
2067 " /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
2068 )?;
2069 writeln!(code, " gate_up_weight: Buffer,")?;
2070 writeln!(code, " down_weight: Buffer,")?;
2071 writeln!(code, "}}")?;
2072 writeln!(code)?;
2073
2074 Ok(())
2075}
2076
2077fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
2078 let hidden = config.hidden_size;
2079 let intermediate = config.intermediate_size;
2080 let _num_layers = config.num_layers;
2081 let num_heads = config.num_attention_heads;
2082 let num_kv_heads = config.num_kv_heads;
2083 let head_dim = config.head_dim;
2084 let vocab = config.vocab_size;
2085 let effective_seq_len = config.max_seq_len.min(4096);
2086 let is_q8 = config.dtype == DType::Q8_0;
2087 let is_q4 = config.dtype == DType::Q4_0;
2088 let kv_dim = num_kv_heads * head_dim;
2089
2090 writeln!(code, "impl MetalModel {{")?;
2091
2092 writeln!(
2094 code,
2095 " /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
2096 )?;
2097 writeln!(code, " ///")?;
2098 writeln!(
2099 code,
2100 " /// `weights` is the raw weight blob produced by `forge export-weights`."
2101 )?;
2102 writeln!(code, " pub fn new(weights: &[u8]) -> Self {{")?;
2103 writeln!(
2104 code,
2105 " let device = Device::system_default().expect(\"no Metal device found\");"
2106 )?;
2107 writeln!(code, " let queue = device.new_command_queue();")?;
2108 writeln!(code)?;
2109
2110 writeln!(
2112 code,
2113 " // Compile Metal shaders from embedded source"
2114 )?;
2115 writeln!(
2116 code,
2117 " let shader_source = include_str!(\"../shaders/kernels.metal\");"
2118 )?;
2119 writeln!(
2120 code,
2121 " let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
2122 )?;
2123 writeln!(
2124 code,
2125 " .expect(\"failed to compile Metal shaders\");"
2126 )?;
2127 writeln!(code)?;
2128
2129 writeln!(code, " // Create compute pipelines")?;
2131 for (var, fn_name) in [
2132 ("matmul_pipeline", "matmul_vec"),
2133 ("matmul_q8_pipeline", "matmul_vec_q8"),
2134 ("matmul_q4_pipeline", "matmul_vec_q4"),
2135 ("rms_norm_pipeline", "rms_norm"),
2136 ("rope_pipeline", "rope"),
2137 ("softmax_pipeline", "softmax"),
2138 ("silu_mul_pipeline", "silu_mul"),
2139 ("silu_mul_fused_pipeline", "silu_mul_fused"),
2140 ("add_pipeline", "elementwise_add"),
2141 ("attention_pipeline", "attention"),
2142 ("add_inplace_pipeline", "add_inplace"),
2143 ("copy_pipeline", "copy_buffer"),
2144 ("copy_offset_pipeline", "copy_offset"),
2145 ("matmul_batch_pipeline", "matmul_vec_batch"),
2146 ("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
2147 ("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
2148 ("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
2149 ("rms_norm_batch_pipeline", "rms_norm_batch"),
2150 ("rope_batch_pipeline", "rope_batch"),
2151 ("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
2152 ("add_inplace_batch_pipeline", "add_inplace_batch"),
2153 ("copy_embedding_batch_pipeline", "copy_embedding_batch"),
2154 ("attention_batch_pipeline", "attention_batch"),
2155 ("copy_kv_batch_pipeline", "copy_kv_batch"),
2156 ("rope_qk_batch_pipeline", "rope_qk_batch"),
2157 ("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
2158 ] {
2159 writeln!(
2160 code,
2161 " let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
2162 )?;
2163 }
2164 writeln!(code)?;
2165
2166 writeln!(
2168 code,
2169 " // Load weights into Metal shared-memory buffers"
2170 )?;
2171 writeln!(code, " let f32_size = mem::size_of::<f32>();")?;
2172 writeln!(code, " let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
2173 writeln!(code, " let hidden_elems = HIDDEN_SIZE;")?;
2174 writeln!(code)?;
2175 writeln!(
2176 code,
2177 " let cursor = std::cell::Cell::new(0usize); // byte cursor into `weights`"
2178 )?;
2179 writeln!(code)?;
2180 writeln!(
2181 code,
2182 " // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
2183 )?;
2184 writeln!(
2185 code,
2186 " let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
2187 )?;
2188 writeln!(code, " let byte_len = n * f32_size;")?;
2189 writeln!(code, " let cur = cursor.get();")?;
2190 writeln!(
2191 code,
2192 " let data = &weights[cur..cur + byte_len];"
2193 )?;
2194 writeln!(code, " cursor.set(cur + byte_len);")?;
2195 writeln!(code, " device.new_buffer_with_data(")?;
2196 writeln!(code, " data.as_ptr() as *const _,")?;
2197 writeln!(code, " byte_len as u64,")?;
2198 writeln!(
2199 code,
2200 " MTLResourceOptions::StorageModeShared,"
2201 )?;
2202 writeln!(code, " )")?;
2203 writeln!(code, " }};")?;
2204 writeln!(code)?;
2205
2206 if is_q8 {
2207 writeln!(
2212 code,
2213 " // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
2214 )?;
2215 writeln!(
2216 code,
2217 " // as raw bytes into a Metal buffer (no dequantization)."
2218 )?;
2219 writeln!(
2220 code,
2221 " // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
2222 )?;
2223 writeln!(
2224 code,
2225 " let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
2226 )?;
2227 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
2228 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
2229 writeln!(code, " let total_raw = rows * row_bytes;")?;
2230 writeln!(code, " let cur = cursor.get();")?;
2231 writeln!(
2232 code,
2233 " let data = &weights[cur..cur + total_raw];"
2234 )?;
2235 writeln!(code, " cursor.set(cur + total_raw);")?;
2236 writeln!(code, " device.new_buffer_with_data(")?;
2237 writeln!(code, " data.as_ptr() as *const _,")?;
2238 writeln!(code, " total_raw as u64,")?;
2239 writeln!(
2240 code,
2241 " MTLResourceOptions::StorageModeShared,"
2242 )?;
2243 writeln!(code, " )")?;
2244 writeln!(code, " }};")?;
2245 writeln!(code)?;
2246 writeln!(
2247 code,
2248 " // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
2249 )?;
2250 writeln!(
2251 code,
2252 " // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
2253 )?;
2254 writeln!(
2255 code,
2256 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
2257 )?;
2258 writeln!(
2259 code,
2260 " let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
2261 )?;
2262 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
2263 writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
2264 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
2265 writeln!(code, " let cur = cursor.get();")?;
2266 writeln!(
2267 code,
2268 " let data = &weights[cur..cur + total_raw];"
2269 )?;
2270 writeln!(code, " cursor.set(cur + total_raw);")?;
2271 writeln!(code, " device.new_buffer_with_data(")?;
2272 writeln!(code, " data.as_ptr() as *const _,")?;
2273 writeln!(code, " total_raw as u64,")?;
2274 writeln!(
2275 code,
2276 " MTLResourceOptions::StorageModeShared,"
2277 )?;
2278 writeln!(code, " )")?;
2279 writeln!(code, " }};")?;
2280 writeln!(code)?;
2281 }
2282
2283 if is_q4 {
2284 writeln!(
2289 code,
2290 " // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
2291 )?;
2292 writeln!(
2293 code,
2294 " // as raw bytes into a Metal buffer (no dequantization)."
2295 )?;
2296 writeln!(
2297 code,
2298 " // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
2299 )?;
2300 writeln!(
2301 code,
2302 " let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
2303 )?;
2304 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
2305 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
2306 writeln!(code, " let total_raw = rows * row_bytes;")?;
2307 writeln!(code, " let cur = cursor.get();")?;
2308 writeln!(
2309 code,
2310 " let data = &weights[cur..cur + total_raw];"
2311 )?;
2312 writeln!(code, " cursor.set(cur + total_raw);")?;
2313 writeln!(code, " device.new_buffer_with_data(")?;
2314 writeln!(code, " data.as_ptr() as *const _,")?;
2315 writeln!(code, " total_raw as u64,")?;
2316 writeln!(
2317 code,
2318 " MTLResourceOptions::StorageModeShared,"
2319 )?;
2320 writeln!(code, " )")?;
2321 writeln!(code, " }};")?;
2322 writeln!(code)?;
2323 writeln!(
2324 code,
2325 " // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
2326 )?;
2327 writeln!(
2328 code,
2329 " // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
2330 )?;
2331 writeln!(
2332 code,
2333 " // Used for fused QKV and gate+up projections where weights are adjacent in the file."
2334 )?;
2335 writeln!(
2336 code,
2337 " let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
2338 )?;
2339 writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
2340 writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
2341 writeln!(code, " let total_raw = total_rows * row_bytes;")?;
2342 writeln!(code, " let cur = cursor.get();")?;
2343 writeln!(
2344 code,
2345 " let data = &weights[cur..cur + total_raw];"
2346 )?;
2347 writeln!(code, " cursor.set(cur + total_raw);")?;
2348 writeln!(code, " device.new_buffer_with_data(")?;
2349 writeln!(code, " data.as_ptr() as *const _,")?;
2350 writeln!(code, " total_raw as u64,")?;
2351 writeln!(
2352 code,
2353 " MTLResourceOptions::StorageModeShared,"
2354 )?;
2355 writeln!(code, " )")?;
2356 writeln!(code, " }};")?;
2357 writeln!(code)?;
2358 }
2359
2360 writeln!(
2361 code,
2362 " let embed_buf = next_f32_buffer(&device, embed_elems);"
2363 )?;
2364 writeln!(code)?;
2365
2366 writeln!(
2368 code,
2369 " let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
2370 )?;
2371 writeln!(code, " for _layer in 0..NUM_LAYERS {{")?;
2372
2373 writeln!(
2375 code,
2376 " let attn_norm = next_f32_buffer(&device, hidden_elems);"
2377 )?;
2378
2379 let qkv_rows = hidden + 2 * kv_dim;
2380 if is_q8 {
2381 writeln!(
2383 code,
2384 " let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
2385 )?;
2386 writeln!(
2387 code,
2388 " let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
2389 )?;
2390 } else if is_q4 {
2391 writeln!(
2393 code,
2394 " let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
2395 )?;
2396 writeln!(
2397 code,
2398 " let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
2399 )?;
2400 } else {
2401 writeln!(
2403 code,
2404 " let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
2405 )?;
2406 writeln!(
2407 code,
2408 " let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
2409 )?;
2410 }
2411
2412 writeln!(
2414 code,
2415 " let ffn_norm = next_f32_buffer(&device, hidden_elems);"
2416 )?;
2417
2418 let gate_up_rows = 2 * intermediate;
2419 if is_q8 {
2420 writeln!(
2422 code,
2423 " let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
2424 )?;
2425 writeln!(
2426 code,
2427 " let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
2428 )?;
2429 } else if is_q4 {
2430 writeln!(
2432 code,
2433 " let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
2434 )?;
2435 writeln!(
2436 code,
2437 " let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
2438 )?;
2439 } else {
2440 writeln!(
2442 code,
2443 " let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
2444 )?;
2445 writeln!(
2446 code,
2447 " let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
2448 )?;
2449 }
2450
2451 writeln!(code, " layers.push(LayerBuffers {{")?;
2452 writeln!(code, " attn_norm,")?;
2453 writeln!(code, " qkv_weight,")?;
2454 writeln!(code, " o_weight,")?;
2455 writeln!(code, " ffn_norm,")?;
2456 writeln!(code, " gate_up_weight,")?;
2457 writeln!(code, " down_weight,")?;
2458 writeln!(code, " }});")?;
2459 writeln!(code, " }}")?;
2460 writeln!(code)?;
2461
2462 writeln!(
2464 code,
2465 " let norm_buf = next_f32_buffer(&device, hidden_elems);"
2466 )?;
2467 writeln!(code)?;
2468
2469 if is_q8 {
2471 writeln!(
2472 code,
2473 " let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
2474 )?;
2475 } else if is_q4 {
2476 writeln!(
2477 code,
2478 " let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
2479 )?;
2480 } else {
2481 writeln!(
2482 code,
2483 " let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
2484 )?;
2485 }
2486 writeln!(code)?;
2487
2488 let hidden_bytes = hidden * 4;
2490 let _kv_dim_bytes = kv_dim * 4;
2491 let intermediate_bytes = intermediate * 4;
2492 let vocab_bytes = vocab * 4;
2493 let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 4;
2494
2495 writeln!(
2496 code,
2497 " // Allocate working buffers (shared memory for zero-copy)"
2498 )?;
2499 writeln!(
2500 code,
2501 " let opts = MTLResourceOptions::StorageModeShared;"
2502 )?;
2503 writeln!(
2504 code,
2505 " let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2506 )?;
2507 writeln!(
2508 code,
2509 " let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2510 )?;
2511 let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
2512 writeln!(
2513 code,
2514 " let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2515 )?;
2516 writeln!(
2517 code,
2518 " // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
2519 )?;
2520 writeln!(
2521 code,
2522 " let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
2523 )?;
2524 writeln!(
2525 code,
2526 " let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2527 )?;
2528 writeln!(
2529 code,
2530 " let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2531 )?;
2532 let gate_up_buf_bytes = 2 * intermediate * 4;
2533 writeln!(
2534 code,
2535 " // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
2536 )?;
2537 writeln!(
2538 code,
2539 " let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
2540 )?;
2541 writeln!(
2542 code,
2543 " let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
2544 )?;
2545 writeln!(
2546 code,
2547 " let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2548 )?;
2549 writeln!(
2550 code,
2551 " let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
2552 )?;
2553 writeln!(
2554 code,
2555 " let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
2556 )?;
2557 writeln!(code)?;
2558
2559 let batch_hidden_bytes = hidden * 4; let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
2562 let batch_gate_up_bytes = 2 * intermediate * 4;
2563 let batch_intermediate_bytes = intermediate * 4;
2564 writeln!(
2565 code,
2566 " // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
2567 )?;
2568 writeln!(
2569 code,
2570 " let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
2571 )?;
2572 writeln!(
2573 code,
2574 " let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
2575 )?;
2576 writeln!(
2577 code,
2578 " let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
2579 )?;
2580 writeln!(
2581 code,
2582 " let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
2583 )?;
2584 writeln!(
2585 code,
2586 " let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
2587 )?;
2588 writeln!(
2589 code,
2590 " let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
2591 )?;
2592 writeln!(
2593 code,
2594 " let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
2595 )?;
2596 writeln!(
2597 code,
2598 " let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
2599 )?;
2600 writeln!(
2601 code,
2602 " let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
2603 )?;
2604 writeln!(
2605 code,
2606 " let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
2607 )?;
2608 writeln!(code)?;
2609
2610 writeln!(code, " // KV cache buffers (per-layer)")?;
2612 writeln!(
2613 code,
2614 " let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
2615 )?;
2616 writeln!(
2617 code,
2618 " let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
2619 )?;
2620 writeln!(code, " for _ in 0..NUM_LAYERS {{")?;
2621 writeln!(
2622 code,
2623 " k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
2624 )?;
2625 writeln!(
2626 code,
2627 " v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
2628 )?;
2629 writeln!(code, " }}")?;
2630 writeln!(code)?;
2631
2632 writeln!(code, " Self {{")?;
2633 writeln!(code, " device,")?;
2634 writeln!(code, " queue,")?;
2635 writeln!(code, " matmul_pipeline,")?;
2636 writeln!(code, " matmul_q8_pipeline,")?;
2637 writeln!(code, " matmul_q4_pipeline,")?;
2638 writeln!(code, " rms_norm_pipeline,")?;
2639 writeln!(code, " rope_pipeline,")?;
2640 writeln!(code, " softmax_pipeline,")?;
2641 writeln!(code, " silu_mul_pipeline,")?;
2642 writeln!(code, " silu_mul_fused_pipeline,")?;
2643 writeln!(code, " add_pipeline,")?;
2644 writeln!(code, " attention_pipeline,")?;
2645 writeln!(code, " add_inplace_pipeline,")?;
2646 writeln!(code, " copy_pipeline,")?;
2647 writeln!(code, " copy_offset_pipeline,")?;
2648 writeln!(code, " matmul_batch_pipeline,")?;
2649 writeln!(code, " matmul_q8_batch_pipeline,")?;
2650 writeln!(code, " matmul_q8_gemm_batch_pipeline,")?;
2651 writeln!(code, " matmul_q4_batch_pipeline,")?;
2652 writeln!(code, " rms_norm_batch_pipeline,")?;
2653 writeln!(code, " rope_batch_pipeline,")?;
2654 writeln!(code, " silu_mul_fused_batch_pipeline,")?;
2655 writeln!(code, " add_inplace_batch_pipeline,")?;
2656 writeln!(code, " copy_embedding_batch_pipeline,")?;
2657 writeln!(code, " attention_batch_pipeline,")?;
2658 writeln!(code, " copy_kv_batch_pipeline,")?;
2659 writeln!(code, " rope_qk_batch_pipeline,")?;
2660 writeln!(code, " copy_kv_both_batch_pipeline,")?;
2661 writeln!(code, " embed_buf,")?;
2662 writeln!(code, " layers,")?;
2663 writeln!(code, " norm_buf,")?;
2664 writeln!(code, " lm_head_buf,")?;
2665 writeln!(code, " hidden_buf,")?;
2666 writeln!(code, " residual_buf,")?;
2667 writeln!(code, " normed_buf,")?;
2668 writeln!(code, " qkv_buf,")?;
2669 writeln!(code, " attn_out_buf,")?;
2670 writeln!(code, " attn_proj_buf,")?;
2671 writeln!(code, " gate_up_buf,")?;
2672 writeln!(code, " ffn_hidden_buf,")?;
2673 writeln!(code, " ffn_out_buf,")?;
2674 writeln!(code, " add_tmp_buf,")?;
2675 writeln!(code, " logits_buf,")?;
2676 writeln!(code, " batch_hidden_buf,")?;
2677 writeln!(code, " batch_residual_buf,")?;
2678 writeln!(code, " batch_qkv_buf,")?;
2679 writeln!(code, " batch_attn_out_buf,")?;
2680 writeln!(code, " batch_attn_proj_buf,")?;
2681 writeln!(code, " batch_gate_up_buf,")?;
2682 writeln!(code, " batch_ffn_hidden_buf,")?;
2683 writeln!(code, " batch_ffn_out_buf,")?;
2684 writeln!(code, " batch_tokens_buf,")?;
2685 writeln!(code, " batch_positions_buf,")?;
2686 writeln!(code, " k_cache,")?;
2687 writeln!(code, " v_cache,")?;
2688 writeln!(code, " pos: 0,")?;
2689 writeln!(code, " prev_cmd: None,")?;
2690 writeln!(code, " }}")?;
2691 writeln!(code, " }}")?;
2692 writeln!(code)?;
2693
2694 writeln!(
2696 code,
2697 " /// Run the forward pass for a single token at the current position."
2698 )?;
2699 writeln!(code, " ///")?;
2700 writeln!(
2701 code,
2702 " /// Returns logits over the vocabulary as a `Vec<f32>`."
2703 )?;
2704 writeln!(code, " ///")?;
2705 writeln!(
2706 code,
2707 " /// All GPU operations are encoded into a single command buffer and"
2708 )?;
2709 writeln!(
2710 code,
2711 " /// committed once at the end, avoiding per-operation synchronization."
2712 )?;
2713 writeln!(
2714 code,
2715 " pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
2716 )?;
2717 writeln!(
2718 code,
2719 " // Wait for any pending prefill command buffer"
2720 )?;
2721 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
2722 writeln!(code, " prev.wait_until_completed();")?;
2723 writeln!(code, " }}")?;
2724 writeln!(code)?;
2725 writeln!(code, " let pos = self.pos;")?;
2726 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
2727 writeln!(code)?;
2728
2729 let matmul_fn = if is_q8 {
2732 "dispatch_matmul_q8"
2733 } else if is_q4 {
2734 "dispatch_matmul_q4"
2735 } else {
2736 "dispatch_matmul"
2737 };
2738
2739 writeln!(
2740 code,
2741 " // Single compute encoder for the entire forward pass (no blit transitions)"
2742 )?;
2743 writeln!(code, " {{")?;
2744 writeln!(
2745 code,
2746 " let enc = cmd.new_compute_command_encoder();"
2747 )?;
2748 writeln!(code)?;
2749
2750 writeln!(
2752 code,
2753 " // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
2754 )?;
2755 writeln!(
2756 code,
2757 " // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
2758 )?;
2759 writeln!(
2760 code,
2761 " // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
2762 )?;
2763 writeln!(
2764 code,
2765 " // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
2766 hidden * 4,
2767 )?;
2768 writeln!(code, " unsafe {{")?;
2769 writeln!(
2770 code,
2771 " let embed_ptr = self.embed_buf.contents() as *const f32;"
2772 )?;
2773 writeln!(
2774 code,
2775 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
2776 )?;
2777 writeln!(
2778 code,
2779 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
2780 )?;
2781 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
2782 writeln!(
2783 code,
2784 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
2785 )?;
2786 writeln!(code, " hidden_ptr,")?;
2787 writeln!(code, " HIDDEN_SIZE,")?;
2788 writeln!(code, " );")?;
2789 writeln!(
2790 code,
2791 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
2792 )?;
2793 writeln!(code, " }}")?;
2794 writeln!(code)?;
2795
2796 writeln!(code, " // 2. Transformer layers")?;
2798 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
2799 writeln!(code)?;
2800 let q_byte_offset = 0usize;
2801 let k_byte_offset = hidden * 4;
2802 let v_byte_offset = (hidden + kv_dim) * 4;
2803
2804 writeln!(
2805 code,
2806 " // Pre-attention: rms_norm, fused QKV projection, RoPE"
2807 )?;
2808 writeln!(
2809 code,
2810 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
2811 )?;
2812 writeln!(
2813 code,
2814 " // Fused Q+K+V matmul: single dispatch for all three projections"
2815 )?;
2816 writeln!(
2817 code,
2818 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
2819 )?;
2820 writeln!(
2821 code,
2822 " // RoPE on Q portion (qkv_buf offset 0) and K portion (qkv_buf offset {k_byte_offset})"
2823 )?;
2824 writeln!(
2825 code,
2826 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
2827 )?;
2828 writeln!(
2829 code,
2830 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
2831 )?;
2832 writeln!(code)?;
2833 writeln!(
2834 code,
2835 " // KV cache update from fused qkv_buf (K at offset {k_byte_offset}, V at offset {v_byte_offset})"
2836 )?;
2837 writeln!(code, " let kv_offset = pos * {kv_dim};")?;
2838 writeln!(
2839 code,
2840 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
2841 )?;
2842 writeln!(
2843 code,
2844 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
2845 )?;
2846 writeln!(code)?;
2847 writeln!(
2848 code,
2849 " // Attention using Q from qkv_buf (offset 0)"
2850 )?;
2851 writeln!(
2852 code,
2853 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
2854 )?;
2855 writeln!(
2856 code,
2857 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
2858 )?;
2859 writeln!(
2860 code,
2861 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
2862 )?;
2863 writeln!(
2864 code,
2865 " // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
2866 )?;
2867 writeln!(
2868 code,
2869 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
2870 )?;
2871 writeln!(
2872 code,
2873 " // Fused gate+up matmul: single dispatch for both projections"
2874 )?;
2875 writeln!(
2876 code,
2877 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
2878 )?;
2879 writeln!(
2880 code,
2881 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
2882 )?;
2883 writeln!(
2884 code,
2885 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
2886 )?;
2887 writeln!(
2888 code,
2889 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
2890 )?;
2891 writeln!(code, " }}")?;
2892 writeln!(code)?;
2893
2894 writeln!(code, " // 3. Final RMS norm + logits projection")?;
2896 writeln!(
2897 code,
2898 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
2899 )?;
2900 writeln!(
2901 code,
2902 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
2903 )?;
2904 writeln!(code)?;
2905 writeln!(code, " enc.end_encoding();")?;
2906 writeln!(code, " }}")?;
2907 writeln!(code)?;
2908
2909 writeln!(
2911 code,
2912 " // 5. Commit all GPU work and wait for completion"
2913 )?;
2914 writeln!(code, " cmd.commit();")?;
2915 writeln!(code, " cmd.wait_until_completed();")?;
2916 writeln!(code)?;
2917 writeln!(code, " // 6. Read back logits from GPU")?;
2918 writeln!(code, " let logits = unsafe {{")?;
2919 writeln!(
2920 code,
2921 " let ptr = self.logits_buf.contents() as *const f32;"
2922 )?;
2923 writeln!(
2924 code,
2925 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
2926 )?;
2927 writeln!(code, " }};")?;
2928 writeln!(code)?;
2929 writeln!(code, " self.pos += 1;")?;
2930 writeln!(code, " logits")?;
2931 writeln!(code, " }}")?;
2932 writeln!(code)?;
2933
2934 writeln!(
2936 code,
2937 " /// Profiling forward pass that prints per-stage GPU timing."
2938 )?;
2939 writeln!(code, " ///")?;
2940 writeln!(
2941 code,
2942 " /// Each stage is committed and waited on separately so that GPU timestamps"
2943 )?;
2944 writeln!(
2945 code,
2946 " /// accurately reflect per-operation cost. This is slower than `forward()` due"
2947 )?;
2948 writeln!(
2949 code,
2950 " /// to the per-stage synchronization, but useful for identifying bottlenecks."
2951 )?;
2952 writeln!(
2953 code,
2954 " pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
2955 )?;
2956 writeln!(code, " use std::time::Instant;")?;
2957 writeln!(code)?;
2958 writeln!(
2959 code,
2960 " // Wait for any pending prefill command buffer"
2961 )?;
2962 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
2963 writeln!(code, " prev.wait_until_completed();")?;
2964 writeln!(code, " }}")?;
2965 writeln!(code)?;
2966 writeln!(code, " let pos = self.pos;")?;
2967 writeln!(code)?;
2968
2969 writeln!(
2971 code,
2972 " // ── Stage: Embedding lookup (CPU via unified memory) ──"
2973 )?;
2974 writeln!(code, " let t_embed = Instant::now();")?;
2975 writeln!(code, " unsafe {{")?;
2976 writeln!(
2977 code,
2978 " let embed_ptr = self.embed_buf.contents() as *const f32;"
2979 )?;
2980 writeln!(
2981 code,
2982 " let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
2983 )?;
2984 writeln!(
2985 code,
2986 " let residual_ptr = self.residual_buf.contents() as *mut f32;"
2987 )?;
2988 writeln!(code, " std::ptr::copy_nonoverlapping(")?;
2989 writeln!(
2990 code,
2991 " embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
2992 )?;
2993 writeln!(code, " hidden_ptr,")?;
2994 writeln!(code, " HIDDEN_SIZE,")?;
2995 writeln!(code, " );")?;
2996 writeln!(
2997 code,
2998 " std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
2999 )?;
3000 writeln!(code, " }}")?;
3001 writeln!(code, " let d_embed = t_embed.elapsed();")?;
3002 writeln!(code)?;
3003
3004 writeln!(code, " // ── Stage: Transformer layers (GPU) ──")?;
3006 writeln!(code, " let t_layers = Instant::now();")?;
3007 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
3008 writeln!(code, " {{")?;
3009 writeln!(
3010 code,
3011 " let enc = cmd.new_compute_command_encoder();"
3012 )?;
3013 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
3014 writeln!(
3015 code,
3016 " self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
3017 )?;
3018 writeln!(
3019 code,
3020 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
3021 )?;
3022 writeln!(
3023 code,
3024 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {q_byte_offset}, {num_heads}, {head_dim}, pos);"
3025 )?;
3026 writeln!(
3027 code,
3028 " self.dispatch_rope_offset(&enc, &self.qkv_buf, {k_byte_offset}, {num_kv_heads}, {head_dim}, pos);"
3029 )?;
3030 writeln!(code, " let kv_offset = pos * {kv_dim};")?;
3031 writeln!(
3032 code,
3033 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {k_byte_offset}, &self.k_cache[layer], kv_offset, {kv_dim});"
3034 )?;
3035 writeln!(
3036 code,
3037 " self.dispatch_copy_from_offset(&enc, &self.qkv_buf, {v_byte_offset}, &self.v_cache[layer], kv_offset, {kv_dim});"
3038 )?;
3039 writeln!(
3040 code,
3041 " self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
3042 )?;
3043 writeln!(
3044 code,
3045 " self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
3046 )?;
3047 writeln!(
3048 code,
3049 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
3050 )?;
3051 writeln!(
3052 code,
3053 " self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
3054 )?;
3055 writeln!(
3056 code,
3057 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
3058 )?;
3059 writeln!(
3060 code,
3061 " self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
3062 )?;
3063 writeln!(
3064 code,
3065 " self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
3066 )?;
3067 writeln!(
3068 code,
3069 " self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
3070 )?;
3071 writeln!(code, " }}")?;
3072 writeln!(code, " enc.end_encoding();")?;
3073 writeln!(code, " }}")?;
3074 writeln!(code, " cmd.commit();")?;
3075 writeln!(code, " cmd.wait_until_completed();")?;
3076 writeln!(code, " let d_layers = t_layers.elapsed();")?;
3077 writeln!(code)?;
3078
3079 writeln!(code, " // ── Stage: Final norm + logits (GPU) ──")?;
3081 writeln!(code, " let t_logits = Instant::now();")?;
3082 writeln!(code, " let cmd2 = self.queue.new_command_buffer();")?;
3083 writeln!(code, " {{")?;
3084 writeln!(
3085 code,
3086 " let enc = cmd2.new_compute_command_encoder();"
3087 )?;
3088 writeln!(
3089 code,
3090 " self.dispatch_rms_norm(&enc, &self.norm_buf);"
3091 )?;
3092 writeln!(
3093 code,
3094 " self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
3095 )?;
3096 writeln!(code, " enc.end_encoding();")?;
3097 writeln!(code, " }}")?;
3098 writeln!(code, " cmd2.commit();")?;
3099 writeln!(code, " cmd2.wait_until_completed();")?;
3100 writeln!(code, " let d_logits = t_logits.elapsed();")?;
3101 writeln!(code)?;
3102
3103 writeln!(
3105 code,
3106 " eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
3107 )?;
3108 writeln!(code, " d_embed.as_secs_f64() * 1000.0,")?;
3109 writeln!(code, " d_layers.as_secs_f64() * 1000.0,")?;
3110 writeln!(code, " d_logits.as_secs_f64() * 1000.0,")?;
3111 writeln!(
3112 code,
3113 " (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
3114 )?;
3115 writeln!(code)?;
3116
3117 writeln!(code, " let logits = unsafe {{")?;
3119 writeln!(
3120 code,
3121 " let ptr = self.logits_buf.contents() as *const f32;"
3122 )?;
3123 writeln!(
3124 code,
3125 " std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
3126 )?;
3127 writeln!(code, " }};")?;
3128 writeln!(code)?;
3129 writeln!(code, " self.pos += 1;")?;
3130 writeln!(code, " logits")?;
3131 writeln!(code, " }}")?;
3132 writeln!(code)?;
3133
3134 writeln!(
3136 code,
3137 " /// Asynchronous forward pass for a single prefill token (no logits readback)."
3138 )?;
3139 writeln!(code, " ///")?;
3140 writeln!(
3141 code,
3142 " /// Commits the command buffer without waiting, enabling double-buffered"
3143 )?;
3144 writeln!(
3145 code,
3146 " /// execution: GPU processes token N while CPU encodes token N+1."
3147 )?;
3148 writeln!(
3149 code,
3150 " pub fn forward_prefill(&mut self, token_id: u32) {{"
3151 )?;
3152 writeln!(code, " self.forward_prefill_batch(&[token_id]);")?;
3153 writeln!(code, " }}")?;
3154 writeln!(code)?;
3155
3156 let batch_matmul_fn = if is_q8 {
3159 "dispatch_matmul_q8_batch"
3160 } else if is_q4 {
3161 "dispatch_matmul_q4_batch"
3162 } else {
3163 "dispatch_matmul_batch"
3164 };
3165
3166 writeln!(
3167 code,
3168 " /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
3169 )?;
3170 writeln!(code, " ///")?;
3171 writeln!(
3172 code,
3173 " /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
3174 )?;
3175 writeln!(
3176 code,
3177 " /// of mat-vec), and batched causal attention with a single GPU dispatch."
3178 )?;
3179 writeln!(
3180 code,
3181 " /// This provides significant speedup during prompt prefill."
3182 )?;
3183 writeln!(
3184 code,
3185 " pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
3186 )?;
3187 writeln!(code, " let m = tokens.len().min(MAX_BATCH_SIZE);")?;
3188 writeln!(code, " if m == 0 {{ return; }}")?;
3189 writeln!(code, " let start_pos = self.pos;")?;
3190 writeln!(code)?;
3191 writeln!(code, " // Wait for any pending command buffer")?;
3192 writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
3193 writeln!(code, " prev.wait_until_completed();")?;
3194 writeln!(code, " }}")?;
3195 writeln!(code)?;
3196
3197 writeln!(
3199 code,
3200 " // Upload token IDs and positions to GPU buffers"
3201 )?;
3202 writeln!(code, " unsafe {{")?;
3203 writeln!(
3204 code,
3205 " let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
3206 )?;
3207 writeln!(
3208 code,
3209 " let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
3210 )?;
3211 writeln!(code, " for i in 0..m {{")?;
3212 writeln!(code, " *tok_ptr.add(i) = tokens[i];")?;
3213 writeln!(
3214 code,
3215 " *pos_ptr.add(i) = (start_pos + i) as u32;"
3216 )?;
3217 writeln!(code, " }}")?;
3218 writeln!(code, " }}")?;
3219 writeln!(code)?;
3220
3221 writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
3222 writeln!(code, " {{")?;
3223 writeln!(
3224 code,
3225 " let enc = cmd.new_compute_command_encoder();"
3226 )?;
3227 writeln!(code)?;
3228
3229 writeln!(
3231 code,
3232 " // 1. Batch embedding lookup: copy all token embeddings at once"
3233 )?;
3234 writeln!(
3235 code,
3236 " self.dispatch_copy_embedding_batch(&enc, m);"
3237 )?;
3238 writeln!(
3240 code,
3241 " self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
3242 )?;
3243 writeln!(code)?;
3244
3245 writeln!(code, " // 2. Transformer layers")?;
3247 writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
3248 writeln!(code)?;
3249
3250 writeln!(
3252 code,
3253 " // Batch RMS norm: batch_residual -> batch_hidden"
3254 )?;
3255 writeln!(
3256 code,
3257 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
3258 )?;
3259
3260 writeln!(
3262 code,
3263 " // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
3264 )?;
3265 writeln!(
3266 code,
3267 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
3268 )?;
3269 writeln!(code)?;
3270
3271 let k_float_offset = hidden;
3273 writeln!(
3274 code,
3275 " // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
3276 )?;
3277 writeln!(
3278 code,
3279 " self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
3280 )?;
3281 writeln!(code)?;
3282
3283 let v_float_offset = hidden + kv_dim;
3285 writeln!(
3286 code,
3287 " // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
3288 )?;
3289 writeln!(
3290 code,
3291 " self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
3292 )?;
3293 writeln!(code)?;
3294
3295 writeln!(
3297 code,
3298 " // Batched causal attention: one dispatch for all M tokens"
3299 )?;
3300 writeln!(
3301 code,
3302 " self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
3303 )?;
3304 writeln!(code)?;
3305
3306 writeln!(code, " // Batched O projection")?;
3308 writeln!(
3309 code,
3310 " self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
3311 )?;
3312 writeln!(code)?;
3313
3314 writeln!(
3316 code,
3317 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
3318 )?;
3319 writeln!(code)?;
3320
3321 writeln!(
3323 code,
3324 " // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
3325 )?;
3326 writeln!(
3327 code,
3328 " self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
3329 )?;
3330 writeln!(
3331 code,
3332 " self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
3333 )?;
3334 writeln!(
3335 code,
3336 " self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
3337 )?;
3338 writeln!(
3339 code,
3340 " self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
3341 )?;
3342 writeln!(
3343 code,
3344 " self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
3345 )?;
3346 writeln!(code, " }}")?;
3347 writeln!(code)?;
3348
3349 writeln!(
3351 code,
3352 " // Copy last token's residual to single-token buffer for subsequent forward()"
3353 )?;
3354 writeln!(
3355 code,
3356 " self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
3357 )?;
3358 writeln!(code)?;
3359 writeln!(code, " enc.end_encoding();")?;
3360 writeln!(code, " }}")?;
3361 writeln!(code)?;
3362
3363 writeln!(code, " cmd.commit();")?;
3364 writeln!(code, " self.prev_cmd = Some(cmd.to_owned());")?;
3365 writeln!(code, " self.pos += m;")?;
3366 writeln!(code, " }}")?;
3367 writeln!(code)?;
3368
3369 writeln!(
3371 code,
3372 " /// Reset the model state for a new inference request."
3373 )?;
3374 writeln!(code, " pub fn reset(&mut self) {{")?;
3375 writeln!(code, " self.pos = 0;")?;
3376 writeln!(code, " self.prev_cmd = None;")?;
3377 writeln!(code, " }}")?;
3378 writeln!(code)?;
3379
3380 writeln!(
3382 code,
3383 " // ── Dispatch helpers (append to a shared compute command encoder) ──"
3384 )?;
3385 writeln!(
3386 code,
3387 " // These methods set pipeline state + buffers + dispatch on an existing"
3388 )?;
3389 writeln!(
3390 code,
3391 " // encoder, avoiding per-operation encoder creation overhead."
3392 )?;
3393 writeln!(code)?;
3394
3395 writeln!(
3397 code,
3398 " /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
3399 )?;
3400 writeln!(
3401 code,
3402 " fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
3403 )?;
3404 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
3405 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
3406 writeln!(
3407 code,
3408 " enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
3409 )?;
3410 writeln!(
3411 code,
3412 " enc.set_buffer(0, Some(&self.residual_buf), 0);"
3413 )?;
3414 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
3415 writeln!(
3416 code,
3417 " enc.set_buffer(2, Some(&self.hidden_buf), 0);"
3418 )?;
3419 writeln!(
3420 code,
3421 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
3422 )?;
3423 writeln!(
3424 code,
3425 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
3426 )?;
3427 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3428 writeln!(
3429 code,
3430 " let grid_size = MTLSize::new(1, 1, 1); // single threadgroup"
3431 )?;
3432 writeln!(
3433 code,
3434 " enc.dispatch_thread_groups(grid_size, tg_size);"
3435 )?;
3436 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3437 writeln!(code, " }}")?;
3438 writeln!(code)?;
3439
3440 writeln!(
3442 code,
3443 " /// Dispatch matrix-vector multiply: weight * input -> output."
3444 )?;
3445 writeln!(
3446 code,
3447 " fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
3448 )?;
3449 writeln!(code, " let r: u32 = rows as u32;")?;
3450 writeln!(code, " let c: u32 = cols as u32;")?;
3451 writeln!(
3452 code,
3453 " enc.set_compute_pipeline_state(&self.matmul_pipeline);"
3454 )?;
3455 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
3456 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
3457 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
3458 writeln!(
3459 code,
3460 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
3461 )?;
3462 writeln!(
3463 code,
3464 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
3465 )?;
3466 writeln!(
3467 code,
3468 " // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
3469 )?;
3470 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3471 writeln!(code, " let num_tg = ((rows + 63) / 64) as u64;")?;
3472 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
3473 writeln!(
3474 code,
3475 " enc.dispatch_thread_groups(grid_size, tg_size);"
3476 )?;
3477 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3478 writeln!(code, " }}")?;
3479 writeln!(code)?;
3480
3481 writeln!(
3483 code,
3484 " /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
3485 )?;
3486 writeln!(
3487 code,
3488 " /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
3489 )?;
3490 writeln!(
3491 code,
3492 " fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
3493 )?;
3494 writeln!(code, " let r: u32 = rows as u32;")?;
3495 writeln!(code, " let c: u32 = cols as u32;")?;
3496 writeln!(
3497 code,
3498 " enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
3499 )?;
3500 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
3501 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
3502 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
3503 writeln!(
3504 code,
3505 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
3506 )?;
3507 writeln!(
3508 code,
3509 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
3510 )?;
3511 writeln!(
3512 code,
3513 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
3514 )?;
3515 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3516 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
3517 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
3518 writeln!(
3519 code,
3520 " enc.dispatch_thread_groups(grid_size, tg_size);"
3521 )?;
3522 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3523 writeln!(code, " }}")?;
3524 writeln!(code)?;
3525
3526 writeln!(
3528 code,
3529 " /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
3530 )?;
3531 writeln!(
3532 code,
3533 " /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
3534 )?;
3535 writeln!(
3536 code,
3537 " fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
3538 )?;
3539 writeln!(code, " let r: u32 = rows as u32;")?;
3540 writeln!(code, " let c: u32 = cols as u32;")?;
3541 writeln!(
3542 code,
3543 " enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
3544 )?;
3545 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
3546 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
3547 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
3548 writeln!(
3549 code,
3550 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
3551 )?;
3552 writeln!(
3553 code,
3554 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
3555 )?;
3556 writeln!(
3557 code,
3558 " // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
3559 )?;
3560 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3561 writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
3562 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
3563 writeln!(
3564 code,
3565 " enc.dispatch_thread_groups(grid_size, tg_size);"
3566 )?;
3567 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3568 writeln!(code, " }}")?;
3569 writeln!(code)?;
3570
3571 writeln!(code, " /// Dispatch RoPE on a buffer in-place.")?;
3573 writeln!(
3574 code,
3575 " fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
3576 )?;
3577 writeln!(code, " let nh: u32 = num_heads as u32;")?;
3578 writeln!(code, " let hd: u32 = head_dim as u32;")?;
3579 writeln!(code, " let p: u32 = pos as u32;")?;
3580 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
3581 writeln!(
3582 code,
3583 " let total_pairs = num_heads * (head_dim / 2);"
3584 )?;
3585 writeln!(
3586 code,
3587 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
3588 )?;
3589 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
3590 writeln!(
3591 code,
3592 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
3593 )?;
3594 writeln!(
3595 code,
3596 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
3597 )?;
3598 writeln!(
3599 code,
3600 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
3601 )?;
3602 writeln!(
3603 code,
3604 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
3605 )?;
3606 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3607 writeln!(
3608 code,
3609 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
3610 )?;
3611 writeln!(
3612 code,
3613 " enc.dispatch_thread_groups(grid_size, tg_size);"
3614 )?;
3615 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3616 writeln!(code, " }}")?;
3617 writeln!(code)?;
3618
3619 writeln!(
3621 code,
3622 " /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
3623 )?;
3624 writeln!(
3625 code,
3626 " fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
3627 )?;
3628 writeln!(code, " let nh: u32 = num_heads as u32;")?;
3629 writeln!(code, " let hd: u32 = head_dim as u32;")?;
3630 writeln!(code, " let p: u32 = pos as u32;")?;
3631 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
3632 writeln!(
3633 code,
3634 " let total_pairs = num_heads * (head_dim / 2);"
3635 )?;
3636 writeln!(
3637 code,
3638 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
3639 )?;
3640 writeln!(
3641 code,
3642 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
3643 )?;
3644 writeln!(
3645 code,
3646 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
3647 )?;
3648 writeln!(
3649 code,
3650 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
3651 )?;
3652 writeln!(
3653 code,
3654 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
3655 )?;
3656 writeln!(
3657 code,
3658 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
3659 )?;
3660 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3661 writeln!(
3662 code,
3663 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
3664 )?;
3665 writeln!(
3666 code,
3667 " enc.dispatch_thread_groups(grid_size, tg_size);"
3668 )?;
3669 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3670 writeln!(code, " }}")?;
3671 writeln!(code)?;
3672
3673 writeln!(
3675 code,
3676 " /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
3677 )?;
3678 writeln!(
3679 code,
3680 " fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
3681 )?;
3682 writeln!(code, " let sl: u32 = seq_len as u32;")?;
3683 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
3684 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
3685 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
3686 writeln!(
3687 code,
3688 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
3689 )?;
3690 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
3691 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
3692 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
3693 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
3694 writeln!(
3695 code,
3696 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
3697 )?;
3698 writeln!(
3699 code,
3700 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
3701 )?;
3702 writeln!(
3703 code,
3704 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
3705 )?;
3706 writeln!(
3707 code,
3708 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
3709 )?;
3710 writeln!(
3711 code,
3712 " // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
3713 )?;
3714 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3715 writeln!(
3716 code,
3717 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
3718 )?;
3719 writeln!(
3720 code,
3721 " enc.dispatch_thread_groups(grid_size, tg_size);"
3722 )?;
3723 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3724 writeln!(code, " }}")?;
3725 writeln!(code)?;
3726
3727 writeln!(
3729 code,
3730 " /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
3731 )?;
3732 writeln!(
3733 code,
3734 " fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
3735 )?;
3736 writeln!(code, " let sl: u32 = seq_len as u32;")?;
3737 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
3738 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
3739 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
3740 writeln!(
3741 code,
3742 " enc.set_compute_pipeline_state(&self.attention_pipeline);"
3743 )?;
3744 writeln!(
3745 code,
3746 " enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
3747 )?;
3748 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
3749 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
3750 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
3751 writeln!(
3752 code,
3753 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
3754 )?;
3755 writeln!(
3756 code,
3757 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
3758 )?;
3759 writeln!(
3760 code,
3761 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
3762 )?;
3763 writeln!(
3764 code,
3765 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
3766 )?;
3767 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3768 writeln!(
3769 code,
3770 " let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
3771 )?;
3772 writeln!(
3773 code,
3774 " enc.dispatch_thread_groups(grid_size, tg_size);"
3775 )?;
3776 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3777 writeln!(code, " }}")?;
3778 writeln!(code)?;
3779
3780 writeln!(code, " /// Dispatch fused SiLU-multiply kernel.")?;
3782 writeln!(
3783 code,
3784 " fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
3785 )?;
3786 writeln!(code, " let count: u32 = n as u32;")?;
3787 writeln!(
3788 code,
3789 " enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
3790 )?;
3791 writeln!(code, " enc.set_buffer(0, Some(gate), 0);")?;
3792 writeln!(code, " enc.set_buffer(1, Some(up), 0);")?;
3793 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
3794 writeln!(
3795 code,
3796 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
3797 )?;
3798 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3799 writeln!(
3800 code,
3801 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
3802 )?;
3803 writeln!(
3804 code,
3805 " enc.dispatch_thread_groups(grid_size, tg_size);"
3806 )?;
3807 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3808 writeln!(code, " }}")?;
3809 writeln!(code)?;
3810
3811 writeln!(
3813 code,
3814 " /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
3815 )?;
3816 writeln!(
3817 code,
3818 " /// gate_up_buf contains [gate(n), up(n)] contiguously."
3819 )?;
3820 writeln!(
3821 code,
3822 " fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
3823 )?;
3824 writeln!(code, " let count: u32 = n as u32;")?;
3825 writeln!(
3826 code,
3827 " enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
3828 )?;
3829 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
3830 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
3831 writeln!(
3832 code,
3833 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
3834 )?;
3835 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3836 writeln!(
3837 code,
3838 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
3839 )?;
3840 writeln!(
3841 code,
3842 " enc.dispatch_thread_groups(grid_size, tg_size);"
3843 )?;
3844 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3845 writeln!(code, " }}")?;
3846 writeln!(code)?;
3847
3848 writeln!(
3850 code,
3851 " /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
3852 )?;
3853 writeln!(
3854 code,
3855 " fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
3856 )?;
3857 writeln!(code, " let n: u32 = count as u32;")?;
3858 writeln!(
3859 code,
3860 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
3861 )?;
3862 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
3863 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
3864 writeln!(
3865 code,
3866 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
3867 )?;
3868 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3869 writeln!(
3870 code,
3871 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
3872 )?;
3873 writeln!(
3874 code,
3875 " enc.dispatch_thread_groups(grid_size, tg_size);"
3876 )?;
3877 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3878 writeln!(code, " }}")?;
3879 writeln!(code)?;
3880
3881 writeln!(
3883 code,
3884 " /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
3885 )?;
3886 writeln!(
3887 code,
3888 " /// Used for embedding table lookup (copy a specific row)."
3889 )?;
3890 writeln!(
3891 code,
3892 " fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
3893 )?;
3894 writeln!(code, " let off: u32 = src_offset as u32;")?;
3895 writeln!(code, " let n: u32 = count as u32;")?;
3896 writeln!(
3897 code,
3898 " enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
3899 )?;
3900 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
3901 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
3902 writeln!(
3903 code,
3904 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
3905 )?;
3906 writeln!(
3907 code,
3908 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
3909 )?;
3910 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3911 writeln!(
3912 code,
3913 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
3914 )?;
3915 writeln!(
3916 code,
3917 " enc.dispatch_thread_groups(grid_size, tg_size);"
3918 )?;
3919 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3920 writeln!(code, " }}")?;
3921 writeln!(code)?;
3922
3923 writeln!(
3925 code,
3926 " /// Dispatch copy from source at byte offset to destination at float offset."
3927 )?;
3928 writeln!(
3929 code,
3930 " /// Used for KV cache updates from fused QKV buffer."
3931 )?;
3932 writeln!(
3933 code,
3934 " fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
3935 )?;
3936 writeln!(code, " let n: u32 = count as u32;")?;
3937 writeln!(
3938 code,
3939 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
3940 )?;
3941 writeln!(
3942 code,
3943 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
3944 )?;
3945 writeln!(
3946 code,
3947 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
3948 )?;
3949 writeln!(
3950 code,
3951 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
3952 )?;
3953 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3954 writeln!(
3955 code,
3956 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
3957 )?;
3958 writeln!(
3959 code,
3960 " enc.dispatch_thread_groups(grid_size, tg_size);"
3961 )?;
3962 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
3963 writeln!(code, " }}")?;
3964 writeln!(code)?;
3965
3966 writeln!(
3968 code,
3969 " /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
3970 )?;
3971 writeln!(
3972 code,
3973 " /// Used for KV cache updates (write to a specific position in the cache)."
3974 )?;
3975 writeln!(
3976 code,
3977 " fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
3978 )?;
3979 writeln!(code, " let n: u32 = count as u32;")?;
3980 writeln!(
3981 code,
3982 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
3983 )?;
3984 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
3985 writeln!(
3986 code,
3987 " enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
3988 )?;
3989 writeln!(
3990 code,
3991 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
3992 )?;
3993 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
3994 writeln!(
3995 code,
3996 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
3997 )?;
3998 writeln!(
3999 code,
4000 " enc.dispatch_thread_groups(grid_size, tg_size);"
4001 )?;
4002 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4003 writeln!(code, " }}")?;
4004 writeln!(code)?;
4005
4006 writeln!(
4008 code,
4009 " /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
4010 )?;
4011 writeln!(
4012 code,
4013 " fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
4014 )?;
4015 writeln!(code, " let count: u32 = n as u32;")?;
4016 writeln!(
4017 code,
4018 " enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
4019 )?;
4020 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
4021 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
4022 writeln!(
4023 code,
4024 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4025 )?;
4026 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4027 writeln!(
4028 code,
4029 " let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
4030 )?;
4031 writeln!(
4032 code,
4033 " enc.dispatch_thread_groups(grid_size, tg_size);"
4034 )?;
4035 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4036 writeln!(code, " }}")?;
4037 writeln!(code)?;
4038
4039 writeln!(code, " // ── Batched prefill dispatch helpers ──")?;
4041 writeln!(code)?;
4042
4043 writeln!(
4045 code,
4046 " /// Dispatch batched embedding lookup: copy M token embeddings at once."
4047 )?;
4048 writeln!(
4049 code,
4050 " fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
4051 )?;
4052 writeln!(code, " let dim: u32 = HIDDEN_SIZE as u32;")?;
4053 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
4054 writeln!(
4055 code,
4056 " enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
4057 )?;
4058 writeln!(code, " enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
4059 writeln!(
4060 code,
4061 " enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
4062 )?;
4063 writeln!(
4064 code,
4065 " enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
4066 )?;
4067 writeln!(
4068 code,
4069 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
4070 )?;
4071 writeln!(
4072 code,
4073 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4074 )?;
4075 writeln!(code, " let total = num_tokens * HIDDEN_SIZE;")?;
4076 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4077 writeln!(
4078 code,
4079 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
4080 )?;
4081 writeln!(
4082 code,
4083 " enc.dispatch_thread_groups(grid_size, tg_size);"
4084 )?;
4085 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4086 writeln!(code, " }}")?;
4087 writeln!(code)?;
4088
4089 writeln!(
4091 code,
4092 " /// Dispatch batched RMS norm: normalizes M vectors at once."
4093 )?;
4094 writeln!(
4095 code,
4096 " fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
4097 )?;
4098 writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
4099 writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
4100 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
4101 writeln!(
4102 code,
4103 " enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
4104 )?;
4105 writeln!(code, " enc.set_buffer(0, Some(input), 0);")?;
4106 writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
4107 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4108 writeln!(
4109 code,
4110 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4111 )?;
4112 writeln!(
4113 code,
4114 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
4115 )?;
4116 writeln!(
4117 code,
4118 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4119 )?;
4120 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4121 writeln!(
4122 code,
4123 " let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
4124 )?;
4125 writeln!(
4126 code,
4127 " enc.dispatch_thread_groups(grid_size, tg_size);"
4128 )?;
4129 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4130 writeln!(code, " }}")?;
4131 writeln!(code)?;
4132
4133 writeln!(
4135 code,
4136 " /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4137 )?;
4138 writeln!(
4139 code,
4140 " fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4141 )?;
4142 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
4143 writeln!(code, " let r: u32 = rows as u32;")?;
4144 writeln!(code, " let c: u32 = cols as u32;")?;
4145 writeln!(
4146 code,
4147 " enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
4148 )?;
4149 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4150 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4151 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4152 writeln!(
4153 code,
4154 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4155 )?;
4156 writeln!(
4157 code,
4158 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4159 )?;
4160 writeln!(
4161 code,
4162 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4163 )?;
4164 writeln!(
4165 code,
4166 " let row_tgs = (rows + 63) / 64; // 64 rows per threadgroup for f32"
4167 )?;
4168 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
4169 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4170 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4171 writeln!(
4172 code,
4173 " enc.dispatch_thread_groups(grid_size, tg_size);"
4174 )?;
4175 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4176 writeln!(code, " }}")?;
4177 writeln!(code)?;
4178
4179 writeln!(
4181 code,
4182 " /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4183 )?;
4184 writeln!(code, " ///")?;
4185 writeln!(
4186 code,
4187 " /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
4188 )?;
4189 writeln!(
4190 code,
4191 " /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
4192 )?;
4193 writeln!(
4194 code,
4195 " fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4196 )?;
4197 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
4198 writeln!(code, " let r: u32 = rows as u32;")?;
4199 writeln!(code, " let c: u32 = cols as u32;")?;
4200 writeln!(
4201 code,
4202 " // Tile size must match TOKENS_PER_TG_Q8 in shaders."
4203 )?;
4204 writeln!(code, " const TOKENS_PER_TG_Q8: usize = 4;")?;
4205 writeln!(code, " if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
4206 writeln!(
4207 code,
4208 " enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
4209 )?;
4210 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4211 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4212 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4213 writeln!(
4214 code,
4215 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4216 )?;
4217 writeln!(
4218 code,
4219 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4220 )?;
4221 writeln!(
4222 code,
4223 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4224 )?;
4225 writeln!(
4226 code,
4227 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
4228 )?;
4229 writeln!(
4230 code,
4231 " let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
4232 )?;
4233 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4234 writeln!(
4235 code,
4236 " let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
4237 )?;
4238 writeln!(
4239 code,
4240 " enc.dispatch_thread_groups(grid_size, tg_size);"
4241 )?;
4242 writeln!(code, " }} else {{")?;
4243 writeln!(
4244 code,
4245 " enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
4246 )?;
4247 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4248 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4249 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4250 writeln!(
4251 code,
4252 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4253 )?;
4254 writeln!(
4255 code,
4256 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4257 )?;
4258 writeln!(
4259 code,
4260 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4261 )?;
4262 writeln!(
4263 code,
4264 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
4265 )?;
4266 writeln!(
4267 code,
4268 " let num_tg = (row_tgs * num_tokens) as u64;"
4269 )?;
4270 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4271 writeln!(
4272 code,
4273 " let grid_size = MTLSize::new(num_tg, 1, 1);"
4274 )?;
4275 writeln!(
4276 code,
4277 " enc.dispatch_thread_groups(grid_size, tg_size);"
4278 )?;
4279 writeln!(code, " }}")?;
4280 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4281 writeln!(code, " }}")?;
4282 writeln!(code)?;
4283
4284 writeln!(
4286 code,
4287 " /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
4288 )?;
4289 writeln!(
4290 code,
4291 " fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
4292 )?;
4293 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
4294 writeln!(code, " let r: u32 = rows as u32;")?;
4295 writeln!(code, " let c: u32 = cols as u32;")?;
4296 writeln!(
4297 code,
4298 " enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
4299 )?;
4300 writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
4301 writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
4302 writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
4303 writeln!(
4304 code,
4305 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4306 )?;
4307 writeln!(
4308 code,
4309 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
4310 )?;
4311 writeln!(
4312 code,
4313 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
4314 )?;
4315 writeln!(
4316 code,
4317 " let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q4"
4318 )?;
4319 writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
4320 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4321 writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
4322 writeln!(
4323 code,
4324 " enc.dispatch_thread_groups(grid_size, tg_size);"
4325 )?;
4326 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4327 writeln!(code, " }}")?;
4328 writeln!(code)?;
4329
4330 writeln!(
4332 code,
4333 " /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
4334 )?;
4335 writeln!(
4336 code,
4337 " /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
4338 )?;
4339 writeln!(
4340 code,
4341 " /// `row_stride` is the number of floats per token row in the batch buffer."
4342 )?;
4343 writeln!(
4344 code,
4345 " fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
4346 )?;
4347 writeln!(code, " let nh: u32 = num_heads as u32;")?;
4348 writeln!(code, " let hd: u32 = head_dim as u32;")?;
4349 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4350 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
4351 writeln!(
4352 code,
4353 " let pairs_per_token = num_heads * (head_dim / 2);"
4354 )?;
4355 writeln!(
4356 code,
4357 " let total_pairs = num_tokens * pairs_per_token;"
4358 )?;
4359 writeln!(
4372 code,
4373 " // Apply RoPE to each token individually (different positions, non-contiguous layout)"
4374 )?;
4375 writeln!(code, " for t in 0..num_tokens {{")?;
4376 writeln!(
4377 code,
4378 " let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
4379 )?;
4380 writeln!(
4381 code,
4382 " let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
4383 )?;
4384 writeln!(
4385 code,
4386 " enc.set_compute_pipeline_state(&self.rope_pipeline);"
4387 )?;
4388 writeln!(
4389 code,
4390 " enc.set_buffer(0, Some(buf), byte_offset as u64);"
4391 )?;
4392 writeln!(
4393 code,
4394 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4395 )?;
4396 writeln!(
4397 code,
4398 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4399 )?;
4400 writeln!(
4401 code,
4402 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
4403 )?;
4404 writeln!(
4405 code,
4406 " enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4407 )?;
4408 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4409 writeln!(
4410 code,
4411 " let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
4412 )?;
4413 writeln!(
4414 code,
4415 " enc.dispatch_thread_groups(grid_size, tg_size);"
4416 )?;
4417 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4418 writeln!(code, " }}")?;
4419 writeln!(code, " }}")?;
4420 writeln!(code)?;
4421
4422 writeln!(
4424 code,
4425 " /// Dispatch batched fused SiLU-multiply for M tokens."
4426 )?;
4427 writeln!(
4428 code,
4429 " fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
4430 )?;
4431 writeln!(code, " let count: u32 = n as u32;")?;
4432 writeln!(code, " let nt: u32 = num_tokens as u32;")?;
4433 writeln!(
4434 code,
4435 " enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
4436 )?;
4437 writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
4438 writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
4439 writeln!(
4440 code,
4441 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4442 )?;
4443 writeln!(
4444 code,
4445 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
4446 )?;
4447 writeln!(code, " let total = num_tokens * n;")?;
4448 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4449 writeln!(
4450 code,
4451 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
4452 )?;
4453 writeln!(
4454 code,
4455 " enc.dispatch_thread_groups(grid_size, tg_size);"
4456 )?;
4457 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4458 writeln!(code, " }}")?;
4459 writeln!(code)?;
4460
4461 writeln!(
4463 code,
4464 " /// Dispatch in-place add for total_n elements: a[i] += b[i]."
4465 )?;
4466 writeln!(
4467 code,
4468 " fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
4469 )?;
4470 writeln!(code, " let count: u32 = total_n as u32;")?;
4471 writeln!(
4472 code,
4473 " enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
4474 )?;
4475 writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
4476 writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
4477 writeln!(
4478 code,
4479 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
4480 )?;
4481 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4482 writeln!(
4483 code,
4484 " let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
4485 )?;
4486 writeln!(
4487 code,
4488 " enc.dispatch_thread_groups(grid_size, tg_size);"
4489 )?;
4490 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4491 writeln!(code, " }}")?;
4492 writeln!(code)?;
4493
4494 writeln!(
4496 code,
4497 " /// Copy src to dst using compute copy kernel (for batch residual init)."
4498 )?;
4499 writeln!(
4500 code,
4501 " fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
4502 )?;
4503 writeln!(code, " let n: u32 = count as u32;")?;
4504 writeln!(
4505 code,
4506 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
4507 )?;
4508 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
4509 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
4510 writeln!(
4511 code,
4512 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4513 )?;
4514 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4515 writeln!(
4516 code,
4517 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4518 )?;
4519 writeln!(
4520 code,
4521 " enc.dispatch_thread_groups(grid_size, tg_size);"
4522 )?;
4523 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4524 writeln!(code, " }}")?;
4525 writeln!(code)?;
4526
4527 writeln!(
4529 code,
4530 " /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
4531 )?;
4532 writeln!(
4533 code,
4534 " fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
4535 )?;
4536 writeln!(code, " let n: u32 = count as u32;")?;
4537 writeln!(
4538 code,
4539 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
4540 )?;
4541 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
4542 writeln!(
4543 code,
4544 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
4545 )?;
4546 writeln!(
4547 code,
4548 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4549 )?;
4550 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4551 writeln!(
4552 code,
4553 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4554 )?;
4555 writeln!(
4556 code,
4557 " enc.dispatch_thread_groups(grid_size, tg_size);"
4558 )?;
4559 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4560 writeln!(code, " }}")?;
4561 writeln!(code)?;
4562
4563 writeln!(
4565 code,
4566 " /// Copy from src at byte offset to dst at float offset."
4567 )?;
4568 writeln!(
4569 code,
4570 " fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
4571 )?;
4572 writeln!(code, " let n: u32 = count as u32;")?;
4573 writeln!(
4574 code,
4575 " enc.set_compute_pipeline_state(&self.copy_pipeline);"
4576 )?;
4577 writeln!(
4578 code,
4579 " enc.set_buffer(0, Some(src), src_byte_offset as u64);"
4580 )?;
4581 writeln!(
4582 code,
4583 " enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
4584 )?;
4585 writeln!(
4586 code,
4587 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
4588 )?;
4589 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4590 writeln!(
4591 code,
4592 " let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
4593 )?;
4594 writeln!(
4595 code,
4596 " enc.dispatch_thread_groups(grid_size, tg_size);"
4597 )?;
4598 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4599 writeln!(code, " }}")?;
4600 writeln!(code)?;
4601
4602 writeln!(
4604 code,
4605 " /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
4606 )?;
4607 writeln!(
4608 code,
4609 " fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
4610 )?;
4611 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
4612 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
4613 writeln!(code, " let bp: u32 = base_pos as u32;")?;
4614 writeln!(code, " let ss: u32 = src_stride as u32;")?;
4615 writeln!(code, " let so: u32 = src_offset as u32;")?;
4616 writeln!(
4617 code,
4618 " enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
4619 )?;
4620 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
4621 writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
4622 writeln!(
4623 code,
4624 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
4625 )?;
4626 writeln!(
4627 code,
4628 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
4629 )?;
4630 writeln!(
4631 code,
4632 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
4633 )?;
4634 writeln!(
4635 code,
4636 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
4637 )?;
4638 writeln!(
4639 code,
4640 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
4641 )?;
4642 writeln!(code, " let total = num_tokens * kv_dim;")?;
4643 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4644 writeln!(
4645 code,
4646 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
4647 )?;
4648 writeln!(
4649 code,
4650 " enc.dispatch_thread_groups(grid_size, tg_size);"
4651 )?;
4652 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4653 writeln!(code, " }}")?;
4654 writeln!(code)?;
4655
4656 writeln!(
4658 code,
4659 " /// Dispatch batched causal attention: one dispatch for all M tokens."
4660 )?;
4661 writeln!(
4662 code,
4663 " fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
4664 )?;
4665 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
4666 writeln!(code, " let bp: u32 = base_pos as u32;")?;
4667 writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
4668 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
4669 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
4670 writeln!(code, " let qs: u32 = q_stride as u32;")?;
4671 writeln!(
4672 code,
4673 " enc.set_compute_pipeline_state(&self.attention_batch_pipeline);"
4674 )?;
4675 writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
4676 writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
4677 writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
4678 writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
4679 writeln!(
4680 code,
4681 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
4682 )?;
4683 writeln!(
4684 code,
4685 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
4686 )?;
4687 writeln!(
4688 code,
4689 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
4690 )?;
4691 writeln!(
4692 code,
4693 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4694 )?;
4695 writeln!(
4696 code,
4697 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4698 )?;
4699 writeln!(
4700 code,
4701 " enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
4702 )?;
4703 writeln!(
4704 code,
4705 " // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
4706 )?;
4707 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4708 writeln!(
4709 code,
4710 " let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
4711 )?;
4712 writeln!(
4713 code,
4714 " enc.dispatch_thread_groups(grid_size, tg_size);"
4715 )?;
4716 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4717 writeln!(code, " }}")?;
4718 writeln!(code)?;
4719
4720 writeln!(
4722 code,
4723 " /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
4724 )?;
4725 writeln!(
4726 code,
4727 " /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
4728 )?;
4729 writeln!(
4730 code,
4731 " fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
4732 )?;
4733 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
4734 writeln!(code, " let bp: u32 = base_pos as u32;")?;
4735 writeln!(code, " let nq: u32 = NUM_HEADS as u32;")?;
4736 writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
4737 writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
4738 writeln!(code, " let qs: u32 = qkv_stride as u32;")?;
4739 writeln!(code, " let theta: f32 = ROPE_THETA;")?;
4740 writeln!(
4741 code,
4742 " enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
4743 )?;
4744 writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
4745 writeln!(
4746 code,
4747 " enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
4748 )?;
4749 writeln!(
4750 code,
4751 " enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
4752 )?;
4753 writeln!(
4754 code,
4755 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
4756 )?;
4757 writeln!(
4758 code,
4759 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
4760 )?;
4761 writeln!(
4762 code,
4763 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
4764 )?;
4765 writeln!(
4766 code,
4767 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
4768 )?;
4769 writeln!(
4770 code,
4771 " enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
4772 )?;
4773 writeln!(
4774 code,
4775 " let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
4776 )?;
4777 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4778 writeln!(
4779 code,
4780 " let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
4781 )?;
4782 writeln!(
4783 code,
4784 " enc.dispatch_thread_groups(grid_size, tg_size);"
4785 )?;
4786 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4787 writeln!(code, " }}")?;
4788 writeln!(code)?;
4789
4790 writeln!(
4792 code,
4793 " /// Dispatch fused K+V cache copy in one kernel launch."
4794 )?;
4795 writeln!(
4796 code,
4797 " /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
4798 )?;
4799 writeln!(
4800 code,
4801 " fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
4802 )?;
4803 writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
4804 writeln!(code, " let kv: u32 = kv_dim as u32;")?;
4805 writeln!(code, " let bp: u32 = base_pos as u32;")?;
4806 writeln!(code, " let ss: u32 = src_stride as u32;")?;
4807 writeln!(code, " let ko: u32 = k_offset as u32;")?;
4808 writeln!(code, " let vo: u32 = v_offset as u32;")?;
4809 writeln!(
4810 code,
4811 " enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
4812 )?;
4813 writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
4814 writeln!(code, " enc.set_buffer(1, Some(k_dst), 0);")?;
4815 writeln!(code, " enc.set_buffer(2, Some(v_dst), 0);")?;
4816 writeln!(
4817 code,
4818 " enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
4819 )?;
4820 writeln!(
4821 code,
4822 " enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
4823 )?;
4824 writeln!(
4825 code,
4826 " enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
4827 )?;
4828 writeln!(
4829 code,
4830 " enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
4831 )?;
4832 writeln!(
4833 code,
4834 " enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
4835 )?;
4836 writeln!(
4837 code,
4838 " enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
4839 )?;
4840 writeln!(
4841 code,
4842 " let total = num_tokens * kv_dim * 2; // K + V"
4843 )?;
4844 writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
4845 writeln!(
4846 code,
4847 " let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
4848 )?;
4849 writeln!(
4850 code,
4851 " enc.dispatch_thread_groups(grid_size, tg_size);"
4852 )?;
4853 writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
4854 writeln!(code, " }}")?;
4855
4856 writeln!(code, "}}")?;
4857 writeln!(code)?;
4858
4859 Ok(())
4860}
4861
4862fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
4863 writeln!(
4864 code,
4865 "// ── Helper functions ──────────────────────────────────"
4866 )?;
4867 writeln!(code)?;
4868 writeln!(
4869 code,
4870 "/// Create a compute pipeline from a named function in the Metal library."
4871 )?;
4872 writeln!(
4873 code,
4874 "fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
4875 )?;
4876 writeln!(
4877 code,
4878 " let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
4879 )?;
4880 writeln!(
4881 code,
4882 " device.new_compute_pipeline_state_with_function(&func)"
4883 )?;
4884 writeln!(
4885 code,
4886 " .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
4887 )?;
4888 writeln!(code, "}}")?;
4889 writeln!(code)?;
4890
4891 Ok(())
4892}
4893
4894fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
4899 let _sanitized = sanitize_name(model_name);
4900 let _vocab = config.vocab_size;
4901
4902 let mut code = String::with_capacity(16 * 1024);
4903 writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
4904 writeln!(
4905 code,
4906 "//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
4907 )?;
4908 writeln!(code)?;
4909 writeln!(code, "mod model;")?;
4910 writeln!(code)?;
4911 writeln!(code, "use std::io::Write;")?;
4912 writeln!(code, "use std::time::Instant;")?;
4913 writeln!(code, "use serde::Deserialize;")?;
4914 writeln!(code)?;
4915
4916 writeln!(code, "fn main() {{")?;
4918 writeln!(
4919 code,
4920 " let args: Vec<String> = std::env::args().collect();"
4921 )?;
4922 writeln!(code)?;
4923 writeln!(
4924 code,
4925 " // Detect --serve mode (only requires weights + tokenizer)"
4926 )?;
4927 writeln!(
4928 code,
4929 " let serve_mode = args.iter().any(|a| a == \"--serve\");"
4930 )?;
4931 writeln!(code)?;
4932 writeln!(code, " if !serve_mode && args.len() < 4 {{")?;
4933 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
4934 writeln!(code, " eprintln!(\" {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
4935 writeln!(code, " std::process::exit(1);")?;
4936 writeln!(code, " }}")?;
4937 writeln!(code)?;
4938 writeln!(code, " if serve_mode && args.len() < 3 {{")?;
4939 writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
4940 writeln!(code, " std::process::exit(1);")?;
4941 writeln!(code, " }}")?;
4942 writeln!(code)?;
4943 writeln!(code, " let weights_path = &args[1];")?;
4944 writeln!(code, " let tokenizer_path = &args[2];")?;
4945 writeln!(code)?;
4946 writeln!(code, " // Parse optional flags")?;
4947 writeln!(code, " let mut max_tokens: usize = 128;")?;
4948 writeln!(code, " let mut port: u16 = 8080;")?;
4949 writeln!(
4950 code,
4951 " let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
4952 )?;
4953 writeln!(
4954 code,
4955 " let profile = args.iter().any(|a| a == \"--profile\");"
4956 )?;
4957 writeln!(code, " let mut i = 3;")?;
4958 writeln!(code, " while i < args.len() {{")?;
4959 writeln!(
4960 code,
4961 " if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
4962 )?;
4963 writeln!(
4964 code,
4965 " max_tokens = args[i + 1].parse().unwrap_or(128);"
4966 )?;
4967 writeln!(code, " i += 2;")?;
4968 writeln!(
4969 code,
4970 " }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
4971 )?;
4972 writeln!(
4973 code,
4974 " port = args[i + 1].parse().unwrap_or(8080);"
4975 )?;
4976 writeln!(code, " i += 2;")?;
4977 writeln!(code, " }} else if args[i] == \"--serve\" {{")?;
4978 writeln!(code, " i += 1;")?;
4979 writeln!(code, " }} else if args[i] == \"--profile\" {{")?;
4980 writeln!(code, " i += 1;")?;
4981 writeln!(code, " }} else {{")?;
4982 writeln!(code, " i += 1;")?;
4983 writeln!(code, " }}")?;
4984 writeln!(code, " }}")?;
4985 writeln!(code)?;
4986
4987 writeln!(
4989 code,
4990 " // Memory-map weights for zero-copy loading on Apple Silicon"
4991 )?;
4992 writeln!(
4993 code,
4994 " let weights_file = std::fs::File::open(weights_path)"
4995 )?;
4996 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
4997 writeln!(
4998 code,
4999 " let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
5000 )?;
5001 writeln!(code)?;
5002 writeln!(code, " // Load tokenizer")?;
5003 writeln!(
5004 code,
5005 " let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
5006 )?;
5007 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
5008 writeln!(code)?;
5009 writeln!(code, " // Create Metal model")?;
5010 writeln!(code, " eprintln!(\"Loading model onto Metal GPU...\");")?;
5011 writeln!(
5012 code,
5013 " let mut model = model::MetalModel::new(&weights_mmap);"
5014 )?;
5015 writeln!(code)?;
5016
5017 writeln!(code, " if serve_mode {{")?;
5019 writeln!(code, " serve(model, tokenizer, port);")?;
5020 writeln!(code, " }} else {{")?;
5021 writeln!(code, " let prompt = &args[3];")?;
5022 writeln!(
5023 code,
5024 " cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
5025 )?;
5026 writeln!(code, " }}")?;
5027 writeln!(code, "}}")?;
5028 writeln!(code)?;
5029
5030 writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
5032 writeln!(code, " // Tokenize prompt")?;
5033 writeln!(code, " let encoding = tokenizer.encode(prompt, true)")?;
5034 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
5035 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
5036 writeln!(code)?;
5037 writeln!(
5038 code,
5039 " // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
5040 )?;
5041 writeln!(
5042 code,
5043 " // Uses double-buffered batch dispatch for GPU-efficient matmul."
5044 )?;
5045 writeln!(
5046 code,
5047 " // The last token uses synchronous forward() to get logits."
5048 )?;
5049 writeln!(code, " let prompt_len = prompt_tokens.len();")?;
5050 writeln!(code, " let prefill_start = Instant::now();")?;
5051 writeln!(code, " let logits = if prompt_len > 1 {{")?;
5052 writeln!(
5053 code,
5054 " model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
5055 )?;
5056 writeln!(code, " model.forward(prompt_tokens[prompt_len - 1])")?;
5057 writeln!(code, " }} else {{")?;
5058 writeln!(code, " model.forward(prompt_tokens[0])")?;
5059 writeln!(code, " }};")?;
5060 writeln!(
5061 code,
5062 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
5063 )?;
5064 writeln!(code, " let prefill_tokens = prompt_tokens.len();")?;
5065 writeln!(
5066 code,
5067 " eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
5068 )?;
5069 writeln!(
5070 code,
5071 " prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
5072 )?;
5073 writeln!(code)?;
5074 writeln!(code, " // Generate tokens")?;
5075 writeln!(code, " let mut next_token = argmax(&logits);")?;
5076 writeln!(code, " let gen_start = Instant::now();")?;
5077 writeln!(code, " let mut generated_count: usize = 0;")?;
5078 writeln!(code)?;
5079 writeln!(code, " for _ in 0..max_tokens {{")?;
5080 writeln!(
5081 code,
5082 " if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
5083 )?;
5084 writeln!(code, " if !quiet {{")?;
5085 writeln!(code, " print!(\"{{}}\", text);")?;
5086 writeln!(code, " std::io::stdout().flush().ok();")?;
5087 writeln!(code, " }}")?;
5088 writeln!(code, " }}")?;
5089 writeln!(code, " generated_count += 1;")?;
5090 writeln!(code)?;
5091 writeln!(
5092 code,
5093 " // Use profiling forward for first token when --profile is set"
5094 )?;
5095 writeln!(
5096 code,
5097 " let logits = if profile && generated_count == 1 {{"
5098 )?;
5099 writeln!(code, " model.forward_profile(next_token)")?;
5100 writeln!(code, " }} else {{")?;
5101 writeln!(code, " model.forward(next_token)")?;
5102 writeln!(code, " }};")?;
5103 writeln!(code, " next_token = argmax(&logits);")?;
5104 writeln!(code)?;
5105 writeln!(code, " // Stop on EOS (token 2 for most models)")?;
5106 writeln!(code, " if next_token == 2 {{")?;
5107 writeln!(code, " break;")?;
5108 writeln!(code, " }}")?;
5109 writeln!(code)?;
5110 writeln!(
5111 code,
5112 " // Yield between tokens to reduce sustained GPU thermal load."
5113 )?;
5114 writeln!(
5115 code,
5116 " // On Apple Silicon, continuous GPU saturation causes thermal throttling"
5117 )?;
5118 writeln!(
5119 code,
5120 " // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
5121 )?;
5122 writeln!(
5123 code,
5124 " // briefly, providing a micro-break that helps sustain peak throughput."
5125 )?;
5126 writeln!(code, " std::thread::yield_now();")?;
5127 writeln!(code, " }}")?;
5128 writeln!(code, " if !quiet {{")?;
5129 writeln!(code, " println!();")?;
5130 writeln!(code, " }}")?;
5131 writeln!(
5132 code,
5133 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
5134 )?;
5135 writeln!(
5136 code,
5137 " eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
5138 )?;
5139 writeln!(
5140 code,
5141 " generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
5142 )?;
5143 writeln!(code, "}}")?;
5144 writeln!(code)?;
5145
5146 writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
5148 writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
5149 writeln!(code, " logits.iter()")?;
5150 writeln!(code, " .enumerate()")?;
5151 writeln!(
5152 code,
5153 " .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
5154 )?;
5155 writeln!(code, " .map(|(i, _)| i as u32)")?;
5156 writeln!(code, " .unwrap_or(0)")?;
5157 writeln!(code, "}}")?;
5158 writeln!(code)?;
5159
5160 writeln!(
5162 code,
5163 "// -----------------------------------------------------------------------"
5164 )?;
5165 writeln!(code, "// OpenAI-compatible API server")?;
5166 writeln!(
5167 code,
5168 "// -----------------------------------------------------------------------"
5169 )?;
5170 writeln!(code)?;
5171 writeln!(code, "#[derive(Deserialize)]")?;
5172 writeln!(code, "struct ChatRequest {{")?;
5173 writeln!(code, " messages: Vec<ChatMessage>,")?;
5174 writeln!(code, " #[serde(default)]")?;
5175 writeln!(code, " stream: Option<bool>,")?;
5176 writeln!(code, " #[serde(default)]")?;
5177 writeln!(code, " max_tokens: Option<usize>,")?;
5178 writeln!(code, " #[serde(default)]")?;
5179 writeln!(code, " temperature: Option<f32>,")?;
5180 writeln!(code, " #[serde(default)]")?;
5181 writeln!(code, " model: Option<String>,")?;
5182 writeln!(code, "}}")?;
5183 writeln!(code)?;
5184 writeln!(code, "#[derive(Deserialize)]")?;
5185 writeln!(code, "struct ChatMessage {{")?;
5186 writeln!(code, " role: String,")?;
5187 writeln!(code, " content: String,")?;
5188 writeln!(code, "}}")?;
5189 writeln!(code)?;
5190
5191 writeln!(
5193 code,
5194 "fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
5195 )?;
5196 writeln!(code, " let mut prompt = String::new();")?;
5197 writeln!(code, " for msg in messages {{")?;
5198 writeln!(code, " prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
5199 writeln!(code, " }}")?;
5200 writeln!(code, " prompt.push_str(\"<|im_start|>assistant\\n\");")?;
5201 writeln!(code, " prompt")?;
5202 writeln!(code, "}}")?;
5203 writeln!(code)?;
5204
5205 writeln!(
5207 code,
5208 "fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
5209 )?;
5210 writeln!(code, " let len = tokens.len();")?;
5211 writeln!(code, " if len > 1 {{")?;
5212 writeln!(
5213 code,
5214 " model.forward_prefill_batch(&tokens[..len - 1]);"
5215 )?;
5216 writeln!(code, " }}")?;
5217 writeln!(code, " model.forward(tokens[len - 1])")?;
5218 writeln!(code, "}}")?;
5219 writeln!(code)?;
5220
5221 writeln!(
5223 code,
5224 "fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
5225 )?;
5226 writeln!(code, " let addr = format!(\"0.0.0.0:{{}}\", port);")?;
5227 writeln!(code, " let server = tiny_http::Server::http(&addr)")?;
5228 writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
5229 writeln!(
5230 code,
5231 " eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
5232 )?;
5233 writeln!(code, " eprintln!(\"Endpoints:\");")?;
5234 writeln!(code, " eprintln!(\" POST /v1/chat/completions\");")?;
5235 writeln!(code, " eprintln!(\" GET /v1/models\");")?;
5236 writeln!(code, " eprintln!(\" GET /health\");")?;
5237 writeln!(code)?;
5238 writeln!(code, " for request in server.incoming_requests() {{")?;
5239 writeln!(code, " let method = request.method().to_string();")?;
5240 writeln!(code, " let url = request.url().to_string();")?;
5241 writeln!(code)?;
5242 writeln!(code, " match (method.as_str(), url.as_str()) {{")?;
5243
5244 writeln!(
5246 code,
5247 " (\"POST\", \"/v1/chat/completions\") => {{"
5248 )?;
5249 writeln!(
5250 code,
5251 " handle_chat_completion(&mut model, &tokenizer, request);"
5252 )?;
5253 writeln!(code, " }}")?;
5254
5255 writeln!(code, " (\"GET\", \"/v1/models\") => {{")?;
5257 writeln!(code, " let body = serde_json::json!({{")?;
5258 writeln!(code, " \"object\": \"list\",")?;
5259 writeln!(code, " \"data\": [{{")?;
5260 writeln!(code, " \"id\": \"forgellm-metal\",")?;
5261 writeln!(code, " \"object\": \"model\",")?;
5262 writeln!(code, " \"owned_by\": \"forgellm\"")?;
5263 writeln!(code, " }}]")?;
5264 writeln!(code, " }});")?;
5265 writeln!(
5266 code,
5267 " let resp = tiny_http::Response::from_string(body.to_string())"
5268 )?;
5269 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
5270 writeln!(code, " request.respond(resp).ok();")?;
5271 writeln!(code, " }}")?;
5272
5273 writeln!(code, " (\"GET\", \"/health\") => {{")?;
5275 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
5276 writeln!(code, " request.respond(resp).ok();")?;
5277 writeln!(code, " }}")?;
5278
5279 writeln!(code, " _ => {{")?;
5281 writeln!(
5282 code,
5283 " let resp = tiny_http::Response::from_string(\"Not Found\")"
5284 )?;
5285 writeln!(code, " .with_status_code(404);")?;
5286 writeln!(code, " request.respond(resp).ok();")?;
5287 writeln!(code, " }}")?;
5288 writeln!(code, " }}")?;
5289 writeln!(code, " }}")?;
5290 writeln!(code, "}}")?;
5291 writeln!(code)?;
5292
5293 writeln!(code, "fn handle_chat_completion(")?;
5295 writeln!(code, " model: &mut model::MetalModel,")?;
5296 writeln!(code, " tokenizer: &tokenizers::Tokenizer,")?;
5297 writeln!(code, " mut request: tiny_http::Request,")?;
5298 writeln!(code, ") {{")?;
5299 writeln!(code, " // Read request body")?;
5300 writeln!(code, " let mut body = String::new();")?;
5301 writeln!(
5302 code,
5303 " if request.as_reader().read_to_string(&mut body).is_err() {{"
5304 )?;
5305 writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
5306 writeln!(code, " .with_status_code(400);")?;
5307 writeln!(code, " request.respond(resp).ok();")?;
5308 writeln!(code, " return;")?;
5309 writeln!(code, " }}")?;
5310 writeln!(code)?;
5311 writeln!(code, " // Parse JSON")?;
5312 writeln!(
5313 code,
5314 " let req: ChatRequest = match serde_json::from_str(&body) {{"
5315 )?;
5316 writeln!(code, " Ok(r) => r,")?;
5317 writeln!(code, " Err(e) => {{")?;
5318 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
5319 writeln!(code, " .with_status_code(400);")?;
5320 writeln!(code, " request.respond(resp).ok();")?;
5321 writeln!(code, " return;")?;
5322 writeln!(code, " }}")?;
5323 writeln!(code, " }};")?;
5324 writeln!(code)?;
5325 writeln!(
5326 code,
5327 " let prompt = format_chat_messages(&req.messages);"
5328 )?;
5329 writeln!(
5330 code,
5331 " let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
5332 )?;
5333 writeln!(code, " Ok(e) => e,")?;
5334 writeln!(code, " Err(e) => {{")?;
5335 writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
5336 writeln!(code, " .with_status_code(500);")?;
5337 writeln!(code, " request.respond(resp).ok();")?;
5338 writeln!(code, " return;")?;
5339 writeln!(code, " }}")?;
5340 writeln!(code, " }};")?;
5341 writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
5342 writeln!(code, " let stream = req.stream.unwrap_or(false);")?;
5343 writeln!(code, " let max_tokens = req.max_tokens.unwrap_or(256);")?;
5344 writeln!(
5345 code,
5346 " let _temperature = req.temperature.unwrap_or(1.0);"
5347 )?;
5348 writeln!(code)?;
5349
5350 writeln!(code, " model.reset();")?;
5352 writeln!(code)?;
5353
5354 writeln!(code, " let prefill_start = Instant::now();")?;
5356 writeln!(code, " let logits = prefill(model, prompt_tokens);")?;
5357 writeln!(
5358 code,
5359 " let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
5360 )?;
5361 writeln!(code, " let prefill_count = prompt_tokens.len();")?;
5362 writeln!(code, " let mut next_token = argmax(&logits);")?;
5363 writeln!(code)?;
5364
5365 writeln!(code, " if stream {{")?;
5366
5367 writeln!(
5369 code,
5370 " // SSE streaming: generate tokens and build SSE body"
5371 )?;
5372 writeln!(code, " let gen_start = Instant::now();")?;
5373 writeln!(code, " let mut generated_count: usize = 0;")?;
5374 writeln!(code, " let mut sse_body = String::new();")?;
5375 writeln!(code, " for _ in 0..max_tokens {{")?;
5376 writeln!(code, " if next_token == 2 {{ break; }}")?;
5377 writeln!(
5378 code,
5379 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
5380 )?;
5381 writeln!(
5382 code,
5383 " let escaped = serde_json::to_string(&text).unwrap_or_default();"
5384 )?;
5385 writeln!(
5386 code,
5387 " // escaped includes surrounding quotes, strip them"
5388 )?;
5389 writeln!(
5390 code,
5391 " let inner = &escaped[1..escaped.len()-1];"
5392 )?;
5393 writeln!(code, " sse_body.push_str(&format!(")?;
5394 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
5395 writeln!(code, " inner")?;
5396 writeln!(code, " ));")?;
5397 writeln!(code, " }}")?;
5398 writeln!(code, " generated_count += 1;")?;
5399 writeln!(code, " let logits = model.forward(next_token);")?;
5400 writeln!(code, " next_token = argmax(&logits);")?;
5401 writeln!(code, " }}")?;
5402 writeln!(
5403 code,
5404 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
5405 )?;
5406 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
5407 writeln!(code, " let gen_time_ms = gen_elapsed * 1000.0;")?;
5408 writeln!(code)?;
5409 writeln!(
5410 code,
5411 " // Final chunk with finish_reason, timing, and DONE sentinel"
5412 )?;
5413 writeln!(code, " sse_body.push_str(&format!(")?;
5414 writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
5415 writeln!(code, " prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
5416 writeln!(code, " ));")?;
5417 writeln!(code)?;
5418 writeln!(
5419 code,
5420 " let resp = tiny_http::Response::from_string(sse_body)"
5421 )?;
5422 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
5423 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
5424 writeln!(code, " request.respond(resp).ok();")?;
5425
5426 writeln!(code, " }} else {{")?;
5427
5428 writeln!(
5430 code,
5431 " // Non-streaming: generate all tokens, return JSON"
5432 )?;
5433 writeln!(code, " let gen_start = Instant::now();")?;
5434 writeln!(code, " let mut generated_count: usize = 0;")?;
5435 writeln!(code, " let mut generated = String::new();")?;
5436 writeln!(code, " for _ in 0..max_tokens {{")?;
5437 writeln!(code, " if next_token == 2 {{ break; }}")?;
5438 writeln!(
5439 code,
5440 " if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
5441 )?;
5442 writeln!(code, " generated.push_str(&text);")?;
5443 writeln!(code, " }}")?;
5444 writeln!(code, " generated_count += 1;")?;
5445 writeln!(code, " let logits = model.forward(next_token);")?;
5446 writeln!(code, " next_token = argmax(&logits);")?;
5447 writeln!(code, " }}")?;
5448 writeln!(
5449 code,
5450 " let gen_elapsed = gen_start.elapsed().as_secs_f64();"
5451 )?;
5452 writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
5453 writeln!(code)?;
5454 writeln!(code, " let resp_json = serde_json::json!({{")?;
5455 writeln!(code, " \"id\": \"chatcmpl-1\",")?;
5456 writeln!(code, " \"object\": \"chat.completion\",")?;
5457 writeln!(code, " \"choices\": [{{")?;
5458 writeln!(code, " \"index\": 0,")?;
5459 writeln!(code, " \"message\": {{")?;
5460 writeln!(code, " \"role\": \"assistant\",")?;
5461 writeln!(code, " \"content\": generated")?;
5462 writeln!(code, " }},")?;
5463 writeln!(code, " \"finish_reason\": \"stop\"")?;
5464 writeln!(code, " }}],")?;
5465 writeln!(code, " \"usage\": {{")?;
5466 writeln!(code, " \"prefill_tokens\": prefill_count,")?;
5467 writeln!(
5468 code,
5469 " \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
5470 )?;
5471 writeln!(
5472 code,
5473 " \"generation_tokens\": generated_count,"
5474 )?;
5475 writeln!(
5476 code,
5477 " \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
5478 )?;
5479 writeln!(code, " \"tokens_per_sec\": gen_tok_s")?;
5480 writeln!(code, " }}")?;
5481 writeln!(code, " }});")?;
5482 writeln!(
5483 code,
5484 " let resp = tiny_http::Response::from_string(resp_json.to_string())"
5485 )?;
5486 writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
5487 writeln!(code, " request.respond(resp).ok();")?;
5488 writeln!(code, " }}")?;
5489 writeln!(code, "}}")?;
5490
5491 Ok(code)
5492}
5493
5494#[cfg(test)]
5499mod tests {
5500 use super::*;
5501 use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
5502
5503 fn minimal_config() -> ModelConfig {
5504 ModelConfig {
5505 architecture: Architecture::Llama,
5506 hidden_size: 64,
5507 intermediate_size: 128,
5508 num_layers: 2,
5509 num_attention_heads: 4,
5510 num_kv_heads: 4,
5511 head_dim: 16,
5512 vocab_size: 256,
5513 max_seq_len: 512,
5514 rms_norm_eps: 1e-5,
5515 rope_theta: 10000.0,
5516 dtype: DType::F32,
5517 sliding_window_size: None,
5518 qkv_bias: false,
5519 }
5520 }
5521
5522 fn minimal_graph() -> Graph {
5523 Graph::new("test-metal").with_config(minimal_config())
5524 }
5525
5526 #[test]
5527 fn generate_metal_project_creates_files() {
5528 let dir = tempfile::tempdir().unwrap();
5529 let graph = minimal_graph();
5530 generate_metal_project(&graph, dir.path(), "test-model").unwrap();
5531
5532 assert!(
5533 dir.path().join("Cargo.toml").exists(),
5534 "Cargo.toml should be created"
5535 );
5536 assert!(
5537 dir.path().join("src/model.rs").exists(),
5538 "src/model.rs should be created"
5539 );
5540 assert!(
5541 dir.path().join("src/main.rs").exists(),
5542 "src/main.rs should be created"
5543 );
5544 assert!(
5545 dir.path().join("shaders/kernels.metal").exists(),
5546 "shaders/kernels.metal should be created"
5547 );
5548 }
5549
5550 #[test]
5551 fn generated_cargo_toml_has_metal_dep() {
5552 let toml = generate_cargo_toml("my-model");
5553 assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
5554 assert!(
5555 toml.contains("tokenizers"),
5556 "Cargo.toml should depend on tokenizers"
5557 );
5558 assert!(
5559 toml.contains("memmap2"),
5560 "Cargo.toml should depend on memmap2"
5561 );
5562 assert!(toml.contains("half"), "Cargo.toml should depend on half");
5563 }
5564
5565 #[test]
5566 fn generated_model_rs_contains_metal_code() {
5567 let config = minimal_config();
5568 let model_rs = generate_model_rs(&config).unwrap();
5569
5570 assert!(
5571 model_rs.contains("pub struct MetalModel"),
5572 "model.rs should define MetalModel struct"
5573 );
5574 assert!(
5575 model_rs.contains("matmul_pipeline: ComputePipelineState"),
5576 "MetalModel should have matmul_pipeline field"
5577 );
5578 assert!(
5579 model_rs.contains("Device::system_default()"),
5580 "model.rs should use Metal device"
5581 );
5582 assert!(
5583 model_rs.contains("new_library_with_source"),
5584 "model.rs should compile Metal shaders"
5585 );
5586 assert!(
5587 model_rs.contains("fn new(weights: &[u8])"),
5588 "MetalModel should implement new()"
5589 );
5590 assert!(
5591 model_rs.contains("fn forward(&mut self, token_id: u32)"),
5592 "MetalModel should implement forward()"
5593 );
5594 }
5595
5596 #[test]
5597 fn generated_shaders_contain_kernel_names() {
5598 let shaders = generate_metal_shaders(&minimal_config());
5599
5600 assert!(
5601 shaders.contains("kernel void matmul_vec"),
5602 "shaders should contain matmul_vec kernel"
5603 );
5604 assert!(
5605 shaders.contains("kernel void rms_norm"),
5606 "shaders should contain rms_norm kernel"
5607 );
5608 assert!(
5609 shaders.contains("kernel void rope"),
5610 "shaders should contain rope kernel"
5611 );
5612 assert!(
5613 shaders.contains("kernel void softmax"),
5614 "shaders should contain softmax kernel"
5615 );
5616 assert!(
5617 shaders.contains("kernel void silu_mul("),
5618 "shaders should contain silu_mul kernel"
5619 );
5620 assert!(
5621 shaders.contains("kernel void silu_mul_fused"),
5622 "shaders should contain silu_mul_fused kernel"
5623 );
5624 assert!(
5625 shaders.contains("kernel void elementwise_add"),
5626 "shaders should contain elementwise_add kernel"
5627 );
5628 assert!(
5629 shaders.contains("kernel void attention"),
5630 "shaders should contain attention kernel"
5631 );
5632 assert!(
5633 shaders.contains("kernel void add_inplace"),
5634 "shaders should contain add_inplace kernel"
5635 );
5636 assert!(
5637 shaders.contains("kernel void copy_buffer"),
5638 "shaders should contain copy_buffer kernel"
5639 );
5640 assert!(
5641 shaders.contains("kernel void copy_offset"),
5642 "shaders should contain copy_offset kernel"
5643 );
5644 }
5645
5646 #[test]
5647 fn generated_shaders_use_simdgroup_features() {
5648 let shaders = generate_metal_shaders(&minimal_config());
5649
5650 assert!(
5651 shaders.contains("threadgroup_barrier"),
5652 "shaders should use threadgroup barriers"
5653 );
5654 assert!(
5655 shaders.contains("threadgroup float"),
5656 "shaders should use threadgroup shared memory"
5657 );
5658 assert!(
5659 shaders.contains("thread_index_in_threadgroup"),
5660 "shaders should use threadgroup indexing"
5661 );
5662 assert!(
5663 shaders.contains("simd_sum"),
5664 "shaders should use simd_sum for warp-level reduction"
5665 );
5666 assert!(
5667 shaders.contains("simd_max"),
5668 "attention kernel should use simd_max for cooperative softmax"
5669 );
5670 assert!(
5671 shaders.contains("thread_index_in_simdgroup"),
5672 "shaders should use simdgroup lane indexing"
5673 );
5674 assert!(
5675 shaders.contains("simdgroup_index_in_threadgroup"),
5676 "shaders should use simdgroup indexing within threadgroup"
5677 );
5678 assert!(
5679 shaders.contains("float4"),
5680 "matmul_vec should use float4 vectorized loads"
5681 );
5682 }
5683
5684 #[test]
5685 fn generated_main_rs_has_tokenizer_usage() {
5686 let config = minimal_config();
5687 let main_rs = generate_main_rs("test-model", &config).unwrap();
5688
5689 assert!(
5690 main_rs.contains("tokenizers::Tokenizer"),
5691 "main.rs should use tokenizers crate"
5692 );
5693 assert!(
5694 main_rs.contains("MetalModel::new"),
5695 "main.rs should call MetalModel::new"
5696 );
5697 assert!(
5698 main_rs.contains("model.forward"),
5699 "main.rs should call model.forward"
5700 );
5701 assert!(
5702 main_rs.contains("memmap2"),
5703 "main.rs should use memmap2 for zero-copy weight loading"
5704 );
5705 }
5706
5707 #[test]
5708 fn missing_config_returns_error() {
5709 let dir = tempfile::tempdir().unwrap();
5710 let graph = Graph::new("no-config");
5711 let result = generate_metal_project(&graph, dir.path(), "fail");
5712 assert!(
5713 matches!(result, Err(MetalCodegenError::MissingConfig)),
5714 "should fail with MissingConfig when graph has no config"
5715 );
5716 }
5717
5718 #[test]
5719 fn sanitize_name_works() {
5720 assert_eq!(sanitize_name("My Model!"), "my-model");
5721 assert_eq!(sanitize_name("test_model"), "test-model");
5722 assert_eq!(sanitize_name("simple"), "simple");
5723 }
5724
5725 #[test]
5726 fn generated_forward_uses_single_command_buffer() {
5727 let config = minimal_config();
5728 let model_rs = generate_model_rs(&config).unwrap();
5729
5730 let forward_start = model_rs
5733 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
5734 .unwrap();
5735 let forward_body = &model_rs[forward_start..];
5736 let forward_end = forward_body
5738 .find("\n pub fn forward_profile")
5739 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
5740 .or_else(|| forward_body.find("\n fn dispatch_"))
5741 .unwrap_or(forward_body.len());
5742 let forward_code = &forward_body[..forward_end];
5743
5744 let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
5746 assert_eq!(
5747 cmd_buf_count, 1,
5748 "forward() should create exactly 1 command buffer, found {cmd_buf_count}"
5749 );
5750
5751 let commit_count = forward_code.matches("cmd.commit()").count();
5753 assert_eq!(
5754 commit_count, 1,
5755 "forward() should commit exactly once, found {commit_count}"
5756 );
5757
5758 let wait_count = forward_code.matches("wait_until_completed()").count();
5760 assert!(
5761 wait_count >= 1 && wait_count <= 2,
5762 "forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
5763 );
5764 }
5765
5766 #[test]
5767 fn generated_model_has_preallocated_working_buffers() {
5768 let config = minimal_config();
5769 let model_rs = generate_model_rs(&config).unwrap();
5770
5771 for buf_name in &[
5772 "normed_buf",
5773 "qkv_buf",
5774 "attn_out_buf",
5775 "attn_proj_buf",
5776 "gate_up_buf",
5777 "ffn_hidden_buf",
5778 "ffn_out_buf",
5779 "add_tmp_buf",
5780 ] {
5781 assert!(
5782 model_rs.contains(&format!("{buf_name}: Buffer")),
5783 "MetalModel should have pre-allocated {buf_name} field"
5784 );
5785 }
5786 }
5787
5788 #[test]
5789 fn generated_dispatch_helpers_take_compute_encoder_ref() {
5790 let config = minimal_config();
5791 let model_rs = generate_model_rs(&config).unwrap();
5792
5793 for method in &[
5794 "fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
5795 "fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
5796 "fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
5797 "fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
5798 "fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
5799 "fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
5800 "fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
5801 "fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
5802 "fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
5803 "fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
5804 "fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
5805 "fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
5806 "fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
5807 ] {
5808 assert!(
5809 model_rs.contains(method),
5810 "model.rs should contain dispatch helper: {method}"
5811 );
5812 }
5813 }
5814
5815 #[test]
5816 fn generated_helpers_do_not_create_command_buffers_or_encoders() {
5817 let config = minimal_config();
5818 let model_rs = generate_model_rs(&config).unwrap();
5819
5820 let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
5822 let helpers_code = &model_rs[helpers_start..];
5823
5824 assert!(
5826 !helpers_code.contains("self.queue.new_command_buffer()"),
5827 "dispatch helpers should not create their own command buffers"
5828 );
5829
5830 assert!(
5832 !helpers_code.contains("new_compute_command_encoder()"),
5833 "dispatch helpers should not create their own compute encoders"
5834 );
5835
5836 assert!(
5838 !helpers_code.contains("end_encoding()"),
5839 "dispatch helpers should not call end_encoding"
5840 );
5841
5842 assert!(
5844 !helpers_code.contains(".commit()"),
5845 "dispatch helpers should not commit command buffers"
5846 );
5847 assert!(
5848 !helpers_code.contains("wait_until_completed"),
5849 "dispatch helpers should not wait on command buffers"
5850 );
5851 }
5852
5853 #[test]
5854 fn generated_forward_batches_compute_encoders() {
5855 let config = minimal_config();
5856 let model_rs = generate_model_rs(&config).unwrap();
5857
5858 let forward_start = model_rs
5860 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
5861 .unwrap();
5862 let forward_body = &model_rs[forward_start..];
5863 let forward_end = forward_body
5864 .find("\n pub fn forward_profile")
5865 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
5866 .or_else(|| forward_body.find("\n fn dispatch_"))
5867 .unwrap_or(forward_body.len());
5868 let forward_code = &forward_body[..forward_end];
5869
5870 assert!(
5872 !forward_code.contains("device.new_buffer"),
5873 "forward() should not allocate new buffers per call"
5874 );
5875
5876 let compute_encoder_count = forward_code
5879 .matches("new_compute_command_encoder()")
5880 .count();
5881 let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
5882
5883 assert_eq!(
5885 compute_encoder_count, 1,
5886 "forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
5887 );
5888 assert_eq!(
5889 blit_encoder_count, 0,
5890 "forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
5891 );
5892 }
5893
5894 #[test]
5895 fn generated_forward_uses_add_inplace() {
5896 let config = minimal_config();
5897 let model_rs = generate_model_rs(&config).unwrap();
5898
5899 assert!(
5901 model_rs.contains("dispatch_add_inplace"),
5902 "forward() should use dispatch_add_inplace for residual connections"
5903 );
5904 assert!(
5905 model_rs.contains("add_inplace_pipeline"),
5906 "MetalModel should have add_inplace_pipeline"
5907 );
5908 }
5909
5910 fn minimal_q8_config() -> ModelConfig {
5911 ModelConfig {
5912 architecture: Architecture::Llama,
5913 hidden_size: 64,
5914 intermediate_size: 128,
5915 num_layers: 2,
5916 num_attention_heads: 4,
5917 num_kv_heads: 4,
5918 head_dim: 16,
5919 vocab_size: 256,
5920 max_seq_len: 512,
5921 rms_norm_eps: 1e-5,
5922 rope_theta: 10000.0,
5923 dtype: DType::Q8_0,
5924 sliding_window_size: None,
5925 qkv_bias: false,
5926 }
5927 }
5928
5929 #[test]
5930 fn generated_shaders_contain_q8_kernel() {
5931 let shaders = generate_metal_shaders(&minimal_config());
5932
5933 assert!(
5934 shaders.contains("kernel void matmul_vec_q8"),
5935 "shaders should contain matmul_vec_q8 kernel"
5936 );
5937 assert!(
5938 shaders.contains("device const uchar* matrix"),
5939 "matmul_vec_q8 should accept raw Q8_0 bytes"
5940 );
5941 assert!(
5942 shaders.contains("packed_short4"),
5943 "matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
5944 );
5945 assert!(
5946 shaders.contains("as_type<char2>"),
5947 "matmul_vec_q8 should bitcast short lanes to char2"
5948 );
5949 assert!(
5950 shaders.contains("device const half*"),
5951 "matmul_vec_q8 should read f16 scale via half pointer"
5952 );
5953 }
5954
5955 #[test]
5956 fn generated_model_uses_fused_qkv_projections() {
5957 let config = minimal_config();
5958 let model_rs = generate_model_rs(&config).unwrap();
5959
5960 assert!(
5962 model_rs.contains("qkv_weight: Buffer"),
5963 "LayerBuffers should have fused qkv_weight field"
5964 );
5965 assert!(
5967 !model_rs.contains(" q_weight: Buffer"),
5968 "LayerBuffers should not have separate q_weight field"
5969 );
5970 assert!(
5971 !model_rs.contains(" k_weight: Buffer"),
5972 "LayerBuffers should not have separate k_weight field"
5973 );
5974 assert!(
5975 !model_rs.contains(" v_weight: Buffer"),
5976 "LayerBuffers should not have separate v_weight field"
5977 );
5978
5979 assert!(
5981 model_rs.contains("gate_up_weight: Buffer"),
5982 "LayerBuffers should have fused gate_up_weight field"
5983 );
5984 assert!(
5986 !model_rs.contains(" gate_weight: Buffer"),
5987 "LayerBuffers should not have separate gate_weight field"
5988 );
5989 assert!(
5990 !model_rs.contains(" up_weight: Buffer"),
5991 "LayerBuffers should not have separate up_weight field"
5992 );
5993
5994 assert!(
5996 model_rs.contains("qkv_buf: Buffer"),
5997 "MetalModel should have fused qkv_buf"
5998 );
5999 assert!(
6000 model_rs.contains("gate_up_buf: Buffer"),
6001 "MetalModel should have fused gate_up_buf"
6002 );
6003
6004 assert!(
6006 model_rs.contains("dispatch_silu_mul_fused"),
6007 "forward pass should use dispatch_silu_mul_fused"
6008 );
6009 assert!(
6010 model_rs.contains("dispatch_rope_offset"),
6011 "forward pass should use dispatch_rope_offset for fused QKV"
6012 );
6013 assert!(
6014 model_rs.contains("dispatch_attention_offset"),
6015 "forward pass should use dispatch_attention_offset for fused QKV"
6016 );
6017 }
6018
6019 #[test]
6020 fn q8_model_has_matmul_q8_pipeline() {
6021 let config = minimal_q8_config();
6022 let model_rs = generate_model_rs(&config).unwrap();
6023
6024 assert!(
6025 model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
6026 "MetalModel should have matmul_q8_pipeline field"
6027 );
6028 assert!(
6029 model_rs.contains("matmul_q8_pipeline,"),
6030 "MetalModel Self should include matmul_q8_pipeline"
6031 );
6032 }
6033
6034 #[test]
6035 fn q8_model_uses_dispatch_matmul_q8() {
6036 let config = minimal_q8_config();
6037 let model_rs = generate_model_rs(&config).unwrap();
6038
6039 assert!(
6040 model_rs.contains("dispatch_matmul_q8"),
6041 "Q8_0 model should use dispatch_matmul_q8 for projections"
6042 );
6043 assert!(
6044 model_rs.contains("fn dispatch_matmul_q8"),
6045 "model.rs should define dispatch_matmul_q8 method"
6046 );
6047 }
6048
6049 #[test]
6050 fn q8_model_loads_raw_bytes_not_dequantized() {
6051 let config = minimal_q8_config();
6052 let model_rs = generate_model_rs(&config).unwrap();
6053
6054 assert!(
6056 !model_rs.contains("f16_to_f32"),
6057 "Q8_0 model should not dequantize weights to f32"
6058 );
6059 assert!(
6060 !model_rs.contains("f32_data"),
6061 "Q8_0 model should not create f32 weight data"
6062 );
6063
6064 assert!(
6066 model_rs.contains("total_raw as u64"),
6067 "Q8_0 model should load raw bytes into Metal buffer"
6068 );
6069 }
6070
6071 #[test]
6072 fn q8_model_norms_stay_f32() {
6073 let config = minimal_q8_config();
6074 let model_rs = generate_model_rs(&config).unwrap();
6075
6076 assert!(
6078 model_rs.contains("let attn_norm = next_f32_buffer"),
6079 "attn_norm should use f32 buffer even for Q8_0 models"
6080 );
6081 assert!(
6082 model_rs.contains("let ffn_norm = next_f32_buffer"),
6083 "ffn_norm should use f32 buffer even for Q8_0 models"
6084 );
6085 assert!(
6086 model_rs.contains("let norm_buf = next_f32_buffer"),
6087 "final norm should use f32 buffer even for Q8_0 models"
6088 );
6089 }
6090
6091 #[test]
6092 fn q8_model_uses_fused_weight_loading() {
6093 let config = minimal_q8_config();
6094 let model_rs = generate_model_rs(&config).unwrap();
6095
6096 assert!(
6098 model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
6099 "Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
6100 );
6101 assert!(
6103 model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
6104 "Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
6105 );
6106 assert!(
6108 model_rs.contains("let o_weight = next_q8_buffer"),
6109 "Q8_0 model should use next_q8_buffer for O weight"
6110 );
6111 assert!(
6112 model_rs.contains("let down_weight = next_q8_buffer"),
6113 "Q8_0 model should use next_q8_buffer for down weight"
6114 );
6115 }
6116
6117 #[test]
6118 fn f32_model_does_not_use_q8_dispatch() {
6119 let config = minimal_config();
6120 let model_rs = generate_model_rs(&config).unwrap();
6121
6122 let forward_start = model_rs
6124 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6125 .unwrap();
6126 let forward_body = &model_rs[forward_start..];
6127 let forward_end = forward_body
6128 .find("\n fn dispatch_")
6129 .unwrap_or(forward_body.len());
6130 let forward_code = &forward_body[..forward_end];
6131
6132 assert!(
6133 !forward_code.contains("dispatch_matmul_q8"),
6134 "f32 model forward should not use dispatch_matmul_q8"
6135 );
6136 }
6137
6138 #[test]
6139 fn q8_dispatch_helper_takes_compute_encoder_ref() {
6140 let config = minimal_q8_config();
6141 let model_rs = generate_model_rs(&config).unwrap();
6142
6143 assert!(
6144 model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
6145 "dispatch_matmul_q8 should take ComputeCommandEncoderRef"
6146 );
6147 }
6148
6149 #[test]
6150 fn generated_model_has_double_buffered_prefill() {
6151 let config = minimal_config();
6152 let model_rs = generate_model_rs(&config).unwrap();
6153
6154 assert!(
6156 model_rs.contains("prev_cmd: Option<CommandBuffer>"),
6157 "MetalModel should have prev_cmd field for double-buffered prefill"
6158 );
6159
6160 assert!(
6162 model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
6163 "MetalModel should have forward_prefill method"
6164 );
6165
6166 assert!(
6168 model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
6169 "forward() should drain prev_cmd from previous prefill"
6170 );
6171 }
6172
6173 #[test]
6174 fn generated_main_rs_uses_forward_prefill_for_prompt() {
6175 let config = minimal_config();
6176 let main_rs = generate_main_rs("test-model", &config).unwrap();
6177
6178 assert!(
6179 main_rs.contains("forward_prefill"),
6180 "main.rs should use forward_prefill for intermediate prompt tokens"
6181 );
6182 assert!(
6183 main_rs.contains("double-buffered"),
6184 "main.rs should document double-buffered prefill"
6185 );
6186 }
6187
6188 #[test]
6189 fn generated_shaders_q8_uses_wide_vectorized_loads() {
6190 let shaders = generate_metal_shaders(&minimal_config());
6191
6192 assert!(
6193 shaders.contains("packed_short4"),
6194 "matmul_vec_q8 should use packed_short4 wide 64-bit loads"
6195 );
6196 assert!(
6197 shaders.contains("d0[0]"),
6198 "matmul_vec_q8 should index the wide pointer for row 0"
6199 );
6200 assert!(
6201 shaders.contains("as_type<char2>"),
6202 "matmul_vec_q8 should bitcast short lanes to char2"
6203 );
6204 assert!(
6205 shaders.contains("dot("),
6206 "matmul_vec_q8 should use dot() intrinsic for fma accumulation"
6207 );
6208 }
6209
6210 fn minimal_q4_config() -> ModelConfig {
6213 ModelConfig {
6214 architecture: Architecture::Llama,
6215 hidden_size: 64,
6216 intermediate_size: 128,
6217 num_layers: 2,
6218 num_attention_heads: 4,
6219 num_kv_heads: 4,
6220 head_dim: 16,
6221 vocab_size: 256,
6222 max_seq_len: 512,
6223 rms_norm_eps: 1e-5,
6224 rope_theta: 10000.0,
6225 dtype: DType::Q4_0,
6226 sliding_window_size: None,
6227 qkv_bias: false,
6228 }
6229 }
6230
6231 #[test]
6232 fn generated_shaders_contain_q4_kernel() {
6233 let shaders = generate_metal_shaders(&minimal_config());
6234
6235 assert!(
6236 shaders.contains("kernel void matmul_vec_q4"),
6237 "shaders should contain matmul_vec_q4 kernel"
6238 );
6239 assert!(
6240 shaders.contains("Q4_ROWS_PER_TG"),
6241 "shaders should define Q4_ROWS_PER_TG constant"
6242 );
6243 assert!(
6244 shaders.contains("Q4_ROWS_PER_SG"),
6245 "shaders should define Q4_ROWS_PER_SG constant"
6246 );
6247 }
6248
6249 #[test]
6250 fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
6251 let shaders = generate_metal_shaders(&minimal_config());
6252
6253 assert!(
6255 shaders.contains("uchar4"),
6256 "matmul_vec_q4 should use uchar4 for packed byte loads"
6257 );
6258 assert!(
6260 shaders.contains("&0xF"),
6261 "matmul_vec_q4 should extract low nibble with &0xF"
6262 );
6263 assert!(
6264 shaders.contains(">>4"),
6265 "matmul_vec_q4 should extract high nibble with >>4"
6266 );
6267 assert!(
6269 shaders.contains("-8)"),
6270 "matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
6271 );
6272 assert!(
6274 shaders.contains("blk * 18"),
6275 "matmul_vec_q4 should use 18-byte block stride"
6276 );
6277 }
6278
6279 #[test]
6280 fn q4_model_has_matmul_q4_pipeline() {
6281 let config = minimal_q4_config();
6282 let model_rs = generate_model_rs(&config).unwrap();
6283
6284 assert!(
6285 model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
6286 "MetalModel should have matmul_q4_pipeline field"
6287 );
6288 assert!(
6289 model_rs.contains("matmul_q4_pipeline,"),
6290 "MetalModel Self should include matmul_q4_pipeline"
6291 );
6292 }
6293
6294 #[test]
6295 fn q4_model_uses_dispatch_matmul_q4() {
6296 let config = minimal_q4_config();
6297 let model_rs = generate_model_rs(&config).unwrap();
6298
6299 assert!(
6300 model_rs.contains("dispatch_matmul_q4"),
6301 "Q4_0 model should use dispatch_matmul_q4 for projections"
6302 );
6303 assert!(
6304 model_rs.contains("fn dispatch_matmul_q4"),
6305 "model.rs should define dispatch_matmul_q4 method"
6306 );
6307 }
6308
6309 #[test]
6310 fn q4_model_loads_raw_bytes_not_dequantized() {
6311 let config = minimal_q4_config();
6312 let model_rs = generate_model_rs(&config).unwrap();
6313
6314 assert!(
6316 !model_rs.contains("f16_to_f32"),
6317 "Q4_0 model should not dequantize weights to f32"
6318 );
6319
6320 assert!(
6322 model_rs.contains("total_raw as u64"),
6323 "Q4_0 model should load raw bytes into Metal buffer"
6324 );
6325 }
6326
6327 #[test]
6328 fn q4_model_norms_stay_f32() {
6329 let config = minimal_q4_config();
6330 let model_rs = generate_model_rs(&config).unwrap();
6331
6332 assert!(
6333 model_rs.contains("let attn_norm = next_f32_buffer"),
6334 "attn_norm should use f32 buffer even for Q4_0 models"
6335 );
6336 assert!(
6337 model_rs.contains("let ffn_norm = next_f32_buffer"),
6338 "ffn_norm should use f32 buffer even for Q4_0 models"
6339 );
6340 assert!(
6341 model_rs.contains("let norm_buf = next_f32_buffer"),
6342 "final norm should use f32 buffer even for Q4_0 models"
6343 );
6344 }
6345
6346 #[test]
6347 fn q4_model_uses_fused_weight_loading() {
6348 let config = minimal_q4_config();
6349 let model_rs = generate_model_rs(&config).unwrap();
6350
6351 assert!(
6352 model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
6353 "Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
6354 );
6355 assert!(
6356 model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
6357 "Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
6358 );
6359 assert!(
6360 model_rs.contains("let o_weight = next_q4_buffer"),
6361 "Q4_0 model should use next_q4_buffer for O weight"
6362 );
6363 assert!(
6364 model_rs.contains("let down_weight = next_q4_buffer"),
6365 "Q4_0 model should use next_q4_buffer for down weight"
6366 );
6367 }
6368
6369 #[test]
6370 fn q4_dispatch_helper_takes_compute_encoder_ref() {
6371 let config = minimal_q4_config();
6372 let model_rs = generate_model_rs(&config).unwrap();
6373
6374 assert!(
6375 model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
6376 "dispatch_matmul_q4 should take ComputeCommandEncoderRef"
6377 );
6378 }
6379
6380 #[test]
6381 fn f32_model_does_not_use_q4_dispatch() {
6382 let config = minimal_config();
6383 let model_rs = generate_model_rs(&config).unwrap();
6384
6385 let forward_start = model_rs
6386 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6387 .unwrap();
6388 let forward_body = &model_rs[forward_start..];
6389 let forward_end = forward_body
6390 .find("\n fn dispatch_")
6391 .unwrap_or(forward_body.len());
6392 let forward_code = &forward_body[..forward_end];
6393
6394 assert!(
6395 !forward_code.contains("dispatch_matmul_q4"),
6396 "f32 model forward should not use dispatch_matmul_q4"
6397 );
6398 }
6399
6400 #[test]
6401 fn q4_model_lm_head_uses_q4_buffer() {
6402 let config = minimal_q4_config();
6403 let model_rs = generate_model_rs(&config).unwrap();
6404
6405 assert!(
6406 model_rs.contains("let lm_head_buf = next_q4_buffer"),
6407 "Q4_0 model should use next_q4_buffer for lm_head"
6408 );
6409 }
6410
6411 #[test]
6412 fn vec_tile_size_matches_model_dimensions() {
6413 let small = minimal_config();
6415 let shaders_small = generate_metal_shaders(&small);
6416 assert!(
6417 shaders_small.contains("vec_tile[128]"),
6418 "vec_tile should be sized to max(hidden, intermediate) = 128"
6419 );
6420
6421 let mut large = minimal_config();
6423 large.hidden_size = 2048;
6424 large.intermediate_size = 8192;
6425 let shaders_large = generate_metal_shaders(&large);
6426 assert!(
6427 shaders_large.contains("vec_tile[8192]"),
6428 "vec_tile should be 8192 for models with intermediate=8192"
6429 );
6430 assert!(
6431 !shaders_large.contains("vec_tile[4096]"),
6432 "vec_tile should NOT be hardcoded to 4096"
6433 );
6434 }
6435
6436 #[test]
6437 fn generated_cargo_toml_has_server_deps() {
6438 let toml = generate_cargo_toml("my-model");
6439 assert!(
6440 toml.contains("tiny_http"),
6441 "Cargo.toml should depend on tiny_http"
6442 );
6443 assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
6444 assert!(
6445 toml.contains("serde_json"),
6446 "Cargo.toml should depend on serde_json"
6447 );
6448 }
6449
6450 #[test]
6451 fn generated_main_rs_has_serve_mode() {
6452 let config = minimal_config();
6453 let main_rs = generate_main_rs("test-model", &config).unwrap();
6454
6455 assert!(
6456 main_rs.contains("--serve"),
6457 "main.rs should parse --serve flag"
6458 );
6459 assert!(
6460 main_rs.contains("--port"),
6461 "main.rs should parse --port flag"
6462 );
6463 assert!(
6464 main_rs.contains("fn serve("),
6465 "main.rs should define serve function"
6466 );
6467 assert!(
6468 main_rs.contains("tiny_http::Server::http"),
6469 "main.rs should create tiny_http server"
6470 );
6471 }
6472
6473 #[test]
6474 fn generated_main_rs_has_chat_completions_endpoint() {
6475 let config = minimal_config();
6476 let main_rs = generate_main_rs("test-model", &config).unwrap();
6477
6478 assert!(
6479 main_rs.contains("/v1/chat/completions"),
6480 "main.rs should handle /v1/chat/completions endpoint"
6481 );
6482 assert!(
6483 main_rs.contains("/v1/models"),
6484 "main.rs should handle /v1/models endpoint"
6485 );
6486 assert!(
6487 main_rs.contains("/health"),
6488 "main.rs should handle /health endpoint"
6489 );
6490 }
6491
6492 #[test]
6493 fn generated_main_rs_has_sse_streaming() {
6494 let config = minimal_config();
6495 let main_rs = generate_main_rs("test-model", &config).unwrap();
6496
6497 assert!(
6498 main_rs.contains("text/event-stream"),
6499 "main.rs should set SSE content type for streaming"
6500 );
6501 assert!(
6502 main_rs.contains("chat.completion.chunk"),
6503 "main.rs should emit SSE chunks"
6504 );
6505 assert!(
6506 main_rs.contains("[DONE]"),
6507 "main.rs should emit [DONE] sentinel"
6508 );
6509 }
6510
6511 #[test]
6512 fn generated_main_rs_has_chat_message_formatting() {
6513 let config = minimal_config();
6514 let main_rs = generate_main_rs("test-model", &config).unwrap();
6515
6516 assert!(
6517 main_rs.contains("fn format_chat_messages"),
6518 "main.rs should define format_chat_messages function"
6519 );
6520 assert!(
6521 main_rs.contains("<|im_start|>"),
6522 "main.rs should use ChatML format"
6523 );
6524 assert!(
6525 main_rs.contains("<|im_end|>"),
6526 "main.rs should use ChatML format"
6527 );
6528 }
6529
6530 #[test]
6531 fn generated_main_rs_has_request_types() {
6532 let config = minimal_config();
6533 let main_rs = generate_main_rs("test-model", &config).unwrap();
6534
6535 assert!(
6536 main_rs.contains("struct ChatRequest"),
6537 "main.rs should define ChatRequest struct"
6538 );
6539 assert!(
6540 main_rs.contains("struct ChatMessage"),
6541 "main.rs should define ChatMessage struct"
6542 );
6543 assert!(
6544 main_rs.contains("Deserialize"),
6545 "main.rs should derive Deserialize for request types"
6546 );
6547 }
6548
6549 #[test]
6550 fn generated_model_has_reset_method() {
6551 let config = minimal_config();
6552 let model_rs = generate_model_rs(&config).unwrap();
6553
6554 assert!(
6555 model_rs.contains("pub fn reset(&mut self)"),
6556 "model.rs should have a reset() method for multi-request serving"
6557 );
6558 assert!(
6559 model_rs.contains("self.pos = 0"),
6560 "reset() should reset position to 0"
6561 );
6562 }
6563
6564 #[test]
6565 fn generated_main_rs_cli_mode_still_works() {
6566 let config = minimal_config();
6567 let main_rs = generate_main_rs("test-model", &config).unwrap();
6568
6569 assert!(
6571 main_rs.contains("fn cli_mode("),
6572 "main.rs should define cli_mode function"
6573 );
6574 assert!(
6575 main_rs.contains("model.forward"),
6576 "main.rs should call model.forward"
6577 );
6578 assert!(
6579 main_rs.contains("model.forward_prefill"),
6580 "main.rs should call model.forward_prefill"
6581 );
6582 }
6583
6584 #[test]
6587 fn generated_shaders_contain_batch_kernels() {
6588 let shaders = generate_metal_shaders(&minimal_config());
6589
6590 assert!(
6591 shaders.contains("kernel void matmul_vec_batch"),
6592 "shaders should contain matmul_vec_batch kernel"
6593 );
6594 assert!(
6595 shaders.contains("kernel void matmul_vec_q8_batch"),
6596 "shaders should contain matmul_vec_q8_batch kernel"
6597 );
6598 assert!(
6599 shaders.contains("kernel void matmul_q8_gemm_batch"),
6600 "shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
6601 );
6602 assert!(
6603 shaders.contains("kernel void matmul_vec_q4_batch"),
6604 "shaders should contain matmul_vec_q4_batch kernel"
6605 );
6606 assert!(
6607 shaders.contains("kernel void rms_norm_batch"),
6608 "shaders should contain rms_norm_batch kernel"
6609 );
6610 assert!(
6611 shaders.contains("kernel void silu_mul_fused_batch"),
6612 "shaders should contain silu_mul_fused_batch kernel"
6613 );
6614 assert!(
6615 shaders.contains("kernel void add_inplace_batch"),
6616 "shaders should contain add_inplace_batch kernel"
6617 );
6618 assert!(
6619 shaders.contains("kernel void copy_embedding_batch"),
6620 "shaders should contain copy_embedding_batch kernel"
6621 );
6622 }
6623
6624 #[test]
6625 fn generated_model_has_batch_pipelines() {
6626 let config = minimal_config();
6627 let model_rs = generate_model_rs(&config).unwrap();
6628
6629 for pipeline in &[
6630 "matmul_batch_pipeline",
6631 "matmul_q8_batch_pipeline",
6632 "matmul_q8_gemm_batch_pipeline",
6633 "matmul_q4_batch_pipeline",
6634 "rms_norm_batch_pipeline",
6635 "rope_batch_pipeline",
6636 "silu_mul_fused_batch_pipeline",
6637 "add_inplace_batch_pipeline",
6638 "copy_embedding_batch_pipeline",
6639 "attention_batch_pipeline",
6640 "copy_kv_batch_pipeline",
6641 ] {
6642 assert!(
6643 model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
6644 "MetalModel should have {pipeline} field"
6645 );
6646 }
6647 }
6648
6649 #[test]
6650 fn generated_model_has_batch_buffers() {
6651 let config = minimal_config();
6652 let model_rs = generate_model_rs(&config).unwrap();
6653
6654 for buf in &[
6655 "batch_hidden_buf",
6656 "batch_residual_buf",
6657 "batch_qkv_buf",
6658 "batch_attn_out_buf",
6659 "batch_attn_proj_buf",
6660 "batch_gate_up_buf",
6661 "batch_ffn_hidden_buf",
6662 "batch_ffn_out_buf",
6663 "batch_tokens_buf",
6664 "batch_positions_buf",
6665 ] {
6666 assert!(
6667 model_rs.contains(&format!("{buf}: Buffer")),
6668 "MetalModel should have {buf} field"
6669 );
6670 }
6671 }
6672
6673 #[test]
6674 fn generated_model_has_forward_prefill_batch() {
6675 let config = minimal_config();
6676 let model_rs = generate_model_rs(&config).unwrap();
6677
6678 assert!(
6679 model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
6680 "MetalModel should have forward_prefill_batch method"
6681 );
6682
6683 assert!(
6685 model_rs.contains("self.forward_prefill_batch(&[token_id])"),
6686 "forward_prefill should delegate to forward_prefill_batch"
6687 );
6688 }
6689
6690 #[test]
6691 fn generated_model_has_max_batch_size_constant() {
6692 let config = minimal_config();
6693 let model_rs = generate_model_rs(&config).unwrap();
6694
6695 assert!(
6696 model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
6697 "model.rs should define MAX_BATCH_SIZE constant"
6698 );
6699 }
6700
6701 #[test]
6702 fn forward_prefill_batch_uses_batch_dispatch() {
6703 let config = minimal_config();
6704 let model_rs = generate_model_rs(&config).unwrap();
6705
6706 let batch_start = model_rs
6707 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
6708 .unwrap();
6709 let batch_body = &model_rs[batch_start..];
6710 let batch_end = batch_body
6711 .find("\n pub fn reset")
6712 .unwrap_or(batch_body.len());
6713 let batch_code = &batch_body[..batch_end];
6714
6715 assert!(
6717 batch_code.contains("dispatch_rms_norm_batch"),
6718 "forward_prefill_batch should use dispatch_rms_norm_batch"
6719 );
6720 assert!(
6721 batch_code.contains("dispatch_copy_embedding_batch"),
6722 "forward_prefill_batch should use dispatch_copy_embedding_batch"
6723 );
6724 assert!(
6725 batch_code.contains("dispatch_silu_mul_fused_batch"),
6726 "forward_prefill_batch should use dispatch_silu_mul_fused_batch"
6727 );
6728 assert!(
6730 batch_code.contains("dispatch_attention_batch"),
6731 "forward_prefill_batch should use dispatch_attention_batch"
6732 );
6733 assert!(
6735 batch_code.contains("dispatch_copy_kv_both_batch"),
6736 "forward_prefill_batch should use dispatch_copy_kv_both_batch"
6737 );
6738 assert!(
6740 batch_code.contains("dispatch_rope_qk_batch"),
6741 "forward_prefill_batch should use dispatch_rope_qk_batch"
6742 );
6743 }
6744
6745 #[test]
6746 fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
6747 let config = minimal_q8_config();
6748 let model_rs = generate_model_rs(&config).unwrap();
6749
6750 let batch_start = model_rs
6751 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
6752 .unwrap();
6753 let batch_body = &model_rs[batch_start..];
6754 let batch_end = batch_body
6755 .find("\n pub fn reset")
6756 .unwrap_or(batch_body.len());
6757 let batch_code = &batch_body[..batch_end];
6758
6759 assert!(
6760 batch_code.contains("dispatch_matmul_q8_batch"),
6761 "Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
6762 );
6763 }
6764
6765 #[test]
6766 fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
6767 let config = minimal_q4_config();
6768 let model_rs = generate_model_rs(&config).unwrap();
6769
6770 let batch_start = model_rs
6771 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
6772 .unwrap();
6773 let batch_body = &model_rs[batch_start..];
6774 let batch_end = batch_body
6775 .find("\n pub fn reset")
6776 .unwrap_or(batch_body.len());
6777 let batch_code = &batch_body[..batch_end];
6778
6779 assert!(
6780 batch_code.contains("dispatch_matmul_q4_batch"),
6781 "Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
6782 );
6783 }
6784
6785 #[test]
6786 fn generated_main_rs_uses_batched_prefill() {
6787 let config = minimal_config();
6788 let main_rs = generate_main_rs("test-model", &config).unwrap();
6789
6790 assert!(
6791 main_rs.contains("forward_prefill_batch"),
6792 "main.rs should use forward_prefill_batch for prompt tokens"
6793 );
6794 }
6795
6796 #[test]
6797 fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
6798 let config = minimal_config();
6799 let model_rs = generate_model_rs(&config).unwrap();
6800
6801 let batch_start = model_rs
6802 .find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
6803 .unwrap();
6804 let batch_body = &model_rs[batch_start..];
6805 let batch_end = batch_body
6806 .find("\n pub fn reset")
6807 .unwrap_or(batch_body.len());
6808 let batch_code = &batch_body[..batch_end];
6809
6810 assert!(
6811 batch_code.contains("dispatch_matmul_batch"),
6812 "f32 forward_prefill_batch should use dispatch_matmul_batch"
6813 );
6814 assert!(
6816 !batch_code.contains("dispatch_matmul_q8_batch"),
6817 "f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
6818 );
6819 assert!(
6820 !batch_code.contains("dispatch_matmul_q4_batch"),
6821 "f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
6822 );
6823 }
6824
6825 #[test]
6826 fn forward_uses_cpu_embedding_lookup() {
6827 let config = minimal_config();
6828 let model_rs = generate_model_rs(&config).unwrap();
6829
6830 let forward_start = model_rs
6832 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6833 .unwrap();
6834 let forward_body = &model_rs[forward_start..];
6835 let forward_end = forward_body
6836 .find("\n pub fn forward_profile")
6837 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
6838 .unwrap_or(forward_body.len());
6839 let forward_code = &forward_body[..forward_end];
6840
6841 assert!(
6843 forward_code.contains("embed_buf.contents()"),
6844 "forward() should access embed_buf via CPU unified memory for embedding lookup"
6845 );
6846 assert!(
6847 forward_code.contains("copy_nonoverlapping"),
6848 "forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
6849 );
6850 assert!(
6852 !forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
6853 "forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
6854 );
6855 }
6856
6857 #[test]
6858 fn forward_profile_method_exists() {
6859 let config = minimal_config();
6860 let model_rs = generate_model_rs(&config).unwrap();
6861
6862 assert!(
6863 model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
6864 "MetalModel should have forward_profile() method"
6865 );
6866 assert!(
6868 model_rs.contains("[profile]"),
6869 "forward_profile() should print timing with [profile] prefix"
6870 );
6871 assert!(
6872 model_rs.contains("d_embed"),
6873 "forward_profile() should measure embedding time"
6874 );
6875 assert!(
6876 model_rs.contains("d_layers"),
6877 "forward_profile() should measure layer time"
6878 );
6879 assert!(
6880 model_rs.contains("d_logits"),
6881 "forward_profile() should measure logits time"
6882 );
6883 }
6884
6885 #[test]
6886 fn generated_cli_has_profile_flag() {
6887 let config = minimal_config();
6888 let main_rs = generate_main_rs("test-model", &config).unwrap();
6889
6890 assert!(
6891 main_rs.contains("--profile"),
6892 "CLI should support --profile flag"
6893 );
6894 assert!(
6895 main_rs.contains("forward_profile"),
6896 "CLI should call forward_profile when --profile is set"
6897 );
6898 }
6899
6900 #[test]
6901 fn generated_cli_has_thermal_yield() {
6902 let config = minimal_config();
6903 let main_rs = generate_main_rs("test-model", &config).unwrap();
6904
6905 assert!(
6906 main_rs.contains("yield_now()"),
6907 "CLI generation loop should include thread::yield_now() for thermal management"
6908 );
6909 }
6910
6911 #[test]
6914 fn generated_forward_handles_single_token_prompt() {
6915 let config = minimal_config();
6919 let model_rs = generate_model_rs(&config).unwrap();
6920
6921 let forward_start = model_rs
6923 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
6924 .expect("forward() must exist");
6925 let forward_body = &model_rs[forward_start..forward_start + 400];
6926
6927 assert!(
6929 !forward_body.contains("assert!(self.pos > 0"),
6930 "forward() must accept pos=0 (first token with no prefill)"
6931 );
6932
6933 assert!(
6935 model_rs.contains("self.pos"),
6936 "forward() should use self.pos to track sequence position"
6937 );
6938 }
6939
6940 #[test]
6941 fn generated_reset_clears_kv_cache_position() {
6942 let config = minimal_config();
6945 let model_rs = generate_model_rs(&config).unwrap();
6946
6947 let reset_start = model_rs
6948 .find("pub fn reset(&mut self)")
6949 .expect("reset() must exist");
6950 let reset_body = &model_rs[reset_start..reset_start + 200];
6951
6952 assert!(
6954 reset_body.contains("self.pos = 0"),
6955 "reset() must set self.pos = 0"
6956 );
6957
6958 assert!(
6960 reset_body.contains("self.prev_cmd = None"),
6961 "reset() should clear prev_cmd for clean command buffer state"
6962 );
6963 }
6964
6965 #[test]
6966 fn generated_serve_handles_empty_messages_gracefully() {
6967 let config = minimal_config();
6970 let main_rs = generate_main_rs("test-model", &config).unwrap();
6971
6972 let format_fn_start = main_rs
6974 .find("fn format_chat_messages")
6975 .expect("format_chat_messages must exist");
6976 let format_fn_body =
6977 &main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
6978
6979 assert!(
6981 format_fn_body.contains("for msg in messages"),
6982 "format_chat_messages should iterate over the messages slice"
6983 );
6984 assert!(
6986 format_fn_body.contains("<|im_start|>assistant"),
6987 "format_chat_messages should always append assistant prompt header"
6988 );
6989
6990 let serve_fn_start = main_rs
6992 .find("fn serve(")
6993 .expect("serve function must exist");
6994 let serve_fn_body = &main_rs[serve_fn_start..];
6995 assert!(
6996 serve_fn_body.contains("model.reset()"),
6997 "serve function should reset model between requests"
6998 );
6999 }
7000
7001 #[test]
7002 fn generated_model_forward_increments_pos() {
7003 let config = minimal_config();
7006 let model_rs = generate_model_rs(&config).unwrap();
7007
7008 let forward_start = model_rs
7009 .find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
7010 .unwrap();
7011 let forward_body = &model_rs[forward_start..];
7012 let forward_end = forward_body
7013 .find("\n pub fn forward_profile")
7014 .or_else(|| forward_body.find("\n pub fn forward_prefill"))
7015 .or_else(|| forward_body.find("\n fn dispatch_"))
7016 .unwrap_or(forward_body.len());
7017 let forward_code = &forward_body[..forward_end];
7018
7019 assert!(
7020 forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
7021 "forward() must increment self.pos after processing a token"
7022 );
7023 }
7024}