use std::collections::BTreeMap;
use cobre_core::TrainingEvent;
use cobre_io::{IterationRecord, RowPoolStatistics, RowSelectionRecord, TrainingOutput};
use crate::{FutureCostFunction, TrainingResult};
#[derive(Default)]
struct PartialRecord {
lower_bound: f64,
upper_bound_mean: f64,
upper_bound_std: f64,
gap: f64,
forward_ms: u64,
backward_ms: u64,
iteration_time_ms: u64,
lp_solves: u64,
forward_passes: u32,
cuts_added: u32,
cuts_removed: u32,
cuts_active: u32,
forward_sync_ms: u64,
cut_sync_ms: u64,
cut_selection_ms: u64,
cut_selection_allgatherv_ms: u64,
solve_time_ms: f64,
state_exchange_ms: u64,
cut_batch_build_ms: u64,
bwd_setup_ms: u64,
bwd_load_imbalance_ms: u64,
bwd_scheduling_overhead_ms: u64,
lower_bound_eval_ms: u64,
fwd_setup_ms: u64,
fwd_load_imbalance_ms: u64,
fwd_scheduling_overhead_ms: u64,
rows_in_lp_sum: u64,
rows_in_lp_count: u64,
rows_in_lp_max: u64,
}
fn accumulate_partial_records(events: &[TrainingEvent]) -> (BTreeMap<u64, PartialRecord>, u64) {
let mut partials: BTreeMap<u64, PartialRecord> = BTreeMap::new();
let mut peak_active: u64 = 0;
for event in events {
match event {
TrainingEvent::IterationSummary {
iteration,
lower_bound,
upper_bound,
gap,
iteration_time_ms,
forward_ms,
backward_ms,
lp_solves,
solve_time_ms,
lower_bound_eval_ms,
fwd_setup_time_ms,
fwd_load_imbalance_ms,
fwd_scheduling_overhead_ms,
rows_in_lp_sum,
rows_in_lp_count,
rows_in_lp_max,
..
} => {
let record = partials.entry(*iteration).or_default();
record.lower_bound = *lower_bound;
record.upper_bound_mean = *upper_bound;
record.gap = *gap;
record.iteration_time_ms = *iteration_time_ms;
record.forward_ms = *forward_ms;
record.backward_ms = *backward_ms;
record.lp_solves = *lp_solves;
record.solve_time_ms = *solve_time_ms;
record.lower_bound_eval_ms = *lower_bound_eval_ms;
record.fwd_setup_ms = *fwd_setup_time_ms;
record.fwd_load_imbalance_ms = *fwd_load_imbalance_ms;
record.fwd_scheduling_overhead_ms = *fwd_scheduling_overhead_ms;
record.rows_in_lp_sum = *rows_in_lp_sum;
record.rows_in_lp_count = *rows_in_lp_count;
record.rows_in_lp_max = *rows_in_lp_max;
}
TrainingEvent::ForwardSyncComplete {
iteration,
global_ub_std,
sync_time_ms,
..
} => {
let record = partials.entry(*iteration).or_default();
record.upper_bound_std = *global_ub_std;
record.forward_sync_ms = *sync_time_ms;
}
TrainingEvent::ForwardPassComplete {
iteration,
scenarios,
..
} => {
let record = partials.entry(*iteration).or_default();
record.forward_passes = *scenarios;
}
TrainingEvent::BackwardPassComplete {
iteration,
rows_generated,
state_exchange_time_ms,
row_batch_build_time_ms,
setup_time_ms,
load_imbalance_ms,
scheduling_overhead_ms,
..
} => {
let record = partials.entry(*iteration).or_default();
record.cuts_added = *rows_generated;
record.state_exchange_ms = *state_exchange_time_ms;
record.cut_batch_build_ms = *row_batch_build_time_ms;
record.bwd_setup_ms = *setup_time_ms;
record.bwd_load_imbalance_ms = *load_imbalance_ms;
record.bwd_scheduling_overhead_ms = *scheduling_overhead_ms;
}
TrainingEvent::PolicySyncComplete {
iteration,
rows_active,
rows_removed,
sync_time_ms,
..
} => {
let record = partials.entry(*iteration).or_default();
record.cuts_active = *rows_active;
record.cuts_removed = *rows_removed;
record.cut_sync_ms = *sync_time_ms;
peak_active = peak_active.max(u64::from(*rows_active));
}
TrainingEvent::PolicySelectionComplete {
iteration,
rows_deactivated,
selection_time_ms,
allgatherv_time_ms,
..
} => {
let record = partials.entry(*iteration).or_default();
record.cut_selection_ms = *selection_time_ms;
record.cut_selection_allgatherv_ms = *allgatherv_time_ms;
record.cuts_active = record.cuts_active.saturating_sub(*rows_deactivated);
}
_ => {}
}
}
(partials, peak_active)
}
fn partial_to_iteration_record(iter: u64, partial: &PartialRecord) -> IterationRecord {
let gap_percent = if partial.lower_bound > 0.0 {
Some(partial.gap * 100.0)
} else {
None
};
#[allow(clippy::cast_possible_truncation)]
let iteration_u32 = iter as u32;
#[allow(clippy::cast_possible_truncation)]
let lp_solves_u32 = partial.lp_solves as u32;
let attributed_ms = partial
.forward_ms
.saturating_add(partial.backward_ms)
.saturating_add(partial.cut_selection_ms)
.saturating_add(partial.cut_selection_allgatherv_ms)
.saturating_add(partial.forward_sync_ms)
.saturating_add(partial.lower_bound_eval_ms);
let overhead_ms = partial.iteration_time_ms.saturating_sub(attributed_ms);
#[allow(clippy::cast_precision_loss)]
let mean_rows_in_lp = if partial.rows_in_lp_count > 0 {
partial.rows_in_lp_sum as f64 / partial.rows_in_lp_count as f64
} else {
0.0
};
IterationRecord {
iteration: iteration_u32,
lower_bound: partial.lower_bound,
upper_bound_mean: partial.upper_bound_mean,
upper_bound_std: partial.upper_bound_std,
gap_percent,
cuts_added: partial.cuts_added,
cuts_removed: partial.cuts_removed,
cuts_active: partial.cuts_active,
time_forward_ms: partial.forward_ms,
time_backward_ms: partial.backward_ms,
time_total_ms: partial.iteration_time_ms,
forward_passes: partial.forward_passes,
lp_solves: lp_solves_u32,
time_forward_wall_ms: partial.forward_ms,
time_backward_wall_ms: partial.backward_ms,
time_cut_selection_ms: partial.cut_selection_ms,
time_mpi_allreduce_ms: partial.forward_sync_ms,
time_cut_sync_ms: partial.cut_sync_ms,
time_lower_bound_ms: partial.lower_bound_eval_ms,
time_state_exchange_ms: partial.state_exchange_ms,
time_cut_batch_build_ms: partial.cut_batch_build_ms,
time_bwd_setup_ms: partial.bwd_setup_ms,
time_bwd_load_imbalance_ms: partial.bwd_load_imbalance_ms,
time_bwd_scheduling_overhead_ms: partial.bwd_scheduling_overhead_ms,
time_fwd_setup_ms: partial.fwd_setup_ms,
time_fwd_load_imbalance_ms: partial.fwd_load_imbalance_ms,
time_fwd_scheduling_overhead_ms: partial.fwd_scheduling_overhead_ms,
time_overhead_ms: overhead_ms,
solve_time_ms: partial.solve_time_ms,
mean_rows_in_lp,
}
}
#[must_use]
pub fn build_training_output(
result: &TrainingResult,
events: &[TrainingEvent],
fcf: &FutureCostFunction,
) -> TrainingOutput {
let (partials, peak_active) = accumulate_partial_records(events);
let summary_iterations: std::collections::BTreeSet<u64> = events
.iter()
.filter_map(|e| {
if let TrainingEvent::IterationSummary { iteration, .. } = e {
Some(*iteration)
} else {
None
}
})
.collect();
let (rows_in_lp_total, rows_in_lp_solve_count, rows_in_lp_max) =
partials.values().fold((0u64, 0u64, 0u64), |(s, c, m), p| {
(
s + p.rows_in_lp_sum,
c + p.rows_in_lp_count,
m.max(p.rows_in_lp_max),
)
});
let convergence_records: Vec<IterationRecord> = partials
.into_iter()
.filter(|(iter, _)| summary_iterations.contains(iter))
.map(|(iter, partial)| partial_to_iteration_record(iter, &partial))
.collect();
let cut_stats = RowPoolStatistics {
total_generated: fcf.total_generated_cuts() as u64,
total_active: fcf.total_active_cuts() as u64,
peak_active,
cuts_active: fcf.total_active_cuts() as u64,
rows_in_lp_total,
rows_in_lp_solve_count,
rows_in_lp_max,
};
let converged = result.reason == crate::stopping_rule::RULE_BOUND_STALLING
|| result.reason == crate::stopping_rule::RULE_SIMULATION_BASED;
let final_gap_percent = if result.final_lb > 0.0 {
Some(result.final_gap * 100.0)
} else {
None
};
#[allow(clippy::cast_possible_truncation)]
let iterations_completed = result.iterations as u32;
#[allow(clippy::cast_possible_truncation)]
let cut_selection_records: Vec<RowSelectionRecord> = events
.iter()
.filter_map(|event| {
if let TrainingEvent::PolicySelectionComplete {
iteration,
per_stage,
..
} = event
{
Some(per_stage.iter().map(move |rec| RowSelectionRecord {
iteration: *iteration as u32,
stage: rec.stage,
cuts_populated: rec.rows_populated,
cuts_active_before: rec.rows_active_before,
cuts_deactivated: rec.rows_deactivated,
cuts_reactivated: rec.rows_reactivated,
cuts_active_after: rec.rows_active_after,
selection_time_ms: rec.selection_time_ms,
budget_evicted: rec.budget_evicted,
active_after_budget: rec.active_after_budget,
}))
} else {
None
}
})
.flatten()
.collect();
let worker_timing_records = build_worker_timing_records(events, &convergence_records);
TrainingOutput {
convergence_records,
final_lower_bound: result.final_lb,
final_upper_bound: Some(result.final_ub),
final_gap_percent,
final_upper_bound_std: Some(result.final_ub_std),
iterations_completed,
converged,
termination_reason: result.reason.clone(),
total_time_ms: result.total_time_ms,
cut_stats,
cut_selection_records,
worker_timing_records,
training_solve_stats: cobre_io::MetadataTrainingSolveStats::default(),
}
}
fn build_worker_timing_records(
events: &[TrainingEvent],
convergence_records: &[IterationRecord],
) -> Vec<cobre_io::WorkerTimingRecord> {
use cobre_core::{
WORKER_TIMING_SLOT_BWD_SETUP, WORKER_TIMING_SLOT_BWD_WALL, WORKER_TIMING_SLOT_COUNT,
WORKER_TIMING_SLOT_FWD_SETUP, WORKER_TIMING_SLOT_FWD_WALL, WORKER_TIMING_SLOT_SCORING,
};
let mut per_worker: BTreeMap<(u32, i32, i32), [u64; WORKER_TIMING_SLOT_COUNT]> =
BTreeMap::new();
for event in events {
if let TrainingEvent::WorkerTiming {
iteration,
rank,
worker_id,
timings,
..
} = event
{
#[allow(clippy::cast_possible_truncation)]
let iter_u32 = *iteration as u32;
let entry = per_worker
.entry((iter_u32, *rank, *worker_id))
.or_insert([0_u64; WORKER_TIMING_SLOT_COUNT]);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
{
entry[WORKER_TIMING_SLOT_FWD_WALL] = entry[WORKER_TIMING_SLOT_FWD_WALL]
.saturating_add(timings.forward_wall_ms.max(0.0).round() as u64);
entry[WORKER_TIMING_SLOT_BWD_WALL] = entry[WORKER_TIMING_SLOT_BWD_WALL]
.saturating_add(timings.backward_wall_ms.max(0.0).round() as u64);
entry[WORKER_TIMING_SLOT_BWD_SETUP] = entry[WORKER_TIMING_SLOT_BWD_SETUP]
.saturating_add(timings.bwd_setup_ms.max(0.0).round() as u64);
entry[WORKER_TIMING_SLOT_FWD_SETUP] = entry[WORKER_TIMING_SLOT_FWD_SETUP]
.saturating_add(timings.fwd_setup_ms.max(0.0).round() as u64);
entry[WORKER_TIMING_SLOT_SCORING] = entry[WORKER_TIMING_SLOT_SCORING]
.saturating_add(timings.scoring_ms.max(0.0).round() as u64);
}
}
}
let mut out: Vec<cobre_io::WorkerTimingRecord> =
Vec::with_capacity(convergence_records.len() + per_worker.len());
for record in convergence_records {
let mut timings = [0_u64; WORKER_TIMING_SLOT_COUNT];
timings[2] = record.time_cut_selection_ms;
timings[3] = record.time_mpi_allreduce_ms;
timings[4] = record.time_cut_sync_ms;
timings[5] = record.time_lower_bound_ms;
timings[6] = record.time_state_exchange_ms;
timings[7] = record.time_cut_batch_build_ms;
timings[9] = record.time_bwd_load_imbalance_ms;
timings[10] = record.time_bwd_scheduling_overhead_ms;
timings[12] = record.time_fwd_load_imbalance_ms;
timings[13] = record.time_fwd_scheduling_overhead_ms;
timings[14] = record.time_overhead_ms;
out.push(cobre_io::WorkerTimingRecord {
iteration: record.iteration,
rank: 0,
worker_id: None,
timings,
});
}
for ((iteration, rank, worker_id), timings) in per_worker {
out.push(cobre_io::WorkerTimingRecord {
iteration,
rank,
worker_id: Some(worker_id),
timings,
});
}
out
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic, clippy::doc_markdown)]
mod tests {
use cobre_core::TrainingEvent;
use super::build_training_output;
use crate::{FutureCostFunction, TrainingResult};
fn make_result(reason: &str, lb: f64, ub: f64, gap: f64, iterations: u64) -> TrainingResult {
TrainingResult::new(
lb,
ub,
0.0,
gap,
iterations,
reason.to_string(),
1_000,
Vec::new(),
Vec::new(),
None,
None,
)
}
fn make_iteration_summary(iter: u64, lb: f64, ub: f64, gap: f64) -> TrainingEvent {
TrainingEvent::IterationSummary {
iteration: iter,
lower_bound: lb,
upper_bound: ub,
gap,
wall_time_ms: iter * 100,
iteration_time_ms: 100,
forward_ms: 40,
backward_ms: 50,
lp_solves: 60,
solve_time_ms: 0.0,
lower_bound_eval_ms: 0,
fwd_setup_time_ms: 0,
fwd_load_imbalance_ms: 0,
fwd_scheduling_overhead_ms: 0,
rows_in_lp_sum: 0,
rows_in_lp_count: 0,
rows_in_lp_max: 0,
}
}
fn make_empty_fcf() -> FutureCostFunction {
FutureCostFunction::new(2, 1, 4, 10, &[0; 2])
}
#[test]
fn records_count_matches_iteration_summaries() {
let result = make_result("iteration_limit", 100.0, 110.0, 0.091, 3);
let events = vec![
make_iteration_summary(1, 95.0, 112.0, 0.15),
make_iteration_summary(2, 98.0, 111.0, 0.12),
make_iteration_summary(3, 100.0, 110.0, 0.091),
];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
assert_eq!(output.convergence_records.len(), 3);
}
#[test]
fn converged_true_for_bound_stalling() {
let result = make_result("bound_stalling", 100.0, 101.0, 0.01, 5);
let events = vec![make_iteration_summary(1, 100.0, 101.0, 0.01)];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
assert!(output.converged);
}
#[test]
fn converged_true_for_simulation_based() {
let result = make_result("simulation_based", 100.0, 101.0, 0.01, 5);
let events = vec![make_iteration_summary(1, 100.0, 101.0, 0.01)];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
assert!(output.converged);
}
#[test]
fn converged_false_for_iteration_limit() {
let result = make_result("iteration_limit", 90.0, 110.0, 0.2, 100);
let events = vec![make_iteration_summary(1, 90.0, 110.0, 0.2)];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
assert!(!output.converged);
}
#[test]
fn cut_stats_from_fcf() {
let result = make_result("iteration_limit", 80.0, 100.0, 0.2, 1);
let events = vec![make_iteration_summary(1, 80.0, 100.0, 0.2)];
let mut fcf = FutureCostFunction::new(2, 1, 4, 10, &[0; 2]);
fcf.add_cut(0, 0, 0, 1.0, &[1.0]);
fcf.add_cut(0, 0, 1, 2.0, &[0.5]);
fcf.add_cut(0, 0, 2, 3.0, &[0.25]);
fcf.add_cut(1, 0, 0, 4.0, &[1.0]);
fcf.add_cut(1, 0, 1, 5.0, &[0.5]);
let output = build_training_output(&result, &events, &fcf);
assert_eq!(
output.cut_stats.total_generated, 5,
"total_generated must equal the true number of cuts added (iteration 0 has no gap)"
);
assert_eq!(
output.cut_stats.total_active, 5,
"total_active must equal active cuts in all pools"
);
}
#[test]
fn total_generated_excludes_reserved_leading_slots() {
let result = make_result("iteration_limit", 80.0, 100.0, 0.2, 1);
let events = vec![make_iteration_summary(1, 80.0, 100.0, 0.2)];
let mut fcf = FutureCostFunction::new(2, 1, 2, 10, &[0; 2]);
fcf.add_cut(0, 1, 0, 1.0, &[1.0]);
fcf.add_cut(0, 1, 1, 2.0, &[1.0]);
fcf.add_cut(1, 1, 0, 3.0, &[1.0]);
let populated: u64 = fcf.pools.iter().map(|p| p.populated_count as u64).sum();
assert_eq!(
populated, 7,
"high-water mark includes the empty leading block"
);
let output = build_training_output(&result, &events, &fcf);
assert_eq!(
output.cut_stats.total_generated, 3,
"total_generated must count the 3 cuts actually added"
);
assert!(
output.cut_stats.total_generated < populated,
"total_generated must exclude reserved-but-empty leading slots"
);
}
#[test]
fn gap_percent_none_when_lb_nonpositive() {
let result = make_result("iteration_limit", 0.0, 10.0, 1.0, 1);
let events = vec![make_iteration_summary(1, 0.0, 10.0, 1.0)];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
assert!(
output.final_gap_percent.is_none(),
"final_gap_percent must be None when final_lb <= 0"
);
}
#[test]
fn converged_false_for_all_other_reasons() {
let reasons = [
"iteration_limit",
"time_limit",
"graceful_shutdown",
"unknown",
];
let fcf = make_empty_fcf();
for reason in reasons {
let result = make_result(reason, 100.0, 110.0, 0.1, 1);
let output = build_training_output(&result, &[], &fcf);
assert!(
!output.converged,
"converged must be false for reason = {reason}"
);
}
}
#[test]
fn empty_events_produces_zero_records() {
let result = make_result("iteration_limit", 50.0, 60.0, 0.2, 0);
let fcf = make_empty_fcf();
let output = build_training_output(&result, &[], &fcf);
assert_eq!(output.convergence_records.len(), 0);
assert_eq!(output.final_lower_bound, 50.0);
assert_eq!(output.final_upper_bound, Some(60.0));
assert_eq!(output.total_time_ms, 1_000);
assert!(!output.converged);
}
#[test]
fn gap_percent_computed_correctly() {
let result = make_result("bound_stalling", 100.0, 102.0, 0.02, 3);
let fcf = make_empty_fcf();
let output = build_training_output(&result, &[], &fcf);
assert_eq!(output.final_gap_percent, Some(2.0));
}
#[test]
fn iteration_gap_percent_none_when_lb_zero_or_negative() {
let result = make_result("iteration_limit", 0.0, 10.0, 1.0, 1);
let events = vec![make_iteration_summary(1, 0.0, 10.0, 1.0)];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
assert!(output.convergence_records[0].gap_percent.is_none());
}
#[test]
fn upper_bound_std_from_forward_sync_complete() {
let result = make_result("iteration_limit", 100.0, 110.0, 0.1, 1);
let events = vec![
make_iteration_summary(1, 100.0, 110.0, 0.1),
TrainingEvent::ForwardSyncComplete {
iteration: 1,
global_ub_mean: 110.0,
global_ub_std: 3.5,
sync_time_ms: 5,
},
];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
assert_eq!(output.convergence_records[0].upper_bound_std, 3.5);
}
#[test]
fn forward_passes_from_forward_pass_complete() {
let result = make_result("iteration_limit", 100.0, 110.0, 0.1, 1);
let events = vec![
make_iteration_summary(1, 100.0, 110.0, 0.1),
TrainingEvent::ForwardPassComplete {
iteration: 1,
scenarios: 8,
ub_mean: 110.0,
ub_std: 2.0,
elapsed_ms: 40,
},
];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
assert_eq!(output.convergence_records[0].forward_passes, 8);
}
#[test]
fn cut_fields_from_backward_and_sync_events() {
let result = make_result("iteration_limit", 100.0, 110.0, 0.1, 1);
let events = vec![
make_iteration_summary(1, 100.0, 110.0, 0.1),
TrainingEvent::BackwardPassComplete {
iteration: 1,
rows_generated: 12,
stages_processed: 3,
elapsed_ms: 80,
state_exchange_time_ms: 0,
row_batch_build_time_ms: 0,
setup_time_ms: 0,
load_imbalance_ms: 0,
scheduling_overhead_ms: 0,
},
TrainingEvent::PolicySyncComplete {
iteration: 1,
rows_distributed: 12,
rows_active: 24,
rows_removed: 2,
sync_time_ms: 4,
},
];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
let rec = &output.convergence_records[0];
assert_eq!(rec.cuts_added, 12);
assert_eq!(rec.cuts_removed, 2);
assert_eq!(rec.cuts_active, 24);
}
#[test]
fn peak_active_tracks_maximum_cuts_active() {
let result = make_result("iteration_limit", 100.0, 110.0, 0.1, 3);
let events = vec![
make_iteration_summary(1, 95.0, 112.0, 0.15),
TrainingEvent::PolicySyncComplete {
iteration: 1,
rows_distributed: 10,
rows_active: 10,
rows_removed: 0,
sync_time_ms: 2,
},
make_iteration_summary(2, 98.0, 111.0, 0.12),
TrainingEvent::PolicySyncComplete {
iteration: 2,
rows_distributed: 10,
rows_active: 20,
rows_removed: 0,
sync_time_ms: 2,
},
make_iteration_summary(3, 100.0, 110.0, 0.1),
TrainingEvent::PolicySyncComplete {
iteration: 3,
rows_distributed: 5,
rows_active: 18, rows_removed: 7,
sync_time_ms: 2,
},
];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
assert_eq!(output.cut_stats.peak_active, 20);
}
#[test]
fn iterations_completed_from_result() {
let result = make_result("iteration_limit", 80.0, 100.0, 0.2, 42);
let fcf = make_empty_fcf();
let output = build_training_output(&result, &[], &fcf);
assert_eq!(output.iterations_completed, 42);
}
#[test]
fn termination_reason_copied_from_result() {
let result = make_result("time_limit", 70.0, 100.0, 0.3, 20);
let fcf = make_empty_fcf();
let output = build_training_output(&result, &[], &fcf);
assert_eq!(output.termination_reason, "time_limit");
}
#[test]
fn per_phase_timing_captured_from_sync_and_selection_events() {
let result = make_result("iteration_limit", 100.0, 110.0, 0.1, 1);
let events = vec![
TrainingEvent::IterationSummary {
iteration: 1,
lower_bound: 100.0,
upper_bound: 110.0,
gap: 0.1,
wall_time_ms: 120,
iteration_time_ms: 120,
forward_ms: 40,
backward_ms: 50,
lp_solves: 60,
solve_time_ms: 0.0,
lower_bound_eval_ms: 0,
fwd_setup_time_ms: 0,
fwd_load_imbalance_ms: 0,
fwd_scheduling_overhead_ms: 0,
rows_in_lp_sum: 0,
rows_in_lp_count: 0,
rows_in_lp_max: 0,
},
TrainingEvent::ForwardSyncComplete {
iteration: 1,
global_ub_mean: 110.0,
global_ub_std: 2.0,
sync_time_ms: 7,
},
TrainingEvent::PolicySyncComplete {
iteration: 1,
rows_distributed: 10,
rows_active: 10,
rows_removed: 0,
sync_time_ms: 5,
},
TrainingEvent::PolicySelectionComplete {
iteration: 1,
rows_deactivated: 3,
stages_processed: 2,
selection_time_ms: 8,
allgatherv_time_ms: 2,
per_stage: vec![],
},
];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
let rec = &output.convergence_records[0];
assert_eq!(
rec.time_forward_wall_ms, 40,
"forward wall must equal forward_ms"
);
assert_eq!(
rec.time_backward_wall_ms, 50,
"backward wall must equal backward_ms"
);
assert_eq!(
rec.time_mpi_allreduce_ms, 7,
"allreduce must come from ForwardSyncComplete"
);
assert_eq!(
rec.time_cut_sync_ms, 5,
"cut_sync must come from PolicySyncComplete"
);
assert_eq!(
rec.time_cut_selection_ms, 8,
"selection must come from PolicySelectionComplete"
);
}
#[test]
fn overhead_ms_is_total_minus_attributed_phases() {
let result = make_result("iteration_limit", 100.0, 110.0, 0.1, 1);
let events = vec![
TrainingEvent::IterationSummary {
iteration: 1,
lower_bound: 100.0,
upper_bound: 110.0,
gap: 0.1,
wall_time_ms: 120,
iteration_time_ms: 120,
forward_ms: 40,
backward_ms: 50,
lp_solves: 60,
solve_time_ms: 0.0,
lower_bound_eval_ms: 3,
fwd_setup_time_ms: 0,
fwd_load_imbalance_ms: 0,
fwd_scheduling_overhead_ms: 0,
rows_in_lp_sum: 0,
rows_in_lp_count: 0,
rows_in_lp_max: 0,
},
TrainingEvent::ForwardSyncComplete {
iteration: 1,
global_ub_mean: 110.0,
global_ub_std: 2.0,
sync_time_ms: 7,
},
TrainingEvent::PolicySyncComplete {
iteration: 1,
rows_distributed: 10,
rows_active: 10,
rows_removed: 0,
sync_time_ms: 5,
},
TrainingEvent::PolicySelectionComplete {
iteration: 1,
rows_deactivated: 3,
stages_processed: 2,
selection_time_ms: 8,
allgatherv_time_ms: 2,
per_stage: vec![],
},
];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
let rec = &output.convergence_records[0];
assert_eq!(
rec.time_overhead_ms, 10,
"overhead_ms must equal total(120) - attributed(110) = 10"
);
}
#[test]
fn overhead_ms_saturates_at_zero_when_attributed_exceeds_total() {
let result = make_result("iteration_limit", 100.0, 110.0, 0.1, 1);
let events = vec![TrainingEvent::IterationSummary {
iteration: 1,
lower_bound: 100.0,
upper_bound: 110.0,
gap: 0.1,
wall_time_ms: 10,
iteration_time_ms: 10,
forward_ms: 50,
backward_ms: 50,
lp_solves: 5,
solve_time_ms: 0.0,
lower_bound_eval_ms: 0,
fwd_setup_time_ms: 0,
fwd_load_imbalance_ms: 0,
fwd_scheduling_overhead_ms: 0,
rows_in_lp_sum: 0,
rows_in_lp_count: 0,
rows_in_lp_max: 0,
}];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
let rec = &output.convergence_records[0];
assert_eq!(
rec.time_overhead_ms, 0,
"overhead_ms must be 0 when attributed phases exceed total (saturating sub)"
);
}
#[test]
fn cut_selection_records_extracted_from_events() {
use cobre_core::StageRowSelectionRecord;
let result = make_result("iteration_limit", 100.0, 110.0, 0.1, 3);
let events = vec![
make_iteration_summary(1, 95.0, 112.0, 0.15),
make_iteration_summary(2, 98.0, 111.0, 0.12),
make_iteration_summary(3, 100.0, 110.0, 0.1),
TrainingEvent::PolicySelectionComplete {
iteration: 2,
rows_deactivated: 3,
stages_processed: 2,
selection_time_ms: 10,
allgatherv_time_ms: 0,
per_stage: vec![
StageRowSelectionRecord {
stage: 0,
rows_populated: 5,
rows_active_before: 5,
rows_deactivated: 0,
rows_reactivated: 0,
rows_active_after: 5,
selection_time_ms: 0.0,
budget_evicted: None,
active_after_budget: None,
rows_in_lp: 5,
},
StageRowSelectionRecord {
stage: 1,
rows_populated: 8,
rows_active_before: 8,
rows_deactivated: 3,
rows_reactivated: 0,
rows_active_after: 5,
selection_time_ms: 2.5,
budget_evicted: None,
active_after_budget: None,
rows_in_lp: 8,
},
],
},
];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
assert_eq!(output.cut_selection_records.len(), 2);
assert_eq!(output.cut_selection_records[0].iteration, 2);
assert_eq!(output.cut_selection_records[0].stage, 0);
assert_eq!(output.cut_selection_records[0].cuts_deactivated, 0);
assert_eq!(output.cut_selection_records[1].stage, 1);
assert_eq!(output.cut_selection_records[1].cuts_deactivated, 3);
assert_eq!(output.cut_selection_records[1].cuts_active_after, 5);
}
#[test]
fn no_cut_selection_events_produces_empty_records() {
let result = make_result("iteration_limit", 100.0, 110.0, 0.1, 1);
let events = vec![make_iteration_summary(1, 100.0, 110.0, 0.1)];
let fcf = make_empty_fcf();
let output = build_training_output(&result, &events, &fcf);
assert!(output.cut_selection_records.is_empty());
}
}