Skip to main content

cobre_core/
training_event.rs

1//! Typed event system for iterative optimization training loops and simulation runners.
2//!
3//! This module defines the [`TrainingEvent`] enum and its companion [`StoppingRuleResult`]
4//! struct. Events are emitted at each step of the iterative optimization lifecycle
5//! (forward pass, backward pass, convergence update, etc.) and consumed by runtime
6//! observers: text loggers, JSON-lines writers, TUI renderers, MCP progress
7//! notifications, and Parquet convergence writers.
8//!
9//! ## Design principles
10//!
11//! - **Zero-overhead when unused.** The event channel uses
12//!   `Option<std::sync::mpsc::Sender<TrainingEvent>>`: when `None`, no events are
13//!   emitted and no allocation occurs. When `Some(sender)`, events are moved into
14//!   the channel at each lifecycle step boundary.
15//! - **No per-event timestamps.** Consumers capture wall-clock time upon receipt.
16//!   This avoids `clock_gettime` syscall overhead on the hot path. The single
17//!   exception is [`TrainingEvent::TrainingStarted::timestamp`], which records the
18//!   run-level start time once at entry.
19//! - **Consumer-agnostic.** This module is defined in `cobre-core` (not in the
20//!   algorithm crate) so that interface crates (`cobre-cli`, `cobre-tui`,
21//!   `cobre-mcp`) can consume events without depending on the algorithm crate.
22//!
23//! ## Event channel pattern
24//!
25//! ```rust
26//! use std::sync::mpsc;
27//! use cobre_core::TrainingEvent;
28//!
29//! let (tx, rx) = mpsc::channel::<TrainingEvent>();
30//! // Pass `Some(tx)` to the training loop; pass `rx` to the consumer thread.
31//! drop(tx);
32//! drop(rx);
33//! ```
34//!
35//! See [`TrainingEvent`] for the full variant catalogue.
36
37/// Result of evaluating a single stopping rule at a given iteration.
38///
39/// The [`TrainingEvent::ConvergenceUpdate`] variant carries a [`Vec`] of these,
40/// one per configured stopping rule. Each element reports the rule's identifier,
41/// whether its condition is satisfied, and a human-readable description of the
42/// current state (e.g., `"gap 0.42% <= 1.00%"`).
43#[derive(Clone, Debug)]
44pub struct StoppingRuleResult {
45    /// Rule identifier matching the variant name in the stopping rules config
46    /// (e.g., `"gap_tolerance"`, `"bound_stalling"`, `"iteration_limit"`,
47    /// `"time_limit"`, `"simulation"`).
48    pub rule_name: String,
49    /// Whether this rule's condition is satisfied at the current iteration.
50    pub triggered: bool,
51    /// Human-readable description of the rule's current state
52    /// (e.g., `"gap 0.42% <= 1.00%"`, `"LB stable for 12/10 iterations"`).
53    pub detail: String,
54}
55
56/// Typed events emitted by an iterative optimization training loop and
57/// simulation runner.
58///
59/// The enum has 12 variants: 8 per-iteration events (one per lifecycle step)
60/// and 4 lifecycle events (emitted once per training or simulation run).
61///
62/// ## Per-iteration events (steps 1–7 + 4a)
63///
64/// | Step | Variant                  | When emitted                                           |
65/// |------|--------------------------|--------------------------------------------------------|
66/// | 1    | [`Self::ForwardPassComplete`]  | Local forward pass done                                |
67/// | 2    | [`Self::ForwardSyncComplete`]  | Global allreduce of bounds done                        |
68/// | 3    | [`Self::BackwardPassComplete`] | Backward sweep done                                    |
69/// | 4    | [`Self::CutSyncComplete`]      | Cut allgatherv done                                    |
70/// | 4a   | [`Self::CutSelectionComplete`] | Cut selection done (conditional on `should_run`)       |
71/// | 5    | [`Self::ConvergenceUpdate`]    | Stopping rules evaluated                               |
72/// | 6    | [`Self::CheckpointComplete`]   | Checkpoint written (conditional on checkpoint interval)|
73/// | 7    | [`Self::IterationSummary`]     | End-of-iteration aggregated summary                    |
74///
75/// ## Lifecycle events
76///
77/// | Variant                      | When emitted                        |
78/// |------------------------------|-------------------------------------|
79/// | [`Self::TrainingStarted`]    | Training loop entry                 |
80/// | [`Self::TrainingFinished`]   | Training loop exit                  |
81/// | [`Self::SimulationProgress`] | Simulation batch completion         |
82/// | [`Self::SimulationFinished`] | Simulation completion               |
83#[derive(Clone, Debug)]
84pub enum TrainingEvent {
85    // ── Per-iteration events (8) ─────────────────────────────────────────────
86    /// Step 1: Forward pass completed for this iteration on the local rank.
87    ForwardPassComplete {
88        /// Iteration number (1-based).
89        iteration: u64,
90        /// Number of forward scenarios evaluated on this rank.
91        scenarios: u32,
92        /// Mean total forward cost across local scenarios.
93        ub_mean: f64,
94        /// Standard deviation of total forward cost across local scenarios.
95        ub_std: f64,
96        /// Wall-clock time for the forward pass on this rank, in milliseconds.
97        elapsed_ms: u64,
98    },
99
100    /// Step 2: Forward synchronization (allreduce) completed.
101    ///
102    /// Emitted after the global reduction of local bound estimates across all
103    /// participating ranks.
104    ForwardSyncComplete {
105        /// Iteration number (1-based).
106        iteration: u64,
107        /// Global upper bound mean after allreduce.
108        global_ub_mean: f64,
109        /// Global upper bound standard deviation after allreduce.
110        global_ub_std: f64,
111        /// Wall-clock time for the synchronization, in milliseconds.
112        sync_time_ms: u64,
113    },
114
115    /// Step 3: Backward pass completed for this iteration.
116    ///
117    /// Emitted after the full backward sweep that generates new cuts for each
118    /// stage.
119    BackwardPassComplete {
120        /// Iteration number (1-based).
121        iteration: u64,
122        /// Number of new cuts generated across all stages.
123        cuts_generated: u32,
124        /// Number of stages processed in the backward sweep.
125        stages_processed: u32,
126        /// Wall-clock time for the backward pass, in milliseconds.
127        elapsed_ms: u64,
128    },
129
130    /// Step 4: Cut synchronization (allgatherv) completed.
131    ///
132    /// Emitted after new cuts from all ranks have been gathered and distributed
133    /// to every rank via allgatherv.
134    CutSyncComplete {
135        /// Iteration number (1-based).
136        iteration: u64,
137        /// Number of cuts distributed to all ranks via allgatherv.
138        cuts_distributed: u32,
139        /// Total number of active cuts in the approximation after synchronization.
140        cuts_active: u32,
141        /// Number of cuts removed during synchronization.
142        cuts_removed: u32,
143        /// Wall-clock time for the synchronization, in milliseconds.
144        sync_time_ms: u64,
145    },
146
147    /// Step 4a: Cut selection completed.
148    ///
149    /// Only emitted on iterations where cut selection runs (i.e., when
150    /// `should_run(iteration)` returns `true`). On non-selection iterations
151    /// this variant is skipped entirely.
152    CutSelectionComplete {
153        /// Iteration number (1-based).
154        iteration: u64,
155        /// Number of cuts deactivated across all stages.
156        cuts_deactivated: u32,
157        /// Number of stages processed during cut selection.
158        stages_processed: u32,
159        /// Wall-clock time for the local cut selection phase, in milliseconds.
160        selection_time_ms: u64,
161        /// Wall-clock time for the allgatherv deactivation-set exchange, in
162        /// milliseconds.
163        allgatherv_time_ms: u64,
164    },
165
166    /// Step 5: Convergence check completed.
167    ///
168    /// Emitted after all configured stopping rules have been evaluated for the
169    /// current iteration. Contains the current bounds, gap, and per-rule results.
170    ConvergenceUpdate {
171        /// Iteration number (1-based).
172        iteration: u64,
173        /// Current lower bound (non-decreasing across iterations).
174        lower_bound: f64,
175        /// Current upper bound (statistical estimate from forward costs).
176        upper_bound: f64,
177        /// Standard deviation of the upper bound estimate.
178        upper_bound_std: f64,
179        /// Relative optimality gap: `(upper_bound - lower_bound) / |upper_bound|`.
180        gap: f64,
181        /// Evaluation result for each configured stopping rule.
182        rules_evaluated: Vec<StoppingRuleResult>,
183    },
184
185    /// Step 6: Checkpoint written.
186    ///
187    /// Only emitted when the checkpoint interval triggers (i.e., when
188    /// `iteration % checkpoint_interval == 0`). Not emitted on every iteration.
189    CheckpointComplete {
190        /// Iteration number (1-based).
191        iteration: u64,
192        /// Filesystem path where the checkpoint was written.
193        checkpoint_path: String,
194        /// Wall-clock time for the checkpoint write, in milliseconds.
195        elapsed_ms: u64,
196    },
197
198    /// Step 7: Full iteration summary with aggregated timings.
199    ///
200    /// Emitted at the end of every iteration as the final per-iteration event.
201    /// Contains all timing breakdowns for the completed iteration.
202    IterationSummary {
203        /// Iteration number (1-based).
204        iteration: u64,
205        /// Current lower bound.
206        lower_bound: f64,
207        /// Current upper bound.
208        upper_bound: f64,
209        /// Relative optimality gap: `(upper_bound - lower_bound) / |upper_bound|`.
210        gap: f64,
211        /// Cumulative wall-clock time since training started, in milliseconds.
212        wall_time_ms: u64,
213        /// Wall-clock time for this iteration only, in milliseconds.
214        iteration_time_ms: u64,
215        /// Forward pass time for this iteration, in milliseconds.
216        forward_ms: u64,
217        /// Backward pass time for this iteration, in milliseconds.
218        backward_ms: u64,
219        /// Total number of LP solves in this iteration (forward + backward stages).
220        lp_solves: u64,
221    },
222
223    // ── Lifecycle events (4) ─────────────────────────────────────────────────
224    /// Emitted once when the training loop begins.
225    ///
226    /// Carries run-level metadata describing the problem size and parallelism
227    /// configuration for this training run.
228    TrainingStarted {
229        /// Case study name from the input data directory.
230        case_name: String,
231        /// Total number of stages in the optimization horizon.
232        stages: u32,
233        /// Number of hydro plants in the system.
234        hydros: u32,
235        /// Number of thermal plants in the system.
236        thermals: u32,
237        /// Number of distributed ranks participating in training.
238        ranks: u32,
239        /// Number of threads per rank.
240        threads_per_rank: u32,
241        /// Wall-clock time at training start as an ISO 8601 string
242        /// (run-level metadata, not a per-event timestamp).
243        timestamp: String,
244    },
245
246    /// Emitted once when the training loop exits (converged or limit reached).
247    TrainingFinished {
248        /// Termination reason (e.g., `"gap_tolerance"`, `"iteration_limit"`,
249        /// `"time_limit"`).
250        reason: String,
251        /// Total number of iterations completed.
252        iterations: u64,
253        /// Final lower bound at termination.
254        final_lb: f64,
255        /// Final upper bound at termination.
256        final_ub: f64,
257        /// Total wall-clock time for the training run, in milliseconds.
258        total_time_ms: u64,
259        /// Total number of cuts in the approximation at termination.
260        total_cuts: u64,
261    },
262
263    /// Emitted periodically during policy simulation (not during training).
264    ///
265    /// Consumers can use this to display a progress indicator during the
266    /// simulation phase, including live cost statistics as scenarios complete.
267    SimulationProgress {
268        /// Number of simulation scenarios completed so far.
269        scenarios_complete: u32,
270        /// Total number of simulation scenarios to run.
271        scenarios_total: u32,
272        /// Wall-clock time since simulation started, in milliseconds.
273        elapsed_ms: u64,
274        /// Running mean of total scenario costs computed so far, in cost units.
275        ///
276        /// Populated by the emitter after each batch of completed scenarios.
277        mean_cost: f64,
278        /// Running standard deviation of total scenario costs.
279        ///
280        /// Set to `0.0` when `scenarios_complete < 2` (insufficient data for
281        /// variance estimation).
282        std_cost: f64,
283        /// Half-width of the 95% confidence interval, in cost units.
284        ///
285        /// Computed as `1.96 * std_cost / sqrt(scenarios_complete)`.
286        /// Set to `0.0` when `scenarios_complete < 2`.
287        ci_95_half_width: f64,
288    },
289
290    /// Emitted once when policy simulation completes.
291    SimulationFinished {
292        /// Total number of simulation scenarios evaluated.
293        scenarios: u32,
294        /// Directory where simulation output files were written.
295        output_dir: String,
296        /// Total wall-clock time for the simulation run, in milliseconds.
297        elapsed_ms: u64,
298    },
299}
300
301#[cfg(test)]
302mod tests {
303    use super::{StoppingRuleResult, TrainingEvent};
304
305    // Helper: build one of each variant with representative values.
306    fn make_all_variants() -> Vec<TrainingEvent> {
307        vec![
308            TrainingEvent::ForwardPassComplete {
309                iteration: 1,
310                scenarios: 10,
311                ub_mean: 110.0,
312                ub_std: 5.0,
313                elapsed_ms: 42,
314            },
315            TrainingEvent::ForwardSyncComplete {
316                iteration: 1,
317                global_ub_mean: 110.0,
318                global_ub_std: 5.0,
319                sync_time_ms: 3,
320            },
321            TrainingEvent::BackwardPassComplete {
322                iteration: 1,
323                cuts_generated: 48,
324                stages_processed: 12,
325                elapsed_ms: 87,
326            },
327            TrainingEvent::CutSyncComplete {
328                iteration: 1,
329                cuts_distributed: 48,
330                cuts_active: 200,
331                cuts_removed: 0,
332                sync_time_ms: 2,
333            },
334            TrainingEvent::CutSelectionComplete {
335                iteration: 10,
336                cuts_deactivated: 15,
337                stages_processed: 12,
338                selection_time_ms: 20,
339                allgatherv_time_ms: 1,
340            },
341            TrainingEvent::ConvergenceUpdate {
342                iteration: 1,
343                lower_bound: 100.0,
344                upper_bound: 110.0,
345                upper_bound_std: 5.0,
346                gap: 0.0909,
347                rules_evaluated: vec![StoppingRuleResult {
348                    rule_name: "gap_tolerance".to_string(),
349                    triggered: false,
350                    detail: "gap 9.09% > 1.00%".to_string(),
351                }],
352            },
353            TrainingEvent::CheckpointComplete {
354                iteration: 5,
355                checkpoint_path: "/tmp/checkpoint.bin".to_string(),
356                elapsed_ms: 150,
357            },
358            TrainingEvent::IterationSummary {
359                iteration: 1,
360                lower_bound: 100.0,
361                upper_bound: 110.0,
362                gap: 0.0909,
363                wall_time_ms: 1000,
364                iteration_time_ms: 200,
365                forward_ms: 80,
366                backward_ms: 100,
367                lp_solves: 240,
368            },
369            TrainingEvent::TrainingStarted {
370                case_name: "test_case".to_string(),
371                stages: 60,
372                hydros: 5,
373                thermals: 10,
374                ranks: 4,
375                threads_per_rank: 8,
376                timestamp: "2026-01-01T00:00:00Z".to_string(),
377            },
378            TrainingEvent::TrainingFinished {
379                reason: "gap_tolerance".to_string(),
380                iterations: 50,
381                final_lb: 105.0,
382                final_ub: 106.0,
383                total_time_ms: 300_000,
384                total_cuts: 2400,
385            },
386            TrainingEvent::SimulationProgress {
387                scenarios_complete: 50,
388                scenarios_total: 200,
389                elapsed_ms: 5_000,
390                mean_cost: 45_230.0,
391                std_cost: 3_100.0,
392                ci_95_half_width: 430.0,
393            },
394            TrainingEvent::SimulationFinished {
395                scenarios: 200,
396                output_dir: "/tmp/output".to_string(),
397                elapsed_ms: 20_000,
398            },
399        ]
400    }
401
402    #[test]
403    fn all_twelve_variants_construct() {
404        let variants = make_all_variants();
405        assert_eq!(
406            variants.len(),
407            12,
408            "expected exactly 12 TrainingEvent variants"
409        );
410    }
411
412    #[test]
413    fn all_variants_clone() {
414        for variant in make_all_variants() {
415            let cloned = variant.clone();
416            // Verify the clone produces a non-empty debug string (proxy for equality).
417            assert!(!format!("{cloned:?}").is_empty());
418        }
419    }
420
421    #[test]
422    fn all_variants_debug_non_empty() {
423        for variant in make_all_variants() {
424            let debug = format!("{variant:?}");
425            assert!(!debug.is_empty(), "debug output must not be empty");
426        }
427    }
428
429    #[test]
430    fn forward_pass_complete_fields_accessible() {
431        let event = TrainingEvent::ForwardPassComplete {
432            iteration: 7,
433            scenarios: 20,
434            ub_mean: 210.0,
435            ub_std: 3.5,
436            elapsed_ms: 55,
437        };
438        let TrainingEvent::ForwardPassComplete {
439            iteration,
440            scenarios,
441            ub_mean,
442            ub_std,
443            elapsed_ms,
444        } = event
445        else {
446            panic!("wrong variant")
447        };
448        assert_eq!(iteration, 7);
449        assert_eq!(scenarios, 20);
450        assert!((ub_mean - 210.0).abs() < f64::EPSILON);
451        assert!((ub_std - 3.5).abs() < f64::EPSILON);
452        assert_eq!(elapsed_ms, 55);
453    }
454
455    #[test]
456    fn convergence_update_rules_evaluated_field() {
457        let rules = vec![
458            StoppingRuleResult {
459                rule_name: "gap_tolerance".to_string(),
460                triggered: true,
461                detail: "gap 0.42% <= 1.00%".to_string(),
462            },
463            StoppingRuleResult {
464                rule_name: "iteration_limit".to_string(),
465                triggered: false,
466                detail: "iteration 10/100".to_string(),
467            },
468        ];
469        let event = TrainingEvent::ConvergenceUpdate {
470            iteration: 10,
471            lower_bound: 99.0,
472            upper_bound: 100.0,
473            upper_bound_std: 0.5,
474            gap: 0.0042,
475            rules_evaluated: rules.clone(),
476        };
477        let TrainingEvent::ConvergenceUpdate {
478            rules_evaluated, ..
479        } = event
480        else {
481            panic!("wrong variant")
482        };
483        assert_eq!(rules_evaluated.len(), 2);
484        assert_eq!(rules_evaluated[0].rule_name, "gap_tolerance");
485        assert!(rules_evaluated[0].triggered);
486        assert_eq!(rules_evaluated[1].rule_name, "iteration_limit");
487        assert!(!rules_evaluated[1].triggered);
488    }
489
490    #[test]
491    fn stopping_rule_result_fields_accessible() {
492        let r = StoppingRuleResult {
493            rule_name: "bound_stalling".to_string(),
494            triggered: false,
495            detail: "LB stable for 8/10 iterations".to_string(),
496        };
497        let cloned = r.clone();
498        assert_eq!(cloned.rule_name, "bound_stalling");
499        assert!(!cloned.triggered);
500        assert_eq!(cloned.detail, "LB stable for 8/10 iterations");
501    }
502
503    #[test]
504    fn stopping_rule_result_debug_non_empty() {
505        let r = StoppingRuleResult {
506            rule_name: "time_limit".to_string(),
507            triggered: true,
508            detail: "elapsed 3602s > 3600s limit".to_string(),
509        };
510        let debug = format!("{r:?}");
511        assert!(!debug.is_empty());
512        assert!(debug.contains("time_limit"));
513    }
514
515    #[test]
516    fn cut_selection_complete_fields_accessible() {
517        let event = TrainingEvent::CutSelectionComplete {
518            iteration: 10,
519            cuts_deactivated: 30,
520            stages_processed: 12,
521            selection_time_ms: 25,
522            allgatherv_time_ms: 2,
523        };
524        let TrainingEvent::CutSelectionComplete {
525            iteration,
526            cuts_deactivated,
527            stages_processed,
528            selection_time_ms,
529            allgatherv_time_ms,
530        } = event
531        else {
532            panic!("wrong variant")
533        };
534        assert_eq!(iteration, 10);
535        assert_eq!(cuts_deactivated, 30);
536        assert_eq!(stages_processed, 12);
537        assert_eq!(selection_time_ms, 25);
538        assert_eq!(allgatherv_time_ms, 2);
539    }
540
541    #[test]
542    fn training_started_timestamp_field() {
543        let event = TrainingEvent::TrainingStarted {
544            case_name: "hydro_sys".to_string(),
545            stages: 120,
546            hydros: 10,
547            thermals: 20,
548            ranks: 8,
549            threads_per_rank: 4,
550            timestamp: "2026-03-01T08:00:00Z".to_string(),
551        };
552        let TrainingEvent::TrainingStarted { timestamp, .. } = event else {
553            panic!("wrong variant")
554        };
555        assert_eq!(timestamp, "2026-03-01T08:00:00Z");
556    }
557
558    #[test]
559    fn simulation_progress_statistics_fields_accessible() {
560        let event = TrainingEvent::SimulationProgress {
561            scenarios_complete: 100,
562            scenarios_total: 500,
563            elapsed_ms: 10_000,
564            mean_cost: 45_230.0,
565            std_cost: 3_100.0,
566            ci_95_half_width: 430.0,
567        };
568        let TrainingEvent::SimulationProgress {
569            scenarios_complete,
570            scenarios_total,
571            elapsed_ms,
572            mean_cost,
573            std_cost,
574            ci_95_half_width,
575        } = event
576        else {
577            panic!("wrong variant")
578        };
579        assert_eq!(scenarios_complete, 100);
580        assert_eq!(scenarios_total, 500);
581        assert_eq!(elapsed_ms, 10_000);
582        assert!((mean_cost - 45_230.0).abs() < f64::EPSILON);
583        assert!((std_cost - 3_100.0).abs() < f64::EPSILON);
584        assert!((ci_95_half_width - 430.0).abs() < f64::EPSILON);
585    }
586
587    #[test]
588    fn simulation_progress_zero_stats_when_insufficient_data() {
589        // When scenarios_complete < 2, std_cost and ci_95_half_width must be 0.0.
590        let event = TrainingEvent::SimulationProgress {
591            scenarios_complete: 1,
592            scenarios_total: 200,
593            elapsed_ms: 100,
594            mean_cost: 50_000.0,
595            std_cost: 0.0,
596            ci_95_half_width: 0.0,
597        };
598        let TrainingEvent::SimulationProgress {
599            std_cost,
600            ci_95_half_width,
601            ..
602        } = event
603        else {
604            panic!("wrong variant")
605        };
606        assert_eq!(std_cost, 0.0);
607        assert_eq!(ci_95_half_width, 0.0);
608    }
609}