use aletheiadb::core::id::NodeId;
use aletheiadb::index::vector::{DistanceMetric, HnswIndexBuilder, VectorIndex};
#[test]
fn test_update_exercises_contains_check() {
let index = HnswIndexBuilder::new(4, DistanceMetric::Cosine)
.build()
.unwrap();
let node = NodeId::new(1).unwrap();
index.add(node, &[1.0, 0.0, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
index.add(node, &[0.9, 0.1, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
index.add(node, &[0.8, 0.2, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
let results = index.search(&[0.8, 0.2, 0.0, 0.0], 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, node);
assert!(results[0].1 > 0.99);
}
#[test]
fn test_multiple_updates_for_coverage() {
let index = HnswIndexBuilder::new(4, DistanceMetric::Cosine)
.build()
.unwrap();
let node = NodeId::new(1).unwrap();
for i in 0..10 {
let mut vector = vec![0.0f32; 4];
vector[i % 4] = 1.0;
index.add(node, &vector).unwrap();
assert_eq!(index.len(), 1);
}
let query = vec![0.0, 1.0, 0.0, 0.0]; let results = index.search(&query, 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, node);
}
#[test]
fn test_add_remove_readd_update_sequence() {
let index = HnswIndexBuilder::new(4, DistanceMetric::Cosine)
.build()
.unwrap();
let node = NodeId::new(1).unwrap();
index.add(node, &[1.0, 0.0, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
index.add(node, &[0.9, 0.1, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
index.remove(node).unwrap();
assert_eq!(index.len(), 0);
index.add(node, &[0.8, 0.2, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
index.add(node, &[0.7, 0.3, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
let results = index.search(&[0.7, 0.3, 0.0, 0.0], 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, node);
assert!(results[0].1 > 0.99);
}
#[test]
fn test_concurrent_updates_coverage() {
use std::sync::Arc;
use std::thread;
let index = Arc::new(
HnswIndexBuilder::new(4, DistanceMetric::Cosine)
.build()
.unwrap(),
);
for i in 1..=5 {
let node = NodeId::new(i).unwrap();
let vector = vec![i as f32, 0.0, 0.0, 0.0];
index.add(node, &vector).unwrap();
}
assert_eq!(index.len(), 5);
let handles: Vec<_> = (1..=5)
.map(|i| {
let index_clone = Arc::clone(&index);
thread::spawn(move || {
let node = NodeId::new(i).unwrap();
for j in 0..3 {
let mut vector = vec![0.0f32; 4];
vector[(i + j) as usize % 4] = 1.0;
index_clone.add(node, &vector).unwrap();
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(index.len(), 5);
}
#[test]
fn test_update_dimension_mismatch() {
let index = HnswIndexBuilder::new(4, DistanceMetric::Cosine)
.build()
.unwrap();
let node = NodeId::new(1).unwrap();
index.add(node, &[1.0, 0.0, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
let result = index.add(node, &[1.0, 0.0, 0.0]); assert!(result.is_err());
assert_eq!(index.len(), 1);
let results = index.search(&[1.0, 0.0, 0.0, 0.0], 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, node);
}
#[test]
fn test_update_with_nan() {
let index = HnswIndexBuilder::new(4, DistanceMetric::Cosine)
.build()
.unwrap();
let node = NodeId::new(1).unwrap();
index.add(node, &[1.0, 0.0, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
let result = index.add(node, &[f32::NAN, 0.0, 0.0, 0.0]);
assert!(result.is_err());
assert_eq!(index.len(), 1);
}
#[test]
fn test_update_with_infinity() {
let index = HnswIndexBuilder::new(4, DistanceMetric::Cosine)
.build()
.unwrap();
let node = NodeId::new(1).unwrap();
index.add(node, &[1.0, 0.0, 0.0, 0.0]).unwrap();
assert_eq!(index.len(), 1);
let result = index.add(node, &[f32::INFINITY, 0.0, 0.0, 0.0]);
assert!(result.is_err());
assert_eq!(index.len(), 1);
}