Skip to main content

trueno/backends/gpu/
pool.rs

1//! Multi-GPU device pool for data-parallel workloads
2//!
3//! Provides [`GpuDevicePool`] for managing multiple GPU devices simultaneously.
4//! Designed for data-parallel training where each GPU processes a shard of the
5//! mini-batch independently.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use trueno::backends::gpu::GpuDevicePool;
11//!
12//! // Open all available GPUs
13//! let pool = GpuDevicePool::all()?;
14//! println!("Found {} GPUs", pool.len());
15//!
16//! // Or select specific GPUs by index
17//! let pool = GpuDevicePool::with_indices(&[0, 1])?;
18//! # Ok::<(), String>(())
19//! ```
20
21use super::GpuDevice;
22
23/// Pool of GPU devices for multi-GPU workloads
24///
25/// Thin wrapper over `Vec<GpuDevice>` that handles enumeration
26/// and selective device opening.
27pub struct GpuDevicePool {
28    devices: Vec<GpuDevice>,
29    indices: Vec<u32>,
30}
31
32impl GpuDevicePool {
33    /// Open all available non-CPU GPU adapters
34    ///
35    /// Filters out adapters with `Noop` backend (CPU fallback).
36    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
37    pub fn all() -> Result<Self, String> {
38        use super::runtime;
39        runtime::block_on(Self::all_async())
40    }
41
42    /// Open all available non-CPU GPU adapters (async)
43    pub async fn all_async() -> Result<Self, String> {
44        let instance = wgpu::Instance::default();
45        let adapters = instance.enumerate_adapters(wgpu::Backends::all());
46
47        if adapters.is_empty() {
48            return Err("No GPU adapters found".to_string());
49        }
50
51        // Filter out CPU/Noop backends
52        let gpu_adapters: Vec<(usize, _)> = adapters
53            .into_iter()
54            .enumerate()
55            .filter(|(_, adapter)| adapter.get_info().backend != wgpu::Backend::Noop)
56            .collect();
57
58        if gpu_adapters.is_empty() {
59            return Err("No non-CPU GPU adapters found".to_string());
60        }
61
62        let mut devices = Vec::with_capacity(gpu_adapters.len());
63        let mut indices = Vec::with_capacity(gpu_adapters.len());
64
65        for (idx, adapter) in gpu_adapters {
66            let mut limits = wgpu::Limits::default();
67            limits.max_buffer_size = adapter.limits().max_buffer_size;
68            limits.max_storage_buffer_binding_size =
69                adapter.limits().max_storage_buffer_binding_size;
70
71            let (device, queue) = adapter
72                .request_device(&wgpu::DeviceDescriptor {
73                    label: Some(&format!("Trueno GPU Device [{}]", idx)),
74                    required_features: wgpu::Features::empty(),
75                    required_limits: limits,
76                    memory_hints: wgpu::MemoryHints::Performance,
77                    experimental_features: Default::default(),
78                    trace: Default::default(),
79                })
80                .await
81                .map_err(|e| format!("Failed to create device at index {}: {}", idx, e))?;
82
83            devices.push(GpuDevice { device, queue });
84            indices.push(idx as u32);
85        }
86
87        Ok(Self { devices, indices })
88    }
89
90    /// Open specific GPU adapters by index
91    #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
92    pub fn with_indices(adapter_indices: &[u32]) -> Result<Self, String> {
93        use super::runtime;
94        runtime::block_on(Self::with_indices_async(adapter_indices))
95    }
96
97    /// Open specific GPU adapters by index (async)
98    pub async fn with_indices_async(adapter_indices: &[u32]) -> Result<Self, String> {
99        if adapter_indices.is_empty() {
100            return Err("No adapter indices specified".to_string());
101        }
102
103        let mut devices = Vec::with_capacity(adapter_indices.len());
104        let mut indices = Vec::with_capacity(adapter_indices.len());
105
106        for &idx in adapter_indices {
107            let device = GpuDevice::new_with_adapter_index_async(idx).await?;
108            devices.push(device);
109            indices.push(idx);
110        }
111
112        Ok(Self { devices, indices })
113    }
114
115    /// Number of devices in the pool
116    #[must_use]
117    pub fn len(&self) -> usize {
118        self.devices.len()
119    }
120
121    /// Whether the pool is empty
122    #[must_use]
123    pub fn is_empty(&self) -> bool {
124        self.devices.is_empty()
125    }
126
127    /// Get a device by pool position (0-based within this pool)
128    #[must_use]
129    pub fn get(&self, pool_index: usize) -> Option<&GpuDevice> {
130        self.devices.get(pool_index)
131    }
132
133    /// Get the adapter index for a pool position
134    #[must_use]
135    pub fn adapter_index(&self, pool_index: usize) -> Option<u32> {
136        self.indices.get(pool_index).copied()
137    }
138
139    /// Iterate over (adapter_index, device) pairs
140    pub fn iter(&self) -> impl Iterator<Item = (u32, &GpuDevice)> {
141        self.indices.iter().copied().zip(self.devices.iter())
142    }
143
144    /// Consume the pool and return the devices
145    pub fn into_devices(self) -> Vec<GpuDevice> {
146        self.devices
147    }
148}
149
150#[cfg(all(test, feature = "gpu", not(target_arch = "wasm32")))]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_pool_len_matches_devices() {
156        if !GpuDevice::is_available() {
157            eprintln!("GPU not available, skipping pool test");
158            return;
159        }
160
161        let pool = GpuDevicePool::all();
162        match pool {
163            Ok(p) => {
164                assert!(!p.is_empty());
165                assert!(p.len() > 0);
166                assert!(p.get(0).is_some());
167                assert!(p.adapter_index(0).is_some());
168            }
169            Err(e) => {
170                eprintln!("Pool creation failed (expected on CPU-only): {}", e);
171            }
172        }
173    }
174}