use core::num::NonZeroUsize;
use crate::{
TokenType,
UnifiedTokenVocab,
alloc::sync::Arc,
spanners::{
TextSpanner,
TextSpanningConfig,
span_lexers::{
LexerTextSpanner,
SpanLexer,
build_regex_lexer,
},
},
};
#[derive(Clone, PartialEq)]
pub struct TextSpannerBuilder<T: TokenType> {
config: TextSpanningConfig<T>,
enable_accelerated_lexers: bool,
enable_regex_automata: bool,
concurrent: bool,
max_pool: Option<NonZeroUsize>,
}
impl<T: TokenType> From<TextSpanningConfig<T>> for TextSpannerBuilder<T> {
fn from(config: TextSpanningConfig<T>) -> Self {
Self::new(config)
}
}
impl<T, V> From<V> for TextSpannerBuilder<T>
where
T: TokenType,
V: AsRef<UnifiedTokenVocab<T>>,
{
fn from(vocab: V) -> Self {
Self::from_vocab(vocab.as_ref())
}
}
impl<T: TokenType> TextSpannerBuilder<T> {
pub fn default(vocab: &UnifiedTokenVocab<T>) -> Arc<dyn TextSpanner> {
Self::from_vocab(vocab).build()
}
pub fn from_vocab(vocab: &UnifiedTokenVocab<T>) -> Self {
Self::new(vocab.spanning().clone())
}
pub fn new(config: TextSpanningConfig<T>) -> Self {
Self {
config,
enable_accelerated_lexers: true,
enable_regex_automata: true,
concurrent: true,
max_pool: None,
}
}
pub fn config(&self) -> &TextSpanningConfig<T> {
&self.config
}
pub fn accelerated_lexers(&self) -> bool {
self.enable_accelerated_lexers
}
pub fn set_accelerated_lexers(
&mut self,
enable: bool,
) {
self.enable_accelerated_lexers = enable;
}
pub fn with_accelerated_lexers(
mut self,
enable: bool,
) -> Self {
self.set_accelerated_lexers(enable);
self
}
pub fn regex_automata(&self) -> bool {
self.enable_regex_automata
}
pub fn set_regex_automata(
&mut self,
enable: bool,
) {
self.enable_regex_automata = enable;
}
pub fn with_regex_automata(
mut self,
accelerated_lexers: bool,
) -> Self {
self.set_regex_automata(accelerated_lexers);
self
}
pub fn concurrent(&self) -> bool {
self.concurrent
}
pub fn set_concurrent(
&mut self,
concurrent: bool,
) {
self.concurrent = concurrent;
}
pub fn with_concurrent(
mut self,
concurrent: bool,
) -> Self {
self.set_concurrent(concurrent);
self
}
pub fn max_pool(&self) -> Option<NonZeroUsize> {
self.max_pool
}
pub fn set_max_pool(
&mut self,
max_pool: NonZeroUsize,
) {
self.max_pool = Some(max_pool);
}
pub fn with_max_pool(
mut self,
max_pool: NonZeroUsize,
) -> Self {
self.set_max_pool(max_pool);
self
}
pub fn build(&self) -> Arc<dyn TextSpanner> {
let word_lexer: Arc<dyn SpanLexer> = build_regex_lexer(
self.config().pattern().clone(),
self.enable_accelerated_lexers,
self.enable_regex_automata,
self.concurrent,
self.max_pool,
);
let special_lexer: Option<Arc<dyn SpanLexer>> =
self.config.specials().special_pattern().map(|pattern| {
build_regex_lexer(
pattern,
false,
self.enable_regex_automata,
self.concurrent,
self.max_pool,
)
});
Arc::new(LexerTextSpanner::new(word_lexer, special_lexer))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pretrained::openai::OA_GPT2_PATTERN;
#[test]
fn test_builder() {
type T = u32;
let config: TextSpanningConfig<T> =
TextSpanningConfig::<u32>::from_pattern(OA_GPT2_PATTERN);
let builder: TextSpannerBuilder<T> = config.clone().into();
assert_eq!(builder.config(), &config);
assert_eq!(builder.accelerated_lexers(), true);
assert_eq!(builder.regex_automata(), true);
assert_eq!(builder.concurrent(), true);
assert_eq!(builder.max_pool(), None);
let builder = builder
.with_accelerated_lexers(false)
.with_regex_automata(false)
.with_concurrent(false)
.with_max_pool(NonZeroUsize::new(1).unwrap());
assert_eq!(builder.accelerated_lexers(), false);
assert_eq!(builder.regex_automata(), false);
assert_eq!(builder.concurrent(), false);
assert_eq!(builder.max_pool(), Some(NonZeroUsize::new(1).unwrap()));
}
}