use crate::calibration_utils;
use calibration_utils::{
busy_wait_ns, compute_estimation_stats_by_effect, init_effect_injection, CalibrationConfig,
EstimationPoint, TimerBackend, TrialRunner,
};
use tacet::helpers::InputPair;
use tacet::{AttackerModel, Outcome, TimingOracle};
const ESTIMATION_EFFECTS_NS: [u64; 5] = [
50, 100, 200, 500, 1000, ];
#[test]
fn estimation_accuracy_iteration() {
run_estimation_test(
"estimation_accuracy_iteration",
AttackerModel::AdjacentNetwork,
&ESTIMATION_EFFECTS_NS[2..4], );
}
#[test]
fn estimation_accuracy_quick_adjacent_network() {
if std::env::var("CALIBRATION_TIER").as_deref() == Ok("iteration") {
eprintln!("[estimation_accuracy_quick_adjacent_network] Skipped: iteration tier");
return;
}
run_estimation_test(
"estimation_accuracy_quick_adjacent_network",
AttackerModel::AdjacentNetwork,
&ESTIMATION_EFFECTS_NS[..4], );
}
#[test]
#[ignore]
fn estimation_accuracy_validation_adjacent_network() {
std::env::set_var("CALIBRATION_TIER", "validation");
run_estimation_test(
"estimation_accuracy_validation_adjacent_network",
AttackerModel::AdjacentNetwork,
&ESTIMATION_EFFECTS_NS,
);
}
#[test]
#[ignore]
fn estimation_accuracy_validation_pmu() {
if !TimerBackend::cycle_accurate_available() {
eprintln!("[estimation_accuracy_validation_pmu] Skipped: PMU timer not available");
return;
}
std::env::set_var("CALIBRATION_TIER", "validation");
run_estimation_test(
"estimation_accuracy_validation_pmu",
AttackerModel::AdjacentNetwork,
&ESTIMATION_EFFECTS_NS,
);
}
#[test]
#[ignore]
fn estimation_accuracy_validation_remote_network() {
std::env::set_var("CALIBRATION_TIER", "validation");
let effects: [u64; 4] = [25_000, 50_000, 100_000, 250_000];
run_estimation_test(
"estimation_accuracy_validation_remote_network",
AttackerModel::RemoteNetwork,
&effects,
);
}
fn run_estimation_test(test_name: &str, attacker_model: AttackerModel, effects: &[u64]) {
init_effect_injection();
if CalibrationConfig::is_disabled() {
eprintln!("[{}] Skipped: CALIBRATION_DISABLED=1", test_name);
return;
}
let config = CalibrationConfig::from_env(test_name);
let trials_per_effect = config.tier.estimation_trials_per_effect();
eprintln!(
"[{}] Starting estimation accuracy test (tier: {}, {} trials per effect)",
test_name, config.tier, trials_per_effect
);
let mut all_points: Vec<EstimationPoint> = Vec::new();
let mut any_failed = false;
let model_name = format!("{:?}", attacker_model);
for &effect_ns in effects {
let sub_test_name = format!("{}_{}ns", test_name, effect_ns);
let mut runner = TrialRunner::new(&sub_test_name, config.clone(), trials_per_effect)
.with_export_info(effect_ns as f64, &model_name);
eprintln!("\n[{}] Testing effect = {}ns", test_name, effect_ns);
for trial in 0..trials_per_effect {
if runner.should_stop() {
eprintln!("[{}] Early stop at trial {}", sub_test_name, trial);
break;
}
let inputs = InputPair::new(|| 0u64, || effect_ns);
let outcome = TimingOracle::for_attacker(attacker_model)
.max_samples(config.samples_per_trial)
.time_budget(config.time_budget_per_trial)
.test(inputs, move |&effect| {
busy_wait_ns(2000 + effect);
});
runner.record(&outcome);
if let Some(point) = extract_estimation_point(&outcome, effect_ns as f64) {
all_points.push(point);
}
if (trial + 1) % 20 == 0 || trial + 1 == trials_per_effect {
eprintln!(
" Trial {}/{}: {} points collected",
trial + 1,
trials_per_effect,
all_points
.iter()
.filter(|p| (p.true_effect_ns - effect_ns as f64).abs() < 1.0)
.count()
);
}
}
}
eprintln!("\n[{}] Computing estimation statistics...", test_name);
let stats_by_effect = compute_estimation_stats_by_effect(&all_points);
eprintln!("\n[{}] Estimation Accuracy Summary:", test_name);
eprintln!(" True Effect | Mean Est. | Bias | Bias % | RMSE | Coverage | N");
eprintln!(" ------------|-----------|----------|----------|----------|----------|----");
for stats in &stats_by_effect {
let bias_pct = if stats.true_effect_ns > 0.0 {
format!("{:>7.1}%", stats.bias_fraction * 100.0)
} else {
" N/A".to_string()
};
let bias_marker = if stats.bias_fraction.abs() > config.tier.max_estimation_bias() {
" !!!"
} else {
""
};
eprintln!(
" {:>10.0}ns | {:>7.1}ns | {:>7.1}ns | {} | {:>7.1}ns | {:>7.1}% | {:>3}{}",
stats.true_effect_ns,
stats.mean_estimate,
stats.bias,
bias_pct,
stats.rmse,
stats.coverage * 100.0,
stats.count,
bias_marker
);
if stats.true_effect_ns >= 200.0
&& stats.bias_fraction.abs() > config.tier.max_estimation_bias()
{
eprintln!(
"[{}] FAILED: Bias {:.1}% at {}ns exceeds {:.0}%",
test_name,
stats.bias_fraction * 100.0,
stats.true_effect_ns,
config.tier.max_estimation_bias() * 100.0
);
any_failed = true;
}
}
let total_points: usize = stats_by_effect.iter().map(|s| s.count).sum();
let avg_bias: f64 = stats_by_effect
.iter()
.filter(|s| s.true_effect_ns > 0.0)
.map(|s| s.bias_fraction.abs())
.sum::<f64>()
/ stats_by_effect
.iter()
.filter(|s| s.true_effect_ns > 0.0)
.count()
.max(1) as f64;
let avg_coverage: f64 = stats_by_effect.iter().map(|s| s.coverage).sum::<f64>()
/ stats_by_effect.len().max(1) as f64;
eprintln!("\n[{}] Overall:", test_name);
eprintln!(" Total points: {}", total_points);
eprintln!(" Average |bias|: {:.1}%", avg_bias * 100.0);
eprintln!(" Average coverage: {:.1}%", avg_coverage * 100.0);
if total_points < 20 {
eprintln!(
"[{}] SKIPPED: Insufficient data ({} points)",
test_name, total_points
);
return;
}
if any_failed {
panic!("[{}] FAILED: Estimation accuracy check failed", test_name);
}
eprintln!(
"\n[{}] PASSED: Estimation accuracy within acceptable bounds",
test_name
);
}
fn extract_estimation_point(outcome: &Outcome, true_effect_ns: f64) -> Option<EstimationPoint> {
match outcome {
Outcome::Pass { effect, .. }
| Outcome::Fail { effect, .. }
| Outcome::Inconclusive { effect, .. } => {
Some(EstimationPoint {
true_effect_ns,
estimated_effect_ns: effect.max_effect_ns,
ci_low_ns: effect.credible_interval_ns.0,
ci_high_ns: effect.credible_interval_ns.1,
})
}
Outcome::Unmeasurable { .. } | Outcome::Research(_) => None,
}
}