Skip to main content

trueno/backends/gpu/device/linalg/
vec_ops.rs

1//! GPU vector addition operations
2
3use super::super::GpuDevice;
4#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
5use crate::backends::gpu::runtime;
6use crate::backends::gpu::shaders;
7
8impl GpuDevice {
9    /// Execute vector addition on GPU: c = a + b (sync, native only)
10    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
11    pub fn vec_add(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> Result<(), String> {
12        runtime::block_on(async { self.vec_add_async(a, b, result).await })
13    }
14
15    /// Execute vector addition on GPU: c = a + b (async, works on all platforms)
16    pub async fn vec_add_async(
17        &self,
18        a: &[f32],
19        b: &[f32],
20        result: &mut [f32],
21    ) -> Result<(), String> {
22        let len = a.len();
23
24        // Create shader module
25        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
26            label: Some("Vec Add Shader"),
27            source: wgpu::ShaderSource::Wgsl(shaders::VEC_ADD_SHADER.into()),
28        });
29
30        // Create buffers
31        let a_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
32            label: Some("Vector A"),
33            size: std::mem::size_of_val(a) as u64,
34            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
35            mapped_at_creation: false,
36        });
37
38        let b_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
39            label: Some("Vector B"),
40            size: std::mem::size_of_val(b) as u64,
41            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
42            mapped_at_creation: false,
43        });
44
45        let c_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
46            label: Some("Vector C"),
47            size: std::mem::size_of_val(result) as u64,
48            usage: wgpu::BufferUsages::STORAGE
49                | wgpu::BufferUsages::COPY_SRC
50                | wgpu::BufferUsages::COPY_DST,
51            mapped_at_creation: false,
52        });
53
54        // Write data to buffers
55        self.queue.write_buffer(&a_buffer, 0, bytemuck::cast_slice(a));
56        self.queue.write_buffer(&b_buffer, 0, bytemuck::cast_slice(b));
57
58        // Create bind group layout
59        let bind_group_layout =
60            self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
61                label: Some("Vec Add Bind Group Layout"),
62                entries: &[
63                    wgpu::BindGroupLayoutEntry {
64                        binding: 0,
65                        visibility: wgpu::ShaderStages::COMPUTE,
66                        ty: wgpu::BindingType::Buffer {
67                            ty: wgpu::BufferBindingType::Storage { read_only: true },
68                            has_dynamic_offset: false,
69                            min_binding_size: None,
70                        },
71                        count: None,
72                    },
73                    wgpu::BindGroupLayoutEntry {
74                        binding: 1,
75                        visibility: wgpu::ShaderStages::COMPUTE,
76                        ty: wgpu::BindingType::Buffer {
77                            ty: wgpu::BufferBindingType::Storage { read_only: true },
78                            has_dynamic_offset: false,
79                            min_binding_size: None,
80                        },
81                        count: None,
82                    },
83                    wgpu::BindGroupLayoutEntry {
84                        binding: 2,
85                        visibility: wgpu::ShaderStages::COMPUTE,
86                        ty: wgpu::BindingType::Buffer {
87                            ty: wgpu::BufferBindingType::Storage { read_only: false },
88                            has_dynamic_offset: false,
89                            min_binding_size: None,
90                        },
91                        count: None,
92                    },
93                ],
94            });
95
96        // Create bind group
97        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
98            label: Some("Vec Add Bind Group"),
99            layout: &bind_group_layout,
100            entries: &[
101                wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() },
102                wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() },
103                wgpu::BindGroupEntry { binding: 2, resource: c_buffer.as_entire_binding() },
104            ],
105        });
106
107        // Create pipeline
108        let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
109            label: Some("Vec Add Pipeline Layout"),
110            bind_group_layouts: &[&bind_group_layout],
111            push_constant_ranges: &[],
112        });
113
114        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
115            label: Some("Vec Add Pipeline"),
116            layout: Some(&pipeline_layout),
117            module: &shader,
118            entry_point: Some("main"),
119            compilation_options: Default::default(),
120            cache: None,
121        });
122
123        // Create staging buffer for reading results
124        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
125            label: Some("Staging Buffer"),
126            size: std::mem::size_of_val(result) as u64,
127            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
128            mapped_at_creation: false,
129        });
130
131        // Create command encoder
132        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
133            label: Some("Vec Add Encoder"),
134        });
135
136        {
137            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
138                label: Some("Vec Add Pass"),
139                timestamp_writes: None,
140            });
141            compute_pass.set_pipeline(&pipeline);
142            compute_pass.set_bind_group(0, &bind_group, &[]);
143
144            // Dispatch workgroups (256 threads per workgroup)
145            let workgroup_size = 256;
146            let num_workgroups = (len as u32).div_ceil(workgroup_size);
147
148            compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
149        }
150
151        // Copy result to staging buffer
152        encoder.copy_buffer_to_buffer(
153            &c_buffer,
154            0,
155            &staging_buffer,
156            0,
157            std::mem::size_of_val(result) as u64,
158        );
159
160        // Submit commands
161        self.queue.submit(Some(encoder.finish()));
162
163        // Read back results
164        let buffer_slice = staging_buffer.slice(..);
165        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
166        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
167            sender.send(result).ok();
168        });
169
170        // Poll device to ensure GPU work completes and callbacks are invoked
171        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
172
173        receiver
174            .receive()
175            .await
176            .ok_or("Failed to receive mapping result")?
177            .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
178
179        {
180            let data = buffer_slice.get_mapped_range();
181            result.copy_from_slice(bytemuck::cast_slice(&data));
182        }
183
184        staging_buffer.unmap();
185
186        Ok(())
187    }
188}