use clap::Parser;
use diskann_providers::{storage::FileStorageProvider, utils::Timer};
use diskann_tools::utils::{
compute_ground_truth_from_datafiles, init_subscriber, CMDResult, DataType, GraphDataF32Vector,
GraphDataHalfVector, GraphDataInt8Vector, GraphDataU8Vector,
};
use diskann_vector::distance::Metric;
fn main() -> CMDResult<()> {
init_subscriber();
let timer = Timer::new();
let args = ComputeGroundTruthArgs::parse();
tracing::info!("Computing ground truth file");
let insert_file = None;
let skip_base = None;
let storage_provider = FileStorageProvider;
let err = match args.data_type {
DataType::Float => {
compute_ground_truth_from_datafiles::<GraphDataF32Vector, FileStorageProvider>(
&storage_provider,
args.distance_function,
&args.base_file,
&args.query_file,
&args.ground_truth_file,
args.vector_filters_file.as_deref(),
args.recall_at,
insert_file,
skip_base,
args.associated_data_file,
args.base_file_labels.as_deref(),
args.query_file_labels.as_deref(),
)
}
DataType::Fp16 => {
compute_ground_truth_from_datafiles::<GraphDataHalfVector, FileStorageProvider>(
&storage_provider,
args.distance_function,
&args.base_file,
&args.query_file,
&args.ground_truth_file,
args.vector_filters_file.as_deref(),
args.recall_at,
insert_file,
skip_base,
args.associated_data_file,
args.base_file_labels.as_deref(),
args.query_file_labels.as_deref(),
)
}
DataType::Uint8 => {
compute_ground_truth_from_datafiles::<GraphDataU8Vector, FileStorageProvider>(
&storage_provider,
args.distance_function,
&args.base_file,
&args.query_file,
&args.ground_truth_file,
args.vector_filters_file.as_deref(),
args.recall_at,
insert_file,
skip_base,
args.associated_data_file,
args.base_file_labels.as_deref(),
args.query_file_labels.as_deref(),
)
}
DataType::Int8 => {
compute_ground_truth_from_datafiles::<GraphDataInt8Vector, FileStorageProvider>(
&storage_provider,
args.distance_function,
&args.base_file,
&args.query_file,
&args.ground_truth_file,
args.vector_filters_file.as_deref(),
args.recall_at,
insert_file,
skip_base,
args.associated_data_file,
args.base_file_labels.as_deref(),
args.query_file_labels.as_deref(),
)
}
};
match err {
Ok(_) => {
tracing::info!(
"Compute ground-truth completed successfully in {:?}",
timer.elapsed()
);
Ok(())
}
Err(err) => {
tracing::error!("Error: {:?}", err);
Err(err)
}
}
}
#[derive(Debug, Parser)]
struct ComputeGroundTruthArgs {
#[arg(long = "data_type", default_value = "float")]
pub data_type: DataType,
#[arg(long = "dist_fn", default_value = "l2")]
pub distance_function: Metric,
#[arg(long = "base_file", short, required = true)]
pub base_file: String,
#[arg(long = "base_file_labels", default_value = None)]
pub base_file_labels: Option<String>,
#[arg(long = "query_file", short, required = true)]
pub query_file: String,
#[arg(long = "query_file_labels", default_value = None)]
pub query_file_labels: Option<String>,
#[arg(long = "gt_file", short, required = true)]
pub ground_truth_file: String,
#[arg(long = "vector_filters_file", short, default_value = None)]
pub vector_filters_file: Option<String>,
#[arg(long = "recall_at", short = 'K', default_value = "10")]
pub recall_at: u32,
#[arg(long = "associated_data_file", required = false, default_value = None)]
pub associated_data_file: Option<String>,
}