use blake3::hash;
use serde::{Deserialize, Serialize};
use crate::vec::VecSearchHit;
use crate::{MemvidError, Result, types::FrameId};
fn vec_config() -> impl bincode::config::Config {
bincode::config::standard()
.with_fixed_int_encoding()
.with_little_endian()
}
#[allow(clippy::cast_possible_truncation)]
const VEC_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
const NUM_SUBSPACES: usize = 96; const SUBSPACE_DIM: usize = 4; const NUM_CENTROIDS: usize = 256; const TOTAL_DIM: usize = NUM_SUBSPACES * SUBSPACE_DIM;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubspaceCodebook {
centroids: Vec<f32>,
}
impl SubspaceCodebook {
fn new() -> Self {
Self {
centroids: vec![0.0; NUM_CENTROIDS * SUBSPACE_DIM],
}
}
fn get_centroid(&self, index: u8) -> &[f32] {
let start = (index as usize) * SUBSPACE_DIM;
&self.centroids[start..start + SUBSPACE_DIM]
}
fn set_centroid(&mut self, index: u8, values: &[f32]) {
assert_eq!(values.len(), SUBSPACE_DIM);
let start = (index as usize) * SUBSPACE_DIM;
self.centroids[start..start + SUBSPACE_DIM].copy_from_slice(values);
}
fn quantize(&self, subspace: &[f32]) -> u8 {
assert_eq!(subspace.len(), SUBSPACE_DIM);
let mut best_idx = 0u8;
let mut best_dist = f32::INFINITY;
for i in 0..NUM_CENTROIDS {
#[allow(clippy::cast_possible_truncation)]
let centroid = self.get_centroid(i as u8);
let dist = l2_distance_squared(subspace, centroid);
if dist < best_dist {
best_dist = dist;
#[allow(clippy::cast_possible_truncation)]
{
best_idx = i as u8;
}
}
}
best_idx
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProductQuantizer {
codebooks: Vec<SubspaceCodebook>,
dimension: u32,
}
impl ProductQuantizer {
pub fn new(dimension: u32) -> Result<Self> {
if dimension as usize != TOTAL_DIM {
return Err(MemvidError::InvalidQuery {
reason: format!("PQ only supports {TOTAL_DIM}-dim vectors, got {dimension}"),
});
}
Ok(Self {
codebooks: vec![SubspaceCodebook::new(); NUM_SUBSPACES],
dimension,
})
}
pub fn train(&mut self, training_vectors: &[Vec<f32>], max_iterations: usize) -> Result<()> {
if training_vectors.is_empty() {
return Err(MemvidError::InvalidQuery {
reason: "Cannot train PQ with empty training set".to_string(),
});
}
for vec in training_vectors {
if vec.len() != TOTAL_DIM {
return Err(MemvidError::InvalidQuery {
reason: format!(
"Training vector has wrong dimension: expected {}, got {}",
TOTAL_DIM,
vec.len()
),
});
}
}
for subspace_idx in 0..NUM_SUBSPACES {
let start_dim = subspace_idx * SUBSPACE_DIM;
let end_dim = start_dim + SUBSPACE_DIM;
let subspace_vecs: Vec<Vec<f32>> = training_vectors
.iter()
.map(|v| v[start_dim..end_dim].to_vec())
.collect();
let centroids = kmeans(&subspace_vecs, NUM_CENTROIDS, max_iterations)?;
for (i, centroid) in centroids.iter().enumerate() {
#[allow(clippy::cast_possible_truncation)]
self.codebooks[subspace_idx].set_centroid(i as u8, centroid);
}
}
Ok(())
}
pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
if vector.len() != TOTAL_DIM {
return Err(MemvidError::InvalidQuery {
reason: format!(
"Vector dimension mismatch: expected {}, got {}",
TOTAL_DIM,
vector.len()
),
});
}
let mut codes = Vec::with_capacity(NUM_SUBSPACES);
for subspace_idx in 0..NUM_SUBSPACES {
let start_dim = subspace_idx * SUBSPACE_DIM;
let end_dim = start_dim + SUBSPACE_DIM;
let subspace = &vector[start_dim..end_dim];
let code = self.codebooks[subspace_idx].quantize(subspace);
codes.push(code);
}
Ok(codes)
}
pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
if codes.len() != NUM_SUBSPACES {
return Err(MemvidError::InvalidQuery {
reason: format!(
"Invalid PQ codes length: expected {}, got {}",
NUM_SUBSPACES,
codes.len()
),
});
}
let mut vector = Vec::with_capacity(TOTAL_DIM);
for (subspace_idx, &code) in codes.iter().enumerate() {
let centroid = self.codebooks[subspace_idx].get_centroid(code);
vector.extend_from_slice(centroid);
}
Ok(vector)
}
#[must_use]
pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
if query.len() != TOTAL_DIM || codes.len() != NUM_SUBSPACES {
return f32::INFINITY;
}
let mut total_dist_sq = 0.0f32;
for subspace_idx in 0..NUM_SUBSPACES {
let start_dim = subspace_idx * SUBSPACE_DIM;
let end_dim = start_dim + SUBSPACE_DIM;
let query_subspace = &query[start_dim..end_dim];
let code = codes[subspace_idx];
let centroid = self.codebooks[subspace_idx].get_centroid(code);
total_dist_sq += l2_distance_squared(query_subspace, centroid);
}
total_dist_sq.sqrt()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuantizedVecDocument {
pub frame_id: FrameId,
pub codes: Vec<u8>,
}
#[derive(Default)]
pub struct QuantizedVecIndexBuilder {
documents: Vec<QuantizedVecDocument>,
quantizer: Option<ProductQuantizer>,
}
impl QuantizedVecIndexBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn train_quantizer(&mut self, training_vectors: &[Vec<f32>], dimension: u32) -> Result<()> {
let mut pq = ProductQuantizer::new(dimension)?;
pq.train(training_vectors, 25)?; self.quantizer = Some(pq);
Ok(())
}
pub fn add_document(&mut self, frame_id: FrameId, embedding: Vec<f32>) -> Result<()> {
let quantizer = self
.quantizer
.as_ref()
.ok_or_else(|| MemvidError::InvalidQuery {
reason: "Quantizer not trained. Call train_quantizer first".to_string(),
})?;
let codes = quantizer.encode(&embedding)?;
self.documents
.push(QuantizedVecDocument { frame_id, codes });
Ok(())
}
pub fn finish(self) -> Result<QuantizedVecIndexArtifact> {
let quantizer = self.quantizer.ok_or_else(|| MemvidError::InvalidQuery {
reason: "Quantizer not trained".to_string(),
})?;
let vector_count = self.documents.len() as u64;
let bytes =
bincode::serde::encode_to_vec(&(quantizer.clone(), self.documents), vec_config())?;
let checksum = *hash(&bytes).as_bytes();
Ok(QuantizedVecIndexArtifact {
bytes,
vector_count,
dimension: quantizer.dimension,
checksum,
compression_ratio: 16.0, })
}
}
#[derive(Debug, Clone)]
pub struct QuantizedVecIndexArtifact {
pub bytes: Vec<u8>,
pub vector_count: u64,
pub dimension: u32,
pub checksum: [u8; 32],
pub compression_ratio: f64,
}
#[derive(Debug, Clone)]
pub struct QuantizedVecIndex {
quantizer: ProductQuantizer,
documents: Vec<QuantizedVecDocument>,
}
impl QuantizedVecIndex {
pub fn decode(bytes: &[u8]) -> Result<Self> {
let config = bincode::config::standard()
.with_fixed_int_encoding()
.with_little_endian()
.with_limit::<VEC_DECODE_LIMIT>();
if let Ok(((quantizer, documents), read)) = bincode::serde::decode_from_slice::<
(ProductQuantizer, Vec<QuantizedVecDocument>),
_,
>(bytes, config)
{
if read == bytes.len() {
return Ok(Self {
quantizer,
documents,
});
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct OldProductQuantizer {
codebooks: Vec<SubspaceCodebook>,
}
let ((old_quantizer, documents), read): (
(OldProductQuantizer, Vec<QuantizedVecDocument>),
usize,
) = bincode::serde::decode_from_slice(bytes, config)?;
if read != bytes.len() {
return Err(MemvidError::InvalidToc {
reason: "unsupported quantized vector index encoding".into(),
});
}
let quantizer = ProductQuantizer {
codebooks: old_quantizer.codebooks,
dimension: u32::try_from(NUM_SUBSPACES * SUBSPACE_DIM).unwrap_or(u32::MAX),
};
Ok(Self {
quantizer,
documents,
})
}
#[must_use]
pub fn search(&self, query: &[f32], limit: usize) -> Vec<VecSearchHit> {
if query.is_empty() {
return Vec::new();
}
let mut hits: Vec<VecSearchHit> = self
.documents
.iter()
.map(|doc| {
let distance = self.quantizer.asymmetric_distance(query, &doc.codes);
VecSearchHit {
frame_id: doc.frame_id,
distance,
}
})
.collect();
hits.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
hits.truncate(limit);
hits
}
pub fn remove(&mut self, frame_id: FrameId) {
self.documents.retain(|doc| doc.frame_id != frame_id);
}
#[must_use]
pub fn compression_stats(&self) -> CompressionStats {
let original_bytes = self.documents.len() * TOTAL_DIM * std::mem::size_of::<f32>();
let compressed_bytes = self.documents.len() * NUM_SUBSPACES; let codebook_bytes =
NUM_SUBSPACES * NUM_CENTROIDS * SUBSPACE_DIM * std::mem::size_of::<f32>();
CompressionStats {
vector_count: self.documents.len() as u64,
original_bytes: original_bytes as u64,
compressed_bytes: compressed_bytes as u64,
codebook_bytes: codebook_bytes as u64,
total_bytes: (compressed_bytes + codebook_bytes) as u64,
compression_ratio: original_bytes as f64 / (compressed_bytes + codebook_bytes) as f64,
}
}
}
#[derive(Debug, Clone)]
pub struct CompressionStats {
pub vector_count: u64,
pub original_bytes: u64,
pub compressed_bytes: u64,
pub codebook_bytes: u64,
pub total_bytes: u64,
pub compression_ratio: f64,
}
fn kmeans(vectors: &[Vec<f32>], k: usize, max_iterations: usize) -> Result<Vec<Vec<f32>>> {
if vectors.is_empty() {
return Err(MemvidError::InvalidQuery {
reason: "Cannot run k-means on empty vector set".to_string(),
});
}
let dim = vectors[0].len();
let mut centroids = kmeans_plus_plus_init(vectors, k)?;
for _iteration in 0..max_iterations {
let mut assignments = vec![Vec::new(); k];
for vec in vectors {
let mut best_cluster = 0;
let mut best_dist = f32::INFINITY;
for (cluster_idx, centroid) in centroids.iter().enumerate() {
let dist = l2_distance_squared(vec, centroid);
if dist < best_dist {
best_dist = dist;
best_cluster = cluster_idx;
}
}
assignments[best_cluster].push(vec.clone());
}
let mut changed = false;
for (cluster_idx, cluster_vecs) in assignments.iter().enumerate() {
if cluster_vecs.is_empty() {
centroids[cluster_idx] = vectors[cluster_idx % vectors.len()].clone();
changed = true;
continue;
}
let mut new_centroid = vec![0.0f32; dim];
for vec in cluster_vecs {
for (i, &val) in vec.iter().enumerate() {
new_centroid[i] += val;
}
}
for val in &mut new_centroid {
*val /= cluster_vecs.len() as f32;
}
if l2_distance_squared(¢roids[cluster_idx], &new_centroid) > 1e-6 {
changed = true;
}
centroids[cluster_idx] = new_centroid;
}
if !changed {
break; }
}
Ok(centroids)
}
fn kmeans_plus_plus_init(vectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>> {
if vectors.is_empty() || k == 0 {
return Err(MemvidError::InvalidQuery {
reason: "Invalid k-means++ initialization".to_string(),
});
}
let mut centroids = Vec::new();
centroids.push(vectors[0].clone());
for _ in 1..k {
let mut distances = Vec::new();
for vec in vectors {
let mut min_dist = f32::INFINITY;
for centroid in ¢roids {
let dist = l2_distance_squared(vec, centroid);
min_dist = min_dist.min(dist);
}
distances.push(min_dist);
}
let max_idx = distances
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(idx, _)| idx);
centroids.push(vectors[max_idx].clone());
}
Ok(centroids)
}
fn l2_distance_squared(a: &[f32], b: &[f32]) -> f32 {
crate::simd::l2_distance_squared_simd(a, b)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subspace_codebook() {
let mut codebook = SubspaceCodebook::new();
codebook.set_centroid(0, &[1.0, 2.0, 3.0, 4.0]);
let centroid = codebook.get_centroid(0);
assert_eq!(centroid, &[1.0, 2.0, 3.0, 4.0]);
let code = codebook.quantize(&[1.1, 2.1, 3.1, 4.1]);
assert_eq!(code, 0);
}
#[test]
fn test_product_quantizer_roundtrip() {
let mut training_vecs = Vec::new();
for i in 0..100 {
let mut vec = vec![0.0f32; TOTAL_DIM];
for j in 0..TOTAL_DIM {
vec[j] = ((i * TOTAL_DIM + j) % 100) as f32 / 100.0;
}
training_vecs.push(vec);
}
let mut pq = ProductQuantizer::new(u32::try_from(TOTAL_DIM).unwrap()).unwrap();
pq.train(&training_vecs, 10).unwrap();
let test_vec = &training_vecs[0];
let codes = pq.encode(test_vec).unwrap();
assert_eq!(codes.len(), NUM_SUBSPACES);
let decoded = pq.decode(&codes).unwrap();
assert_eq!(decoded.len(), TOTAL_DIM);
let dist = l2_distance_squared(test_vec, &decoded).sqrt();
assert!(dist < 10.0, "Reconstruction error too large: {}", dist);
}
#[test]
fn test_quantized_index_builder() {
let mut training_vecs = Vec::new();
for i in 0..50 {
let mut vec = vec![0.0f32; TOTAL_DIM];
for j in 0..TOTAL_DIM {
vec[j] = ((i + j) % 10) as f32;
}
training_vecs.push(vec);
}
let mut builder = QuantizedVecIndexBuilder::new();
builder
.train_quantizer(&training_vecs, u32::try_from(TOTAL_DIM).unwrap())
.unwrap();
for (i, vec) in training_vecs.iter().take(10).enumerate() {
builder
.add_document((i + 1) as FrameId, vec.clone())
.unwrap();
}
let artifact = builder.finish().unwrap();
assert_eq!(artifact.vector_count, 10);
assert_eq!(artifact.dimension, u32::try_from(TOTAL_DIM).unwrap());
assert!(artifact.compression_ratio > 10.0);
let index = QuantizedVecIndex::decode(&artifact.bytes).unwrap();
let query = &training_vecs[0];
let hits = index.search(query, 5);
assert!(!hits.is_empty());
assert_eq!(hits[0].frame_id, 1); }
#[test]
fn test_kmeans_simple() {
let vectors = vec![
vec![0.0, 0.0],
vec![0.1, 0.1],
vec![10.0, 10.0],
vec![10.1, 10.1],
];
let centroids = kmeans(&vectors, 2, 100).unwrap();
assert_eq!(centroids.len(), 2);
let near_zero = centroids.iter().any(|c| c[0] < 5.0 && c[1] < 5.0);
let near_ten = centroids.iter().any(|c| c[0] > 5.0 && c[1] > 5.0);
assert!(near_zero && near_ten);
}
}