#![warn(missing_docs)]
#![warn(clippy::all)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::wildcard_imports)]
pub mod global_init;
pub use global_init::{DeviceSelection, OxiCudaRuntime, OxiCudaRuntimeBuilder};
pub mod profiling;
pub mod device_pool;
#[cfg(feature = "backend")]
pub mod backend;
#[cfg(feature = "onnx-backend")]
pub mod onnx_backend;
#[cfg(feature = "tensor-backend")]
pub mod tensor_backend;
#[cfg(feature = "transformer-backend")]
pub mod transformer_backend;
#[cfg(feature = "wasm-backend")]
pub mod wasm_backend;
#[cfg(feature = "wasm-backend")]
pub use wasm_backend::WasmComputeBackend;
pub mod collective;
pub mod pipeline_parallel;
pub mod distributed;
pub use oxicuda_driver as driver;
pub use oxicuda_memory as memory;
pub use oxicuda_launch as launch;
#[cfg(feature = "ptx")]
pub use oxicuda_ptx as ptx;
#[cfg(feature = "autotune")]
pub use oxicuda_autotune as autotune;
#[cfg(feature = "blas")]
pub use oxicuda_blas as blas;
#[cfg(feature = "dnn")]
pub use oxicuda_dnn as dnn;
#[cfg(feature = "fft")]
pub use oxicuda_fft as fft;
#[cfg(feature = "sparse")]
pub use oxicuda_sparse as sparse;
#[cfg(feature = "solver")]
pub use oxicuda_solver as solver;
#[cfg(feature = "rand")]
pub use oxicuda_rand as rand;
#[cfg(feature = "primitives")]
pub use oxicuda_primitives as primitives;
#[cfg(feature = "vulkan")]
pub use oxicuda_vulkan as vulkan;
#[cfg(feature = "metal")]
pub use oxicuda_metal as metal_backend;
#[cfg(feature = "webgpu")]
pub use oxicuda_webgpu as webgpu;
#[cfg(feature = "rocm")]
pub use oxicuda_rocm as rocm;
#[cfg(feature = "level-zero")]
pub use oxicuda_levelzero as level_zero;
pub use oxicuda_driver::{CudaError, CudaResult, DriverLoadError};
pub use oxicuda_driver::{
Context, Device, Event, Function, JitDiagnostic, JitLog, JitOptions, JitSeverity, Module,
Stream,
};
pub use oxicuda_driver::{best_device, list_devices, try_driver};
pub use oxicuda_memory::copy;
pub use oxicuda_memory::{DeviceBuffer, DeviceSlice, PinnedBuffer, UnifiedBuffer};
pub use oxicuda_launch::{
Dim3, Kernel, KernelArgs, LaunchParams, LaunchParamsBuilder, grid_size_for,
};
pub use oxicuda_launch::launch;
pub fn init() -> CudaResult<()> {
oxicuda_driver::init()
}
pub mod features {
pub const HAS_PTX: bool = cfg!(feature = "ptx");
pub const HAS_AUTOTUNE: bool = cfg!(feature = "autotune");
pub const HAS_BLAS: bool = cfg!(feature = "blas");
pub const HAS_DNN: bool = cfg!(feature = "dnn");
pub const HAS_FFT: bool = cfg!(feature = "fft");
pub const HAS_SPARSE: bool = cfg!(feature = "sparse");
pub const HAS_SOLVER: bool = cfg!(feature = "solver");
pub const HAS_RAND: bool = cfg!(feature = "rand");
pub const HAS_BACKEND: bool = cfg!(feature = "backend");
pub const HAS_ONNX_BACKEND: bool = cfg!(feature = "onnx-backend");
pub const HAS_TENSOR_BACKEND: bool = cfg!(feature = "tensor-backend");
pub const HAS_TRANSFORMER_BACKEND: bool = cfg!(feature = "transformer-backend");
pub const HAS_POOL: bool = cfg!(feature = "pool");
pub const HAS_GPU_TESTS: bool = cfg!(feature = "gpu-tests");
pub const HAS_GLOBAL_INIT: bool = true;
pub const HAS_VULKAN: bool = cfg!(feature = "vulkan");
pub const HAS_METAL: bool = cfg!(feature = "metal");
pub const HAS_WEBGPU: bool = cfg!(feature = "webgpu");
pub const HAS_ROCM: bool = cfg!(feature = "rocm");
pub const HAS_LEVEL_ZERO: bool = cfg!(feature = "level-zero");
pub const HAS_WASM_BACKEND: bool = cfg!(feature = "wasm-backend");
}
pub const AUTO_SELECT_THRESHOLD_BYTES: usize = 64 * 1024;
pub const SUPPORTED_ONNX_OPS: &[&str] = &[
"MatMul",
"Conv",
"Relu",
"BatchNormalization",
"Softmax",
"LayerNormalization",
"Add",
"Mul",
"Transpose",
"Reshape",
"Concat",
];
#[cfg(test)]
mod umbrella_tests {
use super::*;
#[test]
fn compute_backend_threshold_is_64kb() {
assert_eq!(
AUTO_SELECT_THRESHOLD_BYTES,
64 * 1024,
"auto-select threshold must be exactly 64 KiB = 65536 bytes"
);
}
#[test]
fn small_tensor_uses_cpu_backend() {
let small_data_bytes: usize = 1024; assert!(
small_data_bytes < AUTO_SELECT_THRESHOLD_BYTES,
"1 KB should be below threshold → CPU backend"
);
}
#[test]
fn large_tensor_uses_gpu_backend() {
let large_data_bytes: usize = 1024 * 1024; assert!(
large_data_bytes > AUTO_SELECT_THRESHOLD_BYTES,
"1 MB should be above threshold → GPU backend"
);
}
#[test]
fn threshold_boundary_values() {
const { assert!(AUTO_SELECT_THRESHOLD_BYTES <= AUTO_SELECT_THRESHOLD_BYTES) }
const { assert!(AUTO_SELECT_THRESHOLD_BYTES + 1 > AUTO_SELECT_THRESHOLD_BYTES) }
}
#[test]
fn onnx_matmul_op_name_correct() {
assert!(
SUPPORTED_ONNX_OPS.contains(&"MatMul"),
"SUPPORTED_ONNX_OPS must contain 'MatMul'"
);
}
#[test]
fn onnx_conv_op_name_correct() {
assert!(
SUPPORTED_ONNX_OPS.contains(&"Conv"),
"SUPPORTED_ONNX_OPS must contain 'Conv'"
);
}
#[test]
fn onnx_op_list_includes_relu() {
assert!(
SUPPORTED_ONNX_OPS.contains(&"Relu"),
"SUPPORTED_ONNX_OPS must contain 'Relu'"
);
}
#[test]
fn onnx_op_list_includes_softmax() {
assert!(
SUPPORTED_ONNX_OPS.contains(&"Softmax"),
"SUPPORTED_ONNX_OPS must contain 'Softmax'"
);
}
#[test]
fn onnx_op_list_includes_layer_norm() {
assert!(
SUPPORTED_ONNX_OPS.contains(&"LayerNormalization"),
"SUPPORTED_ONNX_OPS must contain 'LayerNormalization'"
);
}
#[test]
fn onnx_op_list_includes_batch_norm() {
assert!(
SUPPORTED_ONNX_OPS.contains(&"BatchNormalization"),
"SUPPORTED_ONNX_OPS must contain 'BatchNormalization'"
);
}
#[cfg(feature = "transformer-backend")]
mod transformer_tests {
use crate::transformer_backend::attention::ComputeTier;
use crate::transformer_backend::attention::{AttentionConfig, AttentionKind, HeadConfig};
#[test]
fn torsh_sdpa_attention_config_exists() {
let cfg = AttentionConfig {
head_config: HeadConfig::Mha { num_heads: 32 },
head_dim: 128,
use_paged_cache: false,
compute_tier: ComputeTier::Hopper,
sliding_window: None,
causal: true,
scale: None,
max_seq_len_hint: Some(4096),
};
use crate::transformer_backend::attention::AttentionDispatch;
let dispatch = AttentionDispatch::new(cfg);
assert!(dispatch.is_ok(), "AttentionDispatch::new should succeed");
let mut dispatch = dispatch.expect("AttentionDispatch creation failed");
let kernel = dispatch.select_kernel(4096);
assert!(
matches!(kernel, AttentionKind::Flash | AttentionKind::FlashHopper),
"Hopper with 4096 tokens should use Flash attention, got {kernel:?}"
);
}
}
#[test]
fn trustformers_moe_config_exists() {
let num_experts: usize = 8;
let top_k: usize = 2;
let batch_size: usize = 4;
let seq_len: usize = 512;
let total_tokens = batch_size * seq_len;
let routed_tokens = total_tokens * top_k;
let tokens_per_expert = routed_tokens / num_experts;
assert_eq!(total_tokens, 2048);
assert_eq!(routed_tokens, 4096);
assert_eq!(tokens_per_expert, 512);
}
#[test]
fn moe_mixtral_config_8x7b() {
let num_experts: usize = 8;
let top_k: usize = 2;
let activation_rate = top_k as f64 / num_experts as f64;
assert!(
(activation_rate - 0.25).abs() < 1e-10,
"Mixtral 8x7B: activation rate = {activation_rate}, expected 0.25"
);
let batch_size: usize = 16;
let seq_len: usize = 1024;
let expected_per_expert = batch_size * seq_len * top_k / num_experts;
assert_eq!(expected_per_expert, 4096);
}
}
pub mod prelude {
pub use crate::{CudaError, CudaResult};
pub use crate::{init, try_driver};
pub use crate::{Context, Device, Event, Function, Module, Stream};
pub use crate::{best_device, list_devices};
pub use crate::{DeviceBuffer, PinnedBuffer, UnifiedBuffer};
pub use crate::{Dim3, Kernel, KernelArgs, LaunchParams, grid_size_for};
pub use crate::global_init::{
default_context, default_device, default_stream, is_initialized, lazy_init,
};
#[cfg(feature = "primitives")]
pub use oxicuda_primitives::{PrimitivesError, PrimitivesHandle, PrimitivesResult, ReduceOp};
}