Skip to main content

cobre_core/
training_event.rs

1//! Typed event system for iterative optimization training loops and simulation runners.
2//!
3//! This module defines the [`TrainingEvent`] enum and its companion [`StoppingRuleResult`]
4//! struct. Events are emitted at each step of the iterative optimization lifecycle
5//! (forward pass, backward pass, convergence update, etc.) and consumed by runtime
6//! observers: text loggers, JSON-lines writers, TUI renderers, MCP progress
7//! notifications, and Parquet convergence writers.
8//!
9//! ## Design principles
10//!
11//! - **Zero-overhead when unused.** The event channel uses
12//!   `Option<std::sync::mpsc::Sender<TrainingEvent>>`: when `None`, no events are
13//!   emitted and no allocation occurs. When `Some(sender)`, events are moved into
14//!   the channel at each lifecycle step boundary.
15//! - **No per-event timestamps.** Consumers capture wall-clock time upon receipt.
16//!   This avoids `clock_gettime` syscall overhead on the hot path. The single
17//!   exception is [`TrainingEvent::TrainingStarted::timestamp`], which records the
18//!   run-level start time once at entry.
19//! - **Consumer-agnostic.** This module is defined in `cobre-core` (not in the
20//!   algorithm crate) so that interface crates (`cobre-cli`, `cobre-tui`,
21//!   `cobre-mcp`) can consume events without depending on the algorithm crate.
22//!
23//! ## Event channel pattern
24//!
25//! ```rust
26//! use std::sync::mpsc;
27//! use cobre_core::TrainingEvent;
28//!
29//! let (tx, rx) = mpsc::channel::<TrainingEvent>();
30//! // Pass `Some(tx)` to the training loop; pass `rx` to the consumer thread.
31//! drop(tx);
32//! drop(rx);
33//! ```
34//!
35//! See [`TrainingEvent`] for the full variant catalogue.
36
37/// Result of evaluating a single stopping rule at a given iteration.
38///
39/// The [`TrainingEvent::ConvergenceUpdate`] variant carries a [`Vec`] of these,
40/// one per configured stopping rule. Each element reports the rule's identifier,
41/// whether its condition is satisfied, and a human-readable description of the
42/// current state (e.g., `"gap 0.42% <= 1.00%"`).
43#[derive(Clone, Debug)]
44pub struct StoppingRuleResult {
45    /// Rule identifier matching the variant name in the stopping rules config
46    /// (e.g., `"gap_tolerance"`, `"bound_stalling"`, `"iteration_limit"`,
47    /// `"time_limit"`, `"simulation"`).
48    pub rule_name: String,
49    /// Whether this rule's condition is satisfied at the current iteration.
50    pub triggered: bool,
51    /// Human-readable description of the rule's current state
52    /// (e.g., `"gap 0.42% <= 1.00%"`, `"LB stable for 12/10 iterations"`).
53    pub detail: String,
54}
55
56/// Per-stage cut selection statistics for one iteration.
57///
58/// Each instance describes the cut lifecycle at a single stage after a
59/// selection step: how many cuts existed, how many were active before
60/// selection, how many were deactivated, and how many remain active.
61#[derive(Debug, Clone)]
62pub struct StageSelectionRecord {
63    /// 0-based stage index.
64    pub stage: u32,
65    /// Total cuts ever generated at this stage (high-water mark).
66    pub cuts_populated: u32,
67    /// Active cuts before selection ran.
68    pub cuts_active_before: u32,
69    /// Cuts deactivated by selection at this stage.
70    pub cuts_deactivated: u32,
71    /// Active cuts after selection.
72    pub cuts_active_after: u32,
73}
74
75/// Typed events emitted by an iterative optimization training loop and
76/// simulation runner.
77///
78/// The enum has 12 variants: 8 per-iteration events (one per lifecycle step)
79/// and 4 lifecycle events (emitted once per training or simulation run).
80///
81/// ## Per-iteration events (steps 1–7 + 4a)
82///
83/// | Step | Variant                  | When emitted                                           |
84/// |------|--------------------------|--------------------------------------------------------|
85/// | 1    | [`Self::ForwardPassComplete`]  | Local forward pass done                                |
86/// | 2    | [`Self::ForwardSyncComplete`]  | Global allreduce of bounds done                        |
87/// | 3    | [`Self::BackwardPassComplete`] | Backward sweep done                                    |
88/// | 4    | [`Self::CutSyncComplete`]      | Cut allgatherv done                                    |
89/// | 4a   | [`Self::CutSelectionComplete`] | Cut selection done (conditional on `should_run`)       |
90/// | 5    | [`Self::ConvergenceUpdate`]    | Stopping rules evaluated                               |
91/// | 6    | [`Self::CheckpointComplete`]   | Checkpoint written (conditional on checkpoint interval)|
92/// | 7    | [`Self::IterationSummary`]     | End-of-iteration aggregated summary                    |
93///
94/// ## Lifecycle events
95///
96/// | Variant                      | When emitted                        |
97/// |------------------------------|-------------------------------------|
98/// | [`Self::TrainingStarted`]    | Training loop entry                 |
99/// | [`Self::TrainingFinished`]   | Training loop exit                  |
100/// | [`Self::SimulationProgress`] | Simulation batch completion         |
101/// | [`Self::SimulationFinished`] | Simulation completion               |
102#[derive(Clone, Debug)]
103pub enum TrainingEvent {
104    // ── Per-iteration events (8) ─────────────────────────────────────────────
105    /// Step 1: Forward pass completed for this iteration on the local rank.
106    ForwardPassComplete {
107        /// Iteration number (1-based).
108        iteration: u64,
109        /// Number of forward scenarios evaluated on this rank.
110        scenarios: u32,
111        /// Mean total forward cost across local scenarios.
112        ub_mean: f64,
113        /// Standard deviation of total forward cost across local scenarios.
114        ub_std: f64,
115        /// Wall-clock time for the forward pass on this rank, in milliseconds.
116        elapsed_ms: u64,
117    },
118
119    /// Step 2: Forward synchronization (allreduce) completed.
120    ///
121    /// Emitted after the global reduction of local bound estimates across all
122    /// participating ranks.
123    ForwardSyncComplete {
124        /// Iteration number (1-based).
125        iteration: u64,
126        /// Global upper bound mean after allreduce.
127        global_ub_mean: f64,
128        /// Global upper bound standard deviation after allreduce.
129        global_ub_std: f64,
130        /// Wall-clock time for the synchronization, in milliseconds.
131        sync_time_ms: u64,
132    },
133
134    /// Step 3: Backward pass completed for this iteration.
135    ///
136    /// Emitted after the full backward sweep that generates new cuts for each
137    /// stage.
138    BackwardPassComplete {
139        /// Iteration number (1-based).
140        iteration: u64,
141        /// Number of new cuts generated across all stages.
142        cuts_generated: u32,
143        /// Number of stages processed in the backward sweep.
144        stages_processed: u32,
145        /// Wall-clock time for the backward pass, in milliseconds.
146        elapsed_ms: u64,
147        /// Wall-clock time for state exchange (`allgatherv`) across all stages,
148        /// in milliseconds.
149        state_exchange_time_ms: u64,
150        /// Wall-clock time for cut batch assembly (`build_cut_row_batch_into`)
151        /// across all stages, in milliseconds.
152        cut_batch_build_time_ms: u64,
153        /// Estimated rayon barrier + scheduling overhead across all stages,
154        /// in milliseconds.
155        rayon_overhead_time_ms: u64,
156    },
157
158    /// Step 4: Cut synchronization (allgatherv) completed.
159    ///
160    /// Emitted after new cuts from all ranks have been gathered and distributed
161    /// to every rank via allgatherv.
162    CutSyncComplete {
163        /// Iteration number (1-based).
164        iteration: u64,
165        /// Number of cuts distributed to all ranks via allgatherv.
166        cuts_distributed: u32,
167        /// Total number of active cuts in the approximation after synchronization.
168        cuts_active: u32,
169        /// Number of cuts removed during synchronization.
170        cuts_removed: u32,
171        /// Wall-clock time for the synchronization, in milliseconds.
172        sync_time_ms: u64,
173    },
174
175    /// Step 4a: Cut selection completed.
176    ///
177    /// Only emitted on iterations where cut selection runs (i.e., when
178    /// `should_run(iteration)` returns `true`). On non-selection iterations
179    /// this variant is skipped entirely.
180    CutSelectionComplete {
181        /// Iteration number (1-based).
182        iteration: u64,
183        /// Number of cuts deactivated across all stages.
184        cuts_deactivated: u32,
185        /// Number of stages processed during cut selection.
186        stages_processed: u32,
187        /// Wall-clock time for the local cut selection phase, in milliseconds.
188        selection_time_ms: u64,
189        /// Wall-clock time for the allgatherv deactivation-set exchange, in
190        /// milliseconds.
191        allgatherv_time_ms: u64,
192        /// Per-stage breakdown of selection results.
193        per_stage: Vec<StageSelectionRecord>,
194    },
195
196    /// Step 5: Convergence check completed.
197    ///
198    /// Emitted after all configured stopping rules have been evaluated for the
199    /// current iteration. Contains the current bounds, gap, and per-rule results.
200    ConvergenceUpdate {
201        /// Iteration number (1-based).
202        iteration: u64,
203        /// Current lower bound (non-decreasing across iterations).
204        lower_bound: f64,
205        /// Current upper bound (statistical estimate from forward costs).
206        upper_bound: f64,
207        /// Standard deviation of the upper bound estimate.
208        upper_bound_std: f64,
209        /// Relative optimality gap: `(upper_bound - lower_bound) / |upper_bound|`.
210        gap: f64,
211        /// Evaluation result for each configured stopping rule.
212        rules_evaluated: Vec<StoppingRuleResult>,
213    },
214
215    /// Step 6: Checkpoint written.
216    ///
217    /// Only emitted when the checkpoint interval triggers (i.e., when
218    /// `iteration % checkpoint_interval == 0`). Not emitted on every iteration.
219    CheckpointComplete {
220        /// Iteration number (1-based).
221        iteration: u64,
222        /// Filesystem path where the checkpoint was written.
223        checkpoint_path: String,
224        /// Wall-clock time for the checkpoint write, in milliseconds.
225        elapsed_ms: u64,
226    },
227
228    /// Step 7: Full iteration summary with aggregated timings.
229    ///
230    /// Emitted at the end of every iteration as the final per-iteration event.
231    /// Contains all timing breakdowns for the completed iteration.
232    IterationSummary {
233        /// Iteration number (1-based).
234        iteration: u64,
235        /// Current lower bound.
236        lower_bound: f64,
237        /// Current upper bound.
238        upper_bound: f64,
239        /// Relative optimality gap: `(upper_bound - lower_bound) / |upper_bound|`.
240        gap: f64,
241        /// Cumulative wall-clock time since training started, in milliseconds.
242        wall_time_ms: u64,
243        /// Wall-clock time for this iteration only, in milliseconds.
244        iteration_time_ms: u64,
245        /// Forward pass time for this iteration, in milliseconds.
246        forward_ms: u64,
247        /// Backward pass time for this iteration, in milliseconds.
248        backward_ms: u64,
249        /// Total number of LP solves in this iteration (forward + backward stages).
250        lp_solves: u64,
251        /// Cumulative LP solve wall-clock time for this iteration, in milliseconds.
252        solve_time_ms: f64,
253    },
254
255    // ── Lifecycle events (4) ─────────────────────────────────────────────────
256    /// Emitted once when the training loop begins.
257    ///
258    /// Carries run-level metadata describing the problem size and parallelism
259    /// configuration for this training run.
260    TrainingStarted {
261        /// Case study name from the input data directory.
262        case_name: String,
263        /// Total number of stages in the optimization horizon.
264        stages: u32,
265        /// Number of hydro plants in the system.
266        hydros: u32,
267        /// Number of thermal plants in the system.
268        thermals: u32,
269        /// Number of distributed ranks participating in training.
270        ranks: u32,
271        /// Number of threads per rank.
272        threads_per_rank: u32,
273        /// Wall-clock time at training start as an ISO 8601 string
274        /// (run-level metadata, not a per-event timestamp).
275        timestamp: String,
276    },
277
278    /// Emitted once when the training loop exits (converged or limit reached).
279    TrainingFinished {
280        /// Termination reason (e.g., `"gap_tolerance"`, `"iteration_limit"`,
281        /// `"time_limit"`).
282        reason: String,
283        /// Total number of iterations completed.
284        iterations: u64,
285        /// Final lower bound at termination.
286        final_lb: f64,
287        /// Final upper bound at termination.
288        final_ub: f64,
289        /// Total wall-clock time for the training run, in milliseconds.
290        total_time_ms: u64,
291        /// Total number of cuts in the approximation at termination.
292        total_cuts: u64,
293    },
294
295    /// Emitted periodically during policy simulation (not during training).
296    ///
297    /// Consumers can use this to display a progress indicator during the
298    /// simulation phase. Each event carries the cost of the most recently
299    /// completed scenario; the progress thread accumulates statistics across
300    /// events (see ticket-007).
301    SimulationProgress {
302        /// Number of simulation scenarios completed so far.
303        scenarios_complete: u32,
304        /// Total number of simulation scenarios to run.
305        scenarios_total: u32,
306        /// Wall-clock time since simulation started, in milliseconds.
307        elapsed_ms: u64,
308        /// Total cost of the most recently completed simulation scenario,
309        /// in cost units.
310        scenario_cost: f64,
311        /// Cumulative LP solve time for this scenario, in milliseconds.
312        solve_time_ms: f64,
313        /// Number of LP solves in this scenario.
314        lp_solves: u64,
315    },
316
317    /// Emitted once when policy simulation completes.
318    SimulationFinished {
319        /// Total number of simulation scenarios evaluated.
320        scenarios: u32,
321        /// Directory where simulation output files were written.
322        output_dir: String,
323        /// Total wall-clock time for the simulation run, in milliseconds.
324        elapsed_ms: u64,
325    },
326}
327
328#[cfg(test)]
329mod tests {
330    use super::{StoppingRuleResult, TrainingEvent};
331
332    // Helper: build one of each variant with representative values.
333    fn make_all_variants() -> Vec<TrainingEvent> {
334        vec![
335            TrainingEvent::ForwardPassComplete {
336                iteration: 1,
337                scenarios: 10,
338                ub_mean: 110.0,
339                ub_std: 5.0,
340                elapsed_ms: 42,
341            },
342            TrainingEvent::ForwardSyncComplete {
343                iteration: 1,
344                global_ub_mean: 110.0,
345                global_ub_std: 5.0,
346                sync_time_ms: 3,
347            },
348            TrainingEvent::BackwardPassComplete {
349                iteration: 1,
350                cuts_generated: 48,
351                stages_processed: 12,
352                elapsed_ms: 87,
353                state_exchange_time_ms: 0,
354                cut_batch_build_time_ms: 0,
355                rayon_overhead_time_ms: 0,
356            },
357            TrainingEvent::CutSyncComplete {
358                iteration: 1,
359                cuts_distributed: 48,
360                cuts_active: 200,
361                cuts_removed: 0,
362                sync_time_ms: 2,
363            },
364            TrainingEvent::CutSelectionComplete {
365                iteration: 10,
366                cuts_deactivated: 15,
367                stages_processed: 12,
368                selection_time_ms: 20,
369                allgatherv_time_ms: 1,
370                per_stage: vec![],
371            },
372            TrainingEvent::ConvergenceUpdate {
373                iteration: 1,
374                lower_bound: 100.0,
375                upper_bound: 110.0,
376                upper_bound_std: 5.0,
377                gap: 0.0909,
378                rules_evaluated: vec![StoppingRuleResult {
379                    rule_name: "gap_tolerance".to_string(),
380                    triggered: false,
381                    detail: "gap 9.09% > 1.00%".to_string(),
382                }],
383            },
384            TrainingEvent::CheckpointComplete {
385                iteration: 5,
386                checkpoint_path: "/tmp/checkpoint.bin".to_string(),
387                elapsed_ms: 150,
388            },
389            TrainingEvent::IterationSummary {
390                iteration: 1,
391                lower_bound: 100.0,
392                upper_bound: 110.0,
393                gap: 0.0909,
394                wall_time_ms: 1000,
395                iteration_time_ms: 200,
396                forward_ms: 80,
397                backward_ms: 100,
398                lp_solves: 240,
399                solve_time_ms: 45.2,
400            },
401            TrainingEvent::TrainingStarted {
402                case_name: "test_case".to_string(),
403                stages: 60,
404                hydros: 5,
405                thermals: 10,
406                ranks: 4,
407                threads_per_rank: 8,
408                timestamp: "2026-01-01T00:00:00Z".to_string(),
409            },
410            TrainingEvent::TrainingFinished {
411                reason: "gap_tolerance".to_string(),
412                iterations: 50,
413                final_lb: 105.0,
414                final_ub: 106.0,
415                total_time_ms: 300_000,
416                total_cuts: 2400,
417            },
418            TrainingEvent::SimulationProgress {
419                scenarios_complete: 50,
420                scenarios_total: 200,
421                elapsed_ms: 5_000,
422                scenario_cost: 45_230.0,
423                solve_time_ms: 0.0,
424                lp_solves: 0,
425            },
426            TrainingEvent::SimulationFinished {
427                scenarios: 200,
428                output_dir: "/tmp/output".to_string(),
429                elapsed_ms: 20_000,
430            },
431        ]
432    }
433
434    #[test]
435    fn all_twelve_variants_construct() {
436        let variants = make_all_variants();
437        assert_eq!(
438            variants.len(),
439            12,
440            "expected exactly 12 TrainingEvent variants"
441        );
442    }
443
444    #[test]
445    fn all_variants_clone() {
446        for variant in make_all_variants() {
447            let cloned = variant.clone();
448            // Verify the clone produces a non-empty debug string (proxy for equality).
449            assert!(!format!("{cloned:?}").is_empty());
450        }
451    }
452
453    #[test]
454    fn all_variants_debug_non_empty() {
455        for variant in make_all_variants() {
456            let debug = format!("{variant:?}");
457            assert!(!debug.is_empty(), "debug output must not be empty");
458        }
459    }
460
461    #[test]
462    fn forward_pass_complete_fields_accessible() {
463        let event = TrainingEvent::ForwardPassComplete {
464            iteration: 7,
465            scenarios: 20,
466            ub_mean: 210.0,
467            ub_std: 3.5,
468            elapsed_ms: 55,
469        };
470        let TrainingEvent::ForwardPassComplete {
471            iteration,
472            scenarios,
473            ub_mean,
474            ub_std,
475            elapsed_ms,
476        } = event
477        else {
478            panic!("wrong variant")
479        };
480        assert_eq!(iteration, 7);
481        assert_eq!(scenarios, 20);
482        assert!((ub_mean - 210.0).abs() < f64::EPSILON);
483        assert!((ub_std - 3.5).abs() < f64::EPSILON);
484        assert_eq!(elapsed_ms, 55);
485    }
486
487    #[test]
488    fn convergence_update_rules_evaluated_field() {
489        let rules = vec![
490            StoppingRuleResult {
491                rule_name: "gap_tolerance".to_string(),
492                triggered: true,
493                detail: "gap 0.42% <= 1.00%".to_string(),
494            },
495            StoppingRuleResult {
496                rule_name: "iteration_limit".to_string(),
497                triggered: false,
498                detail: "iteration 10/100".to_string(),
499            },
500        ];
501        let event = TrainingEvent::ConvergenceUpdate {
502            iteration: 10,
503            lower_bound: 99.0,
504            upper_bound: 100.0,
505            upper_bound_std: 0.5,
506            gap: 0.0042,
507            rules_evaluated: rules.clone(),
508        };
509        let TrainingEvent::ConvergenceUpdate {
510            rules_evaluated, ..
511        } = event
512        else {
513            panic!("wrong variant")
514        };
515        assert_eq!(rules_evaluated.len(), 2);
516        assert_eq!(rules_evaluated[0].rule_name, "gap_tolerance");
517        assert!(rules_evaluated[0].triggered);
518        assert_eq!(rules_evaluated[1].rule_name, "iteration_limit");
519        assert!(!rules_evaluated[1].triggered);
520    }
521
522    #[test]
523    fn stopping_rule_result_fields_accessible() {
524        let r = StoppingRuleResult {
525            rule_name: "bound_stalling".to_string(),
526            triggered: false,
527            detail: "LB stable for 8/10 iterations".to_string(),
528        };
529        let cloned = r.clone();
530        assert_eq!(cloned.rule_name, "bound_stalling");
531        assert!(!cloned.triggered);
532        assert_eq!(cloned.detail, "LB stable for 8/10 iterations");
533    }
534
535    #[test]
536    fn stopping_rule_result_debug_non_empty() {
537        let r = StoppingRuleResult {
538            rule_name: "time_limit".to_string(),
539            triggered: true,
540            detail: "elapsed 3602s > 3600s limit".to_string(),
541        };
542        let debug = format!("{r:?}");
543        assert!(!debug.is_empty());
544        assert!(debug.contains("time_limit"));
545    }
546
547    #[test]
548    fn cut_selection_complete_fields_accessible() {
549        let event = TrainingEvent::CutSelectionComplete {
550            iteration: 10,
551            cuts_deactivated: 30,
552            stages_processed: 12,
553            selection_time_ms: 25,
554            allgatherv_time_ms: 2,
555            per_stage: vec![],
556        };
557        let TrainingEvent::CutSelectionComplete {
558            iteration,
559            cuts_deactivated,
560            stages_processed,
561            selection_time_ms,
562            allgatherv_time_ms,
563            per_stage,
564        } = event
565        else {
566            panic!("wrong variant")
567        };
568        assert_eq!(iteration, 10);
569        assert_eq!(cuts_deactivated, 30);
570        assert_eq!(stages_processed, 12);
571        assert_eq!(selection_time_ms, 25);
572        assert_eq!(allgatherv_time_ms, 2);
573        assert!(per_stage.is_empty());
574    }
575
576    #[test]
577    fn training_started_timestamp_field() {
578        let event = TrainingEvent::TrainingStarted {
579            case_name: "hydro_sys".to_string(),
580            stages: 120,
581            hydros: 10,
582            thermals: 20,
583            ranks: 8,
584            threads_per_rank: 4,
585            timestamp: "2026-03-01T08:00:00Z".to_string(),
586        };
587        let TrainingEvent::TrainingStarted { timestamp, .. } = event else {
588            panic!("wrong variant")
589        };
590        assert_eq!(timestamp, "2026-03-01T08:00:00Z");
591    }
592
593    #[test]
594    fn simulation_progress_scenario_cost_field_accessible() {
595        let event = TrainingEvent::SimulationProgress {
596            scenarios_complete: 100,
597            scenarios_total: 500,
598            elapsed_ms: 10_000,
599            scenario_cost: 45_230.0,
600            solve_time_ms: 0.0,
601            lp_solves: 0,
602        };
603        let TrainingEvent::SimulationProgress {
604            scenarios_complete,
605            scenarios_total,
606            elapsed_ms,
607            scenario_cost,
608            ..
609        } = event
610        else {
611            panic!("wrong variant")
612        };
613        assert_eq!(scenarios_complete, 100);
614        assert_eq!(scenarios_total, 500);
615        assert_eq!(elapsed_ms, 10_000);
616        assert!((scenario_cost - 45_230.0).abs() < f64::EPSILON);
617    }
618
619    #[test]
620    fn simulation_progress_first_scenario_cost_carried() {
621        // The first scenario's cost is emitted directly — no aggregation needed.
622        let event = TrainingEvent::SimulationProgress {
623            scenarios_complete: 1,
624            scenarios_total: 200,
625            elapsed_ms: 100,
626            scenario_cost: 50_000.0,
627            solve_time_ms: 0.0,
628            lp_solves: 0,
629        };
630        let TrainingEvent::SimulationProgress { scenario_cost, .. } = event else {
631            panic!("wrong variant")
632        };
633        assert!((scenario_cost - 50_000.0).abs() < f64::EPSILON);
634    }
635}