Skip to main content

trueno/backends/gpu/device/linalg/
matmul.rs

1//! GPU matrix multiplication operations
2
3use super::super::GpuDevice;
4#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
5use crate::backends::gpu::runtime;
6use crate::backends::gpu::shaders;
7
8impl GpuDevice {
9    /// Execute matrix multiplication on GPU (sync, native only)
10    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
11    pub fn matmul(
12        &self,
13        a: &[f32],
14        b: &[f32],
15        result: &mut [f32],
16        m: usize,
17        k: usize,
18        n: usize,
19    ) -> Result<(), String> {
20        runtime::block_on(async { self.matmul_async(a, b, result, m, k, n).await })
21    }
22
23    /// Execute matrix multiplication on GPU (async, works on all platforms)
24    pub async fn matmul_async(
25        &self,
26        a: &[f32],
27        b: &[f32],
28        result: &mut [f32],
29        m: usize,
30        k: usize,
31        n: usize,
32    ) -> Result<(), String> {
33        contract_pre_matmul!();
34        // Guard: if B exceeds max buffer binding, chunk along N dimension.
35        // Each chunk computes result[:, n_start..n_end] = A @ B[:, n_start..n_end]
36        // This handles lm_head (152064 × 3584 × 4 = 2.18 GB > 2 GB limit).
37        let max_binding = self.device.limits().max_storage_buffer_binding_size as u64;
38        let b_bytes = (b.len() * 4) as u64;
39        if b_bytes > max_binding {
40            // Chunk B along N: each chunk has at most max_n_chunk columns
41            let max_elements = max_binding as usize / 4; // max f32 elements per buffer
42            let max_n_chunk = max_elements / k; // max columns per chunk
43            let max_n_chunk = max_n_chunk.max(1);
44
45            let mut n_start = 0;
46            while n_start < n {
47                let n_end = (n_start + max_n_chunk).min(n);
48                let chunk_n = n_end - n_start;
49
50                // Extract B chunk: B[:, n_start..n_end] from row-major B[K, N]
51                let mut b_chunk = vec![0.0f32; k * chunk_n];
52                for row in 0..k {
53                    for col in 0..chunk_n {
54                        b_chunk[row * chunk_n + col] = b[row * n + n_start + col];
55                    }
56                }
57
58                // Compute C_chunk = A @ B_chunk
59                let mut c_chunk = vec![0.0f32; m * chunk_n];
60                // Use recursive call — chunk fits in buffer now
61                Box::pin(self.matmul_async(a, &b_chunk, &mut c_chunk, m, k, chunk_n)).await?;
62
63                // Copy chunk into result: result[:, n_start..n_end]
64                for row in 0..m {
65                    for col in 0..chunk_n {
66                        result[row * n + n_start + col] = c_chunk[row * chunk_n + col];
67                    }
68                }
69
70                n_start = n_end;
71            }
72            return Ok(());
73        }
74
75        // Create shader module
76        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
77            label: Some("Matmul Shader"),
78            source: wgpu::ShaderSource::Wgsl(shaders::MATMUL_SHADER.into()),
79        });
80
81        // Create buffers
82        let a_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
83            label: Some("Matrix A"),
84            size: std::mem::size_of_val(a) as u64,
85            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
86            mapped_at_creation: false,
87        });
88
89        let b_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
90            label: Some("Matrix B"),
91            size: std::mem::size_of_val(b) as u64,
92            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
93            mapped_at_creation: false,
94        });
95
96        let c_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
97            label: Some("Matrix C"),
98            size: std::mem::size_of_val(result) as u64,
99            usage: wgpu::BufferUsages::STORAGE
100                | wgpu::BufferUsages::COPY_SRC
101                | wgpu::BufferUsages::COPY_DST,
102            mapped_at_creation: false,
103        });
104
105        // Dimensions uniform buffer
106        #[repr(C)]
107        #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
108        struct Dimensions {
109            m: u32,
110            k: u32,
111            n: u32,
112            _padding: u32,
113        }
114
115        let dims = Dimensions { m: m as u32, k: k as u32, n: n as u32, _padding: 0 };
116
117        let dims_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
118            label: Some("Dimensions"),
119            size: std::mem::size_of::<Dimensions>() as u64,
120            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
121            mapped_at_creation: false,
122        });
123
124        // Write data to buffers
125        self.queue.write_buffer(&a_buffer, 0, bytemuck::cast_slice(a));
126        self.queue.write_buffer(&b_buffer, 0, bytemuck::cast_slice(b));
127        self.queue.write_buffer(&dims_buffer, 0, bytemuck::bytes_of(&dims));
128
129        // Create bind group layout
130        let bind_group_layout =
131            self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
132                label: Some("Matmul Bind Group Layout"),
133                entries: &[
134                    wgpu::BindGroupLayoutEntry {
135                        binding: 0,
136                        visibility: wgpu::ShaderStages::COMPUTE,
137                        ty: wgpu::BindingType::Buffer {
138                            ty: wgpu::BufferBindingType::Storage { read_only: true },
139                            has_dynamic_offset: false,
140                            min_binding_size: None,
141                        },
142                        count: None,
143                    },
144                    wgpu::BindGroupLayoutEntry {
145                        binding: 1,
146                        visibility: wgpu::ShaderStages::COMPUTE,
147                        ty: wgpu::BindingType::Buffer {
148                            ty: wgpu::BufferBindingType::Storage { read_only: true },
149                            has_dynamic_offset: false,
150                            min_binding_size: None,
151                        },
152                        count: None,
153                    },
154                    wgpu::BindGroupLayoutEntry {
155                        binding: 2,
156                        visibility: wgpu::ShaderStages::COMPUTE,
157                        ty: wgpu::BindingType::Buffer {
158                            ty: wgpu::BufferBindingType::Storage { read_only: false },
159                            has_dynamic_offset: false,
160                            min_binding_size: None,
161                        },
162                        count: None,
163                    },
164                    wgpu::BindGroupLayoutEntry {
165                        binding: 3,
166                        visibility: wgpu::ShaderStages::COMPUTE,
167                        ty: wgpu::BindingType::Buffer {
168                            ty: wgpu::BufferBindingType::Uniform,
169                            has_dynamic_offset: false,
170                            min_binding_size: None,
171                        },
172                        count: None,
173                    },
174                ],
175            });
176
177        // Create bind group
178        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
179            label: Some("Matmul Bind Group"),
180            layout: &bind_group_layout,
181            entries: &[
182                wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() },
183                wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() },
184                wgpu::BindGroupEntry { binding: 2, resource: c_buffer.as_entire_binding() },
185                wgpu::BindGroupEntry { binding: 3, resource: dims_buffer.as_entire_binding() },
186            ],
187        });
188
189        // Create pipeline
190        let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
191            label: Some("Matmul Pipeline Layout"),
192            bind_group_layouts: &[&bind_group_layout],
193            push_constant_ranges: &[],
194        });
195
196        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
197            label: Some("Matmul Pipeline"),
198            layout: Some(&pipeline_layout),
199            module: &shader,
200            entry_point: Some("main"),
201            compilation_options: Default::default(),
202            cache: None,
203        });
204
205        // Create staging buffer for reading results
206        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
207            label: Some("Staging Buffer"),
208            size: std::mem::size_of_val(result) as u64,
209            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
210            mapped_at_creation: false,
211        });
212
213        // Create command encoder
214        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
215            label: Some("Matmul Encoder"),
216        });
217
218        {
219            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
220                label: Some("Matmul Pass"),
221                timestamp_writes: None,
222            });
223            compute_pass.set_pipeline(&pipeline);
224            compute_pass.set_bind_group(0, &bind_group, &[]);
225
226            // Dispatch workgroups (16x16 threads per workgroup)
227            let workgroup_size_x = 16;
228            let workgroup_size_y = 16;
229            let num_workgroups_x = (m as u32).div_ceil(workgroup_size_x);
230            let num_workgroups_y = (n as u32).div_ceil(workgroup_size_y);
231
232            compute_pass.dispatch_workgroups(num_workgroups_x, num_workgroups_y, 1);
233        }
234
235        // Copy result to staging buffer
236        encoder.copy_buffer_to_buffer(
237            &c_buffer,
238            0,
239            &staging_buffer,
240            0,
241            std::mem::size_of_val(result) as u64,
242        );
243
244        // Submit commands
245        self.queue.submit(Some(encoder.finish()));
246
247        // Read back results
248        let buffer_slice = staging_buffer.slice(..);
249        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
250        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
251            sender.send(result).ok();
252        });
253
254        // Poll device to ensure GPU work completes and callbacks are invoked
255        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
256
257        receiver
258            .receive()
259            .await
260            .ok_or("Failed to receive mapping result")?
261            .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
262
263        {
264            let data = buffer_slice.get_mapped_range();
265            result.copy_from_slice(bytemuck::cast_slice(&data));
266        }
267
268        staging_buffer.unmap();
269
270        contract_post_matmul!(result);
271        Ok(())
272    }
273}