wordchipper 0.9.2

HPC Rust LLM Tokenizer Library
Documentation
use crate::{
    TokenDecoder,
    TokenEncoder,
    TokenType,
    UnifiedTokenVocab,
    WCResult,
    alloc::sync::Arc,
    decoders::{
        BatchDecodeResult,
        DecodeResult,
    },
    prelude::*,
    spanners::TextSpanner,
    vocab::SpecialFilter,
};

/// Unified Tokenizer.
///
/// Combines:
///  * [`UnifiedTokenVocab`],
///  * [`TokenEncoder`], and
///  * [`TokenDecoder`] wrappers.
#[derive(Clone)]
pub struct Tokenizer<T: TokenType> {
    vocab: Arc<UnifiedTokenVocab<T>>,
    encoder: Arc<dyn TokenEncoder<T>>,
    decoder: Arc<dyn TokenDecoder<T>>,
}

impl<T: TokenType> Tokenizer<T> {
    /// Create a new tokenizer.
    pub fn new(
        vocab: Arc<UnifiedTokenVocab<T>>,
        encoder: Arc<dyn TokenEncoder<T>>,
        decoder: Arc<dyn TokenDecoder<T>>,
    ) -> Self {
        Self {
            vocab,
            encoder,
            decoder,
        }
    }

    /// Get the underlying vocabulary.
    pub fn vocab(&self) -> &Arc<UnifiedTokenVocab<T>> {
        &self.vocab
    }

    /// Get the underlying encoder.
    pub fn encoder(&self) -> &Arc<dyn TokenEncoder<T>> {
        &self.encoder
    }

    /// Get the underlying decoder.
    pub fn decoder(&self) -> &Arc<dyn TokenDecoder<T>> {
        &self.decoder
    }

    /// Tokenize text, and return the decoded tokens as individual strings.
    ///
    /// ## Compat
    /// This is added for compatibility for `tiktoken` users.
    ///
    /// ## Arguments
    /// * `text` - The text to split.
    /// * `special_filter` - an optional [`SpecialFilter`]. If `None`, all
    ///   special tokens are accepted.
    pub fn split_by_token(
        &self,
        text: &str,
        special_filter: Option<&SpecialFilter>,
    ) -> WCResult<Vec<String>> {
        let tokens = self.try_encode(text, special_filter)?;
        tokens
            .into_iter()
            .map(|t| self.try_decode_to_string(&[t])?.try_result())
            .collect()
    }
}

impl<T: TokenType> TokenEncoder<T> for Tokenizer<T> {
    fn spanner(&self) -> &Arc<dyn TextSpanner> {
        self.encoder.spanner()
    }

    fn special_vocab(&self) -> &crate::vocab::SpecialVocab<T> {
        self.encoder.special_vocab()
    }

    fn try_encode_append(
        &self,
        text: &str,
        tokens: &mut Vec<T>,
        special_filter: Option<&SpecialFilter>,
    ) -> WCResult<()> {
        self.encoder.try_encode_append(text, tokens, special_filter)
    }

    fn try_encode(
        &self,
        text: &str,
        special_filter: Option<&SpecialFilter>,
    ) -> WCResult<Vec<T>> {
        self.encoder.try_encode(text, special_filter)
    }

    fn try_encode_batch(
        &self,
        batch: &[&str],
        special_filter: Option<&SpecialFilter>,
    ) -> WCResult<Vec<Vec<T>>> {
        self.encoder.try_encode_batch(batch, special_filter)
    }
}

impl<T: TokenType> TokenDecoder<T> for Tokenizer<T> {
    fn try_decode_to_bytes(
        &self,
        tokens: &[T],
    ) -> WCResult<DecodeResult<Vec<u8>>> {
        self.decoder.try_decode_to_bytes(tokens)
    }

    fn try_decode_batch_to_bytes(
        &self,
        batch: &[&[T]],
    ) -> WCResult<BatchDecodeResult<Vec<u8>>> {
        self.decoder.try_decode_batch_to_bytes(batch)
    }

    fn try_decode_to_string(
        &self,
        tokens: &[T],
    ) -> WCResult<DecodeResult<String>> {
        self.decoder.try_decode_to_string(tokens)
    }

    fn try_decode_batch_to_strings(
        &self,
        batch: &[&[T]],
    ) -> WCResult<BatchDecodeResult<String>> {
        self.decoder.try_decode_batch_to_strings(batch)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::{
        TokenizerOptions,
        alloc::vec,
        decoders::utility::testing::common_decoder_tests,
        encoders::testing::common_encoder_tests,
        pretrained::openai::OA_CL100K_BASE_PATTERN,
        spanners::TextSpanningConfig,
        vocab::utility::testing::{
            build_test_shift_byte_vocab,
            build_test_vocab,
        },
    };

    #[test]
    fn test_tokenizer_impl() {
        type T = u32;

        let vocab: Arc<UnifiedTokenVocab<T>> = build_test_vocab(
            build_test_shift_byte_vocab(10),
            TextSpanningConfig::from_pattern(OA_CL100K_BASE_PATTERN),
        )
        .into();

        let tokenizer = TokenizerOptions::default().build(vocab.clone());

        common_encoder_tests(vocab.clone(), tokenizer.clone());

        common_decoder_tests(vocab.clone(), tokenizer.clone());

        let sample = "hell the world";
        assert_eq!(
            &tokenizer.split_by_token(sample, None).unwrap(),
            &vec!["hell", " ", "the", " ", "world"],
        );
    }
}