Skip to main content

trueno/backends/gpu/
mod.rs

1#![allow(missing_docs)]
2//! GPU backend using wgpu (Vulkan/Metal/DX12/WebGPU)
3//!
4//! This backend provides GPU-accelerated compute for large-scale operations.
5//! It uses wgpu for cross-platform GPU access and WGSL compute shaders.
6//!
7//! # Performance
8//!
9//! GPU backend is optimal for very large workloads (>100K elements for reductions,
10//! >1000×1000 for matrix operations) where transfer overhead is amortized.
11//!
12//! Expected speedups vs SIMD:
13//! - Matrix multiplication (large): 10-50x
14//! - Reductions (large): 5-20x
15//!
16//! # Architecture
17//!
18//! - Device initialization is lazy (first GPU operation)
19//! - Compute shaders written in WGSL
20//! - Asynchronous execution with pollster for blocking
21//! - Automatic fallback to CPU if GPU unavailable
22//!
23//! # Memory Hierarchy Abstractions
24//!
25//! - [`TensorView`] - Structured view into GPU memory with shape/stride metadata
26//! - [`PartitionView`] - Tiling strategy for efficient GPU work distribution
27//!
28//! Based on cuda-tile-behavior.md Section 3.2.
29
30#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
31mod batch;
32
33#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
34mod device;
35
36#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
37mod pool;
38
39#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
40pub mod shaders;
41
42#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
43pub mod runtime;
44
45// Memory hierarchy abstractions (always available, no GPU feature required)
46mod partition_view;
47mod tensor_view;
48mod tiled_reduction;
49
50pub use partition_view::{PartitionView, TileInfo};
51pub use tensor_view::{MemoryLayout, TensorView};
52pub use tiled_reduction::{
53    tiled_max_2d, tiled_min_2d, tiled_reduce_2d, tiled_reduce_partial, tiled_sum_2d, MaxOp, MinOp,
54    ReduceOp, SumOp, TILE_SIZE,
55};
56
57#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
58pub use batch::{BufferId, GpuCommandBatch, PipelineCache};
59
60// Export GpuDevice for both native and WASM GPU features
61#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
62pub use device::GpuDevice;
63
64/// Re-export wgpu types for downstream crates that need to create persistent
65/// GPU buffers (KAIZEN-015: GPU-resident weights).
66#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
67pub use wgpu;
68
69#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
70pub use pool::GpuDevicePool;
71
72/// PMAT-322: Cached matmul with persistent weight buffers for LLM inference.
73#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
74pub use device::linalg::cached_matmul::GpuMatmulCache;
75
76/// PMAT-324: WGSL transformer forward pass shaders.
77#[cfg(any(feature = "gpu", feature = "gpu-wasm"))]
78pub use device::linalg::wgsl_forward::{QkvLoRA, WgslForwardPass};
79
80#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
81mod backend_ops;
82
83/// GPU backend for compute operations (native only, uses sync wrappers)
84#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
85#[derive(Clone)]
86pub struct GpuBackend {
87    device: Option<GpuDevice>,
88}
89
90#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
91impl GpuBackend {
92    /// Create a new GPU backend
93    pub fn new() -> Self {
94        Self { device: None }
95    }
96
97    /// Initialize GPU device (lazy)
98    fn ensure_device(&mut self) -> Result<&GpuDevice, String> {
99        if self.device.is_none() {
100            self.device = Some(GpuDevice::new()?);
101        }
102        Ok(self.device.as_ref().expect("device initialized above"))
103    }
104
105    /// Check if GPU is available
106    pub fn is_available() -> bool {
107        GpuDevice::is_available()
108    }
109}
110
111#[cfg(all(feature = "gpu", not(target_arch = "wasm32")))]
112impl Default for GpuBackend {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118// Stub implementation when GPU feature is disabled or on WASM
119#[cfg(any(not(feature = "gpu"), target_arch = "wasm32"))]
120#[derive(Clone)]
121pub struct GpuBackend;
122
123#[cfg(any(not(feature = "gpu"), target_arch = "wasm32"))]
124impl GpuBackend {
125    pub fn new() -> Self {
126        Self
127    }
128
129    pub fn is_available() -> bool {
130        false
131    }
132}
133
134#[cfg(any(not(feature = "gpu"), target_arch = "wasm32"))]
135impl Default for GpuBackend {
136    fn default() -> Self {
137        Self
138    }
139}
140
141// Tests for stub implementation (when GPU feature is NOT enabled)
142#[cfg(test)]
143#[cfg(not(feature = "gpu"))]
144mod stub_tests {
145    use super::*;
146
147    #[test]
148    fn test_gpu_backend_stub_new() {
149        let _backend = GpuBackend::new();
150    }
151
152    #[test]
153    fn test_gpu_backend_stub_is_available() {
154        assert!(!GpuBackend::is_available());
155    }
156
157    #[test]
158    fn test_gpu_backend_stub_default() {
159        let _ = GpuBackend;
160    }
161
162    #[test]
163    fn test_gpu_backend_stub_clone() {
164        let backend = GpuBackend::new();
165        let _cloned = backend.clone();
166    }
167}
168
169#[cfg(test)]
170#[cfg(feature = "gpu")]
171mod tests_gpu;