Skip to main content

axonml_core/
device.rs

1//! Device Abstraction - Hardware Backend Management
2//!
3//! Provides a unified interface for managing compute devices including CPU,
4//! CUDA GPUs, Vulkan, Metal, and WebGPU backends. Tensors can be moved between
5//! devices transparently.
6//!
7//! # Key Features
8//! - Unified device abstraction across backends
9//! - Device availability checking
10//! - Device capability queries
11//! - Seamless tensor transfer between devices
12//!
13//! # Example
14//! ```rust
15//! use axonml_core::Device;
16//!
17//! let cpu = Device::Cpu;
18//! assert!(cpu.is_available());
19//! assert!(cpu.is_cpu());
20//!
21//! // Use default device (CPU)
22//! let device = Device::default();
23//! assert_eq!(device, Device::Cpu);
24//! ```
25//!
26//! @version 0.1.0
27//! @author `AutomataNexus` Development Team
28
29use core::fmt;
30use sysinfo::System;
31
32// =============================================================================
33// Device Enum
34// =============================================================================
35
36/// Represents a compute device where tensors can be allocated and operations executed.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
38pub enum Device {
39    /// CPU device (always available).
40    #[default]
41    Cpu,
42
43    /// NVIDIA CUDA GPU device with device index.
44    #[cfg(feature = "cuda")]
45    Cuda(usize),
46
47    /// Vulkan GPU device with device index (cross-platform).
48    #[cfg(feature = "vulkan")]
49    Vulkan(usize),
50
51    /// Apple Metal GPU device with device index.
52    #[cfg(feature = "metal")]
53    Metal(usize),
54
55    /// WebGPU device with device index (for WASM/browser).
56    #[cfg(feature = "wgpu")]
57    Wgpu(usize),
58}
59
60impl Device {
61    /// Returns true if this device is available on the current system.
62    #[must_use]
63    pub fn is_available(self) -> bool {
64        match self {
65            Self::Cpu => true,
66            #[cfg(feature = "cuda")]
67            Self::Cuda(idx) => crate::backends::cuda::is_device_available(idx),
68            #[cfg(feature = "vulkan")]
69            Self::Vulkan(idx) => crate::backends::vulkan::is_device_available(idx),
70            #[cfg(feature = "metal")]
71            Self::Metal(idx) => crate::backends::metal::is_device_available(idx),
72            #[cfg(feature = "wgpu")]
73            Self::Wgpu(idx) => crate::backends::wgpu_backend::is_device_available(idx),
74        }
75    }
76
77    /// Returns true if this is a CPU device.
78    #[must_use]
79    pub const fn is_cpu(self) -> bool {
80        matches!(self, Self::Cpu)
81    }
82
83    /// Returns true if this is a GPU device.
84    #[must_use]
85    pub const fn is_gpu(self) -> bool {
86        !self.is_cpu()
87    }
88
89    /// Returns the device index for GPU devices, or 0 for CPU.
90    #[must_use]
91    pub const fn index(self) -> usize {
92        match self {
93            Self::Cpu => 0,
94            #[cfg(feature = "cuda")]
95            Self::Cuda(idx) => idx,
96            #[cfg(feature = "vulkan")]
97            Self::Vulkan(idx) => idx,
98            #[cfg(feature = "metal")]
99            Self::Metal(idx) => idx,
100            #[cfg(feature = "wgpu")]
101            Self::Wgpu(idx) => idx,
102        }
103    }
104
105    /// Returns the name of this device type.
106    #[must_use]
107    pub const fn device_type(self) -> &'static str {
108        match self {
109            Self::Cpu => "cpu",
110            #[cfg(feature = "cuda")]
111            Self::Cuda(_) => "cuda",
112            #[cfg(feature = "vulkan")]
113            Self::Vulkan(_) => "vulkan",
114            #[cfg(feature = "metal")]
115            Self::Metal(_) => "metal",
116            #[cfg(feature = "wgpu")]
117            Self::Wgpu(_) => "wgpu",
118        }
119    }
120
121    /// Returns the default CPU device.
122    #[must_use]
123    pub const fn cpu() -> Self {
124        Self::Cpu
125    }
126
127    /// Returns a CUDA device with the given index.
128    #[cfg(feature = "cuda")]
129    #[must_use]
130    pub const fn cuda(index: usize) -> Self {
131        Self::Cuda(index)
132    }
133}
134
135impl fmt::Display for Device {
136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137        match self {
138            Self::Cpu => write!(f, "cpu"),
139            #[cfg(feature = "cuda")]
140            Self::Cuda(idx) => write!(f, "cuda:{idx}"),
141            #[cfg(feature = "vulkan")]
142            Self::Vulkan(idx) => write!(f, "vulkan:{idx}"),
143            #[cfg(feature = "metal")]
144            Self::Metal(idx) => write!(f, "metal:{idx}"),
145            #[cfg(feature = "wgpu")]
146            Self::Wgpu(idx) => write!(f, "wgpu:{idx}"),
147        }
148    }
149}
150
151// =============================================================================
152// Device Capabilities
153// =============================================================================
154
155/// Information about a device's capabilities.
156#[derive(Debug, Clone)]
157pub struct DeviceCapabilities {
158    /// Name of the device.
159    pub name: String,
160    /// Total memory in bytes.
161    pub total_memory: usize,
162    /// Available memory in bytes.
163    pub available_memory: usize,
164    /// Whether the device supports f16.
165    pub supports_f16: bool,
166    /// Whether the device supports f64.
167    pub supports_f64: bool,
168    /// Maximum threads per block (for GPU).
169    pub max_threads_per_block: usize,
170    /// Compute capability version (for CUDA).
171    pub compute_capability: Option<(usize, usize)>,
172}
173
174impl Device {
175    /// Returns the capabilities of this device.
176    #[must_use]
177    pub fn capabilities(self) -> DeviceCapabilities {
178        match self {
179            Self::Cpu => DeviceCapabilities {
180                name: "CPU".to_string(),
181                total_memory: get_system_memory(),
182                available_memory: get_available_memory(),
183                supports_f16: true,
184                supports_f64: true,
185                max_threads_per_block: num_cpus(),
186                compute_capability: None,
187            },
188            #[cfg(feature = "cuda")]
189            Self::Cuda(idx) => crate::backends::cuda::get_capabilities(idx),
190            #[cfg(feature = "vulkan")]
191            Self::Vulkan(idx) => crate::backends::vulkan::get_capabilities(idx),
192            #[cfg(feature = "metal")]
193            Self::Metal(idx) => crate::backends::metal::get_capabilities(idx),
194            #[cfg(feature = "wgpu")]
195            Self::Wgpu(idx) => crate::backends::wgpu_backend::get_capabilities(idx),
196        }
197    }
198}
199
200// =============================================================================
201// Helper Functions
202// =============================================================================
203
204/// Returns the total system memory in bytes.
205fn get_system_memory() -> usize {
206    let sys = System::new_all();
207    sys.total_memory() as usize
208}
209
210/// Returns the available system memory in bytes.
211fn get_available_memory() -> usize {
212    let sys = System::new_all();
213    sys.available_memory() as usize
214}
215
216/// Returns the number of CPU cores.
217fn num_cpus() -> usize {
218    std::thread::available_parallelism()
219        .map(std::num::NonZeroUsize::get)
220        .unwrap_or(1)
221}
222
223impl DeviceCapabilities {
224    /// Returns true if the device supports f32.
225    #[must_use]
226    pub const fn supports_f32(&self) -> bool {
227        true // All devices support f32
228    }
229}
230
231// =============================================================================
232// Device Count Functions
233// =============================================================================
234
235/// Returns the number of available CUDA devices.
236#[cfg(feature = "cuda")]
237#[must_use]
238pub fn cuda_device_count() -> usize {
239    crate::backends::cuda::device_count()
240}
241
242/// Returns the number of available Vulkan devices.
243#[cfg(feature = "vulkan")]
244#[must_use]
245pub fn vulkan_device_count() -> usize {
246    crate::backends::vulkan::device_count()
247}
248
249// =============================================================================
250// Tests
251// =============================================================================
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_cpu_device() {
259        let device = Device::Cpu;
260        assert!(device.is_cpu());
261        assert!(!device.is_gpu());
262        assert!(device.is_available());
263        assert_eq!(device.device_type(), "cpu");
264    }
265
266    #[test]
267    fn test_device_display() {
268        let cpu = Device::Cpu;
269        assert_eq!(format!("{cpu}"), "cpu");
270    }
271
272    #[test]
273    fn test_device_default() {
274        let device = Device::default();
275        assert_eq!(device, Device::Cpu);
276    }
277
278    #[test]
279    fn test_device_capabilities() {
280        let caps = Device::Cpu.capabilities();
281        assert_eq!(caps.name, "CPU");
282        assert!(caps.supports_f32());
283    }
284}