import numpy as np
from collections import Counter
import struct
import os
def count_bigrams_from_huggingface(target_bytes=200_000_000):
from datasets import load_dataset
counts = np.zeros(65536, dtype=np.int64)
total_bytes = 0
languages = [
"rust", "python", "javascript", "go", "java", "c", "cpp", "typescript",
"shell", "sql", "html", "css", "ruby", "php",
]
bytes_per_lang = target_bytes // len(languages)
for lang in languages:
print(f"Processing {lang}...")
lang_bytes = 0
try:
ds = load_dataset(
"bigcode/the-stack-smol",
data_dir=f"data/{lang}",
split="train",
streaming=True,
)
for sample in ds:
content = sample["content"]
if not content:
continue
lowered = content.lower().encode("utf-8", errors="replace")
for i in range(len(lowered) - 1):
pair = (lowered[i] << 8) | lowered[i + 1]
counts[pair] += 1
lang_bytes += len(lowered)
total_bytes += len(lowered)
if lang_bytes >= bytes_per_lang:
break
except Exception as e:
print(f" Skipping {lang}: {e}")
continue
print(f" {lang}: {lang_bytes / 1e6:.1f} MB processed")
print(f"\nTotal: {total_bytes / 1e6:.1f} MB processed")
return counts
def count_bigrams_from_local_dirs(dirs):
CODE_EXTENSIONS = {
".rs", ".py", ".js", ".ts", ".go", ".java", ".c", ".h",
".cpp", ".hpp", ".cc", ".rb", ".swift", ".kt", ".scala",
".cs", ".jsx", ".tsx", ".vue", ".sh", ".bash", ".zsh",
".toml", ".yaml", ".yml", ".json", ".md", ".txt",
}
counts = np.zeros(65536, dtype=np.int64)
total_bytes = 0
file_count = 0
for root_dir in dirs:
for dirpath, _, filenames in os.walk(root_dir):
if any(part.startswith('.') for part in dirpath.split(os.sep)):
continue
if any(skip in dirpath for skip in ["node_modules", "target", "vendor", "__pycache__", ".git"]):
continue
for fname in filenames:
ext = os.path.splitext(fname)[1].lower()
if ext not in CODE_EXTENSIONS:
continue
fpath = os.path.join(dirpath, fname)
try:
with open(fpath, "rb") as f:
raw = f.read(1_000_000) lowered = bytes(
b + 32 if 65 <= b <= 90 else b for b in raw
)
for i in range(len(lowered) - 1):
pair = (lowered[i] << 8) | lowered[i + 1]
counts[pair] += 1
total_bytes += len(lowered)
file_count += 1
except (OSError, UnicodeDecodeError):
continue
print(f"Processed {file_count} files, {total_bytes / 1e6:.1f} MB")
return counts
def counts_to_weights(counts):
weights = np.zeros(65536, dtype=np.float64)
log_counts = np.log1p(counts.astype(np.float64))
max_log = log_counts.max()
if max_log > 0:
normalized = log_counts / max_log
weights = ((1.0 - normalized) * 64900 + 100).astype(np.float64)
weights[counts == 0] = 65535
weights = np.clip(weights, 0, 65535).astype(np.uint16)
return weights
def emit_rust_const(weights, output_path="weights.rs"):
with open(output_path, "w") as f:
f.write("// Auto-generated by generate_weights.py\n")
f.write("// Bigram frequency weights for sparse n-gram tokenization.\n")
f.write("// Rare byte-pairs get high weights (good gram boundaries).\n")
f.write("// Common byte-pairs get low weights (gram interiors).\n")
f.write("//\n")
f.write("// Index: (byte_a << 8) | byte_b\n")
f.write("// Usage: BIGRAM_WEIGHTS[(b1 as usize) << 8 | (b2 as usize)]\n")
f.write(f"// Corpus: mixed open-source code (Rust, Python, JS, Go, Java, C/C++, TS)\n")
f.write("\n")
f.write("pub static BIGRAM_WEIGHTS: [u16; 65536] = [\n")
for i in range(0, 65536, 16):
row = ", ".join(f"{weights[i+j]:5}" for j in range(16))
f.write(f" {row},\n")
f.write("];\n")
print(f"Written to {output_path}")
print(f"File size: {os.path.getsize(output_path) / 1024:.0f} KB")
def print_diagnostics(counts, weights):
print("\n=== Diagnostics ===\n")
top_indices = np.argsort(counts)[::-1][:20]
print("Top 20 most common byte pairs (LOW weight = gram interior):")
for idx in top_indices:
b1, b2 = idx >> 8, idx & 0xFF
c1 = chr(b1) if 32 <= b1 < 127 else f"\\x{b1:02x}"
c2 = chr(b2) if 32 <= b2 < 127 else f"\\x{b2:02x}"
print(f" '{c1}{c2}' count={counts[idx]:>12,} weight={weights[idx]:>5}")
nonzero = counts > 0
rare_indices = np.argsort(counts + (1 - nonzero.astype(np.int64)) * 10**18)[:20]
print("\nTop 20 rarest byte pairs (HIGH weight = good gram boundaries):")
for idx in rare_indices:
if counts[idx] == 0:
continue
b1, b2 = idx >> 8, idx & 0xFF
c1 = chr(b1) if 32 <= b1 < 127 else f"\\x{b1:02x}"
c2 = chr(b2) if 32 <= b2 < 127 else f"\\x{b2:02x}"
print(f" '{c1}{c2}' count={counts[idx]:>12,} weight={weights[idx]:>5}")
print(f"\nWeight stats:")
print(f" Min weight: {weights.min()}")
print(f" Max weight: {weights.max()}")
print(f" Median weight: {np.median(weights):.0f}")
print(f" Zero-count pairs (weight=65535): {(counts == 0).sum()}")
print(f" Non-zero pairs: {(counts > 0).sum()}")
print("\nSanity check (common patterns should have LOW weights):")
test_pairs = [
(ord('e'), ord(' ')), (ord('r'), ord('e')), (ord('f'), ord('n')), (ord('_'), ord('_')), (ord('q'), ord('z')), (ord('x'), ord('j')), ]
for b1, b2 in test_pairs:
idx = (b1 << 8) | b2
c1 = chr(b1) if 32 <= b1 < 127 else f"\\x{b1:02x}"
c2 = chr(b2) if 32 <= b2 < 127 else f"\\x{b2:02x}"
print(f" '{c1}{c2}' weight={weights[idx]:>5} ({'LOW/common' if weights[idx] < 10000 else 'HIGH/rare'})")
def main():
import argparse
parser = argparse.ArgumentParser(description="Generate sparse n-gram weight table")
parser.add_argument("--source", choices=["huggingface", "local"], default="huggingface",
help="Data source: 'huggingface' (download the-stack-smol) or 'local' (scan local dirs)")
parser.add_argument("--dirs", nargs="*", default=[],
help="Local directories to scan (only used with --source local)")
parser.add_argument("--target-mb", type=int, default=200,
help="Target corpus size in MB (for huggingface source)")
parser.add_argument("--output", default="weights.rs",
help="Output file path")
args = parser.parse_args()
if args.source == "huggingface":
counts = count_bigrams_from_huggingface(target_bytes=args.target_mb * 1_000_000)
else:
if not args.dirs:
print("Error: --dirs required when using --source local")
print("Example: python generate_weights.py --source local --dirs ~/projects/linux ~/projects/rustc")
return
counts = count_bigrams_from_local_dirs(args.dirs)
weights = counts_to_weights(counts)
print_diagnostics(counts, weights)
emit_rust_const(weights, args.output)
print(f"\nDone! Copy {args.output} to src/tokenizer/weights.rs")
if __name__ == "__main__":
main()