1use candle_core::{Device, Tensor};
4use candle_nn::Module;
5
6use crate::config::BitNetConfig;
7use crate::error::Result;
8use crate::quantization::{
9 dequantize_activations, dequantize_weights, quantize_activations, quantize_weights,
10 TernaryWeight,
11};
12
13fn warn_cpu_fallback(device: &Device) {
14 static WARN_ONCE: std::sync::Once = std::sync::Once::new();
15 if matches!(device, Device::Cpu) {
16 WARN_ONCE.call_once(|| {
17 eprintln!(
18 "bitnet-quantize: CPU device in use. CUDA is the intended default; enable the 'cuda' feature and use Device::cuda_if_available(0) when possible."
19 );
20 });
21 }
22}
23
24#[derive(Debug)]
48pub struct BitLinear {
49 weight: TernaryWeight,
51
52 bias: Option<Tensor>,
54
55 config: BitNetConfig,
57
58 device: Device,
60}
61
62impl BitLinear {
63 pub fn from_weight(weight: &Tensor, bias: Option<&Tensor>, config: &BitNetConfig) -> Result<Self> {
75 config.validate()?;
76
77 let device = weight.device().clone();
78 warn_cpu_fallback(&device);
79 let quantized_weight = quantize_weights(weight, config)?;
80
81 Ok(Self {
82 weight: quantized_weight,
83 bias: bias.cloned(),
84 config: config.clone(),
85 device,
86 })
87 }
88
89 #[must_use]
98 pub fn from_quantized(
99 weight: TernaryWeight,
100 bias: Option<Tensor>,
101 config: BitNetConfig,
102 device: Device,
103 ) -> Self {
104 warn_cpu_fallback(&device);
105 Self {
106 weight,
107 bias,
108 config,
109 device,
110 }
111 }
112
113 #[must_use]
115 pub fn in_features(&self) -> usize {
116 self.weight.in_features()
117 }
118
119 #[must_use]
121 pub fn out_features(&self) -> usize {
122 self.weight.out_features()
123 }
124
125 #[must_use]
127 pub const fn quantized_weight(&self) -> &TernaryWeight {
128 &self.weight
129 }
130
131 #[must_use]
133 pub const fn bias(&self) -> Option<&Tensor> {
134 self.bias.as_ref()
135 }
136
137 #[must_use]
139 pub const fn config(&self) -> &BitNetConfig {
140 &self.config
141 }
142
143 #[must_use]
145 pub const fn device(&self) -> &Device {
146 &self.device
147 }
148
149 #[must_use]
151 pub fn sparsity(&self) -> f32 {
152 self.weight.sparsity()
153 }
154
155 #[must_use]
157 pub fn compression_ratio(&self) -> f32 {
158 self.weight.compression_ratio()
159 }
160
161 pub fn forward_quantized(&self, input: &Tensor) -> Result<Tensor> {
177 let quantized_input = quantize_activations(input, &self.config)?;
179 let dequant_input = dequantize_activations(&quantized_input, &self.device)?;
180
181 let dequant_weight = dequantize_weights(&self.weight, &self.device)?;
183
184 let output = dequant_input.matmul(&dequant_weight.t()?)?;
186
187 let output = if let Some(ref bias) = self.bias {
189 output.broadcast_add(bias)?
190 } else {
191 output
192 };
193
194 Ok(output)
195 }
196}
197
198impl Module for BitLinear {
199 fn forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
200 let dequant_weight = dequantize_weights(&self.weight, &self.device)
203 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
204
205 let dims = input.dims();
206 let output = if dims.len() == 3 {
207 let (batch, seq_len, hidden) = (dims[0], dims[1], dims[2]);
209 let flat_input = input.reshape((batch * seq_len, hidden))?;
210 let flat_output = flat_input.matmul(&dequant_weight.t()?)?;
211 flat_output.reshape((batch, seq_len, self.out_features()))?
212 } else {
213 input.matmul(&dequant_weight.t()?)?
215 };
216
217 let output = if let Some(ref bias) = self.bias {
218 output.broadcast_add(bias)?
219 } else {
220 output
221 };
222
223 Ok(output)
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_bitlinear_creation() {
233 let device = Device::Cpu;
234 let config = BitNetConfig::default();
235
236 let weight = Tensor::randn(0.0f32, 1.0, (128, 256), &device).unwrap();
237 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
238
239 assert_eq!(layer.in_features(), 256);
240 assert_eq!(layer.out_features(), 128);
241 }
242
243 #[test]
244 fn test_bitlinear_forward() {
245 let device = Device::Cpu;
246 let config = BitNetConfig::default().with_group_size(64);
247
248 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
249 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
250
251 let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
252 let output = layer.forward(&input).unwrap();
253
254 assert_eq!(output.shape().dims(), &[4, 64]);
255 }
256
257 #[test]
258 fn test_bitlinear_forward_quantized() {
259 let device = Device::Cpu;
260 let config = BitNetConfig::default().with_group_size(64);
261
262 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
263 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
264
265 let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
266 let output = layer.forward_quantized(&input).unwrap();
267
268 assert_eq!(output.shape().dims(), &[4, 64]);
269 }
270
271 #[test]
272 fn test_bitlinear_with_bias() {
273 let device = Device::Cpu;
274 let config = BitNetConfig::default().with_group_size(64);
275
276 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
277 let bias = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
278 let layer = BitLinear::from_weight(&weight, Some(&bias), &config).unwrap();
279
280 let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
281 let output = layer.forward(&input).unwrap();
282
283 assert_eq!(output.shape().dims(), &[4, 64]);
284 }
285
286 #[test]
287 fn test_bitlinear_3d_input() {
288 let device = Device::Cpu;
289 let config = BitNetConfig::default().with_group_size(64);
290
291 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
292 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
293
294 let input = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
296 let output = layer.forward(&input).unwrap();
297
298 assert_eq!(output.shape().dims(), &[2, 16, 64]);
299 }
300
301 #[test]
302 fn test_bitlinear_sparsity() {
303 let device = Device::Cpu;
304 let config = BitNetConfig::default().with_group_size(64);
305
306 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
307 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
308
309 let sparsity = layer.sparsity();
311 assert!(sparsity >= 0.0 && sparsity <= 1.0);
312 }
313
314 #[test]
315 fn test_bitlinear_compression() {
316 let device = Device::Cpu;
317 let config = BitNetConfig::default();
318
319 let weight = Tensor::randn(0.0f32, 1.0, (1024, 4096), &device).unwrap();
321 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
322
323 let ratio = layer.compression_ratio();
324 assert!(ratio > 1.0, "should achieve some compression");
325 }
326}