#![forbid(unsafe_code)]
use anyhow::Result;
use clap::Parser;
use dsfb_database::grammar::{MotifClass, MotifEngine, MotifGrammar, MotifParams};
use dsfb_database::metrics::{evaluate, PerMotifMetrics};
use dsfb_database::non_claims;
use dsfb_database::perturbation::tpcds_with_perturbations;
use dsfb_database::report::plots;
use dsfb_database::residual::ResidualStream;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Parser)]
#[command(
name = "pr_sweep",
about = "Phase-A2: precision/recall/F1 sweep across the motif envelope grid.",
version
)]
struct Cli {
#[arg(long, default_value_t = 42)]
seed: u64,
#[arg(long, default_value = "out")]
out: PathBuf,
}
const FACTORS: &[f64] = &[0.25, 0.50, 0.75, 1.00, 1.25, 1.50, 2.00, 3.00];
fn samples_per_motif(stream: &ResidualStream) -> HashMap<MotifClass, usize> {
let mut h = HashMap::new();
for m in MotifClass::ALL {
h.insert(m, stream.iter_class(m.residual_class()).count());
}
h
}
fn grammar_with_override(
target: MotifClass,
drift_factor: f64,
slew_factor: f64,
baseline: &MotifGrammar,
) -> MotifGrammar {
let mut g = baseline.clone();
let base = MotifParams::default_for(target);
let new = MotifParams {
drift_threshold: base.drift_threshold * drift_factor,
slew_threshold: base.slew_threshold * slew_factor,
..base
};
match target {
MotifClass::PlanRegressionOnset => g.plan_regression_onset = new,
MotifClass::CardinalityMismatchRegime => g.cardinality_mismatch_regime = new,
MotifClass::ContentionRamp => g.contention_ramp = new,
MotifClass::CacheCollapse => g.cache_collapse = new,
MotifClass::WorkloadPhaseTransition => g.workload_phase_transition = new,
}
g
}
fn run_grid_point(
target: MotifClass,
drift_factor: f64,
slew_factor: f64,
stream: &ResidualStream,
baseline: &MotifGrammar,
windows: &[dsfb_database::perturbation::PerturbationWindow],
samples: &HashMap<MotifClass, usize>,
) -> Option<PerMotifMetrics> {
let g = grammar_with_override(target, drift_factor, slew_factor, baseline);
let episodes = MotifEngine::new(g).run(stream);
let rows = evaluate(&episodes.clone(), windows, samples, stream.duration());
rows.into_iter().find(|r| r.motif == target.name())
}
fn write_pr_csv(
path: &Path,
rows: &[(f64, f64, PerMotifMetrics)],
seed: u64,
motif: MotifClass,
) -> Result<()> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let mut wtr = csv::Writer::from_path(path)?;
wtr.write_record([
"motif",
"seed",
"drift_factor",
"slew_factor",
"drift_threshold",
"slew_threshold",
"tp",
"fp",
"fn",
"precision",
"recall",
"f1",
])?;
let base = MotifParams::default_for(motif);
for (df, sf, m) in rows {
wtr.write_record([
&m.motif,
&seed.to_string(),
&format!("{:.3}", df),
&format!("{:.3}", sf),
&format!("{:.6}", base.drift_threshold * df),
&format!("{:.6}", base.slew_threshold * sf),
&m.tp.to_string(),
&m.fp.to_string(),
&m.fn_.to_string(),
&format!("{:.6}", m.precision),
&format!("{:.6}", m.recall),
&format!("{:.6}", m.f1),
])?;
}
wtr.flush()?;
Ok(())
}
fn main() -> Result<()> {
let cli = Cli::parse();
non_claims::print();
let (stream, windows) = tpcds_with_perturbations(cli.seed);
let samples = samples_per_motif(&stream);
let baseline_grammar = MotifGrammar::default();
fs::create_dir_all(&cli.out)?;
for target in MotifClass::ALL {
let mut grid_rows: Vec<(f64, f64, PerMotifMetrics)> =
Vec::with_capacity(FACTORS.len() * FACTORS.len());
for &df in FACTORS {
for &sf in FACTORS {
if let Some(m) = run_grid_point(
target,
df,
sf,
&stream,
&baseline_grammar,
&windows,
&samples,
) {
grid_rows.push((df, sf, m));
}
}
}
debug_assert_eq!(
grid_rows.len(),
FACTORS.len() * FACTORS.len(),
"one grid point per (drift_factor, slew_factor)"
);
let csv_path = cli.out.join(format!("pr.{}.csv", target.name()));
write_pr_csv(&csv_path, &grid_rows, cli.seed, target)?;
let baseline_point = grid_rows
.iter()
.find(|(df, sf, _)| (*df - 1.0).abs() < 1e-9 && (*sf - 1.0).abs() < 1e-9)
.map(|(_, _, m)| (m.precision, m.recall));
let plot_rows: Vec<(f64, f64, f64, String)> = grid_rows
.iter()
.map(|(df, sf, m)| {
(
m.precision,
m.recall,
m.f1,
format!("drift*{df:.2}, slew*{sf:.2}"),
)
})
.collect();
let png_path = cli.out.join(format!("pr.{}.png", target.name()));
plots::plot_pr_curve(
&png_path,
&format!("PR sweep: {}", target.name()),
&plot_rows,
baseline_point,
)?;
eprintln!(
"pr_sweep[{}]: {} points, wrote {} + {}",
target.name(),
grid_rows.len(),
csv_path.display(),
png_path.display()
);
}
Ok(())
}