use crate::error::{Error, ErrorKind, Result};
const _: () = {
fn _assert_imports() {
let _ = Error::new(ErrorKind::UnsupportedPlatform, "");
let _: Result<()> = Ok(());
}
};
#[derive(Debug, Clone, PartialEq)]
pub enum ComputeDevice {
Cpu,
Gpu {
name: Option<String>,
},
NeuralEngine,
}
impl std::fmt::Display for ComputeDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Cpu => write!(f, "CPU"),
Self::Gpu { name: Some(n) } => write!(f, "GPU ({n})"),
Self::Gpu { name: None } => write!(f, "GPU"),
Self::NeuralEngine => write!(f, "Neural Engine"),
}
}
}
#[cfg(target_vendor = "apple")]
pub fn available_devices() -> Vec<ComputeDevice> {
let mut devices = vec![ComputeDevice::Cpu];
devices.push(ComputeDevice::Gpu { name: None });
#[cfg(target_arch = "aarch64")]
devices.push(ComputeDevice::NeuralEngine);
devices
}
#[cfg(not(target_vendor = "apple"))]
pub fn available_devices() -> Vec<ComputeDevice> {
vec![]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn compute_device_display() {
assert_eq!(format!("{}", ComputeDevice::Cpu), "CPU");
assert_eq!(
format!(
"{}",
ComputeDevice::Gpu {
name: Some("M1 Pro".into())
}
),
"GPU (M1 Pro)"
);
assert_eq!(
format!("{}", ComputeDevice::Gpu { name: None }),
"GPU"
);
assert_eq!(format!("{}", ComputeDevice::NeuralEngine), "Neural Engine");
}
#[test]
fn compute_device_equality() {
assert_eq!(ComputeDevice::Cpu, ComputeDevice::Cpu);
assert_ne!(ComputeDevice::Cpu, ComputeDevice::NeuralEngine);
}
#[cfg(target_vendor = "apple")]
#[test]
fn available_devices_non_empty() {
let devices = available_devices();
assert!(!devices.is_empty());
assert!(devices.contains(&ComputeDevice::Cpu));
}
#[cfg(not(target_vendor = "apple"))]
#[test]
fn available_devices_empty_on_non_apple() {
assert!(available_devices().is_empty());
}
}