use std::io::{Seek, SeekFrom, Write};
use std::path::PathBuf;
use anyhow::{ensure, Result};
use clap::Parser;
use half::f16;
use rand::rngs::StdRng;
use rand_distr::{Distribution, StandardUniform};
use diskann::utils::VectorRepr;
use diskann_providers::storage::FileStorageProvider;
use diskann_providers::storage::StorageWriteProvider;
use diskann_providers::utils::{random, SampleVectorReader, SamplingDensity};
use diskann_tools::utils::DataType;
use diskann_utils::io::Metadata;
#[derive(Parser, Debug)]
#[command(name = "subsample_bin", about = "Subsample vectors from a binary file")]
struct Args {
#[arg(value_enum)]
data_type: DataType,
base_bin_file: PathBuf,
sampled_output_file: PathBuf,
sampling_probability: f64,
random_seed: Option<u64>,
}
fn create_rng(seed: Option<u64>) -> StdRng {
match seed {
Some(seed) => random::create_rnd_from_seed(seed),
None => random::create_rnd(),
}
}
fn run_for_type<T>(args: &Args) -> Result<()>
where
T: VectorRepr,
{
ensure!(
(0.0..=1.0).contains(&args.sampling_probability),
"sampling_probability must be in the range 0 to 1"
);
let mut rng = create_rng(args.random_seed);
let storage_provider = FileStorageProvider;
let data_file = args.base_bin_file.to_string_lossy().to_string();
let mut reader: SampleVectorReader<T, _> = SampleVectorReader::new(
&data_file,
SamplingDensity::from_sample_rate(args.sampling_probability),
&storage_provider,
)?;
let (npts, dims) = reader.get_dataset_headers();
println!(
"Found base file {} with {} points of dimension {}",
data_file, npts, dims
);
let distribution = StandardUniform;
let sampled_indices = (0..npts).filter(|_| {
let p: f64 = distribution.sample(&mut rng);
p < args.sampling_probability
});
let output_file = args.sampled_output_file.to_string_lossy().to_string();
let mut writer = storage_provider.create_for_write(&output_file)?;
Metadata::new(npts, dims)?.write(&mut writer)?;
let mut sampled_count: u32 = 0;
reader.read_vectors(sampled_indices, |vec_t| {
sampled_count += 1;
writer.write_all(bytemuck::cast_slice(vec_t))?;
Ok(())
})?;
writer.seek(SeekFrom::Start(0))?;
Metadata::new(sampled_count, dims)?.write(&mut writer)?;
println!(
"Wrote {} points to sample file {}",
sampled_count, output_file
);
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
match args.data_type {
DataType::Float => run_for_type::<f32>(&args),
DataType::Int8 => run_for_type::<i8>(&args),
DataType::Uint8 => run_for_type::<u8>(&args),
DataType::Fp16 => run_for_type::<f16>(&args),
}
}