trueno/backends/gpu/device/linalg/
vec_ops.rs1use super::super::GpuDevice;
4#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
5use crate::backends::gpu::runtime;
6use crate::backends::gpu::shaders;
7
8impl GpuDevice {
9 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
11 pub fn vec_add(&self, a: &[f32], b: &[f32], result: &mut [f32]) -> Result<(), String> {
12 runtime::block_on(async { self.vec_add_async(a, b, result).await })
13 }
14
15 pub async fn vec_add_async(
17 &self,
18 a: &[f32],
19 b: &[f32],
20 result: &mut [f32],
21 ) -> Result<(), String> {
22 let len = a.len();
23
24 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
26 label: Some("Vec Add Shader"),
27 source: wgpu::ShaderSource::Wgsl(shaders::VEC_ADD_SHADER.into()),
28 });
29
30 let a_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
32 label: Some("Vector A"),
33 size: std::mem::size_of_val(a) as u64,
34 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
35 mapped_at_creation: false,
36 });
37
38 let b_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
39 label: Some("Vector B"),
40 size: std::mem::size_of_val(b) as u64,
41 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
42 mapped_at_creation: false,
43 });
44
45 let c_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
46 label: Some("Vector C"),
47 size: std::mem::size_of_val(result) as u64,
48 usage: wgpu::BufferUsages::STORAGE
49 | wgpu::BufferUsages::COPY_SRC
50 | wgpu::BufferUsages::COPY_DST,
51 mapped_at_creation: false,
52 });
53
54 self.queue.write_buffer(&a_buffer, 0, bytemuck::cast_slice(a));
56 self.queue.write_buffer(&b_buffer, 0, bytemuck::cast_slice(b));
57
58 let bind_group_layout =
60 self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
61 label: Some("Vec Add Bind Group Layout"),
62 entries: &[
63 wgpu::BindGroupLayoutEntry {
64 binding: 0,
65 visibility: wgpu::ShaderStages::COMPUTE,
66 ty: wgpu::BindingType::Buffer {
67 ty: wgpu::BufferBindingType::Storage { read_only: true },
68 has_dynamic_offset: false,
69 min_binding_size: None,
70 },
71 count: None,
72 },
73 wgpu::BindGroupLayoutEntry {
74 binding: 1,
75 visibility: wgpu::ShaderStages::COMPUTE,
76 ty: wgpu::BindingType::Buffer {
77 ty: wgpu::BufferBindingType::Storage { read_only: true },
78 has_dynamic_offset: false,
79 min_binding_size: None,
80 },
81 count: None,
82 },
83 wgpu::BindGroupLayoutEntry {
84 binding: 2,
85 visibility: wgpu::ShaderStages::COMPUTE,
86 ty: wgpu::BindingType::Buffer {
87 ty: wgpu::BufferBindingType::Storage { read_only: false },
88 has_dynamic_offset: false,
89 min_binding_size: None,
90 },
91 count: None,
92 },
93 ],
94 });
95
96 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
98 label: Some("Vec Add Bind Group"),
99 layout: &bind_group_layout,
100 entries: &[
101 wgpu::BindGroupEntry { binding: 0, resource: a_buffer.as_entire_binding() },
102 wgpu::BindGroupEntry { binding: 1, resource: b_buffer.as_entire_binding() },
103 wgpu::BindGroupEntry { binding: 2, resource: c_buffer.as_entire_binding() },
104 ],
105 });
106
107 let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
109 label: Some("Vec Add Pipeline Layout"),
110 bind_group_layouts: &[&bind_group_layout],
111 push_constant_ranges: &[],
112 });
113
114 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
115 label: Some("Vec Add Pipeline"),
116 layout: Some(&pipeline_layout),
117 module: &shader,
118 entry_point: Some("main"),
119 compilation_options: Default::default(),
120 cache: None,
121 });
122
123 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
125 label: Some("Staging Buffer"),
126 size: std::mem::size_of_val(result) as u64,
127 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
128 mapped_at_creation: false,
129 });
130
131 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
133 label: Some("Vec Add Encoder"),
134 });
135
136 {
137 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
138 label: Some("Vec Add Pass"),
139 timestamp_writes: None,
140 });
141 compute_pass.set_pipeline(&pipeline);
142 compute_pass.set_bind_group(0, &bind_group, &[]);
143
144 let workgroup_size = 256;
146 let num_workgroups = (len as u32).div_ceil(workgroup_size);
147
148 compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
149 }
150
151 encoder.copy_buffer_to_buffer(
153 &c_buffer,
154 0,
155 &staging_buffer,
156 0,
157 std::mem::size_of_val(result) as u64,
158 );
159
160 self.queue.submit(Some(encoder.finish()));
162
163 let buffer_slice = staging_buffer.slice(..);
165 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
166 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
167 sender.send(result).ok();
168 });
169
170 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
172
173 receiver
174 .receive()
175 .await
176 .ok_or("Failed to receive mapping result")?
177 .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
178
179 {
180 let data = buffer_slice.get_mapped_range();
181 result.copy_from_slice(bytemuck::cast_slice(&data));
182 }
183
184 staging_buffer.unmap();
185
186 Ok(())
187 }
188}