use log::{debug, warn};
use crate::{
builder::ArrowSpaceBuilder,
core::ArrowItem,
tests::{init, test_data::make_gaussian_hd},
};
fn create_test_data(n_train: usize, n_query: usize) -> (Vec<Vec<f64>>, Vec<Vec<f64>>) {
let total = n_train + n_query;
let all_data = make_gaussian_hd(total, 0.6);
let train = all_data[0..n_train].to_vec();
let queries = all_data[n_train..].to_vec();
(train, queries)
}
#[test]
fn test_query_without_projection() {
let (data, queries) = create_test_data(99, 18);
debug!(
"Building index with {} items, {} dims",
data.len(),
data[0].len()
);
let (aspace, gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(1.0, 5, 2, 2.0, None)
.with_normalisation(true)
.with_dims_reduction(false, None) .with_sparsity_check(false)
.build(data);
debug!("Built index with {} clusters", aspace.n_clusters);
let query_lambda = aspace.prepare_query_item(&queries[0], &gl);
debug!("{:?}", query_lambda);
assert!(query_lambda.is_finite(), "Query lambda should be finite");
let query_item = ArrowItem::new(queries[0].as_ref(), query_lambda);
let results = aspace.search_lambda_aware(&query_item, 5, 0.7);
assert_eq!(results.len(), 5, "Should return exactly 5 results");
for i in 0..results.len() - 1 {
assert!(
results[i].1 >= results[i + 1].1,
"Results should be sorted by descending score"
);
}
debug!(
"✓ Query without projection: found {} results",
results.len()
);
}
#[test]
fn test_query_with_projection_enabled() {
let (data, queries) = create_test_data(99, 18);
debug!("Testing projection with {}D data", data[0].len());
let (aspace, gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(1.0, 6, 2, 2.0, None)
.with_normalisation(false)
.with_dims_reduction(true, Some(1.0)) .with_sparsity_check(false)
.build(data);
assert!(
aspace.projection_matrix.is_some(),
"Projection matrix should exist"
);
assert!(
aspace.reduced_dim.is_some(),
"Reduced dimension should be set"
);
let reduced_dim = aspace.reduced_dim.unwrap();
debug!(
"Projection: {} → {} dimensions ({:.1}x compression)",
aspace.nfeatures,
reduced_dim,
aspace.nfeatures as f64 / reduced_dim as f64
);
let query_original = queries[1].clone();
let query_projected = aspace.project_query(&query_original);
assert_eq!(
query_projected.len(),
reduced_dim,
"Projected query should match reduced dimension"
);
let query_lambda = aspace.prepare_query_item(&query_original, &gl);
let query_item = ArrowItem::new(query_original.as_ref(), query_lambda);
let results = aspace.search_lambda_aware(&query_item, 10, 0.7);
assert_eq!(results.len(), 10);
for (idx, score) in &results {
assert!(*idx < aspace.nitems, "Index should be valid cluster");
assert!(score.is_finite(), "Score should be finite");
assert!(
*score >= -1.0 && *score <= 1.0,
"Score should be in [-1, 1]"
);
}
debug!(
"✓ Query with projection: {} → {} dims, found {} results",
aspace.nfeatures,
reduced_dim,
results.len()
);
}
#[test]
fn test_prepare_query_item_consistency() {
let (data, queries) = create_test_data(99, 18);
let (aspace, gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(1.0, 6, 2, 2.0, None)
.with_normalisation(true)
.with_sparsity_check(false)
.build(data);
let query = queries[2].clone();
let lambda1 = aspace.prepare_query_item(&query, &gl);
let lambda2 = aspace.prepare_query_item(&query, &gl);
let lambda3 = aspace.prepare_query_item(&query, &gl);
assert!(
approx::relative_eq!(lambda1, lambda2, epsilon = 1e-10),
"Lambda computation should be deterministic"
);
assert!(
approx::relative_eq!(lambda2, lambda3, epsilon = 1e-10,),
"Lambda computation should be deterministic"
);
debug!("✓ Consistent lambda: {:.6}", lambda1);
}
#[test]
fn test_search_lambda_aware_alpha_effect() {
let (data, queries) = create_test_data(297, 18);
let (aspace, gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(1.0, 6, 2, 2.0, None)
.with_spectral(false) .with_sparsity_check(false)
.with_seed(42)
.build(data);
let query = queries[2].clone();
let query_lambda = aspace.prepare_query_item(&query, &gl);
let query_item = ArrowItem::new(query.as_ref(), query_lambda);
let results_high_alpha = aspace.search_lambda_aware(&query_item, 5, 0.9);
let results_low_alpha = aspace.search_lambda_aware(&query_item, 5, 0.1);
assert_eq!(results_high_alpha.len(), 5);
assert_eq!(results_low_alpha.len(), 5);
debug!(
"High alpha top result: idx={}, score={:.4}",
results_high_alpha[0].0, results_high_alpha[0].1
);
debug!(
"Low alpha top result: idx={}, score={:.4}",
results_low_alpha[0].0, results_low_alpha[0].1
);
let top_idx_high = results_high_alpha[0].0;
let top_item_high = aspace.get_item(top_idx_high);
let semantic_sim = query_item.cosine_similarity(&top_item_high.item);
assert!(
semantic_sim > 0.7,
"High alpha should favor semantic match: {:.4}",
semantic_sim
);
debug!("✓ Alpha effect verified: semantic_sim={:.4}", semantic_sim);
}
#[test]
fn test_search_lambda_aware_hybrid() {
let (data, queries) = create_test_data(99, 18);
let (aspace, gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(0.3, 5, 2, 2.0, None)
.with_normalisation(true)
.with_spectral(true)
.with_sparsity_check(false)
.build(data);
let query = queries[1].clone();
let query_lambda = aspace.prepare_query_item(&query, &gl);
let query_item = ArrowItem::new(query.as_ref(), query_lambda);
let results_regular = aspace.search_lambda_aware(&query_item, 10, 0.7);
let results_hybrid = aspace.search_lambda_aware_hybrid(&query_item, 10, 0.7);
assert_eq!(results_regular.len(), 10);
assert_eq!(results_hybrid.len(), 10);
for (idx, score) in results_regular.iter().chain(results_hybrid.iter()) {
assert!(*idx < aspace.nitems, "Index should be valid cluster");
assert!(score.is_finite(), "Score should be finite");
}
debug!(
"✓ Hybrid search: regular={} results, hybrid={} results",
results_regular.len(),
results_hybrid.len()
);
}
#[test]
#[should_panic]
fn test_query_dimension_mismatch_panics() {
let (data, queries) = create_test_data(99, 18);
let (aspace, gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(0.3, 4, 2, 2.0, None)
.build(data);
let wrong_query = queries[0][0..32].to_vec();
let _ = aspace.prepare_query_item(&wrong_query, &gl);
}
#[test]
#[should_panic]
fn test_query_with_nan_values() {
let (data, queries) = create_test_data(99, 18);
let (aspace, gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(0.3, 4, 2, 2.0, None)
.build(data);
let mut bad_query = queries[0].clone();
bad_query[3] = f64::NAN;
let _ = aspace.prepare_query_item(&bad_query, &gl);
}
#[test]
fn test_range_search_with_query_lambda() {
let (data, queries) = create_test_data(99, 18);
let (aspace, gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(0.3, 5, 2, 2.0, None)
.with_normalisation(true)
.with_sparsity_check(false)
.build(data);
let query = queries[1].clone();
let query_lambda = aspace.prepare_query_item(&query, &gl);
let query_item = ArrowItem::new(query.as_ref(), query_lambda);
debug!("query lambda {:?}", query_lambda);
let results = aspace.range_search(&query_item, &gl, 0.1);
debug!("{:?}", aspace.lambdas);
assert!(!results.is_empty(), "Range search should find some results");
for (idx, dist) in &results {
assert!(*idx < aspace.nitems, "Index should be valid cluster");
assert!(
*dist <= 0.5,
"Distance should be within radius: {:.4}",
dist
);
assert!(dist.is_finite(), "Distance should be finite");
}
debug!(
"✓ Range search found {} clusters within radius 0.5",
results.len()
);
}
#[test]
fn test_lambda_values_reasonable_range() {
init();
let (data, queries) = create_test_data(99, 18);
let (aspace, gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(0.3, 5, 2, 2.0, None)
.with_sparsity_check(false)
.with_seed(42)
.build(data);
for (i, query) in queries.iter().enumerate() {
let lambda = aspace.prepare_query_item(query, &gl);
assert!(lambda >= 0.0, "Lambda negative: query {}", i);
assert!(lambda.is_finite(), "Lambda not finite: query {}", i);
assert!(
lambda <= 1.0,
"Lambda out of bounds: query {}, λ={:.2e} (expected [0, 1])",
i,
lambda
);
if lambda > 1.0 {
warn!("Query {} lambda={:.6} exceeds [0,1] range", i, lambda);
}
debug!("Query {} lambda: {:.6}", i, lambda);
}
}
#[test]
fn test_search_returns_top_k_exactly() {
let (data, queries) = create_test_data(99, 18);
let (aspace, gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(0.3, 5, 2, 2.0, None)
.with_normalisation(true)
.with_sparsity_check(false)
.build(data);
let query = queries[2].clone();
let query_lambda = aspace.prepare_query_item(&query, &gl);
let query_item = ArrowItem::new(query.as_ref(), query_lambda);
let num_clusters = aspace.n_clusters;
debug!("Testing k-NN with {} total clusters", num_clusters);
for k in [1, 3, 5, 10] {
let results = aspace.search_lambda_aware(&query_item, k, 0.1);
assert_eq!(
results.len(),
k,
"Should return exactly {} results for k={}",
results.len(),
k
);
}
debug!("✓ k-NN returns correct number of results");
}
#[test]
fn test_projection_preserves_relative_distances() {
crate::tests::init();
let (data, _) = create_test_data(99, 18);
let query: Vec<f64> = data[10].clone().to_vec();
debug!("Testing JL projection with 200D data");
let (aspace, gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(0.5, 6, 2, 2.0, None)
.with_dims_reduction(true, Some(0.3)) .with_sparsity_check(false)
.with_seed(42)
.build(data);
assert!(
aspace.projection_matrix.is_some(),
"Projection should be enabled"
);
let reduced_dim = aspace.reduced_dim.unwrap();
debug!(
"Projection: {} → {} dimensions ({:.1}x compression)",
aspace.nfeatures,
reduced_dim,
aspace.nfeatures as f64 / reduced_dim as f64
);
let query1_orig: Vec<f64> = query.clone();
let query2_orig: Vec<f64> = query.clone().iter().map(|x| x + 0.01).collect(); let query3_orig: Vec<f64> = query.clone().iter().map(|x| x * 10.0).collect();
let query1_proj = aspace.project_query(&query1_orig);
let query2_proj = aspace.project_query(&query2_orig);
let query3_proj = aspace.project_query(&query3_orig);
assert_eq!(query1_proj.len(), reduced_dim);
assert_eq!(query2_proj.len(), reduced_dim);
assert_eq!(query3_proj.len(), reduced_dim);
let dist_12_orig: f64 = query1_orig
.iter()
.zip(&query2_orig)
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
let dist_13_orig: f64 = query1_orig
.iter()
.zip(&query3_orig)
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
let dist_12_proj: f64 = query1_proj
.iter()
.zip(&query2_proj)
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
let dist_13_proj: f64 = query1_proj
.iter()
.zip(&query3_proj)
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
.sqrt();
debug!(
"Original space: dist(q1,q2)={:.4}, dist(q1,q3)={:.4}",
dist_12_orig, dist_13_orig
);
debug!(
"Projected space: dist(q1,q2)={:.4}, dist(q1,q3)={:.4}",
dist_12_proj, dist_13_proj
);
assert!(
dist_12_orig < dist_13_orig,
"In original space, q1 should be closer to q2 than q3"
);
assert!(
dist_12_proj < dist_13_proj,
"In projected space, q1 should STILL be closer to q2 than q3 (ordering preserved)"
);
let epsilon = 0.5;
let ratio_12 = dist_12_proj / dist_12_orig;
let ratio_13 = dist_13_proj / dist_13_orig;
debug!(
"Distance preservation ratios: q1-q2={:.3}, q1-q3={:.3}",
ratio_12, ratio_13
);
assert!(
ratio_12 > 1.0 - epsilon && ratio_12 < 1.0 + epsilon,
"Distance q1-q2 not preserved: ratio {:.3} outside tolerance",
ratio_12
);
assert!(
ratio_13 > 1.0 - epsilon && ratio_13 < 1.0 + epsilon,
"Distance q1-q3 not preserved: ratio {:.3} outside tolerance",
ratio_13
);
let lambda1 = aspace.prepare_query_item(&query1_orig, &gl);
let lambda2 = aspace.prepare_query_item(&query2_orig, &gl);
let lambda3 = aspace.prepare_query_item(&query3_orig, &gl);
assert!(lambda1.is_finite() && lambda1 >= 0.0);
assert!(lambda2.is_finite() && lambda2 >= 0.0);
assert!(lambda3.is_finite() && lambda3 >= 0.0);
debug!(
"Query lambdas: q1={:.6}, q2={:.6}, q3={:.6}",
lambda1, lambda2, lambda3
);
debug!("✓ Projection preserves relative distances and enables lambda computation");
}
#[test]
fn test_project_query_no_projection() {
let (data, queries) = create_test_data(99, 18);
let (aspace, _gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(1.0, 6, 2, 2.0, None)
.with_dims_reduction(false, None) .build(data);
let query = queries[0].clone();
let projected = aspace.project_query(&query);
assert_eq!(projected.len(), query.len());
assert_eq!(projected, query);
debug!("✓ Query passes through unchanged when projection disabled");
}
#[test]
fn test_project_query_consistency() {
let (data, queries) = create_test_data(99, 18);
let (aspace, _gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(0.3, 4, 2, 2.0, None)
.with_dims_reduction(true, Some(0.4))
.with_sparsity_check(false)
.build(data);
let query = queries[1].clone();
let projected1 = aspace.project_query(&query);
let projected2 = aspace.project_query(&query);
let projected3 = aspace.project_query(&query);
assert_eq!(projected1, projected2, "Projection should be deterministic");
assert_eq!(projected2, projected3, "Projection should be deterministic");
debug!("✓ Projection is deterministic");
}
#[test]
fn test_project_query_linearity() {
let (data, queries) = create_test_data(99, 18);
let (aspace, _gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(1.0, 6, 2, 2.0, None)
.with_dims_reduction(true, Some(0.35))
.with_sparsity_check(false)
.build(data);
let query = queries[1].clone();
let scaled_query: Vec<f64> = query.iter().map(|x| x * 2.0).collect();
let projected = aspace.project_query(&query);
let projected_scaled = aspace.project_query(&scaled_query);
for i in 0..projected.len() {
let expected = projected[i] * 2.0;
let actual = projected_scaled[i];
assert!(
(expected - actual).abs() < 1e-9,
"Linearity violation at index {}: expected {:.6}, got {:.6}",
i,
expected,
actual
);
}
debug!("✓ Projection is linear");
}
#[test]
fn test_project_query_zero_vector() {
let (data, _) = create_test_data(99, 18);
let (aspace, _gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(0.2, 4, 2, 2.0, None)
.with_dims_reduction(true, Some(0.8))
.with_sparsity_check(false)
.build(data);
let query_zero = vec![0.0; 100];
let projected = aspace.project_query(&query_zero);
assert_eq!(projected.len(), aspace.reduced_dim.unwrap());
for &val in &projected {
assert!(val.abs() < 1e-8, "Expected near-zero, got {:.6}", val);
}
debug!("✓ Zero vector projects to zero");
}
#[test]
fn test_project_query_preserves_scale_approximately() {
let (data, queries) = create_test_data(99, 18);
let (aspace, _gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(1.0, 6, 2, 2.0, None)
.with_dims_reduction(true, Some(0.25))
.with_sparsity_check(false)
.build(data);
let query = queries[1].clone();
let projected = aspace.project_query(&query);
let orig_norm: f64 = query.iter().map(|x| x * x).sum::<f64>().sqrt();
let proj_norm: f64 = projected.iter().map(|x| x * x).sum::<f64>().sqrt();
let ratio = proj_norm / orig_norm;
debug!(
"Norm preservation: original={:.4}, projected={:.4}, ratio={:.3}",
orig_norm, proj_norm, ratio
);
assert!(
ratio > 0.5 && ratio < 2.0,
"Norm ratio {:.3} out of expected range [0.5, 2.0]",
ratio
);
debug!("✓ Projection preserves scale approximately");
}
#[test]
fn test_project_query_different_queries_differ() {
let (data, queries) = create_test_data(99, 18);
let (aspace, _gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(0.3, 4, 2, 2.0, None)
.with_dims_reduction(true, Some(0.8))
.build(data);
let proj1 = aspace.project_query(&queries[0]);
let proj2 = aspace.project_query(&queries[1]);
let proj3 = aspace.project_query(&queries[2]);
let reduced_dim = aspace.reduced_dim.unwrap();
assert_eq!(proj1.len(), reduced_dim);
assert_eq!(proj2.len(), reduced_dim);
assert_eq!(proj3.len(), reduced_dim);
assert_ne!(proj1, proj2);
assert_ne!(proj2, proj3);
assert_ne!(proj1, proj3);
debug!("✓ Different queries produce different projections");
}
#[test]
fn test_project_query_preserves_dot_product_sign() {
let (data, queries) = create_test_data(99, 18);
let (aspace, _gl) = ArrowSpaceBuilder::default()
.with_lambda_graph(1.0, 6, 2, 2.0, None)
.with_dims_reduction(true, Some(0.8))
.with_inline_sampling(None)
.with_seed(312)
.build(data);
let query_pos = queries[0].clone();
let query_neg: Vec<f64> = queries[0].iter().map(|&x| -x).collect();
let proj_pos = aspace.project_query(&query_pos);
let proj_neg = aspace.project_query(&query_neg);
let dot: f64 = proj_pos.iter().zip(&proj_neg).map(|(a, b)| a * b).sum();
assert!(
dot < 0.0,
"Projection should preserve opposite directions: dot={:.6}",
dot
);
debug!("✓ Projection preserves opposite directions (dot product sign)");
}