use rand::prelude::SliceRandom;
use rand::Rng;
#[derive(Debug, Clone)]
pub struct PQConfig {
pub m: usize,
pub k: usize,
pub seed: u64,
}
impl Default for PQConfig {
fn default() -> Self {
Self {
m: 8, k: 256, seed: 42,
}
}
}
pub struct ProductQuantizer {
config: PQConfig,
dims_per_subspace: usize,
dimensions: usize,
centroids: Vec<Vec<Vec<f32>>>,
trained: bool,
}
impl ProductQuantizer {
pub fn new(dimensions: usize, config: PQConfig) -> Self {
assert!(
dimensions % config.m == 0,
"Dimensions must be divisible by number of subspaces"
);
let dims_per_subspace = dimensions / config.m;
Self {
config,
dims_per_subspace,
dimensions,
centroids: Vec::new(),
trained: false,
}
}
pub fn train(&mut self, vectors: &[Vec<f32>]) {
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
let mut rng = ChaCha8Rng::seed_from_u64(self.config.seed);
self.centroids = Vec::with_capacity(self.config.m);
for subspace in 0..self.config.m {
let start = subspace * self.dims_per_subspace;
let end = start + self.dims_per_subspace;
let subvectors: Vec<Vec<f32>> =
vectors.iter().map(|v| v[start..end].to_vec()).collect();
let centroids = self.kmeans(&subvectors, self.config.k, 10, &mut rng);
self.centroids.push(centroids);
}
self.trained = true;
}
fn kmeans<R: Rng>(
&self,
vectors: &[Vec<f32>],
k: usize,
iterations: usize,
rng: &mut R,
) -> Vec<Vec<f32>> {
if vectors.is_empty() || k == 0 {
return Vec::new();
}
let dims = vectors[0].len();
let k = k.min(vectors.len());
let mut indices: Vec<usize> = (0..vectors.len()).collect();
indices.shuffle(rng);
let mut centroids: Vec<Vec<f32>> = indices
.iter()
.take(k)
.map(|&i| vectors[i].clone())
.collect();
for _ in 0..iterations {
let mut assignments: Vec<Vec<usize>> = vec![Vec::new(); k];
for (i, v) in vectors.iter().enumerate() {
let nearest = self.find_nearest(v, ¢roids);
assignments[nearest].push(i);
}
for (c, assigned) in assignments.iter().enumerate() {
if assigned.is_empty() {
continue;
}
let mut new_centroid = vec![0.0f32; dims];
for &i in assigned {
for (j, &val) in vectors[i].iter().enumerate() {
new_centroid[j] += val;
}
}
let count = assigned.len() as f32;
for val in &mut new_centroid {
*val /= count;
}
centroids[c] = new_centroid;
}
}
centroids
}
fn find_nearest(&self, vector: &[f32], centroids: &[Vec<f32>]) -> usize {
let mut best = 0;
let mut best_dist = f32::MAX;
for (i, c) in centroids.iter().enumerate() {
let dist: f32 = vector
.iter()
.zip(c.iter())
.map(|(a, b)| (a - b).powi(2))
.sum();
if dist < best_dist {
best_dist = dist;
best = i;
}
}
best
}
pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
assert!(self.trained, "Quantizer must be trained");
assert_eq!(vector.len(), self.dimensions);
let mut codes = Vec::with_capacity(self.config.m);
for subspace in 0..self.config.m {
let start = subspace * self.dims_per_subspace;
let end = start + self.dims_per_subspace;
let subvector = &vector[start..end];
let nearest = self.find_nearest(subvector, &self.centroids[subspace]);
codes.push(nearest as u8);
}
codes
}
pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
assert!(self.trained, "Quantizer must be trained");
assert_eq!(codes.len(), self.config.m);
let mut vector = Vec::with_capacity(self.dimensions);
for (subspace, &code) in codes.iter().enumerate() {
let centroid = &self.centroids[subspace][code as usize];
vector.extend_from_slice(centroid);
}
vector
}
pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
assert_eq!(query.len(), self.dimensions);
assert_eq!(codes.len(), self.config.m);
let mut distance_sq = 0.0f32;
for (subspace, &code) in codes.iter().enumerate() {
let start = subspace * self.dims_per_subspace;
let end = start + self.dims_per_subspace;
let query_sub = &query[start..end];
let centroid = &self.centroids[subspace][code as usize];
for (q, c) in query_sub.iter().zip(centroid.iter()) {
distance_sq += (q - c).powi(2);
}
}
distance_sq.sqrt()
}
pub fn precompute_distance_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
assert_eq!(query.len(), self.dimensions);
let mut table = Vec::with_capacity(self.config.m);
for subspace in 0..self.config.m {
let start = subspace * self.dims_per_subspace;
let end = start + self.dims_per_subspace;
let query_sub = &query[start..end];
let distances: Vec<f32> = self.centroids[subspace]
.iter()
.map(|c| {
query_sub
.iter()
.zip(c.iter())
.map(|(q, v)| (q - v).powi(2))
.sum::<f32>()
})
.collect();
table.push(distances);
}
table
}
pub fn table_distance(&self, table: &[Vec<f32>], codes: &[u8]) -> f32 {
let mut distance_sq = 0.0f32;
for (subspace, &code) in codes.iter().enumerate() {
distance_sq += table[subspace][code as usize];
}
distance_sq.sqrt()
}
pub fn bytes_per_vector(&self) -> usize {
self.config.m }
pub fn compression_ratio(&self) -> f32 {
(self.dimensions * 4) as f32 / self.config.m as f32
}
}
#[derive(Debug, Clone)]
pub struct PQVector {
pub codes: Vec<u8>,
}
impl PQVector {
pub fn memory_size(&self) -> usize {
std::mem::size_of::<Self>() + self.codes.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::prelude::*;
use rand_chacha::ChaCha8Rng;
fn random_vectors(n: usize, dims: usize, seed: u64) -> Vec<Vec<f32>> {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
(0..n)
.map(|_| (0..dims).map(|_| rng.gen_range(-1.0..1.0)).collect())
.collect()
}
#[test]
fn test_train_and_encode() {
let dims = 128;
let config = PQConfig {
m: 8,
k: 64,
seed: 42,
};
let mut pq = ProductQuantizer::new(dims, config);
let training = random_vectors(1000, dims, 42);
pq.train(&training);
let vector = random_vectors(1, dims, 123)[0].clone();
let codes = pq.encode(&vector);
assert_eq!(codes.len(), 8);
let decoded = pq.decode(&codes);
let error: f32 = vector
.iter()
.zip(decoded.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt();
assert!(error < 2.0, "Reconstruction error too high: {}", error);
}
#[test]
fn test_distance_table() {
let dims = 64;
let config = PQConfig {
m: 4,
k: 16,
seed: 42,
};
let mut pq = ProductQuantizer::new(dims, config);
let training = random_vectors(500, dims, 42);
pq.train(&training);
let query = random_vectors(1, dims, 123)[0].clone();
let target = random_vectors(1, dims, 456)[0].clone();
let codes = pq.encode(&target);
let asym_dist = pq.asymmetric_distance(&query, &codes);
let table = pq.precompute_distance_table(&query);
let table_dist = pq.table_distance(&table, &codes);
assert!((asym_dist - table_dist).abs() < 0.001);
}
#[test]
fn test_compression_ratio() {
let dims = 1536;
let config = PQConfig {
m: 48,
k: 256,
seed: 42,
};
let pq = ProductQuantizer::new(dims, config);
assert_eq!(pq.bytes_per_vector(), 48);
assert!((pq.compression_ratio() - 128.0).abs() < 0.1);
}
}