use super::{
normalizer::Range, Model, NormalizedString, Normalizer, Offsets, PreTokenizedString, Token,
};
use serde::{ser::SerializeSeq, Deserialize, Serialize, Serializer};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AddedToken {
pub content: String,
pub single_word: bool,
pub lstrip: bool,
pub rstrip: bool,
pub normalized: bool,
}
impl AddedToken {
pub fn from<S: Into<String>>(content: S, special: bool) -> Self {
AddedToken {
content: content.into(),
normalized: !special,
..Default::default()
}
}
pub fn single_word(mut self, single_word: bool) -> Self {
self.single_word = single_word;
self
}
pub fn lstrip(mut self, lstrip: bool) -> Self {
self.lstrip = lstrip;
self
}
pub fn rstrip(mut self, rstrip: bool) -> Self {
self.rstrip = rstrip;
self
}
pub fn normalized(mut self, normalized: bool) -> Self {
self.normalized = normalized;
self
}
pub fn get_pattern<N: Normalizer>(&self, normalizer: Option<&N>) -> String {
let mut r = if self.single_word {
let first_b = self
.content
.chars()
.next()
.map(|c| {
if regex_syntax::is_word_character(c) {
r"\b"
} else {
""
}
})
.unwrap();
let last_b = self
.content
.chars()
.last()
.map(|c| {
if regex_syntax::is_word_character(c) {
r"\b"
} else {
""
}
})
.unwrap();
let mut content = NormalizedString::from(self.content.as_ref());
normalizer.map(|n| n.normalize(&mut content));
format!(r"{}{}{}", first_b, regex::escape(content.get()), last_b)
} else {
regex::escape(&self.content)
};
if self.lstrip && self.rstrip {
r = format!(r"(\s)?{}(\s)?", r);
} else if self.lstrip {
r = format!(r"(\s)?{}", r);
} else if self.rstrip {
r = format!(r"{}(\s)?", r);
}
r
}
}
impl Default for AddedToken {
fn default() -> Self {
AddedToken {
content: String::new(),
single_word: false,
lstrip: false,
rstrip: false,
normalized: true,
}
}
}
impl std::hash::Hash for AddedToken {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.content.hash(state);
}
}
impl std::cmp::PartialEq for AddedToken {
fn eq(&self, other: &Self) -> bool {
self.content == other.content
}
}
impl std::cmp::Eq for AddedToken {}
type MatchingSet = (regex::RegexSet, Vec<u32>);
#[derive(Clone, Debug)]
pub(super) struct AddedVocabulary {
added_tokens_map: HashMap<String, u32>,
added_tokens_map_r: HashMap<u32, AddedToken>,
added_tokens: Vec<AddedToken>,
special_tokens: Vec<AddedToken>,
special_tokens_set: HashSet<String>,
split_re: MatchingSet,
split_normalized_re: MatchingSet,
}
impl AddedVocabulary {
pub fn new() -> Self {
Self {
added_tokens_map: HashMap::new(),
added_tokens_map_r: HashMap::new(),
added_tokens: vec![],
special_tokens: vec![],
special_tokens_set: HashSet::new(),
split_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
split_normalized_re: (regex::RegexSet::new::<_, &&str>(&[]).unwrap(), vec![]),
}
}
pub fn len(&self) -> usize {
self.added_tokens_map.len()
}
pub fn get_vocab(&self) -> &HashMap<String, u32> {
&self.added_tokens_map
}
pub fn token_to_id1(&self, token: &str, model: &impl Model) -> Option<u32> {
self.added_tokens_map
.get(token)
.copied()
.or_else(|| model.token_to_id(token))
}
pub fn id_to_token(&self, id: u32, model: &impl Model) -> Option<String> {
self.added_tokens_map_r
.get(&id)
.map(|t| t.content.clone())
.or_else(|| model.id_to_token(id))
}
pub fn is_special_token(&self, token: &str) -> bool {
self.special_tokens_set.contains(token)
}
pub fn add_special_tokens<N: Normalizer>(
&mut self,
tokens: &[AddedToken],
model: &impl Model,
normalizer: Option<&N>,
) -> usize {
for token in tokens {
if !token.content.is_empty() && !self.special_tokens_set.contains(&token.content) {
self.special_tokens.push(token.to_owned());
self.special_tokens_set.insert(token.content.clone());
}
}
self.add_tokens(&tokens, model, normalizer)
}
pub fn add_tokens<N: Normalizer>(
&mut self,
tokens: &[AddedToken],
model: &impl Model,
normalizer: Option<&N>,
) -> usize {
let mut ignored = 0;
for token in tokens {
if token.content.is_empty() {
ignored += 1;
continue;
}
let id = if let Some(id) = self.token_to_id1(&token.content, model) {
ignored += 1;
id
} else {
let new_id = (model.get_vocab_size() + self.added_tokens_map.len()) as u32;
self.added_tokens_map.insert(token.content.clone(), new_id);
if !self.special_tokens_set.contains(&token.content) {
self.added_tokens.push(token.clone());
}
new_id
};
self.added_tokens_map_r
.entry(id)
.and_modify(|t| *t = token.clone())
.or_insert_with(|| token.clone());
}
self.refresh_added_tokens(model, normalizer);
tokens.len() - ignored
}
fn refresh_added_tokens<N: Normalizer>(&mut self, model: &impl Model, normalizer: Option<&N>) {
type TupleTokenId<'a> = (&'a AddedToken, u32);
let (normalized, non_normalized): (Vec<TupleTokenId>, Vec<TupleTokenId>) = self
.special_tokens
.iter()
.chain(self.added_tokens.iter())
.map(|token| {
(
token,
self.token_to_id1(&token.content, model)
.expect("Missing additional token"),
)
})
.partition(|(token, _)| token.normalized);
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = non_normalized.into_iter().unzip();
self.split_re = (
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
ids,
);
let (tokens, ids): (Vec<&AddedToken>, Vec<u32>) = normalized.into_iter().unzip();
self.split_normalized_re = (
regex::RegexSet::new(tokens.iter().map(|t| t.get_pattern(normalizer))).unwrap(),
ids,
);
}
fn find_matches<'a>(
&self,
sentence: &str,
split_re: &'a MatchingSet,
) -> Vec<(Option<u32>, Offsets)> {
if sentence.is_empty() {
return vec![(None, (0, 0))];
}
let mut matches = split_re
.0
.matches(sentence)
.into_iter()
.flat_map(|idx| {
regex::Regex::new(&split_re.0.patterns()[idx])
.unwrap()
.find_iter(sentence)
.map(|m| (idx, (m.start(), m.end())))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
matches.sort_by(
|(idxa, (sa, _)), (idxb, (sb, _))| {
if sa != sb {
sa.cmp(sb)
} else {
idxa.cmp(idxb)
}
},
);
let mut i = 0;
let mut current_offset = 0;
let mut splits = Vec::with_capacity(matches.len());
while i < matches.len() {
let (idx, (start, end)) = matches[i];
if start < current_offset {
i += 1;
continue;
}
if i + 1 < matches.len() {
if let Some((idx, (s, e))) = matches[i..]
.iter()
.take_while(|(_, (s, e))| *s < end && start < *e)
.min() .copied()
{
splits.push((idx, (s, e)));
current_offset = e;
i += 1;
continue;
}
}
splits.push((idx, (start, end)));
current_offset = end;
i += 1;
}
let mut start_offset = 0;
let mut splits = splits
.into_iter()
.flat_map(|(idx, (start, end))| {
let mut splits = vec![];
if start_offset < start {
splits.push((None, (start_offset, start)));
}
splits.push((Some(split_re.1[idx] as u32), (start, end)));
start_offset = end;
splits
})
.collect::<Vec<_>>();
let total_byte_len = sentence.len();
if start_offset != total_byte_len {
splits.push((None, (start_offset, total_byte_len)));
}
splits
}
fn split_with_indices(
&self,
sentence: NormalizedString,
split_re: &MatchingSet,
) -> Vec<(NormalizedString, Option<Vec<Token>>)> {
self.find_matches(sentence.get(), split_re)
.into_iter()
.map(|(id, byte_offsets)| {
let slice = sentence
.slice(Range::Normalized(byte_offsets.0..byte_offsets.1))
.expect("AddedVocabulary bad split");
if let Some(id) = id {
let value = slice.get().to_owned();
let len = value.len();
(slice, Some(vec![Token::new(id, value, (0, len))]))
} else {
(slice, None)
}
})
.collect()
}
pub fn extract_and_normalize<N: Normalizer>(
&self,
normalizer: Option<&N>,
sequence: &str,
) -> PreTokenizedString {
let mut pretokenized: PreTokenizedString = sequence.into();
pretokenized
.split(|_, sequence| Ok(self.split_with_indices(sequence, &self.split_re)))
.expect("AddedVocabulary bad split");
pretokenized
.split(|_, mut sequence| {
normalizer.map(|n| n.normalize(&mut sequence));
Ok(self.split_with_indices(sequence, &self.split_normalized_re))
})
.expect("AddedVocabulary bad split");
pretokenized
}
}
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct AddedTokenWithId {
pub id: u32,
pub special: bool,
#[serde(flatten)]
pub token: AddedToken,
}
impl Serialize for AddedVocabulary {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut added_tokens = self
.added_tokens_map_r
.iter()
.map(|(id, token)| AddedTokenWithId {
id: *id,
special: self.special_tokens_set.contains(&token.content),
token: token.clone(),
})
.collect::<Vec<_>>();
added_tokens.sort_unstable_by_key(|o| o.id);
let mut vocabulary = serializer.serialize_seq(Some(added_tokens.len()))?;
for token in added_tokens {
vocabulary.serialize_element(&token)?;
}
vocabulary.end()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenization::hf_tokenizers::normalizers::utils::Lowercase;
use crate::tokenization::hf_tokenizers::normalizers::NormalizerWrapper;
use crate::tokenization::hf_tokenizers::{OffsetReferential, OffsetType, Result, Token, Trainer};
use std::path::{Path, PathBuf};
#[derive(Serialize, Deserialize)]
struct ModelMock {
vocab: HashMap<String, u32>,
vocab_r: HashMap<u32, String>,
}
impl ModelMock {
pub fn new<I>(iter: I) -> Self
where
I: IntoIterator<Item = &'static (&'static str, u32)>,
{
let vocab: HashMap<String, u32> = iter
.into_iter()
.map(|&(tok, id)| (tok.to_string(), id))
.collect();
Self {
vocab_r: vocab
.iter()
.map(|(tok, id)| (*id, tok.to_owned()))
.collect(),
vocab,
}
}
}
struct TrainerMock;
impl Trainer for TrainerMock {
type Model = ModelMock;
fn should_show_progress(&self) -> bool {
true
}
fn train(&self, _model: &mut ModelMock) -> Result<Vec<AddedToken>> {
unimplemented!()
}
fn feed<I, S, F>(&mut self, _iterator: I, _process: F) -> Result<()>
where
I: Iterator<Item = S> + Send,
S: AsRef<str> + Send,
F: Fn(&str) -> Result<Vec<String>> + Sync,
{
unimplemented!()
}
}
impl Model for ModelMock {
type Trainer = TrainerMock;
fn tokenize(&self, _sequence: &str) -> Result<Vec<Token>> {
unimplemented!()
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.vocab_r.get(&id).cloned()
}
fn get_vocab(&self) -> HashMap<String, u32> {
self.vocab.clone()
}
fn get_vocab_size(&self) -> usize {
self.vocab.len()
}
fn save(&self, _folder: &Path, _name: Option<&str>) -> Result<Vec<PathBuf>> {
unimplemented!()
}
fn get_trainer(&self) -> Self::Trainer {
TrainerMock
}
}
#[test]
fn can_add_tokens() {
let model = ModelMock::new(&[("test", 0), ("tost", 1)]);
let mut vocab = AddedVocabulary::new();
let normalizer: Option<&NormalizerWrapper> = None;
assert_eq!(
vocab.add_tokens(
&[AddedToken::from("added_token_1", false)],
&model,
normalizer
),
1
);
assert_eq!(vocab.len(), 1);
assert_eq!(
vocab.add_tokens(
&[
AddedToken::from("added_token_2", false),
AddedToken::from("added_token_2", false)
],
&model,
normalizer
),
1
);
assert_eq!(vocab.len(), 2);
assert_eq!(
vocab.add_tokens(&[AddedToken::from("test", false)], &model, normalizer),
0
);
assert_eq!(vocab.len(), 2);
}
#[test]
fn can_add_special_tokens() {
let model = ModelMock::new(&[("test", 0), ("tost", 1)]);
let mut vocab = AddedVocabulary::new();
let normalizer: Option<&NormalizerWrapper> = None;
assert_eq!(
vocab.add_special_tokens(
&[AddedToken::from("added_token_1", true)],
&model,
normalizer
),
1
);
assert_eq!(vocab.len(), 1);
assert_eq!(
vocab.add_special_tokens(
&[
AddedToken::from("added_token_2", true),
AddedToken::from("added_token_2", true)
],
&model,
normalizer
),
1
);
assert_eq!(vocab.len(), 2);
assert_eq!(
vocab.add_special_tokens(&[AddedToken::from("test", true)], &model, normalizer),
0
);
assert_eq!(vocab.len(), 2); assert_eq!(vocab.is_special_token("test"), true);
assert_eq!(vocab.added_tokens_map.contains_key("test"), false);
}
#[test]
fn can_extract_added_tokens() {
let model = ModelMock::new(&[]);
let mut vocab = AddedVocabulary::new();
let normalizer: Option<&NormalizerWrapper> = None;
vocab.add_tokens(
&[
AddedToken::from("my", false),
AddedToken::from("name", false),
],
&model,
normalizer,
);
vocab.add_special_tokens(
&[
AddedToken::from("[CLS]", true),
AddedToken::from("[SEP]", true),
],
&model,
normalizer,
);
let result = vocab.extract_and_normalize(normalizer, "[CLS] My name is Anthony [SEP]");
assert_eq!(
result
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, _, tokens)| (
s,
tokens
.as_ref()
.map(|t| t.iter().map(|t| t.id).collect::<Vec<_>>())
))
.collect::<Vec<_>>(),
vec![
("[CLS]", Some(vec![2])),
(" My ", None),
("name", Some(vec![1])),
(" is Anthony ", None),
("[SEP]", Some(vec![3]))
]
);
}
#[test]
fn options_use_cases() {
let model = ModelMock::new(&[]);
let normalizer = Lowercase;
let mut vocab = AddedVocabulary::new();
vocab.add_tokens(
&[
AddedToken::from("my", false).lstrip(true).rstrip(true),
AddedToken::from("name", false),
AddedToken::from("ony", false).single_word(true),
],
&model,
Some(&normalizer),
);
vocab.add_special_tokens(
&[
AddedToken::from("[CLS]", true),
AddedToken::from("[SEP]", true),
],
&model,
Some(&normalizer),
);
let result =
vocab.extract_and_normalize(Some(&normalizer), "[CLS] My name is Anthony [SEP]");
assert_eq!(
result
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, _, tokens)| (
s,
tokens
.as_ref()
.map(|t| t.iter().map(|t| t.id).collect::<Vec<_>>())
))
.collect::<Vec<_>>(),
vec![
("[CLS]", Some(vec![3])),
(" my ", Some(vec![0])),
("name", Some(vec![1])),
(" is anthony ", None),
("[SEP]", Some(vec![4])),
]
);
}
#[test]
fn empty_matches() {
let vocab = AddedVocabulary::new();
let matches = vocab.find_matches("", &vocab.split_re);
assert_eq!(matches, vec![(None, (0, 0))]);
}
}