use super::{AsymmetricDistance, Quantizer};
#[derive(Debug, Clone)]
pub struct SQ8Quantizer {
dimension: usize,
mins: Vec<f32>,
scales: Vec<f32>,
inv_scales: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct SQ8EncodedVector {
pub codes: Vec<u8>,
}
impl SQ8Quantizer {
pub fn new(mins: Vec<f32>, scales: Vec<f32>) -> Self {
let dimension = mins.len();
debug_assert_eq!(scales.len(), dimension);
let inv_scales: Vec<f32> = scales
.iter()
.map(|&s| if s > 0.0 { 1.0 / s } else { 0.0 })
.collect();
Self {
dimension,
mins,
scales,
inv_scales,
}
}
pub fn from_bounds(mins: Vec<f32>, maxs: Vec<f32>) -> Self {
let scales: Vec<f32> = mins
.iter()
.zip(maxs.iter())
.map(|(&min, &max)| {
let range = max - min;
if range > 0.0 {
range / 255.0
} else {
1.0 }
})
.collect();
Self::new(mins, scales)
}
#[must_use]
pub fn mins(&self) -> &[f32] {
&self.mins
}
#[must_use]
pub fn scales(&self) -> &[f32] {
&self.scales
}
#[inline]
fn encode_dim(&self, value: f32, dim: usize) -> u8 {
let normalized = (value - self.mins[dim]) * self.inv_scales[dim];
normalized.clamp(0.0, 255.0).round() as u8
}
#[inline]
fn decode_dim(&self, code: u8, dim: usize) -> f32 {
f32::from(code) * self.scales[dim] + self.mins[dim]
}
#[inline]
pub fn asymmetric_l2_squared(&self, query: &[f32], encoded: &SQ8EncodedVector) -> f32 {
debug_assert_eq!(query.len(), self.dimension);
debug_assert_eq!(encoded.codes.len(), self.dimension);
let mut sum = 0.0f32;
for (i, (&q, &code)) in query.iter().zip(encoded.codes.iter()).enumerate() {
let decoded = self.decode_dim(code, i);
let diff = q - decoded;
sum += diff * diff;
}
sum
}
}
impl Quantizer for SQ8Quantizer {
type Encoded = SQ8EncodedVector;
fn train(vectors: &[Vec<f32>]) -> Self {
if vectors.is_empty() {
return Self::new(vec![], vec![]);
}
let dimension = vectors[0].len();
let mut mins = vec![f32::INFINITY; dimension];
let mut maxs = vec![f32::NEG_INFINITY; dimension];
for vector in vectors {
debug_assert_eq!(vector.len(), dimension, "All vectors must have same dimension");
for (i, &v) in vector.iter().enumerate() {
mins[i] = mins[i].min(v);
maxs[i] = maxs[i].max(v);
}
}
Self::from_bounds(mins, maxs)
}
fn encode(&self, vector: &[f32]) -> SQ8EncodedVector {
debug_assert_eq!(vector.len(), self.dimension);
let codes: Vec<u8> = vector
.iter()
.enumerate()
.map(|(i, &v)| self.encode_dim(v, i))
.collect();
SQ8EncodedVector { codes }
}
fn decode(&self, encoded: &SQ8EncodedVector) -> Vec<f32> {
debug_assert_eq!(encoded.codes.len(), self.dimension);
encoded
.codes
.iter()
.enumerate()
.map(|(i, &code)| self.decode_dim(code, i))
.collect()
}
fn dimension(&self) -> usize {
self.dimension
}
fn encoded_size(&self) -> usize {
self.dimension }
}
impl AsymmetricDistance<SQ8EncodedVector> for SQ8Quantizer {
fn asymmetric_l2(&self, query: &[f32], encoded: &SQ8EncodedVector) -> f32 {
self.asymmetric_l2_squared(query, encoded).sqrt()
}
fn asymmetric_inner_product(&self, query: &[f32], encoded: &SQ8EncodedVector) -> f32 {
debug_assert_eq!(query.len(), self.dimension);
debug_assert_eq!(encoded.codes.len(), self.dimension);
let mut sum = 0.0f32;
for (i, (&q, &code)) in query.iter().zip(encoded.codes.iter()).enumerate() {
let decoded = self.decode_dim(code, i);
sum += q * decoded;
}
sum
}
fn asymmetric_cosine(&self, query: &[f32], encoded: &SQ8EncodedVector) -> f32 {
debug_assert_eq!(query.len(), self.dimension);
let mut dot = 0.0f32;
let mut norm_q_sq = 0.0f32;
let mut norm_e_sq = 0.0f32;
for (i, (&q, &code)) in query.iter().zip(encoded.codes.iter()).enumerate() {
let decoded = self.decode_dim(code, i);
dot += q * decoded;
norm_q_sq += q * q;
norm_e_sq += decoded * decoded;
}
let norm_q = norm_q_sq.sqrt();
let norm_e = norm_e_sq.sqrt();
if norm_q == 0.0 || norm_e == 0.0 {
0.0
} else {
dot / (norm_q * norm_e)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 0.05;
fn assert_approx_eq(a: f32, b: f32, msg: &str) {
assert!(
(a - b).abs() < EPSILON,
"{}: Expected {} ≈ {}, diff = {}",
msg,
a,
b,
(a - b).abs()
);
}
fn create_test_vectors() -> Vec<Vec<f32>> {
vec![
vec![0.0, 0.5, 1.0],
vec![0.1, 0.6, 0.9],
vec![0.2, 0.4, 0.8],
vec![-0.1, 0.3, 1.1],
]
}
#[test]
fn test_train_and_encode() {
let vectors = create_test_vectors();
let quantizer = SQ8Quantizer::train(&vectors);
assert_eq!(quantizer.dimension(), 3);
assert!(quantizer.mins()[0] <= 0.0);
assert!(quantizer.mins()[1] <= 0.4);
}
#[test]
fn test_encode_decode_roundtrip() {
let vectors = create_test_vectors();
let quantizer = SQ8Quantizer::train(&vectors);
for vector in &vectors {
let encoded = quantizer.encode(vector);
let decoded = quantizer.decode(&encoded);
for (i, (&original, &reconstructed)) in vector.iter().zip(decoded.iter()).enumerate() {
assert_approx_eq(
original,
reconstructed,
&format!("Roundtrip dim {}", i),
);
}
}
}
#[test]
fn test_asymmetric_l2() {
let vectors = create_test_vectors();
let quantizer = SQ8Quantizer::train(&vectors);
let query = vec![0.0, 0.5, 1.0];
let encoded = quantizer.encode(&query);
let dist = quantizer.asymmetric_l2(&query, &encoded);
assert!(dist < 0.1, "Self-distance should be near 0, got {}", dist);
}
#[test]
fn test_asymmetric_inner_product() {
let vectors = vec![
vec![0.0, 0.0, 0.0],
vec![1.0, 1.0, 1.0],
vec![0.5, 0.5, 0.5],
];
let quantizer = SQ8Quantizer::train(&vectors);
let query = vec![1.0, 0.0, 0.0];
let vector = vec![1.0, 0.0, 0.0];
let encoded = quantizer.encode(&vector);
let ip = quantizer.asymmetric_inner_product(&query, &encoded);
assert!(ip > 0.9, "Parallel vectors should have IP near 1.0, got {}", ip);
}
#[test]
fn test_asymmetric_cosine() {
let vectors = vec![
vec![0.0, 0.0, 0.0],
vec![1.0, 1.0, 1.0],
vec![0.5, 0.5, 0.5],
];
let quantizer = SQ8Quantizer::train(&vectors);
let query = vec![1.0, 0.0, 0.0];
let encoded_parallel = quantizer.encode(&vec![1.0, 0.0, 0.0]);
let encoded_ortho = quantizer.encode(&vec![0.0, 1.0, 0.0]);
let cos_parallel = quantizer.asymmetric_cosine(&query, &encoded_parallel);
let cos_ortho = quantizer.asymmetric_cosine(&query, &encoded_ortho);
assert!(cos_parallel > 0.9, "Parallel should have cosine near 1.0");
assert!(cos_ortho.abs() < 0.1, "Orthogonal should have cosine near 0.0");
}
#[test]
fn test_compression_ratio() {
let vectors: Vec<Vec<f32>> = (0..100)
.map(|i| (0..512).map(|j| (i * j) as f32 * 0.001).collect())
.collect();
let quantizer = SQ8Quantizer::train(&vectors);
assert_eq!(quantizer.encoded_size(), 512);
assert!((quantizer.compression_ratio() - 4.0).abs() < 0.01);
}
#[test]
fn test_empty_vectors() {
let vectors: Vec<Vec<f32>> = vec![];
let quantizer = SQ8Quantizer::train(&vectors);
assert_eq!(quantizer.dimension(), 0);
}
#[test]
fn test_constant_dimension() {
let vectors = vec![
vec![5.0, 1.0, 2.0],
vec![5.0, 2.0, 3.0],
vec![5.0, 3.0, 4.0],
];
let quantizer = SQ8Quantizer::train(&vectors);
let encoded = quantizer.encode(&vec![5.0, 2.0, 3.0]);
let decoded = quantizer.decode(&encoded);
assert!((decoded[0] - 5.0).abs() < 0.1);
}
#[test]
fn test_high_dimension() {
let dim = 1536; let vectors: Vec<Vec<f32>> = (0..100)
.map(|i| (0..dim).map(|j| ((i * j) as f32 * 0.0001).sin()).collect())
.collect();
let quantizer = SQ8Quantizer::train(&vectors);
let query: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.001).cos()).collect();
let encoded = quantizer.encode(&query);
let l2 = quantizer.asymmetric_l2(&query, &encoded);
let ip = quantizer.asymmetric_inner_product(&query, &encoded);
let cos = quantizer.asymmetric_cosine(&query, &encoded);
assert!(l2.is_finite());
assert!(ip.is_finite());
assert!(cos.is_finite() && cos >= -1.0 && cos <= 1.0);
}
}