use aletheiadb::core::error::Result;
use aletheiadb::core::id::NodeId;
use aletheiadb::core::temporal::TimeRange;
use aletheiadb::index::VectorIndex;
use aletheiadb::index::vector::temporal::*;
use aletheiadb::index::vector::{DistanceMetric, HnswConfig};
fn create_test_index() -> Result<TemporalVectorIndex> {
let config = TemporalVectorConfig {
snapshot_strategy: SnapshotStrategy::TransactionInterval(1000),
retention_policy: RetentionPolicy::KeepN(100),
max_snapshots: 100,
full_snapshot_interval: 10,
hnsw_config: Some(HnswConfig::new(4, DistanceMetric::Cosine)),
};
TemporalVectorIndex::new(config)
}
fn create_test_index_with_snapshots() -> Result<TemporalVectorIndex> {
let config = TemporalVectorConfig {
snapshot_strategy: SnapshotStrategy::TransactionInterval(2), retention_policy: RetentionPolicy::KeepN(10),
max_snapshots: 10,
full_snapshot_interval: 10,
hnsw_config: Some(HnswConfig::new(4, DistanceMetric::Cosine)),
};
TemporalVectorIndex::new(config)
}
#[test]
fn integration_test_add_vector() -> Result<()> {
let index = create_test_index()?;
let node1 = NodeId::new(1).unwrap();
let vec1 = vec![1.0, 0.0, 0.0, 0.0];
let timestamp = 1000000.into();
index.add(node1, &vec1, timestamp)?;
assert_eq!(index.current_index().len(), 1);
Ok(())
}
#[test]
fn integration_test_multiple_adds() -> Result<()> {
let index = create_test_index()?;
let node1 = NodeId::new(1).unwrap();
let node2 = NodeId::new(2).unwrap();
let vec1 = vec![1.0, 0.0, 0.0, 0.0];
let vec2 = vec![0.0, 1.0, 0.0, 0.0];
let timestamp = 1000000.into();
index.add(node1, &vec1, timestamp)?;
index.add(node2, &vec2, (timestamp.wallclock() + 100).into())?;
assert_eq!(index.current_index().len(), 2);
Ok(())
}
#[test]
fn test_find_similar_as_of() -> Result<()> {
let index = create_test_index_with_snapshots()?;
let node1 = NodeId::new(1).unwrap();
let node2 = NodeId::new(2).unwrap();
let node3 = NodeId::new(3).unwrap();
let vec1 = vec![1.0, 0.0, 0.0, 0.0];
let vec2 = vec![0.9, 0.1, 0.0, 0.0]; let vec3 = vec![0.0, 0.0, 1.0, 0.0];
index.add(node1, &vec1, 1000.into())?;
index.on_transaction_at(1000.into())?;
index.add(node2, &vec2, 2000.into())?;
index.on_transaction_at(2000.into())?;
index.add(node3, &vec3, 3000.into())?;
index.on_transaction_at(3000.into())?;
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.find_similar_as_of(&query, 5, 2500.into())?;
assert!(results.len() >= 2, "Should find at least 2 similar vectors");
assert!(
!results.iter().any(|(id, _)| *id == node3),
"Should not find node3"
);
Ok(())
}
#[test]
fn test_find_similar_in_range() -> Result<()> {
let index = create_test_index_with_snapshots()?;
let node1 = NodeId::new(1).unwrap();
let node2 = NodeId::new(2).unwrap();
let vec1 = vec![1.0, 0.0, 0.0, 0.0];
let vec2 = vec![0.9, 0.1, 0.0, 0.0];
index.add(node1, &vec1, 1000.into())?;
index.on_transaction_at(1000.into())?;
index.add(node2, &vec2, 2000.into())?;
index.on_transaction_at(2000.into())?;
index.on_transaction_at(3000.into())?;
let query = vec![1.0, 0.0, 0.0, 0.0];
let time_range = TimeRange::new(1500.into(), 2500.into()).unwrap();
let results = index.find_similar_in_range(&query, 5, time_range)?;
assert!(!results.is_empty(), "Should have results in time range");
Ok(())
}
#[test]
fn test_create_manual_snapshot() -> Result<()> {
let index = create_test_index()?;
let node1 = NodeId::new(1).unwrap();
let vec1 = vec![1.0, 0.0, 0.0, 0.0];
index.add(node1, &vec1, 1000.into())?;
let count_before = index.snapshot_count();
index.create_manual_snapshot()?;
let count_after = index.snapshot_count();
assert_eq!(
count_after,
count_before + 1,
"Should have created one snapshot"
);
Ok(())
}
#[test]
fn test_prune_snapshots() -> Result<()> {
let config = TemporalVectorConfig {
snapshot_strategy: SnapshotStrategy::TransactionInterval(1),
retention_policy: RetentionPolicy::KeepN(2), max_snapshots: 10,
full_snapshot_interval: 10,
hnsw_config: Some(HnswConfig::new(4, DistanceMetric::Cosine)),
};
let index = TemporalVectorIndex::new(config)?;
for i in 1..=5 {
let node = NodeId::new(i).unwrap();
let vec = vec![i as f32, 0.0, 0.0, 0.0];
index.add(node, &vec, ((i * 1000) as i64).into())?;
index.on_transaction_at(((i * 1000) as i64).into())?; }
let pruned = index.prune_snapshots()?;
assert!(pruned > 0, "Should have pruned snapshots");
assert!(
index.snapshot_count() <= 2,
"Should keep at most 2 snapshots"
);
Ok(())
}
#[test]
fn test_get_snapshot_info() -> Result<()> {
let index = create_test_index_with_snapshots()?;
let node1 = NodeId::new(1).unwrap();
let vec1 = vec![1.0, 0.0, 0.0, 0.0];
index.add(node1, &vec1, 1000.into())?;
index.on_transaction_at(1000.into())?;
index.create_manual_snapshot()?;
let info = index.get_snapshot_info()?;
assert!(!info.is_empty(), "Should have snapshot info");
Ok(())
}
#[test]
fn test_dimensions_and_metric() -> Result<()> {
let index = create_test_index()?;
assert_eq!(index.dimensions(), 4, "Should have 4 dimensions");
assert_eq!(
index.distance_metric(),
DistanceMetric::Cosine,
"Should use Cosine metric"
);
Ok(())
}
#[test]
fn test_config_builders() -> Result<()> {
let hnsw_config = HnswConfig::new(128, DistanceMetric::Euclidean);
let config1 = TemporalVectorConfig::default_with_hnsw(hnsw_config.clone());
assert!(matches!(
config1.snapshot_strategy,
SnapshotStrategy::TransactionInterval(_)
));
let config2 = TemporalVectorConfig::with_time_interval(hnsw_config.clone(), 3600);
assert!(matches!(
config2.snapshot_strategy,
SnapshotStrategy::TimeInterval(_)
));
let config3 = TemporalVectorConfig::with_change_threshold(hnsw_config, 0.1);
assert!(matches!(
config3.snapshot_strategy,
SnapshotStrategy::ChangeThreshold(_)
));
Ok(())
}
#[test]
fn test_find_similar_in_range_chronological_order() -> Result<()> {
let index = create_test_index_with_snapshots()?;
for i in 0i64..20 {
let node_id = NodeId::new(i as u64).unwrap();
let vector = vec![1.0, 0.0, 0.0, 0.0];
index.add(node_id, &vector, (i * 1000).into())?;
index.on_transaction_at((i * 1000).into())?;
}
let query = vec![1.0, 0.0, 0.0, 0.0];
let time_range = TimeRange::new(0.into(), 20000.into()).unwrap();
let results = index.find_similar_in_range(&query, 5, time_range)?;
for i in 1..results.len() {
assert!(
results[i - 1].0 <= results[i].0,
"Results should be sorted chronologically, but found {:?} after {:?}",
results[i].0,
results[i - 1].0
);
}
Ok(())
}
#[test]
fn test_find_similar_in_range_many_snapshots() -> Result<()> {
let index = create_test_index_with_snapshots()?;
for i in 0i64..25 {
for j in 0i64..10 {
let node_id = NodeId::new((i * 10 + j) as u64).unwrap();
let vector = vec![1.0 / (i + 1) as f32, (j as f32) / 10.0, 0.0, 0.0];
index.add(node_id, &vector, (i * 1000 + j * 10).into())?;
}
index.on_transaction_at((i * 1000).into())?;
}
let query = vec![1.0, 0.0, 0.0, 0.0];
let time_range = TimeRange::new(0.into(), 25000.into()).unwrap();
let results = index.find_similar_in_range(&query, 5, time_range)?;
assert!(
results.len() >= 10,
"Should have results from multiple snapshots, got {}",
results.len()
);
for i in 1..results.len() {
assert!(
results[i - 1].0 <= results[i].0,
"Results must be chronologically ordered"
);
}
for (timestamp, snapshot_results) in &results {
assert!(
!snapshot_results.is_empty(),
"Snapshot at timestamp {} should have results",
timestamp.wallclock()
);
assert!(snapshot_results.len() <= 5, "Should respect k=5 limit");
}
Ok(())
}
#[test]
fn test_find_similar_in_range_edge_cases() -> Result<()> {
let index = create_test_index_with_snapshots()?;
let node1 = NodeId::new(1).unwrap();
let node2 = NodeId::new(2).unwrap();
let node3 = NodeId::new(3).unwrap();
let vector = vec![1.0, 0.0, 0.0, 0.0];
index.add(node1, &vector, 1000.into())?;
index.on_transaction_at(1000.into())?;
index.add(node2, &vector, 2000.into())?;
index.on_transaction_at(2000.into())?;
index.add(node3, &vector, 3000.into())?;
index.on_transaction_at(3000.into())?;
let query = vec![1.0, 0.0, 0.0, 0.0];
let empty_range = TimeRange::new(10000.into(), 11000.into()).unwrap();
let empty_results = index.find_similar_in_range(&query, 5, empty_range)?;
assert!(
empty_results.is_empty(),
"Should have no results for empty range"
);
let range_with_snapshots = TimeRange::new(1500.into(), 2500.into()).unwrap();
let results = index.find_similar_in_range(&query, 5, range_with_snapshots)?;
assert!(
!results.is_empty(),
"Should have results when snapshots exist in range"
);
Ok(())
}
#[test]
fn test_find_similar_in_range_deterministic() -> Result<()> {
let index = create_test_index_with_snapshots()?;
for i in 0i64..15 {
for j in 0i64..20 {
let node_id = NodeId::new((i * 20 + j) as u64).unwrap();
let angle = (j as f32) * std::f32::consts::PI / 10.0;
let vector = vec![angle.cos(), angle.sin(), (i as f32) / 15.0, 0.0];
index.add(node_id, &vector, (i * 1000 + j * 10).into())?;
}
index.on_transaction_at((i * 1000).into())?;
}
let query = vec![1.0, 0.0, 0.0, 0.0];
let time_range = TimeRange::new(0.into(), 15000.into()).unwrap();
let results1 = index.find_similar_in_range(&query, 10, time_range)?;
let results2 = index.find_similar_in_range(&query, 10, time_range)?;
let results3 = index.find_similar_in_range(&query, 10, time_range)?;
assert_eq!(
results1.len(),
results2.len(),
"Results should be deterministic (same length)"
);
assert_eq!(
results1.len(),
results3.len(),
"Results should be deterministic (same length)"
);
for i in 0..results1.len() {
assert_eq!(
results1[i].0, results2[i].0,
"Timestamp mismatch for results2 at index {}",
i
);
assert_eq!(
results1[i].0, results3[i].0,
"Timestamp mismatch for results3 at index {}",
i
);
let r1 = &results1[i].1;
let r2 = &results2[i].1;
let r3 = &results3[i].1;
assert_eq!(
r1.len(),
r2.len(),
"Result count mismatch for results2 at index {}",
i
);
assert_eq!(
r1.len(),
r3.len(),
"Result count mismatch for results3 at index {}",
i
);
for j in 0..r1.len() {
assert_eq!(
r1[j].0, r2[j].0,
"NodeId mismatch for results2 at index {}/{}",
i, j
);
assert!(
(r1[j].1 - r2[j].1).abs() < 1e-6,
"Score mismatch for results2 at index {}/{}: {} vs {}",
i,
j,
r1[j].1,
r2[j].1
);
assert_eq!(
r1[j].0, r3[j].0,
"NodeId mismatch for results3 at index {}/{}",
i, j
);
assert!(
(r1[j].1 - r3[j].1).abs() < 1e-6,
"Score mismatch for results3 at index {}/{}: {} vs {}",
i,
j,
r1[j].1,
r3[j].1
);
}
}
Ok(())
}
#[test]
fn test_add_batch_basic() -> Result<()> {
let index = create_test_index()?;
let batch = vec![
(
NodeId::new(1).unwrap(),
vec![1.0, 0.0, 0.0, 0.0],
1000.into(),
),
(
NodeId::new(2).unwrap(),
vec![0.0, 1.0, 0.0, 0.0],
1100.into(),
),
(
NodeId::new(3).unwrap(),
vec![0.0, 0.0, 1.0, 0.0],
1200.into(),
),
];
index.add_batch(&batch)?;
assert_eq!(
index.current_index().len(),
3,
"All 3 vectors should be added"
);
Ok(())
}
#[test]
fn test_add_batch_empty() -> Result<()> {
let index = create_test_index()?;
let batch: Vec<(NodeId, Vec<f32>, _)> = vec![];
index.add_batch(&batch)?;
assert_eq!(index.current_index().len(), 0);
Ok(())
}
#[test]
fn test_add_batch_correctness() -> Result<()> {
let index = create_test_index()?;
let batch = vec![
(
NodeId::new(1).unwrap(),
vec![1.0, 0.0, 0.0, 0.0],
1000.into(),
),
(
NodeId::new(2).unwrap(),
vec![0.9, 0.1, 0.0, 0.0],
1100.into(),
),
(
NodeId::new(3).unwrap(),
vec![0.0, 0.0, 1.0, 0.0],
1200.into(),
),
];
index.add_batch(&batch)?;
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.current_index().search(&query, 2)?;
assert_eq!(results.len(), 2, "Should find 2 most similar vectors");
assert!(
results.iter().any(|(id, _)| *id == NodeId::new(1).unwrap()),
"Node 1 should be in results"
);
assert!(
results.iter().any(|(id, _)| *id == NodeId::new(2).unwrap()),
"Node 2 should be in results"
);
Ok(())
}
#[test]
fn test_add_batch_nan_validation() -> Result<()> {
let index = create_test_index()?;
let batch = vec![
(
NodeId::new(1).unwrap(),
vec![1.0, 0.0, 0.0, 0.0],
1000.into(),
),
(
NodeId::new(2).unwrap(),
vec![f32::NAN, 1.0, 0.0, 0.0],
1100.into(),
), (
NodeId::new(3).unwrap(),
vec![0.0, 0.0, 1.0, 0.0],
1200.into(),
),
];
let result = index.add_batch(&batch);
assert!(result.is_err(), "Should reject batch with NaN values");
Ok(())
}
#[test]
fn test_add_batch_infinity_validation() -> Result<()> {
let index = create_test_index()?;
let batch = vec![
(
NodeId::new(1).unwrap(),
vec![1.0, 0.0, 0.0, 0.0],
1000.into(),
),
(
NodeId::new(2).unwrap(),
vec![f32::INFINITY, 1.0, 0.0, 0.0],
1100.into(),
), ];
let result = index.add_batch(&batch);
assert!(result.is_err(), "Should reject batch with Infinity values");
Ok(())
}
#[test]
fn test_add_batch_equivalence() -> Result<()> {
let index1 = create_test_index()?;
let index2 = create_test_index()?;
let vectors = vec![
(
NodeId::new(1).unwrap(),
vec![1.0, 0.0, 0.0, 0.0],
1000.into(),
),
(
NodeId::new(2).unwrap(),
vec![0.0, 1.0, 0.0, 0.0],
1100.into(),
),
(
NodeId::new(3).unwrap(),
vec![0.0, 0.0, 1.0, 0.0],
1200.into(),
),
(
NodeId::new(4).unwrap(),
vec![0.5, 0.5, 0.0, 0.0],
1300.into(),
),
];
index1.add_batch(&vectors)?;
for (id, vec, ts) in &vectors {
index2.add(*id, vec, *ts)?;
}
assert_eq!(index1.current_index().len(), index2.current_index().len());
let query = vec![1.0, 0.0, 0.0, 0.0];
let results1 = index1.current_index().search(&query, 3)?;
let results2 = index2.current_index().search(&query, 3)?;
assert_eq!(
results1.len(),
results2.len(),
"Should return same number of results"
);
let ids1: std::collections::HashSet<_> = results1.iter().map(|(id, _)| id).collect();
let ids2: std::collections::HashSet<_> = results2.iter().map(|(id, _)| id).collect();
assert_eq!(ids1, ids2, "Should return same nodes");
Ok(())
}
#[test]
fn test_add_batch_large() -> Result<()> {
let index = create_test_index()?;
let batch: Vec<_> = (0..1000)
.map(|i| {
let id = NodeId::new(i).unwrap();
let vec = vec![
(i as f32) / 1000.0,
((i + 1) as f32) / 1000.0,
((i + 2) as f32) / 1000.0,
((i + 3) as f32) / 1000.0,
];
let ts = ((i * 100) as i64).into();
(id, vec, ts)
})
.collect();
index.add_batch(&batch)?;
assert_eq!(
index.current_index().len(),
1000,
"All 1000 vectors should be added"
);
Ok(())
}
#[test]
fn test_add_batch_atomicity_on_dimension_mismatch() -> Result<()> {
let index = create_test_index()?;
let batch = vec![
(
NodeId::new(1).unwrap(),
vec![1.0, 0.0, 0.0, 0.0],
1000.into(),
),
(
NodeId::new(2).unwrap(),
vec![0.0, 1.0, 0.0, 0.0],
1100.into(),
),
(
NodeId::new(3).unwrap(),
vec![0.0, 0.0, 1.0], 1200.into(),
),
];
let result = index.add_batch(&batch);
assert!(
result.is_err(),
"Batch should fail due to dimension mismatch"
);
assert_eq!(
index.current_index().len(),
0,
"No vectors should be added when batch fails"
);
let stats = index.memory_stats();
assert_eq!(
stats.current_vectors, 0,
"current_state should have no vectors after failed batch"
);
Ok(())
}
#[test]
fn test_add_batch_consistency() -> Result<()> {
let index = create_test_index()?;
let batch1 = vec![
(
NodeId::new(1).unwrap(),
vec![1.0, 0.0, 0.0, 0.0],
1000.into(),
),
(
NodeId::new(2).unwrap(),
vec![0.9, 0.1, 0.0, 0.0],
1100.into(),
),
];
index.add_batch(&batch1)?;
assert_eq!(index.current_index().len(), 2, "HNSW should have 2 vectors");
let stats = index.memory_stats();
assert_eq!(
stats.current_vectors, 2,
"current_state should have 2 vectors"
);
let batch2 = vec![
(
NodeId::new(3).unwrap(),
vec![0.0, 1.0, 0.0, 0.0],
2000.into(),
),
(
NodeId::new(4).unwrap(),
vec![0.0, 0.9, 0.1, 0.0],
2100.into(),
),
];
index.add_batch(&batch2)?;
assert_eq!(index.current_index().len(), 4, "HNSW should have 4 vectors");
let stats = index.memory_stats();
assert_eq!(
stats.current_vectors, 4,
"current_state should have 4 vectors"
);
let query = vec![1.0, 0.0, 0.0, 0.0];
let results = index.current_index().search(&query, 4)?;
assert_eq!(results.len(), 4, "Should find all 4 vectors via search");
Ok(())
}
#[test]
fn test_concurrent_add_batch() -> Result<()> {
use std::sync::Arc;
use std::thread;
let index = Arc::new(create_test_index()?);
let mut handles = vec![];
for thread_id in 0..4 {
let index_clone = Arc::clone(&index);
let handle = thread::spawn(move || {
let batch: Vec<_> = (0..25)
.map(|i| {
let node_id = NodeId::new((thread_id * 25 + i) as u64).unwrap();
let vector = vec![(thread_id as f32) / 10.0, (i as f32) / 100.0, 0.0, 0.0];
let timestamp = ((thread_id * 25 + i) as i64 * 1000).into();
(node_id, vector, timestamp)
})
.collect();
index_clone.add_batch(&batch)
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap()?;
}
assert_eq!(
index.current_index().len(),
100,
"All 100 vectors should be added"
);
let stats = index.memory_stats();
assert_eq!(
stats.current_vectors, 100,
"current_state should have 100 vectors"
);
Ok(())
}
#[test]
fn test_remove_on_nonexistent_node() -> Result<()> {
let index = create_test_index()?;
let node1 = NodeId::new(1).unwrap();
index.add(node1, &[1.0, 0.0, 0.0, 0.0], 1000.into())?;
assert_eq!(index.current_index().len(), 1);
let node2 = NodeId::new(2).unwrap();
let _result = index.remove(node2, 2000.into());
assert_eq!(
index.current_index().len(),
1,
"Node1 should still be in HNSW"
);
let stats = index.memory_stats();
assert_eq!(
stats.current_vectors, 1,
"Node1 should still be in current_state"
);
Ok(())
}