1use std::thread;
2use std::time::Duration;
3
4use clap::{App, AppSettings, Arg, ArgMatches};
5use indicatif::{ProgressBar, ProgressStyle};
6use stdinout::OrExit;
7
8use finalfrontier::{
9 CommonConfig, DepembedsConfig, LossType, ModelType, SimpleVocabConfig, SkipGramConfig,
10 SubwordVocabConfig, Trainer, Vocab, SGD,
11};
12
13static DEFAULT_CLAP_SETTINGS: &[AppSettings] = &[
14 AppSettings::DontCollapseArgsInUsage,
15 AppSettings::UnifiedHelpMessage,
16];
17
18static BUCKETS: &str = "buckets";
20static CONTEXT: &str = "context";
21static CONTEXT_MINCOUNT: &str = "context_mincount";
22static CONTEXT_DISCARD: &str = "context_discard";
23static DEPENDENCY_DEPTH: &str = "dependency_depth";
24static DIMS: &str = "dims";
25static DISCARD: &str = "discard";
26static EPOCHS: &str = "epochs";
27static LR: &str = "lr";
28static MINCOUNT: &str = "mincount";
29static MINN: &str = "minn";
30static MAXN: &str = "maxn";
31static MODEL: &str = "model";
32static UNTYPED_DEPS: &str = "untyped";
33static NORMALIZE_CONTEXT: &str = "normalize";
34static NS: &str = "ns";
35static PROJECTIVIZE: &str = "projectivize";
36static THREADS: &str = "threads";
37static USE_ROOT: &str = "use_root";
38static ZIPF_EXPONENT: &str = "zipf";
39
40static CORPUS: &str = "CORPUS";
42static OUTPUT: &str = "OUTPUT";
43
44pub struct SkipGramApp {
46 corpus: String,
47 output: String,
48 n_threads: usize,
49 common_config: CommonConfig,
50 skipgram_config: SkipGramConfig,
51 vocab_config: SubwordVocabConfig,
52}
53
54impl Default for SkipGramApp {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl SkipGramApp {
61 pub fn new() -> Self {
63 let matches = build_with_common_opts("ff-train-skipgram")
64 .arg(
65 Arg::with_name(CONTEXT)
66 .long("context")
67 .value_name("CONTEXT_SIZE")
68 .help("Context size")
69 .takes_value(true)
70 .default_value("10"),
71 )
72 .arg(
73 Arg::with_name(MODEL)
74 .long(MODEL)
75 .value_name("MODEL")
76 .help("Model")
77 .takes_value(true)
78 .possible_values(&["dirgram", "skipgram", "structgram"])
79 .default_value("skipgram"),
80 )
81 .get_matches();
82 let corpus = matches.value_of(CORPUS).unwrap().into();
83 let output = matches.value_of(OUTPUT).unwrap().into();
84 let n_threads = matches
85 .value_of("threads")
86 .map(|v| v.parse().or_exit("Cannot parse number of threads", 1))
87 .unwrap_or(num_cpus::get() / 2);
88 SkipGramApp {
89 corpus,
90 output,
91 n_threads,
92 common_config: common_config_from_matches(&matches),
93 skipgram_config: Self::skipgram_config_from_matches(&matches),
94 vocab_config: subword_config_from_matches(&matches),
95 }
96 }
97
98 pub fn corpus(&self) -> &str {
100 self.corpus.as_str()
101 }
102
103 pub fn output(&self) -> &str {
105 self.output.as_str()
106 }
107
108 pub fn n_threads(&self) -> usize {
110 self.n_threads
111 }
112
113 pub fn common_config(&self) -> CommonConfig {
115 self.common_config
116 }
117
118 pub fn skipgram_config(&self) -> SkipGramConfig {
120 self.skipgram_config
121 }
122
123 pub fn vocab_config(&self) -> SubwordVocabConfig {
125 self.vocab_config
126 }
127
128 fn skipgram_config_from_matches(matches: &ArgMatches) -> SkipGramConfig {
129 let context_size = matches
130 .value_of(CONTEXT)
131 .map(|v| v.parse().or_exit("Cannot parse context size", 1))
132 .unwrap();
133 let model = matches
134 .value_of(MODEL)
135 .map(|v| ModelType::try_from_str(v).or_exit("Cannot parse model type", 1))
136 .unwrap();
137
138 SkipGramConfig {
139 context_size,
140 model,
141 }
142 }
143}
144
145pub struct DepembedsApp {
147 corpus: String,
148 output: String,
149 n_threads: usize,
150 common_config: CommonConfig,
151 depembeds_config: DepembedsConfig,
152 input_vocab_config: SubwordVocabConfig,
153 output_vocab_config: SimpleVocabConfig,
154}
155
156impl Default for DepembedsApp {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162impl DepembedsApp {
163 pub fn new() -> Self {
165 let matches =
166 Self::add_depembeds_opts(build_with_common_opts("ff-train-deps")).get_matches();
167 let corpus = matches.value_of(CORPUS).unwrap().into();
168 let output = matches.value_of(OUTPUT).unwrap().into();
169 let n_threads = matches
170 .value_of("threads")
171 .map(|v| v.parse().or_exit("Cannot parse number of threads", 1))
172 .unwrap_or(num_cpus::get() / 2);
173
174 let discard_threshold = matches
175 .value_of(CONTEXT_DISCARD)
176 .map(|v| v.parse().or_exit("Cannot parse discard threshold", 1))
177 .unwrap();
178 let min_count = matches
179 .value_of(CONTEXT_MINCOUNT)
180 .map(|v| v.parse().or_exit("Cannot parse mincount", 1))
181 .unwrap();
182
183 let output_vocab_config = SimpleVocabConfig {
184 min_count,
185 discard_threshold,
186 };
187
188 DepembedsApp {
189 corpus,
190 output,
191 n_threads,
192 common_config: common_config_from_matches(&matches),
193 depembeds_config: Self::depembeds_config_from_matches(&matches),
194 input_vocab_config: subword_config_from_matches(&matches),
195 output_vocab_config,
196 }
197 }
198
199 pub fn corpus(&self) -> &str {
201 self.corpus.as_str()
202 }
203
204 pub fn output(&self) -> &str {
206 self.output.as_str()
207 }
208
209 pub fn n_threads(&self) -> usize {
211 self.n_threads
212 }
213
214 pub fn common_config(&self) -> CommonConfig {
216 self.common_config
217 }
218
219 pub fn depembeds_config(&self) -> DepembedsConfig {
221 self.depembeds_config
222 }
223
224 pub fn input_vocab_config(&self) -> SubwordVocabConfig {
226 self.input_vocab_config
227 }
228
229 pub fn output_vocab_config(&self) -> SimpleVocabConfig {
231 self.output_vocab_config
232 }
233
234 fn add_depembeds_opts<'a, 'b>(app: App<'a, 'b>) -> App<'a, 'b> {
235 app.arg(
236 Arg::with_name(CONTEXT_DISCARD)
237 .long("context_discard")
238 .value_name("CONTEXT_THRESHOLD")
239 .help("Context discard threshold")
240 .takes_value(true)
241 .default_value("1e-4"),
242 )
243 .arg(
244 Arg::with_name(CONTEXT_MINCOUNT)
245 .long("context_mincount")
246 .value_name("CONTEXT_FREQ")
247 .help("Context mincount")
248 .takes_value(true)
249 .default_value("5"),
250 )
251 .arg(
252 Arg::with_name(DEPENDENCY_DEPTH)
253 .long("dependency_depth")
254 .value_name("DEPENDENCY_DEPTH")
255 .help("Dependency depth")
256 .takes_value(true)
257 .default_value("1"),
258 )
259 .arg(
260 Arg::with_name(UNTYPED_DEPS)
261 .long("untyped_deps")
262 .help("Don't use dependency relation labels."),
263 )
264 .arg(
265 Arg::with_name(NORMALIZE_CONTEXT)
266 .long("normalize_context")
267 .help("Normalize contexts"),
268 )
269 .arg(
270 Arg::with_name(PROJECTIVIZE)
271 .long("projectivize")
272 .help("Projectivize dependency graphs before training."),
273 )
274 .arg(
275 Arg::with_name(USE_ROOT)
276 .long("use_root")
277 .help("Use root when extracting dependency contexts."),
278 )
279 }
280
281 fn depembeds_config_from_matches(matches: &ArgMatches) -> DepembedsConfig {
282 let depth = matches
283 .value_of(DEPENDENCY_DEPTH)
284 .map(|v| v.parse().or_exit("Cannot parse dependency depth", 1))
285 .unwrap();
286 let untyped = matches.is_present(UNTYPED_DEPS);
287 let normalize = matches.is_present(NORMALIZE_CONTEXT);
288 let projectivize = matches.is_present(PROJECTIVIZE);
289 let use_root = matches.is_present(USE_ROOT);
290 DepembedsConfig {
291 depth,
292 untyped,
293 normalize,
294 projectivize,
295 use_root,
296 }
297 }
298}
299
300fn build_with_common_opts<'a, 'b>(name: &str) -> App<'a, 'b> {
301 App::new(name)
302 .settings(DEFAULT_CLAP_SETTINGS)
303 .arg(
304 Arg::with_name(BUCKETS)
305 .long("buckets")
306 .value_name("EXP")
307 .help("Number of buckets: 2^EXP")
308 .takes_value(true)
309 .default_value("21"),
310 )
311 .arg(
312 Arg::with_name(DIMS)
313 .long("dims")
314 .value_name("DIMENSIONS")
315 .help("Embedding dimensionality")
316 .takes_value(true)
317 .default_value("300"),
318 )
319 .arg(
320 Arg::with_name(DISCARD)
321 .long("discard")
322 .value_name("THRESHOLD")
323 .help("Discard threshold")
324 .takes_value(true)
325 .default_value("1e-4"),
326 )
327 .arg(
328 Arg::with_name(EPOCHS)
329 .long("epochs")
330 .value_name("N")
331 .help("Number of epochs")
332 .takes_value(true)
333 .default_value("15"),
334 )
335 .arg(
336 Arg::with_name(LR)
337 .long("lr")
338 .value_name("LEARNING_RATE")
339 .help("Initial learning rate")
340 .takes_value(true)
341 .default_value("0.05"),
342 )
343 .arg(
344 Arg::with_name(MINCOUNT)
345 .long("mincount")
346 .value_name("FREQ")
347 .help("Minimum token frequency")
348 .takes_value(true)
349 .default_value("5"),
350 )
351 .arg(
352 Arg::with_name(MINN)
353 .long("minn")
354 .value_name("LEN")
355 .help("Minimum ngram length")
356 .takes_value(true)
357 .default_value("3"),
358 )
359 .arg(
360 Arg::with_name(MAXN)
361 .long("maxn")
362 .value_name("LEN")
363 .help("Maximum ngram length")
364 .takes_value(true)
365 .default_value("6"),
366 )
367 .arg(
368 Arg::with_name(NS)
369 .long("ns")
370 .value_name("FREQ")
371 .help("Negative samples per word")
372 .takes_value(true)
373 .default_value("5"),
374 )
375 .arg(
376 Arg::with_name(THREADS)
377 .long("threads")
378 .value_name("N")
379 .help("Number of threads (default: logical_cpus / 2)")
380 .takes_value(true),
381 )
382 .arg(
383 Arg::with_name(ZIPF_EXPONENT)
384 .long("zipf")
385 .value_name("EXP")
386 .help("Exponent Zipf distribution for negative sampling")
387 .takes_value(true)
388 .default_value("0.5"),
389 )
390 .arg(
391 Arg::with_name(CORPUS)
392 .help("Tokenized corpus")
393 .index(1)
394 .required(true),
395 )
396 .arg(
397 Arg::with_name(OUTPUT)
398 .help("Embeddings output")
399 .index(2)
400 .required(true),
401 )
402}
403
404fn common_config_from_matches(matches: &ArgMatches) -> CommonConfig {
406 let dims = matches
407 .value_of(DIMS)
408 .map(|v| v.parse().or_exit("Cannot parse dimensionality", 1))
409 .unwrap();
410 let epochs = matches
411 .value_of(EPOCHS)
412 .map(|v| v.parse().or_exit("Cannot parse number of epochs", 1))
413 .unwrap();
414 let lr = matches
415 .value_of(LR)
416 .map(|v| v.parse().or_exit("Cannot parse learning rate", 1))
417 .unwrap();
418 let negative_samples = matches
419 .value_of(NS)
420 .map(|v| {
421 v.parse()
422 .or_exit("Cannot parse number of negative samples", 1)
423 })
424 .unwrap();
425 let zipf_exponent = matches
426 .value_of(ZIPF_EXPONENT)
427 .map(|v| {
428 v.parse()
429 .or_exit("Cannot parse exponent zipf distribution", 1)
430 })
431 .unwrap();
432
433 CommonConfig {
434 loss: LossType::LogisticNegativeSampling,
435 dims,
436 epochs,
437 lr,
438 negative_samples,
439 zipf_exponent,
440 }
441}
442
443fn subword_config_from_matches(matches: &ArgMatches) -> SubwordVocabConfig {
445 let buckets_exp = matches
446 .value_of(BUCKETS)
447 .map(|v| v.parse().or_exit("Cannot parse bucket exponent", 1))
448 .unwrap();
449 let discard_threshold = matches
450 .value_of(DISCARD)
451 .map(|v| v.parse().or_exit("Cannot parse discard threshold", 1))
452 .unwrap();
453 let min_count = matches
454 .value_of(MINCOUNT)
455 .map(|v| v.parse().or_exit("Cannot parse mincount", 1))
456 .unwrap();
457 let min_n = matches
458 .value_of(MINN)
459 .map(|v| v.parse().or_exit("Cannot parse minimum n-gram length", 1))
460 .unwrap();
461 let max_n = matches
462 .value_of(MAXN)
463 .map(|v| v.parse().or_exit("Cannot parse maximum n-gram length", 1))
464 .unwrap();
465
466 SubwordVocabConfig {
467 min_n,
468 max_n,
469 buckets_exp,
470 min_count,
471 discard_threshold,
472 }
473}
474
475pub fn show_progress<T, V>(config: &CommonConfig, sgd: &SGD<T>, update_interval: Duration)
476where
477 T: Trainer<InputVocab = V>,
478 V: Vocab,
479{
480 let n_tokens = sgd.model().input_vocab().n_types();
481
482 let pb = ProgressBar::new(u64::from(config.epochs) * n_tokens as u64);
483 pb.set_style(
484 ProgressStyle::default_bar().template("{bar:40} {percent}% {msg} ETA: {eta_precise}"),
485 );
486
487 while sgd.n_tokens_processed() < n_tokens * config.epochs as usize {
488 let lr = (1.0
489 - (sgd.n_tokens_processed() as f32 / (config.epochs as usize * n_tokens) as f32))
490 * config.lr;
491
492 pb.set_position(sgd.n_tokens_processed() as u64);
493 pb.set_message(&format!(
494 "loss: {:.*} lr: {:.*}",
495 5,
496 sgd.train_loss(),
497 5,
498 lr
499 ));
500
501 thread::sleep(update_interval);
502 }
503
504 pb.finish();
505}