1use crate::error::{Error, ErrorKind, Result};
7
8const _: () = {
10 fn _assert_imports() {
11 let _ = Error::new(ErrorKind::UnsupportedPlatform, "");
12 let _: Result<()> = Ok(());
13 }
14};
15
16#[derive(Debug, Clone, PartialEq)]
18pub enum ComputeDevice {
19 Cpu,
21 Gpu {
23 name: Option<String>,
25 },
26 NeuralEngine,
28}
29
30impl std::fmt::Display for ComputeDevice {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 Self::Cpu => write!(f, "CPU"),
34 Self::Gpu { name: Some(n) } => write!(f, "GPU ({n})"),
35 Self::Gpu { name: None } => write!(f, "GPU"),
36 Self::NeuralEngine => write!(f, "Neural Engine"),
37 }
38 }
39}
40
41#[cfg(target_vendor = "apple")]
43pub fn available_devices() -> Vec<ComputeDevice> {
44 let mut devices = vec![ComputeDevice::Cpu];
50
51 devices.push(ComputeDevice::Gpu { name: None });
53
54 #[cfg(target_arch = "aarch64")]
56 devices.push(ComputeDevice::NeuralEngine);
57
58 devices
59}
60
61#[cfg(not(target_vendor = "apple"))]
62pub fn available_devices() -> Vec<ComputeDevice> {
63 vec![]
64}
65
66#[cfg(test)]
67mod tests {
68 use super::*;
69
70 #[test]
71 fn compute_device_display() {
72 assert_eq!(format!("{}", ComputeDevice::Cpu), "CPU");
73 assert_eq!(
74 format!(
75 "{}",
76 ComputeDevice::Gpu {
77 name: Some("M1 Pro".into())
78 }
79 ),
80 "GPU (M1 Pro)"
81 );
82 assert_eq!(
83 format!("{}", ComputeDevice::Gpu { name: None }),
84 "GPU"
85 );
86 assert_eq!(format!("{}", ComputeDevice::NeuralEngine), "Neural Engine");
87 }
88
89 #[test]
90 fn compute_device_equality() {
91 assert_eq!(ComputeDevice::Cpu, ComputeDevice::Cpu);
92 assert_ne!(ComputeDevice::Cpu, ComputeDevice::NeuralEngine);
93 }
94
95 #[cfg(target_vendor = "apple")]
96 #[test]
97 fn available_devices_non_empty() {
98 let devices = available_devices();
99 assert!(!devices.is_empty());
100 assert!(devices.contains(&ComputeDevice::Cpu));
101 }
102
103 #[cfg(not(target_vendor = "apple"))]
104 #[test]
105 fn available_devices_empty_on_non_apple() {
106 assert!(available_devices().is_empty());
107 }
108}