Skip to main content

oxicuda/
lib.rs

1//! # OxiCUDA — Pure Rust CUDA Replacement
2//!
3//! OxiCUDA provides a complete, pure Rust replacement for NVIDIA's CUDA
4//! software stack. It dynamically loads `libcuda.so` at runtime, requiring
5//! no CUDA Toolkit at build time.
6//!
7//! ## Architecture
8//!
9//! ```text
10//! ┌──────────────────────────────────────────────┐
11//! │           COOLJAPAN Ecosystem                 │
12//! │  SciRS2 │ oxionnx │ TrustformeRS │ ToRSh     │
13//! │         └────┬────┘              │            │
14//! │              └───────────────────┘            │
15//! │                      │                        │
16//! │              ┌───────▼────────┐               │
17//! │              │    OxiCUDA     │               │
18//! │              ├────────────────┤               │
19//! │              │ Driver (Vol.1) │               │
20//! │              │ Memory (Vol.1) │               │
21//! │              │ Launch (Vol.1) │               │
22//! │              │ PTX    (Vol.2) │               │
23//! │              │ Autotune(Vol.2)│               │
24//! │              │ BLAS   (Vol.3) │               │
25//! │              │ DNN    (Vol.4) │               │
26//! │              │ FFT    (Vol.5) │               │
27//! │              │ Sparse (Vol.5) │               │
28//! │              │ Solver (Vol.5) │               │
29//! │              │ Rand   (Vol.5) │               │
30//! │              └───────┬────────┘               │
31//! │              ┌───────▼────────┐               │
32//! │              │ libcuda.so     │               │
33//! │              │ (NVIDIA Driver)│               │
34//! │              └────────────────┘               │
35//! └──────────────────────────────────────────────┘
36//! ```
37//!
38//! ## Quick Start
39//!
40//! ```no_run
41//! use oxicuda::prelude::*;
42//!
43//! fn main() -> CudaResult<()> {
44//!     // Initialize the CUDA driver
45//!     oxicuda::init()?;
46//!
47//!     // Enumerate devices
48//!     let device = Device::get(0)?;
49//!     println!("GPU: {}", device.name()?);
50//!
51//!     // Create context and stream
52//!     let ctx = Context::new(&device)?;
53//!     let ctx = std::sync::Arc::new(ctx);
54//!     let stream = Stream::new(&ctx)?;
55//!
56//!     // Allocate device memory
57//!     let mut buf = DeviceBuffer::<f32>::alloc(1024)?;
58//!     let host_data = vec![1.0f32; 1024];
59//!     buf.copy_from_host(&host_data)?;
60//!
61//!     Ok(())
62//! }
63//! ```
64//!
65//! ## Feature Flags
66//!
67//! | Feature | Description | Default |
68//! |---------|-------------|---------|
69//! | `driver` | CUDA driver API wrapper | Yes |
70//! | `memory` | GPU memory management | Yes |
71//! | `launch` | Kernel launch infrastructure | Yes |
72//! | `ptx` | PTX code generation DSL | No |
73//! | `autotune` | Autotuner engine | No |
74//! | `blas` | cuBLAS equivalent | No |
75//! | `dnn` | cuDNN equivalent | No |
76//! | `fft` | cuFFT equivalent | No |
77//! | `sparse` | cuSPARSE equivalent | No |
78//! | `solver` | cuSOLVER equivalent | No |
79//! | `rand` | cuRAND equivalent | No |
80//! | `pool` | Stream-ordered memory pool | No |
81//! | `backend` | Abstract compute backend trait | No |
82//! | `full` | Enable all features | No |
83//!
84//! (C) 2026 COOLJAPAN OU (Team KitaSan)
85
86#![warn(missing_docs)]
87#![warn(clippy::all)]
88#![allow(clippy::module_name_repetitions)]
89#![allow(clippy::wildcard_imports)]
90
91// ─── Global initialization with device auto-selection ───────
92
93/// Global initialization with device auto-selection.
94///
95/// Provides [`lazy_init`](global_init::lazy_init),
96/// [`OxiCudaRuntimeBuilder`](global_init::OxiCudaRuntimeBuilder), and
97/// related helpers for one-call GPU setup.
98pub mod global_init;
99
100pub use global_init::{DeviceSelection, OxiCudaRuntime, OxiCudaRuntimeBuilder};
101
102// ─── Profiling & tracing ────────────────────────────────────
103
104/// Profiling and tracing hooks for kernel-level performance analysis.
105///
106/// Provides chrome://tracing compatible output for visualizing GPU kernel
107/// execution, memory transfers, and synchronization events.
108pub mod profiling;
109
110// ─── Multi-GPU device pool ─────────────────────────────────
111
112/// Thread-safe multi-GPU device pool with workload-aware scheduling.
113///
114/// Provides [`MultiGpuPool`](device_pool::MultiGpuPool),
115/// [`DeviceSelectionPolicy`](device_pool::DeviceSelectionPolicy),
116/// [`GpuLease`](device_pool::GpuLease), and
117/// [`WorkloadBalancer`](device_pool::WorkloadBalancer).
118pub mod device_pool;
119
120// ─── Abstract compute backend ───────────────────────────────
121
122/// Abstract compute backend for GPU-accelerated operations.
123///
124/// Provides the [`ComputeBackend`](backend::ComputeBackend) trait that
125/// higher-level crates use for GPU dispatch without coupling to a
126/// specific GPU API.
127#[cfg(feature = "backend")]
128pub mod backend;
129
130/// ONNX GPU inference backend.
131///
132/// Provides a complete ONNX operator runtime with IR types, 60+ operators,
133/// graph executor, memory planner, operator fusion, and shape inference.
134#[cfg(feature = "onnx-backend")]
135pub mod onnx_backend;
136
137/// ToRSh GPU tensor backend with autograd, optimizers, and mixed precision.
138///
139/// Provides [`GpuTensor`](tensor_backend::GpuTensor), an autograd tape,
140/// forward/backward ops (matmul, conv2d, softmax, loss functions, etc.),
141/// optimizers (SGD, Adam, AdaGrad, RMSProp, LAMB), and mixed-precision
142/// training (GradScaler, Autocast).
143#[cfg(feature = "tensor-backend")]
144pub mod tensor_backend;
145
146/// TrustformeRS Transformer GPU Backend.
147///
148/// Provides transformer model inference infrastructure: paged KV-cache,
149/// continuous batching, speculative decoding, attention dispatch,
150/// token sampling, and quantized inference.
151#[cfg(feature = "transformer-backend")]
152pub mod transformer_backend;
153
154/// WASM + WebGPU compute backend for browser environments.
155///
156/// Wraps [`oxicuda_webgpu::WebGpuBackend`] with WASM-specific bindings,
157/// making the OxiCUDA compute API usable from a browser via WebAssembly.
158/// On native targets the module is still available and compiles cleanly;
159/// the `#[wasm_bindgen]` exports are only emitted when targeting `wasm32`.
160#[cfg(feature = "wasm-backend")]
161pub mod wasm_backend;
162
163#[cfg(feature = "wasm-backend")]
164pub use wasm_backend::WasmComputeBackend;
165
166// ─── Collective communication (NCCL equivalent) ────────────
167
168/// NCCL-equivalent collective communication primitives for multi-GPU training.
169///
170/// Provides AllReduce, AllGather, ReduceScatter, Broadcast, Reduce, and
171/// AllToAll with ring / tree / recursive-halving algorithm support.
172pub mod collective;
173
174/// Pipeline parallelism primitives for multi-GPU model parallelism.
175///
176/// Provides scheduling algorithms (GPipe, 1F1B, Interleaved, ZeroBubble),
177/// bubble analysis, activation checkpointing, and ASCII visualization.
178pub mod pipeline_parallel;
179
180/// Multi-node distributed training support (TCP/IP based).
181///
182/// Provides [`DistributedRuntime`](distributed::DistributedRuntime),
183/// [`TcpStore`](distributed::TcpStore), [`FileStore`](distributed::FileStore),
184/// [`GradientBucket`](distributed::GradientBucket), and
185/// [`DistributedOptimizer`](distributed::DistributedOptimizer) for
186/// coordinating training across multiple machines.
187pub mod distributed;
188
189// ─── Core crates (always available) ─────────────────────────
190
191/// CUDA Driver API wrapper.
192pub use oxicuda_driver as driver;
193
194/// GPU memory management.
195pub use oxicuda_memory as memory;
196
197/// Kernel launch infrastructure.
198pub use oxicuda_launch as launch;
199
200// ─── Optional crates (feature-gated) ────────────────────────
201
202/// PTX code generation DSL.
203#[cfg(feature = "ptx")]
204pub use oxicuda_ptx as ptx;
205
206/// Autotuner engine.
207#[cfg(feature = "autotune")]
208pub use oxicuda_autotune as autotune;
209
210/// GPU-accelerated BLAS operations.
211#[cfg(feature = "blas")]
212pub use oxicuda_blas as blas;
213
214/// GPU-accelerated deep learning primitives.
215#[cfg(feature = "dnn")]
216pub use oxicuda_dnn as dnn;
217
218/// GPU-accelerated FFT operations.
219#[cfg(feature = "fft")]
220pub use oxicuda_fft as fft;
221
222/// GPU-accelerated sparse matrix operations.
223#[cfg(feature = "sparse")]
224pub use oxicuda_sparse as sparse;
225
226/// GPU-accelerated matrix decompositions.
227#[cfg(feature = "solver")]
228pub use oxicuda_solver as solver;
229
230/// GPU-accelerated random number generation.
231#[cfg(feature = "rand")]
232pub use oxicuda_rand as rand;
233
234/// CUB-equivalent high-performance parallel GPU primitives.
235///
236/// Provides PTX code generators for warp, block, and device-wide reduce, scan,
237/// histogram, radix sort, and merge sort — all without any CUDA SDK dependency.
238#[cfg(feature = "primitives")]
239pub use oxicuda_primitives as primitives;
240
241/// Vulkan Compute backend for cross-vendor GPU compute.
242#[cfg(feature = "vulkan")]
243pub use oxicuda_vulkan as vulkan;
244
245/// Apple Metal Compute backend (macOS/iOS).
246#[cfg(feature = "metal")]
247pub use oxicuda_metal as metal_backend;
248
249/// WebGPU Compute backend (cross-platform via wgpu).
250#[cfg(feature = "webgpu")]
251pub use oxicuda_webgpu as webgpu;
252
253/// AMD ROCm/HIP Compute backend (Linux with AMD GPU).
254#[cfg(feature = "rocm")]
255pub use oxicuda_rocm as rocm;
256
257/// Intel Level Zero Compute backend (Linux/Windows with Intel GPU).
258#[cfg(feature = "level-zero")]
259pub use oxicuda_levelzero as level_zero;
260
261// ─── Key type re-exports ─────────────────────────────────────
262
263// Error types
264pub use oxicuda_driver::{CudaError, CudaResult, DriverLoadError};
265
266// Core types
267pub use oxicuda_driver::{
268    Context, Device, Event, Function, JitDiagnostic, JitLog, JitOptions, JitSeverity, Module,
269    Stream,
270};
271pub use oxicuda_driver::{best_device, list_devices, try_driver};
272
273// Memory types
274pub use oxicuda_memory::copy;
275pub use oxicuda_memory::{DeviceBuffer, DeviceSlice, PinnedBuffer, UnifiedBuffer};
276
277// Launch types
278pub use oxicuda_launch::{
279    Dim3, Kernel, KernelArgs, LaunchParams, LaunchParamsBuilder, grid_size_for,
280};
281
282// Re-export the launch! macro
283pub use oxicuda_launch::launch;
284
285/// Initialize the CUDA driver API.
286///
287/// This must be called before any other OxiCUDA function.
288/// It dynamically loads `libcuda.so` (Linux), `nvcuda.dll` (Windows),
289/// and initializes the CUDA driver.
290///
291/// Returns `Err(CudaError::NotInitialized)` on macOS or systems
292/// without an NVIDIA GPU.
293pub fn init() -> CudaResult<()> {
294    oxicuda_driver::init()
295}
296
297/// Compile-time feature availability.
298pub mod features {
299    /// Whether PTX code generation is available.
300    pub const HAS_PTX: bool = cfg!(feature = "ptx");
301    /// Whether the autotuner is available.
302    pub const HAS_AUTOTUNE: bool = cfg!(feature = "autotune");
303    /// Whether BLAS operations are available.
304    pub const HAS_BLAS: bool = cfg!(feature = "blas");
305    /// Whether DNN operations are available.
306    pub const HAS_DNN: bool = cfg!(feature = "dnn");
307    /// Whether FFT operations are available.
308    pub const HAS_FFT: bool = cfg!(feature = "fft");
309    /// Whether sparse matrix operations are available.
310    pub const HAS_SPARSE: bool = cfg!(feature = "sparse");
311    /// Whether solver operations are available.
312    pub const HAS_SOLVER: bool = cfg!(feature = "solver");
313    /// Whether random number generation is available.
314    pub const HAS_RAND: bool = cfg!(feature = "rand");
315    /// Whether the abstract compute backend is available.
316    pub const HAS_BACKEND: bool = cfg!(feature = "backend");
317    /// Whether the ONNX inference backend is available.
318    pub const HAS_ONNX_BACKEND: bool = cfg!(feature = "onnx-backend");
319    /// Whether the ToRSh tensor backend is available.
320    pub const HAS_TENSOR_BACKEND: bool = cfg!(feature = "tensor-backend");
321    /// Whether the TrustformeRS transformer backend is available.
322    pub const HAS_TRANSFORMER_BACKEND: bool = cfg!(feature = "transformer-backend");
323    /// Whether stream-ordered memory pool is available.
324    pub const HAS_POOL: bool = cfg!(feature = "pool");
325    /// Whether GPU tests are enabled.
326    pub const HAS_GPU_TESTS: bool = cfg!(feature = "gpu-tests");
327    /// Whether global initialization is available (always `true`).
328    pub const HAS_GLOBAL_INIT: bool = true;
329    /// Whether the Vulkan Compute backend is available.
330    pub const HAS_VULKAN: bool = cfg!(feature = "vulkan");
331    /// Whether the Apple Metal Compute backend is available.
332    pub const HAS_METAL: bool = cfg!(feature = "metal");
333    /// Whether the WebGPU Compute backend is available.
334    pub const HAS_WEBGPU: bool = cfg!(feature = "webgpu");
335    /// Whether the AMD ROCm/HIP Compute backend is available.
336    pub const HAS_ROCM: bool = cfg!(feature = "rocm");
337    /// Whether the Intel Level Zero Compute backend is available.
338    pub const HAS_LEVEL_ZERO: bool = cfg!(feature = "level-zero");
339    /// Whether the WASM + WebGPU browser backend is available.
340    pub const HAS_WASM_BACKEND: bool = cfg!(feature = "wasm-backend");
341}
342
343// ---------------------------------------------------------------------------
344// ComputeBackend auto-selection threshold
345// ---------------------------------------------------------------------------
346
347/// Auto-selection threshold for the compute backend.
348///
349/// Tensors or data buffers larger than this threshold (in bytes) will be
350/// dispatched to the GPU backend; smaller workloads use the CPU backend.
351/// The 64 KB default is tuned for SciRS2 workloads where GPU launch overhead
352/// dominates for small matrices.
353pub const AUTO_SELECT_THRESHOLD_BYTES: usize = 64 * 1024; // 65536 bytes
354
355// ---------------------------------------------------------------------------
356// ONNX supported operator list
357// ---------------------------------------------------------------------------
358
359/// List of ONNX operators supported by the OxiCUDA ONNX backend.
360///
361/// This is the canonical list used both for operator dispatch and for
362/// validating ONNX model compatibility.
363pub const SUPPORTED_ONNX_OPS: &[&str] = &[
364    "MatMul",
365    "Conv",
366    "Relu",
367    "BatchNormalization",
368    "Softmax",
369    "LayerNormalization",
370    "Add",
371    "Mul",
372    "Transpose",
373    "Reshape",
374    "Concat",
375];
376
377// ---------------------------------------------------------------------------
378// Tests
379// ---------------------------------------------------------------------------
380
381#[cfg(test)]
382mod umbrella_tests {
383    use super::*;
384
385    // -----------------------------------------------------------------------
386    // ComputeBackend auto-selection threshold tests
387    // -----------------------------------------------------------------------
388
389    #[test]
390    fn compute_backend_threshold_is_64kb() {
391        assert_eq!(
392            AUTO_SELECT_THRESHOLD_BYTES,
393            64 * 1024,
394            "auto-select threshold must be exactly 64 KiB = 65536 bytes"
395        );
396    }
397
398    #[test]
399    fn small_tensor_uses_cpu_backend() {
400        // A tensor with < 64 KB data is below the threshold → CPU backend selected.
401        let small_data_bytes: usize = 1024; // 1 KB
402        assert!(
403            small_data_bytes < AUTO_SELECT_THRESHOLD_BYTES,
404            "1 KB should be below threshold → CPU backend"
405        );
406    }
407
408    #[test]
409    fn large_tensor_uses_gpu_backend() {
410        // A tensor with > 64 KB data is above the threshold → GPU backend attempted.
411        let large_data_bytes: usize = 1024 * 1024; // 1 MB
412        assert!(
413            large_data_bytes > AUTO_SELECT_THRESHOLD_BYTES,
414            "1 MB should be above threshold → GPU backend"
415        );
416    }
417
418    #[test]
419    fn threshold_boundary_values() {
420        // Exactly at threshold: not above → CPU backend.
421        const { assert!(AUTO_SELECT_THRESHOLD_BYTES <= AUTO_SELECT_THRESHOLD_BYTES) }
422        // One byte above threshold → GPU backend.
423        const { assert!(AUTO_SELECT_THRESHOLD_BYTES + 1 > AUTO_SELECT_THRESHOLD_BYTES) }
424    }
425
426    // -----------------------------------------------------------------------
427    // ONNX operator interface tests
428    // -----------------------------------------------------------------------
429
430    #[test]
431    fn onnx_matmul_op_name_correct() {
432        assert!(
433            SUPPORTED_ONNX_OPS.contains(&"MatMul"),
434            "SUPPORTED_ONNX_OPS must contain 'MatMul'"
435        );
436    }
437
438    #[test]
439    fn onnx_conv_op_name_correct() {
440        assert!(
441            SUPPORTED_ONNX_OPS.contains(&"Conv"),
442            "SUPPORTED_ONNX_OPS must contain 'Conv'"
443        );
444    }
445
446    #[test]
447    fn onnx_op_list_includes_relu() {
448        assert!(
449            SUPPORTED_ONNX_OPS.contains(&"Relu"),
450            "SUPPORTED_ONNX_OPS must contain 'Relu'"
451        );
452    }
453
454    #[test]
455    fn onnx_op_list_includes_softmax() {
456        assert!(
457            SUPPORTED_ONNX_OPS.contains(&"Softmax"),
458            "SUPPORTED_ONNX_OPS must contain 'Softmax'"
459        );
460    }
461
462    #[test]
463    fn onnx_op_list_includes_layer_norm() {
464        assert!(
465            SUPPORTED_ONNX_OPS.contains(&"LayerNormalization"),
466            "SUPPORTED_ONNX_OPS must contain 'LayerNormalization'"
467        );
468    }
469
470    #[test]
471    fn onnx_op_list_includes_batch_norm() {
472        assert!(
473            SUPPORTED_ONNX_OPS.contains(&"BatchNormalization"),
474            "SUPPORTED_ONNX_OPS must contain 'BatchNormalization'"
475        );
476    }
477
478    // -----------------------------------------------------------------------
479    // ToRSh SDPA + TrustformeRS MoE config tests
480    // (gated on transformer-backend feature)
481    // -----------------------------------------------------------------------
482
483    #[cfg(feature = "transformer-backend")]
484    mod transformer_tests {
485        use crate::transformer_backend::attention::ComputeTier;
486        use crate::transformer_backend::attention::{AttentionConfig, AttentionKind, HeadConfig};
487
488        #[test]
489        fn torsh_sdpa_attention_config_exists() {
490            // Verify AttentionConfig exists and supports FlashAttention dispatch.
491            let cfg = AttentionConfig {
492                head_config: HeadConfig::Mha { num_heads: 32 },
493                head_dim: 128,
494                use_paged_cache: false,
495                compute_tier: ComputeTier::Hopper,
496                sliding_window: None,
497                causal: true,
498                scale: None,
499                max_seq_len_hint: Some(4096),
500            };
501            // For Hopper + long sequences, kernel should be FlashHopper.
502            use crate::transformer_backend::attention::AttentionDispatch;
503            let dispatch = AttentionDispatch::new(cfg);
504            assert!(dispatch.is_ok(), "AttentionDispatch::new should succeed");
505            let mut dispatch = dispatch.expect("AttentionDispatch creation failed");
506            // With long sequences on Hopper, Flash or FlashHopper should be selected.
507            let kernel = dispatch.select_kernel(4096);
508            assert!(
509                matches!(kernel, AttentionKind::Flash | AttentionKind::FlashHopper),
510                "Hopper with 4096 tokens should use Flash attention, got {kernel:?}"
511            );
512        }
513    }
514
515    /// MoE configuration verification (CPU-side, no GPU required).
516    #[test]
517    fn trustformers_moe_config_exists() {
518        // Verify the MoE routing math: tokens_per_expert ≈ batch * seq * top_k / num_experts.
519        let num_experts: usize = 8;
520        let top_k: usize = 2;
521        let batch_size: usize = 4;
522        let seq_len: usize = 512;
523
524        // Expected tokens per expert on average (Mixtral 8x7B pattern).
525        let total_tokens = batch_size * seq_len;
526        let routed_tokens = total_tokens * top_k;
527        // Each of the 8 experts should receive routed_tokens / num_experts.
528        let tokens_per_expert = routed_tokens / num_experts;
529
530        // For batch=4, seq=512, top_k=2, num_experts=8:
531        // total = 2048, routed = 4096, per_expert = 512.
532        assert_eq!(total_tokens, 2048);
533        assert_eq!(routed_tokens, 4096);
534        assert_eq!(tokens_per_expert, 512);
535    }
536
537    #[test]
538    fn moe_mixtral_config_8x7b() {
539        // Verify the standard Mixtral 8x7B MoE routing configuration.
540        let num_experts: usize = 8;
541        let top_k: usize = 2;
542
543        // Activation rate: each token activates top_k / num_experts fraction of experts.
544        let activation_rate = top_k as f64 / num_experts as f64;
545        assert!(
546            (activation_rate - 0.25).abs() < 1e-10,
547            "Mixtral 8x7B: activation rate = {activation_rate}, expected 0.25"
548        );
549
550        // Load balance: with uniform routing, each expert receives the same
551        // expected number of tokens.
552        let batch_size: usize = 16;
553        let seq_len: usize = 1024;
554        let expected_per_expert = batch_size * seq_len * top_k / num_experts;
555        // 16 * 1024 * 2 / 8 = 4096 tokens per expert.
556        assert_eq!(expected_per_expert, 4096);
557    }
558}
559
560/// Convenience re-exports for common usage patterns.
561///
562/// ```no_run
563/// use oxicuda::prelude::*;
564/// ```
565pub mod prelude {
566    // Error handling
567    pub use crate::{CudaError, CudaResult};
568
569    // Initialization
570    pub use crate::{init, try_driver};
571
572    // Core GPU types
573    pub use crate::{Context, Device, Event, Function, Module, Stream};
574    pub use crate::{best_device, list_devices};
575
576    // Memory management
577    pub use crate::{DeviceBuffer, PinnedBuffer, UnifiedBuffer};
578
579    // Kernel launch
580    pub use crate::{Dim3, Kernel, KernelArgs, LaunchParams, grid_size_for};
581
582    // Global initialization
583    pub use crate::global_init::{
584        default_context, default_device, default_stream, is_initialized, lazy_init,
585    };
586
587    // Parallel primitives (feature = "primitives")
588    #[cfg(feature = "primitives")]
589    pub use oxicuda_primitives::{PrimitivesError, PrimitivesHandle, PrimitivesResult, ReduceOp};
590}