cobre_core/training_event.rs
1//! Typed event system for iterative optimization training loops and simulation runners.
2//!
3//! This module defines the [`TrainingEvent`] enum and its companion [`StoppingRuleResult`]
4//! struct. Events are emitted at each step of the iterative optimization lifecycle
5//! (forward pass, backward pass, convergence update, etc.) and consumed by runtime
6//! observers: text loggers, JSON-lines writers, TUI renderers, MCP progress
7//! notifications, and Parquet convergence writers.
8//!
9//! ## Design principles
10//!
11//! - **Zero-overhead when unused.** The event channel uses
12//! `Option<std::sync::mpsc::Sender<TrainingEvent>>`: when `None`, no events are
13//! emitted and no allocation occurs. When `Some(sender)`, events are moved into
14//! the channel at each lifecycle step boundary.
15//! - **No per-event timestamps.** Consumers capture wall-clock time upon receipt.
16//! This avoids `clock_gettime` syscall overhead on the hot path. The single
17//! exception is [`TrainingEvent::TrainingStarted::timestamp`], which records the
18//! run-level start time once at entry.
19//! - **Consumer-agnostic.** This module is defined in `cobre-core` (not in the
20//! algorithm crate) so that interface crates (`cobre-cli`, `cobre-tui`,
21//! `cobre-mcp`) can consume events without depending on the algorithm crate.
22//!
23//! ## Event channel pattern
24//!
25//! ```rust
26//! use std::sync::mpsc;
27//! use cobre_core::TrainingEvent;
28//!
29//! let (tx, rx) = mpsc::channel::<TrainingEvent>();
30//! // Pass `Some(tx)` to the training loop; pass `rx` to the consumer thread.
31//! drop(tx);
32//! drop(rx);
33//! ```
34//!
35//! See [`TrainingEvent`] for the full variant catalogue.
36
37/// Result of evaluating a single stopping rule at a given iteration.
38///
39/// The [`TrainingEvent::ConvergenceUpdate`] variant carries a [`Vec`] of these,
40/// one per configured stopping rule. Each element reports the rule's identifier,
41/// whether its condition is satisfied, and a human-readable description of the
42/// current state (e.g., `"gap 0.42% <= 1.00%"`).
43#[derive(Clone, Debug)]
44pub struct StoppingRuleResult {
45 /// Rule identifier matching the variant name in the stopping rules config
46 /// (e.g., `"gap_tolerance"`, `"bound_stalling"`, `"iteration_limit"`,
47 /// `"time_limit"`, `"simulation"`).
48 pub rule_name: String,
49 /// Whether this rule's condition is satisfied at the current iteration.
50 pub triggered: bool,
51 /// Human-readable description of the rule's current state
52 /// (e.g., `"gap 0.42% <= 1.00%"`, `"LB stable for 12/10 iterations"`).
53 pub detail: String,
54}
55
56/// Typed events emitted by an iterative optimization training loop and
57/// simulation runner.
58///
59/// The enum has 12 variants: 8 per-iteration events (one per lifecycle step)
60/// and 4 lifecycle events (emitted once per training or simulation run).
61///
62/// ## Per-iteration events (steps 1–7 + 4a)
63///
64/// | Step | Variant | When emitted |
65/// |------|--------------------------|--------------------------------------------------------|
66/// | 1 | [`Self::ForwardPassComplete`] | Local forward pass done |
67/// | 2 | [`Self::ForwardSyncComplete`] | Global allreduce of bounds done |
68/// | 3 | [`Self::BackwardPassComplete`] | Backward sweep done |
69/// | 4 | [`Self::CutSyncComplete`] | Cut allgatherv done |
70/// | 4a | [`Self::CutSelectionComplete`] | Cut selection done (conditional on `should_run`) |
71/// | 5 | [`Self::ConvergenceUpdate`] | Stopping rules evaluated |
72/// | 6 | [`Self::CheckpointComplete`] | Checkpoint written (conditional on checkpoint interval)|
73/// | 7 | [`Self::IterationSummary`] | End-of-iteration aggregated summary |
74///
75/// ## Lifecycle events
76///
77/// | Variant | When emitted |
78/// |------------------------------|-------------------------------------|
79/// | [`Self::TrainingStarted`] | Training loop entry |
80/// | [`Self::TrainingFinished`] | Training loop exit |
81/// | [`Self::SimulationProgress`] | Simulation batch completion |
82/// | [`Self::SimulationFinished`] | Simulation completion |
83#[derive(Clone, Debug)]
84pub enum TrainingEvent {
85 // ── Per-iteration events (8) ─────────────────────────────────────────────
86 /// Step 1: Forward pass completed for this iteration on the local rank.
87 ForwardPassComplete {
88 /// Iteration number (1-based).
89 iteration: u64,
90 /// Number of forward scenarios evaluated on this rank.
91 scenarios: u32,
92 /// Mean total forward cost across local scenarios.
93 ub_mean: f64,
94 /// Standard deviation of total forward cost across local scenarios.
95 ub_std: f64,
96 /// Wall-clock time for the forward pass on this rank, in milliseconds.
97 elapsed_ms: u64,
98 },
99
100 /// Step 2: Forward synchronization (allreduce) completed.
101 ///
102 /// Emitted after the global reduction of local bound estimates across all
103 /// participating ranks.
104 ForwardSyncComplete {
105 /// Iteration number (1-based).
106 iteration: u64,
107 /// Global upper bound mean after allreduce.
108 global_ub_mean: f64,
109 /// Global upper bound standard deviation after allreduce.
110 global_ub_std: f64,
111 /// Wall-clock time for the synchronization, in milliseconds.
112 sync_time_ms: u64,
113 },
114
115 /// Step 3: Backward pass completed for this iteration.
116 ///
117 /// Emitted after the full backward sweep that generates new cuts for each
118 /// stage.
119 BackwardPassComplete {
120 /// Iteration number (1-based).
121 iteration: u64,
122 /// Number of new cuts generated across all stages.
123 cuts_generated: u32,
124 /// Number of stages processed in the backward sweep.
125 stages_processed: u32,
126 /// Wall-clock time for the backward pass, in milliseconds.
127 elapsed_ms: u64,
128 },
129
130 /// Step 4: Cut synchronization (allgatherv) completed.
131 ///
132 /// Emitted after new cuts from all ranks have been gathered and distributed
133 /// to every rank via allgatherv.
134 CutSyncComplete {
135 /// Iteration number (1-based).
136 iteration: u64,
137 /// Number of cuts distributed to all ranks via allgatherv.
138 cuts_distributed: u32,
139 /// Total number of active cuts in the approximation after synchronization.
140 cuts_active: u32,
141 /// Number of cuts removed during synchronization.
142 cuts_removed: u32,
143 /// Wall-clock time for the synchronization, in milliseconds.
144 sync_time_ms: u64,
145 },
146
147 /// Step 4a: Cut selection completed.
148 ///
149 /// Only emitted on iterations where cut selection runs (i.e., when
150 /// `should_run(iteration)` returns `true`). On non-selection iterations
151 /// this variant is skipped entirely.
152 CutSelectionComplete {
153 /// Iteration number (1-based).
154 iteration: u64,
155 /// Number of cuts deactivated across all stages.
156 cuts_deactivated: u32,
157 /// Number of stages processed during cut selection.
158 stages_processed: u32,
159 /// Wall-clock time for the local cut selection phase, in milliseconds.
160 selection_time_ms: u64,
161 /// Wall-clock time for the allgatherv deactivation-set exchange, in
162 /// milliseconds.
163 allgatherv_time_ms: u64,
164 },
165
166 /// Step 5: Convergence check completed.
167 ///
168 /// Emitted after all configured stopping rules have been evaluated for the
169 /// current iteration. Contains the current bounds, gap, and per-rule results.
170 ConvergenceUpdate {
171 /// Iteration number (1-based).
172 iteration: u64,
173 /// Current lower bound (non-decreasing across iterations).
174 lower_bound: f64,
175 /// Current upper bound (statistical estimate from forward costs).
176 upper_bound: f64,
177 /// Standard deviation of the upper bound estimate.
178 upper_bound_std: f64,
179 /// Relative optimality gap: `(upper_bound - lower_bound) / |upper_bound|`.
180 gap: f64,
181 /// Evaluation result for each configured stopping rule.
182 rules_evaluated: Vec<StoppingRuleResult>,
183 },
184
185 /// Step 6: Checkpoint written.
186 ///
187 /// Only emitted when the checkpoint interval triggers (i.e., when
188 /// `iteration % checkpoint_interval == 0`). Not emitted on every iteration.
189 CheckpointComplete {
190 /// Iteration number (1-based).
191 iteration: u64,
192 /// Filesystem path where the checkpoint was written.
193 checkpoint_path: String,
194 /// Wall-clock time for the checkpoint write, in milliseconds.
195 elapsed_ms: u64,
196 },
197
198 /// Step 7: Full iteration summary with aggregated timings.
199 ///
200 /// Emitted at the end of every iteration as the final per-iteration event.
201 /// Contains all timing breakdowns for the completed iteration.
202 IterationSummary {
203 /// Iteration number (1-based).
204 iteration: u64,
205 /// Current lower bound.
206 lower_bound: f64,
207 /// Current upper bound.
208 upper_bound: f64,
209 /// Relative optimality gap: `(upper_bound - lower_bound) / |upper_bound|`.
210 gap: f64,
211 /// Cumulative wall-clock time since training started, in milliseconds.
212 wall_time_ms: u64,
213 /// Wall-clock time for this iteration only, in milliseconds.
214 iteration_time_ms: u64,
215 /// Forward pass time for this iteration, in milliseconds.
216 forward_ms: u64,
217 /// Backward pass time for this iteration, in milliseconds.
218 backward_ms: u64,
219 /// Total number of LP solves in this iteration (forward + backward stages).
220 lp_solves: u64,
221 },
222
223 // ── Lifecycle events (4) ─────────────────────────────────────────────────
224 /// Emitted once when the training loop begins.
225 ///
226 /// Carries run-level metadata describing the problem size and parallelism
227 /// configuration for this training run.
228 TrainingStarted {
229 /// Case study name from the input data directory.
230 case_name: String,
231 /// Total number of stages in the optimization horizon.
232 stages: u32,
233 /// Number of hydro plants in the system.
234 hydros: u32,
235 /// Number of thermal plants in the system.
236 thermals: u32,
237 /// Number of distributed ranks participating in training.
238 ranks: u32,
239 /// Number of threads per rank.
240 threads_per_rank: u32,
241 /// Wall-clock time at training start as an ISO 8601 string
242 /// (run-level metadata, not a per-event timestamp).
243 timestamp: String,
244 },
245
246 /// Emitted once when the training loop exits (converged or limit reached).
247 TrainingFinished {
248 /// Termination reason (e.g., `"gap_tolerance"`, `"iteration_limit"`,
249 /// `"time_limit"`).
250 reason: String,
251 /// Total number of iterations completed.
252 iterations: u64,
253 /// Final lower bound at termination.
254 final_lb: f64,
255 /// Final upper bound at termination.
256 final_ub: f64,
257 /// Total wall-clock time for the training run, in milliseconds.
258 total_time_ms: u64,
259 /// Total number of cuts in the approximation at termination.
260 total_cuts: u64,
261 },
262
263 /// Emitted periodically during policy simulation (not during training).
264 ///
265 /// Consumers can use this to display a progress indicator during the
266 /// simulation phase, including live cost statistics as scenarios complete.
267 SimulationProgress {
268 /// Number of simulation scenarios completed so far.
269 scenarios_complete: u32,
270 /// Total number of simulation scenarios to run.
271 scenarios_total: u32,
272 /// Wall-clock time since simulation started, in milliseconds.
273 elapsed_ms: u64,
274 /// Running mean of total scenario costs computed so far, in cost units.
275 ///
276 /// Populated by the emitter after each batch of completed scenarios.
277 mean_cost: f64,
278 /// Running standard deviation of total scenario costs.
279 ///
280 /// Set to `0.0` when `scenarios_complete < 2` (insufficient data for
281 /// variance estimation).
282 std_cost: f64,
283 /// Half-width of the 95% confidence interval, in cost units.
284 ///
285 /// Computed as `1.96 * std_cost / sqrt(scenarios_complete)`.
286 /// Set to `0.0` when `scenarios_complete < 2`.
287 ci_95_half_width: f64,
288 },
289
290 /// Emitted once when policy simulation completes.
291 SimulationFinished {
292 /// Total number of simulation scenarios evaluated.
293 scenarios: u32,
294 /// Directory where simulation output files were written.
295 output_dir: String,
296 /// Total wall-clock time for the simulation run, in milliseconds.
297 elapsed_ms: u64,
298 },
299}
300
301#[cfg(test)]
302mod tests {
303 use super::{StoppingRuleResult, TrainingEvent};
304
305 // Helper: build one of each variant with representative values.
306 fn make_all_variants() -> Vec<TrainingEvent> {
307 vec![
308 TrainingEvent::ForwardPassComplete {
309 iteration: 1,
310 scenarios: 10,
311 ub_mean: 110.0,
312 ub_std: 5.0,
313 elapsed_ms: 42,
314 },
315 TrainingEvent::ForwardSyncComplete {
316 iteration: 1,
317 global_ub_mean: 110.0,
318 global_ub_std: 5.0,
319 sync_time_ms: 3,
320 },
321 TrainingEvent::BackwardPassComplete {
322 iteration: 1,
323 cuts_generated: 48,
324 stages_processed: 12,
325 elapsed_ms: 87,
326 },
327 TrainingEvent::CutSyncComplete {
328 iteration: 1,
329 cuts_distributed: 48,
330 cuts_active: 200,
331 cuts_removed: 0,
332 sync_time_ms: 2,
333 },
334 TrainingEvent::CutSelectionComplete {
335 iteration: 10,
336 cuts_deactivated: 15,
337 stages_processed: 12,
338 selection_time_ms: 20,
339 allgatherv_time_ms: 1,
340 },
341 TrainingEvent::ConvergenceUpdate {
342 iteration: 1,
343 lower_bound: 100.0,
344 upper_bound: 110.0,
345 upper_bound_std: 5.0,
346 gap: 0.0909,
347 rules_evaluated: vec![StoppingRuleResult {
348 rule_name: "gap_tolerance".to_string(),
349 triggered: false,
350 detail: "gap 9.09% > 1.00%".to_string(),
351 }],
352 },
353 TrainingEvent::CheckpointComplete {
354 iteration: 5,
355 checkpoint_path: "/tmp/checkpoint.bin".to_string(),
356 elapsed_ms: 150,
357 },
358 TrainingEvent::IterationSummary {
359 iteration: 1,
360 lower_bound: 100.0,
361 upper_bound: 110.0,
362 gap: 0.0909,
363 wall_time_ms: 1000,
364 iteration_time_ms: 200,
365 forward_ms: 80,
366 backward_ms: 100,
367 lp_solves: 240,
368 },
369 TrainingEvent::TrainingStarted {
370 case_name: "test_case".to_string(),
371 stages: 60,
372 hydros: 5,
373 thermals: 10,
374 ranks: 4,
375 threads_per_rank: 8,
376 timestamp: "2026-01-01T00:00:00Z".to_string(),
377 },
378 TrainingEvent::TrainingFinished {
379 reason: "gap_tolerance".to_string(),
380 iterations: 50,
381 final_lb: 105.0,
382 final_ub: 106.0,
383 total_time_ms: 300_000,
384 total_cuts: 2400,
385 },
386 TrainingEvent::SimulationProgress {
387 scenarios_complete: 50,
388 scenarios_total: 200,
389 elapsed_ms: 5_000,
390 mean_cost: 45_230.0,
391 std_cost: 3_100.0,
392 ci_95_half_width: 430.0,
393 },
394 TrainingEvent::SimulationFinished {
395 scenarios: 200,
396 output_dir: "/tmp/output".to_string(),
397 elapsed_ms: 20_000,
398 },
399 ]
400 }
401
402 #[test]
403 fn all_twelve_variants_construct() {
404 let variants = make_all_variants();
405 assert_eq!(
406 variants.len(),
407 12,
408 "expected exactly 12 TrainingEvent variants"
409 );
410 }
411
412 #[test]
413 fn all_variants_clone() {
414 for variant in make_all_variants() {
415 let cloned = variant.clone();
416 // Verify the clone produces a non-empty debug string (proxy for equality).
417 assert!(!format!("{cloned:?}").is_empty());
418 }
419 }
420
421 #[test]
422 fn all_variants_debug_non_empty() {
423 for variant in make_all_variants() {
424 let debug = format!("{variant:?}");
425 assert!(!debug.is_empty(), "debug output must not be empty");
426 }
427 }
428
429 #[test]
430 fn forward_pass_complete_fields_accessible() {
431 let event = TrainingEvent::ForwardPassComplete {
432 iteration: 7,
433 scenarios: 20,
434 ub_mean: 210.0,
435 ub_std: 3.5,
436 elapsed_ms: 55,
437 };
438 let TrainingEvent::ForwardPassComplete {
439 iteration,
440 scenarios,
441 ub_mean,
442 ub_std,
443 elapsed_ms,
444 } = event
445 else {
446 panic!("wrong variant")
447 };
448 assert_eq!(iteration, 7);
449 assert_eq!(scenarios, 20);
450 assert!((ub_mean - 210.0).abs() < f64::EPSILON);
451 assert!((ub_std - 3.5).abs() < f64::EPSILON);
452 assert_eq!(elapsed_ms, 55);
453 }
454
455 #[test]
456 fn convergence_update_rules_evaluated_field() {
457 let rules = vec![
458 StoppingRuleResult {
459 rule_name: "gap_tolerance".to_string(),
460 triggered: true,
461 detail: "gap 0.42% <= 1.00%".to_string(),
462 },
463 StoppingRuleResult {
464 rule_name: "iteration_limit".to_string(),
465 triggered: false,
466 detail: "iteration 10/100".to_string(),
467 },
468 ];
469 let event = TrainingEvent::ConvergenceUpdate {
470 iteration: 10,
471 lower_bound: 99.0,
472 upper_bound: 100.0,
473 upper_bound_std: 0.5,
474 gap: 0.0042,
475 rules_evaluated: rules.clone(),
476 };
477 let TrainingEvent::ConvergenceUpdate {
478 rules_evaluated, ..
479 } = event
480 else {
481 panic!("wrong variant")
482 };
483 assert_eq!(rules_evaluated.len(), 2);
484 assert_eq!(rules_evaluated[0].rule_name, "gap_tolerance");
485 assert!(rules_evaluated[0].triggered);
486 assert_eq!(rules_evaluated[1].rule_name, "iteration_limit");
487 assert!(!rules_evaluated[1].triggered);
488 }
489
490 #[test]
491 fn stopping_rule_result_fields_accessible() {
492 let r = StoppingRuleResult {
493 rule_name: "bound_stalling".to_string(),
494 triggered: false,
495 detail: "LB stable for 8/10 iterations".to_string(),
496 };
497 let cloned = r.clone();
498 assert_eq!(cloned.rule_name, "bound_stalling");
499 assert!(!cloned.triggered);
500 assert_eq!(cloned.detail, "LB stable for 8/10 iterations");
501 }
502
503 #[test]
504 fn stopping_rule_result_debug_non_empty() {
505 let r = StoppingRuleResult {
506 rule_name: "time_limit".to_string(),
507 triggered: true,
508 detail: "elapsed 3602s > 3600s limit".to_string(),
509 };
510 let debug = format!("{r:?}");
511 assert!(!debug.is_empty());
512 assert!(debug.contains("time_limit"));
513 }
514
515 #[test]
516 fn cut_selection_complete_fields_accessible() {
517 let event = TrainingEvent::CutSelectionComplete {
518 iteration: 10,
519 cuts_deactivated: 30,
520 stages_processed: 12,
521 selection_time_ms: 25,
522 allgatherv_time_ms: 2,
523 };
524 let TrainingEvent::CutSelectionComplete {
525 iteration,
526 cuts_deactivated,
527 stages_processed,
528 selection_time_ms,
529 allgatherv_time_ms,
530 } = event
531 else {
532 panic!("wrong variant")
533 };
534 assert_eq!(iteration, 10);
535 assert_eq!(cuts_deactivated, 30);
536 assert_eq!(stages_processed, 12);
537 assert_eq!(selection_time_ms, 25);
538 assert_eq!(allgatherv_time_ms, 2);
539 }
540
541 #[test]
542 fn training_started_timestamp_field() {
543 let event = TrainingEvent::TrainingStarted {
544 case_name: "hydro_sys".to_string(),
545 stages: 120,
546 hydros: 10,
547 thermals: 20,
548 ranks: 8,
549 threads_per_rank: 4,
550 timestamp: "2026-03-01T08:00:00Z".to_string(),
551 };
552 let TrainingEvent::TrainingStarted { timestamp, .. } = event else {
553 panic!("wrong variant")
554 };
555 assert_eq!(timestamp, "2026-03-01T08:00:00Z");
556 }
557
558 #[test]
559 fn simulation_progress_statistics_fields_accessible() {
560 let event = TrainingEvent::SimulationProgress {
561 scenarios_complete: 100,
562 scenarios_total: 500,
563 elapsed_ms: 10_000,
564 mean_cost: 45_230.0,
565 std_cost: 3_100.0,
566 ci_95_half_width: 430.0,
567 };
568 let TrainingEvent::SimulationProgress {
569 scenarios_complete,
570 scenarios_total,
571 elapsed_ms,
572 mean_cost,
573 std_cost,
574 ci_95_half_width,
575 } = event
576 else {
577 panic!("wrong variant")
578 };
579 assert_eq!(scenarios_complete, 100);
580 assert_eq!(scenarios_total, 500);
581 assert_eq!(elapsed_ms, 10_000);
582 assert!((mean_cost - 45_230.0).abs() < f64::EPSILON);
583 assert!((std_cost - 3_100.0).abs() < f64::EPSILON);
584 assert!((ci_95_half_width - 430.0).abs() < f64::EPSILON);
585 }
586
587 #[test]
588 fn simulation_progress_zero_stats_when_insufficient_data() {
589 // When scenarios_complete < 2, std_cost and ci_95_half_width must be 0.0.
590 let event = TrainingEvent::SimulationProgress {
591 scenarios_complete: 1,
592 scenarios_total: 200,
593 elapsed_ms: 100,
594 mean_cost: 50_000.0,
595 std_cost: 0.0,
596 ci_95_half_width: 0.0,
597 };
598 let TrainingEvent::SimulationProgress {
599 std_cost,
600 ci_95_half_width,
601 ..
602 } = event
603 else {
604 panic!("wrong variant")
605 };
606 assert_eq!(std_cost, 0.0);
607 assert_eq!(ci_95_half_width, 0.0);
608 }
609}