axonml_core/
device.rs

1//! Device Abstraction - Hardware Backend Management
2//!
3//! Provides a unified interface for managing compute devices including CPU,
4//! CUDA GPUs, Vulkan, Metal, and WebGPU backends. Tensors can be moved between
5//! devices transparently.
6//!
7//! # Key Features
8//! - Unified device abstraction across backends
9//! - Device availability checking
10//! - Device capability queries
11//! - Seamless tensor transfer between devices
12//!
13//! # Example
14//! ```rust
15//! use axonml_core::Device;
16//!
17//! let cpu = Device::Cpu;
18//! assert!(cpu.is_available());
19//! assert!(cpu.is_cpu());
20//!
21//! // Use default device (CPU)
22//! let device = Device::default();
23//! assert_eq!(device, Device::Cpu);
24//! ```
25//!
26//! @version 0.1.0
27//! @author `AutomataNexus` Development Team
28
29use core::fmt;
30use sysinfo::System;
31
32// =============================================================================
33// Device Enum
34// =============================================================================
35
36/// Represents a compute device where tensors can be allocated and operations executed.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38pub enum Device {
39    /// CPU device (always available).
40    Cpu,
41
42    /// NVIDIA CUDA GPU device with device index.
43    #[cfg(feature = "cuda")]
44    Cuda(usize),
45
46    /// Vulkan GPU device with device index (cross-platform).
47    #[cfg(feature = "vulkan")]
48    Vulkan(usize),
49
50    /// Apple Metal GPU device with device index.
51    #[cfg(feature = "metal")]
52    Metal(usize),
53
54    /// WebGPU device with device index (for WASM/browser).
55    #[cfg(feature = "wgpu")]
56    Wgpu(usize),
57}
58
59impl Device {
60    /// Returns true if this device is available on the current system.
61    #[must_use]
62    pub fn is_available(self) -> bool {
63        match self {
64            Self::Cpu => true,
65            #[cfg(feature = "cuda")]
66            Self::Cuda(idx) => crate::backends::cuda::is_device_available(idx),
67            #[cfg(feature = "vulkan")]
68            Self::Vulkan(idx) => crate::backends::vulkan::is_device_available(idx),
69            #[cfg(feature = "metal")]
70            Self::Metal(idx) => crate::backends::metal::is_device_available(idx),
71            #[cfg(feature = "wgpu")]
72            Self::Wgpu(idx) => crate::backends::wgpu::is_device_available(idx),
73        }
74    }
75
76    /// Returns true if this is a CPU device.
77    #[must_use]
78    pub const fn is_cpu(self) -> bool {
79        matches!(self, Self::Cpu)
80    }
81
82    /// Returns true if this is a GPU device.
83    #[must_use]
84    pub const fn is_gpu(self) -> bool {
85        !self.is_cpu()
86    }
87
88    /// Returns the device index for GPU devices, or 0 for CPU.
89    #[must_use]
90    pub const fn index(self) -> usize {
91        match self {
92            Self::Cpu => 0,
93            #[cfg(feature = "cuda")]
94            Self::Cuda(idx) => idx,
95            #[cfg(feature = "vulkan")]
96            Self::Vulkan(idx) => idx,
97            #[cfg(feature = "metal")]
98            Self::Metal(idx) => idx,
99            #[cfg(feature = "wgpu")]
100            Self::Wgpu(idx) => idx,
101        }
102    }
103
104    /// Returns the name of this device type.
105    #[must_use]
106    pub const fn device_type(self) -> &'static str {
107        match self {
108            Self::Cpu => "cpu",
109            #[cfg(feature = "cuda")]
110            Self::Cuda(_) => "cuda",
111            #[cfg(feature = "vulkan")]
112            Self::Vulkan(_) => "vulkan",
113            #[cfg(feature = "metal")]
114            Self::Metal(_) => "metal",
115            #[cfg(feature = "wgpu")]
116            Self::Wgpu(_) => "wgpu",
117        }
118    }
119
120    /// Returns the default CPU device.
121    #[must_use]
122    pub const fn cpu() -> Self {
123        Self::Cpu
124    }
125
126    /// Returns a CUDA device with the given index.
127    #[cfg(feature = "cuda")]
128    #[must_use]
129    pub const fn cuda(index: usize) -> Self {
130        Self::Cuda(index)
131    }
132}
133
134impl Default for Device {
135    fn default() -> Self {
136        Self::Cpu
137    }
138}
139
140impl fmt::Display for Device {
141    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142        match self {
143            Self::Cpu => write!(f, "cpu"),
144            #[cfg(feature = "cuda")]
145            Self::Cuda(idx) => write!(f, "cuda:{idx}"),
146            #[cfg(feature = "vulkan")]
147            Self::Vulkan(idx) => write!(f, "vulkan:{idx}"),
148            #[cfg(feature = "metal")]
149            Self::Metal(idx) => write!(f, "metal:{idx}"),
150            #[cfg(feature = "wgpu")]
151            Self::Wgpu(idx) => write!(f, "wgpu:{idx}"),
152        }
153    }
154}
155
156// =============================================================================
157// Device Capabilities
158// =============================================================================
159
160/// Information about a device's capabilities.
161#[derive(Debug, Clone)]
162pub struct DeviceCapabilities {
163    /// Name of the device.
164    pub name: String,
165    /// Total memory in bytes.
166    pub total_memory: usize,
167    /// Available memory in bytes.
168    pub available_memory: usize,
169    /// Whether the device supports f16.
170    pub supports_f16: bool,
171    /// Whether the device supports f64.
172    pub supports_f64: bool,
173    /// Maximum threads per block (for GPU).
174    pub max_threads_per_block: usize,
175    /// Compute capability version (for CUDA).
176    pub compute_capability: Option<(usize, usize)>,
177}
178
179impl Device {
180    /// Returns the capabilities of this device.
181    #[must_use]
182    pub fn capabilities(self) -> DeviceCapabilities {
183        match self {
184            Self::Cpu => DeviceCapabilities {
185                name: "CPU".to_string(),
186                total_memory: get_system_memory(),
187                available_memory: get_available_memory(),
188                supports_f16: true,
189                supports_f64: true,
190                max_threads_per_block: num_cpus(),
191                compute_capability: None,
192            },
193            #[cfg(feature = "cuda")]
194            Self::Cuda(idx) => crate::backends::cuda::get_capabilities(idx),
195            #[cfg(feature = "vulkan")]
196            Self::Vulkan(idx) => crate::backends::vulkan::get_capabilities(idx),
197            #[cfg(feature = "metal")]
198            Self::Metal(idx) => crate::backends::metal::get_capabilities(idx),
199            #[cfg(feature = "wgpu")]
200            Self::Wgpu(idx) => crate::backends::wgpu::get_capabilities(idx),
201        }
202    }
203}
204
205// =============================================================================
206// Helper Functions
207// =============================================================================
208
209/// Returns the total system memory in bytes.
210fn get_system_memory() -> usize {
211    let sys = System::new_all();
212    sys.total_memory() as usize
213}
214
215/// Returns the available system memory in bytes.
216fn get_available_memory() -> usize {
217    let sys = System::new_all();
218    sys.available_memory() as usize
219}
220
221/// Returns the number of CPU cores.
222fn num_cpus() -> usize {
223    std::thread::available_parallelism()
224        .map(std::num::NonZeroUsize::get)
225        .unwrap_or(1)
226}
227
228impl DeviceCapabilities {
229    /// Returns true if the device supports f32.
230    #[must_use]
231    pub const fn supports_f32(&self) -> bool {
232        true // All devices support f32
233    }
234}
235
236// =============================================================================
237// Device Count Functions
238// =============================================================================
239
240/// Returns the number of available CUDA devices.
241#[cfg(feature = "cuda")]
242#[must_use]
243pub fn cuda_device_count() -> usize {
244    crate::backends::cuda::device_count()
245}
246
247/// Returns the number of available Vulkan devices.
248#[cfg(feature = "vulkan")]
249#[must_use]
250pub fn vulkan_device_count() -> usize {
251    crate::backends::vulkan::device_count()
252}
253
254// =============================================================================
255// Tests
256// =============================================================================
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_cpu_device() {
264        let device = Device::Cpu;
265        assert!(device.is_cpu());
266        assert!(!device.is_gpu());
267        assert!(device.is_available());
268        assert_eq!(device.device_type(), "cpu");
269    }
270
271    #[test]
272    fn test_device_display() {
273        let cpu = Device::Cpu;
274        assert_eq!(format!("{cpu}"), "cpu");
275    }
276
277    #[test]
278    fn test_device_default() {
279        let device = Device::default();
280        assert_eq!(device, Device::Cpu);
281    }
282
283    #[test]
284    fn test_device_capabilities() {
285        let caps = Device::Cpu.capabilities();
286        assert_eq!(caps.name, "CPU");
287        assert!(caps.supports_f32());
288    }
289}