use std::collections::{HashMap, HashSet};
use parking_lot::RwLock;
use rand::seq::SliceRandom;
use serde::{Deserialize, Serialize};
use common::{DistanceMetric, Vector, VectorId};
use crate::distance::calculate_distance;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpFreshConfig {
pub num_clusters: usize,
pub max_cluster_size: usize,
pub min_cluster_size: usize,
pub n_probe: usize,
pub compaction_threshold: f32,
pub distance_metric: DistanceMetric,
}
impl Default for SpFreshConfig {
fn default() -> Self {
Self {
num_clusters: 16,
max_cluster_size: 1000,
min_cluster_size: 50,
n_probe: 4,
compaction_threshold: 0.3,
distance_metric: DistanceMetric::Cosine,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Cluster {
pub id: usize,
pub centroid: Vec<f32>,
pub vectors: Vec<Vector>,
pub tombstones: HashSet<VectorId>,
pub live_count: usize,
}
impl Cluster {
fn new(id: usize, centroid: Vec<f32>) -> Self {
Self {
id,
centroid,
vectors: Vec::new(),
tombstones: HashSet::new(),
live_count: 0,
}
}
fn live_vectors(&self) -> impl Iterator<Item = &Vector> {
self.vectors
.iter()
.filter(|v| !self.tombstones.contains(&v.id))
}
fn tombstone_ratio(&self) -> f32 {
if self.vectors.is_empty() {
0.0
} else {
self.tombstones.len() as f32 / self.vectors.len() as f32
}
}
fn recompute_centroid(&mut self) {
let live: Vec<&Vector> = self.live_vectors().collect();
if live.is_empty() {
return;
}
let dim = live[0].values.len();
let mut new_centroid = vec![0.0f32; dim];
for vector in &live {
for (i, &val) in vector.values.iter().enumerate() {
new_centroid[i] += val;
}
}
let count = live.len() as f32;
for val in &mut new_centroid {
*val /= count;
}
self.centroid = new_centroid;
}
fn compact(&mut self) {
self.vectors.retain(|v| !self.tombstones.contains(&v.id));
self.tombstones.clear();
self.live_count = self.vectors.len();
}
}
#[derive(Debug, Clone)]
pub struct SpFreshSearchResult {
pub id: VectorId,
pub score: f32,
pub vector: Option<Vector>,
}
pub struct SpFreshIndex {
config: SpFreshConfig,
clusters: RwLock<Vec<Cluster>>,
vector_cluster_map: RwLock<HashMap<VectorId, usize>>,
global_tombstones: RwLock<HashSet<VectorId>>,
pending_vectors: RwLock<Vec<Vector>>,
trained: RwLock<bool>,
dimension: RwLock<Option<usize>>,
}
impl SpFreshIndex {
pub fn new(config: SpFreshConfig) -> Self {
Self {
config,
clusters: RwLock::new(Vec::new()),
vector_cluster_map: RwLock::new(HashMap::new()),
global_tombstones: RwLock::new(HashSet::new()),
pending_vectors: RwLock::new(Vec::new()),
trained: RwLock::new(false),
dimension: RwLock::new(None),
}
}
pub fn is_trained(&self) -> bool {
*self.trained.read()
}
pub fn dimension(&self) -> Option<usize> {
*self.dimension.read()
}
pub fn train(&self, vectors: &[Vector]) -> Result<(), String> {
if vectors.is_empty() {
return Err("Cannot train with empty vectors".to_string());
}
let dim = vectors[0].values.len();
*self.dimension.write() = Some(dim);
let centroids = self.kmeans_plus_plus_init(vectors);
let final_centroids = self.kmeans_iterate(vectors, centroids, 20);
let mut clusters = Vec::with_capacity(self.config.num_clusters);
for (i, centroid) in final_centroids.into_iter().enumerate() {
clusters.push(Cluster::new(i, centroid));
}
let mut vector_cluster_map = HashMap::new();
for vector in vectors {
let cluster_id = self.find_nearest_cluster_idx(&vector.values, &clusters);
clusters[cluster_id].vectors.push(vector.clone());
clusters[cluster_id].live_count += 1;
vector_cluster_map.insert(vector.id.clone(), cluster_id);
}
for cluster in &mut clusters {
cluster.recompute_centroid();
}
*self.clusters.write() = clusters;
*self.vector_cluster_map.write() = vector_cluster_map;
*self.trained.write() = true;
Ok(())
}
fn kmeans_plus_plus_init(&self, vectors: &[Vector]) -> Vec<Vec<f32>> {
let mut rng = rand::thread_rng();
let k = self.config.num_clusters.min(vectors.len());
let mut centroids = Vec::with_capacity(k);
let first = vectors.choose(&mut rng).unwrap();
centroids.push(first.values.clone());
for _ in 1..k {
let mut distances: Vec<f32> = vectors
.iter()
.map(|v| {
centroids
.iter()
.map(|c| calculate_distance(&v.values, c, self.config.distance_metric))
.fold(f32::MAX, f32::min)
})
.collect();
let total: f32 = distances.iter().sum();
if total == 0.0 {
break;
}
for d in &mut distances {
*d /= total;
}
let threshold: f32 = rand::random();
let mut cumsum = 0.0;
for (i, d) in distances.iter().enumerate() {
cumsum += d;
if cumsum >= threshold {
centroids.push(vectors[i].values.clone());
break;
}
}
}
centroids
}
fn kmeans_iterate(
&self,
vectors: &[Vector],
mut centroids: Vec<Vec<f32>>,
max_iters: usize,
) -> Vec<Vec<f32>> {
let dim = vectors[0].values.len();
for _ in 0..max_iters {
let mut assignments: Vec<Vec<&Vector>> = vec![Vec::new(); centroids.len()];
for vector in vectors {
let mut best_idx = 0;
let mut best_dist = f32::MAX;
for (i, centroid) in centroids.iter().enumerate() {
let dist =
calculate_distance(&vector.values, centroid, self.config.distance_metric);
if dist < best_dist {
best_dist = dist;
best_idx = i;
}
}
assignments[best_idx].push(vector);
}
let mut new_centroids = Vec::with_capacity(centroids.len());
for (i, assigned) in assignments.iter().enumerate() {
if assigned.is_empty() {
new_centroids.push(centroids[i].clone());
} else {
let mut new_centroid = vec![0.0f32; dim];
for vector in assigned {
for (j, &val) in vector.values.iter().enumerate() {
new_centroid[j] += val;
}
}
let count = assigned.len() as f32;
for val in &mut new_centroid {
*val /= count;
}
new_centroids.push(new_centroid);
}
}
centroids = new_centroids;
}
centroids
}
fn find_nearest_cluster_idx(&self, vector: &[f32], clusters: &[Cluster]) -> usize {
let mut best_idx = 0;
let mut best_dist = f32::MAX;
for (i, cluster) in clusters.iter().enumerate() {
let dist = calculate_distance(vector, &cluster.centroid, self.config.distance_metric);
if dist < best_dist {
best_dist = dist;
best_idx = i;
}
}
best_idx
}
pub fn add(&self, vectors: Vec<Vector>) -> Result<usize, String> {
if vectors.is_empty() {
return Ok(0);
}
let dim = vectors[0].values.len();
{
let current_dim = *self.dimension.read();
if let Some(expected) = current_dim {
if dim != expected {
return Err(format!(
"Dimension mismatch: expected {}, got {}",
expected, dim
));
}
} else {
*self.dimension.write() = Some(dim);
}
}
let count = vectors.len();
if !self.is_trained() {
let mut pending = self.pending_vectors.write();
for vector in vectors {
if !self.global_tombstones.read().contains(&vector.id) {
pending.push(vector);
}
}
return Ok(count);
}
let mut clusters = self.clusters.write();
let mut vector_map = self.vector_cluster_map.write();
let global_tombstones = self.global_tombstones.read();
for vector in vectors {
if global_tombstones.contains(&vector.id) {
continue;
}
let cluster_id = self.find_nearest_cluster_idx(&vector.values, &clusters);
if let Some(&old_cluster_id) = vector_map.get(&vector.id) {
if old_cluster_id != cluster_id {
clusters[old_cluster_id]
.tombstones
.insert(vector.id.clone());
clusters[old_cluster_id].live_count =
clusters[old_cluster_id].live_count.saturating_sub(1);
}
}
clusters[cluster_id].vectors.push(vector.clone());
clusters[cluster_id].live_count += 1;
vector_map.insert(vector.id.clone(), cluster_id);
}
drop(vector_map);
self.check_splits(&mut clusters);
Ok(count)
}
fn check_splits(&self, clusters: &mut Vec<Cluster>) {
let mut new_clusters = Vec::new();
let max_size = self.config.max_cluster_size;
let base_len = clusters.len();
for cluster in clusters.iter_mut().take(base_len) {
if cluster.live_count > max_size {
let new_id = base_len + new_clusters.len();
if let Some(new_cluster) = self.split_cluster(cluster, new_id) {
new_clusters.push(new_cluster);
}
}
}
clusters.extend(new_clusters);
}
fn split_cluster(&self, cluster: &mut Cluster, new_id: usize) -> Option<Cluster> {
let live_vectors: Vec<Vector> = cluster.live_vectors().cloned().collect();
if live_vectors.len() < 2 {
return None;
}
let mut max_dist = 0.0f32;
let mut idx1 = 0;
let mut idx2 = 1;
for (i, v1) in live_vectors.iter().enumerate() {
for (j, v2) in live_vectors.iter().enumerate().skip(i + 1) {
let dist = calculate_distance(&v1.values, &v2.values, self.config.distance_metric);
if dist > max_dist {
max_dist = dist;
idx1 = i;
idx2 = j;
}
}
}
let centroid1 = live_vectors[idx1].values.clone();
let centroid2 = live_vectors[idx2].values.clone();
let mut vectors1 = Vec::new();
let mut vectors2 = Vec::new();
for vector in live_vectors {
let dist1 = calculate_distance(&vector.values, ¢roid1, self.config.distance_metric);
let dist2 = calculate_distance(&vector.values, ¢roid2, self.config.distance_metric);
if dist1 <= dist2 {
vectors1.push(vector);
} else {
vectors2.push(vector);
}
}
cluster.vectors = vectors1;
cluster.tombstones.clear();
cluster.live_count = cluster.vectors.len();
cluster.recompute_centroid();
let mut new_cluster = Cluster::new(new_id, centroid2);
new_cluster.vectors = vectors2;
new_cluster.live_count = new_cluster.vectors.len();
new_cluster.recompute_centroid();
let mut vector_map = self.vector_cluster_map.write();
for v in &cluster.vectors {
vector_map.insert(v.id.clone(), cluster.id);
}
for v in &new_cluster.vectors {
vector_map.insert(v.id.clone(), new_cluster.id);
}
Some(new_cluster)
}
pub fn remove(&self, ids: &[VectorId]) -> usize {
if !self.is_trained() {
let mut pending = self.pending_vectors.write();
let mut global_tombstones = self.global_tombstones.write();
let before = pending.len();
pending.retain(|v| !ids.contains(&v.id));
for id in ids {
global_tombstones.insert(id.clone());
}
return before - pending.len();
}
let mut clusters = self.clusters.write();
let vector_map = self.vector_cluster_map.read();
let mut count = 0;
for id in ids {
if let Some(&cluster_id) = vector_map.get(id) {
if cluster_id < clusters.len() {
clusters[cluster_id].tombstones.insert(id.clone());
clusters[cluster_id].live_count =
clusters[cluster_id].live_count.saturating_sub(1);
count += 1;
}
}
}
count
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SpFreshSearchResult>, String> {
if !self.is_trained() {
return self.search_pending(query, k);
}
let clusters = self.clusters.read();
if clusters.is_empty() {
return Ok(Vec::new());
}
let mut cluster_distances: Vec<(usize, f32)> = clusters
.iter()
.enumerate()
.map(|(i, c)| {
(
i,
calculate_distance(query, &c.centroid, self.config.distance_metric),
)
})
.collect();
cluster_distances
.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let n_probe = self.config.n_probe.min(clusters.len());
let mut results: Vec<SpFreshSearchResult> = Vec::new();
for (cluster_idx, _) in cluster_distances.iter().take(n_probe) {
let cluster = &clusters[*cluster_idx];
for vector in cluster.live_vectors() {
let score = calculate_distance(query, &vector.values, self.config.distance_metric);
results.push(SpFreshSearchResult {
id: vector.id.clone(),
score,
vector: Some(vector.clone()),
});
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
Ok(results)
}
fn search_pending(&self, query: &[f32], k: usize) -> Result<Vec<SpFreshSearchResult>, String> {
let pending = self.pending_vectors.read();
let tombstones = self.global_tombstones.read();
let mut results: Vec<SpFreshSearchResult> = pending
.iter()
.filter(|v| !tombstones.contains(&v.id))
.map(|v| SpFreshSearchResult {
id: v.id.clone(),
score: calculate_distance(query, &v.values, self.config.distance_metric),
vector: Some(v.clone()),
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(k);
Ok(results)
}
pub fn compact(&self) -> usize {
if !self.is_trained() {
return 0;
}
let mut clusters = self.clusters.write();
let mut compacted = 0;
for cluster in clusters.iter_mut() {
if cluster.tombstone_ratio() >= self.config.compaction_threshold {
cluster.compact();
compacted += 1;
}
}
if compacted > 0 {
let mut vector_map = self.vector_cluster_map.write();
vector_map.clear();
for cluster in clusters.iter() {
for vector in &cluster.vectors {
vector_map.insert(vector.id.clone(), cluster.id);
}
}
}
compacted
}
pub fn merge_small_clusters(&self) -> usize {
if !self.is_trained() {
return 0;
}
let mut clusters = self.clusters.write();
let min_size = self.config.min_cluster_size;
let small_clusters: Vec<usize> = clusters
.iter()
.enumerate()
.filter(|(_, c)| c.live_count < min_size && c.live_count > 0)
.map(|(i, _)| i)
.collect();
if small_clusters.len() < 2 {
return 0;
}
let mut merged = 0;
for chunk in small_clusters.chunks(2) {
if chunk.len() == 2 {
let (idx1, idx2) = (chunk[0], chunk[1]);
let vectors_to_move: Vec<Vector> = clusters[idx2].live_vectors().cloned().collect();
for vector in vectors_to_move {
clusters[idx1].vectors.push(vector);
clusters[idx1].live_count += 1;
}
clusters[idx2].vectors.clear();
clusters[idx2].tombstones.clear();
clusters[idx2].live_count = 0;
clusters[idx1].recompute_centroid();
merged += 1;
}
}
if merged > 0 {
let mut vector_map = self.vector_cluster_map.write();
for cluster in clusters.iter() {
for vector in &cluster.vectors {
if !cluster.tombstones.contains(&vector.id) {
vector_map.insert(vector.id.clone(), cluster.id);
}
}
}
}
merged
}
pub fn stats(&self) -> SpFreshStats {
let clusters = self.clusters.read();
let pending = self.pending_vectors.read();
let total_vectors: usize = clusters.iter().map(|c| c.live_count).sum();
let total_tombstones: usize = clusters.iter().map(|c| c.tombstones.len()).sum();
SpFreshStats {
num_clusters: clusters.len(),
total_vectors,
total_tombstones,
pending_vectors: pending.len(),
trained: *self.trained.read(),
dimension: *self.dimension.read(),
}
}
pub fn config(&self) -> &SpFreshConfig {
&self.config
}
pub(crate) fn clusters_read(&self) -> Vec<Cluster> {
self.clusters.read().clone()
}
pub(crate) fn vector_cluster_map_read(&self) -> HashMap<VectorId, usize> {
self.vector_cluster_map.read().clone()
}
pub(crate) fn global_tombstones_read(&self) -> HashSet<VectorId> {
self.global_tombstones.read().clone()
}
pub(crate) fn pending_vectors_read(&self) -> Vec<Vector> {
self.pending_vectors.read().clone()
}
pub fn from_snapshot(
snapshot: crate::persistence::SpFreshFullSnapshot,
) -> Result<Self, String> {
Ok(Self {
config: snapshot.config,
clusters: RwLock::new(snapshot.clusters),
vector_cluster_map: RwLock::new(snapshot.vector_cluster_map),
global_tombstones: RwLock::new(snapshot.global_tombstones),
pending_vectors: RwLock::new(snapshot.pending_vectors),
trained: RwLock::new(snapshot.trained),
dimension: RwLock::new(snapshot.dimension),
})
}
}
#[derive(Debug, Clone)]
pub struct SpFreshStats {
pub num_clusters: usize,
pub total_vectors: usize,
pub total_tombstones: usize,
pub pending_vectors: usize,
pub trained: bool,
pub dimension: Option<usize>,
}
#[cfg(test)]
mod tests {
use super::*;
fn test_vectors(n: usize, dim: usize) -> Vec<Vector> {
(0..n)
.map(|i| Vector {
id: format!("v{}", i),
values: (0..dim)
.map(|j| {
(i as f32) + (j as f32 * 0.01)
})
.collect(),
metadata: None,
ttl_seconds: None,
expires_at: None,
})
.collect()
}
#[test]
fn test_train_and_search() {
let config = SpFreshConfig {
num_clusters: 1,
n_probe: 1,
distance_metric: DistanceMetric::Euclidean,
..Default::default()
};
let index = SpFreshIndex::new(config);
let vectors = test_vectors(50, 4);
index.train(&vectors).unwrap();
assert!(index.is_trained());
assert_eq!(index.dimension(), Some(4));
let results = index.search(&vectors[25].values, 5).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].id, "v25");
assert!(results[0].score < 0.001, "Exact match should have score ~0");
for i in 1..results.len() {
assert!(
results[i - 1].score >= results[i].score,
"Results should be sorted by score descending"
);
}
}
#[test]
fn test_multi_cluster_search() {
let config = SpFreshConfig {
num_clusters: 4,
n_probe: 4, distance_metric: DistanceMetric::Euclidean,
..Default::default()
};
let index = SpFreshIndex::new(config);
let vectors = test_vectors(100, 8);
index.train(&vectors).unwrap();
let results = index.search(&vectors[50].values, 10).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 10);
for i in 1..results.len() {
assert!(results[i - 1].score >= results[i].score);
}
let stats = index.stats();
assert_eq!(stats.num_clusters, 4);
assert_eq!(stats.total_vectors, 100);
}
#[test]
fn test_add_after_train() {
let config = SpFreshConfig {
num_clusters: 4,
..Default::default()
};
let index = SpFreshIndex::new(config);
let vectors = test_vectors(50, 8);
index.train(&vectors).unwrap();
let new_vectors = vec![Vector {
id: "new1".to_string(),
values: vec![0.5; 8],
metadata: None,
ttl_seconds: None,
expires_at: None,
}];
let added = index.add(new_vectors).unwrap();
assert_eq!(added, 1);
let stats = index.stats();
assert_eq!(stats.total_vectors, 51);
}
#[test]
fn test_remove_tombstone() {
let config = SpFreshConfig {
num_clusters: 4,
..Default::default()
};
let index = SpFreshIndex::new(config);
let vectors = test_vectors(50, 8);
index.train(&vectors).unwrap();
let removed = index.remove(&["v0".to_string(), "v1".to_string()]);
assert_eq!(removed, 2);
let stats = index.stats();
assert_eq!(stats.total_vectors, 48);
assert_eq!(stats.total_tombstones, 2);
}
#[test]
fn test_compaction() {
let config = SpFreshConfig {
num_clusters: 2,
compaction_threshold: 0.1,
..Default::default()
};
let index = SpFreshIndex::new(config);
let vectors = test_vectors(20, 4);
index.train(&vectors).unwrap();
let ids: Vec<String> = (0..10).map(|i| format!("v{}", i)).collect();
index.remove(&ids);
let compacted = index.compact();
assert!(compacted > 0);
let stats = index.stats();
assert_eq!(stats.total_tombstones, 0);
}
#[test]
fn test_pending_before_train() {
let config = SpFreshConfig::default();
let index = SpFreshIndex::new(config);
let vectors = test_vectors(10, 4);
index.add(vectors.clone()).unwrap();
assert!(!index.is_trained());
let stats = index.stats();
assert_eq!(stats.pending_vectors, 10);
let results = index.search(&vectors[0].values, 3).unwrap();
assert!(!results.is_empty());
}
#[test]
fn test_dimension_mismatch() {
let config = SpFreshConfig {
num_clusters: 2,
..Default::default()
};
let index = SpFreshIndex::new(config);
let vectors = test_vectors(10, 4);
index.train(&vectors).unwrap();
let bad_vectors = vec![Vector {
id: "bad".to_string(),
values: vec![1.0, 2.0], metadata: None,
ttl_seconds: None,
expires_at: None,
}];
let result = index.add(bad_vectors);
assert!(result.is_err());
}
#[test]
fn test_cluster_split() {
let config = SpFreshConfig {
num_clusters: 1,
max_cluster_size: 10,
..Default::default()
};
let index = SpFreshIndex::new(config);
let vectors = test_vectors(15, 4);
index.train(&vectors).unwrap();
let more_vectors = test_vectors(20, 4)
.into_iter()
.enumerate()
.map(|(i, mut v)| {
v.id = format!("new{}", i);
v
})
.collect();
index.add(more_vectors).unwrap();
let stats = index.stats();
assert!(stats.num_clusters > 1);
}
#[test]
fn test_stats() {
let config = SpFreshConfig {
num_clusters: 4,
..Default::default()
};
let index = SpFreshIndex::new(config);
let vectors = test_vectors(100, 8);
index.train(&vectors).unwrap();
let stats = index.stats();
assert_eq!(stats.total_vectors, 100);
assert_eq!(stats.num_clusters, 4);
assert!(stats.trained);
assert_eq!(stats.dimension, Some(8));
}
}