use std::io::{self, BufRead, BufReader, BufWriter, Write};
use std::path::PathBuf;
use std::process;
use clap::{Args, Parser, Subcommand};
use fasttext::args::{Args as FTArgs, LossName, ModelName};
use fasttext::matrix::Matrix;
use fasttext::meter::Meter;
use fasttext::utils::cpp_default_format;
use fasttext::FastText;
#[derive(Parser)]
#[command(
name = "fasttext",
about = "fastText text classification and representation learning",
arg_required_else_help = true
)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
Supervised(TrainArgs),
Skipgram(TrainArgs),
Cbow(TrainArgs),
Predict(PredictArgs),
#[command(name = "predict-prob")]
PredictProb(PredictArgs),
Test(TestEvalArgs),
#[command(name = "test-label")]
TestLabel(TestEvalArgs),
Quantize(QuantizeArgs),
#[command(name = "print-word-vectors")]
PrintWordVectors(ModelPathArgs),
#[command(name = "print-sentence-vectors")]
PrintSentenceVectors(ModelPathArgs),
#[command(name = "print-ngrams")]
PrintNgrams(PrintNgramsArgs),
Nn(NnArgs),
Analogies(AnalogiesArgs),
Dump(DumpArgs),
}
#[derive(Args, Debug)]
struct TrainArgs {
#[arg(long)]
input: String,
#[arg(long)]
output: String,
#[arg(long)]
lr: Option<f64>,
#[arg(long)]
lr_update_rate: Option<i32>,
#[arg(long)]
dim: Option<i32>,
#[arg(long)]
ws: Option<i32>,
#[arg(long)]
epoch: Option<i32>,
#[arg(long)]
min_count: Option<i32>,
#[arg(long)]
min_count_label: Option<i32>,
#[arg(long)]
neg: Option<i32>,
#[arg(long)]
word_ngrams: Option<i32>,
#[arg(long)]
loss: Option<String>,
#[arg(long)]
bucket: Option<i32>,
#[arg(long)]
minn: Option<i32>,
#[arg(long)]
maxn: Option<i32>,
#[arg(long)]
thread: Option<i32>,
#[arg(long)]
t: Option<f64>,
#[arg(long)]
label: Option<String>,
#[arg(long)]
verbose: Option<i32>,
#[arg(long)]
pretrained_vectors: Option<String>,
#[arg(long)]
save_output: bool,
#[arg(long)]
seed: Option<i32>,
#[arg(long, name = "autotune-validation-file")]
autotune_validation_file: Option<String>,
#[arg(long, name = "autotune-metric")]
autotune_metric: Option<String>,
#[arg(long, name = "autotune-predictions")]
autotune_predictions: Option<i32>,
#[arg(long, name = "autotune-duration")]
autotune_duration: Option<i32>,
#[arg(long, name = "autotune-model-size")]
autotune_model_size: Option<String>,
}
#[derive(Args, Debug)]
struct PredictArgs {
model: String,
#[arg(default_value = "-")]
input: String,
#[arg(default_value_t = 1)]
k: usize,
#[arg(default_value_t = 0.0)]
threshold: f32,
}
#[derive(Args, Debug)]
struct TestEvalArgs {
model: String,
test_file: String,
#[arg(default_value_t = 1)]
k: usize,
#[arg(default_value_t = 0.0)]
threshold: f32,
}
#[derive(Args, Debug)]
struct QuantizeArgs {
#[arg(long)]
output: String,
#[arg(long, default_value = "")]
input: String,
#[arg(long, default_value_t = 0)]
cutoff: usize,
#[arg(long, default_value_t = false)]
retrain: bool,
#[arg(long, default_value_t = false)]
qnorm: bool,
#[arg(long, default_value_t = false)]
qout: bool,
#[arg(long, default_value_t = 2)]
dsub: usize,
}
#[derive(Args, Debug)]
struct ModelPathArgs {
model: String,
}
#[derive(Args, Debug)]
struct PrintNgramsArgs {
model: String,
word: String,
}
#[derive(Args, Debug)]
struct NnArgs {
model: String,
#[arg(default_value_t = 10)]
k: usize,
}
#[derive(Args, Debug)]
struct AnalogiesArgs {
model: String,
#[arg(default_value_t = 10)]
k: usize,
}
#[derive(Args, Debug)]
struct DumpArgs {
model: String,
option: String,
}
fn normalize_args(raw: impl Iterator<Item = String>) -> Vec<String> {
const FLAG_MAP: &[(&str, &str)] = &[
("-epoch", "--epoch"),
("-lr", "--lr"),
("-lrUpdateRate", "--lr-update-rate"),
("-dim", "--dim"),
("-ws", "--ws"),
("-minCount", "--min-count"),
("-minCountLabel", "--min-count-label"),
("-neg", "--neg"),
("-wordNgrams", "--word-ngrams"),
("-loss", "--loss"),
("-bucket", "--bucket"),
("-minn", "--minn"),
("-maxn", "--maxn"),
("-thread", "--thread"),
("-t", "--t"),
("-label", "--label"),
("-verbose", "--verbose"),
("-seed", "--seed"),
("-input", "--input"),
("-output", "--output"),
("-pretrainedVectors", "--pretrained-vectors"),
("-saveOutput", "--save-output"),
("-cutoff", "--cutoff"),
("-retrain", "--retrain"),
("-qnorm", "--qnorm"),
("-qout", "--qout"),
("-dsub", "--dsub"),
("-autotuneValidationFile", "--autotune-validation-file"),
("-autotuneDuration", "--autotune-duration"),
("-autotuneModelSize", "--autotune-model-size"),
("-autotuneMetric", "--autotune-metric"),
];
raw.map(|arg| {
for (single, double) in FLAG_MAP {
if arg == *single {
return double.to_string();
}
}
arg
})
.collect()
}
fn main() {
let raw_args = std::env::args();
let normalized = normalize_args(raw_args);
let cli = Cli::parse_from(normalized);
match cli.command {
Commands::Supervised(args) => run_train(args, ModelName::Supervised),
Commands::Skipgram(args) => run_train(args, ModelName::SkipGram),
Commands::Cbow(args) => run_train(args, ModelName::Cbow),
Commands::Predict(args) => run_predict(args, false),
Commands::PredictProb(args) => run_predict(args, true),
Commands::Test(args) => run_test(args, false),
Commands::TestLabel(args) => run_test(args, true),
Commands::Quantize(args) => run_quantize(args),
Commands::PrintWordVectors(args) => run_print_word_vectors(args),
Commands::PrintSentenceVectors(args) => run_print_sentence_vectors(args),
Commands::PrintNgrams(args) => run_print_ngrams(args),
Commands::Nn(args) => run_nn(args),
Commands::Analogies(args) => run_analogies(args),
Commands::Dump(args) => run_dump(args),
}
}
fn parse_loss(s: &str) -> Option<LossName> {
match s.to_lowercase().as_str() {
"ns" => Some(LossName::NegativeSampling),
"hs" => Some(LossName::HierarchicalSoftmax),
"softmax" => Some(LossName::Softmax),
"ova" | "one-vs-all" | "ovr" => Some(LossName::OneVsAll),
_ => None,
}
}
fn load_model_or_exit(path: &str) -> FastText {
if !std::path::Path::new(path).exists() {
eprintln!("Error: model file '{}' does not exist", path);
process::exit(1);
}
match FastText::load_model(path) {
Ok(model) => model,
Err(e) => {
eprintln!("Error loading model '{}': {}", path, e);
process::exit(1);
}
}
}
fn apply_train_overrides(args: &mut FTArgs, train_args: &TrainArgs) {
if let Some(v) = train_args.lr {
args.lr = v;
}
if let Some(v) = train_args.lr_update_rate {
args.lr_update_rate = v;
}
if let Some(v) = train_args.dim {
args.dim = v;
}
if let Some(v) = train_args.ws {
args.ws = v;
}
if let Some(v) = train_args.epoch {
args.epoch = v;
}
if let Some(v) = train_args.min_count {
args.min_count = v;
}
if let Some(v) = train_args.min_count_label {
args.min_count_label = v;
}
if let Some(v) = train_args.neg {
args.neg = v;
}
if let Some(v) = train_args.word_ngrams {
args.word_ngrams = v;
}
if let Some(ref loss_str) = train_args.loss {
match parse_loss(loss_str) {
Some(loss) => args.loss = loss,
None => {
eprintln!(
"Error: unknown loss function '{}'. Valid values: ns, hs, softmax, ova",
loss_str
);
process::exit(1);
}
}
}
if let Some(v) = train_args.bucket {
args.bucket = v;
}
if let Some(v) = train_args.minn {
args.minn = v;
}
if let Some(v) = train_args.maxn {
args.maxn = v;
}
if let Some(v) = train_args.thread {
args.thread = v;
}
if let Some(v) = train_args.t {
args.t = v;
}
if let Some(ref v) = train_args.label {
args.label = v.clone();
}
if let Some(v) = train_args.verbose {
args.verbose = v;
}
if let Some(ref v) = train_args.pretrained_vectors {
args.pretrained_vectors = PathBuf::from(v.as_str());
}
if train_args.save_output {
args.save_output = true;
}
if let Some(v) = train_args.seed {
args.seed = v;
}
if let Some(ref v) = train_args.autotune_validation_file {
args.autotune_validation_file = PathBuf::from(v.as_str());
}
if let Some(ref v) = train_args.autotune_metric {
args.autotune_metric = v.clone();
}
if let Some(v) = train_args.autotune_predictions {
args.autotune_predictions = v;
}
if let Some(v) = train_args.autotune_duration {
args.autotune_duration = v;
}
if let Some(ref v) = train_args.autotune_model_size {
args.autotune_model_size = v.clone();
}
}
fn build_ft_args(train_args: TrainArgs, model_name: ModelName) -> FTArgs {
let mut args = FTArgs::default();
if model_name == ModelName::Supervised {
args.apply_supervised_defaults();
} else {
args.model = model_name;
}
args.input = PathBuf::from(&train_args.input);
args.output = PathBuf::from(&train_args.output);
apply_train_overrides(&mut args, &train_args);
args
}
fn run_train(train_args: TrainArgs, model_name: ModelName) {
let output_base = train_args.output.clone();
let args = build_ft_args(train_args, model_name);
let has_autotune = args.has_autotune();
let model_size_constrained = has_autotune && !args.autotune_model_size.is_empty();
let result = if has_autotune {
fasttext::autotune::Autotune::run(args)
} else {
FastText::train(args)
};
match result {
Ok(model) => {
let ext = if model_size_constrained { "ftz" } else { "bin" };
let model_path = format!("{}.{}", output_base, ext);
if let Err(e) = model.save_model(&model_path) {
eprintln!("Error saving model to '{}': {}", model_path, e);
process::exit(1);
}
}
Err(e) => {
eprintln!("Error training model: {}", e);
process::exit(1);
}
}
}
fn run_predict(predict_args: PredictArgs, with_prob: bool) {
let model = load_model_or_exit(&predict_args.model);
let k = predict_args.k;
let threshold = predict_args.threshold;
let stdout = io::stdout();
let mut out = BufWriter::new(stdout.lock());
let process_line = |line: &str, out: &mut dyn Write| {
let predictions = model.predict(line, k, threshold);
let mut first = true;
for pred in &predictions {
if !first {
write!(out, " ").unwrap_or_else(|_| process::exit(1));
}
first = false;
write!(out, "{}", pred.label).unwrap_or_else(|_| process::exit(1));
if with_prob {
write!(out, " {}", cpp_default_format(pred.prob as f64, 6))
.unwrap_or_else(|_| process::exit(1));
}
}
writeln!(out).unwrap_or_else(|_| process::exit(1));
};
if predict_args.input == "-" {
let stdin = io::stdin();
for line in stdin.lock().lines() {
let line = line.unwrap_or_else(|e| {
eprintln!("Error reading stdin: {}", e);
process::exit(1);
});
process_line(&line, &mut out);
}
} else {
let file = std::fs::File::open(&predict_args.input).unwrap_or_else(|e| {
eprintln!("Error opening input file '{}': {}", predict_args.input, e);
process::exit(1);
});
for line in BufReader::new(file).lines() {
let line = line.unwrap_or_else(|e| {
eprintln!("Error reading input file: {}", e);
process::exit(1);
});
process_line(&line, &mut out);
}
}
out.flush().unwrap_or_else(|_| process::exit(1));
}
fn run_test(test_args: TestEvalArgs, per_label: bool) {
let model = load_model_or_exit(&test_args.model);
let k = test_args.k;
let threshold = test_args.threshold;
let file = match std::fs::File::open(&test_args.test_file) {
Ok(f) => f,
Err(e) => {
eprintln!("Error opening test file '{}': {}", test_args.test_file, e);
process::exit(1);
}
};
let mut reader = BufReader::new(file);
let meter: Meter = match model.test_model(&mut reader, k, threshold) {
Ok(m) => m,
Err(e) => {
eprintln!("Error evaluating model: {}", e);
process::exit(1);
}
};
if per_label {
let nlabels = model.dict().nlabels();
let fmt_metric = |name: &str, val: f64| -> String {
if val.is_finite() {
format!("{} : {:.6}", name, val)
} else {
format!("{} : --------", name)
}
};
for lid in 0..nlabels {
if let Ok(label_str) = model.dict().get_label(lid) {
let f = meter.f1_for_label(lid);
let p = meter.precision_for_label(lid);
let r = meter.recall_for_label(lid);
println!(
"{} {} {} {}",
fmt_metric("F1-Score", f),
fmt_metric("Precision", p),
fmt_metric("Recall", r),
label_str
);
}
}
}
if per_label {
println!("N\t{}", meter.n_examples());
println!("P@{}\t{:.3}", k, meter.precision());
println!("R@{}\t{:.3}", k, meter.recall());
} else {
meter.write_general_metrics(k as i32);
}
}
fn run_quantize(qargs: QuantizeArgs) {
let model_bin = format!("{}.bin", qargs.output);
let model_ftz = format!("{}.ftz", qargs.output);
if !std::path::Path::new(&model_bin).exists() {
eprintln!("Error: model file '{}' does not exist", model_bin);
process::exit(1);
}
let mut model = match FastText::load_model(&model_bin) {
Ok(m) => m,
Err(e) => {
eprintln!("Error loading model '{}': {}", model_bin, e);
process::exit(1);
}
};
let mut ft_qargs = model.args().clone();
ft_qargs.cutoff = qargs.cutoff;
ft_qargs.retrain = qargs.retrain;
ft_qargs.qnorm = qargs.qnorm;
ft_qargs.qout = qargs.qout;
ft_qargs.dsub = qargs.dsub;
if !qargs.input.is_empty() {
ft_qargs.input = qargs.input;
}
if let Err(e) = model.quantize(&ft_qargs) {
eprintln!("Error quantizing model: {}", e);
process::exit(1);
}
if let Err(e) = model.save_model(&model_ftz) {
eprintln!("Error saving quantized model to '{}': {}", model_ftz, e);
process::exit(1);
}
}
fn run_print_word_vectors(args: ModelPathArgs) {
let model = load_model_or_exit(&args.model);
let stdout = io::stdout();
let mut out = BufWriter::new(stdout.lock());
let stdin = io::stdin();
for line in stdin.lock().lines() {
let line = line.unwrap_or_else(|e| {
eprintln!("Error reading stdin: {}", e);
process::exit(1);
});
for word in line.split_whitespace() {
let vec = model.get_word_vector(word);
write!(out, "{} ", word).unwrap_or_else(|_| process::exit(1));
for &v in &vec {
write!(out, "{} ", cpp_default_format(v as f64, 5))
.unwrap_or_else(|_| process::exit(1));
}
writeln!(out).unwrap_or_else(|_| process::exit(1));
}
}
out.flush().unwrap_or_else(|_| process::exit(1));
}
fn run_print_sentence_vectors(args: ModelPathArgs) {
let model = load_model_or_exit(&args.model);
let stdout = io::stdout();
let mut out = BufWriter::new(stdout.lock());
let stdin = io::stdin();
for line in stdin.lock().lines() {
let line = line.unwrap_or_else(|e| {
eprintln!("Error reading stdin: {}", e);
process::exit(1);
});
let vec = model.get_sentence_vector(&line);
for &v in &vec {
write!(out, "{} ", cpp_default_format(v as f64, 5))
.unwrap_or_else(|_| process::exit(1));
}
writeln!(out).unwrap_or_else(|_| process::exit(1));
}
out.flush().unwrap_or_else(|_| process::exit(1));
}
fn run_print_ngrams(args: PrintNgramsArgs) {
let model = load_model_or_exit(&args.model);
let stdout = io::stdout();
let mut out = BufWriter::new(stdout.lock());
let ngram_vecs = model.get_ngram_vectors(&args.word);
for (ngram_str, vec) in &ngram_vecs {
write!(out, "{} ", ngram_str).unwrap_or_else(|_| process::exit(1));
for &v in vec {
write!(out, "{} ", cpp_default_format(v as f64, 5))
.unwrap_or_else(|_| process::exit(1));
}
writeln!(out).unwrap_or_else(|_| process::exit(1));
}
out.flush().unwrap_or_else(|_| process::exit(1));
}
fn run_nn(args: NnArgs) {
let model = load_model_or_exit(&args.model);
let k = args.k;
let stdout = io::stdout();
let mut out = BufWriter::new(stdout.lock());
write!(out, "Query word? ").unwrap_or_else(|_| process::exit(1));
out.flush().unwrap_or_else(|_| process::exit(1));
let stdin = io::stdin();
for line in stdin.lock().lines() {
let line = line.unwrap_or_else(|e| {
eprintln!("Error reading stdin: {}", e);
process::exit(1);
});
let query = line.trim();
if query.is_empty() {
continue;
}
let neighbors = model.get_nn(query, k);
for (similarity, word) in &neighbors {
writeln!(
out,
"{} {}",
word,
cpp_default_format(*similarity as f64, 6)
)
.unwrap_or_else(|_| process::exit(1));
}
write!(out, "Query word? ").unwrap_or_else(|_| process::exit(1));
out.flush().unwrap_or_else(|_| process::exit(1));
}
}
fn run_analogies(args: AnalogiesArgs) {
let k = args.k;
let stdout = io::stdout();
let mut out = BufWriter::new(stdout.lock());
writeln!(out, "Loading model {}", args.model).unwrap_or_else(|_| process::exit(1));
out.flush().unwrap_or_else(|_| process::exit(1));
let model = load_model_or_exit(&args.model);
write!(out, "Query triplet (A - B + C)? ").unwrap_or_else(|_| process::exit(1));
out.flush().unwrap_or_else(|_| process::exit(1));
let stdin = io::stdin();
for line in stdin.lock().lines() {
let line = line.unwrap_or_else(|e| {
eprintln!("Error reading stdin: {}", e);
process::exit(1);
});
let words: Vec<&str> = line.split_whitespace().collect();
if words.len() < 3 {
continue;
}
let (word_a, word_b, word_c) = (words[0], words[1], words[2]);
let results = model.get_analogies(word_a, word_b, word_c, k);
for (similarity, word) in &results {
writeln!(
out,
"{} {}",
word,
cpp_default_format(*similarity as f64, 6)
)
.unwrap_or_else(|_| process::exit(1));
}
write!(out, "Query triplet (A - B + C)? ").unwrap_or_else(|_| process::exit(1));
out.flush().unwrap_or_else(|_| process::exit(1));
}
}
fn dump_matrix(out: &mut impl Write, m: &fasttext::matrix::DenseMatrix) {
writeln!(out, "{} {}", m.rows(), m.cols()).unwrap_or_else(|_| process::exit(1));
for i in 0..m.rows() {
let row = m.row(i);
let mut first = true;
for &v in row {
if !first {
write!(out, " ").unwrap_or_else(|_| process::exit(1));
}
write!(out, "{}", cpp_default_format(v as f64, 6)).unwrap_or_else(|_| process::exit(1));
first = false;
}
writeln!(out).unwrap_or_else(|_| process::exit(1));
}
}
fn run_dump(args: DumpArgs) {
let model = load_model_or_exit(&args.model);
let stdout = io::stdout();
let mut out = BufWriter::new(stdout.lock());
match args.option.as_str() {
"args" => {
let a = model.args();
writeln!(out, "dim {}", a.dim).unwrap_or_else(|_| process::exit(1));
writeln!(out, "ws {}", a.ws).unwrap_or_else(|_| process::exit(1));
writeln!(out, "epoch {}", a.epoch).unwrap_or_else(|_| process::exit(1));
writeln!(out, "minCount {}", a.min_count).unwrap_or_else(|_| process::exit(1));
writeln!(out, "neg {}", a.neg).unwrap_or_else(|_| process::exit(1));
writeln!(out, "wordNgrams {}", a.word_ngrams).unwrap_or_else(|_| process::exit(1));
writeln!(out, "loss {}", a.loss).unwrap_or_else(|_| process::exit(1));
writeln!(out, "model {}", a.model).unwrap_or_else(|_| process::exit(1));
writeln!(out, "bucket {}", a.bucket).unwrap_or_else(|_| process::exit(1));
writeln!(out, "minn {}", a.minn).unwrap_or_else(|_| process::exit(1));
writeln!(out, "maxn {}", a.maxn).unwrap_or_else(|_| process::exit(1));
writeln!(out, "lrUpdateRate {}", a.lr_update_rate).unwrap_or_else(|_| process::exit(1));
writeln!(out, "t {}", a.t).unwrap_or_else(|_| process::exit(1));
}
"dict" => {
let dict = model.dict();
let words = dict.words();
writeln!(out, "{}", words.len()).unwrap_or_else(|_| process::exit(1));
for entry in words {
let entry_type = match entry.entry_type {
fasttext::dictionary::EntryType::Word => "word",
fasttext::dictionary::EntryType::Label => "label",
};
writeln!(out, "{} {} {}", entry.word, entry.count, entry_type)
.unwrap_or_else(|_| process::exit(1));
}
}
"input" | "output" => {
if model.is_quant() {
eprintln!("Not supported for quantized models.");
process::exit(1);
}
let m = if args.option == "input" {
model.input_matrix()
} else {
model.output_matrix()
};
dump_matrix(&mut out, m);
}
other => {
eprintln!(
"Error: unknown dump option '{}'. Valid options: args, dict, input, output",
other
);
process::exit(1);
}
}
out.flush().unwrap_or_else(|_| process::exit(1));
}
#[cfg(test)]
mod tests {
use super::normalize_args;
#[test]
fn normalize_args_rewrites_known_cpp_style_flags() {
let raw = vec![
"fasttext".to_string(),
"supervised".to_string(),
"-epoch".to_string(),
"5".to_string(),
"-lrUpdateRate".to_string(),
"100".to_string(),
"-pretrainedVectors".to_string(),
"vectors.vec".to_string(),
];
let normalized = normalize_args(raw.into_iter());
assert_eq!(
normalized,
vec![
"fasttext".to_string(),
"supervised".to_string(),
"--epoch".to_string(),
"5".to_string(),
"--lr-update-rate".to_string(),
"100".to_string(),
"--pretrained-vectors".to_string(),
"vectors.vec".to_string(),
]
);
}
#[test]
fn normalize_args_leaves_unknown_args_unchanged() {
let raw = vec![
"fasttext".to_string(),
"supervised".to_string(),
"-unknownFlag".to_string(),
"value".to_string(),
"input.txt".to_string(),
];
let normalized = normalize_args(raw.clone().into_iter());
assert_eq!(normalized, raw);
}
#[test]
fn normalize_args_only_rewrites_exact_matches() {
let raw = vec![
"fasttext".to_string(),
"supervised".to_string(),
"-epoch=5".to_string(),
"-epochExtra".to_string(),
"--epoch".to_string(),
"-lrUpdateRates".to_string(),
];
let normalized = normalize_args(raw.clone().into_iter());
assert_eq!(normalized, raw);
}
}