bitnet_quantize/quantization/
weight.rs1use candle_core::{Device, Tensor};
6use serde::{Deserialize, Serialize};
7use trit_vsa::PackedTritVec;
8
9use crate::config::BitNetConfig;
10use crate::error::{BitNetError, Result};
11
12#[derive(Clone, Serialize, Deserialize)]
17pub struct TernaryWeight {
18 pub data: Vec<PackedTritVec>,
21
22 pub scales: Vec<f32>,
25
26 pub shape: (usize, usize),
28
29 pub group_size: usize,
31}
32
33impl std::fmt::Debug for TernaryWeight {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 f.debug_struct("TernaryWeight")
36 .field("shape", &self.shape)
37 .field("group_size", &self.group_size)
38 .field("num_scales", &self.scales.len())
39 .field("sparsity", &self.sparsity())
40 .finish_non_exhaustive()
41 }
42}
43
44impl TernaryWeight {
45 #[must_use]
47 pub const fn out_features(&self) -> usize {
48 self.shape.0
49 }
50
51 #[must_use]
53 pub const fn in_features(&self) -> usize {
54 self.shape.1
55 }
56
57 #[must_use]
59 #[allow(clippy::cast_precision_loss)]
60 pub fn sparsity(&self) -> f32 {
61 let total_nonzero: usize = self.data.iter().map(PackedTritVec::count_nonzero).sum();
62 let total_elements = self.shape.0 * self.shape.1;
63 1.0 - (total_nonzero as f32 / total_elements as f32)
64 }
65
66 #[must_use]
68 pub fn memory_bytes(&self) -> usize {
69 let trit_bytes: usize = self.data.iter().map(|v| v.num_words() * 8).sum();
71 let scale_bytes = self.scales.len() * 4;
73 trit_bytes + scale_bytes
74 }
75
76 #[must_use]
78 #[allow(clippy::cast_precision_loss)]
79 pub fn compression_ratio(&self) -> f32 {
80 let fp32_bytes = self.shape.0 * self.shape.1 * 4;
81 fp32_bytes as f32 / self.memory_bytes() as f32
82 }
83}
84
85pub fn quantize_weights(weight: &Tensor, config: &BitNetConfig) -> Result<TernaryWeight> {
102 let shape = weight.shape().dims();
103 if shape.len() != 2 {
104 return Err(BitNetError::InvalidConfig(
105 "weight must be 2D [out_features, in_features]".to_string(),
106 ));
107 }
108
109 let out_features = shape[0];
110 let in_features = shape[1];
111 let group_size = config.group_size;
112
113 if !in_features.is_multiple_of(group_size) {
115 return Err(BitNetError::InvalidConfig(format!(
116 "in_features ({in_features}) must be divisible by group_size ({group_size})"
117 )));
118 }
119
120 let num_groups_per_row = in_features / group_size;
121 let mut scales = Vec::with_capacity(out_features * num_groups_per_row);
122 let mut data = Vec::with_capacity(out_features);
123
124 let weight_f32 = weight.to_dtype(candle_core::DType::F32)?.to_vec2::<f32>()?;
126
127 for row in &weight_f32 {
128 let mut packed = PackedTritVec::new(in_features);
129
130 for g in 0..num_groups_per_row {
131 let start = g * group_size;
132 let end = start + group_size;
133 let group = &row[start..end];
134
135 let abs_mean: f32 = group.iter().map(|x| x.abs()).sum::<f32>() / group_size as f32;
137 let scale = if abs_mean > 0.0 { abs_mean } else { 1.0 };
138 scales.push(scale);
139
140 for (i, &val) in group.iter().enumerate() {
142 let normalized = val / scale;
143 let quantized = normalized.round().clamp(-1.0, 1.0) as i8;
144 let trit = trit_vsa::Trit::from_value(quantized as i32)?;
145 packed.set(start + i, trit);
146 }
147 }
148
149 data.push(packed);
150 }
151
152 Ok(TernaryWeight {
153 data,
154 scales,
155 shape: (out_features, in_features),
156 group_size,
157 })
158}
159
160pub fn dequantize_weights(ternary: &TernaryWeight, device: &Device) -> Result<Tensor> {
171 let out_features = ternary.out_features();
172 let in_features = ternary.in_features();
173 let group_size = ternary.group_size;
174 let num_groups_per_row = in_features / group_size;
175
176 let mut output = vec![0.0f32; out_features * in_features];
177
178 for (row_idx, packed) in ternary.data.iter().enumerate() {
179 let row_start = row_idx * in_features;
180
181 for g in 0..num_groups_per_row {
182 let scale_idx = row_idx * num_groups_per_row + g;
183 let scale = ternary.scales[scale_idx];
184 let group_start = g * group_size;
185
186 for i in 0..group_size {
187 let trit = packed.get(group_start + i);
188 let value = trit.value() as f32 * scale;
189 output[row_start + group_start + i] = value;
190 }
191 }
192 }
193
194 let tensor = Tensor::from_vec(output, (out_features, in_features), device)?;
195 Ok(tensor)
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use candle_core::DType;
202
203 #[test]
204 fn test_quantize_dequantize_roundtrip() {
205 let device = Device::Cpu;
206 let config = BitNetConfig::default();
207
208 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
210
211 let ternary = quantize_weights(&weight, &config).unwrap();
213
214 assert_eq!(ternary.shape, (64, 128));
216 assert_eq!(ternary.data.len(), 64);
217 assert_eq!(ternary.scales.len(), 64 * (128 / 64)); let restored = dequantize_weights(&ternary, &device).unwrap();
221 assert_eq!(restored.shape().dims(), &[64, 128]);
222 }
223
224 #[test]
225 fn test_quantize_preserves_sign() {
226 let device = Device::Cpu;
227 let config = BitNetConfig::default().with_group_size(4);
228
229 let values: Vec<f32> = vec![1.0, -1.0, 0.5, -0.5, 2.0, -2.0, 0.1, -0.1];
231 let weight = Tensor::from_vec(values, (2, 4), &device).unwrap();
232
233 let ternary = quantize_weights(&weight, &config).unwrap();
234
235 assert_eq!(ternary.data[0].get(0), trit_vsa::Trit::P);
239 assert_eq!(ternary.data[0].get(1), trit_vsa::Trit::N);
240 }
241
242 #[test]
243 fn test_sparsity() {
244 let device = Device::Cpu;
245 let config = BitNetConfig::default().with_group_size(4);
246
247 let values: Vec<f32> = vec![1.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0];
249 let weight = Tensor::from_vec(values, (2, 4), &device).unwrap();
250
251 let ternary = quantize_weights(&weight, &config).unwrap();
252
253 let sparsity = ternary.sparsity();
255 assert!(sparsity > 0.5, "expected high sparsity, got {sparsity}");
256 }
257
258 #[test]
259 fn test_compression_ratio() {
260 let device = Device::Cpu;
261 let config = BitNetConfig::default();
262
263 let weight = Tensor::randn(0.0f32, 1.0, (1024, 4096), &device).unwrap();
264 let ternary = quantize_weights(&weight, &config).unwrap();
265
266 let ratio = ternary.compression_ratio();
267 assert!(ratio > 4.0, "expected >4x compression, got {ratio:.2}x");
269 }
270
271 #[test]
272 fn test_invalid_shape() {
273 let device = Device::Cpu;
274 let config = BitNetConfig::default();
275
276 let weight = Tensor::zeros(&[64], DType::F32, &device).unwrap();
278 assert!(quantize_weights(&weight, &config).is_err());
279
280 let weight = Tensor::zeros(&[2, 64, 64], DType::F32, &device).unwrap();
282 assert!(quantize_weights(&weight, &config).is_err());
283 }
284
285 #[test]
286 fn test_indivisible_group_size() {
287 let device = Device::Cpu;
288 let config = BitNetConfig::default().with_group_size(64);
289
290 let weight = Tensor::zeros(&[32, 100], DType::F32, &device).unwrap();
292 assert!(quantize_weights(&weight, &config).is_err());
293 }
294}