Skip to main content

bitnet_quantize/
lib.rs

1//! Microsoft BitNet b1.58 quantization and inference for Rust.
2//!
3//! This crate provides an implementation of BitNet, which uses:
4//! - **Ternary weights**: {-1, 0, +1} via AbsMean quantization
5//! - **INT8 activations**: Per-token AbsMax scaling
6//!
7//! # Features
8//!
9//! - `BitLinear`: Drop-in replacement for `nn::Linear`
10//! - Efficient ternary weight storage via `trit-vsa`
11//! - Straight-Through Estimator (STE) for training
12//! - Optional peft-rs adapter integration
13//! - Optional GGUF export via qlora-rs
14//!
15//! # Quick Start
16//!
17//! ```ignore
18//! use bitnet_quantize::{BitLinear, BitNetConfig};
19//! use candle_core::{Device, Tensor};
20//!
21//! let device = Device::Cpu;
22//! let config = BitNetConfig::default();
23//!
24//! // Create BitLinear from existing weights
25//! let weight = Tensor::randn(0.0f32, 1.0, (512, 256), &device)?;
26//! let layer = BitLinear::from_weight(&weight, None, &config)?;
27//!
28//! // Forward pass
29//! let input = Tensor::randn(0.0f32, 1.0, (4, 256), &device)?;
30//! let output = layer.forward(&input)?;
31//!
32//! println!("Compression ratio: {:.2}x", layer.compression_ratio());
33//! ```
34//!
35//! # Quantization
36//!
37//! ## Weight Quantization (AbsMean)
38//!
39//! Weights are quantized using the AbsMean method:
40//! ```text
41//! scale = mean(|W|)
42//! W_q = round(W / scale) clamped to {-1, 0, +1}
43//! ```
44//!
45//! ## Activation Quantization (AbsMax)
46//!
47//! Activations are quantized to INT8 using per-token AbsMax:
48//! ```text
49//! scale = max(|X|) / 127
50//! X_q = round(X / scale) clamped to [-127, 127]
51//! ```
52//!
53//! # Feature Flags
54//!
55//! - `default`: CPU-only
56//! - `cuda`: Enable CUDA GPU kernels
57//! - `peft`: Enable peft-rs adapter integration
58//! - `gguf-export`: Enable GGUF export via qlora-rs
59//!
60//! # References
61//!
62//! - "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits"
63//!   <https://arxiv.org/abs/2402.17764>
64
65#![warn(missing_docs)]
66#![warn(clippy::pedantic)]
67#![allow(clippy::module_name_repetitions)]
68#![allow(clippy::cast_possible_truncation)]
69#![allow(clippy::cast_possible_wrap)]
70#![allow(clippy::cast_sign_loss)]
71#![allow(clippy::cast_precision_loss)]
72#![allow(clippy::cast_lossless)]
73#![allow(clippy::doc_markdown)] // Many technical terms don't need backticks
74
75mod adapter;
76mod config;
77mod error;
78pub mod export;
79pub mod kernels;
80pub mod layer;
81pub mod quantization;
82
83pub use adapter::{BitNetAdapter, BitNetAdapterConfig};
84pub use config::BitNetConfig;
85pub use error::{BitNetError, Result};
86pub use layer::BitLinear;
87pub use quantization::{
88    dequantize_activations, dequantize_weights, quantize_activations, quantize_weights,
89    QuantizedActivations, TernaryWeight,
90};
91
92/// Prelude module for convenient imports.
93pub mod prelude {
94    pub use crate::config::BitNetConfig;
95    pub use crate::error::{BitNetError, Result};
96    pub use crate::layer::BitLinear;
97    pub use crate::quantization::{quantize_activations, quantize_weights};
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use candle_core::Device;
104    use candle_nn::Module;
105
106    #[test]
107    fn test_basic_workflow() {
108        let device = Device::Cpu;
109        let config = BitNetConfig::default().with_group_size(64);
110
111        // Create weight and quantize
112        let weight = candle_core::Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
113        let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
114
115        // Forward pass
116        let input = candle_core::Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
117        let output = layer.forward(&input).unwrap();
118
119        assert_eq!(output.shape().dims(), &[4, 64]);
120
121        // Check compression
122        let ratio = layer.compression_ratio();
123        assert!(ratio > 1.0, "should achieve compression");
124    }
125
126    #[test]
127    fn test_quantization_workflow() {
128        let device = Device::Cpu;
129        let config = BitNetConfig::default();
130
131        // Quantize weights
132        let weight = candle_core::Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
133        let ternary = quantize_weights(&weight, &config).unwrap();
134
135        // Check structure
136        assert_eq!(ternary.shape, (64, 128));
137        assert!(ternary.sparsity() >= 0.0);
138        assert!(ternary.sparsity() <= 1.0);
139
140        // Dequantize
141        let restored = dequantize_weights(&ternary, &device).unwrap();
142        assert_eq!(restored.shape().dims(), &[64, 128]);
143    }
144
145    #[test]
146    fn test_activation_quantization() {
147        let device = Device::Cpu;
148        let config = BitNetConfig::default();
149
150        let activations = candle_core::Tensor::randn(0.0f32, 1.0, (4, 64), &device).unwrap();
151        let quantized = quantize_activations(&activations, &config).unwrap();
152
153        assert_eq!(quantized.shape, vec![4, 64]);
154        assert_eq!(quantized.scales.len(), 4); // Per-token
155
156        let restored = dequantize_activations(&quantized, &device).unwrap();
157        assert_eq!(restored.shape().dims(), &[4, 64]);
158    }
159
160    #[test]
161    fn test_config_builder() {
162        let config = BitNetConfig::new()
163            .with_group_size(128)
164            .with_activation_bits(4)
165            .with_per_token_activation(false)
166            .with_rms_norm(false)
167            .with_ste(true);
168
169        assert_eq!(config.group_size, 128);
170        assert_eq!(config.activation_bits, 4);
171        assert!(!config.per_token_activation);
172        assert!(!config.use_rms_norm);
173        assert!(config.enable_ste);
174    }
175}