use std::collections::HashMap;
use crate::error::{LmError, LmResult};
use crate::tokenizer::vocab::Vocab;
#[derive(Debug, Clone)]
pub struct BpeTokenizer {
vocab: Vocab,
merge_ranks: HashMap<(u32, u32), u32>,
pair_to_merged: HashMap<(u32, u32), u32>,
}
impl BpeTokenizer {
pub fn new(vocab: Vocab, merges: Vec<(Vec<u8>, Vec<u8>)>) -> LmResult<Self> {
let n_merges = merges.len();
let mut merge_ranks = HashMap::with_capacity(n_merges);
let mut pair_to_merged = HashMap::with_capacity(n_merges);
for (rank, (left_bytes, right_bytes)) in merges.into_iter().enumerate() {
let a = vocab
.bytes_to_id(&left_bytes)
.ok_or(LmError::InvalidMergePair {
a: u32::MAX,
b: u32::MAX,
})?;
let b = vocab
.bytes_to_id(&right_bytes)
.ok_or(LmError::InvalidMergePair { a, b: u32::MAX })?;
let mut merged_bytes = left_bytes.clone();
merged_bytes.extend_from_slice(&right_bytes);
let merged = vocab
.bytes_to_id(&merged_bytes)
.ok_or(LmError::InvalidMergePair { a, b })?;
merge_ranks.insert((a, b), rank as u32);
pair_to_merged.insert((a, b), merged);
}
Ok(Self {
vocab,
merge_ranks,
pair_to_merged,
})
}
pub fn vocab_size(&self) -> usize {
self.vocab.size()
}
pub fn special_id(&self, name: &str) -> Option<u32> {
self.vocab.special_id(name)
}
pub fn encode(&self, text: &str) -> LmResult<Vec<u32>> {
if text.is_empty() {
return Ok(vec![]);
}
let mut tokens: Vec<u32> = text
.as_bytes()
.iter()
.map(|&b| {
self.vocab
.bytes_to_id(&[b])
.ok_or(LmError::OutOfVocab { token: b as u32 })
})
.collect::<LmResult<Vec<u32>>>()?;
loop {
let mut best_rank = u32::MAX;
let mut best_pos: Option<usize> = None;
for i in 0..tokens.len().saturating_sub(1) {
let pair = (tokens[i], tokens[i + 1]);
if let Some(&rank) = self.merge_ranks.get(&pair) {
if rank < best_rank {
best_rank = rank;
best_pos = Some(i);
}
}
}
match best_pos {
None => break, Some(pos) => {
let pair = (tokens[pos], tokens[pos + 1]);
let merged =
*self
.pair_to_merged
.get(&pair)
.ok_or(LmError::InvalidMergePair {
a: pair.0,
b: pair.1,
})?;
tokens[pos] = merged;
tokens.remove(pos + 1);
}
}
}
Ok(tokens)
}
pub fn encode_with_special(
&self,
text: &str,
bos_name: Option<&str>,
eos_name: Option<&str>,
) -> LmResult<Vec<u32>> {
let mut ids = Vec::new();
if let Some(name) = bos_name {
if let Some(id) = self.vocab.special_id(name) {
ids.push(id);
}
}
ids.extend(self.encode(text)?);
if let Some(name) = eos_name {
if let Some(id) = self.vocab.special_id(name) {
ids.push(id);
}
}
Ok(ids)
}
pub fn decode(&self, ids: &[u32]) -> LmResult<String> {
let mut bytes = Vec::new();
for &id in ids {
let tok_bytes = self.vocab.id_to_bytes(id)?;
bytes.extend_from_slice(tok_bytes);
}
String::from_utf8(bytes).map_err(|_| {
LmError::Utf8Decode {
token: ids.first().copied().unwrap_or(0),
}
})
}
pub fn decode_one(&self, id: u32) -> LmResult<String> {
self.decode(&[id])
}
pub fn vocab(&self) -> &Vocab {
&self.vocab
}
}
#[derive(Debug, Default)]
pub struct BpeBuilder {
merges: Vec<(Vec<u8>, Vec<u8>)>,
extra_tokens: Vec<Vec<u8>>,
special: HashMap<String, u32>,
}
impl BpeBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn add_merge(mut self, left: &[u8], right: &[u8]) -> Self {
let mut merged = left.to_vec();
merged.extend_from_slice(right);
self.extra_tokens.push(merged);
self.merges.push((left.to_vec(), right.to_vec()));
self
}
pub fn add_special(mut self, name: &str, id: u32) -> Self {
self.special.insert(name.into(), id);
self
}
pub fn build(self) -> LmResult<BpeTokenizer> {
let base = Vocab::gpt2_byte_vocab();
let vocab = base.with_extra_tokens(self.extra_tokens, self.special)?;
BpeTokenizer::new(vocab, self.merges)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn minimal_tokenizer() -> BpeTokenizer {
let tokens: Vec<Vec<u8>> = vec![
vec![b'a'], vec![b'b'], vec![b'c'], vec![b'd'], vec![b'a', b'b'], vec![b'c', b'd'], vec![b'a', b'b', b'c', b'd'], ];
let merges = vec![
(vec![b'a'], vec![b'b']), (vec![b'c'], vec![b'd']), (vec![b'a', b'b'], vec![b'c', b'd']), ];
let vocab = Vocab::from_tokens(tokens, HashMap::new()).unwrap();
BpeTokenizer::new(vocab, merges).unwrap()
}
#[test]
fn encode_single_char() {
let t = minimal_tokenizer();
assert_eq!(t.encode("a").unwrap(), vec![0u32]);
assert_eq!(t.encode("b").unwrap(), vec![1u32]);
}
#[test]
fn encode_empty_string() {
let t = minimal_tokenizer();
assert_eq!(t.encode("").unwrap(), vec![]);
}
#[test]
fn encode_ab_merges_to_one_token() {
let t = minimal_tokenizer();
assert_eq!(t.encode("ab").unwrap(), vec![4u32]);
}
#[test]
fn encode_cd_merges_to_one_token() {
let t = minimal_tokenizer();
assert_eq!(t.encode("cd").unwrap(), vec![5u32]);
}
#[test]
fn encode_abcd_fully_merged() {
let t = minimal_tokenizer();
assert_eq!(t.encode("abcd").unwrap(), vec![6u32]);
}
#[test]
fn encode_abc_partial_merge() {
let t = minimal_tokenizer();
assert_eq!(t.encode("abc").unwrap(), vec![4u32, 2]);
}
#[test]
fn encode_abcdabcd_two_full_merges() {
let t = minimal_tokenizer();
assert_eq!(t.encode("abcdabcd").unwrap(), vec![6u32, 6]);
}
#[test]
fn decode_single_token() {
let t = minimal_tokenizer();
assert_eq!(t.decode(&[0]).unwrap(), "a");
assert_eq!(t.decode(&[4]).unwrap(), "ab");
assert_eq!(t.decode(&[6]).unwrap(), "abcd");
}
#[test]
fn decode_multiple_tokens() {
let t = minimal_tokenizer();
assert_eq!(t.decode(&[4, 2]).unwrap(), "abc");
assert_eq!(t.decode(&[6, 6]).unwrap(), "abcdabcd");
}
#[test]
fn encode_then_decode_roundtrip() {
let t = minimal_tokenizer();
for text in &["a", "ab", "abc", "abcd", "abcdabcd", "ba", "dcba"] {
let ids = t.encode(text).unwrap();
let decoded = t.decode(&ids).unwrap();
assert_eq!(&decoded, text, "roundtrip failed for '{text}'");
}
}
#[test]
fn out_of_vocab_token_errors() {
let t = minimal_tokenizer();
assert!(matches!(
t.decode(&[99]),
Err(LmError::OutOfVocab { token: 99 })
));
}
#[test]
fn encode_with_special_bos_eos() {
let tokens: Vec<Vec<u8>> = vec![
vec![b'a'], vec![b'b'], vec![1_u8, 0_u8], vec![2_u8, 0_u8], ];
let special: HashMap<String, u32> = [("<bos>".into(), 2u32), ("<eos>".into(), 3u32)]
.into_iter()
.collect();
let vocab = Vocab::from_tokens(tokens, special).unwrap();
let t = BpeTokenizer::new(vocab, vec![]).unwrap();
let ids = t
.encode_with_special("ab", Some("<bos>"), Some("<eos>"))
.unwrap();
assert_eq!(ids, vec![2, 0, 1, 3]);
}
#[test]
fn encode_with_special_no_bos_eos_absent() {
let t = minimal_tokenizer();
let ids = t.encode_with_special("ab", Some("<bos>"), None).unwrap();
assert_eq!(ids, vec![4u32]);
}
#[test]
fn bpe_builder_basic() {
let t = BpeBuilder::new()
.add_merge(b"a", b"b") .build()
.unwrap();
assert_eq!(t.vocab_size(), 257);
assert_eq!(t.encode("ab").unwrap(), vec![256u32]);
assert_eq!(t.decode(&[256]).unwrap(), "ab");
}
#[test]
fn bpe_builder_chained_merges() {
let t = BpeBuilder::new()
.add_merge(b"a", b"b") .add_merge(b"ab", b"c") .build()
.unwrap();
assert_eq!(t.encode("abc").unwrap(), vec![257u32]);
assert_eq!(t.decode(&[257]).unwrap(), "abc");
}
#[test]
fn bpe_builder_with_special_token() {
let t = BpeBuilder::new()
.add_special("<eos>", 10) .build()
.unwrap();
assert_eq!(t.special_id("<eos>"), Some(10));
}
#[test]
fn vocab_size_matches_builder() {
let t = BpeBuilder::new()
.add_merge(b"x", b"y")
.add_merge(b"xy", b"z")
.build()
.unwrap();
assert_eq!(t.vocab_size(), 258); }
#[test]
fn priority_order_matters() {
let t = BpeBuilder::new()
.add_merge(b"b", b"c") .add_merge(b"a", b"b") .build()
.unwrap();
let ids = t.encode("abcd").unwrap();
assert_eq!(ids, vec![b'a' as u32, 256u32, b'd' as u32]);
}
}