use super::common::error_helpers;
use super::common::*;
use crate::gpu::coreml::common::coreml_feature;
use crate::gpu::{device_cache::DeviceCache, DeviceType};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};
pub struct CoreMLDeviceManager {
cache: DeviceCache,
capabilities: CoreMLCapabilities,
last_check: Arc<Mutex<Option<Instant>>>,
}
impl CoreMLDeviceManager {
pub fn global() -> &'static CoreMLDeviceManager {
static MANAGER: OnceLock<CoreMLDeviceManager> = OnceLock::new();
MANAGER.get_or_init(|| CoreMLDeviceManager::new())
}
pub fn new() -> Self {
Self {
cache: DeviceCache::new(),
capabilities: CoreMLCapabilities::default(),
last_check: Arc::new(Mutex::new(None)),
}
}
pub fn is_available(&self) -> bool {
self.cache.is_device_available(&DeviceType::CoreML(0))
}
pub fn capabilities(&self) -> &CoreMLCapabilities {
if let Ok(mut last_check) = self.last_check.lock() {
let should_refresh = last_check
.map(|t| t.elapsed() > Duration::from_secs(30))
.unwrap_or(true);
if should_refresh {
*last_check = Some(Instant::now());
}
}
&self.capabilities
}
pub fn initialize(&self, device_id: usize) -> CoreMLResult<CoreMLDevice> {
if !is_coreml_available() {
return Err(error_helpers::not_available());
}
if !self.is_available() {
return Err(error_helpers::device_error("CoreML device not available"));
}
CoreMLDevice::new(device_id)
}
pub fn warmup(&self) -> CoreMLResult<()> {
if !self.is_available() {
return Err(error_helpers::not_available());
}
let _device = self.initialize(0)?;
Ok(())
}
pub fn optimal_device_for_op(&self, op_type: CoreMLOpType) -> Option<usize> {
if !self.is_available() {
return None;
}
match op_type {
CoreMLOpType::MatrixMultiplication => Some(0), CoreMLOpType::Convolution => Some(0), CoreMLOpType::Activation => Some(0), CoreMLOpType::ElementWise => Some(0), }
}
}
impl Default for CoreMLDeviceManager {
fn default() -> Self {
Self::new()
}
}
pub struct CoreMLDevice {
device_id: usize,
capabilities: CoreMLCapabilities,
initialized_at: Instant,
}
coreml_feature! {
impl CoreMLDevice {
pub fn new(device_id: usize) -> CoreMLResult<Self> {
if !is_coreml_available() {
return Err(error_helpers::not_available());
}
let capabilities = Self::detect_capabilities(device_id)?;
Ok(Self {
device_id,
capabilities,
initialized_at: Instant::now(),
})
}
pub fn device_id(&self) -> usize {
self.device_id
}
pub fn capabilities(&self) -> &CoreMLCapabilities {
&self.capabilities
}
pub fn is_ready(&self) -> bool {
self.initialized_at.elapsed() < Duration::from_secs(3600) }
fn detect_capabilities(device_id: usize) -> CoreMLResult<CoreMLCapabilities> {
#[cfg(target_os = "macos")]
{
use crate::backends::DeviceManager;
if !DeviceManager::is_coreml_available() {
return Err(error_helpers::not_available());
}
let neural_engine = true;
Ok(CoreMLCapabilities {
max_tensor_size: 100 * 1024 * 1024, supports_f32: true,
supports_f64: false, supports_complex: false, neural_engine_available: neural_engine,
gpu_acceleration_available: true, })
}
#[cfg(not(target_os = "macos"))]
Err(error_helpers::device_error(COREML_MACOS_ONLY))
}
pub fn validate_tensor<T>(&self, tensor_size: usize, dtype: &crate::dtype::DType) -> CoreMLResult<()> {
if tensor_size > self.capabilities.max_tensor_size {
return Err(error_helpers::unsupported_operation(&format!(
"Tensor too large: {} bytes > {} bytes max",
tensor_size, self.capabilities.max_tensor_size
)));
}
match dtype {
crate::dtype::DType::Float32 => {
if !self.capabilities.supports_f32 {
return Err(error_helpers::unsupported_operation("Float32 not supported"));
}
}
crate::dtype::DType::Float64 => {
if !self.capabilities.supports_f64 {
return Err(error_helpers::unsupported_operation("Float64 not supported by CoreML"));
}
}
crate::dtype::DType::Complex64 | crate::dtype::DType::Complex128 => {
return Err(error_helpers::unsupported_operation("Complex numbers not supported by CoreML"));
}
_ => {} }
Ok(())
}
}
}
#[cfg(not(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
)))]
impl CoreMLDevice {
pub fn new(_device_id: usize) -> CoreMLResult<Self> {
Err(error_helpers::feature_disabled())
}
pub fn device_id(&self) -> usize {
unreachable!("CoreML device not available without features")
}
pub fn capabilities(&self) -> &CoreMLCapabilities {
unreachable!("CoreML device not available without features")
}
pub fn is_ready(&self) -> bool {
false
}
pub fn validate_tensor<T>(
&self,
_tensor_size: usize,
_dtype: &crate::dtype::DType,
) -> CoreMLResult<()> {
Err(error_helpers::feature_disabled())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_manager_singleton() {
let manager1 = CoreMLDeviceManager::global();
let manager2 = CoreMLDeviceManager::global();
assert!(std::ptr::eq(manager1, manager2));
}
#[test]
fn test_capabilities_default() {
let caps = CoreMLCapabilities::default();
assert!(caps.supports_f32);
assert!(!caps.supports_f64);
assert!(!caps.supports_complex);
}
#[test]
fn test_error_conversion() {
let rustorch_err = crate::error::RusTorchError::Device {
device: "CoreML".to_string(),
message: COREML_NOT_AVAILABLE.to_string(),
};
match rustorch_err {
crate::error::RusTorchError::Device { device, message } => {
assert_eq!(device, "CoreML");
assert_eq!(message, COREML_NOT_AVAILABLE);
}
_ => panic!("Unexpected error type"),
}
}
}