use super::ternary_tensor::{pack_ternary, TernaryTensor};
use crate::error::{Result, RuvLLMError};
#[derive(Debug, Clone)]
pub struct PtBitnetConfig {
pub calibration_samples: usize,
pub block_size: usize,
pub optimize_scales: bool,
pub layers_to_quantize: LayerMask,
pub export_format: TernaryFormat,
pub router_precision: Precision,
pub use_mmap: bool,
pub use_metal_calibration: bool,
pub max_memory_gb: usize,
}
impl Default for PtBitnetConfig {
fn default() -> Self {
Self {
calibration_samples: 1000,
block_size: 256,
optimize_scales: true,
layers_to_quantize: LayerMask::ExpertsOnly,
export_format: TernaryFormat::BitnetT158,
router_precision: Precision::FP16,
use_mmap: true,
use_metal_calibration: cfg!(all(target_os = "macos", feature = "metal-compute")),
max_memory_gb: 64,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LayerMask {
ExpertsOnly,
All,
Custom(Vec<String>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TernaryFormat {
BitnetT158,
IQ1S,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Precision {
FP16,
BF16,
FP32,
}
pub fn absmean_ternary(block: &[f32]) -> (Vec<i8>, f32) {
if block.is_empty() {
return (vec![], 1e-8);
}
let sum_abs: f32 = block.iter().map(|&w| w.abs()).sum();
let gamma = (sum_abs / block.len() as f32) + 1e-8;
let ternary: Vec<i8> = block
.iter()
.map(|&w| {
let normalized = w / gamma;
let clamped = normalized.clamp(-1.0, 1.0);
clamped.round() as i8
})
.collect();
(ternary, gamma)
}
pub fn quantize_tensor(
weights: &[f32],
shape: (usize, usize),
config: &PtBitnetConfig,
) -> Result<TernaryTensor> {
let (rows, cols) = shape;
if rows == 0 || cols == 0 {
return Err(RuvLLMError::Model(format!(
"Invalid tensor shape: dimensions must be non-zero, got {:?}",
shape
)));
}
let block_size = config.block_size;
if block_size == 0 {
return Err(RuvLLMError::Model(
"block_size must be non-zero".to_string(),
));
}
let total_elements = rows.checked_mul(cols).ok_or_else(|| {
RuvLLMError::Model(format!(
"Integer overflow computing total elements for shape {:?}",
shape
))
})?;
if weights.len() != total_elements {
return Err(RuvLLMError::Model(format!(
"Weight size mismatch: expected {} elements for shape {:?}, got {}",
total_elements,
shape,
weights.len()
)));
}
let num_blocks = total_elements.checked_add(block_size - 1).ok_or_else(|| {
RuvLLMError::Model("Integer overflow in block count calculation".to_string())
})? / block_size;
let mut all_ternary = Vec::with_capacity(total_elements);
let mut scales = Vec::with_capacity(num_blocks);
for block_idx in 0..num_blocks {
let start = block_idx * block_size;
let end = (start + block_size).min(total_elements);
let block = &weights[start..end];
let (ternary, scale) = absmean_ternary(block);
all_ternary.extend_from_slice(&ternary);
scales.push(scale);
}
let packed_data = pack_ternary(&all_ternary);
Ok(TernaryTensor {
packed_data,
scales,
shape,
block_size,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_absmean_ternary_simple() {
let block = vec![0.5, -0.5, 0.0, 1.0, -1.0, 0.25];
let (ternary, scale) = absmean_ternary(&block);
assert!(ternary.iter().all(|&v| v >= -1 && v <= 1));
assert!(scale > 0.0);
assert_eq!(ternary[0], 1);
assert_eq!(ternary[1], -1);
assert_eq!(ternary[2], 0);
}
#[test]
fn test_absmean_ternary_all_zeros() {
let block = vec![0.0; 256];
let (ternary, scale) = absmean_ternary(&block);
assert!(ternary.iter().all(|&v| v == 0));
assert!(scale < 1e-7 && scale > 0.0);
}
#[test]
fn test_absmean_ternary_large_values() {
let block = vec![10.0, -10.0, 5.0, -5.0];
let (ternary, _scale) = absmean_ternary(&block);
assert!(ternary[0] == 1 || ternary[0] == -1);
assert!(ternary[1] == 1 || ternary[1] == -1);
}
#[test]
fn test_quantize_tensor_simple() {
let weights = vec![0.5; 512]; let shape = (2, 256);
let config = PtBitnetConfig::default();
let ternary = quantize_tensor(&weights, shape, &config).unwrap();
assert_eq!(ternary.shape, shape);
assert_eq!(ternary.block_size, 256);
assert_eq!(ternary.num_blocks(), 2); assert_eq!(ternary.scales.len(), 2);
assert_eq!(ternary.packed_data.len(), 128);
}
#[test]
fn test_quantize_tensor_size_mismatch() {
let weights = vec![0.5; 100]; let shape = (2, 256); let config = PtBitnetConfig::default();
let result = quantize_tensor(&weights, shape, &config);
assert!(result.is_err());
}
#[test]
fn test_quantize_tensor_memory_savings() {
let weights = vec![0.5; 256 * 1024];
let shape = (512, 512);
let config = PtBitnetConfig::default();
let ternary = quantize_tensor(&weights, shape, &config).unwrap();
let original_bytes = weights.len() * 4; let compressed_bytes = ternary.memory_bytes();
let compression_ratio = original_bytes as f32 / compressed_bytes as f32;
assert!(compression_ratio > 10.0); assert!(compression_ratio < 20.0); }
#[test]
fn test_config_default() {
let config = PtBitnetConfig::default();
assert_eq!(config.block_size, 256);
assert_eq!(config.calibration_samples, 1000);
assert!(config.optimize_scales);
assert_eq!(config.layers_to_quantize, LayerMask::ExpertsOnly);
}
#[test]
fn test_layer_mask_variants() {
let experts = LayerMask::ExpertsOnly;
let all = LayerMask::All;
let custom = LayerMask::Custom(vec!["layer.0".to_string()]);
assert_ne!(experts, all);
assert_ne!(all, custom);
assert_ne!(experts, custom);
}
}