use diskann::{utils::VectorRepr, ANNError};
use diskann_providers::storage::StorageReadProvider;
use diskann_utils::io::read_bin;
use rand::Rng;
use crate::utils::{CMDResult, CMDToolError};
fn squared_distance<T: VectorRepr>(v1: &[T], v2: &[T]) -> CMDResult<f32> {
let v1 = &*T::as_f32(v1).map_err(|x| CMDToolError::from(Into::<ANNError>::into(x)))?;
let v2 = &*T::as_f32(v2).map_err(|x| CMDToolError::from(Into::<ANNError>::into(x)))?;
Ok(v1
.iter()
.zip(v2)
.map(|(a, b)| {
let diff = *a - *b;
diff * diff
})
.sum())
}
fn average_squared_distance<T: VectorRepr, R: Rng>(
query: &[T],
base: &[Vec<T>],
num_random_samples: usize,
rng: &mut R,
) -> CMDResult<f32> {
let n = base.len();
let mut sum_dist = 0.0;
for _ in 0..num_random_samples {
let r = rng.random_range(0..n);
sum_dist += squared_distance::<T>(query, &base[r])?;
}
Ok(sum_dist / num_random_samples as f32)
}
pub fn compute_relative_contrast<T: VectorRepr, StorageProvider: StorageReadProvider, R: Rng>(
storage_provider: &StorageProvider,
base_file: &str,
query_file: &str,
gt_file: &str,
recall_at: usize,
num_random_samples: usize,
rng: &mut R,
) -> CMDResult<f32> {
let base_data = read_bin::<T>(&mut storage_provider.open_reader(base_file)?)?;
let query_data = read_bin::<T>(&mut storage_provider.open_reader(query_file)?)?;
let gt_data = read_bin::<u32>(&mut storage_provider.open_reader(gt_file)?)?;
let nb = base_data.nrows();
let dim = base_data.ncols();
let nq = query_data.nrows();
let ngt = gt_data.ncols();
tracing::info!(
"Loaded base: {} points, query: {} points, dimension: {}, ground truth neighbors: {}",
nb,
nq,
dim,
ngt
);
let base: Vec<Vec<T>> = base_data.row_iter().map(|x| x.to_vec()).collect();
let query: Vec<Vec<T>> = query_data.row_iter().map(|x| x.to_vec()).collect();
let gt: Vec<Vec<u32>> = gt_data.row_iter().map(|x| x.to_vec()).collect();
let mut mean_rc = 0.0;
for (i, q) in query.iter().enumerate() {
let numerator = average_squared_distance::<T, R>(q, &base, num_random_samples, rng)?;
let mut denominator = 0.0;
for &idx in gt[i].iter().take(recall_at) {
denominator += squared_distance::<T>(q, &base[idx as usize])?;
}
denominator /= recall_at as f32;
let rc = numerator / denominator;
mean_rc += rc / nq as f32;
}
if (1.5..2.0).contains(&mean_rc) {
tracing::info!(
"Mean relative contrast = {}. The dataset is suitable for ANN.",
mean_rc
);
} else {
tracing::info!(
"Mean relative contrast = {}. The dataset is not suitable for ANN.",
mean_rc
);
}
Ok(mean_rc)
}
#[cfg(test)]
mod relative_contrast_tests {
use diskann_providers::storage::{StorageWriteProvider, VirtualStorageProvider};
use diskann_providers::utils::random;
use diskann_utils::io::Metadata;
use diskann_vector::distance::Metric;
use half::f16;
use rand::Rng;
use super::*;
use crate::utils::ground_truth::compute_ground_truth_from_datafiles;
use diskann_disk::data_model::AdHoc;
use diskann_vector::Half;
#[test]
fn test_compute_relative_contrast_with_random_data() {
let storage_provider = VirtualStorageProvider::new_memory();
let num_vectors = 1000;
let dim = 384;
let mut rng = random::create_rnd_in_tests();
let base: Vec<f16> = (0..num_vectors * dim)
.map(|_| f16::from_f32(rng.random_range(0.0..1.0)))
.collect();
let num_queries = 10;
let query: Vec<f16> = (0..num_queries * dim)
.map(|_| f16::from_f32(rng.random_range(0.0..1.0)))
.collect();
let base_file_path = "/base.bin";
{
let mut base_writer = storage_provider.create_for_write(base_file_path).unwrap();
Metadata::new(num_vectors, dim)
.unwrap()
.write(&mut base_writer)
.unwrap();
for value in &base {
base_writer.write_all(&value.to_le_bytes()).unwrap();
}
}
let query_file_path = "/query.bin";
{
let mut query_writer = storage_provider.create_for_write(query_file_path).unwrap();
Metadata::new(num_queries, dim)
.unwrap()
.write(&mut query_writer)
.unwrap();
for value in &query {
query_writer.write_all(&value.to_le_bytes()).unwrap();
}
}
let gt_file_path = "/ground_truth.bin";
let recall_at = 5;
compute_ground_truth_from_datafiles::<AdHoc<Half>, _>(
&storage_provider,
Metric::L2,
base_file_path,
query_file_path,
gt_file_path,
None,
recall_at as u32,
None,
None,
None,
None,
None,
)
.unwrap();
let num_random_samples = 5;
let mean_rc = compute_relative_contrast::<Half, _, _>(
&storage_provider,
base_file_path,
query_file_path,
gt_file_path,
recall_at,
num_random_samples,
&mut rng,
)
.unwrap();
println!("Mean relative contrast: {}", mean_rc);
assert!(
mean_rc > 1.0 && mean_rc < 1.2,
"Mean relative contrast is out of range: {}",
mean_rc
);
}
#[test]
fn test_compute_relative_contrast_with_sift_files() {
let storage_provider =
VirtualStorageProvider::new_overlay(diskann_utils::test_data_root().join("sift"));
let base_file_path = "/siftsmall_learn_256pts.fbin";
assert!(
storage_provider.exists(base_file_path),
"Base file does not exist"
);
let num_queries = 10;
let dim = 128;
let mut rng = random::create_rnd_in_tests();
let query: Vec<f16> = (0..num_queries * dim)
.map(|_| f16::from_f32(rng.random_range(0.0..1.0)))
.collect();
let query_file_path = "/query.bin";
{
let mut query_writer = storage_provider
.create_for_write(query_file_path)
.expect("Failed to create query file in memory");
Metadata::new(num_queries, dim)
.expect("Failed to create metadata")
.write(&mut query_writer)
.expect("Failed to write metadata");
for value in &query {
query_writer
.write_all(&value.to_le_bytes())
.expect("Failed to write query vector");
}
}
let gt_file_path = "/ground_truth.bin";
let recall_at = 3;
compute_ground_truth_from_datafiles::<AdHoc<Half>, _>(
&storage_provider,
Metric::L2,
base_file_path,
query_file_path,
gt_file_path,
None,
recall_at as u32,
None,
None,
None,
None,
None,
)
.unwrap();
let num_random_samples = 3;
let mean_rc = compute_relative_contrast::<Half, _, _>(
&storage_provider,
base_file_path,
query_file_path,
gt_file_path,
recall_at,
num_random_samples,
&mut rng,
)
.unwrap();
println!("Mean relative contrast: {}", mean_rc);
storage_provider
.delete(query_file_path)
.expect("Failed to delete query file in disk");
storage_provider
.delete(gt_file_path)
.expect("Failed to delete ground truth file in disk");
assert!(
mean_rc > 1.5 && mean_rc < 2.0,
"Mean relative contrast is out of range: {}",
mean_rc
);
}
}