rust_transformers/preprocessing/
adapters.rs

1// Copyright 2019 Guillaume Becquin
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//     http://www.apache.org/licenses/LICENSE-2.0
6// Unless required by applicable law or agreed to in writing, software
7// distributed under the License is distributed on an "AS IS" BASIS,
8// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9// See the License for the specific language governing permissions and
10// limitations under the License.
11
12
13use std::fs::File;
14use std::error::Error;
15
16#[derive(Debug)]
17pub enum Label {
18    Positive,
19    Negative,
20    Unassigned,
21}
22
23#[derive(Debug)]
24pub struct Example {
25    pub sentence_1: String,
26    pub sentence_2: String,
27    pub label: Label,
28}
29
30impl Example {
31    fn new(sentence_1: &str, sentence_2: &str, label: &str) -> Result<Self, Box<dyn Error>> {
32        Ok(Example {
33            sentence_1: String::from(sentence_1),
34            sentence_2: String::from(sentence_2),
35            label: match label {
36                "0" => Ok(Label::Negative),
37                "1" => Ok(Label::Positive),
38                _ => Err("invalid label class (must be 0 or 1)")
39            }?,
40        })
41    }
42
43    pub fn new_from_string(sentence: &str) -> Self {
44        Example {
45            sentence_1: String::from(sentence),
46            sentence_2: String::from(""),
47            label: Label::Unassigned,
48        }
49    }
50
51    pub fn new_from_strings(sentence_1: &str, sentence_2: &str) -> Self {
52        Example {
53            sentence_1: String::from(sentence_1),
54            sentence_2: String::from(sentence_2),
55            label: Label::Unassigned,
56        }
57    }
58}
59
60pub fn read_sst2(path: &str, sep: u8) -> Result<Vec<Example>, Box<dyn Error>> {
61    let mut examples: Vec<Example> = Vec::new();
62    let f = File::open(path).expect("Could not open source file.");
63
64    let mut rdr = csv::ReaderBuilder::new()
65        .has_headers(true)
66        .delimiter(sep)
67        .flexible(false)
68        .from_reader(f);
69
70    for result in rdr.records() {
71        let record = result?;
72        let example = Example::new(&record[0], "", &record[1])?;
73        examples.push(example);
74    };
75    Ok(examples)
76}