autoagents_onnx/
device.rs

1//! Device abstraction for liquid-edge inference
2//!
3//! This module provides device abstractions following the USLS pattern.
4//! Devices are simple enums that can be converted to ORT execution providers.
5
6use serde::{Deserialize, Serialize};
7use std::fmt;
8
9#[allow(unused_imports)]
10use ort::execution_providers::ExecutionProvider;
11
12/// Device types for model execution, following USLS pattern
13#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
14pub enum Device {
15    /// CPU device with thread count
16    Cpu(usize),
17    /// CUDA device with device ID
18    #[cfg(feature = "cuda")]
19    Cuda(usize),
20}
21
22impl Default for Device {
23    fn default() -> Self {
24        {
25            Self::Cpu(0)
26        }
27    }
28}
29
30impl fmt::Display for Device {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        match self {
33            Self::Cpu(i) => write!(f, "cpu:{i}"),
34            #[cfg(feature = "cuda")]
35            Self::Cuda(i) => write!(f, "cuda:{i}"),
36        }
37    }
38}
39
40impl std::str::FromStr for Device {
41    type Err = crate::EdgeError;
42
43    fn from_str(s: &str) -> Result<Self, Self::Err> {
44        #[inline]
45        fn parse_device_id(id_str: Option<&str>) -> usize {
46            id_str
47                .map(|s| s.trim().parse::<usize>().unwrap_or(0))
48                .unwrap_or(0)
49        }
50
51        let (device_type, id_part) = s
52            .trim()
53            .split_once(':')
54            .map_or_else(|| (s.trim(), None), |(device, id)| (device, Some(id)));
55
56        match device_type.to_lowercase().as_str() {
57            "cpu" => Ok(Self::Cpu(parse_device_id(id_part))),
58            #[cfg(feature = "cuda")]
59            "cuda" => Ok(Self::Cuda(parse_device_id(id_part))),
60            _ => Err(crate::EdgeError::runtime(format!(
61                "Unsupported device: {s}"
62            ))),
63        }
64    }
65}
66
67impl Device {
68    /// Get the device ID if applicable
69    pub fn id(&self) -> Option<usize> {
70        match self {
71            Self::Cpu(i) => Some(*i),
72            #[cfg(feature = "cuda")]
73            Self::Cuda(i) => Some(*i),
74        }
75    }
76
77    /// Check if the device is available on the system
78    pub fn is_available(&self) -> bool {
79        match self {
80            Self::Cpu(_) => true, // CPU is always available
81            #[cfg(feature = "cuda")]
82            Self::Cuda(_) => {
83                use ort::execution_providers::CUDAExecutionProvider;
84                CUDAExecutionProvider::default()
85                    .with_device_id(self.id().unwrap_or(0) as i32)
86                    .is_available()
87                    .unwrap_or(false)
88            }
89        }
90    }
91}
92
93/// Convenience functions for device creation
94pub fn cpu() -> Device {
95    Device::Cpu(1)
96}
97
98pub fn cpu_with_threads(threads: usize) -> Device {
99    Device::Cpu(threads)
100}
101
102#[cfg(feature = "cuda")]
103pub fn cuda(device_id: usize) -> Device {
104    Device::Cuda(device_id)
105}
106
107#[cfg(feature = "cuda")]
108pub fn cuda_default() -> Device {
109    Device::Cuda(0)
110}