finalfrontier_utils/
util.rs

1use std::thread;
2use std::time::Duration;
3
4use clap::{App, AppSettings, Arg, ArgMatches};
5use indicatif::{ProgressBar, ProgressStyle};
6use stdinout::OrExit;
7
8use finalfrontier::{
9    CommonConfig, DepembedsConfig, LossType, ModelType, SimpleVocabConfig, SkipGramConfig,
10    SubwordVocabConfig, Trainer, Vocab, SGD,
11};
12
13static DEFAULT_CLAP_SETTINGS: &[AppSettings] = &[
14    AppSettings::DontCollapseArgsInUsage,
15    AppSettings::UnifiedHelpMessage,
16];
17
18// Option constants
19static BUCKETS: &str = "buckets";
20static CONTEXT: &str = "context";
21static CONTEXT_MINCOUNT: &str = "context_mincount";
22static CONTEXT_DISCARD: &str = "context_discard";
23static DEPENDENCY_DEPTH: &str = "dependency_depth";
24static DIMS: &str = "dims";
25static DISCARD: &str = "discard";
26static EPOCHS: &str = "epochs";
27static LR: &str = "lr";
28static MINCOUNT: &str = "mincount";
29static MINN: &str = "minn";
30static MAXN: &str = "maxn";
31static MODEL: &str = "model";
32static UNTYPED_DEPS: &str = "untyped";
33static NORMALIZE_CONTEXT: &str = "normalize";
34static NS: &str = "ns";
35static PROJECTIVIZE: &str = "projectivize";
36static THREADS: &str = "threads";
37static USE_ROOT: &str = "use_root";
38static ZIPF_EXPONENT: &str = "zipf";
39
40// Argument constants
41static CORPUS: &str = "CORPUS";
42static OUTPUT: &str = "OUTPUT";
43
44/// SkipGramApp.
45pub struct SkipGramApp {
46    corpus: String,
47    output: String,
48    n_threads: usize,
49    common_config: CommonConfig,
50    skipgram_config: SkipGramConfig,
51    vocab_config: SubwordVocabConfig,
52}
53
54impl Default for SkipGramApp {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl SkipGramApp {
61    /// Construct new `SkipGramApp`.
62    pub fn new() -> Self {
63        let matches = build_with_common_opts("ff-train-skipgram")
64            .arg(
65                Arg::with_name(CONTEXT)
66                    .long("context")
67                    .value_name("CONTEXT_SIZE")
68                    .help("Context size")
69                    .takes_value(true)
70                    .default_value("10"),
71            )
72            .arg(
73                Arg::with_name(MODEL)
74                    .long(MODEL)
75                    .value_name("MODEL")
76                    .help("Model")
77                    .takes_value(true)
78                    .possible_values(&["dirgram", "skipgram", "structgram"])
79                    .default_value("skipgram"),
80            )
81            .get_matches();
82        let corpus = matches.value_of(CORPUS).unwrap().into();
83        let output = matches.value_of(OUTPUT).unwrap().into();
84        let n_threads = matches
85            .value_of("threads")
86            .map(|v| v.parse().or_exit("Cannot parse number of threads", 1))
87            .unwrap_or(num_cpus::get() / 2);
88        SkipGramApp {
89            corpus,
90            output,
91            n_threads,
92            common_config: common_config_from_matches(&matches),
93            skipgram_config: Self::skipgram_config_from_matches(&matches),
94            vocab_config: subword_config_from_matches(&matches),
95        }
96    }
97
98    /// Get the corpus path.
99    pub fn corpus(&self) -> &str {
100        self.corpus.as_str()
101    }
102
103    /// Get the output path.
104    pub fn output(&self) -> &str {
105        self.output.as_str()
106    }
107
108    /// Get the number of threads.
109    pub fn n_threads(&self) -> usize {
110        self.n_threads
111    }
112
113    /// Get the common config.
114    pub fn common_config(&self) -> CommonConfig {
115        self.common_config
116    }
117
118    /// Get the skipgram config.
119    pub fn skipgram_config(&self) -> SkipGramConfig {
120        self.skipgram_config
121    }
122
123    /// Get the vocab config.
124    pub fn vocab_config(&self) -> SubwordVocabConfig {
125        self.vocab_config
126    }
127
128    fn skipgram_config_from_matches(matches: &ArgMatches) -> SkipGramConfig {
129        let context_size = matches
130            .value_of(CONTEXT)
131            .map(|v| v.parse().or_exit("Cannot parse context size", 1))
132            .unwrap();
133        let model = matches
134            .value_of(MODEL)
135            .map(|v| ModelType::try_from_str(v).or_exit("Cannot parse model type", 1))
136            .unwrap();
137
138        SkipGramConfig {
139            context_size,
140            model,
141        }
142    }
143}
144
145/// DepembedsApp.
146pub struct DepembedsApp {
147    corpus: String,
148    output: String,
149    n_threads: usize,
150    common_config: CommonConfig,
151    depembeds_config: DepembedsConfig,
152    input_vocab_config: SubwordVocabConfig,
153    output_vocab_config: SimpleVocabConfig,
154}
155
156impl Default for DepembedsApp {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162impl DepembedsApp {
163    /// Construct a new `DepembedsApp`.
164    pub fn new() -> Self {
165        let matches =
166            Self::add_depembeds_opts(build_with_common_opts("ff-train-deps")).get_matches();
167        let corpus = matches.value_of(CORPUS).unwrap().into();
168        let output = matches.value_of(OUTPUT).unwrap().into();
169        let n_threads = matches
170            .value_of("threads")
171            .map(|v| v.parse().or_exit("Cannot parse number of threads", 1))
172            .unwrap_or(num_cpus::get() / 2);
173
174        let discard_threshold = matches
175            .value_of(CONTEXT_DISCARD)
176            .map(|v| v.parse().or_exit("Cannot parse discard threshold", 1))
177            .unwrap();
178        let min_count = matches
179            .value_of(CONTEXT_MINCOUNT)
180            .map(|v| v.parse().or_exit("Cannot parse mincount", 1))
181            .unwrap();
182
183        let output_vocab_config = SimpleVocabConfig {
184            min_count,
185            discard_threshold,
186        };
187
188        DepembedsApp {
189            corpus,
190            output,
191            n_threads,
192            common_config: common_config_from_matches(&matches),
193            depembeds_config: Self::depembeds_config_from_matches(&matches),
194            input_vocab_config: subword_config_from_matches(&matches),
195            output_vocab_config,
196        }
197    }
198
199    /// Get the corpus path.
200    pub fn corpus(&self) -> &str {
201        self.corpus.as_str()
202    }
203
204    /// Get the output path.
205    pub fn output(&self) -> &str {
206        self.output.as_str()
207    }
208
209    /// Get the number of threads.
210    pub fn n_threads(&self) -> usize {
211        self.n_threads
212    }
213
214    /// Get the common config.
215    pub fn common_config(&self) -> CommonConfig {
216        self.common_config
217    }
218
219    /// Get the depembeds config.
220    pub fn depembeds_config(&self) -> DepembedsConfig {
221        self.depembeds_config
222    }
223
224    /// Get the input vocab config.
225    pub fn input_vocab_config(&self) -> SubwordVocabConfig {
226        self.input_vocab_config
227    }
228
229    /// Get the output vocab config.
230    pub fn output_vocab_config(&self) -> SimpleVocabConfig {
231        self.output_vocab_config
232    }
233
234    fn add_depembeds_opts<'a, 'b>(app: App<'a, 'b>) -> App<'a, 'b> {
235        app.arg(
236            Arg::with_name(CONTEXT_DISCARD)
237                .long("context_discard")
238                .value_name("CONTEXT_THRESHOLD")
239                .help("Context discard threshold")
240                .takes_value(true)
241                .default_value("1e-4"),
242        )
243        .arg(
244            Arg::with_name(CONTEXT_MINCOUNT)
245                .long("context_mincount")
246                .value_name("CONTEXT_FREQ")
247                .help("Context mincount")
248                .takes_value(true)
249                .default_value("5"),
250        )
251        .arg(
252            Arg::with_name(DEPENDENCY_DEPTH)
253                .long("dependency_depth")
254                .value_name("DEPENDENCY_DEPTH")
255                .help("Dependency depth")
256                .takes_value(true)
257                .default_value("1"),
258        )
259        .arg(
260            Arg::with_name(UNTYPED_DEPS)
261                .long("untyped_deps")
262                .help("Don't use dependency relation labels."),
263        )
264        .arg(
265            Arg::with_name(NORMALIZE_CONTEXT)
266                .long("normalize_context")
267                .help("Normalize contexts"),
268        )
269        .arg(
270            Arg::with_name(PROJECTIVIZE)
271                .long("projectivize")
272                .help("Projectivize dependency graphs before training."),
273        )
274        .arg(
275            Arg::with_name(USE_ROOT)
276                .long("use_root")
277                .help("Use root when extracting dependency contexts."),
278        )
279    }
280
281    fn depembeds_config_from_matches(matches: &ArgMatches) -> DepembedsConfig {
282        let depth = matches
283            .value_of(DEPENDENCY_DEPTH)
284            .map(|v| v.parse().or_exit("Cannot parse dependency depth", 1))
285            .unwrap();
286        let untyped = matches.is_present(UNTYPED_DEPS);
287        let normalize = matches.is_present(NORMALIZE_CONTEXT);
288        let projectivize = matches.is_present(PROJECTIVIZE);
289        let use_root = matches.is_present(USE_ROOT);
290        DepembedsConfig {
291            depth,
292            untyped,
293            normalize,
294            projectivize,
295            use_root,
296        }
297    }
298}
299
300fn build_with_common_opts<'a, 'b>(name: &str) -> App<'a, 'b> {
301    App::new(name)
302        .settings(DEFAULT_CLAP_SETTINGS)
303        .arg(
304            Arg::with_name(BUCKETS)
305                .long("buckets")
306                .value_name("EXP")
307                .help("Number of buckets: 2^EXP")
308                .takes_value(true)
309                .default_value("21"),
310        )
311        .arg(
312            Arg::with_name(DIMS)
313                .long("dims")
314                .value_name("DIMENSIONS")
315                .help("Embedding dimensionality")
316                .takes_value(true)
317                .default_value("300"),
318        )
319        .arg(
320            Arg::with_name(DISCARD)
321                .long("discard")
322                .value_name("THRESHOLD")
323                .help("Discard threshold")
324                .takes_value(true)
325                .default_value("1e-4"),
326        )
327        .arg(
328            Arg::with_name(EPOCHS)
329                .long("epochs")
330                .value_name("N")
331                .help("Number of epochs")
332                .takes_value(true)
333                .default_value("15"),
334        )
335        .arg(
336            Arg::with_name(LR)
337                .long("lr")
338                .value_name("LEARNING_RATE")
339                .help("Initial learning rate")
340                .takes_value(true)
341                .default_value("0.05"),
342        )
343        .arg(
344            Arg::with_name(MINCOUNT)
345                .long("mincount")
346                .value_name("FREQ")
347                .help("Minimum token frequency")
348                .takes_value(true)
349                .default_value("5"),
350        )
351        .arg(
352            Arg::with_name(MINN)
353                .long("minn")
354                .value_name("LEN")
355                .help("Minimum ngram length")
356                .takes_value(true)
357                .default_value("3"),
358        )
359        .arg(
360            Arg::with_name(MAXN)
361                .long("maxn")
362                .value_name("LEN")
363                .help("Maximum ngram length")
364                .takes_value(true)
365                .default_value("6"),
366        )
367        .arg(
368            Arg::with_name(NS)
369                .long("ns")
370                .value_name("FREQ")
371                .help("Negative samples per word")
372                .takes_value(true)
373                .default_value("5"),
374        )
375        .arg(
376            Arg::with_name(THREADS)
377                .long("threads")
378                .value_name("N")
379                .help("Number of threads (default: logical_cpus / 2)")
380                .takes_value(true),
381        )
382        .arg(
383            Arg::with_name(ZIPF_EXPONENT)
384                .long("zipf")
385                .value_name("EXP")
386                .help("Exponent Zipf distribution for negative sampling")
387                .takes_value(true)
388                .default_value("0.5"),
389        )
390        .arg(
391            Arg::with_name(CORPUS)
392                .help("Tokenized corpus")
393                .index(1)
394                .required(true),
395        )
396        .arg(
397            Arg::with_name(OUTPUT)
398                .help("Embeddings output")
399                .index(2)
400                .required(true),
401        )
402}
403
404/// Construct `CommonConfig` from `matches`.
405fn common_config_from_matches(matches: &ArgMatches) -> CommonConfig {
406    let dims = matches
407        .value_of(DIMS)
408        .map(|v| v.parse().or_exit("Cannot parse dimensionality", 1))
409        .unwrap();
410    let epochs = matches
411        .value_of(EPOCHS)
412        .map(|v| v.parse().or_exit("Cannot parse number of epochs", 1))
413        .unwrap();
414    let lr = matches
415        .value_of(LR)
416        .map(|v| v.parse().or_exit("Cannot parse learning rate", 1))
417        .unwrap();
418    let negative_samples = matches
419        .value_of(NS)
420        .map(|v| {
421            v.parse()
422                .or_exit("Cannot parse number of negative samples", 1)
423        })
424        .unwrap();
425    let zipf_exponent = matches
426        .value_of(ZIPF_EXPONENT)
427        .map(|v| {
428            v.parse()
429                .or_exit("Cannot parse exponent zipf distribution", 1)
430        })
431        .unwrap();
432
433    CommonConfig {
434        loss: LossType::LogisticNegativeSampling,
435        dims,
436        epochs,
437        lr,
438        negative_samples,
439        zipf_exponent,
440    }
441}
442
443/// Construct `SubwordVocabConfig` from `matches`.
444fn subword_config_from_matches(matches: &ArgMatches) -> SubwordVocabConfig {
445    let buckets_exp = matches
446        .value_of(BUCKETS)
447        .map(|v| v.parse().or_exit("Cannot parse bucket exponent", 1))
448        .unwrap();
449    let discard_threshold = matches
450        .value_of(DISCARD)
451        .map(|v| v.parse().or_exit("Cannot parse discard threshold", 1))
452        .unwrap();
453    let min_count = matches
454        .value_of(MINCOUNT)
455        .map(|v| v.parse().or_exit("Cannot parse mincount", 1))
456        .unwrap();
457    let min_n = matches
458        .value_of(MINN)
459        .map(|v| v.parse().or_exit("Cannot parse minimum n-gram length", 1))
460        .unwrap();
461    let max_n = matches
462        .value_of(MAXN)
463        .map(|v| v.parse().or_exit("Cannot parse maximum n-gram length", 1))
464        .unwrap();
465
466    SubwordVocabConfig {
467        min_n,
468        max_n,
469        buckets_exp,
470        min_count,
471        discard_threshold,
472    }
473}
474
475pub fn show_progress<T, V>(config: &CommonConfig, sgd: &SGD<T>, update_interval: Duration)
476where
477    T: Trainer<InputVocab = V>,
478    V: Vocab,
479{
480    let n_tokens = sgd.model().input_vocab().n_types();
481
482    let pb = ProgressBar::new(u64::from(config.epochs) * n_tokens as u64);
483    pb.set_style(
484        ProgressStyle::default_bar().template("{bar:40} {percent}% {msg} ETA: {eta_precise}"),
485    );
486
487    while sgd.n_tokens_processed() < n_tokens * config.epochs as usize {
488        let lr = (1.0
489            - (sgd.n_tokens_processed() as f32 / (config.epochs as usize * n_tokens) as f32))
490            * config.lr;
491
492        pb.set_position(sgd.n_tokens_processed() as u64);
493        pb.set_message(&format!(
494            "loss: {:.*} lr: {:.*}",
495            5,
496            sgd.train_loss(),
497            5,
498            lr
499        ));
500
501        thread::sleep(update_interval);
502    }
503
504    pb.finish();
505}