#![allow(dead_code, clippy::many_single_char_names)]
use std::collections::HashMap;
pub struct LayerActivations {
pub attn_norm_out: wgpu::Buffer,
pub attn_output: wgpu::Buffer,
pub ffn_norm_out: wgpu::Buffer,
pub silu_gate_output: wgpu::Buffer,
pub rstd_attn: wgpu::Buffer,
pub rstd_ffn: wgpu::Buffer,
pub softmax_logsumexp: wgpu::Buffer,
}
pub struct QkvLoRA<'a> {
pub q_a: &'a wgpu::Buffer,
pub q_b: &'a wgpu::Buffer,
pub k_a: &'a wgpu::Buffer,
pub k_b: &'a wgpu::Buffer,
pub v_a: &'a wgpu::Buffer,
pub v_b: &'a wgpu::Buffer,
pub rank: u32,
pub scale: f32,
pub in_dim: u32,
pub q_dim: u32,
pub kv_dim: u32,
pub lora_pipeline: &'a wgpu::ComputePipeline,
pub lora_bgl: &'a wgpu::BindGroupLayout,
}
pub struct WgslForwardPass {
device: wgpu::Device,
queue: wgpu::Queue,
matmul_pipeline: wgpu::ComputePipeline,
tiled_matmul_pipeline: wgpu::ComputePipeline,
gemv_pipeline: wgpu::ComputePipeline,
q4k_gemv_pipeline: wgpu::ComputePipeline,
attention_pipeline: wgpu::ComputePipeline,
attention_bgl: wgpu::BindGroupLayout,
rmsnorm_pipeline: wgpu::ComputePipeline,
silu_mul_pipeline: wgpu::ComputePipeline,
rope_pipeline: wgpu::ComputePipeline,
batch_rope_pipeline: wgpu::ComputePipeline,
batch_rope_bgl: wgpu::BindGroupLayout,
residual_pipeline: wgpu::ComputePipeline,
matmul_bgl: wgpu::BindGroupLayout,
elementwise_bgl: wgpu::BindGroupLayout,
weight_buffers: HashMap<String, wgpu::Buffer>,
q4k_weights: HashMap<String, wgpu::Buffer>,
cpu_biases: HashMap<String, Vec<f32>>,
kv_cache_k: Vec<wgpu::Buffer>,
kv_cache_v: Vec<wgpu::Buffer>,
hidden_buf: wgpu::Buffer, q_buf: wgpu::Buffer, k_buf: wgpu::Buffer, v_buf: wgpu::Buffer, attn_out_buf: wgpu::Buffer, ffn_gate_buf: wgpu::Buffer, ffn_up_buf: wgpu::Buffer, ffn_silu_buf: wgpu::Buffer, ffn_out_buf: wgpu::Buffer, norm_buf: wgpu::Buffer, staging_buf: wgpu::Buffer,
hidden_dim: u32,
num_heads: u32,
num_kv_heads: u32,
head_dim: u32,
intermediate_dim: u32,
}
const RMSNORM_SHADER: &str = r#"
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read> weight: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: vec4<u32>; // (dim, 0, 0, 0)
var<workgroup> shared_sum: array<f32, 256>;
@compute @workgroup_size(256)
fn main(@builtin(local_invocation_id) lid: vec3<u32>,
@builtin(workgroup_id) wg_id: vec3<u32>) {
let dim = params.x;
let row = wg_id.y;
let base = row * dim;
let tid = lid.x;
// Compute sum of squares (reduction) for this row
var local_sum: f32 = 0.0;
var i = tid;
while (i < dim) {
let val = input[base + i];
local_sum += val * val;
i += 256u;
}
shared_sum[tid] = local_sum;
workgroupBarrier();
// Tree reduction
var stride = 128u;
while (stride > 0u) {
if (tid < stride) {
shared_sum[tid] += shared_sum[tid + stride];
}
workgroupBarrier();
stride >>= 1u;
}
let rms = sqrt(shared_sum[0] / f32(dim) + 1e-6);
// Normalize and scale
i = tid;
while (i < dim) {
output[base + i] = (input[base + i] / rms) * weight[i];
i += 256u;
}
}
"#;
const SILU_MUL_SHADER: &str = r#"
@group(0) @binding(0) var<storage, read> gate: array<f32>;
@group(0) @binding(1) var<storage, read> up: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: vec4<u32>; // (dim, 0, 0, 0)
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= params.x) { return; }
let g = gate[idx];
let silu_g = g / (1.0 + exp(-g));
output[idx] = silu_g * up[idx];
}
"#;
const RESIDUAL_SHADER: &str = r#"
@group(0) @binding(0) var<storage, read> a: array<f32>;
@group(0) @binding(1) var<storage, read> b: array<f32>;
@group(0) @binding(2) var<storage, read_write> output: array<f32>;
@group(0) @binding(3) var<uniform> params: vec4<u32>; // (dim, 0, 0, 0)
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= params.x) { return; }
output[idx] = a[idx] + b[idx];
}
"#;
const BATCH_ROPE_SHADER: &str = r#"
@group(0) @binding(0) var<storage, read_write> qk: array<f32>;
struct RopeParams {
seq_len: u32,
num_heads: u32,
head_dim: u32,
_pad: u32,
}
@group(0) @binding(1) var<uniform> params: RopeParams;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
let total = params.seq_len * params.num_heads * params.head_dim;
if (idx >= total) { return; }
let head_dim = params.head_dim;
let half_hd = head_dim / 2u;
// Decompose idx into (position, head, pos_in_head)
let elements_per_pos = params.num_heads * head_dim;
let position = idx / elements_per_pos;
let within_pos = idx % elements_per_pos;
let head_idx = within_pos / head_dim;
let pos_in_head = within_pos % head_dim;
// Only process the first half of each head (pairs with second half)
if (pos_in_head >= half_hd) { return; }
let theta = pow(1000000.0, -f32(pos_in_head * 2u) / f32(head_dim));
let angle = f32(position) * theta;
let cos_a = cos(angle);
let sin_a = sin(angle);
let base = position * elements_per_pos + head_idx * head_dim;
let i0 = base + pos_in_head;
let i1 = i0 + half_hd;
let x0 = qk[i0];
let x1 = qk[i1];
qk[i0] = x0 * cos_a - x1 * sin_a;
qk[i1] = x0 * sin_a + x1 * cos_a;
}
"#;
const ROPE_SHADER: &str = r#"
@group(0) @binding(0) var<storage, read_write> qk: array<f32>;
@group(0) @binding(1) var<uniform> params: vec4<u32>; // (dim, position, num_heads, head_dim)
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
let dim = params.x;
let position = params.y;
let head_dim = params.w;
if (idx >= dim) { return; }
let half_hd = head_dim / 2u;
let head_idx = idx / head_dim;
let pos_in_head = idx % head_dim;
if (pos_in_head >= half_hd) { return; }
let theta = pow(1000000.0, -f32(pos_in_head * 2u) / f32(head_dim));
let angle = f32(position) * theta;
let cos_a = cos(angle);
let sin_a = sin(angle);
let i0 = head_idx * head_dim + pos_in_head;
let i1 = i0 + half_hd;
let x0 = qk[i0];
let x1 = qk[i1];
qk[i0] = x0 * cos_a - x1 * sin_a;
qk[i1] = x0 * sin_a + x1 * cos_a;
}
"#;
impl WgslForwardPass {
pub fn rmsnorm_shader() -> &'static str {
RMSNORM_SHADER
}
pub fn silu_mul_shader() -> &'static str {
SILU_MUL_SHADER
}
pub fn residual_shader() -> &'static str {
RESIDUAL_SHADER
}
pub fn rope_shader() -> &'static str {
ROPE_SHADER
}
pub fn new(
device: wgpu::Device,
queue: wgpu::Queue,
hidden_dim: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
intermediate_dim: usize,
) -> Self {
let q_dim = num_heads * head_dim;
let kv_dim = num_kv_heads * head_dim;
let matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("matmul"),
source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::MATMUL_SHADER.into()),
});
let rmsnorm_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("rmsnorm"),
source: wgpu::ShaderSource::Wgsl(RMSNORM_SHADER.into()),
});
let silu_mul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("silu_mul"),
source: wgpu::ShaderSource::Wgsl(SILU_MUL_SHADER.into()),
});
let rope_shader_mod = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("rope"),
source: wgpu::ShaderSource::Wgsl(ROPE_SHADER.into()),
});
let residual_shader_mod = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("residual"),
source: wgpu::ShaderSource::Wgsl(RESIDUAL_SHADER.into()),
});
let matmul_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("matmul_bgl"),
entries: &[
bgl_storage(0, true),
bgl_storage(1, true),
bgl_storage(2, false),
bgl_uniform(3),
],
});
let elementwise_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("ew_bgl"),
entries: &[
bgl_storage(0, true),
bgl_storage(1, true),
bgl_storage(2, false),
bgl_uniform(3),
],
});
let make_pipeline =
|shader: &wgpu::ShaderModule, bgl: &wgpu::BindGroupLayout, label: &str| {
let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(label),
bind_group_layouts: &[bgl],
push_constant_ranges: &[],
});
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(label),
layout: Some(&pl),
module: shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
})
};
let matmul_pipeline = make_pipeline(&matmul_shader, &matmul_bgl, "matmul_pipe");
let tiled_matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("tiled_matmul"),
source: wgpu::ShaderSource::Wgsl(
crate::backends::gpu::shaders::TILED_GEMM_SHADER.into(),
),
});
let tiled_matmul_pipeline =
make_pipeline(&tiled_matmul_shader, &matmul_bgl, "tiled_matmul_pipe");
let attention_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("causal_attention"),
source: wgpu::ShaderSource::Wgsl(
crate::backends::gpu::shaders::CAUSAL_ATTENTION_SHADER.into(),
),
});
let attention_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("attn_bgl"),
entries: &[
bgl_storage(0, true), bgl_storage(1, true), bgl_storage(2, true), bgl_storage(3, false), bgl_uniform(4), ],
});
let attention_pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("attn_pl"),
bind_group_layouts: &[&attention_bgl],
push_constant_ranges: &[],
});
let attention_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("attn_pipe"),
layout: Some(&attention_pl),
module: &attention_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let gemv_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("gemv"),
source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::GEMV_SHADER.into()),
});
let gemv_pipeline = make_pipeline(&gemv_shader, &matmul_bgl, "gemv_pipe");
let q4k_gemv_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("q4k_gemv"),
source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::Q4K_GEMV_SHADER.into()),
});
let q4k_gemv_pipeline = make_pipeline(&q4k_gemv_shader, &matmul_bgl, "q4k_gemv_pipe");
let rmsnorm_pipeline = make_pipeline(&rmsnorm_shader, &elementwise_bgl, "rmsnorm_pipe");
let silu_mul_pipeline = make_pipeline(&silu_mul_shader, &elementwise_bgl, "silu_pipe");
let residual_pipeline = make_pipeline(&residual_shader_mod, &elementwise_bgl, "res_pipe");
let rope_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("rope_bgl"),
entries: &[bgl_storage(0, false), bgl_uniform(1)],
});
let rope_pipeline = {
let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("rope_pl"),
bind_group_layouts: &[&rope_bgl],
push_constant_ranges: &[],
});
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("rope_pipe"),
layout: Some(&pl),
module: &rope_shader_mod,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
})
};
let batch_rope_shader_mod = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("batch_rope"),
source: wgpu::ShaderSource::Wgsl(BATCH_ROPE_SHADER.into()),
});
let batch_rope_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("batch_rope_bgl"),
entries: &[bgl_storage(0, false), bgl_uniform(1)],
});
let batch_rope_pipeline = {
let pl = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("batch_rope_pl"),
bind_group_layouts: &[&batch_rope_bgl],
push_constant_ranges: &[],
});
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("batch_rope_pipe"),
layout: Some(&pl),
module: &batch_rope_shader_mod,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
})
};
let buf = |size: usize, label: &str| -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (size * 4) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
})
};
let max_seq = 2048;
let hidden_buf = buf(max_seq * hidden_dim, "hidden");
let q_buf = buf(max_seq * q_dim, "q");
let k_buf = buf(max_seq * kv_dim, "k");
let v_buf = buf(max_seq * kv_dim, "v");
let attn_out_buf = buf(max_seq * hidden_dim, "attn_out");
let ffn_gate_buf = buf(max_seq * intermediate_dim, "ffn_gate");
let ffn_up_buf = buf(max_seq * intermediate_dim, "ffn_up");
let ffn_silu_buf = buf(max_seq * intermediate_dim, "ffn_silu");
let ffn_out_buf = buf(max_seq * hidden_dim, "ffn_out");
let norm_buf = buf(max_seq * hidden_dim, "norm");
let max_out = max_seq * hidden_dim.max(intermediate_dim);
let staging_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("staging"),
size: (max_out * 4) as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
Self {
device,
queue,
matmul_pipeline,
tiled_matmul_pipeline,
attention_pipeline,
attention_bgl,
gemv_pipeline,
q4k_gemv_pipeline,
rmsnorm_pipeline,
silu_mul_pipeline,
rope_pipeline,
batch_rope_pipeline,
batch_rope_bgl,
residual_pipeline,
matmul_bgl,
elementwise_bgl,
weight_buffers: HashMap::new(),
q4k_weights: HashMap::new(),
kv_cache_k: Vec::new(),
kv_cache_v: Vec::new(),
cpu_biases: HashMap::new(),
hidden_buf,
q_buf,
k_buf,
v_buf,
attn_out_buf,
ffn_gate_buf,
ffn_up_buf,
ffn_silu_buf,
ffn_out_buf,
norm_buf,
staging_buf,
hidden_dim: hidden_dim as u32,
num_heads: num_heads as u32,
num_kv_heads: num_kv_heads as u32,
head_dim: head_dim as u32,
intermediate_dim: intermediate_dim as u32,
}
}
pub fn upload_weight(&mut self, name: &str, data: &[f32]) {
if name.contains("bias") {
self.cpu_biases.insert(name.to_string(), data.to_vec());
return;
}
let size_bytes = (data.len() * 4) as u64;
let max_binding = self.device.limits().max_storage_buffer_binding_size as u64;
if size_bytes > max_binding {
eprintln!(
"[wgpu] Skipping weight '{}' ({:.1} MB > {:.1} MB limit) — CPU fallback",
name,
size_bytes as f64 / 1e6,
max_binding as f64 / 1e6
);
return;
}
use wgpu::util::DeviceExt;
let buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(name),
contents: bytemuck::cast_slice(data),
usage: wgpu::BufferUsages::STORAGE,
});
self.weight_buffers.insert(name.to_string(), buffer);
}
pub fn upload_q4k_weight(&mut self, name: &str, data: &[u8]) {
use wgpu::util::DeviceExt;
let buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(name),
contents: data,
usage: wgpu::BufferUsages::STORAGE,
});
self.q4k_weights.insert(name.to_string(), buffer);
}
pub fn init_kv_cache(&mut self, num_layers: usize) {
let kv_dim = (self.num_kv_heads * self.head_dim) as u64;
let max_seq = 2048u64;
for _ in 0..num_layers {
let k = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("kv_cache_k"),
size: max_seq * kv_dim * 4,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let v = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("kv_cache_v"),
size: max_seq * kv_dim * 4,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
self.kv_cache_k.push(k);
self.kv_cache_v.push(v);
}
}
pub fn weight_count(&self) -> usize {
self.weight_buffers.len()
}
pub fn weight_buffer(&self, name: &str) -> Option<&wgpu::Buffer> {
self.weight_buffers.get(name)
}
pub fn device_ref(&self) -> &wgpu::Device {
&self.device
}
pub fn queue_ref(&self) -> &wgpu::Queue {
&self.queue
}
pub fn hidden_buffer(&self) -> &wgpu::Buffer {
&self.hidden_buf
}
pub fn q_buffer(&self) -> &wgpu::Buffer {
&self.q_buf
}
pub fn k_buffer(&self) -> &wgpu::Buffer {
&self.k_buf
}
pub fn v_buffer(&self) -> &wgpu::Buffer {
&self.v_buf
}
pub fn gpu_residual_add(
&self,
a: &wgpu::Buffer,
b: &wgpu::Buffer,
output: &wgpu::Buffer,
len: u32,
) {
let mut encoder = self.device.create_command_encoder(&Default::default());
self.encode_residual(&mut encoder, a, b, output, len);
self.queue.submit(Some(encoder.finish()));
}
pub fn gpu_rmsnorm(&self, weight: &wgpu::Buffer, output: &wgpu::Buffer, seq_len: u32) {
let mut encoder = self.device.create_command_encoder(&Default::default());
self.encode_rmsnorm(&mut encoder, &self.hidden_buf, weight, output, self.hidden_dim);
self.queue.submit(Some(encoder.finish()));
}
pub fn download_hidden(&self, len: usize) -> Vec<f32> {
let size = (len * 4) as u64;
let staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("hidden_download"),
size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let mut encoder = self.device.create_command_encoder(&Default::default());
encoder.copy_buffer_to_buffer(&self.hidden_buf, 0, &staging, 0, size);
self.queue.submit(Some(encoder.finish()));
let slice = staging.slice(..size);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
tx.send(r).ok();
});
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
rx.recv()
.expect("GPU map_async callback channel disconnected")
.expect("GPU buffer mapping failed");
let data = slice.get_mapped_range();
let result: Vec<f32> = bytemuck::cast_slice(&data)[..len].to_vec();
drop(data);
staging.unmap();
result
}
pub fn total_vram_bytes(&self) -> usize {
let weight_bytes: usize = self.weight_buffers.values().map(|b| b.size() as usize).sum();
let intermediate_bytes = (self.hidden_dim as usize * 4) * 4 + (self.num_heads as usize * self.head_dim as usize * 4) + (self.num_kv_heads as usize * self.head_dim as usize * 4) * 2 + (self.intermediate_dim as usize * 4) * 2; weight_bytes + intermediate_bytes
}
#[provable_contracts_macros::contract("wgpu-forward-pass-v1", equation = "rmsnorm_correctness")]
pub fn forward_model(
&self,
token_id: u32,
position: usize,
num_layers: usize,
token_embedding: &[f32],
output_norm_weight: &[f32],
lm_head_weight: &[f32],
vocab_size: usize,
eps: f32,
kv_caches: &mut Vec<(Vec<f32>, Vec<f32>)>,
) -> Result<Vec<f32>, String> {
let hd = self.hidden_dim as usize;
let embed_start = token_id as usize * hd;
if embed_start + hd > token_embedding.len() {
return Err(format!(
"Token {} out of range (embedding size {})",
token_id,
token_embedding.len() / hd
));
}
let mut hidden: Vec<f32> = token_embedding[embed_start..embed_start + hd].to_vec();
while kv_caches.len() < num_layers {
kv_caches.push((Vec::new(), Vec::new()));
}
for layer_idx in 0..num_layers {
let prefix = format!("layer.{layer_idx}");
let (ref mut k_cache, ref mut v_cache) = kv_caches[layer_idx];
self.forward_layer(&mut hidden, &prefix, position, k_cache, v_cache)?;
}
let rms = (hidden.iter().map(|x| x * x).sum::<f32>() / hd as f32 + eps).sqrt();
for i in 0..hd {
hidden[i] = (hidden[i] / rms) * output_norm_weight[i];
}
let mut logits = vec![0.0f32; vocab_size];
for v in 0..vocab_size {
let mut sum = 0.0f32;
let row_start = v * hd;
for j in 0..hd {
sum += lm_head_weight[row_start + j] * hidden[j];
}
logits[v] = sum;
}
Ok(logits)
}
pub fn forward_layer(
&self,
hidden: &mut [f32],
layer_prefix: &str,
_position: usize,
kv_cache_k: &mut Vec<f32>, kv_cache_v: &mut Vec<f32>, ) -> Result<(), String> {
let hd = self.hidden_dim;
self.queue.write_buffer(&self.hidden_buf, 0, bytemuck::cast_slice(hidden));
let mut encoder = self.device.create_command_encoder(&Default::default());
let norm_w = self
.weight_buffers
.get(&format!("{layer_prefix}.attn_norm"))
.ok_or_else(|| format!("Missing {layer_prefix}.attn_norm"))?;
self.encode_rmsnorm(&mut encoder, &self.hidden_buf, norm_w, &self.norm_buf, hd);
let q_dim = self.num_heads * self.head_dim;
let kv_dim = self.num_kv_heads * self.head_dim;
self.encode_matmul(
&mut encoder,
&self.norm_buf,
layer_prefix,
"q_proj",
&self.q_buf,
1,
hd,
q_dim,
);
self.encode_matmul(
&mut encoder,
&self.norm_buf,
layer_prefix,
"k_proj",
&self.k_buf,
1,
hd,
kv_dim,
);
self.encode_matmul(
&mut encoder,
&self.norm_buf,
layer_prefix,
"v_proj",
&self.v_buf,
1,
hd,
kv_dim,
);
let q_bytes = (q_dim * 4) as u64;
let kv_bytes = (kv_dim * 4) as u64;
let q_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("q_stg"),
size: q_bytes,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let k_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("k_stg"),
size: kv_bytes,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let v_staging = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("v_stg"),
size: kv_bytes,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
encoder.copy_buffer_to_buffer(&self.q_buf, 0, &q_staging, 0, q_bytes);
encoder.copy_buffer_to_buffer(&self.k_buf, 0, &k_staging, 0, kv_bytes);
encoder.copy_buffer_to_buffer(&self.v_buf, 0, &v_staging, 0, kv_bytes);
self.queue.submit(Some(encoder.finish()));
let mut q_data = vec![0.0f32; q_dim as usize];
{
let slice = q_staging.slice(..q_bytes);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
tx.send(r).ok();
});
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
rx.recv().map_err(|e| format!("q recv: {e}"))?.map_err(|e| format!("q map: {e:?}"))?;
let data = slice.get_mapped_range();
q_data.copy_from_slice(&bytemuck::cast_slice::<u8, f32>(&data)[..q_dim as usize]);
}
q_staging.unmap();
let mut k_data = vec![0.0f32; kv_dim as usize];
{
let slice = k_staging.slice(..kv_bytes);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
tx.send(r).ok();
});
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
rx.recv().map_err(|e| format!("k recv: {e}"))?.map_err(|e| format!("k map: {e:?}"))?;
let data = slice.get_mapped_range();
k_data.copy_from_slice(&bytemuck::cast_slice::<u8, f32>(&data)[..kv_dim as usize]);
}
k_staging.unmap();
let mut v_data = vec![0.0f32; kv_dim as usize];
{
let slice = v_staging.slice(..kv_bytes);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
tx.send(r).ok();
});
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
rx.recv().map_err(|e| format!("v recv: {e}"))?.map_err(|e| format!("v map: {e:?}"))?;
let data = slice.get_mapped_range();
v_data.copy_from_slice(&bytemuck::cast_slice::<u8, f32>(&data)[..kv_dim as usize]);
}
v_staging.unmap();
if let Some(q_bias) = self.cpu_biases.get(&format!("{layer_prefix}.q_bias")) {
for (q, b) in q_data.iter_mut().zip(q_bias.iter()) {
*q += *b;
}
}
if let Some(k_bias) = self.cpu_biases.get(&format!("{layer_prefix}.k_bias")) {
for (k, b) in k_data.iter_mut().zip(k_bias.iter()) {
*k += *b;
}
}
if let Some(v_bias) = self.cpu_biases.get(&format!("{layer_prefix}.v_bias")) {
for (v, b) in v_data.iter_mut().zip(v_bias.iter()) {
*v += *b;
}
}
let head_dim = self.head_dim as usize;
let position = _position; let rope_theta = 1_000_000.0f64;
for h in 0..(self.num_heads as usize) {
let offset = h * head_dim;
let half = head_dim / 2;
for i in 0..half {
let theta = rope_theta.powf(-((2 * i) as f64) / head_dim as f64);
let angle = position as f64 * theta;
let cos_a = angle.cos() as f32;
let sin_a = angle.sin() as f32;
let x0 = q_data[offset + i];
let x1 = q_data[offset + i + half];
q_data[offset + i] = x0 * cos_a - x1 * sin_a;
q_data[offset + i + half] = x0 * sin_a + x1 * cos_a;
}
}
for h in 0..(self.num_kv_heads as usize) {
let offset = h * head_dim;
let half = head_dim / 2;
for i in 0..half {
let theta = rope_theta.powf(-((2 * i) as f64) / head_dim as f64);
let angle = position as f64 * theta;
let cos_a = angle.cos() as f32;
let sin_a = angle.sin() as f32;
let x0 = k_data[offset + i];
let x1 = k_data[offset + i + half];
k_data[offset + i] = x0 * cos_a - x1 * sin_a;
k_data[offset + i + half] = x0 * sin_a + x1 * cos_a;
}
}
let head_dim = self.head_dim as usize;
let num_heads = self.num_heads as usize;
let num_kv_heads = self.num_kv_heads as usize;
let kv_dim_usize = kv_dim as usize;
kv_cache_k.extend_from_slice(&k_data);
kv_cache_v.extend_from_slice(&v_data);
let seq_len = kv_cache_k.len() / kv_dim_usize;
let kv_group = num_heads / num_kv_heads;
let scale = 1.0 / (head_dim as f32).sqrt();
let mut attn_out = vec![0.0f32; q_dim as usize];
for h in 0..num_heads {
let kv_h = h / kv_group;
let q_offset = h * head_dim;
let mut scores = vec![0.0f32; seq_len];
for s in 0..seq_len {
let k_offset = s * kv_dim_usize + kv_h * head_dim;
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q_data[q_offset + d] * kv_cache_k[k_offset + d];
}
scores[s] = dot * scale;
}
let max_score = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for s in scores.iter_mut() {
*s = (*s - max_score).exp();
sum += *s;
}
if sum > 0.0 {
for s in scores.iter_mut() {
*s /= sum;
}
}
let out_offset = h * head_dim;
for d in 0..head_dim {
let mut val = 0.0f32;
for s in 0..seq_len {
let v_offset = s * kv_dim_usize + kv_h * head_dim;
val += scores[s] * kv_cache_v[v_offset + d];
}
attn_out[out_offset + d] = val;
}
}
self.queue.write_buffer(&self.q_buf, 0, bytemuck::cast_slice(&attn_out));
let mut encoder = self.device.create_command_encoder(&Default::default());
self.encode_matmul(
&mut encoder,
&self.q_buf,
layer_prefix,
"o_proj",
&self.attn_out_buf,
1,
q_dim,
hd,
);
self.encode_residual(
&mut encoder,
&self.hidden_buf,
&self.attn_out_buf,
&self.ffn_out_buf,
hd,
);
let ffn_norm_w = self
.weight_buffers
.get(&format!("{layer_prefix}.ffn_norm"))
.ok_or_else(|| format!("Missing {layer_prefix}.ffn_norm"))?;
self.encode_rmsnorm(&mut encoder, &self.ffn_out_buf, ffn_norm_w, &self.norm_buf, hd);
let inter = self.intermediate_dim;
self.encode_matmul(
&mut encoder,
&self.norm_buf,
layer_prefix,
"gate_proj",
&self.ffn_gate_buf,
1,
hd,
inter,
);
self.encode_matmul(
&mut encoder,
&self.norm_buf,
layer_prefix,
"up_proj",
&self.ffn_up_buf,
1,
hd,
inter,
);
self.encode_silu_mul(
&mut encoder,
&self.ffn_gate_buf,
&self.ffn_up_buf,
&self.ffn_silu_buf,
inter,
);
self.encode_matmul(
&mut encoder,
&self.ffn_silu_buf,
layer_prefix,
"down_proj",
&self.norm_buf,
1,
inter,
hd,
);
self.encode_residual(&mut encoder, &self.ffn_out_buf, &self.norm_buf, &self.hidden_buf, hd);
encoder.copy_buffer_to_buffer(&self.hidden_buf, 0, &self.staging_buf, 0, (hd * 4) as u64);
self.queue.submit(Some(encoder.finish()));
let slice = self.staging_buf.slice(..(hd as u64 * 4));
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
tx.send(r).ok();
});
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
rx.recv().map_err(|e| format!("recv: {e}"))?.map_err(|e| format!("map: {e:?}"))?;
{
let data = slice.get_mapped_range();
hidden.copy_from_slice(
&bytemuck::cast_slice::<u8, f32>(&data)[..self.hidden_dim as usize],
);
}
self.staging_buf.unmap();
Ok(())
}
pub fn encode_forward_layer_training(
&self,
encoder: &mut wgpu::CommandEncoder,
seq_len: u32,
layer_prefix: &str,
saved: &LayerActivations,
lora: Option<&QkvLoRA<'_>>,
) -> Result<(), String> {
let hd = self.hidden_dim;
let q_dim = self.num_heads * self.head_dim;
let kv_dim = self.num_kv_heads * self.head_dim;
let inter = self.intermediate_dim;
let s = seq_len as usize;
let norm_w = self
.weight_buffers
.get(&format!("{layer_prefix}.attn_norm"))
.ok_or_else(|| format!("Missing {layer_prefix}.attn_norm"))?;
self.encode_rmsnorm(encoder, &self.hidden_buf, norm_w, &self.norm_buf, hd);
encoder.copy_buffer_to_buffer(
&self.norm_buf,
0,
&saved.attn_norm_out,
0,
(s * hd as usize * 4) as u64,
);
self.encode_matmul(
encoder,
&self.norm_buf,
layer_prefix,
"q_proj",
&self.q_buf,
seq_len,
hd,
q_dim,
);
self.encode_matmul(
encoder,
&self.norm_buf,
layer_prefix,
"k_proj",
&self.k_buf,
seq_len,
hd,
kv_dim,
);
self.encode_matmul(
encoder,
&self.norm_buf,
layer_prefix,
"v_proj",
&self.v_buf,
seq_len,
hd,
kv_dim,
);
if let Some(lora) = lora {
self.encode_lora_addmm(
encoder,
&saved.attn_norm_out,
lora.q_a,
lora.q_b,
&self.q_buf,
seq_len,
lora.in_dim,
lora.rank,
lora.q_dim,
lora.scale,
lora.lora_pipeline,
lora.lora_bgl,
);
self.encode_lora_addmm(
encoder,
&saved.attn_norm_out,
lora.k_a,
lora.k_b,
&self.k_buf,
seq_len,
lora.in_dim,
lora.rank,
lora.kv_dim,
lora.scale,
lora.lora_pipeline,
lora.lora_bgl,
);
self.encode_lora_addmm(
encoder,
&saved.attn_norm_out,
lora.v_a,
lora.v_b,
&self.v_buf,
seq_len,
lora.in_dim,
lora.rank,
lora.kv_dim,
lora.scale,
lora.lora_pipeline,
lora.lora_bgl,
);
}
if let Some(q_bias) = self.cpu_biases.get(&format!("{layer_prefix}.q_bias")) {
self.encode_broadcast_bias(encoder, &self.q_buf, q_bias, seq_len);
}
if let Some(k_bias) = self.cpu_biases.get(&format!("{layer_prefix}.k_bias")) {
self.encode_broadcast_bias(encoder, &self.k_buf, k_bias, seq_len);
}
if let Some(v_bias) = self.cpu_biases.get(&format!("{layer_prefix}.v_bias")) {
self.encode_broadcast_bias(encoder, &self.v_buf, v_bias, seq_len);
}
self.encode_batch_rope(encoder, &self.q_buf, seq_len, self.num_heads, self.head_dim);
self.encode_batch_rope(encoder, &self.k_buf, seq_len, self.num_kv_heads, self.head_dim);
self.encode_attention(encoder, seq_len);
encoder.copy_buffer_to_buffer(
&self.attn_out_buf,
0,
&saved.attn_output,
0,
(s * q_dim as usize * 4) as u64,
);
self.encode_matmul(
encoder,
&self.attn_out_buf,
layer_prefix,
"o_proj",
&self.q_buf,
seq_len,
q_dim,
hd,
);
self.encode_residual(
encoder,
&self.hidden_buf,
&self.q_buf,
&self.ffn_out_buf,
hd * seq_len,
);
let ffn_norm_w = self
.weight_buffers
.get(&format!("{layer_prefix}.ffn_norm"))
.ok_or_else(|| format!("Missing {layer_prefix}.ffn_norm"))?;
self.encode_rmsnorm(encoder, &self.ffn_out_buf, ffn_norm_w, &self.norm_buf, hd);
encoder.copy_buffer_to_buffer(
&self.norm_buf,
0,
&saved.ffn_norm_out,
0,
(s * hd as usize * 4) as u64,
);
self.encode_matmul(
encoder,
&self.norm_buf,
layer_prefix,
"gate_proj",
&self.ffn_gate_buf,
seq_len,
hd,
inter,
);
self.encode_matmul(
encoder,
&self.norm_buf,
layer_prefix,
"up_proj",
&self.ffn_up_buf,
seq_len,
hd,
inter,
);
self.encode_silu_mul(
encoder,
&self.ffn_gate_buf,
&self.ffn_up_buf,
&self.ffn_silu_buf,
inter * seq_len,
);
encoder.copy_buffer_to_buffer(
&self.ffn_silu_buf,
0,
&saved.silu_gate_output,
0,
(s * inter as usize * 4) as u64,
);
self.encode_matmul(
encoder,
&self.ffn_silu_buf,
layer_prefix,
"down_proj",
&self.norm_buf,
seq_len,
inter,
hd,
);
self.encode_residual(
encoder,
&self.ffn_out_buf,
&self.norm_buf,
&self.hidden_buf,
hd * seq_len,
);
Ok(())
}
pub fn forward_layer_traced(
&self,
seq_len: u32,
layer_prefix: &str,
saved: &LayerActivations,
lora: Option<&QkvLoRA<'_>>,
) -> Result<(), String> {
let hd = self.hidden_dim;
let q_dim = self.num_heads * self.head_dim;
let kv_dim = self.num_kv_heads * self.head_dim;
let inter = self.intermediate_dim;
let s = seq_len as usize;
let norm_w = self
.weight_buffers
.get(&format!("{layer_prefix}.attn_norm"))
.ok_or_else(|| format!("Missing {layer_prefix}.attn_norm"))?;
let mut trace = Vec::new();
let mut run = |name: &str, f: &dyn Fn(&mut wgpu::CommandEncoder)| {
let mut enc = self.device.create_command_encoder(&Default::default());
f(&mut enc);
self.queue.submit(Some(enc.finish()));
let t = std::time::Instant::now();
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
trace.push((name.to_string(), t.elapsed().as_millis() as u64));
};
run("rmsnorm1", &|e| self.encode_rmsnorm(e, &self.hidden_buf, norm_w, &self.norm_buf, hd));
{
let mut e = self.device.create_command_encoder(&Default::default());
e.copy_buffer_to_buffer(
&self.norm_buf,
0,
&saved.attn_norm_out,
0,
(s * hd as usize * 4) as u64,
);
self.queue.submit(Some(e.finish()));
}
run("q_proj", &|e| {
self.encode_matmul(
e,
&self.norm_buf,
layer_prefix,
"q_proj",
&self.q_buf,
seq_len,
hd,
q_dim,
)
});
run("k_proj", &|e| {
self.encode_matmul(
e,
&self.norm_buf,
layer_prefix,
"k_proj",
&self.k_buf,
seq_len,
hd,
kv_dim,
)
});
run("v_proj", &|e| {
self.encode_matmul(
e,
&self.norm_buf,
layer_prefix,
"v_proj",
&self.v_buf,
seq_len,
hd,
kv_dim,
)
});
if let Some(lr) = lora {
run("lora_qkv", &|e| {
self.encode_lora_addmm(
e,
&saved.attn_norm_out,
lr.q_a,
lr.q_b,
&self.q_buf,
seq_len,
lr.in_dim,
lr.rank,
lr.q_dim,
lr.scale,
lr.lora_pipeline,
lr.lora_bgl,
);
self.encode_lora_addmm(
e,
&saved.attn_norm_out,
lr.k_a,
lr.k_b,
&self.k_buf,
seq_len,
lr.in_dim,
lr.rank,
lr.kv_dim,
lr.scale,
lr.lora_pipeline,
lr.lora_bgl,
);
self.encode_lora_addmm(
e,
&saved.attn_norm_out,
lr.v_a,
lr.v_b,
&self.v_buf,
seq_len,
lr.in_dim,
lr.rank,
lr.kv_dim,
lr.scale,
lr.lora_pipeline,
lr.lora_bgl,
);
});
}
if let Some(q_bias) = self.cpu_biases.get(&format!("{layer_prefix}.q_bias")) {
run("q_bias", &|e| self.encode_broadcast_bias(e, &self.q_buf, q_bias, seq_len));
}
if let Some(k_bias) = self.cpu_biases.get(&format!("{layer_prefix}.k_bias")) {
run("k_bias", &|e| self.encode_broadcast_bias(e, &self.k_buf, k_bias, seq_len));
}
if let Some(v_bias) = self.cpu_biases.get(&format!("{layer_prefix}.v_bias")) {
run("v_bias", &|e| self.encode_broadcast_bias(e, &self.v_buf, v_bias, seq_len));
}
run("rope_q", &|e| {
self.encode_batch_rope(e, &self.q_buf, seq_len, self.num_heads, self.head_dim)
});
run("rope_k", &|e| {
self.encode_batch_rope(e, &self.k_buf, seq_len, self.num_kv_heads, self.head_dim)
});
run("attention", &|e| self.encode_attention(e, seq_len));
{
let mut e = self.device.create_command_encoder(&Default::default());
e.copy_buffer_to_buffer(
&self.attn_out_buf,
0,
&saved.attn_output,
0,
(s * q_dim as usize * 4) as u64,
);
self.queue.submit(Some(e.finish()));
}
run("o_proj", &|e| {
self.encode_matmul(
e,
&self.attn_out_buf,
layer_prefix,
"o_proj",
&self.q_buf,
seq_len,
q_dim,
hd,
)
});
run("residual1", &|e| {
self.encode_residual(e, &self.hidden_buf, &self.q_buf, &self.ffn_out_buf, hd * seq_len)
});
let ffn_norm_w = self
.weight_buffers
.get(&format!("{layer_prefix}.ffn_norm"))
.ok_or_else(|| format!("Missing {layer_prefix}.ffn_norm"))?;
run("rmsnorm2", &|e| {
self.encode_rmsnorm(e, &self.ffn_out_buf, ffn_norm_w, &self.norm_buf, hd)
});
{
let mut e = self.device.create_command_encoder(&Default::default());
e.copy_buffer_to_buffer(
&self.norm_buf,
0,
&saved.ffn_norm_out,
0,
(s * hd as usize * 4) as u64,
);
self.queue.submit(Some(e.finish()));
}
run("gate_proj", &|e| {
self.encode_matmul(
e,
&self.norm_buf,
layer_prefix,
"gate_proj",
&self.ffn_gate_buf,
seq_len,
hd,
inter,
)
});
run("up_proj", &|e| {
self.encode_matmul(
e,
&self.norm_buf,
layer_prefix,
"up_proj",
&self.ffn_up_buf,
seq_len,
hd,
inter,
)
});
run("silu", &|e| {
self.encode_silu_mul(
e,
&self.ffn_gate_buf,
&self.ffn_up_buf,
&self.ffn_silu_buf,
inter * seq_len,
)
});
{
let mut e = self.device.create_command_encoder(&Default::default());
e.copy_buffer_to_buffer(
&self.ffn_silu_buf,
0,
&saved.silu_gate_output,
0,
(s * inter as usize * 4) as u64,
);
self.queue.submit(Some(e.finish()));
}
run("down_proj", &|e| {
self.encode_matmul(
e,
&self.ffn_silu_buf,
layer_prefix,
"down_proj",
&self.norm_buf,
seq_len,
inter,
hd,
)
});
run("residual2", &|e| {
self.encode_residual(
e,
&self.ffn_out_buf,
&self.norm_buf,
&self.hidden_buf,
hd * seq_len,
)
});
let total: u64 = trace.iter().map(|(_, ms)| ms).sum();
let parts: Vec<String> = trace.iter().map(|(n, ms)| format!("{n}={ms}")).collect();
eprintln!("[OP-TRACE] layer {} total={}ms: {}", layer_prefix, total, parts.join(" "));
Ok(())
}
pub fn alloc_layer_activations(&self, seq_len: u32) -> LayerActivations {
let s = seq_len as usize;
let buf = |size: usize, label: &str| -> wgpu::Buffer {
self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: (size * 4) as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
})
};
LayerActivations {
attn_norm_out: buf(s * self.hidden_dim as usize, "saved_attn_norm"),
attn_output: buf(s * (self.num_heads * self.head_dim) as usize, "saved_attn_out"),
ffn_norm_out: buf(s * self.hidden_dim as usize, "saved_ffn_norm"),
silu_gate_output: buf(s * self.intermediate_dim as usize, "saved_silu"),
rstd_attn: buf(s, "saved_rstd_attn"),
rstd_ffn: buf(s, "saved_rstd_ffn"),
softmax_logsumexp: buf(self.num_heads as usize * s, "saved_logsumexp"),
}
}
pub fn forward_layer_training(
&self,
seq_len: u32,
layer_prefix: &str,
) -> Result<LayerActivations, String> {
let saved = self.alloc_layer_activations(seq_len);
let mut encoder = self.device.create_command_encoder(&Default::default());
self.encode_forward_layer_training(&mut encoder, seq_len, layer_prefix, &saved, None)?;
self.queue.submit(Some(encoder.finish()));
Ok(saved)
}
pub fn forward_all_layers_training(
&self,
seq_len: u32,
num_layers: usize,
) -> Result<Vec<LayerActivations>, String> {
let mut encoder = self.device.create_command_encoder(&Default::default());
let mut all_saved = Vec::with_capacity(num_layers);
for layer_idx in 0..num_layers {
let prefix = format!("layer.{layer_idx}");
let saved = self.alloc_layer_activations(seq_len);
self.encode_forward_layer_training(&mut encoder, seq_len, &prefix, &saved, None)?;
all_saved.push(saved);
}
self.queue.submit(Some(encoder.finish()));
Ok(all_saved)
}
pub fn encode_broadcast_bias(
&self,
encoder: &mut wgpu::CommandEncoder,
buf: &wgpu::Buffer,
bias: &[f32],
seq_len: u32,
) {
let dim = bias.len();
let mut full_bias = Vec::with_capacity(seq_len as usize * dim);
for _ in 0..seq_len {
full_bias.extend_from_slice(bias);
}
let bias_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("broadcast_bias"),
size: (full_bias.len() * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.queue.write_buffer(&bias_buf, 0, bytemuck::cast_slice(&full_bias));
let total = seq_len * dim as u32;
let tmp = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("bias_tmp"),
size: (total as usize * 4) as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
self.encode_residual(encoder, buf, &bias_buf, &tmp, total);
encoder.copy_buffer_to_buffer(&tmp, 0, buf, 0, (total as u64) * 4);
}
fn encode_batch_rope(
&self,
encoder: &mut wgpu::CommandEncoder,
qk_buf: &wgpu::Buffer,
seq_len: u32,
num_heads: u32,
head_dim: u32,
) {
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct RopeParams {
seq_len: u32,
num_heads: u32,
head_dim: u32,
_pad: u32,
}
let params = RopeParams { seq_len, num_heads, head_dim, _pad: 0 };
let params_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("batch_rope_params"),
size: 16,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.queue.write_buffer(¶ms_buf, 0, bytemuck::bytes_of(¶ms));
let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("batch_rope_bg"),
layout: &self.batch_rope_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: qk_buf.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: params_buf.as_entire_binding() },
],
});
let total = seq_len * num_heads * head_dim;
let wg = total.div_ceil(256);
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&self.batch_rope_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(wg, 1, 1);
}
fn encode_attention(&self, encoder: &mut wgpu::CommandEncoder, seq_len: u32) {
let params = [seq_len, self.num_heads, self.num_kv_heads, self.head_dim];
let params_buf = self.make_uniform(¶ms);
let q_dim = self.num_heads * self.head_dim;
let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.attention_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: self.q_buf.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: self.k_buf.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: self.v_buf.as_entire_binding() },
wgpu::BindGroupEntry {
binding: 3,
resource: self.attn_out_buf.as_entire_binding(),
},
wgpu::BindGroupEntry { binding: 4, resource: params_buf.as_entire_binding() },
],
});
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&self.attention_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(self.num_heads, seq_len, 1);
}
#[allow(clippy::too_many_arguments)]
fn encode_lora_addmm(
&self,
encoder: &mut wgpu::CommandEncoder,
input: &wgpu::Buffer,
lora_a: &wgpu::Buffer,
lora_b: &wgpu::Buffer,
output: &wgpu::Buffer,
seq_len: u32,
in_dim: u32,
rank: u32,
out_dim: u32,
scale: f32,
_pipeline: &wgpu::ComputePipeline,
_bgl: &wgpu::BindGroupLayout,
) {
let temp_size = (seq_len * rank) as u64 * 4;
let temp = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("lora_temp"),
size: temp_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
self.encode_tiled_gemm(encoder, input, lora_a, &temp, seq_len, in_dim, rank, 1.0);
let delta_size = (seq_len * out_dim) as u64 * 4;
let delta = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("lora_delta"),
size: delta_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
self.encode_tiled_gemm(encoder, &temp, lora_b, &delta, seq_len, rank, out_dim, scale);
let sum_buf = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("lora_sum"),
size: delta_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
self.encode_residual(encoder, output, &delta, &sum_buf, seq_len * out_dim);
encoder.copy_buffer_to_buffer(&sum_buf, 0, output, 0, delta_size);
}
fn encode_tiled_gemm(
&self,
encoder: &mut wgpu::CommandEncoder,
a: &wgpu::Buffer,
b: &wgpu::Buffer,
c: &wgpu::Buffer,
m: u32,
k: u32,
n: u32,
alpha: f32,
) {
let params = [m, k, n, alpha.to_bits()];
let params_buf = self.make_uniform(¶ms);
let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.matmul_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: a.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: b.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: c.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
});
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&self.tiled_matmul_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(n.div_ceil(64), m.div_ceil(64), 1);
}
fn encode_rmsnorm(
&self,
encoder: &mut wgpu::CommandEncoder,
input: &wgpu::Buffer,
weight: &wgpu::Buffer,
output: &wgpu::Buffer,
dim: u32,
) {
let params = [dim, 0u32, 0, 0];
let params_buf = self.make_uniform(¶ms);
let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.elementwise_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: weight.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
});
let num_rows = (input.size() / (dim as u64 * 4)).max(1) as u32;
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&self.rmsnorm_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(1, num_rows, 1);
}
fn encode_matmul(
&self,
encoder: &mut wgpu::CommandEncoder,
input: &wgpu::Buffer,
layer_prefix: &str,
proj_name: &str,
output: &wgpu::Buffer,
m: u32,
k: u32,
n: u32,
) {
if m == 1 {
if self.encode_q4k_gemv(encoder, input, output, layer_prefix, proj_name, n, k) {
return;
}
}
let weight_key = format!("{layer_prefix}.{proj_name}");
let weight = match self.weight_buffers.get(&weight_key) {
Some(w) => w,
None => return, };
let params = if m == 1 { [n, k, 0u32, 0u32] } else { [m, k, n, 1.0_f32.to_bits()] };
let params_buf = self.make_uniform(¶ms);
let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.matmul_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: weight.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
});
let mut pass = encoder.begin_compute_pass(&Default::default());
if m == 1 {
pass.set_pipeline(&self.gemv_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(n, 1, 1);
} else if m >= 4 {
pass.set_pipeline(&self.tiled_matmul_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(n.div_ceil(64), m.div_ceil(64), 1);
} else {
pass.set_pipeline(&self.matmul_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(m.div_ceil(16), n.div_ceil(16), 1);
}
}
fn encode_q4k_gemv(
&self,
encoder: &mut wgpu::CommandEncoder,
input: &wgpu::Buffer,
output: &wgpu::Buffer,
layer_prefix: &str,
proj_name: &str,
n: u32,
k: u32,
) -> bool {
let weight_key = format!("{layer_prefix}.{proj_name}");
let weight = match self.q4k_weights.get(&weight_key) {
Some(w) => w,
None => return false,
};
let num_superblocks = (k + 255) / 256;
let params = [n, k, num_superblocks, 0u32];
let params_buf = self.make_uniform(¶ms);
let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.matmul_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: weight.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
});
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&self.q4k_gemv_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(n, 1, 1);
true
}
fn encode_silu_mul(
&self,
encoder: &mut wgpu::CommandEncoder,
gate: &wgpu::Buffer,
up: &wgpu::Buffer,
output: &wgpu::Buffer,
dim: u32,
) {
let params = [dim, 0u32, 0, 0];
let params_buf = self.make_uniform(¶ms);
let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.elementwise_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: gate.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: up.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
});
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&self.silu_mul_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(dim.div_ceil(256), 1, 1);
}
fn encode_residual(
&self,
encoder: &mut wgpu::CommandEncoder,
a: &wgpu::Buffer,
b: &wgpu::Buffer,
output: &wgpu::Buffer,
dim: u32,
) {
let params = [dim, 0u32, 0, 0];
let params_buf = self.make_uniform(¶ms);
let bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.elementwise_bgl,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: a.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: b.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: output.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: params_buf.as_entire_binding() },
],
});
let mut pass = encoder.begin_compute_pass(&Default::default());
pass.set_pipeline(&self.residual_pipeline);
pass.set_bind_group(0, &bg, &[]);
pass.dispatch_workgroups(dim.div_ceil(256), 1, 1);
}
fn make_uniform(&self, data: &[u32; 4]) -> wgpu::Buffer {
use wgpu::util::DeviceExt;
self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: None,
contents: bytemuck::cast_slice(data),
usage: wgpu::BufferUsages::UNIFORM,
})
}
}
fn bgl_storage(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
fn bgl_uniform(binding: u32) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}