Skip to main content

coreml_native/
compute.rs

1//! Compute device enumeration and inspection.
2//!
3//! Discover available compute devices (CPU, GPU, Neural Engine)
4//! and their capabilities.
5
6use crate::error::{Error, ErrorKind, Result};
7
8// Suppress unused warnings until callers use the error types.
9const _: () = {
10    fn _assert_imports() {
11        let _ = Error::new(ErrorKind::UnsupportedPlatform, "");
12        let _: Result<()> = Ok(());
13    }
14};
15
16/// A compute device available for CoreML inference.
17#[derive(Debug, Clone, PartialEq)]
18pub enum ComputeDevice {
19    /// CPU compute device.
20    Cpu,
21    /// GPU (Metal) compute device.
22    Gpu {
23        /// Metal device name, if available.
24        name: Option<String>,
25    },
26    /// Apple Neural Engine.
27    NeuralEngine,
28}
29
30impl std::fmt::Display for ComputeDevice {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        match self {
33            Self::Cpu => write!(f, "CPU"),
34            Self::Gpu { name: Some(n) } => write!(f, "GPU ({n})"),
35            Self::Gpu { name: None } => write!(f, "GPU"),
36            Self::NeuralEngine => write!(f, "Neural Engine"),
37        }
38    }
39}
40
41/// Returns a list of all compute devices available for CoreML on this system.
42#[cfg(target_vendor = "apple")]
43pub fn available_devices() -> Vec<ComputeDevice> {
44    // Try to use MLAllComputeDevices if available, otherwise return a static list.
45    // MLAllComputeDevices is available on macOS 14+ / iOS 17+.
46    // For older systems, return a reasonable default.
47
48    // Since MLAllComputeDevices may require newer SDK, provide a safe default:
49    let mut devices = vec![ComputeDevice::Cpu];
50
51    // GPU is always available on macOS
52    devices.push(ComputeDevice::Gpu { name: None });
53
54    // Neural Engine is available on Apple Silicon
55    #[cfg(target_arch = "aarch64")]
56    devices.push(ComputeDevice::NeuralEngine);
57
58    devices
59}
60
61#[cfg(not(target_vendor = "apple"))]
62pub fn available_devices() -> Vec<ComputeDevice> {
63    vec![]
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn compute_device_display() {
72        assert_eq!(format!("{}", ComputeDevice::Cpu), "CPU");
73        assert_eq!(
74            format!(
75                "{}",
76                ComputeDevice::Gpu {
77                    name: Some("M1 Pro".into())
78                }
79            ),
80            "GPU (M1 Pro)"
81        );
82        assert_eq!(
83            format!("{}", ComputeDevice::Gpu { name: None }),
84            "GPU"
85        );
86        assert_eq!(format!("{}", ComputeDevice::NeuralEngine), "Neural Engine");
87    }
88
89    #[test]
90    fn compute_device_equality() {
91        assert_eq!(ComputeDevice::Cpu, ComputeDevice::Cpu);
92        assert_ne!(ComputeDevice::Cpu, ComputeDevice::NeuralEngine);
93    }
94
95    #[cfg(target_vendor = "apple")]
96    #[test]
97    fn available_devices_non_empty() {
98        let devices = available_devices();
99        assert!(!devices.is_empty());
100        assert!(devices.contains(&ComputeDevice::Cpu));
101    }
102
103    #[cfg(not(target_vendor = "apple"))]
104    #[test]
105    fn available_devices_empty_on_non_apple() {
106        assert!(available_devices().is_empty());
107    }
108}