use onig::Regex;
use serde::{Deserialize, Serialize};
use crate::tokenization::hf_tokenizers::tokenizer::{
pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum SplitPattern {
String(String),
Regex(String),
}
impl From<String> for SplitPattern {
fn from(v: String) -> Self {
SplitPattern::String(v)
}
}
impl From<&str> for SplitPattern {
fn from(v: &str) -> Self {
SplitPattern::String(v.to_owned())
}
}
#[doc(hidden)]
#[derive(Deserialize)]
#[serde(tag = "type")]
struct SplitDeserializer {
pattern: SplitPattern,
behavior: SplitDelimiterBehavior,
invert: bool,
}
impl std::convert::TryFrom<SplitDeserializer> for Split {
type Error = Box<dyn std::error::Error + Send + Sync>;
fn try_from(v: SplitDeserializer) -> Result<Self> {
Split::new(v.pattern, v.behavior, v.invert)
}
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", try_from = "SplitDeserializer")]
pub struct Split {
pattern: SplitPattern,
#[serde(skip)]
regex: Regex,
behavior: SplitDelimiterBehavior,
invert: bool,
}
impl Clone for Split {
fn clone(&self) -> Self {
Split::new(self.pattern.clone(), self.behavior, self.invert).unwrap()
}
}
impl PartialEq for Split {
fn eq(&self, other: &Split) -> bool {
self.pattern == other.pattern
&& self.behavior == other.behavior
&& self.invert == other.invert
}
}
impl Split {
pub fn new<I: Into<SplitPattern>>(
pattern: I,
behavior: SplitDelimiterBehavior,
invert: bool,
) -> Result<Self> {
let pattern: SplitPattern = pattern.into();
let regex = match &pattern {
SplitPattern::String(s) => Regex::new(®ex::escape(s))?,
SplitPattern::Regex(r) => Regex::new(r)?,
};
Ok(Self {
pattern,
regex,
behavior,
invert,
})
}
}
impl PreTokenizer for Split {
fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
if self.invert {
pretokenized.split(|_, normalized| normalized.split(Invert(&self.regex), self.behavior))
} else {
pretokenized.split(|_, normalized| normalized.split(&self.regex, self.behavior))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenization::hf_tokenizers::{OffsetReferential, OffsetType, PreTokenizer};
use SplitDelimiterBehavior::*;
#[test]
fn basic() {
let tests = vec![
(
Removed,
"How are you doing?",
vec![
("How", (0, 3)),
("are", (4, 7)),
("you", (8, 11)),
("doing", (12, 17)),
("?", (17, 18)),
],
),
(
Isolated,
"How are you doing?",
vec![
("How", (0, 3)),
(" ", (3, 4)),
("are", (4, 7)),
(" ", (7, 8)),
("you", (8, 11)),
(" ", (11, 12)),
("doing", (12, 17)),
("?", (17, 18)),
],
),
(
MergedWithPrevious,
"How are you doing?",
vec![
("How ", (0, 4)),
("are ", (4, 8)),
("you ", (8, 12)),
("doing", (12, 17)),
("?", (17, 18)),
],
),
(
MergedWithNext,
"How are you doing?",
vec![
("How", (0, 3)),
(" are", (3, 7)),
(" you", (7, 11)),
(" doing", (11, 17)),
("?", (17, 18)),
],
),
(
Contiguous,
"How are you doing?",
vec![
("How", (0, 3)),
(" ", (3, 4)),
("are", (4, 7)),
(" ", (7, 8)),
("you", (8, 11)),
(" ", (11, 12)),
("doing?", (12, 18)),
],
),
];
let regex = SplitPattern::Regex(r"\w+|[^\w\s]+".into());
for (behavior, s, res) in tests {
let mut pretokenized = PreTokenizedString::from(s);
let pretok = Split::new(regex.clone(), behavior, true).unwrap();
pretok.pre_tokenize(&mut pretokenized).unwrap();
assert_eq!(
pretokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, o, _)| (s, o))
.collect::<Vec<_>>(),
res
);
}
}
#[test]
fn regex_string() {
let mut pretok_str_for_regex = PreTokenizedString::from("Hey, man!");
let mut pretok_str_for_string = pretok_str_for_regex.clone();
let pretokenizer_regex = Split::new(
SplitPattern::Regex(r"\s+".into()),
SplitDelimiterBehavior::Removed,
false,
)
.unwrap();
let pretokenizer_string = Split::new(" ", SplitDelimiterBehavior::Removed, false).unwrap();
pretokenizer_regex
.pre_tokenize(&mut pretok_str_for_regex)
.unwrap();
pretokenizer_string
.pre_tokenize(&mut pretok_str_for_string)
.unwrap();
assert_eq!(pretok_str_for_regex, pretok_str_for_string);
}
#[test]
fn invert() {
let mut pretok_str = PreTokenizedString::from("Hello Hello Hello");
let mut pretok_str_for_invert = pretok_str.clone();
let pretokenizer = Split::new(" ", SplitDelimiterBehavior::Removed, false).unwrap();
let pretokenizer_invert =
Split::new("Hello", SplitDelimiterBehavior::Removed, true).unwrap();
pretokenizer.pre_tokenize(&mut pretok_str).unwrap();
pretokenizer_invert
.pre_tokenize(&mut pretok_str_for_invert)
.unwrap();
assert_eq!(pretok_str, pretok_str_for_invert);
}
#[test]
fn serialization() {
use SplitDelimiterBehavior::*;
let split = Split::new("Hello", Removed, true).unwrap();
let split_s =
r#"{"type":"Split","pattern":{"String":"Hello"},"behavior":"Removed","invert":true}"#;
assert_eq!(serde_json::to_string(&split).unwrap(), split_s);
assert_eq!(serde_json::from_str::<Split>(split_s).unwrap(), split);
let split = Split::new(SplitPattern::Regex(r"\s+".into()), Isolated, false).unwrap();
let split_s =
r#"{"type":"Split","pattern":{"Regex":"\\s+"},"behavior":"Isolated","invert":false}"#;
assert_eq!(serde_json::to_string(&split).unwrap(), split_s);
assert_eq!(serde_json::from_str::<Split>(split_s).unwrap(), split);
}
}