oxicuda-dnn
GPU-accelerated deep learning primitives -- a pure Rust cuDNN equivalent.
Part of the OxiCUDA project.
Overview
oxicuda-dnn delivers the core building blocks for training and inference of
deep neural networks on NVIDIA GPUs, implemented entirely in Rust with no
C/Fortran dependencies. It covers convolution (five algorithm families with
forward and backward passes), multi-head attention with FlashAttention-2, MoE
routing, normalization layers, pooling, resize, and quantization.
The crate builds on oxicuda-blas for GEMM-based algorithms (im2col
convolution, MoE grouped GEMM) and on oxicuda-ptx for runtime PTX kernel
generation. DnnHandle manages a CUDA stream, a BLAS sub-handle, and a PTX
cache so that compiled kernels are reused across calls.
Algorithm selection is automatic: the convolution dispatcher benchmarks implicit-GEMM, im2col+GEMM, Winograd, direct, and FFT-based strategies and picks the fastest for each problem shape. Fused kernels (conv+BN+ReLU, LayerNorm+activation, fused_add_rms_norm) are provided to minimize global memory traffic.
Modules
| Module | Description |
|---|---|
handle |
DnnHandle -- central entry point binding context, stream, BLAS, PTX cache |
types |
TensorDesc, TensorLayout (NCHW/NHWC), Activation, ConvolutionDescriptor |
conv |
Convolution forward (5 algos), backward dgrad/wgrad, fused conv+BN+ReLU |
attn |
FlashAttention-2 (fwd+bwd), PagedAttention, decode attention, RoPE, KV-cache |
moe |
MoE routing (top-k softmax), token permutation, fused MoE kernel |
norm |
LayerNorm, RMSNorm, BatchNorm, GroupNorm, fused variants |
pool |
MaxPool2D, AvgPool2D, AdaptivePool, GlobalPool |
resize |
Nearest, bilinear, bicubic interpolation |
quantize |
FP8 (E4M3/E5M2), INT8 (symmetric/asymmetric), block-scaled FP4 |
error |
DnnError / DnnResult |
Quick Start
use Arc;
use Context;
use *;
Supported Operations
Convolution
| Algorithm | Forward | dgrad | wgrad |
|---|---|---|---|
| Implicit GEMM | yes | yes | yes |
| im2col + GEMM | yes | -- | -- |
| Winograd F(2,3) / F(4,3) | yes | -- | -- |
| Direct (1x1, depthwise) | yes | -- | -- |
| FFT-based | yes | -- | -- |
Fused: conv + BatchNorm + ReLU in a single kernel launch.
Attention
- FlashAttention-2 forward and backward with online softmax
- PagedAttention for LLM KV-cache serving
- Decode attention for autoregressive generation
- Rotary Position Embedding (RoPE)
Mixture of Experts (MoE)
- Top-k softmax routing with load balancing
- Token permutation / unpermutation
- Fused MoE kernel (TokenParallel / ExpertParallel strategies)
- Grouped GEMM backend
Normalization
| Layer | Training | Inference | Fused Variants |
|---|---|---|---|
| LayerNorm | yes | yes | LN + ReLU |
| RMSNorm | yes | yes | fused_add_rms_norm, RMSNorm + SiLU |
| BatchNorm | yes | yes | conv + BN + ReLU |
| GroupNorm | yes | yes | -- |
Pooling and Resize
- MaxPool2D, AvgPool2D, AdaptivePool, GlobalPool
- Nearest, bilinear, bicubic interpolation resize
Quantization
- FP8: E4M3 and E5M2 quantize / dequantize
- INT8: symmetric and asymmetric per-tensor / per-channel
- Block-scaled FP4 for weight compression
Feature Flags
| Feature | Description |
|---|---|
f16 |
Enable FP16 / BF16 tensor support (enables oxicuda-blas/f16) |
Tensor Layouts
NCHW (PyTorch default), NHWC (Tensor Core optimal), NCDHW, NDHWC (3-D), and generic RowMajor for 2-D intermediates. Channels-last layouts (NHWC/NDHWC) are recommended for best Tensor Core utilization.
Performance Targets
Convolution forward aims for 90% of cuDNN throughput on common CNN shapes (ResNet-50, EfficientNet). FlashAttention-2 targets parity with the reference Tri Dao kernel at sequence lengths 512--8192.
Status
| Item | Value |
|---|---|
| Version | 0.1.5 |
| Release date | 2026-05-01 |
| Tests | 960 passing |
| Warnings | 0 (clippy clean) |
unwrap() |
0 (production code) |
License
Apache-2.0 -- (C) 2026 COOLJAPAN OU (Team KitaSan)