use std::fs::File;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Instant;
use chrono::{DateTime, Local, TimeZone};
use clap::Parser;
use rand::SeedableRng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use ripmap::training::{
Agent,
CURATED_REPOS,
CaseMetrics,
EvalMetrics,
ParameterGrid,
ParameterPoint,
RankingFailure,
RepoSpec,
Scratchpad,
SearchStrategy,
SensitivityAnalysis,
WeightedCase,
apply_changes,
bayesian_next_sample,
compute_coupling_weights,
distill_scratchpad,
extract_cases,
full_analysis,
print_scratchpad_summary,
print_summary,
quick_repos,
reason_about_failures,
sample_points,
update_scratchpad,
weight_cases,
};
#[derive(Parser, Debug)]
#[command(name = "ripmap-bench")]
#[command(about = "Hyperparameter optimization for ripmap ranking")]
struct Args {
#[arg(long)]
repo: Option<PathBuf>,
#[arg(long)]
corpus: Option<String>,
#[arg(long, default_value = "lhs")]
strategy: String,
#[arg(long, default_value = "100")]
budget: usize,
#[arg(long, default_value = "42")]
seed: u64,
#[arg(long, default_value = "training/runs/default/results.json")]
output: PathBuf,
#[arg(long)]
sensitivity: bool,
#[arg(long)]
config: Option<PathBuf>,
#[arg(long)]
extract_only: bool,
#[arg(long, default_value = "500")]
max_commits: usize,
#[arg(long, default_value = "2")]
min_files: usize,
#[arg(long, default_value = "12")]
max_files: usize,
#[arg(long, default_value = "./training/corpus")]
clone_dir: PathBuf,
#[arg(short, long)]
verbose: bool,
#[arg(long)]
distractors: Option<PathBuf>,
#[arg(long)]
prompt: Option<PathBuf>,
#[arg(long)]
reason: bool,
#[arg(long, default_value = "10")]
episodes: usize,
#[arg(long, default_value = "0.5")]
failure_threshold: f64,
#[arg(long, default_value = "training/runs/default/scratchpad.json")]
scratchpad: PathBuf,
#[arg(long)]
distill: bool,
#[arg(long)]
plot: Option<PathBuf>,
#[arg(long, default_value = "claude")]
agent: String,
#[arg(long, short = 'm')]
model: Option<String>,
#[arg(long, default_value = "1")]
save_interval: usize,
#[arg(long)]
run_name: Option<String>,
#[arg(long)]
show: Option<String>,
#[arg(long)]
show_insights: Option<String>,
#[arg(long)]
show_interactions: Option<String>,
#[arg(long)]
list: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SemanticDistractorCase {
seed_file: String,
#[serde(default)]
expected_related: Vec<String>,
#[serde(default)]
commit_message: String,
semantic_distractors: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SemanticDistractorRepo {
repo: String,
n_cases: usize,
cases: Vec<SemanticDistractorCase>,
}
type DistractorLookup = std::collections::HashMap<String, Vec<String>>;
fn load_distractors(path: &Path) -> anyhow::Result<DistractorLookup> {
let file = File::open(path)?;
let repos: Vec<SemanticDistractorRepo> = serde_json::from_reader(file)?;
let mut lookup = DistractorLookup::new();
for repo in repos {
for case in repo.cases {
lookup.insert(case.seed_file, case.semantic_distractors);
}
}
println!("Loaded {} semantic distractor entries", lookup.len());
Ok(lookup)
}
#[derive(Debug, Serialize, Deserialize)]
struct BenchmarkResults {
evaluations: Vec<(ParameterPoint, EvalMetrics)>,
best_config: ParameterPoint,
best_score: f64,
sensitivity: Option<SensitivityAnalysis>,
n_cases: usize,
n_repos: usize,
total_time_secs: f64,
strategy: String,
}
struct RunInfo {
name: String,
episodes: usize,
first_ndcg: f64,
last_ndcg: f64,
delta: f64,
start_ts: i64,
end_ts: i64,
}
fn format_timestamp(ts: i64) -> String {
if ts == 0 {
return "—".to_string();
}
Local
.timestamp_opt(ts, 0)
.single()
.map(|dt: DateTime<Local>| dt.format("%b %d %H:%M").to_string())
.unwrap_or_else(|| "—".to_string())
}
fn list_training_runs() -> anyhow::Result<()> {
let runs_dir = PathBuf::from("training/runs");
if !runs_dir.exists() {
println!("No training runs found. Run with --run-name to create one.");
return Ok(());
}
let mut runs: Vec<RunInfo> = Vec::new();
let mut empty_runs: Vec<(String, i64)> = Vec::new();
for entry in std::fs::read_dir(&runs_dir)?.filter_map(|e| e.ok()) {
if !entry.path().is_dir() {
continue;
}
let name = entry.file_name().to_string_lossy().to_string();
let scratchpad_path = entry.path().join("scratchpad.json");
let file_mtime = scratchpad_path
.metadata()
.and_then(|m| m.modified())
.ok()
.and_then(|t| t.duration_since(std::time::UNIX_EPOCH).ok())
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
if scratchpad_path.exists() {
if let Ok(file) = File::open(&scratchpad_path) {
if let Ok(sp) = serde_json::from_reader::<_, Scratchpad>(file) {
let eps = sp.episodes.len();
if eps > 0 {
let first_ndcg = sp.episodes.first().map(|e| e.ndcg_before).unwrap_or(0.0);
let last_ndcg = sp.episodes.last().map(|e| e.ndcg_before).unwrap_or(0.0);
let delta = last_ndcg - first_ndcg;
let start_ts = sp
.episodes
.first()
.map(|e| {
if e.timestamp > 0 {
e.timestamp
} else {
file_mtime
}
})
.unwrap_or(file_mtime);
let end_ts = sp
.episodes
.last()
.map(|e| {
if e.timestamp > 0 {
e.timestamp
} else {
file_mtime
}
})
.unwrap_or(file_mtime);
runs.push(RunInfo {
name,
episodes: eps,
first_ndcg,
last_ndcg,
delta,
start_ts,
end_ts,
});
continue;
}
}
}
}
empty_runs.push((name, file_mtime));
}
runs.sort_by_key(|r| r.start_ts);
empty_runs.sort_by_key(|(_, ts)| *ts);
use owo_colors::OwoColorize;
println!();
println!("{}", "TRAINING RUNS".bold());
println!("{}", "─".repeat(90));
println!(
"{:26} {:>4} {:^17} {:>8} {:>24}",
"NAME", "EPS", "NDCG TRAJECTORY", "DELTA", "STARTED -> LAST"
);
println!("{}", "─".repeat(90));
for run in &runs {
let (trend, delta_colored) = if run.delta > 0.01 {
("+", format!("{:>+8.4}", run.delta).green().to_string())
} else if run.delta < -0.01 {
("-", format!("{:>+8.4}", run.delta).red().to_string())
} else {
("=", format!("{:>+8.4}", run.delta).dimmed().to_string())
};
let start_str = format_timestamp(run.start_ts);
let end_str = format_timestamp(run.end_ts);
let time_range = format!("{} -> {}", start_str, end_str);
println!(
"{:26} {:>4} {:.3} {} {:.3} {} {}",
run.name,
run.episodes,
run.first_ndcg,
trend,
run.last_ndcg,
delta_colored,
time_range.dimmed()
);
}
for (name, mtime) in &empty_runs {
let ts_str = format_timestamp(*mtime);
println!(
"{:26} - (no data) {:>24}",
name, ts_str
);
}
println!("{}", "─".repeat(90));
println!("\nUse --show <run-name> to see detailed visualization.\n");
Ok(())
}
fn show_training_run(path: &str) -> anyhow::Result<()> {
use comfy_table::{Cell, ContentArrangement, Table, presets::UTF8_FULL_CONDENSED};
use owo_colors::OwoColorize;
let scratchpad_path = if path.ends_with(".json") {
PathBuf::from(path)
} else {
PathBuf::from(format!("training/runs/{}/scratchpad.json", path))
};
if !scratchpad_path.exists() {
anyhow::bail!("Scratchpad not found: {}", scratchpad_path.display());
}
let file = File::open(&scratchpad_path)?;
let scratchpad: Scratchpad = serde_json::from_reader(file)?;
if scratchpad.episodes.is_empty() {
println!("No episodes found in scratchpad.");
return Ok(());
}
let run_name = scratchpad_path
.parent()
.and_then(|p| p.file_name())
.map(|s| s.to_string_lossy().to_string())
.unwrap_or_else(|| "unknown".to_string());
println!();
println!(
"{}",
format!(" TRAINING RUN: {} ", run_name).bold().on_blue()
);
println!(" Episodes: {}", scratchpad.episodes.len());
println!();
let ndcgs: Vec<f64> = scratchpad.episodes.iter().map(|e| e.ndcg_before).collect();
let min_ndcg = ndcgs.iter().cloned().fold(f64::INFINITY, f64::min);
let max_ndcg = ndcgs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let range = (max_ndcg - min_ndcg).max(0.001);
let spark_chars = ['▁', '▂', '▃', '▄', '▅', '▆', '▇', '█'];
let sparkline: String = ndcgs
.iter()
.map(|&n| {
let normalized = ((n - min_ndcg) / range * 7.0).round() as usize;
spark_chars[normalized.min(7)]
})
.collect();
let delta = ndcgs.last().unwrap_or(&0.0) - ndcgs.first().unwrap_or(&0.0);
let delta_str = if delta > 0.0 {
format!("{:+.4}", delta).green().to_string()
} else if delta < 0.0 {
format!("{:+.4}", delta).red().to_string()
} else {
format!("{:+.4}", delta).dimmed().to_string()
};
println!("{}", "NDCG TRAJECTORY".bold());
println!(
" {:.3} {} {:.3} Δ = {}",
min_ndcg,
sparkline.cyan(),
max_ndcg,
delta_str
);
println!();
println!("{}", "EPISODES".bold());
let mut table = Table::new();
table.load_preset(UTF8_FULL_CONDENSED);
table.set_content_arrangement(ContentArrangement::Dynamic);
table.set_header(vec![
"#",
"Trend",
"NDCG",
"Fail",
"Conf",
"Strategy / Changes / Insight",
]);
for (i, ep) in scratchpad.episodes.iter().enumerate() {
let ep_num = i + 1;
let trend = if i == 0 {
"·".to_string()
} else {
let prev = scratchpad.episodes[i - 1].ndcg_before;
if ep.ndcg_before > prev + 0.005 {
"↗".green().to_string()
} else if ep.ndcg_before < prev - 0.005 {
"↘".red().to_string()
} else {
"→".dimmed().to_string()
}
};
let mut desc_parts: Vec<String> = Vec::new();
if !ep.strategy_capsule.is_empty() {
let capsule = if ep.strategy_capsule.len() > 60 {
format!("⟨{}...⟩", &ep.strategy_capsule[..57])
} else {
format!("⟨{}⟩", &ep.strategy_capsule)
};
desc_parts.push(capsule);
}
if !ep.proposed_changes.is_empty() {
let changes: Vec<String> = ep
.proposed_changes
.iter()
.map(|(k, (dir, mag, _))| {
let arrow = if dir == "increase" { "↑" } else { "↓" };
let mag_short = mag
.chars()
.next()
.map(|c| c.to_uppercase().to_string())
.unwrap_or_default();
format!("{}{}{}", k, arrow, mag_short)
})
.collect();
desc_parts.push(changes.join(", "));
}
if !ep.structural_insights.is_empty() {
let insight = &ep.structural_insights[0];
let truncated = if insight.len() > 55 {
format!("💡 {}...", &insight[..52])
} else {
format!("💡 {}", insight)
};
desc_parts.push(truncated);
}
table.add_row(vec![
Cell::new(format!("E{:02}", ep_num)),
Cell::new(&trend),
Cell::new(format!("{:.4}", ep.ndcg_before)),
Cell::new(format!("{}", ep.failures.len())),
Cell::new(format!("{:.2}", ep.confidence)),
Cell::new(desc_parts.join("\n")),
]);
}
println!("{table}");
println!();
if let Some(last_ep) = scratchpad.episodes.last() {
println!("{}", "FINAL PARAMETERS".bold());
let mut ptable = Table::new();
ptable.load_preset(UTF8_FULL_CONDENSED);
ptable.set_content_arrangement(ContentArrangement::Dynamic);
ptable.set_header(vec!["Category", "Parameters"]);
let p = &last_ep.params;
ptable.add_row(vec![
"PageRank",
&format!(
"α={:.3} chat_mult={:.1}",
p.pagerank_alpha, p.pagerank_chat_multiplier
),
]);
ptable.add_row(vec![
"Depth",
&format!(
"root={:.2} mod={:.2} deep={:.2} vendor={:.4}",
p.depth_weight_root,
p.depth_weight_moderate,
p.depth_weight_deep,
p.depth_weight_vendor
),
]);
ptable.add_row(vec![
"Boosts",
&format!(
"ident={:.1} file={:.1} chat={:.1} temp={:.2} focus={:.2}",
p.boost_mentioned_ident,
p.boost_mentioned_file,
p.boost_chat_file,
p.boost_temporal_coupling,
p.boost_focus_expansion
),
]);
ptable.add_row(vec![
"Git",
&format!(
"decay={:.0}d recency_max={:.1} churn_th={:.0} churn_max={:.1}",
p.git_recency_decay_days,
p.git_recency_max_boost,
p.git_churn_threshold,
p.git_churn_max_boost
),
]);
ptable.add_row(vec![
"Focus",
&format!(
"decay={:.2} max_hops={:.0}",
p.focus_decay, p.focus_max_hops
),
]);
println!("{ptable}");
println!();
}
Ok(())
}
fn resolve_scratchpad_path(path: &str) -> PathBuf {
if path.ends_with(".json") {
PathBuf::from(path)
} else {
PathBuf::from(format!("training/runs/{}/scratchpad.json", path))
}
}
fn show_insights_pivot(path: &str) -> anyhow::Result<()> {
use std::collections::HashMap;
let scratchpad_path = resolve_scratchpad_path(path);
if !scratchpad_path.exists() {
anyhow::bail!("Scratchpad not found: {}", scratchpad_path.display());
}
let file = File::open(&scratchpad_path)?;
let scratchpad: Scratchpad = serde_json::from_reader(file)?;
let mut insight_episodes: HashMap<String, Vec<usize>> = HashMap::new();
for (i, ep) in scratchpad.episodes.iter().enumerate() {
for insight in &ep.structural_insights {
insight_episodes
.entry(insight.clone())
.or_default()
.push(i + 1);
}
}
let mut sorted: Vec<_> = insight_episodes.iter().collect();
sorted.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
println!("\nSTRUCTURAL INSIGHTS (pivot view)");
println!("─────────────────────────────────────────────────────────────────────────────────");
println!(
"{} unique insights from {} episodes\n",
sorted.len(),
scratchpad.episodes.len()
);
for (insight, episodes) in sorted.iter().take(30) {
let ep_list = if episodes.len() <= 5 {
episodes
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ")
} else {
format!(
"{}, ... (+{} more)",
episodes[..3]
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", "),
episodes.len() - 3
)
};
println!("[E{}] {}", ep_list, insight);
println!();
}
if sorted.len() > 30 {
println!("... and {} more insights", sorted.len() - 30);
}
Ok(())
}
fn show_interactions_pivot(path: &str) -> anyhow::Result<()> {
use std::collections::HashMap;
let scratchpad_path = resolve_scratchpad_path(path);
if !scratchpad_path.exists() {
anyhow::bail!("Scratchpad not found: {}", scratchpad_path.display());
}
let file = File::open(&scratchpad_path)?;
let scratchpad: Scratchpad = serde_json::from_reader(file)?;
let mut interaction_episodes: HashMap<String, Vec<usize>> = HashMap::new();
for (i, ep) in scratchpad.episodes.iter().enumerate() {
for interaction in &ep.param_interactions {
interaction_episodes
.entry(interaction.clone())
.or_default()
.push(i + 1);
}
}
let mut sorted: Vec<_> = interaction_episodes.iter().collect();
sorted.sort_by(|a, b| b.1.len().cmp(&a.1.len()));
println!("\nPARAMETER INTERACTIONS (pivot view)");
println!("─────────────────────────────────────────────────────────────────────────────────");
println!(
"{} unique interactions from {} episodes\n",
sorted.len(),
scratchpad.episodes.len()
);
for (interaction, episodes) in sorted.iter().take(30) {
let ep_list = if episodes.len() <= 5 {
episodes
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", ")
} else {
format!(
"{}, ... (+{} more)",
episodes[..3]
.iter()
.map(|e| e.to_string())
.collect::<Vec<_>>()
.join(", "),
episodes.len() - 3
)
};
println!("[E{}] {}", ep_list, interaction);
println!();
}
if sorted.len() > 30 {
println!("... and {} more interactions", sorted.len() - 30);
}
Ok(())
}
fn main() -> anyhow::Result<()> {
let mut args = Args::parse();
if args.list {
return list_training_runs();
}
if let Some(ref path) = args.show {
return show_training_run(path);
}
if let Some(ref path) = args.show_insights {
return show_insights_pivot(path);
}
if let Some(ref path) = args.show_interactions {
return show_interactions_pivot(path);
}
if let Some(ref name) = args.run_name {
let run_dir = PathBuf::from(format!("training/runs/{}", name));
std::fs::create_dir_all(&run_dir)?;
std::fs::create_dir_all(run_dir.join("checkpoints"))?;
args.output = run_dir.join("results.json");
args.scratchpad = run_dir.join("scratchpad.json");
if args.plot.is_none() {
args.plot = Some(run_dir.join("progress.png"));
}
println!("Training run: {}", name);
println!(" Output dir: {}", run_dir.display());
}
let repos: Vec<PathBuf> = if let Some(ref repo) = args.repo {
vec![repo.clone()]
} else if let Some(ref corpus) = args.corpus {
let specs = match corpus.as_str() {
"quick" => quick_repos(),
"curated" => CURATED_REPOS.iter().collect(),
other => {
eprintln!(
"Error: Unknown corpus '{}'. Use 'quick' or 'curated'",
other
);
std::process::exit(1);
}
};
ensure_repos_cloned(&specs, &args.clone_dir)?
} else {
eprintln!("Error: Specify --repo or --corpus");
std::process::exit(1);
};
println!("Benchmarking {} repositories", repos.len());
let start = Instant::now();
let mut all_cases: Vec<WeightedCase> = Vec::new();
for repo_path in &repos {
println!("\nProcessing {}...", repo_path.display());
let cases = extract_cases(repo_path, args.max_commits, args.min_files, args.max_files);
println!(" Extracted {} raw cases", cases.len());
let coupling = compute_coupling_weights(repo_path, args.max_commits);
println!(" Computed {} coupling pairs", coupling.len());
let weighted = weight_cases(cases, &coupling);
println!(" Weighted cases: {}", weighted.len());
all_cases.extend(weighted);
}
println!("\nTotal training cases: {}", all_cases.len());
let distractors: Option<Arc<DistractorLookup>> = if let Some(ref path) = args.distractors {
match load_distractors(path) {
Ok(lookup) => Some(Arc::new(lookup)),
Err(e) => {
eprintln!(
"Warning: Failed to load distractors: {}. Using synthetic.",
e
);
None
}
}
} else {
None
};
if args.extract_only {
let file = File::create(&args.output)?;
serde_json::to_writer_pretty(file, &all_cases)?;
println!("Saved cases to {}", args.output.display());
return Ok(());
}
if all_cases.is_empty() {
eprintln!("Error: No training cases extracted. Check repository paths.");
std::process::exit(1);
}
if args.reason {
return run_reasoning_training(&args, &all_cases, distractors.as_deref());
}
if args.distill {
return run_distillation(&args);
}
let strategy = match args.strategy.as_str() {
"grid" => SearchStrategy::Grid { points_per_dim: 3 },
"lhs" => SearchStrategy::LatinHypercube,
"random" => SearchStrategy::Random,
"bayesian" => SearchStrategy::Bayesian,
_ => {
eprintln!("Unknown strategy: {}. Using LHS.", args.strategy);
SearchStrategy::LatinHypercube
}
};
let grid = ParameterGrid::default();
let mut points = sample_points(&grid, strategy, args.budget, args.seed);
use indicatif::{ProgressBar, ProgressStyle};
let cases = Arc::new(all_cases);
let distractors_ref = distractors.clone();
let pb = ProgressBar::new(points.len() as u64);
pb.set_style(
ProgressStyle::with_template(
"{prefix:.bold} {bar:40.cyan/dim} {pos}/{len} [{elapsed}<{eta}] {msg}",
)
.unwrap(),
);
pb.set_prefix("Evaluating");
let best_ndcg = std::sync::atomic::AtomicU64::new(0);
let evaluations: Vec<(ParameterPoint, EvalMetrics)> = points
.par_iter()
.map(|point| {
let metrics = evaluate_point(point, &cases, distractors_ref.as_deref());
let current = (metrics.ndcg_at_10 * 10000.0) as u64;
best_ndcg.fetch_max(current, std::sync::atomic::Ordering::Relaxed);
let best = best_ndcg.load(std::sync::atomic::Ordering::Relaxed) as f64 / 10000.0;
pb.set_message(format!("best={:.4}", best));
pb.inc(1);
(point.clone(), metrics)
})
.collect();
pb.finish_with_message(format!(
"best={:.4}",
best_ndcg.load(std::sync::atomic::Ordering::Relaxed) as f64 / 10000.0
));
let mut evaluations = evaluations; if matches!(strategy, SearchStrategy::Bayesian) && evaluations.len() < args.budget {
let remaining = args.budget - evaluations.len();
let pb = ProgressBar::new(remaining as u64);
pb.set_style(
ProgressStyle::with_template(
"{prefix:.bold} {bar:40.green/dim} {pos}/{len} [{elapsed}<{eta}] {msg}",
)
.unwrap(),
);
pb.set_prefix("Bayesian");
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
let history: Vec<_> = evaluations
.iter()
.map(|(p, m)| (p.clone(), m.ndcg_at_10))
.collect();
let mut best_so_far = evaluations
.iter()
.map(|(_, m)| m.ndcg_at_10)
.fold(0.0_f64, f64::max);
for _ in evaluations.len()..args.budget {
let next = bayesian_next_sample(&grid, &history, &mut rng);
let metrics = evaluate_point(&next, &cases, distractors.as_deref());
if metrics.ndcg_at_10 > best_so_far {
best_so_far = metrics.ndcg_at_10;
}
pb.set_message(format!("best={:.4}", best_so_far));
pb.inc(1);
evaluations.push((next, metrics));
}
pb.finish_with_message(format!("best={:.4}", best_so_far));
}
let (best_config, best_metrics) = evaluations
.iter()
.max_by(|a, b| a.1.ndcg_at_10.partial_cmp(&b.1.ndcg_at_10).unwrap())
.cloned()
.expect("No evaluations");
let elapsed = start.elapsed().as_secs_f64();
println!("\n=== Results ===\n");
println!("Best NDCG@10: {:.4}", best_metrics.ndcg_at_10);
println!("Best NDCG@5: {:.4}", best_metrics.ndcg_at_5);
println!("Best MRR: {:.4}", best_metrics.mrr);
println!("Best P@10: {:.4}", best_metrics.precision_at_10);
println!("\nTotal time: {:.1}s", elapsed);
println!("\n=== Best Configuration ===\n");
print_config(&best_config);
let sensitivity = if args.sensitivity {
println!("\n=== Running Sensitivity Analysis ===\n");
let distractors_ref = distractors.as_deref();
let evaluator = |p: &ParameterPoint| evaluate_point(p, &cases, distractors_ref).ndcg_at_10;
let analysis = full_analysis(&best_config, evaluator);
print_summary(&analysis);
Some(analysis)
} else {
None
};
let results = BenchmarkResults {
evaluations: evaluations.clone(),
best_config: best_config.clone(),
best_score: best_metrics.ndcg_at_10,
sensitivity,
n_cases: cases.len(),
n_repos: repos.len(),
total_time_secs: elapsed,
strategy: args.strategy,
};
let file = File::create(&args.output)?;
serde_json::to_writer_pretty(file, &results)?;
println!("\nResults saved to {}", args.output.display());
let config_path = args.output.with_extension("best.json");
let config_file = File::create(&config_path)?;
serde_json::to_writer_pretty(config_file, &best_config)?;
println!("Best config saved to {}", config_path.display());
Ok(())
}
fn evaluate_point(
point: &ParameterPoint,
cases: &[WeightedCase],
distractors: Option<&DistractorLookup>,
) -> EvalMetrics {
let per_case: Vec<CaseMetrics> = cases
.iter()
.map(|case| {
let ranking = simulate_ranking(point, case, distractors);
let ground_truth: Vec<_> = case.expected_related.iter().cloned().collect();
CaseMetrics::compute(&ranking, &ground_truth, 0.1)
})
.collect();
let weighted: Vec<_> = per_case
.iter()
.zip(cases.iter())
.map(|(m, c)| (m.clone(), c.case_weight))
.collect();
EvalMetrics::aggregate_weighted(&weighted)
}
fn simulate_ranking(
point: &ParameterPoint,
case: &WeightedCase,
distractors: Option<&DistractorLookup>,
) -> Vec<String> {
use rand::Rng;
let mut rng = rand::thread_rng();
let mut scored: Vec<_> = case
.expected_related
.iter()
.map(|(file, coupling_weight)| {
let score = score_file(point, file, *coupling_weight, &mut rng);
(file.clone(), score, true) })
.collect();
let distractor_paths: Vec<String> = if let Some(lookup) = distractors {
lookup
.get(&case.seed_file)
.cloned()
.unwrap_or_else(|| generate_synthetic_distractors(case.expected_related.len()))
} else {
generate_synthetic_distractors(case.expected_related.len())
};
for distractor in distractor_paths {
let score = score_file(point, &distractor, 0.0, &mut rng);
scored.push((distractor, score, false));
}
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().map(|(f, _, _)| f).collect()
}
fn generate_synthetic_distractors(n_ground_truth: usize) -> Vec<String> {
let n_distractors = (n_ground_truth * 3).max(10);
(0..n_distractors)
.map(|i| match i % 5 {
0 => format!("src/distractor_{}.rs", i), 1 => format!("src/utils/helper_{}.rs", i), 2 => format!("src/core/internal/deep_{}.rs", i), 3 => format!("lib/mod_{}.rs", i), _ => format!("tests/test_{}.rs", i), })
.collect()
}
fn score_file(
point: &ParameterPoint,
file: &str,
coupling_weight: f64,
rng: &mut impl rand::Rng,
) -> f64 {
let mut score = coupling_weight;
let depth = file.matches('/').count();
let depth_mult = if depth <= 2 {
point.depth_weight_root
} else if depth <= 4 {
point.depth_weight_moderate
} else {
point.depth_weight_deep
};
if coupling_weight > 0.0 {
score *= depth_mult;
} else {
score = 0.05 * depth_mult;
}
let is_recent = if coupling_weight > 0.0 {
rng.r#gen::<f64>() < 0.7
} else {
rng.r#gen::<f64>() < 0.2
};
if is_recent {
let recency_boost =
1.0 + (point.git_recency_max_boost - 1.0) * (1.0 - rng.r#gen::<f64>() * 0.3); score *= recency_boost;
}
let is_high_churn = if coupling_weight > 0.0 {
rng.r#gen::<f64>() < 0.4
} else {
rng.r#gen::<f64>() < 0.1
};
if is_high_churn {
score *= 1.0 + (point.git_churn_max_boost - 1.0) * 0.5;
}
score *= 0.85 + 0.3 * rng.r#gen::<f64>();
score
}
fn ensure_repos_cloned(specs: &[&RepoSpec], base_dir: &Path) -> anyhow::Result<Vec<PathBuf>> {
use indicatif::{ProgressBar, ProgressStyle};
std::fs::create_dir_all(base_dir)?;
let mut paths = Vec::new();
for spec in specs {
let repo_dir = base_dir.join(&spec.name);
if !repo_dir.exists() {
let spinner = ProgressBar::new_spinner();
spinner.set_style(
ProgressStyle::with_template("{spinner:.green} {msg}")
.unwrap()
.tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]),
);
spinner.set_message(format!("Cloning {}...", spec.name));
spinner.enable_steady_tick(std::time::Duration::from_millis(80));
let status = std::process::Command::new("git")
.args([
"clone",
"--depth",
&spec.estimated_commits().to_string(),
&spec.url,
repo_dir.to_str().unwrap(),
])
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()?;
if !status.success() {
spinner.finish_with_message(format!("✗ Failed to clone {}", spec.name));
continue;
}
spinner.finish_with_message(format!("✓ Cloned {}", spec.name));
}
paths.push(repo_dir);
}
Ok(paths)
}
fn print_config(point: &ParameterPoint) {
println!("PageRank:");
println!(" alpha: {:.2}", point.pagerank_alpha);
println!(" chat_multiplier: {:.1}", point.pagerank_chat_multiplier);
println!();
println!("Depth Weights:");
println!(" root: {:.2}", point.depth_weight_root);
println!(" moderate: {:.2}", point.depth_weight_moderate);
println!(" deep: {:.2}", point.depth_weight_deep);
println!(" vendor: {:.4}", point.depth_weight_vendor);
println!();
println!("Boosts:");
println!(" mentioned_ident: {:.1}", point.boost_mentioned_ident);
println!(" mentioned_file: {:.1}", point.boost_mentioned_file);
println!(" chat_file: {:.1}", point.boost_chat_file);
println!(" temporal: {:.1}", point.boost_temporal_coupling);
println!(" focus_expand: {:.1}", point.boost_focus_expansion);
println!();
println!("Git:");
println!(" recency_decay_days: {:.1}", point.git_recency_decay_days);
println!(" recency_max_boost: {:.1}", point.git_recency_max_boost);
println!(" churn_threshold: {:.0}", point.git_churn_threshold);
println!(" churn_max_boost: {:.1}", point.git_churn_max_boost);
println!();
println!("Focus Expansion:");
println!(" decay: {:.2}", point.focus_decay);
println!(" max_hops: {:.0}", point.focus_max_hops);
}
fn run_reasoning_training(
args: &Args,
cases: &[WeightedCase],
distractors: Option<&DistractorLookup>,
) -> anyhow::Result<()> {
use ripmap::training::LiveProgress;
let prompt_path = args.prompt.as_ref()
.ok_or_else(|| anyhow::anyhow!("--prompt is required for reasoning mode. Example: --prompt training-outer/prompts/inner/v001.md"))?;
let prompt_template = std::fs::read_to_string(prompt_path).map_err(|e| {
anyhow::anyhow!(
"Failed to read prompt template '{}': {}",
prompt_path.display(),
e
)
})?;
let agent: Agent = args.agent.parse().map_err(|e: String| anyhow::anyhow!(e))?;
let model = args.model.as_deref();
println!("\n=== REASONING-BASED TRAINING ===");
println!("Prompt: {}", prompt_path.display());
println!("Using {} as universal function approximator\n", agent);
let mut scratchpad = if args.scratchpad.exists() {
let file = File::open(&args.scratchpad)?;
serde_json::from_reader(file).unwrap_or_default()
} else {
Scratchpad::default()
};
let mut current_params = if let Some(ref config_path) = args.config {
let file = File::open(config_path)?;
serde_json::from_reader(file)?
} else {
ParameterPoint::default()
};
println!("Starting parameters:");
print_config(¤t_params);
println!("\nRunning {} reasoning episodes...\n", args.episodes);
let mut progress = LiveProgress::new();
for episode_num in 0..args.episodes {
let metrics = evaluate_point(¤t_params, cases, distractors);
let failures =
collect_failures(¤t_params, cases, distractors, args.failure_threshold);
if failures.is_empty() {
println!(
"No failures below threshold {:.2}. Training converged!",
args.failure_threshold
);
break;
}
use indicatif::{ProgressBar, ProgressStyle};
let spinner = ProgressBar::new_spinner();
spinner.set_style(
ProgressStyle::with_template("{prefix:.bold} {spinner:.cyan} {msg}")
.unwrap()
.tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]),
);
spinner.set_prefix(format!("E{:02}", episode_num + 1));
spinner.set_message(format!(
"{} reasoning ({} failures)...",
agent,
failures.len()
));
spinner.enable_steady_tick(std::time::Duration::from_millis(80));
let run_dir = args
.scratchpad
.parent()
.and_then(|p| p.to_str())
.map(|s| s.to_string());
match reason_about_failures(
&prompt_template,
&failures,
¤t_params,
&scratchpad,
metrics.ndcg_at_10,
agent,
model,
run_dir.as_deref(),
) {
Ok(episode) => {
spinner.finish_and_clear();
progress.record(
metrics.ndcg_at_10,
failures.len(),
episode.confidence,
current_params.pagerank_alpha,
);
progress.display(episode_num + 1, args.episodes);
println!(); println!(
"Confidence: {:.2} ⏱ {:.1}s",
episode.confidence, episode.duration_secs
);
println!("Proposed {} changes:", episode.proposed_changes.len());
for (param, (dir, mag, rationale)) in &episode.proposed_changes {
println!(" {} {} {} - \"{}\"", param, dir, mag, rationale);
}
if !episode.structural_insights.is_empty() {
println!("\nStructural insights:");
for insight in &episode.structural_insights {
println!(" • {}", insight);
}
}
if episode.confidence >= 0.3 {
current_params = apply_changes(¤t_params, &episode.proposed_changes);
println!("\nApplied changes. New params:");
print_config(¤t_params);
} else {
println!(
"\nConfidence too low ({:.2}), skipping changes",
episode.confidence
);
}
update_scratchpad(&mut scratchpad, &episode);
}
Err(e) => {
spinner.finish_with_message("failed");
eprintln!("Warning: Reasoning failed: {}", e);
progress.record(
metrics.ndcg_at_10,
failures.len(),
0.0,
current_params.pagerank_alpha,
);
progress.display(episode_num + 1, args.episodes);
println!();
}
}
if (episode_num + 1) % args.save_interval == 0 {
std::fs::create_dir_all(args.scratchpad.parent().unwrap_or(Path::new(".")))?;
let scratchpad_file = File::create(&args.scratchpad)?;
serde_json::to_writer_pretty(scratchpad_file, &scratchpad)?;
let checkpoint_path = if args.run_name.is_some() {
args.output
.parent()
.unwrap()
.join("checkpoints")
.join(format!("ep{:03}.json", episode_num + 1))
} else {
args.output
.with_extension(format!("ep{}.json", episode_num + 1))
};
let checkpoint_file = File::create(&checkpoint_path)?;
serde_json::to_writer_pretty(checkpoint_file, ¤t_params)?;
#[cfg(feature = "plotters")]
if let Some(plot_path) = &args.plot {
use ripmap::training::plots::plot_training_progress;
let _ = plot_training_progress(
&scratchpad,
plot_path.to_str().unwrap_or("training.png"),
);
}
println!(" [checkpoint saved at episode {}]", episode_num + 1);
}
println!();
}
let final_metrics = evaluate_point(¤t_params, cases, distractors);
progress.final_summary();
println!("Final NDCG@10: {:.4}", final_metrics.ndcg_at_10);
println!("Final MRR: {:.4}", final_metrics.mrr);
let durations: Vec<f64> = scratchpad
.episodes
.iter()
.map(|e| e.duration_secs)
.filter(|&d| d > 0.0) .collect();
if !durations.is_empty() {
let total_agent_time: f64 = durations.iter().sum();
let avg_time = total_agent_time / durations.len() as f64;
let min_time = durations.iter().cloned().fold(f64::INFINITY, f64::min);
let max_time = durations.iter().cloned().fold(0.0, f64::max);
println!(
"\n⏱ Agent timing ({} episodes with timing data):",
durations.len()
);
println!(
" Total: {:.1}s ({:.1}m) Avg: {:.1}s Min: {:.1}s Max: {:.1}s",
total_agent_time,
total_agent_time / 60.0,
avg_time,
min_time,
max_time
);
}
std::fs::create_dir_all(args.scratchpad.parent().unwrap_or(Path::new(".")))?;
let scratchpad_file = File::create(&args.scratchpad)?;
serde_json::to_writer_pretty(scratchpad_file, &scratchpad)?;
println!("\nScratchpad saved to {}", args.scratchpad.display());
print_scratchpad_summary(&scratchpad);
let config_path = args.output.with_extension("trained.json");
let config_file = File::create(&config_path)?;
serde_json::to_writer_pretty(config_file, ¤t_params)?;
println!("\nTrained config saved to {}", config_path.display());
if let Some(plot_path) = &args.plot {
#[cfg(feature = "plotters")]
{
use ripmap::training::plots::plot_training_progress;
match plot_training_progress(&scratchpad, plot_path.to_str().unwrap_or("training.png"))
{
Ok(()) => println!("Training chart saved to {}", plot_path.display()),
Err(e) => eprintln!("Failed to generate chart: {}", e),
}
}
#[cfg(not(feature = "plotters"))]
{
eprintln!("Chart generation requires: cargo build --features plotters");
let _ = plot_path; }
}
Ok(())
}
fn collect_failures(
params: &ParameterPoint,
cases: &[WeightedCase],
distractors: Option<&DistractorLookup>,
threshold: f64,
) -> Vec<RankingFailure> {
let mut failures = Vec::new();
for case in cases.iter().take(50) {
let ranking = simulate_ranking(params, case, distractors);
let ground_truth_weighted: Vec<_> = case.expected_related.clone();
let ground_truth_files: Vec<_> = case
.expected_related
.iter()
.map(|(f, _)| f.clone())
.collect();
let metrics = CaseMetrics::compute(&ranking, &ground_truth_weighted, 0.1);
if metrics.ndcg_at_10 < threshold {
failures.push(RankingFailure {
query: case.seed_file.clone(), seed_file: case.seed_file.clone(),
expected_top: ground_truth_files.iter().take(5).cloned().collect(),
actual_top: ranking.iter().take(5).cloned().collect(),
ndcg: metrics.ndcg_at_10,
commit_context: format!("intent: {:?}", case.inferred_intent),
repo_name: "curated".to_string(),
repo_file_count: 100, });
if failures.len() >= 10 {
break; }
}
}
failures
}
fn run_distillation(args: &Args) -> anyhow::Result<()> {
let agent: Agent = args.agent.parse().map_err(|e: String| anyhow::anyhow!(e))?;
let model = args.model.as_deref();
println!("\n=== DISTILLING SCRATCHPAD ===\n");
if !args.scratchpad.exists() {
eprintln!(
"Error: Scratchpad not found at {}",
args.scratchpad.display()
);
std::process::exit(1);
}
let scratchpad_file = File::open(&args.scratchpad)?;
let scratchpad: Scratchpad = serde_json::from_reader(scratchpad_file)?;
println!(
"Loaded {} episodes from scratchpad",
scratchpad.episodes.len()
);
print_scratchpad_summary(&scratchpad);
let run_dir = args
.scratchpad
.parent()
.and_then(|p| p.to_str())
.map(|s| s.to_string());
println!("\nCalling {} for distillation...", agent);
match distill_scratchpad(&scratchpad, agent, model, run_dir.as_deref()) {
Ok(wisdom) => {
println!("\n=== DISTILLED WISDOM ===\n");
println!("{}", wisdom);
let wisdom_path = args.output.with_extension("wisdom.json");
let wisdom_file = File::create(&wisdom_path)?;
wisdom_file.sync_all()?;
std::fs::write(&wisdom_path, &wisdom)?;
println!("\nWisdom saved to {}", wisdom_path.display());
}
Err(e) => {
eprintln!("Error: Distillation failed: {}", e);
}
}
Ok(())
}