use crate::debug::debug_tokenizer;
use crate::token_list::TOKENS;
use std::sync::OnceLock;
struct TrieNode {
token_id: Option<usize>,
children: [Option<Box<TrieNode>>; 256],
}
impl TrieNode {
#[inline]
fn new() -> Self {
const INIT: Option<Box<TrieNode>> = None;
Self {
token_id: None,
children: [INIT; 256],
}
}
}
struct TokenTrie {
root: TrieNode,
}
impl TokenTrie {
fn new() -> Self {
debug_tokenizer!("Building token trie with {} tokens...", TOKENS.len());
let mut trie = Self {
root: TrieNode::new(),
};
for (idx, token) in TOKENS.iter().enumerate() {
trie.insert(token, idx);
}
debug_tokenizer!("Token trie built successfully");
trie
}
fn insert(&mut self, token: &[u8], token_id: usize) {
let mut node = &mut self.root;
for &byte in token {
node = node.children[byte as usize].get_or_insert_with(|| Box::new(TrieNode::new()));
}
node.token_id = Some(token_id);
}
#[inline]
fn find_longest_match(&self, text: &[u8], start: usize) -> Option<(usize, usize)> {
let mut node = &self.root;
let mut last_match = None;
let mut pos = start;
while pos < text.len() {
if let Some(next_node) = &node.children[text[pos] as usize] {
node = next_node;
pos += 1;
if let Some(token_id) = node.token_id {
last_match = Some((token_id, pos - start));
}
} else {
break;
}
}
last_match
}
}
pub struct SimpleTokenizer {
trie: TokenTrie,
}
impl SimpleTokenizer {
#[inline]
pub fn global() -> &'static Self {
static INSTANCE: OnceLock<SimpleTokenizer> = OnceLock::new();
INSTANCE.get_or_init(|| Self {
trie: TokenTrie::new(),
})
}
pub fn encode(&self, text: &[u8]) -> Result<Vec<usize>, String> {
debug_tokenizer!("Encoding {} bytes...", text.len());
let mut indices = Vec::with_capacity(text.len() / 2);
let mut pos = 0;
while pos < text.len() {
match self.trie.find_longest_match(text, pos) {
Some((token_id, len)) => {
debug_tokenizer!(" pos {}: matched token {} (len {})", pos, token_id, len);
indices.push(token_id);
pos += len;
}
None => {
let context_start = pos.saturating_sub(10);
let context_end = (pos + 10).min(text.len());
let context = &text[context_start..context_end];
debug_tokenizer!(
" pos {}: FAILED to match byte {:02x} ('{}')",
pos,
text[pos],
if text[pos].is_ascii_graphic() || text[pos] == b' ' {
text[pos] as char
} else {
'?'
}
);
debug_tokenizer!(" Context: {:?}", String::from_utf8_lossy(context));
debug_tokenizer!(" Hex context: {:02x?}", context);
debug_tokenizer!(" Position {} out of {} total bytes", pos, text.len());
return Err(format!(
"Cannot tokenize at position {}: byte {:02x} ('{}')\nContext: {:?}",
pos,
text[pos],
if text[pos].is_ascii_graphic() || text[pos] == b' ' {
text[pos] as char
} else {
'?'
},
String::from_utf8_lossy(context)
));
}
}
}
debug_tokenizer!("Encoded to {} tokens", indices.len());
Ok(indices)
}
pub fn decode(&self, indices: &[usize]) -> Result<Vec<u8>, String> {
debug_tokenizer!("Decoding {} tokens...", indices.len());
let total_size: usize = indices
.iter()
.filter_map(|&idx| TOKENS.get(idx).map(|t| t.len()))
.sum();
let mut result = Vec::with_capacity(total_size);
for &idx in indices {
if let Some(&token) = TOKENS.get(idx) {
debug_tokenizer!(" token {}: {} bytes", idx, token.len());
result.extend_from_slice(token);
} else {
debug_tokenizer!(" token {}: INVALID INDEX", idx);
return Err(format!("Invalid token index: {}", idx));
}
}
debug_tokenizer!("Decoded to {} bytes", result.len());
Ok(result)
}
pub fn decode_to_string(&self, indices: &[usize]) -> Result<String, String> {
debug_tokenizer!("Decoding {} tokens to string...", indices.len());
let bytes = self.decode(indices)?;
match String::from_utf8(bytes) {
Ok(text) => {
debug_tokenizer!("Decoded to string: {} chars", text.len());
Ok(text)
}
Err(e) => {
debug_tokenizer!("Failed to decode to UTF-8: {}", e);
Err(format!("Invalid UTF-8 in decoded output: {}", e))
}
}
}
}
#[inline]
pub fn encode(text: &[u8]) -> Result<Vec<usize>, String> {
SimpleTokenizer::global().encode(text)
}
#[inline]
pub fn decode(indices: &[usize]) -> Result<Vec<u8>, String> {
SimpleTokenizer::global().decode(indices)
}
#[inline]
pub fn decode_to_string(indices: &[usize]) -> Result<String, String> {
SimpleTokenizer::global().decode_to_string(indices)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_singleton_access() {
let t1 = SimpleTokenizer::global();
let t2 = SimpleTokenizer::global();
assert!(std::ptr::eq(t1, t2), "Should be the same instance");
}
#[test]
fn test_empty_input() {
assert_eq!(encode(b"").unwrap(), Vec::<usize>::new());
assert_eq!(decode(&[]).unwrap(), Vec::<u8>::new());
assert_eq!(decode_to_string(&[]).unwrap(), "");
}
#[test]
fn test_encode_decode_roundtrip() {
let text = b"Hello world";
match encode(text) {
Ok(indices) => {
println!("Encoded to {} tokens: {:?}", indices.len(), indices);
let decoded = decode(&indices[..]).expect("Decode failed");
println!("Decoded: {:?}", String::from_utf8_lossy(&decoded[..]));
assert!(!decoded.is_empty(), "Decoded output should not be empty");
}
Err(e) => {
println!("Cannot encode (vocab might not cover this text): {}", e);
}
}
}
#[test]
fn test_decode_to_string_basic() {
let indices = [0, 1, 2];
match decode_to_string(&indices) {
Ok(text) => {
println!("Decoded to string: '{}'", text);
assert!(!text.is_empty() || indices.is_empty());
}
Err(e) => println!("Error: {}", e),
}
}
#[test]
fn test_method_and_function_equivalence() {
let text = b"test";
let tokenizer = SimpleTokenizer::global();
if let (Ok(e1), Ok(e2)) = (tokenizer.encode(text), encode(text)) {
assert_eq!(e1, e2, "Method and function should give same result");
}
}
#[test]
fn test_decode_invalid_token_index() {
let invalid_indices = [usize::MAX];
let result = decode(&invalid_indices);
assert!(result.is_err(), "Should error on invalid token index");
assert!(result.unwrap_err().contains("Invalid token index"));
}
#[test]
fn test_decode_partially_invalid() {
let vocab_size = TOKENS.len();
let indices = [0, vocab_size + 1000];
let result = decode(&indices);
assert!(result.is_err(), "Should error on any invalid token");
}
#[test]
fn test_encode_untokenizable_input() {
let weird_bytes = b"\xFF\xFE\xFD\xFC";
let result = encode(weird_bytes);
if let Err(error) = result {
println!("Expected error: {}", error);
assert!(
error.contains("Cannot tokenize"),
"Error should explain tokenization failure"
);
assert!(error.contains("position"), "Error should include position");
}
}
#[test]
fn test_single_byte_tokens() {
for (i, _item) in TOKENS.iter().enumerate().take(TOKENS.len().min(10)) {
if TOKENS[i].len() == 1 {
let result = decode(&[i]);
assert!(result.is_ok(), "Single-byte token should decode");
assert_eq!(result.unwrap().len(), 1);
}
}
}
#[test]
fn test_long_token_sequence() {
let long_sequence: Vec<usize> = (0..100).map(|i| i % TOKENS.len().min(10)).collect();
let result = decode(&long_sequence[..]);
assert!(result.is_ok(), "Should handle long sequences");
assert!(!result.unwrap().is_empty());
}
#[test]
fn test_repeated_tokens() {
let indices = [0, 0, 0, 1, 1, 2];
let result = decode(&indices);
assert!(result.is_ok(), "Should handle repeated tokens");
}
#[test]
fn test_encode_performance_scales_linearly() {
let small_text = b"a".repeat(100);
let large_text = b"a".repeat(1000);
let small_result = encode(&small_text[..]);
let large_result = encode(&large_text[..]);
assert_eq!(
small_result.is_ok(),
large_result.is_ok(),
"Both sizes should have same success/failure"
);
}
#[test]
fn test_decode_allocates_correct_size() {
let indices = [0, 1, 2, 3, 4];
let result = decode(&indices);
if let Ok(bytes) = result {
let expected_len: usize = indices
.iter()
.filter_map(|&i| TOKENS.get(i).map(|t| t.len()))
.sum();
assert_eq!(bytes.len(), expected_len);
}
}
#[test]
fn test_concurrent_access() {
use std::thread;
let handles: Vec<_> = (0..4)
.map(|_| {
thread::spawn(|| {
let tokenizer = SimpleTokenizer::global();
let _ = tokenizer.encode(b"test");
})
})
.collect();
for handle in handles {
handle.join().expect("Thread should not panic");
}
}
#[test]
fn test_longest_match_preference() {
let text = b"aaa";
let result1 = encode(text);
let result2 = encode(text);
if let (Ok(tokens1), Ok(tokens2)) = (result1, result2) {
assert_eq!(tokens1, tokens2, "Same input should produce same tokens");
}
}
#[test]
fn test_no_partial_matches() {
let invalid = b"\xFF\xFF";
let result = encode(invalid);
if result.is_err() {
println!("Correctly rejected invalid input");
}
}
}