use std::path::Path;
use rayon::prelude::*;
use super::{
Encoding, Error, Result, TokenIdType,
hf::HuggingFaceTokenizer,
traits::{DecodeResult, Decoder, Encoder, Tokenizer},
};
pub struct FastTokenizer {
fast_encoder: fastokens::Tokenizer,
hf_decoder: HuggingFaceTokenizer,
}
impl FastTokenizer {
pub fn from_file(path: &str) -> Result<Self> {
let fast_encoder = fastokens::Tokenizer::from_file(Path::new(path))
.map_err(|e| Error::msg(format!("Error loading fastokens tokenizer: {e}")))?;
let hf_decoder = HuggingFaceTokenizer::from_file(path)?;
Ok(Self {
fast_encoder,
hf_decoder,
})
}
}
impl Encoder for FastTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
let ids = self
.fast_encoder
.encode(input)
.map_err(|e| Error::msg(format!("Fastokens encode error: {e}")))?;
Ok(Encoding::Sp(ids))
}
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
inputs.par_iter().map(|input| self.encode(input)).collect()
}
}
impl Decoder for FastTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<DecodeResult> {
self.hf_decoder.decode(token_ids, skip_special_tokens)
}
}
impl Tokenizer for FastTokenizer {}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizers::HuggingFaceTokenizer;
const TOKENIZER_PATH: &str = concat!(
env!("CARGO_MANIFEST_DIR"),
"/tests/data/sample-models/minimal-bpe/tokenizer.json"
);
#[test]
fn test_fast_encode_decode_roundtrip() {
let tokenizer = FastTokenizer::from_file(TOKENIZER_PATH).unwrap();
let text = "Hello, world!";
let encoding = tokenizer.encode(text).unwrap();
assert!(!encoding.token_ids().is_empty());
let decoded: String = tokenizer.decode(encoding.token_ids(), true).unwrap().into();
assert!(!decoded.is_empty());
let enc_chars: String = text.chars().filter(|c| !c.is_whitespace()).collect();
let dec_chars: String = decoded.chars().filter(|c| !c.is_whitespace()).collect();
assert_eq!(
enc_chars, dec_chars,
"non-space characters must be preserved"
);
}
#[test]
fn test_fast_matches_hf_encoding() {
let fast = FastTokenizer::from_file(TOKENIZER_PATH).unwrap();
let hf = HuggingFaceTokenizer::from_file(TOKENIZER_PATH).unwrap();
for text in &["Hello, world!", "Hello", " world", "He llo"] {
let fast_ids = fast.encode(text).unwrap();
let hf_ids = hf.encode(text).unwrap();
assert_eq!(
fast_ids.token_ids(),
hf_ids.token_ids(),
"fastokens and HuggingFace must produce identical token IDs for '{text}'"
);
}
}
#[test]
fn test_fast_batch_encode() {
let tokenizer = FastTokenizer::from_file(TOKENIZER_PATH).unwrap();
let inputs = &["Hello", " world", "Hello, world!"];
let encodings = tokenizer.encode_batch(inputs).unwrap();
assert_eq!(encodings.len(), inputs.len());
for (enc, input) in encodings.iter().zip(inputs.iter()) {
assert!(
!enc.token_ids().is_empty(),
"encoding for '{input}' must be non-empty"
);
}
}
#[test]
fn test_fast_with_decode_stream() {
use crate::tokenizers::Tokenizer as TokenizerWrapper;
use std::sync::Arc;
let tokenizer = Arc::new(FastTokenizer::from_file(TOKENIZER_PATH).unwrap());
let wrapper = TokenizerWrapper::from(tokenizer);
let prompt_ids = wrapper.encode("Hello").unwrap().token_ids().to_vec();
let continuation = ", world!";
let cont_ids = wrapper.encode(continuation).unwrap().token_ids().to_vec();
let mut stream = wrapper.decode_stream(&prompt_ids, true);
let mut accumulated = String::new();
for id in &cont_ids {
if let Some(chunk) = stream.step(*id).unwrap() {
accumulated.push_str(&chunk);
}
}
let mut all_ids = prompt_ids.clone();
all_ids.extend_from_slice(&cont_ids);
let full_text: String = wrapper.decode(&all_ids, true).unwrap().into();
let prompt_text: String = wrapper.decode(&prompt_ids, true).unwrap().into();
let expected = &full_text[prompt_text.len()..];
assert_eq!(
accumulated, expected,
"streamed chunks must equal context-aware decoded continuation"
);
}
}