bitnet_quantize/kernels/mod.rs
1//! GPU kernels for BitNet quantization operations.
2//!
3//! This module provides CubeCL-based GPU kernels for efficient
4//! ternary weight x activation matrix multiplication.
5//!
6//! ## Kernels
7//!
8//! - `absmean_quantize` - Quantize weights to ternary {-1, 0, +1}
9//! - `ternary_dequantize` - Convert ternary back to float
10//! - `ternary_matmul_gpu` - Optimized ternary matmul (no multiply ops!)
11//! - `packed_ternary_matmul` - 2-bit packed weights for reduced bandwidth
12//! - `bitlinear_forward` - Fused LayerNorm + ternary matmul
13//!
14//! ## Feature Gate
15//!
16//! Requires the `cuda` feature to be enabled:
17//!
18//! ```toml
19//! [dependencies]
20//! bitnet-quantize = { version = "0.1", features = ["cuda"] }
21//! ```
22
23#[cfg(feature = "cuda")]
24mod cubecl;
25
26#[cfg(feature = "cuda")]
27pub use cubecl::{
28 // Core operations
29 absmean_quantize,
30 // Fused operations
31 bitlinear_forward,
32 // Utilities
33 has_cuda_support,
34 // Packed operations
35 pack_ternary_weights,
36 packed_ternary_matmul,
37 should_use_gpu,
38 ternary_dequantize,
39 ternary_matmul_gpu,
40 ternary_matmul_raw,
41 unpack_ternary_weights,
42};
43
44/// Check if CUDA kernels are available.
45#[must_use]
46pub fn cuda_available() -> bool {
47 #[cfg(feature = "cuda")]
48 {
49 cubecl::has_cuda_support()
50 }
51
52 #[cfg(not(feature = "cuda"))]
53 {
54 false
55 }
56}