use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ExecutionProviderKind {
#[default]
Cpu,
#[cfg(feature = "ort-coreml")]
CoreML(CoreMLConfig),
}
impl fmt::Display for ExecutionProviderKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Cpu => write!(f, "cpu"),
#[cfg(feature = "ort-coreml")]
Self::CoreML(config) => write!(f, "coreml-{}", config.compute_units),
}
}
}
impl ExecutionProviderKind {
pub fn name(&self) -> &'static str {
match self {
Self::Cpu => "cpu",
#[cfg(feature = "ort-coreml")]
Self::CoreML(_) => "coreml",
}
}
pub fn requires_hardware(&self) -> bool {
match self {
Self::Cpu => false,
#[cfg(feature = "ort-coreml")]
Self::CoreML(_) => true,
}
}
}
#[cfg(feature = "ort-coreml")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CoreMLConfig {
pub compute_units: CoreMLComputeUnits,
pub use_subgraphs: bool,
pub require_static_shapes: bool,
}
#[cfg(feature = "ort-coreml")]
impl Default for CoreMLConfig {
fn default() -> Self {
Self {
compute_units: CoreMLComputeUnits::default(),
use_subgraphs: true,
require_static_shapes: false,
}
}
}
#[cfg(feature = "ort-coreml")]
impl CoreMLConfig {
pub fn with_neural_engine() -> Self {
Self {
compute_units: CoreMLComputeUnits::CpuAndNeuralEngine,
..Default::default()
}
}
pub fn with_gpu() -> Self {
Self {
compute_units: CoreMLComputeUnits::CpuAndGpu,
..Default::default()
}
}
pub fn cpu_only() -> Self {
Self {
compute_units: CoreMLComputeUnits::CpuOnly,
..Default::default()
}
}
}
#[cfg(feature = "ort-coreml")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CoreMLComputeUnits {
CpuOnly,
CpuAndGpu,
#[default]
CpuAndNeuralEngine,
All,
}
#[cfg(feature = "ort-coreml")]
impl fmt::Display for CoreMLComputeUnits {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::CpuOnly => write!(f, "cpu"),
Self::CpuAndGpu => write!(f, "gpu"),
Self::CpuAndNeuralEngine => write!(f, "ane"),
Self::All => write!(f, "all"),
}
}
}
#[cfg(feature = "ort-coreml")]
impl CoreMLComputeUnits {
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"cpu" | "cpu-only" => Some(Self::CpuOnly),
"gpu" | "cpu-gpu" => Some(Self::CpuAndGpu),
"ane" | "neural-engine" | "cpu-ane" => Some(Self::CpuAndNeuralEngine),
"all" => Some(Self::All),
_ => None,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct ModelHints {
pub task: Option<String>,
pub input_shapes: Option<Vec<Vec<i64>>>,
pub static_shapes: Option<bool>,
pub model_size_mb: Option<f32>,
pub autoregressive: Option<bool>,
pub explicit_provider: Option<String>,
}
impl ModelHints {
pub fn from_metadata(metadata: &serde_json::Value) -> Self {
let mut hints = Self::default();
if let Some(task) = metadata.get("task").and_then(|v| v.as_str()) {
hints.task = Some(task.to_string());
}
if let Some(input_shape) = metadata.get("input_shape").and_then(|v| v.as_array()) {
let shape: Vec<i64> = input_shape.iter().filter_map(|v| v.as_i64()).collect();
if !shape.is_empty() {
let is_static = shape.iter().all(|&d| d > 0);
hints.static_shapes = Some(is_static);
hints.input_shapes = Some(vec![shape]);
}
}
if let Some(size) = metadata.get("model_size_mb").and_then(|v| v.as_f64()) {
hints.model_size_mb = Some(size as f32);
}
hints
}
pub fn is_vision_model(&self) -> bool {
if let Some(ref task) = self.task {
let task_lower = task.to_lowercase();
return task_lower.contains("image")
|| task_lower.contains("vision")
|| task_lower.contains("classification")
|| task_lower.contains("detection")
|| task_lower.contains("segmentation");
}
if let Some(ref shapes) = self.input_shapes {
if let Some(shape) = shapes.first() {
if shape.len() == 4 && shape[1] <= 4 {
return true;
}
}
}
false
}
pub fn is_tts_model(&self) -> bool {
if let Some(ref task) = self.task {
let task_lower = task.to_lowercase();
return task_lower.contains("tts")
|| task_lower.contains("text-to-speech")
|| task_lower.contains("speech-synthesis");
}
false
}
pub fn is_embedding_model(&self) -> bool {
if let Some(ref task) = self.task {
let task_lower = task.to_lowercase();
return task_lower.contains("embedding")
|| task_lower.contains("encoder")
|| task_lower.contains("sentence");
}
false
}
pub fn is_tiny_model(&self) -> bool {
if let Some(size) = self.model_size_mb {
return size < 1.0; }
false
}
}
pub fn select_optimal_provider(hints: &ModelHints) -> ExecutionProviderKind {
if let Some(ref explicit) = hints.explicit_provider {
return parse_provider_string(explicit);
}
#[cfg(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios")))]
{
if hints.is_tiny_model() {
return ExecutionProviderKind::Cpu;
}
if hints.is_tts_model() || hints.autoregressive == Some(true) {
return ExecutionProviderKind::Cpu;
}
if hints.static_shapes == Some(false) {
return ExecutionProviderKind::Cpu;
}
if hints.is_vision_model() && hints.static_shapes != Some(false) {
return ExecutionProviderKind::CoreML(CoreMLConfig::with_neural_engine());
}
if hints.is_embedding_model() {
return ExecutionProviderKind::CoreML(CoreMLConfig::with_neural_engine());
}
if hints.static_shapes == Some(true) {
if let Some(size) = hints.model_size_mb {
if (1.0..=500.0).contains(&size) {
return ExecutionProviderKind::CoreML(CoreMLConfig::with_neural_engine());
}
}
}
}
ExecutionProviderKind::Cpu
}
pub fn parse_provider_string(s: &str) -> ExecutionProviderKind {
let s_lower = s.to_lowercase();
match s_lower.as_str() {
"cpu" => ExecutionProviderKind::Cpu,
#[cfg(feature = "ort-coreml")]
"coreml" | "coreml-ane" | "ane" | "neural-engine" => {
ExecutionProviderKind::CoreML(CoreMLConfig::with_neural_engine())
}
#[cfg(feature = "ort-coreml")]
"coreml-gpu" | "gpu" => ExecutionProviderKind::CoreML(CoreMLConfig::with_gpu()),
#[cfg(feature = "ort-coreml")]
"coreml-all" | "all" => ExecutionProviderKind::CoreML(CoreMLConfig {
compute_units: CoreMLComputeUnits::All,
..Default::default()
}),
_ => ExecutionProviderKind::Cpu,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_provider_is_cpu() {
let provider = ExecutionProviderKind::default();
assert_eq!(provider, ExecutionProviderKind::Cpu);
assert_eq!(provider.name(), "cpu");
assert!(!provider.requires_hardware());
}
#[test]
fn test_cpu_provider_display() {
let provider = ExecutionProviderKind::Cpu;
assert_eq!(format!("{}", provider), "cpu");
}
#[cfg(feature = "ort-coreml")]
#[test]
fn test_coreml_config_default() {
let config = CoreMLConfig::default();
assert_eq!(config.compute_units, CoreMLComputeUnits::CpuAndNeuralEngine);
assert!(config.use_subgraphs);
assert!(!config.require_static_shapes);
}
#[cfg(feature = "ort-coreml")]
#[test]
fn test_coreml_compute_units_from_str() {
assert_eq!(
CoreMLComputeUnits::from_str("ane"),
Some(CoreMLComputeUnits::CpuAndNeuralEngine)
);
assert_eq!(
CoreMLComputeUnits::from_str("gpu"),
Some(CoreMLComputeUnits::CpuAndGpu)
);
assert_eq!(
CoreMLComputeUnits::from_str("cpu"),
Some(CoreMLComputeUnits::CpuOnly)
);
assert_eq!(
CoreMLComputeUnits::from_str("all"),
Some(CoreMLComputeUnits::All)
);
assert_eq!(CoreMLComputeUnits::from_str("invalid"), None);
}
#[cfg(feature = "ort-coreml")]
#[test]
fn test_coreml_provider_display() {
let provider = ExecutionProviderKind::CoreML(CoreMLConfig::with_neural_engine());
assert_eq!(format!("{}", provider), "coreml-ane");
assert_eq!(provider.name(), "coreml");
assert!(provider.requires_hardware());
}
#[test]
fn test_model_hints_from_metadata_vision() {
let metadata = serde_json::json!({
"task": "image_classification",
"input_shape": [1, 3, 224, 224],
"model_size_mb": 13.3
});
let hints = ModelHints::from_metadata(&metadata);
assert_eq!(hints.task, Some("image_classification".to_string()));
assert_eq!(hints.static_shapes, Some(true));
assert_eq!(hints.model_size_mb, Some(13.3));
assert!(hints.is_vision_model());
assert!(!hints.is_tts_model());
}
#[test]
fn test_model_hints_from_metadata_tts() {
let metadata = serde_json::json!({
"task": "text-to-speech",
"model_size_mb": 170.0
});
let hints = ModelHints::from_metadata(&metadata);
assert!(hints.is_tts_model());
assert!(!hints.is_vision_model());
}
#[test]
fn test_model_hints_vision_by_shape() {
let hints = ModelHints {
input_shapes: Some(vec![vec![1, 3, 224, 224]]),
..Default::default()
};
assert!(hints.is_vision_model());
}
#[test]
fn test_tiny_model_detection() {
let hints = ModelHints {
model_size_mb: Some(0.5),
..Default::default()
};
assert!(hints.is_tiny_model());
let hints_large = ModelHints {
model_size_mb: Some(13.0),
..Default::default()
};
assert!(!hints_large.is_tiny_model());
}
#[test]
fn test_parse_provider_string() {
assert_eq!(parse_provider_string("cpu"), ExecutionProviderKind::Cpu);
assert_eq!(parse_provider_string("CPU"), ExecutionProviderKind::Cpu);
assert_eq!(parse_provider_string("unknown"), ExecutionProviderKind::Cpu);
}
#[cfg(feature = "ort-coreml")]
#[test]
fn test_parse_provider_string_coreml() {
match parse_provider_string("coreml-ane") {
ExecutionProviderKind::CoreML(config) => {
assert_eq!(config.compute_units, CoreMLComputeUnits::CpuAndNeuralEngine);
}
_ => panic!("Expected CoreML provider"),
}
match parse_provider_string("gpu") {
ExecutionProviderKind::CoreML(config) => {
assert_eq!(config.compute_units, CoreMLComputeUnits::CpuAndGpu);
}
_ => panic!("Expected CoreML GPU provider"),
}
}
#[test]
fn test_select_optimal_provider_explicit_override() {
let hints = ModelHints {
explicit_provider: Some("cpu".to_string()),
task: Some("image_classification".to_string()),
static_shapes: Some(true),
model_size_mb: Some(13.0),
..Default::default()
};
let provider = select_optimal_provider(&hints);
assert_eq!(provider, ExecutionProviderKind::Cpu);
}
#[cfg(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios")))]
#[test]
fn test_select_optimal_provider_vision() {
let hints = ModelHints {
task: Some("image_classification".to_string()),
static_shapes: Some(true),
model_size_mb: Some(13.0),
..Default::default()
};
let provider = select_optimal_provider(&hints);
match provider {
ExecutionProviderKind::CoreML(config) => {
assert_eq!(config.compute_units, CoreMLComputeUnits::CpuAndNeuralEngine);
}
_ => panic!("Expected CoreML ANE for vision model"),
}
}
#[cfg(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios")))]
#[test]
fn test_select_optimal_provider_tts_returns_cpu() {
let hints = ModelHints {
task: Some("text-to-speech".to_string()),
model_size_mb: Some(170.0),
..Default::default()
};
let provider = select_optimal_provider(&hints);
assert_eq!(provider, ExecutionProviderKind::Cpu);
}
#[cfg(all(feature = "ort-coreml", any(target_os = "macos", target_os = "ios")))]
#[test]
fn test_select_optimal_provider_tiny_model_returns_cpu() {
let hints = ModelHints {
task: Some("image_classification".to_string()),
static_shapes: Some(true),
model_size_mb: Some(0.02), ..Default::default()
};
let provider = select_optimal_provider(&hints);
assert_eq!(provider, ExecutionProviderKind::Cpu);
}
}