trueno/backends/gpu/device/linalg/
cached_matmul.rs1use std::collections::HashMap;
8
9pub struct GpuMatmulCache {
16 device: wgpu::Device,
17 queue: wgpu::Queue,
18 pipeline: wgpu::ComputePipeline,
19 tiled_pipeline: wgpu::ComputePipeline,
21 gemv_pipeline: wgpu::ComputePipeline,
23 bind_group_layout: wgpu::BindGroupLayout,
24 weight_buffers: HashMap<String, WeightEntry>,
26 input_buffer: Option<wgpu::Buffer>,
28 input_size: u64,
29 output_buffer: Option<wgpu::Buffer>,
30 output_size: u64,
31 dims_buffer: Option<wgpu::Buffer>,
32 staging_size: u64,
34 staging_buffer: Option<wgpu::Buffer>,
35}
36
37struct WeightEntry {
38 buffer: wgpu::Buffer,
39 rows: usize,
40 cols: usize,
41}
42
43#[repr(C)]
44#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
45struct Dimensions {
46 m: u32,
47 k: u32,
48 n: u32,
49 alpha_bits: u32,
52}
53
54impl GpuMatmulCache {
55 pub fn new(device: wgpu::Device, queue: wgpu::Queue) -> Self {
57 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
58 label: Some("CachedMatmul Shader"),
59 source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::MATMUL_SHADER.into()),
60 });
61
62 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
63 label: Some("CachedMatmul BGL"),
64 entries: &[
65 bgl_entry(0, true), bgl_entry(1, true), bgl_entry(2, false), wgpu::BindGroupLayoutEntry {
69 binding: 3,
70 visibility: wgpu::ShaderStages::COMPUTE,
71 ty: wgpu::BindingType::Buffer {
72 ty: wgpu::BufferBindingType::Uniform,
73 has_dynamic_offset: false,
74 min_binding_size: None,
75 },
76 count: None,
77 },
78 ],
79 });
80
81 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
82 label: Some("CachedMatmul PL"),
83 bind_group_layouts: &[&bind_group_layout],
84 push_constant_ranges: &[],
85 });
86
87 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
88 label: Some("CachedMatmul Pipeline"),
89 layout: Some(&pipeline_layout),
90 module: &shader,
91 entry_point: Some("main"),
92 compilation_options: Default::default(),
93 cache: None,
94 });
95
96 let tiled_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
98 label: Some("TiledGEMM Shader"),
99 source: wgpu::ShaderSource::Wgsl(
100 crate::backends::gpu::shaders::TILED_GEMM_SHADER.into(),
101 ),
102 });
103 let tiled_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
104 label: Some("TiledGEMM Pipeline"),
105 layout: Some(&pipeline_layout),
106 module: &tiled_shader,
107 entry_point: Some("main"),
108 compilation_options: Default::default(),
109 cache: None,
110 });
111
112 let gemv_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
114 label: Some("GEMV Shader"),
115 source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::GEMV_SHADER.into()),
116 });
117 let gemv_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
118 label: Some("GEMV Pipeline"),
119 layout: Some(&pipeline_layout),
120 module: &gemv_shader,
121 entry_point: Some("main"),
122 compilation_options: Default::default(),
123 cache: None,
124 });
125
126 Self {
127 device,
128 queue,
129 pipeline,
130 tiled_pipeline,
131 gemv_pipeline,
132 bind_group_layout,
133 weight_buffers: HashMap::new(),
134 input_buffer: None,
135 input_size: 0,
136 output_buffer: None,
137 output_size: 0,
138 dims_buffer: None,
139 staging_size: 0,
140 staging_buffer: None,
141 }
142 }
143
144 pub fn upload_weight(&mut self, name: &str, data: &[f32], rows: usize, cols: usize) {
148 assert_eq!(data.len(), rows * cols, "weight size mismatch");
149 let size_bytes = (data.len() * 4) as u64;
150 let max_binding = self.device.limits().max_storage_buffer_binding_size as u64;
151 if size_bytes > max_binding {
152 eprintln!(
153 "[wgpu] Skipping weight '{}' ({:.1} MB > {:.1} MB max binding) — will use CPU fallback",
154 name,
155 size_bytes as f64 / 1e6,
156 max_binding as f64 / 1e6
157 );
158 return;
159 }
160 let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
161 label: Some(name),
162 size: size_bytes,
163 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
164 mapped_at_creation: false,
165 });
166 self.queue.write_buffer(&buffer, 0, bytemuck::cast_slice(data));
167 self.weight_buffers.insert(name.to_string(), WeightEntry { buffer, rows, cols });
168 }
169
170 pub fn weight_count(&self) -> usize {
172 self.weight_buffers.len()
173 }
174
175 pub fn weight_bytes(&self) -> usize {
177 self.weight_buffers.values().map(|w| w.rows * w.cols * 4).sum()
178 }
179
180 fn ensure_input_buffer(&mut self, size: u64) {
183 if self.input_size < size {
184 self.input_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
185 label: Some("persistent_input"),
186 size,
187 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
188 mapped_at_creation: false,
189 }));
190 self.input_size = size;
191 }
192 }
193
194 fn ensure_output_buffer(&mut self, size: u64) {
195 if self.output_size < size {
196 self.output_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
197 label: Some("persistent_output"),
198 size,
199 usage: wgpu::BufferUsages::STORAGE
200 | wgpu::BufferUsages::COPY_SRC
201 | wgpu::BufferUsages::COPY_DST,
202 mapped_at_creation: false,
203 }));
204 self.output_size = size;
205 }
206 }
207
208 fn ensure_dims_buffer(&mut self) {
209 if self.dims_buffer.is_none() {
210 self.dims_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
211 label: Some("persistent_dims"),
212 size: 16,
213 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
214 mapped_at_creation: false,
215 }));
216 }
217 }
218
219 fn ensure_staging_buffer(&mut self, size: u64) {
220 if self.staging_size < size {
221 self.staging_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
222 label: Some("persistent_staging"),
223 size,
224 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
225 mapped_at_creation: false,
226 }));
227 self.staging_size = size;
228 }
229 }
230
231 pub fn matmul_cached(
235 &mut self,
236 weight_name: &str,
237 input: &[f32],
238 output: &mut [f32],
239 m: usize,
240 ) -> Result<(), String> {
241 let (k, n) = {
243 let entry = self
244 .weight_buffers
245 .get(weight_name)
246 .ok_or_else(|| format!("Weight '{}' not uploaded", weight_name))?;
247 (entry.cols, entry.rows)
248 };
249
250 if input.len() < m * k {
251 return Err(format!("input too small: need {}, have {}", m * k, input.len()));
252 }
253 if output.len() < m * n {
254 return Err(format!("output too small: need {}, have {}", m * n, output.len()));
255 }
256
257 let input_bytes = (m * k * 4) as u64;
258 let output_bytes = (m * n * 4) as u64;
259
260 self.ensure_input_buffer(input_bytes);
262 self.ensure_output_buffer(output_bytes);
263 self.ensure_dims_buffer();
264 self.ensure_staging_buffer(output_bytes);
265
266 let input_buf = self.input_buffer.as_ref().expect("ensure_input_buffer was just called");
268 self.queue.write_buffer(input_buf, 0, bytemuck::cast_slice(&input[..m * k]));
269
270 let dims = if m == 1 {
274 Dimensions { m: n as u32, k: k as u32, n: 0, alpha_bits: 1.0_f32.to_bits() }
275 } else {
276 Dimensions { m: m as u32, k: k as u32, n: n as u32, alpha_bits: 1.0_f32.to_bits() }
277 };
278 let dims_buf = self.dims_buffer.as_ref().expect("ensure_dims_buffer was just called");
279 self.queue.write_buffer(dims_buf, 0, bytemuck::bytes_of(&dims));
280
281 let output_buf = self.output_buffer.as_ref().expect("ensure_output_buffer was just called");
283 let weight_buf = &self
284 .weight_buffers
285 .get(weight_name)
286 .ok_or_else(|| {
287 format!("weight '{}' not loaded — call load_weight() first", weight_name)
288 })?
289 .buffer;
290 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
291 label: None,
292 layout: &self.bind_group_layout,
293 entries: &[
294 wgpu::BindGroupEntry { binding: 0, resource: input_buf.as_entire_binding() },
295 wgpu::BindGroupEntry { binding: 1, resource: weight_buf.as_entire_binding() },
296 wgpu::BindGroupEntry { binding: 2, resource: output_buf.as_entire_binding() },
297 wgpu::BindGroupEntry { binding: 3, resource: dims_buf.as_entire_binding() },
298 ],
299 });
300
301 let staging = self.staging_buffer.as_ref().expect("ensure_staging_buffer was just called");
302
303 let mut encoder =
305 self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
306
307 {
308 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
309 label: Some("matmul"),
310 timestamp_writes: None,
311 });
312 if m == 1 {
313 pass.set_pipeline(&self.gemv_pipeline);
317 pass.set_bind_group(0, &bind_group, &[]);
318 pass.dispatch_workgroups(n as u32, 1, 1);
319 } else if m >= 4 {
320 pass.set_pipeline(&self.tiled_pipeline);
324 pass.set_bind_group(0, &bind_group, &[]);
325 pass.dispatch_workgroups((n as u32).div_ceil(64), (m as u32).div_ceil(64), 1);
326 } else {
327 pass.set_pipeline(&self.pipeline);
329 pass.set_bind_group(0, &bind_group, &[]);
330 pass.dispatch_workgroups((m as u32).div_ceil(16), (n as u32).div_ceil(16), 1);
331 }
332 }
333
334 encoder.copy_buffer_to_buffer(output_buf, 0, staging, 0, output_bytes);
335 self.queue.submit(Some(encoder.finish()));
336
337 let slice = staging.slice(..output_bytes);
339 let (tx, rx) = std::sync::mpsc::channel();
340 slice.map_async(wgpu::MapMode::Read, move |r| {
341 tx.send(r).ok();
342 });
343 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
344 rx.recv().map_err(|e| format!("recv: {e}"))?.map_err(|e| format!("map: {e:?}"))?;
345
346 {
347 let data = slice.get_mapped_range();
348 output[..m * n].copy_from_slice(bytemuck::cast_slice(&data));
349 }
350 staging.unmap();
351
352 Ok(())
353 }
354}
355
356fn bgl_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
357 wgpu::BindGroupLayoutEntry {
358 binding,
359 visibility: wgpu::ShaderStages::COMPUTE,
360 ty: wgpu::BindingType::Buffer {
361 ty: wgpu::BufferBindingType::Storage { read_only },
362 has_dynamic_offset: false,
363 min_binding_size: None,
364 },
365 count: None,
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_dimensions_layout() {
375 let dims = Dimensions { m: 1, k: 1536, n: 1536, alpha_bits: 1.0_f32.to_bits() };
376 let bytes = bytemuck::bytes_of(&dims);
377 assert_eq!(bytes.len(), 16); assert_eq!(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), 1);
380 assert_eq!(u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]), 1536);
381 }
382
383 #[test]
384 fn test_gemv_params_layout() {
385 let m = 1usize;
387 let k = 1536usize;
388 let n = 256usize;
389 let dims = if m == 1 {
390 Dimensions { m: n as u32, k: k as u32, n: 0, alpha_bits: 1.0_f32.to_bits() }
391 } else {
392 Dimensions { m: m as u32, k: k as u32, n: n as u32, alpha_bits: 1.0_f32.to_bits() }
393 };
394 let bytes = bytemuck::bytes_of(&dims);
395 let gemv_n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
397 assert_eq!(gemv_n, 256, "GEMV params.n must be output dimension, not m");
398 }
399
400 #[test]
401 fn test_matmul_params_layout() {
402 let dims = Dimensions { m: 4, k: 1536, n: 1536, alpha_bits: 1.0_f32.to_bits() };
403 let bytes = bytemuck::bytes_of(&dims);
404 assert_eq!(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), 4); assert_eq!(u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]), 1536); assert_eq!(u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]), 1536);
408 }
410}