/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/
use clap::Parser;
use diskann_tools::utils::{
get_num_threads, init_subscriber, range_search_disk_index, CMDToolError, DataType,
GraphDataF32Vector, GraphDataHalfVector, GraphDataInt8Vector, GraphDataU8Vector,
};
use diskann_vector::distance::Metric;
fn main() -> Result<(), CMDToolError> {
init_subscriber();
let args: SearchDiskIndexArgs = SearchDiskIndexArgs::parse();
let threads = get_num_threads(args.num_threads);
let result = match args.data_type {
DataType::Float => range_search_disk_index::<GraphDataF32Vector>(
args.dist_fn,
&args.index_path_prefix,
&args.query_file,
&args.gt_file,
threads,
args.range_threshold,
args.beam_width,
args.search_io_limit,
&args.search_list,
args.num_nodes_to_cache,
),
DataType::Int8 => range_search_disk_index::<GraphDataInt8Vector>(
args.dist_fn,
&args.index_path_prefix,
&args.query_file,
&args.gt_file,
threads,
args.range_threshold,
args.beam_width,
args.search_io_limit,
&args.search_list,
args.num_nodes_to_cache,
),
DataType::Uint8 => range_search_disk_index::<GraphDataU8Vector>(
args.dist_fn,
&args.index_path_prefix,
&args.query_file,
&args.gt_file,
threads,
args.range_threshold,
args.beam_width,
args.search_io_limit,
&args.search_list,
args.num_nodes_to_cache,
),
DataType::Fp16 => range_search_disk_index::<GraphDataHalfVector>(
args.dist_fn,
&args.index_path_prefix,
&args.query_file,
&args.gt_file,
threads,
args.range_threshold,
args.beam_width,
args.search_io_limit,
&args.search_list,
args.num_nodes_to_cache,
),
};
match result {
Ok(_) => {
println!("Index search completed successfully");
Ok(())
}
Err(err) => {
tracing::error!("Index search failed - see diagnostic");
Err(err.into())
}
}
}
#[derive(Debug, Parser)]
struct SearchDiskIndexArgs {
/// data type <int8/uint8/float/fp16> (required)
#[arg(long = "data_type", required = true)]
pub data_type: DataType,
/// Distance function to use (l2, cosine)
#[arg(long = "dist_fn", required = true)]
pub dist_fn: Metric,
/// Path to the index file
#[arg(long = "index_path_prefix", required = true)]
pub index_path_prefix: String,
/// Query file in binary format
#[arg(long = "query_file", short, required = true)]
pub query_file: String,
/// Ground truth file for the queryset
#[arg(long = "gt_file", default_value = "")]
pub gt_file: String,
/// Number of neighbors to be returned
#[arg(long = "range_threshold", short = 'K', default_value = "10")]
pub range_threshold: f32,
/// List of L values of search
#[arg(long = "search_list", short = 'L', required = true, num_args=1..)]
pub search_list: Vec<u32>,
/// Beam width for beam search
#[arg(long = "beam_width", default_value = "2")]
pub beam_width: u32,
/// IO limit for each beam search, the default value is u32::MAX
#[arg(long = "search_io_limit", default_value = "4294967295")]
pub search_io_limit: u32,
/// Number of threads used for querying the index
#[arg(long = "num_threads", short = 'T')]
pub num_threads: Option<usize>,
/// Number of BFS nodes around medoid(s) to cache during query warm up
#[arg(long = "num_nodes_to_cache", default_value = "0")]
pub num_nodes_to_cache: usize,
}