rust_transformers/preprocessing/
adapters.rs1use std::fs::File;
14use std::error::Error;
15
16#[derive(Debug)]
17pub enum Label {
18 Positive,
19 Negative,
20 Unassigned,
21}
22
23#[derive(Debug)]
24pub struct Example {
25 pub sentence_1: String,
26 pub sentence_2: String,
27 pub label: Label,
28}
29
30impl Example {
31 fn new(sentence_1: &str, sentence_2: &str, label: &str) -> Result<Self, Box<dyn Error>> {
32 Ok(Example {
33 sentence_1: String::from(sentence_1),
34 sentence_2: String::from(sentence_2),
35 label: match label {
36 "0" => Ok(Label::Negative),
37 "1" => Ok(Label::Positive),
38 _ => Err("invalid label class (must be 0 or 1)")
39 }?,
40 })
41 }
42
43 pub fn new_from_string(sentence: &str) -> Self {
44 Example {
45 sentence_1: String::from(sentence),
46 sentence_2: String::from(""),
47 label: Label::Unassigned,
48 }
49 }
50
51 pub fn new_from_strings(sentence_1: &str, sentence_2: &str) -> Self {
52 Example {
53 sentence_1: String::from(sentence_1),
54 sentence_2: String::from(sentence_2),
55 label: Label::Unassigned,
56 }
57 }
58}
59
60pub fn read_sst2(path: &str, sep: u8) -> Result<Vec<Example>, Box<dyn Error>> {
61 let mut examples: Vec<Example> = Vec::new();
62 let f = File::open(path).expect("Could not open source file.");
63
64 let mut rdr = csv::ReaderBuilder::new()
65 .has_headers(true)
66 .delimiter(sep)
67 .flexible(false)
68 .from_reader(f);
69
70 for result in rdr.records() {
71 let record = result?;
72 let example = Example::new(&record[0], "", &record[1])?;
73 examples.push(example);
74 };
75 Ok(examples)
76}