Skip to main content

trueno/backends/gpu/device/linalg/
cached_matmul.rs

1//! PMAT-322: Cached GPU matmul with persistent weight buffers.
2//!
3//! The default `matmul_async` creates all GPU objects per call (~8ms overhead).
4//! This module pre-uploads weight matrices and caches the pipeline, reducing
5//! per-call overhead to: upload input + dispatch + download output (~0.1ms).
6
7use std::collections::HashMap;
8
9/// Cached matmul state: pipeline + pre-uploaded weight buffers + persistent I/O.
10///
11/// PMAT-323: Three levels of buffer persistence:
12/// 1. Weight buffers — uploaded once at model init (PMAT-322)
13/// 2. I/O buffers — pre-allocated to max size, reused across calls (PMAT-323)
14/// 3. Pipeline + bind group layout — created once, reused forever
15pub struct GpuMatmulCache {
16    device: wgpu::Device,
17    queue: wgpu::Queue,
18    pipeline: wgpu::ComputePipeline,
19    /// CUTLASS-style tiled GEMM pipeline for M>16 (training batch, prefill)
20    tiled_pipeline: wgpu::ComputePipeline,
21    /// PMAT-326: Dedicated GEMV pipeline for M=1 (cooperative K-reduction)
22    gemv_pipeline: wgpu::ComputePipeline,
23    bind_group_layout: wgpu::BindGroupLayout,
24    /// Pre-uploaded weight buffers keyed by name
25    weight_buffers: HashMap<String, WeightEntry>,
26    /// PMAT-323: Persistent I/O buffers (grow-only, never deallocated)
27    input_buffer: Option<wgpu::Buffer>,
28    input_size: u64,
29    output_buffer: Option<wgpu::Buffer>,
30    output_size: u64,
31    dims_buffer: Option<wgpu::Buffer>,
32    /// Reusable staging buffer (grows as needed)
33    staging_size: u64,
34    staging_buffer: Option<wgpu::Buffer>,
35}
36
37struct WeightEntry {
38    buffer: wgpu::Buffer,
39    rows: usize,
40    cols: usize,
41}
42
43#[repr(C)]
44#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
45struct Dimensions {
46    m: u32,
47    k: u32,
48    n: u32,
49    /// Alpha scaling factor for tiled GEMM epilogue. Reinterpreted as f32.
50    /// Old naive shader ignores this field (_padding). Tiled GEMM reads it as `dims.alpha`.
51    alpha_bits: u32,
52}
53
54impl GpuMatmulCache {
55    /// Create a new cached matmul context from an existing GpuDevice.
56    pub fn new(device: wgpu::Device, queue: wgpu::Queue) -> Self {
57        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
58            label: Some("CachedMatmul Shader"),
59            source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::MATMUL_SHADER.into()),
60        });
61
62        let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
63            label: Some("CachedMatmul BGL"),
64            entries: &[
65                bgl_entry(0, true),  // A (input, read-only)
66                bgl_entry(1, true),  // B (weight, read-only)
67                bgl_entry(2, false), // C (output, read-write)
68                wgpu::BindGroupLayoutEntry {
69                    binding: 3,
70                    visibility: wgpu::ShaderStages::COMPUTE,
71                    ty: wgpu::BindingType::Buffer {
72                        ty: wgpu::BufferBindingType::Uniform,
73                        has_dynamic_offset: false,
74                        min_binding_size: None,
75                    },
76                    count: None,
77                },
78            ],
79        });
80
81        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
82            label: Some("CachedMatmul PL"),
83            bind_group_layouts: &[&bind_group_layout],
84            push_constant_ranges: &[],
85        });
86
87        let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
88            label: Some("CachedMatmul Pipeline"),
89            layout: Some(&pipeline_layout),
90            module: &shader,
91            entry_point: Some("main"),
92            compilation_options: Default::default(),
93            cache: None,
94        });
95
96        // CUTLASS-style tiled GEMM pipeline (64×64 tiles, 4×4 thread micro-tiles)
97        let tiled_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
98            label: Some("TiledGEMM Shader"),
99            source: wgpu::ShaderSource::Wgsl(
100                crate::backends::gpu::shaders::TILED_GEMM_SHADER.into(),
101            ),
102        });
103        let tiled_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
104            label: Some("TiledGEMM Pipeline"),
105            layout: Some(&pipeline_layout),
106            module: &tiled_shader,
107            entry_point: Some("main"),
108            compilation_options: Default::default(),
109            cache: None,
110        });
111
112        // PMAT-326: GEMV pipeline (cooperative K-reduction, optimal for M=1)
113        let gemv_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
114            label: Some("GEMV Shader"),
115            source: wgpu::ShaderSource::Wgsl(crate::backends::gpu::shaders::GEMV_SHADER.into()),
116        });
117        let gemv_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
118            label: Some("GEMV Pipeline"),
119            layout: Some(&pipeline_layout),
120            module: &gemv_shader,
121            entry_point: Some("main"),
122            compilation_options: Default::default(),
123            cache: None,
124        });
125
126        Self {
127            device,
128            queue,
129            pipeline,
130            tiled_pipeline,
131            gemv_pipeline,
132            bind_group_layout,
133            weight_buffers: HashMap::new(),
134            input_buffer: None,
135            input_size: 0,
136            output_buffer: None,
137            output_size: 0,
138            dims_buffer: None,
139            staging_size: 0,
140            staging_buffer: None,
141        }
142    }
143
144    /// Pre-upload a weight matrix (call once at model init).
145    /// Weight is stored in row-major f32: shape [rows, cols].
146    /// Silently skips weights that exceed the device's max buffer binding size.
147    pub fn upload_weight(&mut self, name: &str, data: &[f32], rows: usize, cols: usize) {
148        assert_eq!(data.len(), rows * cols, "weight size mismatch");
149        let size_bytes = (data.len() * 4) as u64;
150        let max_binding = self.device.limits().max_storage_buffer_binding_size as u64;
151        if size_bytes > max_binding {
152            eprintln!(
153                "[wgpu] Skipping weight '{}' ({:.1} MB > {:.1} MB max binding) — will use CPU fallback",
154                name,
155                size_bytes as f64 / 1e6,
156                max_binding as f64 / 1e6
157            );
158            return;
159        }
160        let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
161            label: Some(name),
162            size: size_bytes,
163            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
164            mapped_at_creation: false,
165        });
166        self.queue.write_buffer(&buffer, 0, bytemuck::cast_slice(data));
167        self.weight_buffers.insert(name.to_string(), WeightEntry { buffer, rows, cols });
168    }
169
170    /// Number of pre-uploaded weights.
171    pub fn weight_count(&self) -> usize {
172        self.weight_buffers.len()
173    }
174
175    /// Total VRAM used by weight buffers (bytes).
176    pub fn weight_bytes(&self) -> usize {
177        self.weight_buffers.values().map(|w| w.rows * w.cols * 4).sum()
178    }
179
180    /// PMAT-323: Ensure persistent I/O buffers are at least `size` bytes.
181    /// Grows only — never shrinks. Returns reference to the buffer.
182    fn ensure_input_buffer(&mut self, size: u64) {
183        if self.input_size < size {
184            self.input_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
185                label: Some("persistent_input"),
186                size,
187                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
188                mapped_at_creation: false,
189            }));
190            self.input_size = size;
191        }
192    }
193
194    fn ensure_output_buffer(&mut self, size: u64) {
195        if self.output_size < size {
196            self.output_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
197                label: Some("persistent_output"),
198                size,
199                usage: wgpu::BufferUsages::STORAGE
200                    | wgpu::BufferUsages::COPY_SRC
201                    | wgpu::BufferUsages::COPY_DST,
202                mapped_at_creation: false,
203            }));
204            self.output_size = size;
205        }
206    }
207
208    fn ensure_dims_buffer(&mut self) {
209        if self.dims_buffer.is_none() {
210            self.dims_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
211                label: Some("persistent_dims"),
212                size: 16,
213                usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
214                mapped_at_creation: false,
215            }));
216        }
217    }
218
219    fn ensure_staging_buffer(&mut self, size: u64) {
220        if self.staging_size < size {
221            self.staging_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
222                label: Some("persistent_staging"),
223                size,
224                usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
225                mapped_at_creation: false,
226            }));
227            self.staging_size = size;
228        }
229    }
230
231    /// PMAT-323: Zero-alloc matmul using persistent I/O buffers.
232    /// Only creates bind group per call (required by WGPU — bind groups reference
233    /// specific buffer instances). Everything else is reused.
234    pub fn matmul_cached(
235        &mut self,
236        weight_name: &str,
237        input: &[f32],
238        output: &mut [f32],
239        m: usize,
240    ) -> Result<(), String> {
241        // Extract weight dims first to avoid borrow conflict
242        let (k, n) = {
243            let entry = self
244                .weight_buffers
245                .get(weight_name)
246                .ok_or_else(|| format!("Weight '{}' not uploaded", weight_name))?;
247            (entry.cols, entry.rows)
248        };
249
250        if input.len() < m * k {
251            return Err(format!("input too small: need {}, have {}", m * k, input.len()));
252        }
253        if output.len() < m * n {
254            return Err(format!("output too small: need {}, have {}", m * n, output.len()));
255        }
256
257        let input_bytes = (m * k * 4) as u64;
258        let output_bytes = (m * n * 4) as u64;
259
260        // Ensure persistent buffers are large enough (may alloc on first call / size increase)
261        self.ensure_input_buffer(input_bytes);
262        self.ensure_output_buffer(output_bytes);
263        self.ensure_dims_buffer();
264        self.ensure_staging_buffer(output_bytes);
265
266        // Write input + dims to persistent buffers (just memcpy, no alloc)
267        let input_buf = self.input_buffer.as_ref().expect("ensure_input_buffer was just called");
268        self.queue.write_buffer(input_buf, 0, bytemuck::cast_slice(&input[..m * k]));
269
270        // PMAT-346: GEMV shader expects Params { n (output dim), k, _, _ }
271        // but Dimensions struct has { m, k, n, _ }. When m=1, params.n reads m=1
272        // instead of the actual output dimension. Write different layout for GEMV.
273        let dims = if m == 1 {
274            Dimensions { m: n as u32, k: k as u32, n: 0, alpha_bits: 1.0_f32.to_bits() }
275        } else {
276            Dimensions { m: m as u32, k: k as u32, n: n as u32, alpha_bits: 1.0_f32.to_bits() }
277        };
278        let dims_buf = self.dims_buffer.as_ref().expect("ensure_dims_buffer was just called");
279        self.queue.write_buffer(dims_buf, 0, bytemuck::bytes_of(&dims));
280
281        // Bind group (per-call — WGPU requires new bind group when buffer references change)
282        let output_buf = self.output_buffer.as_ref().expect("ensure_output_buffer was just called");
283        let weight_buf = &self
284            .weight_buffers
285            .get(weight_name)
286            .ok_or_else(|| {
287                format!("weight '{}' not loaded — call load_weight() first", weight_name)
288            })?
289            .buffer;
290        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
291            label: None,
292            layout: &self.bind_group_layout,
293            entries: &[
294                wgpu::BindGroupEntry { binding: 0, resource: input_buf.as_entire_binding() },
295                wgpu::BindGroupEntry { binding: 1, resource: weight_buf.as_entire_binding() },
296                wgpu::BindGroupEntry { binding: 2, resource: output_buf.as_entire_binding() },
297                wgpu::BindGroupEntry { binding: 3, resource: dims_buf.as_entire_binding() },
298            ],
299        });
300
301        let staging = self.staging_buffer.as_ref().expect("ensure_staging_buffer was just called");
302
303        // Encode + dispatch
304        let mut encoder =
305            self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
306
307        {
308            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
309                label: Some("matmul"),
310                timestamp_writes: None,
311            });
312            if m == 1 {
313                // PMAT-326: Use GEMV shader for M=1 — cooperative K-reduction
314                // Each workgroup handles 1 output row with 256 threads reducing K
315                // Dispatch N workgroups (one per output element)
316                pass.set_pipeline(&self.gemv_pipeline);
317                pass.set_bind_group(0, &bind_group, &[]);
318                pass.dispatch_workgroups(n as u32, 1, 1);
319            } else if m >= 4 {
320                // CUTLASS-style tiled GEMM for M>=4 (training batch, prefill)
321                // 64×64 tiles, 4×4 thread micro-tiles, double-buffered shared memory
322                // ~10-30x faster than naive 16×16 for large M
323                pass.set_pipeline(&self.tiled_pipeline);
324                pass.set_bind_group(0, &bind_group, &[]);
325                pass.dispatch_workgroups((n as u32).div_ceil(64), (m as u32).div_ceil(64), 1);
326            } else {
327                // Naive 16×16 tiled GEMM for small M (2-3 rows)
328                pass.set_pipeline(&self.pipeline);
329                pass.set_bind_group(0, &bind_group, &[]);
330                pass.dispatch_workgroups((m as u32).div_ceil(16), (n as u32).div_ceil(16), 1);
331            }
332        }
333
334        encoder.copy_buffer_to_buffer(output_buf, 0, staging, 0, output_bytes);
335        self.queue.submit(Some(encoder.finish()));
336
337        // Readback
338        let slice = staging.slice(..output_bytes);
339        let (tx, rx) = std::sync::mpsc::channel();
340        slice.map_async(wgpu::MapMode::Read, move |r| {
341            tx.send(r).ok();
342        });
343        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
344        rx.recv().map_err(|e| format!("recv: {e}"))?.map_err(|e| format!("map: {e:?}"))?;
345
346        {
347            let data = slice.get_mapped_range();
348            output[..m * n].copy_from_slice(bytemuck::cast_slice(&data));
349        }
350        staging.unmap();
351
352        Ok(())
353    }
354}
355
356fn bgl_entry(binding: u32, read_only: bool) -> wgpu::BindGroupLayoutEntry {
357    wgpu::BindGroupLayoutEntry {
358        binding,
359        visibility: wgpu::ShaderStages::COMPUTE,
360        ty: wgpu::BindingType::Buffer {
361            ty: wgpu::BufferBindingType::Storage { read_only },
362            has_dynamic_offset: false,
363            min_binding_size: None,
364        },
365        count: None,
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_dimensions_layout() {
375        let dims = Dimensions { m: 1, k: 1536, n: 1536, alpha_bits: 1.0_f32.to_bits() };
376        let bytes = bytemuck::bytes_of(&dims);
377        assert_eq!(bytes.len(), 16); // 4 × u32
378                                     // Verify field order matches shader uniform layout
379        assert_eq!(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), 1);
380        assert_eq!(u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]), 1536);
381    }
382
383    #[test]
384    fn test_gemv_params_layout() {
385        // PMAT-346: When m=1, first field must be n (output dim), not m
386        let m = 1usize;
387        let k = 1536usize;
388        let n = 256usize;
389        let dims = if m == 1 {
390            Dimensions { m: n as u32, k: k as u32, n: 0, alpha_bits: 1.0_f32.to_bits() }
391        } else {
392            Dimensions { m: m as u32, k: k as u32, n: n as u32, alpha_bits: 1.0_f32.to_bits() }
393        };
394        let bytes = bytemuck::bytes_of(&dims);
395        // GEMV shader reads params.n (offset 0) as output dimension
396        let gemv_n = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
397        assert_eq!(gemv_n, 256, "GEMV params.n must be output dimension, not m");
398    }
399
400    #[test]
401    fn test_matmul_params_layout() {
402        let dims = Dimensions { m: 4, k: 1536, n: 1536, alpha_bits: 1.0_f32.to_bits() };
403        let bytes = bytemuck::bytes_of(&dims);
404        // Matmul shader reads dims.M, dims.K, dims.N
405        assert_eq!(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]), 4); // M
406        assert_eq!(u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]), 1536); // K
407        assert_eq!(u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]), 1536);
408        // N
409    }
410}