trueno/backends/gpu/
pool.rs1use super::GpuDevice;
22
23pub struct GpuDevicePool {
28 devices: Vec<GpuDevice>,
29 indices: Vec<u32>,
30}
31
32impl GpuDevicePool {
33 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
37 pub fn all() -> Result<Self, String> {
38 use super::runtime;
39 runtime::block_on(Self::all_async())
40 }
41
42 pub async fn all_async() -> Result<Self, String> {
44 let instance = wgpu::Instance::default();
45 let adapters = instance.enumerate_adapters(wgpu::Backends::all());
46
47 if adapters.is_empty() {
48 return Err("No GPU adapters found".to_string());
49 }
50
51 let gpu_adapters: Vec<(usize, _)> = adapters
53 .into_iter()
54 .enumerate()
55 .filter(|(_, adapter)| adapter.get_info().backend != wgpu::Backend::Noop)
56 .collect();
57
58 if gpu_adapters.is_empty() {
59 return Err("No non-CPU GPU adapters found".to_string());
60 }
61
62 let mut devices = Vec::with_capacity(gpu_adapters.len());
63 let mut indices = Vec::with_capacity(gpu_adapters.len());
64
65 for (idx, adapter) in gpu_adapters {
66 let mut limits = wgpu::Limits::default();
67 limits.max_buffer_size = adapter.limits().max_buffer_size;
68 limits.max_storage_buffer_binding_size =
69 adapter.limits().max_storage_buffer_binding_size;
70
71 let (device, queue) = adapter
72 .request_device(&wgpu::DeviceDescriptor {
73 label: Some(&format!("Trueno GPU Device [{}]", idx)),
74 required_features: wgpu::Features::empty(),
75 required_limits: limits,
76 memory_hints: wgpu::MemoryHints::Performance,
77 experimental_features: Default::default(),
78 trace: Default::default(),
79 })
80 .await
81 .map_err(|e| format!("Failed to create device at index {}: {}", idx, e))?;
82
83 devices.push(GpuDevice { device, queue });
84 indices.push(idx as u32);
85 }
86
87 Ok(Self { devices, indices })
88 }
89
90 #[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
92 pub fn with_indices(adapter_indices: &[u32]) -> Result<Self, String> {
93 use super::runtime;
94 runtime::block_on(Self::with_indices_async(adapter_indices))
95 }
96
97 pub async fn with_indices_async(adapter_indices: &[u32]) -> Result<Self, String> {
99 if adapter_indices.is_empty() {
100 return Err("No adapter indices specified".to_string());
101 }
102
103 let mut devices = Vec::with_capacity(adapter_indices.len());
104 let mut indices = Vec::with_capacity(adapter_indices.len());
105
106 for &idx in adapter_indices {
107 let device = GpuDevice::new_with_adapter_index_async(idx).await?;
108 devices.push(device);
109 indices.push(idx);
110 }
111
112 Ok(Self { devices, indices })
113 }
114
115 #[must_use]
117 pub fn len(&self) -> usize {
118 self.devices.len()
119 }
120
121 #[must_use]
123 pub fn is_empty(&self) -> bool {
124 self.devices.is_empty()
125 }
126
127 #[must_use]
129 pub fn get(&self, pool_index: usize) -> Option<&GpuDevice> {
130 self.devices.get(pool_index)
131 }
132
133 #[must_use]
135 pub fn adapter_index(&self, pool_index: usize) -> Option<u32> {
136 self.indices.get(pool_index).copied()
137 }
138
139 pub fn iter(&self) -> impl Iterator<Item = (u32, &GpuDevice)> {
141 self.indices.iter().copied().zip(self.devices.iter())
142 }
143
144 pub fn into_devices(self) -> Vec<GpuDevice> {
146 self.devices
147 }
148}
149
150#[cfg(all(test, feature = "gpu", not(target_arch = "wasm32")))]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_pool_len_matches_devices() {
156 if !GpuDevice::is_available() {
157 eprintln!("GPU not available, skipping pool test");
158 return;
159 }
160
161 let pool = GpuDevicePool::all();
162 match pool {
163 Ok(p) => {
164 assert!(!p.is_empty());
165 assert!(p.len() > 0);
166 assert!(p.get(0).is_some());
167 assert!(p.adapter_index(0).is_some());
168 }
169 Err(e) => {
170 eprintln!("Pool creation failed (expected on CPU-only): {}", e);
171 }
172 }
173 }
174}