use crate::error::{MemvidError, Result};
use candle_core::Device;
use serde::{Deserialize, Serialize};
use std::ptr;
use std::sync::Once;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum DeviceType {
Cpu,
Cuda(usize),
Metal,
}
#[derive(Debug, Clone)]
pub struct DeviceInfo {
pub device_type: DeviceType,
pub device: Device,
pub name: String,
pub compute_score: f32,
pub memory_bytes: Option<u64>,
}
static mut DEVICE_MANAGER: Option<DeviceManager> = None;
static DEVICE_MANAGER_INIT: Once = Once::new();
pub struct DeviceManager {
current_device: DeviceInfo,
available_devices: Vec<DeviceInfo>,
}
impl DeviceManager {
pub fn initialize() -> Result<&'static DeviceManager> {
unsafe {
DEVICE_MANAGER_INIT.call_once(|| match Self::new() {
Ok(manager) => {
log::info!(
"Initialized device manager with optimal device: {}",
manager.current_device.name
);
DEVICE_MANAGER = Some(manager);
}
Err(e) => {
log::error!("Failed to initialize device manager: {}", e);
}
});
ptr::addr_of!(DEVICE_MANAGER)
.as_ref()
.unwrap()
.as_ref()
.ok_or_else(|| {
MemvidError::MachineLearning("Device manager initialization failed".to_string())
})
}
}
pub fn global() -> Result<&'static DeviceManager> {
unsafe {
ptr::addr_of!(DEVICE_MANAGER)
.as_ref()
.unwrap()
.as_ref()
.ok_or_else(|| {
MemvidError::MachineLearning("Device manager not initialized".to_string())
})
}
}
fn new() -> Result<Self> {
let mut available_devices = Vec::new();
let cpu_device = DeviceInfo {
device_type: DeviceType::Cpu,
device: Device::Cpu,
name: "CPU".to_string(),
compute_score: 1.0, memory_bytes: Self::estimate_system_memory(),
};
available_devices.push(cpu_device);
#[cfg(feature = "cuda")]
{
for device_id in 0..8 {
if let Ok(device) = Device::cuda_if_available(device_id) {
let device_info = DeviceInfo {
device_type: DeviceType::Cuda(device_id),
device,
name: format!("CUDA GPU {}", device_id),
compute_score: 10.0 + device_id as f32, memory_bytes: Self::estimate_gpu_memory(device_id),
};
available_devices.push(device_info);
log::info!("Detected CUDA device {}", device_id);
}
}
}
#[cfg(feature = "metal")]
{
if let Ok(device) = Device::new_metal(0) {
let device_info = DeviceInfo {
device_type: DeviceType::Metal,
device,
name: "Metal GPU".to_string(),
compute_score: 15.0, memory_bytes: Self::estimate_metal_memory(),
};
available_devices.push(device_info);
log::info!("Detected Metal GPU");
}
}
let current_device = available_devices
.iter()
.max_by(|a, b| a.compute_score.partial_cmp(&b.compute_score).unwrap())
.cloned()
.ok_or_else(|| MemvidError::MachineLearning("No devices available".to_string()))?;
log::info!("Selected optimal device: {}", current_device.name);
Ok(Self {
current_device,
available_devices,
})
}
pub fn current_device(&self) -> &DeviceInfo {
&self.current_device
}
pub fn available_devices(&self) -> &[DeviceInfo] {
&self.available_devices
}
pub fn get_device(&self, device_type: &DeviceType) -> Option<&DeviceInfo> {
self.available_devices
.iter()
.find(|d| d.device_type == *device_type)
}
pub fn switch_device(&mut self, device_type: DeviceType) -> Result<()> {
if let Some(device_info) = self
.available_devices
.iter()
.find(|d| d.device_type == device_type)
.cloned()
{
self.current_device = device_info;
log::info!("Switched to device: {}", self.current_device.name);
Ok(())
} else {
Err(MemvidError::MachineLearning(format!(
"Device type {:?} not available",
device_type
)))
}
}
pub fn optimal_batch_size(&self, base_batch_size: usize) -> usize {
match self.current_device.device_type {
DeviceType::Cpu => base_batch_size.min(32), DeviceType::Cuda(_) => base_batch_size * 2, DeviceType::Metal => base_batch_size.max(16), }
}
pub fn supports_half_precision(&self) -> bool {
matches!(
self.current_device.device_type,
DeviceType::Cuda(_) | DeviceType::Metal
)
}
fn estimate_system_memory() -> Option<u64> {
Some(8 * 1024 * 1024 * 1024) }
#[cfg(feature = "cuda")]
fn estimate_gpu_memory(_device_id: usize) -> Option<u64> {
Some(4 * 1024 * 1024 * 1024) }
#[cfg(feature = "metal")]
fn estimate_metal_memory() -> Option<u64> {
Some(8 * 1024 * 1024 * 1024) }
#[cfg(not(feature = "metal"))]
fn estimate_metal_memory() -> Option<u64> {
None
}
}
pub fn initialize() -> Result<()> {
DeviceManager::initialize()?;
Ok(())
}
pub fn current_device() -> Result<&'static DeviceInfo> {
Ok(DeviceManager::global()?.current_device())
}
pub fn available_devices() -> Result<&'static [DeviceInfo]> {
Ok(DeviceManager::global()?.available_devices())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_manager_initialization() {
let manager = DeviceManager::initialize().unwrap();
assert!(!manager.available_devices().is_empty());
assert!(
manager
.available_devices()
.iter()
.any(|d| matches!(d.device_type, DeviceType::Cpu))
);
}
#[test]
fn test_device_selection() {
let manager = DeviceManager::initialize().unwrap();
let current = manager.current_device();
assert!(!current.name.is_empty());
assert!(current.compute_score > 0.0);
}
#[test]
fn test_batch_size_optimization() {
let manager = DeviceManager::initialize().unwrap();
let base_size = 16;
let optimal = manager.optimal_batch_size(base_size);
assert!(optimal > 0);
assert!(optimal <= base_size * 4); }
}