use std::convert::TryFrom;
use std::fmt;
use std::io::{Read, Write};
use std::path::PathBuf;
use crate::error::Result;
use crate::utils;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
pub enum ModelName {
Cbow = 1,
SkipGram = 2,
Supervised = 3,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
pub enum LossName {
HierarchicalSoftmax = 1,
NegativeSampling = 2,
Softmax = 3,
OneVsAll = 4,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
pub enum MetricName {
F1Score = 1,
LabelF1Score = 2,
PrecisionAtRecall = 3,
PrecisionAtRecallLabel = 4,
RecallAtPrecision = 5,
RecallAtPrecisionLabel = 6,
}
impl TryFrom<i32> for ModelName {
type Error = i32;
fn try_from(value: i32) -> std::result::Result<Self, Self::Error> {
match value {
1 => Ok(ModelName::Cbow),
2 => Ok(ModelName::SkipGram),
3 => Ok(ModelName::Supervised),
_ => Err(value),
}
}
}
impl fmt::Display for ModelName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ModelName::Cbow => write!(f, "cbow"),
ModelName::SkipGram => write!(f, "sg"),
ModelName::Supervised => write!(f, "sup"),
}
}
}
impl TryFrom<i32> for LossName {
type Error = i32;
fn try_from(value: i32) -> std::result::Result<Self, Self::Error> {
match value {
1 => Ok(LossName::HierarchicalSoftmax),
2 => Ok(LossName::NegativeSampling),
3 => Ok(LossName::Softmax),
4 => Ok(LossName::OneVsAll),
_ => Err(value),
}
}
}
impl fmt::Display for LossName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LossName::HierarchicalSoftmax => write!(f, "hs"),
LossName::NegativeSampling => write!(f, "ns"),
LossName::Softmax => write!(f, "softmax"),
LossName::OneVsAll => write!(f, "one-vs-all"),
}
}
}
#[derive(Debug, Clone)]
pub struct Args {
pub input: PathBuf,
pub output: PathBuf,
pub lr: f64,
pub lr_update_rate: i32,
pub dim: i32,
pub ws: i32,
pub epoch: i32,
pub min_count: i32,
pub min_count_label: i32,
pub neg: i32,
pub word_ngrams: i32,
pub loss: LossName,
pub model: ModelName,
pub bucket: i32,
pub minn: i32,
pub maxn: i32,
pub thread: i32,
pub t: f64,
pub label: String,
pub verbose: i32,
pub pretrained_vectors: PathBuf,
pub save_output: bool,
pub seed: i32,
pub qout: bool,
pub retrain: bool,
pub qnorm: bool,
pub cutoff: usize,
pub dsub: usize,
pub autotune_validation_file: PathBuf,
pub autotune_metric: String,
pub autotune_predictions: i32,
pub autotune_duration: i32,
pub autotune_model_size: String,
}
impl Default for Args {
fn default() -> Self {
Args {
input: PathBuf::new(),
output: PathBuf::new(),
lr: 0.05,
lr_update_rate: 100,
dim: 100,
ws: 5,
epoch: 5,
min_count: 5,
min_count_label: 0,
neg: 5,
word_ngrams: 1,
loss: LossName::NegativeSampling,
model: ModelName::SkipGram,
bucket: 2_000_000,
minn: 3,
maxn: 6,
thread: 12,
t: 1e-4,
label: "__label__".to_string(),
verbose: 2,
pretrained_vectors: PathBuf::new(),
save_output: false,
seed: 0,
qout: false,
retrain: false,
qnorm: false,
cutoff: 0,
dsub: 2,
autotune_validation_file: PathBuf::new(),
autotune_metric: "f1".to_string(),
autotune_predictions: 1,
autotune_duration: 300,
autotune_model_size: String::new(),
}
}
}
impl Args {
pub fn new() -> Self {
Self::default()
}
pub fn has_autotune(&self) -> bool {
!self.autotune_validation_file.as_os_str().is_empty()
}
pub fn apply_supervised_defaults(&mut self) {
self.model = ModelName::Supervised;
self.loss = LossName::Softmax;
self.min_count = 1;
self.minn = 0;
self.maxn = 0;
self.lr = 0.1;
if self.word_ngrams <= 1 && self.maxn == 0 && !self.has_autotune() {
self.bucket = 0;
}
}
pub fn save<W: Write>(&self, writer: &mut W) -> Result<()> {
utils::write_i32(writer, self.dim)?;
utils::write_i32(writer, self.ws)?;
utils::write_i32(writer, self.epoch)?;
utils::write_i32(writer, self.min_count)?;
utils::write_i32(writer, self.neg)?;
utils::write_i32(writer, self.word_ngrams)?;
utils::write_i32(writer, self.loss as i32)?;
utils::write_i32(writer, self.model as i32)?;
utils::write_i32(writer, self.bucket)?;
utils::write_i32(writer, self.minn)?;
utils::write_i32(writer, self.maxn)?;
utils::write_i32(writer, self.lr_update_rate)?;
utils::write_f64(writer, self.t)?;
Ok(())
}
pub fn load<R: Read>(&mut self, reader: &mut R) -> Result<()> {
self.dim = utils::read_i32(reader)?;
self.ws = utils::read_i32(reader)?;
self.epoch = utils::read_i32(reader)?;
self.min_count = utils::read_i32(reader)?;
self.neg = utils::read_i32(reader)?;
self.word_ngrams = utils::read_i32(reader)?;
let loss_val = utils::read_i32(reader)?;
self.loss = LossName::try_from(loss_val).map_err(|v| {
crate::error::FastTextError::InvalidModel(format!("Invalid loss value: {}", v))
})?;
let model_val = utils::read_i32(reader)?;
self.model = ModelName::try_from(model_val).map_err(|v| {
crate::error::FastTextError::InvalidModel(format!("Invalid model value: {}", v))
})?;
self.bucket = utils::read_i32(reader)?;
self.minn = utils::read_i32(reader)?;
self.maxn = utils::read_i32(reader)?;
self.lr_update_rate = utils::read_i32(reader)?;
self.t = utils::read_f64(reader)?;
Ok(())
}
pub fn get_autotune_metric_name(&self) -> Option<MetricName> {
if self.autotune_metric.starts_with("f1:") {
Some(MetricName::LabelF1Score)
} else if self.autotune_metric == "f1" {
Some(MetricName::F1Score)
} else if self.autotune_metric.starts_with("precisionAtRecall:") {
let rest = &self.autotune_metric[18..];
if rest.contains(':') {
Some(MetricName::PrecisionAtRecallLabel)
} else {
Some(MetricName::PrecisionAtRecall)
}
} else if self.autotune_metric.starts_with("recallAtPrecision:") {
let rest = &self.autotune_metric[18..];
if rest.contains(':') {
Some(MetricName::RecallAtPrecisionLabel)
} else {
Some(MetricName::RecallAtPrecision)
}
} else {
None
}
}
}