liquid_edge/
device.rs

1//! Device abstraction for liquid-edge inference
2//!
3//! This module provides device abstractions following the USLS pattern.
4//! Devices are simple enums that can be converted to ORT execution providers.
5
6use serde::{Deserialize, Serialize};
7use std::fmt;
8
9#[allow(unused_imports)]
10#[cfg(feature = "onnx")]
11use ort::execution_providers::ExecutionProvider;
12
13/// Device types for model execution, following USLS pattern
14#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
15pub enum Device {
16    /// CPU device with thread count
17    Cpu(usize),
18    /// CUDA device with device ID
19    Cuda(usize),
20}
21
22impl Default for Device {
23    fn default() -> Self {
24        Self::Cpu(0)
25    }
26}
27
28impl fmt::Display for Device {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            Self::Cpu(i) => write!(f, "cpu:{i}"),
32            Self::Cuda(i) => write!(f, "cuda:{i}"),
33        }
34    }
35}
36
37impl std::str::FromStr for Device {
38    type Err = crate::EdgeError;
39
40    fn from_str(s: &str) -> Result<Self, Self::Err> {
41        #[inline]
42        fn parse_device_id(id_str: Option<&str>) -> usize {
43            id_str
44                .map(|s| s.trim().parse::<usize>().unwrap_or(0))
45                .unwrap_or(0)
46        }
47
48        let (device_type, id_part) = s
49            .trim()
50            .split_once(':')
51            .map_or_else(|| (s.trim(), None), |(device, id)| (device, Some(id)));
52
53        match device_type.to_lowercase().as_str() {
54            "cpu" => Ok(Self::Cpu(parse_device_id(id_part))),
55            "cuda" => Ok(Self::Cuda(parse_device_id(id_part))),
56            _ => Err(crate::EdgeError::runtime(format!(
57                "Unsupported device: {s}"
58            ))),
59        }
60    }
61}
62
63impl Device {
64    /// Get the device ID if applicable
65    pub fn id(&self) -> Option<usize> {
66        match self {
67            Self::Cpu(i) | Self::Cuda(i) => Some(*i),
68        }
69    }
70
71    /// Check if the device is available on the system
72    pub fn is_available(&self) -> bool {
73        match self {
74            Self::Cpu(_) => true, // CPU is always available
75            Self::Cuda(_) => {
76                #[cfg(all(feature = "onnx", feature = "cuda"))]
77                {
78                    use ort::execution_providers::CUDAExecutionProvider;
79                    CUDAExecutionProvider::default()
80                        .with_device_id(self.id().unwrap_or(0) as i32)
81                        .is_available()
82                        .unwrap_or(false)
83                }
84                #[cfg(not(all(feature = "onnx", feature = "cuda")))]
85                {
86                    false
87                }
88            }
89        }
90    }
91}
92
93/// Convenience functions for device creation
94pub fn cpu() -> Device {
95    Device::Cpu(1)
96}
97
98pub fn cpu_with_threads(threads: usize) -> Device {
99    Device::Cpu(threads)
100}
101
102pub fn cuda(device_id: usize) -> Device {
103    Device::Cuda(device_id)
104}
105
106pub fn cuda_default() -> Device {
107    Device::Cuda(0)
108}