#![cfg(feature = "rocm-hip")]
use std::env;
use std::fs;
use std::io::{BufReader, Read, Write};
use std::path::{Path, PathBuf};
use std::process::{Command, Stdio};
use std::time::{Instant, SystemTime, UNIX_EPOCH};
use serde::Serialize;
use serde_json::{Value, json};
const DEFAULT_ARTIFACT_SUBDIR: &str = "artifacts";
const DEFAULT_TEST_NAME: &str = "mlp_10k_e2e";
const LOSS_REDUCTION_TARGET: f32 = 0.5;
const EPS: f32 = 1e-6;
#[derive(Debug, Clone, Serialize)]
struct RunSample {
step: u32,
loss: f32,
pre_clip_norm: Option<f32>,
}
#[derive(Debug, Clone, Serialize)]
struct SystemContext {
hostname: String,
kernel_release: String,
hipcc_version: Option<String>,
rocminfo_head: Option<String>,
cargo_version: String,
rustc_version: String,
commit_sha: Option<String>,
hip_cache_path: String,
hip_cache_size_bytes: u64,
hip_cache_file_count: usize,
os_uname: String,
}
#[derive(Debug, Clone, Serialize)]
struct RunCommand {
program: String,
args: Vec<String>,
exit_code: Option<i32>,
duration_sec: f64,
started_at_unix: u64,
}
#[derive(Debug, Clone, Serialize)]
struct RunAggregate {
initial_loss: Option<f32>,
final_loss: Option<f32>,
loss_reduction_pct: Option<f32>,
total_steps: Option<u32>,
time_to_50pct_reduction_step: Option<u32>,
min_loss: Option<f32>,
max_loss: Option<f32>,
loss_curve_monotonic_after_step: Option<u32>,
test_passed: bool,
}
#[derive(Debug, Clone, Serialize)]
struct ArtifactHeader {
schema_version: String,
artifact_kind: String,
generated_at_unix: u64,
generated_at_iso: String,
artifact_path: String,
}
#[derive(Debug, Clone, Serialize)]
struct RunRecord {
header: ArtifactHeader,
system: SystemContext,
command: RunCommand,
samples: Vec<RunSample>,
aggregate: RunAggregate,
raw_stderr_tail: String,
}
struct Cli {
out_dir: Option<PathBuf>,
search_dir: Option<PathBuf>,
test_name: String,
dry_run: bool,
keep_cache: bool,
show_help: bool,
}
fn print_usage() {
eprintln!(
"Usage: collect_paper_artifacts [options]\n\
\n\
Re-runs the 10K MLP convergence test in a subprocess and writes\n\
a JSONL artifact capturing per-step losses + system context.\n\
\n\
Options:\n \
--dry-run capture context only, do not run the test\n \
--out-dir <path> override artifact output directory\n \
(default: <search-dir>/docs/artifacts)\n \
--search-dir <path> override tokitai-search repo root\n \
(default: $TOKITAI_SEARCH_HOME, or\n \
sibling of tokitai-operator)\n \
--test <name> test binary filter (default: mlp_10k_e2e)\n \
--keep-cache do not clean HIP kernel cache before run\n \
--help show this help and exit\n\
\n\
Exit codes:\n \
0 test passed, artifact written\n \
1 test failed or artifact not writable\n \
2 argument parsing error"
);
}
fn parse_args() -> Result<Cli, String> {
let mut cli = Cli {
out_dir: None,
search_dir: None,
test_name: DEFAULT_TEST_NAME.to_string(),
dry_run: false,
keep_cache: false,
show_help: false,
};
let args: Vec<String> = env::args().skip(1).collect();
let mut i = 0;
while i < args.len() {
let a = args[i].as_str();
match a {
"--help" | "-h" => {
cli.show_help = true;
i += 1;
}
"--dry-run" => {
cli.dry_run = true;
i += 1;
}
"--keep-cache" => {
cli.keep_cache = true;
i += 1;
}
"--out-dir" => {
let v = next_value(&args, i, "--out-dir")?;
cli.out_dir = Some(PathBuf::from(v));
i += 2;
}
"--search-dir" => {
let v = next_value(&args, i, "--search-dir")?;
cli.search_dir = Some(PathBuf::from(v));
i += 2;
}
"--test" => {
let v = next_value(&args, i, "--test")?;
cli.test_name = v;
i += 2;
}
other => return Err(format!("unknown flag: {other}")),
}
}
Ok(cli)
}
fn next_value(args: &[String], i: usize, flag: &str) -> Result<String, String> {
if i + 1 >= args.len() {
return Err(format!("{flag} requires a value"));
}
Ok(args[i + 1].clone())
}
fn resolve_search_root(cli: &Cli) -> Option<PathBuf> {
if let Some(p) = &cli.search_dir {
if p.is_dir() {
return Some(p.clone());
}
}
if let Ok(env_p) = env::var("TOKITAI_SEARCH_HOME") {
let p = PathBuf::from(env_p);
if p.is_dir() {
return Some(p);
}
}
if let Ok(manifest) = env::var("CARGO_MANIFEST_DIR") {
let candidate = Path::new(&manifest)
.parent()
.map(|p| p.join("tokitai-search"));
if let Some(p) = candidate {
if p.is_dir() {
return Some(p);
}
}
}
if let Ok(cwd) = env::current_dir() {
for anc in cwd.ancestors() {
let sib = anc.join("tokitai-search");
if sib.join("docs").is_dir() {
return Some(sib);
}
if anc.join("Cargo.toml").is_file()
&& anc
.file_name()
.map(|n| n == "tokitai-operator")
.unwrap_or(false)
{
let sib2 = anc.parent().map(|p| p.join("tokitai-search"));
if let Some(s) = sib2 {
if s.is_dir() {
return Some(s);
}
}
}
}
}
None
}
fn resolve_out_dir(cli: &Cli, search_root: Option<&Path>) -> PathBuf {
if let Some(p) = &cli.out_dir {
return p.clone();
}
let base = search_root
.map(|r| r.join("docs").join(DEFAULT_ARTIFACT_SUBDIR))
.unwrap_or_else(|| {
env::current_dir()
.unwrap_or_else(|_| PathBuf::from("."))
.join(DEFAULT_ARTIFACT_SUBDIR)
});
base
}
fn hip_cache_dir() -> PathBuf {
if let Ok(p) = env::var("TOKITAI_HIP_CACHE") {
return PathBuf::from(p);
}
if let Ok(manifest) = env::var("CARGO_MANIFEST_DIR") {
return Path::new(&manifest).join("target").join("rocm-hip-cache");
}
PathBuf::from("target/rocm-hip-cache")
}
fn read_pipe(cmd: &str, args: &[&str]) -> Option<String> {
let out = Command::new(cmd).args(args).output().ok()?;
if !out.status.success() {
return None;
}
Some(String::from_utf8_lossy(&out.stdout).to_string())
}
fn read_pipe_with_limit(cmd: &str, args: &[&str], head_lines: usize) -> Option<String> {
let mut out = read_pipe(cmd, args)?;
if head_lines == 0 {
return Some(out);
}
let truncated: String = out.lines().take(head_lines).collect::<Vec<_>>().join("\n");
out = truncated;
if out.len() > 4096 {
out.truncate(4096);
out.push_str("\n... [truncated]\n");
}
Some(out)
}
fn hostname_string() -> String {
read_pipe("hostname", &[])
.map(|s| s.trim().to_string())
.unwrap_or_else(|| env::var("HOSTNAME").unwrap_or_else(|_| "unknown".to_string()))
}
fn uname_release() -> String {
read_pipe("uname", &["-r"])
.map(|s| s.trim().to_string())
.unwrap_or_else(|| "unknown".to_string())
}
fn uname_all() -> String {
read_pipe("uname", &["-a"])
.map(|s| s.trim().to_string())
.unwrap_or_else(|| "unknown".to_string())
}
fn hipcc_version() -> Option<String> {
let out = Command::new("/opt/rocm/bin/hipcc")
.arg("--version")
.output()
.ok()?;
if !out.status.success() {
return None;
}
let combined = format!(
"{}{}",
String::from_utf8_lossy(&out.stdout),
String::from_utf8_lossy(&out.stderr)
);
let mut s = String::new();
for line in combined.lines().take(4) {
s.push_str(line.trim());
s.push('\n');
}
if s.is_empty() { None } else { Some(s) }
}
fn rocminfo_head(n: usize) -> Option<String> {
read_pipe_with_limit("/opt/rocm/bin/rocminfo", &[], n)
}
fn cargo_version() -> String {
read_pipe("cargo", &["--version"])
.map(|s| s.trim().to_string())
.unwrap_or_else(|| "unknown".to_string())
}
fn rustc_version() -> String {
read_pipe("rustc", &["--version"])
.map(|s| s.trim().to_string())
.unwrap_or_else(|| "unknown".to_string())
}
fn commit_sha(operator_root: &Path) -> Option<String> {
let out = Command::new("git")
.args(["rev-parse", "HEAD"])
.current_dir(operator_root)
.output()
.ok()?;
if !out.status.success() {
return None;
}
let s = String::from_utf8_lossy(&out.stdout).trim().to_string();
if s.is_empty() { None } else { Some(s) }
}
fn dir_size_and_count(p: &Path) -> (u64, usize) {
let mut total: u64 = 0;
let mut count: usize = 0;
let mut stack = vec![p.to_path_buf()];
while let Some(d) = stack.pop() {
let entries = match fs::read_dir(&d) {
Ok(it) => it,
Err(_) => continue,
};
for e in entries.flatten() {
let path = e.path();
let md = match e.metadata() {
Ok(m) => m,
Err(_) => continue,
};
if md.is_file() {
total = total.saturating_add(md.len());
count += 1;
} else if md.is_dir() {
stack.push(path);
}
}
}
(total, count)
}
fn collect_system_context(operator_root: &Path, hip_cache: &Path) -> SystemContext {
let (size, count) = dir_size_and_count(hip_cache);
SystemContext {
hostname: hostname_string(),
kernel_release: uname_release(),
hipcc_version: hipcc_version(),
rocminfo_head: rocminfo_head(30),
cargo_version: cargo_version(),
rustc_version: rustc_version(),
commit_sha: commit_sha(operator_root),
hip_cache_path: hip_cache.display().to_string(),
hip_cache_size_bytes: size,
hip_cache_file_count: count,
os_uname: uname_all(),
}
}
fn build_cargo_invocation(cli: &Cli) -> (String, Vec<String>) {
let args: Vec<String> = vec![
"test".to_string(),
"--features".to_string(),
"rocm-hip".to_string(),
"--test".to_string(),
cli.test_name.clone(),
"mlp_10k_loss_decreases_50pct".to_string(),
"--".to_string(),
"--nocapture".to_string(),
"--test-threads=1".to_string(),
];
("cargo".to_string(), args)
}
fn maybe_clean_hip_cache(hip_cache: &Path, keep: bool) {
if keep {
return;
}
if !hip_cache.exists() {
return;
}
if let Ok(rd) = fs::read_dir(hip_cache) {
for e in rd.flatten() {
let p = e.path();
if p.is_dir() {
let _ = fs::remove_dir_all(&p);
} else {
let _ = fs::remove_file(&p);
}
}
}
}
fn run_test_subprocess(cli: &Cli, operator_root: &Path) -> (RunCommand, String, bool) {
let (program, args) = build_cargo_invocation(cli);
let started = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let t0 = Instant::now();
let mut child = match Command::new(&program)
.args(&args)
.current_dir(operator_root)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
{
Ok(c) => c,
Err(e) => {
return (
RunCommand {
program,
args,
exit_code: None,
duration_sec: t0.elapsed().as_secs_f64(),
started_at_unix: started,
},
format!("failed to spawn cargo: {e}"),
false,
);
}
};
let stderr_handle = child.stderr.take().map(|s| {
let mut reader = BufReader::new(s);
let mut buf = String::new();
let _ = reader.read_to_string(&mut buf);
std::thread::spawn(move || buf)
});
let stdout_handle = child.stdout.take().map(|s| {
let mut reader = BufReader::new(s);
let mut buf = String::new();
let _ = reader.read_to_string(&mut buf);
std::thread::spawn(move || buf)
});
let status = child.wait();
let stderr = stderr_handle
.and_then(|h| h.join().ok())
.unwrap_or_default();
let _stdout = stdout_handle
.and_then(|h| h.join().ok())
.unwrap_or_default();
let (exit_code, passed) = match status {
Ok(s) => (s.code(), s.success()),
Err(e) => {
eprintln!("collect_paper_artifacts: wait error: {e}");
(None, false)
}
};
(
RunCommand {
program,
args,
exit_code,
duration_sec: t0.elapsed().as_secs_f64(),
started_at_unix: started,
},
stderr,
passed,
)
}
fn parse_samples(stderr: &str) -> Vec<RunSample> {
let mut out: Vec<RunSample> = Vec::new();
for line in stderr.lines() {
let line = line.trim();
if !line.contains("mlp_10k_e2e") {
continue;
}
if let Some(sample) = parse_loss_line(line) {
if let Some(idx) = out.iter().position(|s| s.step == sample.step) {
out[idx] = sample;
} else {
out.push(sample);
}
} else if let Some(norm_sample) = parse_pre_clip_line(line) {
out.push(norm_sample);
}
}
out.sort_by_key(|s| s.step);
out
}
fn parse_loss_line(line: &str) -> Option<RunSample> {
let step_idx = line.find("step=")?;
let after = &line[step_idx + "step=".len()..];
let trimmed = after.trim_start();
let step_end = trimmed
.find(|c: char| !c.is_ascii_digit())
.unwrap_or(trimmed.len());
if step_end == 0 {
return None;
}
let step: u32 = trimmed[..step_end].parse().ok()?;
let loss_marker = "loss=";
let loss_pos = line.find(loss_marker)?;
let after_loss = &line[loss_pos + loss_marker.len()..];
let trimmed_loss = after_loss.trim_start();
let end = trimmed_loss
.find(|c: char| {
!(c.is_ascii_digit() || c == '.' || c == '-' || c == '+' || c == 'e' || c == 'E')
})
.unwrap_or(trimmed_loss.len());
if end == 0 {
return None;
}
let loss: f32 = trimmed_loss[..end].parse().ok()?;
Some(RunSample {
step,
loss,
pre_clip_norm: None,
})
}
fn parse_pre_clip_line(line: &str) -> Option<RunSample> {
let step_idx = line.find("step=")?;
let after = &line[step_idx + "step=".len()..];
let trimmed = after.trim_start();
let step_end = trimmed
.find(|c: char| !c.is_ascii_digit())
.unwrap_or(trimmed.len());
if step_end == 0 {
return None;
}
let step: u32 = trimmed[..step_end].parse().ok()?;
let marker = "pre_clip_norm=";
let pos = line.find(marker)?;
let after = &line[pos + marker.len()..];
let trimmed = after.trim_start();
let end = trimmed
.find(|c: char| {
!(c.is_ascii_digit() || c == '.' || c == '-' || c == '+' || c == 'e' || c == 'E')
})
.unwrap_or(trimmed.len());
if end == 0 {
return None;
}
let norm: f32 = trimmed[..end].parse().ok()?;
Some(RunSample {
step,
loss: f32::NAN,
pre_clip_norm: Some(norm),
})
}
fn compute_aggregate(samples: &[RunSample], passed: bool) -> RunAggregate {
let loss_samples: Vec<(u32, f32)> = samples
.iter()
.filter(|s| !s.loss.is_nan())
.map(|s| (s.step, s.loss))
.collect();
if loss_samples.is_empty() {
return RunAggregate {
initial_loss: None,
final_loss: None,
loss_reduction_pct: None,
total_steps: None,
time_to_50pct_reduction_step: None,
min_loss: None,
max_loss: None,
loss_curve_monotonic_after_step: None,
test_passed: passed,
};
}
let initial_loss = loss_samples[0].1;
let final_loss = loss_samples[loss_samples.len() - 1].1;
let total_steps = loss_samples.last().map(|(s, _)| *s);
let loss_reduction_pct = if initial_loss.abs() > EPS {
Some(100.0 * (1.0 - final_loss / initial_loss))
} else {
None
};
let target = LOSS_REDUCTION_TARGET * initial_loss;
let time_to_50 = loss_samples
.iter()
.find(|(step, loss)| *step > 0 && *loss <= target)
.map(|(step, _)| *step);
let min_loss = loss_samples
.iter()
.map(|(_, l)| *l)
.fold(f32::INFINITY, f32::min);
let max_loss = loss_samples
.iter()
.map(|(_, l)| *l)
.fold(f32::NEG_INFINITY, f32::max);
let mut last_loss = f32::INFINITY;
let mut first_rise: Option<u32> = None;
for (s, l) in &loss_samples {
if *l > last_loss && first_rise.is_none() {
first_rise = Some(*s);
}
last_loss = *l;
}
let loss_curve_monotonic_after_step = if first_rise.is_none() {
Some(0)
} else {
first_rise
};
RunAggregate {
initial_loss: Some(initial_loss),
final_loss: Some(final_loss),
loss_reduction_pct,
total_steps,
time_to_50pct_reduction_step: time_to_50,
min_loss: Some(min_loss),
max_loss: Some(max_loss),
loss_curve_monotonic_after_step,
test_passed: passed,
}
}
fn write_jsonl(record: &RunRecord, path: &Path) -> std::io::Result<()> {
let value = record_to_json(record);
let mut buf = serde_json::to_string(&value)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
buf.push('\n');
fs::write(path, buf)
}
fn record_to_json(r: &RunRecord) -> Value {
json!({
"schema_version": r.header.schema_version,
"artifact_kind": r.header.artifact_kind,
"generated_at_unix": r.header.generated_at_unix,
"generated_at_iso": r.header.generated_at_iso,
"artifact_path": r.header.artifact_path,
"system": {
"hostname": r.system.hostname,
"kernel_release": r.system.kernel_release,
"hipcc_version": r.system.hipcc_version,
"rocminfo_head": r.system.rocminfo_head,
"cargo_version": r.system.cargo_version,
"rustc_version": r.system.rustc_version,
"commit_sha": r.system.commit_sha,
"hip_cache_path": r.system.hip_cache_path,
"hip_cache_size_bytes": r.system.hip_cache_size_bytes,
"hip_cache_file_count": r.system.hip_cache_file_count,
"os_uname": r.system.os_uname,
},
"command": {
"program": r.command.program,
"args": r.command.args,
"exit_code": r.command.exit_code,
"duration_sec": r.command.duration_sec,
"started_at_unix": r.command.started_at_unix,
},
"aggregate": {
"initial_loss": r.aggregate.initial_loss,
"final_loss": r.aggregate.final_loss,
"loss_reduction_pct": r.aggregate.loss_reduction_pct,
"total_steps": r.aggregate.total_steps,
"time_to_50pct_reduction_step": r.aggregate.time_to_50pct_reduction_step,
"min_loss": r.aggregate.min_loss,
"max_loss": r.aggregate.max_loss,
"loss_curve_monotonic_after_step": r.aggregate.loss_curve_monotonic_after_step,
"test_passed": r.aggregate.test_passed,
},
"samples": r.samples.iter().map(|s| json!({
"step": s.step,
"loss": if s.loss.is_nan() { Value::Null } else { json!(s.loss) },
"pre_clip_norm": match s.pre_clip_norm {
Some(v) => json!(v),
None => Value::Null,
},
})).collect::<Vec<_>>(),
"raw_stderr_tail": r.raw_stderr_tail,
})
}
fn iso8601_now() -> String {
let secs = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
format!("1970-01-01T00:00:{secs}Z")
}
fn main() -> std::process::ExitCode {
let cli = match parse_args() {
Ok(c) => c,
Err(e) => {
eprintln!("collect_paper_artifacts: {e}");
eprintln!();
print_usage();
return std::process::ExitCode::from(2);
}
};
if cli.show_help {
print_usage();
return std::process::ExitCode::SUCCESS;
}
let operator_root = env::var("CARGO_MANIFEST_DIR")
.map(PathBuf::from)
.unwrap_or_else(|_| env::current_dir().unwrap_or_else(|_| PathBuf::from(".")));
let search_root = resolve_search_root(&cli);
let out_dir = resolve_out_dir(&cli, search_root.as_deref());
let hip_cache = hip_cache_dir();
if let Err(e) = fs::create_dir_all(&out_dir) {
eprintln!(
"collect_paper_artifacts: cannot create out-dir {}: {e}",
out_dir.display()
);
return std::process::ExitCode::from(1);
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let artifact_path = out_dir.join(format!("mlp_10k_run_{now}.jsonl"));
let system = collect_system_context(&operator_root, &hip_cache);
if !cli.dry_run {
maybe_clean_hip_cache(&hip_cache, cli.keep_cache);
}
let (command, stderr, passed) = if cli.dry_run {
(
RunCommand {
program: "<dry-run>".to_string(),
args: Vec::new(),
exit_code: None,
duration_sec: 0.0,
started_at_unix: now,
},
String::new(),
false,
)
} else {
run_test_subprocess(&cli, &operator_root)
};
let (post_size, post_count) = if cli.dry_run {
(system.hip_cache_size_bytes, system.hip_cache_file_count)
} else {
dir_size_and_count(&hip_cache)
};
let mut system = system;
system.hip_cache_size_bytes = post_size;
system.hip_cache_file_count = post_count;
let samples = parse_samples(&stderr);
let aggregate = compute_aggregate(&samples, passed);
let tail_start = stderr.len().saturating_sub(4096);
let raw_stderr_tail = stderr[tail_start..].to_string();
let record = RunRecord {
header: ArtifactHeader {
schema_version: "mlp_10k_artifact/v1".to_string(),
artifact_kind: "mlp_10k_loss_decreases_50pct".to_string(),
generated_at_unix: now,
generated_at_iso: iso8601_now(),
artifact_path: artifact_path.display().to_string(),
},
system,
command,
samples,
aggregate,
raw_stderr_tail,
};
if let Err(e) = write_jsonl(&record, &artifact_path) {
eprintln!(
"collect_paper_artifacts: failed to write artifact {}: {e}",
artifact_path.display()
);
return std::process::ExitCode::from(1);
}
let mut stdout = std::io::stdout();
let _ = writeln!(stdout, "{}", artifact_path.display());
if passed {
std::process::ExitCode::SUCCESS
} else {
eprintln!(
"collect_paper_artifacts: test did not pass (exit_code={:?}); artifact written to {}",
record.command.exit_code,
artifact_path.display()
);
std::process::ExitCode::from(1)
}
}