1#![warn(missing_docs)]
66#![warn(clippy::pedantic)]
67#![allow(clippy::module_name_repetitions)]
68#![allow(clippy::cast_possible_truncation)]
69#![allow(clippy::cast_possible_wrap)]
70#![allow(clippy::cast_sign_loss)]
71#![allow(clippy::cast_precision_loss)]
72#![allow(clippy::cast_lossless)]
73#![allow(clippy::doc_markdown)] mod adapter;
76mod config;
77mod error;
78pub mod export;
79pub mod kernels;
80pub mod layer;
81pub mod quantization;
82
83pub use adapter::{BitNetAdapter, BitNetAdapterConfig};
84pub use config::BitNetConfig;
85pub use error::{BitNetError, Result};
86pub use layer::BitLinear;
87pub use quantization::{
88 dequantize_activations, dequantize_weights, quantize_activations, quantize_weights,
89 QuantizedActivations, TernaryWeight,
90};
91
92pub mod prelude {
94 pub use crate::config::BitNetConfig;
95 pub use crate::error::{BitNetError, Result};
96 pub use crate::layer::BitLinear;
97 pub use crate::quantization::{quantize_activations, quantize_weights};
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use candle_core::Device;
104 use candle_nn::Module;
105
106 #[test]
107 fn test_basic_workflow() {
108 let device = Device::Cpu;
109 let config = BitNetConfig::default().with_group_size(64);
110
111 let weight = candle_core::Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
113 let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
114
115 let input = candle_core::Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
117 let output = layer.forward(&input).unwrap();
118
119 assert_eq!(output.shape().dims(), &[4, 64]);
120
121 let ratio = layer.compression_ratio();
123 assert!(ratio > 1.0, "should achieve compression");
124 }
125
126 #[test]
127 fn test_quantization_workflow() {
128 let device = Device::Cpu;
129 let config = BitNetConfig::default();
130
131 let weight = candle_core::Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
133 let ternary = quantize_weights(&weight, &config).unwrap();
134
135 assert_eq!(ternary.shape, (64, 128));
137 assert!(ternary.sparsity() >= 0.0);
138 assert!(ternary.sparsity() <= 1.0);
139
140 let restored = dequantize_weights(&ternary, &device).unwrap();
142 assert_eq!(restored.shape().dims(), &[64, 128]);
143 }
144
145 #[test]
146 fn test_activation_quantization() {
147 let device = Device::Cpu;
148 let config = BitNetConfig::default();
149
150 let activations = candle_core::Tensor::randn(0.0f32, 1.0, (4, 64), &device).unwrap();
151 let quantized = quantize_activations(&activations, &config).unwrap();
152
153 assert_eq!(quantized.shape, vec![4, 64]);
154 assert_eq!(quantized.scales.len(), 4); let restored = dequantize_activations(&quantized, &device).unwrap();
157 assert_eq!(restored.shape().dims(), &[4, 64]);
158 }
159
160 #[test]
161 fn test_config_builder() {
162 let config = BitNetConfig::new()
163 .with_group_size(128)
164 .with_activation_bits(4)
165 .with_per_token_activation(false)
166 .with_rms_norm(false)
167 .with_ste(true);
168
169 assert_eq!(config.group_size, 128);
170 assert_eq!(config.activation_bits, 4);
171 assert!(!config.per_token_activation);
172 assert!(!config.use_rms_norm);
173 assert!(config.enable_ste);
174 }
175}