import regex as re
from collections import Counter, defaultdict
import time
import warnings
import rustbpe
import tiktoken
import pytest
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
def get_stats(ids, counts=None):
counts = {} if counts is None else counts
for pair in zip(ids, ids[1:]): counts[pair] = counts.get(pair, 0) + 1
return counts
def merge(ids, pair, idx):
newids = []
i = 0
while i < len(ids):
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
class RegexTokenizer:
def __init__(self, pattern=None):
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
self.merges = {} self.compiled_pattern = re.compile(self.pattern)
self.special_tokens = {}
self.inverse_special_tokens = {}
self.vocab = self._build_vocab()
def _build_vocab(self):
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
num_merges = vocab_size - 256
ambiguous = False
text_chunks = re.findall(self.compiled_pattern, text)
ids = [list(ch.encode("utf-8")) for ch in text_chunks]
merges = {} vocab = {idx: bytes([idx]) for idx in range(256)} for i in range(num_merges):
stats = {}
for chunk_ids in ids:
get_stats(chunk_ids, stats)
pair = max(stats, key=stats.get)
pair_count = stats[pair]
pairs_with_max_count = [pair for pair, count in stats.items() if count == pair_count]
if len(pairs_with_max_count) > 1:
ambiguous = True
idx = 256 + i
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
self.merges = merges self.vocab = vocab return ambiguous
def _encode_chunk(self, text_bytes):
ids = list(text_bytes)
while len(ids) >= 2:
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
if pair not in self.merges:
break idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
def encode_ordinary(self, text):
text_chunks = re.findall(self.compiled_pattern, text)
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids
def fast_merge_inplace(ids, pair, idx):
i = 0
while i < len(ids) - 1:
if ids[i] == pair[0] and ids[i+1] == pair[1]:
ids[i] = idx
ids.pop(i+1)
else:
i += 1
return ids
class FastRegexTokenizer:
def __init__(self, pattern=None):
self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
self.compiled_pattern = re.compile(self.pattern)
self.special_tokens = {}
self.inverse_special_tokens = {}
self.merges = {}
self.vocab = self._build_vocab()
def _build_vocab(self):
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
num_merges = vocab_size - 256
text_chunks = re.findall(self.compiled_pattern, text)
counts = Counter(text_chunks)
unique_chunks = [ch for ch, count in counts.items()]
chunk_counts = [count for ch, count in counts.items()]
ids = [list(ch.encode("utf-8")) for ch in unique_chunks]
merges = {} vocab = {idx: bytes([idx]) for idx in range(256)}
stats = defaultdict(int)
positions = defaultdict(set)
for chunk_idx, (chunk_ids, count) in enumerate(zip(ids, chunk_counts)):
for pair in zip(chunk_ids, chunk_ids[1:]):
stats[pair] += count
positions[pair].add(chunk_idx)
for i in range(num_merges):
if not stats:
break
pair = max(stats, key=stats.get)
idx = 256 + i
affected_chunks = positions[pair]
count_changes = defaultdict(int)
for chunk_idx in affected_chunks:
chunk_ids = ids[chunk_idx]
chunk_count = chunk_counts[chunk_idx]
ix = 0
while ix < len(chunk_ids) - 1:
if chunk_ids[ix] == pair[0] and chunk_ids[ix+1] == pair[1]:
if ix > 0:
old_left = (chunk_ids[ix-1], chunk_ids[ix])
count_changes[old_left] -= chunk_count
count_changes[pair] -= chunk_count
if ix + 2 < len(chunk_ids):
old_right = (chunk_ids[ix+1], chunk_ids[ix+2])
count_changes[old_right] -= chunk_count
chunk_ids[ix] = idx
chunk_ids.pop(ix+1)
if ix > 0:
new_left = (chunk_ids[ix-1], chunk_ids[ix])
count_changes[new_left] += chunk_count
if ix + 1 < len(chunk_ids):
new_right = (chunk_ids[ix], chunk_ids[ix+1])
count_changes[new_right] += chunk_count
else:
ix += 1
for changed_pair, delta in count_changes.items():
if changed_pair == pair:
continue
stats[changed_pair] += delta
for chunk_idx in affected_chunks:
chunk_ids = ids[chunk_idx]
contains_pair = any((chunk_ids[j], chunk_ids[j+1]) == changed_pair
for j in range(len(chunk_ids) - 1))
if contains_pair:
positions[changed_pair].add(chunk_idx)
else:
positions[changed_pair].discard(chunk_idx)
del stats[pair]
del positions[pair]
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
self.merges = merges self.vocab = vocab
def register_special_tokens(self, special_tokens):
self.special_tokens = special_tokens
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
def decode(self, ids):
part_bytes = []
for idx in ids:
if idx in self.vocab:
part_bytes.append(self.vocab[idx])
elif idx in self.inverse_special_tokens:
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
else:
raise ValueError(f"invalid token id: {idx}")
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
def _encode_chunk(self, text_bytes):
ids = list(text_bytes)
while len(ids) >= 2:
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
if pair not in self.merges:
break idx = self.merges[pair]
ids = fast_merge_inplace(ids, pair, idx)
return ids
def encode_ordinary(self, text):
text_chunks = re.findall(self.compiled_pattern, text)
ids = []
for chunk in text_chunks:
chunk_bytes = chunk.encode("utf-8") chunk_ids = self._encode_chunk(chunk_bytes)
ids.extend(chunk_ids)
return ids
from tokenizers import Tokenizer as HFTokenizer
from tokenizers import pre_tokenizers, decoders, Regex
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
class HuggingFaceTokenizer:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
@classmethod
def train_from_iterator(cls, text_iterator, vocab_size):
tokenizer = HFTokenizer(BPE(
byte_fallback=True, unk_token=None,
fuse_unk=False,
))
tokenizer.normalizer = None
gpt4_split_regex = Regex(GPT4_SPLIT_PATTERN) tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
pre_tokenizers.Split(pattern=gpt4_split_regex, behavior="isolated", invert=False),
pre_tokenizers.ByteLevel(add_prefix_space=False, use_regex=False)
])
tokenizer.decoder = decoders.ByteLevel()
tokenizer.post_processor = None
trainer = BpeTrainer(
vocab_size=vocab_size,
show_progress=True,
min_frequency=0, initial_alphabet=pre_tokenizers.ByteLevel.alphabet(),
special_tokens=[], )
tokenizer.train_from_iterator(text_iterator, trainer)
return cls(tokenizer)
def encode_ordinary(self, text):
ids = self.tokenizer.encode(text, add_special_tokens=False).ids
return ids
def get_cache_dir():
import os
cache_home = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
cache_dir = os.path.join(cache_home, "rustbpe")
os.makedirs(cache_dir, exist_ok=True)
return cache_dir
@pytest.fixture(scope="module")
def enwik8_path():
import os
import zipfile
base_dir = get_cache_dir()
enwik8_url = "https://mattmahoney.net/dc/enwik8.zip"
enwik8_local_path = os.path.join(base_dir, "enwik8")
enwik8_local_path_zip = os.path.join(base_dir, "enwik8.zip")
if not os.path.exists(enwik8_local_path):
print(f"Downloading enwik8 to {enwik8_local_path_zip}")
import requests
response = requests.get(enwik8_url)
with open(enwik8_local_path_zip, "wb") as f:
f.write(response.content)
with zipfile.ZipFile(enwik8_local_path_zip, "r") as zip_ref:
zip_ref.extractall(base_dir)
print(f"Unzipped enwik8 to {enwik8_local_path}")
os.remove(enwik8_local_path_zip)
print(f"Removed {enwik8_local_path_zip}")
else:
print(f"Using existing enwik8 at {enwik8_local_path}")
return enwik8_local_path
@pytest.fixture(scope="module")
def enwik8_small(enwik8_path):
with open(enwik8_path, "r", encoding="utf-8") as f:
return f.read(100_000)
@pytest.fixture(scope="module")
def enwik8_large(enwik8_path):
with open(enwik8_path, "r", encoding="utf-8") as f:
return f.read(10**7)
def time_function(func, *args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
elapsed = end_time - start_time
return result, elapsed
def test_correctness(enwik8_small):
text = enwik8_small
encode_text = text
vocab_size = 256 + 20
print("\nTraining slow reference...")
slow_reference_tokenizer = RegexTokenizer()
ambiguous_flag, slow_reference_train_time = time_function(slow_reference_tokenizer.train, text, vocab_size)
slow_reference_ids, slow_reference_encode_time = time_function(slow_reference_tokenizer.encode_ordinary, encode_text)
print(f"Slow reference train time: {slow_reference_train_time:.4f}s")
print(f"Slow reference encode time: {slow_reference_encode_time:.4f}s")
print(slow_reference_ids[:20])
if ambiguous_flag:
print("‼️ WARNING: merge order was detected to be ambiguous given current text and vocab size")
print("The implementation could be correct but we might see different results below")
else:
print("✅ Merge order is NOT ambiguous")
print("\nTraining fast reference...")
fast_reference_tokenizer = FastRegexTokenizer()
_, fast_reference_train_time = time_function(fast_reference_tokenizer.train, text, vocab_size)
fast_reference_ids, fast_reference_encode_time = time_function(fast_reference_tokenizer.encode_ordinary, encode_text)
print(f"Fast reference train time: {fast_reference_train_time:.4f}s")
print(f"Fast reference encode time: {fast_reference_encode_time:.4f}s")
print(fast_reference_ids[:20])
assert fast_reference_ids == slow_reference_ids, "Fast reference should match slow reference"
print("✅ Fast == Slow")
print("\nTraining HuggingFace...")
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
hf_ids, hf_encode_time = time_function(hf_tokenizer.encode_ordinary, encode_text)
print(f"HuggingFace train time: {hf_train_time:.4f}s")
print(f"HuggingFace encode time: {hf_encode_time:.4f}s")
print(hf_ids[:20])
def custom_match(ids1, ids2):
perm = {}
for x, y in zip(ids1, ids2):
if x < 256:
if x in perm:
if perm[x] != y:
return False
perm[x] = y
if x >= 256 and x != y:
return False
return True
assert custom_match(hf_ids, fast_reference_ids), "HuggingFace should match fast reference"
print("✅ HuggingFace == Fast")
print("\nTraining rustbpe...")
rustbpe_tokenizer = rustbpe.Tokenizer()
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
rustbpe_ids, rustbpe_encode_time = time_function(rustbpe_tokenizer.encode, encode_text)
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
print(f"RustBPE encode time: {rustbpe_encode_time:.4f}s")
print(rustbpe_ids[:20])
assert rustbpe_ids == fast_reference_ids, "RustBPE should match fast reference"
print("✅ RustBPE == Fast")
print("\nTesting tiktoken export...")
pattern = rustbpe_tokenizer.get_pattern()
mergeable_ranks_list = rustbpe_tokenizer.get_mergeable_ranks()
mergeable_ranks = {bytes(k): v for k, v in mergeable_ranks_list}
enc = tiktoken.Encoding(
name="rustbpe",
pat_str=pattern,
mergeable_ranks=mergeable_ranks,
special_tokens={},
)
tiktoken_ids, tiktoken_encode_time = time_function(enc.encode, encode_text)
print(f"Tiktoken encode time: {tiktoken_encode_time:.4f}s")
print(tiktoken_ids[:20])
assert tiktoken_ids == rustbpe_ids, "Tiktoken should match RustBPE"
print("✅ Tiktoken == RustBPE")
@pytest.mark.slow
def test_training_performance(enwik8_large):
text = enwik8_large
vocab_size = 2048
print(f"\nText length: {len(text)}")
print("\nTraining rustbpe...")
rustbpe_tokenizer = rustbpe.Tokenizer()
_, rustbpe_train_time = time_function(rustbpe_tokenizer.train_from_iterator, [text], vocab_size)
print(f"RustBPE train time: {rustbpe_train_time:.4f}s")
assert rustbpe_train_time > 0, "Training should take some time"
print("\nTraining HuggingFace...")
hf_tokenizer, hf_train_time = time_function(HuggingFaceTokenizer.train_from_iterator, [text], vocab_size)
print(f"HuggingFace train time: {hf_train_time:.4f}s")
assert hf_train_time > 0, "Training should take some time"
print(f"\n📊 Performance comparison:")
print(f" RustBPE: {rustbpe_train_time:.4f}s")
print(f" HuggingFace: {hf_train_time:.4f}s")
print(f" Speedup: {hf_train_time/rustbpe_train_time:.2f}x")
def test_batch_encode_correctness(enwik8_small):
text = enwik8_small
vocab_size = 512
tokenizer = rustbpe.Tokenizer()
tokenizer.train_from_iterator([text], vocab_size)
test_texts = [
"Hello world",
"The quick brown fox",
"jumps over the lazy dog",
"", "a", ]
individual = [tokenizer.encode(t) for t in test_texts]
batched = tokenizer.batch_encode(test_texts)
assert individual == batched, "Batch encoding should match individual encoding"
print("✅ batch_encode() correctness verified")
def test_vocab_size():
tokenizer = rustbpe.Tokenizer()
assert tokenizer.vocab_size == 256, "New tokenizer should have vocab_size=256"
tokenizer.train_from_iterator(["hello hello hello", "world world world"], vocab_size=260)
assert tokenizer.vocab_size == 260, f"Expected vocab_size=260, got {tokenizer.vocab_size}"
print("✅ vocab_size property works correctly")
def test_decode_roundtrip(enwik8_small):
text = enwik8_small[:1000] vocab_size = 512
tokenizer = rustbpe.Tokenizer()
tokenizer.train_from_iterator([text], vocab_size)
test_strings = [
"hello world",
"The quick brown fox jumps over the lazy dog",
"12345",
" spaces ",
"MixedCASE123",
"", ]
for s in test_strings:
ids = tokenizer.encode(s)
decoded = tokenizer.decode(ids)
assert decoded == s, f"Roundtrip failed for {s!r}: got {decoded!r}"
ids = tokenizer.encode(text)
decoded = tokenizer.decode(ids)
assert decoded == text, "Roundtrip failed on training text"
print("✅ decode() roundtrip works correctly")
def test_decode_invalid_token():
tokenizer = rustbpe.Tokenizer()
try:
tokenizer.decode([300])
assert False, "Should have raised an error for invalid token"
except ValueError as e:
assert "Unknown token id" in str(e) or "unknown" in str(e).lower()
print("✅ decode() correctly rejects invalid tokens")
@pytest.mark.slow
def test_batch_encode_performance(enwik8_large):
text = enwik8_large vocab_size = 2048
print("\nTraining tokenizer...")
tokenizer = rustbpe.Tokenizer()
tokenizer.train_from_iterator([text], vocab_size)
chunk_size = 50_000 chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
chunks = chunks[:20]
print(f"\nBatch encoding benchmark:")
print(f" Number of texts: {len(chunks)}")
print(f" Avg text length: {sum(len(c) for c in chunks) / len(chunks):.0f} chars")
print("\n [1/3] Sequential encode() loop...")
sequential_results, sequential_time = time_function(
lambda: [tokenizer.encode(chunk) for chunk in chunks]
)
print(f" Time: {sequential_time:.4f}s")
print(" [2/3] Parallel batch_encode()...")
batch_results, batch_time = time_function(
tokenizer.batch_encode, chunks
)
print(f" Time: {batch_time:.4f}s")
print(" [3/3] Verifying correctness...")
assert len(batch_results) == len(sequential_results), "Result count mismatch"
for i, (seq, batch) in enumerate(zip(sequential_results, batch_results)):
assert seq == batch, f"Mismatch at index {i}"
print(" ✓ All results match")
speedup = sequential_time / batch_time
print(f"\n Performance Results:")
print(f" Sequential: {sequential_time:.4f}s")
print(f" Batch: {batch_time:.4f}s")
print(f" Speedup: {speedup:.2f}x")
if speedup < 1.5:
warnings.warn(f"batch_encode() speedup was only {speedup:.2f}x (expected >1.5x)")