jax_rs/
device.rs

1//! Device and backend management.
2
3use std::fmt;
4use std::sync::OnceLock;
5
6/// Compute device for array operations.
7///
8/// Corresponds to jax-js Device type: "cpu" | "wasm" | "webgpu"
9#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
10pub enum Device {
11    /// CPU backend (slow, for debugging)
12    Cpu,
13    /// WebAssembly backend with SIMD (optional)
14    Wasm,
15    /// WebGPU backend (primary accelerator)
16    WebGpu,
17}
18
19impl Device {
20    /// Returns all available devices.
21    pub fn all() -> &'static [Device] {
22        &[Device::Cpu, Device::Wasm, Device::WebGpu]
23    }
24
25    /// Returns the name of this device as a string.
26    pub fn as_str(&self) -> &'static str {
27        match self {
28            Device::Cpu => "cpu",
29            Device::Wasm => "wasm",
30            Device::WebGpu => "webgpu",
31        }
32    }
33}
34
35impl fmt::Display for Device {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        f.write_str(self.as_str())
38    }
39}
40
41impl std::str::FromStr for Device {
42    type Err = String;
43
44    fn from_str(s: &str) -> Result<Self, Self::Err> {
45        match s {
46            "cpu" => Ok(Device::Cpu),
47            "wasm" => Ok(Device::Wasm),
48            "webgpu" => Ok(Device::WebGpu),
49            _ => Err(format!("Unknown device: {}", s)),
50        }
51    }
52}
53
54/// Global default device for array operations.
55static DEFAULT_DEVICE: OnceLock<Device> = OnceLock::new();
56
57/// Get or set the default device.
58///
59/// # Examples
60///
61/// ```
62/// # use jax_rs::Device;
63/// // Get current default (CPU by default)
64/// let device = jax_rs::default_device();
65/// assert_eq!(device, Device::Cpu);
66/// ```
67pub fn default_device() -> Device {
68    *DEFAULT_DEVICE.get_or_init(|| Device::Cpu)
69}
70
71/// Set the default device for array operations.
72pub fn set_default_device(device: Device) {
73    // Note: OnceLock doesn't support mutation after init,
74    // so we need a different approach for runtime mutation.
75    // For now, using a global mutable would require unsafe or a Mutex.
76    // Let's use a simple global static that can be updated.
77    // This will be improved in later phases with proper backend management.
78    let _ = DEFAULT_DEVICE.set(device);
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn test_device_display() {
87        assert_eq!(Device::Cpu.to_string(), "cpu");
88        assert_eq!(Device::Wasm.to_string(), "wasm");
89        assert_eq!(Device::WebGpu.to_string(), "webgpu");
90    }
91
92    #[test]
93    fn test_device_from_str() {
94        assert_eq!("cpu".parse::<Device>().unwrap(), Device::Cpu);
95        assert_eq!("wasm".parse::<Device>().unwrap(), Device::Wasm);
96        assert_eq!("webgpu".parse::<Device>().unwrap(), Device::WebGpu);
97        assert!("unknown".parse::<Device>().is_err());
98    }
99
100    #[test]
101    fn test_device_all() {
102        let devices = Device::all();
103        assert_eq!(devices.len(), 3);
104        assert!(devices.contains(&Device::Cpu));
105        assert!(devices.contains(&Device::Wasm));
106        assert!(devices.contains(&Device::WebGpu));
107    }
108
109    #[test]
110    fn test_default_device() {
111        let device = default_device();
112        assert_eq!(device, Device::Cpu);
113    }
114}