use crate::distance::DistanceMetric;
use crate::error::{Result, RetrieveError};
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct LsmConfig {
pub dimension: usize,
pub buffer_capacity: usize,
pub size_ratio: usize,
pub max_levels: usize,
pub hnsw_m: usize,
pub hnsw_ef_construction: usize,
pub ef_search: usize,
pub distance_metric: DistanceMetric,
}
impl Default for LsmConfig {
fn default() -> Self {
Self {
dimension: 128,
buffer_capacity: 10_000,
size_ratio: 10,
max_levels: 5,
hnsw_m: 16,
hnsw_ef_construction: 200,
ef_search: 100,
distance_metric: DistanceMetric::Cosine,
}
}
}
#[derive(Debug)]
struct Level {
vectors: Vec<f32>,
doc_ids: Vec<u32>,
count: usize,
#[cfg(feature = "hnsw")]
hnsw: Option<crate::hnsw::HNSWIndex>,
}
impl Level {
fn new() -> Self {
Self {
vectors: Vec::new(),
doc_ids: Vec::new(),
count: 0,
#[cfg(feature = "hnsw")]
hnsw: None,
}
}
fn is_empty(&self) -> bool {
self.count == 0
}
fn brute_force_search(
&self,
query: &[f32],
k: usize,
dimension: usize,
tombstones: &HashSet<u32>,
metric: DistanceMetric,
) -> Vec<(u32, f32)> {
let mut results: Vec<(u32, f32)> = (0..self.count)
.filter_map(|i| {
let doc_id = self.doc_ids[i];
if tombstones.contains(&doc_id) {
return None;
}
let start = i * dimension;
let vec = &self.vectors[start..start + dimension];
let dist = metric.distance(query, vec);
Some((doc_id, dist))
})
.collect();
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
results.truncate(k);
results
}
}
pub struct LsmIndex {
config: LsmConfig,
levels: Vec<Level>,
tombstones: HashSet<u32>,
total_inserts: u64,
total_deletes: u64,
total_compactions: u64,
}
impl LsmIndex {
pub fn new(config: LsmConfig) -> Self {
let mut levels = Vec::with_capacity(config.max_levels);
levels.push(Level::new()); Self {
config,
levels,
tombstones: HashSet::new(),
total_inserts: 0,
total_deletes: 0,
total_compactions: 0,
}
}
pub fn insert(&mut self, doc_id: u32, vector: Vec<f32>) -> Result<()> {
self.insert_slice(doc_id, &vector)
}
pub fn insert_slice(&mut self, doc_id: u32, vector: &[f32]) -> Result<()> {
if vector.len() != self.config.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: vector.len(),
doc_dim: self.config.dimension,
});
}
self.tombstones.remove(&doc_id);
let norm: f32 = vector.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
self.levels[0]
.vectors
.extend(vector.iter().map(|x| x / norm));
} else {
self.levels[0].vectors.extend_from_slice(vector);
}
self.levels[0].doc_ids.push(doc_id);
self.levels[0].count += 1;
self.total_inserts += 1;
if self.levels[0].count >= self.config.buffer_capacity {
self.compact()?;
}
Ok(())
}
pub fn delete(&mut self, doc_id: u32) {
self.tombstones.insert(doc_id);
self.total_deletes += 1;
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>> {
if query.len() != self.config.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: query.len(),
doc_dim: self.config.dimension,
});
}
let query_norm: f32 = query.iter().map(|x| x * x).sum::<f32>().sqrt();
let query_normalized: Vec<f32> = if query_norm > 1e-10 {
query.iter().map(|x| x / query_norm).collect()
} else {
query.to_vec()
};
let query = query_normalized.as_slice();
let mut all_results: Vec<(u32, f32)> = Vec::new();
for (level_idx, level) in self.levels.iter().enumerate() {
if level.is_empty() {
continue;
}
let level_results = if level_idx == 0 {
level.brute_force_search(
query,
k,
self.config.dimension,
&self.tombstones,
self.config.distance_metric,
)
} else {
#[cfg(feature = "hnsw")]
{
if let Some(ref hnsw) = level.hnsw {
let ef = self.config.ef_search.max(k);
match hnsw.search(query, k, ef) {
Ok(results) => results
.into_iter()
.filter(|(id, _)| !self.tombstones.contains(id))
.collect(),
Err(_) => Vec::new(),
}
} else {
level.brute_force_search(
query,
k,
self.config.dimension,
&self.tombstones,
self.config.distance_metric,
)
}
}
#[cfg(not(feature = "hnsw"))]
{
level.brute_force_search(
query,
k,
self.config.dimension,
&self.tombstones,
self.config.distance_metric,
)
}
};
all_results.extend(level_results);
}
let mut seen = HashSet::new();
all_results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
all_results.retain(|(id, _)| seen.insert(*id));
all_results.truncate(k);
Ok(all_results)
}
pub fn compact(&mut self) -> Result<()> {
if self.levels[0].is_empty() {
return Ok(());
}
let l0 = std::mem::replace(&mut self.levels[0], Level::new());
if self.levels.len() < 2 {
self.levels.push(Level::new());
}
self.merge_into_level(l0, 1)?;
self.cascade_compact(1)?;
self.total_compactions += 1;
Ok(())
}
fn merge_into_level(&mut self, source: Level, target_idx: usize) -> Result<()> {
let dim = self.config.dimension;
let mut merged_vectors: Vec<f32> = Vec::new();
let mut merged_ids: Vec<u32> = Vec::new();
if target_idx < self.levels.len() {
let target = &self.levels[target_idx];
for i in 0..target.count {
let doc_id = target.doc_ids[i];
if !self.tombstones.contains(&doc_id) {
let start = i * dim;
merged_vectors.extend_from_slice(&target.vectors[start..start + dim]);
merged_ids.push(doc_id);
}
}
}
for i in 0..source.count {
let doc_id = source.doc_ids[i];
if !self.tombstones.contains(&doc_id) {
let start = i * dim;
merged_vectors.extend_from_slice(&source.vectors[start..start + dim]);
merged_ids.push(doc_id);
}
}
let merged_count = merged_ids.len();
#[cfg(feature = "hnsw")]
let hnsw = if merged_count > 0 {
let mut hnsw = crate::hnsw::HNSWIndex::builder(dim)
.m(self.config.hnsw_m)
.ef_construction(self.config.hnsw_ef_construction)
.auto_normalize(false) .build()?;
for (i, &doc_id) in merged_ids.iter().enumerate() {
let start = i * dim;
hnsw.add_slice(doc_id, &merged_vectors[start..start + dim])?;
}
hnsw.build()?;
Some(hnsw)
} else {
None
};
while self.levels.len() <= target_idx {
self.levels.push(Level::new());
}
self.levels[target_idx] = Level {
vectors: merged_vectors,
doc_ids: merged_ids,
count: merged_count,
#[cfg(feature = "hnsw")]
hnsw,
};
Ok(())
}
fn cascade_compact(&mut self, level_idx: usize) -> Result<()> {
if level_idx >= self.config.max_levels - 1 {
return Ok(()); }
let level_size = self.levels.get(level_idx).map_or(0, |l| l.count);
let next_size = self.levels.get(level_idx + 1).map_or(0, |l| l.count);
let should_compact = if next_size == 0 {
level_size >= self.config.buffer_capacity * self.config.size_ratio
} else {
level_size >= self.config.size_ratio * next_size
};
if should_compact {
let source = std::mem::replace(&mut self.levels[level_idx], Level::new());
while self.levels.len() <= level_idx + 1 {
self.levels.push(Level::new());
}
self.merge_into_level(source, level_idx + 1)?;
self.cascade_compact(level_idx + 1)?;
}
Ok(())
}
pub fn len(&self) -> usize {
let total: usize = self.levels.iter().map(|l| l.count).sum();
total.saturating_sub(self.tombstones.len())
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn num_levels(&self) -> usize {
self.levels.len()
}
pub fn level_sizes(&self) -> Vec<usize> {
self.levels.iter().map(|l| l.count).collect()
}
pub fn stats(&self) -> LsmStats {
LsmStats {
total_inserts: self.total_inserts,
total_deletes: self.total_deletes,
total_compactions: self.total_compactions,
num_levels: self.levels.len(),
level_sizes: self.level_sizes(),
tombstone_count: self.tombstones.len(),
}
}
pub fn force_merge_all(&mut self) -> Result<()> {
if self.levels.is_empty() {
return Ok(());
}
let dim = self.config.dimension;
let mut all_vectors: Vec<f32> = Vec::new();
let mut all_ids: Vec<u32> = Vec::new();
for level in &self.levels {
for i in 0..level.count {
let doc_id = level.doc_ids[i];
if !self.tombstones.contains(&doc_id) {
let start = i * dim;
all_vectors.extend_from_slice(&level.vectors[start..start + dim]);
all_ids.push(doc_id);
}
}
}
self.levels.clear();
self.levels.push(Level::new());
let count = all_ids.len();
if count == 0 {
return Ok(());
}
#[cfg(feature = "hnsw")]
let hnsw = {
let mut hnsw = crate::hnsw::HNSWIndex::builder(dim)
.m(self.config.hnsw_m)
.ef_construction(self.config.hnsw_ef_construction)
.auto_normalize(false)
.build()?;
for (i, &doc_id) in all_ids.iter().enumerate() {
let start = i * dim;
hnsw.add_slice(doc_id, &all_vectors[start..start + dim])?;
}
hnsw.build()?;
Some(hnsw)
};
self.levels.push(Level {
vectors: all_vectors,
doc_ids: all_ids,
count,
#[cfg(feature = "hnsw")]
hnsw,
});
self.tombstones.clear();
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct LsmStats {
pub total_inserts: u64,
pub total_deletes: u64,
pub total_compactions: u64,
pub num_levels: usize,
pub level_sizes: Vec<usize>,
pub tombstone_count: usize,
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
fn make_config(dim: usize) -> LsmConfig {
LsmConfig {
dimension: dim,
buffer_capacity: 20,
size_ratio: 5,
max_levels: 4,
hnsw_m: 8,
hnsw_ef_construction: 50,
ef_search: 50,
distance_metric: DistanceMetric::L2,
}
}
fn make_vector(dim: usize, seed: u32) -> Vec<f32> {
(0..dim)
.map(|i| (seed as f32 * 0.1 + i as f32 * 0.01).sin())
.collect()
}
#[test]
fn insert_and_search_l0() {
let mut index = LsmIndex::new(make_config(8));
for i in 0..10u32 {
index.insert(i, make_vector(8, i)).unwrap();
}
assert_eq!(index.level_sizes(), vec![10]);
let results = index.search(&make_vector(8, 0), 3).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, 0); }
#[test]
fn compaction_moves_to_l1() {
let mut index = LsmIndex::new(make_config(8));
for i in 0..25u32 {
index.insert(i, make_vector(8, i)).unwrap();
}
let sizes = index.level_sizes();
assert!(sizes.len() >= 2, "expected at least 2 levels: {sizes:?}");
assert!(
sizes[1] > 0,
"L1 should have vectors after compaction: {sizes:?}"
);
let results = index.search(&make_vector(8, 5), 3).unwrap();
assert!(!results.is_empty());
}
#[test]
fn delete_filters_from_search() {
let mut index = LsmIndex::new(make_config(8));
for i in 0..10u32 {
index.insert(i, make_vector(8, i)).unwrap();
}
index.delete(0);
index.delete(1);
let results = index.search(&make_vector(8, 0), 10).unwrap();
for (id, _) in &results {
assert!(*id != 0 && *id != 1, "deleted ID {id} in results");
}
}
#[test]
fn delete_survives_compaction() {
let mut index = LsmIndex::new(make_config(8));
for i in 0..25u32 {
index.insert(i, make_vector(8, i)).unwrap();
}
index.delete(5);
for i in 25..50u32 {
index.insert(i, make_vector(8, i)).unwrap();
}
let results = index.search(&make_vector(8, 5), 50).unwrap();
for (id, _) in &results {
assert_ne!(*id, 5, "deleted ID 5 in results after compaction");
}
}
#[test]
fn force_merge_all() {
let mut index = LsmIndex::new(make_config(8));
for i in 0..50u32 {
index.insert(i, make_vector(8, i)).unwrap();
}
index.delete(10);
index.force_merge_all().unwrap();
let sizes = index.level_sizes();
assert_eq!(sizes[0], 0, "L0 should be empty after merge");
assert_eq!(sizes[1], 49, "L1 should have 49 vectors (50 - 1 deleted)");
assert_eq!(index.stats().tombstone_count, 0);
let results = index.search(&make_vector(8, 0), 3).unwrap();
assert!(!results.is_empty());
}
#[test]
fn reinsert_after_delete() {
let mut index = LsmIndex::new(make_config(8));
index.insert(0, make_vector(8, 0)).unwrap();
index.delete(0);
index.insert(0, make_vector(8, 100)).unwrap();
let results = index.search(&make_vector(8, 100), 1).unwrap();
assert_eq!(results[0].0, 0);
}
#[test]
fn empty_search() {
let index = LsmIndex::new(make_config(8));
let results = index.search(&make_vector(8, 0), 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn stats_tracking() {
let mut index = LsmIndex::new(make_config(8));
for i in 0..30u32 {
index.insert(i, make_vector(8, i)).unwrap();
}
index.delete(0);
let stats = index.stats();
assert_eq!(stats.total_inserts, 30);
assert_eq!(stats.total_deletes, 1);
assert!(stats.total_compactions >= 1);
assert_eq!(stats.tombstone_count, 1);
}
}