use std::{
fs::File,
io::{BufWriter, Write},
num::NonZero,
};
use anyhow::{Context, Result};
use clap::Parser;
use diskann_providers::storage::StorageReadProvider;
use diskann_quantization::{
algorithms::transforms::{DoubleHadamard, TargetDim},
alloc::GlobalAllocator,
minmax::{DataMutRef, MinMaxQuantizer},
num::Positive,
CompressInto,
};
use diskann_utils::io::Metadata;
use half::f16;
use rand::{rngs::StdRng, SeedableRng};
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(short, long)]
input: String,
#[arg(short, long)]
output: String,
#[arg(short, long, default_value = "4")]
bits: u8,
#[arg(short, long, default_value = "f32")]
precision: String,
#[arg(short, long, default_value = "2282129662191")]
seed: u64,
#[arg(short, long, default_value = "1.0")]
grid_scale: f32,
}
fn dispatch_process_file<T: Copy + Into<f32> + bytemuck::Pod>(
bits: u8,
input: &str,
output: &str,
seed: u64,
scale: f32,
) -> Result<()> {
match bits {
1 => process_file::<1, T>(input, output, seed, scale),
2 => process_file::<2, T>(input, output, seed, scale),
4 => process_file::<4, T>(input, output, seed, scale),
8 => process_file::<8, T>(input, output, seed, scale),
_ => anyhow::bail!(
"Unsupported bit width: {}. Supported values are 1, 2, 4, 8",
bits
),
}
}
fn main() -> Result<()> {
let args = Args::parse();
match args.precision.as_str() {
"f32" | "float32" => dispatch_process_file::<f32>(
args.bits,
&args.input,
&args.output,
args.seed,
args.grid_scale,
),
"fp16" | "f16" => dispatch_process_file::<f16>(
args.bits,
&args.input,
&args.output,
args.seed,
args.grid_scale,
),
_ => anyhow::bail!(
"Unsupported precision: {}. Supported values are f32, fp16, f16, float",
args.precision
),
}
}
fn process_file<const NBITS: usize, T: Copy + Into<f32> + bytemuck::Pod>(
input_path: &str,
output_path: &str,
seed: u64,
scale: f32,
) -> Result<()>
where
diskann_quantization::bits::Unsigned: diskann_quantization::bits::Representation<NBITS>,
{
let input_data = diskann_utils::io::read_bin::<T>(
&mut diskann_providers::storage::FileStorageProvider
.open_reader(input_path)
.with_context(|| format!("Failed to open {}", input_path))?,
)
.with_context(|| format!("Failed to load data from {}", input_path))?;
let num_points = input_data.nrows();
let dim = input_data.ncols();
println!("Input file: {} points, {} dimensions", num_points, dim);
let mut rng = StdRng::seed_from_u64(seed);
let double_hadamard = DoubleHadamard::new(
NonZero::new(dim).unwrap(),
TargetDim::Same,
&mut rng,
GlobalAllocator,
)
.unwrap();
let transform = diskann_quantization::algorithms::Transform::DoubleHadamard(double_hadamard);
let quantizer = MinMaxQuantizer::new(transform, Positive::new(scale)?);
let output_dim = quantizer.output_dim();
let bytes_per_vector = diskann_quantization::minmax::Data::<NBITS>::canonical_bytes(output_dim);
println!("Bytes per quantized vector: {}", bytes_per_vector);
let output_file = File::create(output_path)
.with_context(|| format!("Failed to create output file {}", output_path))?;
let mut writer = BufWriter::new(output_file);
Metadata::new(num_points, bytes_per_vector)?
.write(&mut writer)
.context("Failed to write metadata header")?;
println!("Processing {} vectors...", num_points);
let mut loss = 0.0;
for i in 0..num_points {
let input_vector = input_data.row(i);
let mut quantized_buffer = vec![0u8; bytes_per_vector];
let quantized_data =
DataMutRef::<NBITS>::from_canonical_front_mut(&mut quantized_buffer, output_dim)
.with_context(|| format!("Failed to create quantized data ref for vector {}", i))?;
let loss_x = quantizer
.compress_into(input_vector, quantized_data)
.with_context(|| format!("Failed to compress vector {}", i))?;
loss += loss_x.as_f32();
writer
.write_all(&quantized_buffer)
.with_context(|| format!("Failed to write quantized vector {}", i))?;
}
writer.flush().context("Failed to flush output file")?;
println!(
"Successfully quantized {} vectors to {}",
num_points, output_path
);
println!("Average l2 loss : {}", loss / (num_points as f32));
println!("Output file format:");
println!(
" Header: {} bytes (num_points: u32, bytes_per_vector: u32)",
8
);
println!(
" Data: {} bytes ({} vectors × {} bytes each)",
num_points * bytes_per_vector,
num_points,
bytes_per_vector
);
Ok(())
}