use candle_core::{Device, Tensor};
use serde::{Deserialize, Serialize};
use crate::config::BitNetConfig;
use crate::error::{BitNetError, Result};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedActivations {
pub data: Vec<i8>,
pub scales: Vec<f32>,
pub shape: Vec<usize>,
pub per_token: bool,
}
impl QuantizedActivations {
#[must_use]
pub fn batch_size(&self) -> usize {
self.shape.first().copied().unwrap_or(1)
}
#[must_use]
pub fn seq_len(&self) -> usize {
self.shape.get(1).copied().unwrap_or(1)
}
#[must_use]
pub fn hidden_dim(&self) -> usize {
self.shape.last().copied().unwrap_or(0)
}
#[must_use]
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
}
pub fn quantize_activations(
activations: &Tensor,
config: &BitNetConfig,
) -> Result<QuantizedActivations> {
let shape = activations.shape().dims().to_vec();
let (batch_size, seq_len, hidden_dim) = match shape.len() {
2 => (shape[0], 1, shape[1]),
3 => (shape[0], shape[1], shape[2]),
_ => {
return Err(BitNetError::InvalidConfig(
"activations must be 2D or 3D".to_string(),
))
}
};
let flat = activations.reshape((batch_size * seq_len, hidden_dim))?;
let flat_f32 = flat.to_dtype(candle_core::DType::F32)?.to_vec2::<f32>()?;
let max_val = (1 << (config.activation_bits - 1)) - 1; let max_val_f32 = max_val as f32;
let mut data = Vec::with_capacity(batch_size * seq_len * hidden_dim);
let mut scales = Vec::with_capacity(batch_size * seq_len);
for row in &flat_f32 {
let abs_max = row.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let scale = if abs_max > 0.0 {
abs_max / max_val_f32
} else {
1.0
};
scales.push(scale);
for &val in row {
let quantized = (val / scale).round().clamp(-max_val_f32, max_val_f32) as i8;
data.push(quantized);
}
}
Ok(QuantizedActivations {
data,
scales,
shape,
per_token: config.per_token_activation,
})
}
pub fn dequantize_activations(quantized: &QuantizedActivations, device: &Device) -> Result<Tensor> {
let shape = &quantized.shape;
let (batch_size, seq_len, hidden_dim) = match shape.len() {
2 => (shape[0], 1, shape[1]),
3 => (shape[0], shape[1], shape[2]),
_ => {
return Err(BitNetError::InvalidConfig(
"invalid shape for dequantization".to_string(),
))
}
};
let mut output = vec![0.0f32; batch_size * seq_len * hidden_dim];
for token_idx in 0..(batch_size * seq_len) {
let scale = quantized.scales[token_idx];
let token_start = token_idx * hidden_dim;
for i in 0..hidden_dim {
let q_val = quantized.data[token_start + i];
output[token_start + i] = q_val as f32 * scale;
}
}
let tensor = Tensor::from_vec(output, shape.clone(), device)?;
Ok(tensor)
}
pub fn quantize_ste(activations: &Tensor, config: &BitNetConfig) -> Result<Tensor> {
let device = activations.device();
let quantized = quantize_activations(activations, config)?;
dequantize_activations(&quantized, device)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantize_dequantize_roundtrip_2d() {
let device = Device::Cpu;
let config = BitNetConfig::default();
let activations = Tensor::randn(0.0f32, 1.0, (4, 128), &device).unwrap();
let quantized = quantize_activations(&activations, &config).unwrap();
assert_eq!(quantized.shape, vec![4, 128]);
assert_eq!(quantized.scales.len(), 4); assert_eq!(quantized.data.len(), 4 * 128);
let restored = dequantize_activations(&quantized, &device).unwrap();
assert_eq!(restored.shape().dims(), &[4, 128]);
}
#[test]
fn test_quantize_dequantize_roundtrip_3d() {
let device = Device::Cpu;
let config = BitNetConfig::default();
let activations = Tensor::randn(0.0f32, 1.0, (2, 16, 128), &device).unwrap();
let quantized = quantize_activations(&activations, &config).unwrap();
assert_eq!(quantized.shape, vec![2, 16, 128]);
assert_eq!(quantized.scales.len(), 2 * 16); assert_eq!(quantized.data.len(), 2 * 16 * 128);
let restored = dequantize_activations(&quantized, &device).unwrap();
assert_eq!(restored.shape().dims(), &[2, 16, 128]);
}
#[test]
fn test_quantization_range() {
let device = Device::Cpu;
let config = BitNetConfig::default();
let values: Vec<f32> = (0..64).map(|i| (i as f32 - 32.0) / 10.0).collect();
let activations = Tensor::from_vec(values, (1, 64), &device).unwrap();
let quantized = quantize_activations(&activations, &config).unwrap();
for &val in &quantized.data {
assert!(val >= -127, "value {val} below -127");
}
}
#[test]
fn test_ste_passthrough() {
let device = Device::Cpu;
let config = BitNetConfig::training();
let activations = Tensor::randn(0.0f32, 1.0, (2, 64), &device).unwrap();
let result = quantize_ste(&activations, &config).unwrap();
assert_eq!(result.shape().dims(), activations.shape().dims());
let orig: Vec<f32> = activations.flatten_all().unwrap().to_vec1().unwrap();
let quant: Vec<f32> = result.flatten_all().unwrap().to_vec1().unwrap();
for (o, q) in orig.iter().zip(quant.iter()) {
let error = (o - q).abs();
assert!(error < 0.1, "error {error} too large");
}
}
#[test]
fn test_zero_activations() {
let device = Device::Cpu;
let config = BitNetConfig::default();
let activations = Tensor::zeros(&[4, 64], candle_core::DType::F32, &device).unwrap();
let quantized = quantize_activations(&activations, &config).unwrap();
for &val in &quantized.data {
assert_eq!(val, 0);
}
for &scale in &quantized.scales {
assert!((scale - 1.0).abs() < 0.001);
}
}
#[test]
fn test_invalid_shape() {
let device = Device::Cpu;
let config = BitNetConfig::default();
let activations = Tensor::zeros(&[64], candle_core::DType::F32, &device).unwrap();
assert!(quantize_activations(&activations, &config).is_err());
let activations = Tensor::zeros(&[2, 4, 8, 16], candle_core::DType::F32, &device).unwrap();
assert!(quantize_activations(&activations, &config).is_err());
}
}