use std::io::BufReader;
use std::path::Path;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::args::{Args, LossName};
use crate::error::{FastTextError, Result};
use crate::fasttext::FastText;
use crate::model::MinstdRng;
static SIZE_CHECK_COUNTER: AtomicU64 = AtomicU64::new(0);
fn normal_sample(rng: &mut MinstdRng) -> f64 {
let u1 = (rng.generate() as f64 / MinstdRng::M as f64).max(1e-15); let u2 = rng.generate() as f64 / MinstdRng::M as f64;
let r = (-2.0 * u1.ln()).sqrt();
let theta = 2.0 * std::f64::consts::PI * u2;
r * theta.cos()
}
#[allow(clippy::too_many_arguments)]
fn update_arg_gauss(
val: f64,
min: f64,
max: f64,
start_sigma: f64,
end_sigma: f64,
t: f64,
linear: bool,
rng: &mut MinstdRng,
) -> f64 {
let stddev =
start_sigma - ((start_sigma - end_sigma) / 0.5) * (0.5f64.min((t - 0.25).max(0.0)));
let coeff = normal_sample(rng) * stddev;
let result = if linear {
val + coeff
} else {
val * 2.0f64.powf(coeff)
};
result.max(min).min(max)
}
#[allow(clippy::too_many_arguments)]
fn update_arg_gauss_i32(
val: i32,
min: i32,
max: i32,
start_sigma: f64,
end_sigma: f64,
t: f64,
linear: bool,
rng: &mut MinstdRng,
) -> i32 {
let result = update_arg_gauss(
val as f64,
min as f64,
max as f64,
start_sigma,
end_sigma,
t,
linear,
rng,
);
result as i32
}
pub struct AutotuneStrategy {
best_args: Args,
max_duration: f64,
rng: MinstdRng,
trials: u32,
best_minn_index: usize,
best_dsub_exponent: i32,
best_nonzero_bucket: i32,
original_bucket: i32,
minn_choices: Vec<i32>,
}
impl AutotuneStrategy {
pub fn new(original_args: &Args, seed: u64) -> Self {
let minn_choices = vec![0, 2, 3];
let mut strategy = AutotuneStrategy {
best_args: original_args.clone(),
max_duration: original_args.autotune_duration as f64,
rng: MinstdRng::new(seed),
trials: 0,
best_minn_index: 0,
best_dsub_exponent: 1,
best_nonzero_bucket: 2_000_000,
original_bucket: original_args.bucket,
minn_choices,
};
strategy.update_best(original_args);
strategy
}
pub fn ask(&mut self, elapsed: f64) -> Args {
let t = (elapsed / self.max_duration).min(1.0);
self.trials += 1;
if self.trials == 1 {
return self.best_args.clone();
}
let mut args = self.best_args.clone();
let epoch = update_arg_gauss_i32(args.epoch, 1, 100, 2.8, 2.5, t, false, &mut self.rng);
args.epoch = epoch;
let lr = update_arg_gauss(args.lr, 0.01, 5.0, 1.9, 1.0, t, false, &mut self.rng);
args.lr = lr;
let dim = update_arg_gauss_i32(args.dim, 1, 1000, 1.4, 0.3, t, false, &mut self.rng);
args.dim = dim;
let word_ngrams =
update_arg_gauss_i32(args.word_ngrams, 1, 5, 4.3, 2.4, t, true, &mut self.rng);
args.word_ngrams = word_ngrams;
let dsub_exp = update_arg_gauss_i32(
self.best_dsub_exponent,
1,
4,
2.0,
1.0,
t,
true,
&mut self.rng,
);
args.dsub = 1usize << dsub_exp;
let minn_idx = update_arg_gauss_i32(
self.best_minn_index as i32,
0,
(self.minn_choices.len() - 1) as i32,
4.0,
1.4,
t,
true,
&mut self.rng,
);
let minn_idx_clamped = minn_idx.max(0) as usize;
let minn = self.minn_choices[minn_idx_clamped.min(self.minn_choices.len() - 1)];
args.minn = minn;
if minn == 0 {
args.maxn = 0;
} else {
args.maxn = minn + 3;
}
let nonzero_bucket = update_arg_gauss_i32(
self.best_nonzero_bucket,
10_000,
10_000_000,
2.0,
1.5,
t,
false,
&mut self.rng,
);
if args.word_ngrams > 1 || minn != 0 {
args.bucket = nonzero_bucket;
} else {
args.bucket = self.original_bucket;
}
if args.word_ngrams <= 1 && args.maxn == 0 {
args.bucket = 0;
}
args.loss = LossName::Softmax;
args
}
pub fn update_best(&mut self, args: &Args) {
self.best_args = args.clone();
self.best_minn_index = Self::find_index(args.minn, &self.minn_choices);
let dsub = args.dsub as f64;
self.best_dsub_exponent = if dsub > 0.0 {
dsub.log2().round() as i32
} else {
1
};
if args.bucket != 0 {
self.best_nonzero_bucket = args.bucket;
}
}
fn find_index(val: i32, choices: &[i32]) -> usize {
choices.iter().position(|&x| x == val).unwrap_or(0)
}
}
const UNKNOWN_BEST_SCORE: f64 = f64::NEG_INFINITY;
fn parse_size_to_bytes(s: &str) -> Option<u64> {
let s = s.trim();
if s.is_empty() {
return None;
}
let split_pos = s
.find(|c: char| !c.is_ascii_digit() && c != '.')
.unwrap_or(s.len());
let num_str = &s[..split_pos];
let unit = s[split_pos..].trim();
let num: f64 = num_str.parse().ok()?;
let multiplier: u64 = match unit.to_ascii_uppercase().as_str() {
"" | "B" => 1,
"K" | "KB" => 1_024,
"M" | "MB" => 1_024 * 1_024,
"G" | "GB" => 1_024 * 1_024 * 1_024,
_ => return None,
};
Some((num * multiplier as f64) as u64)
}
enum TrialOutcome {
TimedOut,
Failed,
Success(Box<FastText>),
}
pub struct Autotune;
impl Autotune {
pub fn run(autotune_args: Args) -> Result<FastText> {
let val_path = autotune_args.autotune_validation_file.clone();
if val_path.as_os_str().is_empty() {
return Err(FastTextError::InvalidArgument(
"autotune validation file is not set".to_string(),
));
}
let _ = std::fs::File::open(&val_path).map_err(FastTextError::IoError)?;
let seed = autotune_args.seed as u64;
let duration_secs = autotune_args.autotune_duration as f64;
let k = autotune_args.autotune_predictions.max(1) as usize;
let metric = autotune_args.autotune_metric.to_string();
let model_size_bytes = parse_size_to_bytes(&autotune_args.autotune_model_size);
let mut search_args = autotune_args.clone();
search_args.verbose = 0;
let mut strategy = AutotuneStrategy::new(&search_args, seed);
let start = Instant::now();
let mut best_args: Option<Args> = None;
let mut best_score = UNKNOWN_BEST_SCORE;
loop {
let elapsed = start.elapsed().as_secs_f64();
if elapsed >= duration_secs {
break;
}
let trial_args = strategy.ask(elapsed);
let model = match Self::train_trial(trial_args.clone(), &start, duration_secs) {
TrialOutcome::TimedOut => break,
TrialOutcome::Failed => continue,
TrialOutcome::Success(m) => *m,
};
let model = if let Some(max_bytes) = model_size_bytes {
match Self::check_model_size(model, max_bytes, &start, duration_secs) {
Some(m) => m,
None => continue,
}
} else {
model
};
if let Ok(score) = Self::evaluate(&model, &val_path, k, &metric) {
if best_args.is_none() || score > best_score {
best_score = score;
strategy.update_best(&trial_args);
best_args = Some(trial_args);
}
}
}
Self::finish_with_best(best_args, autotune_args.verbose)
}
fn train_trial(trial_args: Args, start: &Instant, duration_secs: f64) -> TrialOutcome {
let abort_flag = Arc::new(AtomicBool::new(false));
let abort_clone = Arc::clone(&abort_flag);
let handle =
std::thread::spawn(move || FastText::train_with_abort(trial_args, abort_clone));
let timed_out = loop {
if handle.is_finished() {
break false;
}
if start.elapsed().as_secs_f64() >= duration_secs {
abort_flag.store(true, Ordering::Relaxed);
break true;
}
std::thread::sleep(Duration::from_millis(10));
};
let model_result = handle.join();
if timed_out {
return TrialOutcome::TimedOut;
}
match model_result {
Ok(Ok(m)) => TrialOutcome::Success(Box::new(m)),
Ok(Err(_)) | Err(_) => TrialOutcome::Failed,
}
}
fn check_model_size(
mut model: FastText,
max_bytes: u64,
start: &Instant,
duration_secs: f64,
) -> Option<FastText> {
if start.elapsed().as_secs_f64() >= duration_secs {
return None;
}
let qargs = Args::default();
if model.quantize(&qargs).is_err() {
return None;
}
if start.elapsed().as_secs_f64() >= duration_secs {
return None;
}
let tmp_path = std::env::temp_dir().join(format!(
"fasttext_autotune_size_{}_{}.ftz",
std::process::id(),
SIZE_CHECK_COUNTER.fetch_add(1, Ordering::Relaxed)
));
if model.save_model(&tmp_path).is_err() {
std::fs::remove_file(&tmp_path).ok();
return None;
}
let ftz_size = std::fs::metadata(&tmp_path)
.map(|m| m.len())
.unwrap_or(u64::MAX);
std::fs::remove_file(&tmp_path).ok();
if ftz_size > max_bytes {
None
} else {
Some(model)
}
}
fn finish_with_best(best_args: Option<Args>, original_verbose: i32) -> Result<FastText> {
match best_args {
None => Err(FastTextError::InvalidArgument(
"Autotune: no trial completed successfully within the time budget. \
Consider increasing autotune-duration."
.to_string(),
)),
Some(mut final_args) => {
final_args.verbose = original_verbose;
FastText::train_with_abort(final_args, Arc::new(AtomicBool::new(false)))
}
}
}
fn evaluate(model: &FastText, val_path: &Path, k: usize, metric: &str) -> Result<f64> {
let file = std::fs::File::open(val_path).map_err(FastTextError::IoError)?;
let mut reader = BufReader::new(file);
let meter = model.test_model(&mut reader, k, 0.0)?;
if metric == "f1" {
Ok(meter.f1())
} else if let Some(label_name) = metric.strip_prefix("f1:") {
let label_id = model.dict().get_id(label_name).unwrap_or(-1);
Ok(meter.f1_for_label(label_id))
} else {
Ok(meter.f1())
}
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::field_reassign_with_default)]
use super::*;
use crate::args::Args;
use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
static FILE_COUNTER: AtomicU64 = AtomicU64::new(0);
fn write_temp(content: &str, tag: &str) -> std::path::PathBuf {
let id = FILE_COUNTER.fetch_add(1, AtomicOrdering::Relaxed);
let path = std::env::temp_dir().join(format!(
"fasttext_autotune_{}_{}_{}.txt",
tag,
std::process::id(),
id
));
std::fs::write(&path, content).expect("Failed to write temp file");
path
}
fn make_train_data() -> String {
let mut data = String::new();
for _ in 0..20 {
data.push_str(
"__label__sports basketball player sport game team score win tournament championship\n",
);
}
for _ in 0..20 {
data.push_str(
"__label__food apple orange banana mango fruit eat cook recipe meal dessert\n",
);
}
data
}
fn make_val_data() -> String {
let mut data = String::new();
for _ in 0..10 {
data.push_str("__label__sports sport player game win score\n");
}
for _ in 0..10 {
data.push_str("__label__food banana fruit eat recipe cook\n");
}
data
}
fn make_fast_supervised_args(input: &std::path::Path) -> Args {
let mut args = Args::default();
args.input = input.to_path_buf();
args.output = std::path::PathBuf::from(if cfg!(windows) { "NUL" } else { "/dev/null" });
args.apply_supervised_defaults();
args.dim = 10;
args.epoch = 3;
args.min_count = 1;
args.lr = 0.1;
args.bucket = 0;
args.thread = 1;
args.seed = 42;
args
}
#[test]
fn test_autotune_activation_default_inactive() {
let args = Args::default();
assert!(
!args.has_autotune(),
"has_autotune() should be false when validation file is empty"
);
assert!(
args.autotune_validation_file.as_os_str().is_empty(),
"Default autotune_validation_file should be empty"
);
}
#[test]
fn test_autotune_activation_when_validation_file_set() {
let mut args = Args::default();
args.autotune_validation_file = std::path::PathBuf::from("some_validation_file.txt");
assert!(
args.has_autotune(),
"has_autotune() should be true when validation file is set"
);
}
#[test]
fn test_autotune_duration_default() {
let args = Args::default();
assert_eq!(
args.autotune_duration, 300,
"Default autotune duration should be 300 seconds"
);
}
#[test]
fn test_autotune_duration_custom() {
let mut args = Args::default();
args.autotune_duration = 60;
assert_eq!(
args.autotune_duration, 60,
"Autotune duration should reflect the custom value"
);
args.autotune_duration = 1;
assert_eq!(
args.autotune_duration, 1,
"Should accept duration of 1 second"
);
args.autotune_duration = 3600;
assert_eq!(
args.autotune_duration, 3600,
"Should accept large durations"
);
}
#[test]
fn test_autotune_time_budget() {
let train_data = make_train_data();
let train_path = write_temp(&train_data, "time_budget_train");
let val_data = make_val_data();
let val_path = write_temp(&val_data, "time_budget_val");
let mut args = make_fast_supervised_args(&train_path);
args.epoch = 1; args.dim = 5;
args.autotune_validation_file = val_path.clone();
args.autotune_duration = 3;
let start = Instant::now();
let result = Autotune::run(args);
let elapsed = start.elapsed();
std::fs::remove_file(&train_path).ok();
std::fs::remove_file(&val_path).ok();
let max_allowed = Duration::from_secs(3 + 30);
assert!(
elapsed <= max_allowed,
"Autotune ran for {:?} which exceeds 3 + 30 seconds",
elapsed
);
match result {
Ok(model) => {
let preds = model.predict("basketball player sport", 1, 0.0);
assert!(
!preds.is_empty(),
"Model returned by autotune should produce predictions"
);
}
Err(_e) => {
}
}
}
#[test]
fn test_autotune_tunes_params() {
let train_data = make_train_data();
let train_path = write_temp(&train_data, "tunes_params_train");
let val_data = make_val_data();
let val_path = write_temp(&val_data, "tunes_params_val");
let mut args = make_fast_supervised_args(&train_path);
args.epoch = 1; args.dim = 5;
args.autotune_validation_file = val_path.clone();
args.autotune_duration = 4;
let result = Autotune::run(args);
std::fs::remove_file(&train_path).ok();
std::fs::remove_file(&val_path).ok();
let model = result.expect("Autotune should succeed within 4 seconds");
let best_args = model.args();
assert!(
best_args.epoch >= 1 && best_args.epoch <= 100,
"Best epoch {} out of range [1, 100]",
best_args.epoch
);
assert!(
best_args.lr >= 0.01 && best_args.lr <= 5.0,
"Best lr {} out of range [0.01, 5.0]",
best_args.lr
);
assert!(
best_args.dim >= 1 && best_args.dim <= 1000,
"Best dim {} out of range [1, 1000]",
best_args.dim
);
assert!(
best_args.word_ngrams >= 1 && best_args.word_ngrams <= 5,
"Best wordNgrams {} out of range [1, 5]",
best_args.word_ngrams
);
let preds = model.predict("basketball player sport game", 1, 0.0);
assert!(
!preds.is_empty(),
"Autotune model should produce predictions"
);
}
#[test]
fn test_autotune_strategy_explores_params() {
let train_path = write_temp(&make_train_data(), "strategy_dummy");
let mut args = make_fast_supervised_args(&train_path);
args.epoch = 5;
args.autotune_duration = 300;
std::fs::remove_file(&train_path).ok();
let mut strategy = AutotuneStrategy::new(&args, 42);
let trial1 = strategy.ask(0.0);
assert_eq!(
trial1.epoch, args.epoch,
"Trial 1 must return original epoch"
);
assert_eq!(trial1.dim, args.dim, "Trial 1 must return original dim");
let trial2 = strategy.ask(1.0);
let epoch_differs = trial2.epoch != args.epoch;
let lr_differs = (trial2.lr - args.lr).abs() > 1e-9;
let dim_differs = trial2.dim != args.dim;
assert!(
epoch_differs || lr_differs || dim_differs,
"Trial 2 must differ from original in at least one parameter \
(epoch={}, lr={:.4}, dim={})",
trial2.epoch,
trial2.lr,
trial2.dim
);
}
#[test]
fn test_autotune_returns_model() {
let train_data = make_train_data();
let train_path = write_temp(&train_data, "returns_model_train");
let val_data = make_val_data();
let val_path = write_temp(&val_data, "returns_model_val");
let mut args = make_fast_supervised_args(&train_path);
args.epoch = 2;
args.dim = 5;
args.autotune_validation_file = val_path.clone();
args.autotune_duration = 3;
let result = Autotune::run(args);
std::fs::remove_file(&train_path).ok();
std::fs::remove_file(&val_path).ok();
let model = result.expect("Autotune should return a model");
let preds = model.predict("basketball player sport game", 1, 0.0);
assert!(!preds.is_empty(), "Predictions must be non-empty");
for p in &preds {
assert!(
p.prob >= 0.0 && p.prob <= 1.0,
"Prediction probability {} out of range [0, 1]",
p.prob
);
}
let best_args = model.args();
assert!(best_args.epoch >= 1, "Best epoch must be >= 1");
assert!(best_args.lr > 0.0, "Best lr must be positive");
assert!(best_args.dim >= 1, "Best dim must be >= 1");
}
#[test]
fn test_autotune_minimal_duration() {
let train_data = make_train_data();
let train_path = write_temp(&train_data, "minimal_dur_train");
let val_data = make_val_data();
let val_path = write_temp(&val_data, "minimal_dur_val");
let mut args = make_fast_supervised_args(&train_path);
args.epoch = 1;
args.dim = 5;
args.autotune_validation_file = val_path.clone();
args.autotune_duration = 1;
let result = Autotune::run(args);
std::fs::remove_file(&train_path).ok();
std::fs::remove_file(&val_path).ok();
let model = result.expect("Autotune with 1-second duration should return a model");
let preds = model.predict("sport game basketball player", 1, 0.0);
assert!(!preds.is_empty(), "Predictions must be non-empty");
assert!(
!preds[0].label.is_empty(),
"Prediction label must be non-empty"
);
}
#[test]
fn test_autotune_strategy_update_best() {
let train_path = write_temp(&make_train_data(), "strategy_update");
let mut args = make_fast_supervised_args(&train_path);
args.epoch = 5;
args.autotune_duration = 300;
std::fs::remove_file(&train_path).ok();
let mut strategy = AutotuneStrategy::new(&args, 1);
let mut new_best = args.clone();
new_best.minn = 2;
new_best.dsub = 4; strategy.update_best(&new_best);
assert_eq!(strategy.best_minn_index, 1, "minn=2 should map to index 1");
assert_eq!(
strategy.best_dsub_exponent, 2,
"dsub=4 should give exponent 2"
);
}
#[test]
fn test_normal_sample_basic() {
let mut rng = MinstdRng::new(42);
let samples: Vec<f64> = (0..1000).map(|_| normal_sample(&mut rng)).collect();
let mean: f64 = samples.iter().sum::<f64>() / samples.len() as f64;
let var: f64 = samples
.iter()
.map(|&x| (x - mean) * (x - mean))
.sum::<f64>()
/ samples.len() as f64;
let stddev = var.sqrt();
assert!(
mean.abs() < 0.15,
"Normal sample mean should be near 0, got {}",
mean
);
assert!(
(stddev - 1.0).abs() < 0.2,
"Normal sample stddev should be near 1, got {}",
stddev
);
}
#[test]
fn test_update_arg_gauss_bounds() {
let mut rng = MinstdRng::new(7);
for _ in 0..100 {
let v = update_arg_gauss(10.0, 5.0, 15.0, 3.0, 1.0, 0.0, false, &mut rng);
assert!((5.0..=15.0).contains(&v), "Value {} out of [5, 15]", v);
let vi = update_arg_gauss_i32(10, 5, 15, 3.0, 1.0, 0.0, false, &mut rng);
assert!((5..=15).contains(&vi), "Value {} out of [5, 15]", vi);
}
}
#[test]
fn test_update_arg_gauss_sigma_decreases() {
let mut rng_early = MinstdRng::new(99);
let mut rng_late = MinstdRng::new(99);
let early_samples: Vec<f64> = (0..200)
.map(|_| update_arg_gauss(10.0, 1.0, 1000.0, 3.0, 1.0, 0.0, false, &mut rng_early))
.collect();
let late_samples: Vec<f64> = (0..200)
.map(|_| update_arg_gauss(10.0, 1.0, 1000.0, 3.0, 1.0, 1.0, false, &mut rng_late))
.collect();
let early_var: f64 = {
let mean = early_samples.iter().sum::<f64>() / early_samples.len() as f64;
early_samples
.iter()
.map(|&x| (x - mean) * (x - mean))
.sum::<f64>()
/ early_samples.len() as f64
};
let late_var: f64 = {
let mean = late_samples.iter().sum::<f64>() / late_samples.len() as f64;
late_samples
.iter()
.map(|&x| (x - mean) * (x - mean))
.sum::<f64>()
/ late_samples.len() as f64
};
assert!(
early_var > late_var,
"Early variance ({:.2}) should be greater than late variance ({:.2})",
early_var,
late_var
);
}
#[test]
fn test_autotune_requires_validation_file() {
let train_path = write_temp(&make_train_data(), "no_val_file");
let mut args = make_fast_supervised_args(&train_path);
args.autotune_duration = 1;
std::fs::remove_file(&train_path).ok();
let result = Autotune::run(args);
assert!(
result.is_err(),
"Autotune without validation file should return an error"
);
}
#[test]
fn test_autotune_missing_validation_file() {
let train_path = write_temp(&make_train_data(), "missing_val");
let mut args = make_fast_supervised_args(&train_path);
args.autotune_duration = 1;
args.autotune_validation_file =
std::path::PathBuf::from("/nonexistent/path/validation.txt");
std::fs::remove_file(&train_path).ok();
let result = Autotune::run(args);
assert!(
result.is_err(),
"Autotune with missing validation file should return an error"
);
}
#[test]
fn test_parse_size_to_bytes() {
assert_eq!(parse_size_to_bytes("100"), Some(100));
assert_eq!(parse_size_to_bytes("100B"), Some(100));
assert_eq!(parse_size_to_bytes("1K"), Some(1_024));
assert_eq!(parse_size_to_bytes("1KB"), Some(1_024));
assert_eq!(parse_size_to_bytes("2M"), Some(2 * 1_024 * 1_024));
assert_eq!(parse_size_to_bytes("2MB"), Some(2 * 1_024 * 1_024));
assert_eq!(parse_size_to_bytes("1G"), Some(1_024 * 1_024 * 1_024));
assert_eq!(parse_size_to_bytes("1GB"), Some(1_024 * 1_024 * 1_024));
let expected = (1.5 * 1_024.0 * 1_024.0) as u64;
assert_eq!(parse_size_to_bytes("1.5M"), Some(expected));
assert_eq!(parse_size_to_bytes(""), None);
assert_eq!(parse_size_to_bytes("100X"), None);
assert_eq!(parse_size_to_bytes("1m"), Some(1_024 * 1_024));
assert_eq!(parse_size_to_bytes("1k"), Some(1_024));
}
#[test]
fn test_autotune_label_f1_metric() {
let train_data = make_train_data();
let train_path = write_temp(&train_data, "label_f1_train");
let val_data = make_val_data();
let val_path = write_temp(&val_data, "label_f1_val");
let mut args = make_fast_supervised_args(&train_path);
args.epoch = 1;
args.dim = 5;
args.autotune_validation_file = val_path.clone();
args.autotune_duration = 3;
args.autotune_metric = "f1:__label__sports".to_string();
let result = Autotune::run(args);
std::fs::remove_file(&train_path).ok();
std::fs::remove_file(&val_path).ok();
let model = result.expect("Autotune with label F1 metric should succeed");
let preds = model.predict("basketball player sport", 1, 0.0);
assert!(
!preds.is_empty(),
"Model from label-F1 autotune should produce predictions"
);
}
#[test]
fn test_autotune_default_f1_metric() {
let train_data = make_train_data();
let train_path = write_temp(&train_data, "default_f1_train");
let val_data = make_val_data();
let val_path = write_temp(&val_data, "default_f1_val");
let mut args = make_fast_supervised_args(&train_path);
args.epoch = 1;
args.dim = 5;
args.autotune_validation_file = val_path.clone();
args.autotune_duration = 3;
args.autotune_metric = "f1".to_string();
let result = Autotune::run(args);
std::fs::remove_file(&train_path).ok();
std::fs::remove_file(&val_path).ok();
let model = result.expect("Autotune with default F1 metric should succeed");
let preds = model.predict("banana fruit eat recipe", 1, 0.0);
assert!(
!preds.is_empty(),
"Model from default-F1 autotune should produce predictions"
);
}
#[test]
fn test_autotune_model_size_check_mechanism() {
let train_data = make_train_data();
let train_path = write_temp(&train_data, "size_check_mech");
let mut args = make_fast_supervised_args(&train_path);
args.epoch = 1;
args.dim = 5;
let mut model = FastText::train(args).expect("Training should succeed");
std::fs::remove_file(&train_path).ok();
let qargs = Args::default();
model
.quantize(&qargs)
.expect("Quantize should succeed for size-check path");
let tmp_path =
std::env::temp_dir().join(format!("test_size_check_mech_{}.ftz", std::process::id()));
model
.save_model(tmp_path.to_str().unwrap())
.expect("Save should succeed");
let ftz_size = std::fs::metadata(&tmp_path).map(|m| m.len()).unwrap_or(0);
std::fs::remove_file(&tmp_path).ok();
assert!(
ftz_size > 1,
"Quantized model should be > 1 byte; got {} bytes",
ftz_size
);
let ten_mb = 10u64 * 1024 * 1024;
assert!(
ftz_size < ten_mb,
"Tiny test model should be < 10 MB; got {} bytes",
ftz_size
);
}
}