use crate::{
TokenType,
WCResult,
support::regex::RegexPattern,
vocab::SpecialVocab,
};
#[derive(Debug, Clone, PartialEq)]
pub struct TextSpanningConfig<T: TokenType> {
pattern: RegexPattern,
specials: SpecialVocab<T>,
}
impl<T: TokenType> From<RegexPattern> for TextSpanningConfig<T> {
fn from(value: RegexPattern) -> Self {
TextSpanningConfig::<T>::from_pattern(value)
}
}
impl<T: TokenType> TextSpanningConfig<T> {
pub fn from_pattern<P>(pattern: P) -> Self
where
P: Into<RegexPattern>,
{
Self {
pattern: pattern.into(),
specials: SpecialVocab::default(),
}
}
pub fn with_pattern<P>(
self,
pattern: P,
) -> Self
where
P: Into<RegexPattern>,
{
Self {
pattern: pattern.into(),
..self
}
}
pub fn with_specials<S>(
self,
specials: S,
) -> Self
where
S: Into<SpecialVocab<T>>,
{
let specials = specials.into();
Self { specials, ..self }
}
pub fn with_special_words<W, S>(
self,
special_words: W,
) -> Self
where
W: IntoIterator<Item = (S, T)>,
S: AsRef<str>,
{
Self {
specials: self.specials.with_special_words(special_words),
..self
}
}
pub fn to_token_type<G: TokenType>(&self) -> WCResult<TextSpanningConfig<G>> {
Ok(TextSpanningConfig::<G> {
pattern: self.pattern.clone(),
specials: self.specials.to_token_type()?,
})
}
pub fn pattern(&self) -> &RegexPattern {
&self.pattern
}
pub fn specials(&self) -> &SpecialVocab<T> {
&self.specials
}
pub fn special_pattern(&self) -> Option<RegexPattern> {
self.specials.special_pattern()
}
pub fn specials_mut(&mut self) -> &mut SpecialVocab<T> {
&mut self.specials
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
alloc::string::ToString,
vocab::SpecialVocab,
};
#[test]
fn test_from_pattern() {
type T = u32;
let pattern = RegexPattern::Adaptive("hello".to_string());
let mut config: TextSpanningConfig<T> = pattern.into();
assert_eq!(config.pattern().as_str(), "hello");
assert_eq!(config.specials().len(), 0);
config.specials_mut().add_str_word("hello", 1);
assert_eq!(config.specials().len(), 1);
let config = config.with_pattern("hi");
assert_eq!(&config.pattern, &RegexPattern::Adaptive("hi".to_string()));
let mut specials = SpecialVocab::default();
specials.add_str_word("apple", 1);
specials.add_str_word("pear", 1);
let config = config.with_specials(specials.clone());
assert_eq!(config.specials(), &specials);
}
}