oxigdal_gpu_advanced/multi_gpu/
mod.rs1pub mod affinity;
10pub mod device_manager;
11pub mod load_balancer;
12pub mod sync;
13pub mod work_queue;
14
15use crate::error::{GpuAdvancedError, Result};
16use dashmap::DashMap;
17use parking_lot::RwLock;
18use std::sync::Arc;
19use wgpu::{Adapter, Device, Queue};
20
21#[derive(Debug, Clone)]
23pub struct GpuDeviceInfo {
24 pub index: usize,
26 pub name: String,
28 pub backend: wgpu::Backend,
30 pub device_type: wgpu::DeviceType,
32 pub max_buffer_size: u64,
34 pub max_texture_dimension_1d: u32,
36 pub max_texture_dimension_2d: u32,
38 pub max_texture_dimension_3d: u32,
40 pub max_compute_workgroup_size_x: u32,
42 pub max_compute_workgroup_size_y: u32,
44 pub max_compute_workgroup_size_z: u32,
46 pub max_compute_workgroups_per_dimension: u32,
48 pub max_bind_groups: u32,
50 pub memory_size: Option<u64>,
52}
53
54pub struct GpuDevice {
56 pub info: GpuDeviceInfo,
58 pub adapter: Arc<Adapter>,
60 pub device: Arc<Device>,
62 pub queue: Arc<Queue>,
64 pub memory_usage: Arc<RwLock<u64>>,
66 pub workload: Arc<RwLock<f32>>,
68}
69
70impl GpuDevice {
71 pub fn new(index: usize, adapter: Adapter, device: Device, queue: Queue) -> Result<Self> {
73 let info = adapter.get_info();
74 let limits = device.limits();
75
76 let device_info = GpuDeviceInfo {
77 index,
78 name: info.name.clone(),
79 backend: info.backend,
80 device_type: info.device_type,
81 max_buffer_size: limits.max_buffer_size,
82 max_texture_dimension_1d: limits.max_texture_dimension_1d,
83 max_texture_dimension_2d: limits.max_texture_dimension_2d,
84 max_texture_dimension_3d: limits.max_texture_dimension_3d,
85 max_compute_workgroup_size_x: limits.max_compute_workgroup_size_x,
86 max_compute_workgroup_size_y: limits.max_compute_workgroup_size_y,
87 max_compute_workgroup_size_z: limits.max_compute_workgroup_size_z,
88 max_compute_workgroups_per_dimension: limits.max_compute_workgroups_per_dimension,
89 max_bind_groups: limits.max_bind_groups,
90 memory_size: None, };
92
93 Ok(Self {
94 info: device_info,
95 adapter: Arc::new(adapter),
96 device: Arc::new(device),
97 queue: Arc::new(queue),
98 memory_usage: Arc::new(RwLock::new(0)),
99 workload: Arc::new(RwLock::new(0.0)),
100 })
101 }
102
103 pub fn get_memory_usage(&self) -> u64 {
105 *self.memory_usage.read()
106 }
107
108 pub fn update_memory_usage(&self, delta: i64) {
110 let mut usage = self.memory_usage.write();
111 if delta >= 0 {
112 *usage = usage.saturating_add(delta as u64);
113 } else {
114 *usage = usage.saturating_sub((-delta) as u64);
115 }
116 }
117
118 pub fn get_workload(&self) -> f32 {
120 *self.workload.read()
121 }
122
123 pub fn set_workload(&self, workload: f32) {
125 *self.workload.write() = workload.clamp(0.0, 1.0);
126 }
127
128 pub fn is_available(&self) -> bool {
130 self.get_workload() < 0.8
131 }
132
133 pub fn get_score(&self) -> f32 {
135 let type_score = match self.info.device_type {
137 wgpu::DeviceType::DiscreteGpu => 1.0,
138 wgpu::DeviceType::IntegratedGpu => 0.7,
139 wgpu::DeviceType::VirtualGpu => 0.5,
140 wgpu::DeviceType::Cpu => 0.3,
141 wgpu::DeviceType::Other => 0.1,
142 };
143
144 let workload = self.get_workload();
145 type_score * (1.0 - workload)
146 }
147}
148
149pub struct MultiGpuManager {
151 devices: Vec<Arc<GpuDevice>>,
153 #[allow(dead_code)]
155 strategy: SelectionStrategy,
156 work_queues: DashMap<usize, Arc<work_queue::WorkQueue>>,
158 load_balancer: Arc<load_balancer::LoadBalancer>,
160}
161
162#[derive(Debug, Clone, Copy)]
164pub enum SelectionStrategy {
165 RoundRobin,
167 LeastLoaded,
169 BestScore,
171 Affinity,
173}
174
175impl MultiGpuManager {
176 pub async fn new(strategy: SelectionStrategy) -> Result<Self> {
178 let devices = Self::enumerate_devices().await?;
179
180 if devices.is_empty() {
181 return Err(GpuAdvancedError::GpuNotFound(
182 "No compatible GPU devices found".to_string(),
183 ));
184 }
185
186 let work_queues = DashMap::new();
187 for device in &devices {
188 work_queues.insert(
189 device.info.index,
190 Arc::new(work_queue::WorkQueue::new(device.clone())),
191 );
192 }
193
194 let load_balancer = Arc::new(load_balancer::LoadBalancer::new(devices.clone(), strategy));
195
196 Ok(Self {
197 devices,
198 strategy,
199 work_queues,
200 load_balancer,
201 })
202 }
203
204 async fn enumerate_devices() -> Result<Vec<Arc<GpuDevice>>> {
206 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
207 backends: wgpu::Backends::all(),
208 ..wgpu::InstanceDescriptor::new_without_display_handle()
209 });
210
211 let mut devices = Vec::new();
212 let mut index = 0;
213
214 for _backend in &[
216 wgpu::Backends::VULKAN,
217 wgpu::Backends::METAL,
218 wgpu::Backends::DX12,
219 wgpu::Backends::GL,
220 ] {
221 if let Ok(adapter) = instance
222 .request_adapter(&wgpu::RequestAdapterOptions {
223 power_preference: wgpu::PowerPreference::HighPerformance,
224 force_fallback_adapter: false,
225 compatible_surface: None,
226 })
227 .await
228 {
229 let info = adapter.get_info();
230
231 if info.device_type == wgpu::DeviceType::Cpu {
233 continue;
234 }
235
236 if devices.iter().any(|d: &Arc<GpuDevice>| {
238 let d_info = &d.info;
239 d_info.name == info.name && d_info.backend == info.backend
240 }) {
241 continue;
242 }
243
244 let (device, queue) = match adapter
246 .request_device(&wgpu::DeviceDescriptor {
247 label: Some(&format!("GPU Device {}", index)),
248 required_features: wgpu::Features::empty(),
249 required_limits: wgpu::Limits::default(),
250 memory_hints: wgpu::MemoryHints::Performance,
251 experimental_features: wgpu::ExperimentalFeatures::disabled(),
252 trace: wgpu::Trace::Off,
253 })
254 .await
255 {
256 Ok((device, queue)) => (device, queue),
257 Err(e) => {
258 tracing::warn!("Failed to request device {}: {}", index, e);
259 continue;
260 }
261 };
262
263 let gpu_device = GpuDevice::new(index, adapter, device, queue)?;
264 devices.push(Arc::new(gpu_device));
265 index += 1;
266 }
267 }
268
269 Ok(devices)
270 }
271
272 pub fn gpu_count(&self) -> usize {
274 self.devices.len()
275 }
276
277 pub fn get_gpu(&self, index: usize) -> Result<Arc<GpuDevice>> {
279 self.devices
280 .get(index)
281 .cloned()
282 .ok_or(GpuAdvancedError::InvalidGpuIndex {
283 index,
284 total: self.devices.len(),
285 })
286 }
287
288 pub fn get_all_gpus(&self) -> &[Arc<GpuDevice>] {
290 &self.devices
291 }
292
293 pub fn select_gpu(&self) -> Result<Arc<GpuDevice>> {
295 self.load_balancer.select_device()
296 }
297
298 pub fn select_gpu_with_requirements(
300 &self,
301 min_memory: Option<u64>,
302 preferred_type: Option<wgpu::DeviceType>,
303 ) -> Result<Arc<GpuDevice>> {
304 let mut candidates: Vec<_> = self
305 .devices
306 .iter()
307 .filter(|device| {
308 if let Some(min_mem) = min_memory {
309 if let Some(mem_size) = device.info.memory_size {
310 if mem_size < min_mem {
311 return false;
312 }
313 }
314 }
315
316 if let Some(pref_type) = preferred_type {
317 if device.info.device_type != pref_type {
318 return false;
319 }
320 }
321
322 device.is_available()
323 })
324 .collect();
325
326 if candidates.is_empty() {
327 return Err(GpuAdvancedError::GpuNotFound(
328 "No GPU matching requirements".to_string(),
329 ));
330 }
331
332 candidates.sort_by(|a, b| {
334 b.get_score()
335 .partial_cmp(&a.get_score())
336 .unwrap_or(std::cmp::Ordering::Equal)
337 });
338
339 Ok(candidates[0].clone())
340 }
341
342 pub fn get_work_queue(&self, index: usize) -> Result<Arc<work_queue::WorkQueue>> {
344 self.work_queues
345 .get(&index)
346 .map(|q| q.clone())
347 .ok_or(GpuAdvancedError::InvalidGpuIndex {
348 index,
349 total: self.devices.len(),
350 })
351 }
352
353 pub async fn submit_work<F, T>(&self, work: F) -> Result<T>
355 where
356 F: FnOnce(&GpuDevice) -> Result<T> + Send + 'static,
357 T: Send + 'static,
358 {
359 let device = self.select_gpu()?;
360 let queue = self.get_work_queue(device.info.index)?;
361 queue.submit_work(work).await
362 }
363
364 pub fn get_load_balancer(&self) -> Arc<load_balancer::LoadBalancer> {
366 self.load_balancer.clone()
367 }
368
369 pub fn print_gpu_info(&self) {
371 println!("Multi-GPU Manager - {} devices found", self.devices.len());
372 for device in &self.devices {
373 println!(
374 " GPU {}: {} ({:?}, {:?})",
375 device.info.index, device.info.name, device.info.backend, device.info.device_type
376 );
377 println!(" Max buffer size: {} bytes", device.info.max_buffer_size);
378 println!(
379 " Max texture 2D: {}x{}",
380 device.info.max_texture_dimension_2d, device.info.max_texture_dimension_2d
381 );
382 println!(
383 " Max workgroup size: {}x{}x{}",
384 device.info.max_compute_workgroup_size_x,
385 device.info.max_compute_workgroup_size_y,
386 device.info.max_compute_workgroup_size_z
387 );
388 println!(
389 " Current workload: {:.1}%",
390 device.get_workload() * 100.0
391 );
392 println!(" Memory usage: {} bytes", device.get_memory_usage());
393 }
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 #[tokio::test]
402 async fn test_multi_gpu_manager_creation() {
403 let result = MultiGpuManager::new(SelectionStrategy::LeastLoaded).await;
404
405 match result {
407 Ok(manager) => {
408 assert!(manager.gpu_count() > 0);
409 manager.print_gpu_info();
410 }
411 Err(e) => {
412 println!("No GPU available: {}", e);
413 }
414 }
415 }
416
417 #[tokio::test]
418 async fn test_gpu_selection() {
419 let result = MultiGpuManager::new(SelectionStrategy::BestScore).await;
420
421 if let Ok(manager) = result {
422 let gpu = manager.select_gpu();
423 assert!(gpu.is_ok());
424
425 if let Ok(gpu) = gpu {
426 println!("Selected GPU: {}", gpu.info.name);
427 assert!(gpu.get_score() >= 0.0);
428 }
429 }
430 }
431}