use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufRead};
use std::ptr;
use std::mem::ManuallyDrop;
#[derive(Eq, PartialEq, Hash, Clone, Debug)]
pub struct BpePairRef<'a> {
pub byte_1: &'a String,
pub byte_2: &'a String,
}
pub struct BpePairVocab {
pub values: HashMap<(String, String), i64>
}
impl BpePairVocab {
pub fn from_file(path: &str) -> BpePairVocab {
let f = File::open(path).expect("Could not open vocabulary file.");
let br = BufReader::new(f);
let mut data = HashMap::new();
let mut index = 0;
for line in br.lines().skip(1) {
let tuple: Vec<String> = line.unwrap().trim().split(' ').map(|v| v.to_owned()).collect();
if tuple.len() > 1 {
data.insert((tuple[0].clone(), tuple[1].clone()), index);
index += 1;
}
};
BpePairVocab { values: data }
}
pub fn byte_pair_to_id(&self, byte_pair: &BpePairRef) -> Option<&i64> {
unsafe {
let byte_1 = byte_pair.byte_1;
let byte_2 = byte_pair.byte_2;
let k = (ptr::read(byte_1), ptr::read(byte_2));
let k = ManuallyDrop::new(k);
let v = self.values.get(&k);
v
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io;
use std::io::Write;
#[test]
fn test_create_pair_vocab() {
let values: HashMap<(String, String), i64> = HashMap::new();
let pair_vocab = BpePairVocab {
values: values.clone(),
};
assert_eq!(pair_vocab.values, values);
}
#[test]
fn test_create_pair_vocab_from_file() -> Result<(), io::Error> {
let mut merges_file = tempfile::NamedTempFile::new()?;
write!(merges_file, "#version: 0.1\n t h\na n\ni n\nth e</w>")?;
let path = merges_file.into_temp_path();
let target_values: HashMap<(String, String), i64> = [
(("t".to_owned(), "h".to_owned()), 0),
(("a".to_owned(), "n".to_owned()), 1),
(("i".to_owned(), "n".to_owned()), 2),
(("th".to_owned(), "e</w>".to_owned()), 3),
].iter().cloned().collect();
let pair_vocab = BpePairVocab::from_file(path.to_path_buf().to_str().unwrap());
assert_eq!(pair_vocab.values, target_values);
drop(path);
Ok(())
}
#[test]
fn test_encode_byte_pairs() -> Result<(), io::Error> {
let mut merges_file = tempfile::NamedTempFile::new()?;
write!(merges_file, "#version: 0.1\n t h\na n\ni n\nth e</w>")?;
let path = merges_file.into_temp_path();
let pair_vocab = BpePairVocab::from_file(path.to_path_buf().to_str().unwrap());
let t = String::from("t");
let h = String::from("h");
let a = String::from("a");
let i = String::from("i");
let n = String::from("n");
let th = String::from("th");
let e_eow = String::from("e</w>");
let test_tuples = [
(
(t.clone(), h.clone()),
Some(&(0 as i64))
),
(
(a.clone(), n.clone()),
Some(&(1 as i64))
),
(
(i.clone(), n.clone()),
Some(&(2 as i64))
),
(
(th.clone(), e_eow.clone()),
Some(&(3 as i64))
),
(
(a.clone(), e_eow.clone()),
None
)
];
for (input, expected_output) in &test_tuples {
assert_eq!(pair_vocab.byte_pair_to_id(&BpePairRef { byte_1: &input.0, byte_2: &input.1 }), *expected_output);
}
drop(path);
Ok(())
}
}