use crate::device::{DeviceCapabilities, DeviceType};
use crate::error::Result;
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::{Arc, Mutex};
pub trait Device: Debug + Send + Sync + 'static {
fn device_type(&self) -> DeviceType;
fn name(&self) -> &str;
fn is_available(&self) -> Result<bool>;
fn capabilities(&self) -> Result<DeviceCapabilities>;
fn synchronize(&self) -> Result<()>;
fn reset(&self) -> Result<()>;
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
fn clone_device(&self) -> Result<Box<dyn Device>>;
fn is_same_device(&self, other: &dyn Device) -> bool {
self.device_type() == other.device_type()
}
fn device_id(&self) -> String {
match self.device_type() {
DeviceType::Cpu => "cpu".to_string(),
DeviceType::Cuda(idx) => format!("cuda:{}", idx),
DeviceType::Metal(idx) => format!("metal:{}", idx),
DeviceType::Wgpu(idx) => format!("wgpu:{}", idx),
}
}
fn supports_feature(&self, feature: &str) -> Result<bool> {
Ok(self.capabilities()?.supports_feature(feature))
}
fn memory_info(&self) -> Result<DeviceMemoryInfo> {
let caps = self.capabilities()?;
Ok(DeviceMemoryInfo {
total: caps.total_memory(),
available: caps.available_memory(),
used: caps.total_memory() - caps.available_memory(),
})
}
fn cleanup(&mut self) -> Result<()> {
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub struct DeviceMemoryInfo {
pub total: u64,
pub available: u64,
pub used: u64,
}
impl DeviceMemoryInfo {
pub fn utilization_percent(&self) -> f64 {
if self.total == 0 {
0.0
} else {
(self.used as f64 / self.total as f64) * 100.0
}
}
pub fn available_percent(&self) -> f64 {
if self.total == 0 {
0.0
} else {
(self.available as f64 / self.total as f64) * 100.0
}
}
pub fn is_memory_pressure(&self, threshold_percent: f64) -> bool {
self.utilization_percent() > threshold_percent
}
}
impl std::fmt::Display for DeviceMemoryInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Memory(total={:.1}MB, used={:.1}MB, available={:.1}MB, utilization={:.1}%)",
self.total as f64 / (1024.0 * 1024.0),
self.used as f64 / (1024.0 * 1024.0),
self.available as f64 / (1024.0 * 1024.0),
self.utilization_percent()
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DeviceState {
Uninitialized,
Initializing,
Ready,
Busy,
Error,
Resetting,
ShuttingDown,
}
impl std::fmt::Display for DeviceState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DeviceState::Uninitialized => write!(f, "Uninitialized"),
DeviceState::Initializing => write!(f, "Initializing"),
DeviceState::Ready => write!(f, "Ready"),
DeviceState::Busy => write!(f, "Busy"),
DeviceState::Error => write!(f, "Error"),
DeviceState::Resetting => write!(f, "Resetting"),
DeviceState::ShuttingDown => write!(f, "Shutting Down"),
}
}
}
#[derive(Debug)]
pub struct DeviceLifecycle {
state: Mutex<DeviceState>,
error_info: Mutex<Option<String>>,
initialization_time: Mutex<Option<std::time::Instant>>,
}
impl DeviceLifecycle {
pub fn new() -> Self {
Self {
state: Mutex::new(DeviceState::Uninitialized),
error_info: Mutex::new(None),
initialization_time: Mutex::new(None),
}
}
pub fn state(&self) -> DeviceState {
*self.state.lock().expect("lock should not be poisoned")
}
pub fn set_state(&self, new_state: DeviceState) -> Result<()> {
let mut state = self.state.lock().expect("lock should not be poisoned");
match (*state, new_state) {
(DeviceState::Uninitialized, DeviceState::Initializing) => {
*self
.initialization_time
.lock()
.expect("lock should not be poisoned") = Some(std::time::Instant::now());
}
(DeviceState::Uninitialized, DeviceState::Ready) => {} (DeviceState::Initializing, DeviceState::Ready) => {}
(DeviceState::Ready, DeviceState::Busy) => {}
(DeviceState::Busy, DeviceState::Ready) => {}
(_, DeviceState::Error) => {} (_, DeviceState::Resetting) => {} (DeviceState::Resetting, DeviceState::Ready) => {}
(DeviceState::Resetting, DeviceState::Uninitialized) => {} (_, DeviceState::ShuttingDown) => {}
(current, target) => {
return Err(crate::error::TorshError::InvalidState(format!(
"Invalid state transition from {:?} to {:?}",
current, target
)));
}
}
*state = new_state;
Ok(())
}
pub fn set_error(&self, error_info: String) -> Result<()> {
*self.error_info.lock().expect("lock should not be poisoned") = Some(error_info);
self.set_state(DeviceState::Error)
}
pub fn error_info(&self) -> Option<String> {
self.error_info
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn initialization_time(&self) -> Option<std::time::Duration> {
self.initialization_time
.lock()
.expect("lock should not be poisoned")
.map(|start| start.elapsed())
}
pub fn is_ready(&self) -> bool {
matches!(self.state(), DeviceState::Ready)
}
pub fn is_error(&self) -> bool {
matches!(self.state(), DeviceState::Error)
}
pub fn reset(&self) -> Result<()> {
self.set_state(DeviceState::Resetting)?;
*self.error_info.lock().expect("lock should not be poisoned") = None;
*self
.initialization_time
.lock()
.expect("lock should not be poisoned") = None;
self.set_state(DeviceState::Uninitialized)
}
}
impl Default for DeviceLifecycle {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct DeviceContext {
device_type: DeviceType,
lifecycle: DeviceLifecycle,
properties: Mutex<HashMap<String, String>>,
resource_handles: Mutex<Vec<Box<dyn Any + Send + Sync>>>,
}
impl DeviceContext {
pub fn new(device_type: DeviceType) -> Self {
Self {
device_type,
lifecycle: DeviceLifecycle::new(),
properties: Mutex::new(HashMap::new()),
resource_handles: Mutex::new(Vec::new()),
}
}
pub fn device_type(&self) -> DeviceType {
self.device_type
}
pub fn lifecycle(&self) -> &DeviceLifecycle {
&self.lifecycle
}
pub fn set_property(&self, key: String, value: String) {
let mut props = self.properties.lock().expect("lock should not be poisoned");
props.insert(key, value);
}
pub fn get_property(&self, key: &str) -> Option<String> {
let props = self.properties.lock().expect("lock should not be poisoned");
props.get(key).cloned()
}
pub fn properties(&self) -> HashMap<String, String> {
self.properties
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn add_resource<T: Any + Send + Sync + 'static>(&self, resource: T) {
let mut handles = self
.resource_handles
.lock()
.expect("lock should not be poisoned");
handles.push(Box::new(resource));
}
pub fn clear_resources(&self) {
let mut handles = self
.resource_handles
.lock()
.expect("lock should not be poisoned");
handles.clear();
}
pub fn resource_count(&self) -> usize {
let handles = self
.resource_handles
.lock()
.expect("lock should not be poisoned");
handles.len()
}
}
pub trait DeviceFactory: Debug + Send + Sync {
fn create_device(&self, device_type: DeviceType) -> Result<Box<dyn Device>>;
fn supports_device_type(&self, device_type: DeviceType) -> bool;
fn factory_name(&self) -> &str;
fn supported_device_types(&self) -> Vec<DeviceType>;
}
#[derive(Debug)]
pub struct DeviceRegistry {
factories: Mutex<HashMap<DeviceType, Box<dyn DeviceFactory>>>,
devices: Mutex<HashMap<String, Arc<dyn Device>>>,
}
impl DeviceRegistry {
pub fn new() -> Self {
Self {
factories: Mutex::new(HashMap::new()),
devices: Mutex::new(HashMap::new()),
}
}
pub fn register_factory<F: DeviceFactory + 'static>(&self, factory: F) -> Result<()> {
let mut factories = self.factories.lock().expect("lock should not be poisoned");
let device_types = factory.supported_device_types();
for device_type in device_types {
if factories.contains_key(&device_type) {
return Err(crate::error::TorshError::InvalidArgument(format!(
"Factory for device type {:?} already registered",
device_type
)));
}
}
let factory_box = Box::new(factory);
let supported_types = factory_box.supported_device_types();
if let Some(&first_type) = supported_types.first() {
factories.insert(first_type, factory_box);
}
Ok(())
}
pub fn create_device(&self, device_type: DeviceType) -> Result<Box<dyn Device>> {
let factories = self.factories.lock().expect("lock should not be poisoned");
match factories.get(&device_type) {
Some(factory) => factory.create_device(device_type),
None => Err(crate::error::TorshError::General(
crate::error::GeneralError::DeviceError(format!(
"No factory registered for device type {:?}",
device_type
)),
)),
}
}
pub fn get_or_create_device(&self, device_type: DeviceType) -> Result<Arc<dyn Device>> {
let device_id = format!("{:?}", device_type);
{
let devices = self.devices.lock().expect("lock should not be poisoned");
if let Some(device) = devices.get(&device_id) {
return Ok(device.clone());
}
}
let device = self.create_device(device_type)?;
let arc_device: Arc<dyn Device> = unsafe { Arc::from_raw(Box::into_raw(device)) };
{
let mut devices = self.devices.lock().expect("lock should not be poisoned");
devices.insert(device_id, arc_device.clone());
}
Ok(arc_device)
}
pub fn registered_device_types(&self) -> Vec<DeviceType> {
let factories = self.factories.lock().expect("lock should not be poisoned");
factories.keys().copied().collect()
}
pub fn clear_devices(&self) {
let mut devices = self.devices.lock().expect("lock should not be poisoned");
devices.clear();
}
pub fn statistics(&self) -> RegistryStatistics {
let factories = self.factories.lock().expect("lock should not be poisoned");
let devices = self.devices.lock().expect("lock should not be poisoned");
RegistryStatistics {
registered_factories: factories.len(),
cached_devices: devices.len(),
supported_device_types: factories.keys().copied().collect(),
}
}
}
impl Default for DeviceRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct RegistryStatistics {
pub registered_factories: usize,
pub cached_devices: usize,
pub supported_device_types: Vec<DeviceType>,
}
impl std::fmt::Display for RegistryStatistics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Registry(factories={}, cached_devices={}, types={:?})",
self.registered_factories, self.cached_devices, self.supported_device_types
)
}
}
static GLOBAL_REGISTRY: std::sync::OnceLock<DeviceRegistry> = std::sync::OnceLock::new();
pub fn global_device_registry() -> &'static DeviceRegistry {
GLOBAL_REGISTRY.get_or_init(DeviceRegistry::new)
}
pub fn initialize_global_registry<F>(init_fn: F) -> Result<()>
where
F: FnOnce(&DeviceRegistry) -> Result<()>,
{
let registry = global_device_registry();
init_fn(registry)
}
pub mod utils {
use super::*;
pub fn devices_compatible(a: &dyn Device, b: &dyn Device) -> bool {
a.device_type() == b.device_type()
}
pub fn find_best_device<'a>(devices: &'a [&'a dyn Device]) -> Result<Option<&'a dyn Device>> {
if devices.is_empty() {
return Ok(None);
}
let mut best_device = devices[0];
let mut best_score = 0u64;
for &device in devices {
if !device.is_available()? {
continue;
}
let caps = device.capabilities()?;
let score = caps.compute_score();
if score > best_score {
best_score = score;
best_device = device;
}
}
Ok(if best_score > 0 {
Some(best_device)
} else {
None
})
}
pub fn synchronize_devices(devices: &[&dyn Device]) -> Result<()> {
for device in devices {
device.synchronize()?;
}
Ok(())
}
pub fn all_devices_available(devices: &[&dyn Device]) -> Result<bool> {
for device in devices {
if !device.is_available()? {
return Ok(false);
}
}
Ok(true)
}
pub fn get_devices_memory_info(devices: &[&dyn Device]) -> Result<Vec<DeviceMemoryInfo>> {
devices.iter().map(|device| device.memory_info()).collect()
}
pub fn filter_devices_by_memory<'a>(
devices: &'a [&'a dyn Device],
min_available_mb: u64,
) -> Result<Vec<&'a dyn Device>> {
let mut filtered = Vec::new();
for &device in devices {
let memory_info = device.memory_info()?;
let available_mb = memory_info.available / (1024 * 1024);
if available_mb >= min_available_mb {
filtered.push(device);
}
}
Ok(filtered)
}
pub fn device_summary(device: &dyn Device) -> Result<String> {
let caps = device.capabilities()?;
let memory_info = device.memory_info()?;
Ok(format!(
"{} - {} ({:.1}MB available, {:.1}% used)",
device.name(),
caps.device_type(),
memory_info.available as f64 / (1024.0 * 1024.0),
memory_info.utilization_percent()
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct MockDevice {
device_type: DeviceType,
name: String,
available: bool,
}
impl MockDevice {
fn new(device_type: DeviceType, name: String) -> Self {
Self {
device_type,
name,
available: true,
}
}
}
impl Device for MockDevice {
fn device_type(&self) -> DeviceType {
self.device_type
}
fn name(&self) -> &str {
&self.name
}
fn is_available(&self) -> Result<bool> {
Ok(self.available)
}
fn capabilities(&self) -> Result<DeviceCapabilities> {
DeviceCapabilities::detect(self.device_type)
}
fn synchronize(&self) -> Result<()> {
Ok(())
}
fn reset(&self) -> Result<()> {
Ok(())
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
fn clone_device(&self) -> Result<Box<dyn Device>> {
Ok(Box::new(MockDevice {
device_type: self.device_type,
name: self.name.clone(),
available: self.available,
}))
}
}
#[test]
fn test_device_lifecycle() {
let lifecycle = DeviceLifecycle::new();
assert_eq!(lifecycle.state(), DeviceState::Uninitialized);
assert!(!lifecycle.is_ready());
lifecycle
.set_state(DeviceState::Initializing)
.expect("set_state should succeed");
lifecycle
.set_state(DeviceState::Ready)
.expect("set_state should succeed");
assert!(lifecycle.is_ready());
lifecycle
.set_error("Test error".to_string())
.expect("set_error should succeed");
assert!(lifecycle.is_error());
assert_eq!(lifecycle.error_info(), Some("Test error".to_string()));
lifecycle.reset().expect("reset should succeed");
assert_eq!(lifecycle.state(), DeviceState::Uninitialized);
assert!(lifecycle.error_info().is_none());
}
#[test]
fn test_device_context() {
let context = DeviceContext::new(DeviceType::Cpu);
assert_eq!(context.device_type(), DeviceType::Cpu);
context.set_property("test_prop".to_string(), "test_value".to_string());
assert_eq!(
context.get_property("test_prop"),
Some("test_value".to_string())
);
context.add_resource(42u32);
assert_eq!(context.resource_count(), 1);
context.clear_resources();
assert_eq!(context.resource_count(), 0);
}
#[test]
fn test_device_memory_info() {
let memory_info = DeviceMemoryInfo {
total: 1024 * 1024 * 1024, available: 512 * 1024 * 1024, used: 512 * 1024 * 1024, };
assert_eq!(memory_info.utilization_percent(), 50.0);
assert_eq!(memory_info.available_percent(), 50.0);
assert!(!memory_info.is_memory_pressure(75.0));
assert!(memory_info.is_memory_pressure(25.0));
}
#[test]
fn test_mock_device() {
let device = MockDevice::new(DeviceType::Cpu, "Test CPU".to_string());
assert_eq!(device.device_type(), DeviceType::Cpu);
assert_eq!(device.name(), "Test CPU");
assert!(device.is_available().expect("is_available should succeed"));
assert_eq!(device.device_id(), "cpu");
let cloned = device.clone_device().expect("clone_device should succeed");
assert!(device.is_same_device(cloned.as_ref()));
}
#[test]
fn test_device_registry() {
let registry = DeviceRegistry::new();
let stats = registry.statistics();
assert_eq!(stats.registered_factories, 0);
assert_eq!(stats.cached_devices, 0);
}
#[test]
fn test_utils_functions() {
let device1 = MockDevice::new(DeviceType::Cpu, "CPU 1".to_string());
let device2 = MockDevice::new(DeviceType::Cpu, "CPU 2".to_string());
let device3 = MockDevice::new(DeviceType::Cuda(0), "GPU 1".to_string());
assert!(utils::devices_compatible(&device1, &device2));
assert!(!utils::devices_compatible(&device1, &device3));
let devices = vec![&device1 as &dyn Device, &device2, &device3];
assert!(
utils::all_devices_available(&devices).expect("all_devices_available should succeed")
);
utils::synchronize_devices(&devices).expect("synchronize_devices should succeed");
let summary = utils::device_summary(&device1).expect("device_summary should succeed");
assert!(summary.contains("CPU 1"));
}
}