Skip to main content

cobre_sddp/setup/
orchestration.rs

1//! Orchestration methods: train, simulate, and workspace pool construction.
2
3use std::sync::Arc;
4use std::sync::atomic::AtomicBool;
5use std::sync::mpsc::{Sender, SyncSender};
6
7use cobre_comm::Communicator;
8use cobre_core::TrainingEvent;
9use cobre_solver::{SolverError, SolverInterface};
10
11use crate::{
12    config::{CutManagementConfig, EventConfig, LoopConfig, TrainingConfig},
13    context::{StageContext, TrainingContext},
14    dcs::DcsParams,
15    error::SddpError,
16    simulation::{
17        SimulationOutputSpec, error::SimulationError, pipeline::SimulationRunResult,
18        types::SimulationScenarioResult,
19    },
20    training::{TrainingOutcome, TrainingResult},
21    workspace::{CapturedBasis, SolverWorkspace, WorkspacePool, WorkspaceSizing},
22};
23
24use super::StudySetup;
25
26impl StudySetup {
27    /// Execute the training loop.
28    ///
29    /// Constructs [`TrainingConfig`] and [`TrainingContext`], then delegates to
30    /// [`crate::train`]. Mutates `self.fcf` to store generated cuts.
31    ///
32    /// # Errors
33    ///
34    /// Returns `SddpError::Infeasible`, `SddpError::Solver`, or
35    /// `SddpError::Communication` on LP, solver, or MPI failure.
36    pub fn train<S, C: Communicator>(
37        &mut self,
38        solver: &mut S,
39        comm: &C,
40        n_threads: usize,
41        solver_factory: impl Fn() -> Result<S, SolverError>,
42        event_sender: Option<Sender<TrainingEvent>>,
43        shutdown_flag: Option<&Arc<AtomicBool>>,
44    ) -> Result<TrainingOutcome, SddpError>
45    where
46        S: SolverInterface<Profile = cobre_solver::ActiveProfile> + Send,
47    {
48        let training_config = TrainingConfig {
49            loop_config: LoopConfig {
50                forward_passes: self.loop_params.forward_passes,
51                max_iterations: self.loop_params.max_iterations,
52                start_iteration: self.loop_params.start_iteration,
53                n_fwd_threads: n_threads,
54                max_blocks: self.loop_params.max_blocks,
55                stopping_rules: self.loop_params.stopping_rules.clone(),
56            },
57            cut_management: CutManagementConfig {
58                cut_selection: self.cut_management.cut_selection.clone(),
59                budget: self.cut_management.budget,
60                cut_activity_tolerance: self.cut_management.cut_activity_tolerance,
61                warm_start_cuts: 0,
62                risk_measures: self.cut_management.risk_measures.clone(),
63            },
64            events: EventConfig {
65                event_sender,
66                checkpoint_interval: None,
67                shutdown_flag: shutdown_flag.map(Arc::clone),
68                export_states: self.events.export_states,
69            },
70        };
71
72        let stage_ctx = StageContext {
73            templates: &self.stage_data.stage_templates.templates,
74            base_rows: &self.stage_data.stage_templates.base_rows,
75            noise_scale: &self.stage_data.stage_templates.noise_scale,
76            n_hydros: self.stage_data.stage_templates.n_hydros,
77            n_load_buses: self.stage_data.stage_templates.n_load_buses,
78            load_balance_row_starts: &self.stage_data.stage_templates.load_balance_row_starts,
79            load_bus_indices: &self.stage_data.stage_templates.load_bus_indices,
80            block_counts_per_stage: &self.stage_data.block_counts_per_stage,
81            ncs_max_gen: &self.ncs_max_gen,
82            ncs_allow_curtailment: &self.ncs_allow_curtailment,
83            discount_factors: &self.stage_data.stage_templates.discount_factors,
84            cumulative_discount_factors: &self
85                .stage_data
86                .stage_templates
87                .cumulative_discount_factors,
88            stage_lag_transitions: &self.stage_data.stage_lag_transitions,
89            noise_group_ids: &self.stage_data.noise_group_ids,
90            downstream_par_order: self.downstream_par_order,
91        };
92
93        let tr = &self.scenario_libraries.training;
94        let training_ctx = TrainingContext {
95            horizon: &self.methodology.horizon,
96            indexer: &self.stage_data.indexer,
97            inflow_method: &self.methodology.inflow_method,
98            stochastic: &self.stochastic,
99            initial_state: &self.initial_state,
100            inflow_scheme: tr.inflow_scheme,
101            load_scheme: tr.load_scheme,
102            ncs_scheme: tr.ncs_scheme,
103            stages: &self.stage_data.stages,
104            historical_library: tr.historical.as_ref(),
105            external_inflow_library: tr.external_inflow.as_ref(),
106            external_load_library: tr.external_load.as_ref(),
107            external_ncs_library: tr.external_ncs.as_ref(),
108            recent_accum_seed: &self.recent_observation_seed.accum_seed,
109            recent_weight_seed: self.recent_observation_seed.weight_seed,
110            // DCS params reach the backward hot path via this context field.
111            // `Some` only for the dynamic cut-selection method; `None` otherwise.
112            dcs: self
113                .cut_management
114                .cut_selection
115                .as_ref()
116                .and_then(DcsParams::from_strategy),
117            // Throwaway backward diagnostic; `Some` only when `COBRE_W1_DIAG`
118            // was set at setup, else `None` (byte-identical default path).
119            noise_key_diag: self.noise_key_diag.as_ref(),
120        };
121
122        // Hand the warm-start basis cache (if any) to the training session so
123        // iteration 1's cut-loaded LPs warm-start from the checkpoint's stored
124        // bases. `take` leaves `None` behind — fresh starts pass `None` and are
125        // untouched.
126        let warm_start_basis_cache = self.warm_start_basis_cache.take();
127
128        crate::train(
129            solver,
130            training_config,
131            &mut self.fcf,
132            &stage_ctx,
133            &training_ctx,
134            comm,
135            solver_factory,
136            warm_start_basis_cache,
137        )
138    }
139
140    /// Run simulation using the trained future cost function.
141    ///
142    /// The caller provides channels, event sender, and thread management.
143    /// `baked_templates` enables the baked-template LP load path (no `add_rows`
144    /// per stage); pass `None` for the legacy `load_model + add_rows` fallback.
145    /// `stage_bases` enables warm-start; pass `&[]` for cold-start.
146    ///
147    /// # Errors
148    ///
149    /// Returns `SimulationError` on LP infeasibility, solver failure, channel closure,
150    /// or if `baked_templates.len() != num_stages`.
151    pub fn simulate<S, C: Communicator>(
152        &self,
153        workspaces: &mut [SolverWorkspace<S>],
154        comm: &C,
155        result_tx: &SyncSender<SimulationScenarioResult>,
156        event_sender: Option<Sender<TrainingEvent>>,
157        baked_templates: Option<&[cobre_solver::StageTemplate]>,
158        stage_bases: &[Option<CapturedBasis>],
159    ) -> Result<SimulationRunResult, SimulationError>
160    where
161        S: SolverInterface<Profile = cobre_solver::ActiveProfile> + Send,
162    {
163        let stage_ctx = self.stage_ctx();
164        let training_ctx = self.simulation_ctx();
165
166        let output = SimulationOutputSpec {
167            result_tx,
168            zeta_per_stage: &self.stage_data.stage_templates.zeta_per_stage,
169            block_hours_per_stage: &self.stage_data.stage_templates.block_hours_per_stage,
170            entity_counts: &self.stage_data.entity_counts,
171            generic_constraint_row_entries: &self
172                .stage_data
173                .stage_templates
174                .generic_constraint_row_entries,
175            ncs_col_starts: &self.stage_data.stage_templates.ncs_col_starts,
176            n_ncs_per_stage: &self.stage_data.stage_templates.n_ncs_per_stage,
177            ncs_entity_ids_per_stage: &self.ncs_entity_ids_per_stage,
178            diversion_upstream: &self.stage_data.stage_templates.diversion_upstream,
179            hydro_productivities_per_stage: &self
180                .stage_data
181                .stage_templates
182                .hydro_productivities_per_stage,
183            energy_conversion: &self.energy_conversion,
184            hydro_min_storage_hm3: &self.hydro_min_storage_hm3,
185            event_sender,
186        };
187
188        crate::simulate(
189            workspaces,
190            &stage_ctx,
191            &self.fcf,
192            &training_ctx,
193            self.simulation_config(),
194            output,
195            baked_templates,
196            stage_bases,
197            comm,
198        )
199    }
200
201    /// Convert [`TrainingResult`] and events into training output.
202    ///
203    /// Delegates to [`crate::build_training_output`] with cut statistics from `self.fcf`.
204    #[must_use]
205    pub fn build_training_output(
206        &self,
207        result: &TrainingResult,
208        events: &[TrainingEvent],
209    ) -> cobre_io::TrainingOutput {
210        crate::build_training_output(result, events, &self.fcf)
211    }
212
213    /// Create a [`WorkspacePool`] sized for this study.
214    ///
215    /// Pool size equals `n_threads`. Each workspace gets a fresh solver instance.
216    /// `comm` is used to read the MPI rank that is stamped into each workspace's
217    /// `rank` field for downstream per-worker observability.
218    ///
219    /// # Errors
220    ///
221    /// Returns `SolverError` if solver creation fails.
222    ///
223    /// # Panics
224    ///
225    /// Panics if `comm.rank() > i32::MAX`. MPI world sizes are bounded well
226    /// below this on all real systems.
227    #[allow(clippy::expect_used)]
228    pub fn create_workspace_pool<S: SolverInterface + Send, C: Communicator>(
229        &self,
230        comm: &C,
231        n_threads: usize,
232        solver_factory: impl Fn() -> Result<S, SolverError>,
233    ) -> Result<WorkspacePool<S>, SolverError> {
234        let rank = i32::try_from(comm.rank()).expect("MPI rank fits in i32");
235        let mut pool = WorkspacePool::try_new(
236            rank,
237            n_threads,
238            self.stage_data.indexer.n_state,
239            WorkspaceSizing {
240                hydro_count: self.stage_data.indexer.hydro_count,
241                max_par_order: self.stage_data.indexer.max_par_order,
242                n_load_buses: self.stage_data.stage_templates.n_load_buses,
243                max_blocks: self.loop_params.max_blocks,
244                downstream_par_order: self.downstream_par_order,
245                max_openings: (0..self.stage_data.stage_templates.templates.len())
246                    .map(|t| self.stochastic.opening_tree().n_openings(t))
247                    .max()
248                    .unwrap_or(0),
249                initial_pool_capacity: 0,
250                n_state: self.stage_data.indexer.n_state,
251                // Simulation-only pool: forward-worker scratch fields unused.
252                max_local_fwd: 0,
253                total_forward_passes: 0,
254                noise_dim: 0,
255                n_anticipated: self.stage_data.indexer.n_anticipated,
256                k_max: self.stage_data.indexer.k_max,
257            },
258            solver_factory,
259        )?;
260        // Always pre-size scratch bases — basis reconstruction runs
261        // unconditionally on every forward/backward apply with a stored basis.
262        let max_cols = self
263            .stage_data
264            .stage_templates
265            .templates
266            .iter()
267            .map(|t| t.num_cols)
268            .max()
269            .unwrap_or(0);
270        let max_rows = self
271            .stage_data
272            .stage_templates
273            .templates
274            .iter()
275            .map(|t| t.num_rows)
276            .max()
277            .unwrap_or(0);
278        pool.resize_scratch_bases(max_cols, max_rows);
279        Ok(pool)
280    }
281}