Skip to main content

bitnet_quantize/kernels/
mod.rs

1//! GPU kernels for BitNet quantization operations.
2//!
3//! This module provides CubeCL-based GPU kernels for efficient
4//! ternary weight x activation matrix multiplication.
5//!
6//! ## Kernels
7//!
8//! - `absmean_quantize` - Quantize weights to ternary {-1, 0, +1}
9//! - `ternary_dequantize` - Convert ternary back to float
10//! - `ternary_matmul_gpu` - Optimized ternary matmul (no multiply ops!)
11//! - `packed_ternary_matmul` - 2-bit packed weights for reduced bandwidth
12//! - `bitlinear_forward` - Fused LayerNorm + ternary matmul
13//!
14//! ## Feature Gate
15//!
16//! Requires the `cuda` feature to be enabled:
17//!
18//! ```toml
19//! [dependencies]
20//! bitnet-quantize = { version = "0.1", features = ["cuda"] }
21//! ```
22
23#[cfg(feature = "cuda")]
24mod cubecl;
25
26#[cfg(feature = "cuda")]
27pub use cubecl::{
28    // Core operations
29    absmean_quantize,
30    // Fused operations
31    bitlinear_forward,
32    // Utilities
33    has_cuda_support,
34    // Packed operations
35    pack_ternary_weights,
36    packed_ternary_matmul,
37    should_use_gpu,
38    ternary_dequantize,
39    ternary_matmul_gpu,
40    ternary_matmul_raw,
41    unpack_ternary_weights,
42};
43
44/// Check if CUDA kernels are available.
45#[must_use]
46pub fn cuda_available() -> bool {
47    #[cfg(feature = "cuda")]
48    {
49        cubecl::has_cuda_support()
50    }
51
52    #[cfg(not(feature = "cuda"))]
53    {
54        false
55    }
56}