use std::fs;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use base64::engine::general_purpose::STANDARD;
use base64::Engine;
use tokenizers::models::wordlevel::WordLevel;
use tokenizers::pre_tokenizers::whitespace::Whitespace;
use tokenizers::Tokenizer;
use turboquant::real_model::{
RealModelGenerationConfig, RealModelQuantizationConfig, RealModelRunner,
};
use turboquant::QuantStrategy;
#[test]
fn tiny_onnx_fixture_runs_exact_and_quantized_decode() {
let fixture_dir = temp_fixture_dir("real-model-fixture");
write_fixture_bundle(&fixture_dir);
let runner = RealModelRunner::load(&fixture_dir).expect("fixture bundle should load");
let generation = RealModelGenerationConfig {
max_new_tokens: 3,
stop_on_eos: false,
};
let quantization = RealModelQuantizationConfig {
key_bits: 4,
value_bits: 4,
key_strategy: QuantStrategy::Prod,
seed: 7,
};
let exact = runner
.generate_exact("hello world", &generation)
.expect("exact decode should succeed");
let quantized = runner
.generate_quantized("hello world", &generation, &quantization)
.expect("quantized decode should succeed");
assert_eq!(runner.model_id(), "tiny-fixture-decoder");
assert_eq!(exact.generated_tokens.len(), 3);
assert_eq!(exact.step_logits.len(), 3);
assert_eq!(quantized.generated_tokens.len(), 3);
assert_eq!(quantized.step_logits.len(), 3);
assert!(exact.kv_cache.exact_bytes > 0);
assert!(quantized.kv_cache.quantized_bytes.is_some());
assert!(
quantized.kv_cache.quantized_bytes.unwrap() < quantized.kv_cache.exact_bytes,
"quantized cache should use fewer bytes than exact cache"
);
fs::remove_dir_all(&fixture_dir).expect("fixture cleanup should succeed");
}
#[test]
#[ignore = "set TURBOQUANT_REAL_MODEL_DIR to an exported ONNX bundle to run this smoke test"]
fn manual_exported_real_model_smoke_test() {
let model_dir = std::env::var("TURBOQUANT_REAL_MODEL_DIR")
.expect("TURBOQUANT_REAL_MODEL_DIR should point to an exported ONNX bundle");
let runner = RealModelRunner::load(&model_dir).expect("real model bundle should load");
let generation = RealModelGenerationConfig {
max_new_tokens: 2,
stop_on_eos: true,
};
let quantization = RealModelQuantizationConfig {
key_bits: 4,
value_bits: 4,
key_strategy: QuantStrategy::Prod,
seed: 42,
};
let exact = runner
.generate_exact("TurboQuant manual smoke test", &generation)
.expect("exact decode should succeed");
let quantized = runner
.generate_quantized("TurboQuant manual smoke test", &generation, &quantization)
.expect("quantized decode should succeed");
assert!(!exact.generated_tokens.is_empty());
assert!(!quantized.generated_tokens.is_empty());
}
fn write_fixture_bundle(dir: &Path) {
fs::create_dir_all(dir).expect("fixture directory should be created");
let onnx_bytes = STANDARD
.decode(include_str!("fixtures/tiny_decoder_with_past.onnx.b64").trim())
.expect("fixture ONNX should decode from base64");
fs::write(dir.join("model.onnx"), onnx_bytes).expect("fixture ONNX should be written");
fs::write(
dir.join("config.json"),
r#"{
"_name_or_path": "tiny-fixture-decoder",
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"hidden_size": 4,
"bos_token_id": 1,
"eos_token_id": 2
}
"#,
)
.expect("fixture config should be written");
fs::write(
dir.join("vocab.json"),
r#"{
"[UNK]": 0,
"<bos>": 1,
"<eos>": 2,
"hello": 3,
"world": 4,
"cache": 5,
"quant": 6,
"test": 7
}
"#,
)
.expect("tokenizer vocab should be written");
let model = WordLevel::builder()
.files(dir.join("vocab.json").display().to_string())
.unk_token("[UNK]".to_string())
.build()
.expect("word-level tokenizer model should build");
let mut tokenizer = Tokenizer::new(model);
tokenizer.with_pre_tokenizer(Some(Whitespace));
tokenizer
.save(dir.join("tokenizer.json"), false)
.expect("tokenizer fixture should save");
}
fn temp_fixture_dir(prefix: &str) -> PathBuf {
let unique = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("system clock should be after unix epoch")
.as_nanos();
std::env::temp_dir().join(format!("{prefix}-{}-{unique}", std::process::id()))
}