cobre_core/training_event.rs
1//! Typed event system for iterative optimization training loops and simulation runners.
2//!
3//! This module defines the [`TrainingEvent`] enum and its companion [`StoppingRuleResult`]
4//! struct. Events are emitted at each step of the iterative optimization lifecycle
5//! (forward pass, backward pass, convergence update, etc.) and consumed by runtime
6//! observers: text loggers, JSON-lines writers, TUI renderers, MCP progress
7//! notifications, and Parquet convergence writers.
8//!
9//! ## Design principles
10//!
11//! - **Zero-overhead when unused.** The event channel uses
12//! `Option<std::sync::mpsc::Sender<TrainingEvent>>`: when `None`, no events are
13//! emitted and no allocation occurs. When `Some(sender)`, events are moved into
14//! the channel at each lifecycle step boundary.
15//! - **No per-event timestamps.** Consumers capture wall-clock time upon receipt.
16//! This avoids `clock_gettime` syscall overhead on the hot path. The single
17//! exception is [`TrainingEvent::TrainingStarted::timestamp`], which records the
18//! run-level start time once at entry.
19//! - **Consumer-agnostic.** This module is defined in `cobre-core` (not in the
20//! algorithm crate) so that interface crates (`cobre-cli`, `cobre-tui`,
21//! `cobre-mcp`) can consume events without depending on the algorithm crate.
22//!
23//! ## Event channel pattern
24//!
25//! ```rust
26//! use std::sync::mpsc;
27//! use cobre_core::TrainingEvent;
28//!
29//! let (tx, rx) = mpsc::channel::<TrainingEvent>();
30//! // Pass `Some(tx)` to the training loop; pass `rx` to the consumer thread.
31//! drop(tx);
32//! drop(rx);
33//! ```
34//!
35//! See [`TrainingEvent`] for the full variant catalogue.
36
37/// Result of evaluating a single stopping rule at a given iteration.
38///
39/// The [`TrainingEvent::ConvergenceUpdate`] variant carries a [`Vec`] of these,
40/// one per configured stopping rule. Each element reports the rule's identifier,
41/// whether its condition is satisfied, and a human-readable description of the
42/// current state (e.g., `"gap 0.42% <= 1.00%"`).
43#[derive(Clone, Debug)]
44pub struct StoppingRuleResult {
45 /// Rule identifier matching the variant name in the stopping rules config
46 /// (e.g., `"gap_tolerance"`, `"bound_stalling"`, `"iteration_limit"`,
47 /// `"time_limit"`, `"simulation"`).
48 pub rule_name: String,
49 /// Whether this rule's condition is satisfied at the current iteration.
50 pub triggered: bool,
51 /// Human-readable description of the rule's current state
52 /// (e.g., `"gap 0.42% <= 1.00%"`, `"LB stable for 12/10 iterations"`).
53 pub detail: String,
54}
55
56/// Per-stage cut selection statistics for one iteration.
57///
58/// Each instance describes the cut lifecycle at a single stage after a
59/// selection step: how many cuts existed, how many were active before
60/// selection, how many were deactivated, and how many remain active.
61#[derive(Debug, Clone)]
62pub struct StageSelectionRecord {
63 /// 0-based stage index.
64 pub stage: u32,
65 /// Total cuts ever generated at this stage (high-water mark).
66 pub cuts_populated: u32,
67 /// Active cuts before selection ran.
68 pub cuts_active_before: u32,
69 /// Cuts deactivated by selection at this stage.
70 pub cuts_deactivated: u32,
71 /// Active cuts after selection.
72 pub cuts_active_after: u32,
73}
74
75/// Typed events emitted by an iterative optimization training loop and
76/// simulation runner.
77///
78/// The enum has 12 variants: 8 per-iteration events (one per lifecycle step)
79/// and 4 lifecycle events (emitted once per training or simulation run).
80///
81/// ## Per-iteration events (steps 1–7 + 4a)
82///
83/// | Step | Variant | When emitted |
84/// |------|--------------------------|--------------------------------------------------------|
85/// | 1 | [`Self::ForwardPassComplete`] | Local forward pass done |
86/// | 2 | [`Self::ForwardSyncComplete`] | Global allreduce of bounds done |
87/// | 3 | [`Self::BackwardPassComplete`] | Backward sweep done |
88/// | 4 | [`Self::CutSyncComplete`] | Cut allgatherv done |
89/// | 4a | [`Self::CutSelectionComplete`] | Cut selection done (conditional on `should_run`) |
90/// | 5 | [`Self::ConvergenceUpdate`] | Stopping rules evaluated |
91/// | 6 | [`Self::CheckpointComplete`] | Checkpoint written (conditional on checkpoint interval)|
92/// | 7 | [`Self::IterationSummary`] | End-of-iteration aggregated summary |
93///
94/// ## Lifecycle events
95///
96/// | Variant | When emitted |
97/// |------------------------------|-------------------------------------|
98/// | [`Self::TrainingStarted`] | Training loop entry |
99/// | [`Self::TrainingFinished`] | Training loop exit |
100/// | [`Self::SimulationProgress`] | Simulation batch completion |
101/// | [`Self::SimulationFinished`] | Simulation completion |
102#[derive(Clone, Debug)]
103pub enum TrainingEvent {
104 // ── Per-iteration events (8) ─────────────────────────────────────────────
105 /// Step 1: Forward pass completed for this iteration on the local rank.
106 ForwardPassComplete {
107 /// Iteration number (1-based).
108 iteration: u64,
109 /// Number of forward scenarios evaluated on this rank.
110 scenarios: u32,
111 /// Mean total forward cost across local scenarios.
112 ub_mean: f64,
113 /// Standard deviation of total forward cost across local scenarios.
114 ub_std: f64,
115 /// Wall-clock time for the forward pass on this rank, in milliseconds.
116 elapsed_ms: u64,
117 },
118
119 /// Step 2: Forward synchronization (allreduce) completed.
120 ///
121 /// Emitted after the global reduction of local bound estimates across all
122 /// participating ranks.
123 ForwardSyncComplete {
124 /// Iteration number (1-based).
125 iteration: u64,
126 /// Global upper bound mean after allreduce.
127 global_ub_mean: f64,
128 /// Global upper bound standard deviation after allreduce.
129 global_ub_std: f64,
130 /// Wall-clock time for the synchronization, in milliseconds.
131 sync_time_ms: u64,
132 },
133
134 /// Step 3: Backward pass completed for this iteration.
135 ///
136 /// Emitted after the full backward sweep that generates new cuts for each
137 /// stage.
138 BackwardPassComplete {
139 /// Iteration number (1-based).
140 iteration: u64,
141 /// Number of new cuts generated across all stages.
142 cuts_generated: u32,
143 /// Number of stages processed in the backward sweep.
144 stages_processed: u32,
145 /// Wall-clock time for the backward pass, in milliseconds.
146 elapsed_ms: u64,
147 /// Wall-clock time for state exchange (`allgatherv`) across all stages,
148 /// in milliseconds.
149 state_exchange_time_ms: u64,
150 /// Wall-clock time for cut batch assembly (`build_cut_row_batch_into`)
151 /// across all stages, in milliseconds.
152 cut_batch_build_time_ms: u64,
153 /// Estimated rayon barrier + scheduling overhead across all stages,
154 /// in milliseconds.
155 rayon_overhead_time_ms: u64,
156 },
157
158 /// Step 4: Cut synchronization (allgatherv) completed.
159 ///
160 /// Emitted after new cuts from all ranks have been gathered and distributed
161 /// to every rank via allgatherv.
162 CutSyncComplete {
163 /// Iteration number (1-based).
164 iteration: u64,
165 /// Number of cuts distributed to all ranks via allgatherv.
166 cuts_distributed: u32,
167 /// Total number of active cuts in the approximation after synchronization.
168 cuts_active: u32,
169 /// Number of cuts removed during synchronization.
170 cuts_removed: u32,
171 /// Wall-clock time for the synchronization, in milliseconds.
172 sync_time_ms: u64,
173 },
174
175 /// Step 4a: Cut selection completed.
176 ///
177 /// Only emitted on iterations where cut selection runs (i.e., when
178 /// `should_run(iteration)` returns `true`). On non-selection iterations
179 /// this variant is skipped entirely.
180 CutSelectionComplete {
181 /// Iteration number (1-based).
182 iteration: u64,
183 /// Number of cuts deactivated across all stages.
184 cuts_deactivated: u32,
185 /// Number of stages processed during cut selection.
186 stages_processed: u32,
187 /// Wall-clock time for the local cut selection phase, in milliseconds.
188 selection_time_ms: u64,
189 /// Wall-clock time for the allgatherv deactivation-set exchange, in
190 /// milliseconds.
191 allgatherv_time_ms: u64,
192 /// Per-stage breakdown of selection results.
193 per_stage: Vec<StageSelectionRecord>,
194 },
195
196 /// Step 5: Convergence check completed.
197 ///
198 /// Emitted after all configured stopping rules have been evaluated for the
199 /// current iteration. Contains the current bounds, gap, and per-rule results.
200 ConvergenceUpdate {
201 /// Iteration number (1-based).
202 iteration: u64,
203 /// Current lower bound (non-decreasing across iterations).
204 lower_bound: f64,
205 /// Current upper bound (statistical estimate from forward costs).
206 upper_bound: f64,
207 /// Standard deviation of the upper bound estimate.
208 upper_bound_std: f64,
209 /// Relative optimality gap: `(upper_bound - lower_bound) / |upper_bound|`.
210 gap: f64,
211 /// Evaluation result for each configured stopping rule.
212 rules_evaluated: Vec<StoppingRuleResult>,
213 },
214
215 /// Step 6: Checkpoint written.
216 ///
217 /// Only emitted when the checkpoint interval triggers (i.e., when
218 /// `iteration % checkpoint_interval == 0`). Not emitted on every iteration.
219 CheckpointComplete {
220 /// Iteration number (1-based).
221 iteration: u64,
222 /// Filesystem path where the checkpoint was written.
223 checkpoint_path: String,
224 /// Wall-clock time for the checkpoint write, in milliseconds.
225 elapsed_ms: u64,
226 },
227
228 /// Step 7: Full iteration summary with aggregated timings.
229 ///
230 /// Emitted at the end of every iteration as the final per-iteration event.
231 /// Contains all timing breakdowns for the completed iteration.
232 IterationSummary {
233 /// Iteration number (1-based).
234 iteration: u64,
235 /// Current lower bound.
236 lower_bound: f64,
237 /// Current upper bound.
238 upper_bound: f64,
239 /// Relative optimality gap: `(upper_bound - lower_bound) / |upper_bound|`.
240 gap: f64,
241 /// Cumulative wall-clock time since training started, in milliseconds.
242 wall_time_ms: u64,
243 /// Wall-clock time for this iteration only, in milliseconds.
244 iteration_time_ms: u64,
245 /// Forward pass time for this iteration, in milliseconds.
246 forward_ms: u64,
247 /// Backward pass time for this iteration, in milliseconds.
248 backward_ms: u64,
249 /// Total number of LP solves in this iteration (forward + backward stages).
250 lp_solves: u64,
251 /// Cumulative LP solve wall-clock time for this iteration, in milliseconds.
252 solve_time_ms: f64,
253 },
254
255 // ── Lifecycle events (4) ─────────────────────────────────────────────────
256 /// Emitted once when the training loop begins.
257 ///
258 /// Carries run-level metadata describing the problem size and parallelism
259 /// configuration for this training run.
260 TrainingStarted {
261 /// Case study name from the input data directory.
262 case_name: String,
263 /// Total number of stages in the optimization horizon.
264 stages: u32,
265 /// Number of hydro plants in the system.
266 hydros: u32,
267 /// Number of thermal plants in the system.
268 thermals: u32,
269 /// Number of distributed ranks participating in training.
270 ranks: u32,
271 /// Number of threads per rank.
272 threads_per_rank: u32,
273 /// Wall-clock time at training start as an ISO 8601 string
274 /// (run-level metadata, not a per-event timestamp).
275 timestamp: String,
276 },
277
278 /// Emitted once when the training loop exits (converged or limit reached).
279 TrainingFinished {
280 /// Termination reason (e.g., `"gap_tolerance"`, `"iteration_limit"`,
281 /// `"time_limit"`).
282 reason: String,
283 /// Total number of iterations completed.
284 iterations: u64,
285 /// Final lower bound at termination.
286 final_lb: f64,
287 /// Final upper bound at termination.
288 final_ub: f64,
289 /// Total wall-clock time for the training run, in milliseconds.
290 total_time_ms: u64,
291 /// Total number of cuts in the approximation at termination.
292 total_cuts: u64,
293 },
294
295 /// Emitted periodically during policy simulation (not during training).
296 ///
297 /// Consumers can use this to display a progress indicator during the
298 /// simulation phase. Each event carries the cost of the most recently
299 /// completed scenario; the progress thread accumulates statistics across
300 /// events (see ticket-007).
301 SimulationProgress {
302 /// Number of simulation scenarios completed so far.
303 scenarios_complete: u32,
304 /// Total number of simulation scenarios to run.
305 scenarios_total: u32,
306 /// Wall-clock time since simulation started, in milliseconds.
307 elapsed_ms: u64,
308 /// Total cost of the most recently completed simulation scenario,
309 /// in cost units.
310 scenario_cost: f64,
311 /// Cumulative LP solve time for this scenario, in milliseconds.
312 solve_time_ms: f64,
313 /// Number of LP solves in this scenario.
314 lp_solves: u64,
315 },
316
317 /// Emitted once when policy simulation completes.
318 SimulationFinished {
319 /// Total number of simulation scenarios evaluated.
320 scenarios: u32,
321 /// Directory where simulation output files were written.
322 output_dir: String,
323 /// Total wall-clock time for the simulation run, in milliseconds.
324 elapsed_ms: u64,
325 },
326}
327
328#[cfg(test)]
329mod tests {
330 use super::{StoppingRuleResult, TrainingEvent};
331
332 // Helper: build one of each variant with representative values.
333 fn make_all_variants() -> Vec<TrainingEvent> {
334 vec![
335 TrainingEvent::ForwardPassComplete {
336 iteration: 1,
337 scenarios: 10,
338 ub_mean: 110.0,
339 ub_std: 5.0,
340 elapsed_ms: 42,
341 },
342 TrainingEvent::ForwardSyncComplete {
343 iteration: 1,
344 global_ub_mean: 110.0,
345 global_ub_std: 5.0,
346 sync_time_ms: 3,
347 },
348 TrainingEvent::BackwardPassComplete {
349 iteration: 1,
350 cuts_generated: 48,
351 stages_processed: 12,
352 elapsed_ms: 87,
353 state_exchange_time_ms: 0,
354 cut_batch_build_time_ms: 0,
355 rayon_overhead_time_ms: 0,
356 },
357 TrainingEvent::CutSyncComplete {
358 iteration: 1,
359 cuts_distributed: 48,
360 cuts_active: 200,
361 cuts_removed: 0,
362 sync_time_ms: 2,
363 },
364 TrainingEvent::CutSelectionComplete {
365 iteration: 10,
366 cuts_deactivated: 15,
367 stages_processed: 12,
368 selection_time_ms: 20,
369 allgatherv_time_ms: 1,
370 per_stage: vec![],
371 },
372 TrainingEvent::ConvergenceUpdate {
373 iteration: 1,
374 lower_bound: 100.0,
375 upper_bound: 110.0,
376 upper_bound_std: 5.0,
377 gap: 0.0909,
378 rules_evaluated: vec![StoppingRuleResult {
379 rule_name: "gap_tolerance".to_string(),
380 triggered: false,
381 detail: "gap 9.09% > 1.00%".to_string(),
382 }],
383 },
384 TrainingEvent::CheckpointComplete {
385 iteration: 5,
386 checkpoint_path: "/tmp/checkpoint.bin".to_string(),
387 elapsed_ms: 150,
388 },
389 TrainingEvent::IterationSummary {
390 iteration: 1,
391 lower_bound: 100.0,
392 upper_bound: 110.0,
393 gap: 0.0909,
394 wall_time_ms: 1000,
395 iteration_time_ms: 200,
396 forward_ms: 80,
397 backward_ms: 100,
398 lp_solves: 240,
399 solve_time_ms: 45.2,
400 },
401 TrainingEvent::TrainingStarted {
402 case_name: "test_case".to_string(),
403 stages: 60,
404 hydros: 5,
405 thermals: 10,
406 ranks: 4,
407 threads_per_rank: 8,
408 timestamp: "2026-01-01T00:00:00Z".to_string(),
409 },
410 TrainingEvent::TrainingFinished {
411 reason: "gap_tolerance".to_string(),
412 iterations: 50,
413 final_lb: 105.0,
414 final_ub: 106.0,
415 total_time_ms: 300_000,
416 total_cuts: 2400,
417 },
418 TrainingEvent::SimulationProgress {
419 scenarios_complete: 50,
420 scenarios_total: 200,
421 elapsed_ms: 5_000,
422 scenario_cost: 45_230.0,
423 solve_time_ms: 0.0,
424 lp_solves: 0,
425 },
426 TrainingEvent::SimulationFinished {
427 scenarios: 200,
428 output_dir: "/tmp/output".to_string(),
429 elapsed_ms: 20_000,
430 },
431 ]
432 }
433
434 #[test]
435 fn all_twelve_variants_construct() {
436 let variants = make_all_variants();
437 assert_eq!(
438 variants.len(),
439 12,
440 "expected exactly 12 TrainingEvent variants"
441 );
442 }
443
444 #[test]
445 fn all_variants_clone() {
446 for variant in make_all_variants() {
447 let cloned = variant.clone();
448 // Verify the clone produces a non-empty debug string (proxy for equality).
449 assert!(!format!("{cloned:?}").is_empty());
450 }
451 }
452
453 #[test]
454 fn all_variants_debug_non_empty() {
455 for variant in make_all_variants() {
456 let debug = format!("{variant:?}");
457 assert!(!debug.is_empty(), "debug output must not be empty");
458 }
459 }
460
461 #[test]
462 fn forward_pass_complete_fields_accessible() {
463 let event = TrainingEvent::ForwardPassComplete {
464 iteration: 7,
465 scenarios: 20,
466 ub_mean: 210.0,
467 ub_std: 3.5,
468 elapsed_ms: 55,
469 };
470 let TrainingEvent::ForwardPassComplete {
471 iteration,
472 scenarios,
473 ub_mean,
474 ub_std,
475 elapsed_ms,
476 } = event
477 else {
478 panic!("wrong variant")
479 };
480 assert_eq!(iteration, 7);
481 assert_eq!(scenarios, 20);
482 assert!((ub_mean - 210.0).abs() < f64::EPSILON);
483 assert!((ub_std - 3.5).abs() < f64::EPSILON);
484 assert_eq!(elapsed_ms, 55);
485 }
486
487 #[test]
488 fn convergence_update_rules_evaluated_field() {
489 let rules = vec![
490 StoppingRuleResult {
491 rule_name: "gap_tolerance".to_string(),
492 triggered: true,
493 detail: "gap 0.42% <= 1.00%".to_string(),
494 },
495 StoppingRuleResult {
496 rule_name: "iteration_limit".to_string(),
497 triggered: false,
498 detail: "iteration 10/100".to_string(),
499 },
500 ];
501 let event = TrainingEvent::ConvergenceUpdate {
502 iteration: 10,
503 lower_bound: 99.0,
504 upper_bound: 100.0,
505 upper_bound_std: 0.5,
506 gap: 0.0042,
507 rules_evaluated: rules.clone(),
508 };
509 let TrainingEvent::ConvergenceUpdate {
510 rules_evaluated, ..
511 } = event
512 else {
513 panic!("wrong variant")
514 };
515 assert_eq!(rules_evaluated.len(), 2);
516 assert_eq!(rules_evaluated[0].rule_name, "gap_tolerance");
517 assert!(rules_evaluated[0].triggered);
518 assert_eq!(rules_evaluated[1].rule_name, "iteration_limit");
519 assert!(!rules_evaluated[1].triggered);
520 }
521
522 #[test]
523 fn stopping_rule_result_fields_accessible() {
524 let r = StoppingRuleResult {
525 rule_name: "bound_stalling".to_string(),
526 triggered: false,
527 detail: "LB stable for 8/10 iterations".to_string(),
528 };
529 let cloned = r.clone();
530 assert_eq!(cloned.rule_name, "bound_stalling");
531 assert!(!cloned.triggered);
532 assert_eq!(cloned.detail, "LB stable for 8/10 iterations");
533 }
534
535 #[test]
536 fn stopping_rule_result_debug_non_empty() {
537 let r = StoppingRuleResult {
538 rule_name: "time_limit".to_string(),
539 triggered: true,
540 detail: "elapsed 3602s > 3600s limit".to_string(),
541 };
542 let debug = format!("{r:?}");
543 assert!(!debug.is_empty());
544 assert!(debug.contains("time_limit"));
545 }
546
547 #[test]
548 fn cut_selection_complete_fields_accessible() {
549 let event = TrainingEvent::CutSelectionComplete {
550 iteration: 10,
551 cuts_deactivated: 30,
552 stages_processed: 12,
553 selection_time_ms: 25,
554 allgatherv_time_ms: 2,
555 per_stage: vec![],
556 };
557 let TrainingEvent::CutSelectionComplete {
558 iteration,
559 cuts_deactivated,
560 stages_processed,
561 selection_time_ms,
562 allgatherv_time_ms,
563 per_stage,
564 } = event
565 else {
566 panic!("wrong variant")
567 };
568 assert_eq!(iteration, 10);
569 assert_eq!(cuts_deactivated, 30);
570 assert_eq!(stages_processed, 12);
571 assert_eq!(selection_time_ms, 25);
572 assert_eq!(allgatherv_time_ms, 2);
573 assert!(per_stage.is_empty());
574 }
575
576 #[test]
577 fn training_started_timestamp_field() {
578 let event = TrainingEvent::TrainingStarted {
579 case_name: "hydro_sys".to_string(),
580 stages: 120,
581 hydros: 10,
582 thermals: 20,
583 ranks: 8,
584 threads_per_rank: 4,
585 timestamp: "2026-03-01T08:00:00Z".to_string(),
586 };
587 let TrainingEvent::TrainingStarted { timestamp, .. } = event else {
588 panic!("wrong variant")
589 };
590 assert_eq!(timestamp, "2026-03-01T08:00:00Z");
591 }
592
593 #[test]
594 fn simulation_progress_scenario_cost_field_accessible() {
595 let event = TrainingEvent::SimulationProgress {
596 scenarios_complete: 100,
597 scenarios_total: 500,
598 elapsed_ms: 10_000,
599 scenario_cost: 45_230.0,
600 solve_time_ms: 0.0,
601 lp_solves: 0,
602 };
603 let TrainingEvent::SimulationProgress {
604 scenarios_complete,
605 scenarios_total,
606 elapsed_ms,
607 scenario_cost,
608 ..
609 } = event
610 else {
611 panic!("wrong variant")
612 };
613 assert_eq!(scenarios_complete, 100);
614 assert_eq!(scenarios_total, 500);
615 assert_eq!(elapsed_ms, 10_000);
616 assert!((scenario_cost - 45_230.0).abs() < f64::EPSILON);
617 }
618
619 #[test]
620 fn simulation_progress_first_scenario_cost_carried() {
621 // The first scenario's cost is emitted directly — no aggregation needed.
622 let event = TrainingEvent::SimulationProgress {
623 scenarios_complete: 1,
624 scenarios_total: 200,
625 elapsed_ms: 100,
626 scenario_cost: 50_000.0,
627 solve_time_ms: 0.0,
628 lp_solves: 0,
629 };
630 let TrainingEvent::SimulationProgress { scenario_cost, .. } = event else {
631 panic!("wrong variant")
632 };
633 assert!((scenario_cost - 50_000.0).abs() < f64::EPSILON);
634 }
635}