use colored::Colorize;
use std::path::{Path, PathBuf};
use std::time::Instant;
use crate::{error::CliError, output};
type Result<T> = std::result::Result<T, CliError>;
#[provable_contracts_macros::contract(
"apr-cli-operations-v1",
equation = "side_effect_classification"
)]
pub(crate) fn run_plan(
data: &Path,
vocab_size: usize,
algorithm: &str,
output_dir: &Path,
format: &str,
json_output: bool,
) -> Result<()> {
contract_pre_tokenizer_training_correctness!();
validate_algorithm(algorithm)?;
validate_vocab_size(vocab_size)?;
if !data.exists() {
return Err(CliError::FileNotFound(data.to_path_buf()));
}
let corpus_stats = analyze_corpus(data)?;
let plan = TokenizePlan {
algorithm: algorithm.to_string(),
vocab_size,
corpus_path: data.display().to_string(),
corpus_lines: corpus_stats.lines,
corpus_bytes: corpus_stats.bytes,
unique_chars: corpus_stats.unique_chars,
output_dir: output_dir.display().to_string(),
estimated_minutes: estimate_training_time(corpus_stats.bytes, vocab_size),
verdict: plan_verdict(&corpus_stats, vocab_size),
};
let effective_format = if json_output { "json" } else { format };
match effective_format {
"json" => {
let json = serde_json::to_string_pretty(&plan)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?;
println!("{json}");
}
"yaml" => {
return Err(CliError::ValidationFailed(
"YAML output not supported. Use --format json or --format text.".to_string(),
));
}
_ => print_plan_text(&plan),
}
if plan.verdict == "blocked" {
return Err(CliError::ValidationFailed(
"Plan is blocked — resolve failures before applying".to_string(),
));
}
contract_post_tokenizer_training_correctness!(&());
Ok(())
}
#[provable_contracts_macros::contract(
"apr-cli-operations-v1",
equation = "side_effect_classification"
)]
pub(crate) fn run_apply(
data: &Path,
vocab_size: usize,
algorithm: &str,
output_dir: &Path,
max_lines: usize,
json_output: bool,
) -> Result<()> {
validate_algorithm(algorithm)?;
validate_vocab_size(vocab_size)?;
if !data.exists() {
return Err(CliError::FileNotFound(data.to_path_buf()));
}
let corpus_text = read_corpus(data, max_lines)?;
let corpus_refs: Vec<&str> = corpus_text.iter().map(String::as_str).collect();
if corpus_refs.is_empty() {
return Err(CliError::ValidationFailed(
"Corpus is empty — no text to train on".to_string(),
));
}
if !json_output {
print_apply_header(data, vocab_size, algorithm, output_dir, corpus_refs.len());
}
let start = Instant::now();
let tokenizer = train_tokenizer(&corpus_refs, vocab_size, algorithm)?;
let elapsed = start.elapsed();
std::fs::create_dir_all(output_dir).map_err(|e| {
CliError::ValidationFailed(format!(
"Cannot create output directory {}: {e}",
output_dir.display()
))
})?;
let actual_vocab_size = tokenizer.vocab_size();
write_vocab_json(output_dir, &tokenizer)?;
write_merges_txt(output_dir, &tokenizer)?;
let result = TokenizeResult {
algorithm: algorithm.to_string(),
vocab_size: actual_vocab_size,
corpus_lines: corpus_refs.len(),
training_seconds: elapsed.as_secs_f64(),
output_dir: output_dir.display().to_string(),
};
if json_output {
let json = serde_json::to_string_pretty(&result)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?;
println!("{json}");
} else {
print_apply_result(&result);
}
Ok(())
}
fn validate_algorithm(algorithm: &str) -> Result<()> {
match algorithm {
"bpe" | "wordpiece" | "unigram" => Ok(()),
_ => Err(CliError::ValidationFailed(format!(
"Unknown algorithm: {algorithm}. Supported: bpe, wordpiece, unigram"
))),
}
}
fn validate_vocab_size(vocab_size: usize) -> Result<()> {
if vocab_size < 10 {
return Err(CliError::ValidationFailed(format!(
"vocab_size must be at least 10, got {vocab_size}"
)));
}
if vocab_size > 1_000_000 {
return Err(CliError::ValidationFailed(format!(
"vocab_size {vocab_size} is unreasonably large (max 1M)"
)));
}
Ok(())
}
#[derive(serde::Serialize)]
struct TokenizePlan {
algorithm: String,
vocab_size: usize,
corpus_path: String,
corpus_lines: usize,
corpus_bytes: u64,
unique_chars: usize,
output_dir: String,
estimated_minutes: f64,
verdict: String,
}
#[derive(serde::Serialize)]
struct TokenizeResult {
algorithm: String,
vocab_size: usize,
corpus_lines: usize,
training_seconds: f64,
output_dir: String,
}
struct CorpusStats {
lines: usize,
bytes: u64,
unique_chars: usize,
}
fn analyze_corpus(path: &Path) -> Result<CorpusStats> {
let metadata = std::fs::metadata(path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot stat {}: {e}", path.display())))?;
let bytes = metadata.len();
let content = std::fs::read_to_string(path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot read {}: {e}", path.display())))?;
let lines = content.lines().count();
let unique_chars: std::collections::HashSet<char> = content.chars().collect();
Ok(CorpusStats {
lines,
bytes,
unique_chars: unique_chars.len(),
})
}
fn estimate_training_time(bytes: u64, vocab_size: usize) -> f64 {
let mb = bytes as f64 / (1024.0 * 1024.0);
let vocab_factor = (vocab_size as f64 / 32000.0).max(1.0);
(mb * vocab_factor) / 60.0
}
fn plan_verdict(stats: &CorpusStats, vocab_size: usize) -> String {
if stats.lines == 0 {
return "blocked".to_string();
}
if vocab_size > stats.unique_chars * 100 {
return "warning".to_string();
}
"ready".to_string()
}
fn read_corpus(path: &Path, max_lines: usize) -> Result<Vec<String>> {
let content = std::fs::read_to_string(path).map_err(|e| {
CliError::ValidationFailed(format!("Cannot read corpus {}: {e}", path.display()))
})?;
let lines: Vec<String> = if max_lines > 0 {
content.lines().take(max_lines).map(String::from).collect()
} else {
content.lines().map(String::from).collect()
};
Ok(lines)
}
struct TrainedTokenizer {
vocab: std::collections::HashMap<String, u32>,
merges: Vec<(String, String)>,
}
impl TrainedTokenizer {
fn vocab_size(&self) -> usize {
self.vocab.len()
}
}
#[cfg(feature = "training")]
fn train_bpe_via_entrenar(
corpus: &[&str],
vocab_size: usize,
min_frequency: usize,
normalization: &str,
) -> Result<TrainedTokenizer> {
use entrenar::tokenizer::{BPETokenizer, Normalization, Tokenizer, TokenizerConfig};
let norm = match normalization {
"nfc" => Normalization::NFC,
"none" => Normalization::None,
other => {
return Err(CliError::ValidationFailed(format!(
"Unknown normalization: {other}. Supported: none, nfc"
)));
}
};
let config = TokenizerConfig::bpe()
.with_vocab_size(vocab_size)
.with_min_frequency(min_frequency)
.with_normalization(norm);
let mut tokenizer = BPETokenizer::new(config);
tokenizer
.train(corpus)
.map_err(|e| CliError::ValidationFailed(format!("BPE training failed: {e}")))?;
Ok(TrainedTokenizer {
vocab: tokenizer.vocab().clone(),
merges: tokenizer.merges().to_vec(),
})
}
#[cfg(not(feature = "training"))]
fn train_bpe_via_entrenar(
corpus: &[&str],
vocab_size: usize,
_min_frequency: usize,
_normalization: &str,
) -> Result<TrainedTokenizer> {
let tokenizer = aprender::text::tokenize::BpeTokenizer::train(corpus, vocab_size)
.map_err(|e| CliError::ValidationFailed(format!("BPE training failed: {e}")))?;
Ok(TrainedTokenizer {
vocab: tokenizer.vocab().clone(),
merges: tokenizer.merges().to_vec(),
})
}
fn train_tokenizer(
corpus: &[&str],
vocab_size: usize,
algorithm: &str,
) -> Result<TrainedTokenizer> {
match algorithm {
"bpe" => {
let tokenizer = aprender::text::tokenize::BpeTokenizer::train(corpus, vocab_size)
.map_err(|e| CliError::ValidationFailed(format!("BPE training failed: {e}")))?;
Ok(TrainedTokenizer {
vocab: tokenizer.vocab().clone(),
merges: tokenizer.merges().to_vec(),
})
}
"wordpiece" => {
let tokenizer = aprender::text::tokenize::WordPieceTokenizer::train(corpus, vocab_size)
.map_err(|e| {
CliError::ValidationFailed(format!("WordPiece training failed: {e}"))
})?;
Ok(TrainedTokenizer {
vocab: tokenizer.vocab().clone(),
merges: Vec::new(),
})
}
"unigram" => {
let tokenizer = aprender::text::tokenize::UnigramTokenizer::train(corpus, vocab_size)
.map_err(|e| {
CliError::ValidationFailed(format!("Unigram training failed: {e}"))
})?;
Ok(TrainedTokenizer {
vocab: tokenizer.vocab_ids(),
merges: Vec::new(),
})
}
_ => unreachable!("algorithm validated above"),
}
}
fn write_vocab_json(output_dir: &Path, tokenizer: &TrainedTokenizer) -> Result<()> {
let vocab_path = output_dir.join("vocab.json");
let mut entries: Vec<(&String, &u32)> = tokenizer.vocab.iter().collect();
entries.sort_by_key(|(_, id)| *id);
let ordered: serde_json::Map<String, serde_json::Value> = entries
.into_iter()
.map(|(k, v)| (k.clone(), serde_json::Value::Number((*v).into())))
.collect();
let json = serde_json::to_string_pretty(&ordered)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?;
std::fs::write(&vocab_path, json).map_err(|e| {
CliError::ValidationFailed(format!("Cannot write {}: {e}", vocab_path.display()))
})?;
Ok(())
}
fn write_merges_txt(output_dir: &Path, tokenizer: &TrainedTokenizer) -> Result<()> {
let merges_path = output_dir.join("merges.txt");
let mut content = String::from("#version: 0.2\n");
for (left, right) in &tokenizer.merges {
content.push_str(left);
content.push(' ');
content.push_str(right);
content.push('\n');
}
std::fs::write(&merges_path, content).map_err(|e| {
CliError::ValidationFailed(format!("Cannot write {}: {e}", merges_path.display()))
})?;
Ok(())
}
fn print_plan_text(plan: &TokenizePlan) {
output::header("apr tokenize plan — Tokenizer Training Pre-flight");
println!();
output::section("Configuration");
output::kv(" Algorithm", &plan.algorithm);
output::kv(" Vocab size", format_number(plan.vocab_size));
output::kv(" Corpus", &plan.corpus_path);
output::kv(" Output", &plan.output_dir);
println!();
output::section("Corpus Analysis");
output::kv(" Lines", format_number(plan.corpus_lines));
output::kv(" Size", format_bytes(plan.corpus_bytes));
output::kv(" Unique chars", format_number(plan.unique_chars));
println!();
output::section("Estimates");
output::kv(" Training time", format_duration(plan.estimated_minutes));
println!();
let verdict_display = match plan.verdict.as_str() {
"ready" => format!("{}", "READY".green().bold()),
"warning" => format!("{}", "WARNING".yellow().bold()),
"blocked" => format!("{}", "BLOCKED".red().bold()),
_ => plan.verdict.clone(),
};
output::kv(" Verdict", verdict_display);
println!();
}
fn print_apply_header(
data: &Path,
vocab_size: usize,
algorithm: &str,
output_dir: &Path,
corpus_lines: usize,
) {
output::header("apr tokenize apply — Training Tokenizer");
println!();
output::kv(" Algorithm", algorithm);
output::kv(" Vocab size", format_number(vocab_size));
output::kv(" Corpus", data.display().to_string());
output::kv(" Lines", format_number(corpus_lines));
output::kv(" Output", output_dir.display().to_string());
println!();
}
fn print_apply_result(result: &TokenizeResult) {
output::section("Result");
println!(" {} Tokenizer trained successfully", "OK".green().bold());
output::kv(" Final vocab size", format_number(result.vocab_size));
output::kv(
" Training time",
format!("{:.1}s", result.training_seconds),
);
output::kv(" vocab.json", format!("{}/vocab.json", result.output_dir));
output::kv(" merges.txt", format!("{}/merges.txt", result.output_dir));
println!();
}
fn format_number(n: usize) -> String {
if n >= 1_000_000 {
format!("{:.1}M", n as f64 / 1_000_000.0)
} else if n >= 1_000 {
format!("{:.1}K", n as f64 / 1_000.0)
} else {
n.to_string()
}
}
fn format_bytes(bytes: u64) -> String {
if bytes >= 1_073_741_824 {
format!("{:.1} GB", bytes as f64 / 1_073_741_824.0)
} else if bytes >= 1_048_576 {
format!("{:.1} MB", bytes as f64 / 1_048_576.0)
} else if bytes >= 1024 {
format!("{:.1} KB", bytes as f64 / 1024.0)
} else {
format!("{bytes} B")
}
}
fn format_duration(minutes: f64) -> String {
if minutes < 1.0 {
format!("{:.0} sec", minutes * 60.0)
} else if minutes < 60.0 {
format!("{:.1} min", minutes)
} else {
format!("{:.1} hours", minutes / 60.0)
}
}
#[derive(serde::Serialize)]
struct TokenizeTrainResult {
algorithm: String,
vocab_size: usize,
corpus_lines: usize,
corpus_files: usize,
min_frequency: usize,
normalization: String,
training_seconds: f64,
output_dir: String,
}
pub(crate) fn run_train(
corpus: &Path,
vocab_size: usize,
min_frequency: usize,
output_dir: &Path,
normalization: &str,
json_output: bool,
) -> Result<()> {
validate_vocab_size(vocab_size)?;
validate_normalization(normalization)?;
if !corpus.exists() {
return Err(CliError::FileNotFound(corpus.to_path_buf()));
}
let files = collect_jsonl_files(corpus)?;
if files.is_empty() {
return Err(CliError::ValidationFailed(format!(
"No .jsonl files found under {}",
corpus.display()
)));
}
let mut lines: Vec<String> = Vec::new();
for file in &files {
read_jsonl_content(file, &mut lines)?;
}
if lines.is_empty() {
return Err(CliError::ValidationFailed(
"Corpus contained no `content` fields — nothing to train on".to_string(),
));
}
if !json_output {
print_train_header(corpus, vocab_size, output_dir, files.len(), lines.len());
}
let refs: Vec<&str> = lines.iter().map(String::as_str).collect();
let start = Instant::now();
let trained = train_bpe_via_entrenar(&refs, vocab_size, min_frequency, normalization)?;
let elapsed = start.elapsed();
std::fs::create_dir_all(output_dir).map_err(|e| {
CliError::ValidationFailed(format!(
"Cannot create output directory {}: {e}",
output_dir.display()
))
})?;
write_vocab_json(output_dir, &trained)?;
write_merges_txt(output_dir, &trained)?;
let result = TokenizeTrainResult {
algorithm: "bpe".to_string(),
vocab_size: trained.vocab_size(),
corpus_lines: lines.len(),
corpus_files: files.len(),
min_frequency,
normalization: normalization.to_string(),
training_seconds: elapsed.as_secs_f64(),
output_dir: output_dir.display().to_string(),
};
if json_output {
let json = serde_json::to_string_pretty(&result)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?;
println!("{json}");
} else {
print_train_result(&result);
}
Ok(())
}
fn validate_normalization(norm: &str) -> Result<()> {
match norm {
"none" | "nfc" => Ok(()),
other => Err(CliError::ValidationFailed(format!(
"Unknown normalization: {other}. Supported: none, nfc"
))),
}
}
#[cfg(feature = "training")]
const ENCODE_CHUNK_SIZE: usize = 10_000;
#[derive(Debug, Clone, Copy)]
pub(crate) struct ProgressConfig {
pub quiet: bool,
pub interval_docs: u64,
pub interval_seconds: u64,
}
impl Default for ProgressConfig {
fn default() -> Self {
Self {
quiet: false,
interval_docs: 1000,
interval_seconds: 60,
}
}
}
#[cfg(feature = "training")]
pub(crate) struct ProgressEmitter {
cfg: ProgressConfig,
start: Instant,
last_emit_docs: u64,
last_emit_time: Instant,
total_docs_hint: Option<u64>,
}
#[cfg(feature = "training")]
impl ProgressEmitter {
pub(crate) fn new(cfg: ProgressConfig, total_docs_hint: Option<u64>) -> Self {
let now = Instant::now();
Self {
cfg,
start: now,
last_emit_docs: 0,
last_emit_time: now,
total_docs_hint,
}
}
pub(crate) fn should_emit(&self, docs_seen: u64, now: Instant) -> bool {
if self.cfg.quiet {
return false;
}
let docs_due = docs_seen.saturating_sub(self.last_emit_docs) >= self.cfg.interval_docs;
let time_due = now.saturating_duration_since(self.last_emit_time).as_secs()
>= self.cfg.interval_seconds;
docs_due || time_due
}
pub(crate) fn mark_emitted(&mut self, docs_seen: u64, now: Instant) {
self.last_emit_docs = docs_seen;
self.last_emit_time = now;
}
pub(crate) fn format_line(&self, docs_seen: u64, tokens_seen: u64, now: Instant) -> String {
let elapsed = now.saturating_duration_since(self.start).as_secs_f64();
let rate = if elapsed > 0.0 {
docs_seen as f64 / elapsed
} else {
0.0
};
match self.total_docs_hint {
Some(total) if total > 0 => {
let remaining = total.saturating_sub(docs_seen);
let eta_secs = if rate > 0.0 {
(remaining as f64 / rate).round() as i64
} else {
0
};
let eta = format_eta_iso8601_utc(eta_secs);
format!(
"[progress] doc={docs_seen}/{total} tokens={tokens_seen} \
rate={rate:.1} docs/s eta={eta}"
)
}
_ => format!("[progress] doc={docs_seen} tokens={tokens_seen} rate={rate:.1} docs/s"),
}
}
pub(crate) fn emit_tick(&mut self, docs_seen: u64, tokens_seen: u64, now: Instant) {
if self.cfg.quiet {
return;
}
eprintln!("{}", self.format_line(docs_seen, tokens_seen, now));
self.mark_emitted(docs_seen, now);
}
pub(crate) fn emit_final(&self, total_docs: u64, total_tokens: u64) {
if self.cfg.quiet {
return;
}
let elapsed = self.start.elapsed().as_secs_f64();
let rate = if elapsed > 0.0 {
total_docs as f64 / elapsed
} else {
0.0
};
eprintln!(
"[progress] done docs={total_docs} tokens={total_tokens} \
elapsed={elapsed:.1}s rate={rate:.1} docs/s"
);
}
}
#[derive(Debug, Clone, Copy)]
pub(crate) struct EstimateConfig {
pub enabled: bool,
pub sample_docs: u64,
}
impl Default for EstimateConfig {
fn default() -> Self {
Self {
enabled: false,
sample_docs: 1000,
}
}
}
pub(crate) fn extrapolate_estimate(
sample_size: u64,
sample_tokens: u64,
sample_wall_seconds: f64,
total_docs: u64,
shard_tokens: u64,
num_workers: u64,
) -> (u64, u64, f64) {
if sample_size == 0 {
return (0, 0, 0.0);
}
let tokens_per_doc = sample_tokens as f64 / sample_size as f64;
let wall_per_doc = sample_wall_seconds / sample_size as f64;
let estimated_total_tokens = (tokens_per_doc * total_docs as f64).round() as u64;
let estimated_shards = if shard_tokens == 0 {
0
} else {
estimated_total_tokens.div_ceil(shard_tokens)
};
let workers = num_workers.max(1);
let estimated_wall = wall_per_doc * total_docs as f64 / workers as f64;
(estimated_total_tokens, estimated_shards, estimated_wall)
}
fn format_eta_iso8601_utc(offset_secs: i64) -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let now_epoch = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
let target = now_epoch.saturating_add(offset_secs);
let (days, seconds_of_day) = (target.div_euclid(86_400), target.rem_euclid(86_400));
let h = seconds_of_day / 3600;
let m = (seconds_of_day % 3600) / 60;
let s = seconds_of_day % 60;
let z = days + 719_468;
let era = z.div_euclid(146_097);
let doe = z - era * 146_097;
let yoe = (doe - doe / 1460 + doe / 36_524 - doe / 146_096) / 365;
let y = yoe + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
let mp = (5 * doy + 2) / 153;
let d = doy - (153 * mp + 2) / 5 + 1;
let m_civ = if mp < 10 { mp + 3 } else { mp - 9 };
let y_civ = if m_civ <= 2 { y + 1 } else { y };
format!("{y_civ:04}-{m_civ:02}-{d:02}T{h:02}:{m:02}:{s:02}Z")
}
#[cfg(feature = "training")]
fn count_corpus_docs_fast(files: &[std::path::PathBuf], format: CorpusFormat) -> Result<u64> {
use std::io::BufRead;
let mut total: u64 = 0;
for file in files {
match format {
CorpusFormat::Jsonl => {
let f = std::fs::File::open(file).map_err(|e| {
CliError::ValidationFailed(format!("Cannot open {}: {e}", file.display()))
})?;
let reader = std::io::BufReader::new(f);
for line in reader.lines() {
let line = line.map_err(|e| {
CliError::ValidationFailed(format!("Read error in {}: {e}", file.display()))
})?;
if !line.trim().is_empty() {
total += 1;
}
}
}
CorpusFormat::Parquet => {
use parquet::file::reader::{FileReader, SerializedFileReader};
let f = std::fs::File::open(file).map_err(|e| {
CliError::ValidationFailed(format!(
"Cannot open parquet {}: {e}",
file.display()
))
})?;
let reader = SerializedFileReader::new(f).map_err(|e| {
CliError::ValidationFailed(format!(
"Cannot read parquet metadata {}: {e}",
file.display()
))
})?;
let metadata = reader.metadata();
for rg in metadata.row_groups() {
total += u64::try_from(rg.num_rows()).unwrap_or(0);
}
}
}
}
Ok(total)
}
#[cfg(feature = "training")]
fn run_estimate_only_path<F>(
mut encode: F,
files: &[std::path::PathBuf],
corpus_format: CorpusFormat,
content_field: &str,
sample_size: u64,
shard_tokens: usize,
num_workers: usize,
) -> Result<()>
where
F: FnMut(&str) -> std::result::Result<Vec<u32>, String>,
{
if sample_size == 0 {
return Err(CliError::ValidationFailed(
"--estimate-sample-docs must be >= 1 (got 0)".to_string(),
));
}
let total_docs = count_corpus_docs_fast(files, corpus_format)?;
if total_docs == 0 {
return Err(CliError::ValidationFailed(format!(
"Corpus contains zero documents — nothing to estimate. \
Files inspected: {}",
files.len()
)));
}
let take_n = sample_size.min(total_docs);
let mut source = iter_corpus_texts(files, corpus_format, content_field);
let sample_start = Instant::now();
let mut sample_tokens: u64 = 0;
let mut sample_count: u64 = 0;
while sample_count < take_n {
let triple = match source.next() {
Some(Ok(t)) => t,
Some(Err(e)) => return Err(e),
None => break,
};
let (file_display, locator, text) = triple;
let ids = encode(&text).map_err(|e| {
CliError::ValidationFailed(format!("Encoding failed at {file_display} {locator}: {e}"))
})?;
sample_tokens += u64::try_from(ids.len()).unwrap_or(0);
sample_count += 1;
}
let sample_wall = sample_start.elapsed().as_secs_f64();
let (estimated_total_tokens, estimated_shards, estimated_wall) = extrapolate_estimate(
sample_count,
sample_tokens,
sample_wall,
total_docs,
u64::try_from(shard_tokens).unwrap_or(0),
u64::try_from(num_workers).unwrap_or(1),
);
eprintln!("[estimate] input_docs={total_docs}");
eprintln!(
"[estimate] sample_size={sample_count} sample_tokens={sample_tokens} \
sample_wall={sample_wall:.3}s"
);
eprintln!("[estimate] estimated_total_tokens={estimated_total_tokens}");
eprintln!("[estimate] estimated_shards={estimated_shards} (at shard_tokens={shard_tokens})");
eprintln!(
"[estimate] estimated_wall={estimated_wall:.0} seconds (at --num-workers={num_workers})"
);
Ok(())
}
#[cfg(feature = "training")]
fn resolve_num_workers(num_workers: Option<usize>) -> Result<usize> {
match num_workers {
Some(0) => Err(CliError::ValidationFailed(
"--num-workers must be >= 1 (got 0)".to_string(),
)),
Some(n) => Ok(n),
None => Ok(std::thread::available_parallelism()
.map(std::num::NonZeroUsize::get)
.unwrap_or(1)),
}
}
#[cfg(feature = "training")]
#[allow(clippy::too_many_arguments)]
pub(crate) fn run_encode_corpus(
corpus: &[std::path::PathBuf],
tokenizer_dir: &Path,
output_dir: &Path,
shard_tokens: usize,
content_field: &str,
normalization: &str,
eos_policy: &str,
num_workers: Option<usize>,
progress: ProgressConfig,
estimate: EstimateConfig,
json_output: bool,
) -> Result<()> {
use entrenar::tokenizer::{BPETokenizer, Normalization, Tokenizer, TokenizerConfig};
use std::io::Write as IoWrite;
enum EncodeTokenizer {
Hex(BPETokenizer),
ByteLevel(aprender::text::bpe::BpeTokenizer),
}
impl EncodeTokenizer {
fn vocab_size(&self) -> usize {
match self {
Self::Hex(t) => Tokenizer::vocab_size(t),
Self::ByteLevel(t) => t.vocab_size(),
}
}
fn token_to_id(&self, name: &str) -> Option<u32> {
match self {
Self::Hex(t) => Tokenizer::token_to_id(t, name),
Self::ByteLevel(t) => t.token_to_id(name),
}
}
fn encode(&self, text: &str) -> std::result::Result<Vec<u32>, String> {
match self {
Self::Hex(t) => Tokenizer::encode(t, text).map_err(|e| format!("{e}")),
Self::ByteLevel(t) => Ok(t.encode(text)),
}
}
}
validate_normalization(normalization)?;
match eos_policy {
"none" | "between" | "after" => {}
other => {
return Err(CliError::ValidationFailed(format!(
"Unknown eos_policy: {other}. Supported: none, between, after"
)));
}
}
if shard_tokens == 0 {
return Err(CliError::ValidationFailed(
"shard_tokens must be > 0".to_string(),
));
}
let workers = resolve_num_workers(num_workers)?;
for path in corpus {
if !path.exists() {
return Err(CliError::FileNotFound(path.clone()));
}
}
let vocab_path = tokenizer_dir.join("vocab.json");
let merges_path = tokenizer_dir.join("merges.txt");
if !vocab_path.exists() {
return Err(CliError::FileNotFound(vocab_path));
}
if !merges_path.exists() {
return Err(CliError::FileNotFound(merges_path));
}
let norm = match normalization {
"nfc" => Normalization::NFC,
"none" => Normalization::None,
_ => unreachable!("validated above"),
};
let config = TokenizerConfig::bpe().with_normalization(norm);
let vocab_path_str = vocab_path
.to_str()
.ok_or_else(|| {
CliError::ValidationFailed("vocab.json path has non-utf8 bytes".to_string())
})?
.to_string();
let merges_path_str = merges_path
.to_str()
.ok_or_else(|| {
CliError::ValidationFailed("merges.txt path has non-utf8 bytes".to_string())
})?
.to_string();
let vocab_json_for_detect = std::fs::read_to_string(&vocab_path_str).map_err(|e| {
CliError::ValidationFailed(format!("cannot read vocab.json {vocab_path_str}: {e}"))
})?;
let detected_vocab: std::collections::HashMap<String, u32> =
serde_json::from_str(&vocab_json_for_detect).map_err(|e| {
CliError::ValidationFailed(format!("vocab.json is not valid JSON: {e}"))
})?;
let hex_byte_count = (0u8..=255)
.map(|b| format!("{b:02x}"))
.filter(|hex| detected_vocab.contains_key(hex))
.count();
const MIN_HEX_BYTES: usize = 200;
let tokenizer: EncodeTokenizer = if hex_byte_count >= MIN_HEX_BYTES {
BPETokenizer::from_vocab_merges(&vocab_path_str, &merges_path_str, config)
.map(EncodeTokenizer::Hex)
.map_err(|e| CliError::ValidationFailed(format!("Cannot load tokenizer: {e}")))?
} else {
let tokenizer_json_path = tokenizer_dir.join("tokenizer.json");
let bpe = if tokenizer_json_path.exists() {
let json = std::fs::read_to_string(&tokenizer_json_path).map_err(|e| {
CliError::ValidationFailed(format!(
"byte-level loader: cannot read {}: {e}",
tokenizer_json_path.display()
))
})?;
aprender::text::bpe::load_from_json(&json).map_err(|byte_err| {
CliError::ValidationFailed(format!(
"byte-level loader (tokenizer.json): {byte_err}"
))
})?
} else {
let merges_txt = std::fs::read_to_string(&merges_path_str).map_err(|e| {
CliError::ValidationFailed(format!(
"byte-level loader: cannot read merges.txt {merges_path_str}: {e}"
))
})?;
aprender::text::bpe::load_from_files(&vocab_json_for_detect, &merges_txt).map_err(
|byte_err| {
CliError::ValidationFailed(format!(
"byte-level loader (vocab.json+merges.txt): {byte_err}"
))
},
)?
};
EncodeTokenizer::ByteLevel(bpe)
};
let vocab_size = tokenizer.vocab_size();
let eos_id = ["</s>", "<|endoftext|>", "<eos>", "<|eos|>"]
.iter()
.find_map(|name| tokenizer.token_to_id(name));
let tagged_files = collect_corpus_files_multi(corpus)?;
let files: Vec<std::path::PathBuf> = tagged_files.iter().map(|(p, _)| p.clone()).collect();
let unique_formats: std::collections::HashSet<CorpusFormat> =
tagged_files.iter().map(|(_, f)| *f).collect();
let corpus_format = if unique_formats.len() == 1 {
*unique_formats.iter().next().expect("non-empty")
} else {
CorpusFormat::Parquet
};
if estimate.enabled {
let workers = resolve_num_workers(num_workers)?;
return run_estimate_only_path(
|text| tokenizer.encode(text),
&files,
corpus_format,
content_field,
estimate.sample_docs,
shard_tokens,
workers,
);
}
std::fs::create_dir_all(output_dir).map_err(|e| {
CliError::ValidationFailed(format!(
"Cannot create output directory {}: {e}",
output_dir.display()
))
})?;
let start = Instant::now();
let mut shard_idx: usize = 0;
let mut tokens_in_shard: usize = 0;
let mut total_tokens: u64 = 0;
let mut total_docs: u64 = 0;
let mut eos_count: u64 = 0;
let mut writer = open_shard(output_dir, shard_idx)?;
let mut doc_iter_count: u64 = 0;
let mut emitter = ProgressEmitter::new(progress, None);
let tagged_iter: Box<dyn Iterator<Item = (std::path::PathBuf, CorpusFormat)>> =
Box::new(tagged_files.iter().cloned());
let mut source = iter_corpus_texts_tagged(tagged_iter, content_field);
let emit = |writer: &mut std::io::BufWriter<std::fs::File>,
shard_idx: &mut usize,
tokens_in_shard: &mut usize,
total_tokens: &mut u64,
eos_count: &mut u64,
doc_iter_count: &mut u64,
file_display: &str,
locator: &str,
ids: &[u32]|
-> Result<()> {
if eos_policy == "between" && *doc_iter_count > 0 {
if let Some(eos) = eos_id {
writer
.write_all(&eos.to_le_bytes())
.map_err(|e| CliError::ValidationFailed(format!("Shard write failed: {e}")))?;
*tokens_in_shard += 1;
*total_tokens += 1;
*eos_count += 1;
}
}
for id in ids {
if (*id as usize) >= vocab_size {
return Err(CliError::ValidationFailed(format!(
"Token id {id} >= vocab_size {vocab_size} at {file_display} {locator} \
(INV-PRETOK-001 violation)"
)));
}
writer
.write_all(&id.to_le_bytes())
.map_err(|e| CliError::ValidationFailed(format!("Shard write failed: {e}")))?;
*tokens_in_shard += 1;
*total_tokens += 1;
}
if eos_policy == "after" {
if let Some(eos) = eos_id {
writer
.write_all(&eos.to_le_bytes())
.map_err(|e| CliError::ValidationFailed(format!("Shard write failed: {e}")))?;
*tokens_in_shard += 1;
*total_tokens += 1;
*eos_count += 1;
}
}
*doc_iter_count += 1;
Ok(())
};
if workers <= 1 {
for triple in source.by_ref() {
let (file_display, locator, text) = triple?;
let ids = tokenizer.encode(&text).map_err(|e| {
CliError::ValidationFailed(format!(
"Encoding failed at {file_display} {locator}: {e}"
))
})?;
emit(
&mut writer,
&mut shard_idx,
&mut tokens_in_shard,
&mut total_tokens,
&mut eos_count,
&mut doc_iter_count,
&file_display,
&locator,
&ids,
)?;
total_docs += 1;
if tokens_in_shard >= shard_tokens {
writer
.flush()
.map_err(|e| CliError::ValidationFailed(format!("Shard flush failed: {e}")))?;
shard_idx += 1;
tokens_in_shard = 0;
writer = open_shard(output_dir, shard_idx)?;
}
let now = Instant::now();
if emitter.should_emit(total_docs, now) {
emitter.emit_tick(total_docs, total_tokens, now);
}
}
} else {
use rayon::prelude::*;
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(workers)
.build()
.map_err(|e| {
CliError::ValidationFailed(format!("Cannot build rayon pool ({workers}): {e}"))
})?;
loop {
let mut chunk: Vec<(String, String, String)> = Vec::with_capacity(ENCODE_CHUNK_SIZE);
for _ in 0..ENCODE_CHUNK_SIZE {
match source.next() {
Some(Ok(triple)) => chunk.push(triple),
Some(Err(e)) => return Err(e),
None => break,
}
}
if chunk.is_empty() {
break;
}
let encoded: Vec<Result<(String, String, Vec<u32>)>> = pool.install(|| {
chunk
.par_iter()
.map(|(file_display, locator, text)| {
tokenizer
.encode(text)
.map(|ids| (file_display.clone(), locator.clone(), ids))
.map_err(|e| {
CliError::ValidationFailed(format!(
"Encoding failed at {file_display} {locator}: {e}"
))
})
})
.collect()
});
for result in encoded {
let (file_display, locator, ids) = result?;
emit(
&mut writer,
&mut shard_idx,
&mut tokens_in_shard,
&mut total_tokens,
&mut eos_count,
&mut doc_iter_count,
&file_display,
&locator,
&ids,
)?;
total_docs += 1;
if tokens_in_shard >= shard_tokens {
writer.flush().map_err(|e| {
CliError::ValidationFailed(format!("Shard flush failed: {e}"))
})?;
shard_idx += 1;
tokens_in_shard = 0;
writer = open_shard(output_dir, shard_idx)?;
}
let now = Instant::now();
if emitter.should_emit(total_docs, now) {
emitter.emit_tick(total_docs, total_tokens, now);
}
}
}
}
writer
.flush()
.map_err(|e| CliError::ValidationFailed(format!("Shard flush failed: {e}")))?;
let shard_count = shard_idx + 1;
let elapsed = start.elapsed();
emitter.emit_final(total_docs, total_tokens);
let manifest = serde_json::json!({
"schema": "pretokenize-bin-v1",
"tokenizer_dir": tokenizer_dir.display().to_string(),
"vocab_size": vocab_size,
"eos_policy": eos_policy,
"eos_token_id": eos_id,
"eos_token_count": eos_count,
"shard_count": shard_count,
"total_tokens": total_tokens,
"total_documents": total_docs,
"content_field": content_field,
"normalization": normalization,
"input_format": match corpus_format {
CorpusFormat::Jsonl => "jsonl",
CorpusFormat::Parquet => "parquet",
},
"input_files": files.iter().map(|p| p.display().to_string()).collect::<Vec<_>>(),
"corpus_roots": corpus.iter().map(|p| p.display().to_string()).collect::<Vec<_>>(),
"num_workers": workers,
"elapsed_seconds": elapsed.as_secs_f64(),
});
let manifest_path = output_dir.join("manifest.json");
std::fs::write(
&manifest_path,
serde_json::to_string_pretty(&manifest)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?,
)
.map_err(|e| CliError::ValidationFailed(format!("Cannot write manifest: {e}")))?;
if json_output {
println!(
"{}",
serde_json::to_string_pretty(&manifest)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?
);
} else {
output::header("apr tokenize encode-corpus — Pretokenization Result");
output::kv(" Shards", format_number(shard_count));
output::kv(" Total tokens", format_number(total_tokens as usize));
output::kv(" Total documents", format_number(total_docs as usize));
output::kv(" Vocab size", format_number(vocab_size));
output::kv(" Workers", workers.to_string());
output::kv(" Elapsed", format!("{:.1}s", elapsed.as_secs_f64()));
output::kv(" Manifest", manifest_path.display().to_string());
}
Ok(())
}
#[cfg(feature = "training")]
fn open_shard(output_dir: &Path, shard_idx: usize) -> Result<std::io::BufWriter<std::fs::File>> {
let path = output_dir.join(format!("shard-{shard_idx:05}.bin"));
let file = std::fs::File::create(&path).map_err(|e| {
CliError::ValidationFailed(format!("Cannot create shard {}: {e}", path.display()))
})?;
Ok(std::io::BufWriter::new(file))
}
fn collect_jsonl_files(path: &Path) -> Result<Vec<std::path::PathBuf>> {
let meta = std::fs::metadata(path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot stat {}: {e}", path.display())))?;
if meta.is_file() {
if is_jsonl(path) {
return Ok(vec![path.to_path_buf()]);
}
return Err(CliError::ValidationFailed(format!(
"Corpus file {} is not a .jsonl file",
path.display()
)));
}
let mut out = Vec::new();
let entries = std::fs::read_dir(path).map_err(|e| {
CliError::ValidationFailed(format!("Cannot read directory {}: {e}", path.display()))
})?;
for entry in entries {
let entry =
entry.map_err(|e| CliError::ValidationFailed(format!("Directory entry error: {e}")))?;
let p = entry.path();
if p.is_file() && is_jsonl(&p) {
out.push(p);
}
}
out.sort();
Ok(out)
}
fn is_jsonl(path: &Path) -> bool {
path.extension().and_then(|e| e.to_str()) == Some("jsonl")
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum CorpusFormat {
Jsonl,
Parquet,
}
fn collect_corpus_files(path: &Path) -> Result<(Vec<std::path::PathBuf>, CorpusFormat)> {
let meta = std::fs::metadata(path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot stat {}: {e}", path.display())))?;
if meta.is_file() {
if super::tokenize_parquet::is_parquet(path) {
return Ok((vec![path.to_path_buf()], CorpusFormat::Parquet));
}
if is_jsonl(path) {
return Ok((vec![path.to_path_buf()], CorpusFormat::Jsonl));
}
return Err(CliError::ValidationFailed(format!(
"Corpus file {} is not a .jsonl or .parquet file",
path.display()
)));
}
let parquet = super::tokenize_parquet::collect_parquet_files(path).unwrap_or_default();
if !parquet.is_empty() {
return Ok((parquet, CorpusFormat::Parquet));
}
let jsonl = collect_jsonl_files(path)?;
if jsonl.is_empty() {
return Err(CliError::ValidationFailed(format!(
"No .jsonl or .parquet files found under {}",
path.display()
)));
}
Ok((jsonl, CorpusFormat::Jsonl))
}
fn collect_corpus_files_multi(
corpora: &[std::path::PathBuf],
) -> Result<Vec<(std::path::PathBuf, CorpusFormat)>> {
if corpora.is_empty() {
return Err(CliError::ValidationFailed(
"At least one --corpus path is required".to_string(),
));
}
let mut tagged: Vec<(std::path::PathBuf, CorpusFormat)> = Vec::new();
for path in corpora {
let (files, fmt) = collect_corpus_files(path)?;
for f in files {
tagged.push((f, fmt));
}
}
Ok(tagged)
}
#[cfg(feature = "training")]
fn iter_corpus_texts<'a>(
files: &'a [std::path::PathBuf],
format: CorpusFormat,
content_field: &'a str,
) -> Box<dyn Iterator<Item = Result<(String, String, String)>> + 'a> {
iter_corpus_texts_tagged(
Box::new(files.iter().map(move |f| (f.clone(), format))),
content_field,
)
}
#[cfg(feature = "training")]
fn iter_corpus_texts_tagged<'a>(
files: Box<dyn Iterator<Item = (std::path::PathBuf, CorpusFormat)> + 'a>,
content_field: &'a str,
) -> Box<dyn Iterator<Item = Result<(String, String, String)>> + 'a> {
Box::new(files.flat_map(move |(file, format)| {
let single = vec![file];
let inner: Box<dyn Iterator<Item = Result<(String, String, String)>>> = match format {
CorpusFormat::Parquet => iter_corpus_texts_parquet(&single, content_field),
CorpusFormat::Jsonl => iter_corpus_texts_jsonl(&single, content_field),
};
inner.collect::<Vec<_>>().into_iter()
}))
}
#[cfg(feature = "training")]
fn iter_corpus_texts_parquet<'a>(
files: &'a [std::path::PathBuf],
content_field: &'a str,
) -> Box<dyn Iterator<Item = Result<(String, String, String)>> + 'a> {
iter_corpus_texts_old(files, CorpusFormat::Parquet, content_field)
}
#[cfg(feature = "training")]
fn iter_corpus_texts_jsonl<'a>(
files: &'a [std::path::PathBuf],
content_field: &'a str,
) -> Box<dyn Iterator<Item = Result<(String, String, String)>> + 'a> {
iter_corpus_texts_old(files, CorpusFormat::Jsonl, content_field)
}
#[cfg(feature = "training")]
fn iter_corpus_texts_old<'a>(
files: &'a [std::path::PathBuf],
format: CorpusFormat,
content_field: &'a str,
) -> Box<dyn Iterator<Item = Result<(String, String, String)>> + 'a> {
match format {
CorpusFormat::Parquet => Box::new(files.iter().flat_map(move |file| {
let file_display = file.display().to_string();
match super::tokenize_parquet::iter_parquet_content(file, content_field) {
Ok(it) => {
let fd = file_display;
let inner: Box<dyn Iterator<Item = Result<(String, String, String)>>> =
Box::new(it.enumerate().map(move |(idx, r)| {
r.map(|t| (fd.clone(), format!("row {}", idx + 1), t))
}));
inner
}
Err(e) => {
let inner: Box<dyn Iterator<Item = Result<(String, String, String)>>> =
Box::new(std::iter::once(Err(e)));
inner
}
}
})),
CorpusFormat::Jsonl => Box::new(files.iter().flat_map(move |file| {
let file_display = file.display().to_string();
match std::fs::read_to_string(file) {
Ok(content) => {
let fd = file_display;
let triples: Vec<Result<(String, String, String)>> = content
.lines()
.enumerate()
.filter_map(|(idx, line)| {
let trimmed = line.trim();
if trimmed.is_empty() {
return None;
}
match serde_json::from_str::<serde_json::Value>(trimmed) {
Ok(v) => v.get(content_field).and_then(|x| x.as_str()).map(|s| {
Ok((fd.clone(), format!("line {}", idx + 1), s.to_string()))
}),
Err(e) => Some(Err(CliError::ValidationFailed(format!(
"Invalid JSON in {fd} line {}: {e}",
idx + 1
)))),
}
})
.collect();
let inner: Box<dyn Iterator<Item = Result<(String, String, String)>>> =
Box::new(triples.into_iter());
inner
}
Err(e) => {
let msg = format!("Cannot read {file_display}: {e}");
let inner: Box<dyn Iterator<Item = Result<(String, String, String)>>> =
Box::new(std::iter::once(Err(CliError::ValidationFailed(msg))));
inner
}
}
})),
}
}
fn read_jsonl_content(path: &Path, out: &mut Vec<String>) -> Result<()> {
let content = std::fs::read_to_string(path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot read {}: {e}", path.display())))?;
for (line_idx, line) in content.lines().enumerate() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let value: serde_json::Value = serde_json::from_str(trimmed).map_err(|e| {
CliError::ValidationFailed(format!(
"Invalid JSON in {} line {}: {e}",
path.display(),
line_idx + 1
))
})?;
if let Some(text) = value.get("content").and_then(|v| v.as_str()) {
out.push(text.to_string());
}
}
Ok(())
}
fn print_train_header(
corpus: &Path,
vocab_size: usize,
output_dir: &Path,
files: usize,
lines: usize,
) {
output::header("apr tokenize train — Training BPE Tokenizer");
println!();
output::kv(" Algorithm", "bpe");
output::kv(" Vocab size", format_number(vocab_size));
output::kv(" Corpus", corpus.display().to_string());
output::kv(" Files", format_number(files));
output::kv(" Lines", format_number(lines));
output::kv(" Output", output_dir.display().to_string());
println!();
}
fn print_train_result(result: &TokenizeTrainResult) {
output::section("Result");
println!(
" {} BPE tokenizer trained successfully",
"OK".green().bold()
);
output::kv(" Final vocab size", format_number(result.vocab_size));
output::kv(" Normalization", &result.normalization);
output::kv(
" Training time",
format!("{:.1}s", result.training_seconds),
);
output::kv(" vocab.json", format!("{}/vocab.json", result.output_dir));
output::kv(" merges.txt", format!("{}/merges.txt", result.output_dir));
println!();
}
pub(crate) fn run_import_hf(
input: &Path,
output: &Path,
include_added_tokens: bool,
json_output: bool,
) -> Result<()> {
if !input.exists() {
return Err(CliError::FileNotFound(input.to_path_buf()));
}
let raw = std::fs::read_to_string(input).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot read {}: {e}",
input.display()
))
})?;
let parsed: serde_json::Value = serde_json::from_str(&raw).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] {} is not valid JSON: {e}",
input.display()
))
})?;
let model_type = parsed
.get("model")
.and_then(|m| m.get("type"))
.and_then(serde_json::Value::as_str)
.ok_or_else(|| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] {} has no model.type field; \
not a recognizable HF tokenizer.json",
input.display()
))
})?;
if model_type != "BPE" {
return Err(CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] FALSIFY-TOK-IMPORT-HF-005: \
model.type = '{model_type}' but only 'BPE' is supported. \
{} cannot be imported with this subcommand. \
Aprender's BPE loader requires GPT-2-style vocab.json + merges.txt; \
Unigram and WordPiece use different state machines and need separate paths.",
input.display()
)));
}
let vocab_obj = parsed
.get("model")
.and_then(|m| m.get("vocab"))
.and_then(serde_json::Value::as_object)
.ok_or_else(|| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] {} has no model.vocab object",
input.display()
))
})?;
let bpe_vocab_count = vocab_obj.len();
let merges_arr = parsed
.get("model")
.and_then(|m| m.get("merges"))
.and_then(serde_json::Value::as_array)
.ok_or_else(|| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] {} has no model.merges array",
input.display()
))
})?;
let merges_count = merges_arr.len();
let added_tokens_arr = parsed
.get("added_tokens")
.and_then(serde_json::Value::as_array)
.cloned()
.unwrap_or_default();
let added_tokens_count = added_tokens_arr.len();
let mut effective_vocab: serde_json::Map<String, serde_json::Value> = vocab_obj.clone();
if include_added_tokens {
for tok in &added_tokens_arr {
if let (Some(content), Some(id)) = (
tok.get("content").and_then(serde_json::Value::as_str),
tok.get("id").and_then(serde_json::Value::as_u64),
) {
effective_vocab.insert(
content.to_string(),
serde_json::Value::Number(serde_json::Number::from(id)),
);
}
}
}
std::fs::create_dir_all(output).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot create output dir {}: {e}",
output.display()
))
})?;
let vocab_path = output.join("vocab.json");
let vocab_json = serde_json::to_string_pretty(&effective_vocab)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?;
std::fs::write(&vocab_path, vocab_json).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot write {}: {e}",
vocab_path.display()
))
})?;
let merges_path = output.join("merges.txt");
let mut merges_body = String::from("#version: 0.2\n");
for (idx, m) in merges_arr.iter().enumerate() {
let line = match m {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(parts) if parts.len() == 2 => {
let a = parts[0].as_str().unwrap_or("");
let b = parts[1].as_str().unwrap_or("");
format!("{a} {b}")
}
_ => {
return Err(CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] merges[{idx}] is neither a string \
nor a [a, b] tuple in {}",
input.display()
)));
}
};
merges_body.push_str(&line);
merges_body.push('\n');
}
std::fs::write(&merges_path, merges_body).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot write {}: {e}",
merges_path.display()
))
})?;
let manifest = serde_json::json!({
"schema": "apr-cli-tokenize-import-hf-v1",
"source": input.display().to_string(),
"source_sha256": sha256_file(input)?,
"model_type": "BPE",
"bpe_vocab_count": bpe_vocab_count,
"merges_count": merges_count,
"added_tokens_count": added_tokens_count,
"include_added_tokens": include_added_tokens,
"effective_vocab_count": effective_vocab.len(),
"extraction_timestamp_utc": chrono::Utc::now().to_rfc3339(),
});
let manifest_path = output.join("manifest.json");
std::fs::write(
&manifest_path,
serde_json::to_string_pretty(&manifest)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?,
)
.map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot write {}: {e}",
manifest_path.display()
))
})?;
if json_output {
println!(
"{}",
serde_json::to_string_pretty(&manifest)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?
);
} else {
output::header("apr tokenize import-hf — HF BPE → aprender extraction");
println!();
output::kv(" Source", input.display().to_string());
output::kv(" BPE vocab", format_number(bpe_vocab_count));
output::kv(" Merges", format_number(merges_count));
output::kv(" Added tokens", format_number(added_tokens_count));
output::kv(" Effective vocab", format_number(effective_vocab.len()));
output::kv(" Output dir", output.display().to_string());
println!();
println!("{}", "Wrote:".green().bold());
output::kv(" vocab.json", format!("{}/vocab.json", output.display()));
output::kv(" merges.txt", format!("{}/merges.txt", output.display()));
output::kv(
" manifest.json",
format!("{}/manifest.json", output.display()),
);
}
Ok(())
}
fn sha256_file(path: &Path) -> Result<String> {
use sha2::{Digest, Sha256};
let bytes = std::fs::read(path).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot read {} for sha256: {e}",
path.display()
))
})?;
let mut h = Sha256::new();
h.update(&bytes);
Ok(format!("{:x}", h.finalize()))
}
#[cfg(feature = "training")]
fn collect_shard_paths(output_dir: &Path) -> Result<Vec<std::path::PathBuf>> {
let entries = std::fs::read_dir(output_dir).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-tokenize-repair-manifest-v1] cannot read output dir {}: {e}",
output_dir.display()
))
})?;
let mut shards: Vec<std::path::PathBuf> = entries
.filter_map(std::result::Result::ok)
.map(|e| e.path())
.filter(|p| {
p.is_file()
&& p.file_name()
.and_then(|n| n.to_str())
.is_some_and(|n| n.starts_with("shard-") && n.ends_with(".bin"))
})
.collect();
shards.sort();
Ok(shards)
}
#[cfg(feature = "training")]
fn read_vocab_size_from_tokenizer(tokenizer_dir: &Path) -> Result<usize> {
let vocab_path = tokenizer_dir.join("vocab.json");
let raw = std::fs::read_to_string(&vocab_path).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-tokenize-repair-manifest-v1] cannot read {}: {e}",
vocab_path.display()
))
})?;
let parsed: serde_json::Value = serde_json::from_str(&raw).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-tokenize-repair-manifest-v1] {} is not valid JSON: {e}",
vocab_path.display()
))
})?;
let obj = parsed.as_object().ok_or_else(|| {
CliError::ValidationFailed(format!(
"[apr-tokenize-repair-manifest-v1] {} is not a JSON object",
vocab_path.display()
))
})?;
Ok(obj.len())
}
#[cfg(feature = "training")]
pub(crate) fn run_repair_manifest(
output_dir: &Path,
tokenizer_dir: Option<&Path>,
json_output: bool,
) -> Result<()> {
if !output_dir.is_dir() {
return Err(CliError::ValidationFailed(format!(
"[apr-tokenize-repair-manifest-v1] output dir {} does not exist or is not a directory",
output_dir.display()
)));
}
let shards = collect_shard_paths(output_dir)?;
if shards.is_empty() {
return Err(CliError::ValidationFailed(format!(
"[apr-tokenize-repair-manifest-v1] no shard-*.bin files in {} — nothing to repair",
output_dir.display()
)));
}
let mut total_bytes: u64 = 0;
for shard in &shards {
let meta = std::fs::metadata(shard).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-tokenize-repair-manifest-v1] cannot stat {}: {e}",
shard.display()
))
})?;
let len = meta.len();
if !len.is_multiple_of(4) {
return Err(CliError::ValidationFailed(format!(
"[apr-tokenize-repair-manifest-v1] {} byte length {} is not a multiple of 4 \
(shards are little-endian u32 streams; corrupt or non-shard file)",
shard.display(),
len
)));
}
total_bytes += len;
}
let total_tokens: u64 = total_bytes / 4;
let shard_count = shards.len();
let vocab_size = match tokenizer_dir {
Some(dir) => Some(read_vocab_size_from_tokenizer(dir)?),
None => None,
};
let manifest = serde_json::json!({
"schema": "pretokenize-bin-v1",
"shard_count": shard_count,
"total_tokens": total_tokens,
"vocab_size": vocab_size,
"tokenizer_dir": tokenizer_dir.map(|p| p.display().to_string()),
"repair": true,
"repaired_at": chrono::Utc::now().to_rfc3339(),
"source": "repair-manifest",
"note": "Reconstructed from existing shard-*.bin file sizes; original \
encoder process exited before writing manifest.json. \
ShardBatchIter consumes this directory regardless — the \
manifest is provenance, not load-bearing.",
});
let manifest_path = output_dir.join("manifest.json");
std::fs::write(
&manifest_path,
serde_json::to_string_pretty(&manifest)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?,
)
.map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-tokenize-repair-manifest-v1] cannot write {}: {e}",
manifest_path.display()
))
})?;
if json_output {
println!(
"{}",
serde_json::to_string_pretty(&manifest)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?
);
} else {
output::header("apr tokenize repair-manifest — Provenance Recovery");
output::kv(" Shards", format_number(shard_count));
output::kv(" Total tokens", format_number(total_tokens as usize));
output::kv(" Total bytes", format_number(total_bytes as usize));
if let Some(v) = vocab_size {
output::kv(" Vocab size", format_number(v));
} else {
output::kv(" Vocab size", "(unknown — pass --tokenizer)".to_string());
}
output::kv(" Manifest", manifest_path.display().to_string());
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn write_corpus_file(dir: &Path, name: &str, lines: &[&str]) -> std::path::PathBuf {
let p = dir.join(name);
let body = lines.join("\n");
std::fs::write(&p, body).expect("write corpus");
p
}
#[test]
fn run_train_happy_path_jsonl_file() {
let tmp = TempDir::new().expect("tempdir");
let corpus = write_corpus_file(
tmp.path(),
"corpus.jsonl",
&[
r#"{"content": "hello world hello"}"#,
r#"{"content": "hello there world"}"#,
],
);
let out = tmp.path().join("tok");
run_train(&corpus, 300, 1, &out, "nfc", true).expect("train");
assert!(out.join("vocab.json").exists());
assert!(out.join("merges.txt").exists());
let vocab = std::fs::read_to_string(out.join("vocab.json")).expect("read vocab");
assert!(
vocab.contains("\"<unk>\""),
"vocab.json missing <unk>: {}",
vocab
);
let merges = std::fs::read_to_string(out.join("merges.txt")).expect("read merges");
assert!(
merges.starts_with("#version: 0.2"),
"merges.txt missing header: {}",
merges
);
}
#[test]
fn run_train_directory_corpus_walks_jsonl() {
let tmp = TempDir::new().expect("tempdir");
let corpus_dir = tmp.path().join("corpus");
std::fs::create_dir_all(&corpus_dir).expect("mkdir");
write_corpus_file(
&corpus_dir,
"a.jsonl",
&[r#"{"content": "alpha beta alpha"}"#],
);
write_corpus_file(
&corpus_dir,
"b.jsonl",
&[r#"{"content": "gamma delta gamma"}"#],
);
std::fs::write(corpus_dir.join("notes.txt"), "ignore me").expect("write ignored");
let out = tmp.path().join("tok");
run_train(&corpus_dir, 300, 2, &out, "nfc", true).expect("train");
assert!(out.join("vocab.json").exists());
assert!(out.join("merges.txt").exists());
}
#[cfg(feature = "training")]
#[test]
fn run_train_honors_min_frequency_pruning() {
let tmp = TempDir::new().expect("tempdir");
let lines: Vec<String> = std::iter::repeat_n(r#"{"content": "abc"}"#.to_string(), 5)
.chain(std::iter::once(r#"{"content": "xyz"}"#.to_string()))
.collect();
let body = lines.join("\n");
let corpus = tmp.path().join("corpus.jsonl");
std::fs::write(&corpus, body).expect("write corpus");
let out = tmp.path().join("tok");
run_train(&corpus, 300, 2, &out, "nfc", true).expect("train");
let merges = std::fs::read_to_string(out.join("merges.txt")).expect("read merges.txt");
assert!(
merges.contains("61 62") || merges.contains("62 63"),
"Expected a merge from the frequent 'abc' pair, got: {}",
merges
);
assert!(
!merges.contains("78 79"),
"min_frequency=2 failed to prune singleton 'xy' pair: {}",
merges
);
assert!(
!merges.contains("79 7a"),
"min_frequency=2 failed to prune singleton 'yz' pair: {}",
merges
);
let vocab = std::fs::read_to_string(out.join("vocab.json")).expect("read vocab");
assert!(
!vocab.contains("\"78797a\""),
"min_frequency=2 failed to prune merged 'xyz' token from vocab: {}",
vocab
);
}
#[test]
fn run_train_rejects_unknown_normalization() {
let tmp = TempDir::new().expect("tempdir");
let corpus = write_corpus_file(tmp.path(), "corpus.jsonl", &[r#"{"content": "x y"}"#]);
let err = run_train(&corpus, 300, 1, tmp.path(), "nfkd", true)
.expect_err("should reject unsupported normalization");
match err {
CliError::ValidationFailed(msg) => assert!(msg.contains("nfkd")),
other => panic!("unexpected error: {other:?}"),
}
}
fn write_minimal_bpe_tokenizer_json(dir: &Path, n_vocab: usize, n_merges: usize) -> PathBuf {
let mut vocab = serde_json::Map::new();
for i in 0..n_vocab {
vocab.insert(format!("tok{i}"), serde_json::Value::Number(i.into()));
}
let merges: Vec<serde_json::Value> = (0..n_merges)
.map(|i| serde_json::Value::String(format!("a{i} b{i}")))
.collect();
let added_tokens = vec![serde_json::json!({
"id": n_vocab,
"content": "<|endoftext|>",
"special": true,
})];
let tok = serde_json::json!({
"version": "1.0",
"added_tokens": added_tokens,
"model": {
"type": "BPE",
"vocab": vocab,
"merges": merges,
},
});
let path = dir.join("tokenizer.json");
std::fs::write(
&path,
serde_json::to_string_pretty(&tok).expect("serialize tok"),
)
.expect("write tok");
path
}
#[test]
fn import_hf_qwen_bpe_writes_vocab_and_merges() {
let tmp = TempDir::new().expect("tempdir");
let input = write_minimal_bpe_tokenizer_json(tmp.path(), 1000, 800);
let output = tmp.path().join("extracted");
run_import_hf(&input, &output, false, true).expect("import-hf should succeed on BPE input");
let vocab_path = output.join("vocab.json");
assert!(vocab_path.exists(), "vocab.json must exist");
let vocab_str = std::fs::read_to_string(&vocab_path).expect("read vocab.json");
let vocab_obj: serde_json::Map<String, serde_json::Value> =
serde_json::from_str(&vocab_str).expect("parse vocab.json");
assert_eq!(
vocab_obj.len(),
1000,
"FALSIFY-TOK-IMPORT-HF-002: vocab.json must have 1000 entries (default mode), got {}",
vocab_obj.len()
);
let merges_path = output.join("merges.txt");
assert!(merges_path.exists(), "merges.txt must exist");
let merges_str = std::fs::read_to_string(&merges_path).expect("read merges.txt");
let merge_lines = merges_str.lines().filter(|l| !l.starts_with('#')).count();
assert_eq!(
merge_lines, 800,
"FALSIFY-TOK-IMPORT-HF-002: merges.txt must have 800 merge lines, got {merge_lines}"
);
}
#[test]
fn import_hf_vocab_count_matches_input() {
let tmp = TempDir::new().expect("tempdir");
let input = write_minimal_bpe_tokenizer_json(tmp.path(), 12345, 100);
let output = tmp.path().join("extracted");
run_import_hf(&input, &output, false, true).expect("import-hf");
let vocab_obj: serde_json::Map<String, serde_json::Value> =
serde_json::from_str(&std::fs::read_to_string(output.join("vocab.json")).unwrap())
.unwrap();
assert_eq!(
vocab_obj.len(),
12345,
"FALSIFY-TOK-IMPORT-HF-003: vocab count must match input model.vocab"
);
}
#[test]
fn import_hf_merges_format_and_order() {
let tmp = TempDir::new().expect("tempdir");
let input = write_minimal_bpe_tokenizer_json(tmp.path(), 10, 5);
let output = tmp.path().join("extracted");
run_import_hf(&input, &output, false, true).expect("import-hf");
let body = std::fs::read_to_string(output.join("merges.txt")).expect("read merges");
let lines: Vec<&str> = body.lines().filter(|l| !l.starts_with('#')).collect();
assert_eq!(lines.len(), 5);
for (i, line) in lines.iter().enumerate() {
assert_eq!(
line.trim(),
format!("a{i} b{i}"),
"FALSIFY-TOK-IMPORT-HF-004: merge[{i}] order or format mismatch"
);
}
}
#[test]
fn import_hf_unigram_input_errors() {
let tmp = TempDir::new().expect("tempdir");
let input = tmp.path().join("tokenizer.json");
let unigram = serde_json::json!({
"model": { "type": "Unigram", "vocab": [] },
});
std::fs::write(&input, serde_json::to_string_pretty(&unigram).unwrap()).unwrap();
let output = tmp.path().join("extracted");
let err = run_import_hf(&input, &output, false, true)
.expect_err("FALSIFY-TOK-IMPORT-HF-005: Unigram input MUST fail-fast");
match err {
CliError::ValidationFailed(msg) => {
assert!(
msg.contains("FALSIFY-TOK-IMPORT-HF-005"),
"error must cite falsifier id (auditability): {msg}"
);
assert!(
msg.contains("Unigram"),
"error must name the actual model type: {msg}"
);
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn import_hf_include_added_tokens_appends_specials() {
let tmp = TempDir::new().expect("tempdir");
let input = write_minimal_bpe_tokenizer_json(tmp.path(), 100, 50);
let out_default = tmp.path().join("default");
run_import_hf(&input, &out_default, false, true).expect("default import");
let v_default: serde_json::Map<String, serde_json::Value> =
serde_json::from_str(&std::fs::read_to_string(out_default.join("vocab.json")).unwrap())
.unwrap();
assert_eq!(v_default.len(), 100);
assert!(
!v_default.contains_key("<|endoftext|>"),
"default mode must NOT include added_tokens"
);
let out_full = tmp.path().join("full");
run_import_hf(&input, &out_full, true, true).expect("full import");
let v_full: serde_json::Map<String, serde_json::Value> =
serde_json::from_str(&std::fs::read_to_string(out_full.join("vocab.json")).unwrap())
.unwrap();
assert_eq!(v_full.len(), 101);
assert!(
v_full.contains_key("<|endoftext|>"),
"include-added-tokens mode must include the special"
);
}
#[cfg(feature = "training")]
#[test]
fn encode_corpus_resolve_workers_default_is_available_parallelism() {
let resolved = resolve_num_workers(None).expect("default must resolve");
let expected = std::thread::available_parallelism()
.map(std::num::NonZeroUsize::get)
.unwrap_or(1);
assert_eq!(
resolved, expected,
"default must equal available_parallelism (or 1 fallback)"
);
assert!(resolved >= 1, "resolved worker count must be >= 1");
}
#[cfg(feature = "training")]
#[test]
fn encode_corpus_resolve_workers_explicit_value_passes_through() {
assert_eq!(resolve_num_workers(Some(1)).expect("Some(1)"), 1);
assert_eq!(resolve_num_workers(Some(4)).expect("Some(4)"), 4);
assert_eq!(resolve_num_workers(Some(64)).expect("Some(64)"), 64);
}
#[cfg(feature = "training")]
#[test]
fn encode_corpus_resolve_workers_rejects_zero() {
let err = resolve_num_workers(Some(0)).expect_err("zero must error");
match err {
CliError::ValidationFailed(msg) => {
assert!(msg.contains("--num-workers"), "error must name flag: {msg}");
assert!(msg.contains(">= 1"), "error must state bound: {msg}");
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[cfg(feature = "training")]
#[test]
fn encode_corpus_num_workers_1_matches_num_workers_n_byte_for_byte() {
let tmp = TempDir::new().expect("tempdir");
let train_corpus = write_corpus_file(
tmp.path(),
"train.jsonl",
&[
r#"{"content": "hello world the quick brown fox"}"#,
r#"{"content": "the lazy dog jumped over the fence"}"#,
r#"{"content": "rust is a systems programming language"}"#,
],
);
let tok_dir = tmp.path().join("tok");
run_train(&train_corpus, 400, 1, &tok_dir, "nfc", true).expect("train tokenizer");
let encode_lines: Vec<String> = (0..10)
.map(|i| {
format!(
r#"{{"content": "doc {i} alpha beta gamma the quick brown fox jumps {i}"}}"#
)
})
.collect();
let encode_refs: Vec<&str> = encode_lines.iter().map(String::as_str).collect();
let corpus = write_corpus_file(tmp.path(), "encode.jsonl", &encode_refs);
let out_1 = tmp.path().join("out_1");
run_encode_corpus(
std::slice::from_ref(&corpus),
&tok_dir,
&out_1,
10_000_000,
"content",
"nfc",
"between",
Some(1),
ProgressConfig {
quiet: true,
..ProgressConfig::default()
},
EstimateConfig::default(),
true,
)
.expect("encode --num-workers 1");
let out_n = tmp.path().join("out_n");
run_encode_corpus(
std::slice::from_ref(&corpus),
&tok_dir,
&out_n,
10_000_000,
"content",
"nfc",
"between",
Some(4),
ProgressConfig {
quiet: true,
..ProgressConfig::default()
},
EstimateConfig::default(),
true,
)
.expect("encode --num-workers 4");
let shards_1: Vec<_> = std::fs::read_dir(&out_1)
.expect("read out_1")
.filter_map(std::result::Result::ok)
.filter(|e| e.path().extension().and_then(std::ffi::OsStr::to_str) == Some("bin"))
.collect();
let shards_n: Vec<_> = std::fs::read_dir(&out_n)
.expect("read out_n")
.filter_map(std::result::Result::ok)
.filter(|e| e.path().extension().and_then(std::ffi::OsStr::to_str) == Some("bin"))
.collect();
assert_eq!(
shards_1.len(),
shards_n.len(),
"shard count must match across worker counts"
);
assert!(!shards_1.is_empty(), "test must produce at least one shard");
let mut names_1: Vec<String> = shards_1
.iter()
.map(|e| e.file_name().to_string_lossy().into_owned())
.collect();
let mut names_n: Vec<String> = shards_n
.iter()
.map(|e| e.file_name().to_string_lossy().into_owned())
.collect();
names_1.sort();
names_n.sort();
assert_eq!(
names_1, names_n,
"shard filenames must match (deterministic naming)"
);
for name in &names_1 {
let bytes_1 = std::fs::read(out_1.join(name)).expect("read single-threaded shard");
let bytes_n = std::fs::read(out_n.join(name)).expect("read parallel shard");
assert_eq!(
bytes_1, bytes_n,
"FALSIFY-#1547-PARITY: shard {name} must be byte-identical \
between --num-workers 1 and --num-workers 4"
);
}
}
#[cfg(feature = "training")]
#[test]
fn progress_emit_every_n_docs_when_under_seconds_window() {
let cfg = ProgressConfig {
quiet: false,
interval_docs: 1000,
interval_seconds: 60,
};
let emitter = ProgressEmitter::new(cfg, None);
let now = emitter.start;
assert!(
!emitter.should_emit(999, now),
"999 docs (< 1000 threshold) must not trigger doc-tick emission"
);
assert!(
emitter.should_emit(1000, now),
"1000 docs (== threshold) must trigger doc-tick emission"
);
assert!(
emitter.should_emit(5000, now),
"5000 docs must trigger doc-tick emission"
);
}
#[cfg(feature = "training")]
#[test]
fn progress_emit_every_n_seconds_when_under_docs_window() {
let cfg = ProgressConfig {
quiet: false,
interval_docs: 1_000_000,
interval_seconds: 1, };
let emitter = ProgressEmitter::new(cfg, None);
let now0 = emitter.start;
assert!(
!emitter.should_emit(10, now0),
"0s elapsed must not trigger time-tick emission"
);
let now1 = emitter.start + std::time::Duration::from_secs(1);
assert!(
emitter.should_emit(10, now1),
"1s elapsed must trigger time-tick emission even with only 10 docs"
);
}
#[cfg(feature = "training")]
#[test]
fn progress_quiet_flag_suppresses_emission() {
let cfg = ProgressConfig {
quiet: true,
interval_docs: 1,
interval_seconds: 1,
};
let emitter = ProgressEmitter::new(cfg, None);
let now = emitter.start + std::time::Duration::from_secs(120);
assert!(
!emitter.should_emit(10_000, now),
"quiet=true must suppress emission regardless of doc/time window"
);
}
#[cfg(feature = "training")]
#[test]
fn progress_format_line_no_total_omits_eta_fragment() {
let cfg = ProgressConfig::default();
let emitter = ProgressEmitter::new(cfg, None);
let now = emitter.start + std::time::Duration::from_secs(10);
let line = emitter.format_line(2000, 50_000, now);
assert!(line.starts_with("[progress] "), "expected prefix: {line}");
assert!(line.contains("doc=2000"), "doc count missing: {line}");
assert!(
!line.contains("doc=2000/"),
"must not include /T fragment when total unknown: {line}"
);
assert!(
!line.contains("eta="),
"must not include eta= when total unknown: {line}"
);
assert!(
line.contains("tokens=50000"),
"tokens count missing: {line}"
);
assert!(
line.contains("rate=") && line.contains("docs/s"),
"rate fragment missing: {line}"
);
}
#[cfg(feature = "training")]
#[test]
fn progress_format_line_with_total_includes_eta_fragment() {
let cfg = ProgressConfig::default();
let emitter = ProgressEmitter::new(cfg, Some(10_000));
let now = emitter.start + std::time::Duration::from_secs(5);
let line = emitter.format_line(1000, 25_000, now);
assert!(
line.contains("doc=1000/10000"),
"doc/total fragment missing: {line}"
);
assert!(line.contains("eta="), "eta fragment missing: {line}");
let eta_idx = line.find("eta=").expect("eta= present");
let after_eta = &line[eta_idx + 4..];
assert!(
after_eta.contains('T') && after_eta.trim_end().ends_with('Z'),
"eta must be ISO-8601 UTC (`...T...Z`): {line}"
);
}
#[cfg(feature = "training")]
#[test]
fn progress_mark_emitted_resets_both_clocks() {
let cfg = ProgressConfig {
quiet: false,
interval_docs: 1000,
interval_seconds: 60,
};
let mut emitter = ProgressEmitter::new(cfg, None);
let t0 = emitter.start;
assert!(emitter.should_emit(1000, t0));
emitter.mark_emitted(1000, t0);
assert!(
!emitter.should_emit(1500, t0),
"1500 - 1000 = 500 < 1000 threshold; must not re-emit"
);
assert!(
emitter.should_emit(2000, t0),
"2000 - 1000 = 1000 >= threshold; must re-emit"
);
}
#[cfg(feature = "training")]
#[test]
fn estimate_only_extrapolation_formula_correct() {
let (tokens, shards, wall) = extrapolate_estimate(
1000, 50_000, 1.0, 100_000, 1_000_000, 4, );
assert_eq!(tokens, 5_000_000, "estimated_total_tokens math");
assert_eq!(shards, 5, "estimated_shards must be ceil(total/shard)");
assert!(
(wall - 25.0).abs() < 0.01,
"estimated_wall = wall_per_doc × total_docs / num_workers; got {wall}"
);
assert_eq!(extrapolate_estimate(0, 0, 0.0, 100, 1000, 4), (0, 0, 0.0));
let (_, _, wall_zero_workers) =
extrapolate_estimate(1000, 50_000, 1.0, 100_000, 1_000_000, 0);
assert!(
(wall_zero_workers - 100.0).abs() < 0.01,
"0 workers must clamp to 1; got {wall_zero_workers}"
);
let (_, shards_zero, _) = extrapolate_estimate(1000, 50_000, 1.0, 100_000, 0, 4);
assert_eq!(
shards_zero, 0,
"shard_tokens=0 must yield 0 estimated_shards"
);
}
#[cfg(feature = "training")]
#[test]
fn estimate_only_no_shards_written() {
let tmp = TempDir::new().expect("tempdir");
let train_corpus = write_corpus_file(
tmp.path(),
"train.jsonl",
&[
r#"{"content": "alpha beta gamma delta"}"#,
r#"{"content": "epsilon zeta eta theta"}"#,
],
);
let tok_dir = tmp.path().join("tok");
run_train(&train_corpus, 400, 1, &tok_dir, "nfc", true).expect("train tokenizer");
let encode_lines: Vec<String> = (0..50)
.map(|i| format!(r#"{{"content": "doc {i} alpha beta gamma {i}"}}"#))
.collect();
let encode_refs: Vec<&str> = encode_lines.iter().map(String::as_str).collect();
let corpus = write_corpus_file(tmp.path(), "encode.jsonl", &encode_refs);
let out = tmp.path().join("out_estimate");
run_encode_corpus(
std::slice::from_ref(&corpus),
&tok_dir,
&out,
10_000_000,
"content",
"nfc",
"between",
Some(1),
ProgressConfig {
quiet: true,
..ProgressConfig::default()
},
EstimateConfig {
enabled: true,
sample_docs: 10,
},
true,
)
.expect("estimate-only must succeed without writing shards");
if out.exists() {
let bins: Vec<_> = std::fs::read_dir(&out)
.expect("read estimate out dir")
.filter_map(std::result::Result::ok)
.filter(|e| e.path().extension().and_then(std::ffi::OsStr::to_str) == Some("bin"))
.collect();
assert!(
bins.is_empty(),
"FALSIFY-APR-TOK-PAR-012: --estimate-only produced {} \
shard(s); estimate is supposed to write nothing",
bins.len()
);
}
}
#[cfg(feature = "training")]
#[test]
fn estimate_only_no_manifest_written() {
let tmp = TempDir::new().expect("tempdir");
let train_corpus = write_corpus_file(
tmp.path(),
"train.jsonl",
&[r#"{"content": "alpha beta gamma"}"#],
);
let tok_dir = tmp.path().join("tok");
run_train(&train_corpus, 400, 1, &tok_dir, "nfc", true).expect("train tokenizer");
let encode_lines: Vec<String> = (0..20)
.map(|i| format!(r#"{{"content": "doc {i}"}}"#))
.collect();
let encode_refs: Vec<&str> = encode_lines.iter().map(String::as_str).collect();
let corpus = write_corpus_file(tmp.path(), "encode.jsonl", &encode_refs);
let out = tmp.path().join("out_no_manifest");
run_encode_corpus(
std::slice::from_ref(&corpus),
&tok_dir,
&out,
10_000_000,
"content",
"nfc",
"between",
Some(1),
ProgressConfig {
quiet: true,
..ProgressConfig::default()
},
EstimateConfig {
enabled: true,
sample_docs: 5,
},
true,
)
.expect("estimate-only must succeed");
let manifest = out.join("manifest.json");
assert!(
!manifest.exists(),
"FALSIFY-APR-TOK-PAR-013: --estimate-only produced manifest.json at {}",
manifest.display()
);
}
#[cfg(feature = "training")]
#[test]
fn estimate_only_emits_estimate_lines_to_stderr() {
let tmp = TempDir::new().expect("tempdir");
let train_corpus = write_corpus_file(
tmp.path(),
"train.jsonl",
&[r#"{"content": "alpha beta gamma delta epsilon"}"#],
);
let tok_dir = tmp.path().join("tok");
run_train(&train_corpus, 400, 1, &tok_dir, "nfc", true).expect("train tokenizer");
let encode_lines: Vec<String> = (0..15)
.map(|i| format!(r#"{{"content": "doc {i} content"}}"#))
.collect();
let encode_refs: Vec<&str> = encode_lines.iter().map(String::as_str).collect();
let corpus = write_corpus_file(tmp.path(), "encode.jsonl", &encode_refs);
let out = tmp.path().join("out_lines");
let result = run_encode_corpus(
std::slice::from_ref(&corpus),
&tok_dir,
&out,
10_000_000,
"content",
"nfc",
"between",
Some(1),
ProgressConfig {
quiet: true,
..ProgressConfig::default()
},
EstimateConfig {
enabled: true,
sample_docs: 5,
},
true,
);
assert!(
result.is_ok(),
"FALSIFY-APR-TOK-PAR-014: --estimate-only path must return Ok(()) \
on a valid corpus + tokenizer; got {result:?}"
);
}
#[cfg(feature = "training")]
#[test]
fn estimate_only_rejects_zero_sample_size() {
let tmp = TempDir::new().expect("tempdir");
let train_corpus =
write_corpus_file(tmp.path(), "train.jsonl", &[r#"{"content": "alpha"}"#]);
let tok_dir = tmp.path().join("tok");
run_train(&train_corpus, 400, 1, &tok_dir, "nfc", true).expect("train tokenizer");
let corpus = write_corpus_file(tmp.path(), "encode.jsonl", &[r#"{"content": "doc"}"#]);
let out = tmp.path().join("out_zero_sample");
let err = run_encode_corpus(
std::slice::from_ref(&corpus),
&tok_dir,
&out,
10_000_000,
"content",
"nfc",
"between",
Some(1),
ProgressConfig {
quiet: true,
..ProgressConfig::default()
},
EstimateConfig {
enabled: true,
sample_docs: 0, },
true,
)
.expect_err("sample_docs=0 must error");
match err {
CliError::ValidationFailed(msg) => {
assert!(
msg.contains("--estimate-sample-docs"),
"error must name flag: {msg}"
);
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn encode_corpus_accepts_multiple_corpus_paths() {
let tmp = TempDir::new().expect("tempdir");
let train_corpus = write_corpus_file(
tmp.path(),
"train.jsonl",
&[r#"{"content": "alpha beta gamma delta epsilon"}"#],
);
let tok_dir = tmp.path().join("tok");
run_train(&train_corpus, 400, 1, &tok_dir, "nfc", true).expect("train tokenizer");
let corpus_a = write_corpus_file(
tmp.path(),
"src_a.jsonl",
&[
r#"{"content": "alpha alpha alpha"}"#,
r#"{"content": "alpha beta"}"#,
],
);
let corpus_b = write_corpus_file(
tmp.path(),
"src_b.jsonl",
&[r#"{"content": "gamma delta epsilon"}"#],
);
let out = tmp.path().join("merged_out");
let result = run_encode_corpus(
&[corpus_a.clone(), corpus_b.clone()],
&tok_dir,
&out,
10_000_000,
"content",
"nfc",
"between",
Some(1),
ProgressConfig {
quiet: true,
..ProgressConfig::default()
},
EstimateConfig::default(),
true,
);
assert!(
result.is_ok(),
"multi-corpus encode must succeed: {result:?}"
);
let manifest_path = out.join("manifest.json");
let manifest_bytes = std::fs::read(&manifest_path).expect("manifest written");
let manifest: serde_json::Value =
serde_json::from_slice(&manifest_bytes).expect("manifest is valid JSON");
let total_documents = manifest
.get("total_documents")
.and_then(|v| v.as_u64())
.expect("manifest has total_documents");
assert_eq!(
total_documents, 3,
"merged corpus must contain 3 docs (2 from src_a + 1 from src_b), got {total_documents}"
);
let input_files = manifest
.get("input_files")
.and_then(|v| v.as_array())
.expect("manifest has input_files");
assert!(
input_files.len() >= 2,
"input_files must list at least 2 source files, got {}",
input_files.len(),
);
let corpus_roots = manifest
.get("corpus_roots")
.and_then(|v| v.as_array())
.expect("manifest has corpus_roots");
assert_eq!(
corpus_roots.len(),
2,
"corpus_roots must list exactly 2 sources (src_a + src_b), got {}",
corpus_roots.len(),
);
let root_strs: Vec<&str> = corpus_roots.iter().filter_map(|v| v.as_str()).collect();
assert!(
root_strs.iter().any(|s| s.contains("src_a.jsonl")),
"corpus_roots must reference src_a, got {root_strs:?}",
);
assert!(
root_strs.iter().any(|s| s.contains("src_b.jsonl")),
"corpus_roots must reference src_b, got {root_strs:?}",
);
}
#[test]
fn encode_corpus_rejects_empty_corpus_list() {
let tmp = TempDir::new().expect("tempdir");
let train_corpus =
write_corpus_file(tmp.path(), "train.jsonl", &[r#"{"content": "alpha"}"#]);
let tok_dir = tmp.path().join("tok");
run_train(&train_corpus, 400, 1, &tok_dir, "nfc", true).expect("train tokenizer");
let out = tmp.path().join("out_empty");
let err = run_encode_corpus(
&[],
&tok_dir,
&out,
10_000_000,
"content",
"nfc",
"between",
Some(1),
ProgressConfig {
quiet: true,
..ProgressConfig::default()
},
EstimateConfig::default(),
true,
)
.expect_err("empty --corpus slice must error");
match err {
CliError::ValidationFailed(msg) => {
assert!(
msg.contains("--corpus") || msg.contains("required"),
"error must mention missing corpus: {msg}"
);
}
other => panic!("unexpected error variant: {other:?}"),
}
}
}