Skip to main content

axonml_core/
device.rs

1//! Device abstraction and hardware backend management.
2//!
3//! Defines the `Device` enum (Cpu, Cuda, Vulkan, Metal, Wgpu) with runtime
4//! availability checks via each backend's `is_device_available()`, capability
5//! queries (`DeviceCapabilities` — memory, f16/f64 support, compute
6//! capability), device counting, and the `best_available_backend()` selector
7//! that prefers CUDA > Metal > Vulkan > WebGPU > CPU.
8//!
9//! # File
10//! `crates/axonml-core/src/device.rs`
11//!
12//! # Author
13//! Andrew Jewell Sr. — AutomataNexus LLC
14//! ORCID: 0009-0005-2158-7060
15//!
16//! # Updated
17//! April 14, 2026 11:15 PM EST
18//!
19//! # Disclaimer
20//! Use at own risk. This software is provided "as is", without warranty of any
21//! kind, express or implied. The author and AutomataNexus shall not be held
22//! liable for any damages arising from the use of this software.
23
24use core::fmt;
25use sysinfo::System;
26
27// =============================================================================
28// Device Enum
29// =============================================================================
30
31/// Represents a compute device where tensors can be allocated and operations executed.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
33pub enum Device {
34    /// CPU device (always available).
35    #[default]
36    Cpu,
37
38    /// NVIDIA CUDA GPU device with device index.
39    #[cfg(feature = "cuda")]
40    Cuda(usize),
41
42    /// Vulkan GPU device with device index (cross-platform).
43    #[cfg(feature = "vulkan")]
44    Vulkan(usize),
45
46    /// Apple Metal GPU device with device index.
47    #[cfg(feature = "metal")]
48    Metal(usize),
49
50    /// WebGPU device with device index (for WASM/browser).
51    #[cfg(feature = "wgpu")]
52    Wgpu(usize),
53}
54
55impl Device {
56    /// Returns true if this device is available on the current system.
57    #[must_use]
58    pub fn is_available(self) -> bool {
59        match self {
60            Self::Cpu => true,
61            #[cfg(feature = "cuda")]
62            Self::Cuda(idx) => crate::backends::cuda::is_device_available(idx),
63            #[cfg(feature = "vulkan")]
64            Self::Vulkan(idx) => crate::backends::vulkan::is_device_available(idx),
65            #[cfg(feature = "metal")]
66            Self::Metal(idx) => crate::backends::metal::is_device_available(idx),
67            #[cfg(feature = "wgpu")]
68            Self::Wgpu(idx) => crate::backends::wgpu_backend::is_device_available(idx),
69        }
70    }
71
72    /// Returns true if this is a CPU device.
73    #[must_use]
74    pub const fn is_cpu(self) -> bool {
75        matches!(self, Self::Cpu)
76    }
77
78    /// Returns true if this is a GPU device.
79    #[must_use]
80    pub const fn is_gpu(self) -> bool {
81        !self.is_cpu()
82    }
83
84    /// Returns the device index for GPU devices, or 0 for CPU.
85    #[must_use]
86    pub const fn index(self) -> usize {
87        match self {
88            Self::Cpu => 0,
89            #[cfg(feature = "cuda")]
90            Self::Cuda(idx) => idx,
91            #[cfg(feature = "vulkan")]
92            Self::Vulkan(idx) => idx,
93            #[cfg(feature = "metal")]
94            Self::Metal(idx) => idx,
95            #[cfg(feature = "wgpu")]
96            Self::Wgpu(idx) => idx,
97        }
98    }
99
100    /// Returns the name of this device type.
101    #[must_use]
102    pub const fn device_type(self) -> &'static str {
103        match self {
104            Self::Cpu => "cpu",
105            #[cfg(feature = "cuda")]
106            Self::Cuda(_) => "cuda",
107            #[cfg(feature = "vulkan")]
108            Self::Vulkan(_) => "vulkan",
109            #[cfg(feature = "metal")]
110            Self::Metal(_) => "metal",
111            #[cfg(feature = "wgpu")]
112            Self::Wgpu(_) => "wgpu",
113        }
114    }
115
116    /// Returns the default CPU device.
117    #[must_use]
118    pub const fn cpu() -> Self {
119        Self::Cpu
120    }
121
122    /// Returns a CUDA device with the given index.
123    #[cfg(feature = "cuda")]
124    #[must_use]
125    pub const fn cuda(index: usize) -> Self {
126        Self::Cuda(index)
127    }
128}
129
130impl fmt::Display for Device {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        match self {
133            Self::Cpu => write!(f, "cpu"),
134            #[cfg(feature = "cuda")]
135            Self::Cuda(idx) => write!(f, "cuda:{idx}"),
136            #[cfg(feature = "vulkan")]
137            Self::Vulkan(idx) => write!(f, "vulkan:{idx}"),
138            #[cfg(feature = "metal")]
139            Self::Metal(idx) => write!(f, "metal:{idx}"),
140            #[cfg(feature = "wgpu")]
141            Self::Wgpu(idx) => write!(f, "wgpu:{idx}"),
142        }
143    }
144}
145
146// =============================================================================
147// Device Capabilities
148// =============================================================================
149
150/// Information about a device's capabilities.
151#[derive(Debug, Clone)]
152pub struct DeviceCapabilities {
153    /// Name of the device.
154    pub name: String,
155    /// Total memory in bytes.
156    pub total_memory: usize,
157    /// Available memory in bytes.
158    pub available_memory: usize,
159    /// Whether the device supports f16.
160    pub supports_f16: bool,
161    /// Whether the device supports f64.
162    pub supports_f64: bool,
163    /// Maximum threads per block (for GPU).
164    pub max_threads_per_block: usize,
165    /// Compute capability version (for CUDA).
166    pub compute_capability: Option<(usize, usize)>,
167}
168
169impl Device {
170    /// Returns the capabilities of this device.
171    #[must_use]
172    pub fn capabilities(self) -> DeviceCapabilities {
173        match self {
174            Self::Cpu => DeviceCapabilities {
175                name: "CPU".to_string(),
176                total_memory: get_system_memory(),
177                available_memory: get_available_memory(),
178                supports_f16: true,
179                supports_f64: true,
180                max_threads_per_block: num_cpus(),
181                compute_capability: None,
182            },
183            #[cfg(feature = "cuda")]
184            Self::Cuda(idx) => crate::backends::cuda::get_capabilities(idx),
185            #[cfg(feature = "vulkan")]
186            Self::Vulkan(idx) => crate::backends::vulkan::get_capabilities(idx),
187            #[cfg(feature = "metal")]
188            Self::Metal(idx) => crate::backends::metal::get_capabilities(idx),
189            #[cfg(feature = "wgpu")]
190            Self::Wgpu(idx) => crate::backends::wgpu_backend::get_capabilities(idx),
191        }
192    }
193}
194
195// =============================================================================
196// Helper Functions
197// =============================================================================
198
199/// Returns the total system memory in bytes.
200fn get_system_memory() -> usize {
201    let sys = System::new_all();
202    sys.total_memory() as usize
203}
204
205/// Returns the available system memory in bytes.
206fn get_available_memory() -> usize {
207    let sys = System::new_all();
208    sys.available_memory() as usize
209}
210
211/// Returns the number of CPU cores.
212fn num_cpus() -> usize {
213    std::thread::available_parallelism().map_or(1, std::num::NonZeroUsize::get)
214}
215
216impl DeviceCapabilities {
217    /// Returns true if the device supports f32.
218    #[must_use]
219    pub const fn supports_f32(&self) -> bool {
220        true // All devices support f32
221    }
222}
223
224// =============================================================================
225// Device Count Functions
226// =============================================================================
227
228/// Returns the number of available CUDA devices.
229#[cfg(feature = "cuda")]
230#[must_use]
231pub fn cuda_device_count() -> usize {
232    crate::backends::cuda::device_count()
233}
234
235/// Returns the number of available Vulkan devices.
236#[cfg(feature = "vulkan")]
237#[must_use]
238pub fn vulkan_device_count() -> usize {
239    crate::backends::vulkan::device_count()
240}
241
242// =============================================================================
243// Tests
244// =============================================================================
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_cpu_device() {
252        let device = Device::Cpu;
253        assert!(device.is_cpu());
254        assert!(!device.is_gpu());
255        assert!(device.is_available());
256        assert_eq!(device.device_type(), "cpu");
257    }
258
259    #[test]
260    fn test_device_display() {
261        let cpu = Device::Cpu;
262        assert_eq!(format!("{cpu}"), "cpu");
263    }
264
265    #[test]
266    fn test_device_default() {
267        let device = Device::default();
268        assert_eq!(device, Device::Cpu);
269    }
270
271    #[test]
272    fn test_device_capabilities() {
273        let caps = Device::Cpu.capabilities();
274        assert_eq!(caps.name, "CPU");
275        assert!(caps.supports_f32());
276    }
277}