oxicuda/lib.rs
1//! # OxiCUDA — Pure Rust CUDA Replacement
2//!
3//! OxiCUDA provides a complete, pure Rust replacement for NVIDIA's CUDA
4//! software stack. It dynamically loads `libcuda.so` at runtime, requiring
5//! no CUDA Toolkit at build time.
6//!
7//! ## Architecture
8//!
9//! ```text
10//! ┌──────────────────────────────────────────────┐
11//! │ COOLJAPAN Ecosystem │
12//! │ SciRS2 │ oxionnx │ TrustformeRS │ ToRSh │
13//! │ └────┬────┘ │ │
14//! │ └───────────────────┘ │
15//! │ │ │
16//! │ ┌───────▼────────┐ │
17//! │ │ OxiCUDA │ │
18//! │ ├────────────────┤ │
19//! │ │ Driver (Vol.1) │ │
20//! │ │ Memory (Vol.1) │ │
21//! │ │ Launch (Vol.1) │ │
22//! │ │ PTX (Vol.2) │ │
23//! │ │ Autotune(Vol.2)│ │
24//! │ │ BLAS (Vol.3) │ │
25//! │ │ DNN (Vol.4) │ │
26//! │ │ FFT (Vol.5) │ │
27//! │ │ Sparse (Vol.5) │ │
28//! │ │ Solver (Vol.5) │ │
29//! │ │ Rand (Vol.5) │ │
30//! │ └───────┬────────┘ │
31//! │ ┌───────▼────────┐ │
32//! │ │ libcuda.so │ │
33//! │ │ (NVIDIA Driver)│ │
34//! │ └────────────────┘ │
35//! └──────────────────────────────────────────────┘
36//! ```
37//!
38//! ## Quick Start
39//!
40//! ```no_run
41//! use oxicuda::prelude::*;
42//!
43//! fn main() -> CudaResult<()> {
44//! // Initialize the CUDA driver
45//! oxicuda::init()?;
46//!
47//! // Enumerate devices
48//! let device = Device::get(0)?;
49//! println!("GPU: {}", device.name()?);
50//!
51//! // Create context and stream
52//! let ctx = Context::new(&device)?;
53//! let ctx = std::sync::Arc::new(ctx);
54//! let stream = Stream::new(&ctx)?;
55//!
56//! // Allocate device memory
57//! let mut buf = DeviceBuffer::<f32>::alloc(1024)?;
58//! let host_data = vec![1.0f32; 1024];
59//! buf.copy_from_host(&host_data)?;
60//!
61//! Ok(())
62//! }
63//! ```
64//!
65//! ## Feature Flags
66//!
67//! | Feature | Description | Default |
68//! |---------|-------------|---------|
69//! | `driver` | CUDA driver API wrapper | Yes |
70//! | `memory` | GPU memory management | Yes |
71//! | `launch` | Kernel launch infrastructure | Yes |
72//! | `ptx` | PTX code generation DSL | No |
73//! | `autotune` | Autotuner engine | No |
74//! | `blas` | cuBLAS equivalent | No |
75//! | `dnn` | cuDNN equivalent | No |
76//! | `fft` | cuFFT equivalent | No |
77//! | `sparse` | cuSPARSE equivalent | No |
78//! | `solver` | cuSOLVER equivalent | No |
79//! | `rand` | cuRAND equivalent | No |
80//! | `pool` | Stream-ordered memory pool | No |
81//! | `backend` | Abstract compute backend trait | No |
82//! | `full` | Enable all features | No |
83//!
84//! (C) 2026 COOLJAPAN OU (Team KitaSan)
85
86#![warn(missing_docs)]
87#![warn(clippy::all)]
88#![allow(clippy::module_name_repetitions)]
89#![allow(clippy::wildcard_imports)]
90
91// ─── Global initialization with device auto-selection ───────
92
93/// Global initialization with device auto-selection.
94///
95/// Provides [`lazy_init`](global_init::lazy_init),
96/// [`OxiCudaRuntimeBuilder`](global_init::OxiCudaRuntimeBuilder), and
97/// related helpers for one-call GPU setup.
98pub mod global_init;
99
100pub use global_init::{DeviceSelection, OxiCudaRuntime, OxiCudaRuntimeBuilder};
101
102// ─── Profiling & tracing ────────────────────────────────────
103
104/// Profiling and tracing hooks for kernel-level performance analysis.
105///
106/// Provides chrome://tracing compatible output for visualizing GPU kernel
107/// execution, memory transfers, and synchronization events.
108pub mod profiling;
109
110// ─── Multi-GPU device pool ─────────────────────────────────
111
112/// Thread-safe multi-GPU device pool with workload-aware scheduling.
113///
114/// Provides [`MultiGpuPool`](device_pool::MultiGpuPool),
115/// [`DeviceSelectionPolicy`](device_pool::DeviceSelectionPolicy),
116/// [`GpuLease`](device_pool::GpuLease), and
117/// [`WorkloadBalancer`](device_pool::WorkloadBalancer).
118pub mod device_pool;
119
120// ─── Abstract compute backend ───────────────────────────────
121
122/// Abstract compute backend for GPU-accelerated operations.
123///
124/// Provides the [`ComputeBackend`](backend::ComputeBackend) trait that
125/// higher-level crates use for GPU dispatch without coupling to a
126/// specific GPU API.
127#[cfg(feature = "backend")]
128pub mod backend;
129
130/// ONNX GPU inference backend.
131///
132/// Provides a complete ONNX operator runtime with IR types, 60+ operators,
133/// graph executor, memory planner, operator fusion, and shape inference.
134#[cfg(feature = "onnx-backend")]
135pub mod onnx_backend;
136
137/// ToRSh GPU tensor backend with autograd, optimizers, and mixed precision.
138///
139/// Provides [`GpuTensor`](tensor_backend::GpuTensor), an autograd tape,
140/// forward/backward ops (matmul, conv2d, softmax, loss functions, etc.),
141/// optimizers (SGD, Adam, AdaGrad, RMSProp, LAMB), and mixed-precision
142/// training (GradScaler, Autocast).
143#[cfg(feature = "tensor-backend")]
144pub mod tensor_backend;
145
146/// TrustformeRS Transformer GPU Backend.
147///
148/// Provides transformer model inference infrastructure: paged KV-cache,
149/// continuous batching, speculative decoding, attention dispatch,
150/// token sampling, and quantized inference.
151#[cfg(feature = "transformer-backend")]
152pub mod transformer_backend;
153
154/// WASM + WebGPU compute backend for browser environments.
155///
156/// Wraps [`oxicuda_webgpu::WebGpuBackend`] with WASM-specific bindings,
157/// making the OxiCUDA compute API usable from a browser via WebAssembly.
158/// On native targets the module is still available and compiles cleanly;
159/// the `#[wasm_bindgen]` exports are only emitted when targeting `wasm32`.
160#[cfg(feature = "wasm-backend")]
161pub mod wasm_backend;
162
163#[cfg(feature = "wasm-backend")]
164pub use wasm_backend::WasmComputeBackend;
165
166// ─── Collective communication (NCCL equivalent) ────────────
167
168/// NCCL-equivalent collective communication primitives for multi-GPU training.
169///
170/// Provides AllReduce, AllGather, ReduceScatter, Broadcast, Reduce, and
171/// AllToAll with ring / tree / recursive-halving algorithm support.
172pub mod collective;
173
174/// Pipeline parallelism primitives for multi-GPU model parallelism.
175///
176/// Provides scheduling algorithms (GPipe, 1F1B, Interleaved, ZeroBubble),
177/// bubble analysis, activation checkpointing, and ASCII visualization.
178pub mod pipeline_parallel;
179
180/// Multi-node distributed training support (TCP/IP based).
181///
182/// Provides [`DistributedRuntime`](distributed::DistributedRuntime),
183/// [`TcpStore`](distributed::TcpStore), [`FileStore`](distributed::FileStore),
184/// [`GradientBucket`](distributed::GradientBucket), and
185/// [`DistributedOptimizer`](distributed::DistributedOptimizer) for
186/// coordinating training across multiple machines.
187pub mod distributed;
188
189// ─── Core crates (always available) ─────────────────────────
190
191/// CUDA Driver API wrapper.
192pub use oxicuda_driver as driver;
193
194/// GPU memory management.
195pub use oxicuda_memory as memory;
196
197/// Kernel launch infrastructure.
198pub use oxicuda_launch as launch;
199
200// ─── Optional crates (feature-gated) ────────────────────────
201
202/// PTX code generation DSL.
203#[cfg(feature = "ptx")]
204pub use oxicuda_ptx as ptx;
205
206/// Autotuner engine.
207#[cfg(feature = "autotune")]
208pub use oxicuda_autotune as autotune;
209
210/// GPU-accelerated BLAS operations.
211#[cfg(feature = "blas")]
212pub use oxicuda_blas as blas;
213
214/// GPU-accelerated deep learning primitives.
215#[cfg(feature = "dnn")]
216pub use oxicuda_dnn as dnn;
217
218/// GPU-accelerated FFT operations.
219#[cfg(feature = "fft")]
220pub use oxicuda_fft as fft;
221
222/// GPU-accelerated sparse matrix operations.
223#[cfg(feature = "sparse")]
224pub use oxicuda_sparse as sparse;
225
226/// GPU-accelerated matrix decompositions.
227#[cfg(feature = "solver")]
228pub use oxicuda_solver as solver;
229
230/// GPU-accelerated random number generation.
231#[cfg(feature = "rand")]
232pub use oxicuda_rand as rand;
233
234/// CUB-equivalent high-performance parallel GPU primitives.
235///
236/// Provides PTX code generators for warp, block, and device-wide reduce, scan,
237/// histogram, radix sort, and merge sort — all without any CUDA SDK dependency.
238#[cfg(feature = "primitives")]
239pub use oxicuda_primitives as primitives;
240
241/// Vulkan Compute backend for cross-vendor GPU compute.
242#[cfg(feature = "vulkan")]
243pub use oxicuda_vulkan as vulkan;
244
245/// Apple Metal Compute backend (macOS/iOS).
246#[cfg(feature = "metal")]
247pub use oxicuda_metal as metal_backend;
248
249/// WebGPU Compute backend (cross-platform via wgpu).
250#[cfg(feature = "webgpu")]
251pub use oxicuda_webgpu as webgpu;
252
253/// AMD ROCm/HIP Compute backend (Linux with AMD GPU).
254#[cfg(feature = "rocm")]
255pub use oxicuda_rocm as rocm;
256
257/// Intel Level Zero Compute backend (Linux/Windows with Intel GPU).
258#[cfg(feature = "level-zero")]
259pub use oxicuda_levelzero as level_zero;
260
261// ─── Key type re-exports ─────────────────────────────────────
262
263// Error types
264pub use oxicuda_driver::{CudaError, CudaResult, DriverLoadError};
265
266// Core types
267pub use oxicuda_driver::{
268 Context, Device, Event, Function, JitDiagnostic, JitLog, JitOptions, JitSeverity, Module,
269 Stream,
270};
271pub use oxicuda_driver::{best_device, list_devices, try_driver};
272
273// Memory types
274pub use oxicuda_memory::copy;
275pub use oxicuda_memory::{DeviceBuffer, DeviceSlice, PinnedBuffer, UnifiedBuffer};
276
277// Launch types
278pub use oxicuda_launch::{
279 Dim3, Kernel, KernelArgs, LaunchParams, LaunchParamsBuilder, grid_size_for,
280};
281
282// Re-export the launch! macro
283pub use oxicuda_launch::launch;
284
285/// Initialize the CUDA driver API.
286///
287/// This must be called before any other OxiCUDA function.
288/// It dynamically loads `libcuda.so` (Linux), `nvcuda.dll` (Windows),
289/// and initializes the CUDA driver.
290///
291/// Returns `Err(CudaError::NotInitialized)` on macOS or systems
292/// without an NVIDIA GPU.
293pub fn init() -> CudaResult<()> {
294 oxicuda_driver::init()
295}
296
297/// Compile-time feature availability.
298pub mod features {
299 /// Whether PTX code generation is available.
300 pub const HAS_PTX: bool = cfg!(feature = "ptx");
301 /// Whether the autotuner is available.
302 pub const HAS_AUTOTUNE: bool = cfg!(feature = "autotune");
303 /// Whether BLAS operations are available.
304 pub const HAS_BLAS: bool = cfg!(feature = "blas");
305 /// Whether DNN operations are available.
306 pub const HAS_DNN: bool = cfg!(feature = "dnn");
307 /// Whether FFT operations are available.
308 pub const HAS_FFT: bool = cfg!(feature = "fft");
309 /// Whether sparse matrix operations are available.
310 pub const HAS_SPARSE: bool = cfg!(feature = "sparse");
311 /// Whether solver operations are available.
312 pub const HAS_SOLVER: bool = cfg!(feature = "solver");
313 /// Whether random number generation is available.
314 pub const HAS_RAND: bool = cfg!(feature = "rand");
315 /// Whether the abstract compute backend is available.
316 pub const HAS_BACKEND: bool = cfg!(feature = "backend");
317 /// Whether the ONNX inference backend is available.
318 pub const HAS_ONNX_BACKEND: bool = cfg!(feature = "onnx-backend");
319 /// Whether the ToRSh tensor backend is available.
320 pub const HAS_TENSOR_BACKEND: bool = cfg!(feature = "tensor-backend");
321 /// Whether the TrustformeRS transformer backend is available.
322 pub const HAS_TRANSFORMER_BACKEND: bool = cfg!(feature = "transformer-backend");
323 /// Whether stream-ordered memory pool is available.
324 pub const HAS_POOL: bool = cfg!(feature = "pool");
325 /// Whether GPU tests are enabled.
326 pub const HAS_GPU_TESTS: bool = cfg!(feature = "gpu-tests");
327 /// Whether global initialization is available (always `true`).
328 pub const HAS_GLOBAL_INIT: bool = true;
329 /// Whether the Vulkan Compute backend is available.
330 pub const HAS_VULKAN: bool = cfg!(feature = "vulkan");
331 /// Whether the Apple Metal Compute backend is available.
332 pub const HAS_METAL: bool = cfg!(feature = "metal");
333 /// Whether the WebGPU Compute backend is available.
334 pub const HAS_WEBGPU: bool = cfg!(feature = "webgpu");
335 /// Whether the AMD ROCm/HIP Compute backend is available.
336 pub const HAS_ROCM: bool = cfg!(feature = "rocm");
337 /// Whether the Intel Level Zero Compute backend is available.
338 pub const HAS_LEVEL_ZERO: bool = cfg!(feature = "level-zero");
339 /// Whether the WASM + WebGPU browser backend is available.
340 pub const HAS_WASM_BACKEND: bool = cfg!(feature = "wasm-backend");
341}
342
343// ---------------------------------------------------------------------------
344// ComputeBackend auto-selection threshold
345// ---------------------------------------------------------------------------
346
347/// Auto-selection threshold for the compute backend.
348///
349/// Tensors or data buffers larger than this threshold (in bytes) will be
350/// dispatched to the GPU backend; smaller workloads use the CPU backend.
351/// The 64 KB default is tuned for SciRS2 workloads where GPU launch overhead
352/// dominates for small matrices.
353pub const AUTO_SELECT_THRESHOLD_BYTES: usize = 64 * 1024; // 65536 bytes
354
355// ---------------------------------------------------------------------------
356// ONNX supported operator list
357// ---------------------------------------------------------------------------
358
359/// List of ONNX operators supported by the OxiCUDA ONNX backend.
360///
361/// This is the canonical list used both for operator dispatch and for
362/// validating ONNX model compatibility.
363pub const SUPPORTED_ONNX_OPS: &[&str] = &[
364 "MatMul",
365 "Conv",
366 "Relu",
367 "BatchNormalization",
368 "Softmax",
369 "LayerNormalization",
370 "Add",
371 "Mul",
372 "Transpose",
373 "Reshape",
374 "Concat",
375];
376
377// ---------------------------------------------------------------------------
378// Tests
379// ---------------------------------------------------------------------------
380
381#[cfg(test)]
382mod umbrella_tests {
383 use super::*;
384
385 // -----------------------------------------------------------------------
386 // ComputeBackend auto-selection threshold tests
387 // -----------------------------------------------------------------------
388
389 #[test]
390 fn compute_backend_threshold_is_64kb() {
391 assert_eq!(
392 AUTO_SELECT_THRESHOLD_BYTES,
393 64 * 1024,
394 "auto-select threshold must be exactly 64 KiB = 65536 bytes"
395 );
396 }
397
398 #[test]
399 fn small_tensor_uses_cpu_backend() {
400 // A tensor with < 64 KB data is below the threshold → CPU backend selected.
401 let small_data_bytes: usize = 1024; // 1 KB
402 assert!(
403 small_data_bytes < AUTO_SELECT_THRESHOLD_BYTES,
404 "1 KB should be below threshold → CPU backend"
405 );
406 }
407
408 #[test]
409 fn large_tensor_uses_gpu_backend() {
410 // A tensor with > 64 KB data is above the threshold → GPU backend attempted.
411 let large_data_bytes: usize = 1024 * 1024; // 1 MB
412 assert!(
413 large_data_bytes > AUTO_SELECT_THRESHOLD_BYTES,
414 "1 MB should be above threshold → GPU backend"
415 );
416 }
417
418 #[test]
419 fn threshold_boundary_values() {
420 // Exactly at threshold: not above → CPU backend.
421 const { assert!(AUTO_SELECT_THRESHOLD_BYTES <= AUTO_SELECT_THRESHOLD_BYTES) }
422 // One byte above threshold → GPU backend.
423 const { assert!(AUTO_SELECT_THRESHOLD_BYTES + 1 > AUTO_SELECT_THRESHOLD_BYTES) }
424 }
425
426 // -----------------------------------------------------------------------
427 // ONNX operator interface tests
428 // -----------------------------------------------------------------------
429
430 #[test]
431 fn onnx_matmul_op_name_correct() {
432 assert!(
433 SUPPORTED_ONNX_OPS.contains(&"MatMul"),
434 "SUPPORTED_ONNX_OPS must contain 'MatMul'"
435 );
436 }
437
438 #[test]
439 fn onnx_conv_op_name_correct() {
440 assert!(
441 SUPPORTED_ONNX_OPS.contains(&"Conv"),
442 "SUPPORTED_ONNX_OPS must contain 'Conv'"
443 );
444 }
445
446 #[test]
447 fn onnx_op_list_includes_relu() {
448 assert!(
449 SUPPORTED_ONNX_OPS.contains(&"Relu"),
450 "SUPPORTED_ONNX_OPS must contain 'Relu'"
451 );
452 }
453
454 #[test]
455 fn onnx_op_list_includes_softmax() {
456 assert!(
457 SUPPORTED_ONNX_OPS.contains(&"Softmax"),
458 "SUPPORTED_ONNX_OPS must contain 'Softmax'"
459 );
460 }
461
462 #[test]
463 fn onnx_op_list_includes_layer_norm() {
464 assert!(
465 SUPPORTED_ONNX_OPS.contains(&"LayerNormalization"),
466 "SUPPORTED_ONNX_OPS must contain 'LayerNormalization'"
467 );
468 }
469
470 #[test]
471 fn onnx_op_list_includes_batch_norm() {
472 assert!(
473 SUPPORTED_ONNX_OPS.contains(&"BatchNormalization"),
474 "SUPPORTED_ONNX_OPS must contain 'BatchNormalization'"
475 );
476 }
477
478 // -----------------------------------------------------------------------
479 // ToRSh SDPA + TrustformeRS MoE config tests
480 // (gated on transformer-backend feature)
481 // -----------------------------------------------------------------------
482
483 #[cfg(feature = "transformer-backend")]
484 mod transformer_tests {
485 use crate::transformer_backend::attention::ComputeTier;
486 use crate::transformer_backend::attention::{AttentionConfig, AttentionKind, HeadConfig};
487
488 #[test]
489 fn torsh_sdpa_attention_config_exists() {
490 // Verify AttentionConfig exists and supports FlashAttention dispatch.
491 let cfg = AttentionConfig {
492 head_config: HeadConfig::Mha { num_heads: 32 },
493 head_dim: 128,
494 use_paged_cache: false,
495 compute_tier: ComputeTier::Hopper,
496 sliding_window: None,
497 causal: true,
498 scale: None,
499 max_seq_len_hint: Some(4096),
500 };
501 // For Hopper + long sequences, kernel should be FlashHopper.
502 use crate::transformer_backend::attention::AttentionDispatch;
503 let dispatch = AttentionDispatch::new(cfg);
504 assert!(dispatch.is_ok(), "AttentionDispatch::new should succeed");
505 let mut dispatch = dispatch.expect("AttentionDispatch creation failed");
506 // With long sequences on Hopper, Flash or FlashHopper should be selected.
507 let kernel = dispatch.select_kernel(4096);
508 assert!(
509 matches!(kernel, AttentionKind::Flash | AttentionKind::FlashHopper),
510 "Hopper with 4096 tokens should use Flash attention, got {kernel:?}"
511 );
512 }
513 }
514
515 /// MoE configuration verification (CPU-side, no GPU required).
516 #[test]
517 fn trustformers_moe_config_exists() {
518 // Verify the MoE routing math: tokens_per_expert ≈ batch * seq * top_k / num_experts.
519 let num_experts: usize = 8;
520 let top_k: usize = 2;
521 let batch_size: usize = 4;
522 let seq_len: usize = 512;
523
524 // Expected tokens per expert on average (Mixtral 8x7B pattern).
525 let total_tokens = batch_size * seq_len;
526 let routed_tokens = total_tokens * top_k;
527 // Each of the 8 experts should receive routed_tokens / num_experts.
528 let tokens_per_expert = routed_tokens / num_experts;
529
530 // For batch=4, seq=512, top_k=2, num_experts=8:
531 // total = 2048, routed = 4096, per_expert = 512.
532 assert_eq!(total_tokens, 2048);
533 assert_eq!(routed_tokens, 4096);
534 assert_eq!(tokens_per_expert, 512);
535 }
536
537 #[test]
538 fn moe_mixtral_config_8x7b() {
539 // Verify the standard Mixtral 8x7B MoE routing configuration.
540 let num_experts: usize = 8;
541 let top_k: usize = 2;
542
543 // Activation rate: each token activates top_k / num_experts fraction of experts.
544 let activation_rate = top_k as f64 / num_experts as f64;
545 assert!(
546 (activation_rate - 0.25).abs() < 1e-10,
547 "Mixtral 8x7B: activation rate = {activation_rate}, expected 0.25"
548 );
549
550 // Load balance: with uniform routing, each expert receives the same
551 // expected number of tokens.
552 let batch_size: usize = 16;
553 let seq_len: usize = 1024;
554 let expected_per_expert = batch_size * seq_len * top_k / num_experts;
555 // 16 * 1024 * 2 / 8 = 4096 tokens per expert.
556 assert_eq!(expected_per_expert, 4096);
557 }
558}
559
560/// Convenience re-exports for common usage patterns.
561///
562/// ```no_run
563/// use oxicuda::prelude::*;
564/// ```
565pub mod prelude {
566 // Error handling
567 pub use crate::{CudaError, CudaResult};
568
569 // Initialization
570 pub use crate::{init, try_driver};
571
572 // Core GPU types
573 pub use crate::{Context, Device, Event, Function, Module, Stream};
574 pub use crate::{best_device, list_devices};
575
576 // Memory management
577 pub use crate::{DeviceBuffer, PinnedBuffer, UnifiedBuffer};
578
579 // Kernel launch
580 pub use crate::{Dim3, Kernel, KernelArgs, LaunchParams, grid_size_for};
581
582 // Global initialization
583 pub use crate::global_init::{
584 default_context, default_device, default_stream, is_initialized, lazy_init,
585 };
586
587 // Parallel primitives (feature = "primitives")
588 #[cfg(feature = "primitives")]
589 pub use oxicuda_primitives::{PrimitivesError, PrimitivesHandle, PrimitivesResult, ReduceOp};
590}