1use std::convert::TryFrom;
4use std::fmt;
5use std::io::{Read, Write};
6use std::path::PathBuf;
7
8use crate::error::Result;
9use crate::utils;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13#[repr(i32)]
14pub enum ModelName {
15 Cbow = 1,
17 SkipGram = 2,
19 Supervised = 3,
21}
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[repr(i32)]
26pub enum LossName {
27 HierarchicalSoftmax = 1,
29 NegativeSampling = 2,
31 Softmax = 3,
33 OneVsAll = 4,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39#[repr(i32)]
40pub enum MetricName {
41 F1Score = 1,
43 LabelF1Score = 2,
45 PrecisionAtRecall = 3,
47 PrecisionAtRecallLabel = 4,
49 RecallAtPrecision = 5,
51 RecallAtPrecisionLabel = 6,
53}
54
55impl TryFrom<i32> for ModelName {
56 type Error = i32;
57
58 fn try_from(value: i32) -> std::result::Result<Self, Self::Error> {
59 match value {
60 1 => Ok(ModelName::Cbow),
61 2 => Ok(ModelName::SkipGram),
62 3 => Ok(ModelName::Supervised),
63 _ => Err(value),
64 }
65 }
66}
67
68impl fmt::Display for ModelName {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 match self {
71 ModelName::Cbow => write!(f, "cbow"),
72 ModelName::SkipGram => write!(f, "sg"),
73 ModelName::Supervised => write!(f, "sup"),
74 }
75 }
76}
77
78impl TryFrom<i32> for LossName {
79 type Error = i32;
80
81 fn try_from(value: i32) -> std::result::Result<Self, Self::Error> {
82 match value {
83 1 => Ok(LossName::HierarchicalSoftmax),
84 2 => Ok(LossName::NegativeSampling),
85 3 => Ok(LossName::Softmax),
86 4 => Ok(LossName::OneVsAll),
87 _ => Err(value),
88 }
89 }
90}
91
92impl fmt::Display for LossName {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 match self {
95 LossName::HierarchicalSoftmax => write!(f, "hs"),
96 LossName::NegativeSampling => write!(f, "ns"),
97 LossName::Softmax => write!(f, "softmax"),
98 LossName::OneVsAll => write!(f, "one-vs-all"),
99 }
100 }
101}
102
103#[derive(Debug, Clone)]
105pub struct Args {
106 pub input: PathBuf,
107 pub output: PathBuf,
108 pub lr: f64,
109 pub lr_update_rate: i32,
110 pub dim: i32,
111 pub ws: i32,
112 pub epoch: i32,
113 pub min_count: i32,
114 pub min_count_label: i32,
115 pub neg: i32,
116 pub word_ngrams: i32,
117 pub loss: LossName,
118 pub model: ModelName,
119 pub bucket: i32,
120 pub minn: i32,
121 pub maxn: i32,
122 pub thread: i32,
123 pub t: f64,
124 pub label: String,
125 pub verbose: i32,
126 pub pretrained_vectors: PathBuf,
127 pub save_output: bool,
128 pub seed: i32,
129 pub qout: bool,
130 pub retrain: bool,
131 pub qnorm: bool,
132 pub cutoff: usize,
133 pub dsub: usize,
134 pub autotune_validation_file: PathBuf,
135 pub autotune_metric: String,
136 pub autotune_predictions: i32,
137 pub autotune_duration: i32,
138 pub autotune_model_size: String,
139}
140
141impl Default for Args {
142 fn default() -> Self {
143 Args {
144 input: PathBuf::new(),
145 output: PathBuf::new(),
146 lr: 0.05,
147 lr_update_rate: 100,
148 dim: 100,
149 ws: 5,
150 epoch: 5,
151 min_count: 5,
152 min_count_label: 0,
153 neg: 5,
154 word_ngrams: 1,
155 loss: LossName::NegativeSampling,
156 model: ModelName::SkipGram,
157 bucket: 2_000_000,
158 minn: 3,
159 maxn: 6,
160 thread: 12,
161 t: 1e-4,
162 label: "__label__".to_string(),
163 verbose: 2,
164 pretrained_vectors: PathBuf::new(),
165 save_output: false,
166 seed: 0,
167
168 qout: false,
169 retrain: false,
170 qnorm: false,
171 cutoff: 0,
172 dsub: 2,
173
174 autotune_validation_file: PathBuf::new(),
175 autotune_metric: "f1".to_string(),
176 autotune_predictions: 1,
177 autotune_duration: 300,
178 autotune_model_size: String::new(),
179 }
180 }
181}
182
183impl Args {
184 pub fn new() -> Self {
186 Self::default()
187 }
188
189 pub fn has_autotune(&self) -> bool {
191 !self.autotune_validation_file.as_os_str().is_empty()
192 }
193
194 pub fn apply_supervised_defaults(&mut self) {
199 self.model = ModelName::Supervised;
200 self.loss = LossName::Softmax;
201 self.min_count = 1;
202 self.minn = 0;
203 self.maxn = 0;
204 self.lr = 0.1;
205
206 if self.word_ngrams <= 1 && self.maxn == 0 && !self.has_autotune() {
207 self.bucket = 0;
208 }
209 }
210
211 pub fn save<W: Write>(&self, writer: &mut W) -> Result<()> {
220 utils::write_i32(writer, self.dim)?;
221 utils::write_i32(writer, self.ws)?;
222 utils::write_i32(writer, self.epoch)?;
223 utils::write_i32(writer, self.min_count)?;
224 utils::write_i32(writer, self.neg)?;
225 utils::write_i32(writer, self.word_ngrams)?;
226 utils::write_i32(writer, self.loss as i32)?;
227 utils::write_i32(writer, self.model as i32)?;
228 utils::write_i32(writer, self.bucket)?;
229 utils::write_i32(writer, self.minn)?;
230 utils::write_i32(writer, self.maxn)?;
231 utils::write_i32(writer, self.lr_update_rate)?;
232 utils::write_f64(writer, self.t)?;
233 Ok(())
234 }
235
236 pub fn load<R: Read>(&mut self, reader: &mut R) -> Result<()> {
238 self.dim = utils::read_i32(reader)?;
239 self.ws = utils::read_i32(reader)?;
240 self.epoch = utils::read_i32(reader)?;
241 self.min_count = utils::read_i32(reader)?;
242 self.neg = utils::read_i32(reader)?;
243 self.word_ngrams = utils::read_i32(reader)?;
244 let loss_val = utils::read_i32(reader)?;
245 self.loss = LossName::try_from(loss_val).map_err(|v| {
246 crate::error::FastTextError::InvalidModel(format!("Invalid loss value: {}", v))
247 })?;
248 let model_val = utils::read_i32(reader)?;
249 self.model = ModelName::try_from(model_val).map_err(|v| {
250 crate::error::FastTextError::InvalidModel(format!("Invalid model value: {}", v))
251 })?;
252 self.bucket = utils::read_i32(reader)?;
253 self.minn = utils::read_i32(reader)?;
254 self.maxn = utils::read_i32(reader)?;
255 self.lr_update_rate = utils::read_i32(reader)?;
256 self.t = utils::read_f64(reader)?;
257 Ok(())
258 }
259
260 pub fn get_autotune_metric_name(&self) -> Option<MetricName> {
262 if self.autotune_metric.starts_with("f1:") {
263 Some(MetricName::LabelF1Score)
264 } else if self.autotune_metric == "f1" {
265 Some(MetricName::F1Score)
266 } else if self.autotune_metric.starts_with("precisionAtRecall:") {
267 let rest = &self.autotune_metric[18..];
268 if rest.contains(':') {
269 Some(MetricName::PrecisionAtRecallLabel)
270 } else {
271 Some(MetricName::PrecisionAtRecall)
272 }
273 } else if self.autotune_metric.starts_with("recallAtPrecision:") {
274 let rest = &self.autotune_metric[18..];
275 if rest.contains(':') {
276 Some(MetricName::RecallAtPrecisionLabel)
277 } else {
278 Some(MetricName::RecallAtPrecision)
279 }
280 } else {
281 None
282 }
283 }
284}