Skip to main content

trueno/backends/gpu/device/
mod.rs

1//! GPU device initialization and management
2//!
3//! This module provides cross-platform GPU compute via wgpu (WebGPU).
4//!
5//! # Platform differences
6//!
7//! - **Native**: Sync wrappers available using `pollster::block_on`
8//! - **WASM**: Sync wrappers unavailable (can't block main thread); use `*_async` methods
9//!
10//! Use `runtime::sync_available()` to check at runtime.
11
12mod activations;
13mod backward;
14mod eigen;
15pub(crate) mod linalg;
16mod reductions;
17
18#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
19use super::runtime;
20
21/// GPU device manager
22#[derive(Clone)]
23pub struct GpuDevice {
24    pub device: wgpu::Device,
25    pub queue: wgpu::Queue,
26}
27
28impl GpuDevice {
29    /// Initialize GPU device (sync, native only)
30    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
31    pub fn new() -> Result<Self, String> {
32        runtime::block_on(async { Self::new_async().await })
33    }
34
35    /// Initialize GPU device (async, works on all platforms)
36    pub async fn new_async() -> Result<Self, String> {
37        // Create instance
38        let instance = wgpu::Instance::default();
39
40        // Request adapter (GPU)
41        let adapter = instance
42            .request_adapter(&wgpu::RequestAdapterOptions {
43                power_preference: wgpu::PowerPreference::HighPerformance,
44                compatible_surface: None,
45                force_fallback_adapter: false,
46            })
47            .await
48            .map_err(|e| format!("Failed to find GPU adapter: {}", e))?;
49
50        // Request device and queue with adapter's actual max buffer size
51        // Default wgpu limits cap buffers at 256MB, which is too small for
52        // 7B+ model weight matrices (e.g., FFN [18944, 3584] x f32 = 271MB)
53        let mut limits = wgpu::Limits::default();
54        limits.max_buffer_size = adapter.limits().max_buffer_size;
55        limits.max_storage_buffer_binding_size = adapter.limits().max_storage_buffer_binding_size;
56
57        let (device, queue) = adapter
58            .request_device(&wgpu::DeviceDescriptor {
59                label: Some("Trueno GPU Device"),
60                required_features: wgpu::Features::empty(),
61                required_limits: limits,
62                memory_hints: wgpu::MemoryHints::Performance,
63                experimental_features: Default::default(),
64                trace: Default::default(),
65            })
66            .await
67            .map_err(|e| format!("Failed to create device: {}", e))?;
68
69        Ok(Self { device, queue })
70    }
71
72    /// Initialize GPU device with a specific adapter index (sync, native only)
73    ///
74    /// Use this to select a specific GPU when multiple are available.
75    /// Adapter indices correspond to `Instance::enumerate_adapters()` ordering.
76    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
77    pub fn new_with_adapter_index(index: u32) -> Result<Self, String> {
78        runtime::block_on(async { Self::new_with_adapter_index_async(index).await })
79    }
80
81    /// Initialize GPU device with a specific adapter index (async, all platforms)
82    ///
83    /// Use this to select a specific GPU when multiple are available.
84    /// Adapter indices correspond to `Instance::enumerate_adapters()` ordering.
85    pub async fn new_with_adapter_index_async(index: u32) -> Result<Self, String> {
86        let instance = wgpu::Instance::default();
87        let adapters = instance.enumerate_adapters(wgpu::Backends::all());
88
89        if adapters.is_empty() {
90            return Err("No GPU adapters found".to_string());
91        }
92
93        let adapter = adapters
94            .into_iter()
95            .nth(index as usize)
96            .ok_or_else(|| format!("GPU adapter index {} out of range", index))?;
97
98        let mut limits = wgpu::Limits::default();
99        limits.max_buffer_size = adapter.limits().max_buffer_size;
100        limits.max_storage_buffer_binding_size = adapter.limits().max_storage_buffer_binding_size;
101
102        let (device, queue) = adapter
103            .request_device(&wgpu::DeviceDescriptor {
104                label: Some(&format!("Trueno GPU Device [{}]", index)),
105                required_features: wgpu::Features::empty(),
106                required_limits: limits,
107                memory_hints: wgpu::MemoryHints::Performance,
108                experimental_features: Default::default(),
109                trace: Default::default(),
110            })
111            .await
112            .map_err(|e| format!("Failed to create device at index {}: {}", index, e))?;
113
114        Ok(Self { device, queue })
115    }
116
117    /// List all available GPU adapters (sync, native only)
118    ///
119    /// Returns a list of (index, name, backend) tuples for each adapter.
120    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
121    pub fn list_adapters() -> Vec<(u32, String, String)> {
122        runtime::block_on(Self::list_adapters_async())
123    }
124
125    /// List all available GPU adapters (async, all platforms)
126    pub async fn list_adapters_async() -> Vec<(u32, String, String)> {
127        let instance = wgpu::Instance::default();
128        let adapters = instance.enumerate_adapters(wgpu::Backends::all());
129
130        adapters
131            .iter()
132            .enumerate()
133            .map(|(idx, adapter)| {
134                let info = adapter.get_info();
135                (idx as u32, info.name, format!("{:?}", info.backend))
136            })
137            .collect()
138    }
139
140    /// Check if GPU is available (sync, native only)
141    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
142    pub fn is_available() -> bool {
143        runtime::block_on(Self::is_available_async())
144    }
145
146    /// Check if GPU is available (async, works on all platforms)
147    pub async fn is_available_async() -> bool {
148        let instance = wgpu::Instance::default();
149        instance
150            .request_adapter(&wgpu::RequestAdapterOptions {
151                power_preference: wgpu::PowerPreference::HighPerformance,
152                compatible_surface: None,
153                force_fallback_adapter: false,
154            })
155            .await
156            .is_ok()
157    }
158
159    /// Generic helper for element-wise GPU operations
160    ///
161    /// This helper eliminates code duplication between element-wise operations
162    /// (relu, clip, sigmoid, tanh, etc.) by abstracting the common GPU compute pattern.
163    ///
164    /// # Arguments
165    ///
166    /// * `op_name` - Operation name for labels (e.g., "ReLU", "Clip")
167    /// * `shader_source` - WGSL shader source code
168    /// * `input` - Input data
169    /// * `result` - Output buffer
170    /// * `uniform_data` - Optional uniform buffer data (e.g., clip parameters)
171    pub(super) async fn execute_element_wise_op(
172        &self,
173        op_name: &str,
174        shader_source: &str,
175        input: &[f32],
176        result: &mut [f32],
177        uniform_data: Option<&[u8]>,
178    ) -> Result<(), String> {
179        let len = input.len();
180
181        // Create shader module
182        let shader = self.device.create_shader_module(wgpu::ShaderModuleDescriptor {
183            label: Some(&format!("{} Shader", op_name)),
184            source: wgpu::ShaderSource::Wgsl(shader_source.into()),
185        });
186
187        // Create input buffer
188        let input_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
189            label: Some(&format!("{} Input", op_name)),
190            size: std::mem::size_of_val(input) as u64,
191            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
192            mapped_at_creation: false,
193        });
194
195        // Create output buffer
196        let output_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
197            label: Some(&format!("{} Output", op_name)),
198            size: std::mem::size_of_val(result) as u64,
199            usage: wgpu::BufferUsages::STORAGE
200                | wgpu::BufferUsages::COPY_SRC
201                | wgpu::BufferUsages::COPY_DST,
202            mapped_at_creation: false,
203        });
204
205        // Write input data
206        self.queue.write_buffer(&input_buffer, 0, bytemuck::cast_slice(input));
207
208        // Create optional uniform buffer
209        let uniform_buffer = uniform_data.map(|data| {
210            let buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
211                label: Some(&format!("{} Uniform", op_name)),
212                size: data.len() as u64,
213                usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
214                mapped_at_creation: false,
215            });
216            self.queue.write_buffer(&buffer, 0, data);
217            buffer
218        });
219
220        // Create bind group layout entries (input + output + optional uniform)
221        let mut bind_group_entries = vec![
222            wgpu::BindGroupLayoutEntry {
223                binding: 0,
224                visibility: wgpu::ShaderStages::COMPUTE,
225                ty: wgpu::BindingType::Buffer {
226                    ty: wgpu::BufferBindingType::Storage { read_only: true },
227                    has_dynamic_offset: false,
228                    min_binding_size: None,
229                },
230                count: None,
231            },
232            wgpu::BindGroupLayoutEntry {
233                binding: 1,
234                visibility: wgpu::ShaderStages::COMPUTE,
235                ty: wgpu::BindingType::Buffer {
236                    ty: wgpu::BufferBindingType::Storage { read_only: false },
237                    has_dynamic_offset: false,
238                    min_binding_size: None,
239                },
240                count: None,
241            },
242        ];
243
244        // Add uniform buffer binding if present
245        if uniform_buffer.is_some() {
246            bind_group_entries.push(wgpu::BindGroupLayoutEntry {
247                binding: 2,
248                visibility: wgpu::ShaderStages::COMPUTE,
249                ty: wgpu::BindingType::Buffer {
250                    ty: wgpu::BufferBindingType::Uniform,
251                    has_dynamic_offset: false,
252                    min_binding_size: None,
253                },
254                count: None,
255            });
256        }
257
258        // Create bind group layout
259        let bind_group_layout =
260            self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
261                label: Some(&format!("{} Bind Group Layout", op_name)),
262                entries: &bind_group_entries,
263            });
264
265        // Create bind group entries
266        let mut bind_entries = vec![
267            wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
268            wgpu::BindGroupEntry { binding: 1, resource: output_buffer.as_entire_binding() },
269        ];
270
271        // Add uniform buffer binding if present
272        if let Some(ref uniform_buf) = uniform_buffer {
273            bind_entries.push(wgpu::BindGroupEntry {
274                binding: 2,
275                resource: uniform_buf.as_entire_binding(),
276            });
277        }
278
279        // Create bind group
280        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
281            label: Some(&format!("{} Bind Group", op_name)),
282            layout: &bind_group_layout,
283            entries: &bind_entries,
284        });
285
286        // Create pipeline
287        let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
288            label: Some(&format!("{} Pipeline Layout", op_name)),
289            bind_group_layouts: &[&bind_group_layout],
290            push_constant_ranges: &[],
291        });
292
293        let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
294            label: Some(&format!("{} Pipeline", op_name)),
295            layout: Some(&pipeline_layout),
296            module: &shader,
297            entry_point: Some("main"),
298            compilation_options: Default::default(),
299            cache: None,
300        });
301
302        // Create staging buffer for reading results
303        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
304            label: Some(&format!("{} Staging Buffer", op_name)),
305            size: std::mem::size_of_val(result) as u64,
306            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
307            mapped_at_creation: false,
308        });
309
310        // Create command encoder
311        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
312            label: Some(&format!("{} Encoder", op_name)),
313        });
314
315        {
316            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
317                label: Some(&format!("{} Pass", op_name)),
318                timestamp_writes: None,
319            });
320            compute_pass.set_pipeline(&pipeline);
321            compute_pass.set_bind_group(0, &bind_group, &[]);
322
323            // Dispatch workgroups (256 threads per workgroup)
324            let workgroup_size = 256;
325            let num_workgroups = (len as u32).div_ceil(workgroup_size);
326
327            compute_pass.dispatch_workgroups(num_workgroups, 1, 1);
328        }
329
330        // Copy result to staging buffer
331        encoder.copy_buffer_to_buffer(
332            &output_buffer,
333            0,
334            &staging_buffer,
335            0,
336            std::mem::size_of_val(result) as u64,
337        );
338
339        // Submit commands
340        self.queue.submit(Some(encoder.finish()));
341
342        // Read back results
343        let buffer_slice = staging_buffer.slice(..);
344        let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
345        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
346            sender.send(result).ok();
347        });
348
349        // Poll device to ensure GPU work completes and callbacks are invoked
350        self.device.poll(wgpu::PollType::Wait { submission_index: None, timeout: None }).ok();
351
352        receiver
353            .receive()
354            .await
355            .ok_or("Failed to receive mapping result")?
356            .map_err(|e| format!("Buffer mapping failed: {:?}", e))?;
357
358        {
359            let data = buffer_slice.get_mapped_range();
360            result.copy_from_slice(bytemuck::cast_slice(&data));
361        }
362
363        staging_buffer.unmap();
364
365        Ok(())
366    }
367}
368
369#[cfg(all(test, feature = "gpu", not(target_arch = "wasm32")))]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_is_available_consistency() {
375        // EXTREME TDD: Kill mutant that replaces is_available() with hardcoded false
376        // Test that is_available() is consistent with GpuDevice::new()
377        let available = GpuDevice::is_available();
378        let device_result = GpuDevice::new();
379
380        if available {
381            // If is_available() returns true, device creation should succeed
382            assert!(
383                device_result.is_ok(),
384                "is_available() returned true, but GpuDevice::new() failed"
385            );
386        } else {
387            // If is_available() returns false, we can't make assertions about new()
388            // (it might still succeed in some edge cases, but typically should fail)
389            // The key test is: mutant always returns false, so on GPU systems this fails
390            eprintln!(
391                "GPU not available (is_available=false), device creation result: {:?}",
392                device_result.is_err()
393            );
394        }
395    }
396
397    #[test]
398    fn test_reduce_sum_not_hardcoded() {
399        // EXTREME TDD: Kill mutant that replaces reduce_sum with Ok(-1.0)
400        if !GpuDevice::is_available() {
401            eprintln!("GPU not available, skipping test");
402            return;
403        }
404
405        let device = GpuDevice::new().expect("Failed to create GPU device");
406        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // sum = 15.0
407
408        // reduce_sum is async, so we use runtime::block_on
409        let result = runtime::block_on(device.reduce_sum(&input)).expect("reduce_sum failed");
410
411        // Kill mutant: verify result is NOT -1.0
412        assert_ne!(result, -1.0, "reduce_sum returned hardcoded -1.0 (mutant not killed)");
413
414        // Verify correct computation
415        let expected: f32 = input.iter().sum();
416        assert!(
417            (result - expected).abs() < 1e-4,
418            "reduce_sum({:?}) = {} (expected {})",
419            input,
420            result,
421            expected
422        );
423    }
424}