#![no_std]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![cfg_attr(docsrs, doc(auto_cfg))]
extern crate alloc;
#[cfg(feature = "std")]
extern crate std;
mod charsmap;
mod config;
mod decoder;
mod definition;
mod encoder;
mod regex;
mod vocab;
#[cfg(feature = "serialization")]
mod serialization;
#[cfg(feature = "web")]
mod web;
pub mod convert;
use alloc::boxed::Box;
use alloc::fmt::Debug;
use alloc::string::String;
use alloc::vec::Vec;
use core::str::Utf8Error;
use derive_more::{Deref, DerefMut};
use hashbrown::HashMap;
pub use crate::charsmap::*;
pub use crate::config::*;
pub use crate::decoder::*;
pub use crate::definition::*;
pub use crate::encoder::*;
pub use crate::regex::*;
pub use crate::vocab::*;
#[cfg(feature = "serialization")]
pub use crate::serialization::*;
#[cfg(feature = "web")]
pub use crate::web::*;
#[doc(hidden)]
pub mod util;
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum InitializationError {
#[error("invalid config: {0}")]
InvalidConfig(ConfigurationError),
#[error(
"encoder and scores must have the same length in unigram mode and every token must have a score"
)]
InvalidScores,
#[error("encoder and decoder must have the same length and vocab must not have duplicates")]
InvalidEncoder,
#[error(
"special encoder and decoder must have the same length and specials must not have duplicates"
)]
InvalidSpecialEncoder,
#[error("invalid regex: {0}")]
InvalidRegex(String),
#[error("invalid utf-8: {0}")]
InvalidUtf8(Utf8Error),
}
impl From<ConfigurationError> for InitializationError {
#[inline(always)]
fn from(e: ConfigurationError) -> Self {
Self::InvalidConfig(e)
}
}
impl From<RegexError> for InitializationError {
#[inline(always)]
fn from(e: RegexError) -> Self {
Self::InvalidRegex(e.0)
}
}
impl From<Utf8Error> for InitializationError {
#[inline(always)]
fn from(e: Utf8Error) -> Self {
Self::InvalidUtf8(e)
}
}
#[derive(Clone, Deref, DerefMut)]
struct SpecialsMap(HashMap<TokenBytes, SpecialToken>);
impl FromIterator<(TokenBytes, SpecialToken)> for SpecialsMap {
#[inline(always)]
fn from_iter<I: IntoIterator<Item = (TokenBytes, SpecialToken)>>(iter: I) -> Self {
Self(iter.into_iter().collect())
}
}
impl Debug for SpecialsMap {
#[inline(never)]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_list().entries(self.0.values()).finish()
}
}
pub trait SpecialTokenKinds {
fn as_kinds(&self, meta: &Metadata) -> &[SpecialTokenKind];
}
impl SpecialTokenKinds for bool {
#[inline(always)]
fn as_kinds(&self, meta: &Metadata) -> &[SpecialTokenKind] {
if *self {
&[
SpecialTokenKind::Control,
SpecialTokenKind::Priority,
SpecialTokenKind::Unknown,
]
} else if meta.source == "tokenizers" {
&[SpecialTokenKind::Priority]
} else {
&[SpecialTokenKind::Priority, SpecialTokenKind::Unknown]
}
}
}
impl<'a> SpecialTokenKinds for &'a [SpecialTokenKind] {
#[inline(always)]
fn as_kinds(&self, _meta: &Metadata) -> &'a [SpecialTokenKind] {
self
}
}
impl SpecialTokenKinds for Vec<SpecialTokenKind> {
#[inline(always)]
fn as_kinds(&self, _meta: &Metadata) -> &[SpecialTokenKind] {
self
}
}
#[derive(Debug)]
pub struct Kitoken {
encoder: Box<dyn Encoder>,
decoder: Decoder,
specials: SpecialsMap,
extract_split: Regex,
special_split: Regex,
config: Configuration,
meta: Metadata,
}
impl Kitoken {
#[inline(never)]
pub fn new(
model: Model, specials: SpecialVocab, config: Configuration, meta: Metadata,
) -> Result<Self, InitializationError> {
if let Err(error) = config.validate() {
return Err(InitializationError::InvalidConfig(error));
}
let special_split = Regex::new(
&specials
.iter()
.filter(|special| !special.extract)
.map(|special| core::str::from_utf8(&special.bytes))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.map(|s| regex::escape(s))
.collect::<Vec<_>>()
.join("|"),
)?;
let extract_split = Regex::new(
&specials
.iter()
.filter(|special| special.extract)
.map(|special| core::str::from_utf8(&special.bytes))
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.map(|s| regex::escape(s))
.collect::<Vec<_>>()
.join("|"),
)?;
let (encoder, decoder) = match model {
Model::BytePair { vocab, chars } => {
let decoder = Decoder::new(&vocab, &specials, &config);
let encoder = Box::new(BytePair::new(vocab, &specials, &config, chars)?) as _;
(encoder, decoder)
}
Model::Unigram { vocab, scores } => {
let decoder = Decoder::new(&vocab, &specials, &config);
let encoder = Box::new(Unigram::new(vocab, &specials, &config, scores)?) as _;
(encoder, decoder)
}
Model::WordPiece {
vocab,
max_word_chars,
} => {
let decoder = Decoder::new(&vocab, &specials, &config);
let encoder =
Box::new(WordPiece::new(vocab, &specials, &config, max_word_chars)) as _;
(encoder, decoder)
}
};
let specials_len = specials.len();
let specials = specials
.into_iter()
.map(|special| (special.bytes.clone(), special))
.collect::<SpecialsMap>();
if specials_len != specials.len() {
return Err(InitializationError::InvalidSpecialEncoder);
}
Ok(Self {
encoder,
decoder,
specials,
special_split,
extract_split,
config,
meta,
})
}
#[inline(always)]
pub fn encode(
&self, text: impl AsRef<str>, encode_specials: impl SpecialTokenKinds,
) -> Result<Vec<TokenId>, EncodeError> {
self.inner_encode(text, encode_specials.as_kinds(&self.meta))
}
#[inline(never)]
fn inner_encode(
&self, text: impl AsRef<str>, encode_specials: &[SpecialTokenKind],
) -> Result<Vec<TokenId>, EncodeError> {
let text = text.as_ref();
let mut extracted = if self.extract_split.is_empty() {
Vec::with_capacity(0)
} else {
let mut extracted = self.extract_split.find_iter(text);
extracted.reverse();
extracted
};
let mut parts = Vec::with_capacity(extracted.len() * 2 + 1);
let mut posit = 0;
while posit < text.len() {
if let Some(next) = extracted.pop() {
if next.0 > posit {
let mut text = text[posit..next.0].into();
self.config.normalize(&mut text, posit..next.0);
parts.push(TextPart {
text,
special: Token::INVALID,
})
}
let special = &self.specials[&text.as_bytes()[next.0..next.1]];
parts.push(TextPart {
text: text[next.0..next.1].into(),
special: if encode_specials.contains(&special.kind) {
special.id
} else {
Token::INVALID
},
});
posit = next.1;
} else {
let mut rest = text[posit..text.len()].into();
self.config.normalize(&mut rest, posit..usize::MAX);
parts.push(TextPart {
text: rest,
special: Token::INVALID,
});
posit = text.len();
}
}
let mut parts = parts.iter().fold(Vec::with_capacity(text.len() / 6), |mut acc, part| {
let mut specials = if part.special != Token::INVALID {
acc.push(part.clone());
return acc;
} else if self.special_split.is_empty() {
Vec::with_capacity(0)
} else {
let mut specials = self
.special_split
.find_iter(&part.text)
.into_iter()
.map(|(start, end)| {
(start, end, &self.specials[part.text[start..end].as_bytes()])
})
.filter(|(_, _, special)| encode_specials.contains(&special.kind))
.collect::<Vec<_>>();
specials.reverse();
specials
};
let mut posit = 0;
while posit < part.text.len() {
if let Some(next) = specials.pop() {
if next.0 > posit {
for (start, end) in self.config.split(&part.text[posit..next.0]) {
if end > start {
acc.push(TextPart {
text: part.text[posit + start..posit + end].into(),
special: Token::INVALID,
});
}
}
}
acc.push(TextPart {
text: part.text[next.0..next.1].into(),
special: next.2.id,
});
posit = next.1;
} else {
for (start, end) in self.config.split(&part.text[posit..part.text.len()]) {
if end > start {
acc.push(TextPart {
text: part.text[posit + start..posit + end].into(),
special: Token::INVALID,
});
}
}
posit = part.text.len();
}
}
acc
});
let mut result = self.encoder.encode(text, &mut parts)?;
self.config.process(&mut result);
Ok(result)
}
#[inline(never)]
pub fn decode(
&self, tokens: impl AsRef<[TokenId]>, decode_specials: impl SpecialTokenKinds,
) -> Result<Vec<u8>, DecodeError> {
let tokens = tokens.as_ref();
let mut result = self.decoder.decode(tokens, decode_specials.as_kinds(&self.meta))?;
self.config.decode(&mut result);
Ok(result)
}
#[inline(always)]
pub fn config(&self) -> &Configuration {
&self.config
}
#[inline(always)]
pub fn meta(&self) -> &Metadata {
&self.meta
}
}