#[macro_use(c)]
extern crate cute;
use std::collections::HashMap;
use std::path::Path;
use std::process::{Command, Output, Stdio};
const VERSION: &'static str = "0.1.0";
const DEBUG: bool = true;
fn s(v: &str) -> String {
v.to_string()
}
pub fn install() -> Vec<Output> {
let cmds = [
s("wget https://github.com/facebookresearch/fastText/archive/v0.1.0.zip"),
s("unzip v") + VERSION + ".zip",
s("cd fastText-") + VERSION + "; make",
s("mv fastText-") + VERSION + "/fasttext .",
s("rm -r fastText-") + VERSION,
s("rm v") + VERSION + ".zip"
];
c![
Command::new("sh")
.arg("-c")
.arg(c)
.stdout(Stdio::piped())
.output()
.expect("failed to execute process"),
for c in cmds.iter()
]
}
fn wrap_install(cmds: &str) -> Output {
let st = s("./fasttext ") + cmds;
run_cmd(&st)
}
pub fn supervised(args: &HashMap<&str, &str>) {
gen_mod(s("supervised"), args);
}
pub fn quantize(args: &HashMap<&str, &str>) {
gen_mod(s("quantize"), args);
}
pub fn predict(model: &str, inp: &str, k: u32) -> Vec<Vec<String>> {
let mut out = Vec::new();
let s = s("predict ") + model + " " + inp + " " + &k.to_string();
let r = wrap_install(&s);
for p in String::from_utf8_lossy(&r.stdout).split("\n") {
let mut innerv = Vec::new();
for v in p.split(" ") {
if v != "" {
innerv.push(v.to_string());
}
}
if innerv.len() != 0 {
out.push(innerv);
}
}
out
}
pub fn predict_prob(model: &str, inp: &str, k: u32) -> Vec<Vec<(String, f64)>> {
fn ext(l: &str) -> Vec<(String, f64)> {
let mut out = Vec::new();
let mut f = true;
let mut label = "";
for u in l.split(" ") {
if u != "" {
if f {
label = u;
} else {
out.push((label.to_string(), u.parse::<f64>().unwrap()));
}
f = !f;
}
}
if DEBUG { assert!(f); } out
}
let mut out = Vec::new();
let s = s("predict-prob ") + model + " " + inp + " " + &k.to_string();
let r = wrap_install(&s);
for l in String::from_utf8_lossy(&r.stdout).split("\n") {
let v = ext(l);
if v.len() > 0 {
out.push(v);
}
}
out
}
fn gen_mod<'a>(mut s: String, args: &HashMap<&str, &'a str>) {
for k in args.keys() {
s = s + " -" + k + " " + args.get(k).unwrap();
}
if !wrap_install(&s).status.success() {
panic!("Gen_mod failed with given input: {}", s)
}
}
pub fn skipgram(args: &HashMap<&str, &str>) {
gen_mod(s("skipgram"), args);
}
pub fn cbow(args: &HashMap<&str, &str>) {
gen_mod(s("cbow"), args);
}
pub fn min_skipgram(input: &str, output: &str) -> String {
let st = s("skipgram -input ") + input + " -output " + output;
let o = wrap_install(&st);
if o.status.success() {
s(output) + ".bin"
} else {
panic!("Min_skipgram failed with given input: {} \noutput: {:?}", st, o)
}
}
pub fn min_cbow(input: &str, output: &str) -> String {
let st = s("cbow -input ") + input + " -output " + output;
if wrap_install(&st).status.success() {
s(output) + ".bin"
} else {
panic!("Cbow failed with given input: {}", st)
}
}
fn resp(sm: &str, stdout: std::borrow::Cow<str>) -> Vec<Vec<(String, f64)>> {
let mut v0 = Vec::new();
if DEBUG {
println!("Beginning match iteration");
println!("stdout: {}", stdout);
}
for (start, _) in stdout.match_indices(sm) {
if DEBUG { println!("Match found: {}", start); }
let mut v1 = Vec::new();
let mut first = true;
for l in stdout[start..].split("\n") {
let lar: Vec<&str> = l.split(" ").collect();
if DEBUG { println!("{:?}", lar); }
if lar.len() == 2 {
v1.push((lar[0].to_string(), lar[1].parse::<f64>().unwrap()));
} else if lar.len() == 4 && first {
v1.push((lar[2].to_string(), lar[3].parse::<f64>().unwrap()));
first = false;
} else if l == sm || (lar.len() == 4 && !first) {
break;
} else {
panic!("misformatted line in input: {}", l);
}
}
if v1.len() > 0 {
v0.push(v1);
}
}
v0
}
fn run_cmd(cmd: &str) -> Output {
if DEBUG { println!("cmd: {}", cmd); }
let r = Command::new("sh")
.arg("-c")
.arg(cmd)
.stdout(Stdio::piped())
.output()
.expect("failed to execute process");
if !r.status.success() && !Path::new("./fasttext").exists() {
let ir = install();
for o in ir.iter() {
if !o.status.success() { panic!("Missing files / executable"); }
}
}
if DEBUG { println!("{:?}", r); }
r
}
pub fn nn(words: &str, model: &str, k: u32) -> Vec<Vec<(String, f64)>> {
if DEBUG { println!("NN begun") };
let cmd = s("echo ") + words + " | ./fasttext nn " + model + " " + &k.to_string();
resp("Query word? ", String::from_utf8_lossy(&run_cmd(&cmd).stdout))
}
pub fn analogies(analogies: &str, model: &str, k: u32) -> Vec<Vec<(String, f64)>> {
unimplemented!();
let cmd = s("echo \"") + analogies + "\" | ./fasttext analogies " + model + " " + &k.to_string();
resp("Query triplet (A - B + C)? ", String::from_utf8_lossy(&run_cmd(&cmd).stdout))
}
fn parse_vec(cmd: &str, sentence: Option<&str>) -> Vec<Vec<f64>> {
let mut out = Vec::new();
let mut st = String::from_utf8_lossy(&run_cmd(cmd).stdout).to_string();
match sentence {
None => (),
Some(sent) => {
st = st.replace(sent, "");
}
};
for l in st.split("\n") {
let mut wordvec = Vec::new();
let mut f = true;
for t in l.split(" ") {
if f {
f = false;
} else {
if t != "" {
wordvec.push(t.parse::<f64>().unwrap());
}
}
}
if wordvec.len() > 0 {
out.push(wordvec);
}
}
out
}
pub fn word_vector(words: &str, model: &str) -> Vec<Vec<f64>> {
let cmd = s("echo \"") + words + "\" | ./fasttext print-word-vectors " + model;
parse_vec(&cmd, None)
}
pub fn sentence_vector(sentence: &str, model: &str) -> Vec<Vec<f64>> {
let cmd = s("echo \"") + sentence + "\" | ./fasttext print-sentence-vectors " + model;
parse_vec(&cmd, Some(sentence))
}
#[cfg(test)]
mod tests {
extern crate kolmogorov_smirnov as ks;
use std::{thread, time};
use std::collections::HashSet;
use super::*;
fn check_exists(file: &str, or: fn()) {
if !Path::new(file).exists() {
thread::sleep(time::Duration::from_secs(30));
if !Path::new(file).exists() {
or()
}
}
}
fn rm(files: Vec<&str>) {
for f in files.iter() {
let cmd = s("rm -r ") + f;
Command::new("sh")
.arg("-c")
.arg(&cmd)
.stdout(Stdio::piped())
.output()
.expect("failed to execute process");
}
}
fn set(v: Vec<Vec<(String, f64)>>) -> HashSet<String> {
let mut out = HashSet::new();
for v0 in v.into_iter() {
for t in v0.into_iter() {
let (st, _) = t;
out.insert(st);
}
}
out
}
fn sim(a: &HashSet<String>, b: &HashSet<String>) -> usize {
c![v, for v in a.intersection(b)].len()
}
fn inst() {
check_exists("fasttext", || { install(); });
}
fn samp() {
check_exists("sample.bin", sample_skipgram);
}
#[test]
fn test_install() {
let rv = install();
for r in rv.iter() {
println!("{}", String::from_utf8_lossy(&r.stdout));
println!("{}", String::from_utf8_lossy(&r.stderr));
assert!(r.status.success());
}
let r = Command::new("sh")
.arg("-c")
.arg("./fasttext")
.stdout(Stdio::piped())
.output()
.expect("failed to execute process");
println!("{}", String::from_utf8_lossy(&r.stdout));
println!("{}", String::from_utf8_lossy(&r.stderr));
assert_eq!(r.status.code(), Some(1)); }
fn sample_skipgram() {
inst();
let model = min_skipgram("sample_text.txt", "sample");
println!("Generated skipgram model: {}", model);
}
#[test]
fn test_nn() {
samp();
let out = nn("lesbian", "sample.bin", 10);
println!("{:?}", out);
assert_eq!(out.len(), 1); assert_eq!(out[0].len(), 10);
let out = nn("lesbian gay", "sample.bin", 5);
println!("{:?}", out);
assert_eq!(out.len(), 2);
assert_eq!(out[0].len(), 5);
let out = nn("lesbian gay bisexual", "sample.bin", 8);
println!("{:?}", out);
assert_eq!(out.len(), 3);
assert_eq!(out[0].len(), 8);
let out = nn("lesbian gay bisexual transgender", "sample.bin", 1);
println!("{:?}", out);
assert_eq!(out.len(), 4);
assert_eq!(out[0].len(), 1);
}
fn test_embedding(min_fn: fn(&str, &str) -> String, reg_fn: fn(&HashMap<&str, &str>), min_name: &str, reg_name: &str) {
inst();
let input = "sample_text.txt";
let mut args = HashMap::new();
args.insert("input", input);
args.insert("output", reg_name);
let k = 10;
for w in ["friend", "day", "door"].iter() {
let m1 = min_fn(input, min_name);
reg_fn(&args);
let m2 = s(min_name) + ".bin";
let r1 = nn(w, &m1, k);
let r2 = nn(w, &m2, k);
assert_eq!(r1.len(), r2.len());
for i in 0..r1.len() {
assert_eq!(r1[i].len(), r2[i].len());
assert_eq!(r1[i].len(), k as usize);
}
assert!(sim(&set(r1), &set(r2)) > (0.9 * k as f64) as usize);
}
let r1 = s(min_name) + "*";
let r2 = s(reg_name) + "*";
rm(vec![&r1, &r2]);
}
#[test]
fn test_skipgram() {
test_embedding(min_skipgram, skipgram, "test_min_skipgram", "test_skipgram");
}
#[test]
fn test_cbow() {
test_embedding(min_cbow, cbow, "test_min_cbow", "test_cbow");
}
fn test_predict(model: String) {
let p = predict(&model, "t.txt", 1);
println!("test_predict output: {:?}", p);
assert_eq!(p[0].len(), 1);
assert_eq!(p.len(), 2);
let p = predict(&model, "t.txt", 2);
println!("test_predict output: {:?}", p);
assert_eq!(p[0].len(), 2);
assert_eq!(p.len(), 2);
}
fn test_predict_prob(model: String) {
let p = predict_prob(&model, "t.txt", 1);
println!("output of predict_prob: {:?}", p);
assert_eq!(p[0].len(), 1);
assert_eq!(p.len(), 2);
let p = predict_prob(&model, "t.txt", 2);
println!("output of predict_prob: {:?}", p);
assert_eq!(p[0].len(), 2);
assert_eq!(p.len(), 2);
}
#[test]
fn test_supervised_and_predicts() {
inst();
let model = "sup";
let args: HashMap<_, _> = vec![
("input", "sample_text.txt"),
("output", model),
].into_iter().collect();
supervised(&args);
test_predict(s(model) + ".bin");
test_predict_prob(s(model) + ".bin");
let m = s(model) + "*";
rm(vec![&m]);
}
#[test]
fn test_word_vector() {
samp();
let v = word_vector("gay math queen", "sample.bin");
assert_eq!(v.len(), 3);
let mut hs = HashSet::new();
for wv in v.iter() {
hs.insert(wv.len());
}
assert_eq!(hs.len(), 1);
let v = word_vector("naps", "sample.bin");
assert_eq!(v.len(), 1);
}
#[test]
fn test_sentence_vector() {
samp();
let v = sentence_vector("To die, to sleep – to sleep, perchance to dream – ay, there's the rub, for in this sleep of death what dreams may come…", "sample.bin");
assert_eq!(v.len(), 1);
}
}