use std::collections::HashMap;
use crate::error::{Result, RuvLLMError};
pub struct SpecialTokens {
pub bos_id: u32,
pub eos_id: u32,
pub pad_id: u32,
pub unk_id: u32,
}
impl Default for SpecialTokens {
fn default() -> Self {
Self {
bos_id: 1,
eos_id: 2,
pad_id: 0,
unk_id: 3,
}
}
}
pub struct BpeTokenizer {
vocab: Vec<String>,
token_to_id: HashMap<String, u32>,
merges: Vec<(String, String)>,
special_tokens: SpecialTokens,
}
impl BpeTokenizer {
pub fn from_vocab(
tokens: Vec<String>,
merges: Vec<(String, String)>,
special: SpecialTokens,
) -> Self {
let mut token_to_id = HashMap::with_capacity(tokens.len());
for (id, tok) in tokens.iter().enumerate() {
token_to_id.insert(tok.clone(), id as u32);
}
Self {
vocab: tokens,
token_to_id,
merges,
special_tokens: special,
}
}
pub fn encode(&self, text: &str) -> Vec<u32> {
if text.is_empty() {
return vec![self.special_tokens.bos_id];
}
let bytes = text.as_bytes();
let mut symbols: Vec<String> = bytes.iter().map(|&b| self.byte_to_token(b)).collect();
for (left, right) in &self.merges {
let merged = format!("{}{}", left, right);
if !self.token_to_id.contains_key(&merged) {
continue;
}
let mut i = 0;
while i + 1 < symbols.len() {
if symbols[i] == *left && symbols[i + 1] == *right {
symbols[i] = merged.clone();
symbols.remove(i + 1);
} else {
i += 1;
}
}
}
let mut ids = Vec::with_capacity(symbols.len() + 1);
ids.push(self.special_tokens.bos_id);
for sym in &symbols {
let id = self
.token_to_id
.get(sym)
.copied()
.unwrap_or(self.special_tokens.unk_id);
ids.push(id);
}
ids
}
pub fn decode(&self, ids: &[u32]) -> String {
let mut bytes = Vec::new();
for &id in ids {
if id == self.special_tokens.bos_id
|| id == self.special_tokens.eos_id
|| id == self.special_tokens.pad_id
{
continue;
}
if let Some(token_str) = self.vocab.get(id as usize) {
let token_bytes = self.token_to_bytes(token_str);
bytes.extend_from_slice(&token_bytes);
}
}
String::from_utf8(bytes)
.unwrap_or_else(|e| String::from_utf8_lossy(e.as_bytes()).into_owned())
}
pub fn vocab_size(&self) -> usize {
self.vocab.len()
}
fn byte_to_token(&self, byte: u8) -> String {
let hex_token = format!("<{:02X}>", byte);
if self.token_to_id.contains_key(&hex_token) {
return hex_token;
}
let char_token = String::from(byte as char);
if self.token_to_id.contains_key(&char_token) {
return char_token;
}
hex_token
}
fn token_to_bytes(&self, token: &str) -> Vec<u8> {
let mut result = Vec::new();
let mut chars = token.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '<' {
let mut hex = String::new();
let mut found_close = false;
for c in chars.by_ref() {
if c == '>' {
found_close = true;
break;
}
hex.push(c);
}
if found_close && hex.len() == 2 {
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
result.push(byte);
continue;
}
}
result.push(b'<');
result.extend_from_slice(hex.as_bytes());
if found_close {
result.push(b'>');
}
} else {
let mut buf = [0u8; 4];
let encoded = ch.encode_utf8(&mut buf);
result.extend_from_slice(encoded.as_bytes());
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_tokenizer(merges: Vec<(String, String)>, extra_tokens: Vec<String>) -> BpeTokenizer {
let mut vocab = vec![
"<PAD>".to_string(), "<BOS>".to_string(), "<EOS>".to_string(), "<UNK>".to_string(), ];
for b in 0..=255u8 {
vocab.push(format!("<{:02X}>", b));
}
for tok in extra_tokens {
vocab.push(tok);
}
BpeTokenizer::from_vocab(vocab, merges, SpecialTokens::default())
}
#[test]
fn test_roundtrip_ascii() {
let tok = test_tokenizer(vec![], vec![]);
let text = "Hello, world!";
let ids = tok.encode(text);
let decoded = tok.decode(&ids);
assert_eq!(decoded, text, "ASCII roundtrip failed");
}
#[test]
fn test_roundtrip_utf8() {
let tok = test_tokenizer(vec![], vec![]);
let text = "cafe\u{0301}"; let ids = tok.encode(text);
let decoded = tok.decode(&ids);
assert_eq!(decoded, text, "UTF-8 roundtrip failed");
}
#[test]
fn test_bos_prepended() {
let tok = test_tokenizer(vec![], vec![]);
let ids = tok.encode("A");
assert_eq!(ids[0], 1, "First token should be BOS (id=1)");
assert!(ids.len() >= 2, "Should have at least BOS + one token");
}
#[test]
fn test_eos_handling() {
let tok = test_tokenizer(vec![], vec![]);
let ids = vec![1, 4 + b'H' as u32, 4 + b'i' as u32, 2]; let decoded = tok.decode(&ids);
assert_eq!(decoded, "Hi", "EOS should be skipped in decode");
}
#[test]
fn test_unknown_token() {
let tok = test_tokenizer(vec![], vec![]);
let ids = vec![99999]; let decoded = tok.decode(&ids);
assert_eq!(decoded, "", "Unknown ID should produce empty output");
}
#[test]
fn test_empty_string() {
let tok = test_tokenizer(vec![], vec![]);
let ids = tok.encode("");
assert_eq!(ids, vec![1], "Empty string should encode to just BOS");
let decoded = tok.decode(&ids);
assert_eq!(decoded, "", "Decoding just BOS should give empty string");
}
#[test]
fn test_single_char() {
let tok = test_tokenizer(vec![], vec![]);
let ids = tok.encode("A");
assert_eq!(ids.len(), 2, "Single char should give BOS + 1 token");
assert_eq!(ids[0], 1, "First should be BOS");
let decoded = tok.decode(&ids);
assert_eq!(decoded, "A");
}
#[test]
fn test_bpe_merge_application() {
let merged_token = "<48><65>".to_string();
let merges = vec![("<48>".to_string(), "<65>".to_string())];
let tok = test_tokenizer(merges, vec![merged_token.clone()]);
let ids = tok.encode("He");
assert_eq!(
ids.len(),
2,
"Merge should reduce 'He' to BOS + 1 merged token"
);
}
#[test]
fn test_bpe_merge_multiple_occurrences() {
let merged_token = "<61><62>".to_string(); let merges = vec![("<61>".to_string(), "<62>".to_string())];
let tok = test_tokenizer(merges, vec![merged_token]);
let ids = tok.encode("ababab");
assert_eq!(ids.len(), 4, "Should merge all 'ab' pairs");
}
#[test]
fn test_vocab_size() {
let tok = test_tokenizer(vec![], vec![]);
assert_eq!(
tok.vocab_size(),
4 + 256,
"Should have 4 special + 256 byte tokens"
);
}
#[test]
fn test_decode_skips_pad() {
let tok = test_tokenizer(vec![], vec![]);
let ids = vec![0, 1, 4 + b'X' as u32, 0, 0]; let decoded = tok.decode(&ids);
assert_eq!(decoded, "X", "PAD and BOS should be skipped");
}
}