use serde::{Deserialize, Serialize};
use crate::tokenization::hf_tokenizers::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(tag = "type")]
pub struct Metaspace {
replacement: char,
str_rep: String,
pub add_prefix_space: bool,
}
impl Metaspace {
pub fn new(replacement: char, add_prefix_space: bool) -> Self {
Self {
replacement,
str_rep: replacement.to_string(),
add_prefix_space,
}
}
pub fn get_replacement(&self) -> char {
self.replacement
}
pub fn set_replacement(&mut self, replacement: char) {
self.replacement = replacement;
self.str_rep = replacement.to_string();
}
}
impl Default for Metaspace {
fn default() -> Self {
Self::new('▁', true)
}
}
impl PreTokenizer for Metaspace {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
pretokenized.split(|_, mut normalized| {
normalized.replace(' ', &self.str_rep)?;
if self.add_prefix_space && !normalized.get().starts_with(self.replacement) {
normalized.prepend(&self.str_rep);
}
normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
})
}
}
impl Decoder for Metaspace {
fn decode(&self, tokens: Vec<String>) -> Result<String> {
Ok(tokens
.iter()
.flat_map(|t| t.chars())
.enumerate()
.filter_map(|(i, c)| {
if c == self.replacement {
if i == 0 && self.add_prefix_space {
None
} else {
Some(' ')
}
} else {
Some(c)
}
})
.collect::<String>())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenization::hf_tokenizers::{OffsetReferential, OffsetType};
#[test]
fn basic() {
let pretok = Metaspace::new('▁', true);
let mut pretokenized = PreTokenizedString::from("Hey friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![("▁Hey", (0, 6)), ("▁friend!", (6, 16))]
);
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![("▁Hey", (0, 3)), ("▁friend!", (3, 11))]
);
}
#[test]
fn multiple_spaces() {
let pretok = Metaspace::new('▁', true);
let mut pretokenized = PreTokenizedString::from("Hey friend!");
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Normalized, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 6)),
("▁", (6, 9)),
("▁", (9, 12)),
("▁friend!", (12, 22)),
]
);
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
vec![
("▁Hey", (0, 3)),
("▁", (3, 4)),
("▁", (4, 5)),
("▁friend!", (5, 13)),
]
);
}
#[test]
fn decode() {
let decoder = Metaspace::new('▁', true);
let res = decoder
.decode(vec!["▁Hey".into(), "▁friend!".into()])
.unwrap();
assert_eq!(&res, "Hey friend!")
}
}