use bit_set::BitSet;
use diskann_label_filter::{eval_query_expr, read_and_parse_queries, read_baselabels};
use std::{io::Write, mem::size_of, str::FromStr};
use bytemuck::cast_slice;
use diskann::{
neighbor::{Neighbor, NeighborPriorityQueue},
utils::VectorRepr,
};
use diskann_disk::data_model::GraphDataType;
use diskann_providers::storage::{StorageReadProvider, StorageWriteProvider};
use diskann_providers::utils::{
create_thread_pool, file_util, ParallelIteratorInPool, VectorDataIterator,
};
use diskann_utils::{
io::{read_bin, Metadata},
views::Matrix,
};
use diskann_vector::{distance::Metric, DistanceFunction};
use itertools::Itertools;
use rayon::prelude::*;
use crate::utils::{search_index_utils, CMDResult, CMDToolError};
pub fn read_labels_and_compute_bitmap(
base_label_filename: &str,
query_label_filename: &str,
) -> CMDResult<Vec<BitSet>> {
let base_labels = read_baselabels(base_label_filename)?;
let parsed_queries = read_and_parse_queries(query_label_filename)?;
#[allow(clippy::disallowed_methods)]
let query_bitmaps: Vec<BitSet> = parsed_queries
.par_iter()
.map(|(_query_id, query_expr)| {
let mut bitmap = BitSet::new();
for base_label in base_labels.iter() {
if eval_query_expr(query_expr, &base_label.label) {
bitmap.insert(base_label.doc_id);
}
}
bitmap
})
.collect();
Ok(query_bitmaps)
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::panic)]
pub fn compute_ground_truth_from_datafiles<
Data: GraphDataType,
StorageProvider: StorageReadProvider + StorageWriteProvider,
>(
storage_provider: &StorageProvider,
distance_function: Metric,
base_file: &str,
query_file: &str,
ground_truth_file: &str,
vector_filters_file: Option<&str>,
recall_at: u32,
insert_file: Option<&str>,
skip_base: Option<usize>,
associated_data_file: Option<String>,
base_file_labels: Option<&str>,
query_file_labels: Option<&str>,
) -> CMDResult<()> {
let dataset_iterator = VectorDataIterator::<
StorageProvider,
Data::VectorDataType,
Data::AssociatedDataType,
>::new(base_file, associated_data_file.clone(), storage_provider)?;
if !((base_file_labels.is_some() && query_file_labels.is_some())
|| (base_file_labels.is_none() && query_file_labels.is_none()))
{
return Err(CMDToolError {
details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
});
}
if base_file_labels.is_some() && vector_filters_file.is_some() {
return Err(CMDToolError {
details: "Both base_file_labels and vector_filters_file cannot be provided."
.to_string(),
});
}
let insert_iterator = match insert_file {
Some(insert_file) => {
let i = VectorDataIterator::<
StorageProvider,
Data::VectorDataType,
Data::AssociatedDataType,
>::new(insert_file, Option::None, storage_provider)?;
Some(i)
}
None => None,
};
let query_data =
read_bin::<Data::VectorDataType>(&mut storage_provider.open_reader(query_file)?)?;
let query_num = query_data.nrows();
let mut query_bitmaps: Option<Vec<BitSet>> = None;
if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
{
query_bitmaps = Some(read_labels_and_compute_bitmap(
base_file_labels,
query_file_labels,
)?);
}
let vector_filters = match vector_filters_file {
Some(vector_filters_file) => {
let filters =
search_index_utils::load_vector_filters(storage_provider, vector_filters_file)?;
assert_eq!(
filters.len(),
query_num,
"Mismatch in query and vector filter sizes"
);
Some(filters)
}
None => None,
};
let has_vector_filters = vector_filters.is_some();
let has_query_bitmaps = query_bitmaps.is_some();
if has_vector_filters {
if let Some(filters) = vector_filters {
let mut bitmaps = vec![BitSet::new(); query_num];
for (idx_query, filter) in filters.iter().enumerate() {
for item in filter.iter() {
if let Ok(idx) = (*item).try_into() {
bitmaps[idx_query].insert(idx);
}
}
}
query_bitmaps = Some(bitmaps)
}
}
let ground_truth_result = compute_ground_truth_from_data::<Data, StorageProvider>(
distance_function,
dataset_iterator,
&query_data,
recall_at,
insert_iterator,
skip_base,
query_bitmaps,
);
assert!(
&ground_truth_result.is_ok(),
"Ground-truth computation failed"
);
let (ground_truth, id_to_associated_data) = ground_truth_result?;
assert_ne!(ground_truth.len(), 0, "No ground-truth results computed");
if has_vector_filters || has_query_bitmaps {
let ground_truth_collection = ground_truth
.into_iter()
.map(|npq| npq.into_iter().collect())
.collect();
write_range_search_ground_truth(
storage_provider,
ground_truth_file,
query_num,
ground_truth_collection,
)
} else {
let id_to_associated_data = associated_data_file.map(|_| id_to_associated_data);
write_ground_truth::<Data>(
storage_provider,
ground_truth_file,
query_num,
recall_at as usize,
ground_truth,
id_to_associated_data,
)
}
}
#[derive(Debug, Clone)]
pub enum MultivecAggregationMethod {
AveragePairwise,
MinPairwise,
AvgofMins,
}
#[derive(Debug)]
pub enum ParseAggrError {
InvalidFormat(String),
}
impl std::fmt::Display for ParseAggrError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidFormat(str) => write!(f, "Invalid format for Aggregation Method: {}", str),
}
}
}
impl std::error::Error for ParseAggrError {}
impl FromStr for MultivecAggregationMethod {
type Err = ParseAggrError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"average_pairwise" => Ok(MultivecAggregationMethod::AveragePairwise),
"min_pairwise" => Ok(MultivecAggregationMethod::MinPairwise),
"avg_of_mins" => Ok(MultivecAggregationMethod::AvgofMins),
_ => Err(ParseAggrError::InvalidFormat(String::from(s))),
}
}
}
#[allow(clippy::too_many_arguments)]
#[allow(clippy::panic)]
pub fn compute_multivec_ground_truth_from_datafiles<
Data: GraphDataType,
StorageProvider: StorageReadProvider + StorageWriteProvider,
>(
storage_provider: &StorageProvider,
distance_function: Metric,
aggregation_method: MultivecAggregationMethod,
base_file: &str,
query_file: &str,
ground_truth_file: &str,
recall_at: u32,
base_file_labels: Option<&str>,
query_file_labels: Option<&str>,
) -> CMDResult<()> {
let (base_vectors, _, _, _) = file_util::load_multivec_bin::<
Data::VectorDataType,
StorageProvider,
>(storage_provider, base_file)?;
let (query_vectors, query_num, query_dim, _) = file_util::load_multivec_bin::<
Data::VectorDataType,
StorageProvider,
>(storage_provider, query_file)?;
if !((base_file_labels.is_some() && query_file_labels.is_some())
|| (base_file_labels.is_none() && query_file_labels.is_none()))
{
return Err(CMDToolError {
details: "Both base_file_labels and query_file_labels must be provided or both must be not provided.".to_string(),
});
}
let mut query_bitmaps: Option<Vec<BitSet>> = None;
if let (Some(base_file_labels), Some(query_file_labels)) = (base_file_labels, query_file_labels)
{
query_bitmaps = Some(read_labels_and_compute_bitmap(
base_file_labels,
query_file_labels,
)?);
}
let has_query_bitmaps = query_bitmaps.is_some();
let ground_truth = compute_multivec_ground_truth_from_data::<Data::VectorDataType>(
distance_function,
aggregation_method,
base_vectors,
query_vectors,
query_dim,
recall_at,
query_bitmaps,
)?;
if has_query_bitmaps {
let ground_truth_collection = ground_truth
.into_iter()
.map(|npq| npq.into_iter().collect())
.collect();
write_range_search_ground_truth(
storage_provider,
ground_truth_file,
query_num,
ground_truth_collection,
)
} else {
write_ground_truth::<Data>(
storage_provider,
ground_truth_file,
query_num,
recall_at as usize,
ground_truth,
Option::None,
)
}
}
fn write_range_search_ground_truth<StorageProvider: StorageReadProvider + StorageWriteProvider>(
storage_provider: &StorageProvider,
ground_truth_file: &str,
number_of_queries: usize,
ground_truth: Vec<Vec<Neighbor<u32>>>,
) -> CMDResult<()> {
let mut file = storage_provider.create_for_write(ground_truth_file)?;
let queue_sizes: Vec<u32> = ground_truth
.iter()
.map(|queue| queue.len() as u32)
.collect();
let total_number_of_neighbors: usize = queue_sizes.iter().sum::<u32>() as usize;
Metadata::new(number_of_queries, total_number_of_neighbors)?.write(&mut file)?;
let mut queue_sizes_buffer = vec![0; queue_sizes.len() * size_of::<u32>()];
queue_sizes_buffer.clone_from_slice(cast_slice::<u32, u8>(&queue_sizes));
file.write_all(&queue_sizes_buffer)?;
let mut neighbor_ids: Vec<u32> = Vec::with_capacity(total_number_of_neighbors);
for query_neighbors in ground_truth {
for neighbor in query_neighbors.iter() {
neighbor_ids.push(neighbor.id);
}
}
let mut id_buffer = vec![0; total_number_of_neighbors * size_of::<u32>()];
id_buffer.clone_from_slice(cast_slice::<u32, u8>(&neighbor_ids));
file.write_all(&id_buffer)?;
file.flush()?;
Ok(())
}
fn write_ground_truth<Data: GraphDataType>(
storage_provider: &impl StorageWriteProvider,
ground_truth_file: &str,
number_of_queries: usize,
number_of_neighbors: usize,
ground_truth: Vec<NeighborPriorityQueue<u32>>,
id_to_associated_data: Option<Vec<Data::AssociatedDataType>>,
) -> CMDResult<()> {
let mut file = storage_provider.create_for_write(ground_truth_file)?;
Metadata::new(number_of_queries, number_of_neighbors)?.write(&mut file)?;
let mut gt_ids: Vec<u32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
let mut gt_distances: Vec<f32> = Vec::with_capacity(number_of_neighbors * number_of_queries);
for mut query_neighbors in ground_truth {
while let Some(closest_node) = query_neighbors.closest_notvisited() {
gt_ids.push(closest_node.id);
gt_distances.push(closest_node.distance);
}
}
if let Some(id_to_associated_data) = id_to_associated_data {
let mut associated_data_buffer = Vec::<u8>::new();
for id in gt_ids {
let associated_data = id_to_associated_data[id as usize];
let serialized_associated_data =
bincode::serialize(&associated_data).map_err(|e| CMDToolError {
details: format!("Failed to serialize associated data: {}", e),
})?;
associated_data_buffer.extend_from_slice(serialized_associated_data.as_slice());
}
file.write_all(&associated_data_buffer)?;
} else {
let mut id_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<u32>()];
id_buffer.clone_from_slice(cast_slice::<u32, u8>(>_ids));
file.write_all(&id_buffer)?;
}
let mut distance_buffer = vec![0; number_of_queries * number_of_neighbors * size_of::<f32>()];
distance_buffer.clone_from_slice(cast_slice::<f32, u8>(>_distances));
file.write_all(&distance_buffer)?;
file.flush()?;
Ok(())
}
type Npq = Vec<NeighborPriorityQueue<u32>>;
#[allow(clippy::too_many_arguments)]
pub fn compute_ground_truth_from_data<Data, VectorReader>(
distance_function: Metric,
dataset_iter: VectorDataIterator<VectorReader, Data::VectorDataType, Data::AssociatedDataType>,
queries: &Matrix<Data::VectorDataType>,
recall_at: u32,
insert_iter: Option<
VectorDataIterator<VectorReader, Data::VectorDataType, Data::AssociatedDataType>,
>,
skip_base: Option<usize>,
query_bitmaps: Option<Vec<BitSet>>,
) -> CMDResult<(Npq, Vec<Data::AssociatedDataType>)>
where
Data: GraphDataType,
VectorReader: StorageReadProvider,
{
let query_num = queries.nrows();
let query_dim = queries.ncols();
let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = (0..query_num)
.map(|_| NeighborPriorityQueue::new(recall_at as usize))
.collect();
let mut queries_and_neighbor_queue: Vec<_> =
queries.row_iter().zip(neighbor_queues.iter_mut()).collect();
let distance_comparer = Data::VectorDataType::distance(distance_function, Some(query_dim));
let batch_size = 10_000;
let mut data_batch: Vec<Box<[Data::VectorDataType]>> = Vec::with_capacity(batch_size);
let pool = create_thread_pool(0)?;
let mut num_base_points: usize = 0;
let mut id_to_associated_data = Vec::<Data::AssociatedDataType>::new();
let skip_base = skip_base.unwrap_or(0);
for chunk in dataset_iter.skip(skip_base).chunks(batch_size).into_iter() {
data_batch.clear();
for (data_vector, associated_data) in chunk {
data_batch.push(data_vector);
id_to_associated_data.push(associated_data);
}
let points = data_batch.len();
if points == 0 {
continue;
}
queries_and_neighbor_queue
.par_iter_mut()
.enumerate()
.for_each_in_pool(&pool, |(idx_query, (query, ref mut neighbor_queue))| {
for (idx_in_batch, data) in data_batch.iter().enumerate() {
let idx = (num_base_points + idx_in_batch) as u32;
let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
if let Ok(idx_usize) = idx.try_into() {
bitmaps[idx_query].contains(idx_usize)
} else {
false
}
} else {
true
};
if allowed_by_bitmap {
let distance = distance_comparer.evaluate_similarity(data, query);
neighbor_queue.insert(Neighbor { id: idx, distance });
}
}
});
num_base_points += points;
}
if let Some(insert_iter) = insert_iter {
for (insert_idx, (data_vector, _associated_data)) in insert_iter.enumerate() {
for (idx_query, (query, ref mut neighbor_queue)) in
queries_and_neighbor_queue.iter_mut().enumerate()
{
let idx = (num_base_points + insert_idx) as u32;
let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
if let Ok(idx_usize) = idx.try_into() {
bitmaps[idx_query].contains(idx_usize)
} else {
false
}
} else {
true
};
if allowed_by_bitmap {
let distance = distance_comparer.evaluate_similarity(&data_vector, query);
neighbor_queue.insert(Neighbor { id: idx, distance })
}
}
}
}
Ok((neighbor_queues, id_to_associated_data))
}
#[allow(clippy::too_many_arguments)]
pub fn compute_multivec_ground_truth_from_data<T>(
distance_function: Metric,
aggregation_method: MultivecAggregationMethod,
base_vectors: Vec<Matrix<T>>,
queries: Vec<Matrix<T>>,
query_dim: usize,
recall_at: u32,
query_bitmaps: Option<Vec<BitSet>>,
) -> CMDResult<Vec<NeighborPriorityQueue<u32>>>
where
T: VectorRepr,
{
let query_num = queries.len();
let mut neighbor_queues: Vec<NeighborPriorityQueue<u32>> = Vec::with_capacity(query_num);
for _ in 0..query_num {
neighbor_queues.push(NeighborPriorityQueue::new(recall_at as usize));
}
let mut query_multivecs_and_neighbor_queue: Vec<_> =
queries.iter().zip(neighbor_queues.iter_mut()).collect();
let distance_comparer = T::distance(distance_function, Some(query_dim));
let pool = create_thread_pool(0)?;
query_multivecs_and_neighbor_queue
.par_iter_mut()
.enumerate()
.for_each_in_pool(&pool, |(query_idx, (query_multivec, neighbor_queue))| {
for (idx_base, base_multivec) in base_vectors.iter().enumerate() {
let allowed_by_bitmap = if let Some(ref bitmaps) = query_bitmaps {
bitmaps[query_idx].contains(idx_base)
} else {
true
};
if allowed_by_bitmap {
let distance = match aggregation_method {
MultivecAggregationMethod::AveragePairwise => {
let mut total_distance = 0.0;
for query_vec in query_multivec.row_iter() {
for base_vec in base_multivec.row_iter() {
let dist =
distance_comparer.evaluate_similarity(query_vec, base_vec);
total_distance += dist;
}
}
total_distance / (query_multivec.nrows() * base_multivec.nrows()) as f32
}
MultivecAggregationMethod::MinPairwise => {
let mut min_distance = f32::MAX;
for query_vec in query_multivec.row_iter() {
for base_vec in base_multivec.row_iter() {
let dist =
distance_comparer.evaluate_similarity(query_vec, base_vec);
min_distance = min_distance.min(dist);
}
}
min_distance
}
MultivecAggregationMethod::AvgofMins => {
let mut distance = 0_f32;
for query_vec in query_multivec.row_iter() {
let mut local_min = f32::MAX;
for base_vec in base_multivec.row_iter() {
let dist =
distance_comparer.evaluate_similarity(query_vec, base_vec);
local_min = local_min.min(dist);
}
distance += local_min;
}
distance / query_multivec.nrows() as f32
}
};
let idx = idx_base as u32;
neighbor_queue.insert(Neighbor { id: idx, distance });
}
}
});
Ok(neighbor_queues)
}