vq 0.2.0

A vector quantization library for Rust
Documentation
#![allow(dead_code)]
#[path = "common.rs"]
mod common;

use anyhow::Result;
use clap::Parser;
use half::f16;
use std::time::Instant;
use vq::{Distance, ProductQuantizer, Quantizer, Vector};

#[derive(Parser)]
#[command(name = "eval_pq")]
#[command(about = "Evaluate Product Quantizer performance")]
struct Args {
    #[arg(long, default_value_t = common::SEED)]
    seed: u64,

    #[arg(long, default_value_t = common::DIM)]
    dim: usize,

    #[arg(long, default_value_t = common::M)]
    m: usize,

    #[arg(long, default_value_t = common::K)]
    k: usize,

    #[arg(long, default_value_t = common::MAX_ITERS)]
    max_iters: usize,
}

fn main() -> Result<()> {
    let args = Args::parse();
    println!("Product Quantizer Evaluation");
    println!("============================");

    for &n_samples in &common::NUM_SAMPLES {
        println!("\nSamples: {}", n_samples);

        let original_data = common::generate_synthetic_data(n_samples, args.dim, args.seed);
        let training_refs: Vec<&[f32]> = original_data.iter().map(|v| v.data.as_slice()).collect();

        let start = Instant::now();
        let pq = ProductQuantizer::new(
            &training_refs,
            args.m,
            args.k,
            args.max_iters,
            Distance::Euclidean,
            args.seed,
        )?;
        let training_time = start.elapsed().as_millis() as f64;

        let start = Instant::now();
        let quantized: Vec<Vec<f16>> = original_data
            .iter()
            .filter_map(|vec| pq.quantize(&vec.data).ok())
            .collect();
        let quantization_time = start.elapsed().as_millis() as f64;

        let reconstructed: Vec<Vector<f32>> = quantized
            .iter()
            .map(|q| Vector::new(q.iter().map(|&x| f16::to_f32(x)).collect()))
            .collect();

        let error = common::calculate_reconstruction_error(&original_data, &reconstructed);

        println!("  Training time: {} ms", training_time);
        println!("  Quantization time: {} ms", quantization_time);
        println!("  Reconstruction error: {:.6}", error);
    }

    Ok(())
}