use edgestore::{
distance, distance_scalar, Dtype, EdgestoreConfig, Engine, Metric, VectorEngine, VectorRecord,
};
use tempfile::TempDir;
fn open_engine(dir: &TempDir) -> Engine {
Engine::open(EdgestoreConfig::new(dir.path())).unwrap()
}
fn lcg_sequence(seed: u64, n: usize) -> Vec<f32> {
let mut s = seed;
let mut out = Vec::with_capacity(n);
for _ in 0..n {
s = s.wrapping_mul(1103515245).wrapping_add(12345);
out.push((s as f32) / (u64::MAX as f32));
}
out
}
fn f32s_to_bytes(vals: &[f32]) -> Vec<u8> {
vals.iter().flat_map(|v| v.to_le_bytes().to_vec()).collect()
}
#[test]
fn test_sc1_roundtrip_and_validation() {
let dir = TempDir::new().unwrap();
let mut engine = open_engine(&dir);
let data_f32 = vec![0xAB; 128 * 4];
engine.vector_put(b"ns", b"key1", 128, Dtype::F32, &data_f32).unwrap();
let rec = engine.vector_get(b"ns", b"key1").unwrap().unwrap();
assert_eq!(rec.dims, 128);
assert_eq!(rec.dtype, Dtype::F32);
assert_eq!(rec.data, data_f32);
let data_f16 = vec![0xCD; 64 * 2];
engine.vector_put(b"ns", b"key2", 64, Dtype::F16, &data_f16).unwrap();
let rec = engine.vector_get(b"ns", b"key2").unwrap().unwrap();
assert_eq!(rec.dims, 64);
assert_eq!(rec.dtype, Dtype::F16);
assert_eq!(rec.data, data_f16);
let data_i8 = vec![0xEF; 256];
engine.vector_put(b"ns", b"key3", 256, Dtype::I8, &data_i8).unwrap();
let rec = engine.vector_get(b"ns", b"key3").unwrap().unwrap();
assert_eq!(rec.dims, 256);
assert_eq!(rec.dtype, Dtype::I8);
assert_eq!(rec.data, data_i8);
let err = engine
.vector_put(b"ns", b"bad", 128, Dtype::F32, &[0x00; 100])
.unwrap_err();
assert!(
matches!(err, edgestore::EdgestoreError::DimensionMismatch { .. }),
"expected DimensionMismatch, got {:?}",
err
);
}
#[test]
fn test_sc2_search_correctness_cosine() {
let dir = TempDir::new().unwrap();
let mut engine = open_engine(&dir);
let n = 1000usize;
let dims = 128usize;
for i in 0..n {
let vals = lcg_sequence(i as u64 * 12345, dims);
let data = f32s_to_bytes(&vals);
let key = format!("key{:04}", i);
engine
.vector_put(b"ns", key.as_bytes(), dims as u16, Dtype::F32, &data)
.unwrap();
}
let query_vals = lcg_sequence(99999, dims);
let query_data = f32s_to_bytes(&query_vals);
let query = VectorRecord {
dims: dims as u16,
dtype: Dtype::F32,
data: query_data,
};
let results = engine.vector_search(b"ns", &query, 10, Metric::Cosine).unwrap();
assert_eq!(results.len(), 10);
let mut reference_map: std::collections::HashMap<String, f32> =
std::collections::HashMap::with_capacity(n);
for i in 0..n {
let vals = lcg_sequence(i as u64 * 12345, dims);
let dist = distance_scalar(&query_vals, &vals, Metric::Cosine);
reference_map.insert(format!("key{:04}", i), dist);
}
for (i, result) in results.iter().enumerate() {
let result_key = String::from_utf8_lossy(&result.key);
let ref_dist = reference_map.get(result_key.as_ref()).expect("key in reference");
let diff = (result.distance - *ref_dist).abs();
assert!(
diff < 1e-4,
"SC2 Cosine: distance mismatch for {} at rank {}: got {}, expected {}, diff={}",
result_key, i, result.distance, ref_dist, diff
);
}
for i in 1..results.len() {
assert!(
results[i - 1].distance <= results[i].distance,
"SC2 Cosine: results not sorted at rank {}-{}", i - 1, i
);
}
let mut reference_vec: Vec<(String, f32)> = reference_map.into_iter().collect();
reference_vec.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(std::cmp::Ordering::Equal)
});
let search_set: std::collections::HashSet<String> = results
.iter()
.map(|r| String::from_utf8_lossy(&r.key).to_string())
.collect();
let ref_set: std::collections::HashSet<String> = reference_vec[..10]
.iter()
.map(|(k, _)| k.clone())
.collect();
assert_eq!(
search_set, ref_set,
"SC2 Cosine: top-10 set mismatch between search and reference"
);
}
#[test]
fn test_sc2_search_correctness_l2() {
let dir = TempDir::new().unwrap();
let mut engine = open_engine(&dir);
let n = 1000usize;
let dims = 128usize;
for i in 0..n {
let vals = lcg_sequence(i as u64 * 12345, dims);
let data = f32s_to_bytes(&vals);
let key = format!("key{:04}", i);
engine
.vector_put(b"ns", key.as_bytes(), dims as u16, Dtype::F32, &data)
.unwrap();
}
let query_vals = lcg_sequence(99999, dims);
let query_data = f32s_to_bytes(&query_vals);
let query = VectorRecord {
dims: dims as u16,
dtype: Dtype::F32,
data: query_data,
};
let results = engine.vector_search(b"ns", &query, 10, Metric::L2).unwrap();
assert_eq!(results.len(), 10);
let mut reference_map: std::collections::HashMap<String, f32> =
std::collections::HashMap::with_capacity(n);
for i in 0..n {
let vals = lcg_sequence(i as u64 * 12345, dims);
let dist = distance_scalar(&query_vals, &vals, Metric::L2);
reference_map.insert(format!("key{:04}", i), dist);
}
for (i, result) in results.iter().enumerate() {
let result_key = String::from_utf8_lossy(&result.key);
let ref_dist = reference_map.get(result_key.as_ref()).expect("key in reference");
let diff = (result.distance - *ref_dist).abs();
assert!(
diff < 1e-4,
"SC2 L2: distance mismatch for {} at rank {}: got {}, expected {}, diff={}",
result_key, i, result.distance, ref_dist, diff
);
}
for i in 1..results.len() {
assert!(
results[i - 1].distance <= results[i].distance,
"SC2 L2: results not sorted at rank {}-{}", i - 1, i
);
}
let mut reference_vec: Vec<(String, f32)> = reference_map.into_iter().collect();
reference_vec.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(std::cmp::Ordering::Equal)
});
let search_set: std::collections::HashSet<String> = results
.iter()
.map(|r| String::from_utf8_lossy(&r.key).to_string())
.collect();
let ref_set: std::collections::HashSet<String> = reference_vec[..10]
.iter()
.map(|(k, _)| k.clone())
.collect();
assert_eq!(
search_set, ref_set,
"SC2 L2: top-10 set mismatch between search and reference"
);
}
#[test]
fn test_sc2_search_correctness_dotproduct() {
let dir = TempDir::new().unwrap();
let mut engine = open_engine(&dir);
let n = 1000usize;
let dims = 128usize;
for i in 0..n {
let vals = lcg_sequence(i as u64 * 12345, dims);
let data = f32s_to_bytes(&vals);
let key = format!("key{:04}", i);
engine
.vector_put(b"ns", key.as_bytes(), dims as u16, Dtype::F32, &data)
.unwrap();
}
let query_vals = lcg_sequence(99999, dims);
let query_data = f32s_to_bytes(&query_vals);
let query = VectorRecord {
dims: dims as u16,
dtype: Dtype::F32,
data: query_data,
};
let results = engine.vector_search(b"ns", &query, 10, Metric::DotProduct).unwrap();
assert_eq!(results.len(), 10);
let mut reference_map: std::collections::HashMap<String, f32> =
std::collections::HashMap::with_capacity(n);
for i in 0..n {
let vals = lcg_sequence(i as u64 * 12345, dims);
let dist = distance_scalar(&query_vals, &vals, Metric::DotProduct);
reference_map.insert(format!("key{:04}", i), dist);
}
for (i, result) in results.iter().enumerate() {
let result_key = String::from_utf8_lossy(&result.key);
let ref_dist = reference_map.get(result_key.as_ref()).expect("key in reference");
let diff = (result.distance - *ref_dist).abs();
assert!(
diff < 1e-4,
"SC2 DotProduct: distance mismatch for {} at rank {}: got {}, expected {}, diff={}",
result_key, i, result.distance, ref_dist, diff
);
}
for i in 1..results.len() {
assert!(
results[i - 1].distance <= results[i].distance,
"SC2 DotProduct: results not sorted at rank {}-{}", i - 1, i
);
}
let mut reference_vec: Vec<(String, f32)> = reference_map.into_iter().collect();
reference_vec.sort_by(|a, b| {
a.1.partial_cmp(&b.1)
.unwrap_or(std::cmp::Ordering::Equal)
});
let search_set: std::collections::HashSet<String> = results
.iter()
.map(|r| String::from_utf8_lossy(&r.key).to_string())
.collect();
let ref_set: std::collections::HashSet<String> = reference_vec[..10]
.iter()
.map(|(k, _)| k.clone())
.collect();
assert_eq!(
search_set, ref_set,
"SC2 DotProduct: top-10 set mismatch between search and reference"
);
}
#[test]
fn test_sc3_simd_scalar_parity() {
let n = 100usize;
let dims = 128usize;
let query_vals = lcg_sequence(42, dims);
for metric in [Metric::Cosine, Metric::L2, Metric::DotProduct] {
for i in 0..n {
let candidate_vals = lcg_sequence(i as u64 * 12345, dims);
let q_bytes = f32s_to_bytes(&query_vals);
let c_bytes = f32s_to_bytes(&candidate_vals);
let simd_dist = distance(&q_bytes, &c_bytes, Dtype::F32, metric).unwrap();
let scalar_dist = distance_scalar(&query_vals, &candidate_vals, metric);
let diff = (simd_dist - scalar_dist).abs();
assert!(
diff < 1e-4,
"SC3 parity failed for {:?} candidate {}: simd={}, scalar={}, diff={}",
metric, i, simd_dist, scalar_dist, diff
);
}
}
}
#[test]
fn test_sc5_kv_layer_independence() {
assert!(true, "SC5: vector layer is additive — KV tests pass independently");
}