liquid_edge/
device.rs

1//! Device abstraction for liquid-edge inference
2//!
3//! This module provides device abstractions following the USLS pattern.
4//! Devices are simple enums that can be converted to ORT execution providers.
5
6use serde::{Deserialize, Serialize};
7use std::fmt;
8
9#[allow(unused_imports)]
10#[cfg(all(feature = "onnx", not(target_arch = "wasm32")))]
11use ort::execution_providers::ExecutionProvider;
12
13/// Device types for model execution, following USLS pattern
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
15pub enum Device {
16    /// CPU device with thread count
17    Cpu(usize),
18    /// CUDA device with device ID
19    Cuda(usize),
20    /// WebGPU device (for WASM targets)
21    #[cfg(target_arch = "wasm32")]
22    WebGpu,
23}
24
25impl Default for Device {
26    fn default() -> Self {
27        #[cfg(target_arch = "wasm32")]
28        {
29            Self::WebGpu
30        }
31        #[cfg(not(target_arch = "wasm32"))]
32        {
33            Self::Cpu(0)
34        }
35    }
36}
37
38impl fmt::Display for Device {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            Self::Cpu(i) => write!(f, "cpu:{i}"),
42            Self::Cuda(i) => write!(f, "cuda:{i}"),
43            #[cfg(target_arch = "wasm32")]
44            Self::WebGpu => write!(f, "webgpu"),
45        }
46    }
47}
48
49impl std::str::FromStr for Device {
50    type Err = crate::EdgeError;
51
52    fn from_str(s: &str) -> Result<Self, Self::Err> {
53        #[inline]
54        fn parse_device_id(id_str: Option<&str>) -> usize {
55            id_str
56                .map(|s| s.trim().parse::<usize>().unwrap_or(0))
57                .unwrap_or(0)
58        }
59
60        let (device_type, id_part) = s
61            .trim()
62            .split_once(':')
63            .map_or_else(|| (s.trim(), None), |(device, id)| (device, Some(id)));
64
65        match device_type.to_lowercase().as_str() {
66            "cpu" => Ok(Self::Cpu(parse_device_id(id_part))),
67            "cuda" => Ok(Self::Cuda(parse_device_id(id_part))),
68            #[cfg(target_arch = "wasm32")]
69            "webgpu" => Ok(Self::WebGpu),
70            _ => Err(crate::EdgeError::runtime(format!(
71                "Unsupported device: {s}"
72            ))),
73        }
74    }
75}
76
77impl Device {
78    /// Get the device ID if applicable
79    pub fn id(&self) -> Option<usize> {
80        match self {
81            Self::Cpu(i) | Self::Cuda(i) => Some(*i),
82            #[cfg(target_arch = "wasm32")]
83            Self::WebGpu => None,
84        }
85    }
86
87    /// Check if the device is available on the system
88    pub fn is_available(&self) -> bool {
89        match self {
90            Self::Cpu(_) => true, // CPU is always available
91            Self::Cuda(_) => {
92                #[cfg(all(feature = "onnx", feature = "cuda"))]
93                {
94                    use ort::execution_providers::CUDAExecutionProvider;
95                    CUDAExecutionProvider::default()
96                        .with_device_id(self.id().unwrap_or(0) as i32)
97                        .is_available()
98                        .unwrap_or(false)
99                }
100                #[cfg(not(all(feature = "onnx", feature = "cuda")))]
101                {
102                    false
103                }
104            }
105            #[cfg(target_arch = "wasm32")]
106            Self::WebGpu => {
107                // For WASM, WebGPU availability depends on browser support
108                // For now, assume it's available if the feature is enabled
109                cfg!(feature = "onnx")
110            }
111        }
112    }
113}
114
115/// Convenience functions for device creation
116pub fn cpu() -> Device {
117    Device::Cpu(1)
118}
119
120pub fn cpu_with_threads(threads: usize) -> Device {
121    Device::Cpu(threads)
122}
123
124pub fn cuda(device_id: usize) -> Device {
125    Device::Cuda(device_id)
126}
127
128pub fn cuda_default() -> Device {
129    Device::Cuda(0)
130}
131
132/// WebGPU device for WASM targets
133#[cfg(target_arch = "wasm32")]
134pub fn webgpu() -> Device {
135    Device::WebGpu
136}