Skip to main content

axonml_core/
device.rs

1//! Device Abstraction - Hardware Backend Management
2//!
3//! # File
4//! `crates/axonml-core/src/device.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use core::fmt;
18use sysinfo::System;
19
20// =============================================================================
21// Device Enum
22// =============================================================================
23
24/// Represents a compute device where tensors can be allocated and operations executed.
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
26pub enum Device {
27    /// CPU device (always available).
28    #[default]
29    Cpu,
30
31    /// NVIDIA CUDA GPU device with device index.
32    #[cfg(feature = "cuda")]
33    Cuda(usize),
34
35    /// Vulkan GPU device with device index (cross-platform).
36    #[cfg(feature = "vulkan")]
37    Vulkan(usize),
38
39    /// Apple Metal GPU device with device index.
40    #[cfg(feature = "metal")]
41    Metal(usize),
42
43    /// WebGPU device with device index (for WASM/browser).
44    #[cfg(feature = "wgpu")]
45    Wgpu(usize),
46}
47
48impl Device {
49    /// Returns true if this device is available on the current system.
50    #[must_use]
51    pub fn is_available(self) -> bool {
52        match self {
53            Self::Cpu => true,
54            #[cfg(feature = "cuda")]
55            Self::Cuda(idx) => crate::backends::cuda::is_device_available(idx),
56            #[cfg(feature = "vulkan")]
57            Self::Vulkan(idx) => crate::backends::vulkan::is_device_available(idx),
58            #[cfg(feature = "metal")]
59            Self::Metal(idx) => crate::backends::metal::is_device_available(idx),
60            #[cfg(feature = "wgpu")]
61            Self::Wgpu(idx) => crate::backends::wgpu_backend::is_device_available(idx),
62        }
63    }
64
65    /// Returns true if this is a CPU device.
66    #[must_use]
67    pub const fn is_cpu(self) -> bool {
68        matches!(self, Self::Cpu)
69    }
70
71    /// Returns true if this is a GPU device.
72    #[must_use]
73    pub const fn is_gpu(self) -> bool {
74        !self.is_cpu()
75    }
76
77    /// Returns the device index for GPU devices, or 0 for CPU.
78    #[must_use]
79    pub const fn index(self) -> usize {
80        match self {
81            Self::Cpu => 0,
82            #[cfg(feature = "cuda")]
83            Self::Cuda(idx) => idx,
84            #[cfg(feature = "vulkan")]
85            Self::Vulkan(idx) => idx,
86            #[cfg(feature = "metal")]
87            Self::Metal(idx) => idx,
88            #[cfg(feature = "wgpu")]
89            Self::Wgpu(idx) => idx,
90        }
91    }
92
93    /// Returns the name of this device type.
94    #[must_use]
95    pub const fn device_type(self) -> &'static str {
96        match self {
97            Self::Cpu => "cpu",
98            #[cfg(feature = "cuda")]
99            Self::Cuda(_) => "cuda",
100            #[cfg(feature = "vulkan")]
101            Self::Vulkan(_) => "vulkan",
102            #[cfg(feature = "metal")]
103            Self::Metal(_) => "metal",
104            #[cfg(feature = "wgpu")]
105            Self::Wgpu(_) => "wgpu",
106        }
107    }
108
109    /// Returns the default CPU device.
110    #[must_use]
111    pub const fn cpu() -> Self {
112        Self::Cpu
113    }
114
115    /// Returns a CUDA device with the given index.
116    #[cfg(feature = "cuda")]
117    #[must_use]
118    pub const fn cuda(index: usize) -> Self {
119        Self::Cuda(index)
120    }
121}
122
123impl fmt::Display for Device {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        match self {
126            Self::Cpu => write!(f, "cpu"),
127            #[cfg(feature = "cuda")]
128            Self::Cuda(idx) => write!(f, "cuda:{idx}"),
129            #[cfg(feature = "vulkan")]
130            Self::Vulkan(idx) => write!(f, "vulkan:{idx}"),
131            #[cfg(feature = "metal")]
132            Self::Metal(idx) => write!(f, "metal:{idx}"),
133            #[cfg(feature = "wgpu")]
134            Self::Wgpu(idx) => write!(f, "wgpu:{idx}"),
135        }
136    }
137}
138
139// =============================================================================
140// Device Capabilities
141// =============================================================================
142
143/// Information about a device's capabilities.
144#[derive(Debug, Clone)]
145pub struct DeviceCapabilities {
146    /// Name of the device.
147    pub name: String,
148    /// Total memory in bytes.
149    pub total_memory: usize,
150    /// Available memory in bytes.
151    pub available_memory: usize,
152    /// Whether the device supports f16.
153    pub supports_f16: bool,
154    /// Whether the device supports f64.
155    pub supports_f64: bool,
156    /// Maximum threads per block (for GPU).
157    pub max_threads_per_block: usize,
158    /// Compute capability version (for CUDA).
159    pub compute_capability: Option<(usize, usize)>,
160}
161
162impl Device {
163    /// Returns the capabilities of this device.
164    #[must_use]
165    pub fn capabilities(self) -> DeviceCapabilities {
166        match self {
167            Self::Cpu => DeviceCapabilities {
168                name: "CPU".to_string(),
169                total_memory: get_system_memory(),
170                available_memory: get_available_memory(),
171                supports_f16: true,
172                supports_f64: true,
173                max_threads_per_block: num_cpus(),
174                compute_capability: None,
175            },
176            #[cfg(feature = "cuda")]
177            Self::Cuda(idx) => crate::backends::cuda::get_capabilities(idx),
178            #[cfg(feature = "vulkan")]
179            Self::Vulkan(idx) => crate::backends::vulkan::get_capabilities(idx),
180            #[cfg(feature = "metal")]
181            Self::Metal(idx) => crate::backends::metal::get_capabilities(idx),
182            #[cfg(feature = "wgpu")]
183            Self::Wgpu(idx) => crate::backends::wgpu_backend::get_capabilities(idx),
184        }
185    }
186}
187
188// =============================================================================
189// Helper Functions
190// =============================================================================
191
192/// Returns the total system memory in bytes.
193fn get_system_memory() -> usize {
194    let sys = System::new_all();
195    sys.total_memory() as usize
196}
197
198/// Returns the available system memory in bytes.
199fn get_available_memory() -> usize {
200    let sys = System::new_all();
201    sys.available_memory() as usize
202}
203
204/// Returns the number of CPU cores.
205fn num_cpus() -> usize {
206    std::thread::available_parallelism()
207        .map(std::num::NonZeroUsize::get)
208        .unwrap_or(1)
209}
210
211impl DeviceCapabilities {
212    /// Returns true if the device supports f32.
213    #[must_use]
214    pub const fn supports_f32(&self) -> bool {
215        true // All devices support f32
216    }
217}
218
219// =============================================================================
220// Device Count Functions
221// =============================================================================
222
223/// Returns the number of available CUDA devices.
224#[cfg(feature = "cuda")]
225#[must_use]
226pub fn cuda_device_count() -> usize {
227    crate::backends::cuda::device_count()
228}
229
230/// Returns the number of available Vulkan devices.
231#[cfg(feature = "vulkan")]
232#[must_use]
233pub fn vulkan_device_count() -> usize {
234    crate::backends::vulkan::device_count()
235}
236
237// =============================================================================
238// Tests
239// =============================================================================
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244
245    #[test]
246    fn test_cpu_device() {
247        let device = Device::Cpu;
248        assert!(device.is_cpu());
249        assert!(!device.is_gpu());
250        assert!(device.is_available());
251        assert_eq!(device.device_type(), "cpu");
252    }
253
254    #[test]
255    fn test_device_display() {
256        let cpu = Device::Cpu;
257        assert_eq!(format!("{cpu}"), "cpu");
258    }
259
260    #[test]
261    fn test_device_default() {
262        let device = Device::default();
263        assert_eq!(device, Device::Cpu);
264    }
265
266    #[test]
267    fn test_device_capabilities() {
268        let caps = Device::Cpu.capabilities();
269        assert_eq!(caps.name, "CPU");
270        assert!(caps.supports_f32());
271    }
272}