trueno/backends/gpu/device/
mod.rs1mod activations;
13mod backward;
14mod eigen;
15pub(crate) mod linalg;
16mod reductions;
17
18#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
19use super::runtime;
20
21#[derive(Clone)]
23pub struct GpuDevice {
24 pub device: wgpu::Device,
25 pub queue: wgpu::Queue,
26}
27
28impl GpuDevice {
29 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
31 pub fn new() -> Result<Self, String> {
32 runtime::block_on(async { Self::new_async().await })
33 }
34
35 pub async fn new_async() -> Result<Self, String> {
37 let instance = wgpu::Instance::default();
39
40 let adapter = instance
42 .request_adapter(&wgpu::RequestAdapterOptions {
43 power_preference: wgpu::PowerPreference::HighPerformance,
44 compatible_surface: None,
45 force_fallback_adapter: false,
46 })
47 .await
48 .map_err(|e| format!("Failed to find GPU adapter: {}", e))?;
49
50 let mut limits = wgpu::Limits::default();
54 limits.max_buffer_size = adapter.limits().max_buffer_size;
55 limits.max_storage_buffer_binding_size = adapter.limits().max_storage_buffer_binding_size;
56
57 let (device, queue) = adapter
58 .request_device(&wgpu::DeviceDescriptor {
59 label: Some("Trueno GPU Device"),
60 required_features: wgpu::Features::empty(),
61 required_limits: limits,
62 memory_hints: wgpu::MemoryHints::Performance,
63 experimental_features: Default::default(),
64 trace: Default::default(),
65 })
66 .await
67 .map_err(|e| format!("Failed to create device: {}", e))?;
68
69 Ok(Self { device, queue })
70 }
71
72 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
77 pub fn new_with_adapter_index(index: u32) -> Result<Self, String> {
78 runtime::block_on(async { Self::new_with_adapter_index_async(index).await })
79 }
80
81 pub async fn new_with_adapter_index_async(index: u32) -> Result<Self, String> {
86 let instance = wgpu::Instance::default();
87 let adapters = instance.enumerate_adapters(wgpu::Backends::all());
88
89 if adapters.is_empty() {
90 return Err("No GPU adapters found".to_string());
91 }
92
93 let adapter = adapters
94 .into_iter()
95 .nth(index as usize)
96 .ok_or_else(|| format!("GPU adapter index {} out of range", index))?;
97
98 let mut limits = wgpu::Limits::default();
99 limits.max_buffer_size = adapter.limits().max_buffer_size;
100 limits.max_storage_buffer_binding_size = adapter.limits().max_storage_buffer_binding_size;
101
102 let (device, queue) = adapter
103 .request_device(&wgpu::DeviceDescriptor {
104 label: Some(&format!("Trueno GPU Device [{}]", index)),
105 required_features: wgpu::Features::empty(),
106 required_limits: limits,
107 memory_hints: wgpu::MemoryHints::Performance,
108 experimental_features: Default::default(),
109 trace: Default::default(),
110 })
111 .await
112 .map_err(|e| format!("Failed to create device at index {}: {}", index, e))?;
113
114 Ok(Self { device, queue })
115 }
116
117 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
121 pub fn list_adapters() -> Vec<(u32, String, String)> {
122 runtime::block_on(Self::list_adapters_async())
123 }
124
125 pub async fn list_adapters_async() -> Vec<(u32, String, String)> {
127 let instance = wgpu::Instance::default();
128 let adapters = instance.enumerate_adapters(wgpu::Backends::all());
129
130 adapters
131 .iter()
132 .enumerate()
133 .map(|(idx, adapter)| {
134 let info = adapter.get_info();
135 (idx as u32, info.name, format!("{:?}", info.backend))
136 })
137 .collect()
138 }
139
140 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
142 pub fn is_available() -> bool {
143 runtime::block_on(Self::is_available_async())
144 }
145
146 pub async fn is_available_async() -> bool {
148 let instance = wgpu::Instance::default();
149 instance
150 .request_adapter(&wgpu::RequestAdapterOptions {
151 power_preference: wgpu::PowerPreference::HighPerformance,
152 compatible_surface: None,
153 force_fallback_adapter: false,
154 })
155 .await
156 .is_ok()
157 }
158
159 pub(super) async fn execute_element_wise_op(
172 &self,
173 op_name: &str,
174 shader_source: &str,
175 input: &[f32],
176 result: &mut [f32],
177 uniform_data: Option<&[u8]>,
178 ) -> Result<(), String> {
179 let len = input.len();
180
181 let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
183 label: Some(&format!("{} Shader", op_name)),
184 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
185 });
186
187 let input_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
189 label: Some(&format!("{} Input", op_name)),
190 size: std::mem::size_of_val(input) as u64,
191 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
192 mapped_at_creation: false,
193 });
194
195 let output_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
197 label: Some(&format!("{} Output", op_name)),
198 size: std::mem::size_of_val(result) as u64,
199 usage: wgpu::BufferUsages::STORAGE
200 | wgpu::BufferUsages::COPY_SRC
201 | wgpu::BufferUsages::COPY_DST,
202 mapped_at_creation: false,
203 });
204
205 self.queue.write_buffer(&input_buffer, 0, bytemuck::cast_slice(input));
207
208 let uniform_buffer = uniform_data.map(|data| {
210 let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
211 label: Some(&format!("{} Uniform", op_name)),
212 size: data.len() as u64,
213 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
214 mapped_at_creation: false,
215 });
216 self.queue.write_buffer(&buffer, 0, data);
217 buffer
218 });
219
220 let mut bind_group_entries = vec![
222 wgpu::BindGroupLayoutEntry {
223 binding: 0,
224 visibility: wgpu::ShaderStages::COMPUTE,
225 ty: wgpu::BindingType::Buffer {
226 ty: wgpu::BufferBindingType::Storage { read_only: true },
227 has_dynamic_offset: false,
228 min_binding_size: None,
229 },
230 count: None,
231 },
232 wgpu::BindGroupLayoutEntry {
233 binding: 1,
234 visibility: wgpu::ShaderStages::COMPUTE,
235 ty: wgpu::BindingType::Buffer {
236 ty: wgpu::BufferBindingType::Storage { read_only: false },
237 has_dynamic_offset: false,
238 min_binding_size: None,
239 },
240 count: None,
241 },
242 ];
243
244 if uniform_buffer.is_some() {
246 bind_group_entries.push(wgpu::BindGroupLayoutEntry {
247 binding: 2,
248 visibility: wgpu::ShaderStages::COMPUTE,
249 ty: wgpu::BindingType::Buffer {
250 ty: wgpu::BufferBindingType::Uniform,
251 has_dynamic_offset: false,
252 min_binding_size: None,
253 },
254 count: None,
255 });
256 }
257
258 let bind_group_layout =
260 self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
261 label: Some(&format!("{} Bind Group Layout", op_name)),
262 entries: &bind_group_entries,
263 });
264
265 let mut bind_entries = vec![
267 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
268 wgpu::BindGroupEntry { binding: 1, resource: output_buffer.as_entire_binding() },
269 ];
270
271 if let Some(ref uniform_buf) = uniform_buffer {
273 bind_entries.push(wgpu::BindGroupEntry {
274 binding: 2,
275 resource: uniform_buf.as_entire_binding(),
276 });
277 }
278
279 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
281 label: Some(&format!("{} Bind Group", op_name)),
282 layout: &bind_group_layout,
283 entries: &bind_entries,
284 });
285
286 let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
288 label: Some(&format!("{} Pipeline Layout", op_name)),
289 bind_group_layouts: &[&bind_group_layout],
290 push_constant_ranges: &[],
291 });
292
293 let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
294 label: Some(&format!("{} Pipeline", op_name)),
295 layout: Some(&pipeline_layout),
296 module: &shader,
297 entry_point: Some("main"),
298 compilation_options: Default::default(),
299 cache: None,
300 });
301
302 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
304 label: Some(&format!("{} Staging Buffer", op_name)),
305 size: std::mem::size_of_val(result) as u64,
306 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
307 mapped_at_creation: false,
308 });
309
310 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
312 label: Some(&format!("{} Encoder", op_name)),
313 });
314
315 {
316 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
317 label: Some(&format!("{} Pass", op_name)),
318 timestamp_writes: None,
319 });
320 compute_pass.set_pipeline(&pipeline);
321 compute_pass.set_bind_group(0, &bind_group, &[]);
322
323 let workgroup_size = 256;
325 let num_workgroups = (len as u32).div_ceil(workgroup_size);
326
327 compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
328 }
329
330 encoder.copy_buffer_to_buffer(
332 &output_buffer,
333 0,
334 &staging_buffer,
335 0,
336 std::mem::size_of_val(result) as u64,
337 );
338
339 self.queue.submit(Some(encoder.finish()));
341
342 let buffer_slice = staging_buffer.slice(..);
344 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
345 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
346 sender.send(result).ok();
347 });
348
349 self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
351
352 receiver
353 .receive()
354 .await
355 .ok_or("Failed to receive mapping result")?
356 .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
357
358 {
359 let data = buffer_slice.get_mapped_range();
360 result.copy_from_slice(bytemuck::cast_slice(&data));
361 }
362
363 staging_buffer.unmap();
364
365 Ok(())
366 }
367}
368
369#[cfg(all(test, feature = "gpu", not(target_arch = "wasm32")))]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_is_available_consistency() {
375 let available = GpuDevice::is_available();
378 let device_result = GpuDevice::new();
379
380 if available {
381 assert!(
383 device_result.is_ok(),
384 "is_available() returned true, but GpuDevice::new() failed"
385 );
386 } else {
387 eprintln!(
391 "GPU not available (is_available=false), device creation result: {:?}",
392 device_result.is_err()
393 );
394 }
395 }
396
397 #[test]
398 fn test_reduce_sum_not_hardcoded() {
399 if !GpuDevice::is_available() {
401 eprintln!("GPU not available, skipping test");
402 return;
403 }
404
405 let device = GpuDevice::new().expect("Failed to create GPU device");
406 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0]; let result = runtime::block_on(device.reduce_sum(&input)).expect("reduce_sum failed");
410
411 assert_ne!(result, -1.0, "reduce_sum returned hardcoded -1.0 (mutant not killed)");
413
414 let expected: f32 = input.iter().sum();
416 assert!(
417 (result - expected).abs() < 1e-4,
418 "reduce_sum({:?}) = {} (expected {})",
419 input,
420 result,
421 expected
422 );
423 }
424}