use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BitNetConfig {
pub group_size: usize,
pub activation_bits: u8,
pub per_token_activation: bool,
pub use_rms_norm: bool,
pub eps: f32,
pub enable_ste: bool,
}
impl Default for BitNetConfig {
fn default() -> Self {
Self {
group_size: 64,
activation_bits: 8,
per_token_activation: true,
use_rms_norm: true,
eps: 1e-5,
enable_ste: true,
}
}
}
impl BitNetConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn inference() -> Self {
Self {
enable_ste: false,
..Default::default()
}
}
#[must_use]
pub fn training() -> Self {
Self {
enable_ste: true,
..Default::default()
}
}
#[must_use]
pub const fn with_group_size(mut self, group_size: usize) -> Self {
self.group_size = group_size;
self
}
#[must_use]
pub const fn with_activation_bits(mut self, bits: u8) -> Self {
self.activation_bits = bits;
self
}
#[must_use]
pub const fn with_per_token_activation(mut self, enabled: bool) -> Self {
self.per_token_activation = enabled;
self
}
#[must_use]
pub const fn with_rms_norm(mut self, enabled: bool) -> Self {
self.use_rms_norm = enabled;
self
}
#[must_use]
pub const fn with_ste(mut self, enabled: bool) -> Self {
self.enable_ste = enabled;
self
}
pub fn validate(&self) -> crate::Result<()> {
if self.group_size == 0 {
return Err(crate::BitNetError::InvalidConfig(
"group_size must be > 0".to_string(),
));
}
if !self.group_size.is_power_of_two() {
return Err(crate::BitNetError::InvalidConfig(
"group_size must be a power of 2".to_string(),
));
}
if self.activation_bits == 0 || self.activation_bits > 16 {
return Err(crate::BitNetError::InvalidConfig(
"activation_bits must be 1-16".to_string(),
));
}
if self.eps <= 0.0 {
return Err(crate::BitNetError::InvalidConfig(
"eps must be > 0".to_string(),
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = BitNetConfig::default();
assert_eq!(config.group_size, 64);
assert_eq!(config.activation_bits, 8);
assert!(config.per_token_activation);
assert!(config.use_rms_norm);
assert!(config.enable_ste);
}
#[test]
fn test_inference_config() {
let config = BitNetConfig::inference();
assert!(!config.enable_ste);
}
#[test]
fn test_training_config() {
let config = BitNetConfig::training();
assert!(config.enable_ste);
}
#[test]
fn test_builder_pattern() {
let config = BitNetConfig::new()
.with_group_size(128)
.with_activation_bits(4)
.with_per_token_activation(false)
.with_ste(false);
assert_eq!(config.group_size, 128);
assert_eq!(config.activation_bits, 4);
assert!(!config.per_token_activation);
assert!(!config.enable_ste);
}
#[test]
fn test_validation() {
let valid = BitNetConfig::default();
assert!(valid.validate().is_ok());
let invalid_group = BitNetConfig {
group_size: 0,
..Default::default()
};
assert!(invalid_group.validate().is_err());
let invalid_bits = BitNetConfig {
activation_bits: 0,
..Default::default()
};
assert!(invalid_bits.validate().is_err());
let non_power_of_two = BitNetConfig {
group_size: 65,
..Default::default()
};
assert!(non_power_of_two.validate().is_err());
}
}