use crate::gpu::DeviceType;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub enum DeviceStatus {
Available,
Unavailable(String), Initializing,
}
#[derive(Debug, Clone)]
pub struct CachedDevice {
pub status: DeviceStatus,
pub last_checked: Instant,
pub initialization_time: Option<Duration>,
}
impl CachedDevice {
pub fn new(status: DeviceStatus) -> Self {
Self {
status,
last_checked: Instant::now(),
initialization_time: None,
}
}
pub fn with_init_time(mut self, duration: Duration) -> Self {
self.initialization_time = Some(duration);
self
}
pub fn is_valid(&self) -> bool {
self.last_checked.elapsed() < Duration::from_secs(30)
}
}
#[derive(Debug, Clone)]
pub struct DeviceCache {
cache: Arc<Mutex<HashMap<DeviceType, CachedDevice>>>,
}
impl DeviceCache {
pub fn global() -> &'static DeviceCache {
static CACHE: OnceLock<DeviceCache> = OnceLock::new();
CACHE.get_or_init(|| DeviceCache::new())
}
pub fn new() -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn is_device_available(&self, device: &DeviceType) -> bool {
if let Some(cached) = self.get_cached_status(device) {
if cached.is_valid() {
return matches!(cached.status, DeviceStatus::Available);
}
}
let start = Instant::now();
let is_available = self.check_device_availability_impl(device);
let check_duration = start.elapsed();
let status = if is_available {
DeviceStatus::Available
} else {
DeviceStatus::Unavailable("Device check failed".to_string())
};
let cached_device = CachedDevice::new(status).with_init_time(check_duration);
self.update_cache(device.clone(), cached_device);
is_available
}
pub fn get_cached_status(&self, device: &DeviceType) -> Option<CachedDevice> {
let cache = self.cache.lock().ok()?;
cache.get(device).cloned()
}
pub fn update_cache(&self, device: DeviceType, cached_device: CachedDevice) {
if let Ok(mut cache) = self.cache.lock() {
cache.insert(device, cached_device);
}
}
pub fn cleanup_expired(&self) {
if let Ok(mut cache) = self.cache.lock() {
cache.retain(|_, cached| cached.is_valid());
}
}
pub fn get_stats(&self) -> CacheStats {
let cache = self.cache.lock().unwrap();
let total_entries = cache.len();
let valid_entries = cache.values().filter(|c| c.is_valid()).count();
let available_devices = cache
.values()
.filter(|c| c.is_valid() && matches!(c.status, DeviceStatus::Available))
.count();
CacheStats {
total_entries,
valid_entries,
available_devices,
cache_hit_rate: if total_entries > 0 {
valid_entries as f64 / total_entries as f64
} else {
0.0
},
}
}
fn check_device_availability_impl(&self, device: &DeviceType) -> bool {
match device {
DeviceType::Cpu => true, DeviceType::Auto => true,
#[cfg(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
))]
DeviceType::CoreML(_) => {
use crate::backends::DeviceManager;
DeviceManager::is_coreml_available()
}
#[cfg(feature = "metal")]
DeviceType::Metal(_) => {
use crate::backends::DeviceManager;
DeviceManager::is_metal_available()
}
#[cfg(not(feature = "metal"))]
DeviceType::Metal(_) => false,
#[cfg(feature = "cuda")]
DeviceType::Cuda(_) => {
false
}
#[cfg(not(feature = "cuda"))]
DeviceType::Cuda(_) => false,
#[cfg(feature = "opencl")]
DeviceType::OpenCL(_) => {
false
}
#[cfg(not(feature = "opencl"))]
DeviceType::OpenCL(_) => false,
#[cfg(feature = "coreml-hybrid")]
DeviceType::CoreMLHybrid { .. } => {
cfg!(target_os = "macos")
}
#[cfg(feature = "mac-hybrid")]
DeviceType::MacHybrid => {
cfg!(target_os = "macos")
}
}
}
pub fn warmup(&self) {
#[allow(unused_mut)]
let mut devices_to_check = vec![DeviceType::Cpu];
#[cfg(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
))]
devices_to_check.push(DeviceType::CoreML(0));
#[cfg(feature = "metal")]
devices_to_check.push(DeviceType::Metal(0));
#[cfg(feature = "cuda")]
devices_to_check.push(DeviceType::Cuda(0));
for device in devices_to_check {
self.is_device_available(&device);
}
}
}
impl Default for DeviceCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub total_entries: usize,
pub valid_entries: usize,
pub available_devices: usize,
pub cache_hit_rate: f64,
}
#[derive(Debug)]
pub struct CoreMLCache {
is_initialized: Arc<Mutex<bool>>,
initialization_result: Arc<Mutex<Option<Result<(), String>>>>,
}
impl CoreMLCache {
pub fn global() -> &'static CoreMLCache {
static COREML_CACHE: OnceLock<CoreMLCache> = OnceLock::new();
COREML_CACHE.get_or_init(|| CoreMLCache::new())
}
pub fn new() -> Self {
Self {
is_initialized: Arc::new(Mutex::new(false)),
initialization_result: Arc::new(Mutex::new(None)),
}
}
pub fn ensure_initialized(&self) -> Result<(), String> {
if let Ok(initialized) = self.is_initialized.lock() {
if *initialized {
if let Ok(result) = self.initialization_result.lock() {
if let Some(ref cached_result) = *result {
return cached_result.clone();
}
}
}
}
let result = self.initialize_coreml();
if let (Ok(mut initialized), Ok(mut cached_result)) = (
self.is_initialized.lock(),
self.initialization_result.lock(),
) {
*initialized = true;
*cached_result = Some(result.clone());
}
result
}
fn initialize_coreml(&self) -> Result<(), String> {
#[cfg(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
))]
{
if !cfg!(target_os = "macos") {
return Err("CoreML is only available on macOS".to_string());
}
use crate::backends::DeviceManager;
if DeviceManager::is_coreml_available() {
Ok(())
} else {
Err("CoreML not available on this system".to_string())
}
}
#[cfg(not(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
)))]
{
Err("CoreML features not enabled".to_string())
}
}
pub fn reset(&self) {
if let (Ok(mut initialized), Ok(mut result)) = (
self.is_initialized.lock(),
self.initialization_result.lock(),
) {
*initialized = false;
*result = None;
}
}
}
impl Default for CoreMLCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_cache_basic() {
let cache = DeviceCache::new();
assert!(cache.is_device_available(&DeviceType::Cpu));
let cached = cache.get_cached_status(&DeviceType::Cpu);
assert!(cached.is_some());
assert!(matches!(cached.unwrap().status, DeviceStatus::Available));
}
#[test]
fn test_cache_expiration() {
let cache = DeviceCache::new();
let expired_device = CachedDevice {
status: DeviceStatus::Available,
last_checked: Instant::now() - Duration::from_secs(60), initialization_time: None,
};
cache.update_cache(DeviceType::Cpu, expired_device);
assert!(cache.is_device_available(&DeviceType::Cpu));
}
#[test]
fn test_coreml_cache() {
let cache = CoreMLCache::new();
let result1 = cache.ensure_initialized();
let result2 = cache.ensure_initialized();
match (result1, result2) {
(Ok(()), Ok(())) => {}
(Err(e1), Err(e2)) => assert_eq!(e1, e2),
_ => panic!("Inconsistent cache results"),
}
}
#[test]
#[cfg(feature = "coreml")]
fn test_unsupported_operation_bypass() {
use crate::gpu::smart_device_selector::{
OperationProfile, OperationType, SmartDeviceSelector,
};
use crate::gpu::DeviceType;
let available_devices = vec![DeviceType::CoreML(0), DeviceType::Metal(0), DeviceType::Cpu];
let selector = SmartDeviceSelector::new(available_devices);
let supported_profile = OperationProfile::new(
OperationType::MatrixMultiplication,
&[128, 128],
4, );
let selected = selector.select_device(&supported_profile);
assert!(
matches!(selected, DeviceType::CoreML(_)) || matches!(selected, DeviceType::Metal(_))
);
let unsupported_ops = vec![
OperationType::ComplexNumber,
OperationType::StatisticalDistribution,
OperationType::CustomKernel,
OperationType::DistributedOp,
];
for op_type in unsupported_ops {
let unsupported_profile = OperationProfile::new(
op_type,
&[128, 128],
8, );
let selected = selector.select_device(&unsupported_profile);
assert!(!matches!(selected, DeviceType::CoreML(_)));
assert!(
matches!(selected, DeviceType::Metal(_)) || matches!(selected, DeviceType::Cpu)
);
}
}
}