use crate::error::RusTorchResult;
use crate::tensor::Tensor;
use std::any::Any;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DeviceType {
Cpu,
Cuda(usize), Metal(usize), OpenCL(usize), }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferDirection {
HostToDevice,
DeviceToHost,
}
#[derive(Debug, Clone)]
pub enum Operation<T: num_traits::Float> {
Add {
a: Tensor<T>,
b: Tensor<T>,
},
Multiply {
a: Tensor<T>,
b: Tensor<T>,
},
MatMul {
a: Tensor<T>,
b: Tensor<T>,
},
Conv2D {
input: Tensor<T>,
kernel: Tensor<T>,
stride: (usize, usize),
padding: (usize, usize),
},
Reduce {
input: Tensor<T>,
operation: ReduceOp,
axes: Option<Vec<usize>>,
},
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReduceOp {
Sum,
Mean,
Max,
Min,
Product,
}
#[derive(Debug, Clone)]
pub struct PerformanceMetrics {
pub execution_time_ns: u64,
pub memory_bandwidth_gbps: f64,
pub device_utilization: f64,
pub memory_usage_bytes: usize,
}
pub trait ComputeBackendGeneric {
fn execute_operation<T>(&self, operation: &Operation<T>) -> RusTorchResult<Tensor<T>>
where
T: Clone + Send + Sync + 'static + num_traits::Float;
fn memory_transfer<T>(
&self,
data: &[T],
direction: TransferDirection,
) -> RusTorchResult<Vec<T>>
where
T: Clone + Send + Sync + 'static + num_traits::Float;
}
pub trait ComputeBackend: Send + Sync {
fn device_type(&self) -> DeviceType;
fn is_available(&self) -> bool;
fn initialize(&mut self) -> RusTorchResult<()>;
fn get_metrics(&self) -> PerformanceMetrics;
fn get_info(&self) -> HashMap<String, Box<dyn Any + Send + Sync>>;
fn synchronize(&self) -> RusTorchResult<()>;
fn available_memory(&self) -> RusTorchResult<usize>;
fn set_config(&mut self, key: &str, value: Box<dyn Any + Send + Sync>) -> RusTorchResult<()>;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SelectionStrategy {
Performance,
Memory,
Balanced,
Manual(Vec<DeviceType>),
}
pub struct DeviceManager {
cpu_backend: Option<crate::backends::cpu_unified::CpuBackend>,
pub strategy: SelectionStrategy,
performance_history: HashMap<DeviceType, Vec<PerformanceMetrics>>,
}
impl DeviceManager {
pub fn new(strategy: SelectionStrategy) -> Self {
Self {
cpu_backend: None,
strategy,
performance_history: HashMap::new(),
}
}
pub fn register_backend(&mut self, backend: Box<dyn ComputeBackend>) -> RusTorchResult<()> {
let device_type = backend.device_type();
if device_type == DeviceType::Cpu && backend.is_available() {
let cpu_backend = match backend.device_type() {
DeviceType::Cpu => {
crate::backends::cpu_unified::CpuBackend::new()?
}
_ => {
return Err(crate::error::RusTorchError::DeviceNotAvailable(
"Only CPU backend is currently supported".to_string(),
))
}
};
self.cpu_backend = Some(cpu_backend);
self.performance_history.insert(device_type, Vec::new());
Ok(())
} else {
Err(crate::error::RusTorchError::DeviceNotAvailable(format!(
"Backend for device {:?} is not available",
device_type
)))
}
}
pub fn available_devices(&self) -> Vec<DeviceType> {
if self.cpu_backend.is_some() {
vec![DeviceType::Cpu]
} else {
vec![]
}
}
pub fn select_backend<T: num_traits::Float>(
&self,
_operation: &Operation<T>,
) -> RusTorchResult<DeviceType> {
if self.cpu_backend.is_none() {
return Err(crate::error::RusTorchError::DeviceNotAvailable(
"No compute backends available".to_string(),
));
}
Ok(DeviceType::Cpu)
}
pub fn execute_operation<T>(&mut self, operation: &Operation<T>) -> RusTorchResult<Tensor<T>>
where
T: Clone + Send + Sync + 'static + num_traits::Float,
{
let selected_device = self.select_backend(operation)?;
if let Some(ref backend) = self.cpu_backend {
let start_time = std::time::Instant::now();
let result = backend.execute_operation(operation)?;
let execution_time = start_time.elapsed();
let mut metrics = backend.get_metrics();
metrics.execution_time_ns = execution_time.as_nanos() as u64;
self.performance_history
.get_mut(&selected_device)
.unwrap()
.push(metrics);
let history = self.performance_history.get_mut(&selected_device).unwrap();
if history.len() > 100 {
history.drain(0..history.len() - 100);
}
Ok(result)
} else {
Err(crate::error::RusTorchError::DeviceNotAvailable(
"CPU backend not available".to_string(),
))
}
}
pub fn get_device_stats(&self, device: DeviceType) -> Option<PerformanceMetrics> {
self.performance_history.get(&device).and_then(|history| {
if history.is_empty() {
None
} else {
let avg_time =
history.iter().map(|m| m.execution_time_ns).sum::<u64>() / history.len() as u64;
let avg_bandwidth = history.iter().map(|m| m.memory_bandwidth_gbps).sum::<f64>()
/ history.len() as f64;
let avg_utilization = history.iter().map(|m| m.device_utilization).sum::<f64>()
/ history.len() as f64;
let avg_memory =
history.iter().map(|m| m.memory_usage_bytes).sum::<usize>() / history.len();
Some(PerformanceMetrics {
execution_time_ns: avg_time,
memory_bandwidth_gbps: avg_bandwidth,
device_utilization: avg_utilization,
memory_usage_bytes: avg_memory,
})
}
})
}
#[cfg(feature = "coreml")]
pub fn is_coreml_available() -> bool {
cfg!(target_os = "macos")
}
#[cfg(not(feature = "coreml"))]
pub fn is_coreml_available() -> bool {
false
}
#[cfg(feature = "metal")]
pub fn is_metal_available() -> bool {
cfg!(target_os = "macos")
}
#[cfg(not(feature = "metal"))]
pub fn is_metal_available() -> bool {
false
}
#[cfg(feature = "cuda")]
pub fn is_cuda_available() -> bool {
!cfg!(target_os = "macos")
}
#[cfg(not(feature = "cuda"))]
pub fn is_cuda_available() -> bool {
false
}
}
impl Default for DeviceManager {
fn default() -> Self {
Self::new(SelectionStrategy::Balanced)
}
}
lazy_static::lazy_static! {
static ref GLOBAL_DEVICE_MANAGER: std::sync::RwLock<DeviceManager> =
std::sync::RwLock::new(DeviceManager::default());
}
pub fn global_device_manager() -> &'static std::sync::RwLock<DeviceManager> {
&GLOBAL_DEVICE_MANAGER
}
pub fn initialize_backends() -> RusTorchResult<()> {
let mut manager = global_device_manager().write().unwrap();
let cpu_backend = crate::backends::cpu_unified::CpuBackend::new()?;
manager.register_backend(Box::new(cpu_backend))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_manager_creation() {
let manager = DeviceManager::new(SelectionStrategy::Performance);
assert_eq!(manager.strategy, SelectionStrategy::Performance);
assert!(manager.cpu_backend.is_none());
}
#[test]
fn test_selection_strategy_manual() {
let priorities = vec![DeviceType::Cuda(0), DeviceType::Cpu];
let strategy = SelectionStrategy::Manual(priorities.clone());
if let SelectionStrategy::Manual(ref manual_priorities) = strategy {
assert_eq!(manual_priorities, &priorities);
} else {
panic!("Strategy should be Manual");
}
}
}