Skip to main content

axonml_core/backends/
mod.rs

1//! Backends - Device-Specific Implementations
2//!
3//! This module contains backend implementations for different compute devices.
4//! Each backend provides device-specific memory operations and kernel execution.
5//!
6//! # Available Backends
7//! - `cpu` - CPU backend (always available)
8//! - `cuda` - NVIDIA CUDA backend (requires `cuda` feature)
9//! - `vulkan` - Vulkan backend (requires `vulkan` feature)
10//! - `metal` - Apple Metal backend (requires `metal` feature)
11//! - `wgpu` - WebGPU backend (requires `wgpu` feature)
12//!
13//! # Backend Trait
14//!
15//! All backends implement the `Backend` trait which provides a common interface
16//! for tensor operations. This enables device-agnostic code.
17//!
18//! @version 0.1.0
19//! @author `AutomataNexus` Development Team
20
21use crate::device::DeviceCapabilities;
22
23// =============================================================================
24// Backend Modules
25// =============================================================================
26
27pub mod cpu;
28
29#[cfg(feature = "cuda")]
30pub mod cuda;
31
32#[cfg(feature = "vulkan")]
33pub mod vulkan;
34
35#[cfg(feature = "metal")]
36pub mod metal;
37
38#[cfg(feature = "wgpu")]
39pub mod wgpu_backend;
40
41// =============================================================================
42// Re-exports
43// =============================================================================
44
45pub use cpu::CpuBackend;
46
47#[cfg(feature = "cuda")]
48pub use cuda::CudaBackend;
49
50#[cfg(feature = "vulkan")]
51pub use vulkan::VulkanBackend;
52
53#[cfg(feature = "metal")]
54pub use metal::MetalBackend;
55
56#[cfg(feature = "wgpu")]
57pub use wgpu_backend::WgpuBackend;
58
59// =============================================================================
60// Backend Trait
61// =============================================================================
62
63/// Common trait for all compute backends.
64///
65/// This trait defines the interface that all backends must implement,
66/// enabling device-agnostic tensor operations.
67pub trait Backend: Send + Sync {
68    /// Returns the name of this backend.
69    fn name(&self) -> &'static str;
70
71    /// Returns whether this backend is available on the current system.
72    fn is_available(&self) -> bool;
73
74    /// Returns the device capabilities.
75    fn capabilities(&self) -> DeviceCapabilities;
76
77    /// Allocates memory on this backend.
78    fn allocate(&self, size: usize) -> *mut u8;
79
80    /// Deallocates memory on this backend.
81    fn deallocate(&self, ptr: *mut u8, size: usize);
82
83    /// Copies data from host to device.
84    fn copy_to_device(&self, dst: *mut u8, src: *const u8, size: usize);
85
86    /// Copies data from device to host.
87    fn copy_to_host(&self, dst: *mut u8, src: *const u8, size: usize);
88
89    /// Copies data within the device.
90    fn copy_device_to_device(&self, dst: *mut u8, src: *const u8, size: usize);
91
92    /// Synchronizes the device (waits for all operations to complete).
93    fn synchronize(&self);
94}
95
96// =============================================================================
97// GPU Memory Management
98// =============================================================================
99
100/// GPU memory handle for safe memory management.
101#[derive(Debug)]
102pub struct GpuMemory {
103    ptr: *mut u8,
104    size: usize,
105    device_index: usize,
106    backend_type: BackendType,
107}
108
109/// Type of backend for a GPU memory allocation.
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
111pub enum BackendType {
112    /// CPU backend.
113    Cpu,
114    /// CUDA backend.
115    #[cfg(feature = "cuda")]
116    Cuda,
117    /// Vulkan backend.
118    #[cfg(feature = "vulkan")]
119    Vulkan,
120    /// Metal backend.
121    #[cfg(feature = "metal")]
122    Metal,
123    /// WebGPU backend.
124    #[cfg(feature = "wgpu")]
125    Wgpu,
126}
127
128impl GpuMemory {
129    /// Creates a new GPU memory handle.
130    pub fn new(ptr: *mut u8, size: usize, device_index: usize, backend_type: BackendType) -> Self {
131        Self {
132            ptr,
133            size,
134            device_index,
135            backend_type,
136        }
137    }
138
139    /// Returns the raw pointer.
140    #[must_use]
141    pub fn ptr(&self) -> *mut u8 {
142        self.ptr
143    }
144
145    /// Returns the size in bytes.
146    #[must_use]
147    pub fn size(&self) -> usize {
148        self.size
149    }
150
151    /// Returns the device index.
152    #[must_use]
153    pub fn device_index(&self) -> usize {
154        self.device_index
155    }
156
157    /// Returns the backend type.
158    #[must_use]
159    pub fn backend_type(&self) -> BackendType {
160        self.backend_type
161    }
162}
163
164// =============================================================================
165// GPU Stream/Queue Abstraction
166// =============================================================================
167
168/// GPU execution stream for async operations.
169#[derive(Debug)]
170pub struct GpuStream {
171    /// Stream handle (backend-specific).
172    handle: usize,
173    /// Device index.
174    device_index: usize,
175    /// Backend type.
176    backend_type: BackendType,
177}
178
179impl GpuStream {
180    /// Creates a new GPU stream.
181    #[must_use]
182    pub fn new(handle: usize, device_index: usize, backend_type: BackendType) -> Self {
183        Self {
184            handle,
185            device_index,
186            backend_type,
187        }
188    }
189
190    /// Returns the stream handle.
191    #[must_use]
192    pub fn handle(&self) -> usize {
193        self.handle
194    }
195
196    /// Returns the device index.
197    #[must_use]
198    pub fn device_index(&self) -> usize {
199        self.device_index
200    }
201
202    /// Synchronizes this stream (waits for all operations to complete).
203    pub fn synchronize(&self) {
204        match self.backend_type {
205            BackendType::Cpu => {} // No-op for CPU
206            #[cfg(feature = "cuda")]
207            BackendType::Cuda => cuda::stream_synchronize(self.handle),
208            #[cfg(feature = "vulkan")]
209            BackendType::Vulkan => vulkan::queue_wait_idle(self.handle),
210            #[cfg(feature = "metal")]
211            BackendType::Metal => metal::command_buffer_wait(self.handle),
212            #[cfg(feature = "wgpu")]
213            BackendType::Wgpu => wgpu_backend::queue_submit(self.handle),
214        }
215    }
216}
217
218// =============================================================================
219// Device Selection Utilities
220// =============================================================================
221
222/// Returns the best available GPU backend.
223#[must_use]
224pub fn best_available_backend() -> BackendType {
225    #[cfg(feature = "cuda")]
226    if cuda::is_available() {
227        return BackendType::Cuda;
228    }
229
230    #[cfg(feature = "metal")]
231    if metal::is_available() {
232        return BackendType::Metal;
233    }
234
235    #[cfg(feature = "vulkan")]
236    if vulkan::is_available() {
237        return BackendType::Vulkan;
238    }
239
240    #[cfg(feature = "wgpu")]
241    if wgpu_backend::is_available() {
242        return BackendType::Wgpu;
243    }
244
245    BackendType::Cpu
246}
247
248/// Returns the number of available GPUs across all backends.
249#[must_use]
250pub fn gpu_count() -> usize {
251    #[allow(unused_mut)]
252    let mut count = 0_usize;
253
254    #[cfg(feature = "cuda")]
255    {
256        count += cuda::device_count();
257    }
258
259    #[cfg(feature = "vulkan")]
260    {
261        count += vulkan::device_count();
262    }
263
264    #[cfg(feature = "metal")]
265    {
266        count += metal::device_count();
267    }
268
269    #[cfg(feature = "wgpu")]
270    {
271        count += wgpu_backend::device_count();
272    }
273
274    count
275}