use std::path::Path;
use tokenizers::Tokenizer;
pub fn normalize_text(text: &str) -> String {
text
.replace(['\u{2019}', '\u{2018}'], "'") .replace(['\u{201C}', '\u{201D}'], "\"") .replace('\u{2014}', "--") .replace('\u{2013}', "-") .replace('\u{2026}', "...") .replace(['\u{00A0}', '\u{202F}', '\u{3000}'], " ")
.replace("\r\n", "\n") .replace('\r', "\n") .replace(['\u{200B}', '\u{200C}', '\u{200D}', '\u{FEFF}'], "")
}
#[derive(Debug, Clone, Copy)]
pub struct SpecialTokenIds {
pub im_start: u32,
pub im_end: u32,
pub tts_pad: u32,
pub tts_bos: u32,
pub tts_eos: u32,
}
impl Default for SpecialTokenIds {
fn default() -> Self {
Self {
im_start: 151644,
im_end: 151645,
tts_pad: 151671,
tts_bos: 151672,
tts_eos: 151673,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum PaddingSide {
#[default]
Left,
Right,
}
#[derive(Debug, Clone)]
pub struct TokenizerOutput {
pub input_ids: Vec<Vec<u32>>,
pub attention_mask: Vec<Vec<u32>>,
pub lengths: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct TextProcessor {
tokenizer: Option<Tokenizer>,
pub special_tokens: SpecialTokenIds,
}
impl TextProcessor {
pub fn new() -> Self {
Self {
tokenizer: None,
special_tokens: SpecialTokenIds::default(),
}
}
pub fn with_special_tokens(special_tokens: SpecialTokenIds) -> Self {
Self {
tokenizer: None,
special_tokens,
}
}
pub fn from_pretrained<P: AsRef<Path>>(
model_dir: P,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let model_dir = model_dir.as_ref();
let tokenizer_json = model_dir.join("tokenizer.json");
if tokenizer_json.exists() {
tracing::debug!("Loading tokenizer from tokenizer.json (HuggingFace fast format)");
return Self::from_file(&tokenizer_json);
}
let vocab_json = model_dir.join("vocab.json");
let merges_txt = model_dir.join("merges.txt");
let tokenizer_config = model_dir.join("tokenizer_config.json");
if vocab_json.exists() {
tracing::debug!(
"Loading tokenizer from vocab.json + merges.txt + tokenizer_config.json"
);
return Self::from_bpe_files_with_config(
&vocab_json,
if merges_txt.exists() {
Some(&merges_txt)
} else {
None
},
if tokenizer_config.exists() {
Some(&tokenizer_config)
} else {
None
},
);
}
Err(format!(
"No tokenizer found in {}. Expected tokenizer.json or vocab.json",
model_dir.display()
)
.into())
}
pub fn from_bpe_files_with_config(
vocab_path: &Path,
merges_path: Option<&Path>,
config_path: Option<&Path>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
use std::fs;
use tokenizers::decoders::byte_level::ByteLevel as ByteLevelDecoder;
use tokenizers::models::bpe::BPE;
use tokenizers::normalizers::NFC;
use tokenizers::pre_tokenizers::byte_level::ByteLevel;
use tokenizers::pre_tokenizers::sequence::Sequence;
use tokenizers::pre_tokenizers::split::Split;
use tokenizers::{AddedToken, SplitDelimiterBehavior, TokenizerImpl};
tracing::debug!("Building Qwen2-style tokenizer (replicating tokenization_qwen2.py)");
let vocab_content = fs::read_to_string(vocab_path)?;
let vocab_std: std::collections::HashMap<String, u32> =
serde_json::from_str(&vocab_content)?;
let vocab: tokenizers::models::bpe::Vocab = vocab_std.into_iter().collect();
tracing::debug!(entries = vocab.len(), "Loaded vocab");
let merges: tokenizers::models::bpe::Merges = if let Some(merges_path) = merges_path {
let merges_content = fs::read_to_string(merges_path)?;
merges_content
.lines()
.skip(1) .filter(|l| !l.is_empty() && !l.starts_with('#'))
.filter_map(|line| {
let parts: Vec<&str> = line.split(' ').collect();
if parts.len() == 2 {
Some((parts[0].to_string(), parts[1].to_string()))
} else {
None
}
})
.collect()
} else {
vec![]
};
tracing::debug!(count = merges.len(), "Loaded merges");
let added_tokens: Vec<(u32, String, bool)> = if let Some(config_path) = config_path {
let config_content = fs::read_to_string(config_path)?;
let config: serde_json::Value = serde_json::from_str(&config_content)?;
if let Some(added_tokens_decoder) = config
.get("added_tokens_decoder")
.and_then(|v| v.as_object())
{
let mut tokens: Vec<(u32, String, bool)> = added_tokens_decoder
.iter()
.filter_map(|(id_str, token_info)| {
let id: u32 = id_str.parse().ok()?;
let content = token_info.get("content")?.as_str()?.to_string();
let special = token_info
.get("special")
.and_then(|v| v.as_bool())
.unwrap_or(false);
Some((id, content, special))
})
.collect();
tokens.sort_by_key(|(id, _, _)| *id);
tracing::debug!(count = tokens.len(), "Found added tokens");
tokens
} else {
vec![]
}
} else {
vec![]
};
let bpe = BPE::new(vocab, merges);
let pretokenize_regex = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+";
let split = Split::new(pretokenize_regex, SplitDelimiterBehavior::Isolated, false)?;
let byte_level = ByteLevel::new(false, true, false);
let pre_tokenizer = Sequence::new(vec![split.into(), byte_level.into()]);
use tokenizers::{
DecoderWrapper, NormalizerWrapper, PostProcessorWrapper, PreTokenizerWrapper,
};
type FullTokenizer = TokenizerImpl<
BPE,
NormalizerWrapper,
PreTokenizerWrapper,
PostProcessorWrapper,
DecoderWrapper,
>;
let mut tokenizer: FullTokenizer = TokenizerImpl::new(bpe);
tokenizer.with_normalizer(Some(NFC));
tokenizer.with_pre_tokenizer(Some(pre_tokenizer));
tokenizer.with_decoder(Some(ByteLevelDecoder::new(false, true, true)));
let mut tokenizer: Tokenizer = tokenizer.into();
if !added_tokens.is_empty() {
let special_tokens: Vec<AddedToken> = added_tokens
.iter()
.filter(|(_, _, special)| *special)
.map(|(_, content, _)| {
AddedToken::from(content.clone(), true)
.single_word(false)
.lstrip(false)
.rstrip(false)
.normalized(false)
})
.collect();
if !special_tokens.is_empty() {
tracing::debug!(count = special_tokens.len(), "Adding special tokens");
tokenizer.add_special_tokens(&special_tokens);
}
}
tracing::debug!("Tokenizer built successfully");
Ok(Self {
tokenizer: Some(tokenizer),
special_tokens: SpecialTokenIds::default(),
})
}
pub fn from_bpe_files<P: AsRef<Path>>(
vocab_path: P,
merges_path: Option<&Path>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
Self::from_bpe_files_with_config(vocab_path.as_ref(), merges_path, None)
}
pub fn from_file<P: AsRef<Path>>(
path: P,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let tokenizer = Tokenizer::from_file(path)?;
Ok(Self {
tokenizer: Some(tokenizer),
special_tokens: SpecialTokenIds::default(),
})
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let tokenizer = Tokenizer::from_bytes(bytes)?;
Ok(Self {
tokenizer: Some(tokenizer),
special_tokens: SpecialTokenIds::default(),
})
}
pub fn has_tokenizer(&self) -> bool {
self.tokenizer.is_some()
}
pub fn build_assistant_text(&self, text: &str) -> String {
format!(
"<|im_start|>assistant\n{}<|im_end|>\n<|im_start|>assistant\n",
text
)
}
pub fn build_ref_text(&self, text: &str) -> String {
format!("<|im_start|>assistant\n{}<|im_end|>\n", text)
}
pub fn build_instruct_text(&self, instruct: &str) -> String {
format!("<|im_start|>user\n{}<|im_end|>\n", instruct)
}
pub fn tokenize(&self, text: &str) -> Vec<u32> {
let normalized = normalize_text(text);
match &self.tokenizer {
Some(tokenizer) => {
let encoding = tokenizer
.encode(normalized.as_str(), false)
.unwrap_or_else(|e| {
panic!(
"Failed to encode text: {:?}\n\
This may indicate that special tokens like <|im_start|> or <|im_end|> \
are not properly defined in the tokenizer. \
Ensure you're using a valid tokenizer.json from the Qwen model.\n\
Text being tokenized (first 200 chars): {:?}",
e,
&text.chars().take(200).collect::<String>()
)
});
encoding.get_ids().to_vec()
}
None => {
tracing::warn!("No tokenizer loaded, returning empty token list");
vec![]
}
}
}
pub fn try_tokenize(
&self,
text: &str,
) -> Result<Vec<u32>, Box<dyn std::error::Error + Send + Sync>> {
let normalized = normalize_text(text);
match &self.tokenizer {
Some(tokenizer) => {
let encoding = tokenizer.encode(normalized.as_str(), false)?;
Ok(encoding.get_ids().to_vec())
}
None => Err("No tokenizer loaded".into()),
}
}
pub fn tokenize_for_tts(&self, text: &str) -> Vec<u32> {
let formatted = self.build_assistant_text(text);
let tokens = self.tokenize(&formatted);
if tracing::enabled!(tracing::Level::DEBUG) {
tracing::debug!(
input_text = %formatted,
token_count = tokens.len(),
token_ids = ?tokens,
"tokenize_for_tts"
);
if let Some(ref tokenizer) = self.tokenizer {
let decoded_tokens: Vec<String> = tokens
.iter()
.map(|&id| {
tokenizer
.decode(&[id], true)
.unwrap_or_else(|_| format!("<{}>", id))
})
.collect();
tracing::debug!(decoded_tokens = ?decoded_tokens, "tokenize_for_tts decoded");
}
}
tokens
}
pub fn tokenize_ref_text(&self, text: &str) -> Vec<u32> {
let formatted = self.build_ref_text(text);
self.tokenize(&formatted)
}
pub fn tokenize_instruct(&self, instruct: &str) -> Vec<u32> {
let formatted = self.build_instruct_text(instruct);
self.tokenize(&formatted)
}
pub fn add_tts_tokens(&self, token_ids: Vec<u32>) -> Vec<u32> {
let mut result = vec![self.special_tokens.tts_bos];
result.extend(token_ids);
result
}
pub fn decode(&self, token_ids: &[u32]) -> Option<String> {
self.tokenizer
.as_ref()
.map(|tokenizer| tokenizer.decode(token_ids, true).unwrap_or_default())
}
pub fn vocab_size(&self) -> Option<usize> {
self.tokenizer.as_ref().map(|t| t.get_vocab_size(true))
}
pub fn batch_tokenize(&self, texts: &[&str]) -> Vec<Vec<u32>> {
texts.iter().map(|text| self.tokenize(text)).collect()
}
pub fn try_batch_tokenize(
&self,
texts: &[&str],
) -> Result<Vec<Vec<u32>>, Box<dyn std::error::Error + Send + Sync>> {
if self.tokenizer.is_none() {
return Err("No tokenizer loaded".into());
}
texts.iter().map(|text| self.try_tokenize(text)).collect()
}
pub fn batch_tokenize_padded(
&self,
texts: &[&str],
padding_side: PaddingSide,
) -> Result<TokenizerOutput, Box<dyn std::error::Error + Send + Sync>> {
self.batch_tokenize_padded_with_pad_token(texts, padding_side, self.special_tokens.tts_pad)
}
pub fn batch_tokenize_padded_with_pad_token(
&self,
texts: &[&str],
padding_side: PaddingSide,
pad_token_id: u32,
) -> Result<TokenizerOutput, Box<dyn std::error::Error + Send + Sync>> {
let sequences = self.try_batch_tokenize(texts)?;
Ok(Self::pad_sequences(&sequences, padding_side, pad_token_id))
}
pub fn batch_tokenize_for_tts(
&self,
texts: &[&str],
padding_side: PaddingSide,
) -> Result<TokenizerOutput, Box<dyn std::error::Error + Send + Sync>> {
let formatted: Vec<String> = texts.iter().map(|t| self.build_assistant_text(t)).collect();
let formatted_refs: Vec<&str> = formatted.iter().map(|s| s.as_str()).collect();
self.batch_tokenize_padded(&formatted_refs, padding_side)
}
pub fn pad_sequences(
sequences: &[Vec<u32>],
padding_side: PaddingSide,
pad_token_id: u32,
) -> TokenizerOutput {
if sequences.is_empty() {
return TokenizerOutput {
input_ids: vec![],
attention_mask: vec![],
lengths: vec![],
};
}
let lengths: Vec<usize> = sequences.iter().map(|s| s.len()).collect();
let max_len = lengths.iter().copied().max().unwrap_or(0);
let mut input_ids = Vec::with_capacity(sequences.len());
let mut attention_mask = Vec::with_capacity(sequences.len());
for seq in sequences {
let pad_len = max_len - seq.len();
let mut padded_ids = Vec::with_capacity(max_len);
let mut mask = Vec::with_capacity(max_len);
match padding_side {
PaddingSide::Left => {
padded_ids.extend(std::iter::repeat_n(pad_token_id, pad_len));
padded_ids.extend(seq.iter().copied());
mask.extend(std::iter::repeat_n(0u32, pad_len));
mask.extend(std::iter::repeat_n(1u32, seq.len()));
}
PaddingSide::Right => {
padded_ids.extend(seq.iter().copied());
padded_ids.extend(std::iter::repeat_n(pad_token_id, pad_len));
mask.extend(std::iter::repeat_n(1u32, seq.len()));
mask.extend(std::iter::repeat_n(0u32, pad_len));
}
}
input_ids.push(padded_ids);
attention_mask.push(mask);
}
TokenizerOutput {
input_ids,
attention_mask,
lengths,
}
}
pub fn batch_decode(&self, batch_ids: &[&[u32]]) -> Vec<Option<String>> {
batch_ids.iter().map(|ids| self.decode(ids)).collect()
}
pub fn batch_decode_with_options(
&self,
batch_ids: &[&[u32]],
skip_special_tokens: bool,
) -> Vec<Option<String>> {
match &self.tokenizer {
Some(tokenizer) => batch_ids
.iter()
.map(|ids| tokenizer.decode(ids, skip_special_tokens).ok())
.collect(),
None => batch_ids.iter().map(|_| None).collect(),
}
}
}
impl Default for TextProcessor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_text() {
assert_eq!(normalize_text("\u{2018}hello\u{2019}"), "'hello'");
assert_eq!(normalize_text("\u{201C}world\u{201D}"), "\"world\"");
assert_eq!(normalize_text("one\u{2013}two"), "one-two"); assert_eq!(normalize_text("one\u{2014}two"), "one--two");
assert_eq!(normalize_text("wait\u{2026}"), "wait...");
assert_eq!(normalize_text("hello\u{00A0}world"), "hello world"); assert_eq!(normalize_text("hello\u{202F}world"), "hello world"); assert_eq!(normalize_text("hello\u{3000}world"), "hello world");
assert_eq!(normalize_text("a\r\nb"), "a\nb"); assert_eq!(normalize_text("a\rb"), "a\nb");
assert_eq!(normalize_text("hel\u{200B}lo"), "hello"); assert_eq!(normalize_text("hel\u{FEFF}lo"), "hello"); assert_eq!(normalize_text("hel\u{200C}lo"), "hello"); assert_eq!(normalize_text("hel\u{200D}lo"), "hello");
assert_eq!(
normalize_text("He said, \u{201C}It\u{2019}s\u{2014}well\u{2026}\u{201D}"),
"He said, \"It's--well...\""
);
assert_eq!(normalize_text("Hello, world!"), "Hello, world!");
}
#[test]
fn test_chat_templates() {
let processor = TextProcessor::new();
let assistant = processor.build_assistant_text("Hello world");
assert_eq!(
assistant,
"<|im_start|>assistant\nHello world<|im_end|>\n<|im_start|>assistant\n"
);
let ref_text = processor.build_ref_text("Reference");
assert_eq!(ref_text, "<|im_start|>assistant\nReference<|im_end|>\n");
let instruct = processor.build_instruct_text("Speak slowly");
assert_eq!(instruct, "<|im_start|>user\nSpeak slowly<|im_end|>\n");
}
#[test]
fn test_special_token_defaults() {
let tokens = SpecialTokenIds::default();
assert_eq!(tokens.im_start, 151644);
assert_eq!(tokens.im_end, 151645);
assert_eq!(tokens.tts_bos, 151672);
}
#[test]
fn test_tokenize_without_tokenizer() {
let processor = TextProcessor::new();
let tokens = processor.tokenize("Hello");
assert!(tokens.is_empty());
}
#[test]
fn test_add_tts_tokens() {
let processor = TextProcessor::new();
let tokens = vec![1, 2, 3];
let with_tts = processor.add_tts_tokens(tokens);
assert_eq!(with_tts, vec![151672, 1, 2, 3]);
}
#[test]
fn test_padding_side_default() {
let side = PaddingSide::default();
assert_eq!(side, PaddingSide::Left);
}
#[test]
fn test_pad_sequences_left() {
let sequences = vec![vec![1, 2], vec![3, 4, 5, 6], vec![7]];
let output = TextProcessor::pad_sequences(&sequences, PaddingSide::Left, 0);
assert_eq!(output.input_ids.len(), 3);
assert_eq!(output.input_ids[0], vec![0, 0, 1, 2]); assert_eq!(output.input_ids[1], vec![3, 4, 5, 6]); assert_eq!(output.input_ids[2], vec![0, 0, 0, 7]);
assert_eq!(output.attention_mask[0], vec![0, 0, 1, 1]);
assert_eq!(output.attention_mask[1], vec![1, 1, 1, 1]);
assert_eq!(output.attention_mask[2], vec![0, 0, 0, 1]);
assert_eq!(output.lengths, vec![2, 4, 1]);
}
#[test]
fn test_pad_sequences_right() {
let sequences = vec![vec![1, 2], vec![3, 4, 5]];
let output = TextProcessor::pad_sequences(&sequences, PaddingSide::Right, 99);
assert_eq!(output.input_ids[0], vec![1, 2, 99]); assert_eq!(output.input_ids[1], vec![3, 4, 5]);
assert_eq!(output.attention_mask[0], vec![1, 1, 0]);
assert_eq!(output.attention_mask[1], vec![1, 1, 1]);
}
#[test]
fn test_pad_sequences_empty() {
let sequences: Vec<Vec<u32>> = vec![];
let output = TextProcessor::pad_sequences(&sequences, PaddingSide::Left, 0);
assert!(output.input_ids.is_empty());
assert!(output.attention_mask.is_empty());
assert!(output.lengths.is_empty());
}
#[test]
fn test_pad_sequences_single() {
let sequences = vec![vec![1, 2, 3]];
let output = TextProcessor::pad_sequences(&sequences, PaddingSide::Left, 0);
assert_eq!(output.input_ids.len(), 1);
assert_eq!(output.input_ids[0], vec![1, 2, 3]); assert_eq!(output.attention_mask[0], vec![1, 1, 1]);
assert_eq!(output.lengths, vec![3]);
}
#[test]
fn test_batch_tokenize_without_tokenizer() {
let processor = TextProcessor::new();
let results = processor.batch_tokenize(&["Hello", "World"]);
assert_eq!(results.len(), 2);
assert!(results[0].is_empty());
assert!(results[1].is_empty());
}
#[test]
fn test_try_batch_tokenize_without_tokenizer() {
let processor = TextProcessor::new();
let result = processor.try_batch_tokenize(&["Hello", "World"]);
assert!(result.is_err());
}
#[test]
fn test_batch_decode_without_tokenizer() {
let processor = TextProcessor::new();
let ids: &[&[u32]] = &[&[1, 2, 3], &[4, 5]];
let results = processor.batch_decode(ids);
assert_eq!(results.len(), 2);
assert!(results[0].is_none());
assert!(results[1].is_none());
}
#[test]
fn test_tokenizer_output_structure() {
let output = TokenizerOutput {
input_ids: vec![vec![1, 2], vec![3, 4]],
attention_mask: vec![vec![1, 1], vec![1, 1]],
lengths: vec![2, 2],
};
assert_eq!(output.input_ids.len(), 2);
assert_eq!(output.attention_mask.len(), 2);
assert_eq!(output.lengths.len(), 2);
}
}