Skip to main content

axonml_core/backends/
mod.rs

1//! Backends - Device-Specific Implementations
2//!
3//! This module contains backend implementations for different compute devices.
4//! Each backend provides device-specific memory operations and kernel execution.
5//!
6//! # Available Backends
7//! - `cpu` - CPU backend (always available)
8//! - `cuda` - NVIDIA CUDA backend (requires `cuda` feature)
9//! - `vulkan` - Vulkan backend (requires `vulkan` feature)
10//! - `metal` - Apple Metal backend (requires `metal` feature)
11//! - `wgpu` - WebGPU backend (requires `wgpu` feature)
12//!
13//! # Backend Trait
14//!
15//! All backends implement the `Backend` trait which provides a common interface
16//! for tensor operations. This enables device-agnostic code.
17//!
18//! @version 0.1.0
19//! @author `AutomataNexus` Development Team
20
21use crate::device::DeviceCapabilities;
22
23// =============================================================================
24// Backend Modules
25// =============================================================================
26
27pub mod cpu;
28
29#[cfg(feature = "cuda")]
30pub mod cuda;
31
32#[cfg(feature = "cuda")]
33pub mod cuda_kernels;
34
35#[cfg(feature = "vulkan")]
36pub mod vulkan;
37
38#[cfg(feature = "metal")]
39pub mod metal;
40
41#[cfg(feature = "wgpu")]
42pub mod wgpu_backend;
43
44// GPU testing infrastructure
45pub mod gpu_tests;
46
47// =============================================================================
48// Re-exports
49// =============================================================================
50
51pub use cpu::CpuBackend;
52
53#[cfg(feature = "cuda")]
54pub use cuda::CudaBackend;
55
56#[cfg(feature = "vulkan")]
57pub use vulkan::VulkanBackend;
58
59#[cfg(feature = "metal")]
60pub use metal::MetalBackend;
61
62#[cfg(feature = "wgpu")]
63pub use wgpu_backend::WgpuBackend;
64
65// =============================================================================
66// Backend Trait
67// =============================================================================
68
69/// Common trait for all compute backends.
70///
71/// This trait defines the interface that all backends must implement,
72/// enabling device-agnostic tensor operations.
73pub trait Backend: Send + Sync {
74    /// Returns the name of this backend.
75    fn name(&self) -> &'static str;
76
77    /// Returns whether this backend is available on the current system.
78    fn is_available(&self) -> bool;
79
80    /// Returns the device capabilities.
81    fn capabilities(&self) -> DeviceCapabilities;
82
83    /// Allocates memory on this backend.
84    fn allocate(&self, size: usize) -> *mut u8;
85
86    /// Deallocates memory on this backend.
87    fn deallocate(&self, ptr: *mut u8, size: usize);
88
89    /// Copies data from host to device.
90    fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize);
91
92    /// Copies data from device to host.
93    fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize);
94
95    /// Copies data within the device.
96    fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize);
97
98    /// Synchronizes the device (waits for all operations to complete).
99    fn synchronize(&self);
100}
101
102// =============================================================================
103// GPU Memory Management
104// =============================================================================
105
106/// GPU memory handle for safe memory management.
107#[derive(Debug)]
108pub struct GpuMemory {
109    ptr: *mut u8,
110    size: usize,
111    device_index: usize,
112    backend_type: BackendType,
113}
114
115/// Type of backend for a GPU memory allocation.
116#[derive(Debug, Clone, Copy, PartialEq, Eq)]
117pub enum BackendType {
118    /// CPU backend.
119    Cpu,
120    /// CUDA backend.
121    #[cfg(feature = "cuda")]
122    Cuda,
123    /// Vulkan backend.
124    #[cfg(feature = "vulkan")]
125    Vulkan,
126    /// Metal backend.
127    #[cfg(feature = "metal")]
128    Metal,
129    /// WebGPU backend.
130    #[cfg(feature = "wgpu")]
131    Wgpu,
132}
133
134impl GpuMemory {
135    /// Creates a new GPU memory handle.
136    pub fn new(ptr: *mut u8, size: usize, device_index: usize, backend_type: BackendType) -> Self {
137        Self {
138            ptr,
139            size,
140            device_index,
141            backend_type,
142        }
143    }
144
145    /// Returns the raw pointer.
146    #[must_use]
147    pub fn ptr(&self) -> *mut u8 {
148        self.ptr
149    }
150
151    /// Returns the size in bytes.
152    #[must_use]
153    pub fn size(&self) -> usize {
154        self.size
155    }
156
157    /// Returns the device index.
158    #[must_use]
159    pub fn device_index(&self) -> usize {
160        self.device_index
161    }
162
163    /// Returns the backend type.
164    #[must_use]
165    pub fn backend_type(&self) -> BackendType {
166        self.backend_type
167    }
168}
169
170// =============================================================================
171// GPU Stream/Queue Abstraction
172// =============================================================================
173
174/// GPU execution stream for async operations.
175#[derive(Debug)]
176pub struct GpuStream {
177    /// Stream handle (backend-specific).
178    handle: usize,
179    /// Device index.
180    device_index: usize,
181    /// Backend type.
182    backend_type: BackendType,
183}
184
185impl GpuStream {
186    /// Creates a new GPU stream.
187    #[must_use]
188    pub fn new(handle: usize, device_index: usize, backend_type: BackendType) -> Self {
189        Self {
190            handle,
191            device_index,
192            backend_type,
193        }
194    }
195
196    /// Returns the stream handle.
197    #[must_use]
198    pub fn handle(&self) -> usize {
199        self.handle
200    }
201
202    /// Returns the device index.
203    #[must_use]
204    pub fn device_index(&self) -> usize {
205        self.device_index
206    }
207
208    /// Synchronizes this stream (waits for all operations to complete).
209    ///
210    /// # Backend-specific behavior
211    /// - **CPU**: No-op (CPU operations are synchronous)
212    /// - **CUDA**: No-op at stream level; use `CudaBackend::synchronize()` for device sync
213    /// - **Vulkan**: Waits for queue to become idle
214    /// - **Metal**: Waits for command buffer completion
215    /// - **WebGPU**: Submits pending commands to queue
216    ///
217    /// For CUDA, proper synchronization should be done through `CudaBackend::synchronize()`
218    /// which performs device-level synchronization.
219    pub fn synchronize(&self) {
220        match self.backend_type {
221            BackendType::Cpu => {} // No-op for CPU (synchronous)
222            #[cfg(feature = "cuda")]
223            BackendType::Cuda => cuda::stream_synchronize(self.handle),
224            #[cfg(feature = "vulkan")]
225            BackendType::Vulkan => vulkan::queue_wait_idle(self.handle),
226            #[cfg(feature = "metal")]
227            BackendType::Metal => metal::command_buffer_wait(self.handle),
228            #[cfg(feature = "wgpu")]
229            BackendType::Wgpu => wgpu_backend::queue_submit(self.handle),
230        }
231    }
232}
233
234// =============================================================================
235// Device Selection Utilities
236// =============================================================================
237
238/// Returns the best available GPU backend.
239#[must_use]
240pub fn best_available_backend() -> BackendType {
241    #[cfg(feature = "cuda")]
242    if cuda::is_available() {
243        return BackendType::Cuda;
244    }
245
246    #[cfg(feature = "metal")]
247    if metal::is_available() {
248        return BackendType::Metal;
249    }
250
251    #[cfg(feature = "vulkan")]
252    if vulkan::is_available() {
253        return BackendType::Vulkan;
254    }
255
256    #[cfg(feature = "wgpu")]
257    if wgpu_backend::is_available() {
258        return BackendType::Wgpu;
259    }
260
261    BackendType::Cpu
262}
263
264/// Returns the number of available GPUs across all backends.
265#[must_use]
266pub fn gpu_count() -> usize {
267    #[allow(unused_mut)]
268    let mut count = 0_usize;
269
270    #[cfg(feature = "cuda")]
271    {
272        count += cuda::device_count();
273    }
274
275    #[cfg(feature = "vulkan")]
276    {
277        count += vulkan::device_count();
278    }
279
280    #[cfg(feature = "metal")]
281    {
282        count += metal::device_count();
283    }
284
285    #[cfg(feature = "wgpu")]
286    {
287        count += wgpu_backend::device_count();
288    }
289
290    count
291}