trueno/backends/gpu/device/linalg/
matmul.rs1use super::super::GpuDevice;
4#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
5use crate::backends::gpu::runtime;
6use crate::backends::gpu::shaders;
7
8impl GpuDevice {
9 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
11 pub fn matmul(
12 &self,
13 a: &[f32],
14 b: &[f32],
15 result: &mut [f32],
16 m: usize,
17 k: usize,
18 n: usize,
19 ) -> Result<(), String> {
20 runtime::block_on(async { self.matmul_async(a, b, result, m, k, n).await })
21 }
22
23 pub async fn matmul_async(
25 &self,
26 a: &[f32],
27 b: &[f32],
28 result: &mut [f32],
29 m: usize,
30 k: usize,
31 n: usize,
32 ) -> Result<(), String> {
33 contract_pre_matmul!();
34 let max_binding = self.device.limits().max_storage_buffer_binding_size as u64;
38 let b_bytes = (b.len() * 4) as u64;
39 if b_bytes > max_binding {
40 let max_elements = max_binding as usize / 4; let max_n_chunk = max_elements / k; let max_n_chunk = max_n_chunk.max(1);
44
45 let mut n_start = 0;
46 while n_start < n {
47 let n_end = (n_start + max_n_chunk).min(n);
48 let chunk_n = n_end - n_start;
49
50 let mut b_chunk = vec![0.0f32; k * chunk_n];
52 for row in 0..k {
53 for col in 0..chunk_n {
54 b_chunk[row * chunk_n + col] = b[row * n + n_start + col];
55 }
56 }
57
58 let mut c_chunk = vec![0.0f32; m * chunk_n];
60 Box::pin(self.matmul_async(a, &b_chunk, &mut c_chunk, m, k, chunk_n)).await?;
62
63 for row in 0..m {
65 for col in 0..chunk_n {
66 result[row * n + n_start + col] = c_chunk[row * chunk_n + col];
67 }
68 }
69
70 n_start = n_end;
71 }
72 return Ok(());
73 }
74
75 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
77 label: Some("Matmul Shader"),
78 source: wgpu::ShaderSource::Wgsl(shaders::MATMUL_SHADER.into()),
79 });
80
81 let a_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
83 label: Some("Matrix A"),
84 size: std::mem::size_of_val(a) as u64,
85 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
86 mapped_at_creation: false,
87 });
88
89 let b_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
90 label: Some("Matrix B"),
91 size: std::mem::size_of_val(b) as u64,
92 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
93 mapped_at_creation: false,
94 });
95
96 let c_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
97 label: Some("Matrix C"),
98 size: std::mem::size_of_val(result) as u64,
99 usage: wgpu::BufferUsages::STORAGE
100 | wgpu::BufferUsages::COPY_SRC
101 | wgpu::BufferUsages::COPY_DST,
102 mapped_at_creation: false,
103 });
104
105 #[repr(C)]
107 #[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
108 struct Dimensions {
109 m: u32,
110 k: u32,
111 n: u32,
112 _padding: u32,
113 }
114
115 let dims = Dimensions { m: m as u32, k: k as u32, n: n as u32, _padding: 0 };
116
117 let dims_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
118 label: Some("Dimensions"),
119 size: std::mem::size_of::<Dimensions>() as u64,
120 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
121 mapped_at_creation: false,
122 });
123
124 self.queue.write_buffer(&a_buffer, 0, bytemuck::cast_slice(a));
126 self.queue.write_buffer(&b_buffer, 0, bytemuck::cast_slice(b));
127 self.queue.write_buffer(&dims_buffer, 0, bytemuck::bytes_of(&dims));
128
129 let bind_group_layout =
131 self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
132 label: Some("Matmul Bind Group Layout"),
133 entries: &[
134 wgpu::BindGroupLayoutEntry {
135 binding: 0,
136 visibility: wgpu::ShaderStages::COMPUTE,
137 ty: wgpu::BindingType::Buffer {
138 ty: wgpu::BufferBindingType::Storage { read_only: true },
139 has_dynamic_offset: false,
140 min_binding_size: None,
141 },
142 count: None,
143 },
144 wgpu::BindGroupLayoutEntry {
145 binding: 1,
146 visibility: wgpu::ShaderStages::COMPUTE,
147 ty: wgpu::BindingType::Buffer {
148 ty: wgpu::BufferBindingType::Storage { read_only: true },
149 has_dynamic_offset: false,
150 min_binding_size: None,
151 },
152 count: None,
153 },
154 wgpu::BindGroupLayoutEntry {
155 binding: 2,
156 visibility: wgpu::ShaderStages::COMPUTE,
157 ty: wgpu::BindingType::Buffer {
158 ty: wgpu::BufferBindingType::Storage { read_only: false },
159 has_dynamic_offset: false,
160 min_binding_size: None,
161 },
162 count: None,
163 },
164 wgpu::BindGroupLayoutEntry {
165 binding: 3,
166 visibility: wgpu::ShaderStages::COMPUTE,
167 ty: wgpu::BindingType::Buffer {
168 ty: wgpu::BufferBindingType::Uniform,
169 has_dynamic_offset: false,
170 min_binding_size: None,
171 },
172 count: None,
173 },
174 ],
175 });
176
177 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
179 label: Some("Matmul Bind Group"),
180 layout: &bind_group_layout,
181 entries: &[
182 wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() },
183 wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() },
184 wgpu::BindGroupEntry { binding: 2, resource: c_buffer.as_entire_binding() },
185 wgpu::BindGroupEntry { binding: 3, resource: dims_buffer.as_entire_binding() },
186 ],
187 });
188
189 let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
191 label: Some("Matmul Pipeline Layout"),
192 bind_group_layouts: &[&bind_group_layout],
193 push_constant_ranges: &[],
194 });
195
196 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
197 label: Some("Matmul Pipeline"),
198 layout: Some(&pipeline_layout),
199 module: &shader,
200 entry_point: Some("main"),
201 compilation_options: Default::default(),
202 cache: None,
203 });
204
205 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
207 label: Some("Staging Buffer"),
208 size: std::mem::size_of_val(result) as u64,
209 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
210 mapped_at_creation: false,
211 });
212
213 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
215 label: Some("Matmul Encoder"),
216 });
217
218 {
219 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
220 label: Some("Matmul Pass"),
221 timestamp_writes: None,
222 });
223 compute_pass.set_pipeline(&pipeline);
224 compute_pass.set_bind_group(0, &bind_group, &[]);
225
226 let workgroup_size_x = 16;
228 let workgroup_size_y = 16;
229 let num_workgroups_x = (m as u32).div_ceil(workgroup_size_x);
230 let num_workgroups_y = (n as u32).div_ceil(workgroup_size_y);
231
232 compute_pass.dispatch_workgroups(num_workgroups_x, num_workgroups_y, 1);
233 }
234
235 encoder.copy_buffer_to_buffer(
237 &c_buffer,
238 0,
239 &staging_buffer,
240 0,
241 std::mem::size_of_val(result) as u64,
242 );
243
244 self.queue.submit(Some(encoder.finish()));
246
247 let buffer_slice = staging_buffer.slice(..);
249 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
250 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
251 sender.send(result).ok();
252 });
253
254 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
256
257 receiver
258 .receive()
259 .await
260 .ok_or("Failed to receive mapping result")?
261 .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
262
263 {
264 let data = buffer_slice.get_mapped_range();
265 result.copy_from_slice(bytemuck::cast_slice(&data));
266 }
267
268 staging_buffer.unmap();
269
270 contract_post_matmul!(result);
271 Ok(())
272 }
273}