1use std::sync::Arc;
4use std::sync::atomic::AtomicBool;
5use std::sync::mpsc::{Sender, SyncSender};
6
7use cobre_comm::Communicator;
8use cobre_core::TrainingEvent;
9use cobre_solver::{SolverError, SolverInterface};
10
11use crate::{
12 config::{CutManagementConfig, EventConfig, LoopConfig, TrainingConfig},
13 context::{StageContext, TrainingContext},
14 dcs::DcsParams,
15 error::SddpError,
16 simulation::{
17 SimulationOutputSpec, error::SimulationError, pipeline::SimulationRunResult,
18 types::SimulationScenarioResult,
19 },
20 training::{TrainingOutcome, TrainingResult},
21 workspace::{CapturedBasis, SolverWorkspace, WorkspacePool, WorkspaceSizing},
22};
23
24use super::StudySetup;
25
26impl StudySetup {
27 pub fn train<S, C: Communicator>(
37 &mut self,
38 solver: &mut S,
39 comm: &C,
40 n_threads: usize,
41 solver_factory: impl Fn() -> Result<S, SolverError>,
42 event_sender: Option<Sender<TrainingEvent>>,
43 shutdown_flag: Option<&Arc<AtomicBool>>,
44 ) -> Result<TrainingOutcome, SddpError>
45 where
46 S: SolverInterface<Profile = cobre_solver::ActiveProfile> + Send,
47 {
48 let training_config = TrainingConfig {
49 loop_config: LoopConfig {
50 forward_passes: self.loop_params.forward_passes,
51 max_iterations: self.loop_params.max_iterations,
52 start_iteration: self.loop_params.start_iteration,
53 n_fwd_threads: n_threads,
54 max_blocks: self.loop_params.max_blocks,
55 stopping_rules: self.loop_params.stopping_rules.clone(),
56 },
57 cut_management: CutManagementConfig {
58 cut_selection: self.cut_management.cut_selection.clone(),
59 budget: self.cut_management.budget,
60 cut_activity_tolerance: self.cut_management.cut_activity_tolerance,
61 warm_start_cuts: 0,
62 risk_measures: self.cut_management.risk_measures.clone(),
63 },
64 events: EventConfig {
65 event_sender,
66 checkpoint_interval: None,
67 shutdown_flag: shutdown_flag.map(Arc::clone),
68 export_states: self.events.export_states,
69 },
70 };
71
72 let stage_ctx = StageContext {
73 templates: &self.stage_data.stage_templates.templates,
74 base_rows: &self.stage_data.stage_templates.base_rows,
75 noise_scale: &self.stage_data.stage_templates.noise_scale,
76 n_hydros: self.stage_data.stage_templates.n_hydros,
77 n_load_buses: self.stage_data.stage_templates.n_load_buses,
78 load_balance_row_starts: &self.stage_data.stage_templates.load_balance_row_starts,
79 load_bus_indices: &self.stage_data.stage_templates.load_bus_indices,
80 block_counts_per_stage: &self.stage_data.block_counts_per_stage,
81 ncs_max_gen: &self.ncs_max_gen,
82 ncs_allow_curtailment: &self.ncs_allow_curtailment,
83 discount_factors: &self.stage_data.stage_templates.discount_factors,
84 cumulative_discount_factors: &self
85 .stage_data
86 .stage_templates
87 .cumulative_discount_factors,
88 stage_lag_transitions: &self.stage_data.stage_lag_transitions,
89 noise_group_ids: &self.stage_data.noise_group_ids,
90 downstream_par_order: self.downstream_par_order,
91 };
92
93 let tr = &self.scenario_libraries.training;
94 let training_ctx = TrainingContext {
95 horizon: &self.methodology.horizon,
96 indexer: &self.stage_data.indexer,
97 inflow_method: &self.methodology.inflow_method,
98 stochastic: &self.stochastic,
99 initial_state: &self.initial_state,
100 inflow_scheme: tr.inflow_scheme,
101 load_scheme: tr.load_scheme,
102 ncs_scheme: tr.ncs_scheme,
103 stages: &self.stage_data.stages,
104 historical_library: tr.historical.as_ref(),
105 external_inflow_library: tr.external_inflow.as_ref(),
106 external_load_library: tr.external_load.as_ref(),
107 external_ncs_library: tr.external_ncs.as_ref(),
108 recent_accum_seed: &self.recent_observation_seed.accum_seed,
109 recent_weight_seed: self.recent_observation_seed.weight_seed,
110 dcs: self
113 .cut_management
114 .cut_selection
115 .as_ref()
116 .and_then(DcsParams::from_strategy),
117 noise_key_diag: self.noise_key_diag.as_ref(),
120 };
121
122 let warm_start_basis_cache = self.warm_start_basis_cache.take();
127
128 crate::train(
129 solver,
130 training_config,
131 &mut self.fcf,
132 &stage_ctx,
133 &training_ctx,
134 comm,
135 solver_factory,
136 warm_start_basis_cache,
137 )
138 }
139
140 pub fn simulate<S, C: Communicator>(
152 &self,
153 workspaces: &mut [SolverWorkspace<S>],
154 comm: &C,
155 result_tx: &SyncSender<SimulationScenarioResult>,
156 event_sender: Option<Sender<TrainingEvent>>,
157 baked_templates: Option<&[cobre_solver::StageTemplate]>,
158 stage_bases: &[Option<CapturedBasis>],
159 ) -> Result<SimulationRunResult, SimulationError>
160 where
161 S: SolverInterface<Profile = cobre_solver::ActiveProfile> + Send,
162 {
163 let stage_ctx = self.stage_ctx();
164 let training_ctx = self.simulation_ctx();
165
166 let output = SimulationOutputSpec {
167 result_tx,
168 zeta_per_stage: &self.stage_data.stage_templates.zeta_per_stage,
169 block_hours_per_stage: &self.stage_data.stage_templates.block_hours_per_stage,
170 entity_counts: &self.stage_data.entity_counts,
171 generic_constraint_row_entries: &self
172 .stage_data
173 .stage_templates
174 .generic_constraint_row_entries,
175 ncs_col_starts: &self.stage_data.stage_templates.ncs_col_starts,
176 n_ncs_per_stage: &self.stage_data.stage_templates.n_ncs_per_stage,
177 ncs_entity_ids_per_stage: &self.ncs_entity_ids_per_stage,
178 diversion_upstream: &self.stage_data.stage_templates.diversion_upstream,
179 hydro_productivities_per_stage: &self
180 .stage_data
181 .stage_templates
182 .hydro_productivities_per_stage,
183 energy_conversion: &self.energy_conversion,
184 hydro_min_storage_hm3: &self.hydro_min_storage_hm3,
185 event_sender,
186 };
187
188 crate::simulate(
189 workspaces,
190 &stage_ctx,
191 &self.fcf,
192 &training_ctx,
193 self.simulation_config(),
194 output,
195 baked_templates,
196 stage_bases,
197 comm,
198 )
199 }
200
201 #[must_use]
205 pub fn build_training_output(
206 &self,
207 result: &TrainingResult,
208 events: &[TrainingEvent],
209 ) -> cobre_io::TrainingOutput {
210 crate::build_training_output(result, events, &self.fcf)
211 }
212
213 #[allow(clippy::expect_used)]
228 pub fn create_workspace_pool<S: SolverInterface + Send, C: Communicator>(
229 &self,
230 comm: &C,
231 n_threads: usize,
232 solver_factory: impl Fn() -> Result<S, SolverError>,
233 ) -> Result<WorkspacePool<S>, SolverError> {
234 let rank = i32::try_from(comm.rank()).expect("MPI rank fits in i32");
235 let mut pool = WorkspacePool::try_new(
236 rank,
237 n_threads,
238 self.stage_data.indexer.n_state,
239 WorkspaceSizing {
240 hydro_count: self.stage_data.indexer.hydro_count,
241 max_par_order: self.stage_data.indexer.max_par_order,
242 n_load_buses: self.stage_data.stage_templates.n_load_buses,
243 max_blocks: self.loop_params.max_blocks,
244 downstream_par_order: self.downstream_par_order,
245 max_openings: (0..self.stage_data.stage_templates.templates.len())
246 .map(|t| self.stochastic.opening_tree().n_openings(t))
247 .max()
248 .unwrap_or(0),
249 initial_pool_capacity: 0,
250 n_state: self.stage_data.indexer.n_state,
251 max_local_fwd: 0,
253 total_forward_passes: 0,
254 noise_dim: 0,
255 n_anticipated: self.stage_data.indexer.n_anticipated,
256 k_max: self.stage_data.indexer.k_max,
257 },
258 solver_factory,
259 )?;
260 let max_cols = self
263 .stage_data
264 .stage_templates
265 .templates
266 .iter()
267 .map(|t| t.num_cols)
268 .max()
269 .unwrap_or(0);
270 let max_rows = self
271 .stage_data
272 .stage_templates
273 .templates
274 .iter()
275 .map(|t| t.num_rows)
276 .max()
277 .unwrap_or(0);
278 pool.resize_scratch_bases(max_cols, max_rows);
279 Ok(pool)
280 }
281}