#[derive(Clone, Debug)]
pub struct StoppingRuleResult {
pub rule_name: String,
pub triggered: bool,
pub detail: String,
}
#[derive(Debug, Clone)]
pub struct StageSelectionRecord {
pub stage: u32,
pub cuts_populated: u32,
pub cuts_active_before: u32,
pub cuts_deactivated: u32,
pub cuts_active_after: u32,
}
#[derive(Clone, Debug)]
pub enum TrainingEvent {
ForwardPassComplete {
iteration: u64,
scenarios: u32,
ub_mean: f64,
ub_std: f64,
elapsed_ms: u64,
},
ForwardSyncComplete {
iteration: u64,
global_ub_mean: f64,
global_ub_std: f64,
sync_time_ms: u64,
},
BackwardPassComplete {
iteration: u64,
cuts_generated: u32,
stages_processed: u32,
elapsed_ms: u64,
state_exchange_time_ms: u64,
cut_batch_build_time_ms: u64,
rayon_overhead_time_ms: u64,
},
CutSyncComplete {
iteration: u64,
cuts_distributed: u32,
cuts_active: u32,
cuts_removed: u32,
sync_time_ms: u64,
},
CutSelectionComplete {
iteration: u64,
cuts_deactivated: u32,
stages_processed: u32,
selection_time_ms: u64,
allgatherv_time_ms: u64,
per_stage: Vec<StageSelectionRecord>,
},
ConvergenceUpdate {
iteration: u64,
lower_bound: f64,
upper_bound: f64,
upper_bound_std: f64,
gap: f64,
rules_evaluated: Vec<StoppingRuleResult>,
},
CheckpointComplete {
iteration: u64,
checkpoint_path: String,
elapsed_ms: u64,
},
IterationSummary {
iteration: u64,
lower_bound: f64,
upper_bound: f64,
gap: f64,
wall_time_ms: u64,
iteration_time_ms: u64,
forward_ms: u64,
backward_ms: u64,
lp_solves: u64,
solve_time_ms: f64,
},
TrainingStarted {
case_name: String,
stages: u32,
hydros: u32,
thermals: u32,
ranks: u32,
threads_per_rank: u32,
timestamp: String,
},
TrainingFinished {
reason: String,
iterations: u64,
final_lb: f64,
final_ub: f64,
total_time_ms: u64,
total_cuts: u64,
},
SimulationProgress {
scenarios_complete: u32,
scenarios_total: u32,
elapsed_ms: u64,
scenario_cost: f64,
solve_time_ms: f64,
lp_solves: u64,
},
SimulationFinished {
scenarios: u32,
output_dir: String,
elapsed_ms: u64,
},
}
#[cfg(test)]
mod tests {
use super::{StoppingRuleResult, TrainingEvent};
fn make_all_variants() -> Vec<TrainingEvent> {
vec![
TrainingEvent::ForwardPassComplete {
iteration: 1,
scenarios: 10,
ub_mean: 110.0,
ub_std: 5.0,
elapsed_ms: 42,
},
TrainingEvent::ForwardSyncComplete {
iteration: 1,
global_ub_mean: 110.0,
global_ub_std: 5.0,
sync_time_ms: 3,
},
TrainingEvent::BackwardPassComplete {
iteration: 1,
cuts_generated: 48,
stages_processed: 12,
elapsed_ms: 87,
state_exchange_time_ms: 0,
cut_batch_build_time_ms: 0,
rayon_overhead_time_ms: 0,
},
TrainingEvent::CutSyncComplete {
iteration: 1,
cuts_distributed: 48,
cuts_active: 200,
cuts_removed: 0,
sync_time_ms: 2,
},
TrainingEvent::CutSelectionComplete {
iteration: 10,
cuts_deactivated: 15,
stages_processed: 12,
selection_time_ms: 20,
allgatherv_time_ms: 1,
per_stage: vec![],
},
TrainingEvent::ConvergenceUpdate {
iteration: 1,
lower_bound: 100.0,
upper_bound: 110.0,
upper_bound_std: 5.0,
gap: 0.0909,
rules_evaluated: vec![StoppingRuleResult {
rule_name: "gap_tolerance".to_string(),
triggered: false,
detail: "gap 9.09% > 1.00%".to_string(),
}],
},
TrainingEvent::CheckpointComplete {
iteration: 5,
checkpoint_path: "/tmp/checkpoint.bin".to_string(),
elapsed_ms: 150,
},
TrainingEvent::IterationSummary {
iteration: 1,
lower_bound: 100.0,
upper_bound: 110.0,
gap: 0.0909,
wall_time_ms: 1000,
iteration_time_ms: 200,
forward_ms: 80,
backward_ms: 100,
lp_solves: 240,
solve_time_ms: 45.2,
},
TrainingEvent::TrainingStarted {
case_name: "test_case".to_string(),
stages: 60,
hydros: 5,
thermals: 10,
ranks: 4,
threads_per_rank: 8,
timestamp: "2026-01-01T00:00:00Z".to_string(),
},
TrainingEvent::TrainingFinished {
reason: "gap_tolerance".to_string(),
iterations: 50,
final_lb: 105.0,
final_ub: 106.0,
total_time_ms: 300_000,
total_cuts: 2400,
},
TrainingEvent::SimulationProgress {
scenarios_complete: 50,
scenarios_total: 200,
elapsed_ms: 5_000,
scenario_cost: 45_230.0,
solve_time_ms: 0.0,
lp_solves: 0,
},
TrainingEvent::SimulationFinished {
scenarios: 200,
output_dir: "/tmp/output".to_string(),
elapsed_ms: 20_000,
},
]
}
#[test]
fn all_twelve_variants_construct() {
let variants = make_all_variants();
assert_eq!(
variants.len(),
12,
"expected exactly 12 TrainingEvent variants"
);
}
#[test]
fn all_variants_clone() {
for variant in make_all_variants() {
let cloned = variant.clone();
assert!(!format!("{cloned:?}").is_empty());
}
}
#[test]
fn all_variants_debug_non_empty() {
for variant in make_all_variants() {
let debug = format!("{variant:?}");
assert!(!debug.is_empty(), "debug output must not be empty");
}
}
#[test]
fn forward_pass_complete_fields_accessible() {
let event = TrainingEvent::ForwardPassComplete {
iteration: 7,
scenarios: 20,
ub_mean: 210.0,
ub_std: 3.5,
elapsed_ms: 55,
};
let TrainingEvent::ForwardPassComplete {
iteration,
scenarios,
ub_mean,
ub_std,
elapsed_ms,
} = event
else {
panic!("wrong variant")
};
assert_eq!(iteration, 7);
assert_eq!(scenarios, 20);
assert!((ub_mean - 210.0).abs() < f64::EPSILON);
assert!((ub_std - 3.5).abs() < f64::EPSILON);
assert_eq!(elapsed_ms, 55);
}
#[test]
fn convergence_update_rules_evaluated_field() {
let rules = vec![
StoppingRuleResult {
rule_name: "gap_tolerance".to_string(),
triggered: true,
detail: "gap 0.42% <= 1.00%".to_string(),
},
StoppingRuleResult {
rule_name: "iteration_limit".to_string(),
triggered: false,
detail: "iteration 10/100".to_string(),
},
];
let event = TrainingEvent::ConvergenceUpdate {
iteration: 10,
lower_bound: 99.0,
upper_bound: 100.0,
upper_bound_std: 0.5,
gap: 0.0042,
rules_evaluated: rules.clone(),
};
let TrainingEvent::ConvergenceUpdate {
rules_evaluated, ..
} = event
else {
panic!("wrong variant")
};
assert_eq!(rules_evaluated.len(), 2);
assert_eq!(rules_evaluated[0].rule_name, "gap_tolerance");
assert!(rules_evaluated[0].triggered);
assert_eq!(rules_evaluated[1].rule_name, "iteration_limit");
assert!(!rules_evaluated[1].triggered);
}
#[test]
fn stopping_rule_result_fields_accessible() {
let r = StoppingRuleResult {
rule_name: "bound_stalling".to_string(),
triggered: false,
detail: "LB stable for 8/10 iterations".to_string(),
};
let cloned = r.clone();
assert_eq!(cloned.rule_name, "bound_stalling");
assert!(!cloned.triggered);
assert_eq!(cloned.detail, "LB stable for 8/10 iterations");
}
#[test]
fn stopping_rule_result_debug_non_empty() {
let r = StoppingRuleResult {
rule_name: "time_limit".to_string(),
triggered: true,
detail: "elapsed 3602s > 3600s limit".to_string(),
};
let debug = format!("{r:?}");
assert!(!debug.is_empty());
assert!(debug.contains("time_limit"));
}
#[test]
fn cut_selection_complete_fields_accessible() {
let event = TrainingEvent::CutSelectionComplete {
iteration: 10,
cuts_deactivated: 30,
stages_processed: 12,
selection_time_ms: 25,
allgatherv_time_ms: 2,
per_stage: vec![],
};
let TrainingEvent::CutSelectionComplete {
iteration,
cuts_deactivated,
stages_processed,
selection_time_ms,
allgatherv_time_ms,
per_stage,
} = event
else {
panic!("wrong variant")
};
assert_eq!(iteration, 10);
assert_eq!(cuts_deactivated, 30);
assert_eq!(stages_processed, 12);
assert_eq!(selection_time_ms, 25);
assert_eq!(allgatherv_time_ms, 2);
assert!(per_stage.is_empty());
}
#[test]
fn training_started_timestamp_field() {
let event = TrainingEvent::TrainingStarted {
case_name: "hydro_sys".to_string(),
stages: 120,
hydros: 10,
thermals: 20,
ranks: 8,
threads_per_rank: 4,
timestamp: "2026-03-01T08:00:00Z".to_string(),
};
let TrainingEvent::TrainingStarted { timestamp, .. } = event else {
panic!("wrong variant")
};
assert_eq!(timestamp, "2026-03-01T08:00:00Z");
}
#[test]
fn simulation_progress_scenario_cost_field_accessible() {
let event = TrainingEvent::SimulationProgress {
scenarios_complete: 100,
scenarios_total: 500,
elapsed_ms: 10_000,
scenario_cost: 45_230.0,
solve_time_ms: 0.0,
lp_solves: 0,
};
let TrainingEvent::SimulationProgress {
scenarios_complete,
scenarios_total,
elapsed_ms,
scenario_cost,
..
} = event
else {
panic!("wrong variant")
};
assert_eq!(scenarios_complete, 100);
assert_eq!(scenarios_total, 500);
assert_eq!(elapsed_ms, 10_000);
assert!((scenario_cost - 45_230.0).abs() < f64::EPSILON);
}
#[test]
fn simulation_progress_first_scenario_cost_carried() {
let event = TrainingEvent::SimulationProgress {
scenarios_complete: 1,
scenarios_total: 200,
elapsed_ms: 100,
scenario_cost: 50_000.0,
solve_time_ms: 0.0,
lp_solves: 0,
};
let TrainingEvent::SimulationProgress { scenario_cost, .. } = event else {
panic!("wrong variant")
};
assert!((scenario_cost - 50_000.0).abs() < f64::EPSILON);
}
}