use crate::{
TokenDecoder,
TokenEncoder,
TokenType,
UnifiedTokenVocab,
WCResult,
alloc::sync::Arc,
decoders::{
BatchDecodeResult,
DecodeResult,
},
prelude::*,
spanners::TextSpanner,
vocab::SpecialFilter,
};
#[derive(Clone)]
pub struct Tokenizer<T: TokenType> {
vocab: Arc<UnifiedTokenVocab<T>>,
encoder: Arc<dyn TokenEncoder<T>>,
decoder: Arc<dyn TokenDecoder<T>>,
}
impl<T: TokenType> Tokenizer<T> {
pub fn new(
vocab: Arc<UnifiedTokenVocab<T>>,
encoder: Arc<dyn TokenEncoder<T>>,
decoder: Arc<dyn TokenDecoder<T>>,
) -> Self {
Self {
vocab,
encoder,
decoder,
}
}
pub fn vocab(&self) -> &Arc<UnifiedTokenVocab<T>> {
&self.vocab
}
pub fn encoder(&self) -> &Arc<dyn TokenEncoder<T>> {
&self.encoder
}
pub fn decoder(&self) -> &Arc<dyn TokenDecoder<T>> {
&self.decoder
}
pub fn split_by_token(
&self,
text: &str,
special_filter: Option<&SpecialFilter>,
) -> WCResult<Vec<String>> {
let tokens = self.try_encode(text, special_filter)?;
tokens
.into_iter()
.map(|t| self.try_decode_to_string(&[t])?.try_result())
.collect()
}
}
impl<T: TokenType> TokenEncoder<T> for Tokenizer<T> {
fn spanner(&self) -> &Arc<dyn TextSpanner> {
self.encoder.spanner()
}
fn special_vocab(&self) -> &crate::vocab::SpecialVocab<T> {
self.encoder.special_vocab()
}
fn try_encode_append(
&self,
text: &str,
tokens: &mut Vec<T>,
special_filter: Option<&SpecialFilter>,
) -> WCResult<()> {
self.encoder.try_encode_append(text, tokens, special_filter)
}
fn try_encode(
&self,
text: &str,
special_filter: Option<&SpecialFilter>,
) -> WCResult<Vec<T>> {
self.encoder.try_encode(text, special_filter)
}
fn try_encode_batch(
&self,
batch: &[&str],
special_filter: Option<&SpecialFilter>,
) -> WCResult<Vec<Vec<T>>> {
self.encoder.try_encode_batch(batch, special_filter)
}
}
impl<T: TokenType> TokenDecoder<T> for Tokenizer<T> {
fn try_decode_to_bytes(
&self,
tokens: &[T],
) -> WCResult<DecodeResult<Vec<u8>>> {
self.decoder.try_decode_to_bytes(tokens)
}
fn try_decode_batch_to_bytes(
&self,
batch: &[&[T]],
) -> WCResult<BatchDecodeResult<Vec<u8>>> {
self.decoder.try_decode_batch_to_bytes(batch)
}
fn try_decode_to_string(
&self,
tokens: &[T],
) -> WCResult<DecodeResult<String>> {
self.decoder.try_decode_to_string(tokens)
}
fn try_decode_batch_to_strings(
&self,
batch: &[&[T]],
) -> WCResult<BatchDecodeResult<String>> {
self.decoder.try_decode_batch_to_strings(batch)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
TokenizerOptions,
alloc::vec,
decoders::utility::testing::common_decoder_tests,
encoders::testing::common_encoder_tests,
pretrained::openai::OA_CL100K_BASE_PATTERN,
spanners::TextSpanningConfig,
vocab::utility::testing::{
build_test_shift_byte_vocab,
build_test_vocab,
},
};
#[test]
fn test_tokenizer_impl() {
type T = u32;
let vocab: Arc<UnifiedTokenVocab<T>> = build_test_vocab(
build_test_shift_byte_vocab(10),
TextSpanningConfig::from_pattern(OA_CL100K_BASE_PATTERN),
)
.into();
let tokenizer = TokenizerOptions::default().build(vocab.clone());
common_encoder_tests(vocab.clone(), tokenizer.clone());
common_decoder_tests(vocab.clone(), tokenizer.clone());
let sample = "hell the world";
assert_eq!(
&tokenizer.split_by_token(sample, None).unwrap(),
&vec!["hell", " ", "the", " ", "world"],
);
}
}