use std::collections::HashMap;
use crate::vocab::Vocabulary;
#[derive(Debug, Clone, Default)]
pub struct BpeMerges {
merges: HashMap<(String, String), u32>,
merge_order: Vec<(String, String)>,
}
impl BpeMerges {
pub fn new() -> Self {
Self::default()
}
pub fn add_merge(&mut self, a: &str, b: &str, result_id: u32) {
let key = (a.to_owned(), b.to_owned());
if !self.merges.contains_key(&key) {
self.merge_order.push(key.clone());
}
self.merges.insert(key, result_id);
}
pub fn get_merge_priority(&self, a: &str, b: &str) -> Option<usize> {
let key = (a.to_owned(), b.to_owned());
if self.merges.contains_key(&key) {
self.merge_order.iter().position(|p| *p == key)
} else {
None
}
}
pub fn get_merge_result(&self, a: &str, b: &str) -> Option<u32> {
self.merges.get(&(a.to_owned(), b.to_owned())).copied()
}
pub fn len(&self) -> usize {
self.merges.len()
}
pub fn is_empty(&self) -> bool {
self.merges.is_empty()
}
}
pub fn pretokenize(text: &str) -> Vec<String> {
if text.is_empty() {
return Vec::new();
}
let mut tokens: Vec<String> = Vec::new();
let mut current = String::new();
let mut last_was_space = false;
for ch in text.chars() {
if ch.is_whitespace() {
if !current.is_empty() {
tokens.push(current.clone());
current.clear();
}
last_was_space = true;
} else if ch.is_ascii_punctuation() {
if !current.is_empty() {
tokens.push(current.clone());
current.clear();
}
let mut tok = String::new();
if last_was_space {
tok.push('\u{0120}'); }
tok.push(ch);
tokens.push(tok);
last_was_space = false;
} else {
if last_was_space && !current.is_empty() {
tokens.push(current.clone());
current.clear();
}
if last_was_space && current.is_empty() {
current.push('\u{0120}'); }
current.push(ch);
last_was_space = false;
}
}
if !current.is_empty() {
tokens.push(current);
}
tokens
}
pub fn bpe_encode(word: &str, vocab: &Vocabulary, merges: &BpeMerges) -> Vec<u32> {
if word.is_empty() {
return Vec::new();
}
let mut symbols: Vec<String> = word.chars().map(|c| c.to_string()).collect();
loop {
if symbols.len() < 2 {
break;
}
let best = symbols
.windows(2)
.enumerate()
.filter_map(|(pos, pair)| {
merges
.get_merge_priority(&pair[0], &pair[1])
.map(|priority| (priority, pos))
})
.min_by_key(|&(priority, _)| priority);
match best {
None => break, Some((_, pos)) => {
let merged = format!("{}{}", symbols[pos], symbols[pos + 1]);
symbols[pos] = merged;
symbols.remove(pos + 1);
}
}
}
symbols
.iter()
.flat_map(|sym| symbol_to_ids(sym, vocab))
.collect()
}
fn symbol_to_ids(sym: &str, vocab: &Vocabulary) -> Vec<u32> {
if let Some(id) = vocab.get_id(sym) {
return vec![id];
}
sym.as_bytes()
.iter()
.filter_map(|&b| {
let fallback = byte_fallback_id(b);
vocab.get_id(&fallback)
})
.collect()
}
pub fn byte_fallback_id(byte: u8) -> String {
format!("<0x{byte:02X}>")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vocab::Vocabulary;
fn make_vocab_with_merges() -> (Vocabulary, BpeMerges) {
let mut vocab = Vocabulary::new();
vocab.insert("h", 10);
vocab.insert("e", 11);
vocab.insert("l", 12);
vocab.insert("o", 13);
vocab.insert("he", 20);
vocab.insert("hel", 21);
vocab.insert("hell", 22);
vocab.insert("hello", 23);
vocab.insert("lo", 24);
let mut merges = BpeMerges::new();
merges.add_merge("h", "e", 20);
merges.add_merge("he", "l", 21);
merges.add_merge("hel", "l", 22);
merges.add_merge("hell", "o", 23);
merges.add_merge("l", "o", 24);
(vocab, merges)
}
#[test]
fn byte_fallback_format() {
assert_eq!(byte_fallback_id(0x00), "<0x00>");
assert_eq!(byte_fallback_id(0x20), "<0x20>");
assert_eq!(byte_fallback_id(0xFF), "<0xFF>");
assert_eq!(byte_fallback_id(0x0A), "<0x0A>");
}
#[test]
fn bpe_merges_priority() {
let mut m = BpeMerges::new();
m.add_merge("a", "b", 1);
m.add_merge("b", "c", 2);
assert_eq!(m.get_merge_priority("a", "b"), Some(0));
assert_eq!(m.get_merge_priority("b", "c"), Some(1));
assert_eq!(m.get_merge_priority("x", "y"), None);
assert_eq!(m.len(), 2);
}
#[test]
fn bpe_encode_hello() {
let (vocab, merges) = make_vocab_with_merges();
let ids = bpe_encode("hello", &vocab, &merges);
assert_eq!(ids, vec![23]);
}
#[test]
fn bpe_encode_empty() {
let (vocab, merges) = make_vocab_with_merges();
let ids = bpe_encode("", &vocab, &merges);
assert!(ids.is_empty());
}
#[test]
fn pretokenize_simple_sentence() {
let tokens = pretokenize("hello world");
assert!(!tokens.is_empty());
assert!(tokens.iter().any(|t| t.contains("hello") || t == "hello"));
}
#[test]
fn pretokenize_empty() {
assert!(pretokenize("").is_empty());
}
#[test]
fn pretokenize_punctuation_splits() {
let tokens = pretokenize("hi,there");
assert!(tokens.len() >= 2);
}
}