use std::fs;
use std::io::{self, Write};
use std::path::Path;
const DEFAULT_CORPUS: &str = "/mnt/g/repos/rust";
const SKIP_DIRS: &[&str] = &[
".git",
"target",
"__pycache__",
"node_modules",
".svn",
".hg",
];
const SKIP_EXTENSIONS: &[&str] = &[
"png", "jpg", "jpeg", "gif", "ico", "webp", "bmp", "tiff", "woff", "woff2", "ttf", "eot", "otf", "exe", "dll", "so", "dylib", "a", "o", "rlib", "rmeta", "pdb", "d", "gz", "zip", "tar", "zst", "xz", "bz2", "7z", "pyc", "db", "sqlite", "pdf", "class", "jar",
];
const MAX_FILE_BYTES: u64 = 10 * 1024 * 1024;
fn should_skip_dir(name: &str) -> bool {
SKIP_DIRS.contains(&name)
}
fn should_skip_file(name: &str) -> bool {
match name.rsplit('.').next() {
Some(ext) => SKIP_EXTENSIONS.contains(&ext.to_ascii_lowercase().as_str()),
None => false,
}
}
fn walk(path: &Path, counts: &mut [u64; 256], total: &mut u64) -> io::Result<()> {
if path.is_symlink() {
return Ok(());
}
if path.is_file() {
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
if should_skip_file(name) {
return Ok(());
}
if fs::metadata(path).map(|m| m.len()).unwrap_or(0) > MAX_FILE_BYTES {
return Ok(());
}
if let Ok(data) = fs::read(path) {
for &b in &data {
counts[b as usize] += 1;
}
*total += data.len() as u64;
}
} else if path.is_dir() {
let name = path.file_name().and_then(|n| n.to_str()).unwrap_or("");
if should_skip_dir(name) {
return Ok(());
}
let mut entries: Vec<_> = fs::read_dir(path)?.filter_map(|e| e.ok()).collect();
entries.sort_by_key(|e| e.file_name());
for entry in entries {
walk(&entry.path(), counts, total)?;
}
}
Ok(())
}
fn generate_source(counts: &[u64; 256]) -> String {
let max_count = *counts.iter().max().unwrap_or(&1).max(&1);
let mut table = [0u16; 256];
for (b, &count) in counts.iter().enumerate() {
if count > 0 {
let v = ((count * u16::MAX as u64 + max_count / 2) / max_count).max(1) as u16;
table[b] = v;
}
}
let mut out = String::with_capacity(6 * 1024);
out.push_str(
"// Auto-generated by `cargo run --bin build_byte_freq -- <corpus_path>`.\n\
// Higher value = more common = worse prefilter candidate.\n",
);
out.push_str("pub static BYTE_FREQ: [u16; 256] = [\n");
for row in 0..16usize {
let base = row * 16;
out.push_str(&format!(" // 0x{:02X}..=0x{:02X}\n", base, base + 15));
out.push_str(" ");
for col in 0..16usize {
let v = table[base + col];
if col < 15 {
out.push_str(&format!("{v:4},"));
} else {
out.push_str(&format!("{v:4},\n"));
}
}
}
out.push_str("];\n");
out
}
fn main() -> io::Result<()> {
let args: Vec<String> = std::env::args().collect();
let corpus = args.get(1).map(|s| s.as_str()).unwrap_or(DEFAULT_CORPUS);
let out_file = args.get(2).map(|s| s.as_str());
let mut counts = [0u64; 256];
let mut total: u64 = 0;
eprintln!("Scanning {}...", corpus);
walk(Path::new(corpus), &mut counts, &mut total)?;
eprintln!("Total bytes scanned: {total}");
let source = generate_source(&counts);
if let Some(path) = out_file {
fs::write(path, source.as_bytes())?;
eprintln!("Written to {path}");
} else {
io::stdout().write_all(source.as_bytes())?;
}
Ok(())
}