use std::sync::Arc;
use rand::SeedableRng;
use rand_chacha::ChaCha20Rng;
use tinyquant_bruteforce::BruteForceBackend;
use tinyquant_core::backend::SearchBackend;
use tinyquant_core::errors::BackendError;
fn make_id(s: &str) -> Arc<str> {
Arc::from(s)
}
fn unit_vec(dim: usize, hot: usize) -> Vec<f32> {
let mut v = vec![0.0_f32; dim];
v[hot] = 1.0;
v
}
fn random_vec(rng: &mut ChaCha20Rng, dim: usize) -> Vec<f32> {
use rand::Rng;
(0..dim).map(|_| rng.gen_range(-1.0_f32..1.0_f32)).collect()
}
#[test]
fn ingest_then_search_returns_top_k_descending() {
let mut b = BruteForceBackend::new();
b.ingest(&[
(make_id("a"), unit_vec(3, 0)), (make_id("b"), unit_vec(3, 1)), (make_id("c"), unit_vec(3, 2)), ])
.unwrap();
let query = [0.9_f32, 0.1, 0.0];
let results = b.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].vector_id.as_ref(), "a");
assert!(results[0].score >= results[1].score);
}
#[test]
fn identical_query_scores_near_one() {
let mut b = BruteForceBackend::new();
let vec = vec![1.0_f32, 0.0, 0.0];
b.ingest(&[(make_id("x"), vec.clone())]).unwrap();
let results = b.search(&vec, 1).unwrap();
assert_eq!(results.len(), 1);
assert!(
(results[0].score - 1.0).abs() < 1e-5,
"expected score ~1.0, got {}",
results[0].score
);
}
#[test]
fn orthogonal_query_scores_near_zero() {
let mut b = BruteForceBackend::new();
let stored = vec![1.0_f32, 0.0];
let query = [0.0_f32, 1.0];
b.ingest(&[(make_id("x"), stored)]).unwrap();
let results = b.search(&query, 1).unwrap();
assert_eq!(results.len(), 1);
assert!(
results[0].score.abs() < 1e-5,
"expected score ~0.0, got {}",
results[0].score
);
}
#[test]
fn empty_backend_returns_empty_list() {
let b = BruteForceBackend::new();
let query = [1.0_f32, 0.0];
let results = b.search(&query, 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn remove_ids_unknown_is_silent() {
let mut b = BruteForceBackend::new();
let res = b.remove(&[make_id("ghost")]);
assert!(res.is_ok());
b.ingest(&[(make_id("real"), vec![1.0_f32, 0.0])]).unwrap();
let res = b.remove(&[make_id("ghost")]);
assert!(res.is_ok());
assert_eq!(b.len(), 1);
}
#[test]
fn remove_ids_empty_list_is_noop() {
let mut b = BruteForceBackend::new();
b.ingest(&[(make_id("a"), vec![1.0_f32, 0.0])]).unwrap();
b.remove(&[]).unwrap();
assert_eq!(b.len(), 1);
}
#[test]
fn search_with_fewer_vectors_than_top_k_returns_all() {
let mut b = BruteForceBackend::new();
b.ingest(&[
(make_id("a"), vec![1.0_f32, 0.0]),
(make_id("b"), vec![0.0_f32, 1.0]),
])
.unwrap();
let results = b.search(&[1.0_f32, 0.0], 10).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn search_with_top_k_zero_returns_invalid_top_k() {
let b = BruteForceBackend::new();
let err = b.search(&[1.0_f32, 0.0], 0).unwrap_err();
assert!(matches!(err, BackendError::InvalidTopK));
}
#[test]
fn search_with_wrong_query_dim_returns_adapter_error() {
let mut b = BruteForceBackend::new();
b.ingest(&[(make_id("x"), vec![1.0_f32, 0.0, 0.0])])
.unwrap(); let err = b.search(&[1.0_f32, 0.0], 1).unwrap_err(); assert!(
matches!(err, BackendError::Adapter(_)),
"expected Adapter, got {err:?}"
);
}
#[test]
fn ingest_with_mismatched_dim_returns_adapter_error() {
let mut b = BruteForceBackend::new();
b.ingest(&[(make_id("x"), vec![1.0_f32, 0.0, 0.0])])
.unwrap(); let err = b
.ingest(&[(make_id("y"), vec![1.0_f32, 0.0])]) .unwrap_err();
assert!(
matches!(err, BackendError::Adapter(_)),
"expected Adapter, got {err:?}"
);
assert_eq!(b.len(), 1);
}
#[test]
fn ingest_overwrites_existing_id_silently() {
let mut b = BruteForceBackend::new();
b.ingest(&[(make_id("x"), vec![1.0_f32, 0.0])]).unwrap();
b.ingest(&[(make_id("x"), vec![0.0_f32, 1.0])]).unwrap();
assert_eq!(b.len(), 1);
let results = b.search(&[0.0_f32, 1.0], 1).unwrap();
assert!(
(results[0].score - 1.0).abs() < 1e-5,
"expected overwritten vector to score ~1.0, got {}",
results[0].score
);
}
#[test]
fn ingest_empty_slice_is_noop() {
let mut b = BruteForceBackend::new();
let result = b.ingest(&[]);
assert!(result.is_ok());
assert_eq!(b.len(), 0);
assert!(b.dim().is_none());
}
#[test]
fn remove_id_preserves_insertion_order_of_survivors() {
let mut b = BruteForceBackend::new();
let dim = 4;
b.ingest(&[
(Arc::from("a"), vec![1.0, 0.0, 0.0, 0.0]),
(Arc::from("b"), vec![0.0, 1.0, 0.0, 0.0]),
(Arc::from("c"), vec![0.0, 0.0, 1.0, 0.0]),
(Arc::from("d"), vec![0.0, 0.0, 0.0, 1.0]),
])
.unwrap();
b.remove(&[Arc::from("b")]).unwrap();
assert_eq!(b.len(), 3);
let results = b.search(&[1.0, 0.0, 0.0, 0.0], 10).unwrap();
let ids: Vec<&str> = results.iter().map(|r| r.vector_id.as_ref()).collect();
assert!(ids.contains(&"a"), "a should still be present");
assert!(!ids.contains(&"b"), "b should have been removed");
assert!(ids.contains(&"c"), "c should still be present");
assert!(ids.contains(&"d"), "d should still be present");
assert_eq!(ids[0], "a");
let _ = dim;
}
#[test]
fn golden_fixture_top_10_ordering() {
let mut rng = ChaCha20Rng::seed_from_u64(42);
let dim = 32_usize;
let n = 100_usize;
let vectors: Vec<(Arc<str>, Vec<f32>)> = (0..n)
.map(|i| {
(
Arc::from(format!("v{i:03}").as_str()),
random_vec(&mut rng, dim),
)
})
.collect();
let mut backend = BruteForceBackend::new();
backend
.ingest(
&vectors
.iter()
.map(|(id, v)| (Arc::clone(id), v.clone()))
.collect::<Vec<_>>(),
)
.unwrap();
for q in 0..5_usize {
let query = random_vec(&mut rng, dim);
let results = backend.search(&query, 10).unwrap();
assert_eq!(results.len(), 10, "query {q}: expected 10 results");
for window in results.windows(2) {
assert!(
window[0].score >= window[1].score,
"query {q}: results not sorted descending at window boundary: {} vs {}",
window[0].score,
window[1].score
);
}
}
}