turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
use serde::{Deserialize, Serialize};

use crate::codebook::Codebook;
use crate::error::{Result, TurboQuantError};

/// A scalar quantizer backed by a Lloyd-Max codebook optimized for the
/// coordinate distribution of unit-sphere vectors.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScalarQuantizer {
    pub codebook: Codebook,
}

impl ScalarQuantizer {
    /// Create a quantizer from a pre-computed codebook.
    pub fn from_codebook(codebook: Codebook) -> Self {
        Self { codebook }
    }

    /// Quantize a single scalar value to its codebook index.
    pub fn quantize_scalar(&self, value: f64) -> u8 {
        self.codebook.quantize_scalar(value)
    }

    /// Checked variant of [`ScalarQuantizer::quantize_scalar`].
    pub fn checked_quantize_scalar(&self, value: f64) -> Result<u8> {
        self.codebook.checked_quantize_scalar(value)
    }

    /// Dequantize a codebook index back to its centroid value.
    pub fn dequantize_scalar(&self, index: u8) -> f64 {
        self.codebook.dequantize_scalar(index)
    }

    /// Checked variant of [`ScalarQuantizer::dequantize_scalar`].
    pub fn checked_dequantize_scalar(&self, index: u8) -> Result<f64> {
        self.codebook.checked_dequantize_scalar(index)
    }

    /// Quantize a batch of values to their indices.
    pub fn quantize_batch(&self, values: &[f64]) -> Vec<u8> {
        values.iter().map(|&v| self.quantize_scalar(v)).collect()
    }

    /// Dequantize a batch of indices back to centroid values.
    pub fn dequantize_batch(&self, indices: &[u8]) -> Vec<f64> {
        indices.iter().map(|&i| self.dequantize_scalar(i)).collect()
    }

    /// Compute mean squared error of quantization over a batch.
    ///
    /// Returns 0.0 for empty input.
    pub fn mse(&self, values: &[f64]) -> f64 {
        if values.is_empty() {
            return 0.0;
        }
        let indices = self.quantize_batch(values);
        let recon = self.dequantize_batch(&indices);
        let sq_err: f64 = values
            .iter()
            .zip(recon.iter())
            .map(|(v, r)| (v - r) * (v - r))
            .sum();
        sq_err / values.len() as f64
    }

    /// Validate that indices are within range for this codebook.
    pub fn validate_indices(&self, indices: &[u8]) -> Result<()> {
        let max_idx = self.codebook.num_levels() as u8 - 1;
        for &idx in indices {
            if idx > max_idx {
                return Err(TurboQuantError::InvalidQuantizationIndex {
                    index: idx,
                    max: max_idx,
                    bit_width: self.codebook.bit_width,
                });
            }
        }
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::codebook::generate_codebook;

    fn make_quantizer(dim: usize, bits: u8) -> ScalarQuantizer {
        let cb = generate_codebook(dim, bits, 50).unwrap();
        ScalarQuantizer::from_codebook(cb)
    }

    #[test]
    fn test_quantize_scalar_in_range() {
        let sq = make_quantizer(64, 2);
        let max_idx = (1u8 << 2) - 1;
        for val in [-0.9, -0.5, 0.0, 0.5, 0.9] {
            let idx = sq.quantize_scalar(val);
            assert!(idx <= max_idx, "index {} out of range for 2-bit", idx);
        }
    }

    #[test]
    fn test_batch_roundtrip() {
        // For dim=128 the Beta distribution ≈ N(0, 1/128), σ≈0.088.
        // Use values within ±3σ ≈ ±0.265 where the codebook has density.
        let sq = make_quantizer(128, 4);
        let values: Vec<f64> = (0..16).map(|i| (i as f64 - 7.5) * 0.03).collect();
        let indices = sq.quantize_batch(&values);
        let recon = sq.dequantize_batch(&indices);
        // Each reconstruction should be within 0.1 of original (4-bit in ±0.265 range)
        for (v, r) in values.iter().zip(recon.iter()) {
            assert!((v - r).abs() < 0.1, "v={}, r={}", v, r);
        }
    }

    #[test]
    fn test_validate_indices_out_of_range() {
        let sq = make_quantizer(64, 2);
        // 2-bit → max index is 3
        assert!(sq.validate_indices(&[0, 1, 2, 3]).is_ok());
        assert!(sq.validate_indices(&[4]).is_err());
    }

    #[test]
    fn test_mse_empty_input() {
        let sq = make_quantizer(64, 2);
        assert_eq!(sq.mse(&[]), 0.0);
    }

    #[test]
    fn test_mse_decreases_with_bits() {
        let values: Vec<f64> = (0..1000).map(|i| ((i as f64) * 0.1).sin()).collect();
        let mse2 = make_quantizer(128, 2).mse(&values);
        let mse4 = make_quantizer(128, 4).mse(&values);
        assert!(
            mse4 < mse2,
            "4-bit MSE {} should be less than 2-bit MSE {}",
            mse4,
            mse2
        );
    }

    #[test]
    fn test_checked_scalar_methods() {
        let sq = make_quantizer(64, 2);
        assert!(sq.checked_quantize_scalar(0.0).is_ok());
        assert!(matches!(
            sq.checked_quantize_scalar(f64::NAN),
            Err(TurboQuantError::InvalidValue { .. })
        ));
        assert!(matches!(
            sq.checked_dequantize_scalar(4),
            Err(TurboQuantError::InvalidQuantizationIndex { .. })
        ));
    }
}