use std::collections::HashSet;
use tracing::debug;
use crate::{
bpe::{bpe_encode, byte_fallback_id, pretokenize, BpeMerges},
error::{TokenizerError, TokenizerResult},
vocab::Vocabulary,
};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct TokenizerConfig {
pub add_bos: bool,
pub add_eos: bool,
pub bos_token_id: u32,
pub eos_token_id: u32,
pub unk_token_id: u32,
pub pad_token_id: u32,
pub max_length: Option<usize>,
pub byte_level_decode: bool,
}
impl Default for TokenizerConfig {
fn default() -> Self {
Self {
add_bos: false,
add_eos: false,
bos_token_id: 1,
eos_token_id: 2,
unk_token_id: 0,
pad_token_id: 3,
max_length: None,
byte_level_decode: false,
}
}
}
pub struct OxiTokenizer {
vocab: Vocabulary,
merges: BpeMerges,
config: TokenizerConfig,
special_ids: HashSet<u32>,
}
impl OxiTokenizer {
pub fn new(vocab: Vocabulary, merges: BpeMerges, config: TokenizerConfig) -> Self {
let special_ids = build_special_ids(&config);
Self {
vocab,
merges,
config,
special_ids,
}
}
pub fn encode(&self, text: &str) -> TokenizerResult<Vec<u32>> {
debug!(text_len = text.len(), "encoding text");
let mut ids: Vec<u32> = Vec::new();
if self.config.add_bos {
ids.push(self.config.bos_token_id);
}
let words = pretokenize(text);
for word in &words {
let word_ids = bpe_encode(word, &self.vocab, &self.merges);
if word_ids.is_empty() {
for byte in word.as_bytes() {
let fallback = byte_fallback_id(*byte);
let fallback_id = self.vocab.get_id(&fallback);
ids.push(fallback_id.unwrap_or(self.config.unk_token_id));
}
} else {
ids.extend_from_slice(&word_ids);
}
}
if self.config.add_eos {
ids.push(self.config.eos_token_id);
}
if let Some(max) = self.config.max_length {
ids.truncate(max);
}
Ok(ids)
}
pub fn encode_batch(&self, texts: &[&str]) -> TokenizerResult<Vec<Vec<u32>>> {
texts.iter().map(|t| self.encode(t)).collect()
}
pub fn decode(&self, ids: &[u32]) -> TokenizerResult<String> {
let bytes = self.decode_to_bytes(ids);
String::from_utf8(bytes).map_err(|e| TokenizerError::DecodeFailed(e.to_string()))
}
pub(crate) fn decode_to_bytes(&self, ids: &[u32]) -> Vec<u8> {
let mut bytes: Vec<u8> = Vec::with_capacity(ids.len() * 2);
for &id in ids {
self.decode_id_into(id, &mut bytes);
}
bytes
}
pub(crate) fn decode_id_into(&self, id: u32, bytes: &mut Vec<u8>) {
if self.special_ids.contains(&id) {
return;
}
let token = match self.vocab.get_token(id) {
Some(t) => t,
None => {
bytes.extend_from_slice("\u{FFFD}".as_bytes());
return;
}
};
if let Some(byte) = parse_byte_fallback(token) {
bytes.push(byte);
return;
}
if self.config.byte_level_decode {
for ch in token.chars() {
if let Some(b) = crate::hf_format::unicode_to_byte(ch) {
bytes.push(b);
} else {
let mut buf = [0u8; 4];
let s = ch.encode_utf8(&mut buf);
bytes.extend_from_slice(s.as_bytes());
}
}
} else {
let stripped = token.trim_start_matches('\u{0120}');
if token.starts_with('\u{0120}') && !bytes.is_empty() {
bytes.push(b' ');
}
bytes.extend_from_slice(stripped.as_bytes());
}
}
pub fn decode_token(&self, id: u32) -> TokenizerResult<String> {
self.vocab
.get_token(id)
.map(|s| s.to_owned())
.ok_or_else(|| TokenizerError::DecodeFailed(format!("unknown token id {id}")))
}
pub fn vocab_size(&self) -> usize {
self.vocab.size()
}
pub fn from_json(
vocab_json: &str,
merges_json: &str,
config: TokenizerConfig,
) -> TokenizerResult<Self> {
let vocab = Vocabulary::from_json(vocab_json)?;
let raw_merges: Vec<(String, String)> = serde_json::from_str(merges_json)
.map_err(|e| TokenizerError::InvalidJson(e.to_string()))?;
let mut merges = BpeMerges::new();
for (a, b) in &raw_merges {
let merged = format!("{a}{b}");
let result_id = vocab.get_id(&merged).ok_or_else(|| {
TokenizerError::InvalidVocab(format!("merged token {merged:?} not in vocabulary"))
})?;
merges.add_merge(a, b, result_id);
}
Ok(Self::new(vocab, merges, config))
}
pub fn from_json_file(path: impl AsRef<std::path::Path>) -> TokenizerResult<Self> {
let json = std::fs::read_to_string(path)?;
Self::from_hf_tokenizer_json(&json)
}
pub fn from_hf_tokenizer_json(json: &str) -> TokenizerResult<Self> {
let parsed = crate::hf_format::HfTokenizerJson::parse(json)?;
parsed.into_tokenizer()
}
pub fn streaming_decoder(&self) -> crate::streaming::StreamingDecoder<'_> {
crate::streaming::StreamingDecoder::new(self)
}
pub fn config(&self) -> &TokenizerConfig {
&self.config
}
pub fn vocab(&self) -> &Vocabulary {
&self.vocab
}
pub fn merges(&self) -> &BpeMerges {
&self.merges
}
pub fn char_level_stub(vocab_size: usize) -> Self {
assert!(
vocab_size >= 4,
"char_level_stub requires vocab_size >= 4 for special tokens"
);
let mut vocab = Vocabulary::new();
vocab.add_special("<unk>", 0);
vocab.add_special("<bos>", 1);
vocab.add_special("<eos>", 2);
vocab.add_special("<pad>", 3);
let mut next_id = 4u32;
for byte in 0x20u8..=0x7Eu8 {
if next_id as usize >= vocab_size {
break;
}
let ch = char::from(byte);
vocab.insert(&ch.to_string(), next_id);
next_id += 1;
}
for byte in 0u8..=255u8 {
if next_id as usize >= vocab_size {
break;
}
let fallback = byte_fallback_id(byte);
if vocab.get_id(&fallback).is_none() {
vocab.insert(&fallback, next_id);
next_id += 1;
}
}
let config = TokenizerConfig {
add_bos: false,
add_eos: false,
bos_token_id: 1,
eos_token_id: 2,
unk_token_id: 0,
pad_token_id: 3,
max_length: None,
byte_level_decode: false,
};
let merges = BpeMerges::new();
Self::new(vocab, merges, config)
}
pub fn bos_id(&self) -> u32 {
self.config.bos_token_id
}
pub fn eos_id(&self) -> u32 {
self.config.eos_token_id
}
pub fn is_special(&self, id: u32) -> bool {
self.special_ids.contains(&id)
}
}
fn build_special_ids(config: &TokenizerConfig) -> HashSet<u32> {
let mut set = HashSet::new();
set.insert(config.bos_token_id);
set.insert(config.eos_token_id);
set.insert(config.unk_token_id);
set.insert(config.pad_token_id);
set
}
fn parse_byte_fallback(token: &str) -> Option<u8> {
let inner = token.strip_prefix("<0x")?.strip_suffix('>')?;
if inner.len() != 2 {
return None;
}
u8::from_str_radix(inner, 16).ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn char_level_stub_encode_ascii() {
let tok = OxiTokenizer::char_level_stub(200);
let ids = tok.encode("ab").expect("encode should succeed");
assert_eq!(ids.len(), 2);
assert_ne!(ids[0], 0); assert_ne!(ids[1], 0);
assert_ne!(ids[0], ids[1]); }
#[test]
fn char_level_stub_bos_eos() {
let mut tok = OxiTokenizer::char_level_stub(200);
tok.config.add_bos = true;
tok.config.add_eos = true;
tok.special_ids = build_special_ids(&tok.config);
let ids = tok.encode("hi").expect("encode should succeed");
assert_eq!(ids[0], 1); assert_eq!(*ids.last().expect("must have last element"), 2); }
#[test]
fn char_level_stub_vocab_size() {
let tok = OxiTokenizer::char_level_stub(50);
assert!(tok.vocab_size() <= 50);
assert!(tok.vocab_size() >= 4); }
#[test]
fn special_token_detection() {
let tok = OxiTokenizer::char_level_stub(200);
assert!(tok.is_special(0)); assert!(tok.is_special(1)); assert!(tok.is_special(2)); assert!(tok.is_special(3)); assert!(!tok.is_special(4)); }
#[test]
fn bos_eos_ids_match_config() {
let tok = OxiTokenizer::char_level_stub(200);
assert_eq!(tok.bos_id(), 1);
assert_eq!(tok.eos_id(), 2);
}
#[test]
fn decode_token_roundtrip() {
let tok = OxiTokenizer::char_level_stub(200);
let ids = tok.encode("a").expect("should encode");
if let Some(&id) = ids.first() {
let s = tok.decode_token(id).expect("decode_token should succeed");
assert_eq!(s, "a");
}
}
#[test]
fn decode_unknown_id_returns_error() {
let tok = OxiTokenizer::char_level_stub(50);
let result = tok.decode_token(99_999);
assert!(result.is_err());
}
#[test]
fn max_length_truncates() {
let mut tok = OxiTokenizer::char_level_stub(200);
tok.config.max_length = Some(3);
tok.special_ids = build_special_ids(&tok.config);
let ids = tok.encode("hello world").expect("encode should succeed");
assert!(ids.len() <= 3);
}
#[test]
fn encode_batch_consistency() {
let tok = OxiTokenizer::char_level_stub(200);
let texts = ["ab", "cd", "ef"];
let batch = tok
.encode_batch(&texts)
.expect("batch encode should succeed");
assert_eq!(batch.len(), 3);
for (i, ids) in batch.iter().enumerate() {
let single = tok.encode(texts[i]).expect("single encode should succeed");
assert_eq!(*ids, single);
}
}
#[test]
fn parse_byte_fallback_valid() {
assert_eq!(parse_byte_fallback("<0x41>"), Some(0x41));
assert_eq!(parse_byte_fallback("<0x00>"), Some(0x00));
assert_eq!(parse_byte_fallback("<0xFF>"), Some(0xFF));
}
#[test]
fn parse_byte_fallback_invalid() {
assert_eq!(parse_byte_fallback("hello"), None);
assert_eq!(parse_byte_fallback("<0x>"), None);
assert_eq!(parse_byte_fallback("<0x1>"), None);
}
#[test]
fn from_json_roundtrip() {
let vocab_json = r#"{"a":10,"b":11,"ab":20,"<unk>":0,"<bos>":1,"<eos>":2,"<pad>":3}"#;
let merges_json = r#"[["a","b"]]"#;
let config = TokenizerConfig::default();
let tok = OxiTokenizer::from_json(vocab_json, merges_json, config)
.expect("from_json should succeed");
assert_eq!(tok.vocab_size(), 7);
let ids = tok.encode("ab").expect("encode should succeed");
assert!(ids.contains(&20));
}
}