use core::ops::Range;
use crate::{
alloc::sync::Arc,
spanners::{
SpanRef,
TextSpanner,
span_lexers::SpanLexer,
},
support::ranges::offset_range,
vocab::SpecialFilter,
};
#[derive(Clone)]
pub struct LexerTextSpanner {
word_lexer: Arc<dyn SpanLexer>,
special_lexer: Option<Arc<dyn SpanLexer>>,
}
impl LexerTextSpanner {
pub fn new(
word_scanner: Arc<dyn SpanLexer>,
special_scanner: Option<Arc<dyn SpanLexer>>,
) -> Self {
Self {
word_lexer: word_scanner,
special_lexer: special_scanner,
}
}
fn next_special_span(
&self,
text: &str,
special_filter: Option<&SpecialFilter>,
) -> Option<Range<usize>> {
self.special_lexer.as_ref().and_then(|lexer| {
lexer
.find_span_iter(text)
.find(|range| match special_filter {
None => true,
Some(filter) => filter.contains(&text[range.clone()]),
})
})
}
fn for_each_word(
&self,
text: &str,
offset: usize,
f: &mut dyn FnMut(SpanRef) -> bool,
) -> (bool, usize) {
let mut last = 0;
for Range { start, end } in self.word_lexer.find_span_iter(text) {
if last < start {
if !f(SpanRef::Gap(offset_range::<usize>(last..start, offset))) {
return (false, last);
}
last = start;
}
if !f(SpanRef::Word(offset_range::<usize>(start..end, offset))) {
return (false, last);
}
last = end;
}
if last < text.len() {
if !f(SpanRef::Gap(offset_range::<usize>(
last..text.len(),
offset,
))) {
return (false, last);
}
last = text.len();
}
(true, last)
}
}
impl TextSpanner for LexerTextSpanner {
fn for_each_split_span(
&self,
text: &str,
special_filter: Option<&SpecialFilter>,
f: &mut dyn FnMut(SpanRef) -> bool,
) -> (bool, usize) {
let mut current = text;
let mut offset = 0;
while let Some(Range { start, end }) = self.next_special_span(current, special_filter) {
let pre = ¤t[..start];
let (cont, used) = self.for_each_word(pre, offset, f);
if !cont {
return (false, offset + used);
}
if !f(SpanRef::Special(offset_range::<usize>(start..end, offset))) {
return (false, offset + start);
}
current = ¤t[end..];
offset += end;
}
self.for_each_word(current, offset, f)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
TokenType,
WCHashSet,
alloc::{
boxed::Box,
vec,
vec::Vec,
},
prelude::*,
pretrained::openai::OA_CL100K_BASE_PATTERN,
spanners::{
SpanRef,
TextSpanningConfig,
},
};
const _LEXER_SPANNER_BOX_CHECK: Option<Box<LexerTextSpanner>> = None;
const _LEXER_SPANNER_ARC_CHECK: Option<Arc<LexerTextSpanner>> = None;
fn from_config<T: TokenType>(config: &TextSpanningConfig<T>) -> LexerTextSpanner {
LexerTextSpanner::new(
Arc::new(config.pattern().clone().compile().unwrap()),
config
.special_pattern()
.map(|p| Arc::new(p.compile().unwrap()) as Arc<dyn SpanLexer>),
)
}
#[test]
fn test_special_filter() {
use crate::spanners::text_spanner::SpanRef::*;
type T = u32;
let config: TextSpanningConfig<T> = TextSpanningConfig::from_pattern(r"\w+")
.with_special_words([("<|FNORD|>", 4000), ("<|NORP|>", 4001)]);
let spanner = from_config(&config);
let mut allowed = WCHashSet::default();
allowed.insert("<|NORP|>".to_string());
let allowed = SpecialFilter::Include(allowed);
let source = "abc 1<|FNORD|> def <|NORP|> ghi ";
assert_eq!(
spanner.split_spans(source, None),
vec![
Word(0..3),
Gap(3..4),
Word(4..5),
Special(5..14),
Gap(14..15),
Word(15..18),
Gap(18..20),
Special(20..28),
Gap(28..29),
Word(29..32),
Gap(32..35),
]
);
assert_eq!(
spanner.split_spans(source, Some(&allowed)),
vec![
Word(0..3),
Gap(3..4),
Word(4..5),
Gap(5..7),
Word(7..12),
Gap(12..15),
Word(15..18),
Gap(18..20),
Special(20..28),
Gap(28..29),
Word(29..32),
Gap(32..35),
]
);
assert_eq!(
spanner.split_spans(source, Some(&SpecialFilter::IncludeNone)),
vec![
Word(0..3),
Gap(3..4),
Word(4..5),
Gap(5..7),
Word(7..12),
Gap(12..15),
Word(15..18),
Gap(18..22),
Word(22..26),
Gap(26..29),
Word(29..32),
Gap(32..35),
]
);
}
#[test]
fn test_for_each_split_span() {
use crate::spanners::text_spanner::SpanRef::*;
type T = u32;
let config: TextSpanningConfig<T> = TextSpanningConfig::from_pattern(r"\w+")
.with_special_words([("<|FNORD|>", 4000), ("<|NORP|>", 4001)]);
let spanner = from_config(&config);
let source = "abc 1<|FNORD|> def <|NORP|> ghi ";
let mut spans: Vec<SpanRef> = Vec::new();
spanner.for_each_split_span(source, None, &mut |span_ref| {
spans.push(span_ref);
true
});
assert_eq!(
spans,
vec![
Word(0..3),
Gap(3..4),
Word(4..5),
Special(5..14),
Gap(14..15),
Word(15..18),
Gap(18..20),
Special(20..28),
Gap(28..29),
Word(29..32),
Gap(32..35),
]
);
let mut spans: Vec<SpanRef> = Vec::new();
spanner.for_each_split_span(" abc", None, &mut |span_ref| match span_ref {
Word(_) => false,
_ => {
spans.push(span_ref);
true
}
});
assert_eq!(spans, vec![Gap(0..3)]);
let mut spans: Vec<SpanRef> = Vec::new();
spanner.for_each_split_span("abc def<|FNORD|>", None, &mut |span_ref| match span_ref {
Special(_) => false,
_ => {
spans.push(span_ref);
true
}
});
assert_eq!(spans, vec![Word(0..3), Gap(3..6), Word(6..9)]);
let mut spans: Vec<SpanRef> = Vec::new();
spanner.for_each_split_span("abc def", None, &mut |span_ref| match span_ref {
Gap(_) => false,
_ => {
spans.push(span_ref);
true
}
});
assert_eq!(spans, vec![Word(0..3)]);
let mut spans: Vec<SpanRef> = Vec::new();
spanner.for_each_split_span("foo ", None, &mut |span_ref| match span_ref {
Gap(_) => false,
_ => {
spans.push(span_ref);
true
}
});
assert_eq!(spans, vec![Word(0..3)]);
}
#[test]
fn test_split_words() {
type T = u32;
let config: TextSpanningConfig<T> =
TextSpanningConfig::from_pattern(OA_CL100K_BASE_PATTERN)
.with_special_words([("<|FNORD|>", 4000), ("<|NORP|>", 4001)]);
let spanner = from_config(&config);
let buf = "hello<|FNORD|> wor<|NORP|>ld!";
assert_eq!(
&spanner.split_spans(buf, None),
&vec![
SpanRef::Word(0..5),
SpanRef::Special(5..14),
SpanRef::Word(14..18),
SpanRef::Special(18..26),
SpanRef::Word(26..28),
SpanRef::Word(28..buf.len()),
]
);
}
#[test]
fn test_rewrite() {
type T = u32;
let config: TextSpanningConfig<T> = TextSpanningConfig::from_pattern(r"\w+");
let spanner = from_config(&config);
let buf = vec!["hello world!", "abc def"];
assert_eq!(
spanner.batch_remove_gaps(&buf, None),
vec!["helloworld", "abcdef"]
);
}
}