use std::fmt::Write as FmtWrite;
use std::fs;
use std::path::Path;
use forgellm_frontend::ir::*;
#[derive(Debug, thiserror::Error)]
pub enum MetalCodegenError {
#[error("graph has no model config")]
MissingConfig,
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("format error: {0}")]
Fmt(#[from] std::fmt::Error),
}
pub fn generate_metal_project(
graph: &Graph,
output_dir: &Path,
model_name: &str,
) -> Result<(), MetalCodegenError> {
let config = graph
.config
.as_ref()
.ok_or(MetalCodegenError::MissingConfig)?;
let src_dir = output_dir.join("src");
let shader_dir = output_dir.join("shaders");
fs::create_dir_all(&src_dir)?;
fs::create_dir_all(&shader_dir)?;
fs::write(
output_dir.join("Cargo.toml"),
generate_cargo_toml(model_name),
)?;
fs::write(
shader_dir.join("kernels.metal"),
generate_metal_shaders(config),
)?;
let model_rs = generate_model_rs(config)?;
fs::write(src_dir.join("model.rs"), model_rs)?;
let main_rs = generate_main_rs(model_name, config)?;
fs::write(src_dir.join("main.rs"), main_rs)?;
Ok(())
}
fn sanitize_name(name: &str) -> String {
name.to_lowercase()
.replace(|c: char| !c.is_alphanumeric() && c != '-', "-")
.trim_matches('-')
.to_string()
}
fn generate_cargo_toml(model_name: &str) -> String {
let sanitized = sanitize_name(model_name);
format!(
r#"[package]
name = "{sanitized}"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "{sanitized}"
path = "src/main.rs"
[dependencies]
metal = "0.29"
objc = "0.2"
half = "2"
tokenizers = {{ version = "0.21", default-features = false, features = ["onig"] }}
memmap2 = "0.9"
tiny_http = "0.12"
serde = {{ version = "1", features = ["derive"] }}
serde_json = "1"
[profile.release]
opt-level = 3
lto = "fat"
codegen-units = 1
"#
)
}
fn generate_metal_shaders(config: &ModelConfig) -> String {
let vec_tile_size = config.hidden_size.max(config.intermediate_size).min(8192);
let attn_scores_size = config.max_seq_len.min(4096);
r#"//
// Auto-generated by ForgeLLM Metal codegen.
// Metal Shading Language compute kernels for transformer inference.
//
// Optimized with simdgroup cooperative reductions, shared memory vector
// caching, float4 vectorized loads, multi-block Q8_0/Q4_0 processing per SIMD
// lane, and fast:: math intrinsics for Apple Silicon throughput.
//
#include <metal_stdlib>
#include <metal_simdgroup_matrix>
using namespace metal;
// ── Constants ───────────────────────────────────────────────────────────
// 8 simdgroups per threadgroup = 256 threads, each simdgroup handles 8 rows
// = 64 rows per threadgroup. 8-row register blocking doubles vector reuse
// per shared memory load vs 4-row, improving ILP and reducing launches.
constant constexpr uint SIMDGROUPS_PER_TG = 8;
constant constexpr uint ROWS_PER_SIMDGROUP = 8;
constant constexpr uint ROWS_PER_TG = SIMDGROUPS_PER_TG * ROWS_PER_SIMDGROUP; // 64
// ── matmul_vec ──────────────────────────────────────────────────────────
// Matrix-vector multiply: output[row] = dot(matrix[row, :], vector[:])
// Uses simdgroup cooperative dot product with shared memory vector caching
// and float4 vectorized loads. Each simdgroup processes 8 rows for better
// shared memory reuse (8x vector reuse per load) and instruction-level
// parallelism. 8 simdgroups x 8 rows = 64 rows per threadgroup.
kernel void matmul_vec(
device const float* matrix [[buffer(0)]],
device const float* vector [[buffer(1)]],
device float* output [[buffer(2)]],
constant uint& rows [[buffer(3)]],
constant uint& cols [[buffer(4)]],
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
// Cooperatively load vector into threadgroup shared memory
threadgroup float vec_tile[VEC_TILE_SIZE]; // sized to max(hidden, intermediate), capped at 8192 (32 KB TG mem)
for (uint i = tid; i < cols; i += 256) {
vec_tile[i] = vector[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each simdgroup handles 8 consecutive rows
uint row_base = tgid * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
if (row_base >= rows) return;
uint base0 = row_base * cols;
uint base1 = (row_base + 1) * cols;
uint base2 = (row_base + 2) * cols;
uint base3 = (row_base + 3) * cols;
uint base4 = (row_base + 4) * cols;
uint base5 = (row_base + 5) * cols;
uint base6 = (row_base + 6) * cols;
uint base7 = (row_base + 7) * cols;
// float4 vectorized accumulation across 8 rows
uint cols_vec4 = cols & ~127u; // largest multiple of 128 <= cols
float4 sum4_0 = float4(0.0f);
float4 sum4_1 = float4(0.0f);
float4 sum4_2 = float4(0.0f);
float4 sum4_3 = float4(0.0f);
float4 sum4_4 = float4(0.0f);
float4 sum4_5 = float4(0.0f);
float4 sum4_6 = float4(0.0f);
float4 sum4_7 = float4(0.0f);
for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
}
float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
// Handle remaining elements (cols not divisible by 128)
for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
float vv = vec_tile[j];
sum0 += matrix[base0 + j] * vv;
if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
}
// Simdgroup hardware warp-level reduction
sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
// Only first lane writes the results
if (simd_lane == 0) {
if (row_base < rows) output[row_base] = sum0;
if (row_base + 1 < rows) output[row_base + 1] = sum1;
if (row_base + 2 < rows) output[row_base + 2] = sum2;
if (row_base + 3 < rows) output[row_base + 3] = sum3;
if (row_base + 4 < rows) output[row_base + 4] = sum4;
if (row_base + 5 < rows) output[row_base + 5] = sum5;
if (row_base + 6 < rows) output[row_base + 6] = sum6;
if (row_base + 7 < rows) output[row_base + 7] = sum7;
}
}
// ── rms_norm ────────────────────────────────────────────────────────────
// RMS normalization: output[i] = input[i] * rsqrt(mean(input^2) + eps) * weight[i]
// Uses simdgroup reduction within each warp, then cross-simdgroup reduction
// via shared memory for minimal synchronization overhead.
kernel void rms_norm(
device const float* input [[buffer(0)]],
device const float* weight [[buffer(1)]],
device float* output [[buffer(2)]],
constant uint& n [[buffer(3)]],
constant float& eps [[buffer(4)]],
uint tid [[thread_index_in_threadgroup]])
{
// Each thread accumulates partial sum-of-squares
float sum_sq = 0.0f;
for (uint i = tid; i < n; i += 256) {
float v = input[i];
sum_sq += v * v;
}
// Simdgroup-level reduction (hardware warp sum)
sum_sq = simd_sum(sum_sq);
// Cross-simdgroup reduction via shared memory
threadgroup float shared[8];
uint simd_id = tid / 32;
uint simd_lane = tid % 32;
if (simd_lane == 0) {
shared[simd_id] = sum_sq;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// First thread computes final inverse RMS
if (tid == 0) {
float total = 0.0f;
for (uint i = 0; i < 8; i++) {
total += shared[i];
}
shared[0] = fast::rsqrt(total / float(n) + eps);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float inv_rms = shared[0];
// Normalize
for (uint i = tid; i < n; i += 256) {
output[i] = input[i] * inv_rms * weight[i];
}
}
// ── rope ────────────────────────────────────────────────────────────────
// Rotary Position Embedding applied in-place.
// Each thread handles one (head, pair) combination.
kernel void rope(
device float* data [[buffer(0)]],
constant uint& num_heads [[buffer(1)]],
constant uint& head_dim [[buffer(2)]],
constant uint& pos [[buffer(3)]],
constant float& theta [[buffer(4)]],
uint id [[thread_position_in_grid]])
{
uint half_dim = head_dim / 2;
uint total_pairs = num_heads * half_dim;
if (id >= total_pairs) return;
uint h = id / half_dim;
uint i = id % half_dim;
uint off = h * head_dim;
float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
float angle = float(pos) * freq;
float c = cos(angle);
float s = sin(angle);
float x0 = data[off + 2 * i];
float x1 = data[off + 2 * i + 1];
data[off + 2 * i] = x0 * c - x1 * s;
data[off + 2 * i + 1] = x0 * s + x1 * c;
}
// ── softmax ─────────────────────────────────────────────────────────────
// Numerically stable softmax over a 1-D array.
// Single-threadgroup kernel with cooperative reduction.
kernel void softmax(
device float* data [[buffer(0)]],
constant uint& n [[buffer(1)]],
uint tid [[thread_index_in_threadgroup]],
uint tg_size [[threads_per_threadgroup]])
{
threadgroup float shared_val[256];
// Pass 1: find max
float local_max = -INFINITY;
for (uint i = tid; i < n; i += tg_size) {
local_max = max(local_max, data[i]);
}
shared_val[tid] = local_max;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
shared_val[tid] = max(shared_val[tid], shared_val[tid + stride]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float max_val = shared_val[0];
threadgroup_barrier(mem_flags::mem_threadgroup);
// Pass 2: exp and sum
float local_sum = 0.0f;
for (uint i = tid; i < n; i += tg_size) {
float e = fast::exp(data[i] - max_val);
data[i] = e;
local_sum += e;
}
shared_val[tid] = local_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint stride = tg_size / 2; stride > 0; stride >>= 1) {
if (tid < stride) {
shared_val[tid] += shared_val[tid + stride];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float inv_sum = 1.0f / shared_val[0];
threadgroup_barrier(mem_flags::mem_threadgroup);
// Pass 3: normalize
for (uint i = tid; i < n; i += tg_size) {
data[i] *= inv_sum;
}
}
// ── silu_mul ────────────────────────────────────────────────────────────
// Fused SiLU activation * element-wise multiply:
// output[i] = (gate[i] / (1 + exp(-gate[i]))) * up[i]
kernel void silu_mul(
device const float* gate [[buffer(0)]],
device const float* up [[buffer(1)]],
device float* output [[buffer(2)]],
constant uint& n [[buffer(3)]],
uint id [[thread_position_in_grid]])
{
if (id >= n) return;
float g = gate[id];
output[id] = (g / (1.0f + fast::exp(-g))) * up[id];
}
// ── silu_mul_fused ─────────────────────────────────────────────────────
// Fused SiLU-multiply reading gate and up from a single concatenated buffer:
// gate = gate_up[0..n], up = gate_up[n..2*n]
// output[i] = silu(gate_up[i]) * gate_up[n + i]
kernel void silu_mul_fused(
device const float* gate_up [[buffer(0)]],
device float* output [[buffer(1)]],
constant uint& n [[buffer(2)]],
uint id [[thread_position_in_grid]])
{
if (id >= n) return;
float g = gate_up[id];
float u = gate_up[n + id];
output[id] = (g / (1.0f + fast::exp(-g))) * u;
}
// ── elementwise_add ─────────────────────────────────────────────────────
// Residual connection: output[i] = a[i] + b[i]
kernel void elementwise_add(
device const float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
device float* output [[buffer(2)]],
constant uint& n [[buffer(3)]],
uint id [[thread_position_in_grid]])
{
if (id >= n) return;
output[id] = a[id] + b[id];
}
// ── copy_buffer ─────────────────────────────────────────────────────────
// Simple buffer-to-buffer copy via compute kernel, avoiding blit encoder
// transitions. Used for KV cache updates and embedding lookup.
kernel void copy_buffer(
device const float* src [[buffer(0)]],
device float* dst [[buffer(1)]],
constant uint& count [[buffer(2)]],
uint id [[thread_position_in_grid]])
{
if (id < count) dst[id] = src[id];
}
// ── copy_offset ─────────────────────────────────────────────────────────
// Copy with source offset (in floats). Used for embedding table lookup
// where we need to copy a specific row from a large table.
kernel void copy_offset(
device const float* src [[buffer(0)]],
device float* dst [[buffer(1)]],
constant uint& src_offset [[buffer(2)]], // in floats
constant uint& count [[buffer(3)]],
uint id [[thread_position_in_grid]])
{
if (id < count) dst[id] = src[src_offset + id];
}
// ── copy_f32_to_f16_offset ──────────────────────────────────────────────
// Copy f32 elements from src into a half-typed dst, converting to half on
// write. Used by the single-token decode path to append a new K/V vector
// to the f16 KV cache. Byte offsets into src/dst are supplied via the
// Metal buffer binding offsets — no in-kernel offsets needed.
kernel void copy_f32_to_f16_offset(
device const float* src [[buffer(0)]],
device half* dst [[buffer(1)]],
constant uint& count [[buffer(2)]],
uint id [[thread_position_in_grid]])
{
if (id < count) dst[id] = half(src[id]);
}
// ── add_inplace ─────────────────────────────────────────────────────────
// In-place residual connection: a[i] += b[i]
// Avoids a separate blit copy for residual add, reducing encoder overhead.
kernel void add_inplace(
device float* a [[buffer(0)]],
device const float* b [[buffer(1)]],
constant uint& n [[buffer(2)]],
uint id [[thread_position_in_grid]])
{
if (id >= n) return;
a[id] += b[id];
}
// ── matmul_vec_q8 ─────────────────────────────────────────────────────
// Matrix-vector multiply where the matrix is stored as Q8_0 blocks.
// Q8_0 block: 2 bytes f16 scale + 32 bytes int8 data = 34 bytes per 32 elements.
// Operates directly on quantized weights to halve memory bandwidth vs f32,
// yielding ~1.5-2x speedup on bandwidth-bound GPU matmul.
//
// Register-pressure-optimised: 4 rows per simdgroup (vs 8 for f32 matmul)
// because int8->float conversion doubles register demand. Fully unrolled
// inner loop with float4 vector loads from shared memory eliminates loop
// overhead and enables better instruction scheduling.
// 8 simdgroups x 4 rows = 32 rows per threadgroup of 256 threads.
constant constexpr uint Q8_ROWS_PER_SG = 4;
constant constexpr uint Q8_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q8_ROWS_PER_SG; // 32
// Q4_0 uses the same 4-row-per-simdgroup layout as Q8_0 (nibble unpacking
// doubles ALU work, so the same register budget applies).
constant constexpr uint Q4_ROWS_PER_SG = 4;
constant constexpr uint Q4_ROWS_PER_TG = SIMDGROUPS_PER_TG * Q4_ROWS_PER_SG; // 32
kernel void matmul_vec_q8(
device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes
device const float* vector [[buffer(1)]], // f32 input
device float* output [[buffer(2)]],
constant uint& rows [[buffer(3)]],
constant uint& cols [[buffer(4)]], // number of elements per row
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
// Load vector into shared memory
threadgroup float vec_tile[VEC_TILE_SIZE];
for (uint i = tid; i < cols; i += 256) {
vec_tile[i] = vector[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each simdgroup handles 4 consecutive rows (lower register pressure)
uint row_base = tgid * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
if (row_base >= rows) return;
// Q8_0: each block is 34 bytes for 32 elements
uint blocks_per_row = cols / 32;
uint row_bytes = blocks_per_row * 34;
// Pointers to each row's Q8_0 data
device const uchar* r0 = matrix + row_base * row_bytes;
device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
uint bb = blk * 34;
uint vb = blk * 32;
// Prefetch all 4 scales
float sc0 = float(*(device const half*)(r0 + bb));
float sc1 = float(*(device const half*)(r1 + bb));
float sc2 = float(*(device const half*)(r2 + bb));
float sc3 = float(*(device const half*)(r3 + bb));
// Wide 64-bit loads via packed_short4 (2-byte aligned — matches the
// Q8_0 block layout where the int8 data starts at offset +2 from a
// 34-byte block boundary). Each packed_short4 covers 8 int8 weights,
// so 4 loads per row per block vs the previous 8 char4 loads — a 2x
// reduction in memory transactions. Metal's char16/packed_char16 are
// reserved types and packed_*int4 require >=4-byte alignment which
// this layout does not provide, so packed_short4 is the widest valid
// vectorized load.
device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
// Load all 8 float4 vector values for this 32-element block from shared memory
float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
// Helper: expand a packed_short4 into a float4 pair covering 8 int8 weights.
// char2(as_type<char2>(s)) yields (low_byte, high_byte) on little-endian.
#define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
short4 _s = short4(SHORT4); \
char2 _a = as_type<char2>(_s.x); \
char2 _b = as_type<char2>(_s.y); \
char2 _c = as_type<char2>(_s.z); \
char2 _d = as_type<char2>(_s.w); \
(OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
(OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
}
float4 f0, f1;
float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
// Row 0: 4 short4 loads cover 32 int8 weights
Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
// Row 1
Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
// Row 2
Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
// Row 3
Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
#undef Q8_UNPACK8
sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
}
sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
if (simd_lane == 0) {
if (row_base < rows) output[row_base] = sum0;
if (row_base + 1 < rows) output[row_base + 1] = sum1;
if (row_base + 2 < rows) output[row_base + 2] = sum2;
if (row_base + 3 < rows) output[row_base + 3] = sum3;
}
}
// ── matmul_vec_q4 ─────────────────────────────────────────────────────
// Matrix-vector multiply where the matrix is stored as Q4_0 blocks.
// Q4_0 block: 2 bytes f16 scale + 16 packed bytes (32 4-bit values) = 18 bytes per 32 elements.
// Each packed byte holds two 4-bit unsigned values; subtract 8 to get signed.
// Low nibble (& 0x0F) - 8 → element[i], high nibble (>> 4) - 8 → element[i+16].
//
// Same threadgroup geometry as Q8_0: 4 rows per simdgroup, 32 rows per TG.
// Inner loop fully unrolled with uchar4 loads and float4 vector reads.
kernel void matmul_vec_q4(
device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes
device const float* vector [[buffer(1)]], // f32 input
device float* output [[buffer(2)]],
constant uint& rows [[buffer(3)]],
constant uint& cols [[buffer(4)]], // number of elements per row
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
// Load vector into shared memory
threadgroup float vec_tile[VEC_TILE_SIZE];
for (uint i = tid; i < cols; i += 256) {
vec_tile[i] = vector[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Each simdgroup handles 4 consecutive rows
uint row_base = tgid * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
if (row_base >= rows) return;
// Q4_0: each block is 18 bytes for 32 elements
uint blocks_per_row = cols / 32;
uint row_bytes = blocks_per_row * 18;
// Pointers to each row's Q4_0 data
device const uchar* r0 = matrix + row_base * row_bytes;
device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
uint bb = blk * 18;
uint vb = blk * 32;
// Prefetch all 4 scales
float sc0 = float(*(device const half*)(r0 + bb));
float sc1 = float(*(device const half*)(r1 + bb));
float sc2 = float(*(device const half*)(r2 + bb));
float sc3 = float(*(device const half*)(r3 + bb));
// Packed byte pointers (16 bytes = 32 nibbles = 32 elements)
device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
// Load 8 float4 vector values for 32 elements from shared memory
// Low nibble elements: indices [0..15], High nibble elements: indices [16..31]
float4 v0 = *(threadgroup const float4*)(vec_tile + vb); // [0..3]
float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4); // [4..7]
float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8); // [8..11]
float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12); // [12..15]
float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16); // [16..19]
float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20); // [20..23]
float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24); // [24..27]
float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28); // [28..31]
// Fully unrolled block dot products — 4 rows x 4 uchar4 reads
// Each uchar4 has 4 packed bytes; low nibble → elem[j], high nibble → elem[j+16]
float bd0=0, bd1=0, bd2=0, bd3=0;
uchar4 b;
// Row 0: p0[0]→v0/v4, p0[1]→v1/v5, p0[2]→v2/v6, p0[3]→v3/v7
b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
// Row 1
b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
// Row 2
b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
// Row 3
b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x
+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y
+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z
+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x
+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y
+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z
+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x
+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y
+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z
+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x
+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y
+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z
+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
}
sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
if (simd_lane == 0) {
if (row_base < rows) output[row_base] = sum0;
if (row_base + 1 < rows) output[row_base + 1] = sum1;
if (row_base + 2 < rows) output[row_base + 2] = sum2;
if (row_base + 3 < rows) output[row_base + 3] = sum3;
}
}
// ── attention ───────────────────────────────────────────────────────────
// Single-query attention with simdgroup cooperative reductions.
// Computes Q*K^T scores using 32-lane simd dot products, applies softmax
// with simd_max/simd_sum reductions, then weighted sum of V.
// Each threadgroup handles one head with 256 threads (8 simdgroups).
//
// Buffers:
// q: [num_heads * head_dim] current query
// k_cache: [max_seq_len * num_kv_heads * head_dim]
// v_cache: [max_seq_len * num_kv_heads * head_dim]
// output: [num_heads * head_dim]
kernel void attention(
device const float* q [[buffer(0)]],
device const half* k_cache [[buffer(1)]],
device const half* v_cache [[buffer(2)]],
device float* output [[buffer(3)]],
constant uint& seq_len [[buffer(4)]],
constant uint& num_heads [[buffer(5)]],
constant uint& num_kv_heads [[buffer(6)]],
constant uint& head_dim [[buffer(7)]],
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint head = tgid;
if (head >= num_heads) return;
uint kv_head = head / (num_heads / num_kv_heads);
uint q_off = head * head_dim;
// Step 1: Compute attention scores Q·K^T with simdgroup reduction
// Use shared memory for scores — 2048 entries (8 KB) saves TG memory
// vs 4096. For seq_len > 2048, generation-phase attention is rare;
// most generation steps have short effective context.
threadgroup float scores[ATTN_SCORES_SIZE]; // max seq_len for generation phase (matches MAX_SEQ_LEN cap)
// Q·K^T with half4/float4 vectorized loads.
// Each simdgroup handles one s; 32 lanes cover head_dim in chunks of 4.
// For head_dim=128, every lane does exactly one half4 load (no loop).
// For head_dim=64, 16 lanes active; for head_dim=96, 24 lanes.
uint head_dim4 = head_dim / 4;
for (uint s = simd_id; s < seq_len; s += 8) {
uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
float dot = 0.0;
for (uint d4 = simd_lane; d4 < head_dim4; d4 += 32) {
uint d = d4 * 4;
float4 q4 = *(device const float4*)(q + q_off + d);
float4 k4 = float4(*(device const half4*)(k_cache + k_off + d));
dot += q4.x * k4.x + q4.y * k4.y + q4.z * k4.z + q4.w * k4.w;
}
// Scalar fallback for head_dim not divisible by 4 (unused for all
// current models: head_dim ∈ {64, 96, 128} are all multiples of 4).
for (uint d = head_dim4 * 4 + simd_lane; d < head_dim; d += 32) {
dot += q[q_off + d] * float(k_cache[k_off + d]);
}
dot = simd_sum(dot);
if (simd_lane == 0) {
scores[s] = dot * fast::rsqrt(float(head_dim));
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Step 2: Softmax over scores (cooperative)
// Find max
float local_max = -INFINITY;
for (uint s = tid; s < seq_len; s += 256) {
local_max = max(local_max, scores[s]);
}
local_max = simd_max(local_max);
threadgroup float shared_max[8];
if (simd_lane == 0) shared_max[simd_id] = local_max;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float m = shared_max[0];
for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
shared_max[0] = m;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float max_val = shared_max[0];
// Exp and sum
float local_sum = 0.0;
for (uint s = tid; s < seq_len; s += 256) {
scores[s] = fast::exp(scores[s] - max_val);
local_sum += scores[s];
}
local_sum = simd_sum(local_sum);
threadgroup float shared_sum[8];
if (simd_lane == 0) shared_sum[simd_id] = local_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0;
for (uint i = 0; i < 8; i++) total += shared_sum[i];
shared_sum[0] = 1.0 / total;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float inv_sum = shared_sum[0];
for (uint s = tid; s < seq_len; s += 256) {
scores[s] *= inv_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Step 3: Weighted sum of V: output = scores · V, half4 vectorized.
// Mirrors attention_batch prefill: each thread handles 4 d-dims via
// half4 loads, with scalar fallback for head_dim not divisible by 4.
uint seq_len4 = seq_len & ~3u; // largest multiple of 4 <= seq_len
uint v_stride = num_kv_heads * head_dim;
for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
uint d = d4 * 4;
float4 acc = float4(0.0);
uint v_base = kv_head * head_dim + d;
for (uint s = 0; s < seq_len4; s += 4) {
float sc0 = scores[s];
float sc1 = scores[s + 1];
float sc2 = scores[s + 2];
float sc3 = scores[s + 3];
acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
+ sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
+ sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
+ sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
}
for (uint s = seq_len4; s < seq_len; s++) {
acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
}
*(device float4*)(output + q_off + d) = acc;
}
// Scalar fallback for remaining dimensions (unused for current models)
for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
float acc = 0.0;
uint v_base = kv_head * head_dim + d;
for (uint s = 0; s < seq_len; s++) {
acc += scores[s] * float(v_cache[s * v_stride + v_base]);
}
output[q_off + d] = acc;
}
}
// ── Batched prefill kernels ────────────────────────────────────────────
// These kernels process M input vectors against the same weight matrix
// in a single dispatch, converting mat-vec into mat-mat for better GPU
// utilization during prompt prefill.
// ── rms_norm_batch ─────────────────────────────────────────────────────
// RMS normalization for a batch of vectors.
// Each threadgroup handles one vector: input[token * n .. (token+1) * n].
// Grid: M threadgroups (one per token).
kernel void rms_norm_batch(
device const float* input [[buffer(0)]], // [M, n]
device const float* weight [[buffer(1)]], // [n]
device float* output [[buffer(2)]], // [M, n]
constant uint& n [[buffer(3)]],
constant float& eps [[buffer(4)]],
constant uint& num_tokens [[buffer(5)]],
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]])
{
if (tgid >= num_tokens) return;
uint base = tgid * n;
float sum_sq = 0.0f;
for (uint i = tid; i < n; i += 256) {
float v = input[base + i];
sum_sq += v * v;
}
sum_sq = simd_sum(sum_sq);
threadgroup float shared[8];
uint simd_id = tid / 32;
uint simd_lane = tid % 32;
if (simd_lane == 0) {
shared[simd_id] = sum_sq;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0f;
for (uint i = 0; i < 8; i++) {
total += shared[i];
}
shared[0] = fast::rsqrt(total / float(n) + eps);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float inv_rms = shared[0];
for (uint i = tid; i < n; i += 256) {
output[base + i] = input[base + i] * inv_rms * weight[i];
}
}
// ── rope_batch ─────────────────────────────────────────────────────────
// Rotary Position Embedding for a batch of vectors with different positions.
// data layout: [M, num_heads * head_dim], positions: [M]
// Each thread handles one (token, head, pair) combination.
kernel void rope_batch(
device float* data [[buffer(0)]], // [M, num_heads * head_dim]
constant uint& num_heads [[buffer(1)]],
constant uint& head_dim [[buffer(2)]],
device const uint* positions [[buffer(3)]], // [M] position per token
constant float& theta [[buffer(4)]],
constant uint& num_tokens [[buffer(5)]],
uint id [[thread_position_in_grid]])
{
uint half_dim = head_dim / 2;
uint pairs_per_token = num_heads * half_dim;
uint total = num_tokens * pairs_per_token;
if (id >= total) return;
uint token = id / pairs_per_token;
uint rem = id % pairs_per_token;
uint h = rem / half_dim;
uint i = rem % half_dim;
uint off = token * (num_heads * head_dim) + h * head_dim;
float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
float angle = float(positions[token]) * freq;
float c = cos(angle);
float s = sin(angle);
float x0 = data[off + 2 * i];
float x1 = data[off + 2 * i + 1];
data[off + 2 * i] = x0 * c - x1 * s;
data[off + 2 * i + 1] = x0 * s + x1 * c;
}
// ── silu_mul_fused_batch ───────────────────────────────────────────────
// Fused SiLU-multiply for a batch: gate_up layout [M, 2*n].
// Each element: output[token*n + i] = silu(gate_up[token*2*n + i]) * gate_up[token*2*n + n + i]
kernel void silu_mul_fused_batch(
device const float* gate_up [[buffer(0)]], // [M, 2*n]
device float* output [[buffer(1)]], // [M, n]
constant uint& n [[buffer(2)]],
constant uint& num_tokens [[buffer(3)]],
uint id [[thread_position_in_grid]])
{
uint total = num_tokens * n;
if (id >= total) return;
uint token = id / n;
uint i = id % n;
uint gu_base = token * 2 * n;
float g = gate_up[gu_base + i];
float u = gate_up[gu_base + n + i];
output[token * n + i] = (g / (1.0f + fast::exp(-g))) * u;
}
// ── add_inplace_batch ──────────────────────────────────────────────────
// In-place residual connection for a batch: a[i] += b[i] for all M*n elements.
kernel void add_inplace_batch(
device float* a [[buffer(0)]], // [M * n]
device const float* b [[buffer(1)]], // [M * n]
constant uint& total [[buffer(2)]], // M * n
uint id [[thread_position_in_grid]])
{
if (id >= total) return;
a[id] += b[id];
}
// ── copy_embedding_batch ───────────────────────────────────────────────
// Copy M embedding rows from embedding table to a contiguous batch buffer.
// tokens: [M] array of token IDs, each selects a row of `dim` floats.
kernel void copy_embedding_batch(
device const float* embed [[buffer(0)]], // [vocab_size, dim]
device float* output [[buffer(1)]], // [M, dim]
device const uint* tokens [[buffer(2)]], // [M]
constant uint& dim [[buffer(3)]],
constant uint& num_tokens [[buffer(4)]],
uint id [[thread_position_in_grid]])
{
uint total = num_tokens * dim;
if (id >= total) return;
uint token_idx = id / dim;
uint d = id % dim;
output[id] = embed[tokens[token_idx] * dim + d];
}
// ── matmul_vec_batch ───────────────────────────────────────────────────
// Batched matrix-vector multiply: process M input vectors against the same
// weight matrix. Grid: ceil(rows/ROWS_PER_TG) * M threadgroups.
// Each threadgroup handles one (token, row_group) pair.
kernel void matmul_vec_batch(
device const float* matrix [[buffer(0)]], // [rows, cols] weight
device const float* inputs [[buffer(1)]], // [M, cols] input batch
device float* outputs [[buffer(2)]], // [M, rows] output batch
constant uint& num_tokens [[buffer(3)]], // M
constant uint& rows [[buffer(4)]],
constant uint& cols [[buffer(5)]],
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint row_tgs = (rows + ROWS_PER_TG - 1) / ROWS_PER_TG;
uint token = tgid / row_tgs;
uint tg_in_token = tgid % row_tgs;
if (token >= num_tokens) return;
// Load this token's input vector into shared memory
threadgroup float vec_tile[VEC_TILE_SIZE];
device const float* input = inputs + token * cols;
for (uint i = tid; i < cols; i += 256) {
vec_tile[i] = input[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint row_base = tg_in_token * ROWS_PER_TG + simd_id * ROWS_PER_SIMDGROUP;
if (row_base >= rows) return;
uint base0 = row_base * cols;
uint base1 = (row_base + 1) * cols;
uint base2 = (row_base + 2) * cols;
uint base3 = (row_base + 3) * cols;
uint base4 = (row_base + 4) * cols;
uint base5 = (row_base + 5) * cols;
uint base6 = (row_base + 6) * cols;
uint base7 = (row_base + 7) * cols;
uint cols_vec4 = cols & ~127u;
float4 sum4_0 = float4(0.0f);
float4 sum4_1 = float4(0.0f);
float4 sum4_2 = float4(0.0f);
float4 sum4_3 = float4(0.0f);
float4 sum4_4 = float4(0.0f);
float4 sum4_5 = float4(0.0f);
float4 sum4_6 = float4(0.0f);
float4 sum4_7 = float4(0.0f);
for (uint j = simd_lane * 4; j < cols_vec4; j += 128) {
float4 v = *reinterpret_cast<threadgroup const float4*>(vec_tile + j);
sum4_0 += *reinterpret_cast<device const float4*>(matrix + base0 + j) * v;
if (row_base + 1 < rows) sum4_1 += *reinterpret_cast<device const float4*>(matrix + base1 + j) * v;
if (row_base + 2 < rows) sum4_2 += *reinterpret_cast<device const float4*>(matrix + base2 + j) * v;
if (row_base + 3 < rows) sum4_3 += *reinterpret_cast<device const float4*>(matrix + base3 + j) * v;
if (row_base + 4 < rows) sum4_4 += *reinterpret_cast<device const float4*>(matrix + base4 + j) * v;
if (row_base + 5 < rows) sum4_5 += *reinterpret_cast<device const float4*>(matrix + base5 + j) * v;
if (row_base + 6 < rows) sum4_6 += *reinterpret_cast<device const float4*>(matrix + base6 + j) * v;
if (row_base + 7 < rows) sum4_7 += *reinterpret_cast<device const float4*>(matrix + base7 + j) * v;
}
float sum0 = sum4_0.x + sum4_0.y + sum4_0.z + sum4_0.w;
float sum1 = sum4_1.x + sum4_1.y + sum4_1.z + sum4_1.w;
float sum2 = sum4_2.x + sum4_2.y + sum4_2.z + sum4_2.w;
float sum3 = sum4_3.x + sum4_3.y + sum4_3.z + sum4_3.w;
float sum4 = sum4_4.x + sum4_4.y + sum4_4.z + sum4_4.w;
float sum5 = sum4_5.x + sum4_5.y + sum4_5.z + sum4_5.w;
float sum6 = sum4_6.x + sum4_6.y + sum4_6.z + sum4_6.w;
float sum7 = sum4_7.x + sum4_7.y + sum4_7.z + sum4_7.w;
for (uint j = cols_vec4 + simd_lane; j < cols; j += 32) {
float vv = vec_tile[j];
sum0 += matrix[base0 + j] * vv;
if (row_base + 1 < rows) sum1 += matrix[base1 + j] * vv;
if (row_base + 2 < rows) sum2 += matrix[base2 + j] * vv;
if (row_base + 3 < rows) sum3 += matrix[base3 + j] * vv;
if (row_base + 4 < rows) sum4 += matrix[base4 + j] * vv;
if (row_base + 5 < rows) sum5 += matrix[base5 + j] * vv;
if (row_base + 6 < rows) sum6 += matrix[base6 + j] * vv;
if (row_base + 7 < rows) sum7 += matrix[base7 + j] * vv;
}
sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
sum4 = simd_sum(sum4); sum5 = simd_sum(sum5);
sum6 = simd_sum(sum6); sum7 = simd_sum(sum7);
device float* output = outputs + token * rows;
if (simd_lane == 0) {
if (row_base < rows) output[row_base] = sum0;
if (row_base + 1 < rows) output[row_base + 1] = sum1;
if (row_base + 2 < rows) output[row_base + 2] = sum2;
if (row_base + 3 < rows) output[row_base + 3] = sum3;
if (row_base + 4 < rows) output[row_base + 4] = sum4;
if (row_base + 5 < rows) output[row_base + 5] = sum5;
if (row_base + 6 < rows) output[row_base + 6] = sum6;
if (row_base + 7 < rows) output[row_base + 7] = sum7;
}
}
// ── matmul_vec_q8_batch ────────────────────────────────────────────────
// Batched Q8_0 matrix-vector multiply for M input vectors.
// Grid: ceil(rows/Q8_ROWS_PER_TG) * M threadgroups.
kernel void matmul_vec_q8_batch(
device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
device const float* inputs [[buffer(1)]], // [M, cols] input batch
device float* outputs [[buffer(2)]], // [M, rows] output batch
constant uint& num_tokens [[buffer(3)]], // M
constant uint& rows [[buffer(4)]],
constant uint& cols [[buffer(5)]],
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint row_tgs = (rows + Q8_ROWS_PER_TG - 1) / Q8_ROWS_PER_TG;
uint token = tgid / row_tgs;
uint tg_in_token = tgid % row_tgs;
if (token >= num_tokens) return;
// Load this token's input vector into shared memory
threadgroup float vec_tile[VEC_TILE_SIZE];
device const float* input = inputs + token * cols;
for (uint i = tid; i < cols; i += 256) {
vec_tile[i] = input[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint row_base = tg_in_token * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
if (row_base >= rows) return;
uint blocks_per_row = cols / 32;
uint row_bytes = blocks_per_row * 34;
device const uchar* r0 = matrix + row_base * row_bytes;
device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
uint bb = blk * 34;
uint vb = blk * 32;
float sc0 = float(*(device const half*)(r0 + bb));
float sc1 = float(*(device const half*)(r1 + bb));
float sc2 = float(*(device const half*)(r2 + bb));
float sc3 = float(*(device const half*)(r3 + bb));
// Wide 64-bit loads via packed_short4 (2-byte aligned): 4 loads per
// row per block vs 8 char4 loads — 2x reduction in memory transactions.
device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
#define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
short4 _s = short4(SHORT4); \
char2 _a = as_type<char2>(_s.x); \
char2 _b = as_type<char2>(_s.y); \
char2 _c = as_type<char2>(_s.z); \
char2 _d = as_type<char2>(_s.w); \
(OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
(OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
}
float4 f0, f1;
float bd0 = 0.0, bd1 = 0.0, bd2 = 0.0, bd3 = 0.0;
Q8_UNPACK8(d0[0], f0, f1); bd0 += dot(f0, v0) + dot(f1, v1);
Q8_UNPACK8(d0[1], f0, f1); bd0 += dot(f0, v2) + dot(f1, v3);
Q8_UNPACK8(d0[2], f0, f1); bd0 += dot(f0, v4) + dot(f1, v5);
Q8_UNPACK8(d0[3], f0, f1); bd0 += dot(f0, v6) + dot(f1, v7);
Q8_UNPACK8(d1[0], f0, f1); bd1 += dot(f0, v0) + dot(f1, v1);
Q8_UNPACK8(d1[1], f0, f1); bd1 += dot(f0, v2) + dot(f1, v3);
Q8_UNPACK8(d1[2], f0, f1); bd1 += dot(f0, v4) + dot(f1, v5);
Q8_UNPACK8(d1[3], f0, f1); bd1 += dot(f0, v6) + dot(f1, v7);
Q8_UNPACK8(d2[0], f0, f1); bd2 += dot(f0, v0) + dot(f1, v1);
Q8_UNPACK8(d2[1], f0, f1); bd2 += dot(f0, v2) + dot(f1, v3);
Q8_UNPACK8(d2[2], f0, f1); bd2 += dot(f0, v4) + dot(f1, v5);
Q8_UNPACK8(d2[3], f0, f1); bd2 += dot(f0, v6) + dot(f1, v7);
Q8_UNPACK8(d3[0], f0, f1); bd3 += dot(f0, v0) + dot(f1, v1);
Q8_UNPACK8(d3[1], f0, f1); bd3 += dot(f0, v2) + dot(f1, v3);
Q8_UNPACK8(d3[2], f0, f1); bd3 += dot(f0, v4) + dot(f1, v5);
Q8_UNPACK8(d3[3], f0, f1); bd3 += dot(f0, v6) + dot(f1, v7);
#undef Q8_UNPACK8
sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
}
sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
device float* output = outputs + token * rows;
if (simd_lane == 0) {
if (row_base < rows) output[row_base] = sum0;
if (row_base + 1 < rows) output[row_base + 1] = sum1;
if (row_base + 2 < rows) output[row_base + 2] = sum2;
if (row_base + 3 < rows) output[row_base + 3] = sum3;
}
}
// ── matmul_q8_gemm_batch ───────────────────────────────────────────────
// True GEMM-style Q8_0 kernel that reuses weight reads across a token tile.
// Each threadgroup covers 32 rows and TOKENS_PER_TG consecutive tokens, so
// the Q8_0 weight blocks are fetched once from device memory and reused for
// every token in the tile (1/TOKENS_PER_TG the weight bandwidth of the
// per-token dispatch).
//
// Grid: (ceil(rows/32), ceil(M/TOKENS_PER_TG)) threadgroups.
// Each TG: 8 simdgroups * 4 rows = 32 rows; each simdgroup reduces over blocks
// with simd_sum. Token vectors are read directly from device memory inside
// the block loop (not cached in shared memory) so intermediate_size up to
// 8192 fits without spilling threadgroup memory.
constant constexpr uint TOKENS_PER_TG_Q8 = 4;
kernel void matmul_q8_gemm_batch(
device const uchar* matrix [[buffer(0)]], // Q8_0 raw bytes [rows, cols]
device const float* inputs [[buffer(1)]], // [M, cols] input batch
device float* outputs [[buffer(2)]], // [M, rows] output batch
constant uint& num_tokens [[buffer(3)]], // M
constant uint& rows [[buffer(4)]],
constant uint& cols [[buffer(5)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint row_base = tgid.x * Q8_ROWS_PER_TG + simd_id * Q8_ROWS_PER_SG;
uint tok_base = tgid.y * TOKENS_PER_TG_Q8;
if (row_base >= rows || tok_base >= num_tokens) return;
// How many tokens in this tile are valid?
uint tok_count = min(uint(TOKENS_PER_TG_Q8), num_tokens - tok_base);
uint blocks_per_row = cols / 32;
uint row_bytes = blocks_per_row * 34;
device const uchar* r0 = matrix + row_base * row_bytes;
device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
// Accumulators: 4 tokens × 4 rows per simdgroup.
float s00 = 0, s01 = 0, s02 = 0, s03 = 0;
float s10 = 0, s11 = 0, s12 = 0, s13 = 0;
float s20 = 0, s21 = 0, s22 = 0, s23 = 0;
float s30 = 0, s31 = 0, s32 = 0, s33 = 0;
for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
uint bb = blk * 34;
uint vb = blk * 32;
// ── Load weight data ONCE per block (reused across all tokens) ──
float sc0 = float(*(device const half*)(r0 + bb));
float sc1 = float(*(device const half*)(r1 + bb));
float sc2 = float(*(device const half*)(r2 + bb));
float sc3 = float(*(device const half*)(r3 + bb));
device const packed_short4* d0 = (device const packed_short4*)(r0 + bb + 2);
device const packed_short4* d1 = (device const packed_short4*)(r1 + bb + 2);
device const packed_short4* d2 = (device const packed_short4*)(r2 + bb + 2);
device const packed_short4* d3 = (device const packed_short4*)(r3 + bb + 2);
#define Q8_UNPACK8(SHORT4, OUT_LO, OUT_HI) { \
short4 _s = short4(SHORT4); \
char2 _a = as_type<char2>(_s.x); \
char2 _b = as_type<char2>(_s.y); \
char2 _c = as_type<char2>(_s.z); \
char2 _d = as_type<char2>(_s.w); \
(OUT_LO) = float4(float(_a.x), float(_a.y), float(_b.x), float(_b.y)); \
(OUT_HI) = float4(float(_c.x), float(_c.y), float(_d.x), float(_d.y)); \
}
// Unpack all 4 rows × 8 float4 weights (scaled). These live in
// registers for the duration of the block and are dotted against
// every token's vector tile.
float4 w0_0, w0_1, w0_2, w0_3, w0_4, w0_5, w0_6, w0_7;
float4 w1_0, w1_1, w1_2, w1_3, w1_4, w1_5, w1_6, w1_7;
float4 w2_0, w2_1, w2_2, w2_3, w2_4, w2_5, w2_6, w2_7;
float4 w3_0, w3_1, w3_2, w3_3, w3_4, w3_5, w3_6, w3_7;
Q8_UNPACK8(d0[0], w0_0, w0_1);
Q8_UNPACK8(d0[1], w0_2, w0_3);
Q8_UNPACK8(d0[2], w0_4, w0_5);
Q8_UNPACK8(d0[3], w0_6, w0_7);
Q8_UNPACK8(d1[0], w1_0, w1_1);
Q8_UNPACK8(d1[1], w1_2, w1_3);
Q8_UNPACK8(d1[2], w1_4, w1_5);
Q8_UNPACK8(d1[3], w1_6, w1_7);
Q8_UNPACK8(d2[0], w2_0, w2_1);
Q8_UNPACK8(d2[1], w2_2, w2_3);
Q8_UNPACK8(d2[2], w2_4, w2_5);
Q8_UNPACK8(d2[3], w2_6, w2_7);
Q8_UNPACK8(d3[0], w3_0, w3_1);
Q8_UNPACK8(d3[1], w3_2, w3_3);
Q8_UNPACK8(d3[2], w3_4, w3_5);
Q8_UNPACK8(d3[3], w3_6, w3_7);
#undef Q8_UNPACK8
// ── For each token, read vector and accumulate against shared weights ──
// Token 0 (always valid: tok_count >= 1).
{
device const float* a0 = inputs + (tok_base + 0) * cols + vb;
float4 v0 = *(device const float4*)(a0);
float4 v1 = *(device const float4*)(a0 + 4);
float4 v2 = *(device const float4*)(a0 + 8);
float4 v3 = *(device const float4*)(a0 + 12);
float4 v4 = *(device const float4*)(a0 + 16);
float4 v5 = *(device const float4*)(a0 + 20);
float4 v6 = *(device const float4*)(a0 + 24);
float4 v7 = *(device const float4*)(a0 + 28);
float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
+ dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
+ dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
+ dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
+ dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
s00 += sc0 * bd0; s01 += sc1 * bd1; s02 += sc2 * bd2; s03 += sc3 * bd3;
}
// Token 1
if (tok_count > 1) {
device const float* a1 = inputs + (tok_base + 1) * cols + vb;
float4 v0 = *(device const float4*)(a1);
float4 v1 = *(device const float4*)(a1 + 4);
float4 v2 = *(device const float4*)(a1 + 8);
float4 v3 = *(device const float4*)(a1 + 12);
float4 v4 = *(device const float4*)(a1 + 16);
float4 v5 = *(device const float4*)(a1 + 20);
float4 v6 = *(device const float4*)(a1 + 24);
float4 v7 = *(device const float4*)(a1 + 28);
float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
+ dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
+ dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
+ dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
+ dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
s10 += sc0 * bd0; s11 += sc1 * bd1; s12 += sc2 * bd2; s13 += sc3 * bd3;
}
// Token 2
if (tok_count > 2) {
device const float* a2 = inputs + (tok_base + 2) * cols + vb;
float4 v0 = *(device const float4*)(a2);
float4 v1 = *(device const float4*)(a2 + 4);
float4 v2 = *(device const float4*)(a2 + 8);
float4 v3 = *(device const float4*)(a2 + 12);
float4 v4 = *(device const float4*)(a2 + 16);
float4 v5 = *(device const float4*)(a2 + 20);
float4 v6 = *(device const float4*)(a2 + 24);
float4 v7 = *(device const float4*)(a2 + 28);
float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
+ dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
+ dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
+ dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
+ dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
s20 += sc0 * bd0; s21 += sc1 * bd1; s22 += sc2 * bd2; s23 += sc3 * bd3;
}
// Token 3
if (tok_count > 3) {
device const float* a3 = inputs + (tok_base + 3) * cols + vb;
float4 v0 = *(device const float4*)(a3);
float4 v1 = *(device const float4*)(a3 + 4);
float4 v2 = *(device const float4*)(a3 + 8);
float4 v3 = *(device const float4*)(a3 + 12);
float4 v4 = *(device const float4*)(a3 + 16);
float4 v5 = *(device const float4*)(a3 + 20);
float4 v6 = *(device const float4*)(a3 + 24);
float4 v7 = *(device const float4*)(a3 + 28);
float bd0 = dot(w0_0,v0)+dot(w0_1,v1)+dot(w0_2,v2)+dot(w0_3,v3)
+ dot(w0_4,v4)+dot(w0_5,v5)+dot(w0_6,v6)+dot(w0_7,v7);
float bd1 = dot(w1_0,v0)+dot(w1_1,v1)+dot(w1_2,v2)+dot(w1_3,v3)
+ dot(w1_4,v4)+dot(w1_5,v5)+dot(w1_6,v6)+dot(w1_7,v7);
float bd2 = dot(w2_0,v0)+dot(w2_1,v1)+dot(w2_2,v2)+dot(w2_3,v3)
+ dot(w2_4,v4)+dot(w2_5,v5)+dot(w2_6,v6)+dot(w2_7,v7);
float bd3 = dot(w3_0,v0)+dot(w3_1,v1)+dot(w3_2,v2)+dot(w3_3,v3)
+ dot(w3_4,v4)+dot(w3_5,v5)+dot(w3_6,v6)+dot(w3_7,v7);
s30 += sc0 * bd0; s31 += sc1 * bd1; s32 += sc2 * bd2; s33 += sc3 * bd3;
}
}
// simdgroup reduction
s00 = simd_sum(s00); s01 = simd_sum(s01); s02 = simd_sum(s02); s03 = simd_sum(s03);
s10 = simd_sum(s10); s11 = simd_sum(s11); s12 = simd_sum(s12); s13 = simd_sum(s13);
s20 = simd_sum(s20); s21 = simd_sum(s21); s22 = simd_sum(s22); s23 = simd_sum(s23);
s30 = simd_sum(s30); s31 = simd_sum(s31); s32 = simd_sum(s32); s33 = simd_sum(s33);
if (simd_lane == 0) {
device float* o0 = outputs + (tok_base + 0) * rows;
if (row_base < rows) o0[row_base] = s00;
if (row_base + 1 < rows) o0[row_base + 1] = s01;
if (row_base + 2 < rows) o0[row_base + 2] = s02;
if (row_base + 3 < rows) o0[row_base + 3] = s03;
if (tok_count > 1) {
device float* o1 = outputs + (tok_base + 1) * rows;
if (row_base < rows) o1[row_base] = s10;
if (row_base + 1 < rows) o1[row_base + 1] = s11;
if (row_base + 2 < rows) o1[row_base + 2] = s12;
if (row_base + 3 < rows) o1[row_base + 3] = s13;
}
if (tok_count > 2) {
device float* o2 = outputs + (tok_base + 2) * rows;
if (row_base < rows) o2[row_base] = s20;
if (row_base + 1 < rows) o2[row_base + 1] = s21;
if (row_base + 2 < rows) o2[row_base + 2] = s22;
if (row_base + 3 < rows) o2[row_base + 3] = s23;
}
if (tok_count > 3) {
device float* o3 = outputs + (tok_base + 3) * rows;
if (row_base < rows) o3[row_base] = s30;
if (row_base + 1 < rows) o3[row_base + 1] = s31;
if (row_base + 2 < rows) o3[row_base + 2] = s32;
if (row_base + 3 < rows) o3[row_base + 3] = s33;
}
}
}
// ── matmul_q8_mma ──────────────────────────────────────────────────────
// Hardware matrix-multiply GEMM for Q8_0 weights, using Apple Silicon
// simdgroup_matrix tiles (simdgroup_multiply_accumulate). This dispatches
// far higher FLOP/cycle than the scalar dot-product GEMM and is the primary
// driver of prompt-prefill throughput on M >= MMA_TOK_TILE inputs.
//
// Tile: 16 tokens × 16 rows per threadgroup, K=32 per iteration (one Q8 block).
// 4 simdgroups per TG, each computing a single 8×8 output sub-tile via one
// simdgroup_matrix<float, 8, 8> accumulator. Weight bytes are cooperatively
// dequantized into threadgroup memory once per block and reused by all
// simdgroups in the tile.
//
// Assumptions (verified in the dispatch helper, falls back otherwise):
// * cols % 32 == 0 (one Q8_0 block per K chunk)
// * rows % 16 == 0 (tile-aligned; true for all supported architectures)
// * num_tokens may be any value; partial row at the tile boundary is handled
// via a scratch copy path.
constant constexpr uint MMA_TOK_TILE = 16;
constant constexpr uint MMA_ROW_TILE = 16;
kernel void matmul_q8_mma(
device const uchar* matrix [[buffer(0)]], // Q8_0 [rows, cols/32 * 34]
device const float* inputs [[buffer(1)]], // [M, cols]
device float* outputs [[buffer(2)]], // [M, rows]
constant uint& num_tokens [[buffer(3)]],
constant uint& rows [[buffer(4)]],
constant uint& cols [[buffer(5)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint row_base = tgid.x * MMA_ROW_TILE;
uint tok_base = tgid.y * MMA_TOK_TILE;
if (row_base >= rows || tok_base >= num_tokens) return;
// Shared dequant tiles (16*32 = 512 floats = 2 KB each, 4 KB total).
threadgroup float w_tile[MMA_ROW_TILE * 32];
threadgroup float t_tile[MMA_TOK_TILE * 32];
// 4 simdgroups → 2×2 grid of 8×8 sub-tiles inside the 16×16 output.
uint sg_tok_base = (simd_id / 2) * 8; // row within output tile (token dim)
uint sg_row_base = (simd_id % 2) * 8; // col within output tile (row dim)
simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
uint blocks_per_row = cols / 32;
uint row_bytes = blocks_per_row * 34;
for (uint blk = 0; blk < blocks_per_row; blk++) {
// ── Cooperatively dequantize 16 weight rows × 32 K into w_tile ──
// 512 floats / 128 threads = 4 floats per thread.
{
uint base = tid * 4;
for (uint ii = 0; ii < 4; ii++) {
uint idx = base + ii;
uint r = idx / 32;
uint k = idx % 32;
device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
float sc = float(*(device const half*)rp);
int ival = int(*(device const int8_t*)(rp + 2 + k));
w_tile[r * 32 + k] = float(ival) * sc;
}
}
// ── Cooperatively load 16 token vectors × 32 K into t_tile ──
{
uint base = tid * 4;
for (uint ii = 0; ii < 4; ii++) {
uint idx = base + ii;
uint m = idx / 32;
uint k = idx % 32;
uint tok = tok_base + m;
t_tile[m * 32 + k] = (tok < num_tokens)
? inputs[tok * cols + blk * 32 + k]
: 0.0f;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── 4 × (8×8×8) MMA over the K=32 chunk ──
// A[m, k] = t_tile[(sg_tok_base + m) * 32 + k_sub*8 + k] (M×K, no transpose)
// B[k, r] = w_tile[(sg_row_base + r) * 32 + k_sub*8 + k] (loaded transposed → K×R)
// C[m, r] += A[m, k] * B[k, r]
for (uint k_sub = 0; k_sub < 4; k_sub++) {
simdgroup_matrix<float, 8, 8> A, B;
simdgroup_load(A,
t_tile + sg_tok_base * 32 + k_sub * 8,
32,
ulong2(0, 0),
false);
simdgroup_load(B,
w_tile + sg_row_base * 32 + k_sub * 8,
32,
ulong2(0, 0),
true);
simdgroup_multiply_accumulate(C, A, B, C);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// ── Store C to outputs[(tok_base+sg_tok_base)+m, (row_base+sg_row_base)+r] ──
// Output layout: outputs[tok * rows + row], stride = rows (always tile-aligned).
uint out_tok = tok_base + sg_tok_base;
uint out_row = row_base + sg_row_base;
bool full_tok = (out_tok + 8 <= num_tokens);
if (full_tok) {
// Fast path: entire 8×8 sub-tile is in-bounds.
simdgroup_store(C, outputs + out_tok * rows + out_row, rows);
} else if (out_tok < num_tokens) {
// Partial row at the last token tile: stage in per-simdgroup scratch
// and scalar-copy the valid rows.
threadgroup float scratch[4 * 64];
simdgroup_store(C, scratch + simd_id * 64, 8);
simdgroup_barrier(mem_flags::mem_threadgroup);
uint lane = tid % 32;
if (lane == 0) {
uint valid = num_tokens - out_tok; // 1..7
for (uint m = 0; m < valid; m++) {
device float* dst = outputs + (out_tok + m) * rows + out_row;
threadgroup const float* src = scratch + simd_id * 64 + m * 8;
dst[0] = src[0]; dst[1] = src[1]; dst[2] = src[2]; dst[3] = src[3];
dst[4] = src[4]; dst[5] = src[5]; dst[6] = src[6]; dst[7] = src[7];
}
}
}
}
// ── matmul_q8_mma32 ────────────────────────────────────────────────────
// Larger-tile variant of matmul_q8_mma for long-context prefill.
//
// Tile: 32 tokens × 32 rows per threadgroup, K=32 per iteration.
// 8 simdgroups (256 threads) cover the 16-tile 4×4 output grid, with each
// simdgroup owning *two* stacked 8×8 accumulators along the row axis:
//
// simd_id = 2*sg_tok_idx + sg_row_half (sg_tok_idx∈[0,3], sg_row_half∈[0,1])
// output sub-tiles (tok, row):
// (sg_tok_idx*8, sg_row_half*16 + 0) -> C_a
// (sg_tok_idx*8, sg_row_half*16 + 8) -> C_b
//
// This layout reuses the loaded A (token) simdgroup_matrix twice per K_sub
// iteration — better FLOP/load ratio than the 16×16 single-accumulator
// kernel — and halves the number of threadgroups vs the 16×16 tile.
//
// Assumptions (verified in dispatch helper, fallback otherwise):
// * cols % 32 == 0
// * rows % 32 == 0
constant constexpr uint MMA32_TOK_TILE = 32;
constant constexpr uint MMA32_ROW_TILE = 32;
kernel void matmul_q8_mma32(
device const uchar* matrix [[buffer(0)]],
device const float* inputs [[buffer(1)]],
device float* outputs [[buffer(2)]],
constant uint& num_tokens [[buffer(3)]],
constant uint& rows [[buffer(4)]],
constant uint& cols [[buffer(5)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint row_base = tgid.x * MMA32_ROW_TILE;
uint tok_base = tgid.y * MMA32_TOK_TILE;
if (row_base >= rows || tok_base >= num_tokens) return;
// 32×32 float tiles in threadgroup memory = 4 KB each, 8 KB total.
threadgroup float w_tile[MMA32_ROW_TILE * 32];
threadgroup float t_tile[MMA32_TOK_TILE * 32];
uint sg_tok_idx = simd_id / 2; // 0..3
uint sg_row_half = simd_id % 2; // 0..1
uint sg_tok_base = sg_tok_idx * 8;
uint sg_row_base_a = sg_row_half * 16 + 0;
uint sg_row_base_b = sg_row_half * 16 + 8;
simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
uint blocks_per_row = cols / 32;
uint row_bytes = blocks_per_row * 34;
for (uint blk = 0; blk < blocks_per_row; blk++) {
// Cooperative weight dequantization: 32*32 floats / 256 threads = 4 floats each.
{
uint base = tid * 4;
for (uint ii = 0; ii < 4; ii++) {
uint idx = base + ii;
uint r = idx / 32;
uint k = idx % 32;
device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
float sc = float(*(device const half*)rp);
int ival = int(*(device const int8_t*)(rp + 2 + k));
w_tile[r * 32 + k] = float(ival) * sc;
}
}
// Cooperative token tile load.
{
uint base = tid * 4;
for (uint ii = 0; ii < 4; ii++) {
uint idx = base + ii;
uint m = idx / 32;
uint k = idx % 32;
uint tok = tok_base + m;
t_tile[m * 32 + k] = (tok < num_tokens)
? inputs[tok * cols + blk * 32 + k]
: 0.0f;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// 4 K-sub chunks of 8 each. For each, reuse A across both row accumulators.
for (uint k_sub = 0; k_sub < 4; k_sub++) {
simdgroup_matrix<float, 8, 8> A, B_a, B_b;
simdgroup_load(A,
t_tile + sg_tok_base * 32 + k_sub * 8,
32,
ulong2(0, 0),
false);
simdgroup_load(B_a,
w_tile + sg_row_base_a * 32 + k_sub * 8,
32,
ulong2(0, 0),
true);
simdgroup_load(B_b,
w_tile + sg_row_base_b * 32 + k_sub * 8,
32,
ulong2(0, 0),
true);
simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Store both 8×8 accumulators. rows is always MMA32_ROW_TILE-aligned
// (verified in dispatch), so full simdgroup_store is safe for the row
// dimension; only the last token tile may be partial.
uint out_tok = tok_base + sg_tok_base;
uint out_row_a = row_base + sg_row_base_a;
uint out_row_b = row_base + sg_row_base_b;
bool full_tok = (out_tok + 8 <= num_tokens);
if (full_tok) {
simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
} else if (out_tok < num_tokens) {
threadgroup float scratch[8 * 2 * 64]; // 8 simdgroups × 2 accs × 64 floats
simdgroup_store(C_a, scratch + simd_id * 128, 8);
simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
simdgroup_barrier(mem_flags::mem_threadgroup);
uint lane = tid % 32;
if (lane == 0) {
uint valid = num_tokens - out_tok; // 1..7
for (uint m = 0; m < valid; m++) {
device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
for (uint j = 0; j < 8; j++) {
dst_a[j] = src_a[j];
dst_b[j] = src_b[j];
}
}
}
}
}
// ── matmul_q8_mma32_h ──────────────────────────────────────────────────
// FP16 threadgroup-tile variant of matmul_q8_mma32.
//
// Stores dequantized weights and token inputs as `half` in threadgroup
// memory — halving the shared-memory footprint (4 KB total vs 8 KB) and
// doubling concurrent-threadgroup occupancy per GPU core on Apple Silicon.
// The Q8_0 weight range is already int8 × f32_scale, so a f16 intermediate
// representation preserves the full quantized dynamic range. Token
// activations stay numerically safe because the subsequent
// `simdgroup_multiply_accumulate` keeps the accumulator in `float`.
//
// Tile: 32 × 32 (same as mma32), 8 simdgroups × 2 row-stacked 8×8
// accumulators each. Primary win vs mma32 is occupancy at moderate
// prefill lengths where the GPU is wave-starved.
kernel void matmul_q8_mma32_h(
device const uchar* matrix [[buffer(0)]],
device const float* inputs [[buffer(1)]],
device float* outputs [[buffer(2)]],
constant uint& num_tokens [[buffer(3)]],
constant uint& rows [[buffer(4)]],
constant uint& cols [[buffer(5)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint row_base = tgid.x * MMA32_ROW_TILE;
uint tok_base = tgid.y * MMA32_TOK_TILE;
if (row_base >= rows || tok_base >= num_tokens) return;
// 32×32 half tiles — 2 KB each, 4 KB total.
threadgroup half w_tile[MMA32_ROW_TILE * 32];
threadgroup half t_tile[MMA32_TOK_TILE * 32];
uint sg_tok_idx = simd_id / 2;
uint sg_row_half = simd_id % 2;
uint sg_tok_base = sg_tok_idx * 8;
uint sg_row_base_a = sg_row_half * 16 + 0;
uint sg_row_base_b = sg_row_half * 16 + 8;
simdgroup_matrix<float, 8, 8> C_a = simdgroup_matrix<float, 8, 8>(0.0f);
simdgroup_matrix<float, 8, 8> C_b = simdgroup_matrix<float, 8, 8>(0.0f);
uint blocks_per_row = cols / 32;
uint row_bytes = blocks_per_row * 34;
for (uint blk = 0; blk < blocks_per_row; blk++) {
// Cooperative weight dequantization to FP16.
{
uint base = tid * 4;
for (uint ii = 0; ii < 4; ii++) {
uint idx = base + ii;
uint r = idx / 32;
uint k = idx % 32;
device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
float sc = float(*(device const half*)rp);
int ival = int(*(device const int8_t*)(rp + 2 + k));
w_tile[r * 32 + k] = half(float(ival) * sc);
}
}
// Cooperative token tile load (f32 → f16 narrowing).
{
uint base = tid * 4;
for (uint ii = 0; ii < 4; ii++) {
uint idx = base + ii;
uint m = idx / 32;
uint k = idx % 32;
uint tok = tok_base + m;
t_tile[m * 32 + k] = (tok < num_tokens)
? half(inputs[tok * cols + blk * 32 + k])
: half(0);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint k_sub = 0; k_sub < 4; k_sub++) {
simdgroup_matrix<half, 8, 8> A, B_a, B_b;
simdgroup_load(A,
t_tile + sg_tok_base * 32 + k_sub * 8,
32,
ulong2(0, 0),
false);
simdgroup_load(B_a,
w_tile + sg_row_base_a * 32 + k_sub * 8,
32,
ulong2(0, 0),
true);
simdgroup_load(B_b,
w_tile + sg_row_base_b * 32 + k_sub * 8,
32,
ulong2(0, 0),
true);
simdgroup_multiply_accumulate(C_a, A, B_a, C_a);
simdgroup_multiply_accumulate(C_b, A, B_b, C_b);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
uint out_tok = tok_base + sg_tok_base;
uint out_row_a = row_base + sg_row_base_a;
uint out_row_b = row_base + sg_row_base_b;
bool full_tok = (out_tok + 8 <= num_tokens);
if (full_tok) {
simdgroup_store(C_a, outputs + out_tok * rows + out_row_a, rows);
simdgroup_store(C_b, outputs + out_tok * rows + out_row_b, rows);
} else if (out_tok < num_tokens) {
threadgroup float scratch[8 * 2 * 64];
simdgroup_store(C_a, scratch + simd_id * 128, 8);
simdgroup_store(C_b, scratch + simd_id * 128 + 64, 8);
simdgroup_barrier(mem_flags::mem_threadgroup);
uint lane = tid % 32;
if (lane == 0) {
uint valid = num_tokens - out_tok;
for (uint m = 0; m < valid; m++) {
device float* dst_a = outputs + (out_tok + m) * rows + out_row_a;
device float* dst_b = outputs + (out_tok + m) * rows + out_row_b;
threadgroup const float* src_a = scratch + simd_id * 128 + m * 8;
threadgroup const float* src_b = scratch + simd_id * 128 + 64 + m * 8;
for (uint j = 0; j < 8; j++) {
dst_a[j] = src_a[j];
dst_b[j] = src_b[j];
}
}
}
}
}
// ── matmul_q8_mma32_h4 ─────────────────────────────────────────────────
// 4-simdgroup variant of the FP16-tile 32×32 MMA kernel.
//
// Instead of 8 simdgroups × 2 row-stacked accumulators, this kernel runs
// 4 simdgroups × **2×2 grid** of 8×8 accumulators each. Per simdgroup:
// C_00 (tok 0..8, row 0..8) C_01 (tok 0..8, row 8..16)
// C_10 (tok 8..16, row 0..8) C_11 (tok 8..16, row 8..16)
// A simdgroup_id addresses one 16×16 quadrant of the 32×32 output tile.
//
// Per K_sub iteration: load two A fragments and two B fragments, then run
// **four** MMA instructions reusing A_top with both B's and A_bot with
// both B's. That's double the FLOP-per-simdgroup-load compared to the
// 2-accumulator kernel and halves the thread count per threadgroup (128
// threads), which often improves occupancy on Apple GPUs where the
// concurrent-thread budget is the tighter limit than shared-memory size.
kernel void matmul_q8_mma32_h4(
device const uchar* matrix [[buffer(0)]],
device const float* inputs [[buffer(1)]],
device float* outputs [[buffer(2)]],
constant uint& num_tokens [[buffer(3)]],
constant uint& rows [[buffer(4)]],
constant uint& cols [[buffer(5)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint row_base = tgid.x * MMA32_ROW_TILE;
uint tok_base = tgid.y * MMA32_TOK_TILE;
if (row_base >= rows || tok_base >= num_tokens) return;
// 32×32 FP16 tiles, 4 KB total.
threadgroup half w_tile[MMA32_ROW_TILE * 32];
threadgroup half t_tile[MMA32_TOK_TILE * 32];
// 4 simdgroups laid out as a 2×2 grid of 16×16 quadrants.
uint sg_tok_q = simd_id / 2; // 0..1
uint sg_row_q = simd_id % 2; // 0..1
uint sg_tok_base = sg_tok_q * 16;
uint sg_row_base = sg_row_q * 16;
simdgroup_matrix<float, 8, 8> C_00 = simdgroup_matrix<float, 8, 8>(0.0f);
simdgroup_matrix<float, 8, 8> C_01 = simdgroup_matrix<float, 8, 8>(0.0f);
simdgroup_matrix<float, 8, 8> C_10 = simdgroup_matrix<float, 8, 8>(0.0f);
simdgroup_matrix<float, 8, 8> C_11 = simdgroup_matrix<float, 8, 8>(0.0f);
uint blocks_per_row = cols / 32;
uint row_bytes = blocks_per_row * 34;
for (uint blk = 0; blk < blocks_per_row; blk++) {
// Cooperative weight dequant — 128 threads × 8 halves = 1024 = 32*32.
{
uint base = tid * 8;
for (uint ii = 0; ii < 8; ii++) {
uint idx = base + ii;
uint r = idx / 32;
uint k = idx % 32;
device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
float sc = float(*(device const half*)rp);
int ival = int(*(device const int8_t*)(rp + 2 + k));
w_tile[r * 32 + k] = half(float(ival) * sc);
}
}
// Cooperative token tile load.
{
uint base = tid * 8;
for (uint ii = 0; ii < 8; ii++) {
uint idx = base + ii;
uint m = idx / 32;
uint k = idx % 32;
uint tok = tok_base + m;
t_tile[m * 32 + k] = (tok < num_tokens)
? half(inputs[tok * cols + blk * 32 + k])
: half(0);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// 4 K-sub chunks, 4 MMA ops each, reusing A's and B's.
for (uint k_sub = 0; k_sub < 4; k_sub++) {
simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
simdgroup_load(A_top,
t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
32,
ulong2(0, 0),
false);
simdgroup_load(A_bot,
t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
32,
ulong2(0, 0),
false);
simdgroup_load(B_lo,
w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
32,
ulong2(0, 0),
true);
simdgroup_load(B_hi,
w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
32,
ulong2(0, 0),
true);
simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Store 4 output tiles. Full-tile fast path assumes full 16×16 valid.
uint out_tok_top = tok_base + sg_tok_base + 0;
uint out_tok_bot = tok_base + sg_tok_base + 8;
uint out_row_lo = row_base + sg_row_base + 0;
uint out_row_hi = row_base + sg_row_base + 8;
bool full = (out_tok_bot + 8 <= num_tokens);
if (full) {
simdgroup_store(C_00, outputs + out_tok_top * rows + out_row_lo, rows);
simdgroup_store(C_01, outputs + out_tok_top * rows + out_row_hi, rows);
simdgroup_store(C_10, outputs + out_tok_bot * rows + out_row_lo, rows);
simdgroup_store(C_11, outputs + out_tok_bot * rows + out_row_hi, rows);
} else {
// Partial-token fallback via per-simdgroup scratch.
threadgroup float scratch[4 * 4 * 64]; // 4 simdgroups × 4 accs × 64
uint sg_base = simd_id * 256;
simdgroup_store(C_00, scratch + sg_base + 0, 8);
simdgroup_store(C_01, scratch + sg_base + 64, 8);
simdgroup_store(C_10, scratch + sg_base + 128, 8);
simdgroup_store(C_11, scratch + sg_base + 192, 8);
simdgroup_barrier(mem_flags::mem_threadgroup);
uint lane = tid % 32;
if (lane == 0) {
for (uint m = 0; m < 8; m++) {
uint t_top = out_tok_top + m;
if (t_top < num_tokens) {
device float* dst0 = outputs + t_top * rows + out_row_lo;
device float* dst1 = outputs + t_top * rows + out_row_hi;
threadgroup const float* src0 = scratch + sg_base + 0 + m * 8;
threadgroup const float* src1 = scratch + sg_base + 64 + m * 8;
for (uint j = 0; j < 8; j++) { dst0[j] = src0[j]; dst1[j] = src1[j]; }
}
uint t_bot = out_tok_bot + m;
if (t_bot < num_tokens) {
device float* dst2 = outputs + t_bot * rows + out_row_lo;
device float* dst3 = outputs + t_bot * rows + out_row_hi;
threadgroup const float* src2 = scratch + sg_base + 128 + m * 8;
threadgroup const float* src3 = scratch + sg_base + 192 + m * 8;
for (uint j = 0; j < 8; j++) { dst2[j] = src2[j]; dst3[j] = src3[j]; }
}
}
}
}
}
// ── matmul_q8_mma32_hh4 ────────────────────────────────────────────────
// All-half MMA variant of matmul_q8_mma32_h4.
//
// Both the input matrices and the accumulators are simdgroup_matrix<half>.
// On Apple Silicon, FP16 `simdgroup_multiply_accumulate` runs at 2x the FP32
// rate (dual-issue FMA), so if Q8_0 precision holds through half
// accumulation this kernel can double the effective FLOP throughput on
// matmul-bound prefill.
//
// Numerical notes: Q8_0 weights have only ~8 bits of mantissa and the token
// activations at each layer are bounded (post-RMSNorm ≈ O(1)). Summing
// 2048-wide K for 1B or 8192-wide for the FFN may exceed half's ~3.3-digit
// precision on extreme values, but the inputs are already quantized so the
// per-product error floor is higher than the half-precision rounding error.
// We verify correctness on 135M / 1B / 3B before enabling.
kernel void matmul_q8_mma32_hh4(
device const uchar* matrix [[buffer(0)]],
device const float* inputs [[buffer(1)]],
device float* outputs [[buffer(2)]],
constant uint& num_tokens [[buffer(3)]],
constant uint& rows [[buffer(4)]],
constant uint& cols [[buffer(5)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint row_base = tgid.x * MMA32_ROW_TILE;
uint tok_base = tgid.y * MMA32_TOK_TILE;
if (row_base >= rows || tok_base >= num_tokens) return;
threadgroup half w_tile[MMA32_ROW_TILE * 32];
threadgroup half t_tile[MMA32_TOK_TILE * 32];
uint sg_tok_q = simd_id / 2;
uint sg_row_q = simd_id % 2;
uint sg_tok_base = sg_tok_q * 16;
uint sg_row_base = sg_row_q * 16;
simdgroup_matrix<half, 8, 8> C_00 = simdgroup_matrix<half, 8, 8>(half(0));
simdgroup_matrix<half, 8, 8> C_01 = simdgroup_matrix<half, 8, 8>(half(0));
simdgroup_matrix<half, 8, 8> C_10 = simdgroup_matrix<half, 8, 8>(half(0));
simdgroup_matrix<half, 8, 8> C_11 = simdgroup_matrix<half, 8, 8>(half(0));
uint blocks_per_row = cols / 32;
uint row_bytes = blocks_per_row * 34;
for (uint blk = 0; blk < blocks_per_row; blk++) {
{
uint base = tid * 8;
for (uint ii = 0; ii < 8; ii++) {
uint idx = base + ii;
uint r = idx / 32;
uint k = idx % 32;
device const uchar* rp = matrix + (row_base + r) * row_bytes + blk * 34;
float sc = float(*(device const half*)rp);
int ival = int(*(device const int8_t*)(rp + 2 + k));
w_tile[r * 32 + k] = half(float(ival) * sc);
}
}
{
uint base = tid * 8;
for (uint ii = 0; ii < 8; ii++) {
uint idx = base + ii;
uint m = idx / 32;
uint k = idx % 32;
uint tok = tok_base + m;
t_tile[m * 32 + k] = (tok < num_tokens)
? half(inputs[tok * cols + blk * 32 + k])
: half(0);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint k_sub = 0; k_sub < 4; k_sub++) {
simdgroup_matrix<half, 8, 8> A_top, A_bot, B_lo, B_hi;
simdgroup_load(A_top,
t_tile + (sg_tok_base + 0) * 32 + k_sub * 8,
32, ulong2(0, 0), false);
simdgroup_load(A_bot,
t_tile + (sg_tok_base + 8) * 32 + k_sub * 8,
32, ulong2(0, 0), false);
simdgroup_load(B_lo,
w_tile + (sg_row_base + 0) * 32 + k_sub * 8,
32, ulong2(0, 0), true);
simdgroup_load(B_hi,
w_tile + (sg_row_base + 8) * 32 + k_sub * 8,
32, ulong2(0, 0), true);
simdgroup_multiply_accumulate(C_00, A_top, B_lo, C_00);
simdgroup_multiply_accumulate(C_01, A_top, B_hi, C_01);
simdgroup_multiply_accumulate(C_10, A_bot, B_lo, C_10);
simdgroup_multiply_accumulate(C_11, A_bot, B_hi, C_11);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// Store half accumulators via scratch (must widen to f32 for device output).
uint out_tok_top = tok_base + sg_tok_base + 0;
uint out_tok_bot = tok_base + sg_tok_base + 8;
uint out_row_lo = row_base + sg_row_base + 0;
uint out_row_hi = row_base + sg_row_base + 8;
threadgroup half scratch[4 * 4 * 64];
uint sg_base = simd_id * 256;
simdgroup_store(C_00, scratch + sg_base + 0, 8);
simdgroup_store(C_01, scratch + sg_base + 64, 8);
simdgroup_store(C_10, scratch + sg_base + 128, 8);
simdgroup_store(C_11, scratch + sg_base + 192, 8);
simdgroup_barrier(mem_flags::mem_threadgroup);
uint lane = tid % 32;
if (lane == 0) {
for (uint m = 0; m < 8; m++) {
uint t_top = out_tok_top + m;
if (t_top < num_tokens) {
device float* dst0 = outputs + t_top * rows + out_row_lo;
device float* dst1 = outputs + t_top * rows + out_row_hi;
threadgroup const half* src0 = scratch + sg_base + 0 + m * 8;
threadgroup const half* src1 = scratch + sg_base + 64 + m * 8;
for (uint j = 0; j < 8; j++) {
dst0[j] = float(src0[j]);
dst1[j] = float(src1[j]);
}
}
uint t_bot = out_tok_bot + m;
if (t_bot < num_tokens) {
device float* dst2 = outputs + t_bot * rows + out_row_lo;
device float* dst3 = outputs + t_bot * rows + out_row_hi;
threadgroup const half* src2 = scratch + sg_base + 128 + m * 8;
threadgroup const half* src3 = scratch + sg_base + 192 + m * 8;
for (uint j = 0; j < 8; j++) {
dst2[j] = float(src2[j]);
dst3[j] = float(src3[j]);
}
}
}
}
}
// ── add_bias_batch ─────────────────────────────────────────────────────
// Broadcast-add a per-row bias vector to every row of an [M, rows] output.
// Used for Qwen2 QKV bias after the fused qkv matmul.
// out[token, i] += bias[i] for i in 0..rows, token in 0..num_tokens
kernel void add_bias_batch(
device float* out [[buffer(0)]], // [num_tokens, rows]
device const float* bias [[buffer(1)]], // [rows]
constant uint& num_tokens [[buffer(2)]],
constant uint& rows [[buffer(3)]],
uint id [[thread_position_in_grid]])
{
uint total = num_tokens * rows;
if (id >= total) return;
uint i = id % rows;
out[id] += bias[i];
}
// ── matmul_vec_q4_batch ────────────────────────────────────────────────
// Batched Q4_0 matrix-vector multiply for M input vectors.
// Grid: ceil(rows/Q4_ROWS_PER_TG) * M threadgroups.
kernel void matmul_vec_q4_batch(
device const uchar* matrix [[buffer(0)]], // Q4_0 raw bytes [rows, cols]
device const float* inputs [[buffer(1)]], // [M, cols] input batch
device float* outputs [[buffer(2)]], // [M, rows] output batch
constant uint& num_tokens [[buffer(3)]], // M
constant uint& rows [[buffer(4)]],
constant uint& cols [[buffer(5)]],
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint row_tgs = (rows + Q4_ROWS_PER_TG - 1) / Q4_ROWS_PER_TG;
uint token = tgid / row_tgs;
uint tg_in_token = tgid % row_tgs;
if (token >= num_tokens) return;
threadgroup float vec_tile[VEC_TILE_SIZE];
device const float* input = inputs + token * cols;
for (uint i = tid; i < cols; i += 256) {
vec_tile[i] = input[i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
uint row_base = tg_in_token * Q4_ROWS_PER_TG + simd_id * Q4_ROWS_PER_SG;
if (row_base >= rows) return;
uint blocks_per_row = cols / 32;
uint row_bytes = blocks_per_row * 18;
device const uchar* r0 = matrix + row_base * row_bytes;
device const uchar* r1 = matrix + (row_base + 1) * row_bytes;
device const uchar* r2 = matrix + (row_base + 2) * row_bytes;
device const uchar* r3 = matrix + (row_base + 3) * row_bytes;
float sum0 = 0.0, sum1 = 0.0, sum2 = 0.0, sum3 = 0.0;
for (uint blk = simd_lane; blk < blocks_per_row; blk += 32) {
uint bb = blk * 18;
uint vb = blk * 32;
float sc0 = float(*(device const half*)(r0 + bb));
float sc1 = float(*(device const half*)(r1 + bb));
float sc2 = float(*(device const half*)(r2 + bb));
float sc3 = float(*(device const half*)(r3 + bb));
device const packed_uchar4* p0 = (device const packed_uchar4*)(r0 + bb + 2);
device const packed_uchar4* p1 = (device const packed_uchar4*)(r1 + bb + 2);
device const packed_uchar4* p2 = (device const packed_uchar4*)(r2 + bb + 2);
device const packed_uchar4* p3 = (device const packed_uchar4*)(r3 + bb + 2);
float4 v0 = *(threadgroup const float4*)(vec_tile + vb);
float4 v1 = *(threadgroup const float4*)(vec_tile + vb + 4);
float4 v2 = *(threadgroup const float4*)(vec_tile + vb + 8);
float4 v3 = *(threadgroup const float4*)(vec_tile + vb + 12);
float4 v4 = *(threadgroup const float4*)(vec_tile + vb + 16);
float4 v5 = *(threadgroup const float4*)(vec_tile + vb + 20);
float4 v6 = *(threadgroup const float4*)(vec_tile + vb + 24);
float4 v7 = *(threadgroup const float4*)(vec_tile + vb + 28);
float bd0=0, bd1=0, bd2=0, bd3=0;
uchar4 b;
b=p0[0]; bd0+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
b=p0[1]; bd0+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
b=p0[2]; bd0+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
b=p0[3]; bd0+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
b=p1[0]; bd1+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
b=p1[1]; bd1+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
b=p1[2]; bd1+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
b=p1[3]; bd1+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
b=p2[0]; bd2+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
b=p2[1]; bd2+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
b=p2[2]; bd2+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
b=p2[3]; bd2+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
b=p3[0]; bd3+=float(int(b.x&0xF)-8)*v0.x+float(int(b.x>>4)-8)*v4.x+float(int(b.y&0xF)-8)*v0.y+float(int(b.y>>4)-8)*v4.y+float(int(b.z&0xF)-8)*v0.z+float(int(b.z>>4)-8)*v4.z+float(int(b.w&0xF)-8)*v0.w+float(int(b.w>>4)-8)*v4.w;
b=p3[1]; bd3+=float(int(b.x&0xF)-8)*v1.x+float(int(b.x>>4)-8)*v5.x+float(int(b.y&0xF)-8)*v1.y+float(int(b.y>>4)-8)*v5.y+float(int(b.z&0xF)-8)*v1.z+float(int(b.z>>4)-8)*v5.z+float(int(b.w&0xF)-8)*v1.w+float(int(b.w>>4)-8)*v5.w;
b=p3[2]; bd3+=float(int(b.x&0xF)-8)*v2.x+float(int(b.x>>4)-8)*v6.x+float(int(b.y&0xF)-8)*v2.y+float(int(b.y>>4)-8)*v6.y+float(int(b.z&0xF)-8)*v2.z+float(int(b.z>>4)-8)*v6.z+float(int(b.w&0xF)-8)*v2.w+float(int(b.w>>4)-8)*v6.w;
b=p3[3]; bd3+=float(int(b.x&0xF)-8)*v3.x+float(int(b.x>>4)-8)*v7.x+float(int(b.y&0xF)-8)*v3.y+float(int(b.y>>4)-8)*v7.y+float(int(b.z&0xF)-8)*v3.z+float(int(b.z>>4)-8)*v7.z+float(int(b.w&0xF)-8)*v3.w+float(int(b.w>>4)-8)*v7.w;
sum0 += sc0 * bd0; sum1 += sc1 * bd1; sum2 += sc2 * bd2; sum3 += sc3 * bd3;
}
sum0 = simd_sum(sum0); sum1 = simd_sum(sum1);
sum2 = simd_sum(sum2); sum3 = simd_sum(sum3);
device float* output = outputs + token * rows;
if (simd_lane == 0) {
if (row_base < rows) output[row_base] = sum0;
if (row_base + 1 < rows) output[row_base + 1] = sum1;
if (row_base + 2 < rows) output[row_base + 2] = sum2;
if (row_base + 3 < rows) output[row_base + 3] = sum3;
}
}
// ── copy_kv_batch ─────────────────────────────────────────────────────
// Copy K or V from a strided batch QKV buffer to the KV cache.
// src layout: [M, qkv_stride] with K/V at src_float_offset within each row.
// dst layout: contiguous [max_seq, kv_dim] cache.
kernel void copy_kv_batch(
device const float* src [[buffer(0)]], // batch QKV buffer (f32)
device half* dst [[buffer(1)]], // KV cache (f16)
constant uint& M [[buffer(2)]], // num tokens in batch
constant uint& kv_dim [[buffer(3)]], // elements per KV vector
constant uint& base_pos [[buffer(4)]], // starting position in cache
constant uint& src_stride [[buffer(5)]], // floats per row in src (qkv_rows)
constant uint& src_offset [[buffer(6)]], // float offset within each src row
uint id [[thread_position_in_grid]])
{
uint total = M * kv_dim;
if (id >= total) return;
uint token = id / kv_dim;
uint d = id % kv_dim;
uint dst_off = (base_pos + token) * kv_dim + d;
uint src_off = token * src_stride + src_offset + d;
dst[dst_off] = half(src[src_off]);
}
// ── attention_batch ───────────────────────────────────────────────────
// Batched causal attention for prefill. Processes M tokens in one dispatch.
// Each threadgroup handles one (token, head) pair.
// Q is read from a strided buffer (batch QKV), K/V from the KV cache.
// Causal masking: token i can only attend to positions 0..base_pos+i.
kernel void attention_batch(
device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
constant uint& M [[buffer(4)]], // num tokens in batch
constant uint& base_pos [[buffer(5)]], // starting position in KV cache
constant uint& num_heads [[buffer(6)]],
constant uint& num_kv_heads [[buffer(7)]],
constant uint& head_dim [[buffer(8)]],
constant uint& q_stride [[buffer(9)]], // floats per row in q_batch
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
// Grid: M * num_heads threadgroups
uint token_idx = tgid / num_heads;
uint head = tgid % num_heads;
if (token_idx >= M) return;
uint kv_head = head / (num_heads / num_kv_heads);
uint seq_len = base_pos + token_idx + 1; // causal: see positions 0..base_pos+token_idx
// Q offset uses strided layout (from batch QKV buffer)
uint q_off = token_idx * q_stride + head * head_dim;
// Output is contiguous [M, num_heads * head_dim]
uint out_off = token_idx * num_heads * head_dim + head * head_dim;
// Shared memory for attention scores — sized to the effective max_seq_len
// (4096 for all supported models) so long-context attention doesn't overflow.
threadgroup float scores[ATTN_SCORES_SIZE];
// Step 1: Q * K^T with simdgroup reduction
for (uint s = simd_id; s < seq_len; s += 8) {
uint k_off = s * num_kv_heads * head_dim + kv_head * head_dim;
float dot = 0.0;
for (uint d = simd_lane; d < head_dim; d += 32) {
dot += q_batch[q_off + d] * k_cache[k_off + d];
}
dot = simd_sum(dot);
if (simd_lane == 0) {
scores[s] = dot * fast::rsqrt(float(head_dim));
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Step 2: Softmax (cooperative)
float local_max = -INFINITY;
for (uint s = tid; s < seq_len; s += 256) {
local_max = max(local_max, scores[s]);
}
local_max = simd_max(local_max);
threadgroup float shared_max[8];
if (simd_lane == 0) shared_max[simd_id] = local_max;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float m = shared_max[0];
for (uint i = 1; i < 8; i++) m = max(m, shared_max[i]);
shared_max[0] = m;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float max_val = shared_max[0];
float local_sum = 0.0;
for (uint s = tid; s < seq_len; s += 256) {
scores[s] = fast::exp(scores[s] - max_val);
local_sum += scores[s];
}
local_sum = simd_sum(local_sum);
threadgroup float shared_sum[8];
if (simd_lane == 0) shared_sum[simd_id] = local_sum;
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float total = 0.0;
for (uint i = 0; i < 8; i++) total += shared_sum[i];
shared_sum[0] = (total > 0.0) ? (1.0 / total) : 0.0;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float inv_sum = shared_sum[0];
for (uint s = tid; s < seq_len; s += 256) {
scores[s] *= inv_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// Step 3: scores * V using float4 vectorized loads
// With head_dim=64, processing 4 dims per thread means 16 threads cover all dims.
// This is much better than the scalar version where only 64 of 256 threads are active.
uint v_stride = num_kv_heads * head_dim;
uint head_dim4 = head_dim / 4;
for (uint d4 = tid; d4 < head_dim4; d4 += 256) {
uint d = d4 * 4;
float4 acc = float4(0.0);
uint v_base = kv_head * head_dim + d;
uint seq_len4 = seq_len & ~3u;
for (uint s = 0; s < seq_len4; s += 4) {
float sc0 = scores[s];
float sc1 = scores[s + 1];
float sc2 = scores[s + 2];
float sc3 = scores[s + 3];
acc += sc0 * float4(*(device const half4*)(v_cache + s * v_stride + v_base))
+ sc1 * float4(*(device const half4*)(v_cache + (s+1) * v_stride + v_base))
+ sc2 * float4(*(device const half4*)(v_cache + (s+2) * v_stride + v_base))
+ sc3 * float4(*(device const half4*)(v_cache + (s+3) * v_stride + v_base));
}
for (uint s = seq_len4; s < seq_len; s++) {
acc += scores[s] * float4(*(device const half4*)(v_cache + s * v_stride + v_base));
}
*(device float4*)(output_batch + out_off + d) = acc;
}
// Handle remaining dimensions not divisible by 4 (scalar fallback)
for (uint d = head_dim4 * 4 + tid; d < head_dim; d += 256) {
float acc = 0.0;
uint v_base = kv_head * head_dim + d;
for (uint s = 0; s < seq_len; s++) {
acc += scores[s] * v_cache[s * v_stride + v_base];
}
output_batch[out_off + d] = acc;
}
}
// ── attention_flash_batch ─────────────────────────────────────────────
// Streaming attention with online softmax. Same grid as attention_batch
// (M × num_heads threadgroups, one per (token, head) pair) but the scores
// matrix is never materialized. K/V positions are processed in a tile of
// FLASH_K_TILE at a time, and the running (m, l, O) tuple is updated via
// the standard flash-attention recurrence:
//
// m_new = max(m_old, tile_max)
// alpha = exp(m_old - m_new)
// l_new = alpha * l_old + sum(exp(S - m_new))
// O_new = alpha * O_old + sum(exp(S - m_new) * V)
// O_final = O / l
//
// This removes the `scores[2048]` cap in attention_batch (which silently
// overflows for prompts with seq_len > 2048) and keeps threadgroup memory
// use to O(head_dim + FLASH_K_TILE) instead of O(seq_len).
//
// Assumptions: head_dim ≤ 256 (Llama/Qwen/Mistral/Phi-3 all satisfy this).
constant constexpr uint FLASH_K_TILE = 32;
constant constexpr uint FLASH_MAX_HEAD_DIM = 256;
kernel void attention_flash_batch(
device const float* q_batch [[buffer(0)]], // batch QKV buf (strided)
device const half* k_cache [[buffer(1)]], // [max_seq, num_kv_heads * head_dim] f16
device const half* v_cache [[buffer(2)]], // [max_seq, num_kv_heads * head_dim] f16
device float* output_batch [[buffer(3)]], // [M, num_heads * head_dim]
constant uint& M [[buffer(4)]],
constant uint& base_pos [[buffer(5)]],
constant uint& num_heads [[buffer(6)]],
constant uint& num_kv_heads [[buffer(7)]],
constant uint& head_dim [[buffer(8)]],
constant uint& q_stride [[buffer(9)]],
uint tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint token_idx = tgid / num_heads;
uint head = tgid % num_heads;
if (token_idx >= M) return;
uint kv_head = head / (num_heads / num_kv_heads);
uint seq_len = base_pos + token_idx + 1; // causal: attend to [0, base_pos + token_idx]
uint q_off = token_idx * q_stride + head * head_dim;
uint out_off = token_idx * num_heads * head_dim + head * head_dim;
// Threadgroup state:
// q_sh: Q vector for this (token, head), loaded once
// o_sh: running output vector, updated each K tile
// scores_sh: scores for the current K tile only
// stats: [running max, running sum] (see flash-attention recurrence)
threadgroup float q_sh[FLASH_MAX_HEAD_DIM];
threadgroup float o_sh[FLASH_MAX_HEAD_DIM];
threadgroup float scores_sh[FLASH_K_TILE];
threadgroup float stats[2];
threadgroup float sg_scratch[8]; // simdgroup-level reduction buffer
// --- Load Q (one row) and zero the running O ---
for (uint d = tid; d < head_dim; d += 256) {
q_sh[d] = q_batch[q_off + d];
o_sh[d] = 0.0f;
}
if (tid == 0) {
stats[0] = -INFINITY;
stats[1] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float scale = fast::rsqrt(float(head_dim));
uint v_stride = num_kv_heads * head_dim;
uint v_base = kv_head * head_dim;
// --- Stream K/V in FLASH_K_TILE chunks, updating (m, l, O) each iteration ---
for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_K_TILE) {
uint tile_n = min((uint)FLASH_K_TILE, seq_len - kv_base);
// [1] Compute scores for this tile: scores[ti] = dot(q, k[kv_base+ti]) * scale.
// 8 simdgroups cover up to FLASH_K_TILE/8 positions each, 32 lanes reduce head_dim.
for (uint ti = simd_id; ti < tile_n; ti += 8) {
uint k_off = (kv_base + ti) * v_stride + v_base; // same layout as V stride
float dot = 0.0f;
for (uint d = simd_lane; d < head_dim; d += 32) {
dot += q_sh[d] * k_cache[k_off + d];
}
dot = simd_sum(dot);
if (simd_lane == 0) {
scores_sh[ti] = dot * scale;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// [2] Tile max via cooperative reduction.
float local_max = -INFINITY;
for (uint s = tid; s < tile_n; s += 256) {
local_max = max(local_max, scores_sh[s]);
}
local_max = simd_max(local_max);
if (simd_lane == 0) {
sg_scratch[simd_id] = local_max;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// [3] Merge with running max, compute alpha, rescale running l.
float m_new;
float alpha;
if (tid == 0) {
float tile_max = sg_scratch[0];
for (uint i = 1; i < 8; i++) tile_max = max(tile_max, sg_scratch[i]);
float m_old = stats[0];
m_new = max(m_old, tile_max);
// First iteration: m_old = -inf → alpha = 0 (reset O).
alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
stats[0] = m_new;
stats[1] *= alpha;
// Broadcast via sg_scratch.
sg_scratch[0] = alpha;
sg_scratch[1] = m_new;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
alpha = sg_scratch[0];
m_new = sg_scratch[1];
// [4] Rescale running output by alpha, then compute exp(scores - m_new).
for (uint d = tid; d < head_dim; d += 256) {
o_sh[d] *= alpha;
}
for (uint s = tid; s < tile_n; s += 256) {
scores_sh[s] = fast::exp(scores_sh[s] - m_new);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// [5] Tile sum → update running l.
float local_sum = 0.0f;
for (uint s = tid; s < tile_n; s += 256) {
local_sum += scores_sh[s];
}
local_sum = simd_sum(local_sum);
if (simd_lane == 0) {
sg_scratch[simd_id] = local_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tid == 0) {
float tile_sum = 0.0f;
for (uint i = 0; i < 8; i++) tile_sum += sg_scratch[i];
stats[1] += tile_sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// [6] Accumulate P @ V into o_sh: o_sh[d] += sum_s P[s] * V[kv_base+s, d]
for (uint d = tid; d < head_dim; d += 256) {
float acc = 0.0f;
for (uint s = 0; s < tile_n; s++) {
acc += scores_sh[s] * v_cache[(kv_base + s) * v_stride + v_base + d];
}
o_sh[d] += acc;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// --- Normalize and write output ---
float inv_l = (stats[1] > 0.0f) ? (1.0f / stats[1]) : 0.0f;
for (uint d = tid; d < head_dim; d += 256) {
output_batch[out_off + d] = o_sh[d] * inv_l;
}
}
// ── attention_mma_flash_batch ─────────────────────────────────────────
// MMA-accelerated flash attention using simdgroup_matrix<half, 8, 8> for
// both Q·K^T and P·V. Processes Q_BLOCK=8 tokens of one head per
// threadgroup (vs 1 token per TG in attention_flash_batch), amortizing
// K/V loads across 8 Q rows and using hardware matrix-multiply for the
// arithmetic.
//
// Grid: [ceil(M / 8), num_heads, 1], 128 threads (4 simdgroups) per TG.
// Requires head_dim ≤ FLASH_MMA_MAX_HEAD_DIM (128). Dispatch falls back
// to attention_batch / attention_flash_batch otherwise.
//
// Online softmax recurrence is identical to attention_flash_batch but
// per-Q-row: each K tile updates m[q], l[q], O[q] for q in 0..8.
constant constexpr uint FLASH_MMA_Q_BLOCK = 8;
constant constexpr uint FLASH_MMA_K_BLOCK = 32;
constant constexpr uint FLASH_MMA_MAX_HEAD_DIM = 128;
kernel void attention_mma_flash_batch(
device const float* q_batch [[buffer(0)]],
device const half* k_cache [[buffer(1)]],
device const half* v_cache [[buffer(2)]],
device float* output_batch [[buffer(3)]],
constant uint& M [[buffer(4)]],
constant uint& base_pos [[buffer(5)]],
constant uint& num_heads [[buffer(6)]],
constant uint& num_kv_heads [[buffer(7)]],
constant uint& head_dim [[buffer(8)]],
constant uint& q_stride [[buffer(9)]],
uint2 tgid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simd_lane [[thread_index_in_simdgroup]],
uint simd_id [[simdgroup_index_in_threadgroup]])
{
uint q_block_start = tgid.x * FLASH_MMA_Q_BLOCK;
uint head = tgid.y;
if (q_block_start >= M) return;
uint q_valid = min((uint)FLASH_MMA_Q_BLOCK, M - q_block_start);
uint kv_head = head / (num_heads / num_kv_heads);
// Causal: Q row q (0..q_valid-1) attends to kv_pos in [0, base_pos + q_block_start + q].
// Max attended pos across the block = base_pos + q_block_start + q_valid - 1.
uint seq_len = base_pos + q_block_start + q_valid;
float scale = fast::rsqrt(float(head_dim));
uint kv_stride = num_kv_heads * head_dim;
uint kv_base_off = kv_head * head_dim;
// ── Threadgroup memory ──
// q_sh: [Q_BLOCK, head_dim] half — Q tile, loaded once
// k_sh: [K_BLOCK, head_dim] half — K tile, refreshed per kv_base iter
// v_sh: [K_BLOCK, head_dim] half — V tile, refreshed per kv_base iter
// s_sh: [Q_BLOCK, K_BLOCK] float — raw Q·K^T scores, then scaled+masked
// p_sh: [Q_BLOCK, K_BLOCK] half — softmax probabilities (for P·V MMA)
// o_sh: [Q_BLOCK, head_dim] float — running output accumulator
// m_sh: [Q_BLOCK] float — running max per Q row
// l_sh: [Q_BLOCK] float — running softmax denominator per Q row
// scratch: 4*Q_BLOCK floats — per-row reduction scratch
threadgroup half q_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
threadgroup half k_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
threadgroup half v_sh[FLASH_MMA_K_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
threadgroup float s_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
threadgroup half p_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK];
threadgroup float o_sh[FLASH_MMA_Q_BLOCK * FLASH_MMA_MAX_HEAD_DIM];
threadgroup float m_sh[FLASH_MMA_Q_BLOCK];
threadgroup float l_sh[FLASH_MMA_Q_BLOCK];
threadgroup float scratch[4 * FLASH_MMA_Q_BLOCK];
// ── Load Q tile (Q_BLOCK rows, head_dim cols), init o_sh=0, m_sh=-INF, l_sh=0 ──
uint qblock_elems = FLASH_MMA_Q_BLOCK * head_dim;
for (uint i = tid; i < qblock_elems; i += 128) {
uint q = i / head_dim;
uint d = i % head_dim;
if (q < q_valid) {
uint q_off = (q_block_start + q) * q_stride + head * head_dim + d;
q_sh[q * head_dim + d] = half(q_batch[q_off]);
} else {
q_sh[q * head_dim + d] = half(0);
}
o_sh[q * head_dim + d] = 0.0f;
}
if (tid < FLASH_MMA_Q_BLOCK) {
m_sh[tid] = -INFINITY;
l_sh[tid] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Stream K/V in K_BLOCK chunks ──
for (uint kv_base = 0; kv_base < seq_len; kv_base += FLASH_MMA_K_BLOCK) {
uint tile_n = min((uint)FLASH_MMA_K_BLOCK, seq_len - kv_base);
// Load K and V tile into TG memory (as half).
uint kv_tile_elems = FLASH_MMA_K_BLOCK * head_dim;
for (uint i = tid; i < kv_tile_elems; i += 128) {
uint k_pos = i / head_dim;
uint d = i % head_dim;
if (k_pos < tile_n) {
uint off = (kv_base + k_pos) * kv_stride + kv_base_off + d;
k_sh[k_pos * head_dim + d] = half(k_cache[off]);
v_sh[k_pos * head_dim + d] = half(v_cache[off]);
} else {
k_sh[k_pos * head_dim + d] = half(0);
v_sh[k_pos * head_dim + d] = half(0);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 1: S = Q @ K^T via MMA ──
// 4 simdgroups × 1 tile each → 4 tiles of [8,8] covering [Q_BLOCK=8, K_BLOCK=32].
// Each simdgroup owns S columns [simd_id*8, simd_id*8+8).
// Q is [Q_BLOCK, head_dim]; K is [K_BLOCK, head_dim]; we want K^T via transposed load.
{
simdgroup_matrix<float, 8, 8> C = simdgroup_matrix<float, 8, 8>(0.0f);
uint dim_chunks = head_dim / 8;
for (uint dc = 0; dc < dim_chunks; dc++) {
simdgroup_matrix<half, 8, 8> A, B;
// A = Q[0:8, dc*8 : dc*8+8] (rows of Q, no transpose)
simdgroup_load(A, q_sh + dc * 8, head_dim, ulong2(0, 0), false);
// B = K^T[dc*8 : dc*8+8, simd_id*8 : simd_id*8+8]
// K in TG mem is laid out [K_BLOCK, head_dim]. We load the tile
// K[simd_id*8 : simd_id*8+8, dc*8 : dc*8+8] (stride=head_dim) with
// transpose=true, which places it in the register as K^T of that sub-block.
simdgroup_load(B,
k_sh + (simd_id * 8) * head_dim + dc * 8,
head_dim, ulong2(0, 0), true);
simdgroup_multiply_accumulate(C, A, B, C);
}
// Store S tile into s_sh[0..8, simd_id*8..simd_id*8+8], stride=K_BLOCK.
simdgroup_store(C, s_sh + simd_id * 8, FLASH_MMA_K_BLOCK);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 2a: Apply scale + causal mask in place on s_sh ──
// s_sh is [Q_BLOCK=8, K_BLOCK=32] = 256 elements; 128 threads → 2 each.
uint s_elems = FLASH_MMA_Q_BLOCK * FLASH_MMA_K_BLOCK;
for (uint i = tid; i < s_elems; i += 128) {
uint q = i / FLASH_MMA_K_BLOCK;
uint k = i % FLASH_MMA_K_BLOCK;
uint global_q = q_block_start + q;
uint global_kv = kv_base + k;
bool valid = (q < q_valid) && (k < tile_n) && (global_kv <= base_pos + global_q);
s_sh[i] = valid ? (s_sh[i] * scale) : -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 2b: per-row max via simdgroup reduction ──
// 4 simdgroups × 2 rows each = 8 rows (= Q_BLOCK).
// simd_lane (0..31) covers all K_BLOCK=32 positions in one pass.
{
uint row_base = simd_id * 2;
for (uint qr = 0; qr < 2; qr++) {
uint q = row_base + qr;
float my = s_sh[q * FLASH_MMA_K_BLOCK + simd_lane];
float row_max = simd_max(my);
if (simd_lane == 0) {
scratch[q] = row_max; // tile_max[q]
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 2c: update m, alpha, rescale l; publish m_new and alpha ──
if (tid < FLASH_MMA_Q_BLOCK) {
uint q = tid;
float m_old = m_sh[q];
float tile_max = scratch[q];
float m_new = max(m_old, tile_max);
float alpha = (m_old == -INFINITY) ? 0.0f : fast::exp(m_old - m_new);
m_sh[q] = m_new;
l_sh[q] = l_sh[q] * alpha;
// scratch[q] = m_new (for phase 2d)
// scratch[Q_BLOCK + q] = alpha (for phase 3)
scratch[q] = m_new;
scratch[FLASH_MMA_Q_BLOCK + q] = alpha;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 2d: P = exp(S - m_new), populate p_sh (half) and row-sum ──
{
uint row_base = simd_id * 2;
for (uint qr = 0; qr < 2; qr++) {
uint q = row_base + qr;
float m_new = scratch[q];
float p = fast::exp(s_sh[q * FLASH_MMA_K_BLOCK + simd_lane] - m_new);
p_sh[q * FLASH_MMA_K_BLOCK + simd_lane] = half(p);
float row_sum = simd_sum(p);
if (simd_lane == 0) {
scratch[2 * FLASH_MMA_Q_BLOCK + q] = row_sum;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 2e: l_sh += tile_sum ──
if (tid < FLASH_MMA_Q_BLOCK) {
uint q = tid;
l_sh[q] += scratch[2 * FLASH_MMA_Q_BLOCK + q];
}
// ── Phase 3: Rescale o_sh[q,:] *= alpha[q] ──
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = tid; i < qblock_elems; i += 128) {
uint q = i / head_dim;
float alpha = scratch[FLASH_MMA_Q_BLOCK + q];
o_sh[i] *= alpha;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// ── Phase 4: O += P @ V via MMA ──
// P is [Q_BLOCK=8, K_BLOCK=32] half; V is [K_BLOCK=32, head_dim] half.
// Output tile span for this simdgroup: head_dim / 4 dims, divided into 8-wide tiles.
// For head_dim=64: 16 dims/sg = 2 tiles. head_dim=128: 32 dims/sg = 4 tiles.
{
uint dims_per_sg = head_dim / 4; // 16 or 32
uint tiles_per_sg = dims_per_sg / 8; // 2 or 4
uint sg_d_base = simd_id * dims_per_sg;
for (uint t = 0; t < tiles_per_sg; t++) {
uint d_base = sg_d_base + t * 8;
simdgroup_matrix<float, 8, 8> O_acc;
simdgroup_load(O_acc, o_sh + d_base, head_dim, ulong2(0, 0), false);
uint k_chunks = FLASH_MMA_K_BLOCK / 8; // 4
for (uint kc = 0; kc < k_chunks; kc++) {
simdgroup_matrix<half, 8, 8> A, B;
simdgroup_load(A, p_sh + kc * 8, FLASH_MMA_K_BLOCK,
ulong2(0, 0), false);
simdgroup_load(B, v_sh + (kc * 8) * head_dim + d_base, head_dim,
ulong2(0, 0), false);
simdgroup_multiply_accumulate(O_acc, A, B, O_acc);
}
simdgroup_store(O_acc, o_sh + d_base, head_dim);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// ── Finalize: O /= l, write to output ──
for (uint i = tid; i < qblock_elems; i += 128) {
uint q = i / head_dim;
uint d = i % head_dim;
if (q < q_valid) {
float l = l_sh[q];
float inv_l = (l > 0.0f) ? (1.0f / l) : 0.0f;
uint token_idx = q_block_start + q;
uint out_off = token_idx * num_heads * head_dim + head * head_dim + d;
output_batch[out_off] = o_sh[i] * inv_l;
}
}
}
// ── rope_qk_batch ─────────────────────────────────────────────────────
// Fused RoPE for both Q and K in a single dispatch, saving one kernel
// launch + memory barrier per layer. Both Q and K live in the same
// qkv_data buffer at different offsets within each token's row.
// Q: offset 0, num_q_heads heads. K: offset num_q_heads*head_dim, num_kv_heads heads.
kernel void rope_qk_batch(
device float* qkv_data [[buffer(0)]], // [M, qkv_stride]
constant uint& M [[buffer(1)]], // num tokens
constant uint& base_pos [[buffer(2)]], // starting position
constant uint& num_q_heads [[buffer(3)]],
constant uint& num_kv_heads [[buffer(4)]],
constant uint& head_dim [[buffer(5)]],
constant uint& qkv_stride [[buffer(6)]], // floats per row
constant float& theta [[buffer(7)]],
uint id [[thread_position_in_grid]])
{
uint half_dim = head_dim / 2;
uint total_pairs = (num_q_heads + num_kv_heads) * half_dim;
uint token = id / total_pairs;
uint pair = id % total_pairs;
if (token >= M) return;
uint pos = base_pos + token;
uint q_pairs = num_q_heads * half_dim;
uint h, i, offset;
if (pair < q_pairs) {
// Q head
h = pair / half_dim;
i = pair % half_dim;
offset = token * qkv_stride + h * head_dim + i * 2;
} else {
// K head
uint kp = pair - q_pairs;
h = kp / half_dim;
i = kp % half_dim;
uint k_start = num_q_heads * head_dim;
offset = token * qkv_stride + k_start + h * head_dim + i * 2;
}
float freq = 1.0f / pow(theta, 2.0f * float(i) / float(head_dim));
float angle = float(pos) * freq;
float cos_val = cos(angle);
float sin_val = sin(angle);
float x0 = qkv_data[offset];
float x1 = qkv_data[offset + 1];
qkv_data[offset] = x0 * cos_val - x1 * sin_val;
qkv_data[offset + 1] = x0 * sin_val + x1 * cos_val;
}
// ── copy_kv_both_batch ────────────────────────────────────────────────
// Fused K+V cache copy in a single dispatch: copies both K and V from
// the strided batch QKV buffer to their respective KV cache buffers.
// Saves one kernel launch + memory barrier per layer vs two copy_kv_batch calls.
kernel void copy_kv_both_batch(
device const float* src [[buffer(0)]], // batch QKV buffer [M, qkv_stride] f32
device half* k_dst [[buffer(1)]], // K cache [max_seq, kv_dim] f16
device half* v_dst [[buffer(2)]], // V cache [max_seq, kv_dim] f16
constant uint& M [[buffer(3)]], // num tokens in batch
constant uint& kv_dim [[buffer(4)]], // elements per KV vector
constant uint& base_pos [[buffer(5)]], // starting position in cache
constant uint& src_stride [[buffer(6)]], // floats per row in src (qkv_stride)
constant uint& k_offset [[buffer(7)]], // float offset of K within each src row
constant uint& v_offset [[buffer(8)]], // float offset of V within each src row
uint id [[thread_position_in_grid]])
{
// Total elements = M * kv_dim * 2 (K + V)
uint total_kv = M * kv_dim;
if (id >= total_kv * 2) return;
uint is_v = id / total_kv; // 0 = K, 1 = V
uint local_id = id % total_kv;
uint token = local_id / kv_dim;
uint d = local_id % kv_dim;
uint dst_off = (base_pos + token) * kv_dim + d;
uint src_off = token * src_stride + (is_v ? v_offset : k_offset) + d;
if (is_v) {
v_dst[dst_off] = half(src[src_off]);
} else {
k_dst[dst_off] = half(src[src_off]);
}
}
"#
.replace("VEC_TILE_SIZE", &vec_tile_size.to_string())
.replace("ATTN_SCORES_SIZE", &attn_scores_size.to_string())
}
fn generate_model_rs(config: &ModelConfig) -> Result<String, MetalCodegenError> {
let mut code = String::with_capacity(48 * 1024);
emit_model_header(&mut code, config)?;
emit_metal_model_struct(&mut code, config)?;
emit_layer_buffers_struct(&mut code, config)?;
emit_metal_model_impl(&mut code, config)?;
emit_helper_functions(&mut code)?;
Ok(code)
}
fn emit_model_header(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
writeln!(
code,
"//! Model: {} ({} layers, hidden={})",
config.architecture, config.num_layers, config.hidden_size
)?;
writeln!(code, "//!")?;
writeln!(
code,
"//! Uses native Metal compute pipelines via the metal crate."
)?;
writeln!(code)?;
writeln!(code, "#![allow(dead_code)]")?;
writeln!(code)?;
writeln!(code, "use metal::*;")?;
writeln!(code, "#[allow(unused_imports)]")?;
writeln!(code, "use metal::objc::{{sel, sel_impl}};")?;
writeln!(code, "use std::mem;")?;
writeln!(code)?;
writeln!(
code,
"// ── Model constants ──────────────────────────────────"
)?;
writeln!(
code,
"pub const HIDDEN_SIZE: usize = {};",
config.hidden_size
)?;
writeln!(
code,
"pub const INTERMEDIATE_SIZE: usize = {};",
config.intermediate_size
)?;
writeln!(code, "pub const NUM_LAYERS: usize = {};", config.num_layers)?;
writeln!(
code,
"pub const NUM_HEADS: usize = {};",
config.num_attention_heads
)?;
writeln!(
code,
"pub const NUM_KV_HEADS: usize = {};",
config.num_kv_heads
)?;
writeln!(code, "pub const HEAD_DIM: usize = {};", config.head_dim)?;
writeln!(code, "pub const VOCAB_SIZE: usize = {};", config.vocab_size)?;
let effective_seq_len = config.max_seq_len.min(4096);
writeln!(
code,
"pub const MAX_SEQ_LEN: usize = {}; // capped from model's {}",
effective_seq_len, config.max_seq_len
)?;
writeln!(
code,
"pub const RMS_NORM_EPS: f32 = {:e};",
config.rms_norm_eps
)?;
writeln!(code, "pub const ROPE_THETA: f32 = {:e};", config.rope_theta)?;
writeln!(
code,
"/// Maximum batch size for batched prefill (prompt tokens processed at once)."
)?;
writeln!(code, "pub const MAX_BATCH_SIZE: usize = 512;")?;
writeln!(code)?;
Ok(())
}
fn emit_metal_model_struct(
code: &mut String,
config: &ModelConfig,
) -> Result<(), MetalCodegenError> {
writeln!(
code,
"// ── MetalModel ──────────────────────────────────────────"
)?;
writeln!(code)?;
writeln!(
code,
"/// Metal-accelerated transformer model for Apple Silicon."
)?;
writeln!(code, "///")?;
writeln!(
code,
"/// Uses unified memory for zero-copy weight access and native Metal"
)?;
writeln!(code, "/// compute pipelines for GPU-accelerated inference.")?;
writeln!(code, "pub struct MetalModel {{")?;
writeln!(code, " device: Device,")?;
writeln!(code, " queue: CommandQueue,")?;
writeln!(code)?;
writeln!(code, " // ── Compute pipelines ──")?;
writeln!(code, " matmul_pipeline: ComputePipelineState,")?;
writeln!(code, " matmul_q8_pipeline: ComputePipelineState,")?;
writeln!(code, " matmul_q4_pipeline: ComputePipelineState,")?;
writeln!(code, " rms_norm_pipeline: ComputePipelineState,")?;
writeln!(code, " rope_pipeline: ComputePipelineState,")?;
writeln!(code, " softmax_pipeline: ComputePipelineState,")?;
writeln!(code, " silu_mul_pipeline: ComputePipelineState,")?;
writeln!(code, " silu_mul_fused_pipeline: ComputePipelineState,")?;
writeln!(code, " add_pipeline: ComputePipelineState,")?;
writeln!(code, " attention_pipeline: ComputePipelineState,")?;
writeln!(code, " add_inplace_pipeline: ComputePipelineState,")?;
writeln!(code, " copy_pipeline: ComputePipelineState,")?;
writeln!(code, " copy_offset_pipeline: ComputePipelineState,")?;
writeln!(
code,
" copy_f32_to_f16_offset_pipeline: ComputePipelineState,"
)?;
writeln!(code)?;
writeln!(code, " // ── Batched prefill pipelines ──")?;
writeln!(code, " matmul_batch_pipeline: ComputePipelineState,")?;
writeln!(code, " matmul_q8_batch_pipeline: ComputePipelineState,")?;
writeln!(
code,
" matmul_q8_gemm_batch_pipeline: ComputePipelineState,"
)?;
writeln!(code, " matmul_q8_mma_pipeline: ComputePipelineState,")?;
writeln!(code, " matmul_q8_mma32_pipeline: ComputePipelineState,")?;
writeln!(
code,
" matmul_q8_mma32_h_pipeline: ComputePipelineState,"
)?;
writeln!(
code,
" matmul_q8_mma32_h4_pipeline: ComputePipelineState,"
)?;
writeln!(
code,
" matmul_q8_mma32_hh4_pipeline: ComputePipelineState,"
)?;
if config.qkv_bias {
writeln!(code, " add_bias_batch_pipeline: ComputePipelineState,")?;
}
writeln!(code, " matmul_q4_batch_pipeline: ComputePipelineState,")?;
writeln!(code, " rms_norm_batch_pipeline: ComputePipelineState,")?;
writeln!(code, " rope_batch_pipeline: ComputePipelineState,")?;
writeln!(
code,
" silu_mul_fused_batch_pipeline: ComputePipelineState,"
)?;
writeln!(
code,
" add_inplace_batch_pipeline: ComputePipelineState,"
)?;
writeln!(
code,
" copy_embedding_batch_pipeline: ComputePipelineState,"
)?;
writeln!(code, " attention_batch_pipeline: ComputePipelineState,")?;
writeln!(
code,
" attention_flash_batch_pipeline: ComputePipelineState,"
)?;
writeln!(
code,
" attention_mma_flash_batch_pipeline: ComputePipelineState,"
)?;
writeln!(code, " copy_kv_batch_pipeline: ComputePipelineState,")?;
writeln!(code, " rope_qk_batch_pipeline: ComputePipelineState,")?;
writeln!(
code,
" copy_kv_both_batch_pipeline: ComputePipelineState,"
)?;
writeln!(code)?;
writeln!(code, " // ── Weight buffers (Metal shared memory) ──")?;
writeln!(
code,
" /// Token embedding table [VOCAB_SIZE, HIDDEN_SIZE]"
)?;
writeln!(code, " embed_buf: Buffer,")?;
writeln!(code)?;
writeln!(code, " /// Per-layer weight buffers")?;
writeln!(code, " layers: Vec<LayerBuffers>,")?;
writeln!(code)?;
writeln!(code, " /// Final layer-norm weight [HIDDEN_SIZE]")?;
writeln!(code, " norm_buf: Buffer,")?;
writeln!(code)?;
writeln!(
code,
" /// LM head projection weight [VOCAB_SIZE, HIDDEN_SIZE] (f32)"
)?;
writeln!(code, " lm_head_buf: Buffer,")?;
writeln!(code)?;
writeln!(
code,
" // ── Working buffers (pre-allocated, reused every forward pass) ──"
)?;
writeln!(code, " hidden_buf: Buffer,")?;
writeln!(code, " residual_buf: Buffer,")?;
writeln!(code, " normed_buf: Buffer,")?;
writeln!(
code,
" /// Fused QKV output buffer [hidden + 2*kv_dim] — Q, K, V are contiguous slices"
)?;
writeln!(code, " qkv_buf: Buffer,")?;
writeln!(code, " attn_out_buf: Buffer,")?;
writeln!(code, " attn_proj_buf: Buffer,")?;
writeln!(
code,
" /// Fused gate+up output buffer [2*intermediate] — gate and up are contiguous slices"
)?;
writeln!(code, " gate_up_buf: Buffer,")?;
writeln!(code, " ffn_hidden_buf: Buffer,")?;
writeln!(code, " ffn_out_buf: Buffer,")?;
writeln!(code, " add_tmp_buf: Buffer,")?;
writeln!(code, " logits_buf: Buffer,")?;
writeln!(code)?;
writeln!(code, " // ── Batched prefill working buffers ──")?;
writeln!(code, " /// Batch hidden states [MAX_BATCH, HIDDEN_SIZE]")?;
writeln!(code, " batch_hidden_buf: Buffer,")?;
writeln!(
code,
" /// Batch residual states [MAX_BATCH, HIDDEN_SIZE]"
)?;
writeln!(code, " batch_residual_buf: Buffer,")?;
writeln!(code, " /// Batch QKV output [MAX_BATCH, qkv_rows]")?;
writeln!(code, " batch_qkv_buf: Buffer,")?;
writeln!(
code,
" /// Batch attention output [MAX_BATCH, HIDDEN_SIZE]"
)?;
writeln!(code, " batch_attn_out_buf: Buffer,")?;
writeln!(
code,
" /// Batch attention projection [MAX_BATCH, HIDDEN_SIZE]"
)?;
writeln!(code, " batch_attn_proj_buf: Buffer,")?;
writeln!(
code,
" /// Batch gate+up output [MAX_BATCH, 2*INTERMEDIATE_SIZE]"
)?;
writeln!(code, " batch_gate_up_buf: Buffer,")?;
writeln!(
code,
" /// Batch FFN hidden [MAX_BATCH, INTERMEDIATE_SIZE]"
)?;
writeln!(code, " batch_ffn_hidden_buf: Buffer,")?;
writeln!(code, " /// Batch FFN output [MAX_BATCH, HIDDEN_SIZE]")?;
writeln!(code, " batch_ffn_out_buf: Buffer,")?;
writeln!(code, " /// Token IDs buffer for batch embedding lookup")?;
writeln!(code, " batch_tokens_buf: Buffer,")?;
writeln!(code, " /// Positions buffer for batch RoPE")?;
writeln!(code, " batch_positions_buf: Buffer,")?;
writeln!(code)?;
writeln!(code, " // ── KV cache buffers (per-layer) ──")?;
writeln!(code, " k_cache: Vec<Buffer>, // per-layer")?;
writeln!(code, " v_cache: Vec<Buffer>, // per-layer")?;
writeln!(code)?;
writeln!(code, " // ── Inference state ──")?;
writeln!(code, " pos: usize,")?;
writeln!(code)?;
writeln!(
code,
" /// Previous command buffer for double-buffered prefill."
)?;
writeln!(
code,
" /// While the GPU executes token N, the CPU can encode token N+1."
)?;
writeln!(code, " prev_cmd: Option<CommandBuffer>,")?;
writeln!(code, "}}")?;
writeln!(code)?;
Ok(())
}
fn emit_layer_buffers_struct(
code: &mut String,
config: &ModelConfig,
) -> Result<(), MetalCodegenError> {
writeln!(
code,
"/// Per-layer weight buffers for attention and FFN projections."
)?;
writeln!(code, "struct LayerBuffers {{")?;
writeln!(code, " attn_norm: Buffer,")?;
writeln!(
code,
" /// Fused Q+K+V weight [hidden+2*kv_dim, hidden] (concatenated rows)"
)?;
writeln!(code, " qkv_weight: Buffer,")?;
if config.qkv_bias {
writeln!(
code,
" /// Fused Q+K+V bias [hidden+2*kv_dim] (f32) — Qwen2 only."
)?;
writeln!(code, " qkv_bias: Buffer,")?;
}
writeln!(code, " o_weight: Buffer,")?;
writeln!(code, " ffn_norm: Buffer,")?;
writeln!(
code,
" /// Fused gate+up weight [2*intermediate, hidden] (concatenated rows)"
)?;
writeln!(code, " gate_up_weight: Buffer,")?;
writeln!(code, " down_weight: Buffer,")?;
writeln!(code, "}}")?;
writeln!(code)?;
Ok(())
}
fn emit_metal_model_impl(code: &mut String, config: &ModelConfig) -> Result<(), MetalCodegenError> {
let hidden = config.hidden_size;
let intermediate = config.intermediate_size;
let _num_layers = config.num_layers;
let _num_heads = config.num_attention_heads;
let num_kv_heads = config.num_kv_heads;
let head_dim = config.head_dim;
let vocab = config.vocab_size;
let effective_seq_len = config.max_seq_len.min(4096);
let is_q8 = config.dtype == DType::Q8_0;
let is_q4 = config.dtype == DType::Q4_0;
let kv_dim = num_kv_heads * head_dim;
writeln!(code, "impl MetalModel {{")?;
writeln!(
code,
" /// Create a new MetalModel: compile shaders, load weights, allocate buffers."
)?;
writeln!(code, " ///")?;
writeln!(
code,
" /// `weights` is the raw weight blob produced by `forge export-weights`."
)?;
writeln!(code, " pub fn new(weights: &[u8]) -> Self {{")?;
writeln!(
code,
" let device = Device::system_default().expect(\"no Metal device found\");"
)?;
writeln!(code, " let queue = device.new_command_queue();")?;
writeln!(code)?;
writeln!(
code,
" // Compile Metal shaders from embedded source"
)?;
writeln!(
code,
" let shader_source = include_str!(\"../shaders/kernels.metal\");"
)?;
writeln!(
code,
" let library = device.new_library_with_source(shader_source, &CompileOptions::new())"
)?;
writeln!(
code,
" .expect(\"failed to compile Metal shaders\");"
)?;
writeln!(code)?;
writeln!(code, " // Create compute pipelines")?;
for (var, fn_name) in [
("matmul_pipeline", "matmul_vec"),
("matmul_q8_pipeline", "matmul_vec_q8"),
("matmul_q4_pipeline", "matmul_vec_q4"),
("rms_norm_pipeline", "rms_norm"),
("rope_pipeline", "rope"),
("softmax_pipeline", "softmax"),
("silu_mul_pipeline", "silu_mul"),
("silu_mul_fused_pipeline", "silu_mul_fused"),
("add_pipeline", "elementwise_add"),
("attention_pipeline", "attention"),
("add_inplace_pipeline", "add_inplace"),
("copy_pipeline", "copy_buffer"),
("copy_offset_pipeline", "copy_offset"),
("copy_f32_to_f16_offset_pipeline", "copy_f32_to_f16_offset"),
("matmul_batch_pipeline", "matmul_vec_batch"),
("matmul_q8_batch_pipeline", "matmul_vec_q8_batch"),
("matmul_q8_gemm_batch_pipeline", "matmul_q8_gemm_batch"),
("matmul_q8_mma_pipeline", "matmul_q8_mma"),
("matmul_q8_mma32_pipeline", "matmul_q8_mma32"),
("matmul_q8_mma32_h_pipeline", "matmul_q8_mma32_h"),
("matmul_q8_mma32_h4_pipeline", "matmul_q8_mma32_h4"),
("matmul_q8_mma32_hh4_pipeline", "matmul_q8_mma32_hh4"),
("matmul_q4_batch_pipeline", "matmul_vec_q4_batch"),
("rms_norm_batch_pipeline", "rms_norm_batch"),
("rope_batch_pipeline", "rope_batch"),
("silu_mul_fused_batch_pipeline", "silu_mul_fused_batch"),
("add_inplace_batch_pipeline", "add_inplace_batch"),
("copy_embedding_batch_pipeline", "copy_embedding_batch"),
("attention_batch_pipeline", "attention_batch"),
("attention_flash_batch_pipeline", "attention_flash_batch"),
(
"attention_mma_flash_batch_pipeline",
"attention_mma_flash_batch",
),
("copy_kv_batch_pipeline", "copy_kv_batch"),
("rope_qk_batch_pipeline", "rope_qk_batch"),
("copy_kv_both_batch_pipeline", "copy_kv_both_batch"),
] {
writeln!(
code,
" let {var} = make_pipeline(&device, &library, \"{fn_name}\");"
)?;
}
if config.qkv_bias {
writeln!(
code,
" let add_bias_batch_pipeline = make_pipeline(&device, &library, \"add_bias_batch\");"
)?;
}
writeln!(code)?;
writeln!(
code,
" // Load weights into Metal shared-memory buffers"
)?;
writeln!(code, " let f32_size = mem::size_of::<f32>();")?;
writeln!(code, " let embed_elems = VOCAB_SIZE * HIDDEN_SIZE;")?;
writeln!(code, " let hidden_elems = HIDDEN_SIZE;")?;
writeln!(code)?;
writeln!(
code,
" let cursor = std::cell::Cell::new(0usize); // byte cursor into `weights`"
)?;
writeln!(code)?;
writeln!(
code,
" // Helper: read the next `n` f32s from the weight blob as a Metal buffer."
)?;
writeln!(
code,
" let next_f32_buffer = |device: &Device, n: usize| -> Buffer {{"
)?;
writeln!(code, " let byte_len = n * f32_size;")?;
writeln!(code, " let cur = cursor.get();")?;
writeln!(
code,
" let data = &weights[cur..cur + byte_len];"
)?;
writeln!(code, " cursor.set(cur + byte_len);")?;
writeln!(code, " device.new_buffer_with_data(")?;
writeln!(code, " data.as_ptr() as *const _,")?;
writeln!(code, " byte_len as u64,")?;
writeln!(
code,
" MTLResourceOptions::StorageModeShared,"
)?;
writeln!(code, " )")?;
writeln!(code, " }};")?;
writeln!(code)?;
if is_q8 {
writeln!(
code,
" // Helper: read the next Q8_0 weight matrix (rows x cols elements)"
)?;
writeln!(
code,
" // as raw bytes into a Metal buffer (no dequantization)."
)?;
writeln!(
code,
" // Q8_0 format: each block of 32 values = 2 bytes (f16 scale) + 32 bytes (int8) = 34 bytes."
)?;
writeln!(
code,
" let next_q8_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
)?;
writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
writeln!(code, " let total_raw = rows * row_bytes;")?;
writeln!(code, " let cur = cursor.get();")?;
writeln!(
code,
" let data = &weights[cur..cur + total_raw];"
)?;
writeln!(code, " cursor.set(cur + total_raw);")?;
writeln!(code, " device.new_buffer_with_data(")?;
writeln!(code, " data.as_ptr() as *const _,")?;
writeln!(code, " total_raw as u64,")?;
writeln!(
code,
" MTLResourceOptions::StorageModeShared,"
)?;
writeln!(code, " )")?;
writeln!(code, " }};")?;
writeln!(code)?;
writeln!(
code,
" // Helper: read consecutive Q8_0 weight matrices as a single fused buffer."
)?;
writeln!(
code,
" // Reads `total_rows` rows of Q8_0 data (each row has `cols` elements)."
)?;
writeln!(
code,
" // Used for fused QKV and gate+up projections where weights are adjacent in the file."
)?;
writeln!(
code,
" let next_q8_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
)?;
writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
writeln!(code, " let row_bytes = blocks_per_row * 34;")?;
writeln!(code, " let total_raw = total_rows * row_bytes;")?;
writeln!(code, " let cur = cursor.get();")?;
writeln!(
code,
" let data = &weights[cur..cur + total_raw];"
)?;
writeln!(code, " cursor.set(cur + total_raw);")?;
writeln!(code, " device.new_buffer_with_data(")?;
writeln!(code, " data.as_ptr() as *const _,")?;
writeln!(code, " total_raw as u64,")?;
writeln!(
code,
" MTLResourceOptions::StorageModeShared,"
)?;
writeln!(code, " )")?;
writeln!(code, " }};")?;
writeln!(code)?;
}
if is_q4 {
writeln!(
code,
" // Helper: read the next Q4_0 weight matrix (rows x cols elements)"
)?;
writeln!(
code,
" // as raw bytes into a Metal buffer (no dequantization)."
)?;
writeln!(
code,
" // Q4_0 format: each block of 32 values = 2 bytes (f16 scale) + 16 bytes (4-bit pairs) = 18 bytes."
)?;
writeln!(
code,
" let next_q4_buffer = |device: &Device, rows: usize, cols: usize| -> Buffer {{"
)?;
writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
writeln!(code, " let total_raw = rows * row_bytes;")?;
writeln!(code, " let cur = cursor.get();")?;
writeln!(
code,
" let data = &weights[cur..cur + total_raw];"
)?;
writeln!(code, " cursor.set(cur + total_raw);")?;
writeln!(code, " device.new_buffer_with_data(")?;
writeln!(code, " data.as_ptr() as *const _,")?;
writeln!(code, " total_raw as u64,")?;
writeln!(
code,
" MTLResourceOptions::StorageModeShared,"
)?;
writeln!(code, " )")?;
writeln!(code, " }};")?;
writeln!(code)?;
writeln!(
code,
" // Helper: read consecutive Q4_0 weight matrices as a single fused buffer."
)?;
writeln!(
code,
" // Reads `total_rows` rows of Q4_0 data (each row has `cols` elements)."
)?;
writeln!(
code,
" // Used for fused QKV and gate+up projections where weights are adjacent in the file."
)?;
writeln!(
code,
" let next_q4_fused_buffer = |device: &Device, total_rows: usize, cols: usize| -> Buffer {{"
)?;
writeln!(code, " let blocks_per_row = cols.div_ceil(32);")?;
writeln!(code, " let row_bytes = blocks_per_row * 18;")?;
writeln!(code, " let total_raw = total_rows * row_bytes;")?;
writeln!(code, " let cur = cursor.get();")?;
writeln!(
code,
" let data = &weights[cur..cur + total_raw];"
)?;
writeln!(code, " cursor.set(cur + total_raw);")?;
writeln!(code, " device.new_buffer_with_data(")?;
writeln!(code, " data.as_ptr() as *const _,")?;
writeln!(code, " total_raw as u64,")?;
writeln!(
code,
" MTLResourceOptions::StorageModeShared,"
)?;
writeln!(code, " )")?;
writeln!(code, " }};")?;
writeln!(code)?;
}
writeln!(
code,
" let embed_buf = next_f32_buffer(&device, embed_elems);"
)?;
writeln!(code)?;
writeln!(
code,
" let mut layers: Vec<LayerBuffers> = Vec::with_capacity(NUM_LAYERS);"
)?;
writeln!(code, " for _layer in 0..NUM_LAYERS {{")?;
writeln!(
code,
" let attn_norm = next_f32_buffer(&device, hidden_elems);"
)?;
let qkv_rows = hidden + 2 * kv_dim;
if is_q8 {
writeln!(
code,
" let qkv_weight = next_q8_fused_buffer(&device, {qkv_rows}, {hidden});"
)?;
if config.qkv_bias {
writeln!(
code,
" // Qwen2 QKV bias triplet (F32): {qkv_rows} floats, loaded immediately after the fused weight."
)?;
writeln!(
code,
" let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
)?;
}
writeln!(
code,
" let o_weight = next_q8_buffer(&device, {hidden}, {hidden});"
)?;
} else if is_q4 {
writeln!(
code,
" let qkv_weight = next_q4_fused_buffer(&device, {qkv_rows}, {hidden});"
)?;
if config.qkv_bias {
writeln!(
code,
" let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
)?;
}
writeln!(
code,
" let o_weight = next_q4_buffer(&device, {hidden}, {hidden});"
)?;
} else {
writeln!(
code,
" let qkv_weight = next_f32_buffer(&device, {qkv_rows} * {hidden});"
)?;
if config.qkv_bias {
writeln!(
code,
" let qkv_bias = next_f32_buffer(&device, {qkv_rows});"
)?;
}
writeln!(
code,
" let o_weight = next_f32_buffer(&device, {hidden} * {hidden});"
)?;
}
writeln!(
code,
" let ffn_norm = next_f32_buffer(&device, hidden_elems);"
)?;
let gate_up_rows = 2 * intermediate;
if is_q8 {
writeln!(
code,
" let gate_up_weight = next_q8_fused_buffer(&device, {gate_up_rows}, {hidden});"
)?;
writeln!(
code,
" let down_weight = next_q8_buffer(&device, {hidden}, {intermediate});"
)?;
} else if is_q4 {
writeln!(
code,
" let gate_up_weight = next_q4_fused_buffer(&device, {gate_up_rows}, {hidden});"
)?;
writeln!(
code,
" let down_weight = next_q4_buffer(&device, {hidden}, {intermediate});"
)?;
} else {
writeln!(
code,
" let gate_up_weight = next_f32_buffer(&device, {gate_up_rows} * {hidden});"
)?;
writeln!(
code,
" let down_weight = next_f32_buffer(&device, {hidden} * {intermediate});"
)?;
}
writeln!(code, " layers.push(LayerBuffers {{")?;
writeln!(code, " attn_norm,")?;
writeln!(code, " qkv_weight,")?;
if config.qkv_bias {
writeln!(code, " qkv_bias,")?;
}
writeln!(code, " o_weight,")?;
writeln!(code, " ffn_norm,")?;
writeln!(code, " gate_up_weight,")?;
writeln!(code, " down_weight,")?;
writeln!(code, " }});")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" let norm_buf = next_f32_buffer(&device, hidden_elems);"
)?;
writeln!(code)?;
if is_q8 {
writeln!(
code,
" let lm_head_buf = next_q8_buffer(&device, {vocab}, {hidden});"
)?;
} else if is_q4 {
writeln!(
code,
" let lm_head_buf = next_q4_buffer(&device, {vocab}, {hidden});"
)?;
} else {
writeln!(
code,
" let lm_head_buf = next_f32_buffer(&device, {vocab} * {hidden});"
)?;
}
writeln!(code)?;
let hidden_bytes = hidden * 4;
let _kv_dim_bytes = kv_dim * 4;
let intermediate_bytes = intermediate * 4;
let vocab_bytes = vocab * 4;
let kv_cache_bytes = effective_seq_len * num_kv_heads * head_dim * 2;
writeln!(
code,
" // Allocate working buffers (shared memory for zero-copy)"
)?;
writeln!(
code,
" let opts = MTLResourceOptions::StorageModeShared;"
)?;
writeln!(
code,
" let hidden_buf = device.new_buffer({hidden_bytes} as u64, opts);"
)?;
writeln!(
code,
" let residual_buf = device.new_buffer({hidden_bytes} as u64, opts);"
)?;
let qkv_buf_bytes = (hidden + 2 * kv_dim) * 4;
writeln!(
code,
" let normed_buf = device.new_buffer({hidden_bytes} as u64, opts);"
)?;
writeln!(
code,
" // Fused QKV output: [Q ({hidden}), K ({kv_dim}), V ({kv_dim})] = {qkv_rows} f32s"
)?;
writeln!(
code,
" let qkv_buf = device.new_buffer({qkv_buf_bytes} as u64, opts);"
)?;
writeln!(
code,
" let attn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
)?;
writeln!(
code,
" let attn_proj_buf = device.new_buffer({hidden_bytes} as u64, opts);"
)?;
let gate_up_buf_bytes = 2 * intermediate * 4;
writeln!(
code,
" // Fused gate+up output: [gate ({intermediate}), up ({intermediate})] = {gate_up_rows} f32s"
)?;
writeln!(
code,
" let gate_up_buf = device.new_buffer({gate_up_buf_bytes} as u64, opts);"
)?;
writeln!(
code,
" let ffn_hidden_buf = device.new_buffer({intermediate_bytes} as u64, opts);"
)?;
writeln!(
code,
" let ffn_out_buf = device.new_buffer({hidden_bytes} as u64, opts);"
)?;
writeln!(
code,
" let add_tmp_buf = device.new_buffer({hidden_bytes} as u64, opts);"
)?;
writeln!(
code,
" let logits_buf = device.new_buffer({vocab_bytes} as u64, opts);"
)?;
writeln!(code)?;
let batch_hidden_bytes = hidden * 4; let batch_qkv_bytes = (hidden + 2 * kv_dim) * 4;
let batch_gate_up_bytes = 2 * intermediate * 4;
let batch_intermediate_bytes = intermediate * 4;
writeln!(
code,
" // Batched prefill working buffers (sized for MAX_BATCH_SIZE tokens)"
)?;
writeln!(
code,
" let batch_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
)?;
writeln!(
code,
" let batch_residual_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
)?;
writeln!(
code,
" let batch_qkv_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_qkv_bytes}) as u64, opts);"
)?;
writeln!(
code,
" let batch_attn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
)?;
writeln!(
code,
" let batch_attn_proj_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
)?;
writeln!(
code,
" let batch_gate_up_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_gate_up_bytes}) as u64, opts);"
)?;
writeln!(
code,
" let batch_ffn_hidden_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_intermediate_bytes}) as u64, opts);"
)?;
writeln!(
code,
" let batch_ffn_out_buf = device.new_buffer((MAX_BATCH_SIZE * {batch_hidden_bytes}) as u64, opts);"
)?;
writeln!(
code,
" let batch_tokens_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
)?;
writeln!(
code,
" let batch_positions_buf = device.new_buffer((MAX_BATCH_SIZE * mem::size_of::<u32>()) as u64, opts);"
)?;
writeln!(code)?;
writeln!(code, " // KV cache buffers (per-layer)")?;
writeln!(
code,
" let mut k_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
)?;
writeln!(
code,
" let mut v_cache: Vec<Buffer> = Vec::with_capacity(NUM_LAYERS);"
)?;
writeln!(code, " for _ in 0..NUM_LAYERS {{")?;
writeln!(
code,
" k_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
)?;
writeln!(
code,
" v_cache.push(device.new_buffer({kv_cache_bytes} as u64, opts));"
)?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " Self {{")?;
writeln!(code, " device,")?;
writeln!(code, " queue,")?;
writeln!(code, " matmul_pipeline,")?;
writeln!(code, " matmul_q8_pipeline,")?;
writeln!(code, " matmul_q4_pipeline,")?;
writeln!(code, " rms_norm_pipeline,")?;
writeln!(code, " rope_pipeline,")?;
writeln!(code, " softmax_pipeline,")?;
writeln!(code, " silu_mul_pipeline,")?;
writeln!(code, " silu_mul_fused_pipeline,")?;
writeln!(code, " add_pipeline,")?;
writeln!(code, " attention_pipeline,")?;
writeln!(code, " add_inplace_pipeline,")?;
writeln!(code, " copy_pipeline,")?;
writeln!(code, " copy_offset_pipeline,")?;
writeln!(code, " copy_f32_to_f16_offset_pipeline,")?;
writeln!(code, " matmul_batch_pipeline,")?;
writeln!(code, " matmul_q8_batch_pipeline,")?;
writeln!(code, " matmul_q8_gemm_batch_pipeline,")?;
writeln!(code, " matmul_q8_mma_pipeline,")?;
writeln!(code, " matmul_q8_mma32_pipeline,")?;
writeln!(code, " matmul_q8_mma32_h_pipeline,")?;
writeln!(code, " matmul_q8_mma32_h4_pipeline,")?;
writeln!(code, " matmul_q8_mma32_hh4_pipeline,")?;
if config.qkv_bias {
writeln!(code, " add_bias_batch_pipeline,")?;
}
writeln!(code, " matmul_q4_batch_pipeline,")?;
writeln!(code, " rms_norm_batch_pipeline,")?;
writeln!(code, " rope_batch_pipeline,")?;
writeln!(code, " silu_mul_fused_batch_pipeline,")?;
writeln!(code, " add_inplace_batch_pipeline,")?;
writeln!(code, " copy_embedding_batch_pipeline,")?;
writeln!(code, " attention_batch_pipeline,")?;
writeln!(code, " attention_flash_batch_pipeline,")?;
writeln!(code, " attention_mma_flash_batch_pipeline,")?;
writeln!(code, " copy_kv_batch_pipeline,")?;
writeln!(code, " rope_qk_batch_pipeline,")?;
writeln!(code, " copy_kv_both_batch_pipeline,")?;
writeln!(code, " embed_buf,")?;
writeln!(code, " layers,")?;
writeln!(code, " norm_buf,")?;
writeln!(code, " lm_head_buf,")?;
writeln!(code, " hidden_buf,")?;
writeln!(code, " residual_buf,")?;
writeln!(code, " normed_buf,")?;
writeln!(code, " qkv_buf,")?;
writeln!(code, " attn_out_buf,")?;
writeln!(code, " attn_proj_buf,")?;
writeln!(code, " gate_up_buf,")?;
writeln!(code, " ffn_hidden_buf,")?;
writeln!(code, " ffn_out_buf,")?;
writeln!(code, " add_tmp_buf,")?;
writeln!(code, " logits_buf,")?;
writeln!(code, " batch_hidden_buf,")?;
writeln!(code, " batch_residual_buf,")?;
writeln!(code, " batch_qkv_buf,")?;
writeln!(code, " batch_attn_out_buf,")?;
writeln!(code, " batch_attn_proj_buf,")?;
writeln!(code, " batch_gate_up_buf,")?;
writeln!(code, " batch_ffn_hidden_buf,")?;
writeln!(code, " batch_ffn_out_buf,")?;
writeln!(code, " batch_tokens_buf,")?;
writeln!(code, " batch_positions_buf,")?;
writeln!(code, " k_cache,")?;
writeln!(code, " v_cache,")?;
writeln!(code, " pos: 0,")?;
writeln!(code, " prev_cmd: None,")?;
writeln!(code, " }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Run the forward pass for a single token at the current position."
)?;
writeln!(code, " ///")?;
writeln!(
code,
" /// Returns logits over the vocabulary as a `Vec<f32>`."
)?;
writeln!(code, " ///")?;
writeln!(
code,
" /// All GPU operations are encoded into a single command buffer and"
)?;
writeln!(
code,
" /// committed once at the end, avoiding per-operation synchronization."
)?;
writeln!(
code,
" pub fn forward(&mut self, token_id: u32) -> Vec<f32> {{"
)?;
writeln!(
code,
" // Wait for any pending prefill command buffer"
)?;
writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
writeln!(code, " prev.wait_until_completed();")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " let pos = self.pos;")?;
writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
writeln!(code)?;
let matmul_fn = if is_q8 {
"dispatch_matmul_q8"
} else if is_q4 {
"dispatch_matmul_q4"
} else {
"dispatch_matmul"
};
writeln!(
code,
" // Single compute encoder for the entire forward pass (no blit transitions)"
)?;
writeln!(code, " {{")?;
writeln!(
code,
" let enc = cmd.new_compute_command_encoder();"
)?;
writeln!(code)?;
writeln!(
code,
" // 1. Embedding lookup via CPU memcpy (unified memory = zero-copy, avoids GPU dispatch overhead)"
)?;
writeln!(
code,
" // On Apple Silicon, Metal shared buffers use unified memory so CPU and GPU see the"
)?;
writeln!(
code,
" // same physical memory. CPU memcpy avoids GPU dispatch + memory barrier overhead for"
)?;
writeln!(
code,
" // this tiny copy ({hidden} floats = {} bytes), reducing GPU thermal load.",
hidden * 4,
)?;
writeln!(code, " unsafe {{")?;
writeln!(
code,
" let embed_ptr = self.embed_buf.contents() as *const f32;"
)?;
writeln!(
code,
" let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
)?;
writeln!(
code,
" let residual_ptr = self.residual_buf.contents() as *mut f32;"
)?;
writeln!(code, " std::ptr::copy_nonoverlapping(")?;
writeln!(
code,
" embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
)?;
writeln!(code, " hidden_ptr,")?;
writeln!(code, " HIDDEN_SIZE,")?;
writeln!(code, " );")?;
writeln!(
code,
" std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
)?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " // 2. Transformer layers")?;
writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
writeln!(code)?;
let q_byte_offset = 0usize;
let k_float_offset = hidden;
let v_float_offset = hidden + kv_dim;
writeln!(
code,
" // Pre-attention: rms_norm, fused QKV projection, RoPE"
)?;
writeln!(
code,
" self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
)?;
writeln!(
code,
" // Fused Q+K+V matmul: single dispatch for all three projections"
)?;
writeln!(
code,
" self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
)?;
if config.qkv_bias {
writeln!(
code,
" // Qwen2: broadcast-add per-row QKV bias after the fused matmul."
)?;
writeln!(
code,
" self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
)?;
}
writeln!(
code,
" // Fused Q+K RoPE in one dispatch (saves 1 dispatch + barrier vs separate Q and K rope)"
)?;
writeln!(
code,
" self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
)?;
writeln!(code)?;
writeln!(
code,
" // Fused K+V cache update in one dispatch (f32 qkv_buf -> f16 KV cache)"
)?;
writeln!(
code,
" self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
)?;
writeln!(code)?;
writeln!(
code,
" // Attention using Q from qkv_buf (offset 0)"
)?;
writeln!(
code,
" self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
)?;
writeln!(
code,
" self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
)?;
writeln!(
code,
" self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
)?;
writeln!(
code,
" // FFN: rms_norm, fused gate+up projection, silu_mul, down projection"
)?;
writeln!(
code,
" self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
)?;
writeln!(
code,
" // Fused gate+up matmul: single dispatch for both projections"
)?;
writeln!(
code,
" self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
)?;
writeln!(
code,
" self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
)?;
writeln!(
code,
" self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
)?;
writeln!(
code,
" self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
)?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " // 3. Final RMS norm + logits projection")?;
writeln!(
code,
" self.dispatch_rms_norm(&enc, &self.norm_buf);"
)?;
writeln!(
code,
" self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
)?;
writeln!(code)?;
writeln!(code, " enc.end_encoding();")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" // 5. Commit all GPU work and wait for completion"
)?;
writeln!(code, " cmd.commit();")?;
writeln!(code, " cmd.wait_until_completed();")?;
writeln!(code)?;
writeln!(code, " // 6. Read back logits from GPU")?;
writeln!(code, " let logits = unsafe {{")?;
writeln!(
code,
" let ptr = self.logits_buf.contents() as *const f32;"
)?;
writeln!(
code,
" std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
)?;
writeln!(code, " }};")?;
writeln!(code)?;
writeln!(code, " self.pos += 1;")?;
writeln!(code, " logits")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Profiling forward pass that prints per-stage GPU timing."
)?;
writeln!(code, " ///")?;
writeln!(
code,
" /// Each stage is committed and waited on separately so that GPU timestamps"
)?;
writeln!(
code,
" /// accurately reflect per-operation cost. This is slower than `forward()` due"
)?;
writeln!(
code,
" /// to the per-stage synchronization, but useful for identifying bottlenecks."
)?;
writeln!(
code,
" pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32> {{"
)?;
writeln!(code, " use std::time::Instant;")?;
writeln!(code)?;
writeln!(
code,
" // Wait for any pending prefill command buffer"
)?;
writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
writeln!(code, " prev.wait_until_completed();")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " let pos = self.pos;")?;
writeln!(code)?;
writeln!(
code,
" // ── Stage: Embedding lookup (CPU via unified memory) ──"
)?;
writeln!(code, " let t_embed = Instant::now();")?;
writeln!(code, " unsafe {{")?;
writeln!(
code,
" let embed_ptr = self.embed_buf.contents() as *const f32;"
)?;
writeln!(
code,
" let hidden_ptr = self.hidden_buf.contents() as *mut f32;"
)?;
writeln!(
code,
" let residual_ptr = self.residual_buf.contents() as *mut f32;"
)?;
writeln!(code, " std::ptr::copy_nonoverlapping(")?;
writeln!(
code,
" embed_ptr.add(token_id as usize * HIDDEN_SIZE),"
)?;
writeln!(code, " hidden_ptr,")?;
writeln!(code, " HIDDEN_SIZE,")?;
writeln!(code, " );")?;
writeln!(
code,
" std::ptr::copy_nonoverlapping(hidden_ptr, residual_ptr, HIDDEN_SIZE);"
)?;
writeln!(code, " }}")?;
writeln!(code, " let d_embed = t_embed.elapsed();")?;
writeln!(code)?;
writeln!(code, " // ── Stage: Transformer layers (GPU) ──")?;
writeln!(code, " let t_layers = Instant::now();")?;
writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
writeln!(code, " {{")?;
writeln!(
code,
" let enc = cmd.new_compute_command_encoder();"
)?;
writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
writeln!(
code,
" self.dispatch_rms_norm(&enc, &self.layers[layer].attn_norm);"
)?;
writeln!(
code,
" self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].qkv_weight, &self.qkv_buf, {qkv_rows}, {hidden});"
)?;
if config.qkv_bias {
writeln!(
code,
" self.dispatch_add_bias_batch(&enc, &self.qkv_buf, &self.layers[layer].qkv_bias, 1, {qkv_rows});"
)?;
}
writeln!(
code,
" self.dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1, pos, {qkv_rows});"
)?;
writeln!(
code,
" self.dispatch_copy_kv_both_batch(&enc, &self.qkv_buf, &self.k_cache[layer], &self.v_cache[layer], 1, {kv_dim}, pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
)?;
writeln!(
code,
" self.dispatch_attention_offset(&enc, &self.qkv_buf, {q_byte_offset}, &self.k_cache[layer], &self.v_cache[layer], &self.attn_out_buf, pos + 1);"
)?;
writeln!(
code,
" self.{matmul_fn}(&enc, &self.attn_out_buf, &self.layers[layer].o_weight, &self.attn_proj_buf, {hidden}, {hidden});"
)?;
writeln!(
code,
" self.dispatch_add_inplace(&enc, &self.residual_buf, &self.attn_proj_buf, {hidden});"
)?;
writeln!(
code,
" self.dispatch_rms_norm(&enc, &self.layers[layer].ffn_norm);"
)?;
writeln!(
code,
" self.{matmul_fn}(&enc, &self.hidden_buf, &self.layers[layer].gate_up_weight, &self.gate_up_buf, {gate_up_rows}, {hidden});"
)?;
writeln!(
code,
" self.dispatch_silu_mul_fused(&enc, &self.gate_up_buf, &self.ffn_hidden_buf, {intermediate});"
)?;
writeln!(
code,
" self.{matmul_fn}(&enc, &self.ffn_hidden_buf, &self.layers[layer].down_weight, &self.ffn_out_buf, {hidden}, {intermediate});"
)?;
writeln!(
code,
" self.dispatch_add_inplace(&enc, &self.residual_buf, &self.ffn_out_buf, {hidden});"
)?;
writeln!(code, " }}")?;
writeln!(code, " enc.end_encoding();")?;
writeln!(code, " }}")?;
writeln!(code, " cmd.commit();")?;
writeln!(code, " cmd.wait_until_completed();")?;
writeln!(code, " let d_layers = t_layers.elapsed();")?;
writeln!(code)?;
writeln!(code, " // ── Stage: Final norm + logits (GPU) ──")?;
writeln!(code, " let t_logits = Instant::now();")?;
writeln!(code, " let cmd2 = self.queue.new_command_buffer();")?;
writeln!(code, " {{")?;
writeln!(
code,
" let enc = cmd2.new_compute_command_encoder();"
)?;
writeln!(
code,
" self.dispatch_rms_norm(&enc, &self.norm_buf);"
)?;
writeln!(
code,
" self.{matmul_fn}(&enc, &self.hidden_buf, &self.lm_head_buf, &self.logits_buf, {vocab}, {hidden});"
)?;
writeln!(code, " enc.end_encoding();")?;
writeln!(code, " }}")?;
writeln!(code, " cmd2.commit();")?;
writeln!(code, " cmd2.wait_until_completed();")?;
writeln!(code, " let d_logits = t_logits.elapsed();")?;
writeln!(code)?;
writeln!(
code,
" eprintln!(\"[profile] embed: {{:.3}}ms, layers: {{:.3}}ms, norm+logits: {{:.3}}ms, total: {{:.3}}ms\","
)?;
writeln!(code, " d_embed.as_secs_f64() * 1000.0,")?;
writeln!(code, " d_layers.as_secs_f64() * 1000.0,")?;
writeln!(code, " d_logits.as_secs_f64() * 1000.0,")?;
writeln!(
code,
" (d_embed + d_layers + d_logits).as_secs_f64() * 1000.0);"
)?;
writeln!(code)?;
writeln!(code, " let logits = unsafe {{")?;
writeln!(
code,
" let ptr = self.logits_buf.contents() as *const f32;"
)?;
writeln!(
code,
" std::slice::from_raw_parts(ptr, VOCAB_SIZE).to_vec()"
)?;
writeln!(code, " }};")?;
writeln!(code)?;
writeln!(code, " self.pos += 1;")?;
writeln!(code, " logits")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Asynchronous forward pass for a single prefill token (no logits readback)."
)?;
writeln!(code, " ///")?;
writeln!(
code,
" /// Commits the command buffer without waiting, enabling double-buffered"
)?;
writeln!(
code,
" /// execution: GPU processes token N while CPU encodes token N+1."
)?;
writeln!(
code,
" pub fn forward_prefill(&mut self, token_id: u32) {{"
)?;
writeln!(code, " self.forward_prefill_batch(&[token_id]);")?;
writeln!(code, " }}")?;
writeln!(code)?;
let batch_matmul_fn = if is_q8 {
"dispatch_matmul_q8_batch"
} else if is_q4 {
"dispatch_matmul_q4_batch"
} else {
"dispatch_matmul_batch"
};
writeln!(
code,
" /// Batched prefill: process multiple prompt tokens in one GPU dispatch."
)?;
writeln!(code, " ///")?;
writeln!(
code,
" /// Uses batched matmul kernels for QKV/O/FFN projections (mat-mat instead"
)?;
writeln!(
code,
" /// of mat-vec), and batched causal attention with a single GPU dispatch."
)?;
writeln!(
code,
" /// This provides significant speedup during prompt prefill."
)?;
writeln!(
code,
" pub fn forward_prefill_batch(&mut self, tokens: &[u32]) {{"
)?;
writeln!(code, " if tokens.is_empty() {{ return; }}")?;
writeln!(
code,
" // Chunk long prompts into MAX_BATCH_SIZE-sized slices — the batched"
)?;
writeln!(
code,
" // prefill buffers are sized for MAX_BATCH_SIZE tokens, so prompts"
)?;
writeln!(
code,
" // longer than that must be processed iteratively. The KV cache"
)?;
writeln!(code, " // carries state across chunks via self.pos.")?;
writeln!(
code,
" for chunk in tokens.chunks(MAX_BATCH_SIZE) {{"
)?;
writeln!(code, " let m = chunk.len();")?;
writeln!(code, " if m == 0 {{ continue; }}")?;
writeln!(code, " let start_pos = self.pos;")?;
writeln!(code)?;
writeln!(code, " // Wait for any pending command buffer")?;
writeln!(code, " if let Some(prev) = self.prev_cmd.take() {{")?;
writeln!(code, " prev.wait_until_completed();")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" // Upload token IDs and positions to GPU buffers"
)?;
writeln!(code, " unsafe {{")?;
writeln!(
code,
" let tok_ptr = self.batch_tokens_buf.contents() as *mut u32;"
)?;
writeln!(
code,
" let pos_ptr = self.batch_positions_buf.contents() as *mut u32;"
)?;
writeln!(code, " for i in 0..m {{")?;
writeln!(code, " *tok_ptr.add(i) = chunk[i];")?;
writeln!(
code,
" *pos_ptr.add(i) = (start_pos + i) as u32;"
)?;
writeln!(code, " }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " let cmd = self.queue.new_command_buffer();")?;
writeln!(code, " {{")?;
writeln!(
code,
" let enc = cmd.new_compute_command_encoder();"
)?;
writeln!(code)?;
writeln!(
code,
" // 1. Batch embedding lookup: copy all token embeddings at once"
)?;
writeln!(
code,
" self.dispatch_copy_embedding_batch(&enc, m);"
)?;
writeln!(
code,
" self.dispatch_add_inplace_batch_copy(&enc, &self.batch_hidden_buf, &self.batch_residual_buf, m * {hidden});"
)?;
writeln!(code)?;
writeln!(code, " // 2. Transformer layers")?;
writeln!(code, " for layer in 0..NUM_LAYERS {{")?;
writeln!(code)?;
writeln!(
code,
" // Batch RMS norm: batch_residual -> batch_hidden"
)?;
writeln!(
code,
" self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].attn_norm, &self.batch_hidden_buf, m);"
)?;
writeln!(
code,
" // Batch QKV matmul: [M, hidden] x [qkv_rows, hidden]^T -> [M, qkv_rows]"
)?;
writeln!(
code,
" self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].qkv_weight, &self.batch_qkv_buf, m, {qkv_rows}, {hidden});"
)?;
if config.qkv_bias {
writeln!(
code,
" // Qwen2: broadcast-add QKV bias across all M tokens."
)?;
writeln!(
code,
" self.dispatch_add_bias_batch(&enc, &self.batch_qkv_buf, &self.layers[layer].qkv_bias, m, {qkv_rows});"
)?;
}
writeln!(code)?;
let k_float_offset = hidden;
writeln!(
code,
" // Fused RoPE on Q+K in one dispatch (saves 1 dispatch + barrier per layer)"
)?;
writeln!(
code,
" self.dispatch_rope_qk_batch(&enc, &self.batch_qkv_buf, m, start_pos, {qkv_rows});"
)?;
writeln!(code)?;
let v_float_offset = hidden + kv_dim;
writeln!(
code,
" // Fused KV cache update: copy K+V in one dispatch (saves 1 dispatch + barrier per layer)"
)?;
writeln!(
code,
" self.dispatch_copy_kv_both_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], m, {kv_dim}, start_pos, {qkv_rows}, {k_float_offset}, {v_float_offset});"
)?;
writeln!(code)?;
writeln!(
code,
" // Batched causal attention: one dispatch for all M tokens"
)?;
writeln!(
code,
" self.dispatch_attention_batch(&enc, &self.batch_qkv_buf, &self.k_cache[layer], &self.v_cache[layer], &self.batch_attn_out_buf, m, start_pos, {qkv_rows});"
)?;
writeln!(code)?;
writeln!(code, " // Batched O projection")?;
writeln!(
code,
" self.{batch_matmul_fn}(&enc, &self.batch_attn_out_buf, &self.layers[layer].o_weight, &self.batch_attn_proj_buf, m, {hidden}, {hidden});"
)?;
writeln!(code)?;
writeln!(
code,
" self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_attn_proj_buf, m * {hidden});"
)?;
writeln!(code)?;
writeln!(
code,
" // Batch FFN: rms_norm, gate+up matmul, silu_mul, down matmul"
)?;
writeln!(
code,
" self.dispatch_rms_norm_batch(&enc, &self.batch_residual_buf, &self.layers[layer].ffn_norm, &self.batch_hidden_buf, m);"
)?;
writeln!(
code,
" self.{batch_matmul_fn}(&enc, &self.batch_hidden_buf, &self.layers[layer].gate_up_weight, &self.batch_gate_up_buf, m, {gate_up_rows}, {hidden});"
)?;
writeln!(
code,
" self.dispatch_silu_mul_fused_batch(&enc, &self.batch_gate_up_buf, &self.batch_ffn_hidden_buf, {intermediate}, m);"
)?;
writeln!(
code,
" self.{batch_matmul_fn}(&enc, &self.batch_ffn_hidden_buf, &self.layers[layer].down_weight, &self.batch_ffn_out_buf, m, {hidden}, {intermediate});"
)?;
writeln!(
code,
" self.dispatch_add_inplace_batch_n(&enc, &self.batch_residual_buf, &self.batch_ffn_out_buf, m * {hidden});"
)?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" // Copy last token's residual to single-token buffer for subsequent forward()"
)?;
writeln!(
code,
" self.dispatch_copy_from_offset_bytes(&enc, &self.batch_residual_buf, (m - 1) * {hidden} * 4, &self.residual_buf, 0, {hidden});"
)?;
writeln!(code)?;
writeln!(code, " enc.end_encoding();")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " cmd.commit();")?;
writeln!(code, " self.prev_cmd = Some(cmd.to_owned());")?;
writeln!(code, " self.pos += m;")?;
writeln!(code, " }} // end for chunk")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Reset the model state for a new inference request."
)?;
writeln!(code, " pub fn reset(&mut self) {{")?;
writeln!(code, " self.pos = 0;")?;
writeln!(code, " self.prev_cmd = None;")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" // ── Dispatch helpers (append to a shared compute command encoder) ──"
)?;
writeln!(
code,
" // These methods set pipeline state + buffers + dispatch on an existing"
)?;
writeln!(
code,
" // encoder, avoiding per-operation encoder creation overhead."
)?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch RMS norm: normalizes residual_buf -> hidden_buf using given weight."
)?;
writeln!(
code,
" fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef, weight: &Buffer) {{"
)?;
writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.rms_norm_pipeline);"
)?;
writeln!(
code,
" enc.set_buffer(0, Some(&self.residual_buf), 0);"
)?;
writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
writeln!(
code,
" enc.set_buffer(2, Some(&self.hidden_buf), 0);"
)?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(1, 1, 1); // single threadgroup"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch matrix-vector multiply: weight * input -> output."
)?;
writeln!(
code,
" fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
)?;
writeln!(code, " let r: u32 = rows as u32;")?;
writeln!(code, " let c: u32 = cols as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.matmul_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
)?;
writeln!(
code,
" // 64 rows per threadgroup (8 simdgroups x 8 rows each = 256 threads)"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(code, " let num_tg = ((rows + 63) / 64) as u64;")?;
writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch Q8_0 quantized matrix-vector multiply: weight_q8 * input -> output."
)?;
writeln!(
code,
" /// Weights are raw Q8_0 bytes (34 bytes per block of 32 elements)."
)?;
writeln!(
code,
" fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
)?;
writeln!(code, " let r: u32 = rows as u32;")?;
writeln!(code, " let c: u32 = cols as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.matmul_q8_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
)?;
writeln!(
code,
" // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch Q4_0 quantized matrix-vector multiply: weight_q4 * input -> output."
)?;
writeln!(
code,
" /// Weights are raw Q4_0 bytes (18 bytes per block of 32 elements)."
)?;
writeln!(
code,
" fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, rows: usize, cols: usize) {{"
)?;
writeln!(code, " let r: u32 = rows as u32;")?;
writeln!(code, " let c: u32 = cols as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.matmul_q4_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
)?;
writeln!(
code,
" // 32 rows per threadgroup (8 simdgroups x 4 rows each = 256 threads)"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(code, " let num_tg = ((rows + 31) / 32) as u64;")?;
writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " /// Dispatch RoPE on a buffer in-place.")?;
writeln!(
code,
" fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_heads: usize, head_dim: usize, pos: usize) {{"
)?;
writeln!(code, " let nh: u32 = num_heads as u32;")?;
writeln!(code, " let hd: u32 = head_dim as u32;")?;
writeln!(code, " let p: u32 = pos as u32;")?;
writeln!(code, " let theta: f32 = ROPE_THETA;")?;
writeln!(
code,
" let total_pairs = num_heads * (head_dim / 2);"
)?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.rope_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
writeln!(
code,
" enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch RoPE on a buffer at a given byte offset (for fused QKV buffer)."
)?;
writeln!(
code,
" fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, byte_offset: usize, num_heads: usize, head_dim: usize, pos: usize) {{"
)?;
writeln!(code, " let nh: u32 = num_heads as u32;")?;
writeln!(code, " let hd: u32 = head_dim as u32;")?;
writeln!(code, " let p: u32 = pos as u32;")?;
writeln!(code, " let theta: f32 = ROPE_THETA;")?;
writeln!(
code,
" let total_pairs = num_heads * (head_dim / 2);"
)?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.rope_pipeline);"
)?;
writeln!(
code,
" enc.set_buffer(0, Some(buf), byte_offset as u64);"
)?;
writeln!(
code,
" enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch attention kernel: Q * K^T -> softmax -> * V."
)?;
writeln!(
code,
" fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
)?;
writeln!(code, " let sl: u32 = seq_len as u32;")?;
writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.attention_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
)?;
writeln!(
code,
" // One threadgroup per head, 256 threads (8 simdgroups for cooperative reductions)"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch attention with Q at a byte offset in the source buffer (for fused QKV)."
)?;
writeln!(
code,
" fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, q_byte_offset: usize, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, seq_len: usize) {{"
)?;
writeln!(code, " let sl: u32 = seq_len as u32;")?;
writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.attention_pipeline);"
)?;
writeln!(
code,
" enc.set_buffer(0, Some(q_buf), q_byte_offset as u64);"
)?;
writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &sl as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(6, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(7, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(NUM_HEADS as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " /// Dispatch fused SiLU-multiply kernel.")?;
writeln!(
code,
" fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef, gate: &Buffer, up: &Buffer, output: &Buffer, n: usize) {{"
)?;
writeln!(code, " let count: u32 = n as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.silu_mul_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(gate), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(up), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch fused SiLU-multiply reading gate and up from a single concatenated buffer."
)?;
writeln!(
code,
" /// gate_up_buf contains [gate(n), up(n)] contiguously."
)?;
writeln!(
code,
" fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize) {{"
)?;
writeln!(code, " let count: u32 = n as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.silu_mul_fused_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch buffer copy via compute kernel: dst[i] = src[i] for i in 0..count."
)?;
writeln!(
code,
" fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
)?;
writeln!(code, " let n: u32 = count as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.copy_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch offset copy via compute kernel: dst[i] = src[src_offset + i]."
)?;
writeln!(
code,
" /// Used for embedding table lookup (copy a specific row)."
)?;
writeln!(
code,
" fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, src_offset: usize, count: usize) {{"
)?;
writeln!(code, " let off: u32 = src_offset as u32;")?;
writeln!(code, " let n: u32 = count as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.copy_offset_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &off as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch copy from source at byte offset to destination at float offset."
)?;
writeln!(
code,
" /// Used for KV cache updates from fused QKV buffer."
)?;
writeln!(
code,
" fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
)?;
writeln!(code, " let n: u32 = count as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.copy_pipeline);"
)?;
writeln!(
code,
" enc.set_buffer(0, Some(src), src_byte_offset as u64);"
)?;
writeln!(
code,
" enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
)?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch f32 -> f16 copy from src at byte offset to half-typed dst at element offset."
)?;
writeln!(
code,
" /// Used for single-token decode KV cache updates (f32 QKV buf -> f16 KV cache)."
)?;
writeln!(
code,
" fn dispatch_copy_from_offset_f16(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_elem_offset: usize, count: usize) {{"
)?;
writeln!(code, " let n: u32 = count as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.copy_f32_to_f16_offset_pipeline);"
)?;
writeln!(
code,
" enc.set_buffer(0, Some(src), src_byte_offset as u64);"
)?;
writeln!(
code,
" enc.set_buffer(1, Some(dst), (dst_elem_offset * 2) as u64);"
)?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch copy to destination offset: dst[dst_offset + i] = src[i]."
)?;
writeln!(
code,
" /// Used for KV cache updates (write to a specific position in the cache)."
)?;
writeln!(
code,
" fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_offset: usize, count: usize) {{"
)?;
writeln!(code, " let n: u32 = count as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.copy_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
writeln!(
code,
" enc.set_buffer(1, Some(dst), (dst_offset * mem::size_of::<f32>()) as u64);"
)?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch in-place element-wise add (residual connection): a[i] += b[i]."
)?;
writeln!(
code,
" fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, n: usize) {{"
)?;
writeln!(code, " let count: u32 = n as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.add_inplace_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((n + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " // ── Batched prefill dispatch helpers ──")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch batched embedding lookup: copy M token embeddings at once."
)?;
writeln!(
code,
" fn dispatch_copy_embedding_batch(&self, enc: &ComputeCommandEncoderRef, num_tokens: usize) {{"
)?;
writeln!(code, " let dim: u32 = HIDDEN_SIZE as u32;")?;
writeln!(code, " let nt: u32 = num_tokens as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.copy_embedding_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(&self.embed_buf), 0);")?;
writeln!(
code,
" enc.set_buffer(1, Some(&self.batch_hidden_buf), 0);"
)?;
writeln!(
code,
" enc.set_buffer(2, Some(&self.batch_tokens_buf), 0);"
)?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &dim as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
)?;
writeln!(code, " let total = num_tokens * HIDDEN_SIZE;")?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch batched RMS norm: normalizes M vectors at once."
)?;
writeln!(
code,
" fn dispatch_rms_norm_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize) {{"
)?;
writeln!(code, " let n: u32 = HIDDEN_SIZE as u32;")?;
writeln!(code, " let eps: f32 = RMS_NORM_EPS;")?;
writeln!(code, " let nt: u32 = num_tokens as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.rms_norm_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(input), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(weight), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<f32>() as u64, &eps as *const f32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(num_tokens as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch batched f32 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
)?;
writeln!(
code,
" fn dispatch_matmul_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
)?;
writeln!(code, " let nt: u32 = num_tokens as u32;")?;
writeln!(code, " let r: u32 = rows as u32;")?;
writeln!(code, " let c: u32 = cols as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.matmul_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
)?;
writeln!(
code,
" let row_tgs = (rows + 63) / 64; // 64 rows per threadgroup for f32"
)?;
writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch batched Q8_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
)?;
writeln!(code, " ///")?;
writeln!(
code,
" /// Selects between a per-token mat-vec kernel (M < 4) and a GEMM-style"
)?;
writeln!(
code,
" /// kernel that reuses weight reads across a tile of 4 tokens (M >= 4)."
)?;
writeln!(
code,
" fn dispatch_matmul_q8_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
)?;
writeln!(code, " let nt: u32 = num_tokens as u32;")?;
writeln!(code, " let r: u32 = rows as u32;")?;
writeln!(code, " let c: u32 = cols as u32;")?;
writeln!(
code,
" // Tile sizes must match the Metal shader constants."
)?;
writeln!(code, " const TOKENS_PER_TG_Q8: usize = 4;")?;
writeln!(code, " const MMA_TOK_TILE: usize = 16;")?;
writeln!(code, " const MMA_ROW_TILE: usize = 16;")?;
writeln!(code, " const MMA32_TOK_TILE: usize = 32;")?;
writeln!(code, " const MMA32_ROW_TILE: usize = 32;")?;
writeln!(
code,
" // Hardware matrix-multiply paths (simdgroup_matrix)."
)?;
writeln!(
code,
" // Prefer the large 32×32 tile when the problem supports it — halves"
)?;
writeln!(
code,
" // dispatch count and reuses each weight load across 32 tokens."
)?;
writeln!(
code,
" if num_tokens >= MMA32_TOK_TILE && rows % MMA32_ROW_TILE == 0 && cols % 32 == 0 {{"
)?;
writeln!(
code,
" // FP16-tile variant: 4 KB shared mem vs 8 KB doubles TG occupancy."
)?;
writeln!(
code,
" // It wins at moderate prefill lengths where the GPU is wave-starved,"
)?;
writeln!(
code,
" // but the f32→f16 conversion overhead slightly hurts the small-hidden"
)?;
writeln!(
code,
" // case (135M / 360M). Switch at cols >= 2048 — a clean split that"
)?;
writeln!(
code,
" // keeps the FP32 path for small-hidden models and gives 1B/3B the win."
)?;
writeln!(
code,
" // All-FP16 MMA (hh4) has a scalar-widening store path that costs a"
)?;
writeln!(
code,
" // little at low M but wins at higher M via ~2x FP16 MMA throughput."
)?;
writeln!(
code,
" // Empirically the crossover is around M=256 on M5 Pro for 1B/3B."
)?;
writeln!(code, " let use_h4 = cols >= 2048;")?;
writeln!(code, " let pipe = if use_h4 {{")?;
writeln!(code, " if num_tokens >= 256 {{")?;
writeln!(
code,
" &self.matmul_q8_mma32_hh4_pipeline"
)?;
writeln!(code, " }} else {{")?;
writeln!(
code,
" &self.matmul_q8_mma32_h4_pipeline"
)?;
writeln!(code, " }}")?;
writeln!(code, " }} else {{")?;
writeln!(code, " &self.matmul_q8_mma32_pipeline")?;
writeln!(code, " }};")?;
writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
)?;
writeln!(code, " let row_tgs = rows / MMA32_ROW_TILE;")?;
writeln!(
code,
" let tok_tgs = (num_tokens + MMA32_TOK_TILE - 1) / MMA32_TOK_TILE;"
)?;
writeln!(
code,
" let tg_size = if use_h4 {{ MTLSize::new(128, 1, 1) }} else {{ MTLSize::new(256, 1, 1) }};"
)?;
writeln!(
code,
" let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(
code,
" }} else if num_tokens >= MMA_TOK_TILE && rows % MMA_ROW_TILE == 0 && cols % 32 == 0 {{"
)?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.matmul_q8_mma_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
)?;
writeln!(code, " let row_tgs = rows / MMA_ROW_TILE;")?;
writeln!(
code,
" let tok_tgs = (num_tokens + MMA_TOK_TILE - 1) / MMA_TOK_TILE;"
)?;
writeln!(
code,
" let tg_size = MTLSize::new(128, 1, 1); // 4 simdgroups × 32 lanes"
)?;
writeln!(
code,
" let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " }} else if num_tokens >= TOKENS_PER_TG_Q8 {{")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.matmul_q8_gemm_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
)?;
writeln!(
code,
" let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
)?;
writeln!(
code,
" let tok_tgs = (num_tokens + TOKENS_PER_TG_Q8 - 1) / TOKENS_PER_TG_Q8;"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(row_tgs as u64, tok_tgs as u64, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " }} else {{")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.matmul_q8_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
)?;
writeln!(
code,
" let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q8"
)?;
writeln!(
code,
" let num_tg = (row_tgs * num_tokens) as u64;"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(num_tg, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " }}")?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch batched Q4_0 matmul: [M, cols] x [rows, cols]^T -> [M, rows]."
)?;
writeln!(
code,
" fn dispatch_matmul_q4_batch(&self, enc: &ComputeCommandEncoderRef, input: &Buffer, weight: &Buffer, output: &Buffer, num_tokens: usize, rows: usize, cols: usize) {{"
)?;
writeln!(code, " let nt: u32 = num_tokens as u32;")?;
writeln!(code, " let r: u32 = rows as u32;")?;
writeln!(code, " let c: u32 = cols as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.matmul_q4_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(weight), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(input), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &c as *const u32 as *const _);"
)?;
writeln!(
code,
" let row_tgs = (rows + 31) / 32; // 32 rows per threadgroup for Q4"
)?;
writeln!(code, " let num_tg = (row_tgs * num_tokens) as u64;")?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(code, " let grid_size = MTLSize::new(num_tg, 1, 1);")?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
if config.qkv_bias {
writeln!(
code,
" /// Broadcast-add a per-row bias vector to every row of an [M, rows] buffer."
)?;
writeln!(
code,
" fn dispatch_add_bias_batch(&self, enc: &ComputeCommandEncoderRef, out: &Buffer, bias: &Buffer, num_tokens: usize, rows: usize) {{"
)?;
writeln!(code, " let nt: u32 = num_tokens as u32;")?;
writeln!(code, " let r: u32 = rows as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.add_bias_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(out), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(bias), 0);")?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &r as *const u32 as *const _);"
)?;
writeln!(code, " let total = (num_tokens * rows) as u64;")?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new((total + 255) / 256, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
}
writeln!(
code,
" /// Dispatch batched RoPE: apply RoPE to M vectors with different positions."
)?;
writeln!(
code,
" /// `data_float_offset` is the offset in floats within each token's row in the batch buffer."
)?;
writeln!(
code,
" /// `row_stride` is the number of floats per token row in the batch buffer."
)?;
writeln!(
code,
" fn dispatch_rope_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, data_float_offset: usize, num_heads: usize, head_dim: usize, num_tokens: usize, row_stride: usize) {{"
)?;
writeln!(code, " let nh: u32 = num_heads as u32;")?;
writeln!(code, " let hd: u32 = head_dim as u32;")?;
writeln!(code, " let theta: f32 = ROPE_THETA;")?;
writeln!(code, " let nt: u32 = num_tokens as u32;")?;
writeln!(
code,
" let pairs_per_token = num_heads * (head_dim / 2);"
)?;
writeln!(
code,
" let total_pairs = num_tokens * pairs_per_token;"
)?;
writeln!(
code,
" // Apply RoPE to each token individually (different positions, non-contiguous layout)"
)?;
writeln!(code, " for t in 0..num_tokens {{")?;
writeln!(
code,
" let byte_offset = (t * row_stride + data_float_offset) * mem::size_of::<f32>();"
)?;
writeln!(
code,
" let p: u32 = unsafe {{ *(self.batch_positions_buf.contents() as *const u32).add(t) }};"
)?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.rope_pipeline);"
)?;
writeln!(
code,
" enc.set_buffer(0, Some(buf), byte_offset as u64);"
)?;
writeln!(
code,
" enc.set_bytes(1, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &p as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((pairs_per_token + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch batched fused SiLU-multiply for M tokens."
)?;
writeln!(
code,
" fn dispatch_silu_mul_fused_batch(&self, enc: &ComputeCommandEncoderRef, gate_up: &Buffer, output: &Buffer, n: usize, num_tokens: usize) {{"
)?;
writeln!(code, " let count: u32 = n as u32;")?;
writeln!(code, " let nt: u32 = num_tokens as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.silu_mul_fused_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(gate_up), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &nt as *const u32 as *const _);"
)?;
writeln!(code, " let total = num_tokens * n;")?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch in-place add for total_n elements: a[i] += b[i]."
)?;
writeln!(
code,
" fn dispatch_add_inplace_batch_n(&self, enc: &ComputeCommandEncoderRef, a: &Buffer, b: &Buffer, total_n: usize) {{"
)?;
writeln!(code, " let count: u32 = total_n as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.add_inplace_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(a), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(b), 0);")?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &count as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((total_n + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Copy src to dst using compute copy kernel (for batch residual init)."
)?;
writeln!(
code,
" fn dispatch_add_inplace_batch_copy(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, count: usize) {{"
)?;
writeln!(code, " let n: u32 = count as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.copy_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Copy src buffer to dst at a float offset (for scatter into batch buffers)."
)?;
writeln!(
code,
" fn dispatch_copy_to_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
)?;
writeln!(code, " let n: u32 = count as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.copy_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
writeln!(
code,
" enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
)?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Copy from src at byte offset to dst at float offset."
)?;
writeln!(
code,
" fn dispatch_copy_from_offset_bytes(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, src_byte_offset: usize, dst: &Buffer, dst_float_offset: usize, count: usize) {{"
)?;
writeln!(code, " let n: u32 = count as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.copy_pipeline);"
)?;
writeln!(
code,
" enc.set_buffer(0, Some(src), src_byte_offset as u64);"
)?;
writeln!(
code,
" enc.set_buffer(1, Some(dst), (dst_float_offset * mem::size_of::<f32>()) as u64);"
)?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &n as *const u32 as *const _);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((count + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch batched KV cache copy: copy K or V from strided batch QKV buffer to KV cache."
)?;
writeln!(
code,
" fn dispatch_copy_kv_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, src_offset: usize) {{"
)?;
writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
writeln!(code, " let kv: u32 = kv_dim as u32;")?;
writeln!(code, " let bp: u32 = base_pos as u32;")?;
writeln!(code, " let ss: u32 = src_stride as u32;")?;
writeln!(code, " let so: u32 = src_offset as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.copy_kv_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(dst), 0);")?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(6, mem::size_of::<u32>() as u64, &so as *const u32 as *const _);"
)?;
writeln!(code, " let total = num_tokens * kv_dim;")?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch batched causal attention: one dispatch for all M tokens."
)?;
writeln!(
code,
" fn dispatch_attention_batch(&self, enc: &ComputeCommandEncoderRef, q_buf: &Buffer, k_cache: &Buffer, v_cache: &Buffer, output: &Buffer, num_tokens: usize, base_pos: usize, q_stride: usize) {{"
)?;
writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
writeln!(code, " let bp: u32 = base_pos as u32;")?;
writeln!(code, " let nh: u32 = NUM_HEADS as u32;")?;
writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
writeln!(code, " let qs: u32 = q_stride as u32;")?;
writeln!(code, " let max_seq = base_pos + num_tokens;")?;
writeln!(code, " let _ = max_seq;")?;
writeln!(
code,
" let mma_opt_out = std::env::var(\"FORGE_MMA_ATTN\")"
)?;
writeln!(code, " .map(|v| v == \"0\").unwrap_or(false);")?;
writeln!(
code,
" let use_mma_flash = !mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8;"
)?;
writeln!(code, " if use_mma_flash {{")?;
writeln!(
code,
" let pipe = &self.attention_mma_flash_batch_pipeline;"
)?;
writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
)?;
writeln!(
code,
" // Grid: [ceil(M/8), NUM_HEADS, 1], 128 threads (4 simdgroups) per TG"
)?;
writeln!(code, " let tg_size = MTLSize::new(128, 1, 1);")?;
writeln!(
code,
" let q_blocks = ((num_tokens + 7) / 8) as u64;"
)?;
writeln!(
code,
" let grid_size = MTLSize::new(q_blocks, NUM_HEADS as u64, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " return;")?;
writeln!(code, " }}")?;
writeln!(code, " let pipe = &self.attention_batch_pipeline;")?;
writeln!(code, " enc.set_compute_pipeline_state(pipe);")?;
writeln!(code, " enc.set_buffer(0, Some(q_buf), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(k_cache), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(v_cache), 0);")?;
writeln!(code, " enc.set_buffer(3, Some(output), 0);")?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(6, mem::size_of::<u32>() as u64, &nh as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(7, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(8, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(9, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
)?;
writeln!(
code,
" // One threadgroup per (token, head) pair, 256 threads for cooperative reductions"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new((num_tokens * NUM_HEADS) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch fused RoPE for both Q and K heads in one kernel launch."
)?;
writeln!(
code,
" /// Saves one dispatch + barrier per layer vs separate Q and K RoPE calls."
)?;
writeln!(
code,
" fn dispatch_rope_qk_batch(&self, enc: &ComputeCommandEncoderRef, buf: &Buffer, num_tokens: usize, base_pos: usize, qkv_stride: usize) {{"
)?;
writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
writeln!(code, " let bp: u32 = base_pos as u32;")?;
writeln!(code, " let nq: u32 = NUM_HEADS as u32;")?;
writeln!(code, " let nkv: u32 = NUM_KV_HEADS as u32;")?;
writeln!(code, " let hd: u32 = HEAD_DIM as u32;")?;
writeln!(code, " let qs: u32 = qkv_stride as u32;")?;
writeln!(code, " let theta: f32 = ROPE_THETA;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.rope_qk_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(buf), 0);")?;
writeln!(
code,
" enc.set_bytes(1, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(2, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &nq as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &nkv as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &hd as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(6, mem::size_of::<u32>() as u64, &qs as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(7, mem::size_of::<f32>() as u64, &theta as *const f32 as *const _);"
)?;
writeln!(
code,
" let total_pairs = num_tokens * (NUM_HEADS + NUM_KV_HEADS) * (HEAD_DIM / 2);"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((total_pairs + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" /// Dispatch fused K+V cache copy in one kernel launch."
)?;
writeln!(
code,
" /// Saves one dispatch + barrier per layer vs separate K and V copy calls."
)?;
writeln!(
code,
" fn dispatch_copy_kv_both_batch(&self, enc: &ComputeCommandEncoderRef, src: &Buffer, k_dst: &Buffer, v_dst: &Buffer, num_tokens: usize, kv_dim: usize, base_pos: usize, src_stride: usize, k_offset: usize, v_offset: usize) {{"
)?;
writeln!(code, " let m_val: u32 = num_tokens as u32;")?;
writeln!(code, " let kv: u32 = kv_dim as u32;")?;
writeln!(code, " let bp: u32 = base_pos as u32;")?;
writeln!(code, " let ss: u32 = src_stride as u32;")?;
writeln!(code, " let ko: u32 = k_offset as u32;")?;
writeln!(code, " let vo: u32 = v_offset as u32;")?;
writeln!(
code,
" enc.set_compute_pipeline_state(&self.copy_kv_both_batch_pipeline);"
)?;
writeln!(code, " enc.set_buffer(0, Some(src), 0);")?;
writeln!(code, " enc.set_buffer(1, Some(k_dst), 0);")?;
writeln!(code, " enc.set_buffer(2, Some(v_dst), 0);")?;
writeln!(
code,
" enc.set_bytes(3, mem::size_of::<u32>() as u64, &m_val as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(4, mem::size_of::<u32>() as u64, &kv as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(5, mem::size_of::<u32>() as u64, &bp as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(6, mem::size_of::<u32>() as u64, &ss as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(7, mem::size_of::<u32>() as u64, &ko as *const u32 as *const _);"
)?;
writeln!(
code,
" enc.set_bytes(8, mem::size_of::<u32>() as u64, &vo as *const u32 as *const _);"
)?;
writeln!(
code,
" let total = num_tokens * kv_dim * 2; // K + V"
)?;
writeln!(code, " let tg_size = MTLSize::new(256, 1, 1);")?;
writeln!(
code,
" let grid_size = MTLSize::new(((total + 255) / 256) as u64, 1, 1);"
)?;
writeln!(
code,
" enc.dispatch_thread_groups(grid_size, tg_size);"
)?;
writeln!(code, " unsafe {{ let _: () = metal::objc::msg_send![enc, memoryBarrierWithScope:1u64]; }}")?;
writeln!(code, " }}")?;
writeln!(code, "}}")?;
writeln!(code)?;
Ok(())
}
fn emit_helper_functions(code: &mut String) -> Result<(), MetalCodegenError> {
writeln!(
code,
"// ── Helper functions ──────────────────────────────────"
)?;
writeln!(code)?;
writeln!(
code,
"/// Create a compute pipeline from a named function in the Metal library."
)?;
writeln!(
code,
"fn make_pipeline(device: &Device, library: &Library, name: &str) -> ComputePipelineState {{"
)?;
writeln!(
code,
" let func = library.get_function(name, None).unwrap_or_else(|_| panic!(\"Metal function '{{name}}' not found\"));"
)?;
writeln!(
code,
" device.new_compute_pipeline_state_with_function(&func)"
)?;
writeln!(
code,
" .unwrap_or_else(|_| panic!(\"failed to create pipeline for '{{name}}'\"))"
)?;
writeln!(code, "}}")?;
writeln!(code)?;
Ok(())
}
fn generate_main_rs(model_name: &str, config: &ModelConfig) -> Result<String, MetalCodegenError> {
let _sanitized = sanitize_name(model_name);
let _vocab = config.vocab_size;
let mut code = String::with_capacity(16 * 1024);
writeln!(code, "//! Auto-generated by ForgeLLM Metal codegen.")?;
writeln!(
code,
"//! CLI and HTTP server for Metal GPU inference on Apple Silicon."
)?;
writeln!(code)?;
writeln!(code, "mod model;")?;
writeln!(code)?;
writeln!(code, "use std::io::Write;")?;
writeln!(code, "use std::time::Instant;")?;
writeln!(code, "use serde::Deserialize;")?;
writeln!(code)?;
writeln!(code, "fn main() {{")?;
writeln!(
code,
" let args: Vec<String> = std::env::args().collect();"
)?;
writeln!(code)?;
writeln!(
code,
" // Detect --serve mode (only requires weights + tokenizer)"
)?;
writeln!(
code,
" let serve_mode = args.iter().any(|a| a == \"--serve\");"
)?;
writeln!(code)?;
writeln!(code, " if !serve_mode && args.len() < 4 {{")?;
writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> <prompt> [--max-tokens N] [--quiet] [--profile]\", args[0]);")?;
writeln!(code, " eprintln!(\" {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
writeln!(code, " std::process::exit(1);")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " if serve_mode && args.len() < 3 {{")?;
writeln!(code, " eprintln!(\"Usage: {{}} <weights.bin> <tokenizer.json> --serve [--port 8080]\", args[0]);")?;
writeln!(code, " std::process::exit(1);")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " let weights_path = &args[1];")?;
writeln!(code, " let tokenizer_path = &args[2];")?;
writeln!(code)?;
writeln!(code, " // Parse optional flags")?;
writeln!(code, " let mut max_tokens: usize = 128;")?;
writeln!(code, " let mut port: u16 = 8080;")?;
writeln!(
code,
" let quiet = args.iter().any(|a| a == \"--quiet\" || a == \"-q\");"
)?;
writeln!(
code,
" let profile = args.iter().any(|a| a == \"--profile\");"
)?;
writeln!(code, " let mut i = 3;")?;
writeln!(code, " while i < args.len() {{")?;
writeln!(
code,
" if args[i] == \"--max-tokens\" && i + 1 < args.len() {{"
)?;
writeln!(
code,
" max_tokens = args[i + 1].parse().unwrap_or(128);"
)?;
writeln!(code, " i += 2;")?;
writeln!(
code,
" }} else if args[i] == \"--port\" && i + 1 < args.len() {{"
)?;
writeln!(
code,
" port = args[i + 1].parse().unwrap_or(8080);"
)?;
writeln!(code, " i += 2;")?;
writeln!(code, " }} else if args[i] == \"--serve\" {{")?;
writeln!(code, " i += 1;")?;
writeln!(code, " }} else if args[i] == \"--profile\" {{")?;
writeln!(code, " i += 1;")?;
writeln!(code, " }} else {{")?;
writeln!(code, " i += 1;")?;
writeln!(code, " }}")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" // Memory-map weights for zero-copy loading on Apple Silicon"
)?;
writeln!(
code,
" let weights_file = std::fs::File::open(weights_path)"
)?;
writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to open weights: {{e}}\"); std::process::exit(1); }});")?;
writeln!(
code,
" let weights_mmap = unsafe {{ memmap2::Mmap::map(&weights_file).unwrap() }};"
)?;
writeln!(code)?;
writeln!(code, " // Load tokenizer")?;
writeln!(
code,
" let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)"
)?;
writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to load tokenizer: {{e}}\"); std::process::exit(1); }});")?;
writeln!(code)?;
writeln!(code, " // Create Metal model")?;
writeln!(code, " eprintln!(\"Loading model onto Metal GPU...\");")?;
writeln!(
code,
" let mut model = model::MetalModel::new(&weights_mmap);"
)?;
writeln!(code)?;
writeln!(code, " if serve_mode {{")?;
writeln!(code, " serve(model, tokenizer, port);")?;
writeln!(code, " }} else {{")?;
writeln!(code, " let prompt = &args[3];")?;
writeln!(
code,
" cli_mode(&mut model, &tokenizer, prompt, max_tokens, quiet, profile);"
)?;
writeln!(code, " }}")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(code, "fn cli_mode(model: &mut model::MetalModel, tokenizer: &tokenizers::Tokenizer, prompt: &str, max_tokens: usize, quiet: bool, profile: bool) {{")?;
writeln!(code, " // Tokenize prompt")?;
writeln!(code, " let encoding = tokenizer.encode(prompt, true)")?;
writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Tokenization failed: {{e}}\"); std::process::exit(1); }});")?;
writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
writeln!(code)?;
writeln!(
code,
" // Process prompt tokens with batched prefill (mat-mat instead of mat-vec)."
)?;
writeln!(
code,
" // Uses double-buffered batch dispatch for GPU-efficient matmul."
)?;
writeln!(
code,
" // The last token uses synchronous forward() to get logits."
)?;
writeln!(code, " let prompt_len = prompt_tokens.len();")?;
writeln!(code, " let prefill_start = Instant::now();")?;
writeln!(code, " let logits = if prompt_len > 1 {{")?;
writeln!(
code,
" model.forward_prefill_batch(&prompt_tokens[..prompt_len - 1]);"
)?;
writeln!(code, " model.forward(prompt_tokens[prompt_len - 1])")?;
writeln!(code, " }} else {{")?;
writeln!(code, " model.forward(prompt_tokens[0])")?;
writeln!(code, " }};")?;
writeln!(
code,
" let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
)?;
writeln!(code, " let prefill_tokens = prompt_tokens.len();")?;
writeln!(
code,
" eprintln!(\"Prefill: {{}} tokens in {{:.3}}s ({{:.1}} tok/s)\","
)?;
writeln!(
code,
" prefill_tokens, prefill_elapsed, prefill_tokens as f64 / prefill_elapsed);"
)?;
writeln!(code)?;
writeln!(code, " // Generate tokens")?;
writeln!(code, " let mut next_token = argmax(&logits);")?;
writeln!(code, " let gen_start = Instant::now();")?;
writeln!(code, " let mut generated_count: usize = 0;")?;
writeln!(code)?;
writeln!(code, " for _ in 0..max_tokens {{")?;
writeln!(
code,
" if let Some(text) = tokenizer.decode(&[next_token], false).ok() {{"
)?;
writeln!(code, " if !quiet {{")?;
writeln!(code, " print!(\"{{}}\", text);")?;
writeln!(code, " std::io::stdout().flush().ok();")?;
writeln!(code, " }}")?;
writeln!(code, " }}")?;
writeln!(code, " generated_count += 1;")?;
writeln!(code)?;
writeln!(
code,
" // Use profiling forward for first token when --profile is set"
)?;
writeln!(
code,
" let logits = if profile && generated_count == 1 {{"
)?;
writeln!(code, " model.forward_profile(next_token)")?;
writeln!(code, " }} else {{")?;
writeln!(code, " model.forward(next_token)")?;
writeln!(code, " }};")?;
writeln!(code, " next_token = argmax(&logits);")?;
writeln!(code)?;
writeln!(code, " // Stop on EOS (token 2 for most models)")?;
writeln!(code, " if next_token == 2 {{")?;
writeln!(code, " break;")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(
code,
" // Yield between tokens to reduce sustained GPU thermal load."
)?;
writeln!(
code,
" // On Apple Silicon, continuous GPU saturation causes thermal throttling"
)?;
writeln!(
code,
" // (500 tok/s -> 300 tok/s). A yield_now() lets the OS scheduler run"
)?;
writeln!(
code,
" // briefly, providing a micro-break that helps sustain peak throughput."
)?;
writeln!(code, " std::thread::yield_now();")?;
writeln!(code, " }}")?;
writeln!(code, " if !quiet {{")?;
writeln!(code, " println!();")?;
writeln!(code, " }}")?;
writeln!(
code,
" let gen_elapsed = gen_start.elapsed().as_secs_f64();"
)?;
writeln!(
code,
" eprintln!(\"Generate: {{}} tokens in {{:.2}}s ({{:.1}} tok/s)\","
)?;
writeln!(
code,
" generated_count, gen_elapsed, generated_count as f64 / gen_elapsed);"
)?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(code, "/// Simple argmax over a slice of f32 logits.")?;
writeln!(code, "fn argmax(logits: &[f32]) -> u32 {{")?;
writeln!(code, " logits.iter()")?;
writeln!(code, " .enumerate()")?;
writeln!(
code,
" .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))"
)?;
writeln!(code, " .map(|(i, _)| i as u32)")?;
writeln!(code, " .unwrap_or(0)")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(
code,
"// -----------------------------------------------------------------------"
)?;
writeln!(code, "// OpenAI-compatible API server")?;
writeln!(
code,
"// -----------------------------------------------------------------------"
)?;
writeln!(code)?;
writeln!(code, "#[derive(Deserialize)]")?;
writeln!(code, "struct ChatRequest {{")?;
writeln!(code, " messages: Vec<ChatMessage>,")?;
writeln!(code, " #[serde(default)]")?;
writeln!(code, " stream: Option<bool>,")?;
writeln!(code, " #[serde(default)]")?;
writeln!(code, " max_tokens: Option<usize>,")?;
writeln!(code, " #[serde(default)]")?;
writeln!(code, " temperature: Option<f32>,")?;
writeln!(code, " #[serde(default)]")?;
writeln!(code, " model: Option<String>,")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(code, "#[derive(Deserialize)]")?;
writeln!(code, "struct ChatMessage {{")?;
writeln!(code, " role: String,")?;
writeln!(code, " content: String,")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(
code,
"fn format_chat_messages(messages: &[ChatMessage]) -> String {{"
)?;
writeln!(code, " let mut prompt = String::new();")?;
writeln!(code, " for msg in messages {{")?;
writeln!(code, " prompt.push_str(&format!(\"<|im_start|>{{}}\\n{{}}<|im_end|>\\n\", msg.role, msg.content));")?;
writeln!(code, " }}")?;
writeln!(code, " prompt.push_str(\"<|im_start|>assistant\\n\");")?;
writeln!(code, " prompt")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(
code,
"fn prefill(model: &mut model::MetalModel, tokens: &[u32]) -> Vec<f32> {{"
)?;
writeln!(code, " let len = tokens.len();")?;
writeln!(code, " if len > 1 {{")?;
writeln!(
code,
" model.forward_prefill_batch(&tokens[..len - 1]);"
)?;
writeln!(code, " }}")?;
writeln!(code, " model.forward(tokens[len - 1])")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(
code,
"fn serve(mut model: model::MetalModel, tokenizer: tokenizers::Tokenizer, port: u16) {{"
)?;
writeln!(code, " let addr = format!(\"0.0.0.0:{{}}\", port);")?;
writeln!(code, " let server = tiny_http::Server::http(&addr)")?;
writeln!(code, " .unwrap_or_else(|e| {{ eprintln!(\"Failed to bind {{}}: {{e}}\", addr); std::process::exit(1); }});")?;
writeln!(
code,
" eprintln!(\"ForgeLLM Metal server listening on http://0.0.0.0:{{}}\", port);"
)?;
writeln!(code, " eprintln!(\"Endpoints:\");")?;
writeln!(code, " eprintln!(\" POST /v1/chat/completions\");")?;
writeln!(code, " eprintln!(\" GET /v1/models\");")?;
writeln!(code, " eprintln!(\" GET /health\");")?;
writeln!(code)?;
writeln!(code, " for request in server.incoming_requests() {{")?;
writeln!(code, " let method = request.method().to_string();")?;
writeln!(code, " let url = request.url().to_string();")?;
writeln!(code)?;
writeln!(code, " match (method.as_str(), url.as_str()) {{")?;
writeln!(
code,
" (\"POST\", \"/v1/chat/completions\") => {{"
)?;
writeln!(
code,
" handle_chat_completion(&mut model, &tokenizer, request);"
)?;
writeln!(code, " }}")?;
writeln!(code, " (\"GET\", \"/v1/models\") => {{")?;
writeln!(code, " let body = serde_json::json!({{")?;
writeln!(code, " \"object\": \"list\",")?;
writeln!(code, " \"data\": [{{")?;
writeln!(code, " \"id\": \"forgellm-metal\",")?;
writeln!(code, " \"object\": \"model\",")?;
writeln!(code, " \"owned_by\": \"forgellm\"")?;
writeln!(code, " }}]")?;
writeln!(code, " }});")?;
writeln!(
code,
" let resp = tiny_http::Response::from_string(body.to_string())"
)?;
writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
writeln!(code, " request.respond(resp).ok();")?;
writeln!(code, " }}")?;
writeln!(code, " (\"GET\", \"/health\") => {{")?;
writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"status\\\":\\\"ok\\\"}}\");")?;
writeln!(code, " request.respond(resp).ok();")?;
writeln!(code, " }}")?;
writeln!(code, " _ => {{")?;
writeln!(
code,
" let resp = tiny_http::Response::from_string(\"Not Found\")"
)?;
writeln!(code, " .with_status_code(404);")?;
writeln!(code, " request.respond(resp).ok();")?;
writeln!(code, " }}")?;
writeln!(code, " }}")?;
writeln!(code, " }}")?;
writeln!(code, "}}")?;
writeln!(code)?;
writeln!(code, "fn handle_chat_completion(")?;
writeln!(code, " model: &mut model::MetalModel,")?;
writeln!(code, " tokenizer: &tokenizers::Tokenizer,")?;
writeln!(code, " mut request: tiny_http::Request,")?;
writeln!(code, ") {{")?;
writeln!(code, " // Read request body")?;
writeln!(code, " let mut body = String::new();")?;
writeln!(
code,
" if request.as_reader().read_to_string(&mut body).is_err() {{"
)?;
writeln!(code, " let resp = tiny_http::Response::from_string(\"{{\\\"error\\\":\\\"Failed to read request body\\\"}}\")")?;
writeln!(code, " .with_status_code(400);")?;
writeln!(code, " request.respond(resp).ok();")?;
writeln!(code, " return;")?;
writeln!(code, " }}")?;
writeln!(code)?;
writeln!(code, " // Parse JSON")?;
writeln!(
code,
" let req: ChatRequest = match serde_json::from_str(&body) {{"
)?;
writeln!(code, " Ok(r) => r,")?;
writeln!(code, " Err(e) => {{")?;
writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Invalid JSON: {{}}\\\"}}}}\", e))")?;
writeln!(code, " .with_status_code(400);")?;
writeln!(code, " request.respond(resp).ok();")?;
writeln!(code, " return;")?;
writeln!(code, " }}")?;
writeln!(code, " }};")?;
writeln!(code)?;
writeln!(
code,
" let prompt = format_chat_messages(&req.messages);"
)?;
writeln!(
code,
" let encoding = match tokenizer.encode(prompt.as_str(), true) {{"
)?;
writeln!(code, " Ok(e) => e,")?;
writeln!(code, " Err(e) => {{")?;
writeln!(code, " let resp = tiny_http::Response::from_string(format!(\"{{{{\\\"error\\\":\\\"Tokenization failed: {{}}\\\"}}}}\", e))")?;
writeln!(code, " .with_status_code(500);")?;
writeln!(code, " request.respond(resp).ok();")?;
writeln!(code, " return;")?;
writeln!(code, " }}")?;
writeln!(code, " }};")?;
writeln!(code, " let prompt_tokens = encoding.get_ids();")?;
writeln!(code, " let stream = req.stream.unwrap_or(false);")?;
writeln!(code, " let max_tokens = req.max_tokens.unwrap_or(256);")?;
writeln!(
code,
" let _temperature = req.temperature.unwrap_or(1.0);"
)?;
writeln!(code)?;
writeln!(code, " model.reset();")?;
writeln!(code)?;
writeln!(code, " let prefill_start = Instant::now();")?;
writeln!(code, " let logits = prefill(model, prompt_tokens);")?;
writeln!(
code,
" let prefill_elapsed = prefill_start.elapsed().as_secs_f64();"
)?;
writeln!(code, " let prefill_count = prompt_tokens.len();")?;
writeln!(code, " let mut next_token = argmax(&logits);")?;
writeln!(code)?;
writeln!(code, " if stream {{")?;
writeln!(
code,
" // SSE streaming: generate tokens and build SSE body"
)?;
writeln!(code, " let gen_start = Instant::now();")?;
writeln!(code, " let mut generated_count: usize = 0;")?;
writeln!(code, " let mut sse_body = String::new();")?;
writeln!(code, " for _ in 0..max_tokens {{")?;
writeln!(code, " if next_token == 2 {{ break; }}")?;
writeln!(
code,
" if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
)?;
writeln!(
code,
" let escaped = serde_json::to_string(&text).unwrap_or_default();"
)?;
writeln!(
code,
" // escaped includes surrounding quotes, strip them"
)?;
writeln!(
code,
" let inner = &escaped[1..escaped.len()-1];"
)?;
writeln!(code, " sse_body.push_str(&format!(")?;
writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{\\\"content\\\":\\\"{{}}\\\"}}}},\\\"finish_reason\\\":null}}}}]}}}}\\n\\n\",")?;
writeln!(code, " inner")?;
writeln!(code, " ));")?;
writeln!(code, " }}")?;
writeln!(code, " generated_count += 1;")?;
writeln!(code, " let logits = model.forward(next_token);")?;
writeln!(code, " next_token = argmax(&logits);")?;
writeln!(code, " }}")?;
writeln!(
code,
" let gen_elapsed = gen_start.elapsed().as_secs_f64();"
)?;
writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
writeln!(code, " let gen_time_ms = gen_elapsed * 1000.0;")?;
writeln!(code)?;
writeln!(
code,
" // Final chunk with finish_reason, timing, and DONE sentinel"
)?;
writeln!(code, " sse_body.push_str(&format!(")?;
writeln!(code, " \"data: {{{{\\\"id\\\":\\\"chatcmpl-1\\\",\\\"object\\\":\\\"chat.completion.chunk\\\",\\\"choices\\\":[{{{{\\\"index\\\":0,\\\"delta\\\":{{{{}}}},\\\"finish_reason\\\":\\\"stop\\\"}}}}],\\\"usage\\\":{{{{\\\"prefill_tokens\\\":{{}},\\\"prefill_time_ms\\\":{{:.1}},\\\"generation_tokens\\\":{{}},\\\"generation_time_ms\\\":{{:.1}},\\\"tokens_per_sec\\\":{{:.1}}}}}}}}}}\\n\\ndata: [DONE]\\n\\n\",")?;
writeln!(code, " prefill_count, prefill_elapsed * 1000.0, generated_count, gen_time_ms, gen_tok_s")?;
writeln!(code, " ));")?;
writeln!(code)?;
writeln!(
code,
" let resp = tiny_http::Response::from_string(sse_body)"
)?;
writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"text/event-stream\").unwrap())")?;
writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Cache-Control\", \"no-cache\").unwrap());")?;
writeln!(code, " request.respond(resp).ok();")?;
writeln!(code, " }} else {{")?;
writeln!(
code,
" // Non-streaming: generate all tokens, return JSON"
)?;
writeln!(code, " let gen_start = Instant::now();")?;
writeln!(code, " let mut generated_count: usize = 0;")?;
writeln!(code, " let mut generated = String::new();")?;
writeln!(code, " for _ in 0..max_tokens {{")?;
writeln!(code, " if next_token == 2 {{ break; }}")?;
writeln!(
code,
" if let Ok(text) = tokenizer.decode(&[next_token], false) {{"
)?;
writeln!(code, " generated.push_str(&text);")?;
writeln!(code, " }}")?;
writeln!(code, " generated_count += 1;")?;
writeln!(code, " let logits = model.forward(next_token);")?;
writeln!(code, " next_token = argmax(&logits);")?;
writeln!(code, " }}")?;
writeln!(
code,
" let gen_elapsed = gen_start.elapsed().as_secs_f64();"
)?;
writeln!(code, " let gen_tok_s = if gen_elapsed > 0.0 {{ generated_count as f64 / gen_elapsed }} else {{ 0.0 }};")?;
writeln!(code)?;
writeln!(code, " let resp_json = serde_json::json!({{")?;
writeln!(code, " \"id\": \"chatcmpl-1\",")?;
writeln!(code, " \"object\": \"chat.completion\",")?;
writeln!(code, " \"choices\": [{{")?;
writeln!(code, " \"index\": 0,")?;
writeln!(code, " \"message\": {{")?;
writeln!(code, " \"role\": \"assistant\",")?;
writeln!(code, " \"content\": generated")?;
writeln!(code, " }},")?;
writeln!(code, " \"finish_reason\": \"stop\"")?;
writeln!(code, " }}],")?;
writeln!(code, " \"usage\": {{")?;
writeln!(code, " \"prefill_tokens\": prefill_count,")?;
writeln!(
code,
" \"prefill_time_ms\": (prefill_elapsed * 1000.0) as u64,"
)?;
writeln!(
code,
" \"generation_tokens\": generated_count,"
)?;
writeln!(
code,
" \"generation_time_ms\": (gen_elapsed * 1000.0) as u64,"
)?;
writeln!(code, " \"tokens_per_sec\": gen_tok_s")?;
writeln!(code, " }}")?;
writeln!(code, " }});")?;
writeln!(
code,
" let resp = tiny_http::Response::from_string(resp_json.to_string())"
)?;
writeln!(code, " .with_header(tiny_http::Header::from_bytes(\"Content-Type\", \"application/json\").unwrap());")?;
writeln!(code, " request.respond(resp).ok();")?;
writeln!(code, " }}")?;
writeln!(code, "}}")?;
Ok(code)
}
#[cfg(test)]
mod tests {
use super::*;
use forgellm_frontend::ir::{Architecture, DType, ModelConfig};
fn minimal_config() -> ModelConfig {
ModelConfig {
architecture: Architecture::Llama,
hidden_size: 64,
intermediate_size: 128,
num_layers: 2,
num_attention_heads: 4,
num_kv_heads: 4,
head_dim: 16,
vocab_size: 256,
max_seq_len: 512,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
dtype: DType::F32,
sliding_window_size: None,
qkv_bias: false,
}
}
fn minimal_graph() -> Graph {
Graph::new("test-metal").with_config(minimal_config())
}
#[test]
fn generate_metal_project_creates_files() {
let dir = tempfile::tempdir().unwrap();
let graph = minimal_graph();
generate_metal_project(&graph, dir.path(), "test-model").unwrap();
assert!(
dir.path().join("Cargo.toml").exists(),
"Cargo.toml should be created"
);
assert!(
dir.path().join("src/model.rs").exists(),
"src/model.rs should be created"
);
assert!(
dir.path().join("src/main.rs").exists(),
"src/main.rs should be created"
);
assert!(
dir.path().join("shaders/kernels.metal").exists(),
"shaders/kernels.metal should be created"
);
}
#[test]
fn generated_cargo_toml_has_metal_dep() {
let toml = generate_cargo_toml("my-model");
assert!(toml.contains("metal"), "Cargo.toml should depend on metal");
assert!(
toml.contains("tokenizers"),
"Cargo.toml should depend on tokenizers"
);
assert!(
toml.contains("memmap2"),
"Cargo.toml should depend on memmap2"
);
assert!(toml.contains("half"), "Cargo.toml should depend on half");
}
#[test]
fn generated_model_rs_contains_metal_code() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("pub struct MetalModel"),
"model.rs should define MetalModel struct"
);
assert!(
model_rs.contains("matmul_pipeline: ComputePipelineState"),
"MetalModel should have matmul_pipeline field"
);
assert!(
model_rs.contains("Device::system_default()"),
"model.rs should use Metal device"
);
assert!(
model_rs.contains("new_library_with_source"),
"model.rs should compile Metal shaders"
);
assert!(
model_rs.contains("fn new(weights: &[u8])"),
"MetalModel should implement new()"
);
assert!(
model_rs.contains("fn forward(&mut self, token_id: u32)"),
"MetalModel should implement forward()"
);
}
#[test]
fn generated_shaders_contain_kernel_names() {
let shaders = generate_metal_shaders(&minimal_config());
assert!(
shaders.contains("kernel void matmul_vec"),
"shaders should contain matmul_vec kernel"
);
assert!(
shaders.contains("kernel void rms_norm"),
"shaders should contain rms_norm kernel"
);
assert!(
shaders.contains("kernel void rope"),
"shaders should contain rope kernel"
);
assert!(
shaders.contains("kernel void softmax"),
"shaders should contain softmax kernel"
);
assert!(
shaders.contains("kernel void silu_mul("),
"shaders should contain silu_mul kernel"
);
assert!(
shaders.contains("kernel void silu_mul_fused"),
"shaders should contain silu_mul_fused kernel"
);
assert!(
shaders.contains("kernel void elementwise_add"),
"shaders should contain elementwise_add kernel"
);
assert!(
shaders.contains("kernel void attention"),
"shaders should contain attention kernel"
);
assert!(
shaders.contains("kernel void add_inplace"),
"shaders should contain add_inplace kernel"
);
assert!(
shaders.contains("kernel void copy_buffer"),
"shaders should contain copy_buffer kernel"
);
assert!(
shaders.contains("kernel void copy_offset"),
"shaders should contain copy_offset kernel"
);
}
#[test]
fn generated_shaders_use_simdgroup_features() {
let shaders = generate_metal_shaders(&minimal_config());
assert!(
shaders.contains("threadgroup_barrier"),
"shaders should use threadgroup barriers"
);
assert!(
shaders.contains("threadgroup float"),
"shaders should use threadgroup shared memory"
);
assert!(
shaders.contains("thread_index_in_threadgroup"),
"shaders should use threadgroup indexing"
);
assert!(
shaders.contains("simd_sum"),
"shaders should use simd_sum for warp-level reduction"
);
assert!(
shaders.contains("simd_max"),
"attention kernel should use simd_max for cooperative softmax"
);
assert!(
shaders.contains("thread_index_in_simdgroup"),
"shaders should use simdgroup lane indexing"
);
assert!(
shaders.contains("simdgroup_index_in_threadgroup"),
"shaders should use simdgroup indexing within threadgroup"
);
assert!(
shaders.contains("float4"),
"matmul_vec should use float4 vectorized loads"
);
}
#[test]
fn generated_main_rs_has_tokenizer_usage() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
assert!(
main_rs.contains("tokenizers::Tokenizer"),
"main.rs should use tokenizers crate"
);
assert!(
main_rs.contains("MetalModel::new"),
"main.rs should call MetalModel::new"
);
assert!(
main_rs.contains("model.forward"),
"main.rs should call model.forward"
);
assert!(
main_rs.contains("memmap2"),
"main.rs should use memmap2 for zero-copy weight loading"
);
}
#[test]
fn missing_config_returns_error() {
let dir = tempfile::tempdir().unwrap();
let graph = Graph::new("no-config");
let result = generate_metal_project(&graph, dir.path(), "fail");
assert!(
matches!(result, Err(MetalCodegenError::MissingConfig)),
"should fail with MissingConfig when graph has no config"
);
}
#[test]
fn sanitize_name_works() {
assert_eq!(sanitize_name("My Model!"), "my-model");
assert_eq!(sanitize_name("test_model"), "test-model");
assert_eq!(sanitize_name("simple"), "simple");
}
#[test]
fn generated_forward_uses_single_command_buffer() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let forward_start = model_rs
.find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
.unwrap();
let forward_body = &model_rs[forward_start..];
let forward_end = forward_body
.find("\n pub fn forward_profile")
.or_else(|| forward_body.find("\n pub fn forward_prefill"))
.or_else(|| forward_body.find("\n fn dispatch_"))
.unwrap_or(forward_body.len());
let forward_code = &forward_body[..forward_end];
let cmd_buf_count = forward_code.matches("new_command_buffer()").count();
assert_eq!(
cmd_buf_count, 1,
"forward() should create exactly 1 command buffer, found {cmd_buf_count}"
);
let commit_count = forward_code.matches("cmd.commit()").count();
assert_eq!(
commit_count, 1,
"forward() should commit exactly once, found {commit_count}"
);
let wait_count = forward_code.matches("wait_until_completed()").count();
assert!(
wait_count >= 1 && wait_count <= 2,
"forward() should wait 1-2 times (cmd + optional prev_cmd drain), found {wait_count}"
);
}
#[test]
fn generated_model_has_preallocated_working_buffers() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
for buf_name in &[
"normed_buf",
"qkv_buf",
"attn_out_buf",
"attn_proj_buf",
"gate_up_buf",
"ffn_hidden_buf",
"ffn_out_buf",
"add_tmp_buf",
] {
assert!(
model_rs.contains(&format!("{buf_name}: Buffer")),
"MetalModel should have pre-allocated {buf_name} field"
);
}
}
#[test]
fn generated_dispatch_helpers_take_compute_encoder_ref() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
for method in &[
"fn dispatch_rms_norm(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_matmul(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_rope(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_rope_offset(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_attention(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_attention_offset(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_silu_mul(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_silu_mul_fused(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_add_inplace(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_copy(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_copy_offset(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_copy_from_offset(&self, enc: &ComputeCommandEncoderRef",
"fn dispatch_copy_to_offset(&self, enc: &ComputeCommandEncoderRef",
] {
assert!(
model_rs.contains(method),
"model.rs should contain dispatch helper: {method}"
);
}
}
#[test]
fn generated_helpers_do_not_create_command_buffers_or_encoders() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let helpers_start = model_rs.find("fn dispatch_rms_norm").unwrap();
let helpers_code = &model_rs[helpers_start..];
assert!(
!helpers_code.contains("self.queue.new_command_buffer()"),
"dispatch helpers should not create their own command buffers"
);
assert!(
!helpers_code.contains("new_compute_command_encoder()"),
"dispatch helpers should not create their own compute encoders"
);
assert!(
!helpers_code.contains("end_encoding()"),
"dispatch helpers should not call end_encoding"
);
assert!(
!helpers_code.contains(".commit()"),
"dispatch helpers should not commit command buffers"
);
assert!(
!helpers_code.contains("wait_until_completed"),
"dispatch helpers should not wait on command buffers"
);
}
#[test]
fn generated_forward_batches_compute_encoders() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let forward_start = model_rs
.find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
.unwrap();
let forward_body = &model_rs[forward_start..];
let forward_end = forward_body
.find("\n pub fn forward_profile")
.or_else(|| forward_body.find("\n pub fn forward_prefill"))
.or_else(|| forward_body.find("\n fn dispatch_"))
.unwrap_or(forward_body.len());
let forward_code = &forward_body[..forward_end];
assert!(
!forward_code.contains("device.new_buffer"),
"forward() should not allocate new buffers per call"
);
let compute_encoder_count = forward_code
.matches("new_compute_command_encoder()")
.count();
let blit_encoder_count = forward_code.matches("new_blit_command_encoder()").count();
assert_eq!(
compute_encoder_count, 1,
"forward() should use exactly 1 compute encoder for the entire pass, found {compute_encoder_count}"
);
assert_eq!(
blit_encoder_count, 0,
"forward() should have zero blit encoders (replaced by compute copy kernels), found {blit_encoder_count}"
);
}
#[test]
fn generated_forward_uses_add_inplace() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("dispatch_add_inplace"),
"forward() should use dispatch_add_inplace for residual connections"
);
assert!(
model_rs.contains("add_inplace_pipeline"),
"MetalModel should have add_inplace_pipeline"
);
}
fn minimal_q8_config() -> ModelConfig {
ModelConfig {
architecture: Architecture::Llama,
hidden_size: 64,
intermediate_size: 128,
num_layers: 2,
num_attention_heads: 4,
num_kv_heads: 4,
head_dim: 16,
vocab_size: 256,
max_seq_len: 512,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
dtype: DType::Q8_0,
sliding_window_size: None,
qkv_bias: false,
}
}
#[test]
fn generated_shaders_contain_q8_kernel() {
let shaders = generate_metal_shaders(&minimal_config());
assert!(
shaders.contains("kernel void matmul_vec_q8"),
"shaders should contain matmul_vec_q8 kernel"
);
assert!(
shaders.contains("device const uchar* matrix"),
"matmul_vec_q8 should accept raw Q8_0 bytes"
);
assert!(
shaders.contains("packed_short4"),
"matmul_vec_q8 should use packed_short4 wide 64-bit loads for int8 data"
);
assert!(
shaders.contains("as_type<char2>"),
"matmul_vec_q8 should bitcast short lanes to char2"
);
assert!(
shaders.contains("device const half*"),
"matmul_vec_q8 should read f16 scale via half pointer"
);
}
#[test]
fn generated_model_uses_fused_qkv_projections() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("qkv_weight: Buffer"),
"LayerBuffers should have fused qkv_weight field"
);
assert!(
!model_rs.contains(" q_weight: Buffer"),
"LayerBuffers should not have separate q_weight field"
);
assert!(
!model_rs.contains(" k_weight: Buffer"),
"LayerBuffers should not have separate k_weight field"
);
assert!(
!model_rs.contains(" v_weight: Buffer"),
"LayerBuffers should not have separate v_weight field"
);
assert!(
model_rs.contains("gate_up_weight: Buffer"),
"LayerBuffers should have fused gate_up_weight field"
);
assert!(
!model_rs.contains(" gate_weight: Buffer"),
"LayerBuffers should not have separate gate_weight field"
);
assert!(
!model_rs.contains(" up_weight: Buffer"),
"LayerBuffers should not have separate up_weight field"
);
assert!(
model_rs.contains("qkv_buf: Buffer"),
"MetalModel should have fused qkv_buf"
);
assert!(
model_rs.contains("gate_up_buf: Buffer"),
"MetalModel should have fused gate_up_buf"
);
assert!(
model_rs.contains("dispatch_silu_mul_fused"),
"forward pass should use dispatch_silu_mul_fused"
);
assert!(
model_rs.contains("dispatch_rope_offset"),
"forward pass should use dispatch_rope_offset for fused QKV"
);
assert!(
model_rs.contains("dispatch_attention_offset"),
"forward pass should use dispatch_attention_offset for fused QKV"
);
}
#[test]
fn q8_model_has_matmul_q8_pipeline() {
let config = minimal_q8_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("matmul_q8_pipeline: ComputePipelineState"),
"MetalModel should have matmul_q8_pipeline field"
);
assert!(
model_rs.contains("matmul_q8_pipeline,"),
"MetalModel Self should include matmul_q8_pipeline"
);
}
#[test]
fn q8_model_uses_dispatch_matmul_q8() {
let config = minimal_q8_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("dispatch_matmul_q8"),
"Q8_0 model should use dispatch_matmul_q8 for projections"
);
assert!(
model_rs.contains("fn dispatch_matmul_q8"),
"model.rs should define dispatch_matmul_q8 method"
);
}
#[test]
fn q8_model_loads_raw_bytes_not_dequantized() {
let config = minimal_q8_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
!model_rs.contains("f16_to_f32"),
"Q8_0 model should not dequantize weights to f32"
);
assert!(
!model_rs.contains("f32_data"),
"Q8_0 model should not create f32 weight data"
);
assert!(
model_rs.contains("total_raw as u64"),
"Q8_0 model should load raw bytes into Metal buffer"
);
}
#[test]
fn q8_model_norms_stay_f32() {
let config = minimal_q8_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("let attn_norm = next_f32_buffer"),
"attn_norm should use f32 buffer even for Q8_0 models"
);
assert!(
model_rs.contains("let ffn_norm = next_f32_buffer"),
"ffn_norm should use f32 buffer even for Q8_0 models"
);
assert!(
model_rs.contains("let norm_buf = next_f32_buffer"),
"final norm should use f32 buffer even for Q8_0 models"
);
}
#[test]
fn q8_model_uses_fused_weight_loading() {
let config = minimal_q8_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("let qkv_weight = next_q8_fused_buffer"),
"Q8_0 model should use next_q8_fused_buffer for fused QKV weights"
);
assert!(
model_rs.contains("let gate_up_weight = next_q8_fused_buffer"),
"Q8_0 model should use next_q8_fused_buffer for fused gate+up weights"
);
assert!(
model_rs.contains("let o_weight = next_q8_buffer"),
"Q8_0 model should use next_q8_buffer for O weight"
);
assert!(
model_rs.contains("let down_weight = next_q8_buffer"),
"Q8_0 model should use next_q8_buffer for down weight"
);
}
#[test]
fn f32_model_does_not_use_q8_dispatch() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let forward_start = model_rs
.find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
.unwrap();
let forward_body = &model_rs[forward_start..];
let forward_end = forward_body
.find("\n fn dispatch_")
.unwrap_or(forward_body.len());
let forward_code = &forward_body[..forward_end];
assert!(
!forward_code.contains("dispatch_matmul_q8"),
"f32 model forward should not use dispatch_matmul_q8"
);
}
#[test]
fn q8_dispatch_helper_takes_compute_encoder_ref() {
let config = minimal_q8_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("fn dispatch_matmul_q8(&self, enc: &ComputeCommandEncoderRef"),
"dispatch_matmul_q8 should take ComputeCommandEncoderRef"
);
}
#[test]
fn generated_model_has_double_buffered_prefill() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("prev_cmd: Option<CommandBuffer>"),
"MetalModel should have prev_cmd field for double-buffered prefill"
);
assert!(
model_rs.contains("pub fn forward_prefill(&mut self, token_id: u32)"),
"MetalModel should have forward_prefill method"
);
assert!(
model_rs.contains("if let Some(prev) = self.prev_cmd.take()"),
"forward() should drain prev_cmd from previous prefill"
);
}
#[test]
fn generated_main_rs_uses_forward_prefill_for_prompt() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
assert!(
main_rs.contains("forward_prefill"),
"main.rs should use forward_prefill for intermediate prompt tokens"
);
assert!(
main_rs.contains("double-buffered"),
"main.rs should document double-buffered prefill"
);
}
#[test]
fn generated_shaders_q8_uses_wide_vectorized_loads() {
let shaders = generate_metal_shaders(&minimal_config());
assert!(
shaders.contains("packed_short4"),
"matmul_vec_q8 should use packed_short4 wide 64-bit loads"
);
assert!(
shaders.contains("d0[0]"),
"matmul_vec_q8 should index the wide pointer for row 0"
);
assert!(
shaders.contains("as_type<char2>"),
"matmul_vec_q8 should bitcast short lanes to char2"
);
assert!(
shaders.contains("dot("),
"matmul_vec_q8 should use dot() intrinsic for fma accumulation"
);
}
fn minimal_q4_config() -> ModelConfig {
ModelConfig {
architecture: Architecture::Llama,
hidden_size: 64,
intermediate_size: 128,
num_layers: 2,
num_attention_heads: 4,
num_kv_heads: 4,
head_dim: 16,
vocab_size: 256,
max_seq_len: 512,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
dtype: DType::Q4_0,
sliding_window_size: None,
qkv_bias: false,
}
}
#[test]
fn generated_shaders_contain_q4_kernel() {
let shaders = generate_metal_shaders(&minimal_config());
assert!(
shaders.contains("kernel void matmul_vec_q4"),
"shaders should contain matmul_vec_q4 kernel"
);
assert!(
shaders.contains("Q4_ROWS_PER_TG"),
"shaders should define Q4_ROWS_PER_TG constant"
);
assert!(
shaders.contains("Q4_ROWS_PER_SG"),
"shaders should define Q4_ROWS_PER_SG constant"
);
}
#[test]
fn generated_shaders_q4_uses_uchar4_nibble_unpacking() {
let shaders = generate_metal_shaders(&minimal_config());
assert!(
shaders.contains("uchar4"),
"matmul_vec_q4 should use uchar4 for packed byte loads"
);
assert!(
shaders.contains("&0xF"),
"matmul_vec_q4 should extract low nibble with &0xF"
);
assert!(
shaders.contains(">>4"),
"matmul_vec_q4 should extract high nibble with >>4"
);
assert!(
shaders.contains("-8)"),
"matmul_vec_q4 should subtract 8 for unsigned-to-signed conversion"
);
assert!(
shaders.contains("blk * 18"),
"matmul_vec_q4 should use 18-byte block stride"
);
}
#[test]
fn q4_model_has_matmul_q4_pipeline() {
let config = minimal_q4_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("matmul_q4_pipeline: ComputePipelineState"),
"MetalModel should have matmul_q4_pipeline field"
);
assert!(
model_rs.contains("matmul_q4_pipeline,"),
"MetalModel Self should include matmul_q4_pipeline"
);
}
#[test]
fn q4_model_uses_dispatch_matmul_q4() {
let config = minimal_q4_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("dispatch_matmul_q4"),
"Q4_0 model should use dispatch_matmul_q4 for projections"
);
assert!(
model_rs.contains("fn dispatch_matmul_q4"),
"model.rs should define dispatch_matmul_q4 method"
);
}
#[test]
fn q4_model_loads_raw_bytes_not_dequantized() {
let config = minimal_q4_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
!model_rs.contains("f16_to_f32"),
"Q4_0 model should not dequantize weights to f32"
);
assert!(
model_rs.contains("total_raw as u64"),
"Q4_0 model should load raw bytes into Metal buffer"
);
}
#[test]
fn q4_model_norms_stay_f32() {
let config = minimal_q4_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("let attn_norm = next_f32_buffer"),
"attn_norm should use f32 buffer even for Q4_0 models"
);
assert!(
model_rs.contains("let ffn_norm = next_f32_buffer"),
"ffn_norm should use f32 buffer even for Q4_0 models"
);
assert!(
model_rs.contains("let norm_buf = next_f32_buffer"),
"final norm should use f32 buffer even for Q4_0 models"
);
}
#[test]
fn q4_model_uses_fused_weight_loading() {
let config = minimal_q4_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("let qkv_weight = next_q4_fused_buffer"),
"Q4_0 model should use next_q4_fused_buffer for fused QKV weights"
);
assert!(
model_rs.contains("let gate_up_weight = next_q4_fused_buffer"),
"Q4_0 model should use next_q4_fused_buffer for fused gate+up weights"
);
assert!(
model_rs.contains("let o_weight = next_q4_buffer"),
"Q4_0 model should use next_q4_buffer for O weight"
);
assert!(
model_rs.contains("let down_weight = next_q4_buffer"),
"Q4_0 model should use next_q4_buffer for down weight"
);
}
#[test]
fn attention_flash_batch_kernel_exists() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let shaders = generate_metal_shaders(&config);
assert!(
shaders.contains("kernel void attention_flash_batch"),
"shaders.metal must still contain the attention_flash_batch kernel"
);
assert!(
shaders.contains("FLASH_K_TILE"),
"flash kernel must tile K/V with a FLASH_K_TILE constant"
);
assert!(
model_rs.contains("attention_flash_batch_pipeline"),
"MetalModel must register the flash attention pipeline"
);
}
#[test]
fn decode_uses_fused_rope_and_kv_copy() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let forward_start = model_rs.find("pub fn forward(").expect("forward missing");
let forward_end = model_rs[forward_start..]
.find("pub fn forward_prefill(")
.expect("forward_prefill missing");
let forward_body = &model_rs[forward_start..forward_start + forward_end];
assert!(
forward_body.contains("dispatch_rope_qk_batch(&enc, &self.qkv_buf, 1"),
"single-token forward must use fused rope_qk_batch with M=1"
);
assert!(
forward_body.contains("dispatch_copy_kv_both_batch(&enc, &self.qkv_buf,"),
"single-token forward must use fused copy_kv_both_batch"
);
assert!(
!forward_body.contains("dispatch_copy_from_offset_f16"),
"decode path should no longer call the per-K/per-V copy"
);
}
#[test]
fn kv_cache_stored_as_f16() {
let config = minimal_config();
let shaders = generate_metal_shaders(&config);
let model_rs = generate_model_rs(&config).unwrap();
for kernel in [
"kernel void attention(",
"kernel void attention_batch(",
"kernel void attention_flash_batch(",
"kernel void attention_mma_flash_batch(",
] {
let start = shaders
.find(kernel)
.unwrap_or_else(|| panic!("kernel {kernel} missing"));
let sig_end = shaders[start..]
.find("){")
.unwrap_or_else(|| shaders[start..].find(") {").unwrap());
let sig = &shaders[start..start + sig_end];
assert!(
sig.contains("device const half*")
&& sig.contains("k_cache")
&& sig.contains("v_cache"),
"{kernel} must read k_cache/v_cache as `device const half*`"
);
assert!(
!sig.contains("device const float* k_cache")
&& !sig.contains("device const float* k_cache"),
"{kernel} still reads k_cache as float"
);
}
assert!(
shaders.contains("kernel void copy_f32_to_f16_offset"),
"f32->f16 copy kernel must be present for single-token decode KV writes"
);
assert!(
model_rs.contains("dispatch_copy_from_offset_f16"),
"single-token decode must dispatch the f32->f16 KV copy"
);
}
#[test]
fn decode_attention_uses_half4_vectorized_loads() {
let config = minimal_config();
let shaders = generate_metal_shaders(&config);
let start = shaders
.find("kernel void attention(")
.expect("decode attention kernel missing");
let end_rel = shaders[start + 1..]
.find("kernel void ")
.expect("next kernel missing");
let body = &shaders[start..start + 1 + end_rel];
assert!(
body.contains("device const half4*"),
"decode attention must half4-load K/V"
);
assert!(
body.contains("device const float4*"),
"decode attention must float4-load Q"
);
assert!(
body.contains("device float4*"),
"decode attention must float4-store output"
);
assert!(
body.contains("head_dim4"),
"decode attention must iterate head_dim in chunks of 4"
);
}
#[test]
fn attention_mma_flash_batch_kernel_wired() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let shaders = generate_metal_shaders(&config);
assert!(
shaders.contains("kernel void attention_mma_flash_batch"),
"shaders.metal must contain the MMA flash kernel"
);
assert!(
shaders.contains("FLASH_MMA_Q_BLOCK"),
"MMA flash kernel must define Q_BLOCK tiling constant"
);
assert!(
shaders.contains("simdgroup_multiply_accumulate"),
"MMA flash kernel must use hardware MMA"
);
assert!(
model_rs.contains("attention_mma_flash_batch_pipeline"),
"MetalModel must register the MMA flash pipeline"
);
assert!(
model_rs.contains("mma_opt_out"),
"dispatch_attention_batch must read FORGE_MMA_ATTN as opt-out"
);
assert!(
model_rs.contains("!mma_opt_out && HEAD_DIM <= 128 && num_tokens >= 8"),
"MMA flash must be default-on when HEAD_DIM ≤ 128 and num_tokens ≥ 8"
);
}
#[test]
fn forward_prefill_batch_chunks_by_max_batch_size() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("for chunk in tokens.chunks(MAX_BATCH_SIZE)"),
"forward_prefill_batch must chunk long prompts"
);
assert!(
!model_rs.contains("tokens.len().min(MAX_BATCH_SIZE)"),
"the old truncation path must be gone"
);
}
#[test]
fn qwen2_qkv_bias_wired_through_metal_codegen() {
let config = ModelConfig {
architecture: Architecture::Qwen2,
qkv_bias: true,
..minimal_config()
};
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("qkv_bias: Buffer"),
"Qwen2 LayerBuffers must declare qkv_bias field"
);
assert!(
model_rs.contains("let qkv_bias = next_f32_buffer"),
"Qwen2 layer init must load the bias from the weight blob"
);
assert!(
model_rs.contains("add_bias_batch_pipeline"),
"Qwen2 model struct must include the add_bias_batch_pipeline"
);
assert!(
model_rs.contains("fn dispatch_add_bias_batch"),
"Qwen2 codegen must emit dispatch_add_bias_batch helper"
);
assert!(
model_rs.contains("dispatch_add_bias_batch(&enc, &self.batch_qkv_buf"),
"forward_prefill_batch must call dispatch_add_bias_batch on batch_qkv_buf"
);
assert!(
model_rs.contains("dispatch_add_bias_batch(&enc, &self.qkv_buf"),
"forward must call dispatch_add_bias_batch on the single-token qkv_buf"
);
let shaders = generate_metal_shaders(&config);
assert!(
shaders.contains("kernel void add_bias_batch"),
"shaders.metal must contain the add_bias_batch kernel"
);
}
#[test]
fn llama_does_not_emit_qkv_bias_machinery() {
let config = minimal_config();
assert!(!config.qkv_bias);
let model_rs = generate_model_rs(&config).unwrap();
assert!(
!model_rs.contains("qkv_bias: Buffer"),
"Llama must not have qkv_bias field"
);
assert!(
!model_rs.contains("add_bias_batch_pipeline"),
"Llama must not pull in add_bias_batch_pipeline"
);
assert!(
!model_rs.contains("dispatch_add_bias_batch"),
"Llama must not call dispatch_add_bias_batch"
);
}
#[test]
fn q4_dispatch_helper_takes_compute_encoder_ref() {
let config = minimal_q4_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("fn dispatch_matmul_q4(&self, enc: &ComputeCommandEncoderRef"),
"dispatch_matmul_q4 should take ComputeCommandEncoderRef"
);
}
#[test]
fn f32_model_does_not_use_q4_dispatch() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let forward_start = model_rs
.find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
.unwrap();
let forward_body = &model_rs[forward_start..];
let forward_end = forward_body
.find("\n fn dispatch_")
.unwrap_or(forward_body.len());
let forward_code = &forward_body[..forward_end];
assert!(
!forward_code.contains("dispatch_matmul_q4"),
"f32 model forward should not use dispatch_matmul_q4"
);
}
#[test]
fn q4_model_lm_head_uses_q4_buffer() {
let config = minimal_q4_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("let lm_head_buf = next_q4_buffer"),
"Q4_0 model should use next_q4_buffer for lm_head"
);
}
#[test]
fn vec_tile_size_matches_model_dimensions() {
let small = minimal_config();
let shaders_small = generate_metal_shaders(&small);
assert!(
shaders_small.contains("vec_tile[128]"),
"vec_tile should be sized to max(hidden, intermediate) = 128"
);
let mut large = minimal_config();
large.hidden_size = 2048;
large.intermediate_size = 8192;
let shaders_large = generate_metal_shaders(&large);
assert!(
shaders_large.contains("vec_tile[8192]"),
"vec_tile should be 8192 for models with intermediate=8192"
);
assert!(
!shaders_large.contains("vec_tile[4096]"),
"vec_tile should NOT be hardcoded to 4096"
);
}
#[test]
fn generated_cargo_toml_has_server_deps() {
let toml = generate_cargo_toml("my-model");
assert!(
toml.contains("tiny_http"),
"Cargo.toml should depend on tiny_http"
);
assert!(toml.contains("serde"), "Cargo.toml should depend on serde");
assert!(
toml.contains("serde_json"),
"Cargo.toml should depend on serde_json"
);
}
#[test]
fn generated_main_rs_has_serve_mode() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
assert!(
main_rs.contains("--serve"),
"main.rs should parse --serve flag"
);
assert!(
main_rs.contains("--port"),
"main.rs should parse --port flag"
);
assert!(
main_rs.contains("fn serve("),
"main.rs should define serve function"
);
assert!(
main_rs.contains("tiny_http::Server::http"),
"main.rs should create tiny_http server"
);
}
#[test]
fn generated_main_rs_has_chat_completions_endpoint() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
assert!(
main_rs.contains("/v1/chat/completions"),
"main.rs should handle /v1/chat/completions endpoint"
);
assert!(
main_rs.contains("/v1/models"),
"main.rs should handle /v1/models endpoint"
);
assert!(
main_rs.contains("/health"),
"main.rs should handle /health endpoint"
);
}
#[test]
fn generated_main_rs_has_sse_streaming() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
assert!(
main_rs.contains("text/event-stream"),
"main.rs should set SSE content type for streaming"
);
assert!(
main_rs.contains("chat.completion.chunk"),
"main.rs should emit SSE chunks"
);
assert!(
main_rs.contains("[DONE]"),
"main.rs should emit [DONE] sentinel"
);
}
#[test]
fn generated_main_rs_has_chat_message_formatting() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
assert!(
main_rs.contains("fn format_chat_messages"),
"main.rs should define format_chat_messages function"
);
assert!(
main_rs.contains("<|im_start|>"),
"main.rs should use ChatML format"
);
assert!(
main_rs.contains("<|im_end|>"),
"main.rs should use ChatML format"
);
}
#[test]
fn generated_main_rs_has_request_types() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
assert!(
main_rs.contains("struct ChatRequest"),
"main.rs should define ChatRequest struct"
);
assert!(
main_rs.contains("struct ChatMessage"),
"main.rs should define ChatMessage struct"
);
assert!(
main_rs.contains("Deserialize"),
"main.rs should derive Deserialize for request types"
);
}
#[test]
fn generated_model_has_reset_method() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("pub fn reset(&mut self)"),
"model.rs should have a reset() method for multi-request serving"
);
assert!(
model_rs.contains("self.pos = 0"),
"reset() should reset position to 0"
);
}
#[test]
fn generated_main_rs_cli_mode_still_works() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
assert!(
main_rs.contains("fn cli_mode("),
"main.rs should define cli_mode function"
);
assert!(
main_rs.contains("model.forward"),
"main.rs should call model.forward"
);
assert!(
main_rs.contains("model.forward_prefill"),
"main.rs should call model.forward_prefill"
);
}
#[test]
fn generated_shaders_contain_batch_kernels() {
let shaders = generate_metal_shaders(&minimal_config());
assert!(
shaders.contains("kernel void matmul_vec_batch"),
"shaders should contain matmul_vec_batch kernel"
);
assert!(
shaders.contains("kernel void matmul_vec_q8_batch"),
"shaders should contain matmul_vec_q8_batch kernel"
);
assert!(
shaders.contains("kernel void matmul_q8_gemm_batch"),
"shaders should contain matmul_q8_gemm_batch kernel (weight-reuse GEMM)"
);
assert!(
shaders.contains("kernel void matmul_vec_q4_batch"),
"shaders should contain matmul_vec_q4_batch kernel"
);
assert!(
shaders.contains("kernel void rms_norm_batch"),
"shaders should contain rms_norm_batch kernel"
);
assert!(
shaders.contains("kernel void silu_mul_fused_batch"),
"shaders should contain silu_mul_fused_batch kernel"
);
assert!(
shaders.contains("kernel void add_inplace_batch"),
"shaders should contain add_inplace_batch kernel"
);
assert!(
shaders.contains("kernel void copy_embedding_batch"),
"shaders should contain copy_embedding_batch kernel"
);
}
#[test]
fn generated_model_has_batch_pipelines() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
for pipeline in &[
"matmul_batch_pipeline",
"matmul_q8_batch_pipeline",
"matmul_q8_gemm_batch_pipeline",
"matmul_q4_batch_pipeline",
"rms_norm_batch_pipeline",
"rope_batch_pipeline",
"silu_mul_fused_batch_pipeline",
"add_inplace_batch_pipeline",
"copy_embedding_batch_pipeline",
"attention_batch_pipeline",
"copy_kv_batch_pipeline",
] {
assert!(
model_rs.contains(&format!("{pipeline}: ComputePipelineState")),
"MetalModel should have {pipeline} field"
);
}
}
#[test]
fn generated_model_has_batch_buffers() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
for buf in &[
"batch_hidden_buf",
"batch_residual_buf",
"batch_qkv_buf",
"batch_attn_out_buf",
"batch_attn_proj_buf",
"batch_gate_up_buf",
"batch_ffn_hidden_buf",
"batch_ffn_out_buf",
"batch_tokens_buf",
"batch_positions_buf",
] {
assert!(
model_rs.contains(&format!("{buf}: Buffer")),
"MetalModel should have {buf} field"
);
}
}
#[test]
fn generated_model_has_forward_prefill_batch() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("pub fn forward_prefill_batch(&mut self, tokens: &[u32])"),
"MetalModel should have forward_prefill_batch method"
);
assert!(
model_rs.contains("self.forward_prefill_batch(&[token_id])"),
"forward_prefill should delegate to forward_prefill_batch"
);
}
#[test]
fn generated_model_has_max_batch_size_constant() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("pub const MAX_BATCH_SIZE: usize = 512"),
"model.rs should define MAX_BATCH_SIZE constant"
);
}
#[test]
fn forward_prefill_batch_uses_batch_dispatch() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let batch_start = model_rs
.find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
.unwrap();
let batch_body = &model_rs[batch_start..];
let batch_end = batch_body
.find("\n pub fn reset")
.unwrap_or(batch_body.len());
let batch_code = &batch_body[..batch_end];
assert!(
batch_code.contains("dispatch_rms_norm_batch"),
"forward_prefill_batch should use dispatch_rms_norm_batch"
);
assert!(
batch_code.contains("dispatch_copy_embedding_batch"),
"forward_prefill_batch should use dispatch_copy_embedding_batch"
);
assert!(
batch_code.contains("dispatch_silu_mul_fused_batch"),
"forward_prefill_batch should use dispatch_silu_mul_fused_batch"
);
assert!(
batch_code.contains("dispatch_attention_batch"),
"forward_prefill_batch should use dispatch_attention_batch"
);
assert!(
batch_code.contains("dispatch_copy_kv_both_batch"),
"forward_prefill_batch should use dispatch_copy_kv_both_batch"
);
assert!(
batch_code.contains("dispatch_rope_qk_batch"),
"forward_prefill_batch should use dispatch_rope_qk_batch"
);
}
#[test]
fn q8_forward_prefill_batch_uses_q8_batch_dispatch() {
let config = minimal_q8_config();
let model_rs = generate_model_rs(&config).unwrap();
let batch_start = model_rs
.find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
.unwrap();
let batch_body = &model_rs[batch_start..];
let batch_end = batch_body
.find("\n pub fn reset")
.unwrap_or(batch_body.len());
let batch_code = &batch_body[..batch_end];
assert!(
batch_code.contains("dispatch_matmul_q8_batch"),
"Q8 forward_prefill_batch should use dispatch_matmul_q8_batch"
);
}
#[test]
fn q4_forward_prefill_batch_uses_q4_batch_dispatch() {
let config = minimal_q4_config();
let model_rs = generate_model_rs(&config).unwrap();
let batch_start = model_rs
.find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
.unwrap();
let batch_body = &model_rs[batch_start..];
let batch_end = batch_body
.find("\n pub fn reset")
.unwrap_or(batch_body.len());
let batch_code = &batch_body[..batch_end];
assert!(
batch_code.contains("dispatch_matmul_q4_batch"),
"Q4 forward_prefill_batch should use dispatch_matmul_q4_batch"
);
}
#[test]
fn generated_main_rs_uses_batched_prefill() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
assert!(
main_rs.contains("forward_prefill_batch"),
"main.rs should use forward_prefill_batch for prompt tokens"
);
}
#[test]
fn f32_forward_prefill_batch_uses_f32_batch_dispatch() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let batch_start = model_rs
.find("pub fn forward_prefill_batch(&mut self, tokens: &[u32])")
.unwrap();
let batch_body = &model_rs[batch_start..];
let batch_end = batch_body
.find("\n pub fn reset")
.unwrap_or(batch_body.len());
let batch_code = &batch_body[..batch_end];
assert!(
batch_code.contains("dispatch_matmul_batch"),
"f32 forward_prefill_batch should use dispatch_matmul_batch"
);
assert!(
!batch_code.contains("dispatch_matmul_q8_batch"),
"f32 forward_prefill_batch should not use dispatch_matmul_q8_batch"
);
assert!(
!batch_code.contains("dispatch_matmul_q4_batch"),
"f32 forward_prefill_batch should not use dispatch_matmul_q4_batch"
);
}
#[test]
fn forward_uses_cpu_embedding_lookup() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let forward_start = model_rs
.find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
.unwrap();
let forward_body = &model_rs[forward_start..];
let forward_end = forward_body
.find("\n pub fn forward_profile")
.or_else(|| forward_body.find("\n pub fn forward_prefill"))
.unwrap_or(forward_body.len());
let forward_code = &forward_body[..forward_end];
assert!(
forward_code.contains("embed_buf.contents()"),
"forward() should access embed_buf via CPU unified memory for embedding lookup"
);
assert!(
forward_code.contains("copy_nonoverlapping"),
"forward() should use ptr::copy_nonoverlapping for CPU embedding copy"
);
assert!(
!forward_code.contains("dispatch_copy_offset(&enc, &self.embed_buf"),
"forward() should not use GPU dispatch_copy_offset for embedding (use CPU memcpy instead)"
);
}
#[test]
fn forward_profile_method_exists() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
assert!(
model_rs.contains("pub fn forward_profile(&mut self, token_id: u32) -> Vec<f32>"),
"MetalModel should have forward_profile() method"
);
assert!(
model_rs.contains("[profile]"),
"forward_profile() should print timing with [profile] prefix"
);
assert!(
model_rs.contains("d_embed"),
"forward_profile() should measure embedding time"
);
assert!(
model_rs.contains("d_layers"),
"forward_profile() should measure layer time"
);
assert!(
model_rs.contains("d_logits"),
"forward_profile() should measure logits time"
);
}
#[test]
fn generated_cli_has_profile_flag() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
assert!(
main_rs.contains("--profile"),
"CLI should support --profile flag"
);
assert!(
main_rs.contains("forward_profile"),
"CLI should call forward_profile when --profile is set"
);
}
#[test]
fn generated_cli_has_thermal_yield() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
assert!(
main_rs.contains("yield_now()"),
"CLI generation loop should include thread::yield_now() for thermal management"
);
}
#[test]
fn generated_forward_handles_single_token_prompt() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let forward_start = model_rs
.find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
.expect("forward() must exist");
let forward_body = &model_rs[forward_start..forward_start + 400];
assert!(
!forward_body.contains("assert!(self.pos > 0"),
"forward() must accept pos=0 (first token with no prefill)"
);
assert!(
model_rs.contains("self.pos"),
"forward() should use self.pos to track sequence position"
);
}
#[test]
fn generated_reset_clears_kv_cache_position() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let reset_start = model_rs
.find("pub fn reset(&mut self)")
.expect("reset() must exist");
let reset_body = &model_rs[reset_start..reset_start + 200];
assert!(
reset_body.contains("self.pos = 0"),
"reset() must set self.pos = 0"
);
assert!(
reset_body.contains("self.prev_cmd = None"),
"reset() should clear prev_cmd for clean command buffer state"
);
}
#[test]
fn generated_serve_handles_empty_messages_gracefully() {
let config = minimal_config();
let main_rs = generate_main_rs("test-model", &config).unwrap();
let format_fn_start = main_rs
.find("fn format_chat_messages")
.expect("format_chat_messages must exist");
let format_fn_body =
&main_rs[format_fn_start..format_fn_start + 500.min(main_rs.len() - format_fn_start)];
assert!(
format_fn_body.contains("for msg in messages"),
"format_chat_messages should iterate over the messages slice"
);
assert!(
format_fn_body.contains("<|im_start|>assistant"),
"format_chat_messages should always append assistant prompt header"
);
let serve_fn_start = main_rs
.find("fn serve(")
.expect("serve function must exist");
let serve_fn_body = &main_rs[serve_fn_start..];
assert!(
serve_fn_body.contains("model.reset()"),
"serve function should reset model between requests"
);
}
#[test]
fn generated_model_forward_increments_pos() {
let config = minimal_config();
let model_rs = generate_model_rs(&config).unwrap();
let forward_start = model_rs
.find("pub fn forward(&mut self, token_id: u32) -> Vec<f32>")
.unwrap();
let forward_body = &model_rs[forward_start..];
let forward_end = forward_body
.find("\n pub fn forward_profile")
.or_else(|| forward_body.find("\n pub fn forward_prefill"))
.or_else(|| forward_body.find("\n fn dispatch_"))
.unwrap_or(forward_body.len());
let forward_code = &forward_body[..forward_end];
assert!(
forward_code.contains("self.pos += 1") || forward_code.contains("self.pos +=1"),
"forward() must increment self.pos after processing a token"
);
}
}