bitnet_quantize/
config.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct BitNetConfig {
17 pub group_size: usize,
21
22 pub activation_bits: u8,
25
26 pub per_token_activation: bool,
30
31 pub use_rms_norm: bool,
33
34 pub eps: f32,
36
37 pub enable_ste: bool,
39}
40
41impl Default for BitNetConfig {
42 fn default() -> Self {
43 Self {
44 group_size: 64,
45 activation_bits: 8,
46 per_token_activation: true,
47 use_rms_norm: true,
48 eps: 1e-5,
49 enable_ste: true,
50 }
51 }
52}
53
54impl BitNetConfig {
55 #[must_use]
57 pub fn new() -> Self {
58 Self::default()
59 }
60
61 #[must_use]
65 pub fn inference() -> Self {
66 Self {
67 enable_ste: false,
68 ..Default::default()
69 }
70 }
71
72 #[must_use]
76 pub fn training() -> Self {
77 Self {
78 enable_ste: true,
79 ..Default::default()
80 }
81 }
82
83 #[must_use]
85 pub const fn with_group_size(mut self, group_size: usize) -> Self {
86 self.group_size = group_size;
87 self
88 }
89
90 #[must_use]
92 pub const fn with_activation_bits(mut self, bits: u8) -> Self {
93 self.activation_bits = bits;
94 self
95 }
96
97 #[must_use]
99 pub const fn with_per_token_activation(mut self, enabled: bool) -> Self {
100 self.per_token_activation = enabled;
101 self
102 }
103
104 #[must_use]
106 pub const fn with_rms_norm(mut self, enabled: bool) -> Self {
107 self.use_rms_norm = enabled;
108 self
109 }
110
111 #[must_use]
113 pub const fn with_ste(mut self, enabled: bool) -> Self {
114 self.enable_ste = enabled;
115 self
116 }
117
118 pub fn validate(&self) -> crate::Result<()> {
124 if self.group_size == 0 {
125 return Err(crate::BitNetError::InvalidConfig(
126 "group_size must be > 0".to_string(),
127 ));
128 }
129
130 if !self.group_size.is_power_of_two() {
131 return Err(crate::BitNetError::InvalidConfig(
132 "group_size must be a power of 2".to_string(),
133 ));
134 }
135
136 if self.activation_bits == 0 || self.activation_bits > 16 {
137 return Err(crate::BitNetError::InvalidConfig(
138 "activation_bits must be 1-16".to_string(),
139 ));
140 }
141
142 if self.eps <= 0.0 {
143 return Err(crate::BitNetError::InvalidConfig(
144 "eps must be > 0".to_string(),
145 ));
146 }
147
148 Ok(())
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155
156 #[test]
157 fn test_default_config() {
158 let config = BitNetConfig::default();
159 assert_eq!(config.group_size, 64);
160 assert_eq!(config.activation_bits, 8);
161 assert!(config.per_token_activation);
162 assert!(config.use_rms_norm);
163 assert!(config.enable_ste);
164 }
165
166 #[test]
167 fn test_inference_config() {
168 let config = BitNetConfig::inference();
169 assert!(!config.enable_ste);
170 }
171
172 #[test]
173 fn test_training_config() {
174 let config = BitNetConfig::training();
175 assert!(config.enable_ste);
176 }
177
178 #[test]
179 fn test_builder_pattern() {
180 let config = BitNetConfig::new()
181 .with_group_size(128)
182 .with_activation_bits(4)
183 .with_per_token_activation(false)
184 .with_ste(false);
185
186 assert_eq!(config.group_size, 128);
187 assert_eq!(config.activation_bits, 4);
188 assert!(!config.per_token_activation);
189 assert!(!config.enable_ste);
190 }
191
192 #[test]
193 fn test_validation() {
194 let valid = BitNetConfig::default();
195 assert!(valid.validate().is_ok());
196
197 let invalid_group = BitNetConfig {
198 group_size: 0,
199 ..Default::default()
200 };
201 assert!(invalid_group.validate().is_err());
202
203 let invalid_bits = BitNetConfig {
204 activation_bits: 0,
205 ..Default::default()
206 };
207 assert!(invalid_bits.validate().is_err());
208
209 let non_power_of_two = BitNetConfig {
210 group_size: 65,
211 ..Default::default()
212 };
213 assert!(non_power_of_two.validate().is_err());
214 }
215}