1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
#[macro_use] extern crate serde_derive; pub use self::utils::launch_run; use rusty_machine::analysis::score::accuracy; use serde_json::{json, Value}; use std::fs::File; pub use structopt::StructOpt; pub type Error = failure::Error; mod csvrow; mod utils; use csvrow::CsvRow; pub fn compute_scores(predictions: Vec<String>, effective_class: Vec<String>) -> () { let acc = accuracy(predictions.iter(), effective_class.iter()); println!("Final accuracy: {}", acc); } fn parse_row( url: &String, rw: Result<CsvRow, csv::Error>, client: &reqwest::Client, predictions: &mut Vec<String>, effective_class: &mut Vec<String>, ) -> Result<(), Error> { let test_row = rw?; let query = test_row.query; let real_intention = test_row.intention; let params = utils::get_params(&query); let resp: Value = client.post(url).json(¶ms).send()?.json()?; let predicted_intention: String = resp["intention"][0][1] .to_string() .trim_matches('\"') .to_string(); println!( " --- query: \'{}\', real_intention: \'{}\', predicted_intention: \'{}\'", query, real_intention, predicted_intention ); effective_class.push(real_intention); predictions.push(predicted_intention); Ok(()) } pub fn parse_csv<I>(url: String, paths: I) -> Result<(Vec<String>, Vec<String>), Error> where I: Iterator<Item = std::path::PathBuf>, { let mut predictions: Vec<String> = Vec::new(); let mut effective_class: Vec<String> = Vec::new(); let client = reqwest::Client::new(); for path in paths { println!( "Loading test file: \'{}\'...", path.clone().into_os_string().into_string().unwrap() ); let file = match File::open(path) { Ok(f) => f, Err(_) => continue, }; let mut rdr = csv::Reader::from_reader(file); for rw in rdr.deserialize::<CsvRow>() { match parse_row(&url, rw, &client, &mut predictions, &mut effective_class) { Ok(_) => (), Err(_) => continue, } } } Ok((predictions, effective_class)) }