1use crate::error::{MemvidError, Result};
7use candle_core::Device;
8use serde::{Deserialize, Serialize};
9use std::ptr;
10use std::sync::Once;
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
14pub enum DeviceType {
15 Cpu,
17 Cuda(usize),
19 Metal,
21}
22
23#[derive(Debug, Clone)]
25pub struct DeviceInfo {
26 pub device_type: DeviceType,
28 pub device: Device,
30 pub name: String,
32 pub compute_score: f32,
34 pub memory_bytes: Option<u64>,
36}
37
38static mut DEVICE_MANAGER: Option<DeviceManager> = None;
40static DEVICE_MANAGER_INIT: Once = Once::new();
41
42pub struct DeviceManager {
44 current_device: DeviceInfo,
46 available_devices: Vec<DeviceInfo>,
48}
49
50impl DeviceManager {
51 pub fn initialize() -> Result<&'static DeviceManager> {
53 unsafe {
54 DEVICE_MANAGER_INIT.call_once(|| match Self::new() {
55 Ok(manager) => {
56 log::info!(
57 "Initialized device manager with optimal device: {}",
58 manager.current_device.name
59 );
60 DEVICE_MANAGER = Some(manager);
61 }
62 Err(e) => {
63 log::error!("Failed to initialize device manager: {}", e);
64 }
65 });
66
67 ptr::addr_of!(DEVICE_MANAGER)
68 .as_ref()
69 .unwrap()
70 .as_ref()
71 .ok_or_else(|| {
72 MemvidError::MachineLearning("Device manager initialization failed".to_string())
73 })
74 }
75 }
76
77 pub fn global() -> Result<&'static DeviceManager> {
79 unsafe {
80 ptr::addr_of!(DEVICE_MANAGER)
81 .as_ref()
82 .unwrap()
83 .as_ref()
84 .ok_or_else(|| {
85 MemvidError::MachineLearning("Device manager not initialized".to_string())
86 })
87 }
88 }
89
90 fn new() -> Result<Self> {
92 let mut available_devices = Vec::new();
93
94 let cpu_device = DeviceInfo {
96 device_type: DeviceType::Cpu,
97 device: Device::Cpu,
98 name: "CPU".to_string(),
99 compute_score: 1.0, memory_bytes: Self::estimate_system_memory(),
101 };
102 available_devices.push(cpu_device);
103
104 #[cfg(feature = "cuda")]
106 {
107 for device_id in 0..8 {
108 if let Ok(device) = Device::cuda_if_available(device_id) {
110 let device_info = DeviceInfo {
111 device_type: DeviceType::Cuda(device_id),
112 device,
113 name: format!("CUDA GPU {}", device_id),
114 compute_score: 10.0 + device_id as f32, memory_bytes: Self::estimate_gpu_memory(device_id),
116 };
117 available_devices.push(device_info);
118 log::info!("Detected CUDA device {}", device_id);
119 }
120 }
121 }
122
123 #[cfg(feature = "metal")]
125 {
126 if let Ok(device) = Device::new_metal(0) {
127 let device_info = DeviceInfo {
128 device_type: DeviceType::Metal,
129 device,
130 name: "Metal GPU".to_string(),
131 compute_score: 15.0, memory_bytes: Self::estimate_metal_memory(),
133 };
134 available_devices.push(device_info);
135 log::info!("Detected Metal GPU");
136 }
137 }
138
139 let current_device = available_devices
141 .iter()
142 .max_by(|a, b| a.compute_score.partial_cmp(&b.compute_score).unwrap())
143 .cloned()
144 .ok_or_else(|| MemvidError::MachineLearning("No devices available".to_string()))?;
145
146 log::info!("Selected optimal device: {}", current_device.name);
147
148 Ok(Self {
149 current_device,
150 available_devices,
151 })
152 }
153
154 pub fn current_device(&self) -> &DeviceInfo {
156 &self.current_device
157 }
158
159 pub fn available_devices(&self) -> &[DeviceInfo] {
161 &self.available_devices
162 }
163
164 pub fn get_device(&self, device_type: &DeviceType) -> Option<&DeviceInfo> {
166 self.available_devices
167 .iter()
168 .find(|d| d.device_type == *device_type)
169 }
170
171 pub fn switch_device(&mut self, device_type: DeviceType) -> Result<()> {
173 if let Some(device_info) = self
174 .available_devices
175 .iter()
176 .find(|d| d.device_type == device_type)
177 .cloned()
178 {
179 self.current_device = device_info;
180 log::info!("Switched to device: {}", self.current_device.name);
181 Ok(())
182 } else {
183 Err(MemvidError::MachineLearning(format!(
184 "Device type {:?} not available",
185 device_type
186 )))
187 }
188 }
189
190 pub fn optimal_batch_size(&self, base_batch_size: usize) -> usize {
192 match self.current_device.device_type {
193 DeviceType::Cpu => base_batch_size.min(32), DeviceType::Cuda(_) => base_batch_size * 2, DeviceType::Metal => base_batch_size.max(16), }
197 }
198
199 pub fn supports_half_precision(&self) -> bool {
201 matches!(
202 self.current_device.device_type,
203 DeviceType::Cuda(_) | DeviceType::Metal
204 )
205 }
206
207 fn estimate_system_memory() -> Option<u64> {
209 Some(8 * 1024 * 1024 * 1024) }
212
213 #[cfg(feature = "cuda")]
215 fn estimate_gpu_memory(_device_id: usize) -> Option<u64> {
216 Some(4 * 1024 * 1024 * 1024) }
219
220 #[cfg(feature = "metal")]
222 fn estimate_metal_memory() -> Option<u64> {
223 Some(8 * 1024 * 1024 * 1024) }
226
227 #[cfg(not(feature = "metal"))]
228 fn estimate_metal_memory() -> Option<u64> {
229 None
230 }
231}
232
233pub fn initialize() -> Result<()> {
235 DeviceManager::initialize()?;
236 Ok(())
237}
238
239pub fn current_device() -> Result<&'static DeviceInfo> {
241 Ok(DeviceManager::global()?.current_device())
242}
243
244pub fn available_devices() -> Result<&'static [DeviceInfo]> {
246 Ok(DeviceManager::global()?.available_devices())
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[test]
254 fn test_device_manager_initialization() {
255 let manager = DeviceManager::initialize().unwrap();
256 assert!(!manager.available_devices().is_empty());
257
258 assert!(
260 manager
261 .available_devices()
262 .iter()
263 .any(|d| matches!(d.device_type, DeviceType::Cpu))
264 );
265 }
266
267 #[test]
268 fn test_device_selection() {
269 let manager = DeviceManager::initialize().unwrap();
270 let current = manager.current_device();
271
272 assert!(!current.name.is_empty());
274 assert!(current.compute_score > 0.0);
275 }
276
277 #[test]
278 fn test_batch_size_optimization() {
279 let manager = DeviceManager::initialize().unwrap();
280 let base_size = 16;
281 let optimal = manager.optimal_batch_size(base_size);
282
283 assert!(optimal > 0);
284 assert!(optimal <= base_size * 4); }
286}