use crate::error::TokenizerError;
use crate::vocab::sentencepiece_proto::sentencepiece_model::ModelProto;
use protobuf::Message;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufRead, BufReader, Read};
use std::mem::ManuallyDrop;
use std::path::Path;
use std::ptr;
#[derive(Eq, PartialEq, Hash, Clone, Debug)]
pub struct BpePairRef<'a> {
pub byte_1: &'a String,
pub byte_2: &'a String,
}
#[derive(Debug, Clone)]
pub struct BpePairVocab {
pub values: HashMap<(String, String), i64>,
}
impl BpePairVocab {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<BpePairVocab, TokenizerError> {
let f = File::open(&path).map_err(|e| {
TokenizerError::FileNotFound(format!(
"{} vocabulary file not found :{}",
path.as_ref().display(),
e
))
})?;
let br = BufReader::new(f);
let mut data = HashMap::new();
let mut index = 0;
for line in br.lines().skip(1) {
let line = match line {
Ok(value) => value,
Err(e) => {
return Err(TokenizerError::VocabularyParsingError(e.to_string()));
}
};
let tuple: Vec<String> = line.trim().split(' ').map(|v| v.to_owned()).collect();
if tuple.len() > 1 {
data.insert((tuple[0].clone(), tuple[1].clone()), index);
index += 1;
}
}
Ok(BpePairVocab { values: data })
}
pub fn from_sentencepiece_file<P: AsRef<Path>>(
path: P,
) -> Result<BpePairVocab, TokenizerError> {
let mut f = File::open(&path).map_err(|e| {
TokenizerError::FileNotFound(format!(
"{} vocabulary file not found :{}",
path.as_ref().display(),
e
))
})?;
let mut contents = Vec::new();
let proto = match f.read_to_end(&mut contents) {
Ok(_) => match ModelProto::parse_from_bytes(contents.as_slice()) {
Ok(proto_value) => proto_value,
Err(e) => {
return Err(TokenizerError::VocabularyParsingError(e.to_string()));
}
},
Err(e) => {
return Err(TokenizerError::VocabularyParsingError(e.to_string()));
}
};
let mut values = HashMap::new();
for (idx, piece) in proto.get_pieces().iter().enumerate() {
values.insert(piece.get_piece().to_owned(), idx as i64);
}
let mut data = HashMap::new();
for l_piece in proto.get_pieces().iter().map(|v| v.get_piece()) {
for r_piece in proto.get_pieces().iter().map(|v| v.get_piece()) {
if let Some(id) = values.get(&[l_piece, r_piece].concat()) {
data.insert((l_piece.to_string(), r_piece.to_string()), *id);
}
}
}
Ok(BpePairVocab { values: data })
}
pub fn byte_pair_to_id(&self, byte_pair: &BpePairRef) -> Option<&i64> {
unsafe {
let byte_1 = byte_pair.byte_1;
let byte_2 = byte_pair.byte_2;
let k = (ptr::read(byte_1), ptr::read(byte_2));
let k = ManuallyDrop::new(k);
self.values.get(&k)
}
}
}
#[cfg(test)]
#[allow(clippy::type_complexity)]
mod tests {
extern crate anyhow;
use super::*;
use std::io::Write;
#[test]
fn test_create_pair_vocab() {
let values: HashMap<(String, String), i64> = HashMap::new();
let pair_vocab = BpePairVocab {
values: values.clone(),
};
assert_eq!(pair_vocab.values, values);
}
#[test]
fn test_create_pair_vocab_from_file() -> anyhow::Result<()> {
let mut merges_file = tempfile::NamedTempFile::new()?;
write!(merges_file, "#version: 0.1\n t h\na n\ni n\nth e</w>")?;
let path = merges_file.into_temp_path();
let target_values: HashMap<(String, String), i64> = [
(("t".to_owned(), "h".to_owned()), 0),
(("a".to_owned(), "n".to_owned()), 1),
(("i".to_owned(), "n".to_owned()), 2),
(("th".to_owned(), "e</w>".to_owned()), 3),
]
.iter()
.cloned()
.collect();
let pair_vocab = BpePairVocab::from_file(&path)?;
assert_eq!(pair_vocab.values, target_values);
drop(path);
Ok(())
}
#[test]
fn test_encode_byte_pairs() -> anyhow::Result<()> {
let mut merges_file = tempfile::NamedTempFile::new()?;
write!(merges_file, "#version: 0.1\n t h\na n\ni n\nth e</w>")?;
let path = merges_file.into_temp_path();
let pair_vocab = BpePairVocab::from_file(&path)?;
let t_token = String::from("t");
let h_token = String::from("h");
let a_token = String::from("a");
let i_token = String::from("i");
let n_token = String::from("n");
let th_token = String::from("th");
let e_eow_token = String::from("e</w>");
let test_tuples = [
((t_token, h_token), Some(&(0_i64))),
((a_token.clone(), n_token.clone()), Some(&(1_i64))),
((i_token, n_token), Some(&(2_i64))),
((th_token, e_eow_token.clone()), Some(&(3_i64))),
((a_token, e_eow_token), None),
];
for (input, expected_output) in &test_tuples {
assert_eq!(
pair_vocab.byte_pair_to_id(&BpePairRef {
byte_1: &input.0,
byte_2: &input.1
}),
*expected_output
);
}
drop(path);
Ok(())
}
}