use std::time::Instant;
use cobre_core::StoppingRuleResult;
use crate::{
forward::SyncResult,
stopping_rule::{MonitorState, StoppingRuleSet},
};
#[derive(Debug)]
pub struct ConvergenceMonitor {
rule_set: StoppingRuleSet,
lower_bound: f64,
upper_bound: f64,
upper_bound_std: f64,
ci_95_half_width: f64,
gap: f64,
lower_bound_history: Vec<f64>,
iteration_count: u64,
start_time: Instant,
shutdown_requested: bool,
simulation_costs: Option<Vec<f64>>,
}
impl ConvergenceMonitor {
#[must_use]
pub fn new(rule_set: StoppingRuleSet) -> Self {
Self {
rule_set,
lower_bound: 0.0,
upper_bound: 0.0,
upper_bound_std: 0.0,
ci_95_half_width: 0.0,
gap: 0.0,
lower_bound_history: Vec::new(),
iteration_count: 0,
start_time: Instant::now(),
shutdown_requested: false,
simulation_costs: None,
}
}
pub fn update(&mut self, lb: f64, sync_result: &SyncResult) -> (bool, Vec<StoppingRuleResult>) {
self.lower_bound = lb;
self.upper_bound = sync_result.global_ub_mean;
self.upper_bound_std = sync_result.global_ub_std;
self.ci_95_half_width = sync_result.ci_95_half_width;
let denominator = self.upper_bound.abs().max(1.0_f64);
self.gap = (self.upper_bound - lb) / denominator;
self.iteration_count += 1;
self.lower_bound_history.push(lb);
let state = MonitorState {
iteration: self.iteration_count,
wall_time_seconds: self.start_time.elapsed().as_secs_f64(),
lower_bound: self.lower_bound,
lower_bound_history: self.lower_bound_history.clone(),
shutdown_requested: self.shutdown_requested,
simulation_costs: self.simulation_costs.clone(),
};
self.rule_set.evaluate(&state)
}
pub fn set_shutdown(&mut self) {
self.shutdown_requested = true;
}
pub fn set_simulation_costs(&mut self, costs: Vec<f64>) {
self.simulation_costs = Some(costs);
}
#[must_use]
pub fn lower_bound(&self) -> f64 {
self.lower_bound
}
#[must_use]
pub fn upper_bound(&self) -> f64 {
self.upper_bound
}
#[must_use]
pub fn upper_bound_std(&self) -> f64 {
self.upper_bound_std
}
#[must_use]
pub fn ci_95_half_width(&self) -> f64 {
self.ci_95_half_width
}
#[must_use]
pub fn gap(&self) -> f64 {
self.gap
}
#[must_use]
pub fn iteration_count(&self) -> u64 {
self.iteration_count
}
}
#[cfg(test)]
mod tests {
use super::ConvergenceMonitor;
use crate::{
forward::SyncResult,
stopping_rule::{StoppingMode, StoppingRule, StoppingRuleSet},
};
fn make_rule_set(rule: StoppingRule) -> StoppingRuleSet {
StoppingRuleSet {
rules: vec![rule],
mode: StoppingMode::Any,
}
}
fn make_sync(ub_mean: f64) -> SyncResult {
SyncResult {
global_ub_mean: ub_mean,
global_ub_std: 5.0,
ci_95_half_width: 2.0,
sync_time_ms: 10,
}
}
fn default_sync() -> SyncResult {
make_sync(110.0)
}
#[test]
fn new_initializes_all_fields_to_default() {
let monitor =
ConvergenceMonitor::new(make_rule_set(StoppingRule::IterationLimit { limit: 10 }));
assert_eq!(monitor.lower_bound(), 0.0);
assert_eq!(monitor.upper_bound(), 0.0);
assert_eq!(monitor.upper_bound_std(), 0.0);
assert_eq!(monitor.ci_95_half_width(), 0.0);
assert_eq!(monitor.gap(), 0.0);
assert_eq!(monitor.iteration_count(), 0);
}
#[test]
fn update_increments_iteration_count() {
let mut monitor =
ConvergenceMonitor::new(make_rule_set(StoppingRule::IterationLimit { limit: 100 }));
monitor.update(100.0, &default_sync());
assert_eq!(monitor.iteration_count(), 1);
monitor.update(101.0, &default_sync());
assert_eq!(monitor.iteration_count(), 2);
}
#[test]
fn update_stores_lb_and_ub_correctly() {
let mut monitor =
ConvergenceMonitor::new(make_rule_set(StoppingRule::IterationLimit { limit: 100 }));
let sync = SyncResult {
global_ub_mean: 200.0,
global_ub_std: 10.0,
ci_95_half_width: 3.0,
sync_time_ms: 5,
};
monitor.update(150.0, &sync);
assert!((monitor.lower_bound() - 150.0).abs() < 1e-10);
assert!((monitor.upper_bound() - 200.0).abs() < 1e-10);
assert!((monitor.upper_bound_std() - 10.0).abs() < 1e-10);
assert!((monitor.ci_95_half_width() - 3.0).abs() < 1e-10);
}
#[test]
fn gap_formula_uses_max_guard() {
let mut monitor =
ConvergenceMonitor::new(make_rule_set(StoppingRule::IterationLimit { limit: 100 }));
let sync = make_sync(0.5);
monitor.update(100.0, &sync);
let expected = (0.5_f64 - 100.0) / 1.0_f64;
assert!(
(monitor.gap() - expected).abs() < 1e-10,
"gap with UB=0.5 must use max guard of 1.0, got {}",
monitor.gap()
);
}
#[test]
fn gap_formula_normal_case() {
let mut monitor =
ConvergenceMonitor::new(make_rule_set(StoppingRule::IterationLimit { limit: 100 }));
let sync = make_sync(110.0);
monitor.update(100.0, &sync);
let expected = 10.0_f64 / 110.0_f64;
assert!(
(monitor.gap() - expected).abs() < 1e-10,
"gap must be 10/110, got {}",
monitor.gap()
);
}
#[test]
fn lower_bound_history_grows() {
let mut monitor =
ConvergenceMonitor::new(make_rule_set(StoppingRule::IterationLimit { limit: 100 }));
for i in 0..5 {
monitor.update(f64::from(i) * 10.0, &default_sync());
}
assert_eq!(monitor.lower_bound_history.len(), 5);
}
#[test]
fn set_shutdown_triggers_graceful_rule() {
let rule_set = StoppingRuleSet {
rules: vec![
StoppingRule::GracefulShutdown,
StoppingRule::IterationLimit { limit: 100 },
],
mode: StoppingMode::Any,
};
let mut monitor = ConvergenceMonitor::new(rule_set);
monitor.set_shutdown();
let (stop, results) = monitor.update(100.0, &default_sync());
assert!(stop, "should stop after shutdown signal");
assert!(
results[0].triggered,
"GracefulShutdown result must be triggered"
);
assert_eq!(results[0].rule_name, "graceful_shutdown");
}
#[test]
fn set_simulation_costs_populates_monitor_state() {
let rule_set = StoppingRuleSet {
rules: vec![StoppingRule::SimulationBased {
period: 1,
distance_tolerance: 1e6, replications: 10,
bound_stability_window: 1,
}],
mode: StoppingMode::Any,
};
let mut monitor = ConvergenceMonitor::new(rule_set);
monitor.set_simulation_costs(vec![100.0, 200.0, 300.0]);
let (_stop, results) = monitor.update(80.0, &default_sync());
assert_eq!(results[0].rule_name, "simulation_based");
assert!(
!results[0]
.detail
.contains("no simulation results available"),
"detail should not indicate missing costs: {}",
results[0].detail
);
}
#[test]
fn iteration_limit_triggers_at_limit() {
let mut monitor =
ConvergenceMonitor::new(make_rule_set(StoppingRule::IterationLimit { limit: 3 }));
let sync = default_sync();
let (stop1, _) = monitor.update(100.0, &sync);
let (stop2, _) = monitor.update(100.0, &sync);
let (stop3, results) = monitor.update(100.0, &sync);
assert!(!stop1, "should not stop at iteration 1");
assert!(!stop2, "should not stop at iteration 2");
assert!(stop3, "should stop at iteration 3 (limit reached)");
assert!(results[0].triggered);
assert_eq!(results[0].rule_name, "iteration_limit");
}
#[test]
fn bound_stalling_triggers_when_stable() {
let monitor = ConvergenceMonitor::new(make_rule_set(StoppingRule::BoundStalling {
tolerance: 0.01,
iterations: 3,
}));
let sync = default_sync();
let rule_set = StoppingRuleSet {
rules: vec![StoppingRule::BoundStalling {
tolerance: 0.011,
iterations: 3,
}],
mode: StoppingMode::Any,
};
let mut monitor2 = ConvergenceMonitor::new(rule_set);
let (_, _) = monitor2.update(90.0, &sync);
let (_, _) = monitor2.update(99.0, &sync);
let (_, _) = monitor2.update(99.5, &sync);
let (stop, _) = monitor2.update(100.0, &sync);
assert!(
stop,
"BoundStalling should trigger when improvement is < 0.011"
);
assert!(
(monitor2.gap() - 10.0 / 110.0).abs() < 1e-10,
"gap after 4th update must equal 10/110, got {}",
monitor2.gap()
);
let _ = monitor; }
#[test]
fn ac_iteration_limit_triggers_at_third_call() {
let rule_set = StoppingRuleSet {
rules: vec![StoppingRule::IterationLimit { limit: 3 }],
mode: StoppingMode::Any,
};
let mut monitor = ConvergenceMonitor::new(rule_set);
let sync = SyncResult {
global_ub_mean: 110.0,
global_ub_std: 5.0,
ci_95_half_width: 2.0,
sync_time_ms: 10,
};
monitor.update(100.0, &sync);
monitor.update(100.0, &sync);
let (stop, results) = monitor.update(100.0, &sync);
assert!(stop, "third update must trigger IterationLimit(3)");
assert!(results[0].triggered);
assert_eq!(results[0].rule_name, "iteration_limit");
}
#[test]
fn ac_gap_formula_with_ub_110_lb_100() {
let mut monitor =
ConvergenceMonitor::new(make_rule_set(StoppingRule::IterationLimit { limit: 100 }));
let sync = SyncResult {
global_ub_mean: 110.0,
global_ub_std: 5.0,
ci_95_half_width: 2.0,
sync_time_ms: 10,
};
monitor.update(90.0, &sync);
monitor.update(99.0, &sync);
monitor.update(99.5, &sync);
monitor.update(100.0, &sync);
let expected = 10.0_f64 / 110.0_f64;
assert!(
(monitor.gap() - expected).abs() < 1e-10,
"gap must equal {expected}, got {}",
monitor.gap()
);
}
#[test]
fn ac_set_shutdown_triggers_graceful_shutdown_rule() {
let rule_set = StoppingRuleSet {
rules: vec![
StoppingRule::GracefulShutdown,
StoppingRule::IterationLimit { limit: 100 },
],
mode: StoppingMode::Any,
};
let mut monitor = ConvergenceMonitor::new(rule_set);
monitor.set_shutdown();
let (stop, results) = monitor.update(100.0, &default_sync());
assert!(stop);
assert!(results[0].triggered);
assert_eq!(results[0].rule_name, "graceful_shutdown");
}
#[test]
fn ac_lb_and_iteration_count_track_correctly() {
let mut monitor =
ConvergenceMonitor::new(make_rule_set(StoppingRule::IterationLimit { limit: 100 }));
monitor.update(50.0, &default_sync());
monitor.update(60.0, &default_sync());
assert!(
(monitor.lower_bound() - 60.0).abs() < 1e-10,
"lower_bound must return latest LB 60.0, got {}",
monitor.lower_bound()
);
assert_eq!(monitor.iteration_count(), 2);
}
}