use rand::prelude::*;
use super::{AsymmetricDistance, Quantizer};
pub const DEFAULT_K: usize = 256;
#[derive(Debug, Clone)]
pub struct PQQuantizer {
dimension: usize,
m: usize,
sub_dim: usize,
k: usize,
codebooks: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct PQEncodedVector {
pub codes: Vec<u8>,
}
#[derive(Debug)]
pub struct PQDistanceTable {
distances: Vec<f32>,
m: usize,
k: usize,
}
#[derive(Debug, Clone)]
pub struct PQConfig {
pub m: usize,
pub k: usize,
pub kmeans_iters: usize,
pub seed: Option<u64>,
}
impl Default for PQConfig {
fn default() -> Self {
Self {
m: 16,
k: 256,
kmeans_iters: 25,
seed: None,
}
}
}
impl PQQuantizer {
pub fn new(dimension: usize, m: usize, k: usize, codebooks: Vec<f32>) -> Self {
assert!(dimension % m == 0, "Dimension must be divisible by M");
assert!(k <= 256, "K must be <= 256 for u8 codes");
let sub_dim = dimension / m;
assert_eq!(
codebooks.len(),
m * k * sub_dim,
"Codebook size mismatch"
);
Self {
dimension,
m,
sub_dim,
k,
codebooks,
}
}
pub fn train_with_config(vectors: &[Vec<f32>], config: &PQConfig) -> Self {
if vectors.is_empty() {
return Self::new(0, config.m, config.k, vec![]);
}
let dimension = vectors[0].len();
assert!(
dimension % config.m == 0,
"Dimension {} must be divisible by M {}",
dimension,
config.m
);
let sub_dim = dimension / config.m;
let mut codebooks = Vec::with_capacity(config.m * config.k * sub_dim);
let mut rng = match config.seed {
Some(s) => StdRng::seed_from_u64(s),
None => StdRng::from_entropy(),
};
for subspace in 0..config.m {
let start = subspace * sub_dim;
let end = start + sub_dim;
let subvectors: Vec<Vec<f32>> = vectors
.iter()
.map(|v| v[start..end].to_vec())
.collect();
let centroids = kmeans(&subvectors, config.k, config.kmeans_iters, &mut rng);
for centroid in centroids {
codebooks.extend_from_slice(¢roid);
}
}
Self::new(dimension, config.m, config.k, codebooks)
}
#[inline]
fn get_centroid(&self, subspace: usize, code: u8) -> &[f32] {
let start = (subspace * self.k + code as usize) * self.sub_dim;
&self.codebooks[start..start + self.sub_dim]
}
fn find_nearest_centroid(&self, subspace: usize, subvector: &[f32]) -> u8 {
let mut best_code = 0u8;
let mut best_dist = f32::INFINITY;
for code in 0..self.k {
let centroid = self.get_centroid(subspace, code as u8);
let dist = l2_squared(subvector, centroid);
if dist < best_dist {
best_dist = dist;
best_code = code as u8;
}
}
best_code
}
pub fn compute_distance_table(&self, query: &[f32]) -> PQDistanceTable {
debug_assert_eq!(query.len(), self.dimension);
let mut distances = Vec::with_capacity(self.m * self.k);
for subspace in 0..self.m {
let query_sub = &query[subspace * self.sub_dim..(subspace + 1) * self.sub_dim];
for code in 0..self.k {
let centroid = self.get_centroid(subspace, code as u8);
let dist = l2_squared(query_sub, centroid);
distances.push(dist);
}
}
PQDistanceTable {
distances,
m: self.m,
k: self.k,
}
}
#[inline]
pub fn asymmetric_l2_with_table(&self, table: &PQDistanceTable, encoded: &PQEncodedVector) -> f32 {
debug_assert_eq!(encoded.codes.len(), self.m);
let mut sum = 0.0f32;
for (subspace, &code) in encoded.codes.iter().enumerate() {
sum += table.get(subspace, code);
}
sum.sqrt()
}
#[inline]
pub fn asymmetric_l2_squared_with_table(&self, table: &PQDistanceTable, encoded: &PQEncodedVector) -> f32 {
debug_assert_eq!(encoded.codes.len(), self.m);
let mut sum = 0.0f32;
for (subspace, &code) in encoded.codes.iter().enumerate() {
sum += table.get(subspace, code);
}
sum
}
#[must_use]
pub fn num_subspaces(&self) -> usize {
self.m
}
#[must_use]
pub fn num_centroids(&self) -> usize {
self.k
}
#[must_use]
pub fn sub_dimension(&self) -> usize {
self.sub_dim
}
#[must_use]
pub fn codebook_size(&self) -> usize {
self.codebooks.len() * 4
}
}
impl PQDistanceTable {
#[inline]
pub fn get(&self, subspace: usize, code: u8) -> f32 {
self.distances[subspace * self.k + code as usize]
}
}
impl Quantizer for PQQuantizer {
type Encoded = PQEncodedVector;
fn train(vectors: &[Vec<f32>]) -> Self {
Self::train_with_config(vectors, &PQConfig::default())
}
fn encode(&self, vector: &[f32]) -> PQEncodedVector {
debug_assert_eq!(vector.len(), self.dimension);
let codes: Vec<u8> = (0..self.m)
.map(|subspace| {
let subvector = &vector[subspace * self.sub_dim..(subspace + 1) * self.sub_dim];
self.find_nearest_centroid(subspace, subvector)
})
.collect();
PQEncodedVector { codes }
}
fn decode(&self, encoded: &PQEncodedVector) -> Vec<f32> {
debug_assert_eq!(encoded.codes.len(), self.m);
let mut result = Vec::with_capacity(self.dimension);
for (subspace, &code) in encoded.codes.iter().enumerate() {
let centroid = self.get_centroid(subspace, code);
result.extend_from_slice(centroid);
}
result
}
fn dimension(&self) -> usize {
self.dimension
}
fn encoded_size(&self) -> usize {
self.m }
}
impl AsymmetricDistance<PQEncodedVector> for PQQuantizer {
fn asymmetric_l2(&self, query: &[f32], encoded: &PQEncodedVector) -> f32 {
debug_assert_eq!(query.len(), self.dimension);
debug_assert_eq!(encoded.codes.len(), self.m);
let mut sum = 0.0f32;
for (subspace, &code) in encoded.codes.iter().enumerate() {
let query_sub = &query[subspace * self.sub_dim..(subspace + 1) * self.sub_dim];
let centroid = self.get_centroid(subspace, code);
sum += l2_squared(query_sub, centroid);
}
sum.sqrt()
}
fn asymmetric_inner_product(&self, query: &[f32], encoded: &PQEncodedVector) -> f32 {
debug_assert_eq!(query.len(), self.dimension);
debug_assert_eq!(encoded.codes.len(), self.m);
let mut sum = 0.0f32;
for (subspace, &code) in encoded.codes.iter().enumerate() {
let query_sub = &query[subspace * self.sub_dim..(subspace + 1) * self.sub_dim];
let centroid = self.get_centroid(subspace, code);
sum += inner_product(query_sub, centroid);
}
sum
}
fn asymmetric_cosine(&self, query: &[f32], encoded: &PQEncodedVector) -> f32 {
debug_assert_eq!(query.len(), self.dimension);
debug_assert_eq!(encoded.codes.len(), self.m);
let decoded = self.decode(encoded);
let mut dot = 0.0f32;
let mut norm_q_sq = 0.0f32;
let mut norm_d_sq = 0.0f32;
for (&q, &d) in query.iter().zip(decoded.iter()) {
dot += q * d;
norm_q_sq += q * q;
norm_d_sq += d * d;
}
let norm_q = norm_q_sq.sqrt();
let norm_d = norm_d_sq.sqrt();
if norm_q == 0.0 || norm_d == 0.0 {
0.0
} else {
dot / (norm_q * norm_d)
}
}
}
#[inline]
fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(&x, &y)| {
let d = x - y;
d * d
})
.sum()
}
#[inline]
fn inner_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
fn kmeans(vectors: &[Vec<f32>], k: usize, max_iters: usize, rng: &mut StdRng) -> Vec<Vec<f32>> {
if vectors.is_empty() || k == 0 {
return vec![];
}
let n = vectors.len();
let dim = vectors[0].len();
let actual_k = k.min(n);
let mut centroids = kmeans_plusplus_init(vectors, actual_k, rng);
while centroids.len() < k {
let random_idx = rng.gen_range(0..n);
centroids.push(vectors[random_idx].clone());
}
let mut assignments = vec![0usize; n];
for _ in 0..max_iters {
let mut changed = false;
for (i, vector) in vectors.iter().enumerate() {
let mut best_centroid = 0;
let mut best_dist = f32::INFINITY;
for (c, centroid) in centroids.iter().enumerate() {
let dist = l2_squared(vector, centroid);
if dist < best_dist {
best_dist = dist;
best_centroid = c;
}
}
if assignments[i] != best_centroid {
assignments[i] = best_centroid;
changed = true;
}
}
if !changed {
break;
}
let mut new_centroids = vec![vec![0.0f32; dim]; k];
let mut counts = vec![0usize; k];
for (i, vector) in vectors.iter().enumerate() {
let c = assignments[i];
counts[c] += 1;
for (j, &v) in vector.iter().enumerate() {
new_centroids[c][j] += v;
}
}
for (c, centroid) in new_centroids.iter_mut().enumerate() {
if counts[c] > 0 {
for v in centroid.iter_mut() {
*v /= counts[c] as f32;
}
} else {
let random_idx = rng.gen_range(0..n);
*centroid = vectors[random_idx].clone();
}
}
centroids = new_centroids;
}
centroids
}
fn kmeans_plusplus_init(vectors: &[Vec<f32>], k: usize, rng: &mut StdRng) -> Vec<Vec<f32>> {
let n = vectors.len();
if n == 0 || k == 0 {
return vec![];
}
let mut centroids = Vec::with_capacity(k);
let first_idx = rng.gen_range(0..n);
centroids.push(vectors[first_idx].clone());
let mut distances = vec![f32::INFINITY; n];
while centroids.len() < k && centroids.len() < n {
let last_centroid = centroids.last().unwrap();
for (i, vector) in vectors.iter().enumerate() {
let dist = l2_squared(vector, last_centroid);
distances[i] = distances[i].min(dist);
}
let total: f32 = distances.iter().sum();
if total == 0.0 {
break;
}
let threshold = rng.gen::<f32>() * total;
let mut cumsum = 0.0f32;
let mut chosen = 0;
for (i, &d) in distances.iter().enumerate() {
cumsum += d;
if cumsum >= threshold {
chosen = i;
break;
}
}
centroids.push(vectors[chosen].clone());
}
centroids
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f32 = 0.1;
fn random_test_vectors(n: usize, dim: usize, seed: u64) -> Vec<Vec<f32>> {
let mut rng = StdRng::seed_from_u64(seed);
(0..n)
.map(|_| (0..dim).map(|_| rng.gen::<f32>() - 0.5).collect())
.collect()
}
#[test]
fn test_pq_train() {
let vectors = random_test_vectors(100, 64, 42);
let config = PQConfig {
m: 8,
k: 16,
kmeans_iters: 10,
seed: Some(42),
};
let quantizer = PQQuantizer::train_with_config(&vectors, &config);
assert_eq!(quantizer.dimension(), 64);
assert_eq!(quantizer.num_subspaces(), 8);
assert_eq!(quantizer.num_centroids(), 16);
assert_eq!(quantizer.sub_dimension(), 8);
}
#[test]
fn test_pq_encode_decode() {
let vectors = random_test_vectors(100, 64, 42);
let config = PQConfig {
m: 8,
k: 32,
kmeans_iters: 10,
seed: Some(42),
};
let quantizer = PQQuantizer::train_with_config(&vectors, &config);
for vector in vectors.iter().take(10) {
let encoded = quantizer.encode(vector);
assert_eq!(encoded.codes.len(), 8);
let decoded = quantizer.decode(&encoded);
assert_eq!(decoded.len(), 64);
let dist = l2_squared(vector, &decoded).sqrt();
assert!(
dist < 2.0,
"Reconstruction error too high: {}",
dist
);
}
}
#[test]
fn test_pq_compression_ratio() {
let vectors = random_test_vectors(1000, 512, 42);
let config = PQConfig {
m: 16,
k: 256,
kmeans_iters: 5,
seed: Some(42),
};
let quantizer = PQQuantizer::train_with_config(&vectors, &config);
assert_eq!(quantizer.encoded_size(), 16);
let ratio = quantizer.compression_ratio();
assert!(
(ratio - 128.0).abs() < 0.01,
"Expected 128x compression, got {}x",
ratio
);
}
#[test]
fn test_pq_asymmetric_l2() {
let vectors = random_test_vectors(100, 64, 42);
let config = PQConfig {
m: 8,
k: 32,
kmeans_iters: 10,
seed: Some(42),
};
let quantizer = PQQuantizer::train_with_config(&vectors, &config);
let query = &vectors[0];
let encoded = quantizer.encode(query);
let dist = quantizer.asymmetric_l2(query, &encoded);
assert!(
dist < 2.0,
"Self-distance should be reasonable, got {}",
dist
);
}
#[test]
fn test_pq_distance_table() {
let vectors = random_test_vectors(100, 64, 42);
let config = PQConfig {
m: 8,
k: 32,
kmeans_iters: 10,
seed: Some(42),
};
let quantizer = PQQuantizer::train_with_config(&vectors, &config);
let query = &vectors[0];
let encoded = quantizer.encode(&vectors[1]);
let direct = quantizer.asymmetric_l2(query, &encoded);
let table = quantizer.compute_distance_table(query);
let table_based = quantizer.asymmetric_l2_with_table(&table, &encoded);
assert!(
(direct - table_based).abs() < 1e-5,
"Table-based distance should match direct: {} vs {}",
direct,
table_based
);
}
#[test]
fn test_pq_asymmetric_inner_product() {
let vectors = random_test_vectors(100, 64, 42);
let config = PQConfig {
m: 8,
k: 32,
kmeans_iters: 10,
seed: Some(42),
};
let quantizer = PQQuantizer::train_with_config(&vectors, &config);
let query = &vectors[0];
let encoded = quantizer.encode(query);
let ip = quantizer.asymmetric_inner_product(query, &encoded);
assert!(ip > 0.0, "Self IP should be positive, got {}", ip);
}
#[test]
fn test_pq_asymmetric_cosine() {
let vectors = random_test_vectors(100, 64, 42);
let config = PQConfig {
m: 8,
k: 32,
kmeans_iters: 10,
seed: Some(42),
};
let quantizer = PQQuantizer::train_with_config(&vectors, &config);
let query = &vectors[0];
let encoded = quantizer.encode(query);
let cos = quantizer.asymmetric_cosine(query, &encoded);
assert!(
cos > 0.8,
"Self-cosine should be high, got {}",
cos
);
}
#[test]
fn test_pq_empty_vectors() {
let vectors: Vec<Vec<f32>> = vec![];
let quantizer = PQQuantizer::train(&vectors);
assert_eq!(quantizer.dimension(), 0);
}
#[test]
fn test_pq_few_vectors() {
let vectors = random_test_vectors(5, 64, 42);
let config = PQConfig {
m: 8,
k: 256, kmeans_iters: 5,
seed: Some(42),
};
let quantizer = PQQuantizer::train_with_config(&vectors, &config);
let encoded = quantizer.encode(&vectors[0]);
let decoded = quantizer.decode(&encoded);
assert_eq!(decoded.len(), 64);
}
#[test]
fn test_kmeans() {
let mut rng = StdRng::seed_from_u64(42);
let vectors: Vec<Vec<f32>> = (0..50)
.map(|i| {
if i < 25 {
vec![0.0, 0.0]
} else {
vec![10.0, 10.0]
}
})
.collect();
let centroids = kmeans(&vectors, 2, 10, &mut rng);
assert_eq!(centroids.len(), 2);
let mut found_origin = false;
let mut found_corner = false;
for centroid in ¢roids {
if centroid[0] < 1.0 && centroid[1] < 1.0 {
found_origin = true;
}
if centroid[0] > 9.0 && centroid[1] > 9.0 {
found_corner = true;
}
}
assert!(found_origin, "Should find centroid near origin");
assert!(found_corner, "Should find centroid near (10, 10)");
}
}