use std::io::Write;
use std::{fmt::Debug, time::Instant};
use clap::Parser;
use kannolo::pq::ProductQuantizer;
use std::fs::File;
use kannolo::{
hnsw::graph_index::GraphIndex, hnsw_utils::config_hnsw::ConfigHnsw,
plain_quantizer::PlainQuantizer, read_numpy_f32_flatten_2d, Dataset, DenseDataset,
DistanceType, IndexSerializer,
};
#[derive(Parser, Debug)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[clap(short, long, value_parser)]
index_file: String,
#[clap(short, long, value_parser)]
query_file: String,
#[clap(short, long, value_parser)]
output_path: Option<String>,
#[clap(short, long, value_parser)]
#[arg(default_value_t = 64)]
m_pq: usize,
#[clap(short, long, value_parser)]
#[arg(default_value_t = 10)]
k: usize,
#[clap(long, value_parser)]
#[arg(default_value_t = 40)]
ef_search: usize,
#[clap(long, value_parser)]
#[arg(default_value_t = 1)]
n_run: usize,
#[clap(long, value_parser)]
#[arg(default_value_t = false)]
warmup: bool,
}
fn main() {
let args: Args = Args::parse();
let query_path = args.query_file;
let index_path = args.index_file;
let m_pq = args.m_pq;
let k = args.k;
let ef_search = args.ef_search;
println!("Reading Queries");
let (queries_vec, d) = read_numpy_f32_flatten_2d(query_path);
let queries = DenseDataset::from_vec(
queries_vec,
d,
PlainQuantizer::<f32>::new(d, DistanceType::Euclidean),
);
println!("Starting search");
let num_queries = queries.len();
let mut config = ConfigHnsw::new().build();
config.set_ef_search(ef_search);
println!("N queries {num_queries}");
if args.warmup {
}
let mut total_time_search = 0;
let mut results = Vec::<(f32, usize)>::with_capacity(num_queries);
match m_pq {
8 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<8>>, ProductQuantizer<8>> =
IndexSerializer::load_index(&index_path);
for query in queries.iter() {
let start_time = Instant::now();
results.extend(
index.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query, k, &config,
),
);
let duration_search = start_time.elapsed();
total_time_search += duration_search.as_micros();
}
index.print_space_usage_byte();
}
16 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<16>>, ProductQuantizer<16>> =
IndexSerializer::load_index(&index_path);
for query in queries.iter() {
let start_time = Instant::now();
results.extend(
index.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query, k, &config,
),
);
let duration_search = start_time.elapsed();
total_time_search += duration_search.as_micros();
}
index.print_space_usage_byte();
}
32 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<32>>, ProductQuantizer<32>> =
IndexSerializer::load_index(&index_path);
for query in queries.iter() {
let start_time = Instant::now();
results.extend(
index.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query, k, &config,
),
);
let duration_search = start_time.elapsed();
total_time_search += duration_search.as_micros();
}
index.print_space_usage_byte();
}
48 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<48>>, ProductQuantizer<48>> =
IndexSerializer::load_index(&index_path);
for query in queries.iter() {
let start_time = Instant::now();
results.extend(
index.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query, k, &config,
),
);
let duration_search = start_time.elapsed();
total_time_search += duration_search.as_micros();
}
index.print_space_usage_byte();
}
64 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<64>>, ProductQuantizer<64>> =
IndexSerializer::load_index(&index_path);
for query in queries.iter() {
let start_time = Instant::now();
results.extend(
index.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query, k, &config,
),
);
let duration_search = start_time.elapsed();
total_time_search += duration_search.as_micros();
}
index.print_space_usage_byte();
}
96 => {
println!("Sono entrato pd");
let index: GraphIndex<DenseDataset<ProductQuantizer<96>>, ProductQuantizer<96>> =
IndexSerializer::load_index(&index_path);
for query in queries.iter() {
let start_time = Instant::now();
results.extend(
index.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query, k, &config,
),
);
let duration_search = start_time.elapsed();
total_time_search += duration_search.as_micros();
}
index.print_space_usage_byte();
}
128 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<128>>, ProductQuantizer<128>> =
IndexSerializer::load_index(&index_path);
for query in queries.iter() {
let start_time = Instant::now();
results.extend(
index.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query, k, &config,
),
);
let duration_search = start_time.elapsed();
total_time_search += duration_search.as_micros();
}
index.print_space_usage_byte();
}
192 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<192>>, ProductQuantizer<192>> =
IndexSerializer::load_index(&index_path);
for query in queries.iter() {
let start_time = Instant::now();
results.extend(
index.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query, k, &config,
),
);
let duration_search = start_time.elapsed();
total_time_search += duration_search.as_micros();
}
index.print_space_usage_byte();
}
256 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<256>>, ProductQuantizer<256>> =
IndexSerializer::load_index(&index_path);
for query in queries.iter() {
let start_time = Instant::now();
results.extend(
index.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query, k, &config,
),
);
let duration_search = start_time.elapsed();
total_time_search += duration_search.as_micros();
}
index.print_space_usage_byte();
}
384 => {
let index: GraphIndex<DenseDataset<ProductQuantizer<384>>, ProductQuantizer<384>> =
IndexSerializer::load_index(&index_path);
for query in queries.iter() {
let start_time = Instant::now();
results.extend(
index.search::<DenseDataset<PlainQuantizer<f32>>, PlainQuantizer<f32>>(
query, k, &config,
),
);
let duration_search = start_time.elapsed();
total_time_search += duration_search.as_micros();
}
index.print_space_usage_byte();
}
_ => {
panic!("Unsupported M_PQ value");
}
}
let avg_time_search_per_query = total_time_search / (num_queries * args.n_run) as u128;
println!("[######] Average Query Time: {avg_time_search_per_query}");
let output_path = args.output_path.unwrap();
let mut output_file = File::create(output_path).unwrap();
for (query_id, result) in results.chunks_exact(k).enumerate() {
for (idx, (score, doc_id)) in result.iter().enumerate() {
writeln!(
&mut output_file,
"{query_id}\t{doc_id}\t{}\t{score}",
idx + 1,
)
.unwrap();
}
}
}