#![allow(clippy::field_reassign_with_default)]
use std::sync::atomic::{AtomicU64, Ordering};
use fasttext::args::{Args, LossName, ModelName};
use fasttext::error::FastTextError;
use fasttext::FastText;
const COOKING_MODEL: &str = "tests/fixtures/cooking.model.bin";
fn write_temp_file(content: &str) -> std::path::PathBuf {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let id = COUNTER.fetch_add(1, Ordering::Relaxed);
let path = std::env::temp_dir().join(format!(
"fasttext_quant_test_{}_{}.txt",
std::process::id(),
id
));
std::fs::write(&path, content).expect("Failed to write temp file");
path
}
fn write_unique_temp_file(content: &str, tag: &str) -> std::path::PathBuf {
static UNIQUE_COUNTER: AtomicU64 = AtomicU64::new(0);
let id = UNIQUE_COUNTER.fetch_add(1, Ordering::Relaxed);
let path = std::env::temp_dir().join(format!(
"fasttext_{}_{}_{}.txt",
tag,
std::process::id(),
id
));
std::fs::write(&path, content).expect("Failed to write unique temp file");
path
}
fn supervised_train_data() -> String {
let mut data = String::new();
for _ in 0..15 {
data.push_str("__label__sports basketball player sport game team score win\n");
}
for _ in 0..15 {
data.push_str("__label__food apple orange banana fruit eat cook recipe\n");
}
data
}
fn unsupervised_train_data() -> String {
let mut data = String::new();
for _ in 0..20 {
data.push_str("the quick brown fox jumps over the lazy dog\n");
data.push_str("machine learning algorithms work with data\n");
data.push_str("neural networks are powerful tools for classification\n");
}
data
}
fn train_small_supervised(dim: i32, epoch: i32, bucket: i32) -> (FastText, std::path::PathBuf) {
let data = supervised_train_data();
let path = write_unique_temp_file(&data, "quant_train");
let path_str = path.to_str().unwrap().to_string();
let mut args = Args::default();
args.input = std::path::PathBuf::from(path_str);
args.output = std::path::PathBuf::from(if cfg!(windows) { "NUL" } else { "/dev/null" });
args.apply_supervised_defaults();
args.dim = dim;
args.epoch = epoch;
args.min_count = 1;
args.lr = 0.1;
args.bucket = bucket;
args.thread = 1;
args.seed = 42;
let model = FastText::train(args).expect("Training should succeed");
(model, path)
}
#[test]
fn test_model_is_quant_false_for_bin() {
let model = FastText::load_model(COOKING_MODEL).unwrap();
assert!(
!model.is_quant(),
"cooking.model.bin should not be quantized"
);
}
#[test]
fn test_quantize_unsupervised_rejected() {
let data = unsupervised_train_data();
let path = write_temp_file(&data);
let path_str = path.to_str().unwrap().to_string();
let mut args = Args::default();
args.input = std::path::PathBuf::from(path_str);
args.output = std::path::PathBuf::from(if cfg!(windows) { "NUL" } else { "/dev/null" });
args.model = ModelName::Cbow;
args.loss = LossName::NegativeSampling;
args.dim = 10;
args.epoch = 1;
args.min_count = 1;
args.bucket = 100;
let mut model = FastText::train(args).expect("CBOW training should succeed");
std::fs::remove_file(&path).ok();
let qargs = Args::default();
let result = model.quantize(&qargs);
assert!(result.is_err(), "CBOW model quantize should return error");
match result.unwrap_err() {
FastTextError::InvalidArgument(msg) => {
assert!(
msg.contains("supervised") || msg.contains("supervised"),
"Error should mention supervised: {}",
msg
);
}
e => panic!("Expected InvalidArgument, got: {:?}", e),
}
}
#[test]
fn test_quantize_supervised_ok() {
let (mut model, path) = train_small_supervised(16, 5, 0);
std::fs::remove_file(&path).ok();
let mut qargs = Args::default();
qargs.dsub = 2;
let result = model.quantize(&qargs);
assert!(
result.is_ok(),
"Supervised model quantize should succeed: {:?}",
result.err()
);
assert!(
model.is_quant(),
"is_quant() should be true after quantization"
);
}
#[test]
fn test_quantize_produces_valid_model() {
let (mut model, path) = train_small_supervised(16, 5, 0);
std::fs::remove_file(&path).ok();
let mut qargs = Args::default();
qargs.dsub = 2;
model.quantize(&qargs).expect("Quantize should succeed");
assert!(model.is_quant(), "is_quant() should be true");
let preds = model.predict("basketball player sport game", 1, 0.0);
assert!(
!preds.is_empty(),
"Quantized model should produce predictions"
);
assert!(preds[0].prob > 0.0, "Prediction probability should be > 0");
assert!(
preds[0].prob <= 1.0,
"Prediction probability should be <= 1.0"
);
assert!(
preds[0].prob.is_finite(),
"Prediction probability should be finite"
);
assert!(
!preds[0].label.is_empty(),
"Prediction label should not be empty"
);
}
#[test]
fn test_quantize_smaller_file() {
let mut data = String::new();
for i in 0..200 {
data.push_str(&format!(
"__label__sports basketball game sport player score word{} tok{} item{} entry{}\n",
i * 3,
i * 3 + 1,
i * 3 + 2,
i
));
data.push_str(&format!(
"__label__food apple banana fruit eat cook word{} tok{} item{} entry{}\n",
i * 3 + 100,
i * 3 + 101,
i * 3 + 102,
i + 100
));
}
let path = write_unique_temp_file(&data, "quantize_smaller");
let path_str = path.to_str().unwrap().to_string();
let mut args = Args::default();
args.input = std::path::PathBuf::from(path_str);
args.output = std::path::PathBuf::from(if cfg!(windows) { "NUL" } else { "/dev/null" });
args.apply_supervised_defaults();
args.dim = 50;
args.epoch = 1;
args.min_count = 1;
args.bucket = 0;
args.thread = 1;
let mut model = FastText::train(args).expect("Training should succeed");
std::fs::remove_file(&path).ok();
let tmp_dir = std::env::temp_dir();
let bin_path = tmp_dir.join("test_quant_smaller.bin");
let ftz_path = tmp_dir.join("test_quant_smaller.ftz");
model
.save_model(bin_path.to_str().unwrap())
.expect("Save .bin should succeed");
let bin_size = std::fs::metadata(&bin_path).unwrap().len();
let mut qargs = Args::default();
qargs.dsub = 2;
model.quantize(&qargs).expect("Quantize should succeed");
model
.save_model(ftz_path.to_str().unwrap())
.expect("Save .ftz should succeed");
let ftz_size = std::fs::metadata(&ftz_path).unwrap().len();
std::fs::remove_file(&bin_path).ok();
std::fs::remove_file(&ftz_path).ok();
assert!(
ftz_size < bin_size,
".ftz ({} bytes) should be smaller than .bin ({} bytes), nwords={}",
ftz_size,
bin_size,
model.dict().nwords()
);
}
#[test]
fn test_quantize_prediction_agreement() {
let (mut model, train_path) = train_small_supervised(16, 10, 0);
std::fs::remove_file(&train_path).ok();
let test_inputs = [
"basketball player sport game score",
"apple orange banana fruit eat",
"team win lose tournament",
"cook recipe meal dessert",
"basketball game team score",
"fruit eat recipe food",
"sport player win",
"meal cook eat banana",
"game score win team",
"food fruit apple banana",
];
let preds_before: Vec<String> = test_inputs
.iter()
.map(|s| {
let p = model.predict(s, 1, 0.0);
if p.is_empty() {
String::new()
} else {
p[0].label.clone()
}
})
.collect();
let mut qargs = Args::default();
qargs.dsub = 2;
model.quantize(&qargs).expect("Quantize should succeed");
let preds_after: Vec<String> = test_inputs
.iter()
.map(|s| {
let p = model.predict(s, 1, 0.0);
if p.is_empty() {
String::new()
} else {
p[0].label.clone()
}
})
.collect();
let agreement = preds_before
.iter()
.zip(preds_after.iter())
.filter(|(b, a)| !b.is_empty() && b == a)
.count();
let total = preds_before.iter().filter(|l| !l.is_empty()).count();
assert!(total > 0, "Should have some predictions");
let rate = agreement as f32 / total as f32;
assert!(
rate >= 0.9,
"Quantized predictions should agree with unquantized >= 90%, got {:.1}% ({}/{})",
rate * 100.0,
agreement,
total
);
}
#[test]
fn test_quantize_retrain() {
let mut data = String::new();
for _ in 0..20 {
data.push_str("__label__sports basketball player sport game team score win\n");
data.push_str("__label__food apple orange banana fruit eat cook recipe\n");
}
let path = write_unique_temp_file(&data, "quantize_retrain");
let path_str = path.to_str().unwrap().to_string();
let mut args = Args::default();
args.input = std::path::PathBuf::from(&path_str);
args.output = std::path::PathBuf::from(if cfg!(windows) { "NUL" } else { "/dev/null" });
args.apply_supervised_defaults();
args.dim = 16;
args.epoch = 5;
args.min_count = 1;
args.lr = 0.1;
args.bucket = 0;
args.thread = 1;
args.seed = 42;
let mut model = FastText::train(args).expect("Training should succeed");
let nwords_before = model.dict().nwords();
let cutoff = (nwords_before as usize / 2).max(2);
let mut qargs = Args::default();
qargs.dsub = 2;
qargs.cutoff = cutoff;
qargs.retrain = true;
qargs.input = std::path::PathBuf::from(&path_str);
qargs.epoch = 1;
qargs.lr = 0.05;
qargs.thread = 1;
let result = model.quantize(&qargs);
std::fs::remove_file(&path).ok();
assert!(
result.is_ok(),
"retrain quantize should succeed: {:?}",
result.err()
);
assert!(
model.is_quant(),
"is_quant() should be true after retrain quantize"
);
let preds = model.predict("basketball player sport game", 1, 0.0);
assert!(
!preds.is_empty(),
"Retrained quantized model should produce predictions"
);
assert!(
preds[0].prob.is_finite() && preds[0].prob > 0.0,
"Retrained quantized model prediction prob should be valid"
);
}
#[test]
fn test_quantize_save_load_roundtrip() {
let (mut model, train_path) = train_small_supervised(16, 5, 0);
std::fs::remove_file(&train_path).ok();
let mut qargs = Args::default();
qargs.dsub = 2;
model.quantize(&qargs).expect("Quantize should succeed");
let test_input = "basketball player sport game score";
let preds_before = model.predict(test_input, 2, 0.0);
assert!(
!preds_before.is_empty(),
"Quantized model should have predictions before save"
);
let ftz_path = std::env::temp_dir().join("test_quant_roundtrip.ftz");
model
.save_model(ftz_path.to_str().unwrap())
.expect("Save .ftz should succeed");
let model2 =
FastText::load_model(ftz_path.to_str().unwrap()).expect("Load .ftz should succeed");
std::fs::remove_file(&ftz_path).ok();
assert!(
model2.is_quant(),
"Loaded .ftz model should have is_quant()=true"
);
let preds_after = model2.predict(test_input, 2, 0.0);
assert_eq!(
preds_before.len(),
preds_after.len(),
"Prediction count should match after .ftz round-trip"
);
for (i, (pb, pa)) in preds_before.iter().zip(preds_after.iter()).enumerate() {
assert_eq!(
pb.label, pa.label,
"Prediction[{}] label should match after .ftz round-trip: '{}' vs '{}'",
i, pb.label, pa.label
);
assert!(
(pb.prob - pa.prob).abs() < 1e-5,
"Prediction[{}] prob should be close after .ftz round-trip: {} vs {}",
i,
pb.prob,
pa.prob
);
}
}
#[test]
fn test_is_quant_true_for_ftz() {
let (mut model, train_path) = train_small_supervised(16, 3, 0);
std::fs::remove_file(&train_path).ok();
assert!(
!model.is_quant(),
"Before quantization: is_quant() should be false"
);
let mut qargs = Args::default();
qargs.dsub = 2;
model.quantize(&qargs).expect("Quantize should succeed");
assert!(
model.is_quant(),
"After quantization: is_quant() should be true"
);
let ftz_path = std::env::temp_dir().join("test_is_quant_ftz.ftz");
model
.save_model(ftz_path.to_str().unwrap())
.expect("Save should succeed");
let loaded =
FastText::load_model(ftz_path.to_str().unwrap()).expect("Load .ftz should succeed");
std::fs::remove_file(&ftz_path).ok();
assert!(loaded.is_quant(), "Loaded .ftz: is_quant() should be true");
}
#[test]
fn test_quantize_cutoff_predictions_valid() {
let (mut model, train_path) = train_small_supervised(16, 5, 0);
std::fs::remove_file(&train_path).ok();
let nwords_before = model.dict().nwords();
let cutoff = (nwords_before as usize / 2).max(1);
let mut qargs = Args::default();
qargs.dsub = 2;
qargs.cutoff = cutoff;
model
.quantize(&qargs)
.expect("cutoff quantize should succeed");
let preds = model.predict("basketball player sport game", 1, 0.0);
assert!(
!preds.is_empty(),
"Cutoff-pruned quantized model should produce predictions"
);
assert!(
preds[0].prob.is_finite() && preds[0].prob >= 0.0 && preds[0].prob <= 1.0,
"Cutoff-pruned model prediction prob {} should be in [0, 1]",
preds[0].prob
);
let nwords_after = model.dict().nwords();
assert!(
nwords_after <= cutoff as i32,
"After cutoff={}, nwords should be <= {}, got {}",
cutoff,
cutoff,
nwords_after
);
}
#[test]
fn test_quantize_cutoff_matrix_row_alignment() {
let (mut model, train_path) = train_small_supervised(16, 5, 0);
std::fs::remove_file(&train_path).ok();
let nwords_before = model.dict().nwords();
let cutoff = (nwords_before as usize / 2).max(2);
let mut qargs = Args::default();
qargs.dsub = 2;
qargs.cutoff = cutoff;
model
.quantize(&qargs)
.expect("cutoff quantize should succeed");
let ftz_path =
std::env::temp_dir().join(format!("test_cutoff_alignment_{}.ftz", std::process::id()));
model
.save_model(ftz_path.to_str().unwrap())
.expect("Save pruned .ftz should succeed");
let loaded =
FastText::load_model(ftz_path.to_str().unwrap()).expect("Load pruned .ftz should succeed");
std::fs::remove_file(&ftz_path).ok();
assert!(loaded.is_quant(), "Loaded pruned .ftz should be quant");
let preds = loaded.predict("basketball player sport game", 1, 0.0);
assert!(
!preds.is_empty(),
"Loaded cutoff-pruned model should produce predictions"
);
}