use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use crate::connection::SochConnection;
use crate::error::{ClientError, Result};
use sochdb_core::soch::SochValue;
use sochdb_index::vector_simd::l2_squared_f32;
const VAMANA_THRESHOLD: usize = 100_000;
const DEFAULT_MIGRATION_BATCH_SIZE: usize = 1000;
const DEFAULT_PQ_SUBSPACES: usize = 8;
const DEFAULT_PQ_CENTROIDS: usize = 256;
const DEFAULT_PQ_ITERATIONS: usize = 20;
const MIN_PQ_TRAINING_VECTORS: usize = 1000;
#[derive(Clone)]
pub struct ProductQuantizer {
m: usize,
k: usize,
d_sub: usize,
dimension: usize,
codebooks: Vec<Vec<Vec<f32>>>,
trained: bool,
}
impl ProductQuantizer {
pub fn new(dimension: usize, m: usize, k: usize) -> Self {
assert!(
dimension.is_multiple_of(m),
"Dimension must be divisible by number of subspaces"
);
let d_sub = dimension / m;
Self {
m,
k,
d_sub,
dimension,
codebooks: Vec::new(),
trained: false,
}
}
pub fn new_default(dimension: usize) -> Self {
let m = (dimension / 8).min(DEFAULT_PQ_SUBSPACES.max(dimension / 16));
Self::new(dimension, m, DEFAULT_PQ_CENTROIDS)
}
pub fn train(
&mut self,
vectors: &[Vec<f32>],
iterations: usize,
) -> std::result::Result<(), String> {
if vectors.is_empty() {
return Err("Cannot train PQ with empty vector set".to_string());
}
if vectors.len() < MIN_PQ_TRAINING_VECTORS {
return Err(format!(
"Need at least {} vectors for PQ training, got {}",
MIN_PQ_TRAINING_VECTORS,
vectors.len()
));
}
for (i, v) in vectors.iter().enumerate() {
if v.len() != self.dimension {
return Err(format!(
"Vector {} has dimension {}, expected {}",
i,
v.len(),
self.dimension
));
}
}
self.codebooks = Vec::with_capacity(self.m);
for sub in 0..self.m {
let start_dim = sub * self.d_sub;
let end_dim = start_dim + self.d_sub;
let subvectors: Vec<Vec<f32>> = vectors
.iter()
.map(|v| v[start_dim..end_dim].to_vec())
.collect();
let centroids = kmeans_clustering(&subvectors, self.k, iterations);
self.codebooks.push(centroids);
}
self.trained = true;
Ok(())
}
pub fn is_trained(&self) -> bool {
self.trained
}
pub fn encode(&self, vector: &[f32]) -> Vec<u8> {
if !self.trained {
return simple_pq_encode(vector);
}
assert_eq!(vector.len(), self.dimension);
let mut codes = Vec::with_capacity(self.m);
for sub in 0..self.m {
let start_dim = sub * self.d_sub;
let end_dim = start_dim + self.d_sub;
let subvec = &vector[start_dim..end_dim];
let centroid_idx = find_nearest_centroid(subvec, &self.codebooks[sub]);
codes.push(centroid_idx as u8);
}
codes
}
pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
if !self.trained || codes.len() != self.m {
return vec![0.0; self.dimension];
}
let mut vector = Vec::with_capacity(self.dimension);
for (sub, &code) in codes.iter().enumerate() {
let centroid = &self.codebooks[sub][code as usize];
vector.extend_from_slice(centroid);
}
vector
}
pub fn compute_distance_table(&self, query: &[f32]) -> Vec<Vec<f32>> {
if !self.trained {
return Vec::new();
}
assert_eq!(query.len(), self.dimension);
let mut table = Vec::with_capacity(self.m);
for sub in 0..self.m {
let start_dim = sub * self.d_sub;
let end_dim = start_dim + self.d_sub;
let query_sub = &query[start_dim..end_dim];
let distances: Vec<f32> = self.codebooks[sub]
.iter()
.map(|centroid| squared_euclidean(query_sub, centroid))
.collect();
table.push(distances);
}
table
}
pub fn asymmetric_distance(&self, table: &[Vec<f32>], codes: &[u8]) -> f32 {
if table.len() != self.m || codes.len() != self.m {
return f32::MAX;
}
let mut dist_sq = 0.0;
for (sub, &code) in codes.iter().enumerate() {
dist_sq += table[sub][code as usize];
}
dist_sq.sqrt()
}
pub fn bytes_per_vector(&self) -> usize {
self.m
}
pub fn compression_ratio(&self) -> f32 {
(4 * self.dimension) as f32 / self.m as f32
}
pub fn num_subspaces(&self) -> usize {
self.m
}
}
fn kmeans_clustering(vectors: &[Vec<f32>], k: usize, iterations: usize) -> Vec<Vec<f32>> {
if vectors.is_empty() {
return Vec::new();
}
let dim = vectors[0].len();
let n = vectors.len();
let k = k.min(n);
let mut centroids = kmeans_plus_plus_init(vectors, k);
let mut assignments = vec![0usize; n];
for _iter in 0..iterations {
for (i, vec) in vectors.iter().enumerate() {
assignments[i] = find_nearest_centroid(vec, ¢roids);
}
let mut new_centroids = vec![vec![0.0f32; dim]; k];
let mut counts = vec![0usize; k];
for (i, vec) in vectors.iter().enumerate() {
let cluster = assignments[i];
counts[cluster] += 1;
for (j, &val) in vec.iter().enumerate() {
new_centroids[cluster][j] += val;
}
}
for (cluster, centroid) in new_centroids.iter_mut().enumerate() {
if counts[cluster] > 0 {
for val in centroid.iter_mut() {
*val /= counts[cluster] as f32;
}
} else {
*centroid = vectors[cluster % n].clone();
}
}
centroids = new_centroids;
}
centroids
}
fn kmeans_plus_plus_init(vectors: &[Vec<f32>], k: usize) -> Vec<Vec<f32>> {
let n = vectors.len();
if n == 0 || k == 0 {
return Vec::new();
}
let mut centroids = Vec::with_capacity(k);
let mut rng_state = 42u64;
let first_idx = (simple_random(&mut rng_state) % n as u64) as usize;
centroids.push(vectors[first_idx].clone());
let mut min_distances = vec![f32::MAX; n];
for _ in 1..k {
let last_centroid = centroids.last().unwrap();
for (i, vec) in vectors.iter().enumerate() {
let dist = squared_euclidean(vec, last_centroid);
min_distances[i] = min_distances[i].min(dist);
}
let total: f32 = min_distances.iter().sum();
if total == 0.0 {
break;
}
let threshold = (simple_random(&mut rng_state) as f32 / u64::MAX as f32) * total;
let mut cumsum = 0.0;
let mut chosen_idx = 0;
for (i, &d) in min_distances.iter().enumerate() {
cumsum += d;
if cumsum >= threshold {
chosen_idx = i;
break;
}
}
centroids.push(vectors[chosen_idx].clone());
}
centroids
}
fn simple_random(state: &mut u64) -> u64 {
*state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
*state
}
fn find_nearest_centroid(vector: &[f32], centroids: &[Vec<f32>]) -> usize {
let mut min_dist = f32::MAX;
let mut min_idx = 0;
for (i, centroid) in centroids.iter().enumerate() {
let dist = squared_euclidean(vector, centroid);
if dist < min_dist {
min_dist = dist;
min_idx = i;
}
}
min_idx
}
fn squared_euclidean(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
}
fn simple_pq_encode(vector: &[f32]) -> Vec<u8> {
vector
.iter()
.step_by(8)
.map(|&v| {
((v.clamp(-1.0, 1.0) + 1.0) * 127.5) as u8
})
.collect()
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: String,
pub distance: f32,
pub metadata: Option<HashMap<String, SochValue>>,
}
#[derive(Debug, Clone)]
pub struct VectorStats {
pub count: usize,
pub dimension: usize,
pub backend: String,
pub memory_bytes: usize,
pub pq_enabled: bool,
pub migration_progress: Option<u8>,
}
#[derive(Debug)]
pub struct MigrationState {
pub migrated_count: AtomicUsize,
pub total_count: usize,
pub batch_size: usize,
pub in_progress: AtomicBool,
}
impl MigrationState {
fn new(total_count: usize) -> Self {
Self {
migrated_count: AtomicUsize::new(0),
total_count,
batch_size: DEFAULT_MIGRATION_BATCH_SIZE,
in_progress: AtomicBool::new(true),
}
}
fn progress(&self) -> u8 {
if self.total_count == 0 {
return 100;
}
let migrated = self.migrated_count.load(Ordering::SeqCst);
((migrated as f64 / self.total_count as f64) * 100.0).min(100.0) as u8
}
fn is_complete(&self) -> bool {
self.migrated_count.load(Ordering::SeqCst) >= self.total_count
}
}
pub struct VectorCollection {
name: String,
dimension: usize,
backend: VectorBackend,
pq: Option<ProductQuantizer>,
#[allow(dead_code)]
conn: Arc<SochConnection>,
id_map: RwLock<HashMap<String, usize>>,
reverse_map: RwLock<HashMap<usize, String>>,
}
enum VectorBackend {
InMemory { vectors: RwLock<Vec<Vec<f32>>> },
Vamana {
vectors: RwLock<Vec<Vec<f32>>>,
pq_codes: RwLock<Vec<Vec<u8>>>,
},
Hybrid {
hnsw_vectors: RwLock<Vec<Vec<f32>>>,
vamana_vectors: RwLock<Vec<Vec<f32>>>,
pq_codes: RwLock<Vec<Vec<u8>>>,
migrated_bitmap: RwLock<Vec<bool>>,
migration: Arc<MigrationState>,
},
}
impl VectorCollection {
pub fn open(conn: &Arc<SochConnection>, name: &str) -> Result<Self> {
Ok(Self {
name: name.to_string(),
dimension: 0,
backend: VectorBackend::InMemory {
vectors: RwLock::new(Vec::new()),
},
pq: None,
conn: Arc::clone(conn),
id_map: RwLock::new(HashMap::new()),
reverse_map: RwLock::new(HashMap::new()),
})
}
pub fn create(conn: &Arc<SochConnection>, name: &str, dimension: usize) -> Result<Self> {
Ok(Self {
name: name.to_string(),
dimension,
backend: VectorBackend::InMemory {
vectors: RwLock::new(Vec::new()),
},
pq: None,
conn: Arc::clone(conn),
id_map: RwLock::new(HashMap::new()),
reverse_map: RwLock::new(HashMap::new()),
})
}
pub fn train_pq(&mut self) -> Result<()> {
let vectors = match &self.backend {
VectorBackend::InMemory { vectors } => vectors.read().clone(),
VectorBackend::Vamana { vectors, .. } => vectors.read().clone(),
VectorBackend::Hybrid { hnsw_vectors, .. } => hnsw_vectors.read().clone(),
};
if vectors.len() < MIN_PQ_TRAINING_VECTORS {
return Err(ClientError::Validation(format!(
"Need at least {} vectors for PQ training, got {}",
MIN_PQ_TRAINING_VECTORS,
vectors.len()
)));
}
let mut pq = ProductQuantizer::new_default(self.dimension);
pq.train(&vectors, DEFAULT_PQ_ITERATIONS)
.map_err(ClientError::Validation)?;
self.pq = Some(pq);
Ok(())
}
pub fn is_pq_trained(&self) -> bool {
self.pq.as_ref().map(|pq| pq.is_trained()).unwrap_or(false)
}
pub fn name(&self) -> &str {
&self.name
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn len(&self) -> usize {
self.id_map.read().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn add(&mut self, ids: &[&str], vectors: &[Vec<f32>]) -> Result<()> {
if ids.len() != vectors.len() {
return Err(ClientError::Validation(
"IDs and vectors must have same length".to_string(),
));
}
for vec in vectors {
if self.dimension == 0 {
self.dimension = vec.len();
} else if vec.len() != self.dimension {
return Err(ClientError::Validation(format!(
"Vector dimension mismatch: expected {}, got {}",
self.dimension,
vec.len()
)));
}
}
let current_size = self.len();
let new_size = current_size + vectors.len();
if new_size > VAMANA_THRESHOLD
&& let VectorBackend::InMemory { .. } = &self.backend
{
self.start_gradual_migration()?;
}
let pq_ref = self.pq.as_ref();
let encode_vec = |vec: &[f32]| -> Vec<u8> {
if let Some(pq) = pq_ref
&& pq.is_trained()
{
return pq.encode(vec);
}
simple_pq_encode(vec)
};
match &mut self.backend {
VectorBackend::InMemory { vectors: store } => {
let mut store = store.write();
let mut id_map = self.id_map.write();
let mut reverse_map = self.reverse_map.write();
for (id, vec) in ids.iter().zip(vectors.iter()) {
let internal_id = store.len();
store.push(vec.clone());
id_map.insert(id.to_string(), internal_id);
reverse_map.insert(internal_id, id.to_string());
}
}
VectorBackend::Vamana {
vectors: store,
pq_codes,
} => {
let mut store = store.write();
let mut codes = pq_codes.write();
let mut id_map = self.id_map.write();
let mut reverse_map = self.reverse_map.write();
for (id, vec) in ids.iter().zip(vectors.iter()) {
let internal_id = store.len();
store.push(vec.clone());
let code = encode_vec(vec);
codes.push(code);
id_map.insert(id.to_string(), internal_id);
reverse_map.insert(internal_id, id.to_string());
}
}
VectorBackend::Hybrid {
hnsw_vectors,
vamana_vectors,
pq_codes,
migrated_bitmap,
..
} => {
let mut hnsw = hnsw_vectors.write();
let mut vamana = vamana_vectors.write();
let mut codes = pq_codes.write();
let mut bitmap = migrated_bitmap.write();
let mut id_map = self.id_map.write();
let mut reverse_map = self.reverse_map.write();
for (id, vec) in ids.iter().zip(vectors.iter()) {
let internal_id = hnsw.len();
hnsw.push(vec.clone());
vamana.push(vec.clone());
let code = encode_vec(vec);
codes.push(code);
bitmap.push(true); id_map.insert(id.to_string(), internal_id);
reverse_map.insert(internal_id, id.to_string());
}
}
}
Ok(())
}
fn encode_vector(&self, vector: &[f32]) -> Vec<u8> {
if let Some(ref pq) = self.pq
&& pq.is_trained()
{
return pq.encode(vector);
}
simple_pq_encode(vector)
}
pub fn add_one(&mut self, id: &str, vector: Vec<f32>) -> Result<()> {
self.add(&[id], &[vector])
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
if query.len() != self.dimension {
return Err(ClientError::Validation(format!(
"Query dimension mismatch: expected {}, got {}",
self.dimension,
query.len()
)));
}
match &self.backend {
VectorBackend::InMemory { vectors } => {
self.search_brute_force(vectors.read().as_slice(), query, k)
}
VectorBackend::Vamana { vectors, pq_codes } => {
if self.is_pq_trained() {
self.search_with_pq(
vectors.read().as_slice(),
pq_codes.read().as_slice(),
query,
k,
)
} else {
self.search_brute_force(vectors.read().as_slice(), query, k)
}
}
VectorBackend::Hybrid {
hnsw_vectors,
vamana_vectors,
pq_codes,
migrated_bitmap,
migration,
} => {
self.search_hybrid(
hnsw_vectors.read().as_slice(),
vamana_vectors.read().as_slice(),
pq_codes.read().as_slice(),
migrated_bitmap.read().as_slice(),
query,
k,
migration,
)
}
}
}
pub fn get(&self, id: &str) -> Option<Vec<f32>> {
let internal_id = *self.id_map.read().get(id)?;
match &self.backend {
VectorBackend::InMemory { vectors } => vectors.read().get(internal_id).cloned(),
VectorBackend::Vamana { vectors, .. } => vectors.read().get(internal_id).cloned(),
VectorBackend::Hybrid { hnsw_vectors, .. } => {
hnsw_vectors.read().get(internal_id).cloned()
}
}
}
pub fn delete(&mut self, id: &str) -> Result<bool> {
let internal_id = match self.id_map.write().remove(id) {
Some(id) => id,
None => return Ok(false),
};
self.reverse_map.write().remove(&internal_id);
Ok(true)
}
pub fn stats(&self) -> VectorStats {
let pq_trained = self.is_pq_trained();
let (backend_name, pq_enabled, migration_progress) = match &self.backend {
VectorBackend::InMemory { .. } => ("InMemory/HNSW", pq_trained, None),
VectorBackend::Vamana { .. } => ("Vamana", pq_trained, None),
VectorBackend::Hybrid { migration, .. } => {
("Hybrid (migrating)", pq_trained, Some(migration.progress()))
}
};
let count = self.len();
let memory_bytes = if let Some(ref pq) = self.pq {
if pq.is_trained() {
count * pq.bytes_per_vector()
} else {
count * self.dimension * 4
}
} else {
count * self.dimension * 4 };
VectorStats {
count,
dimension: self.dimension,
backend: backend_name.to_string(),
memory_bytes,
pq_enabled,
migration_progress,
}
}
pub fn compression_ratio(&self) -> Option<f32> {
self.pq.as_ref().and_then(|pq| {
if pq.is_trained() {
Some(pq.compression_ratio())
} else {
None
}
})
}
fn start_gradual_migration(&mut self) -> Result<()> {
let old_vectors = match &self.backend {
VectorBackend::InMemory { vectors } => vectors.read().clone(),
_ => return Ok(()), };
let total_count = old_vectors.len();
let migration = Arc::new(MigrationState::new(total_count));
self.backend = VectorBackend::Hybrid {
hnsw_vectors: RwLock::new(old_vectors),
vamana_vectors: RwLock::new(Vec::with_capacity(total_count)),
pq_codes: RwLock::new(Vec::with_capacity(total_count)),
migrated_bitmap: RwLock::new(vec![false; total_count]),
migration,
};
Ok(())
}
pub fn migrate_batch(&mut self) -> Result<usize> {
let VectorBackend::Hybrid {
hnsw_vectors,
vamana_vectors,
pq_codes,
migrated_bitmap,
migration,
} = &self.backend
else {
return Ok(0); };
if migration.is_complete() {
self.finalize_migration()?;
return Ok(0);
}
let batch_size = migration.batch_size;
let migrated_so_far = migration.migrated_count.load(Ordering::SeqCst);
let to_migrate = batch_size.min(migration.total_count - migrated_so_far);
if to_migrate == 0 {
return Ok(0);
}
let hnsw = hnsw_vectors.read();
let mut vamana = vamana_vectors.write();
let mut codes = pq_codes.write();
let mut bitmap = migrated_bitmap.write();
let start = migrated_so_far;
let end = start + to_migrate;
let pq_ref = self.pq.as_ref();
for i in start..end {
if i < hnsw.len() && !bitmap[i] {
let vec = &hnsw[i];
vamana.push(vec.clone());
let code = if let Some(pq) = pq_ref {
if pq.is_trained() {
pq.encode(vec)
} else {
simple_pq_encode(vec)
}
} else {
simple_pq_encode(vec)
};
codes.push(code);
bitmap[i] = true;
}
}
migration
.migrated_count
.fetch_add(to_migrate, Ordering::SeqCst);
if migration.is_complete() {
drop(hnsw);
drop(vamana);
drop(codes);
drop(bitmap);
self.finalize_migration()?;
}
Ok(to_migrate)
}
fn finalize_migration(&mut self) -> Result<()> {
let VectorBackend::Hybrid {
vamana_vectors,
pq_codes,
..
} = std::mem::replace(
&mut self.backend,
VectorBackend::InMemory {
vectors: RwLock::new(Vec::new()),
},
)
else {
return Ok(()); };
self.backend = VectorBackend::Vamana {
vectors: vamana_vectors,
pq_codes,
};
Ok(())
}
#[allow(dead_code)]
fn promote_to_vamana(&mut self) -> Result<()> {
let old_vectors = match &self.backend {
VectorBackend::InMemory { vectors } => vectors.read().clone(),
_ => return Ok(()), };
let pq_codes: Vec<Vec<u8>> = old_vectors.iter().map(|v| self.encode_vector(v)).collect();
self.backend = VectorBackend::Vamana {
vectors: RwLock::new(old_vectors),
pq_codes: RwLock::new(pq_codes),
};
Ok(())
}
fn search_brute_force(
&self,
vectors: &[Vec<f32>],
query: &[f32],
k: usize,
) -> Result<Vec<SearchResult>> {
let reverse_map = self.reverse_map.read();
let mut distances: Vec<(usize, f32)> = vectors
.iter()
.enumerate()
.map(|(i, v)| (i, euclidean_distance(query, v)))
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(distances
.into_iter()
.take(k)
.filter_map(|(idx, dist)| {
reverse_map.get(&idx).map(|id| SearchResult {
id: id.clone(),
distance: dist,
metadata: None,
})
})
.collect())
}
fn search_with_pq(
&self,
vectors: &[Vec<f32>],
pq_codes: &[Vec<u8>],
query: &[f32],
k: usize,
) -> Result<Vec<SearchResult>> {
if let Some(ref pq) = self.pq
&& pq.is_trained()
{
return self.search_with_adc(pq_codes, query, k, pq, vectors);
}
self.search_brute_force(vectors, query, k)
}
fn search_with_adc(
&self,
pq_codes: &[Vec<u8>],
query: &[f32],
k: usize,
pq: &ProductQuantizer,
vectors: &[Vec<f32>],
) -> Result<Vec<SearchResult>> {
let reverse_map = self.reverse_map.read();
let distance_table = pq.compute_distance_table(query);
let mut candidates: Vec<(usize, f32)> = pq_codes
.iter()
.enumerate()
.map(|(i, code)| (i, pq.asymmetric_distance(&distance_table, code)))
.collect();
let rerank_count = (k * 4).min(candidates.len());
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
candidates.truncate(rerank_count);
let mut reranked: Vec<(usize, f32)> = candidates
.iter()
.filter_map(|&(idx, _)| {
vectors
.get(idx)
.map(|v| (idx, euclidean_distance(query, v)))
})
.collect();
reranked.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(reranked
.into_iter()
.take(k)
.filter_map(|(idx, dist)| {
reverse_map.get(&idx).map(|id| SearchResult {
id: id.clone(),
distance: dist,
metadata: None,
})
})
.collect())
}
#[allow(clippy::too_many_arguments)]
fn search_hybrid(
&self,
hnsw_vectors: &[Vec<f32>],
vamana_vectors: &[Vec<f32>],
_pq_codes: &[Vec<u8>],
migrated_bitmap: &[bool],
query: &[f32],
k: usize,
migration: &MigrationState,
) -> Result<Vec<SearchResult>> {
let reverse_map = self.reverse_map.read();
let migration_progress = migration.progress();
let mut all_distances: Vec<(usize, f32)> = if migration_progress < 50 {
hnsw_vectors
.iter()
.enumerate()
.map(|(i, v)| (i, euclidean_distance(query, v)))
.collect()
} else if migration_progress >= 90 {
vamana_vectors
.iter()
.enumerate()
.map(|(i, v)| (i, euclidean_distance(query, v)))
.collect()
} else {
let mut distances = Vec::with_capacity(hnsw_vectors.len());
for (i, hnsw_vec) in hnsw_vectors.iter().enumerate() {
if i < migrated_bitmap.len() && migrated_bitmap[i] {
if i < vamana_vectors.len() {
distances.push((i, euclidean_distance(query, &vamana_vectors[i])));
}
} else {
distances.push((i, euclidean_distance(query, hnsw_vec)));
}
}
distances
};
all_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
Ok(all_distances
.into_iter()
.take(k)
.filter_map(|(idx, dist)| {
reverse_map.get(&idx).map(|id| SearchResult {
id: id.clone(),
distance: dist,
metadata: None,
})
})
.collect())
}
}
#[inline]
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
l2_squared_f32(a, b).sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
fn test_conn() -> Arc<SochConnection> {
Arc::new(SochConnection::open("./test").unwrap())
}
#[test]
fn test_create_collection() {
let conn = test_conn();
let coll = VectorCollection::create(&conn, "test_vectors", 128).unwrap();
assert_eq!(coll.name(), "test_vectors");
assert_eq!(coll.dimension(), 128);
assert!(coll.is_empty());
}
#[test]
fn test_add_vectors() {
let conn = test_conn();
let mut coll = VectorCollection::create(&conn, "test", 4).unwrap();
coll.add(
&["a", "b"],
&[vec![1.0, 0.0, 0.0, 0.0], vec![0.0, 1.0, 0.0, 0.0]],
)
.unwrap();
assert_eq!(coll.len(), 2);
assert!(coll.get("a").is_some());
}
#[test]
fn test_search() {
let conn = test_conn();
let mut coll = VectorCollection::create(&conn, "test", 4).unwrap();
coll.add(
&["a", "b", "c"],
&[
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
],
)
.unwrap();
let results = coll.search(&[1.0, 0.1, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, "a"); }
#[test]
fn test_delete() {
let conn = test_conn();
let mut coll = VectorCollection::create(&conn, "test", 4).unwrap();
coll.add_one("a", vec![1.0, 0.0, 0.0, 0.0]).unwrap();
assert_eq!(coll.len(), 1);
assert!(coll.delete("a").unwrap());
assert!(!coll.delete("a").unwrap()); }
#[test]
fn test_stats() {
let conn = test_conn();
let mut coll = VectorCollection::create(&conn, "test", 128).unwrap();
coll.add_one("a", vec![0.0; 128]).unwrap();
let stats = coll.stats();
assert_eq!(stats.count, 1);
assert_eq!(stats.dimension, 128);
assert!(!stats.pq_enabled);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert!((euclidean_distance(&a, &b) - 5.0).abs() < 0.001);
}
#[test]
fn test_pq_new() {
let pq = ProductQuantizer::new(128, 8, 256);
assert_eq!(pq.m, 8);
assert_eq!(pq.k, 256);
assert_eq!(pq.d_sub, 16);
assert!(!pq.is_trained());
}
#[test]
fn test_pq_new_default() {
let pq = ProductQuantizer::new_default(768);
assert!(pq.m > 0);
assert_eq!(pq.k, 256);
assert!(!pq.is_trained());
}
#[test]
fn test_pq_train_and_encode() {
let dimension = 32;
let n_vectors = 1500;
let mut vectors = Vec::with_capacity(n_vectors);
let mut rng = 42u64;
for _ in 0..n_vectors {
let mut vec = Vec::with_capacity(dimension);
for _ in 0..dimension {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
vec.push(val);
}
vectors.push(vec);
}
let mut pq = ProductQuantizer::new(dimension, 4, 256);
pq.train(&vectors, 10).unwrap();
assert!(pq.is_trained());
assert_eq!(pq.codebooks.len(), 4);
assert_eq!(pq.codebooks[0].len(), 256);
let code = pq.encode(&vectors[0]);
assert_eq!(code.len(), 4);
let reconstructed = pq.decode(&code);
assert_eq!(reconstructed.len(), dimension);
}
#[test]
fn test_pq_asymmetric_distance() {
let dimension = 32;
let n_vectors = 1500;
let mut vectors = Vec::with_capacity(n_vectors);
let mut rng = 123u64;
for _ in 0..n_vectors {
let mut vec = Vec::with_capacity(dimension);
for _ in 0..dimension {
rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
vec.push(val);
}
vectors.push(vec);
}
let mut pq = ProductQuantizer::new(dimension, 4, 256);
pq.train(&vectors, 10).unwrap();
let codes: Vec<Vec<u8>> = vectors.iter().take(100).map(|v| pq.encode(v)).collect();
let query = &vectors[0];
let table = pq.compute_distance_table(query);
let self_dist = pq.asymmetric_distance(&table, &codes[0]);
assert!(
self_dist < 1.0,
"Self distance should be small, got {}",
self_dist
);
let mut adc_distances: Vec<(usize, f32)> = codes
.iter()
.enumerate()
.map(|(i, c)| (i, pq.asymmetric_distance(&table, c)))
.collect();
adc_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
assert_eq!(adc_distances[0].0, 0, "Query should be nearest to itself");
}
#[test]
fn test_pq_compression_ratio() {
let pq = ProductQuantizer::new(768, 96, 256);
let ratio = pq.compression_ratio();
assert!((ratio - 32.0).abs() < 0.1);
}
#[test]
fn test_kmeans_clustering() {
let mut vectors = Vec::new();
for _ in 0..50 {
vectors.push(vec![1.0 + 0.1, 1.0 + 0.1]);
}
for _ in 0..50 {
vectors.push(vec![-1.0 - 0.1, -1.0 - 0.1]);
}
let centroids = kmeans_clustering(&vectors, 2, 20);
assert_eq!(centroids.len(), 2);
let has_positive = centroids.iter().any(|c| c[0] > 0.5 && c[1] > 0.5);
let has_negative = centroids.iter().any(|c| c[0] < -0.5 && c[1] < -0.5);
assert!(has_positive, "Should have centroid near (1, 1)");
assert!(has_negative, "Should have centroid near (-1, -1)");
}
#[test]
fn test_squared_euclidean() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![1.0, 2.0, 2.0];
let dist_sq = squared_euclidean(&a, &b);
assert!((dist_sq - 9.0).abs() < 0.001);
}
}