1use std::fmt;
4use std::sync::OnceLock;
5
6#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
10pub enum Device {
11 Cpu,
13 Wasm,
15 WebGpu,
17}
18
19impl Device {
20 pub fn all() -> &'static [Device] {
22 &[Device::Cpu, Device::Wasm, Device::WebGpu]
23 }
24
25 pub fn as_str(&self) -> &'static str {
27 match self {
28 Device::Cpu => "cpu",
29 Device::Wasm => "wasm",
30 Device::WebGpu => "webgpu",
31 }
32 }
33}
34
35impl fmt::Display for Device {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 f.write_str(self.as_str())
38 }
39}
40
41impl std::str::FromStr for Device {
42 type Err = String;
43
44 fn from_str(s: &str) -> Result<Self, Self::Err> {
45 match s {
46 "cpu" => Ok(Device::Cpu),
47 "wasm" => Ok(Device::Wasm),
48 "webgpu" => Ok(Device::WebGpu),
49 _ => Err(format!("Unknown device: {}", s)),
50 }
51 }
52}
53
54static DEFAULT_DEVICE: OnceLock<Device> = OnceLock::new();
56
57pub fn default_device() -> Device {
68 *DEFAULT_DEVICE.get_or_init(|| Device::Cpu)
69}
70
71pub fn set_default_device(device: Device) {
73 let _ = DEFAULT_DEVICE.set(device);
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84
85 #[test]
86 fn test_device_display() {
87 assert_eq!(Device::Cpu.to_string(), "cpu");
88 assert_eq!(Device::Wasm.to_string(), "wasm");
89 assert_eq!(Device::WebGpu.to_string(), "webgpu");
90 }
91
92 #[test]
93 fn test_device_from_str() {
94 assert_eq!("cpu".parse::<Device>().unwrap(), Device::Cpu);
95 assert_eq!("wasm".parse::<Device>().unwrap(), Device::Wasm);
96 assert_eq!("webgpu".parse::<Device>().unwrap(), Device::WebGpu);
97 assert!("unknown".parse::<Device>().is_err());
98 }
99
100 #[test]
101 fn test_device_all() {
102 let devices = Device::all();
103 assert_eq!(devices.len(), 3);
104 assert!(devices.contains(&Device::Cpu));
105 assert!(devices.contains(&Device::Wasm));
106 assert!(devices.contains(&Device::WebGpu));
107 }
108
109 #[test]
110 fn test_default_device() {
111 let device = default_device();
112 assert_eq!(device, Device::Cpu);
113 }
114}