use super::backend::CoreMLBackend;
use super::common::*;
use super::device::CoreMLDeviceManager;
use super::operations::*;
use crate::error::RusTorchError;
use crate::tensor::Tensor;
use ndarray::ScalarOperand;
use num_traits::{Float, FromPrimitive};
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HybridStrategy {
CoreMLPreferred,
CPUPreferred,
Automatic,
ForceCoreMI,
ForceCPU,
}
impl Default for HybridStrategy {
fn default() -> Self {
Self::Automatic
}
}
pub struct CoreMLHybridExecutor {
backend: Arc<CoreMLBackend>,
strategy: HybridStrategy,
performance_stats: Arc<std::sync::Mutex<PerformanceStats>>,
}
#[derive(Debug, Default)]
struct PerformanceStats {
coreml_success_rate: f64,
coreml_avg_time: std::time::Duration,
cpu_avg_time: std::time::Duration,
total_operations: usize,
}
impl CoreMLHybridExecutor {
pub fn new(strategy: HybridStrategy) -> CoreMLResult<Self> {
let backend = Arc::new((*CoreMLBackend::global()).clone());
Ok(Self {
backend,
strategy,
performance_stats: Arc::new(std::sync::Mutex::new(PerformanceStats::default())),
})
}
pub fn execute<T, Op, CpuFn>(
&self,
operation: &Op,
cpu_fallback: CpuFn,
) -> Result<Tensor<T>, RusTorchError>
where
T: Float + FromPrimitive + ScalarOperand + 'static,
Op: CoreMLOperation<T>,
CpuFn: FnOnce() -> Result<Tensor<T>, RusTorchError>,
{
match self.decide_execution_path(operation) {
ExecutionPath::CoreML => self.execute_coreml_with_fallback(operation, cpu_fallback),
ExecutionPath::CPU => self.execute_cpu_with_stats(cpu_fallback),
ExecutionPath::TryBoth => {
let coreml_result = self.try_coreml_execution(operation);
match coreml_result {
Ok(result) => Ok(result),
Err(_) => self.execute_cpu_with_stats(cpu_fallback),
}
}
}
}
fn decide_execution_path<T, Op>(&self, operation: &Op) -> ExecutionPath
where
T: Float + FromPrimitive + ScalarOperand + 'static,
Op: CoreMLOperation<T>,
{
match self.strategy {
HybridStrategy::ForceCoreMI => ExecutionPath::CoreML,
HybridStrategy::ForceCPU => ExecutionPath::CPU,
HybridStrategy::CoreMLPreferred => {
if operation.is_supported_by_coreml() {
ExecutionPath::CoreML
} else {
ExecutionPath::CPU
}
}
HybridStrategy::CPUPreferred => {
if self.should_use_coreml_for_large_ops(operation) {
ExecutionPath::CoreML
} else {
ExecutionPath::CPU
}
}
HybridStrategy::Automatic => self.automatic_decision(operation),
}
}
fn automatic_decision<T, Op>(&self, operation: &Op) -> ExecutionPath
where
T: Float + FromPrimitive + ScalarOperand + 'static,
Op: CoreMLOperation<T>,
{
if !operation.is_supported_by_coreml() {
return ExecutionPath::CPU;
}
if let Ok(stats) = self.performance_stats.lock() {
if stats.total_operations < 10 {
return ExecutionPath::CoreML;
}
if stats.coreml_success_rate > 0.8 && stats.coreml_avg_time < stats.cpu_avg_time * 2 {
ExecutionPath::CoreML
} else {
ExecutionPath::CPU
}
} else {
ExecutionPath::CoreML
}
}
fn should_use_coreml_for_large_ops<T, Op>(&self, operation: &Op) -> bool
where
T: Float + FromPrimitive + ScalarOperand + 'static,
Op: CoreMLOperation<T>,
{
if !operation.is_supported_by_coreml() {
return false;
}
if let Some(estimated_time) = operation.estimated_execution_time() {
estimated_time.as_millis() > 10 } else {
false
}
}
fn execute_coreml_with_fallback<T, Op, CpuFn>(
&self,
operation: &Op,
cpu_fallback: CpuFn,
) -> Result<Tensor<T>, RusTorchError>
where
T: Float + FromPrimitive + ScalarOperand + 'static,
Op: CoreMLOperation<T>,
CpuFn: FnOnce() -> Result<Tensor<T>, RusTorchError>,
{
let start_time = std::time::Instant::now();
match self.try_coreml_execution(operation) {
Ok(result) => {
self.update_stats(true, start_time.elapsed(), None);
Ok(result)
}
Err(e) => {
let cpu_start = std::time::Instant::now();
match cpu_fallback() {
Ok(result) => {
self.update_stats(false, start_time.elapsed(), Some(cpu_start.elapsed()));
Ok(result)
}
Err(cpu_err) => {
Err(e)
}
}
}
}
}
fn execute_cpu_with_stats<T, CpuFn>(
&self,
cpu_fallback: CpuFn,
) -> Result<Tensor<T>, RusTorchError>
where
T: num_traits::Float + 'static,
CpuFn: FnOnce() -> Result<Tensor<T>, RusTorchError>,
{
let start_time = std::time::Instant::now();
let result = cpu_fallback();
if result.is_ok() {
self.update_stats(false, std::time::Duration::ZERO, Some(start_time.elapsed()));
}
result
}
fn try_coreml_execution<T, Op>(&self, operation: &Op) -> CoreMLResult<Tensor<T>>
where
T: Float + FromPrimitive + ScalarOperand + 'static,
Op: CoreMLOperation<T>,
{
operation.execute_coreml(0) }
fn update_stats(
&self,
coreml_success: bool,
coreml_time: std::time::Duration,
cpu_time: Option<std::time::Duration>,
) {
if let Ok(mut stats) = self.performance_stats.lock() {
stats.total_operations += 1;
let success_count =
(stats.coreml_success_rate * (stats.total_operations - 1) as f64) as usize;
let new_success_count = if coreml_success {
success_count + 1
} else {
success_count
};
stats.coreml_success_rate = new_success_count as f64 / stats.total_operations as f64;
if coreml_success && coreml_time > std::time::Duration::ZERO {
stats.coreml_avg_time = self.update_running_average(
stats.coreml_avg_time,
coreml_time,
stats.total_operations,
);
}
if let Some(cpu_time) = cpu_time {
stats.cpu_avg_time = self.update_running_average(
stats.cpu_avg_time,
cpu_time,
stats.total_operations,
);
}
}
}
fn update_running_average(
&self,
current_avg: std::time::Duration,
new_value: std::time::Duration,
count: usize,
) -> std::time::Duration {
let current_total = current_avg.as_nanos() as f64 * (count - 1) as f64;
let new_total = current_total + new_value.as_nanos() as f64;
std::time::Duration::from_nanos((new_total / count as f64) as u64)
}
pub fn get_performance_stats(
&self,
) -> Option<(f64, std::time::Duration, std::time::Duration, usize)> {
if let Ok(stats) = self.performance_stats.lock() {
Some((
stats.coreml_success_rate,
stats.coreml_avg_time,
stats.cpu_avg_time,
stats.total_operations,
))
} else {
None
}
}
pub fn reset_stats(&self) {
if let Ok(mut stats) = self.performance_stats.lock() {
*stats = PerformanceStats::default();
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ExecutionPath {
CoreML,
CPU,
TryBoth,
}
#[macro_export]
macro_rules! coreml_hybrid {
($operation:expr, $cpu_fallback:expr) => {{
let executor = CoreMLHybridExecutor::new(HybridStrategy::Automatic)?;
executor.execute(&$operation, || $cpu_fallback)
}};
($operation:expr, $cpu_fallback:expr, $strategy:expr) => {{
let executor = CoreMLHybridExecutor::new($strategy)?;
executor.execute(&$operation, || $cpu_fallback)
}};
}
static GLOBAL_HYBRID_EXECUTOR: std::sync::OnceLock<CoreMLHybridExecutor> =
std::sync::OnceLock::new();
pub fn global_hybrid_executor() -> &'static CoreMLHybridExecutor {
GLOBAL_HYBRID_EXECUTOR.get_or_init(|| {
CoreMLHybridExecutor::new(HybridStrategy::Automatic).unwrap_or_else(|_| {
CoreMLHybridExecutor::new(HybridStrategy::CPUPreferred)
.expect("Failed to create fallback hybrid executor")
})
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hybrid_executor_creation() {
let executor = CoreMLHybridExecutor::new(HybridStrategy::Automatic);
match executor {
Ok(_) => {
println!("Hybrid executor created successfully");
}
Err(e) => {
println!("Expected failure on platforms without CoreML: {}", e);
}
}
}
#[test]
fn test_global_hybrid_executor() {
let executor = global_hybrid_executor();
let stats = executor.get_performance_stats();
assert!(stats.is_some());
let (success_rate, _, _, operations) = stats.unwrap();
assert_eq!(operations, 0); assert_eq!(success_rate, 0.0); }
#[test]
fn test_strategy_decisions() {
use super::super::operations::linear_algebra::MatMulOperation;
use crate::tensor::Tensor;
let a = Tensor::<f32>::zeros(&[2, 2]);
let b = Tensor::<f32>::zeros(&[2, 2]);
let operation = MatMulOperation::new(a, b);
if let Ok(executor) = CoreMLHybridExecutor::new(HybridStrategy::CPUPreferred) {
let path = executor.decide_execution_path(&operation);
assert_eq!(path, ExecutionPath::CPU); }
}
}