use anyhow::{Context, Result, bail};
use clap::Parser;
use multiscreen_rs::prelude::*;
use rayon::prelude::*;
use sentencepiece_rs::SentencePieceProcessor;
use serde::{Deserialize, Serialize};
use std::collections::hash_map::DefaultHasher;
use std::fs;
use std::hash::{Hash, Hasher};
use std::io::{BufRead, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Instant;
struct SpTokenizer {
proc: SentencePieceProcessor,
}
impl SpTokenizer {
fn load(path: &Path) -> Result<Self> {
Ok(Self {
proc: SentencePieceProcessor::open(path)
.with_context(|| format!("failed to load {}", path.display()))?,
})
}
fn encode(&self, text: &str) -> Vec<u32> {
self.proc
.encode_to_ids(text)
.unwrap_or_default()
.into_iter()
.map(|id| id as u32)
.collect()
}
fn decode(&self, ids: &[u32]) -> String {
let ids: Vec<usize> = ids.iter().map(|&id| id as usize).collect();
self.proc.decode_ids(&ids).unwrap_or_default()
}
fn vocab_size(&self) -> usize {
self.proc.model().vocab_size()
}
fn eos_id(&self) -> Option<u32> {
self.proc.eos_id().map(|id| id as u32)
}
}
#[derive(Parser)]
#[command(
name = "train_with_tokenizer",
about = "Train a Multiscreen LM with SentencePiece and produce a full report"
)]
struct Args {
#[arg(long, default_value = "examples/data")]
train_dir: PathBuf,
#[arg(long, default_value = "runs/my-model")]
run_dir: PathBuf,
#[arg(long, default_value = "10m")]
budget: String,
#[arg(long, default_value_t = 10_000)]
steps: usize,
#[arg(long, default_value_t = 4)]
batch_size: usize,
#[arg(long, default_value_t = 128)]
seq_len: usize,
#[arg(long, default_value_t = 2e-4)]
lr: f64,
#[arg(long, default_value_t = 0.1)]
val_split: f64,
#[arg(long, default_value_t = 20)]
latency_tokens: usize,
#[arg(long, default_value_t = 100)]
log_interval: usize,
#[arg(long, default_value_t = 0)]
checkpoint_interval: usize,
#[arg(long, default_value_t = 0)]
max_samples: usize,
#[arg(long, default_value_t = false)]
eval_only: bool,
}
fn split_sentences(text: &str) -> Vec<String> {
let chars: Vec<char> = text.chars().collect();
let mut sentences = Vec::new();
let mut start = 0;
let n = chars.len();
let mut i = 0;
while i < n {
let c = chars[i];
if c == '.' || c == '!' || c == '?' {
let mut end = i + 1;
if end < n && (chars[end] == '"' || chars[end] == '\'' || chars[end] == '\u{201d}') {
end += 1;
}
let sentence: String = chars[start..end].iter().collect();
let trimmed = sentence.trim().to_owned();
if !trimmed.is_empty() {
sentences.push(trimmed);
}
start = end;
i = end;
} else {
i += 1;
}
}
if start < n {
let remaining: String = chars[start..].iter().collect();
let trimmed = remaining.trim().to_owned();
if !trimmed.is_empty() {
sentences.push(trimmed);
}
}
sentences
}
fn load_samples(dir: &Path, max_samples: usize) -> Result<Vec<(String, String)>> {
let mut samples = Vec::new();
let is_maxed = |len: usize| max_samples > 0 && len >= max_samples;
for entry in fs::read_dir(dir).with_context(|| format!("cannot read {}", dir.display()))? {
if is_maxed(samples.len()) {
break;
}
let entry = entry?;
let path = entry.path();
let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
match ext {
"csv" => {
let text = fs::read_to_string(&path)
.with_context(|| format!("cannot read {}", path.display()))?;
let mut reader = csv::ReaderBuilder::new()
.flexible(true)
.from_reader(text.as_bytes());
let headers = reader.headers()?.clone();
let has_prompt = headers.iter().any(|h| h.trim() == "prompt");
let has_response = headers.iter().any(|h| h.trim() == "response");
if !has_prompt || !has_response {
continue;
}
let prompt_idx = headers.iter().position(|h| h.trim() == "prompt").unwrap();
let response_idx = headers.iter().position(|h| h.trim() == "response").unwrap();
for result in reader.records() {
if is_maxed(samples.len()) {
break;
}
let record = match result {
Ok(r) => r,
Err(_) => continue,
};
let prompt = record.get(prompt_idx).unwrap_or("").trim().to_owned();
let response = record.get(response_idx).unwrap_or("").trim().to_owned();
if prompt.is_empty() || response.is_empty() {
continue;
}
samples.push((prompt, response));
}
}
"txt" => {
let file = fs::File::open(&path)
.with_context(|| format!("cannot open {}", path.display()))?;
let reader = std::io::BufReader::new(file);
let mut current_lines: Vec<String> = Vec::new();
let flush = |lines: &mut Vec<String>, out: &mut Vec<(String, String)>| {
if lines.is_empty() {
return;
}
let story = lines.join(" ");
lines.clear();
let story = story.trim().to_owned();
if story.is_empty() {
return;
}
let sentences = split_sentences(&story);
if sentences.len() < 2 {
out.push((String::new(), story));
return;
}
let split_point = (sentences.len() as f64 * 0.6).ceil() as usize;
let split_point = split_point.max(1).min(sentences.len() - 1);
let input = sentences[..split_point].join(" ");
let output = sentences[split_point..].join(" ");
if output.is_empty() {
out.push((String::new(), story));
} else {
out.push((input, output));
}
};
let mut line_count: u64 = 0;
for line_result in reader.lines() {
if is_maxed(samples.len()) {
break;
}
let line = match line_result {
Ok(l) => l,
Err(_) => continue,
};
line_count += 1;
if line_count.is_multiple_of(1_000_000) {
eprintln!(
" streaming {}: {line_count} lines, {} stories so far",
path.file_name().unwrap_or_default().to_string_lossy(),
samples.len()
);
}
if line.trim().is_empty() {
flush(&mut current_lines, &mut samples);
} else {
current_lines.push(line.trim().to_owned());
}
}
flush(&mut current_lines, &mut samples); eprintln!(
" loaded {} stories from {}",
samples.len(),
path.file_name().unwrap_or_default().to_string_lossy()
);
}
"jsonl" => {
let text = fs::read_to_string(&path)
.with_context(|| format!("cannot read {}", path.display()))?;
for line in text.lines() {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
if let Ok(val) = serde_json::from_str::<serde_json::Value>(trimmed) {
if let Some(s) = val.get("text").and_then(|v| v.as_str()) {
samples.push((String::new(), s.to_owned()));
}
else if let Some(messages) =
val.get("messages").and_then(|v| v.as_array())
{
let mut prompt_parts: Vec<String> = Vec::new();
let mut response_parts: Vec<String> = Vec::new();
for msg in messages {
let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("");
let content =
msg.get("content").and_then(|v| v.as_str()).unwrap_or("");
if content.is_empty() {
continue;
}
match role {
"assistant" => response_parts.push(content.to_owned()),
_ => prompt_parts.push(content.to_owned()),
}
}
if !response_parts.is_empty() {
let prompt = prompt_parts.join("\n");
let response = response_parts.join("\n");
samples.push((prompt, response));
}
}
}
}
}
_ => {}
}
}
Ok(samples)
}
fn parse_budget(s: &str) -> Result<MultiscreenParameterBudget> {
match s.to_lowercase().as_str() {
"1m" => Ok(MultiscreenParameterBudget::Params1M),
"5m" => Ok(MultiscreenParameterBudget::Params5M),
"10m" => Ok(MultiscreenParameterBudget::Params10M),
"50m" => Ok(MultiscreenParameterBudget::Params50M),
"100m" => Ok(MultiscreenParameterBudget::Params100M),
other => bail!("unknown budget '{other}'; use 1m, 5m, 10m, 50m, or 100m"),
}
}
#[derive(Serialize, Deserialize, Clone)]
struct RunMeta {
step: usize,
loss: f64,
params: usize,
model_config: MultiscreenModelConfig,
}
#[derive(Serialize, Deserialize, Clone)]
struct TrainReport {
budget: String,
parameter_count: usize,
seq_len: usize,
batch_size: usize,
learning_rate: f64,
total_steps: usize,
train_duration_secs: f64,
steps_per_sec: f64,
final_train_loss: f64,
best_train_loss: f64,
val: Option<EvalMetrics>,
test: Option<EvalMetrics>,
inference: Option<InferenceMetrics>,
train_samples: usize,
val_samples: usize,
test_samples: usize,
total_tokens: usize,
device: String,
}
#[derive(Serialize, Deserialize, Clone)]
struct EvalMetrics {
loss: f64,
perplexity: f64,
accuracy: f64,
tokens: usize,
}
#[derive(Serialize, Deserialize, Clone)]
struct InferenceMetrics {
avg_ms_per_token: f64,
tokens_generated: usize,
total_secs: f64,
}
#[derive(Serialize, Deserialize)]
struct TokenCache {
cache_key: String,
pairs: Vec<(Vec<u32>, Vec<u32>)>,
}
fn compute_cache_key(train_dir: &Path, max_samples: usize, tokenizer_path: &Path) -> String {
let mut hasher = DefaultHasher::new();
if let Ok(entries) = fs::read_dir(train_dir) {
let mut file_infos: Vec<_> = entries
.filter_map(|e| e.ok())
.map(|e| {
let name = e.file_name().to_string_lossy().to_string();
let size = e.metadata().ok().map(|m| m.len()).unwrap_or(0);
(name, size)
})
.collect();
file_infos.sort_by(|a, b| a.0.cmp(&b.0));
for (name, size) in &file_infos {
name.hash(&mut hasher);
size.hash(&mut hasher);
}
}
max_samples.hash(&mut hasher);
if let Ok(bytes) = fs::read(tokenizer_path) {
bytes.hash(&mut hasher);
}
format!("{:016x}", hasher.finish())
}
fn try_load_token_cache(path: &Path, expected_key: &str) -> Option<Vec<(Vec<u32>, Vec<u32>)>> {
let data = fs::read_to_string(path).ok()?;
let cache: TokenCache = serde_json::from_str(&data).ok()?;
if cache.cache_key != expected_key {
return None;
}
Some(cache.pairs)
}
fn save_token_cache(path: &Path, cache_key: &str, pairs: &[(Vec<u32>, Vec<u32>)]) -> Result<()> {
let cache = TokenCache {
cache_key: cache_key.to_owned(),
pairs: pairs.to_vec(),
};
let json = serde_json::to_string(&cache)?;
fs::write(path, json)?;
Ok(())
}
fn main() -> Result<()> {
let args = Args::parse();
let tokenizer_path = args.train_dir.join("tokenizer.model");
if !tokenizer_path.exists() {
bail!("tokenizer.model not found in {}", args.train_dir.display());
}
let ckpt_dir = args.run_dir.join("checkpoints");
fs::create_dir_all(&ckpt_dir)
.with_context(|| format!("cannot create {}", ckpt_dir.display()))?;
let sp = SpTokenizer::load(&tokenizer_path)?;
let vocab_size = sp.vocab_size();
let eos_id = sp.eos_id();
println!("tokenizer: vocab_size={vocab_size} eos={eos_id:?}");
let samples = load_samples(&args.train_dir, args.max_samples)?;
if samples.is_empty() {
bail!("no training samples found in {}", args.train_dir.display());
}
println!("loaded {} samples", samples.len());
let has_chat_data = samples.iter().any(|(prompt, _)| !prompt.is_empty());
if has_chat_data {
println!("detected chat-format data — using loss masking for prompt tokens");
}
let cache_path = args.run_dir.join("token_cache.json");
let cache_key = compute_cache_key(&args.train_dir, args.max_samples, &tokenizer_path);
let mut chat_pairs: Vec<(Vec<u32>, Vec<u32>)> =
if let Some(pairs) = try_load_token_cache(&cache_path, &cache_key) {
println!(
"loaded {} tokenized pairs from cache ({})",
pairs.len(),
cache_path.display()
);
pairs
} else {
eprintln!("tokenizing {} samples (parallel)...", samples.len());
let tokenize_start = Instant::now();
let newline_ids = sp.encode("\n");
let progress = AtomicUsize::new(0);
let total = samples.len();
let pairs: Vec<(Vec<u32>, Vec<u32>)> = samples
.par_iter()
.enumerate()
.filter_map(|(_, (prompt, response))| {
let done = progress.fetch_add(1, Ordering::Relaxed) + 1;
if done.is_multiple_of(1000) {
eprintln!(
" tokenized {}/{} ({:.1}s)",
done,
total,
tokenize_start.elapsed().as_secs_f64()
);
}
let prompt_ids = if prompt.is_empty() {
Vec::new()
} else {
let mut ids = sp.encode(prompt);
ids.extend_from_slice(&newline_ids);
ids
};
let mut response_ids = sp.encode(response);
if let Some(eos) = eos_id {
response_ids.push(eos);
}
if prompt_ids.len() + response_ids.len() >= 2 {
Some((prompt_ids, response_ids))
} else {
None
}
})
.collect();
eprintln!(
"tokenized {} pairs in {:.1}s",
pairs.len(),
tokenize_start.elapsed().as_secs_f64()
);
eprintln!("saving token cache...");
if let Err(e) = save_token_cache(&cache_path, &cache_key, &pairs) {
eprintln!("warning: failed to save token cache: {e}");
} else {
println!("token cache saved: {}", cache_path.display());
}
eprintln!("cache save done.");
pairs
};
if chat_pairs.is_empty() {
bail!("all samples tokenized to <2 tokens — cannot train");
}
eprintln!("deduplicating {} pairs...", chat_pairs.len());
let dedup_start = Instant::now();
chat_pairs.sort();
chat_pairs.dedup();
eprintln!(
"dedup done: {} pairs in {:.1}s",
chat_pairs.len(),
dedup_start.elapsed().as_secs_f64()
);
let sequences: Vec<Vec<u32>> = chat_pairs
.iter()
.map(|(p, r)| {
let mut seq = p.clone();
seq.extend_from_slice(r);
seq
})
.collect();
let total_tokens: usize = sequences.iter().map(|s| s.len()).sum();
println!(
"{} token sequences ({} tokens total, deduped)",
sequences.len(),
total_tokens
);
let n = chat_pairs.len();
let val_count = ((n as f64 * args.val_split).ceil() as usize)
.max(1)
.min(n / 2);
let test_count = val_count.min(n - val_count);
let train_count = n.saturating_sub(val_count + test_count);
let (train_pairs, rest) = chat_pairs.split_at(train_count);
let (val_pairs, test_pairs) = rest.split_at(rest.len().min(val_count));
let val_seqs: Vec<Vec<u32>> = val_pairs
.iter()
.map(|(p, r)| {
let mut seq = p.clone();
seq.extend_from_slice(r);
seq
})
.collect();
let test_seqs: Vec<Vec<u32>> = test_pairs
.iter()
.map(|(p, r)| {
let mut seq = p.clone();
seq.extend_from_slice(r);
seq
})
.collect();
println!(
"split: {} train, {} val, {} test sequences",
train_pairs.len(),
val_seqs.len(),
test_seqs.len()
);
eprintln!("building model config...");
let budget = parse_budget(&args.budget)?;
let config = MultiscreenModelConfig::for_parameter_budget(budget, vocab_size, args.seq_len);
let param_count = config.estimated_parameter_count();
println!("model: {} params, budget={}", param_count, args.budget);
let config_json = serde_json::to_string_pretty(&config)?;
fs::write(ckpt_dir.join("config.json"), &config_json)?;
fs::copy(&tokenizer_path, args.run_dir.join("tokenizer.model"))
.with_context(|| "failed to copy tokenizer to run dir")?;
eprintln!("initializing device...");
let device_start = Instant::now();
let device = auto_device()?;
let device_name = device_label(&device);
println!(
"device: {device_name} (took {:.1}s)",
device_start.elapsed().as_secs_f64()
);
let final_ckpt = ckpt_dir.join("latest.mpk");
let loss_csv_path = args.run_dir.join("loss.csv");
let (train_steps, final_loss, train_params, train_secs, steps_per_sec, best_loss) = if args
.eval_only
{
anyhow::ensure!(
final_ckpt.exists(),
"--eval-only requires an existing checkpoint at {}",
final_ckpt.display()
);
let meta_path = ckpt_dir.join("latest.json");
let meta: RunMeta = if meta_path.exists() {
let json = fs::read_to_string(&meta_path)
.with_context(|| format!("cannot read {}", meta_path.display()))?;
serde_json::from_str(&json)
.with_context(|| format!("cannot parse {}", meta_path.display()))?
} else {
anyhow::bail!(
"--eval-only requires {} from a previous training run",
meta_path.display()
);
};
println!("skipping training, loading checkpoint from previous run");
println!(
" step={} loss={:.6} params={}",
meta.step, meta.loss, meta.params
);
(
meta.step,
meta.loss as f32,
meta.params,
0.0,
0.0,
meta.loss,
)
} else {
let mut loss_csv = fs::File::create(&loss_csv_path)
.with_context(|| format!("cannot create loss CSV at {}", loss_csv_path.display()))?;
writeln!(loss_csv, "step,loss")?;
eprintln!("building trainer (model init on GPU)...");
let trainer_start = Instant::now();
let mut trainer = Trainer::builder()
.vocab_size(vocab_size)
.budget(budget)
.device({
#[cfg(feature = "cuda")]
{
device.clone()
}
#[cfg(not(feature = "cuda"))]
{
device
}
})
.batch_size(args.batch_size)
.seq_len(args.seq_len)
.steps(args.steps)
.learning_rate(args.lr)
.checkpoint_dir(ckpt_dir.to_string_lossy().into_owned())
.checkpoint_interval(args.checkpoint_interval)
.build()?;
eprintln!(
"trainer built in {:.1}s",
trainer_start.elapsed().as_secs_f64()
);
let log_interval = args.log_interval;
let mut bl = f64::MAX;
let _loss_values: Vec<(usize, f64)> = Vec::new();
println!("\ntraining {} steps...", args.steps);
let train_start = Instant::now();
let rpt = if has_chat_data {
trainer.train_on_chat_sequences_with_callback(train_pairs, |step, loss| {
let loss_f64 = loss as f64;
if loss_f64 < bl {
bl = loss_f64;
}
let _ = writeln!(&mut loss_csv, "{step},{loss_f64}");
let _ = loss_csv.flush();
if step == 0 || (step + 1) % log_interval == 0 {
let elapsed = train_start.elapsed().as_secs_f64();
let sps = if step > 0 {
(step + 1) as f64 / elapsed
} else {
0.0
};
println!(
" step {}/{} loss={:.6} best={:.6} {:.1} steps/s",
step + 1,
args.steps,
loss_f64,
bl,
sps
);
}
})
} else {
let train_seqs: Vec<Vec<u32>> = train_pairs
.iter()
.map(|(p, r)| {
let mut seq = p.clone();
seq.extend_from_slice(r);
seq
})
.collect();
trainer.train_on_token_sequences_with_callback(&train_seqs, |step, loss| {
let loss_f64 = loss as f64;
if loss_f64 < bl {
bl = loss_f64;
}
let _ = writeln!(&mut loss_csv, "{step},{loss_f64}");
let _ = loss_csv.flush();
if step == 0 || (step + 1) % log_interval == 0 {
let elapsed = train_start.elapsed().as_secs_f64();
let sps = if step > 0 {
(step + 1) as f64 / elapsed
} else {
0.0
};
println!(
" step {}/{} loss={:.6} best={:.6} {:.1} steps/s",
step + 1,
args.steps,
loss_f64,
bl,
sps
);
}
})
}?;
let train_duration = train_start.elapsed();
let ts = train_duration.as_secs_f64();
let sps = args.steps as f64 / ts;
drop(loss_csv);
println!("\ntraining complete in {:.1}s ({:.1} steps/s)", ts, sps);
println!(
" final loss: {:.6} best loss: {:.6} (step {}) params: {}",
rpt.final_loss,
rpt.best_loss,
rpt.best_loss_step + 1,
rpt.parameter_count
);
let final_path = ckpt_dir.join("final.mpk");
trainer.save_checkpoint(final_path.to_str().unwrap())?;
println!("final checkpoint: {}", final_path.display());
let best_path = ckpt_dir.join("best.mpk");
if best_path.exists() {
fs::copy(&best_path, &final_ckpt)
.with_context(|| format!("failed to copy {:?} → {:?}", best_path, final_ckpt))?;
println!(
"best checkpoint (loss {:.6} @ step {}): {}",
rpt.best_loss,
rpt.best_loss_step + 1,
final_ckpt.display()
);
} else {
trainer.save_checkpoint(final_ckpt.to_str().unwrap())?;
println!("checkpoint: {}", final_ckpt.display());
}
let meta = RunMeta {
step: rpt.steps,
loss: rpt.best_loss as f64,
params: rpt.parameter_count,
model_config: config.clone(),
};
fs::write(
ckpt_dir.join("latest.json"),
serde_json::to_string_pretty(&meta)?,
)?;
drop(trainer);
(
rpt.steps,
rpt.final_loss,
rpt.parameter_count,
ts,
sps,
rpt.best_loss as f64,
)
};
println!("\nevaluating...");
use burn::module::AutodiffModule;
let eval_model = {
let mut m = DefaultMultiscreenModel::new(config.clone(), &device)?;
m.load_parameters(&final_ckpt)?;
m.valid() };
let inner_device = device;
let val_metrics = if !val_seqs.is_empty() {
println!(" validation set ({} sequences)...", val_seqs.len());
let result =
eval_model.evaluate_on_sequences(&val_seqs, args.seq_len, 4, 0, &inner_device)?;
println!(
" loss={:.4} ppl={:.2} accuracy={:.2}% ({} tokens)",
result.loss,
result.perplexity,
result.accuracy * 100.0,
result.total_tokens
);
Some(EvalMetrics {
loss: result.loss as f64,
perplexity: result.perplexity as f64,
accuracy: result.accuracy,
tokens: result.total_tokens,
})
} else {
None
};
let test_metrics = if !test_seqs.is_empty() {
println!(" test set ({} sequences)...", test_seqs.len());
let result =
eval_model.evaluate_on_sequences(&test_seqs, args.seq_len, 4, 0, &inner_device)?;
println!(
" loss={:.4} ppl={:.2} accuracy={:.2}% ({} tokens)",
result.loss,
result.perplexity,
result.accuracy * 100.0,
result.total_tokens
);
Some(EvalMetrics {
loss: result.loss as f64,
perplexity: result.perplexity as f64,
accuracy: result.accuracy,
tokens: result.total_tokens,
})
} else {
None
};
drop(eval_model);
println!("\nmeasuring inference latency...");
let prompt = if has_chat_data {
match train_pairs.first() {
Some((p, _)) if !p.is_empty() => {
sp.decode(p)
}
_ => "Once upon a time, there was a little".to_owned(),
}
} else {
"Once upon a time, there was a little".to_owned()
};
println!(" sample prompt: {prompt}");
let prompt_ids = sp.encode(&prompt);
let chat_model = ChatModel::load(&final_ckpt)?;
let latency_start = Instant::now();
let output = chat_model.generate(
&prompt_ids,
GenerationConfig {
max_new_tokens: args.latency_tokens,
..Default::default()
},
)?;
let latency_secs = latency_start.elapsed().as_secs_f64();
let new_tokens = output.len().saturating_sub(prompt_ids.len());
let avg_ms_per_token = if new_tokens > 0 {
latency_secs * 1000.0 / new_tokens as f64
} else {
0.0
};
let inference_metrics = InferenceMetrics {
avg_ms_per_token,
tokens_generated: new_tokens,
total_secs: latency_secs,
};
println!(
" {} tokens in {:.3}s = {:.2} ms/token",
new_tokens, latency_secs, avg_ms_per_token
);
let new_token_ids: Vec<u32> = if output.len() > prompt_ids.len() {
output[prompt_ids.len()..].to_vec()
} else {
output.clone()
};
let generated_text = sp.decode(&new_token_ids);
let full_text = sp.decode(&output);
println!("\nsample output:");
println!(" prompt: {prompt}");
println!(" generated: {generated_text}");
println!(" full: {full_text}");
let full_report = TrainReport {
budget: args.budget.clone(),
parameter_count: train_params,
seq_len: args.seq_len,
batch_size: args.batch_size,
learning_rate: args.lr,
total_steps: train_steps,
train_duration_secs: train_secs,
steps_per_sec,
final_train_loss: final_loss as f64,
best_train_loss: best_loss,
val: val_metrics,
test: test_metrics,
inference: Some(inference_metrics),
train_samples: train_pairs.len(),
val_samples: val_seqs.len(),
test_samples: test_seqs.len(),
total_tokens,
device: device_name,
};
let report_json = serde_json::to_string_pretty(&full_report)?;
let report_path = args.run_dir.join("report.json");
fs::write(&report_path, &report_json)?;
println!("\nreport: {}", report_path.display());
let md = format_report_md(&full_report);
let report_md_path = args.run_dir.join("report.md");
fs::write(&report_md_path, &md)?;
println!("report: {}", report_md_path.display());
println!("\nloss CSV: {}", loss_csv_path.display());
println!(
"to generate a loss plot: python examples/plot_loss.py {}",
loss_csv_path.display()
);
println!(
"\nnext step: cargo run --release --example chat_with_tokenizer -- --run-dir {}",
args.run_dir.display()
);
Ok(())
}
fn format_report_md(r: &TrainReport) -> String {
let mut md = String::new();
md.push_str("# Training Report\n\n");
md.push_str("## Configuration\n\n");
md.push_str("| Parameter | Value |\n|---|---|\n");
md.push_str(&format!("| Budget | {} |\n", r.budget));
md.push_str(&format!(
"| Parameters | {} (~{:.1}M) |\n",
r.parameter_count,
r.parameter_count as f64 / 1e6
));
md.push_str(&format!("| Seq Length | {} |\n", r.seq_len));
md.push_str(&format!("| Batch Size | {} |\n", r.batch_size));
md.push_str(&format!("| Learning Rate | {} |\n", r.learning_rate));
md.push_str(&format!("| Total Steps | {} |\n", r.total_steps));
md.push_str(&format!("| Device | {} |\n", r.device));
md.push('\n');
md.push_str("## Data\n\n");
md.push_str("| Split | Sequences |\n|---|---|\n");
md.push_str(&format!("| Train | {} |\n", r.train_samples));
md.push_str(&format!("| Val | {} |\n", r.val_samples));
md.push_str(&format!("| Test | {} |\n", r.test_samples));
md.push_str(&format!("| Total Tokens | {} |\n", r.total_tokens));
md.push('\n');
md.push_str("## Training\n\n");
md.push_str("| Metric | Value |\n|---|---|\n");
md.push_str(&format!("| Duration | {:.1}s |\n", r.train_duration_secs));
md.push_str(&format!(
"| Throughput | {:.1} steps/s |\n",
r.steps_per_sec
));
md.push_str(&format!("| Final Loss | {:.6} |\n", r.final_train_loss));
md.push_str(&format!("| Best Loss | {:.6} |\n", r.best_train_loss));
md.push('\n');
if let Some(val) = &r.val {
md.push_str("## Validation\n\n");
md.push_str("| Metric | Value |\n|---|---|\n");
md.push_str(&format!("| Loss | {:.4} |\n", val.loss));
md.push_str(&format!("| Perplexity | {:.2} |\n", val.perplexity));
md.push_str(&format!("| Accuracy | {:.2}% |\n", val.accuracy * 100.0));
md.push_str(&format!("| Tokens | {} |\n", val.tokens));
md.push('\n');
}
if let Some(test) = &r.test {
md.push_str("## Test\n\n");
md.push_str("| Metric | Value |\n|---|---|\n");
md.push_str(&format!("| Loss | {:.4} |\n", test.loss));
md.push_str(&format!("| Perplexity | {:.2} |\n", test.perplexity));
md.push_str(&format!("| Accuracy | {:.2}% |\n", test.accuracy * 100.0));
md.push_str(&format!("| Tokens | {} |\n", test.tokens));
md.push('\n');
}
if let Some(inf) = &r.inference {
md.push_str("## Inference\n\n");
md.push_str("| Metric | Value |\n|---|---|\n");
md.push_str(&format!(
"| Avg Latency | {:.2} ms/token |\n",
inf.avg_ms_per_token
));
md.push_str(&format!(
"| Tokens Generated | {} |\n",
inf.tokens_generated
));
md.push_str(&format!("| Total Time | {:.3}s |\n", inf.total_secs));
md.push('\n');
}
md.push_str("## Loss Plot\n\n");
md.push_str("Generate with: `python examples/plot_loss.py runs/<name>/loss.csv`\n");
md
}