Skip to main content

fast_text/
lib.rs

1#[macro_use(c)]
2extern crate cute;
3
4use std::collections::HashMap;
5use std::path::Path;
6use std::process::{Command, Output, Stdio};
7
8
9const VERSION: &'static str = "0.1.0";
10// fastText archive version to pull.
11
12const DEBUG: bool = true;
13
14fn s(v: &str) -> String {
15    v.to_string()
16}
17
18/// Installs fastText from the archive on Facebook's github.
19pub fn install() -> Vec<Output> {
20    let cmds = [
21        s("wget https://github.com/facebookresearch/fastText/archive/v0.1.0.zip"),
22        s("unzip v") + VERSION + ".zip",
23        s("cd fastText-") + VERSION + "; make",
24        s("mv fastText-") + VERSION + "/fasttext .",
25        s("rm -r fastText-") + VERSION,
26        s("rm v") + VERSION + ".zip"
27    ];
28    c![
29            Command::new("sh")
30                .arg("-c")
31                .arg(c)
32                .stdout(Stdio::piped())
33                .output()
34                .expect("failed to execute process"),
35
36            for c in cmds.iter()
37     ]
38}
39
40/// runs a fastText command and, if it fails because fastText DNE, installs fastText and tries again.
41fn wrap_install(cmds: &str) -> Output {
42    let st = s("./fasttext ") + cmds;
43    run_cmd(&st)
44}
45
46/// Interface for fastText's supervised learning algorithm.
47///
48/// Each line of input should have labels included as such:
49///
50/// __label__sauce __label__cheese How much does potato starch affect a cheese sauce recipe?
51/// __label__food-safety __label__acidity Dangerous pathogens capable of growing in acidic environments
52/// __label__cast-iron __label__stove How do I cover up the white spots on my cast iron stove?
53///
54///
55/// Documentation from fastText:
56///
57///The following arguments are mandatory:
58///  -input              training file path
59///  -output             output file path
60///
61///The following arguments are optional:
62///  -verbose            verbosity level [2]
63///
64///The following arguments for the dictionary are optional:
65///  -minCount           minimal number of word occurences [1]
66///  -minCountLabel      minimal number of label occurences [0]
67///  -wordNgrams         max length of word ngram [1]
68///  -bucket             number of buckets [2000000]
69///  -minn               min length of char ngram [0]
70///  -maxn               max length of char ngram [0]
71///  -t                  sampling threshold [0.0001]
72///  -label              labels prefix [__label__]
73///
74///The following arguments for training are optional:
75///  -lr                 learning rate [0.1]
76///  -lrUpdateRate       change the rate of updates for the learning rate [100]
77///  -dim                size of word vectors [100]
78///  -ws                 size of the context window [5]
79///  -epoch              number of epochs [5]
80///  -neg                number of negatives sampled [5]
81///  -loss               loss function {ns, hs, softmax} [softmax]
82///  -thread             number of threads [12]
83///  -pretrainedVectors  pretrained word vectors for supervised learning []
84///  -saveOutput         whether output params should be saved [0]
85///
86///The following arguments for quantization are optional:
87///  -cutoff             number of words and ngrams to retain [0]
88///  -retrain            finetune embeddings if a cutoff is applied [0]
89///  -qnorm              quantizing the norm separately [0]
90///  -qout               quantizing the classifier [0]
91///  -dsub               size of each sub-vector [2]
92pub fn supervised(args: &HashMap<&str, &str>) {
93    gen_mod(s("supervised"), args);
94}
95
96/// Interface to shrink a model's memory requirements.
97///
98/// Full interface from fastText:
99///
100///usage: fasttext quantize <args>
101///
102///The following arguments are mandatory:
103///  -input              training file path
104///  -output             output file path
105///
106///The following arguments are optional:
107///  -verbose            verbosity level [2]
108///
109///The following arguments for the dictionary are optional:
110///  -minCount           minimal number of word occurences [5]
111///  -minCountLabel      minimal number of label occurences [0]
112///  -wordNgrams         max length of word ngram [1]
113///  -bucket             number of buckets [2000000]
114///  -minn               min length of char ngram [3]
115///  -maxn               max length of char ngram [6]
116///  -t                  sampling threshold [0.0001]
117///  -label              labels prefix [__label__]
118///
119///The following arguments for training are optional:
120///  -lr                 learning rate [0.05]
121///  -lrUpdateRate       change the rate of updates for the learning rate [100]
122///  -dim                size of word vectors [100]
123///  -ws                 size of the context window [5]
124///  -epoch              number of epochs [5]
125///  -neg                number of negatives sampled [5]
126///  -loss               loss function {ns, hs, softmax} [ns]
127///  -thread             number of threads [12]
128///  -pretrainedVectors  pretrained word vectors for supervised learning []
129///  -saveOutput         whether output params should be saved [0]
130///
131///The following arguments for quantization are optional:
132///  -cutoff             number of words and ngrams to retain [0]
133///  -retrain            finetune embeddings if a cutoff is applied [0]
134///  -qnorm              quantizing the norm separately [0]
135///  -qout               quantizing the classifier [0]
136///  -dsub               size of each sub-vector [2]
137pub fn quantize(args: &HashMap<&str, &str>) {
138    gen_mod(s("quantize"), args);
139}
140
141
142/// Classify each line in an input file.
143///
144/// output: a vector of the same length as the number of lines in the input
145/// file where each length is value is a  k-length vector of labels for the text from the line.
146///
147/// Documentation from fastText:
148///
149/// usage: fasttext predict[-prob] <model> <test-data> [<k>]
150///
151///  <model>      model filename
152///  <test-data>  test data filename (if -, read from stdin)
153///  <k>          (optional; 1 by default) predict top k labels
154pub fn predict(model: &str, inp: &str, k: u32) -> Vec<Vec<String>> {
155    let mut out = Vec::new();
156    let s = s("predict ") + model + " " + inp + " " + &k.to_string();
157    let r = wrap_install(&s);
158    for p in String::from_utf8_lossy(&r.stdout).split("\n") {
159        let mut innerv = Vec::new();
160        for v in p.split(" ") {
161            if v != "" {
162                innerv.push(v.to_string());
163            }
164        }
165        if innerv.len() != 0 {
166            out.push(innerv);
167        }
168    }
169    out
170}
171
172/// Classify each line in an input file with probabilities of labels.
173///
174/// Documentation from fastText:
175///
176/// usage: fasttext predict[-prob] <model> <test-data> [<k>]
177///
178///  <model>      model filename
179///  <test-data>  test data filename (if -, read from stdin)
180///  <k>          (optional; 1 by default) predict top k labels
181pub fn predict_prob(model: &str, inp: &str, k: u32) -> Vec<Vec<(String, f64)>> {
182    fn ext(l: &str) -> Vec<(String, f64)> {
183        let mut out = Vec::new();
184        let mut f = true;
185        let mut label = "";
186        for u in l.split(" ") {
187            if u != "" {
188                if f {
189                    label = u;
190                } else {
191                    out.push((label.to_string(), u.parse::<f64>().unwrap()));
192                }
193                f = !f;
194            }
195        }
196        if DEBUG { assert!(f); } // last value is a prob, not a label
197        out
198    }
199    let mut out = Vec::new();
200    let s = s("predict-prob ") + model + " " + inp + " " + &k.to_string();
201    let r = wrap_install(&s);
202    for l in String::from_utf8_lossy(&r.stdout).split("\n") {
203        let v = ext(l);
204        if v.len() > 0 {
205            out.push(v);
206        }
207    }
208    out
209}
210
211
212/// Helper function used to unspool arguments. S is a string with the primary fastText command
213/// (e.g. "skipgram") and args are the named arguments to be passed to it, with keys as the
214/// argument tag and values as the argument value.
215fn gen_mod<'a>(mut s: String, args: &HashMap<&str, &'a str>) {
216    for k in args.keys() {
217        s = s + " -" + k + " " + args.get(k).unwrap();
218    }
219    if !wrap_install(&s).status.success() {
220        panic!("Gen_mod failed with given input: {}", s)
221    }
222}
223
224/// Provides functionality for generating skipgrams.
225///
226/// Include argument names as HashMap keys and argument values as HashMap values, e.g.:
227/// "input" : "sample_text.txt"
228/// "output" : "sample"
229///
230/// Documentation from fastText:
231///
232///
233/// The following arguments are mandatory:
234///  -input              training file path
235///  -output             output file path
236///
237///The following arguments are optional:
238///  -verbose            verbosity level [2]
239///
240///The following arguments for the dictionary are optional:
241///  -minCount           minimal number of word occurences [5]
242///  -minCountLabel      minimal number of label occurences [0]
243///  -wordNgrams         max length of word ngram [1]
244///  -bucket             number of buckets [2000000]
245///  -minn               min length of char ngram [3]
246///  -maxn               max length of char ngram [6]
247///  -t                  sampling threshold [0.0001]
248///  -label              labels prefix [__label__]
249///
250///The following arguments for training are optional:
251///  -lr                 learning rate [0.05]
252///  -lrUpdateRate       change the rate of updates for the learning rate [100]
253///  -dim                size of word vectors [100]
254///  -ws                 size of the context window [5]
255///  -epoch              number of epochs [5]
256///  -neg                number of negatives sampled [5]
257///  -loss               loss function {ns, hs, softmax} [ns]
258///  -thread             number of threads [12]
259///  -pretrainedVectors  pretrained word vectors for supervised learning []
260///  -saveOutput         whether output params should be saved [0]
261///
262///The following arguments for quantization are optional:
263///  -cutoff             number of words and ngrams to retain [0]
264///  -retrain            finetune embeddings if a cutoff is applied [0]
265///  -qnorm              quantizing the norm separately [0]
266///  -qout               quantizing the classifier [0]
267///  -dsub               size of each sub-vector [2]
268pub fn skipgram(args: &HashMap<&str, &str>) {
269    gen_mod(s("skipgram"), args);
270}
271
272/// Provides functionality for generating a continuous bag of words model.
273///
274/// Include argument names as HashMap keys and argument values as HashMap values, e.g.:
275/// "input" : "sample_text.txt"
276/// "output" : "sample"
277///
278/// Documentation from fastText:
279///
280/// The following arguments are mandatory:
281///  -input              training file path
282///  -output             output file path
283///
284///The following arguments are optional:
285///  -verbose            verbosity level [2]
286///
287///The following arguments for the dictionary are optional:
288///  -minCount           minimal number of word occurences [5]
289///  -minCountLabel      minimal number of label occurences [0]
290///  -wordNgrams         max length of word ngram [1]
291///  -bucket             number of buckets [2000000]
292///  -minn               min length of char ngram [3]
293///  -maxn               max length of char ngram [6]
294///  -t                  sampling threshold [0.0001]
295///  -label              labels prefix [__label__]
296///
297///The following arguments for training are optional:
298///  -lr                 learning rate [0.05]
299///  -lrUpdateRate       change the rate of updates for the learning rate [100]
300///  -dim                size of word vectors [100]
301///  -ws                 size of the context window [5]
302///  -epoch              number of epochs [5]
303///  -neg                number of negatives sampled [5]
304///  -loss               loss function {ns, hs, softmax} [ns]
305///  -thread             number of threads [12]
306///  -pretrainedVectors  pretrained word vectors for supervised learning []
307///  -saveOutput         whether output params should be saved [0]
308///
309///The following arguments for quantization are optional:
310///  -cutoff             number of words and ngrams to retain [0]
311///  -retrain            finetune embeddings if a cutoff is applied [0]
312///  -qnorm              quantizing the norm separately [0]
313///  -qout               quantizing the classifier [0]
314///  -dsub               size of each sub-vector [2]
315pub fn cbow(args: &HashMap<&str, &str>) {
316    gen_mod(s("cbow"), args);
317}
318
319/// Provides minimal functionality for generating skipgrams.
320///
321/// Full documentation from fastText:
322///
323/// The following arguments are mandatory:
324///  -input              training file path
325///  -output             output file path
326///
327///The following arguments are optional:
328///  -verbose            verbosity level [2]
329///
330///The following arguments for the dictionary are optional:
331///  -minCount           minimal number of word occurences [5]
332///  -minCountLabel      minimal number of label occurences [0]
333///  -wordNgrams         max length of word ngram [1]
334///  -bucket             number of buckets [2000000]
335///  -minn               min length of char ngram [3]
336///  -maxn               max length of char ngram [6]
337///  -t                  sampling threshold [0.0001]
338///  -label              labels prefix [__label__]
339///
340///The following arguments for training are optional:
341///  -lr                 learning rate [0.05]
342///  -lrUpdateRate       change the rate of updates for the learning rate [100]
343///  -dim                size of word vectors [100]
344///  -ws                 size of the context window [5]
345///  -epoch              number of epochs [5]
346///  -neg                number of negatives sampled [5]
347///  -loss               loss function {ns, hs, softmax} [ns]
348///  -thread             number of threads [12]
349///  -pretrainedVectors  pretrained word vectors for supervised learning []
350///  -saveOutput         whether output params should be saved [0]
351///
352///The following arguments for quantization are optional:
353///  -cutoff             number of words and ngrams to retain [0]
354///  -retrain            finetune embeddings if a cutoff is applied [0]
355///  -qnorm              quantizing the norm separately [0]
356///  -qout               quantizing the classifier [0]
357///  -dsub               size of each sub-vector [2]
358pub fn min_skipgram(input: &str, output: &str) -> String {
359    let st = s("skipgram -input ") + input + " -output " + output;
360    let o = wrap_install(&st);
361    if o.status.success() {
362        s(output) + ".bin"
363    } else {
364        panic!("Min_skipgram failed with given input: {} \noutput: {:?}", st, o)
365    }
366}
367
368
369/// Provides minimal functionality for generating a continuous bag of words model.
370///
371/// Documentation from fastText:
372///
373/// The following arguments are mandatory:
374///  -input              training file path
375///  -output             output file path
376///
377///The following arguments are optional:
378///  -verbose            verbosity level [2]
379///
380///The following arguments for the dictionary are optional:
381///  -minCount           minimal number of word occurences [5]
382///  -minCountLabel      minimal number of label occurences [0]
383///  -wordNgrams         max length of word ngram [1]
384///  -bucket             number of buckets [2000000]
385///  -minn               min length of char ngram [3]
386///  -maxn               max length of char ngram [6]
387///  -t                  sampling threshold [0.0001]
388///  -label              labels prefix [__label__]
389///
390///The following arguments for training are optional:
391///  -lr                 learning rate [0.05]
392///  -lrUpdateRate       change the rate of updates for the learning rate [100]
393///  -dim                size of word vectors [100]
394///  -ws                 size of the context window [5]
395///  -epoch              number of epochs [5]
396///  -neg                number of negatives sampled [5]
397///  -loss               loss function {ns, hs, softmax} [ns]
398///  -thread             number of threads [12]
399///  -pretrainedVectors  pretrained word vectors for supervised learning []
400///  -saveOutput         whether output params should be saved [0]
401///
402///The following arguments for quantization are optional:
403///  -cutoff             number of words and ngrams to retain [0]
404///  -retrain            finetune embeddings if a cutoff is applied [0]
405///  -qnorm              quantizing the norm separately [0]
406///  -qout               quantizing the classifier [0]
407///  -dsub               size of each sub-vector [2]
408pub fn min_cbow(input: &str, output: &str) -> String {
409    let st = s("cbow -input ") + input + " -output " + output;
410    if wrap_install(&st).status.success() {
411        s(output) + ".bin"
412    } else {
413        panic!("Cbow failed with given input: {}", st)
414    }
415}
416
417
418fn resp(sm: &str, stdout: std::borrow::Cow<str>) -> Vec<Vec<(String, f64)>> {
419    let mut v0 = Vec::new();
420    if DEBUG {
421        println!("Beginning match iteration");
422        println!("stdout: {}", stdout);
423    }
424    for (start, _) in stdout.match_indices(sm) {
425        if DEBUG { println!("Match found: {}", start); }
426        let mut v1 = Vec::new();
427        let mut first = true;
428        for l in stdout[start..].split("\n") {
429            let lar: Vec<&str> = l.split(" ").collect();
430            if DEBUG { println!("{:?}", lar); }
431            if lar.len() == 2 {
432                v1.push((lar[0].to_string(), lar[1].parse::<f64>().unwrap()));
433            } else if lar.len() == 4 && first {
434                v1.push((lar[2].to_string(), lar[3].parse::<f64>().unwrap()));
435                first = false;
436            } else if l == sm || (lar.len() == 4 && !first) {
437                break;
438            } else {
439                panic!("misformatted line in input: {}", l);
440            }
441        }
442        if v1.len() > 0 {
443            v0.push(v1);
444        }
445    }
446    v0
447}
448
449/// runs an arbitrary command and makes sure that the fasttext is set up locally.
450fn run_cmd(cmd: &str) -> Output {
451    if DEBUG { println!("cmd: {}", cmd); }
452    let r = Command::new("sh")
453        .arg("-c")
454        .arg(cmd)
455        .stdout(Stdio::piped())
456        .output()
457        .expect("failed to execute process");
458    if !r.status.success() && !Path::new("./fasttext").exists() {
459        let ir = install();
460        for o in ir.iter() {
461            if !o.status.success() { panic!("Missing files / executable"); }
462        }
463    }
464    if DEBUG { println!("{:?}", r); }
465    r
466}
467
468
469/// Nearest neighbors. Input of "words" are single words separated by spaces.
470///
471/// Full documentation from FastText:
472///
473/// usage: fasttext nn <model> <k>
474///
475///  <model>      model filename
476///  <k>          (optional; 10 by default) predict top k labels
477pub fn nn(words: &str, model: &str, k: u32) -> Vec<Vec<(String, f64)>> {
478    if DEBUG { println!("NN begun") };
479    let cmd = s("echo ") + words + " | ./fasttext nn " + model + " " + &k.to_string();
480    resp("Query word? ", String::from_utf8_lossy(&run_cmd(&cmd).stdout))
481}
482
483/// Access to the analogies function. Not supported.
484///
485/// Documentation from fastText:
486///
487/// usage: fasttext analogies <model> <k>
488///
489///  <model>      model filename
490///  <k>          (optional; 10 by default) predict top k labels
491pub fn analogies(analogies: &str, model: &str, k: u32) -> Vec<Vec<(String, f64)>> {
492    unimplemented!();
493    let cmd = s("echo \"") + analogies + "\" | ./fasttext analogies " + model + " " + &k.to_string();
494    // just doing the "echo "cmd" | ./fasttext [...]" thing won't work here since it just keeps
495    // checking stdin and re-outputting results.
496    resp("Query triplet (A - B + C)? ", String::from_utf8_lossy(&run_cmd(&cmd).stdout))
497}
498
499fn parse_vec(cmd: &str, sentence: Option<&str>) -> Vec<Vec<f64>> {
500    let mut out = Vec::new();
501    let mut st = String::from_utf8_lossy(&run_cmd(cmd).stdout).to_string();
502    match sentence {
503        None => (),
504        Some(sent) => {
505            st = st.replace(sent, "");
506        }
507    };
508    for l in st.split("\n") {
509        let mut wordvec = Vec::new();
510        let mut f = true;
511        for t in l.split(" ") {
512            if f {
513                f = false;
514            } else {
515                if t != "" {
516                    wordvec.push(t.parse::<f64>().unwrap());
517                }
518            }
519        }
520        if wordvec.len() > 0 {
521            out.push(wordvec);
522        }
523    }
524    out
525}
526
527/// access to the vectors for a given set of words.
528///
529/// Input: one or more words (separated by spaces)
530/// Output: A vec of word vectors (one for each input word)
531pub fn word_vector(words: &str, model: &str) -> Vec<Vec<f64>> {
532    let cmd = s("echo \"") + words + "\" | ./fasttext print-word-vectors " + model;
533    parse_vec(&cmd, None)
534}
535
536
537/// access to the vectors for a given sentence.
538///
539/// Input: sentence
540/// Output: A vec of a sentence vector
541pub fn sentence_vector(sentence: &str, model: &str) -> Vec<Vec<f64>> {
542    let cmd = s("echo \"") + sentence + "\" | ./fasttext print-sentence-vectors " + model;
543    parse_vec(&cmd, Some(sentence))
544}
545
546
547/// the objective for testing here is not to check that the fasttext binary is working as expected,
548/// but that it can be install and that its output can be consistently read.
549
550#[cfg(test)]
551mod tests {
552    extern crate kolmogorov_smirnov as ks;
553
554    use std::{thread, time};
555    use std::collections::HashSet;
556    use super::*;
557
558    fn check_exists(file: &str, or: fn()) {
559        if !Path::new(file).exists() {
560            thread::sleep(time::Duration::from_secs(30));
561            if !Path::new(file).exists() {
562                or()
563            }
564        }
565    }
566
567    fn rm(files: Vec<&str>) {
568        for f in files.iter() {
569            let cmd = s("rm -r ") + f;
570            Command::new("sh")
571                .arg("-c")
572                .arg(&cmd)
573                .stdout(Stdio::piped())
574                .output()
575                .expect("failed to execute process");
576        }
577    }
578
579    fn set(v: Vec<Vec<(String, f64)>>) -> HashSet<String> {
580        let mut out = HashSet::new();
581        for v0 in v.into_iter() {
582            for t in v0.into_iter() {
583                let (st, _) = t;
584                out.insert(st);
585            }
586        }
587        out
588    }
589
590    fn sim(a: &HashSet<String>, b: &HashSet<String>) -> usize {
591        c![v, for v in a.intersection(b)].len()
592    }
593
594    fn inst() {
595        check_exists("fasttext", || { install(); });
596    }
597
598    fn samp() {
599        check_exists("sample.bin", sample_skipgram);
600    }
601
602    #[test]
603    fn test_install() {
604        let rv = install();
605        for r in rv.iter() {
606            println!("{}", String::from_utf8_lossy(&r.stdout));
607            println!("{}", String::from_utf8_lossy(&r.stderr));
608            assert!(r.status.success());
609        }
610
611        let r = Command::new("sh")
612            .arg("-c")
613            .arg("./fasttext")
614            .stdout(Stdio::piped())
615            .output()
616            .expect("failed to execute process");
617        println!("{}", String::from_utf8_lossy(&r.stdout));
618        println!("{}", String::from_utf8_lossy(&r.stderr));
619        assert_eq!(r.status.code(), Some(1)); // returns 127 if ./fasttext DNE
620    }
621
622    /// generate a skipgram model for testing things like the nearest neighbor function.
623    fn sample_skipgram() {
624        inst();
625        let model = min_skipgram("sample_text.txt", "sample");
626        println!("Generated skipgram model: {}", model);
627    }
628
629    #[test]
630    fn test_nn() {
631        samp();
632
633        let out = nn("lesbian", "sample.bin", 10);
634        println!("{:?}", out);
635        assert_eq!(out.len(), 1); // number of words queried
636        assert_eq!(out[0].len(), 10); // k
637
638        let out = nn("lesbian gay", "sample.bin", 5);
639        println!("{:?}", out);
640        assert_eq!(out.len(), 2);
641        assert_eq!(out[0].len(), 5);
642
643        let out = nn("lesbian gay bisexual", "sample.bin", 8);
644        println!("{:?}", out);
645        assert_eq!(out.len(), 3);
646        assert_eq!(out[0].len(), 8);
647
648        let out = nn("lesbian gay bisexual transgender", "sample.bin", 1);
649        println!("{:?}", out);
650        assert_eq!(out.len(), 4);
651        assert_eq!(out[0].len(), 1);
652    }
653
654
655    /// test nearest neighbors for two functions yields valid results.
656    fn test_embedding(min_fn: fn(&str, &str) -> String, reg_fn: fn(&HashMap<&str, &str>), min_name: &str, reg_name: &str) {
657        inst();
658
659        let input = "sample_text.txt";
660        let mut args = HashMap::new();
661        args.insert("input", input);
662        args.insert("output", reg_name);
663
664        let k = 10;
665
666        // Would iterate through a set of arbitrary words to compare on.
667        for w in ["friend", "day", "door"].iter() {
668            let m1 = min_fn(input, min_name);
669            reg_fn(&args);
670            let m2 = s(min_name) + ".bin";
671
672            let r1 = nn(w, &m1, k);
673            let r2 = nn(w, &m2, k);
674
675            assert_eq!(r1.len(), r2.len());
676            for i in 0..r1.len() {
677                assert_eq!(r1[i].len(), r2[i].len());
678                assert_eq!(r1[i].len(), k as usize);
679            }
680
681            assert!(sim(&set(r1), &set(r2)) > (0.9 * k as f64) as usize);
682        }
683
684        let r1 = s(min_name) + "*";
685        let r2 = s(reg_name) + "*";
686        rm(vec![&r1, &r2]);
687    }
688
689    #[test]
690    fn test_skipgram() {
691        test_embedding(min_skipgram, skipgram, "test_min_skipgram", "test_skipgram");
692    }
693
694    #[test]
695    fn test_cbow() {
696        test_embedding(min_cbow, cbow, "test_min_cbow", "test_cbow");
697    }
698
699    fn test_predict(model: String) {
700        let p = predict(&model, "t.txt", 1);
701        println!("test_predict output: {:?}", p);
702        assert_eq!(p[0].len(), 1);
703        assert_eq!(p.len(), 2);
704
705        let p = predict(&model, "t.txt", 2);
706        println!("test_predict output: {:?}", p);
707        assert_eq!(p[0].len(), 2);
708        assert_eq!(p.len(), 2);
709    }
710
711    fn test_predict_prob(model: String) {
712        let p = predict_prob(&model, "t.txt", 1);
713        println!("output of predict_prob: {:?}", p);
714        assert_eq!(p[0].len(), 1);
715        assert_eq!(p.len(), 2);
716
717        let p = predict_prob(&model, "t.txt", 2);
718        println!("output of predict_prob: {:?}", p);
719        assert_eq!(p[0].len(), 2);
720        assert_eq!(p.len(), 2);
721    }
722
723    #[test]
724    fn test_supervised_and_predicts() {
725        inst();
726
727        let model = "sup";
728
729        let args: HashMap<_, _> = vec![
730            ("input", "sample_text.txt"),
731            ("output", model),
732        ].into_iter().collect();
733
734        supervised(&args);
735
736        test_predict(s(model) + ".bin");
737        test_predict_prob(s(model) + ".bin");
738
739        let m = s(model) + "*";
740        rm(vec![&m]);
741    }
742
743    #[test]
744    fn test_word_vector() {
745        samp();
746        let v = word_vector("gay math queen", "sample.bin");
747        assert_eq!(v.len(), 3); // three words go in, three wordvecs come out
748
749        let mut hs = HashSet::new();
750        for wv in v.iter() {
751            hs.insert(wv.len());
752        }
753        assert_eq!(hs.len(), 1); // vectors are all the same length
754
755        let v = word_vector("naps", "sample.bin");
756        assert_eq!(v.len(), 1);
757    }
758
759    #[test]
760    fn test_sentence_vector() {
761        samp();
762        let v = sentence_vector("To die, to sleep – to sleep, perchance to dream – ay, there's the rub, for in this sleep of death what dreams may come…", "sample.bin");
763        assert_eq!(v.len(), 1);
764    }
765}