Skip to main content

axonml_core/backends/
mod.rs

1//! Backends - Device-Specific Implementations
2//!
3//! # File
4//! `crates/axonml-core/src/backends/mod.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use crate::device::DeviceCapabilities;
18
19// =============================================================================
20// Backend Modules
21// =============================================================================
22
23pub mod cpu;
24
25pub mod cuda;
26
27#[cfg(feature = "cuda")]
28pub mod cuda_kernels;
29
30pub mod cuda_pool;
31
32#[cfg(feature = "cudnn")]
33pub mod cudnn_ops;
34
35#[cfg(feature = "vulkan")]
36pub mod vulkan;
37
38#[cfg(feature = "metal")]
39pub mod metal;
40
41#[cfg(feature = "wgpu")]
42pub mod wgpu_backend;
43
44// GPU testing infrastructure
45pub mod gpu_tests;
46
47// =============================================================================
48// Re-exports
49// =============================================================================
50
51pub use cpu::CpuBackend;
52
53pub use cuda::CudaBackend;
54
55#[cfg(feature = "vulkan")]
56pub use vulkan::VulkanBackend;
57
58#[cfg(feature = "metal")]
59pub use metal::MetalBackend;
60
61#[cfg(feature = "wgpu")]
62pub use wgpu_backend::WgpuBackend;
63
64// =============================================================================
65// Backend Trait
66// =============================================================================
67
68/// Common trait for all compute backends.
69///
70/// This trait defines the interface that all backends must implement,
71/// enabling device-agnostic tensor operations.
72pub trait Backend: Send + Sync {
73    /// Returns the name of this backend.
74    fn name(&self) -> &'static str;
75
76    /// Returns whether this backend is available on the current system.
77    fn is_available(&self) -> bool;
78
79    /// Returns the device capabilities.
80    fn capabilities(&self) -> DeviceCapabilities;
81
82    /// Allocates memory on this backend.
83    fn allocate(&self, size: usize) -> *mut u8;
84
85    /// Deallocates memory on this backend.
86    fn deallocate(&self, ptr: *mut u8, size: usize);
87
88    /// Copies data from host to device.
89    fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize);
90
91    /// Copies data from device to host.
92    fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize);
93
94    /// Copies data within the device.
95    fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize);
96
97    /// Synchronizes the device (waits for all operations to complete).
98    fn synchronize(&self);
99}
100
101// =============================================================================
102// GPU Memory Management
103// =============================================================================
104
105/// GPU memory handle for safe memory management.
106#[derive(Debug)]
107pub struct GpuMemory {
108    ptr: *mut u8,
109    size: usize,
110    device_index: usize,
111    backend_type: BackendType,
112}
113
114/// Type of backend for a GPU memory allocation.
115#[derive(Debug, Clone, Copy, PartialEq, Eq)]
116pub enum BackendType {
117    /// CPU backend.
118    Cpu,
119    /// CUDA backend.
120    #[cfg(feature = "cuda")]
121    Cuda,
122    /// Vulkan backend.
123    #[cfg(feature = "vulkan")]
124    Vulkan,
125    /// Metal backend.
126    #[cfg(feature = "metal")]
127    Metal,
128    /// WebGPU backend.
129    #[cfg(feature = "wgpu")]
130    Wgpu,
131}
132
133impl GpuMemory {
134    /// Creates a new GPU memory handle.
135    pub fn new(ptr: *mut u8, size: usize, device_index: usize, backend_type: BackendType) -> Self {
136        Self {
137            ptr,
138            size,
139            device_index,
140            backend_type,
141        }
142    }
143
144    /// Returns the raw pointer.
145    #[must_use]
146    pub fn ptr(&self) -> *mut u8 {
147        self.ptr
148    }
149
150    /// Returns the size in bytes.
151    #[must_use]
152    pub fn size(&self) -> usize {
153        self.size
154    }
155
156    /// Returns the device index.
157    #[must_use]
158    pub fn device_index(&self) -> usize {
159        self.device_index
160    }
161
162    /// Returns the backend type.
163    #[must_use]
164    pub fn backend_type(&self) -> BackendType {
165        self.backend_type
166    }
167}
168
169// =============================================================================
170// GPU Stream/Queue Abstraction
171// =============================================================================
172
173/// GPU execution stream for async operations.
174#[derive(Debug)]
175pub struct GpuStream {
176    /// Stream handle (backend-specific).
177    handle: usize,
178    /// Device index.
179    device_index: usize,
180    /// Backend type.
181    backend_type: BackendType,
182}
183
184impl GpuStream {
185    /// Creates a new GPU stream.
186    #[must_use]
187    pub fn new(handle: usize, device_index: usize, backend_type: BackendType) -> Self {
188        Self {
189            handle,
190            device_index,
191            backend_type,
192        }
193    }
194
195    /// Returns the stream handle.
196    #[must_use]
197    pub fn handle(&self) -> usize {
198        self.handle
199    }
200
201    /// Returns the device index.
202    #[must_use]
203    pub fn device_index(&self) -> usize {
204        self.device_index
205    }
206
207    /// Synchronizes this stream (waits for all operations to complete).
208    ///
209    /// # Backend-specific behavior
210    /// - **CPU**: No-op (CPU operations are synchronous)
211    /// - **CUDA**: No-op at stream level; use `CudaBackend::synchronize()` for device sync
212    /// - **Vulkan**: Waits for queue to become idle
213    /// - **Metal**: Waits for command buffer completion
214    /// - **WebGPU**: Submits pending commands to queue
215    ///
216    /// For CUDA, proper synchronization should be done through `CudaBackend::synchronize()`
217    /// which performs device-level synchronization.
218    pub fn synchronize(&self) {
219        match self.backend_type {
220            BackendType::Cpu => {} // No-op for CPU (synchronous)
221            #[cfg(feature = "cuda")]
222            BackendType::Cuda => cuda::stream_synchronize(self.handle),
223            #[cfg(feature = "vulkan")]
224            BackendType::Vulkan => vulkan::queue_wait_idle(self.handle),
225            #[cfg(feature = "metal")]
226            BackendType::Metal => metal::command_buffer_wait(self.handle),
227            #[cfg(feature = "wgpu")]
228            BackendType::Wgpu => wgpu_backend::queue_submit(self.handle),
229        }
230    }
231}
232
233// =============================================================================
234// Device Selection Utilities
235// =============================================================================
236
237/// Returns the best available GPU backend.
238#[must_use]
239pub fn best_available_backend() -> BackendType {
240    #[cfg(feature = "cuda")]
241    if cuda::is_available() {
242        return BackendType::Cuda;
243    }
244
245    #[cfg(feature = "metal")]
246    if metal::is_available() {
247        return BackendType::Metal;
248    }
249
250    #[cfg(feature = "vulkan")]
251    if vulkan::is_available() {
252        return BackendType::Vulkan;
253    }
254
255    #[cfg(feature = "wgpu")]
256    if wgpu_backend::is_available() {
257        return BackendType::Wgpu;
258    }
259
260    BackendType::Cpu
261}
262
263/// Returns the number of available GPUs across all backends.
264#[must_use]
265pub fn gpu_count() -> usize {
266    #[allow(unused_mut)]
267    let mut count = 0_usize;
268
269    #[cfg(feature = "cuda")]
270    {
271        count += cuda::device_count();
272    }
273
274    #[cfg(feature = "vulkan")]
275    {
276        count += vulkan::device_count();
277    }
278
279    #[cfg(feature = "metal")]
280    {
281        count += metal::device_count();
282    }
283
284    #[cfg(feature = "wgpu")]
285    {
286        count += wgpu_backend::device_count();
287    }
288
289    count
290}