use std::collections::HashMap;
use crate::error::{LmError, LmResult};
use crate::tokenizer::vocab::Vocab;
#[derive(Debug, Clone)]
pub struct BpeTokenizer {
vocab: Vocab,
merge_ranks: HashMap<(u32, u32), u32>,
pair_to_merged: HashMap<(u32, u32), u32>,
}
impl BpeTokenizer {
pub fn new(vocab: Vocab, merges: Vec<(Vec<u8>, Vec<u8>)>) -> LmResult<Self> {
let n_merges = merges.len();
let mut merge_ranks = HashMap::with_capacity(n_merges);
let mut pair_to_merged = HashMap::with_capacity(n_merges);
for (rank, (left_bytes, right_bytes)) in merges.into_iter().enumerate() {
let a = vocab
.bytes_to_id(&left_bytes)
.ok_or(LmError::InvalidMergePair {
a: u32::MAX,
b: u32::MAX,
})?;
let b = vocab
.bytes_to_id(&right_bytes)
.ok_or(LmError::InvalidMergePair { a, b: u32::MAX })?;
let mut merged_bytes = left_bytes.clone();
merged_bytes.extend_from_slice(&right_bytes);
let merged = vocab
.bytes_to_id(&merged_bytes)
.ok_or(LmError::InvalidMergePair { a, b })?;
merge_ranks.insert((a, b), rank as u32);
pair_to_merged.insert((a, b), merged);
}
Ok(Self {
vocab,
merge_ranks,
pair_to_merged,
})
}
pub fn vocab_size(&self) -> usize {
self.vocab.size()
}
pub fn special_id(&self, name: &str) -> Option<u32> {
self.vocab.special_id(name)
}
pub fn encode(&self, text: &str) -> LmResult<Vec<u32>> {
if text.is_empty() {
return Ok(vec![]);
}
let mut tokens: Vec<u32> = text
.as_bytes()
.iter()
.map(|&b| {
self.vocab
.bytes_to_id(&[b])
.ok_or(LmError::OutOfVocab { token: b as u32 })
})
.collect::<LmResult<Vec<u32>>>()?;
loop {
let mut best_rank = u32::MAX;
let mut best_pos: Option<usize> = None;
for i in 0..tokens.len().saturating_sub(1) {
let pair = (tokens[i], tokens[i + 1]);
if let Some(&rank) = self.merge_ranks.get(&pair) {
if rank < best_rank {
best_rank = rank;
best_pos = Some(i);
}
}
}
match best_pos {
None => break, Some(pos) => {
let pair = (tokens[pos], tokens[pos + 1]);
let merged =
*self
.pair_to_merged
.get(&pair)
.ok_or(LmError::InvalidMergePair {
a: pair.0,
b: pair.1,
})?;
tokens[pos] = merged;
tokens.remove(pos + 1);
}
}
}
Ok(tokens)
}
pub fn encode_with_special(
&self,
text: &str,
bos_name: Option<&str>,
eos_name: Option<&str>,
) -> LmResult<Vec<u32>> {
let mut ids = Vec::new();
if let Some(name) = bos_name {
if let Some(id) = self.vocab.special_id(name) {
ids.push(id);
}
}
ids.extend(self.encode(text)?);
if let Some(name) = eos_name {
if let Some(id) = self.vocab.special_id(name) {
ids.push(id);
}
}
Ok(ids)
}
pub fn decode(&self, ids: &[u32]) -> LmResult<String> {
let mut bytes = Vec::new();
for &id in ids {
let tok_bytes = self.vocab.id_to_bytes(id)?;
bytes.extend_from_slice(tok_bytes);
}
String::from_utf8(bytes).map_err(|_| {
LmError::Utf8Decode {
token: ids.first().copied().unwrap_or(0),
}
})
}
pub fn decode_one(&self, id: u32) -> LmResult<String> {
self.decode(&[id])
}
pub fn vocab(&self) -> &Vocab {
&self.vocab
}
}
#[derive(Debug, Default)]
pub struct BpeBuilder {
merges: Vec<(Vec<u8>, Vec<u8>)>,
extra_tokens: Vec<Vec<u8>>,
special: HashMap<String, u32>,
}
impl BpeBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn add_merge(mut self, left: &[u8], right: &[u8]) -> Self {
let mut merged = left.to_vec();
merged.extend_from_slice(right);
self.extra_tokens.push(merged);
self.merges.push((left.to_vec(), right.to_vec()));
self
}
pub fn add_special(mut self, name: &str, id: u32) -> Self {
self.special.insert(name.into(), id);
self
}
pub fn build(self) -> LmResult<BpeTokenizer> {
let base = Vocab::gpt2_byte_vocab();
let vocab = base.with_extra_tokens(self.extra_tokens, self.special)?;
BpeTokenizer::new(vocab, self.merges)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn minimal_tokenizer() -> BpeTokenizer {
let tokens: Vec<Vec<u8>> = vec![
vec![b'a'], vec![b'b'], vec![b'c'], vec![b'd'], vec![b'a', b'b'], vec![b'c', b'd'], vec![b'a', b'b', b'c', b'd'], ];
let merges = vec![
(vec![b'a'], vec![b'b']), (vec![b'c'], vec![b'd']), (vec![b'a', b'b'], vec![b'c', b'd']), ];
let vocab = Vocab::from_tokens(tokens, HashMap::new())
.expect("minimal 7-token vocabulary should be valid");
BpeTokenizer::new(vocab, merges)
.expect("minimal BPE tokenizer with 3 merge rules should be valid")
}
#[test]
fn encode_single_char() {
let t = minimal_tokenizer();
assert_eq!(
t.encode("a").expect("single byte 'a' should encode"),
vec![0u32]
);
assert_eq!(
t.encode("b").expect("single byte 'b' should encode"),
vec![1u32]
);
}
#[test]
fn encode_empty_string() {
let t = minimal_tokenizer();
assert_eq!(
t.encode("")
.expect("empty string should encode to empty vec"),
vec![]
);
}
#[test]
fn encode_ab_merges_to_one_token() {
let t = minimal_tokenizer();
assert_eq!(
t.encode("ab").expect("'ab' should merge to token 4"),
vec![4u32]
);
}
#[test]
fn encode_cd_merges_to_one_token() {
let t = minimal_tokenizer();
assert_eq!(
t.encode("cd").expect("'cd' should merge to token 5"),
vec![5u32]
);
}
#[test]
fn encode_abcd_fully_merged() {
let t = minimal_tokenizer();
assert_eq!(
t.encode("abcd")
.expect("'abcd' should fully merge to token 6"),
vec![6u32]
);
}
#[test]
fn encode_abc_partial_merge() {
let t = minimal_tokenizer();
assert_eq!(
t.encode("abc")
.expect("'abc' should partially merge to [4, 2]"),
vec![4u32, 2]
);
}
#[test]
fn encode_abcdabcd_two_full_merges() {
let t = minimal_tokenizer();
assert_eq!(
t.encode("abcdabcd")
.expect("'abcdabcd' should merge to [6, 6]"),
vec![6u32, 6]
);
}
#[test]
fn decode_single_token() {
let t = minimal_tokenizer();
assert_eq!(t.decode(&[0]).expect("token 0 should decode to 'a'"), "a");
assert_eq!(t.decode(&[4]).expect("token 4 should decode to 'ab'"), "ab");
assert_eq!(
t.decode(&[6]).expect("token 6 should decode to 'abcd'"),
"abcd"
);
}
#[test]
fn decode_multiple_tokens() {
let t = minimal_tokenizer();
assert_eq!(
t.decode(&[4, 2])
.expect("tokens [4,2] should decode to 'abc'"),
"abc"
);
assert_eq!(
t.decode(&[6, 6])
.expect("tokens [6,6] should decode to 'abcdabcd'"),
"abcdabcd"
);
}
#[test]
fn encode_then_decode_roundtrip() {
let t = minimal_tokenizer();
for text in &["a", "ab", "abc", "abcd", "abcdabcd", "ba", "dcba"] {
let ids = t
.encode(text)
.unwrap_or_else(|e| panic!("encode '{text}' should succeed: {e}"));
let decoded = t
.decode(&ids)
.unwrap_or_else(|e| panic!("decode of '{text}' ids should succeed: {e}"));
assert_eq!(&decoded, text, "roundtrip failed for '{text}'");
}
}
#[test]
fn out_of_vocab_token_errors() {
let t = minimal_tokenizer();
assert!(matches!(
t.decode(&[99]),
Err(LmError::OutOfVocab { token: 99 })
));
}
#[test]
fn encode_with_special_bos_eos() {
let tokens: Vec<Vec<u8>> = vec![
vec![b'a'], vec![b'b'], vec![1_u8, 0_u8], vec![2_u8, 0_u8], ];
let special: HashMap<String, u32> = [("<bos>".into(), 2u32), ("<eos>".into(), 3u32)]
.into_iter()
.collect();
let vocab = Vocab::from_tokens(tokens, special)
.expect("4-token vocabulary with valid special ids should succeed");
let t =
BpeTokenizer::new(vocab, vec![]).expect("BPE tokenizer with no merges should succeed");
let ids = t
.encode_with_special("ab", Some("<bos>"), Some("<eos>"))
.expect("encode_with_special for 'ab' with valid BOS/EOS should succeed");
assert_eq!(ids, vec![2, 0, 1, 3]);
}
#[test]
fn encode_with_special_no_bos_eos_absent() {
let t = minimal_tokenizer();
let ids = t
.encode_with_special("ab", Some("<bos>"), None)
.expect("encode_with_special with absent BOS should skip silently and succeed");
assert_eq!(ids, vec![4u32]);
}
#[test]
fn bpe_builder_basic() {
let t = BpeBuilder::new()
.add_merge(b"a", b"b") .build()
.expect("BpeBuilder with single merge should produce valid tokenizer");
assert_eq!(t.vocab_size(), 257);
assert_eq!(
t.encode("ab").expect("'ab' should merge to id 256"),
vec![256u32]
);
assert_eq!(
t.decode(&[256]).expect("id 256 should decode to 'ab'"),
"ab"
);
}
#[test]
fn bpe_builder_chained_merges() {
let t = BpeBuilder::new()
.add_merge(b"a", b"b") .add_merge(b"ab", b"c") .build()
.expect("BpeBuilder with chained merges a+b→ab then ab+c→abc should succeed");
assert_eq!(
t.encode("abc").expect("'abc' should chain-merge to id 257"),
vec![257u32]
);
assert_eq!(
t.decode(&[257]).expect("id 257 should decode to 'abc'"),
"abc"
);
}
#[test]
fn bpe_builder_with_special_token() {
let t = BpeBuilder::new()
.add_special("<eos>", 10) .build()
.expect("BpeBuilder with special token at valid id 10 should succeed");
assert_eq!(t.special_id("<eos>"), Some(10));
}
#[test]
fn vocab_size_matches_builder() {
let t = BpeBuilder::new()
.add_merge(b"x", b"y")
.add_merge(b"xy", b"z")
.build()
.expect("BpeBuilder with x+y and xy+z merges should succeed");
assert_eq!(t.vocab_size(), 258); }
#[test]
fn priority_order_matters() {
let t = BpeBuilder::new()
.add_merge(b"b", b"c") .add_merge(b"a", b"b") .build()
.expect("BpeBuilder with priority-order merges b+c then a+b should succeed");
let ids = t
.encode("abcd")
.expect("'abcd' encoding with priority b+c merge should succeed");
assert_eq!(ids, vec![b'a' as u32, 256u32, b'd' as u32]);
}
}