use core::ops::Range;
use alloc::borrow::Cow;
use alloc::string::String;
use alloc::vec::Vec;
#[cfg(feature = "serialization")]
use serde::{Deserialize, Serialize};
mod decoding;
mod normalization;
mod processing;
mod split;
pub use decoding::*;
pub use normalization::*;
pub use processing::*;
pub use split::*;
use crate::TokenId;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
#[cfg_attr(feature = "serialization", derive(Deserialize, Serialize))]
pub enum Fallback {
Skip,
Unknown,
Bytes,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
#[cfg_attr(feature = "serialization", derive(Deserialize, Serialize))]
pub enum InsertionPosition {
WordStart,
WordContinuation,
WordEnd,
SequenceStart,
SequenceContinuation,
SequenceEnd,
SubSequenceStart,
SubSequenceContinuation,
SubSequenceEnd,
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serialization", derive(Deserialize, Serialize))]
pub struct Template {
pub content: String,
pub position: InsertionPosition,
}
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum ConfigurationError {
#[error("required feature not enabled: {0}")]
FeatureDisabled(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serialization", derive(Deserialize, Serialize))]
pub struct Configuration {
pub fallback: Vec<Fallback>,
pub normalization: Vec<Normalization>,
pub split: Vec<Split>,
pub processing: Vec<Processing>,
pub decoding: Vec<Decoding>,
pub templates: Vec<Template>,
}
impl Configuration {
#[inline(never)]
pub fn validate(&self) -> Result<(), ConfigurationError> {
#[cfg(not(feature = "normalization-unicode"))]
if self
.normalization
.iter()
.any(|norm| matches!(norm, Normalization::Unicode { .. }))
{
use alloc::string::ToString;
return Err(ConfigurationError::FeatureDisabled("normalization-unicode".to_string()));
}
#[cfg(not(feature = "normalization-charsmap"))]
if self
.normalization
.iter()
.any(|norm| matches!(norm, Normalization::CharsMap { .. }))
{
use alloc::string::ToString;
return Err(ConfigurationError::FeatureDisabled("normalization-charsmap".to_string()));
}
#[cfg(not(feature = "split-unicode-script"))]
if self.split.iter().any(|split| matches!(split, Split::UnicodeScript)) {
use alloc::string::ToString;
return Err(ConfigurationError::FeatureDisabled("split-unicode-script".to_string()));
}
Ok(())
}
#[inline(never)]
pub fn normalize(&self, text: &mut Cow<str>, position: Range<usize>) {
if text.is_empty() {
return;
}
for norm in &self.normalization {
norm.normalize(text, position.clone());
}
}
#[inline(never)]
pub fn split(&self, text: &str) -> Vec<(usize, usize)> {
if text.is_empty() {
return Vec::new();
}
if self.split.is_empty() {
return Vec::from([(0, text.len())]);
}
if self.split.len() == 1 {
return self.split[0].split(text);
}
let mut matches = Vec::from([(0, text.len())]);
for split in &self.split {
let split_matches = matches.iter().map(|&(start, end)| {
let mut split_match = split.split(&text[start..end]);
split_match.iter_mut().for_each(|(split_start, split_end)| {
*split_start += start;
*split_end += start;
});
split_match
});
matches = split_matches.flatten().collect();
}
matches
}
#[inline(never)]
pub fn process(&self, tokens: &mut Vec<TokenId>) {
if tokens.is_empty() {
return;
}
for processing in &self.processing {
processing.process(tokens);
}
}
#[inline(never)]
pub fn decode(&self, tokens: &mut Vec<u8>) {
if tokens.is_empty() {
return;
}
for decoding in &self.decoding {
decoding.decode(tokens);
}
}
}