use std::cmp::Ordering;
use std::collections::HashMap as StdHashMap;
use dary_heap::OctonaryHeap;
use fancy_regex::Regex;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
use ahash::{AHashMap, AHashSet};
use compact_str::CompactString;
use rayon::prelude::*;
pub const GPT4_PATTERN: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+";
type Pair = (u32, u32);
#[cfg_attr(feature = "python-bindings", pyclass)]
pub struct Tokenizer {
pub merges: StdHashMap<Pair, u32>,
pub pattern: String,
compiled_pattern: Regex,
}
impl Default for Tokenizer {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
struct Word {
ids: Vec<u32>,
}
impl Word {
#[inline]
fn new(ids: Vec<u32>) -> Self {
Self { ids }
}
#[inline]
fn pairs(&self) -> impl Iterator<Item = Pair> + '_ {
self.ids.windows(2).map(|w| (w[0], w[1]))
}
fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> {
let (a, b) = pair;
let n = self.ids.len();
if n < 2 {
return Vec::new();
}
let mut out: Vec<u32> = Vec::with_capacity(n);
let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6);
let mut i = 0;
while i < n {
if i + 1 < n && self.ids[i] == a && self.ids[i + 1] == b {
let left = out.last().copied();
let right = if i + 2 < n {
Some(self.ids[i + 2])
} else {
None
};
if let Some(x) = left {
deltas.push(((x, a), -1));
deltas.push(((x, new_id), 1));
}
deltas.push(((a, b), -1));
if let Some(y) = right {
deltas.push(((b, y), -1));
deltas.push(((new_id, y), 1));
}
out.push(new_id);
i += 2; } else {
out.push(self.ids[i]);
i += 1;
}
}
self.ids = out;
deltas
}
}
#[derive(Debug, Eq)]
struct MergeJob {
pair: Pair,
count: u64,
pos: AHashSet<usize>,
}
impl PartialEq for MergeJob {
fn eq(&self, other: &Self) -> bool {
self.count == other.count && self.pair == other.pair
}
}
impl PartialOrd for MergeJob {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MergeJob {
fn cmp(&self, other: &Self) -> Ordering {
if self.count != other.count {
self.count.cmp(&other.count)
} else {
other.pair.cmp(&self.pair)
}
}
}
#[inline]
fn count_pairs_parallel(
words: &[Word],
counts: &[i32],
) -> (AHashMap<Pair, i32>, AHashMap<Pair, AHashSet<usize>>) {
words
.par_iter()
.enumerate()
.map(|(i, w)| {
let mut local_pc: AHashMap<Pair, i32> = AHashMap::new();
let mut local_wtu: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
if w.ids.len() >= 2 && counts[i] != 0 {
for (a, b) in w.pairs() {
*local_pc.entry((a, b)).or_default() += counts[i];
local_wtu.entry((a, b)).or_default().insert(i);
}
}
(local_pc, local_wtu)
})
.reduce(
|| (AHashMap::new(), AHashMap::new()),
|(mut acc_pc, mut acc_wtu), (pc, wtu)| {
for (k, v) in pc {
*acc_pc.entry(k).or_default() += v;
}
for (k, s) in wtu {
acc_wtu.entry(k).or_default().extend(s);
}
(acc_pc, acc_wtu)
},
)
}
impl Tokenizer {
pub fn new() -> Self {
Self {
merges: StdHashMap::new(),
pattern: String::new(),
compiled_pattern: Regex::new("").expect("Empty regex should be valid"),
}
}
fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) {
assert!(vocab_size >= 256, "vocab_size must be at least 256");
let num_merges = vocab_size - 256;
log::info!("Starting BPE training: {} merges to compute", num_merges);
self.merges.clear();
log::info!(
"Computing initial pair counts from {} unique sequences",
words.len()
);
let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts);
log::info!("Building heap with {} unique pairs", pair_counts.len());
let mut heap = OctonaryHeap::with_capacity(pair_counts.len());
for (pair, pos) in where_to_update.drain() {
let c = *pair_counts.get(&pair).unwrap_or(&0);
if c > 0 {
heap.push(MergeJob {
pair,
count: c as u64,
pos,
});
}
}
log::info!("Starting merge loop");
let mut merges_done = 0u32;
let mut last_log_percent = 0u32;
while merges_done < num_merges {
let Some(mut top) = heap.pop() else {
break;
};
let current = *pair_counts.get(&top.pair).unwrap_or(&0);
if current <= 0 {
continue;
}
if top.count != current as u64 {
top.count = current as u64;
heap.push(top);
continue;
}
let new_id = 256 + merges_done;
self.merges.insert(top.pair, new_id);
let mut local_pos_updates: AHashMap<Pair, AHashSet<usize>> = AHashMap::new();
for &word_idx in &top.pos {
let changes = words[word_idx].merge_pair(top.pair, new_id);
for (pair, delta) in changes {
let delta_total = delta * counts[word_idx];
if delta_total != 0 {
*pair_counts.entry(pair).or_default() += delta_total;
if delta > 0 {
local_pos_updates.entry(pair).or_default().insert(word_idx);
}
}
}
}
for (pair, pos) in local_pos_updates {
let cnt = *pair_counts.get(&pair).unwrap_or(&0);
if cnt > 0 {
heap.push(MergeJob {
pair,
count: cnt as u64,
pos,
});
}
}
merges_done += 1;
let current_percent = (merges_done * 100) / num_merges;
if current_percent > last_log_percent {
log::info!(
"Progress: {}% ({}/{} merges) - Last merge: {:?} -> {} (frequency: {})",
current_percent,
merges_done,
num_merges,
top.pair,
new_id,
top.count
);
last_log_percent = current_percent;
}
}
log::info!("Finished training: {} merges completed", merges_done);
}
pub fn new_with_pattern(pattern: impl Into<String>) -> Result<Self, fancy_regex::Error> {
let pattern = pattern.into();
Ok(Self {
merges: StdHashMap::new(),
compiled_pattern: Regex::new(&pattern)?,
pattern,
})
}
pub fn from_merges(
pattern: impl Into<String>,
merges: StdHashMap<(u32, u32), u32>,
) -> Result<Self, fancy_regex::Error> {
let pattern = pattern.into();
Ok(Self {
merges,
compiled_pattern: Regex::new(&pattern)?,
pattern,
})
}
pub fn train_from_texts<'a, I>(
&mut self,
texts: I,
vocab_size: u32,
pattern: Option<&str>,
) -> Result<(), fancy_regex::Error>
where
I: IntoIterator<Item = &'a str>,
{
let pattern_str = pattern.unwrap_or(GPT4_PATTERN).to_string();
self.pattern = pattern_str.clone();
self.compiled_pattern = Regex::new(&pattern_str)?;
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
for text in texts {
for mat in self.compiled_pattern.find_iter(text) {
let piece = mat?.as_str();
*counts.entry(CompactString::from(piece)).or_default() += 1;
}
}
let mut words = Vec::with_capacity(counts.len());
let mut cvec = Vec::with_capacity(counts.len());
for (chunk, count) in counts.into_iter() {
words.push(Word::new(
chunk.as_bytes().iter().map(|&byte| byte as u32).collect(),
));
cvec.push(count);
}
self.train_core_incremental(words, cvec, vocab_size);
Ok(())
}
pub fn decode_to_string(&self, ids: &[u32]) -> Result<String, String> {
let mut vocab: Vec<Vec<u8>> = (0..256u32).map(|i| vec![i as u8]).collect();
let mut sorted_merges: Vec<_> = self.merges.iter().collect();
sorted_merges.sort_by_key(|&(_, &token_id)| token_id);
for ((left, right), merged_id) in sorted_merges {
let left = *left;
let right = *right;
let merged_id = *merged_id;
let mut merged_bytes = vocab
.get(left as usize)
.ok_or_else(|| format!("invalid token id {left} in merge"))?
.clone();
merged_bytes.extend(
vocab
.get(right as usize)
.ok_or_else(|| format!("invalid token id {right} in merge"))?,
);
if vocab.len() <= merged_id as usize {
vocab.resize(merged_id as usize + 1, Vec::new());
}
vocab[merged_id as usize] = merged_bytes;
}
let mut bytes = Vec::new();
for &id in ids {
let token_bytes = vocab
.get(id as usize)
.ok_or_else(|| format!("unknown token id: {id}"))?;
bytes.extend(token_bytes);
}
String::from_utf8(bytes).map_err(|err| format!("decoded bytes are not valid UTF-8: {err}"))
}
pub fn get_pattern(&self) -> String {
self.pattern.clone()
}
pub fn vocab_size(&self) -> u32 {
256 + self.merges.len() as u32
}
pub fn get_mergeable_ranks(&self) -> Vec<(Vec<u8>, u32)> {
let mut mergeable_ranks = Vec::new();
let mut token_bytes: Vec<Vec<u8>> = (0..256_u32).map(|i| vec![i as u8]).collect();
for (i, bytes) in token_bytes.iter().enumerate() {
mergeable_ranks.push((bytes.clone(), i as u32));
}
let mut sorted_merges: Vec<_> = self.merges.iter().collect();
sorted_merges.sort_by_key(|&(_, &token_id)| token_id);
for (&pair, &merged_id) in sorted_merges {
let (left, right) = pair;
let mut merged_bytes = token_bytes[left as usize].clone();
merged_bytes.extend(&token_bytes[right as usize]);
if token_bytes.len() <= merged_id as usize {
token_bytes.resize(merged_id as usize + 1, Vec::new());
}
token_bytes[merged_id as usize] = merged_bytes.clone();
mergeable_ranks.push((merged_bytes, merged_id));
}
mergeable_ranks
}
pub fn encode(&self, text: &str) -> Vec<u32> {
let mut all_ids = Vec::new();
for m in self.compiled_pattern.find_iter(text) {
let chunk = match m {
Ok(mat) => mat.as_str(),
Err(err) => {
log::warn!("Regex match error, skipping chunk: {err}");
continue;
}
};
let mut ids: Vec<u32> = chunk.bytes().map(|byte| byte as u32).collect();
while ids.len() >= 2 {
let mut best_pair: Option<(usize, Pair, u32)> = None;
for i in 0..ids.len() - 1 {
let pair: Pair = (ids[i], ids[i + 1]);
if let Some(&new_id) = self.merges.get(&pair) {
let is_better = match best_pair {
Some((_, _, best_id)) => new_id < best_id,
None => true,
};
if is_better {
best_pair = Some((i, pair, new_id));
}
}
}
if let Some((idx, _pair, new_id)) = best_pair {
ids[idx] = new_id;
ids.remove(idx + 1);
} else {
break;
}
}
all_ids.extend(ids);
}
all_ids
}
}
#[cfg(feature = "python-bindings")]
#[pymethods]
impl Tokenizer {
#[new]
pub fn py_new() -> Self {
Self::new()
}
#[pyo3(signature = (iterator, vocab_size, buffer_size=8192, pattern=None))]
#[pyo3(text_signature = "(self, iterator, vocab_size, buffer_size=8192, pattern=None)")]
pub fn train_from_iterator(
&mut self,
py: pyo3::Python<'_>,
iterator: &pyo3::Bound<'_, pyo3::PyAny>,
vocab_size: u32,
buffer_size: usize,
pattern: Option<String>,
) -> PyResult<()> {
let pattern_str = pattern.unwrap_or_else(|| GPT4_PATTERN.to_string());
self.pattern = pattern_str.clone();
self.compiled_pattern = Regex::new(&pattern_str).map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e))
})?;
let py_iter: pyo3::Py<pyo3::PyAny> = unsafe {
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
};
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
let mut buf: Vec<String> = Vec::with_capacity(buffer_size);
log::info!(
"Processing sequences from iterator (buffer_size: {})",
buffer_size
);
let mut total_sequences = 0u64;
let refill = |buf: &mut Vec<String>| -> PyResult<bool> {
pyo3::Python::attach(|py| {
buf.clear();
let it = py_iter.bind(py);
loop {
if buf.len() >= buffer_size {
return Ok(false);
}
let next_obj = unsafe {
pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr()))
};
match next_obj {
Some(obj) => {
let s: String = obj.extract()?;
buf.push(s);
}
None => {
if pyo3::PyErr::occurred(py) {
return Err(pyo3::PyErr::fetch(py));
} else {
return Ok(true); }
}
}
}
})
};
loop {
let exhausted = refill(&mut buf)?;
if buf.is_empty() && exhausted {
break;
}
total_sequences += buf.len() as u64;
let pattern = self.compiled_pattern.clone();
let local: AHashMap<CompactString, i32> = py.detach(|| {
buf.par_iter()
.map(|s| {
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
for mat in pattern.find_iter(s) {
let piece = mat.expect("regex match failed").as_str();
*m.entry(CompactString::from(piece)).or_default() += 1;
}
m
})
.reduce(AHashMap::new, |mut a, b| {
for (k, v) in b {
*a.entry(k).or_default() += v;
}
a
})
});
for (k, v) in local {
*counts.entry(k).or_default() += v;
}
if exhausted {
break;
}
}
log::info!(
"Processed {} sequences total, {} unique",
total_sequences,
counts.len()
);
let mut words = Vec::with_capacity(counts.len());
let mut cvec = Vec::with_capacity(counts.len());
for (chunk, c) in counts.into_iter() {
words.push(Word::new(
chunk.as_bytes().iter().map(|&b| b as u32).collect(),
));
cvec.push(c);
}
self.train_core_incremental(words, cvec, vocab_size);
Ok(())
}
#[pyo3(name = "get_pattern")]
pub fn py_get_pattern(&self) -> String {
self.pattern.clone()
}
#[getter]
pub fn py_vocab_size(&self) -> u32 {
self.vocab_size()
}
#[pyo3(name = "get_mergeable_ranks")]
pub fn py_get_mergeable_ranks(&self) -> Vec<(Vec<u8>, u32)> {
self.get_mergeable_ranks()
}
#[pyo3(name = "encode")]
pub fn py_encode(&self, text: &str) -> Vec<u32> {
self.encode(text)
}
#[pyo3(name = "decode")]
pub fn py_decode(&self, ids: Vec<u32>) -> PyResult<String> {
self.decode_to_string(&ids)
.map_err(|err| pyo3::exceptions::PyValueError::new_err(format!("decode failed: {err}")))
}
#[pyo3(signature = (texts))]
#[pyo3(text_signature = "(self, texts)")]
pub fn batch_encode(&self, py: Python<'_>, texts: Vec<String>) -> PyResult<Vec<Vec<u32>>> {
let results = py.detach(|| {
texts
.par_iter()
.map(|text| self.encode(text))
.collect::<Vec<Vec<u32>>>()
});
Ok(results)
}
}
#[cfg(feature = "python-bindings")]
#[pymodule]
fn rustbpe(m: &Bound<'_, PyModule>) -> PyResult<()> {
pyo3_log::init(); m.add_class::<Tokenizer>()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_word_pairs() {
let word = Word::new(vec![1, 2, 3, 4]);
let pairs: Vec<Pair> = word.pairs().collect();
assert_eq!(pairs, vec![(1, 2), (2, 3), (3, 4)]);
}
#[test]
fn test_word_pairs_empty() {
let word = Word::new(vec![]);
let pairs: Vec<Pair> = word.pairs().collect();
assert!(pairs.is_empty());
}
#[test]
fn test_word_pairs_single() {
let word = Word::new(vec![42]);
let pairs: Vec<Pair> = word.pairs().collect();
assert!(pairs.is_empty());
}
#[test]
fn test_word_merge_pair() {
let mut word = Word::new(vec![1, 2, 3, 1, 2]);
let _deltas = word.merge_pair((1, 2), 99);
assert_eq!(word.ids, vec![99, 3, 99]);
}
#[test]
fn test_word_merge_pair_adjacent() {
let mut word = Word::new(vec![1, 2, 1, 2, 1, 2]);
let _deltas = word.merge_pair((1, 2), 99);
assert_eq!(word.ids, vec![99, 99, 99]);
}
#[test]
fn test_word_merge_no_match() {
let mut word = Word::new(vec![1, 2, 3]);
let deltas = word.merge_pair((4, 5), 99);
assert_eq!(word.ids, vec![1, 2, 3]); assert!(deltas.is_empty() || deltas.iter().all(|(_, d)| *d == 0));
}
#[test]
fn test_tokenizer_new() {
let tok = Tokenizer::new();
assert!(tok.merges.is_empty());
assert!(tok.pattern.is_empty());
}
#[test]
fn test_encode_untrained_simple() {
let tok = Tokenizer::new();
let ids = tok.encode("hello");
assert!(ids.is_empty()); }
#[test]
fn test_encode_with_pattern_no_merges() {
let tok = Tokenizer {
merges: StdHashMap::new(),
pattern: r"\w+".to_string(),
compiled_pattern: Regex::new(r"\w+").unwrap(),
};
let ids = tok.encode("hi");
assert_eq!(ids, vec![104, 105]);
}
#[test]
fn test_encode_with_merges() {
let mut merges = StdHashMap::new();
merges.insert((104, 105), 256);
let tok = Tokenizer {
merges,
pattern: r"\w+".to_string(),
compiled_pattern: Regex::new(r"\w+").unwrap(),
};
let ids = tok.encode("hi");
assert_eq!(ids, vec![256]);
let ids2 = tok.encode("hip");
assert_eq!(ids2, vec![256, 112]);
}
#[test]
fn test_get_mergeable_ranks_empty() {
let tok = Tokenizer::new();
let ranks = tok.get_mergeable_ranks();
assert_eq!(ranks.len(), 256);
assert_eq!(ranks[0], (vec![0u8], 0));
assert_eq!(ranks[255], (vec![255u8], 255));
}
#[test]
fn test_get_mergeable_ranks_with_merge() {
let mut merges = StdHashMap::new();
merges.insert((65, 66), 256);
let tok = Tokenizer {
merges,
pattern: String::new(),
compiled_pattern: Regex::new("").unwrap(),
};
let ranks = tok.get_mergeable_ranks();
assert_eq!(ranks.len(), 257);
let last = &ranks[256];
assert_eq!(last.0, vec![65u8, 66u8]);
assert_eq!(last.1, 256);
}
#[test]
fn test_count_pairs_parallel() {
let words = vec![Word::new(vec![1, 2, 3]), Word::new(vec![1, 2, 4])];
let counts = vec![1, 2];
let (pair_counts, positions) = count_pairs_parallel(&words, &counts);
assert_eq!(pair_counts.get(&(1, 2)), Some(&3));
assert_eq!(pair_counts.get(&(2, 3)), Some(&1));
assert_eq!(pair_counts.get(&(2, 4)), Some(&2));
assert!(positions.get(&(1, 2)).unwrap().contains(&0));
assert!(positions.get(&(1, 2)).unwrap().contains(&1));
}
#[test]
fn test_train_core_incremental() {
let mut tok = Tokenizer {
merges: StdHashMap::new(),
pattern: String::new(),
compiled_pattern: Regex::new("").unwrap(),
};
let words = vec![
Word::new(vec![97, 98]), Word::new(vec![99, 100]), ];
let counts = vec![10, 5];
tok.train_core_incremental(words, counts, 257);
assert_eq!(tok.merges.len(), 1);
assert!(tok.merges.contains_key(&(97, 98)));
assert_eq!(tok.merges.get(&(97, 98)), Some(&256));
}
#[test]
fn test_default_trait() {
let tok = Tokenizer::default();
assert!(tok.merges.is_empty());
assert!(tok.pattern.is_empty());
}
#[test]
fn test_vocab_size() {
let mut tok = Tokenizer::new();
assert_eq!(tok.vocab_size(), 256);
tok.merges.insert((97, 98), 256);
assert_eq!(tok.vocab_size(), 257);
tok.merges.insert((256, 99), 257);
assert_eq!(tok.vocab_size(), 258);
}
#[test]
fn test_word_merge_overlapping_pairs() {
let mut word = Word::new(vec![97, 97, 97]);
let _deltas = word.merge_pair((97, 97), 256);
assert_eq!(word.ids, vec![256, 97]);
}
#[test]
fn test_word_merge_overlapping_pairs_even() {
let mut word = Word::new(vec![97, 97, 97, 97]);
let _deltas = word.merge_pair((97, 97), 256);
assert_eq!(word.ids, vec![256, 256]);
}
#[test]
fn test_word_merge_multiple_occurrences() {
let mut word = Word::new(vec![1, 2, 99, 1, 2]);
let deltas = word.merge_pair((1, 2), 256);
assert_eq!(word.ids, vec![256, 99, 256]);
let ab_removals: i32 = deltas
.iter()
.filter(|(p, _)| *p == (1, 2))
.map(|(_, d)| d)
.sum();
assert_eq!(ab_removals, -2); }
#[test]
fn test_encode_chained_merges() {
let mut merges = StdHashMap::new();
merges.insert((97, 97), 256); merges.insert((256, 97), 257);
let tok = Tokenizer {
merges,
pattern: r"\w+".to_string(),
compiled_pattern: Regex::new(r"\w+").unwrap(),
};
let ids = tok.encode("aaa");
assert_eq!(ids, vec![257]);
let ids = tok.encode("aaaa");
assert_eq!(ids, vec![256, 256]);
let ids = tok.encode("aaaaa");
assert_eq!(ids, vec![256, 257]);
}
#[test]
fn test_encode_decode_roundtrip_simple() {
let mut merges = StdHashMap::new();
merges.insert((104, 105), 256);
let tok = Tokenizer {
merges,
pattern: r"\w+|\s+".to_string(),
compiled_pattern: Regex::new(r"\w+|\s+").unwrap(),
};
let text = "hi";
let ids = tok.encode(text);
let decoded = tok.decode(ids).unwrap();
assert_eq!(decoded, text);
}
#[test]
fn test_encode_decode_roundtrip_with_spaces() {
let mut merges = StdHashMap::new();
merges.insert((104, 101), 256); merges.insert((108, 108), 257); merges.insert((256, 257), 258);
let tok = Tokenizer {
merges,
pattern: r"\w+|\s+".to_string(),
compiled_pattern: Regex::new(r"\w+|\s+").unwrap(),
};
let text = "hello world";
let ids = tok.encode(text);
let decoded = tok.decode(ids).unwrap();
assert_eq!(decoded, text);
}
#[test]
fn test_decode_byte_level() {
let tok = Tokenizer {
merges: StdHashMap::new(),
pattern: String::new(),
compiled_pattern: Regex::new("").unwrap(),
};
let decoded = tok.decode(vec![104, 105]).unwrap();
assert_eq!(decoded, "hi");
}
#[test]
fn test_decode_invalid_token() {
let tok = Tokenizer::new();
let result = tok.decode(vec![300]);
assert!(result.is_err());
}
#[test]
fn test_train_multiple_merges() {
let mut tok = Tokenizer {
merges: StdHashMap::new(),
pattern: String::new(),
compiled_pattern: Regex::new("").unwrap(),
};
let words = vec![
Word::new(vec![97, 98]), ];
let counts = vec![10];
tok.train_core_incremental(words, counts, 258);
assert_eq!(tok.merges.len(), 1);
}
#[test]
fn test_train_creates_chained_merges() {
let mut tok = Tokenizer {
merges: StdHashMap::new(),
pattern: String::new(),
compiled_pattern: Regex::new("").unwrap(),
};
let words = vec![Word::new(vec![97, 97, 97])];
let counts = vec![10];
tok.train_core_incremental(words, counts, 258);
assert_eq!(tok.merges.len(), 2);
assert_eq!(tok.merges.get(&(97, 97)), Some(&256));
assert_eq!(tok.merges.get(&(256, 97)), Some(&257));
}
#[test]
fn test_get_mergeable_ranks_chained() {
let mut merges = StdHashMap::new();
merges.insert((65, 66), 256); merges.insert((256, 67), 257);
let tok = Tokenizer {
merges,
pattern: String::new(),
compiled_pattern: Regex::new("").unwrap(),
};
let ranks = tok.get_mergeable_ranks();
assert_eq!(ranks.len(), 258);
assert_eq!(ranks[256], (vec![65u8, 66u8], 256));
assert_eq!(ranks[257], (vec![65u8, 66u8, 67u8], 257));
}
#[test]
fn test_encode_empty_string() {
let tok = Tokenizer {
merges: StdHashMap::new(),
pattern: r"\w+".to_string(),
compiled_pattern: Regex::new(r"\w+").unwrap(),
};
let ids = tok.encode("");
assert!(ids.is_empty());
}
#[test]
fn test_encode_no_matches() {
let tok = Tokenizer {
merges: StdHashMap::new(),
pattern: r"\w+".to_string(),
compiled_pattern: Regex::new(r"\w+").unwrap(),
};
let ids = tok.encode(" "); assert!(ids.is_empty());
}
#[test]
fn test_decode_empty() {
let tok = Tokenizer::new();
let decoded = tok.decode(vec![]).unwrap();
assert_eq!(decoded, "");
}
#[test]
fn test_word_merge_deltas_correctness() {
let mut word = Word::new(vec![1, 2, 3, 1, 2]);
let deltas = word.merge_pair((1, 2), 99);
let mut delta_map: StdHashMap<Pair, i32> = StdHashMap::new();
for (pair, delta) in deltas {
*delta_map.entry(pair).or_default() += delta;
}
assert_eq!(delta_map.get(&(1, 2)), Some(&-2));
assert_eq!(delta_map.get(&(2, 3)), Some(&-1));
assert_eq!(delta_map.get(&(3, 1)), Some(&-1));
assert_eq!(delta_map.get(&(99, 3)), Some(&1));
assert_eq!(delta_map.get(&(3, 99)), Some(&1));
}
#[test]
fn test_count_pairs_parallel_empty() {
let words: Vec<Word> = vec![];
let counts: Vec<i32> = vec![];
let (pair_counts, positions) = count_pairs_parallel(&words, &counts);
assert!(pair_counts.is_empty());
assert!(positions.is_empty());
}
#[test]
fn test_count_pairs_parallel_zero_count() {
let words = vec![Word::new(vec![1, 2, 3])];
let counts = vec![0];
let (pair_counts, _positions) = count_pairs_parallel(&words, &counts);
assert!(pair_counts.is_empty());
}
}