use std::borrow::Cow;
#[derive(Clone, Debug)]
pub enum WorkerTimingPhase {
Forward,
Backward,
}
pub const WORKER_TIMING_SLOT_COUNT: usize = 16;
pub const WORKER_TIMING_SLOT_FWD_WALL: usize = 0;
pub const WORKER_TIMING_SLOT_BWD_WALL: usize = 1;
pub const WORKER_TIMING_SLOT_BWD_SETUP: usize = 8;
pub const WORKER_TIMING_SLOT_FWD_SETUP: usize = 11;
#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub struct WorkerPhaseTimings {
pub forward_wall_ms: f64,
pub backward_wall_ms: f64,
pub fwd_setup_ms: f64,
pub bwd_setup_ms: f64,
}
#[derive(Clone, Debug)]
pub struct StoppingRuleResult {
pub rule_name: &'static str,
pub triggered: bool,
pub detail: Cow<'static, str>,
}
#[derive(Debug, Clone)]
pub struct StageRowSelectionRecord {
pub stage: u32,
pub rows_populated: u32,
pub rows_active_before: u32,
pub rows_deactivated: u32,
pub rows_active_after: u32,
pub selection_time_ms: f64,
pub budget_evicted: Option<u32>,
pub active_after_budget: Option<u32>,
}
#[derive(Clone, Debug)]
pub enum TrainingEvent {
ForwardPassComplete {
iteration: u64,
scenarios: u32,
ub_mean: f64,
ub_std: f64,
elapsed_ms: u64,
},
ForwardSyncComplete {
iteration: u64,
global_ub_mean: f64,
global_ub_std: f64,
sync_time_ms: u64,
},
BackwardPassComplete {
iteration: u64,
rows_generated: u32,
stages_processed: u32,
elapsed_ms: u64,
state_exchange_time_ms: u64,
row_batch_build_time_ms: u64,
setup_time_ms: u64,
load_imbalance_ms: u64,
scheduling_overhead_ms: u64,
},
PolicySyncComplete {
iteration: u64,
rows_distributed: u32,
rows_active: u32,
rows_removed: u32,
sync_time_ms: u64,
},
PolicySelectionComplete {
iteration: u64,
rows_deactivated: u32,
stages_processed: u32,
selection_time_ms: u64,
allgatherv_time_ms: u64,
per_stage: Vec<StageRowSelectionRecord>,
},
PolicyBudgetEnforcementComplete {
iteration: u64,
rows_evicted: u32,
stages_processed: u32,
enforcement_time_ms: u64,
},
PolicyTemplateBakeComplete {
iteration: u64,
stages_processed: u32,
total_rows_baked: u64,
bake_time_ms: u64,
},
ConvergenceUpdate {
iteration: u64,
lower_bound: f64,
upper_bound: f64,
upper_bound_std: f64,
gap: f64,
rules_evaluated: Vec<StoppingRuleResult>,
},
CheckpointComplete {
iteration: u64,
checkpoint_path: String,
elapsed_ms: u64,
},
IterationSummary {
iteration: u64,
lower_bound: f64,
upper_bound: f64,
gap: f64,
wall_time_ms: u64,
iteration_time_ms: u64,
forward_ms: u64,
backward_ms: u64,
lp_solves: u64,
solve_time_ms: f64,
lower_bound_eval_ms: u64,
fwd_setup_time_ms: u64,
fwd_load_imbalance_ms: u64,
fwd_scheduling_overhead_ms: u64,
},
TrainingStarted {
case_name: String,
stages: u32,
hydros: u32,
thermals: u32,
ranks: u32,
threads_per_rank: u32,
timestamp: String,
},
TrainingFinished {
reason: String,
iterations: u64,
final_lb: f64,
final_ub: f64,
total_time_ms: u64,
total_rows: u64,
},
SimulationStarted {
case_name: String,
n_scenarios: u32,
n_stages: u32,
ranks: u32,
threads_per_rank: u32,
timestamp: String,
},
SimulationProgress {
scenarios_complete: u32,
scenarios_total: u32,
elapsed_ms: u64,
scenario_cost: f64,
solve_time_ms: f64,
lp_solves: u64,
},
SimulationFinished {
scenarios: u32,
output_dir: String,
elapsed_ms: u64,
},
WorkerTiming {
rank: i32,
worker_id: i32,
iteration: u64,
phase: WorkerTimingPhase,
timings: WorkerPhaseTimings,
},
}
#[cfg(test)]
mod tests {
use std::borrow::Cow;
use super::{
StageRowSelectionRecord, StoppingRuleResult, TrainingEvent, WorkerPhaseTimings,
WorkerTimingPhase,
};
fn make_all_variants() -> Vec<TrainingEvent> {
vec![
TrainingEvent::ForwardPassComplete {
iteration: 1,
scenarios: 10,
ub_mean: 110.0,
ub_std: 5.0,
elapsed_ms: 42,
},
TrainingEvent::ForwardSyncComplete {
iteration: 1,
global_ub_mean: 110.0,
global_ub_std: 5.0,
sync_time_ms: 3,
},
TrainingEvent::BackwardPassComplete {
iteration: 1,
rows_generated: 48,
stages_processed: 12,
elapsed_ms: 87,
state_exchange_time_ms: 0,
row_batch_build_time_ms: 0,
setup_time_ms: 0,
load_imbalance_ms: 0,
scheduling_overhead_ms: 0,
},
TrainingEvent::PolicySyncComplete {
iteration: 1,
rows_distributed: 48,
rows_active: 200,
rows_removed: 0,
sync_time_ms: 2,
},
TrainingEvent::PolicySelectionComplete {
iteration: 10,
rows_deactivated: 15,
stages_processed: 12,
selection_time_ms: 20,
allgatherv_time_ms: 1,
per_stage: vec![],
},
TrainingEvent::PolicyBudgetEnforcementComplete {
iteration: 10,
rows_evicted: 2,
stages_processed: 12,
enforcement_time_ms: 1,
},
TrainingEvent::PolicyTemplateBakeComplete {
iteration: 10,
stages_processed: 12,
total_rows_baked: 48,
bake_time_ms: 2,
},
TrainingEvent::ConvergenceUpdate {
iteration: 1,
lower_bound: 100.0,
upper_bound: 110.0,
upper_bound_std: 5.0,
gap: 0.0909,
rules_evaluated: vec![StoppingRuleResult {
rule_name: "gap_tolerance",
triggered: false,
detail: Cow::Borrowed("gap 9.09% > 1.00%"),
}],
},
TrainingEvent::CheckpointComplete {
iteration: 5,
checkpoint_path: "/tmp/checkpoint.bin".to_string(),
elapsed_ms: 150,
},
TrainingEvent::IterationSummary {
iteration: 1,
lower_bound: 100.0,
upper_bound: 110.0,
gap: 0.0909,
wall_time_ms: 1000,
iteration_time_ms: 200,
forward_ms: 80,
backward_ms: 100,
lp_solves: 240,
solve_time_ms: 45.2,
lower_bound_eval_ms: 10,
fwd_setup_time_ms: 2,
fwd_load_imbalance_ms: 2,
fwd_scheduling_overhead_ms: 1,
},
TrainingEvent::TrainingStarted {
case_name: "test_case".to_string(),
stages: 60,
hydros: 5,
thermals: 10,
ranks: 4,
threads_per_rank: 8,
timestamp: "2026-01-01T00:00:00Z".to_string(),
},
TrainingEvent::TrainingFinished {
reason: "gap_tolerance".to_string(),
iterations: 50,
final_lb: 105.0,
final_ub: 106.0,
total_time_ms: 300_000,
total_rows: 2400,
},
TrainingEvent::SimulationProgress {
scenarios_complete: 50,
scenarios_total: 200,
elapsed_ms: 5_000,
scenario_cost: 45_230.0,
solve_time_ms: 0.0,
lp_solves: 0,
},
TrainingEvent::SimulationFinished {
scenarios: 200,
output_dir: "/tmp/output".to_string(),
elapsed_ms: 20_000,
},
TrainingEvent::WorkerTiming {
rank: 0,
worker_id: 2,
iteration: 1,
phase: WorkerTimingPhase::Backward,
timings: WorkerPhaseTimings::default(),
},
]
}
#[test]
fn all_fifteen_variants_construct() {
let variants = make_all_variants();
assert_eq!(
variants.len(),
15,
"expected exactly 15 TrainingEvent variants"
);
}
#[test]
fn all_variants_clone() {
for variant in make_all_variants() {
let cloned = variant.clone();
assert!(!format!("{cloned:?}").is_empty());
}
}
#[test]
fn all_variants_debug_non_empty() {
for variant in make_all_variants() {
let debug = format!("{variant:?}");
assert!(!debug.is_empty(), "debug output must not be empty");
}
}
#[test]
fn forward_pass_complete_fields_accessible() {
let event = TrainingEvent::ForwardPassComplete {
iteration: 7,
scenarios: 20,
ub_mean: 210.0,
ub_std: 3.5,
elapsed_ms: 55,
};
let TrainingEvent::ForwardPassComplete {
iteration,
scenarios,
ub_mean,
ub_std,
elapsed_ms,
} = event
else {
panic!("wrong variant")
};
assert_eq!(iteration, 7);
assert_eq!(scenarios, 20);
assert!((ub_mean - 210.0).abs() < f64::EPSILON);
assert!((ub_std - 3.5).abs() < f64::EPSILON);
assert_eq!(elapsed_ms, 55);
}
#[test]
fn convergence_update_rules_evaluated_field() {
let rules = vec![
StoppingRuleResult {
rule_name: "gap_tolerance",
triggered: true,
detail: Cow::Borrowed("gap 0.42% <= 1.00%"),
},
StoppingRuleResult {
rule_name: "iteration_limit",
triggered: false,
detail: Cow::Borrowed("iteration 10/100"),
},
];
let event = TrainingEvent::ConvergenceUpdate {
iteration: 10,
lower_bound: 99.0,
upper_bound: 100.0,
upper_bound_std: 0.5,
gap: 0.0042,
rules_evaluated: rules.clone(),
};
let TrainingEvent::ConvergenceUpdate {
rules_evaluated, ..
} = event
else {
panic!("wrong variant")
};
assert_eq!(rules_evaluated.len(), 2);
assert_eq!(rules_evaluated[0].rule_name, "gap_tolerance");
assert!(rules_evaluated[0].triggered);
assert_eq!(rules_evaluated[1].rule_name, "iteration_limit");
assert!(!rules_evaluated[1].triggered);
}
#[test]
fn stopping_rule_result_fields_accessible() {
let r = StoppingRuleResult {
rule_name: "bound_stalling",
triggered: false,
detail: Cow::Borrowed("LB stable for 8/10 iterations"),
};
let cloned = r.clone();
assert_eq!(cloned.rule_name, "bound_stalling");
assert!(!cloned.triggered);
assert_eq!(cloned.detail, "LB stable for 8/10 iterations");
}
#[test]
fn stopping_rule_result_debug_non_empty() {
let r = StoppingRuleResult {
rule_name: "time_limit",
triggered: true,
detail: Cow::Borrowed("elapsed 3602s > 3600s limit"),
};
let debug = format!("{r:?}");
assert!(!debug.is_empty());
assert!(debug.contains("time_limit"));
}
#[test]
fn policy_selection_complete_fields_accessible() {
let event = TrainingEvent::PolicySelectionComplete {
iteration: 10,
rows_deactivated: 30,
stages_processed: 12,
selection_time_ms: 25,
allgatherv_time_ms: 2,
per_stage: vec![],
};
let TrainingEvent::PolicySelectionComplete {
iteration,
rows_deactivated,
stages_processed,
selection_time_ms,
allgatherv_time_ms,
per_stage,
} = event
else {
panic!("wrong variant")
};
assert_eq!(iteration, 10);
assert_eq!(rows_deactivated, 30);
assert_eq!(stages_processed, 12);
assert_eq!(selection_time_ms, 25);
assert_eq!(allgatherv_time_ms, 2);
assert!(per_stage.is_empty());
}
#[test]
fn training_started_timestamp_field() {
let event = TrainingEvent::TrainingStarted {
case_name: "hydro_sys".to_string(),
stages: 120,
hydros: 10,
thermals: 20,
ranks: 8,
threads_per_rank: 4,
timestamp: "2026-03-01T08:00:00Z".to_string(),
};
let TrainingEvent::TrainingStarted { timestamp, .. } = event else {
panic!("wrong variant")
};
assert_eq!(timestamp, "2026-03-01T08:00:00Z");
}
#[test]
fn simulation_progress_scenario_cost_field_accessible() {
let event = TrainingEvent::SimulationProgress {
scenarios_complete: 100,
scenarios_total: 500,
elapsed_ms: 10_000,
scenario_cost: 45_230.0,
solve_time_ms: 0.0,
lp_solves: 0,
};
let TrainingEvent::SimulationProgress {
scenarios_complete,
scenarios_total,
elapsed_ms,
scenario_cost,
..
} = event
else {
panic!("wrong variant")
};
assert_eq!(scenarios_complete, 100);
assert_eq!(scenarios_total, 500);
assert_eq!(elapsed_ms, 10_000);
assert!((scenario_cost - 45_230.0).abs() < f64::EPSILON);
}
#[test]
fn simulation_progress_first_scenario_cost_carried() {
let event = TrainingEvent::SimulationProgress {
scenarios_complete: 1,
scenarios_total: 200,
elapsed_ms: 100,
scenario_cost: 50_000.0,
solve_time_ms: 0.0,
lp_solves: 0,
};
let TrainingEvent::SimulationProgress { scenario_cost, .. } = event else {
panic!("wrong variant")
};
assert!((scenario_cost - 50_000.0).abs() < f64::EPSILON);
}
#[test]
fn policy_budget_enforcement_complete_fields_accessible() {
let event = TrainingEvent::PolicyBudgetEnforcementComplete {
iteration: 7,
rows_evicted: 5,
stages_processed: 12,
enforcement_time_ms: 3,
};
let TrainingEvent::PolicyBudgetEnforcementComplete {
iteration,
rows_evicted,
stages_processed,
enforcement_time_ms,
} = event
else {
panic!("wrong variant")
};
assert_eq!(iteration, 7);
assert_eq!(rows_evicted, 5);
assert_eq!(stages_processed, 12);
assert_eq!(enforcement_time_ms, 3);
}
#[test]
fn worker_timing_fields_accessible() {
let timings = WorkerPhaseTimings {
forward_wall_ms: 10.0,
backward_wall_ms: 0.0,
fwd_setup_ms: 2.5,
bwd_setup_ms: 0.0,
};
let event = TrainingEvent::WorkerTiming {
rank: 2,
worker_id: 3,
iteration: 7,
phase: WorkerTimingPhase::Forward,
timings,
};
let TrainingEvent::WorkerTiming {
rank,
worker_id,
iteration,
phase,
timings: t,
} = event
else {
panic!("wrong variant");
};
assert_eq!(rank, 2);
assert_eq!(worker_id, 3);
assert_eq!(iteration, 7);
assert!(
!format!("{phase:?}").is_empty(),
"WorkerTimingPhase::Forward debug must be non-empty"
);
assert!((t.forward_wall_ms - 10.0).abs() < f64::EPSILON);
assert!((t.fwd_setup_ms - 2.5).abs() < f64::EPSILON);
assert_eq!(t.backward_wall_ms, 0.0);
assert_eq!(t.bwd_setup_ms, 0.0);
let bwd = WorkerTimingPhase::Backward;
assert!(
!format!("{bwd:?}").is_empty(),
"WorkerTimingPhase::Backward debug must be non-empty"
);
}
#[test]
fn policy_template_bake_complete_fields_accessible() {
let event = TrainingEvent::PolicyTemplateBakeComplete {
iteration: 5,
stages_processed: 12,
total_rows_baked: 96,
bake_time_ms: 3,
};
let TrainingEvent::PolicyTemplateBakeComplete {
iteration,
stages_processed,
total_rows_baked,
bake_time_ms,
} = event
else {
panic!("wrong variant")
};
assert_eq!(iteration, 5);
assert_eq!(stages_processed, 12);
assert_eq!(total_rows_baked, 96);
assert_eq!(bake_time_ms, 3);
}
#[test]
fn stage_row_selection_record_fields_accessible() {
let record = StageRowSelectionRecord {
stage: 3,
rows_populated: 100,
rows_active_before: 80,
rows_deactivated: 10,
rows_active_after: 70,
selection_time_ms: 1.5,
budget_evicted: Some(5),
active_after_budget: Some(65),
};
assert_eq!(record.stage, 3);
assert_eq!(record.rows_populated, 100);
assert_eq!(record.rows_active_before, 80);
assert_eq!(record.rows_deactivated, 10);
assert_eq!(record.rows_active_after, 70);
assert!((record.selection_time_ms - 1.5).abs() < f64::EPSILON);
assert_eq!(record.budget_evicted, Some(5));
assert_eq!(record.active_after_budget, Some(65));
}
}