Skip to main content

trueno/backends/gpu/device/linalg/
dot.rs

1//! GPU dot product operations
2
3use super::super::GpuDevice;
4#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
5use crate::backends::gpu::runtime;
6use crate::backends::gpu::shaders;
7
8impl GpuDevice {
9    /// Execute dot product on GPU (sync, native only)
10    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
11    pub fn dot(&self, a: &[f32], b: &[f32]) -> Result<f32, String> {
12        runtime::block_on(async { self.dot_async(a, b).await })
13    }
14
15    /// Execute dot product on GPU (async, works on all platforms)
16    pub async fn dot_async(&self, a: &[f32], b: &[f32]) -> Result<f32, String> {
17        let len = a.len();
18        let workgroup_size = 256;
19        let num_workgroups = (len as u32).div_ceil(workgroup_size);
20
21        // Create shader module
22        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
23            label: Some("Dot Product Shader"),
24            source: wgpu::ShaderSource::Wgsl(shaders::DOT_PRODUCT_SHADER.into()),
25        });
26
27        // Create buffers
28        let a_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
29            label: Some("Vector A"),
30            size: std::mem::size_of_val(a) as u64,
31            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
32            mapped_at_creation: false,
33        });
34
35        let b_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
36            label: Some("Vector B"),
37            size: std::mem::size_of_val(b) as u64,
38            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
39            mapped_at_creation: false,
40        });
41
42        // Result buffer for partial sums (one per workgroup)
43        let partial_results = vec![0.0f32; num_workgroups as usize];
44        let result_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
45            label: Some("Partial Results"),
46            size: std::mem::size_of_val(partial_results.as_slice()) as u64,
47            usage: wgpu::BufferUsages::STORAGE
48                | wgpu::BufferUsages::COPY_SRC
49                | wgpu::BufferUsages::COPY_DST,
50            mapped_at_creation: false,
51        });
52
53        // Write data to buffers
54        self.queue.write_buffer(&a_buffer, 0, bytemuck::cast_slice(a));
55        self.queue.write_buffer(&b_buffer, 0, bytemuck::cast_slice(b));
56
57        // Create bind group layout
58        let bind_group_layout =
59            self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
60                label: Some("Dot Product Bind Group Layout"),
61                entries: &[
62                    wgpu::BindGroupLayoutEntry {
63                        binding: 0,
64                        visibility: wgpu::ShaderStages::COMPUTE,
65                        ty: wgpu::BindingType::Buffer {
66                            ty: wgpu::BufferBindingType::Storage { read_only: true },
67                            has_dynamic_offset: false,
68                            min_binding_size: None,
69                        },
70                        count: None,
71                    },
72                    wgpu::BindGroupLayoutEntry {
73                        binding: 1,
74                        visibility: wgpu::ShaderStages::COMPUTE,
75                        ty: wgpu::BindingType::Buffer {
76                            ty: wgpu::BufferBindingType::Storage { read_only: true },
77                            has_dynamic_offset: false,
78                            min_binding_size: None,
79                        },
80                        count: None,
81                    },
82                    wgpu::BindGroupLayoutEntry {
83                        binding: 2,
84                        visibility: wgpu::ShaderStages::COMPUTE,
85                        ty: wgpu::BindingType::Buffer {
86                            ty: wgpu::BufferBindingType::Storage { read_only: false },
87                            has_dynamic_offset: false,
88                            min_binding_size: None,
89                        },
90                        count: None,
91                    },
92                ],
93            });
94
95        // Create bind group
96        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
97            label: Some("Dot Product Bind Group"),
98            layout: &bind_group_layout,
99            entries: &[
100                wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() },
101                wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() },
102                wgpu::BindGroupEntry { binding: 2, resource: result_buffer.as_entire_binding() },
103            ],
104        });
105
106        // Create pipeline
107        let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
108            label: Some("Dot Product Pipeline Layout"),
109            bind_group_layouts: &[&bind_group_layout],
110            push_constant_ranges: &[],
111        });
112
113        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
114            label: Some("Dot Product Pipeline"),
115            layout: Some(&pipeline_layout),
116            module: &shader,
117            entry_point: Some("main"),
118            compilation_options: Default::default(),
119            cache: None,
120        });
121
122        // Create staging buffer for reading results
123        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
124            label: Some("Staging Buffer"),
125            size: std::mem::size_of_val(partial_results.as_slice()) as u64,
126            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
127            mapped_at_creation: false,
128        });
129
130        // Create command encoder
131        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
132            label: Some("Dot Product Encoder"),
133        });
134
135        {
136            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
137                label: Some("Dot Product Pass"),
138                timestamp_writes: None,
139            });
140            compute_pass.set_pipeline(&pipeline);
141            compute_pass.set_bind_group(0, &bind_group, &[]);
142
143            // Dispatch workgroups
144            compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
145        }
146
147        // Copy result to staging buffer
148        encoder.copy_buffer_to_buffer(
149            &result_buffer,
150            0,
151            &staging_buffer,
152            0,
153            std::mem::size_of_val(partial_results.as_slice()) as u64,
154        );
155
156        // Submit commands
157        self.queue.submit(Some(encoder.finish()));
158
159        // Read back results
160        let buffer_slice = staging_buffer.slice(..);
161        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
162        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
163            sender.send(result).ok();
164        });
165
166        // Poll device to ensure GPU work completes and callbacks are invoked
167        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
168
169        receiver
170            .receive()
171            .await
172            .ok_or("Failed to receive mapping result")?
173            .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
174
175        let final_result = {
176            let data = buffer_slice.get_mapped_range();
177            let partial_sums: &[f32] = bytemuck::cast_slice(&data);
178
179            // Sum the partial results from each workgroup on CPU
180            partial_sums.iter().sum()
181        };
182
183        staging_buffer.unmap();
184
185        Ok(final_result)
186    }
187}