use colored::Colorize;
use std::path::{Path, PathBuf};
use std::time::Instant;
use crate::{error::CliError, output};
type Result<T> = std::result::Result<T, CliError>;
#[provable_contracts_macros::contract(
"apr-cli-operations-v1",
equation = "side_effect_classification"
)]
pub(crate) fn run_plan(
data: &Path,
vocab_size: usize,
algorithm: &str,
output_dir: &Path,
format: &str,
json_output: bool,
) -> Result<()> {
contract_pre_tokenizer_training_correctness!();
validate_algorithm(algorithm)?;
validate_vocab_size(vocab_size)?;
if !data.exists() {
return Err(CliError::FileNotFound(data.to_path_buf()));
}
let corpus_stats = analyze_corpus(data)?;
let plan = TokenizePlan {
algorithm: algorithm.to_string(),
vocab_size,
corpus_path: data.display().to_string(),
corpus_lines: corpus_stats.lines,
corpus_bytes: corpus_stats.bytes,
unique_chars: corpus_stats.unique_chars,
output_dir: output_dir.display().to_string(),
estimated_minutes: estimate_training_time(corpus_stats.bytes, vocab_size),
verdict: plan_verdict(&corpus_stats, vocab_size),
};
let effective_format = if json_output { "json" } else { format };
match effective_format {
"json" => {
let json = serde_json::to_string_pretty(&plan)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?;
println!("{json}");
}
"yaml" => {
return Err(CliError::ValidationFailed(
"YAML output not supported. Use --format json or --format text.".to_string(),
));
}
_ => print_plan_text(&plan),
}
if plan.verdict == "blocked" {
return Err(CliError::ValidationFailed(
"Plan is blocked — resolve failures before applying".to_string(),
));
}
contract_post_tokenizer_training_correctness!(&());
Ok(())
}
#[provable_contracts_macros::contract(
"apr-cli-operations-v1",
equation = "side_effect_classification"
)]
pub(crate) fn run_apply(
data: &Path,
vocab_size: usize,
algorithm: &str,
output_dir: &Path,
max_lines: usize,
json_output: bool,
) -> Result<()> {
validate_algorithm(algorithm)?;
validate_vocab_size(vocab_size)?;
if !data.exists() {
return Err(CliError::FileNotFound(data.to_path_buf()));
}
let corpus_text = read_corpus(data, max_lines)?;
let corpus_refs: Vec<&str> = corpus_text.iter().map(String::as_str).collect();
if corpus_refs.is_empty() {
return Err(CliError::ValidationFailed(
"Corpus is empty — no text to train on".to_string(),
));
}
if !json_output {
print_apply_header(data, vocab_size, algorithm, output_dir, corpus_refs.len());
}
let start = Instant::now();
let tokenizer = train_tokenizer(&corpus_refs, vocab_size, algorithm)?;
let elapsed = start.elapsed();
std::fs::create_dir_all(output_dir).map_err(|e| {
CliError::ValidationFailed(format!(
"Cannot create output directory {}: {e}",
output_dir.display()
))
})?;
let actual_vocab_size = tokenizer.vocab_size();
write_vocab_json(output_dir, &tokenizer)?;
write_merges_txt(output_dir, &tokenizer)?;
let result = TokenizeResult {
algorithm: algorithm.to_string(),
vocab_size: actual_vocab_size,
corpus_lines: corpus_refs.len(),
training_seconds: elapsed.as_secs_f64(),
output_dir: output_dir.display().to_string(),
};
if json_output {
let json = serde_json::to_string_pretty(&result)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?;
println!("{json}");
} else {
print_apply_result(&result);
}
Ok(())
}
fn validate_algorithm(algorithm: &str) -> Result<()> {
match algorithm {
"bpe" | "wordpiece" | "unigram" => Ok(()),
_ => Err(CliError::ValidationFailed(format!(
"Unknown algorithm: {algorithm}. Supported: bpe, wordpiece, unigram"
))),
}
}
fn validate_vocab_size(vocab_size: usize) -> Result<()> {
if vocab_size < 10 {
return Err(CliError::ValidationFailed(format!(
"vocab_size must be at least 10, got {vocab_size}"
)));
}
if vocab_size > 1_000_000 {
return Err(CliError::ValidationFailed(format!(
"vocab_size {vocab_size} is unreasonably large (max 1M)"
)));
}
Ok(())
}
#[derive(serde::Serialize)]
struct TokenizePlan {
algorithm: String,
vocab_size: usize,
corpus_path: String,
corpus_lines: usize,
corpus_bytes: u64,
unique_chars: usize,
output_dir: String,
estimated_minutes: f64,
verdict: String,
}
#[derive(serde::Serialize)]
struct TokenizeResult {
algorithm: String,
vocab_size: usize,
corpus_lines: usize,
training_seconds: f64,
output_dir: String,
}
struct CorpusStats {
lines: usize,
bytes: u64,
unique_chars: usize,
}
fn analyze_corpus(path: &Path) -> Result<CorpusStats> {
let metadata = std::fs::metadata(path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot stat {}: {e}", path.display())))?;
let bytes = metadata.len();
let content = std::fs::read_to_string(path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot read {}: {e}", path.display())))?;
let lines = content.lines().count();
let unique_chars: std::collections::HashSet<char> = content.chars().collect();
Ok(CorpusStats {
lines,
bytes,
unique_chars: unique_chars.len(),
})
}
fn estimate_training_time(bytes: u64, vocab_size: usize) -> f64 {
let mb = bytes as f64 / (1024.0 * 1024.0);
let vocab_factor = (vocab_size as f64 / 32000.0).max(1.0);
(mb * vocab_factor) / 60.0
}
fn plan_verdict(stats: &CorpusStats, vocab_size: usize) -> String {
if stats.lines == 0 {
return "blocked".to_string();
}
if vocab_size > stats.unique_chars * 100 {
return "warning".to_string();
}
"ready".to_string()
}
fn read_corpus(path: &Path, max_lines: usize) -> Result<Vec<String>> {
let content = std::fs::read_to_string(path).map_err(|e| {
CliError::ValidationFailed(format!("Cannot read corpus {}: {e}", path.display()))
})?;
let lines: Vec<String> = if max_lines > 0 {
content.lines().take(max_lines).map(String::from).collect()
} else {
content.lines().map(String::from).collect()
};
Ok(lines)
}
struct TrainedTokenizer {
vocab: std::collections::HashMap<String, u32>,
merges: Vec<(String, String)>,
}
impl TrainedTokenizer {
fn vocab_size(&self) -> usize {
self.vocab.len()
}
}
#[cfg(feature = "training")]
fn train_bpe_via_entrenar(
corpus: &[&str],
vocab_size: usize,
min_frequency: usize,
normalization: &str,
) -> Result<TrainedTokenizer> {
use entrenar::tokenizer::{BPETokenizer, Normalization, Tokenizer, TokenizerConfig};
let norm = match normalization {
"nfc" => Normalization::NFC,
"none" => Normalization::None,
other => {
return Err(CliError::ValidationFailed(format!(
"Unknown normalization: {other}. Supported: none, nfc"
)));
}
};
let config = TokenizerConfig::bpe()
.with_vocab_size(vocab_size)
.with_min_frequency(min_frequency)
.with_normalization(norm);
let mut tokenizer = BPETokenizer::new(config);
tokenizer
.train(corpus)
.map_err(|e| CliError::ValidationFailed(format!("BPE training failed: {e}")))?;
Ok(TrainedTokenizer {
vocab: tokenizer.vocab().clone(),
merges: tokenizer.merges().to_vec(),
})
}
#[cfg(not(feature = "training"))]
fn train_bpe_via_entrenar(
corpus: &[&str],
vocab_size: usize,
_min_frequency: usize,
_normalization: &str,
) -> Result<TrainedTokenizer> {
let tokenizer = aprender::text::tokenize::BpeTokenizer::train(corpus, vocab_size)
.map_err(|e| CliError::ValidationFailed(format!("BPE training failed: {e}")))?;
Ok(TrainedTokenizer {
vocab: tokenizer.vocab().clone(),
merges: tokenizer.merges().to_vec(),
})
}
fn train_tokenizer(
corpus: &[&str],
vocab_size: usize,
algorithm: &str,
) -> Result<TrainedTokenizer> {
match algorithm {
"bpe" => {
let tokenizer = aprender::text::tokenize::BpeTokenizer::train(corpus, vocab_size)
.map_err(|e| CliError::ValidationFailed(format!("BPE training failed: {e}")))?;
Ok(TrainedTokenizer {
vocab: tokenizer.vocab().clone(),
merges: tokenizer.merges().to_vec(),
})
}
"wordpiece" => {
let tokenizer = aprender::text::tokenize::WordPieceTokenizer::train(corpus, vocab_size)
.map_err(|e| {
CliError::ValidationFailed(format!("WordPiece training failed: {e}"))
})?;
Ok(TrainedTokenizer {
vocab: tokenizer.vocab().clone(),
merges: Vec::new(),
})
}
"unigram" => {
let tokenizer = aprender::text::tokenize::UnigramTokenizer::train(corpus, vocab_size)
.map_err(|e| {
CliError::ValidationFailed(format!("Unigram training failed: {e}"))
})?;
Ok(TrainedTokenizer {
vocab: tokenizer.vocab_ids(),
merges: Vec::new(),
})
}
_ => unreachable!("algorithm validated above"),
}
}
fn write_vocab_json(output_dir: &Path, tokenizer: &TrainedTokenizer) -> Result<()> {
let vocab_path = output_dir.join("vocab.json");
let mut entries: Vec<(&String, &u32)> = tokenizer.vocab.iter().collect();
entries.sort_by_key(|(_, id)| *id);
let ordered: serde_json::Map<String, serde_json::Value> = entries
.into_iter()
.map(|(k, v)| (k.clone(), serde_json::Value::Number((*v).into())))
.collect();
let json = serde_json::to_string_pretty(&ordered)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?;
std::fs::write(&vocab_path, json).map_err(|e| {
CliError::ValidationFailed(format!("Cannot write {}: {e}", vocab_path.display()))
})?;
Ok(())
}
fn write_merges_txt(output_dir: &Path, tokenizer: &TrainedTokenizer) -> Result<()> {
let merges_path = output_dir.join("merges.txt");
let mut content = String::from("#version: 0.2\n");
for (left, right) in &tokenizer.merges {
content.push_str(left);
content.push(' ');
content.push_str(right);
content.push('\n');
}
std::fs::write(&merges_path, content).map_err(|e| {
CliError::ValidationFailed(format!("Cannot write {}: {e}", merges_path.display()))
})?;
Ok(())
}
fn print_plan_text(plan: &TokenizePlan) {
output::header("apr tokenize plan — Tokenizer Training Pre-flight");
println!();
output::section("Configuration");
output::kv(" Algorithm", &plan.algorithm);
output::kv(" Vocab size", format_number(plan.vocab_size));
output::kv(" Corpus", &plan.corpus_path);
output::kv(" Output", &plan.output_dir);
println!();
output::section("Corpus Analysis");
output::kv(" Lines", format_number(plan.corpus_lines));
output::kv(" Size", format_bytes(plan.corpus_bytes));
output::kv(" Unique chars", format_number(plan.unique_chars));
println!();
output::section("Estimates");
output::kv(" Training time", format_duration(plan.estimated_minutes));
println!();
let verdict_display = match plan.verdict.as_str() {
"ready" => format!("{}", "READY".green().bold()),
"warning" => format!("{}", "WARNING".yellow().bold()),
"blocked" => format!("{}", "BLOCKED".red().bold()),
_ => plan.verdict.clone(),
};
output::kv(" Verdict", verdict_display);
println!();
}
fn print_apply_header(
data: &Path,
vocab_size: usize,
algorithm: &str,
output_dir: &Path,
corpus_lines: usize,
) {
output::header("apr tokenize apply — Training Tokenizer");
println!();
output::kv(" Algorithm", algorithm);
output::kv(" Vocab size", format_number(vocab_size));
output::kv(" Corpus", data.display().to_string());
output::kv(" Lines", format_number(corpus_lines));
output::kv(" Output", output_dir.display().to_string());
println!();
}
fn print_apply_result(result: &TokenizeResult) {
output::section("Result");
println!(" {} Tokenizer trained successfully", "OK".green().bold());
output::kv(" Final vocab size", format_number(result.vocab_size));
output::kv(
" Training time",
format!("{:.1}s", result.training_seconds),
);
output::kv(" vocab.json", format!("{}/vocab.json", result.output_dir));
output::kv(" merges.txt", format!("{}/merges.txt", result.output_dir));
println!();
}
fn format_number(n: usize) -> String {
if n >= 1_000_000 {
format!("{:.1}M", n as f64 / 1_000_000.0)
} else if n >= 1_000 {
format!("{:.1}K", n as f64 / 1_000.0)
} else {
n.to_string()
}
}
fn format_bytes(bytes: u64) -> String {
if bytes >= 1_073_741_824 {
format!("{:.1} GB", bytes as f64 / 1_073_741_824.0)
} else if bytes >= 1_048_576 {
format!("{:.1} MB", bytes as f64 / 1_048_576.0)
} else if bytes >= 1024 {
format!("{:.1} KB", bytes as f64 / 1024.0)
} else {
format!("{bytes} B")
}
}
fn format_duration(minutes: f64) -> String {
if minutes < 1.0 {
format!("{:.0} sec", minutes * 60.0)
} else if minutes < 60.0 {
format!("{:.1} min", minutes)
} else {
format!("{:.1} hours", minutes / 60.0)
}
}
#[derive(serde::Serialize)]
struct TokenizeTrainResult {
algorithm: String,
vocab_size: usize,
corpus_lines: usize,
corpus_files: usize,
min_frequency: usize,
normalization: String,
training_seconds: f64,
output_dir: String,
}
pub(crate) fn run_train(
corpus: &Path,
vocab_size: usize,
min_frequency: usize,
output_dir: &Path,
normalization: &str,
json_output: bool,
) -> Result<()> {
validate_vocab_size(vocab_size)?;
validate_normalization(normalization)?;
if !corpus.exists() {
return Err(CliError::FileNotFound(corpus.to_path_buf()));
}
let files = collect_jsonl_files(corpus)?;
if files.is_empty() {
return Err(CliError::ValidationFailed(format!(
"No .jsonl files found under {}",
corpus.display()
)));
}
let mut lines: Vec<String> = Vec::new();
for file in &files {
read_jsonl_content(file, &mut lines)?;
}
if lines.is_empty() {
return Err(CliError::ValidationFailed(
"Corpus contained no `content` fields — nothing to train on".to_string(),
));
}
if !json_output {
print_train_header(corpus, vocab_size, output_dir, files.len(), lines.len());
}
let refs: Vec<&str> = lines.iter().map(String::as_str).collect();
let start = Instant::now();
let trained = train_bpe_via_entrenar(&refs, vocab_size, min_frequency, normalization)?;
let elapsed = start.elapsed();
std::fs::create_dir_all(output_dir).map_err(|e| {
CliError::ValidationFailed(format!(
"Cannot create output directory {}: {e}",
output_dir.display()
))
})?;
write_vocab_json(output_dir, &trained)?;
write_merges_txt(output_dir, &trained)?;
let result = TokenizeTrainResult {
algorithm: "bpe".to_string(),
vocab_size: trained.vocab_size(),
corpus_lines: lines.len(),
corpus_files: files.len(),
min_frequency,
normalization: normalization.to_string(),
training_seconds: elapsed.as_secs_f64(),
output_dir: output_dir.display().to_string(),
};
if json_output {
let json = serde_json::to_string_pretty(&result)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?;
println!("{json}");
} else {
print_train_result(&result);
}
Ok(())
}
fn validate_normalization(norm: &str) -> Result<()> {
match norm {
"none" | "nfc" => Ok(()),
other => Err(CliError::ValidationFailed(format!(
"Unknown normalization: {other}. Supported: none, nfc"
))),
}
}
#[cfg(feature = "training")]
#[allow(clippy::too_many_arguments)]
pub(crate) fn run_encode_corpus(
corpus: &Path,
tokenizer_dir: &Path,
output_dir: &Path,
shard_tokens: usize,
content_field: &str,
normalization: &str,
eos_policy: &str,
json_output: bool,
) -> Result<()> {
use entrenar::tokenizer::{BPETokenizer, Normalization, Tokenizer, TokenizerConfig};
use std::io::Write as IoWrite;
validate_normalization(normalization)?;
match eos_policy {
"none" | "between" | "after" => {}
other => {
return Err(CliError::ValidationFailed(format!(
"Unknown eos_policy: {other}. Supported: none, between, after"
)));
}
}
if shard_tokens == 0 {
return Err(CliError::ValidationFailed(
"shard_tokens must be > 0".to_string(),
));
}
if !corpus.exists() {
return Err(CliError::FileNotFound(corpus.to_path_buf()));
}
let vocab_path = tokenizer_dir.join("vocab.json");
let merges_path = tokenizer_dir.join("merges.txt");
if !vocab_path.exists() {
return Err(CliError::FileNotFound(vocab_path));
}
if !merges_path.exists() {
return Err(CliError::FileNotFound(merges_path));
}
let norm = match normalization {
"nfc" => Normalization::NFC,
"none" => Normalization::None,
_ => unreachable!("validated above"),
};
let config = TokenizerConfig::bpe().with_normalization(norm);
let tokenizer = BPETokenizer::from_vocab_merges(
vocab_path.to_str().ok_or_else(|| {
CliError::ValidationFailed("vocab.json path has non-utf8 bytes".to_string())
})?,
merges_path.to_str().ok_or_else(|| {
CliError::ValidationFailed("merges.txt path has non-utf8 bytes".to_string())
})?,
config,
)
.map_err(|e| CliError::ValidationFailed(format!("Cannot load tokenizer: {e}")))?;
let vocab_size = tokenizer.vocab_size();
let eos_id = ["</s>", "<|endoftext|>", "<eos>", "<|eos|>"]
.iter()
.find_map(|name| tokenizer.token_to_id(name));
let (files, corpus_format) = collect_corpus_files(corpus)?;
std::fs::create_dir_all(output_dir).map_err(|e| {
CliError::ValidationFailed(format!(
"Cannot create output directory {}: {e}",
output_dir.display()
))
})?;
let start = Instant::now();
let mut shard_idx: usize = 0;
let mut tokens_in_shard: usize = 0;
let mut total_tokens: u64 = 0;
let mut total_docs: u64 = 0;
let mut eos_count: u64 = 0;
let mut writer = open_shard(output_dir, shard_idx)?;
let mut doc_iter_count: u64 = 0;
for triple in iter_corpus_texts(&files, corpus_format, content_field) {
let (file_display, locator, text) = triple?;
let ids = tokenizer.encode(&text).map_err(|e| {
CliError::ValidationFailed(format!("Encoding failed at {file_display} {locator}: {e}"))
})?;
if eos_policy == "between" && doc_iter_count > 0 {
if let Some(eos) = eos_id {
writer
.write_all(&eos.to_le_bytes())
.map_err(|e| CliError::ValidationFailed(format!("Shard write failed: {e}")))?;
tokens_in_shard += 1;
total_tokens += 1;
eos_count += 1;
}
}
for id in &ids {
if (*id as usize) >= vocab_size {
return Err(CliError::ValidationFailed(format!(
"Token id {id} >= vocab_size {vocab_size} at {file_display} {locator} \
(INV-PRETOK-001 violation)"
)));
}
writer
.write_all(&id.to_le_bytes())
.map_err(|e| CliError::ValidationFailed(format!("Shard write failed: {e}")))?;
tokens_in_shard += 1;
total_tokens += 1;
}
if eos_policy == "after" {
if let Some(eos) = eos_id {
writer
.write_all(&eos.to_le_bytes())
.map_err(|e| CliError::ValidationFailed(format!("Shard write failed: {e}")))?;
tokens_in_shard += 1;
total_tokens += 1;
eos_count += 1;
}
}
doc_iter_count += 1;
total_docs += 1;
if tokens_in_shard >= shard_tokens {
writer
.flush()
.map_err(|e| CliError::ValidationFailed(format!("Shard flush failed: {e}")))?;
shard_idx += 1;
tokens_in_shard = 0;
writer = open_shard(output_dir, shard_idx)?;
}
}
writer
.flush()
.map_err(|e| CliError::ValidationFailed(format!("Shard flush failed: {e}")))?;
let shard_count = shard_idx + 1;
let elapsed = start.elapsed();
let manifest = serde_json::json!({
"schema": "pretokenize-bin-v1",
"tokenizer_dir": tokenizer_dir.display().to_string(),
"vocab_size": vocab_size,
"eos_policy": eos_policy,
"eos_token_id": eos_id,
"eos_token_count": eos_count,
"shard_count": shard_count,
"total_tokens": total_tokens,
"total_documents": total_docs,
"content_field": content_field,
"normalization": normalization,
"input_format": match corpus_format {
CorpusFormat::Jsonl => "jsonl",
CorpusFormat::Parquet => "parquet",
},
"input_files": files.iter().map(|p| p.display().to_string()).collect::<Vec<_>>(),
"elapsed_seconds": elapsed.as_secs_f64(),
});
let manifest_path = output_dir.join("manifest.json");
std::fs::write(
&manifest_path,
serde_json::to_string_pretty(&manifest)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?,
)
.map_err(|e| CliError::ValidationFailed(format!("Cannot write manifest: {e}")))?;
if json_output {
println!(
"{}",
serde_json::to_string_pretty(&manifest)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?
);
} else {
output::header("apr tokenize encode-corpus — Pretokenization Result");
output::kv(" Shards", format_number(shard_count));
output::kv(" Total tokens", format_number(total_tokens as usize));
output::kv(" Total documents", format_number(total_docs as usize));
output::kv(" Vocab size", format_number(vocab_size));
output::kv(" Elapsed", format!("{:.1}s", elapsed.as_secs_f64()));
output::kv(" Manifest", manifest_path.display().to_string());
}
Ok(())
}
#[cfg(feature = "training")]
fn open_shard(output_dir: &Path, shard_idx: usize) -> Result<std::io::BufWriter<std::fs::File>> {
let path = output_dir.join(format!("shard-{shard_idx:05}.bin"));
let file = std::fs::File::create(&path).map_err(|e| {
CliError::ValidationFailed(format!("Cannot create shard {}: {e}", path.display()))
})?;
Ok(std::io::BufWriter::new(file))
}
fn collect_jsonl_files(path: &Path) -> Result<Vec<std::path::PathBuf>> {
let meta = std::fs::metadata(path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot stat {}: {e}", path.display())))?;
if meta.is_file() {
if is_jsonl(path) {
return Ok(vec![path.to_path_buf()]);
}
return Err(CliError::ValidationFailed(format!(
"Corpus file {} is not a .jsonl file",
path.display()
)));
}
let mut out = Vec::new();
let entries = std::fs::read_dir(path).map_err(|e| {
CliError::ValidationFailed(format!("Cannot read directory {}: {e}", path.display()))
})?;
for entry in entries {
let entry =
entry.map_err(|e| CliError::ValidationFailed(format!("Directory entry error: {e}")))?;
let p = entry.path();
if p.is_file() && is_jsonl(&p) {
out.push(p);
}
}
out.sort();
Ok(out)
}
fn is_jsonl(path: &Path) -> bool {
path.extension().and_then(|e| e.to_str()) == Some("jsonl")
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CorpusFormat {
Jsonl,
Parquet,
}
fn collect_corpus_files(path: &Path) -> Result<(Vec<std::path::PathBuf>, CorpusFormat)> {
let meta = std::fs::metadata(path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot stat {}: {e}", path.display())))?;
if meta.is_file() {
if super::tokenize_parquet::is_parquet(path) {
return Ok((vec![path.to_path_buf()], CorpusFormat::Parquet));
}
if is_jsonl(path) {
return Ok((vec![path.to_path_buf()], CorpusFormat::Jsonl));
}
return Err(CliError::ValidationFailed(format!(
"Corpus file {} is not a .jsonl or .parquet file",
path.display()
)));
}
let parquet = super::tokenize_parquet::collect_parquet_files(path).unwrap_or_default();
if !parquet.is_empty() {
return Ok((parquet, CorpusFormat::Parquet));
}
let jsonl = collect_jsonl_files(path)?;
if jsonl.is_empty() {
return Err(CliError::ValidationFailed(format!(
"No .jsonl or .parquet files found under {}",
path.display()
)));
}
Ok((jsonl, CorpusFormat::Jsonl))
}
#[cfg(feature = "training")]
fn iter_corpus_texts<'a>(
files: &'a [std::path::PathBuf],
format: CorpusFormat,
content_field: &'a str,
) -> Box<dyn Iterator<Item = Result<(String, String, String)>> + 'a> {
match format {
CorpusFormat::Parquet => Box::new(files.iter().flat_map(move |file| {
let file_display = file.display().to_string();
match super::tokenize_parquet::iter_parquet_content(file, content_field) {
Ok(it) => {
let fd = file_display;
let inner: Box<dyn Iterator<Item = Result<(String, String, String)>>> =
Box::new(it.enumerate().map(move |(idx, r)| {
r.map(|t| (fd.clone(), format!("row {}", idx + 1), t))
}));
inner
}
Err(e) => {
let inner: Box<dyn Iterator<Item = Result<(String, String, String)>>> =
Box::new(std::iter::once(Err(e)));
inner
}
}
})),
CorpusFormat::Jsonl => Box::new(files.iter().flat_map(move |file| {
let file_display = file.display().to_string();
match std::fs::read_to_string(file) {
Ok(content) => {
let fd = file_display;
let triples: Vec<Result<(String, String, String)>> = content
.lines()
.enumerate()
.filter_map(|(idx, line)| {
let trimmed = line.trim();
if trimmed.is_empty() {
return None;
}
match serde_json::from_str::<serde_json::Value>(trimmed) {
Ok(v) => v.get(content_field).and_then(|x| x.as_str()).map(|s| {
Ok((fd.clone(), format!("line {}", idx + 1), s.to_string()))
}),
Err(e) => Some(Err(CliError::ValidationFailed(format!(
"Invalid JSON in {fd} line {}: {e}",
idx + 1
)))),
}
})
.collect();
let inner: Box<dyn Iterator<Item = Result<(String, String, String)>>> =
Box::new(triples.into_iter());
inner
}
Err(e) => {
let msg = format!("Cannot read {file_display}: {e}");
let inner: Box<dyn Iterator<Item = Result<(String, String, String)>>> =
Box::new(std::iter::once(Err(CliError::ValidationFailed(msg))));
inner
}
}
})),
}
}
fn read_jsonl_content(path: &Path, out: &mut Vec<String>) -> Result<()> {
let content = std::fs::read_to_string(path)
.map_err(|e| CliError::ValidationFailed(format!("Cannot read {}: {e}", path.display())))?;
for (line_idx, line) in content.lines().enumerate() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let value: serde_json::Value = serde_json::from_str(trimmed).map_err(|e| {
CliError::ValidationFailed(format!(
"Invalid JSON in {} line {}: {e}",
path.display(),
line_idx + 1
))
})?;
if let Some(text) = value.get("content").and_then(|v| v.as_str()) {
out.push(text.to_string());
}
}
Ok(())
}
fn print_train_header(
corpus: &Path,
vocab_size: usize,
output_dir: &Path,
files: usize,
lines: usize,
) {
output::header("apr tokenize train — Training BPE Tokenizer");
println!();
output::kv(" Algorithm", "bpe");
output::kv(" Vocab size", format_number(vocab_size));
output::kv(" Corpus", corpus.display().to_string());
output::kv(" Files", format_number(files));
output::kv(" Lines", format_number(lines));
output::kv(" Output", output_dir.display().to_string());
println!();
}
fn print_train_result(result: &TokenizeTrainResult) {
output::section("Result");
println!(
" {} BPE tokenizer trained successfully",
"OK".green().bold()
);
output::kv(" Final vocab size", format_number(result.vocab_size));
output::kv(" Normalization", &result.normalization);
output::kv(
" Training time",
format!("{:.1}s", result.training_seconds),
);
output::kv(" vocab.json", format!("{}/vocab.json", result.output_dir));
output::kv(" merges.txt", format!("{}/merges.txt", result.output_dir));
println!();
}
pub(crate) fn run_import_hf(
input: &Path,
output: &Path,
include_added_tokens: bool,
json_output: bool,
) -> Result<()> {
if !input.exists() {
return Err(CliError::FileNotFound(input.to_path_buf()));
}
let raw = std::fs::read_to_string(input).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot read {}: {e}",
input.display()
))
})?;
let parsed: serde_json::Value = serde_json::from_str(&raw).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] {} is not valid JSON: {e}",
input.display()
))
})?;
let model_type = parsed
.get("model")
.and_then(|m| m.get("type"))
.and_then(serde_json::Value::as_str)
.ok_or_else(|| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] {} has no model.type field; \
not a recognizable HF tokenizer.json",
input.display()
))
})?;
if model_type != "BPE" {
return Err(CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] FALSIFY-TOK-IMPORT-HF-005: \
model.type = '{model_type}' but only 'BPE' is supported. \
{} cannot be imported with this subcommand. \
Aprender's BPE loader requires GPT-2-style vocab.json + merges.txt; \
Unigram and WordPiece use different state machines and need separate paths.",
input.display()
)));
}
let vocab_obj = parsed
.get("model")
.and_then(|m| m.get("vocab"))
.and_then(serde_json::Value::as_object)
.ok_or_else(|| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] {} has no model.vocab object",
input.display()
))
})?;
let bpe_vocab_count = vocab_obj.len();
let merges_arr = parsed
.get("model")
.and_then(|m| m.get("merges"))
.and_then(serde_json::Value::as_array)
.ok_or_else(|| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] {} has no model.merges array",
input.display()
))
})?;
let merges_count = merges_arr.len();
let added_tokens_arr = parsed
.get("added_tokens")
.and_then(serde_json::Value::as_array)
.cloned()
.unwrap_or_default();
let added_tokens_count = added_tokens_arr.len();
let mut effective_vocab: serde_json::Map<String, serde_json::Value> = vocab_obj.clone();
if include_added_tokens {
for tok in &added_tokens_arr {
if let (Some(content), Some(id)) = (
tok.get("content").and_then(serde_json::Value::as_str),
tok.get("id").and_then(serde_json::Value::as_u64),
) {
effective_vocab.insert(
content.to_string(),
serde_json::Value::Number(serde_json::Number::from(id)),
);
}
}
}
std::fs::create_dir_all(output).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot create output dir {}: {e}",
output.display()
))
})?;
let vocab_path = output.join("vocab.json");
let vocab_json = serde_json::to_string_pretty(&effective_vocab)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?;
std::fs::write(&vocab_path, vocab_json).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot write {}: {e}",
vocab_path.display()
))
})?;
let merges_path = output.join("merges.txt");
let mut merges_body = String::from("#version: 0.2\n");
for (idx, m) in merges_arr.iter().enumerate() {
let line = match m {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Array(parts) if parts.len() == 2 => {
let a = parts[0].as_str().unwrap_or("");
let b = parts[1].as_str().unwrap_or("");
format!("{a} {b}")
}
_ => {
return Err(CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] merges[{idx}] is neither a string \
nor a [a, b] tuple in {}",
input.display()
)));
}
};
merges_body.push_str(&line);
merges_body.push('\n');
}
std::fs::write(&merges_path, merges_body).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot write {}: {e}",
merges_path.display()
))
})?;
let manifest = serde_json::json!({
"schema": "apr-cli-tokenize-import-hf-v1",
"source": input.display().to_string(),
"source_sha256": sha256_file(input)?,
"model_type": "BPE",
"bpe_vocab_count": bpe_vocab_count,
"merges_count": merges_count,
"added_tokens_count": added_tokens_count,
"include_added_tokens": include_added_tokens,
"effective_vocab_count": effective_vocab.len(),
"extraction_timestamp_utc": chrono::Utc::now().to_rfc3339(),
});
let manifest_path = output.join("manifest.json");
std::fs::write(
&manifest_path,
serde_json::to_string_pretty(&manifest)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?,
)
.map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot write {}: {e}",
manifest_path.display()
))
})?;
if json_output {
println!(
"{}",
serde_json::to_string_pretty(&manifest)
.map_err(|e| CliError::InvalidFormat(e.to_string()))?
);
} else {
output::header("apr tokenize import-hf — HF BPE → aprender extraction");
println!();
output::kv(" Source", input.display().to_string());
output::kv(" BPE vocab", format_number(bpe_vocab_count));
output::kv(" Merges", format_number(merges_count));
output::kv(" Added tokens", format_number(added_tokens_count));
output::kv(
" Effective vocab",
format_number(effective_vocab.len()),
);
output::kv(" Output dir", output.display().to_string());
println!();
println!("{}", "Wrote:".green().bold());
output::kv(" vocab.json", format!("{}/vocab.json", output.display()));
output::kv(" merges.txt", format!("{}/merges.txt", output.display()));
output::kv(
" manifest.json",
format!("{}/manifest.json", output.display()),
);
}
Ok(())
}
fn sha256_file(path: &Path) -> Result<String> {
use sha2::{Digest, Sha256};
let bytes = std::fs::read(path).map_err(|e| {
CliError::ValidationFailed(format!(
"[apr-cli-tokenize-import-hf-v1] cannot read {} for sha256: {e}",
path.display()
))
})?;
let mut h = Sha256::new();
h.update(&bytes);
Ok(format!("{:x}", h.finalize()))
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn write_corpus_file(dir: &Path, name: &str, lines: &[&str]) -> std::path::PathBuf {
let p = dir.join(name);
let body = lines.join("\n");
std::fs::write(&p, body).expect("write corpus");
p
}
#[test]
fn run_train_happy_path_jsonl_file() {
let tmp = TempDir::new().expect("tempdir");
let corpus = write_corpus_file(
tmp.path(),
"corpus.jsonl",
&[
r#"{"content": "hello world hello"}"#,
r#"{"content": "hello there world"}"#,
],
);
let out = tmp.path().join("tok");
run_train(&corpus, 300, 1, &out, "nfc", true).expect("train");
assert!(out.join("vocab.json").exists());
assert!(out.join("merges.txt").exists());
let vocab = std::fs::read_to_string(out.join("vocab.json")).expect("read vocab");
assert!(
vocab.contains("\"<unk>\""),
"vocab.json missing <unk>: {}",
vocab
);
let merges = std::fs::read_to_string(out.join("merges.txt")).expect("read merges");
assert!(
merges.starts_with("#version: 0.2"),
"merges.txt missing header: {}",
merges
);
}
#[test]
fn run_train_directory_corpus_walks_jsonl() {
let tmp = TempDir::new().expect("tempdir");
let corpus_dir = tmp.path().join("corpus");
std::fs::create_dir_all(&corpus_dir).expect("mkdir");
write_corpus_file(
&corpus_dir,
"a.jsonl",
&[r#"{"content": "alpha beta alpha"}"#],
);
write_corpus_file(
&corpus_dir,
"b.jsonl",
&[r#"{"content": "gamma delta gamma"}"#],
);
std::fs::write(corpus_dir.join("notes.txt"), "ignore me").expect("write ignored");
let out = tmp.path().join("tok");
run_train(&corpus_dir, 300, 2, &out, "nfc", true).expect("train");
assert!(out.join("vocab.json").exists());
assert!(out.join("merges.txt").exists());
}
#[cfg(feature = "training")]
#[test]
fn run_train_honors_min_frequency_pruning() {
let tmp = TempDir::new().expect("tempdir");
let lines: Vec<String> = std::iter::repeat_n(r#"{"content": "abc"}"#.to_string(), 5)
.chain(std::iter::once(r#"{"content": "xyz"}"#.to_string()))
.collect();
let body = lines.join("\n");
let corpus = tmp.path().join("corpus.jsonl");
std::fs::write(&corpus, body).expect("write corpus");
let out = tmp.path().join("tok");
run_train(&corpus, 300, 2, &out, "nfc", true).expect("train");
let merges = std::fs::read_to_string(out.join("merges.txt")).expect("read merges.txt");
assert!(
merges.contains("61 62") || merges.contains("62 63"),
"Expected a merge from the frequent 'abc' pair, got: {}",
merges
);
assert!(
!merges.contains("78 79"),
"min_frequency=2 failed to prune singleton 'xy' pair: {}",
merges
);
assert!(
!merges.contains("79 7a"),
"min_frequency=2 failed to prune singleton 'yz' pair: {}",
merges
);
let vocab = std::fs::read_to_string(out.join("vocab.json")).expect("read vocab");
assert!(
!vocab.contains("\"78797a\""),
"min_frequency=2 failed to prune merged 'xyz' token from vocab: {}",
vocab
);
}
#[test]
fn run_train_rejects_unknown_normalization() {
let tmp = TempDir::new().expect("tempdir");
let corpus = write_corpus_file(tmp.path(), "corpus.jsonl", &[r#"{"content": "x y"}"#]);
let err = run_train(&corpus, 300, 1, tmp.path(), "nfkd", true)
.expect_err("should reject unsupported normalization");
match err {
CliError::ValidationFailed(msg) => assert!(msg.contains("nfkd")),
other => panic!("unexpected error: {other:?}"),
}
}
fn write_minimal_bpe_tokenizer_json(dir: &Path, n_vocab: usize, n_merges: usize) -> PathBuf {
let mut vocab = serde_json::Map::new();
for i in 0..n_vocab {
vocab.insert(format!("tok{i}"), serde_json::Value::Number(i.into()));
}
let merges: Vec<serde_json::Value> = (0..n_merges)
.map(|i| serde_json::Value::String(format!("a{i} b{i}")))
.collect();
let added_tokens = vec![serde_json::json!({
"id": n_vocab,
"content": "<|endoftext|>",
"special": true,
})];
let tok = serde_json::json!({
"version": "1.0",
"added_tokens": added_tokens,
"model": {
"type": "BPE",
"vocab": vocab,
"merges": merges,
},
});
let path = dir.join("tokenizer.json");
std::fs::write(
&path,
serde_json::to_string_pretty(&tok).expect("serialize tok"),
)
.expect("write tok");
path
}
#[test]
fn import_hf_qwen_bpe_writes_vocab_and_merges() {
let tmp = TempDir::new().expect("tempdir");
let input = write_minimal_bpe_tokenizer_json(tmp.path(), 1000, 800);
let output = tmp.path().join("extracted");
run_import_hf(&input, &output, false, true).expect("import-hf should succeed on BPE input");
let vocab_path = output.join("vocab.json");
assert!(vocab_path.exists(), "vocab.json must exist");
let vocab_str = std::fs::read_to_string(&vocab_path).expect("read vocab.json");
let vocab_obj: serde_json::Map<String, serde_json::Value> =
serde_json::from_str(&vocab_str).expect("parse vocab.json");
assert_eq!(
vocab_obj.len(),
1000,
"FALSIFY-TOK-IMPORT-HF-002: vocab.json must have 1000 entries (default mode), got {}",
vocab_obj.len()
);
let merges_path = output.join("merges.txt");
assert!(merges_path.exists(), "merges.txt must exist");
let merges_str = std::fs::read_to_string(&merges_path).expect("read merges.txt");
let merge_lines = merges_str.lines().filter(|l| !l.starts_with('#')).count();
assert_eq!(
merge_lines, 800,
"FALSIFY-TOK-IMPORT-HF-002: merges.txt must have 800 merge lines, got {merge_lines}"
);
}
#[test]
fn import_hf_vocab_count_matches_input() {
let tmp = TempDir::new().expect("tempdir");
let input = write_minimal_bpe_tokenizer_json(tmp.path(), 12345, 100);
let output = tmp.path().join("extracted");
run_import_hf(&input, &output, false, true).expect("import-hf");
let vocab_obj: serde_json::Map<String, serde_json::Value> =
serde_json::from_str(&std::fs::read_to_string(output.join("vocab.json")).unwrap())
.unwrap();
assert_eq!(
vocab_obj.len(),
12345,
"FALSIFY-TOK-IMPORT-HF-003: vocab count must match input model.vocab"
);
}
#[test]
fn import_hf_merges_format_and_order() {
let tmp = TempDir::new().expect("tempdir");
let input = write_minimal_bpe_tokenizer_json(tmp.path(), 10, 5);
let output = tmp.path().join("extracted");
run_import_hf(&input, &output, false, true).expect("import-hf");
let body = std::fs::read_to_string(output.join("merges.txt")).expect("read merges");
let lines: Vec<&str> = body.lines().filter(|l| !l.starts_with('#')).collect();
assert_eq!(lines.len(), 5);
for (i, line) in lines.iter().enumerate() {
assert_eq!(
line.trim(),
format!("a{i} b{i}"),
"FALSIFY-TOK-IMPORT-HF-004: merge[{i}] order or format mismatch"
);
}
}
#[test]
fn import_hf_unigram_input_errors() {
let tmp = TempDir::new().expect("tempdir");
let input = tmp.path().join("tokenizer.json");
let unigram = serde_json::json!({
"model": { "type": "Unigram", "vocab": [] },
});
std::fs::write(&input, serde_json::to_string_pretty(&unigram).unwrap()).unwrap();
let output = tmp.path().join("extracted");
let err = run_import_hf(&input, &output, false, true)
.expect_err("FALSIFY-TOK-IMPORT-HF-005: Unigram input MUST fail-fast");
match err {
CliError::ValidationFailed(msg) => {
assert!(
msg.contains("FALSIFY-TOK-IMPORT-HF-005"),
"error must cite falsifier id (auditability): {msg}"
);
assert!(
msg.contains("Unigram"),
"error must name the actual model type: {msg}"
);
}
other => panic!("unexpected error variant: {other:?}"),
}
}
#[test]
fn import_hf_include_added_tokens_appends_specials() {
let tmp = TempDir::new().expect("tempdir");
let input = write_minimal_bpe_tokenizer_json(tmp.path(), 100, 50);
let out_default = tmp.path().join("default");
run_import_hf(&input, &out_default, false, true).expect("default import");
let v_default: serde_json::Map<String, serde_json::Value> = serde_json::from_str(
&std::fs::read_to_string(out_default.join("vocab.json")).unwrap(),
)
.unwrap();
assert_eq!(v_default.len(), 100);
assert!(
!v_default.contains_key("<|endoftext|>"),
"default mode must NOT include added_tokens"
);
let out_full = tmp.path().join("full");
run_import_hf(&input, &out_full, true, true).expect("full import");
let v_full: serde_json::Map<String, serde_json::Value> = serde_json::from_str(
&std::fs::read_to_string(out_full.join("vocab.json")).unwrap(),
)
.unwrap();
assert_eq!(v_full.len(), 101);
assert!(
v_full.contains_key("<|endoftext|>"),
"include-added-tokens mode must include the special"
);
}
}