Skip to main content

fasttext/
args.rs

1// Args: hyperparameter configuration for fastText
2
3use std::convert::TryFrom;
4use std::fmt;
5use std::io::{Read, Write};
6use std::path::PathBuf;
7
8use crate::error::Result;
9use crate::utils;
10
11/// Model architecture type.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13#[repr(i32)]
14pub enum ModelName {
15    /// Continuous bag-of-words.
16    Cbow = 1,
17    /// Skip-gram.
18    SkipGram = 2,
19    /// Supervised classification.
20    Supervised = 3,
21}
22
23/// Loss function type.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[repr(i32)]
26pub enum LossName {
27    /// Hierarchical softmax.
28    HierarchicalSoftmax = 1,
29    /// Negative sampling.
30    NegativeSampling = 2,
31    /// Softmax.
32    Softmax = 3,
33    /// One-vs-all.
34    OneVsAll = 4,
35}
36
37/// Autotune metric type.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39#[repr(i32)]
40pub enum MetricName {
41    /// F1 score (macro).
42    F1Score = 1,
43    /// F1 score for a specific label.
44    LabelF1Score = 2,
45    /// Precision at recall threshold.
46    PrecisionAtRecall = 3,
47    /// Precision at recall threshold for a specific label.
48    PrecisionAtRecallLabel = 4,
49    /// Recall at precision threshold.
50    RecallAtPrecision = 5,
51    /// Recall at precision threshold for a specific label.
52    RecallAtPrecisionLabel = 6,
53}
54
55impl TryFrom<i32> for ModelName {
56    type Error = i32;
57
58    fn try_from(value: i32) -> std::result::Result<Self, Self::Error> {
59        match value {
60            1 => Ok(ModelName::Cbow),
61            2 => Ok(ModelName::SkipGram),
62            3 => Ok(ModelName::Supervised),
63            _ => Err(value),
64        }
65    }
66}
67
68impl fmt::Display for ModelName {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        match self {
71            ModelName::Cbow => write!(f, "cbow"),
72            ModelName::SkipGram => write!(f, "sg"),
73            ModelName::Supervised => write!(f, "sup"),
74        }
75    }
76}
77
78impl TryFrom<i32> for LossName {
79    type Error = i32;
80
81    fn try_from(value: i32) -> std::result::Result<Self, Self::Error> {
82        match value {
83            1 => Ok(LossName::HierarchicalSoftmax),
84            2 => Ok(LossName::NegativeSampling),
85            3 => Ok(LossName::Softmax),
86            4 => Ok(LossName::OneVsAll),
87            _ => Err(value),
88        }
89    }
90}
91
92impl fmt::Display for LossName {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        match self {
95            LossName::HierarchicalSoftmax => write!(f, "hs"),
96            LossName::NegativeSampling => write!(f, "ns"),
97            LossName::Softmax => write!(f, "softmax"),
98            LossName::OneVsAll => write!(f, "one-vs-all"),
99        }
100    }
101}
102
103/// All fastText hyperparameters.
104#[derive(Debug, Clone)]
105pub struct Args {
106    pub input: PathBuf,
107    pub output: PathBuf,
108    pub lr: f64,
109    pub lr_update_rate: i32,
110    pub dim: i32,
111    pub ws: i32,
112    pub epoch: i32,
113    pub min_count: i32,
114    pub min_count_label: i32,
115    pub neg: i32,
116    pub word_ngrams: i32,
117    pub loss: LossName,
118    pub model: ModelName,
119    pub bucket: i32,
120    pub minn: i32,
121    pub maxn: i32,
122    pub thread: i32,
123    pub t: f64,
124    pub label: String,
125    pub verbose: i32,
126    pub pretrained_vectors: PathBuf,
127    pub save_output: bool,
128    pub seed: i32,
129    pub qout: bool,
130    pub retrain: bool,
131    pub qnorm: bool,
132    pub cutoff: usize,
133    pub dsub: usize,
134    pub autotune_validation_file: PathBuf,
135    pub autotune_metric: String,
136    pub autotune_predictions: i32,
137    pub autotune_duration: i32,
138    pub autotune_model_size: String,
139}
140
141impl Default for Args {
142    fn default() -> Self {
143        Args {
144            input: PathBuf::new(),
145            output: PathBuf::new(),
146            lr: 0.05,
147            lr_update_rate: 100,
148            dim: 100,
149            ws: 5,
150            epoch: 5,
151            min_count: 5,
152            min_count_label: 0,
153            neg: 5,
154            word_ngrams: 1,
155            loss: LossName::NegativeSampling,
156            model: ModelName::SkipGram,
157            bucket: 2_000_000,
158            minn: 3,
159            maxn: 6,
160            thread: 12,
161            t: 1e-4,
162            label: "__label__".to_string(),
163            verbose: 2,
164            pretrained_vectors: PathBuf::new(),
165            save_output: false,
166            seed: 0,
167
168            qout: false,
169            retrain: false,
170            qnorm: false,
171            cutoff: 0,
172            dsub: 2,
173
174            autotune_validation_file: PathBuf::new(),
175            autotune_metric: "f1".to_string(),
176            autotune_predictions: 1,
177            autotune_duration: 300,
178            autotune_model_size: String::new(),
179        }
180    }
181}
182
183impl Args {
184    /// Create a new Args with default values.
185    pub fn new() -> Self {
186        Self::default()
187    }
188
189    /// Returns true if autotune is enabled (validation file is non-empty).
190    pub fn has_autotune(&self) -> bool {
191        !self.autotune_validation_file.as_os_str().is_empty()
192    }
193
194    /// Apply supervised mode overrides.
195    ///
196    /// Sets model=Supervised, loss=Softmax, minCount=1, minn=0, maxn=0, lr=0.1.
197    /// Also sets bucket=0 when wordNgrams<=1 and maxn==0 and autotune is not enabled.
198    pub fn apply_supervised_defaults(&mut self) {
199        self.model = ModelName::Supervised;
200        self.loss = LossName::Softmax;
201        self.min_count = 1;
202        self.minn = 0;
203        self.maxn = 0;
204        self.lr = 0.1;
205
206        if self.word_ngrams <= 1 && self.maxn == 0 && !self.has_autotune() {
207            self.bucket = 0;
208        }
209    }
210
211    //
212    // The binary format writes exactly 12 i32 fields + 1 f64 field = 56 bytes total.
213    // Fields in order (matching C++ Args::save/load):
214    //   dim, ws, epoch, minCount, neg, wordNgrams, loss, model, bucket, minn, maxn, lrUpdateRate, t
215    //
216    // Note: loss and model are written as their i32 discriminant values.
217
218    /// Save the 13-field Args block to a writer (56 bytes).
219    pub fn save<W: Write>(&self, writer: &mut W) -> Result<()> {
220        utils::write_i32(writer, self.dim)?;
221        utils::write_i32(writer, self.ws)?;
222        utils::write_i32(writer, self.epoch)?;
223        utils::write_i32(writer, self.min_count)?;
224        utils::write_i32(writer, self.neg)?;
225        utils::write_i32(writer, self.word_ngrams)?;
226        utils::write_i32(writer, self.loss as i32)?;
227        utils::write_i32(writer, self.model as i32)?;
228        utils::write_i32(writer, self.bucket)?;
229        utils::write_i32(writer, self.minn)?;
230        utils::write_i32(writer, self.maxn)?;
231        utils::write_i32(writer, self.lr_update_rate)?;
232        utils::write_f64(writer, self.t)?;
233        Ok(())
234    }
235
236    /// Load the 13-field Args block from a reader (56 bytes).
237    pub fn load<R: Read>(&mut self, reader: &mut R) -> Result<()> {
238        self.dim = utils::read_i32(reader)?;
239        self.ws = utils::read_i32(reader)?;
240        self.epoch = utils::read_i32(reader)?;
241        self.min_count = utils::read_i32(reader)?;
242        self.neg = utils::read_i32(reader)?;
243        self.word_ngrams = utils::read_i32(reader)?;
244        let loss_val = utils::read_i32(reader)?;
245        self.loss = LossName::try_from(loss_val).map_err(|v| {
246            crate::error::FastTextError::InvalidModel(format!("Invalid loss value: {}", v))
247        })?;
248        let model_val = utils::read_i32(reader)?;
249        self.model = ModelName::try_from(model_val).map_err(|v| {
250            crate::error::FastTextError::InvalidModel(format!("Invalid model value: {}", v))
251        })?;
252        self.bucket = utils::read_i32(reader)?;
253        self.minn = utils::read_i32(reader)?;
254        self.maxn = utils::read_i32(reader)?;
255        self.lr_update_rate = utils::read_i32(reader)?;
256        self.t = utils::read_f64(reader)?;
257        Ok(())
258    }
259
260    /// Parse the autotune metric string and return the corresponding MetricName.
261    pub fn get_autotune_metric_name(&self) -> Option<MetricName> {
262        if self.autotune_metric.starts_with("f1:") {
263            Some(MetricName::LabelF1Score)
264        } else if self.autotune_metric == "f1" {
265            Some(MetricName::F1Score)
266        } else if self.autotune_metric.starts_with("precisionAtRecall:") {
267            let rest = &self.autotune_metric[18..];
268            if rest.contains(':') {
269                Some(MetricName::PrecisionAtRecallLabel)
270            } else {
271                Some(MetricName::PrecisionAtRecall)
272            }
273        } else if self.autotune_metric.starts_with("recallAtPrecision:") {
274            let rest = &self.autotune_metric[18..];
275            if rest.contains(':') {
276                Some(MetricName::RecallAtPrecisionLabel)
277            } else {
278                Some(MetricName::RecallAtPrecision)
279            }
280        } else {
281            None
282        }
283    }
284}