Skip to main content

memvid_rs/ml/
device.rs

1//! Device detection and optimization for ML inference
2//!
3//! This module handles automatic detection of the best available compute device
4//! (CUDA GPU, Metal GPU, or CPU) and provides device management for ML operations.
5
6use crate::error::{MemvidError, Result};
7use candle_core::Device;
8use serde::{Deserialize, Serialize};
9use std::ptr;
10use std::sync::Once;
11
12/// Device types supported for ML inference
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub enum DeviceType {
15    /// CPU inference
16    Cpu,
17    /// CUDA GPU inference
18    Cuda(usize),
19    /// Metal GPU inference (macOS)
20    Metal,
21}
22
23/// Device information and capabilities
24#[derive(Debug, Clone)]
25pub struct DeviceInfo {
26    /// Device type
27    pub device_type: DeviceType,
28    /// Candle device instance
29    pub device: Device,
30    /// Device name/description
31    pub name: String,
32    /// Estimated compute capability (relative score)
33    pub compute_score: f32,
34    /// Available memory in bytes (estimate)
35    pub memory_bytes: Option<u64>,
36}
37
38/// Global device manager instance
39static mut DEVICE_MANAGER: Option<DeviceManager> = None;
40static DEVICE_MANAGER_INIT: Once = Once::new();
41
42/// Device manager for automatic device selection and optimization
43pub struct DeviceManager {
44    /// Current optimal device
45    current_device: DeviceInfo,
46    /// All available devices
47    available_devices: Vec<DeviceInfo>,
48}
49
50impl DeviceManager {
51    /// Initialize device manager with automatic device detection
52    pub fn initialize() -> Result<&'static DeviceManager> {
53        unsafe {
54            DEVICE_MANAGER_INIT.call_once(|| match Self::new() {
55                Ok(manager) => {
56                    log::info!(
57                        "Initialized device manager with optimal device: {}",
58                        manager.current_device.name
59                    );
60                    DEVICE_MANAGER = Some(manager);
61                }
62                Err(e) => {
63                    log::error!("Failed to initialize device manager: {}", e);
64                }
65            });
66
67            ptr::addr_of!(DEVICE_MANAGER)
68                .as_ref()
69                .unwrap()
70                .as_ref()
71                .ok_or_else(|| {
72                    MemvidError::MachineLearning("Device manager initialization failed".to_string())
73                })
74        }
75    }
76
77    /// Get global device manager instance
78    pub fn global() -> Result<&'static DeviceManager> {
79        unsafe {
80            ptr::addr_of!(DEVICE_MANAGER)
81                .as_ref()
82                .unwrap()
83                .as_ref()
84                .ok_or_else(|| {
85                    MemvidError::MachineLearning("Device manager not initialized".to_string())
86                })
87        }
88    }
89
90    /// Create new device manager with automatic device detection
91    fn new() -> Result<Self> {
92        let mut available_devices = Vec::new();
93
94        // Detect CPU
95        let cpu_device = DeviceInfo {
96            device_type: DeviceType::Cpu,
97            device: Device::Cpu,
98            name: "CPU".to_string(),
99            compute_score: 1.0, // Base score
100            memory_bytes: Self::estimate_system_memory(),
101        };
102        available_devices.push(cpu_device);
103
104        // Detect CUDA devices
105        #[cfg(feature = "cuda")]
106        {
107            for device_id in 0..8 {
108                // Check up to 8 CUDA devices
109                if let Ok(device) = Device::cuda_if_available(device_id) {
110                    let device_info = DeviceInfo {
111                        device_type: DeviceType::Cuda(device_id),
112                        device,
113                        name: format!("CUDA GPU {}", device_id),
114                        compute_score: 10.0 + device_id as f32, // Higher score for GPUs
115                        memory_bytes: Self::estimate_gpu_memory(device_id),
116                    };
117                    available_devices.push(device_info);
118                    log::info!("Detected CUDA device {}", device_id);
119                }
120            }
121        }
122
123        // Detect Metal device (macOS)
124        #[cfg(feature = "metal")]
125        {
126            if let Ok(device) = Device::new_metal(0) {
127                let device_info = DeviceInfo {
128                    device_type: DeviceType::Metal,
129                    device,
130                    name: "Metal GPU".to_string(),
131                    compute_score: 15.0, // High score for Metal
132                    memory_bytes: Self::estimate_metal_memory(),
133                };
134                available_devices.push(device_info);
135                log::info!("Detected Metal GPU");
136            }
137        }
138
139        // Select optimal device (highest compute score)
140        let current_device = available_devices
141            .iter()
142            .max_by(|a, b| a.compute_score.partial_cmp(&b.compute_score).unwrap())
143            .cloned()
144            .ok_or_else(|| MemvidError::MachineLearning("No devices available".to_string()))?;
145
146        log::info!("Selected optimal device: {}", current_device.name);
147
148        Ok(Self {
149            current_device,
150            available_devices,
151        })
152    }
153
154    /// Get current optimal device
155    pub fn current_device(&self) -> &DeviceInfo {
156        &self.current_device
157    }
158
159    /// Get all available devices
160    pub fn available_devices(&self) -> &[DeviceInfo] {
161        &self.available_devices
162    }
163
164    /// Get device by type
165    pub fn get_device(&self, device_type: &DeviceType) -> Option<&DeviceInfo> {
166        self.available_devices
167            .iter()
168            .find(|d| d.device_type == *device_type)
169    }
170
171    /// Switch to a specific device type
172    pub fn switch_device(&mut self, device_type: DeviceType) -> Result<()> {
173        if let Some(device_info) = self
174            .available_devices
175            .iter()
176            .find(|d| d.device_type == device_type)
177            .cloned()
178        {
179            self.current_device = device_info;
180            log::info!("Switched to device: {}", self.current_device.name);
181            Ok(())
182        } else {
183            Err(MemvidError::MachineLearning(format!(
184                "Device type {:?} not available",
185                device_type
186            )))
187        }
188    }
189
190    /// Get optimal batch size for current device
191    pub fn optimal_batch_size(&self, base_batch_size: usize) -> usize {
192        match self.current_device.device_type {
193            DeviceType::Cpu => base_batch_size.min(32), // Conservative for CPU
194            DeviceType::Cuda(_) => base_batch_size * 2, // Can handle larger batches
195            DeviceType::Metal => base_batch_size.max(16), // Good performance on Metal
196        }
197    }
198
199    /// Check if device supports half precision
200    pub fn supports_half_precision(&self) -> bool {
201        matches!(
202            self.current_device.device_type,
203            DeviceType::Cuda(_) | DeviceType::Metal
204        )
205    }
206
207    /// Estimate system memory
208    fn estimate_system_memory() -> Option<u64> {
209        // Simple heuristic - could be improved with platform-specific APIs
210        Some(8 * 1024 * 1024 * 1024) // Assume 8GB as conservative estimate
211    }
212
213    /// Estimate GPU memory for CUDA device
214    #[cfg(feature = "cuda")]
215    fn estimate_gpu_memory(_device_id: usize) -> Option<u64> {
216        // Would use CUDA APIs in production
217        Some(4 * 1024 * 1024 * 1024) // Assume 4GB as conservative estimate
218    }
219
220    /// Estimate Metal GPU memory
221    #[cfg(feature = "metal")]
222    fn estimate_metal_memory() -> Option<u64> {
223        // Would use Metal APIs in production
224        Some(8 * 1024 * 1024 * 1024) // Assume 8GB unified memory
225    }
226
227    #[cfg(not(feature = "metal"))]
228    fn estimate_metal_memory() -> Option<u64> {
229        None
230    }
231}
232
233/// Initialize device system
234pub fn initialize() -> Result<()> {
235    DeviceManager::initialize()?;
236    Ok(())
237}
238
239/// Get current optimal device
240pub fn current_device() -> Result<&'static DeviceInfo> {
241    Ok(DeviceManager::global()?.current_device())
242}
243
244/// Get all available devices
245pub fn available_devices() -> Result<&'static [DeviceInfo]> {
246    Ok(DeviceManager::global()?.available_devices())
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_device_manager_initialization() {
255        let manager = DeviceManager::initialize().unwrap();
256        assert!(!manager.available_devices().is_empty());
257
258        // Should always have at least CPU
259        assert!(
260            manager
261                .available_devices()
262                .iter()
263                .any(|d| matches!(d.device_type, DeviceType::Cpu))
264        );
265    }
266
267    #[test]
268    fn test_device_selection() {
269        let manager = DeviceManager::initialize().unwrap();
270        let current = manager.current_device();
271
272        // Should select a valid device
273        assert!(!current.name.is_empty());
274        assert!(current.compute_score > 0.0);
275    }
276
277    #[test]
278    fn test_batch_size_optimization() {
279        let manager = DeviceManager::initialize().unwrap();
280        let base_size = 16;
281        let optimal = manager.optimal_batch_size(base_size);
282
283        assert!(optimal > 0);
284        assert!(optimal <= base_size * 4); // Reasonable upper bound
285    }
286}