use logos::Logos;
use super::gpt2_family::{
Gpt2FamilyLogos,
Gpt2FamilyTokenRole,
};
use crate::pretrained::openai::OA_O200K_BASE_PATTERN;
#[derive(Logos, Debug, PartialEq, Clone)]
pub(crate) enum O200kToken {
#[regex(
r"[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'(?:s|t|d|m|re|ve|ll))?",
priority = 4
)]
WordLower,
#[regex(
r"[^\r\n\p{Letter}\p{Number}\p{Mark}][\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'(?:s|t|d|m|re|ve|ll))?",
priority = 3
)]
PrefixedWordLower,
#[regex(
r"[\p{Lu}\p{Lt}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'(?:s|t|d|m|re|ve|ll))?",
priority = 2
)]
WordUpper,
#[regex(
r"[^\r\n\p{Letter}\p{Number}\p{Mark}][\p{Lu}\p{Lt}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'(?:s|t|d|m|re|ve|ll))?",
priority = 1
)]
PrefixedWordUpper,
#[regex(r"\p{Number}{1,3}")]
Digits,
#[regex(r" [^\s\p{Letter}\p{Number}\p{Mark}][^\s\p{Letter}\p{Number}]*[\r\n/]*")]
PunctuationSpaced,
#[regex(r"[^\s\p{Letter}\p{Number}\p{Mark}](?:[^\s\p{Letter}\p{Number}\p{Mark}][^\s\p{Letter}\p{Number}]*)?[\r\n/]*")]
PunctuationBare,
#[regex(r"\s*[\r\n]+")]
Newline,
#[regex(r"[^\S\r\n]+")]
Whitespace,
}
impl Gpt2FamilyLogos<'_> for O200kToken {
fn family_role(&self) -> Gpt2FamilyTokenRole {
match self {
Self::Whitespace => Gpt2FamilyTokenRole::Whitespace,
Self::WordLower | Self::WordUpper => Gpt2FamilyTokenRole::Word {
check_contraction: false,
first_char_is_letter: true,
},
Self::PrefixedWordLower | Self::PrefixedWordUpper => Gpt2FamilyTokenRole::Word {
check_contraction: false,
first_char_is_letter: false,
},
Self::PunctuationSpaced | Self::PunctuationBare => Gpt2FamilyTokenRole::Punctuation,
Self::Digits | Self::Newline => Gpt2FamilyTokenRole::Standalone,
}
}
}
logos_lexer! {
pub struct O200kLexer;
token = O200kToken;
pattern = OA_O200K_BASE_PATTERN;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::{
sync::Arc,
vec,
vec::Vec,
},
spanners::{
SpanRef,
TextSpanner,
span_lexers::{
LexerTextSpanner,
SpanLexer,
},
},
};
fn spanner(lexer: impl SpanLexer + 'static) -> LexerTextSpanner {
LexerTextSpanner::new(Arc::new(lexer), None)
}
#[test]
fn test_o200k_contractions_attached() {
let s = spanner(O200kLexer);
let text = "don't I'll she's";
let spans = s.split_spans(text, None);
let words: Vec<&str> = spans
.iter()
.filter_map(|s| match s {
SpanRef::Word(r) => Some(&text[r.clone()]),
_ => None,
})
.collect();
assert!(
words.contains(&"don't"),
"expected \"don't\" as one token, got: {:?}",
words
);
assert!(
words.contains(&" she's"),
"expected \" she's\" as one token, got: {:?}",
words
);
}
#[test]
fn test_o200k_common() {
crate::spanners::span_lexers::logos::testutil::common_lexer_tests(
crate::alloc::boxed::Box::new(O200kLexer),
);
}
#[test]
fn test_o200k_camel_case() {
let s = spanner(O200kLexer);
let spans = s.split_spans("CamelCase", None);
let words: Vec<&str> = spans
.iter()
.filter_map(|s| match s {
SpanRef::Word(r) => Some(&"CamelCase"[r.clone()]),
_ => None,
})
.collect();
assert_eq!(words, &["Camel", "Case"]);
let text = "getElementById";
let spans = s.split_spans(text, None);
let words: Vec<&str> = spans
.iter()
.filter_map(|s| match s {
SpanRef::Word(r) => Some(&text[r.clone()]),
_ => None,
})
.collect();
assert_eq!(words, &["get", "Element", "By", "Id"]);
let text = "HTMLParser";
let spans = s.split_spans(text, None);
let words: Vec<&str> = spans
.iter()
.filter_map(|s| match s {
SpanRef::Word(r) => Some(&text[r.clone()]),
_ => None,
})
.collect();
assert_eq!(words, &["HTMLParser"]);
}
#[test]
fn test_mark_groups_with_punctuation() {
let s = spanner(O200kLexer);
assert_eq!(
s.split_spans(" !\u{0300}a", None),
vec![
SpanRef::Word(0..1), SpanRef::Word(1..5), SpanRef::Word(5..6), ]
);
assert_eq!(
s.split_spans(" !\u{0300}\r", None),
vec![
SpanRef::Word(0..1), SpanRef::Word(1..6), ]
);
assert_eq!(
s.split_spans(" !\u{0300}!\u{0300}", None),
vec![
SpanRef::Word(0..1), SpanRef::Word(1..8), ]
);
assert_eq!(
s.split_spans(" !\u{0300}!A", None),
vec![
SpanRef::Word(0..1), SpanRef::Word(1..6), SpanRef::Word(6..7), ]
);
assert_eq!(
s.split_spans(" !\u{0300} A", None),
vec![
SpanRef::Word(0..1), SpanRef::Word(1..5), SpanRef::Word(5..7), ]
);
}
}