use clap::{ColorChoice, Parser};
use libffm::{Error, Model};
use std::path::PathBuf;
use std::process;
#[derive(Debug, Parser)]
#[command(name = "ffm-train", version, color = ColorChoice::Never)]
struct Args {
#[arg(short, default_value_t = 0.00002)]
lambda: f32,
#[arg(short = 'k', default_value_t = 4)]
factor: i32,
#[arg(short = 't', default_value_t = 15)]
iteration: i32,
#[arg(short = 'r', default_value_t = 0.2)]
eta: f32,
#[arg(short = 's', default_value_t = 1)]
nr_threads: i32,
#[arg(short = 'p', value_parser)]
va_path: Option<PathBuf>,
#[arg(long)]
quiet: bool,
#[arg(long)]
no_norm: bool,
#[arg(long)]
auto_stop: bool,
#[arg(long)]
in_memory: bool,
#[arg(name = "train-file", value_parser)]
tr_path: PathBuf,
#[arg(name = "model-file", value_parser)]
model_path: Option<PathBuf>,
}
fn train_on_disk(args: &Args) -> Result<(), Error> {
let mut params = Model::params();
params
.learning_rate(args.eta)
.lambda(args.lambda)
.iterations(args.iteration)
.factors(args.factor)
.normalization(!args.no_norm)
.auto_stop(args.auto_stop)
.quiet(args.quiet)
.on_disk(!args.in_memory);
let model = match &args.va_path {
Some(p) => params.train_eval(&args.tr_path, &p)?,
None => params.train(&args.tr_path)?,
};
let model_path = args.model_path.clone().unwrap_or_else(|| {
let mut filename = args.tr_path.file_name().unwrap().to_os_string();
filename.push(".model");
PathBuf::from(filename)
});
model.save(&model_path)
}
fn main() {
let args = Args::parse();
if args.auto_stop && args.va_path.is_none() {
println!("To use auto-stop, you need to assign a validation set");
process::exit(1);
}
if let Err(err) = train_on_disk(&args) {
println!("{}", err);
process::exit(1);
}
}