Skip to main content

bitnet_quantize/
config.rs

1//! Configuration for BitNet quantization and layers.
2
3use serde::{Deserialize, Serialize};
4
5/// Configuration for BitNet b1.58 quantization.
6///
7/// BitNet uses:
8/// - Ternary weights: {-1, 0, +1} via AbsMean quantization
9/// - INT8 activations: Per-token AbsMax scaling
10///
11/// # Reference
12///
13/// "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits"
14/// <https://arxiv.org/abs/2402.17764>
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct BitNetConfig {
17    /// Group size for weight quantization.
18    /// Weights are quantized in groups, each with its own scale.
19    /// Typical values: 64, 128, 256.
20    pub group_size: usize,
21
22    /// Number of bits for activation quantization.
23    /// BitNet b1.58 uses 8 bits (INT8).
24    pub activation_bits: u8,
25
26    /// Whether to use per-token activation scaling.
27    /// If true, each token gets its own scale factor.
28    /// If false, uses per-tensor scaling.
29    pub per_token_activation: bool,
30
31    /// Whether to apply RMS normalization before quantization.
32    pub use_rms_norm: bool,
33
34    /// Epsilon for numerical stability in normalization.
35    pub eps: f32,
36
37    /// Whether to enable Straight-Through Estimator for training.
38    pub enable_ste: bool,
39}
40
41impl Default for BitNetConfig {
42    fn default() -> Self {
43        Self {
44            group_size: 64,
45            activation_bits: 8,
46            per_token_activation: true,
47            use_rms_norm: true,
48            eps: 1e-5,
49            enable_ste: true,
50        }
51    }
52}
53
54impl BitNetConfig {
55    /// Create a new configuration with default values.
56    #[must_use]
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Create configuration optimized for inference.
62    ///
63    /// Disables training-specific features like STE.
64    #[must_use]
65    pub fn inference() -> Self {
66        Self {
67            enable_ste: false,
68            ..Default::default()
69        }
70    }
71
72    /// Create configuration for training.
73    ///
74    /// Enables STE for gradient estimation through quantization.
75    #[must_use]
76    pub fn training() -> Self {
77        Self {
78            enable_ste: true,
79            ..Default::default()
80        }
81    }
82
83    /// Set the group size for weight quantization.
84    #[must_use]
85    pub const fn with_group_size(mut self, group_size: usize) -> Self {
86        self.group_size = group_size;
87        self
88    }
89
90    /// Set the activation bit width.
91    #[must_use]
92    pub const fn with_activation_bits(mut self, bits: u8) -> Self {
93        self.activation_bits = bits;
94        self
95    }
96
97    /// Enable or disable per-token activation scaling.
98    #[must_use]
99    pub const fn with_per_token_activation(mut self, enabled: bool) -> Self {
100        self.per_token_activation = enabled;
101        self
102    }
103
104    /// Enable or disable RMS normalization.
105    #[must_use]
106    pub const fn with_rms_norm(mut self, enabled: bool) -> Self {
107        self.use_rms_norm = enabled;
108        self
109    }
110
111    /// Enable or disable Straight-Through Estimator.
112    #[must_use]
113    pub const fn with_ste(mut self, enabled: bool) -> Self {
114        self.enable_ste = enabled;
115        self
116    }
117
118    /// Validate the configuration.
119    ///
120    /// # Errors
121    ///
122    /// Returns error if configuration is invalid.
123    pub fn validate(&self) -> crate::Result<()> {
124        if self.group_size == 0 {
125            return Err(crate::BitNetError::InvalidConfig(
126                "group_size must be > 0".to_string(),
127            ));
128        }
129
130        if !self.group_size.is_power_of_two() {
131            return Err(crate::BitNetError::InvalidConfig(
132                "group_size must be a power of 2".to_string(),
133            ));
134        }
135
136        if self.activation_bits == 0 || self.activation_bits > 16 {
137            return Err(crate::BitNetError::InvalidConfig(
138                "activation_bits must be 1-16".to_string(),
139            ));
140        }
141
142        if self.eps <= 0.0 {
143            return Err(crate::BitNetError::InvalidConfig(
144                "eps must be > 0".to_string(),
145            ));
146        }
147
148        Ok(())
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155
156    #[test]
157    fn test_default_config() {
158        let config = BitNetConfig::default();
159        assert_eq!(config.group_size, 64);
160        assert_eq!(config.activation_bits, 8);
161        assert!(config.per_token_activation);
162        assert!(config.use_rms_norm);
163        assert!(config.enable_ste);
164    }
165
166    #[test]
167    fn test_inference_config() {
168        let config = BitNetConfig::inference();
169        assert!(!config.enable_ste);
170    }
171
172    #[test]
173    fn test_training_config() {
174        let config = BitNetConfig::training();
175        assert!(config.enable_ste);
176    }
177
178    #[test]
179    fn test_builder_pattern() {
180        let config = BitNetConfig::new()
181            .with_group_size(128)
182            .with_activation_bits(4)
183            .with_per_token_activation(false)
184            .with_ste(false);
185
186        assert_eq!(config.group_size, 128);
187        assert_eq!(config.activation_bits, 4);
188        assert!(!config.per_token_activation);
189        assert!(!config.enable_ste);
190    }
191
192    #[test]
193    fn test_validation() {
194        let valid = BitNetConfig::default();
195        assert!(valid.validate().is_ok());
196
197        let invalid_group = BitNetConfig {
198            group_size: 0,
199            ..Default::default()
200        };
201        assert!(invalid_group.validate().is_err());
202
203        let invalid_bits = BitNetConfig {
204            activation_bits: 0,
205            ..Default::default()
206        };
207        assert!(invalid_bits.validate().is_err());
208
209        let non_power_of_two = BitNetConfig {
210            group_size: 65,
211            ..Default::default()
212        };
213        assert!(non_power_of_two.validate().is_err());
214    }
215}