use core::ops::Range;
use crate::{
alloc::{
string::String,
vec::Vec,
},
vocab::{
DEFAULT_BYTE_PER_TOKEN_RATIO,
SpecialFilter,
},
};
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum SpanRef {
Word(Range<usize>),
Special(Range<usize>),
Gap(Range<usize>),
}
impl SpanRef {
pub fn range(&self) -> &Range<usize> {
match self {
SpanRef::Word(range) => range,
SpanRef::Special(range) => range,
SpanRef::Gap(range) => range,
}
}
}
impl From<SpanRef> for Range<usize> {
fn from(span: SpanRef) -> Self {
span.range().clone()
}
}
pub trait TextSpanner: Send + Sync {
fn expected_bytes_per_span(&self) -> f32 {
DEFAULT_BYTE_PER_TOKEN_RATIO
}
fn expected_span_count(
&self,
text: &str,
) -> usize {
text.len() / self.expected_bytes_per_span() as usize
}
fn for_each_split_span(
&self,
text: &str,
special_filter: Option<&SpecialFilter>,
f: &mut dyn FnMut(SpanRef) -> bool,
) -> (bool, usize);
fn split_spans(
&self,
text: &str,
special_filter: Option<&SpecialFilter>,
) -> Vec<SpanRef> {
let capacity = self.expected_span_count(text) * 115 / 100;
let mut words = Vec::with_capacity(capacity);
self.for_each_split_span(text, special_filter, &mut |span_ref| {
words.push(span_ref);
true
});
words
}
fn remove_gaps(
&self,
text: &str,
special_filter: Option<&SpecialFilter>,
) -> String {
self.split_spans(text, special_filter)
.into_iter()
.filter_map(|m| match m {
SpanRef::Gap(_) => None,
_ => Some(&text[Range::<usize>::from(m)]),
})
.collect()
}
fn batch_remove_gaps(
&self,
texts: &[&str],
special_filter: Option<&SpecialFilter>,
) -> Vec<String> {
texts
.iter()
.map(|t| self.remove_gaps(t, special_filter))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alloc::{
boxed::Box,
sync::Arc,
};
const _TEXT_SPANNER_BOX_CHECK: Option<Box<dyn TextSpanner>> = None;
const _TEXT_SPANNER_ARC_CHECK: Option<Arc<dyn TextSpanner>> = None;
#[test]
fn test_spanref() {
let span = SpanRef::Word(0..3);
assert_eq!(span.range(), &(0..3));
assert_eq!(Range::<usize>::from(span), 0..3);
let span = SpanRef::Gap(0..3);
assert_eq!(span.range(), &(0..3));
assert_eq!(Range::<usize>::from(span), 0..3);
let span = SpanRef::Special(0..3);
assert_eq!(span.range(), &(0..3));
assert_eq!(Range::<usize>::from(span), 0..3);
}
}