use std::collections::HashMap;
pub struct GpuMatmulCache {
device: wgpu::Device,
queue: wgpu::Queue,
pipeline: wgpu::ComputePipeline,
tiled_pipeline: wgpu::ComputePipeline,
gemv_pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
weight_buffers: HashMap<String, WeightEntry>,
input_buffer: Option<wgpu::Buffer>,
input_size: u64,
output_buffer: Option<wgpu::Buffer>,
output_size: u64,
dims_buffer: Option<wgpu::Buffer>,
staging_size: u64,
staging_buffer: Option<wgpu::Buffer>,
}
struct WeightEntry {
buffer: wgpu::Buffer,
rows: usize,
cols: usize,
}
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct Dimensions {
m: u32,
k: u32,
n: u32,
alpha_bits: u32,
}
impl GpuMatmulCache {
pub fn new(device: wgpu::Device, queue: wgpu::Queue) -> Self {
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("CachedMatmul Shader"),
source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::MATMUL_SHADER.into()),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("CachedMatmul BGL"),
entries: &[
bgl_entry(0, true), bgl_entry(1, true), bgl_entry(2, false), wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("CachedMatmul PL"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("CachedMatmul Pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let tiled_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("TiledGEMM Shader"),
source: wgpu::ShaderSource::Wgsl(
crate::backends::gpu::shaders::TILED_GEMM_SHADER.into(),
),
});
let tiled_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("TiledGEMM Pipeline"),
layout: Some(&pipeline_layout),
module: &tiled_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let gemv_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("GEMV Shader"),
source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::GEMV_SHADER.into()),
});
let gemv_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("GEMV Pipeline"),
layout: Some(&pipeline_layout),
module: &gemv_shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
Self {
device,
queue,
pipeline,
tiled_pipeline,
gemv_pipeline,
bind_group_layout,
weight_buffers: HashMap::new(),
input_buffer: None,
input_size: 0,
output_buffer: None,
output_size: 0,
dims_buffer: None,
staging_size: 0,
staging_buffer: None,
}
}
pub fn upload_weight(&mut self, name: &str, data: &[f32], rows: usize, cols: usize) {
assert_eq!(data.len(), rows * cols, "weight size mismatch");
let size_bytes = (data.len() * 4) as u64;
let max_binding = self.device.limits().max_storage_buffer_binding_size as u64;
if size_bytes > max_binding {
eprintln!(
"[wgpu] Skipping weight '{}' ({:.1} MB > {:.1} MB max binding) — will use CPU fallback",
name,
size_bytes as f64 / 1e6,
max_binding as f64 / 1e6
);
return;
}
let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(name),
size: size_bytes,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.queue.write_buffer(&buffer, 0, bytemuck::cast_slice(data));
self.weight_buffers.insert(name.to_string(), WeightEntry { buffer, rows, cols });
}
pub fn weight_count(&self) -> usize {
self.weight_buffers.len()
}
pub fn weight_bytes(&self) -> usize {
self.weight_buffers.values().map(|w| w.rows * w.cols * 4).sum()
}
fn ensure_input_buffer(&mut self, size: u64) {
if self.input_size < size {
self.input_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("persistent_input"),
size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
}));
self.input_size = size;
}
}
fn ensure_output_buffer(&mut self, size: u64) {
if self.output_size < size {
self.output_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("persistent_output"),
size,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_SRC
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
}));
self.output_size = size;
}
}
fn ensure_dims_buffer(&mut self) {
if self.dims_buffer.is_none() {
self.dims_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("persistent_dims"),
size: 16,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
}));
}
}
fn ensure_staging_buffer(&mut self, size: u64) {
if self.staging_size < size {
self.staging_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("persistent_staging"),
size,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
}));
self.staging_size = size;
}
}
pub fn matmul_cached(
&mut self,
weight_name: &str,
input: &[f32],
output: &mut [f32],
m: usize,
) -> Result<(), String> {
let (k, n) = {
let entry = self
.weight_buffers
.get(weight_name)
.ok_or_else(|| format!("Weight '{}' not uploaded", weight_name))?;
(entry.cols, entry.rows)
};
if input.len() < m * k {
return Err(format!("input too small: need {}, have {}", m * k, input.len()));
}
if output.len() < m * n {
return Err(format!("output too small: need {}, have {}", m * n, output.len()));
}
let input_bytes = (m * k * 4) as u64;
let output_bytes = (m * n * 4) as u64;
self.ensure_input_buffer(input_bytes);
self.ensure_output_buffer(output_bytes);
self.ensure_dims_buffer();
self.ensure_staging_buffer(output_bytes);
let input_buf = self.input_buffer.as_ref().unwrap();
self.queue.write_buffer(input_buf, 0, bytemuck::cast_slice(&input[..m * k]));
let dims = if m == 1 {
Dimensions { m: n as u32, k: k as u32, n: 0, alpha_bits: 1.0_f32.to_bits() }
} else {
Dimensions { m: m as u32, k: k as u32, n: n as u32, alpha_bits: 1.0_f32.to_bits() }
};
let dims_buf = self.dims_buffer.as_ref().unwrap();
self.queue.write_buffer(dims_buf, 0, bytemuck::bytes_of(&dims));
let output_buf = self.output_buffer.as_ref().unwrap();
let weight_buf = &self
.weight_buffers
.get(weight_name)
.ok_or_else(|| {
format!("weight '{}' not loaded — call load_weight() first", weight_name)
})?
.buffer;
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry { binding: 0, resource: input_buf.as_entire_binding() },
wgpu::BindGroupEntry { binding: 1, resource: weight_buf.as_entire_binding() },
wgpu::BindGroupEntry { binding: 2, resource: output_buf.as_entire_binding() },
wgpu::BindGroupEntry { binding: 3, resource: dims_buf.as_entire_binding() },
],
});
let staging = self.staging_buffer.as_ref().unwrap();
let mut encoder =
self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("matmul"),
timestamp_writes: None,
});
if m == 1 {
pass.set_pipeline(&self.gemv_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(n as u32, 1, 1);
} else if m >= 4 {
pass.set_pipeline(&self.tiled_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((n as u32).div_ceil(64), (m as u32).div_ceil(64), 1);
} else {
pass.set_pipeline(&self.pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups((m as u32).div_ceil(16), (n as u32).div_ceil(16), 1);
}
}
encoder.copy_buffer_to_buffer(output_buf, 0, staging, 0, output_bytes);
self.queue.submit(Some(encoder.finish()));
let slice = staging.slice(..output_bytes);
let (tx, rx) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |r| {
tx.send(r).ok();
});
self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
rx.recv().map_err(|e| format!("recv: {e}"))?.map_err(|e| format!("map: {e:?}"))?;
{
let data = slice.get_mapped_range();
output[..m * n].copy_from_slice(bytemuck::cast_slice(&data));
}
staging.unmap();
Ok(())
}
}
fn bgl_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dimensions_layout() {
let dims = Dimensions { m: 1, k: 1536, n: 1536, alpha_bits: 1.0_f32.to_bits() };
let bytes = bytemuck::bytes_of(&dims);
assert_eq!(bytes.len(), 16); assert_eq!(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), 1);
assert_eq!(u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]), 1536);
}
#[test]
fn test_gemv_params_layout() {
let m = 1usize;
let k = 1536usize;
let n = 256usize;
let dims = if m == 1 {
Dimensions { m: n as u32, k: k as u32, n: 0, alpha_bits: 1.0_f32.to_bits() }
} else {
Dimensions { m: m as u32, k: k as u32, n: n as u32, alpha_bits: 1.0_f32.to_bits() }
};
let bytes = bytemuck::bytes_of(&dims);
let gemv_n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
assert_eq!(gemv_n, 256, "GEMV params.n must be output dimension, not m");
}
#[test]
fn test_matmul_params_layout() {
let dims = Dimensions { m: 4, k: 1536, n: 1536, alpha_bits: 1.0_f32.to_bits() };
let bytes = bytemuck::bytes_of(&dims);
assert_eq!(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), 4); assert_eq!(u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]), 1536); assert_eq!(u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]), 1536);
}
}