mod backtracking;
mod sentencepiece;
mod simple;
mod unigram;
mod wordpiece;
pub use backtracking::{BacktrackingBytePairEncoder, EncodeIter};
pub use sentencepiece::{EncodeState, SentencePieceBPE};
pub use simple::BytePairEncoder;
pub use unigram::UnigramEncoder;
pub use wordpiece::WordPieceEncoder;
use crate::types::{Split, TokenId};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[repr(u32)]
pub enum EncoderType {
#[default]
Backtracking = 0,
Simple = 1,
WordPiece = 2,
SentencePiece = 3,
Unigram = 4,
}
impl EncoderType {
pub fn from_u32(v: u32) -> Option<Self> {
match v {
0 => Some(Self::Backtracking),
1 => Some(Self::Simple),
2 => Some(Self::WordPiece),
3 => Some(Self::SentencePiece),
4 => Some(Self::Unigram),
_ => None,
}
}
}
#[derive(Clone)]
pub enum Encoder {
Backtracking(BacktrackingBytePairEncoder),
Simple(BytePairEncoder),
WordPiece(WordPieceEncoder),
SentencePiece(SentencePieceBPE),
Unigram(UnigramEncoder),
}
impl Encoder {
pub fn encoder_type(&self) -> EncoderType {
match self {
Encoder::Backtracking(_) => EncoderType::Backtracking,
Encoder::Simple(_) => EncoderType::Simple,
Encoder::WordPiece(_) => EncoderType::WordPiece,
Encoder::SentencePiece(_) => EncoderType::SentencePiece,
Encoder::Unigram(_) => EncoderType::Unigram,
}
}
pub fn encode(&self, text: &[u8]) -> Vec<TokenId> {
match self {
Encoder::Backtracking(e) => e.encode(text),
Encoder::Simple(e) => e.encode(text),
Encoder::WordPiece(e) => e.encode(text),
Encoder::SentencePiece(e) => e.encode(text),
Encoder::Unigram(e) => e.encode(text),
}
}
pub fn vocab_size(&self) -> usize {
match self {
Encoder::Backtracking(e) => e.vocab_size(),
Encoder::Simple(e) => e.vocab_size(),
Encoder::WordPiece(e) => e.vocab_size(),
Encoder::SentencePiece(e) => e.vocab_size(),
Encoder::Unigram(e) => e.vocab_size(),
}
}
pub fn num_base_tokens(&self) -> usize {
match self {
Encoder::Backtracking(e) => e.num_base_tokens(),
Encoder::Simple(e) => e.num_base_tokens(),
Encoder::WordPiece(e) => e.num_base_tokens(),
Encoder::SentencePiece(e) => e.num_base_tokens(),
Encoder::Unigram(e) => e.num_base_tokens(),
}
}
pub fn split_table(&self) -> Option<&[Split]> {
match self {
Encoder::Backtracking(e) => Some(e.split_table()),
Encoder::Simple(_) | Encoder::WordPiece(_) | Encoder::SentencePiece(_) | Encoder::Unigram(_) => None,
}
}
pub fn encode_iter<'a>(&'a self, text: &'a [u8]) -> EncoderIter<'a> {
match self {
Encoder::Backtracking(e) => EncoderIter::Backtracking(e.encode_iter(text)),
Encoder::Simple(_) | Encoder::WordPiece(_) | Encoder::SentencePiece(_) | Encoder::Unigram(_) => {
EncoderIter::Collected(self.encode(text).into_iter())
}
}
}
pub fn as_backtracking(&self) -> Option<&BacktrackingBytePairEncoder> {
match self {
Encoder::Backtracking(e) => Some(e),
_ => None,
}
}
pub fn as_simple(&self) -> Option<&BytePairEncoder> {
match self {
Encoder::Simple(e) => Some(e),
_ => None,
}
}
pub fn as_wordpiece(&self) -> Option<&WordPieceEncoder> {
match self {
Encoder::WordPiece(e) => Some(e),
_ => None,
}
}
pub fn as_sentencepiece(&self) -> Option<&SentencePieceBPE> {
match self {
Encoder::SentencePiece(e) => Some(e),
_ => None,
}
}
pub fn as_unigram(&self) -> Option<&UnigramEncoder> {
match self {
Encoder::Unigram(e) => Some(e),
_ => None,
}
}
pub fn is_valid_pair(&self, token1: TokenId, token2: TokenId) -> bool {
match self {
Encoder::Backtracking(e) => e.is_valid_pair(token1, token2),
Encoder::Simple(e) => e.is_valid_pair(token1, token2),
Encoder::WordPiece(e) => e.is_valid_pair(token1, token2),
Encoder::SentencePiece(e) => e.is_valid_pair(token1, token2),
Encoder::Unigram(e) => e.is_valid_pair(token1, token2),
}
}
}
pub enum EncoderIter<'a> {
Backtracking(EncodeIter<'a>),
Collected(std::vec::IntoIter<TokenId>),
}
impl Iterator for EncoderIter<'_> {
type Item = TokenId;
fn next(&mut self) -> Option<TokenId> {
match self {
EncoderIter::Backtracking(iter) => iter.next(),
EncoderIter::Collected(iter) => iter.next(),
}
}
}
impl std::iter::FusedIterator for EncoderIter<'_> {}