1use candle_core::{Device, Tensor};
4use candle_nn::Module;
5
6use crate::config::BitNetConfig;
7use crate::error::Result;
8use crate::quantization::{
9 dequantize_activations, dequantize_weights, quantize_activations, quantize_weights,
10 TernaryWeight,
11};
12
13fn warn_cpu_fallback(device: &Device) {
18 static WARN_ONCE: std::sync::Once = std::sync::Once::new();
19 if matches!(device, Device::Cpu) {
20 WARN_ONCE.call_once(|| {
21 eprintln!(
22 "bitnet-quantize: CPU device in use. CUDA is the intended default; \
23 enable the 'cuda' feature and use Device::cuda_if_available(0) when possible."
24 );
25 });
26 }
27}
28
29#[derive(Debug)]
53pub struct BitLinear {
54 weight: TernaryWeight,
56
57 bias: Option<Tensor>,
59
60 config: BitNetConfig,
62
63 device: Device,
65}
66
67impl BitLinear {
68 pub fn from_weight(
80 weight: &Tensor,
81 bias: Option<&Tensor>,
82 config: &BitNetConfig,
83 ) -> Result<Self> {
84 config.validate()?;
85
86 let device = weight.device().clone();
87 warn_cpu_fallback(&device);
88 let quantized_weight = quantize_weights(weight, config)?;
89
90 Ok(Self {
91 weight: quantized_weight,
92 bias: bias.cloned(),
93 config: config.clone(),
94 device,
95 })
96 }
97
98 #[must_use]
107 pub fn from_quantized(
108 weight: TernaryWeight,
109 bias: Option<Tensor>,
110 config: BitNetConfig,
111 device: Device,
112 ) -> Self {
113 warn_cpu_fallback(&device);
114 Self {
115 weight,
116 bias,
117 config,
118 device,
119 }
120 }
121
122 #[must_use]
124 pub fn in_features(&self) -> usize {
125 self.weight.in_features()
126 }
127
128 #[must_use]
130 pub fn out_features(&self) -> usize {
131 self.weight.out_features()
132 }
133
134 #[must_use]
136 pub const fn quantized_weight(&self) -> &TernaryWeight {
137 &self.weight
138 }
139
140 #[must_use]
142 pub const fn bias(&self) -> Option<&Tensor> {
143 self.bias.as_ref()
144 }
145
146 #[must_use]
148 pub const fn config(&self) -> &BitNetConfig {
149 &self.config
150 }
151
152 #[must_use]
154 pub const fn device(&self) -> &Device {
155 &self.device
156 }
157
158 #[must_use]
160 pub fn sparsity(&self) -> f32 {
161 self.weight.sparsity()
162 }
163
164 #[must_use]
166 pub fn compression_ratio(&self) -> f32 {
167 self.weight.compression_ratio()
168 }
169
170 pub fn forward_quantized(&self, input: &Tensor) -> Result<Tensor> {
189 let quantized_input = quantize_activations(input, &self.config)?;
191 let dequant_input = dequantize_activations(&quantized_input, &self.device)?;
192
193 #[cfg(feature = "cuda")]
195 let output = {
196 if crate::kernels::should_use_gpu(&dequant_input, &self.weight) {
197 crate::kernels::ternary_matmul_gpu(&dequant_input, &self.weight)?
198 } else {
199 let dequant_weight = dequantize_weights(&self.weight, &self.device)?;
201 dequant_input.matmul(&dequant_weight.t()?)?
202 }
203 };
204
205 #[cfg(not(feature = "cuda"))]
206 let output = {
207 let dequant_weight = dequantize_weights(&self.weight, &self.device)?;
209 dequant_input.matmul(&dequant_weight.t()?)?
210 };
211
212 let output = if let Some(ref bias) = self.bias {
214 output.broadcast_add(bias)?
215 } else {
216 output
217 };
218
219 Ok(output)
220 }
221}
222
223impl Module for BitLinear {
224 fn forward(&self, input: &Tensor) -> candle_core::Result<Tensor> {
225 let dims = input.dims();
226
227 let (flat_input, original_shape) = if dims.len() == 3 {
229 let (batch, seq_len, hidden) = (dims[0], dims[1], dims[2]);
230 (
231 input.reshape((batch * seq_len, hidden))?,
232 Some((batch, seq_len)),
233 )
234 } else {
235 (input.clone(), None)
236 };
237
238 #[cfg(feature = "cuda")]
240 let output = {
241 if crate::kernels::cuda_available()
242 && crate::kernels::should_use_gpu(&flat_input, &self.weight)
243 {
244 crate::kernels::ternary_matmul_gpu(&flat_input, &self.weight)
245 .map_err(|e| candle_core::Error::Msg(e.to_string()))?
246 } else {
247 let dequant_weight = dequantize_weights(&self.weight, &self.device)
249 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
250 flat_input.matmul(&dequant_weight.t()?)?
251 }
252 };
253
254 #[cfg(not(feature = "cuda"))]
255 let output = {
256 let dequant_weight = dequantize_weights(&self.weight, &self.device)
257 .map_err(|e| candle_core::Error::Msg(e.to_string()))?;
258 flat_input.matmul(&dequant_weight.t()?)?
259 };
260
261 let output = if let Some((batch, seq_len)) = original_shape {
263 output.reshape((batch, seq_len, self.out_features()))?
264 } else {
265 output
266 };
267
268 let output = if let Some(ref bias) = self.bias {
270 output.broadcast_add(bias)?
271 } else {
272 output
273 };
274
275 Ok(output)
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_bitlinear_creation() {
285 let device = Device::Cpu;
286 let config = BitNetConfig::default();
287
288 let weight = Tensor::randn(0.0f32, 1.0, (128, 256), &device).unwrap();
289 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
290
291 assert_eq!(layer.in_features(), 256);
292 assert_eq!(layer.out_features(), 128);
293 }
294
295 #[test]
296 fn test_bitlinear_forward() {
297 let device = Device::Cpu;
298 let config = BitNetConfig::default().with_group_size(64);
299
300 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
301 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
302
303 let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
304 let output = layer.forward(&input).unwrap();
305
306 assert_eq!(output.shape().dims(), &[4, 64]);
307 }
308
309 #[test]
310 fn test_bitlinear_forward_quantized() {
311 let device = Device::Cpu;
312 let config = BitNetConfig::default().with_group_size(64);
313
314 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
315 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
316
317 let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
318 let output = layer.forward_quantized(&input).unwrap();
319
320 assert_eq!(output.shape().dims(), &[4, 64]);
321 }
322
323 #[test]
324 fn test_bitlinear_with_bias() {
325 let device = Device::Cpu;
326 let config = BitNetConfig::default().with_group_size(64);
327
328 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
329 let bias = Tensor::randn(0.0f32, 1.0, (64,), &device).unwrap();
330 let layer = BitLinear::from_weight(&weight, Some(&bias), &config).unwrap();
331
332 let input = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
333 let output = layer.forward(&input).unwrap();
334
335 assert_eq!(output.shape().dims(), &[4, 64]);
336 }
337
338 #[test]
339 fn test_bitlinear_3d_input() {
340 let device = Device::Cpu;
341 let config = BitNetConfig::default().with_group_size(64);
342
343 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
344 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
345
346 let input = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
348 let output = layer.forward(&input).unwrap();
349
350 assert_eq!(output.shape().dims(), &[2, 16, 64]);
351 }
352
353 #[test]
354 fn test_bitlinear_sparsity() {
355 let device = Device::Cpu;
356 let config = BitNetConfig::default().with_group_size(64);
357
358 let weight = Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
359 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
360
361 let sparsity = layer.sparsity();
363 assert!(sparsity >= 0.0 && sparsity <= 1.0);
364 }
365
366 #[test]
367 fn test_bitlinear_compression() {
368 let device = Device::Cpu;
369 let config = BitNetConfig::default();
370
371 let weight = Tensor::randn(0.0f32, 1.0, (1024, 4096), &device).unwrap();
373 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
374
375 let ratio = layer.compression_ratio();
376 assert!(ratio > 1.0, "should achieve some compression");
377 }
378}