1use candle_core::{Device, Tensor};
4use candle_nn::Module;
5
6use crate::config::BitNetConfig;
7use crate::error::Result;
8use crate::quantization::{
9 dequantize_activations, dequantize_weights, quantize_activations, quantize_weights,
10 TernaryWeight,
11};
12
13#[derive(Debug)]
37pub struct BitLinear {
38 weight: TernaryWeight,
40
41 bias: Option<Tensor>,
43
44 config: BitNetConfig,
46
47 device: Device,
49}
50
51impl BitLinear {
52 pub fn from_weight(weight: &Tensor, bias: Option<&Tensor>, config: &BitNetConfig) -> Result<Self> {
64 config.validate()?;
65
66 let device = weight.device().clone();
67 let quantized_weight = quantize_weights(weight, config)?;
68
69 Ok(Self {
70 weight: quantized_weight,
71 bias: bias.cloned(),
72 config: config.clone(),
73 device,
74 })
75 }
76
77 #[must_use]
86 pub fn from_quantized(
87 weight: TernaryWeight,
88 bias: Option<Tensor>,
89 config: BitNetConfig,
90 device: Device,
91 ) -> Self {
92 Self {
93 weight,
94 bias,
95 config,
96 device,
97 }
98 }
99
100 #[must_use]
102 pub fn in_features(&self) -> usize {
103 self.weight.in_features()
104 }
105
106 #[must_use]
108 pub fn out_features(&self) -> usize {
109 self.weight.out_features()
110 }
111
112 #[must_use]
114 pub const fn quantized_weight(&self) -> &TernaryWeight {
115 &self.weight
116 }
117
118 #[must_use]
120 pub const fn bias(&self) -> Option<&Tensor> {
121 self.bias.as_ref()
122 }
123
124 #[must_use]
126 pub const fn config(&self) -> &BitNetConfig {
127 &self.config
128 }
129
130 #[must_use]
132 pub const fn device(&self) -> &Device {
133 &self.device
134 }
135
136 #[must_use]
138 pub fn sparsity(&self) -> f32 {
139 self.weight.sparsity()
140 }
141
142 #[must_use]
144 pub fn compression_ratio(&self) -> f32 {
145 self.weight.compression_ratio()
146 }
147
148 pub fn forward_quantized(&self, input: &Tensor) -> Result<Tensor> {
164 let quantized_input = quantize_activations(input, &self.config)?;
166 let dequant_input = dequantize_activations(&quantized_input, &self.device)?;
167
168 let dequant_weight = dequantize_weights(&self.weight, &self.device)?;
170
171 let output = dequant_input.matmul(&dequant_weight.t()?)?;
173
174 let output = if let Some(ref bias) = self.bias {
176 output.broadcast_add(bias)?
177 } else {
178 output
179 };
180
181 Ok(output)
182 }
183}
184
185impl Module for BitLinear {
186 fn forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
187 let dequant_weight = dequantize_weights(&self.weight, &self.device)
190 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
191
192 let dims = input.dims();
193 let output = if dims.len() == 3 {
194 let (batch, seq_len, hidden) = (dims[0], dims[1], dims[2]);
196 let flat_input = input.reshape((batch * seq_len, hidden))?;
197 let flat_output = flat_input.matmul(&dequant_weight.t()?)?;
198 flat_output.reshape((batch, seq_len, self.out_features()))?
199 } else {
200 input.matmul(&dequant_weight.t()?)?
202 };
203
204 let output = if let Some(ref bias) = self.bias {
205 output.broadcast_add(bias)?
206 } else {
207 output
208 };
209
210 Ok(output)
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn test_bitlinear_creation() {
220 let device = Device::Cpu;
221 let config = BitNetConfig::default();
222
223 let weight = Tensor::randn(0.0f32, 1.0, (128, 256), &device).unwrap();
224 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
225
226 assert_eq!(layer.in_features(), 256);
227 assert_eq!(layer.out_features(), 128);
228 }
229
230 #[test]
231 fn test_bitlinear_forward() {
232 let device = Device::Cpu;
233 let config = BitNetConfig::default().with_group_size(64);
234
235 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
236 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
237
238 let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
239 let output = layer.forward(&input).unwrap();
240
241 assert_eq!(output.shape().dims(), &[4, 64]);
242 }
243
244 #[test]
245 fn test_bitlinear_forward_quantized() {
246 let device = Device::Cpu;
247 let config = BitNetConfig::default().with_group_size(64);
248
249 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
250 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
251
252 let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
253 let output = layer.forward_quantized(&input).unwrap();
254
255 assert_eq!(output.shape().dims(), &[4, 64]);
256 }
257
258 #[test]
259 fn test_bitlinear_with_bias() {
260 let device = Device::Cpu;
261 let config = BitNetConfig::default().with_group_size(64);
262
263 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
264 let bias = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
265 let layer = BitLinear::from_weight(&weight, Some(&bias), &config).unwrap();
266
267 let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
268 let output = layer.forward(&input).unwrap();
269
270 assert_eq!(output.shape().dims(), &[4, 64]);
271 }
272
273 #[test]
274 fn test_bitlinear_3d_input() {
275 let device = Device::Cpu;
276 let config = BitNetConfig::default().with_group_size(64);
277
278 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
279 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
280
281 let input = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
283 let output = layer.forward(&input).unwrap();
284
285 assert_eq!(output.shape().dims(), &[2, 16, 64]);
286 }
287
288 #[test]
289 fn test_bitlinear_sparsity() {
290 let device = Device::Cpu;
291 let config = BitNetConfig::default().with_group_size(64);
292
293 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
294 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
295
296 let sparsity = layer.sparsity();
298 assert!(sparsity >= 0.0 && sparsity <= 1.0);
299 }
300
301 #[test]
302 fn test_bitlinear_compression() {
303 let device = Device::Cpu;
304 let config = BitNetConfig::default();
305
306 let weight = Tensor::randn(0.0f32, 1.0, (1024, 4096), &device).unwrap();
308 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
309
310 let ratio = layer.compression_ratio();
311 assert!(ratio > 1.0, "should achieve some compression");
312 }
313}