use super::common::*;
use crate::error::RusTorchResult;
use crate::tensor::Tensor;
use ndarray::ScalarOperand;
use num_traits::{Float, FromPrimitive};
#[derive(Debug, thiserror::Error)]
pub enum CoreMLError {
#[error("CoreML does not support operation: {operation}")]
UnsupportedOperation {
operation: String,
},
#[error("CoreML not available")]
NotAvailable,
#[error("CoreML backend error: {message}")]
Backend {
message: String,
},
#[error("Invalid input for CoreML operation: {0}")]
InvalidInput(String),
#[error("Conversion error: {0}")]
ConversionError(String),
}
impl From<CoreMLError> for crate::error::RusTorchError {
fn from(err: CoreMLError) -> Self {
match err {
CoreMLError::UnsupportedOperation { operation } => {
crate::error::RusTorchError::InvalidOperation {
operation,
message: "CoreML does not support this operation".to_string(),
}
}
CoreMLError::NotAvailable => crate::error::RusTorchError::BackendUnavailable {
backend: "CoreML".to_string(),
},
CoreMLError::Backend { message } => crate::error::RusTorchError::Device {
device: "CoreML".to_string(),
message,
},
CoreMLError::InvalidInput(message) => crate::error::RusTorchError::InvalidParameters {
operation: "CoreML".to_string(),
message,
},
CoreMLError::ConversionError(message) => crate::error::RusTorchError::TensorOp {
message,
source: None,
},
}
}
}
pub mod linear_algebra;
pub mod convolution;
pub mod activation;
pub use activation::*;
pub use convolution::*;
pub use linear_algebra::*;
pub trait CoreMLOperation<T>
where
T: Float + FromPrimitive + ScalarOperand + 'static,
{
fn execute_coreml(&self, device_id: usize) -> CoreMLResult<Tensor<T>>;
fn is_supported_by_coreml(&self) -> bool;
fn estimated_execution_time(&self) -> Option<std::time::Duration> {
None }
}
pub struct CoreMLExecutor {
device: super::device::CoreMLDevice,
}
impl CoreMLExecutor {
pub fn new(device_id: usize) -> CoreMLResult<Self> {
let device = super::device::CoreMLDevice::new(device_id)?;
Ok(Self { device })
}
pub fn execute<T, Op>(&self, operation: &Op) -> CoreMLResult<Tensor<T>>
where
T: Float + FromPrimitive + ScalarOperand + 'static,
Op: CoreMLOperation<T>,
{
if !operation.is_supported_by_coreml() {
return Err(super::error_helpers::unsupported_operation(
"Operation not supported by CoreML",
));
}
operation.execute_coreml(self.device.device_id())
}
pub fn capabilities(&self) -> &CoreMLCapabilities {
self.device.capabilities()
}
}
#[macro_export]
macro_rules! coreml_fallback {
($operation:expr, $cpu_fallback:expr) => {
coreml_feature! {
match $operation {
Ok(result) => Ok(result),
Err(CoreMLError::UnsupportedOperation(_)) => $cpu_fallback,
Err(e) => Err(e.into()),
}
}
#[cfg(not(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
)))]
$cpu_fallback
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coreml_availability() {
let available = is_coreml_available();
println!("CoreML available: {}", available);
}
#[test]
fn test_executor_creation() {
let result = CoreMLExecutor::new(0);
#[cfg(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
))]
{
match result {
Ok(executor) => {
assert!(executor.capabilities().supports_f32);
}
Err(e) => {
println!("CoreML executor creation failed (expected): {}", e);
}
}
}
#[cfg(not(any(
feature = "coreml",
feature = "coreml-hybrid",
feature = "coreml-fallback"
)))]
{
assert!(result.is_err());
}
}
}