use crate::error::TokenizerError;
use crate::tokenizer::base_tokenizer::{Token, TokenRef};
use crate::tokenizer::tokenization_utils::{is_punctuation, is_whitespace};
use crate::vocab::sentencepiece_proto::sentencepiece_model::ModelProto;
use crate::{Mask, Offset, OffsetSize};
use hashbrown::HashMap;
use protobuf::Message;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::fs::File;
use std::io::Read;
use std::ops::Index;
#[derive(Debug, Clone)]
pub struct BpeMergeVocab {
pub values: HashMap<String, i64>,
}
pub struct SentencePieceBpeModel {
bpe_ranks: BpeMergeVocab,
}
impl SentencePieceBpeModel {
pub fn from_file(path: &str) -> Result<SentencePieceBpeModel, TokenizerError> {
let mut f = File::open(path).map_err(|e| {
TokenizerError::FileNotFound(format!("{} vocabulary file not found :{}", path, e))
})?;
let mut contents = Vec::new();
let proto = match f.read_to_end(&mut contents) {
Ok(_) => match ModelProto::parse_from_bytes(contents.as_slice()) {
Ok(proto_value) => proto_value,
Err(e) => {
return Err(TokenizerError::VocabularyParsingError(e.to_string()));
}
},
Err(e) => {
return Err(TokenizerError::VocabularyParsingError(e.to_string()));
}
};
let mut values = HashMap::new();
for (idx, piece) in proto.get_pieces().iter().enumerate() {
values.insert(piece.get_piece().to_owned(), idx as i64);
}
let bpe_ranks = BpeMergeVocab { values };
Ok(SentencePieceBpeModel { bpe_ranks })
}
pub fn tokenize_to_tokens(&self, initial_token: TokenRef) -> Vec<Token> {
let mut sub_tokens = Vec::new();
if initial_token.mask != Mask::Special && initial_token.mask != Mask::Unknown {
let mut agenda: BinaryHeap<SymbolPair> = BinaryHeap::new();
let mut symbols = SymbolList::from(initial_token);
for symbol_index in 1..symbols.len() {
self.maybe_add_pair(
symbol_index as isize - 1,
symbol_index as isize,
initial_token.text,
&symbols,
&mut agenda,
);
}
while let Some(symbol_pair) = agenda.pop() {
let left_symbol_index = symbol_pair.left;
let right_symbol_index = symbol_pair.right;
if left_symbol_index != -1 && right_symbol_index != -1 {
let new_symbol = symbols.merge_symbols(
left_symbol_index as usize,
right_symbol_index as usize,
symbol_pair.pair_size,
);
if let Some(new_symbol) = new_symbol {
self.maybe_add_pair(
new_symbol.prev,
left_symbol_index,
initial_token.text,
&symbols,
&mut agenda,
);
self.maybe_add_pair(
left_symbol_index,
new_symbol.next,
initial_token.text,
&symbols,
&mut agenda,
);
}
}
}
for symbol in symbols.into_iter().flatten() {
sub_tokens.push(Token {
text: initial_token.text[symbol.start_byte..symbol.end_byte].to_string(),
offset: Offset {
begin: symbol.start_offset as OffsetSize + initial_token.offset.begin,
end: symbol.end_offset as OffsetSize + initial_token.offset.begin,
},
reference_offsets: initial_token.reference_offsets
[symbol.start_offset..symbol.end_offset]
.to_vec(),
mask: Default::default(),
})
}
} else {
sub_tokens.push(initial_token.to_owned());
}
self.populate_masks(sub_tokens.as_mut_slice(), '\u{2581}');
sub_tokens
}
fn maybe_add_pair(
&self,
left_symbol_index: isize,
right_symbol_index: isize,
input_text: &str,
symbols: &SymbolList,
agenda: &mut BinaryHeap<SymbolPair>,
) {
if left_symbol_index != -1 && right_symbol_index != -1 {
if let (Some(left_symbol), Some(right_symbol)) = (
symbols[left_symbol_index as usize],
symbols[right_symbol_index as usize],
) {
let merged_text = &input_text[left_symbol.start_byte..right_symbol.end_byte];
if let Some(&score) = self.bpe_ranks.values.get(merged_text) {
agenda.push(SymbolPair {
left: left_symbol_index,
right: right_symbol_index,
score,
pair_size: left_symbol.size + right_symbol.size,
})
}
}
}
}
pub fn populate_masks(&self, tokens: &mut [Token], whitespace_token: char) {
let mut previous_mask = Mask::None;
for token in tokens {
if token.text.chars().count() == 1 {
let first_char = match token.text.chars().last() {
Some(value) => value,
None => {
token.mask = Mask::Unknown;
previous_mask = Mask::Unknown;
continue;
}
};
if is_punctuation(&first_char) {
token.mask = Mask::Punctuation;
previous_mask = Mask::Punctuation;
continue;
}
if is_whitespace(&first_char) {
token.mask = Mask::Whitespace;
previous_mask = Mask::Punctuation;
continue;
}
}
if !token.text.starts_with(whitespace_token)
& !(previous_mask == Mask::Punctuation)
& !(previous_mask == Mask::Whitespace)
{
token.mask = Mask::Continuation;
previous_mask = Mask::Continuation;
} else {
previous_mask = Mask::None;
}
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
struct Symbol {
start_byte: usize,
end_byte: usize,
start_offset: usize,
end_offset: usize,
prev: isize,
next: isize,
size: usize,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
struct SymbolPair {
left: isize,
right: isize,
score: i64,
pair_size: usize,
}
impl Ord for SymbolPair {
fn cmp(&self, other: &Self) -> Ordering {
other
.score
.cmp(&self.score)
.then_with(|| other.left.cmp(&self.left))
}
}
impl PartialOrd for SymbolPair {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
struct SymbolList {
symbols: Vec<Option<Symbol>>,
}
impl Index<usize> for SymbolList {
type Output = Option<Symbol>;
fn index(&self, index: usize) -> &Option<Symbol> {
self.symbols.index(index)
}
}
impl IntoIterator for SymbolList {
type Item = Option<Symbol>;
type IntoIter = <Vec<Option<Symbol>> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.symbols.into_iter()
}
}
impl From<TokenRef<'_>> for SymbolList {
fn from(token: TokenRef) -> Self {
let mut symbols = Vec::with_capacity(token.text.len());
for (index, (character_start, character)) in token.text.char_indices().enumerate() {
let next = if index == token.text.char_indices().count() - 1 {
-1
} else {
(index + 1) as isize
};
symbols.push(Some(Symbol {
start_byte: character_start,
end_byte: character_start + character.len_utf8(),
start_offset: index,
end_offset: index + 1,
prev: index as isize - 1,
next,
size: 1,
}));
}
Self { symbols }
}
}
impl SymbolList {
pub fn len(&self) -> usize {
self.symbols.len()
}
pub fn merge_symbols(
&mut self,
symbol_1_index: usize,
symbol_2_index: usize,
size_validation: usize,
) -> Option<Symbol> {
if let (Some(left_symbol), Some(right_symbol)) =
(self[symbol_1_index], self[symbol_2_index])
{
if left_symbol.size + right_symbol.size != size_validation {
return None;
}
if right_symbol.next != -1 {
if let Some(next_next) = self.symbols.get_mut(right_symbol.next as usize).unwrap() {
next_next.prev = symbol_1_index as isize;
}
}
let new_symbol = Symbol {
start_byte: left_symbol.start_byte,
end_byte: right_symbol.end_byte,
start_offset: left_symbol.start_offset,
end_offset: right_symbol.end_offset,
prev: left_symbol.prev,
next: right_symbol.next,
size: left_symbol.size + right_symbol.size,
};
self.symbols[symbol_2_index] = None;
self.symbols[symbol_1_index] = Some(new_symbol);
Some(new_symbol)
} else {
None
}
}
}