turboquant 0.1.1

Implementation of Google's TurboQuant algorithm for vector quantization
Documentation
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};

use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use tokenizers::models::wordlevel::WordLevel;
use tokenizers::pre_tokenizers::whitespace::Whitespace;
use tokenizers::Tokenizer;
use turboquant::real_model::{
    RealModelGenerationConfig, RealModelQuantizationConfig, RealModelRunner,
};
use turboquant::QuantStrategy;

#[test]
fn tiny_onnx_fixture_runs_exact_and_quantized_decode() {
    let fixture_dir = temp_fixture_dir("real-model-fixture");
    write_fixture_bundle(&fixture_dir);

    let runner = RealModelRunner::load(&fixture_dir).expect("fixture bundle should load");
    let generation = RealModelGenerationConfig {
        max_new_tokens: 3,
        stop_on_eos: false,
    };
    let quantization = RealModelQuantizationConfig {
        key_bits: 4,
        value_bits: 4,
        key_strategy: QuantStrategy::Prod,
        seed: 7,
    };

    let exact = runner
        .generate_exact("hello world", &generation)
        .expect("exact decode should succeed");
    let quantized = runner
        .generate_quantized("hello world", &generation, &quantization)
        .expect("quantized decode should succeed");

    assert_eq!(runner.model_id(), "tiny-fixture-decoder");
    assert_eq!(exact.generated_tokens.len(), 3);
    assert_eq!(exact.step_logits.len(), 3);
    assert_eq!(quantized.generated_tokens.len(), 3);
    assert_eq!(quantized.step_logits.len(), 3);
    assert!(exact.kv_cache.exact_bytes > 0);
    assert!(quantized.kv_cache.quantized_bytes.is_some());
    assert!(
        quantized.kv_cache.quantized_bytes.unwrap() < quantized.kv_cache.exact_bytes,
        "quantized cache should use fewer bytes than exact cache"
    );

    fs::remove_dir_all(&fixture_dir).expect("fixture cleanup should succeed");
}

#[test]
#[ignore = "set TURBOQUANT_REAL_MODEL_DIR to an exported ONNX bundle to run this smoke test"]
fn manual_exported_real_model_smoke_test() {
    let model_dir = std::env::var("TURBOQUANT_REAL_MODEL_DIR")
        .expect("TURBOQUANT_REAL_MODEL_DIR should point to an exported ONNX bundle");
    let runner = RealModelRunner::load(&model_dir).expect("real model bundle should load");
    let generation = RealModelGenerationConfig {
        max_new_tokens: 2,
        stop_on_eos: true,
    };
    let quantization = RealModelQuantizationConfig {
        key_bits: 4,
        value_bits: 4,
        key_strategy: QuantStrategy::Prod,
        seed: 42,
    };

    let exact = runner
        .generate_exact("TurboQuant manual smoke test", &generation)
        .expect("exact decode should succeed");
    let quantized = runner
        .generate_quantized("TurboQuant manual smoke test", &generation, &quantization)
        .expect("quantized decode should succeed");

    assert!(!exact.generated_tokens.is_empty());
    assert!(!quantized.generated_tokens.is_empty());
}

fn write_fixture_bundle(dir: &Path) {
    fs::create_dir_all(dir).expect("fixture directory should be created");

    let onnx_bytes = STANDARD
        .decode(include_str!("fixtures/tiny_decoder_with_past.onnx.b64").trim())
        .expect("fixture ONNX should decode from base64");
    fs::write(dir.join("model.onnx"), onnx_bytes).expect("fixture ONNX should be written");
    fs::write(
        dir.join("config.json"),
        r#"{
  "_name_or_path": "tiny-fixture-decoder",
  "num_hidden_layers": 1,
  "num_attention_heads": 2,
  "num_key_value_heads": 2,
  "hidden_size": 4,
  "bos_token_id": 1,
  "eos_token_id": 2
}
"#,
    )
    .expect("fixture config should be written");

    fs::write(
        dir.join("vocab.json"),
        r#"{
  "[UNK]": 0,
  "<bos>": 1,
  "<eos>": 2,
  "hello": 3,
  "world": 4,
  "cache": 5,
  "quant": 6,
  "test": 7
}
"#,
    )
    .expect("tokenizer vocab should be written");
    let model = WordLevel::builder()
        .files(dir.join("vocab.json").display().to_string())
        .unk_token("[UNK]".to_string())
        .build()
        .expect("word-level tokenizer model should build");
    let mut tokenizer = Tokenizer::new(model);
    tokenizer.with_pre_tokenizer(Some(Whitespace));
    tokenizer
        .save(dir.join("tokenizer.json"), false)
        .expect("tokenizer fixture should save");
}

fn temp_fixture_dir(prefix: &str) -> PathBuf {
    let unique = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .expect("system clock should be after unix epoch")
        .as_nanos();
    std::env::temp_dir().join(format!("{prefix}-{}-{unique}", std::process::id()))
}