#![warn(missing_docs)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_possible_wrap)]
#![allow(clippy::cast_sign_loss)]
#![allow(clippy::cast_precision_loss)]
#![allow(clippy::cast_lossless)]
#![allow(clippy::doc_markdown)]
mod adapter;
mod config;
mod error;
pub mod export;
pub mod kernels;
pub mod layer;
pub mod quantization;
pub use adapter::{BitNetAdapter, BitNetAdapterConfig};
pub use config::BitNetConfig;
pub use error::{BitNetError, Result};
pub use layer::BitLinear;
pub use quantization::{
dequantize_activations, dequantize_weights, quantize_activations, quantize_weights,
QuantizedActivations, TernaryWeight,
};
pub mod prelude {
pub use crate::config::BitNetConfig;
pub use crate::error::{BitNetError, Result};
pub use crate::layer::BitLinear;
pub use crate::quantization::{quantize_activations, quantize_weights};
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
use candle_nn::Module;
#[test]
fn test_basic_workflow() {
let device = Device::Cpu;
let config = BitNetConfig::default().with_group_size(64);
let weight = candle_core::Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
let layer = BitLinear::from_weight(&weight, None, &config).unwrap();
let input = candle_core::Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
let output = layer.forward(&input).unwrap();
assert_eq!(output.shape().dims(), &[4, 64]);
let ratio = layer.compression_ratio();
assert!(ratio > 1.0, "should achieve compression");
}
#[test]
fn test_quantization_workflow() {
let device = Device::Cpu;
let config = BitNetConfig::default();
let weight = candle_core::Tensor::randn(0.0f32, 1.0, (64, 128), &device).unwrap();
let ternary = quantize_weights(&weight, &config).unwrap();
assert_eq!(ternary.shape, (64, 128));
assert!(ternary.sparsity() >= 0.0);
assert!(ternary.sparsity() <= 1.0);
let restored = dequantize_weights(&ternary, &device).unwrap();
assert_eq!(restored.shape().dims(), &[64, 128]);
}
#[test]
fn test_activation_quantization() {
let device = Device::Cpu;
let config = BitNetConfig::default();
let activations = candle_core::Tensor::randn(0.0f32, 1.0, (4, 64), &device).unwrap();
let quantized = quantize_activations(&activations, &config).unwrap();
assert_eq!(quantized.shape, vec![4, 64]);
assert_eq!(quantized.scales.len(), 4);
let restored = dequantize_activations(&quantized, &device).unwrap();
assert_eq!(restored.shape().dims(), &[4, 64]);
}
#[test]
fn test_config_builder() {
let config = BitNetConfig::new()
.with_group_size(128)
.with_activation_bits(4)
.with_per_token_activation(false)
.with_rms_norm(false)
.with_ste(true);
assert_eq!(config.group_size, 128);
assert_eq!(config.activation_bits, 4);
assert!(!config.per_token_activation);
assert!(!config.use_rms_norm);
assert!(config.enable_ste);
}
}