use crate::core::error::{VqError, VqResult};
use crate::core::quantizer::Quantizer;
pub struct BinaryQuantizer {
threshold: f32,
low: u8,
high: u8,
}
impl BinaryQuantizer {
pub fn new(threshold: f32, low: u8, high: u8) -> VqResult<Self> {
if !threshold.is_finite() {
return Err(VqError::InvalidParameter {
parameter: "threshold",
reason: "must be finite (not NaN or infinite)".to_string(),
});
}
if low >= high {
return Err(VqError::InvalidParameter {
parameter: "low/high",
reason: "low must be less than high".to_string(),
});
}
Ok(Self {
threshold,
low,
high,
})
}
pub fn threshold(&self) -> f32 {
self.threshold
}
pub fn low(&self) -> u8 {
self.low
}
pub fn high(&self) -> u8 {
self.high
}
}
impl Quantizer for BinaryQuantizer {
type QuantizedOutput = Vec<u8>;
fn quantize(&self, vector: &[f32]) -> VqResult<Self::QuantizedOutput> {
Ok(vector
.iter()
.map(|&x| {
if x >= self.threshold {
self.high
} else {
self.low
}
})
.collect())
}
fn dequantize(&self, quantized: &Self::QuantizedOutput) -> VqResult<Vec<f32>> {
Ok(quantized
.iter()
.map(|&x| {
if x >= self.high {
self.high as f32
} else {
self.low as f32
}
})
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic() {
let bq = BinaryQuantizer::new(0.0, 0, 1).unwrap();
let input = vec![-1.0, 0.0, 1.0, -0.5, 0.5];
let result = bq.quantize(&input).unwrap();
assert_eq!(result, vec![0, 1, 1, 0, 1]);
}
#[test]
fn test_large_vector() {
let bq = BinaryQuantizer::new(0.0, 0, 1).unwrap();
let input: Vec<f32> = (0..1024).map(|i| (i as f32) - 512.0).collect();
let result = bq.quantize(&input).unwrap();
assert_eq!(result.len(), 1024);
for (i, &val) in result.iter().enumerate() {
let expected = if input[i] >= 0.0 { 1 } else { 0 };
assert_eq!(val, expected);
}
}
#[test]
fn test_invalid_levels() {
let result = BinaryQuantizer::new(0.0, 1, 0);
assert!(result.is_err());
}
#[test]
fn test_getters() {
let bq = BinaryQuantizer::new(0.5, 10, 20).unwrap();
assert_eq!(bq.threshold(), 0.5);
assert_eq!(bq.low(), 10);
assert_eq!(bq.high(), 20);
}
#[test]
fn test_invalid_parameters() {
let result = BinaryQuantizer::new(0.0, 5, 5);
assert!(result.is_err());
let result = BinaryQuantizer::new(0.0, 6, 5);
assert!(result.is_err());
}
#[test]
fn test_empty_input() {
let bq = BinaryQuantizer::new(0.0, 0, 1).unwrap();
let input: Vec<f32> = vec![];
let result = bq.quantize(&input).unwrap();
assert!(result.is_empty());
let empty_codes: Vec<u8> = vec![];
let reconstructed = bq.dequantize(&empty_codes).unwrap();
assert!(reconstructed.is_empty());
}
#[test]
fn test_nan_threshold_rejected() {
let result = BinaryQuantizer::new(f32::NAN, 0, 1);
assert!(result.is_err());
}
}