use regex_automata::{meta, util::captures::Captures, Anchored, Input};
pub struct Regex {
inner: meta::Regex,
lookahead: Vec<bool>,
}
impl std::fmt::Debug for Regex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Regex")
.field("patterns", &self.inner.pattern_len())
.finish()
}
}
impl Regex {
pub fn new(patterns: &[(&str, bool)]) -> Result<Self, meta::BuildError> {
let pats: Vec<&str> = patterns.iter().map(|(p, _)| *p).collect();
let lookahead: Vec<bool> = patterns.iter().map(|(_, l)| *l).collect();
let inner = meta::Regex::new_many(&pats)?;
Ok(Self { inner, lookahead })
}
pub fn cl100k() -> Self {
Self::new(&[
(r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+$", false),
(r"\s+\s", true),
(r"\s+", false),
]).expect("valid cl100k pattern")
}
pub fn o200k() -> Self {
let pat1 = [
r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
r"\p{N}{1,3}",
r" ?[^\s\p{L}\p{N}]+[\r\n/]*",
r"\s*[\r\n]+",
r"\s+$",
]
.join("|");
Self::new(&[(&pat1, false), (r"\s+\s", true), (r"\s+", false)])
.expect("valid o200k pattern")
}
pub fn gpt2() -> Self {
Self::new(&[
(
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+$",
false,
),
(r"\s+\s", true),
(r"\s+", false),
])
.expect("valid GPT-2 pattern")
}
pub fn split<'a>(&'a self, text: &'a str) -> RegexIter<'a> {
RegexIter {
pretokenizer: self,
text,
pos: 0,
caps: Captures::matches(self.inner.group_info().clone()),
}
}
pub fn split_to_vec<'a>(&'a self, text: &'a str) -> Vec<&'a str> {
self.split(text).collect()
}
}
pub struct RegexIter<'a> {
pretokenizer: &'a Regex,
text: &'a str,
pos: usize,
caps: Captures,
}
impl<'a> Iterator for RegexIter<'a> {
type Item = &'a str;
fn next(&mut self) -> Option<Self::Item> {
if self.pos >= self.text.len() {
return None;
}
let input = Input::new(&self.text[self.pos..]).anchored(Anchored::Yes);
self.caps.clear();
self.pretokenizer.inner.captures(input, &mut self.caps);
let m = self.caps.get_match()?;
let start = self.pos;
let mut end = self.pos + m.range().end;
if self.pretokenizer.lookahead[m.pattern().as_usize()] {
if let Some(last_char) = self.text[start..end].chars().next_back() {
end -= last_char.len_utf8();
}
}
if end <= start {
self.pos = start + 1;
return self.next();
}
self.pos = end;
Some(&self.text[start..end])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cl100k_numbers() {
let pre = Regex::cl100k();
let tokens: Vec<_> = pre.split("12345").collect();
assert_eq!(tokens, vec!["123", "45"]);
}
#[test]
fn test_gpt2_basic() {
let pre = Regex::gpt2();
let tokens: Vec<_> = pre.split("Hello world").collect();
assert_eq!(tokens, vec!["Hello", " world"]);
}
#[test]
fn test_gpt2_contractions() {
let pre = Regex::gpt2();
let tokens: Vec<_> = pre.split("How's it going?").collect();
assert_eq!(tokens, vec!["How", "'s", " it", " going", "?"]);
}
#[test]
fn test_o200k_camelcase() {
let o200k = Regex::o200k();
assert_eq!(o200k.split("CamelCase").collect::<Vec<_>>(), vec!["Camel", "Case"]);
assert_eq!(o200k.split("JSONParser").collect::<Vec<_>>(), vec!["JSONParser"]);
assert_eq!(o200k.split("parseJSON").collect::<Vec<_>>(), vec!["parse", "JSON"]);
assert_eq!(o200k.split("XMLHttpRequest").collect::<Vec<_>>(), vec!["XMLHttp", "Request"]);
assert_eq!(o200k.split("don't").collect::<Vec<_>>(), vec!["don't"]);
assert_eq!(o200k.split("Hello world").collect::<Vec<_>>(), vec!["Hello", " world"]);
}
}