trueno/backends/gpu/device/linalg/
dot.rs1use super::super::GpuDevice;
4#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
5use crate::backends::gpu::runtime;
6use crate::backends::gpu::shaders;
7
8impl GpuDevice {
9 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
11 pub fn dot(&self, a: &[f32], b: &[f32]) -> Result<f32, String> {
12 runtime::block_on(async { self.dot_async(a, b).await })
13 }
14
15 pub async fn dot_async(&self, a: &[f32], b: &[f32]) -> Result<f32, String> {
17 let len = a.len();
18 let workgroup_size = 256;
19 let num_workgroups = (len as u32).div_ceil(workgroup_size);
20
21 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
23 label: Some("Dot Product Shader"),
24 source: wgpu::ShaderSource::Wgsl(shaders::DOT_PRODUCT_SHADER.into()),
25 });
26
27 let a_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
29 label: Some("Vector A"),
30 size: std::mem::size_of_val(a) as u64,
31 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
32 mapped_at_creation: false,
33 });
34
35 let b_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
36 label: Some("Vector B"),
37 size: std::mem::size_of_val(b) as u64,
38 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
39 mapped_at_creation: false,
40 });
41
42 let partial_results = vec![0.0f32; num_workgroups as usize];
44 let result_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
45 label: Some("Partial Results"),
46 size: std::mem::size_of_val(partial_results.as_slice()) as u64,
47 usage: wgpu::BufferUsages::STORAGE
48 | wgpu::BufferUsages::COPY_SRC
49 | wgpu::BufferUsages::COPY_DST,
50 mapped_at_creation: false,
51 });
52
53 self.queue.write_buffer(&a_buffer, 0, bytemuck::cast_slice(a));
55 self.queue.write_buffer(&b_buffer, 0, bytemuck::cast_slice(b));
56
57 let bind_group_layout =
59 self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
60 label: Some("Dot Product Bind Group Layout"),
61 entries: &[
62 wgpu::BindGroupLayoutEntry {
63 binding: 0,
64 visibility: wgpu::ShaderStages::COMPUTE,
65 ty: wgpu::BindingType::Buffer {
66 ty: wgpu::BufferBindingType::Storage { read_only: true },
67 has_dynamic_offset: false,
68 min_binding_size: None,
69 },
70 count: None,
71 },
72 wgpu::BindGroupLayoutEntry {
73 binding: 1,
74 visibility: wgpu::ShaderStages::COMPUTE,
75 ty: wgpu::BindingType::Buffer {
76 ty: wgpu::BufferBindingType::Storage { read_only: true },
77 has_dynamic_offset: false,
78 min_binding_size: None,
79 },
80 count: None,
81 },
82 wgpu::BindGroupLayoutEntry {
83 binding: 2,
84 visibility: wgpu::ShaderStages::COMPUTE,
85 ty: wgpu::BindingType::Buffer {
86 ty: wgpu::BufferBindingType::Storage { read_only: false },
87 has_dynamic_offset: false,
88 min_binding_size: None,
89 },
90 count: None,
91 },
92 ],
93 });
94
95 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
97 label: Some("Dot Product Bind Group"),
98 layout: &bind_group_layout,
99 entries: &[
100 wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() },
101 wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() },
102 wgpu::BindGroupEntry { binding: 2, resource: result_buffer.as_entire_binding() },
103 ],
104 });
105
106 let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
108 label: Some("Dot Product Pipeline Layout"),
109 bind_group_layouts: &[&bind_group_layout],
110 push_constant_ranges: &[],
111 });
112
113 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
114 label: Some("Dot Product Pipeline"),
115 layout: Some(&pipeline_layout),
116 module: &shader,
117 entry_point: Some("main"),
118 compilation_options: Default::default(),
119 cache: None,
120 });
121
122 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
124 label: Some("Staging Buffer"),
125 size: std::mem::size_of_val(partial_results.as_slice()) as u64,
126 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
127 mapped_at_creation: false,
128 });
129
130 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
132 label: Some("Dot Product Encoder"),
133 });
134
135 {
136 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
137 label: Some("Dot Product Pass"),
138 timestamp_writes: None,
139 });
140 compute_pass.set_pipeline(&pipeline);
141 compute_pass.set_bind_group(0, &bind_group, &[]);
142
143 compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
145 }
146
147 encoder.copy_buffer_to_buffer(
149 &result_buffer,
150 0,
151 &staging_buffer,
152 0,
153 std::mem::size_of_val(partial_results.as_slice()) as u64,
154 );
155
156 self.queue.submit(Some(encoder.finish()));
158
159 let buffer_slice = staging_buffer.slice(..);
161 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
162 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
163 sender.send(result).ok();
164 });
165
166 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
168
169 receiver
170 .receive()
171 .await
172 .ok_or("Failed to receive mapping result")?
173 .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
174
175 let final_result = {
176 let data = buffer_slice.get_mapped_range();
177 let partial_sums: &[f32] = bytemuck::cast_slice(&data);
178
179 partial_sums.iter().sum()
181 };
182
183 staging_buffer.unmap();
184
185 Ok(final_result)
186 }
187}