fast_text/lib.rs
1#[macro_use(c)]
2extern crate cute;
3
4use std::collections::HashMap;
5use std::path::Path;
6use std::process::{Command, Output, Stdio};
7
8
9const VERSION: &'static str = "0.1.0";
10// fastText archive version to pull.
11
12const DEBUG: bool = true;
13
14fn s(v: &str) -> String {
15 v.to_string()
16}
17
18/// Installs fastText from the archive on Facebook's github.
19pub fn install() -> Vec<Output> {
20 let cmds = [
21 s("wget https://github.com/facebookresearch/fastText/archive/v0.1.0.zip"),
22 s("unzip v") + VERSION + ".zip",
23 s("cd fastText-") + VERSION + "; make",
24 s("mv fastText-") + VERSION + "/fasttext .",
25 s("rm -r fastText-") + VERSION,
26 s("rm v") + VERSION + ".zip"
27 ];
28 c![
29 Command::new("sh")
30 .arg("-c")
31 .arg(c)
32 .stdout(Stdio::piped())
33 .output()
34 .expect("failed to execute process"),
35
36 for c in cmds.iter()
37 ]
38}
39
40/// runs a fastText command and, if it fails because fastText DNE, installs fastText and tries again.
41fn wrap_install(cmds: &str) -> Output {
42 let st = s("./fasttext ") + cmds;
43 run_cmd(&st)
44}
45
46/// Interface for fastText's supervised learning algorithm.
47///
48/// Each line of input should have labels included as such:
49///
50/// __label__sauce __label__cheese How much does potato starch affect a cheese sauce recipe?
51/// __label__food-safety __label__acidity Dangerous pathogens capable of growing in acidic environments
52/// __label__cast-iron __label__stove How do I cover up the white spots on my cast iron stove?
53///
54///
55/// Documentation from fastText:
56///
57///The following arguments are mandatory:
58/// -input training file path
59/// -output output file path
60///
61///The following arguments are optional:
62/// -verbose verbosity level [2]
63///
64///The following arguments for the dictionary are optional:
65/// -minCount minimal number of word occurences [1]
66/// -minCountLabel minimal number of label occurences [0]
67/// -wordNgrams max length of word ngram [1]
68/// -bucket number of buckets [2000000]
69/// -minn min length of char ngram [0]
70/// -maxn max length of char ngram [0]
71/// -t sampling threshold [0.0001]
72/// -label labels prefix [__label__]
73///
74///The following arguments for training are optional:
75/// -lr learning rate [0.1]
76/// -lrUpdateRate change the rate of updates for the learning rate [100]
77/// -dim size of word vectors [100]
78/// -ws size of the context window [5]
79/// -epoch number of epochs [5]
80/// -neg number of negatives sampled [5]
81/// -loss loss function {ns, hs, softmax} [softmax]
82/// -thread number of threads [12]
83/// -pretrainedVectors pretrained word vectors for supervised learning []
84/// -saveOutput whether output params should be saved [0]
85///
86///The following arguments for quantization are optional:
87/// -cutoff number of words and ngrams to retain [0]
88/// -retrain finetune embeddings if a cutoff is applied [0]
89/// -qnorm quantizing the norm separately [0]
90/// -qout quantizing the classifier [0]
91/// -dsub size of each sub-vector [2]
92pub fn supervised(args: &HashMap<&str, &str>) {
93 gen_mod(s("supervised"), args);
94}
95
96/// Interface to shrink a model's memory requirements.
97///
98/// Full interface from fastText:
99///
100///usage: fasttext quantize <args>
101///
102///The following arguments are mandatory:
103/// -input training file path
104/// -output output file path
105///
106///The following arguments are optional:
107/// -verbose verbosity level [2]
108///
109///The following arguments for the dictionary are optional:
110/// -minCount minimal number of word occurences [5]
111/// -minCountLabel minimal number of label occurences [0]
112/// -wordNgrams max length of word ngram [1]
113/// -bucket number of buckets [2000000]
114/// -minn min length of char ngram [3]
115/// -maxn max length of char ngram [6]
116/// -t sampling threshold [0.0001]
117/// -label labels prefix [__label__]
118///
119///The following arguments for training are optional:
120/// -lr learning rate [0.05]
121/// -lrUpdateRate change the rate of updates for the learning rate [100]
122/// -dim size of word vectors [100]
123/// -ws size of the context window [5]
124/// -epoch number of epochs [5]
125/// -neg number of negatives sampled [5]
126/// -loss loss function {ns, hs, softmax} [ns]
127/// -thread number of threads [12]
128/// -pretrainedVectors pretrained word vectors for supervised learning []
129/// -saveOutput whether output params should be saved [0]
130///
131///The following arguments for quantization are optional:
132/// -cutoff number of words and ngrams to retain [0]
133/// -retrain finetune embeddings if a cutoff is applied [0]
134/// -qnorm quantizing the norm separately [0]
135/// -qout quantizing the classifier [0]
136/// -dsub size of each sub-vector [2]
137pub fn quantize(args: &HashMap<&str, &str>) {
138 gen_mod(s("quantize"), args);
139}
140
141
142/// Classify each line in an input file.
143///
144/// output: a vector of the same length as the number of lines in the input
145/// file where each length is value is a k-length vector of labels for the text from the line.
146///
147/// Documentation from fastText:
148///
149/// usage: fasttext predict[-prob] <model> <test-data> [<k>]
150///
151/// <model> model filename
152/// <test-data> test data filename (if -, read from stdin)
153/// <k> (optional; 1 by default) predict top k labels
154pub fn predict(model: &str, inp: &str, k: u32) -> Vec<Vec<String>> {
155 let mut out = Vec::new();
156 let s = s("predict ") + model + " " + inp + " " + &k.to_string();
157 let r = wrap_install(&s);
158 for p in String::from_utf8_lossy(&r.stdout).split("\n") {
159 let mut innerv = Vec::new();
160 for v in p.split(" ") {
161 if v != "" {
162 innerv.push(v.to_string());
163 }
164 }
165 if innerv.len() != 0 {
166 out.push(innerv);
167 }
168 }
169 out
170}
171
172/// Classify each line in an input file with probabilities of labels.
173///
174/// Documentation from fastText:
175///
176/// usage: fasttext predict[-prob] <model> <test-data> [<k>]
177///
178/// <model> model filename
179/// <test-data> test data filename (if -, read from stdin)
180/// <k> (optional; 1 by default) predict top k labels
181pub fn predict_prob(model: &str, inp: &str, k: u32) -> Vec<Vec<(String, f64)>> {
182 fn ext(l: &str) -> Vec<(String, f64)> {
183 let mut out = Vec::new();
184 let mut f = true;
185 let mut label = "";
186 for u in l.split(" ") {
187 if u != "" {
188 if f {
189 label = u;
190 } else {
191 out.push((label.to_string(), u.parse::<f64>().unwrap()));
192 }
193 f = !f;
194 }
195 }
196 if DEBUG { assert!(f); } // last value is a prob, not a label
197 out
198 }
199 let mut out = Vec::new();
200 let s = s("predict-prob ") + model + " " + inp + " " + &k.to_string();
201 let r = wrap_install(&s);
202 for l in String::from_utf8_lossy(&r.stdout).split("\n") {
203 let v = ext(l);
204 if v.len() > 0 {
205 out.push(v);
206 }
207 }
208 out
209}
210
211
212/// Helper function used to unspool arguments. S is a string with the primary fastText command
213/// (e.g. "skipgram") and args are the named arguments to be passed to it, with keys as the
214/// argument tag and values as the argument value.
215fn gen_mod<'a>(mut s: String, args: &HashMap<&str, &'a str>) {
216 for k in args.keys() {
217 s = s + " -" + k + " " + args.get(k).unwrap();
218 }
219 if !wrap_install(&s).status.success() {
220 panic!("Gen_mod failed with given input: {}", s)
221 }
222}
223
224/// Provides functionality for generating skipgrams.
225///
226/// Include argument names as HashMap keys and argument values as HashMap values, e.g.:
227/// "input" : "sample_text.txt"
228/// "output" : "sample"
229///
230/// Documentation from fastText:
231///
232///
233/// The following arguments are mandatory:
234/// -input training file path
235/// -output output file path
236///
237///The following arguments are optional:
238/// -verbose verbosity level [2]
239///
240///The following arguments for the dictionary are optional:
241/// -minCount minimal number of word occurences [5]
242/// -minCountLabel minimal number of label occurences [0]
243/// -wordNgrams max length of word ngram [1]
244/// -bucket number of buckets [2000000]
245/// -minn min length of char ngram [3]
246/// -maxn max length of char ngram [6]
247/// -t sampling threshold [0.0001]
248/// -label labels prefix [__label__]
249///
250///The following arguments for training are optional:
251/// -lr learning rate [0.05]
252/// -lrUpdateRate change the rate of updates for the learning rate [100]
253/// -dim size of word vectors [100]
254/// -ws size of the context window [5]
255/// -epoch number of epochs [5]
256/// -neg number of negatives sampled [5]
257/// -loss loss function {ns, hs, softmax} [ns]
258/// -thread number of threads [12]
259/// -pretrainedVectors pretrained word vectors for supervised learning []
260/// -saveOutput whether output params should be saved [0]
261///
262///The following arguments for quantization are optional:
263/// -cutoff number of words and ngrams to retain [0]
264/// -retrain finetune embeddings if a cutoff is applied [0]
265/// -qnorm quantizing the norm separately [0]
266/// -qout quantizing the classifier [0]
267/// -dsub size of each sub-vector [2]
268pub fn skipgram(args: &HashMap<&str, &str>) {
269 gen_mod(s("skipgram"), args);
270}
271
272/// Provides functionality for generating a continuous bag of words model.
273///
274/// Include argument names as HashMap keys and argument values as HashMap values, e.g.:
275/// "input" : "sample_text.txt"
276/// "output" : "sample"
277///
278/// Documentation from fastText:
279///
280/// The following arguments are mandatory:
281/// -input training file path
282/// -output output file path
283///
284///The following arguments are optional:
285/// -verbose verbosity level [2]
286///
287///The following arguments for the dictionary are optional:
288/// -minCount minimal number of word occurences [5]
289/// -minCountLabel minimal number of label occurences [0]
290/// -wordNgrams max length of word ngram [1]
291/// -bucket number of buckets [2000000]
292/// -minn min length of char ngram [3]
293/// -maxn max length of char ngram [6]
294/// -t sampling threshold [0.0001]
295/// -label labels prefix [__label__]
296///
297///The following arguments for training are optional:
298/// -lr learning rate [0.05]
299/// -lrUpdateRate change the rate of updates for the learning rate [100]
300/// -dim size of word vectors [100]
301/// -ws size of the context window [5]
302/// -epoch number of epochs [5]
303/// -neg number of negatives sampled [5]
304/// -loss loss function {ns, hs, softmax} [ns]
305/// -thread number of threads [12]
306/// -pretrainedVectors pretrained word vectors for supervised learning []
307/// -saveOutput whether output params should be saved [0]
308///
309///The following arguments for quantization are optional:
310/// -cutoff number of words and ngrams to retain [0]
311/// -retrain finetune embeddings if a cutoff is applied [0]
312/// -qnorm quantizing the norm separately [0]
313/// -qout quantizing the classifier [0]
314/// -dsub size of each sub-vector [2]
315pub fn cbow(args: &HashMap<&str, &str>) {
316 gen_mod(s("cbow"), args);
317}
318
319/// Provides minimal functionality for generating skipgrams.
320///
321/// Full documentation from fastText:
322///
323/// The following arguments are mandatory:
324/// -input training file path
325/// -output output file path
326///
327///The following arguments are optional:
328/// -verbose verbosity level [2]
329///
330///The following arguments for the dictionary are optional:
331/// -minCount minimal number of word occurences [5]
332/// -minCountLabel minimal number of label occurences [0]
333/// -wordNgrams max length of word ngram [1]
334/// -bucket number of buckets [2000000]
335/// -minn min length of char ngram [3]
336/// -maxn max length of char ngram [6]
337/// -t sampling threshold [0.0001]
338/// -label labels prefix [__label__]
339///
340///The following arguments for training are optional:
341/// -lr learning rate [0.05]
342/// -lrUpdateRate change the rate of updates for the learning rate [100]
343/// -dim size of word vectors [100]
344/// -ws size of the context window [5]
345/// -epoch number of epochs [5]
346/// -neg number of negatives sampled [5]
347/// -loss loss function {ns, hs, softmax} [ns]
348/// -thread number of threads [12]
349/// -pretrainedVectors pretrained word vectors for supervised learning []
350/// -saveOutput whether output params should be saved [0]
351///
352///The following arguments for quantization are optional:
353/// -cutoff number of words and ngrams to retain [0]
354/// -retrain finetune embeddings if a cutoff is applied [0]
355/// -qnorm quantizing the norm separately [0]
356/// -qout quantizing the classifier [0]
357/// -dsub size of each sub-vector [2]
358pub fn min_skipgram(input: &str, output: &str) -> String {
359 let st = s("skipgram -input ") + input + " -output " + output;
360 let o = wrap_install(&st);
361 if o.status.success() {
362 s(output) + ".bin"
363 } else {
364 panic!("Min_skipgram failed with given input: {} \noutput: {:?}", st, o)
365 }
366}
367
368
369/// Provides minimal functionality for generating a continuous bag of words model.
370///
371/// Documentation from fastText:
372///
373/// The following arguments are mandatory:
374/// -input training file path
375/// -output output file path
376///
377///The following arguments are optional:
378/// -verbose verbosity level [2]
379///
380///The following arguments for the dictionary are optional:
381/// -minCount minimal number of word occurences [5]
382/// -minCountLabel minimal number of label occurences [0]
383/// -wordNgrams max length of word ngram [1]
384/// -bucket number of buckets [2000000]
385/// -minn min length of char ngram [3]
386/// -maxn max length of char ngram [6]
387/// -t sampling threshold [0.0001]
388/// -label labels prefix [__label__]
389///
390///The following arguments for training are optional:
391/// -lr learning rate [0.05]
392/// -lrUpdateRate change the rate of updates for the learning rate [100]
393/// -dim size of word vectors [100]
394/// -ws size of the context window [5]
395/// -epoch number of epochs [5]
396/// -neg number of negatives sampled [5]
397/// -loss loss function {ns, hs, softmax} [ns]
398/// -thread number of threads [12]
399/// -pretrainedVectors pretrained word vectors for supervised learning []
400/// -saveOutput whether output params should be saved [0]
401///
402///The following arguments for quantization are optional:
403/// -cutoff number of words and ngrams to retain [0]
404/// -retrain finetune embeddings if a cutoff is applied [0]
405/// -qnorm quantizing the norm separately [0]
406/// -qout quantizing the classifier [0]
407/// -dsub size of each sub-vector [2]
408pub fn min_cbow(input: &str, output: &str) -> String {
409 let st = s("cbow -input ") + input + " -output " + output;
410 if wrap_install(&st).status.success() {
411 s(output) + ".bin"
412 } else {
413 panic!("Cbow failed with given input: {}", st)
414 }
415}
416
417
418fn resp(sm: &str, stdout: std::borrow::Cow<str>) -> Vec<Vec<(String, f64)>> {
419 let mut v0 = Vec::new();
420 if DEBUG {
421 println!("Beginning match iteration");
422 println!("stdout: {}", stdout);
423 }
424 for (start, _) in stdout.match_indices(sm) {
425 if DEBUG { println!("Match found: {}", start); }
426 let mut v1 = Vec::new();
427 let mut first = true;
428 for l in stdout[start..].split("\n") {
429 let lar: Vec<&str> = l.split(" ").collect();
430 if DEBUG { println!("{:?}", lar); }
431 if lar.len() == 2 {
432 v1.push((lar[0].to_string(), lar[1].parse::<f64>().unwrap()));
433 } else if lar.len() == 4 && first {
434 v1.push((lar[2].to_string(), lar[3].parse::<f64>().unwrap()));
435 first = false;
436 } else if l == sm || (lar.len() == 4 && !first) {
437 break;
438 } else {
439 panic!("misformatted line in input: {}", l);
440 }
441 }
442 if v1.len() > 0 {
443 v0.push(v1);
444 }
445 }
446 v0
447}
448
449/// runs an arbitrary command and makes sure that the fasttext is set up locally.
450fn run_cmd(cmd: &str) -> Output {
451 if DEBUG { println!("cmd: {}", cmd); }
452 let r = Command::new("sh")
453 .arg("-c")
454 .arg(cmd)
455 .stdout(Stdio::piped())
456 .output()
457 .expect("failed to execute process");
458 if !r.status.success() && !Path::new("./fasttext").exists() {
459 let ir = install();
460 for o in ir.iter() {
461 if !o.status.success() { panic!("Missing files / executable"); }
462 }
463 }
464 if DEBUG { println!("{:?}", r); }
465 r
466}
467
468
469/// Nearest neighbors. Input of "words" are single words separated by spaces.
470///
471/// Full documentation from FastText:
472///
473/// usage: fasttext nn <model> <k>
474///
475/// <model> model filename
476/// <k> (optional; 10 by default) predict top k labels
477pub fn nn(words: &str, model: &str, k: u32) -> Vec<Vec<(String, f64)>> {
478 if DEBUG { println!("NN begun") };
479 let cmd = s("echo ") + words + " | ./fasttext nn " + model + " " + &k.to_string();
480 resp("Query word? ", String::from_utf8_lossy(&run_cmd(&cmd).stdout))
481}
482
483/// Access to the analogies function. Not supported.
484///
485/// Documentation from fastText:
486///
487/// usage: fasttext analogies <model> <k>
488///
489/// <model> model filename
490/// <k> (optional; 10 by default) predict top k labels
491pub fn analogies(analogies: &str, model: &str, k: u32) -> Vec<Vec<(String, f64)>> {
492 unimplemented!();
493 let cmd = s("echo \"") + analogies + "\" | ./fasttext analogies " + model + " " + &k.to_string();
494 // just doing the "echo "cmd" | ./fasttext [...]" thing won't work here since it just keeps
495 // checking stdin and re-outputting results.
496 resp("Query triplet (A - B + C)? ", String::from_utf8_lossy(&run_cmd(&cmd).stdout))
497}
498
499fn parse_vec(cmd: &str, sentence: Option<&str>) -> Vec<Vec<f64>> {
500 let mut out = Vec::new();
501 let mut st = String::from_utf8_lossy(&run_cmd(cmd).stdout).to_string();
502 match sentence {
503 None => (),
504 Some(sent) => {
505 st = st.replace(sent, "");
506 }
507 };
508 for l in st.split("\n") {
509 let mut wordvec = Vec::new();
510 let mut f = true;
511 for t in l.split(" ") {
512 if f {
513 f = false;
514 } else {
515 if t != "" {
516 wordvec.push(t.parse::<f64>().unwrap());
517 }
518 }
519 }
520 if wordvec.len() > 0 {
521 out.push(wordvec);
522 }
523 }
524 out
525}
526
527/// access to the vectors for a given set of words.
528///
529/// Input: one or more words (separated by spaces)
530/// Output: A vec of word vectors (one for each input word)
531pub fn word_vector(words: &str, model: &str) -> Vec<Vec<f64>> {
532 let cmd = s("echo \"") + words + "\" | ./fasttext print-word-vectors " + model;
533 parse_vec(&cmd, None)
534}
535
536
537/// access to the vectors for a given sentence.
538///
539/// Input: sentence
540/// Output: A vec of a sentence vector
541pub fn sentence_vector(sentence: &str, model: &str) -> Vec<Vec<f64>> {
542 let cmd = s("echo \"") + sentence + "\" | ./fasttext print-sentence-vectors " + model;
543 parse_vec(&cmd, Some(sentence))
544}
545
546
547/// the objective for testing here is not to check that the fasttext binary is working as expected,
548/// but that it can be install and that its output can be consistently read.
549
550#[cfg(test)]
551mod tests {
552 extern crate kolmogorov_smirnov as ks;
553
554 use std::{thread, time};
555 use std::collections::HashSet;
556 use super::*;
557
558 fn check_exists(file: &str, or: fn()) {
559 if !Path::new(file).exists() {
560 thread::sleep(time::Duration::from_secs(30));
561 if !Path::new(file).exists() {
562 or()
563 }
564 }
565 }
566
567 fn rm(files: Vec<&str>) {
568 for f in files.iter() {
569 let cmd = s("rm -r ") + f;
570 Command::new("sh")
571 .arg("-c")
572 .arg(&cmd)
573 .stdout(Stdio::piped())
574 .output()
575 .expect("failed to execute process");
576 }
577 }
578
579 fn set(v: Vec<Vec<(String, f64)>>) -> HashSet<String> {
580 let mut out = HashSet::new();
581 for v0 in v.into_iter() {
582 for t in v0.into_iter() {
583 let (st, _) = t;
584 out.insert(st);
585 }
586 }
587 out
588 }
589
590 fn sim(a: &HashSet<String>, b: &HashSet<String>) -> usize {
591 c![v, for v in a.intersection(b)].len()
592 }
593
594 fn inst() {
595 check_exists("fasttext", || { install(); });
596 }
597
598 fn samp() {
599 check_exists("sample.bin", sample_skipgram);
600 }
601
602 #[test]
603 fn test_install() {
604 let rv = install();
605 for r in rv.iter() {
606 println!("{}", String::from_utf8_lossy(&r.stdout));
607 println!("{}", String::from_utf8_lossy(&r.stderr));
608 assert!(r.status.success());
609 }
610
611 let r = Command::new("sh")
612 .arg("-c")
613 .arg("./fasttext")
614 .stdout(Stdio::piped())
615 .output()
616 .expect("failed to execute process");
617 println!("{}", String::from_utf8_lossy(&r.stdout));
618 println!("{}", String::from_utf8_lossy(&r.stderr));
619 assert_eq!(r.status.code(), Some(1)); // returns 127 if ./fasttext DNE
620 }
621
622 /// generate a skipgram model for testing things like the nearest neighbor function.
623 fn sample_skipgram() {
624 inst();
625 let model = min_skipgram("sample_text.txt", "sample");
626 println!("Generated skipgram model: {}", model);
627 }
628
629 #[test]
630 fn test_nn() {
631 samp();
632
633 let out = nn("lesbian", "sample.bin", 10);
634 println!("{:?}", out);
635 assert_eq!(out.len(), 1); // number of words queried
636 assert_eq!(out[0].len(), 10); // k
637
638 let out = nn("lesbian gay", "sample.bin", 5);
639 println!("{:?}", out);
640 assert_eq!(out.len(), 2);
641 assert_eq!(out[0].len(), 5);
642
643 let out = nn("lesbian gay bisexual", "sample.bin", 8);
644 println!("{:?}", out);
645 assert_eq!(out.len(), 3);
646 assert_eq!(out[0].len(), 8);
647
648 let out = nn("lesbian gay bisexual transgender", "sample.bin", 1);
649 println!("{:?}", out);
650 assert_eq!(out.len(), 4);
651 assert_eq!(out[0].len(), 1);
652 }
653
654
655 /// test nearest neighbors for two functions yields valid results.
656 fn test_embedding(min_fn: fn(&str, &str) -> String, reg_fn: fn(&HashMap<&str, &str>), min_name: &str, reg_name: &str) {
657 inst();
658
659 let input = "sample_text.txt";
660 let mut args = HashMap::new();
661 args.insert("input", input);
662 args.insert("output", reg_name);
663
664 let k = 10;
665
666 // Would iterate through a set of arbitrary words to compare on.
667 for w in ["friend", "day", "door"].iter() {
668 let m1 = min_fn(input, min_name);
669 reg_fn(&args);
670 let m2 = s(min_name) + ".bin";
671
672 let r1 = nn(w, &m1, k);
673 let r2 = nn(w, &m2, k);
674
675 assert_eq!(r1.len(), r2.len());
676 for i in 0..r1.len() {
677 assert_eq!(r1[i].len(), r2[i].len());
678 assert_eq!(r1[i].len(), k as usize);
679 }
680
681 assert!(sim(&set(r1), &set(r2)) > (0.9 * k as f64) as usize);
682 }
683
684 let r1 = s(min_name) + "*";
685 let r2 = s(reg_name) + "*";
686 rm(vec![&r1, &r2]);
687 }
688
689 #[test]
690 fn test_skipgram() {
691 test_embedding(min_skipgram, skipgram, "test_min_skipgram", "test_skipgram");
692 }
693
694 #[test]
695 fn test_cbow() {
696 test_embedding(min_cbow, cbow, "test_min_cbow", "test_cbow");
697 }
698
699 fn test_predict(model: String) {
700 let p = predict(&model, "t.txt", 1);
701 println!("test_predict output: {:?}", p);
702 assert_eq!(p[0].len(), 1);
703 assert_eq!(p.len(), 2);
704
705 let p = predict(&model, "t.txt", 2);
706 println!("test_predict output: {:?}", p);
707 assert_eq!(p[0].len(), 2);
708 assert_eq!(p.len(), 2);
709 }
710
711 fn test_predict_prob(model: String) {
712 let p = predict_prob(&model, "t.txt", 1);
713 println!("output of predict_prob: {:?}", p);
714 assert_eq!(p[0].len(), 1);
715 assert_eq!(p.len(), 2);
716
717 let p = predict_prob(&model, "t.txt", 2);
718 println!("output of predict_prob: {:?}", p);
719 assert_eq!(p[0].len(), 2);
720 assert_eq!(p.len(), 2);
721 }
722
723 #[test]
724 fn test_supervised_and_predicts() {
725 inst();
726
727 let model = "sup";
728
729 let args: HashMap<_, _> = vec![
730 ("input", "sample_text.txt"),
731 ("output", model),
732 ].into_iter().collect();
733
734 supervised(&args);
735
736 test_predict(s(model) + ".bin");
737 test_predict_prob(s(model) + ".bin");
738
739 let m = s(model) + "*";
740 rm(vec![&m]);
741 }
742
743 #[test]
744 fn test_word_vector() {
745 samp();
746 let v = word_vector("gay math queen", "sample.bin");
747 assert_eq!(v.len(), 3); // three words go in, three wordvecs come out
748
749 let mut hs = HashSet::new();
750 for wv in v.iter() {
751 hs.insert(wv.len());
752 }
753 assert_eq!(hs.len(), 1); // vectors are all the same length
754
755 let v = word_vector("naps", "sample.bin");
756 assert_eq!(v.len(), 1);
757 }
758
759 #[test]
760 fn test_sentence_vector() {
761 samp();
762 let v = sentence_vector("To die, to sleep – to sleep, perchance to dream – ay, there's the rub, for in this sleep of death what dreams may come…", "sample.bin");
763 assert_eq!(v.len(), 1);
764 }
765}