powers_rs/
sddp.rs

1//! Implementation of the Stochastic Dual Dynamic Programming (SDDP)
2//! algorithm for the hydrothermal dispatch problem. In exchange for
3//! the simplified power system and state definition, some "smart"
4//! optimizations and features are already considered in this code.
5//!
6//! The underlying power system is modeled with only four entities:
7//! - Buses
8//! - Lines
9//! - Thermals
10//! - Hydros
11//!
12//! Some considerations about the implementation:
13//!
14//! 1. Only hydro storages are considered as state variables.
15//! 2. No memory management was made ready for parallelism (no locks and mutexes)
16//! 3. Only risk-neutral policy evaluation is supported (no risk-aversion)
17//! 4. An exact cut selection strategy (inspired in SDDP.jl) is implemented
18//! 5. Only the "single-cut" (average cut) variant of the algorithm is supported.
19//!
20//! The only external dependencies are:
21//!
22//! 1. Random number generation and distribution sampling from rand* crates
23//! 2. Low-level C-bindings from the highs-sys crate
24//! 3. JSON and CSV serializers from the serde, serde_json and csv crates
25
26use crate::fcf;
27use crate::graph;
28use crate::initial_condition;
29use crate::log;
30use crate::risk_measure;
31use crate::scenario;
32use crate::stochastic_process;
33use crate::subproblem;
34use crate::system;
35use crate::utils;
36use chrono::prelude::*;
37use rand::prelude::*;
38
39use rand_xoshiro::Xoshiro256Plus;
40use rayon::prelude::*;
41use std::f64;
42use std::sync::{Arc, Mutex};
43use std::time::Instant;
44
45pub struct NodeData {
46    pub id: isize,
47    pub stage_id: usize,
48    pub season_id: usize,
49    pub start_date: DateTime<Utc>,
50    pub end_date: DateTime<Utc>,
51    pub kind: subproblem::StudyPeriodKind,
52    pub system: system::System,
53    pub risk_measure: Box<dyn risk_measure::RiskMeasure>,
54    pub load_stochastic_process: Box<dyn stochastic_process::StochasticProcess>,
55    pub inflow_stochastic_process:
56        Box<dyn stochastic_process::StochasticProcess>,
57    pub state_choice: String,
58}
59
60impl NodeData {
61    pub fn new(
62        node_id: isize,
63        stage_id: usize,
64        season_id: usize,
65        start_date_str: &str,
66        end_date_str: &str,
67        kind: subproblem::StudyPeriodKind,
68        system: system::System,
69        risk_measure_str: &str,
70        load_stochastic_process_str: &str,
71        inflow_stochastic_process_str: &str,
72        state_str: &str,
73    ) -> Result<Self, String> {
74        // Changed to return Result
75        let load_stochastic_process =
76            stochastic_process::factory(load_stochastic_process_str);
77        let inflow_stochastic_process =
78            stochastic_process::factory(inflow_stochastic_process_str);
79
80        Ok(Self {
81            id: node_id,
82            stage_id,
83            season_id,
84            start_date: start_date_str.parse::<DateTime<Utc>>().map_err(
85                |e| {
86                    format!(
87                        "Failed to parse start_date {}: {}",
88                        start_date_str, e
89                    )
90                },
91            )?,
92            end_date: end_date_str.parse::<DateTime<Utc>>().map_err(|e| {
93                format!("Failed to parse end_date {}: {}", end_date_str, e)
94            })?,
95            kind,
96            system,
97            risk_measure: risk_measure::factory(risk_measure_str),
98            load_stochastic_process,
99            inflow_stochastic_process,
100            state_choice: state_str.to_string(),
101        })
102    }
103}
104
105pub struct SddpTrainHandler {
106    subproblem_graph: graph::DirectedGraph<subproblem::Subproblem>,
107    realization_graph: graph::DirectedGraph<subproblem::Realization>,
108    branching_graph: graph::DirectedGraph<Vec<subproblem::Realization>>,
109}
110
111impl SddpTrainHandler {
112    pub fn new(
113        pre_study_id: &usize,
114        node_data_graph: &graph::DirectedGraph<NodeData>,
115        initial_condition: &initial_condition::InitialCondition,
116        saa: &scenario::SAA,
117    ) -> Result<Self, String> {
118        // allocates graph with all required memory for forward solutions
119        let mut realization_graph =
120            node_data_graph.map_topology_with(|node_data, _id| {
121                subproblem::Realization::with_capacity(
122                    &node_data.kind,
123                    &node_data.system,
124                )
125            });
126
127        let subproblem_graph =
128            node_data_graph.map_topology_with(|node_data, _id| {
129                subproblem::Subproblem::new(
130                    &node_data.system,
131                    &node_data.state_choice,
132                    &node_data.load_stochastic_process,
133                    &node_data.inflow_stochastic_process,
134                )
135            });
136
137        // add initial_condition to the PreStudy realization graph node
138        realization_graph
139            .get_node_mut(*pre_study_id)
140            .ok_or_else(|| {
141                "Failed to set initial condition to graph".to_string()
142            })?
143            .data
144            .final_storage
145            .clone_from_slice(initial_condition.get_storage());
146
147        // allocates branching graph with all required memory for backward solutions
148        let branching_graph =
149            node_data_graph.map_topology_with(|node_data, id| {
150                vec![
151                    subproblem::Realization::with_capacity(
152                        &node_data.kind,
153                        &node_data.system,
154                    );
155                    saa.get_branching_count_at_stage(id).expect(&format!(
156                        "Missing branching count for node {}",
157                        id
158                    ))
159                ]
160            });
161
162        Ok(Self {
163            subproblem_graph,
164            realization_graph,
165            branching_graph,
166        })
167    }
168
169    pub fn forward(
170        &mut self,
171        sampled_noises: Vec<&scenario::SampledBranchingNoises>,
172        node_data_graph: &graph::DirectedGraph<NodeData>,
173        graph_bfs_table: &Vec<Vec<usize>>,
174        study_period_ids: &Vec<usize>,
175    ) -> Result<f64, String> {
176        for (idx, id) in study_period_ids.iter().enumerate() {
177            let data_node = node_data_graph.get_node(*id).ok_or_else(|| {
178                format!("Could not find data for node {}", id)
179            })?;
180
181            let subproblem_node =
182                self.subproblem_graph.get_node_mut(*id).ok_or_else(|| {
183                    format!("Could not find subproblem for node {}", id)
184                })?;
185
186            let past_node_ids = graph_bfs_table.get(idx).ok_or_else(|| {
187                format!("Could not find past node ids for node {}", id)
188            })?;
189            let past_realizations: Vec<&subproblem::Realization> = past_node_ids
190                .iter()
191                .map(|&past_id| {
192                    self.realization_graph
193                        .get_node(past_id)
194                        .map(|node| &node.data)
195                        .ok_or_else(|| {
196                            format!("Could not find realization for past_node {} (current_id {})", past_id, id)
197                        })
198                    })
199                .collect::<Result<_, _>>()?;
200
201            subproblem_node
202                .data
203                .update_with_current_trajectory(past_realizations);
204
205            let realization_node =
206                self.realization_graph.get_node_mut(*id).ok_or_else(|| {
207                    format!("Could not find realization for node {}", id)
208                })?;
209
210            let current_stage_noises =
211                sampled_noises.get(*id).ok_or_else(|| {
212                    format!("Could not find noises for node {}", id)
213                })?;
214
215            step(
216                data_node,
217                &mut subproblem_node.data,
218                &mut realization_node.data,
219                current_stage_noises,
220            )?;
221
222            subproblem_node
223                .data
224                .update_with_current_realization(&realization_node.data);
225        }
226
227        let trajectory_cost: f64 = study_period_ids
228            .iter()
229            .map(|&id| {
230                self.realization_graph
231                    .get_node(id)
232                    .map(|node| node.data.current_stage_objective)
233                    .ok_or_else(|| {
234                        format!(
235                            "Could not find realization node {} in iterate",
236                            id
237                        )
238                    })
239            })
240            .sum::<Result<f64, String>>()?;
241        Ok(trajectory_cost)
242    }
243
244    pub fn backward_step_at_node(
245        &mut self,
246        id: usize,
247        past_node_ids: &Vec<usize>,
248        node_data_graph: &graph::DirectedGraph<NodeData>,
249        saa: &scenario::SAA,
250        future_cost_function_graph: &graph::DirectedGraph<
251            Arc<Mutex<fcf::FutureCostFunction>>,
252        >,
253    ) -> Result<(), String> {
254        let node_forward_trajectory: Vec<&subproblem::Realization> =
255                past_node_ids
256                    .iter()
257                    .map(|&past_id| {
258                        self.realization_graph
259                            .get_node(past_id)
260                            .map(|node| &node.data)
261                            .ok_or_else(|| {
262                                format!("Could not find realization for past_node {} (current_id {})", past_id, id)
263                            })
264                    })
265                    .collect::<Result<_, _>>()?;
266
267        let num_branchings =
268            saa.get_branching_count_at_stage(id).ok_or_else(|| {
269                format!(
270                    "Missing branching count for node {} in backward pass",
271                    id
272                )
273            })?;
274
275        solve_all_branchings(
276            &mut self.subproblem_graph,
277            &mut self.branching_graph,
278            id,
279            num_branchings,
280            &node_forward_trajectory,
281            node_data_graph,
282            saa,
283        )?;
284
285        let branching_node_data = &self
286            .branching_graph
287            .get_node(id)
288            .ok_or_else(|| {
289                format!("Could not find branching realizations for node {}", id)
290            })?
291            .data;
292
293        let parent_id = node_data_graph
294            .get_parents(id)
295            .and_then(|parents| parents.first().copied()) // Assumes a single parent for path graphs
296            .ok_or_else(|| {
297                format!("Could not find a unique parent for node {}", id)
298            })?;
299
300        update_future_cost_function(
301            &mut self.subproblem_graph,
302            future_cost_function_graph,
303            parent_id,
304            id,
305            node_data_graph,
306            &node_forward_trajectory,
307            branching_node_data,
308        )?;
309
310        Ok(())
311    }
312
313    pub fn eval_first_stage_bound(
314        &mut self,
315        id: usize,
316        past_node_ids: &Vec<usize>,
317        node_data_graph: &graph::DirectedGraph<NodeData>,
318        saa: &scenario::SAA,
319    ) -> Result<f64, String> {
320        let node_forward_trajectory: Vec<&subproblem::Realization> =
321                past_node_ids
322                    .iter()
323                    .map(|&past_id| {
324                        self.realization_graph
325                            .get_node(past_id)
326                            .map(|node| &node.data)
327                            .ok_or_else(|| {
328                                format!("Could not find realization for past_node {} (current_id {})", past_id, id)
329                            })
330                    })
331                    .collect::<Result<_, _>>()?;
332
333        let num_branchings =
334            saa.get_branching_count_at_stage(id).ok_or_else(|| {
335                format!(
336                    "Missing branching count for node {} in backward pass",
337                    id
338                )
339            })?;
340
341        solve_all_branchings(
342            &mut self.subproblem_graph,
343            &mut self.branching_graph,
344            id,
345            num_branchings,
346            &node_forward_trajectory,
347            node_data_graph,
348            saa,
349        )?;
350
351        let branching_node_data = &self
352            .branching_graph
353            .get_node(id)
354            .ok_or_else(|| {
355                format!("Could not find branching realizations for node {}", id)
356            })?
357            .data;
358
359        eval_first_stage_bound(
360            branching_node_data,
361            &node_data_graph
362                .get_node(id)
363                .ok_or_else(|| {
364                    format!("Could not find node data for node {}", id)
365                })?
366                .data
367                .risk_measure,
368        )
369    }
370}
371
372fn solve_all_branchings(
373    subproblem_graph: &mut graph::DirectedGraph<subproblem::Subproblem>,
374    branching_graph: &mut graph::DirectedGraph<Vec<subproblem::Realization>>,
375    node_id: usize,
376    num_branchings: usize,
377    node_forward_trajectory: &Vec<&subproblem::Realization>,
378    node_data_graph: &graph::DirectedGraph<NodeData>,
379    saa: &scenario::SAA,
380) -> Result<(), String> {
381    let data_node = node_data_graph.get_node(node_id).ok_or_else(|| {
382        format!("Could not find node data for node {}", node_id)
383    })?;
384
385    let subproblem_node =
386        subproblem_graph.get_node_mut(node_id).ok_or_else(|| {
387            format!("Could not find subproblem for node {}", node_id)
388        })?;
389
390    let node_forward_realization =
391        node_forward_trajectory.last().ok_or_else(|| {
392            format!("Could not find forward realization for node {}", node_id)
393        })?;
394
395    let current_branching_node =
396        branching_graph.get_node_mut(node_id).ok_or_else(|| {
397            format!(
398                "Could not find branching realizations for node {}",
399                node_id
400            )
401        })?;
402
403    for branching_id in 0..num_branchings {
404        reuse_forward_basis(
405            &mut subproblem_node.data,
406            node_forward_realization,
407        )?;
408
409        step(
410            data_node,
411            &mut subproblem_node.data,
412            current_branching_node
413                .data
414                .get_mut(branching_id)
415                .ok_or_else(|| {
416                    format!(
417                        "Could not find branching {} realization for node {}",
418                        branching_id, node_id
419                    )
420                })?,
421            saa.get_noises_by_stage_and_branching(node_id, branching_id)
422                .ok_or_else(|| {
423                    format!(
424                        "Could not find noises for branching {}, node {}",
425                        branching_id, node_id
426                    )
427                })?,
428        )?;
429    }
430    Ok(())
431}
432
433fn update_future_cost_function(
434    subproblem_graph: &mut graph::DirectedGraph<subproblem::Subproblem>,
435    future_cost_function_graph: &graph::DirectedGraph<
436        Arc<Mutex<fcf::FutureCostFunction>>,
437    >,
438    parent_id: usize,
439    child_id: usize,
440    node_data_graph: &graph::DirectedGraph<NodeData>,
441    forward_trajectory: &Vec<&subproblem::Realization>,
442    branching_realizations: &Vec<subproblem::Realization>,
443) -> Result<(), String> {
444    // evals cut with the state sampled by the child node, which will represent the
445    // future cost function of that node, for the parent one.
446
447    let child_data_node =
448        node_data_graph.get_node(child_id).ok_or_else(|| {
449            format!("Could not find node data for node {}", child_id)
450        })?;
451    let child_subproblem_node =
452        subproblem_graph.get_node(child_id).ok_or_else(|| {
453            format!("Could not find subproblem for node {}", child_id)
454        })?;
455    let cut_state_pair = child_subproblem_node.data.compute_new_cut(
456        forward_trajectory,
457        branching_realizations,
458        &child_data_node.data.risk_measure,
459    );
460
461    // adds cut to the pools in the parent node, applying cut selection
462    let parent_subproblem_node: &mut graph::Node<subproblem::Subproblem> =
463        subproblem_graph.get_node_mut(parent_id).ok_or_else(|| {
464            format!("Could not find subproblem for node {}", parent_id)
465        })?;
466    let parent_fcf_node: &graph::Node<Arc<Mutex<fcf::FutureCostFunction>>> =
467        future_cost_function_graph
468            .get_node(parent_id)
469            .ok_or_else(|| {
470                format!(
471                    "Could not find future cost function for node {}",
472                    parent_id
473                )
474            })?;
475
476    parent_subproblem_node
477        .data
478        .add_cut_and_evaluate_cut_selection(
479            cut_state_pair,
480            Arc::clone(&parent_fcf_node.data),
481        );
482    Ok(())
483}
484
485pub struct SddpSimulationHandler {
486    subproblem_graph: graph::DirectedGraph<subproblem::Subproblem>,
487    realization_graph: graph::DirectedGraph<subproblem::Realization>,
488}
489
490impl SddpSimulationHandler {
491    pub fn new(
492        pre_study_id: &usize,
493        node_data_graph: &graph::DirectedGraph<NodeData>,
494        initial_condition: &initial_condition::InitialCondition,
495    ) -> Result<Self, String> {
496        // allocates graph with all required memory for forward solutions
497        let mut realization_graph =
498            node_data_graph.map_topology_with(|node_data, _id| {
499                subproblem::Realization::with_capacity(
500                    &node_data.kind,
501                    &node_data.system,
502                )
503            });
504
505        let subproblem_graph =
506            node_data_graph.map_topology_with(|node_data, _id| {
507                subproblem::Subproblem::new(
508                    &node_data.system,
509                    &node_data.state_choice,
510                    &node_data.load_stochastic_process,
511                    &node_data.inflow_stochastic_process,
512                )
513            });
514
515        // add initial_condition to the PreStudy realization graph node
516        realization_graph
517            .get_node_mut(*pre_study_id)
518            .ok_or_else(|| {
519                "Failed to set initial condition to graph".to_string()
520            })?
521            .data
522            .final_storage
523            .clone_from_slice(initial_condition.get_storage());
524
525        Ok(Self {
526            subproblem_graph,
527            realization_graph,
528        })
529    }
530
531    pub fn forward(
532        &mut self,
533        sampled_noises: Vec<&scenario::SampledBranchingNoises>,
534        node_data_graph: &graph::DirectedGraph<NodeData>,
535        graph_bfs_table: &Vec<Vec<usize>>,
536        study_period_ids: &Vec<usize>,
537    ) -> Result<f64, String> {
538        for (idx, id) in study_period_ids.iter().enumerate() {
539            let data_node = node_data_graph.get_node(*id).ok_or_else(|| {
540                format!("Could not find data for node {}", id)
541            })?;
542
543            let subproblem_node =
544                self.subproblem_graph.get_node_mut(*id).ok_or_else(|| {
545                    format!("Could not find subproblem for node {}", id)
546                })?;
547
548            let past_node_ids = graph_bfs_table.get(idx).ok_or_else(|| {
549                format!("Could not find past node ids for node {}", id)
550            })?;
551            let past_realizations: Vec<&subproblem::Realization> = past_node_ids
552                .iter()
553                .map(|&past_id| {
554                    self.realization_graph
555                        .get_node(past_id)
556                        .map(|node| &node.data)
557                        .ok_or_else(|| {
558                            format!("Could not find realization for past_node {} (current_id {})", past_id, id)
559                        })
560                    })
561                .collect::<Result<_, _>>()?;
562
563            subproblem_node
564                .data
565                .update_with_current_trajectory(past_realizations);
566
567            let realization_node =
568                self.realization_graph.get_node_mut(*id).ok_or_else(|| {
569                    format!("Could not find realization for node {}", id)
570                })?;
571
572            let current_stage_noises =
573                sampled_noises.get(*id).ok_or_else(|| {
574                    format!("Could not find noises for node {}", id)
575                })?;
576
577            step(
578                data_node,
579                &mut subproblem_node.data,
580                &mut realization_node.data,
581                current_stage_noises,
582            )?;
583
584            subproblem_node
585                .data
586                .update_with_current_realization(&realization_node.data);
587        }
588
589        let trajectory_cost: f64 = study_period_ids
590            .iter()
591            .map(|&id| {
592                self.realization_graph
593                    .get_node(id)
594                    .map(|node| node.data.current_stage_objective)
595                    .ok_or_else(|| {
596                        format!(
597                            "Could not find realization node {} in iterate",
598                            id
599                        )
600                    })
601            })
602            .sum::<Result<f64, String>>()?;
603        Ok(trajectory_cost)
604    }
605
606    pub fn get_realization_at_node(
607        &self,
608        id: usize,
609    ) -> Option<&graph::Node<subproblem::Realization>> {
610        self.realization_graph.get_node(id)
611    }
612}
613
614pub struct SddpAlgorithm {
615    // core graphs and data
616    node_data_graph: graph::DirectedGraph<NodeData>,
617    pub future_cost_function_graph:
618        graph::DirectedGraph<Arc<Mutex<fcf::FutureCostFunction>>>,
619
620    // initial state
621    initial_condition: initial_condition::InitialCondition,
622
623    // for rng reproducibility
624    seed: u64,
625
626    // helpers for traversing the graphs
627    pre_study_id: usize,
628    pub study_period_ids: Vec<usize>,
629    graph_bfs_table: Vec<Vec<usize>>, // BFS table for study periods
630}
631
632impl SddpAlgorithm {
633    pub fn new(
634        node_data_graph: graph::DirectedGraph<NodeData>,
635        initial_condition: initial_condition::InitialCondition,
636        seed: u64,
637    ) -> Result<Self, String> {
638        let future_cost_function_graph =
639            node_data_graph.map_topology_with(|_node_data, _id| {
640                Arc::new(Mutex::new(fcf::FutureCostFunction::new()))
641            });
642
643        let pre_study_id = node_data_graph
644            .get_node_id_with(|node| {
645                node.kind == subproblem::StudyPeriodKind::PreStudy
646            })
647            .ok_or_else(|| {
648                "Failed to find initial condition info in graph".to_string()
649            })?;
650
651        let study_period_ids = node_data_graph.get_all_node_ids_with(|node| {
652            node.kind == subproblem::StudyPeriodKind::Study
653        });
654
655        // TODO - for the path graph case, this is enough. But for markovian graphs
656        // and cyclic graphs (infinite horizon) this might not be enough.
657        let graph_bfs_table = study_period_ids
658            .iter()
659            .map(|id| node_data_graph.get_bfs(*id, true))
660            .collect();
661
662        Ok(Self {
663            node_data_graph,
664            future_cost_function_graph,
665            initial_condition,
666            seed,
667            pre_study_id,
668            study_period_ids,
669            graph_bfs_table,
670        })
671    }
672
673    pub fn train(
674        &mut self,
675        num_iterations: usize,
676        num_forward_passes: usize,
677        saa: &scenario::SAA,
678    ) -> Result<(), String> {
679        // rng is always created for reproducibility
680        let mut rng = Xoshiro256Plus::seed_from_u64(self.seed);
681
682        let begin = Instant::now();
683
684        log::training_greeting(num_iterations, num_forward_passes);
685        log::training_table_divider();
686        log::training_table_header();
687        log::training_table_divider();
688
689        let mut train_handlers: Vec<SddpTrainHandler> = (0..num_forward_passes)
690            .map(|_| {
691                SddpTrainHandler::new(
692                    &self.pre_study_id,
693                    &self.node_data_graph,
694                    &self.initial_condition,
695                    saa,
696                )
697            })
698            .collect::<Result<_, _>>()?;
699
700        // Main training loop
701        for index in 0..num_iterations {
702            let iter_begin = Instant::now();
703
704            // Sample noises for each forward pass
705            let all_sampled_noises: Vec<_> = (0..num_forward_passes)
706                .map(|_| saa.sample_scenario(&mut rng))
707                .collect();
708
709            // --- Parallel Forward Passes ---
710            let forward_costs: Vec<f64> = train_handlers
711                .par_iter_mut()
712                .zip(all_sampled_noises.par_iter())
713                .map(|(handler, noises)| self.forward(noises.to_vec(), handler))
714                .collect::<Result<Vec<f64>, String>>()?;
715
716            let avg_forward_cost = utils::mean(&forward_costs);
717
718            // --- Parallel Backward Pass with Stage-wise Synchronization ---
719            let num_study_periods = self.study_period_ids.len();
720            let mut lower_bound = 0.0;
721            // Iterate backwards through study periods
722            for rev_idx in 0..num_study_periods {
723                let current_stage_original_idx =
724                    num_study_periods - 1 - rev_idx;
725                let id = self.study_period_ids[current_stage_original_idx];
726
727                let past_node_ids = self
728                .graph_bfs_table
729                .get(current_stage_original_idx)
730                .ok_or_else(||
731                    format!("Could not find past node ids for node {} (original_idx {})", id, current_stage_original_idx)
732                )?;
733                // If it's not the very first stage of the study (i.e., has a parent stage)
734                if current_stage_original_idx > 0 {
735                    train_handlers
736                        .par_iter_mut()
737                        .map(|handler| {
738                            handler.backward_step_at_node(
739                                id,
740                                past_node_ids,
741                                &self.node_data_graph,
742                                saa,
743                                &self.future_cost_function_graph,
744                            )
745                        })
746                        .collect::<Result<(), String>>()?;
747                    // TODO - try serial cut selection instead of selecting while the FCF is locked on each thread
748                } else {
749                    lower_bound = train_handlers
750                        .get_mut(0)
751                        .unwrap()
752                        .eval_first_stage_bound(
753                            id,
754                            past_node_ids,
755                            &self.node_data_graph,
756                            saa,
757                        )
758                        .unwrap();
759                }
760            }
761
762            let iter_time = iter_begin.elapsed();
763            log::training_table_row(
764                index + 1,
765                lower_bound,
766                avg_forward_cost,
767                iter_time,
768            );
769        }
770
771        log::training_table_divider();
772        let duration = begin.elapsed();
773        log::training_duration(duration);
774        log::policy_size(
775            self.future_cost_function_graph
776                .get_node(1)
777                .ok_or_else(|| {
778                    format!("Could not find node 1 for counting cuts")
779                })?
780                .data
781                .lock()
782                .unwrap()
783                .cut_pool
784                .total_cut_count,
785        );
786        Ok(())
787    }
788
789    pub fn forward(
790        &self,
791        sampled_noises: Vec<&scenario::SampledBranchingNoises>,
792        handler: &mut SddpTrainHandler,
793    ) -> Result<f64, String> {
794        let trajectory_cost = handler.forward(
795            sampled_noises,
796            &self.node_data_graph,
797            &self.graph_bfs_table,
798            &self.study_period_ids,
799        )?;
800        Ok(trajectory_cost)
801    }
802
803    pub fn simulate(
804        &mut self,
805        num_simulation_scenarios: usize,
806        saa: &scenario::SAA,
807    ) -> Result<Vec<SddpSimulationHandler>, String> {
808        let mut rng = Xoshiro256Plus::seed_from_u64(self.seed);
809
810        let begin = Instant::now();
811
812        log::simulation_greeting(num_simulation_scenarios);
813
814        let all_sampled_noises: Vec<_> = (0..num_simulation_scenarios)
815            .map(|_| saa.sample_scenario(&mut rng))
816            .collect();
817
818        let mut simulation_handlers: Vec<SddpSimulationHandler> = (0
819            ..num_simulation_scenarios)
820            .map(|_| {
821                SddpSimulationHandler::new(
822                    &self.pre_study_id,
823                    &self.node_data_graph,
824                    &self.initial_condition,
825                )
826            })
827            .collect::<Result<_, _>>()?;
828
829        let simulation_costs: Vec<f64> = simulation_handlers
830            .par_iter_mut()
831            .zip(all_sampled_noises.par_iter())
832            .map(|(handler, noises)| {
833                handler.forward(
834                    noises.to_vec(),
835                    &self.node_data_graph,
836                    &self.graph_bfs_table,
837                    &self.study_period_ids,
838                )
839            })
840            .collect::<Result<Vec<f64>, String>>()?;
841
842        let _simulation_costs: Vec<f64> = simulation_handlers
843            .par_iter()
844            .map(|t| {
845                Ok(self.study_period_ids
846                    .iter()
847                    .map(|&id| {
848                        t.get_realization_at_node(id)
849                            .map(|node| node.data.current_stage_objective)
850                            .ok_or_else(|| format!("Could not find realization for node {} in simulation_costs", id))
851                    })
852                    .collect::<Result<Vec<f64>, String>>()?
853                    .iter()
854                    .sum())
855            })
856            .collect::<Result<Vec<f64>, String>>()?;
857        let mean_cost = utils::mean(&simulation_costs);
858        let std_cost = utils::standard_deviation(&simulation_costs);
859        log::simulation_stats(mean_cost, std_cost);
860        let duration = begin.elapsed();
861        log::simulation_duration(duration);
862
863        Ok(simulation_handlers)
864    }
865}
866
867fn step(
868    data_node: &graph::Node<NodeData>,
869    subproblem: &mut subproblem::Subproblem,
870    realization_container: &mut subproblem::Realization,
871    noises: &scenario::SampledBranchingNoises,
872) -> Result<(), String> {
873    subproblem.realize_uncertainties(
874        noises,
875        &data_node.data.load_stochastic_process,
876        &data_node.data.inflow_stochastic_process,
877        realization_container,
878    )?;
879    Ok(())
880}
881
882fn reuse_forward_basis(
883    subproblem: &mut subproblem::Subproblem,
884    node_forward_realization: &subproblem::Realization,
885) -> Result<(), String> {
886    if node_forward_realization.basis.columns().len() > 0 {
887        if let Some(model) = subproblem.model.as_mut() {
888            let num_model_rows = model.num_rows();
889            let mut forward_rows =
890                node_forward_realization.basis.rows().to_vec();
891            let num_forward_rows = forward_rows.len();
892
893            // checks if should add zeros to the rows (new cuts added)
894            if num_forward_rows < num_model_rows {
895                let row_diff = num_model_rows - num_forward_rows;
896                forward_rows.append(&mut vec![0; row_diff]);
897            } else if num_forward_rows > num_model_rows {
898                forward_rows.truncate(num_model_rows);
899            }
900
901            model.set_basis(
902                Some(node_forward_realization.basis.columns()),
903                Some(&forward_rows),
904            );
905        }
906    }
907    Ok(())
908}
909
910fn eval_first_stage_bound(
911    branching_realizations: &Vec<subproblem::Realization>,
912    risk_measure: &Box<dyn risk_measure::RiskMeasure>,
913) -> Result<f64, String> {
914    let costs: Vec<f64> = branching_realizations
915        .iter()
916        .map(|r| r.total_stage_objective)
917        .collect();
918    let num_branchings = costs.len();
919    let probabilities = utils::uniform_prob_by_count(num_branchings);
920    let adjusted_probabilities =
921        risk_measure.adjust_probabilities(&probabilities, &costs);
922    let average_solution_cost =
923        utils::dot_product(adjusted_probabilities, &costs);
924    Ok(average_solution_cost)
925}
926
927#[cfg(test)]
928mod tests {
929
930    use super::*;
931    use rand_distr::{LogNormal, Normal};
932
933    #[test]
934    fn test_forward_with_default_system() {
935        let mut node_data_graph = graph::DirectedGraph::<NodeData>::new();
936        let pre_study_id = node_data_graph
937            .add_node(
938                NodeData::new(
939                    -1,
940                    0,
941                    0,
942                    "1970-01-01T00:00:00Z",
943                    "1970-01-01T00:00:00Z",
944                    subproblem::StudyPeriodKind::PreStudy,
945                    system::System::default(),
946                    "expectation",
947                    "naive",
948                    "naive",
949                    "storage",
950                )
951                .unwrap(),
952            )
953            .unwrap();
954        let node_0_id = node_data_graph
955            .add_node(
956                NodeData::new(
957                    0,
958                    0,
959                    0,
960                    "2025-01-01T00:00:00Z",
961                    "2025-02-01T00:00:00Z",
962                    subproblem::StudyPeriodKind::Study,
963                    system::System::default(), // Assuming System::default() is cheap or test-only
964                    "expectation",
965                    "naive",
966                    "naive",
967                    "storage",
968                )
969                .unwrap(),
970            )
971            .unwrap();
972        let node_1_id = node_data_graph
973            .add_node(
974                NodeData::new(
975                    1,
976                    1,
977                    1,
978                    "2025-02-01T00:00:00Z",
979                    "2025-03-01T00:00:00Z",
980                    subproblem::StudyPeriodKind::Study,
981                    system::System::default(),
982                    "expectation",
983                    "naive",
984                    "naive",
985                    "storage",
986                )
987                .unwrap(),
988            )
989            .unwrap();
990        let node_2_id = node_data_graph
991            .add_node(
992                NodeData::new(
993                    2,
994                    2,
995                    2,
996                    "2025-03-01T00:00:00Z",
997                    "2025-04-01T00:00:00Z",
998                    subproblem::StudyPeriodKind::Study,
999                    system::System::default(),
1000                    "expectation",
1001                    "naive",
1002                    "naive",
1003                    "storage",
1004                )
1005                .unwrap(),
1006            )
1007            .unwrap();
1008        node_data_graph.add_edge(pre_study_id, node_0_id).unwrap();
1009        node_data_graph.add_edge(node_0_id, node_1_id).unwrap();
1010        node_data_graph.add_edge(node_1_id, node_2_id).unwrap();
1011        let storage = vec![83.222];
1012
1013        let initial_condition =
1014            initial_condition::InitialCondition::new(storage, vec![]);
1015
1016        let example_noises = scenario::SampledBranchingNoises {
1017            load_noises: vec![75.0],
1018            inflow_noises: vec![10.0],
1019            num_load_entities: 1,
1020            num_inflow_entities: 1,
1021        };
1022        let sampled_noises = vec![
1023            &example_noises,
1024            &example_noises,
1025            &example_noises,
1026            &example_noises,
1027        ];
1028
1029        let pre_study_id = node_data_graph
1030            .get_node_id_with(|node| {
1031                node.kind == subproblem::StudyPeriodKind::PreStudy
1032            })
1033            .unwrap_or_else(|| {
1034                node_data_graph
1035                    .add_node(
1036                        NodeData::new(
1037                            -1,
1038                            0,
1039                            0,
1040                            "1970-01-01T00:00:00Z",
1041                            "1970-01-01T00:00:00Z",
1042                            subproblem::StudyPeriodKind::PreStudy,
1043                            system::System::default(),
1044                            "expectation",
1045                            "naive",
1046                            "naive",
1047                            "storage",
1048                        )
1049                        .unwrap(),
1050                    )
1051                    .unwrap()
1052            });
1053
1054        let study_period_ids = node_data_graph.get_all_node_ids_with(|node| {
1055            node.kind == subproblem::StudyPeriodKind::Study
1056        });
1057
1058        let graph_bfs_table = study_period_ids
1059            .iter()
1060            .map(|id| node_data_graph.get_bfs(*id, true))
1061            .collect();
1062
1063        let mut handler = SddpTrainHandler::new(
1064            &pre_study_id,
1065            &node_data_graph,
1066            &initial_condition,
1067            &generate_test_saa_for_four_stages(),
1068        )
1069        .unwrap();
1070
1071        handler
1072            .forward(
1073                sampled_noises,
1074                &node_data_graph,
1075                &graph_bfs_table,
1076                &study_period_ids,
1077            )
1078            .unwrap();
1079    }
1080
1081    fn generate_test_saa_for_four_stages() -> scenario::SAA {
1082        scenario::SAA {
1083            branching_samples: vec![
1084                scenario::SampledNodeBranchings {
1085                    num_branchings: 1,
1086                    branching_noises: vec![scenario::SampledBranchingNoises {
1087                        load_noises: vec![75.0],
1088                        inflow_noises: vec![5.0],
1089                        num_load_entities: 1,
1090                        num_inflow_entities: 1,
1091                    }],
1092                },
1093                scenario::SampledNodeBranchings {
1094                    num_branchings: 1,
1095                    branching_noises: vec![scenario::SampledBranchingNoises {
1096                        load_noises: vec![75.0],
1097                        inflow_noises: vec![10.0],
1098                        num_load_entities: 1,
1099                        num_inflow_entities: 1,
1100                    }],
1101                },
1102                scenario::SampledNodeBranchings {
1103                    num_branchings: 1,
1104                    branching_noises: vec![scenario::SampledBranchingNoises {
1105                        load_noises: vec![75.0],
1106                        inflow_noises: vec![15.0],
1107                        num_load_entities: 1,
1108                        num_inflow_entities: 1,
1109                    }],
1110                },
1111                scenario::SampledNodeBranchings {
1112                    num_branchings: 1,
1113                    branching_noises: vec![scenario::SampledBranchingNoises {
1114                        load_noises: vec![75.0],
1115                        inflow_noises: vec![15.0],
1116                        num_load_entities: 1,
1117                        num_inflow_entities: 1,
1118                    }],
1119                },
1120            ],
1121            index_samplers: vec![],
1122        }
1123    }
1124
1125    #[test]
1126    fn test_backward_with_default_system() {
1127        let mut node_data_graph = graph::DirectedGraph::<NodeData>::new();
1128        let pre_study_id = node_data_graph
1129            .add_node(
1130                NodeData::new(
1131                    -1,
1132                    0,
1133                    0,
1134                    "1970-01-01T00:00:00Z",
1135                    "1970-01-01T00:00:00Z",
1136                    subproblem::StudyPeriodKind::PreStudy,
1137                    system::System::default(),
1138                    "expectation",
1139                    "naive",
1140                    "naive",
1141                    "storage",
1142                )
1143                .unwrap(),
1144            )
1145            .unwrap();
1146        let node_0_id = node_data_graph
1147            .add_node(
1148                NodeData::new(
1149                    0,
1150                    0,
1151                    0,
1152                    "2025-01-01T00:00:00Z",
1153                    "2025-02-01T00:00:00Z",
1154                    subproblem::StudyPeriodKind::Study,
1155                    system::System::default(),
1156                    "expectation",
1157                    "naive",
1158                    "naive",
1159                    "storage",
1160                )
1161                .unwrap(),
1162            )
1163            .unwrap();
1164        let node_1_id = node_data_graph
1165            .add_node(
1166                NodeData::new(
1167                    1,
1168                    1,
1169                    1,
1170                    "2025-02-01T00:00:00Z",
1171                    "2025-03-01T00:00:00Z",
1172                    subproblem::StudyPeriodKind::Study,
1173                    system::System::default(),
1174                    "expectation",
1175                    "naive",
1176                    "naive",
1177                    "storage",
1178                )
1179                .unwrap(),
1180            )
1181            .unwrap();
1182        let node_2_id = node_data_graph
1183            .add_node(
1184                NodeData::new(
1185                    2,
1186                    2,
1187                    2,
1188                    "2025-03-01T00:00:00Z",
1189                    "2025-04-01T00:00:00Z",
1190                    subproblem::StudyPeriodKind::Study,
1191                    system::System::default(),
1192                    "expectation",
1193                    "naive",
1194                    "naive",
1195                    "storage",
1196                )
1197                .unwrap(),
1198            )
1199            .unwrap();
1200        node_data_graph.add_edge(pre_study_id, node_0_id).unwrap();
1201        node_data_graph.add_edge(node_0_id, node_1_id).unwrap();
1202        node_data_graph.add_edge(node_1_id, node_2_id).unwrap();
1203        let storage = vec![83.222];
1204
1205        let initial_condition =
1206            initial_condition::InitialCondition::new(storage, vec![]);
1207
1208        let future_cost_function_graph =
1209            node_data_graph.map_topology_with(|_node_data, _id| {
1210                Arc::new(Mutex::new(fcf::FutureCostFunction::new()))
1211            });
1212
1213        let example_noises = scenario::SampledBranchingNoises {
1214            load_noises: vec![75.0],
1215            inflow_noises: vec![10.0],
1216            num_load_entities: 1,
1217            num_inflow_entities: 1,
1218        };
1219        let sampled_noises = vec![
1220            &example_noises,
1221            &example_noises,
1222            &example_noises,
1223            &example_noises,
1224        ];
1225
1226        let pre_study_id = node_data_graph
1227            .get_node_id_with(|node| {
1228                node.kind == subproblem::StudyPeriodKind::PreStudy
1229            })
1230            .unwrap_or_else(|| {
1231                node_data_graph
1232                    .add_node(
1233                        NodeData::new(
1234                            -1,
1235                            0,
1236                            0,
1237                            "1970-01-01T00:00:00Z",
1238                            "1970-01-01T00:00:00Z",
1239                            subproblem::StudyPeriodKind::PreStudy,
1240                            system::System::default(),
1241                            "expectation",
1242                            "naive",
1243                            "naive",
1244                            "storage",
1245                        )
1246                        .unwrap(),
1247                    )
1248                    .unwrap()
1249            });
1250
1251        let study_period_ids = node_data_graph.get_all_node_ids_with(|node| {
1252            node.kind == subproblem::StudyPeriodKind::Study
1253        });
1254
1255        let graph_bfs_table = study_period_ids
1256            .iter()
1257            .map(|id| node_data_graph.get_bfs(*id, true))
1258            .collect();
1259
1260        let saa = generate_test_saa_for_four_stages();
1261
1262        let mut handler = SddpTrainHandler::new(
1263            &pre_study_id,
1264            &node_data_graph,
1265            &initial_condition,
1266            &saa,
1267        )
1268        .unwrap();
1269
1270        handler
1271            .forward(
1272                sampled_noises,
1273                &node_data_graph,
1274                &graph_bfs_table,
1275                &study_period_ids,
1276            )
1277            .unwrap();
1278
1279        let current_stage_original_idx = 1; // Corresponds to node 1
1280        let id = study_period_ids[current_stage_original_idx];
1281        let past_node_ids =
1282            graph_bfs_table.get(current_stage_original_idx).unwrap();
1283
1284        handler
1285            .backward_step_at_node(
1286                id,
1287                past_node_ids,
1288                &node_data_graph,
1289                &saa,
1290                &future_cost_function_graph,
1291            )
1292            .unwrap();
1293    }
1294
1295    #[test]
1296    fn test_train_with_default_system() {
1297        let mut node_data_graph = graph::DirectedGraph::<NodeData>::new();
1298        let pre_study_id = node_data_graph
1299            .add_node(
1300                NodeData::new(
1301                    -1,
1302                    0,
1303                    0,
1304                    "1970-01-01T00:00:00Z",
1305                    "1970-01-01T00:00:00Z",
1306                    subproblem::StudyPeriodKind::PreStudy,
1307                    system::System::default(),
1308                    "expectation",
1309                    "naive",
1310                    "naive",
1311                    "storage",
1312                )
1313                .unwrap(),
1314            )
1315            .unwrap();
1316        let prev_id = node_data_graph
1317            .add_node(
1318                NodeData::new(
1319                    0,
1320                    0,
1321                    0,
1322                    "2025-01-01T00:00:00Z",
1323                    "2025-02-01T00:00:00Z",
1324                    subproblem::StudyPeriodKind::Study,
1325                    system::System::default(),
1326                    "expectation",
1327                    "naive",
1328                    "naive",
1329                    "storage",
1330                )
1331                .unwrap(),
1332            )
1333            .unwrap();
1334        node_data_graph.add_edge(pre_study_id, prev_id).unwrap();
1335        let mut scenario_generator = scenario::NoiseGenerator::new();
1336        scenario_generator.add_node_generator(
1337            vec![Normal::new(75.0, 0.0).unwrap()],
1338            vec![LogNormal::new(3.6, 0.6928).unwrap()],
1339            3,
1340        );
1341        scenario_generator.add_node_generator(
1342            vec![Normal::new(75.0, 0.0).unwrap()],
1343            vec![LogNormal::new(3.6, 0.6928).unwrap()],
1344            3,
1345        );
1346
1347        for new_id_isize in 1..4 {
1348            let new_id = node_data_graph
1349                .add_node(
1350                    NodeData::new(
1351                        new_id_isize,
1352                        new_id_isize.try_into().unwrap(),
1353                        new_id_isize.try_into().unwrap(),
1354                        "2025-01-01T00:00:00Z",
1355                        "2025-02-01T00:00:00Z",
1356                        subproblem::StudyPeriodKind::Study,
1357                        system::System::default(),
1358                        "expectation",
1359                        "naive",
1360                        "naive",
1361                        "storage",
1362                    )
1363                    .unwrap(),
1364                )
1365                .unwrap();
1366            node_data_graph.add_edge(prev_id, new_id).unwrap();
1367            scenario_generator.add_node_generator(
1368                vec![Normal::new(75.0, 0.0).unwrap()],
1369                vec![LogNormal::new(3.6, 0.6928).unwrap()],
1370                3,
1371            );
1372        }
1373
1374        let storage = vec![83.222];
1375
1376        let initial_condition =
1377            initial_condition::InitialCondition::new(storage, vec![]);
1378
1379        let saa = scenario_generator.generate(0);
1380
1381        let mut sddp_algo =
1382            SddpAlgorithm::new(node_data_graph, initial_condition, 0).unwrap();
1383
1384        sddp_algo.train(24, 1, &saa).unwrap();
1385    }
1386
1387    #[test]
1388    fn test_simulate_with_default_system() {
1389        let mut node_data_graph = graph::DirectedGraph::<NodeData>::new();
1390        let pre_study_id = node_data_graph
1391            .add_node(
1392                NodeData::new(
1393                    -1,
1394                    0,
1395                    0,
1396                    "1970-01-01T00:00:00Z",
1397                    "1970-01-01T00:00:00Z",
1398                    subproblem::StudyPeriodKind::PreStudy,
1399                    system::System::default(),
1400                    "expectation",
1401                    "naive",
1402                    "naive",
1403                    "storage",
1404                )
1405                .unwrap(),
1406            )
1407            .unwrap();
1408        let prev_id = node_data_graph
1409            .add_node(
1410                NodeData::new(
1411                    0,
1412                    0,
1413                    0,
1414                    "2025-01-01T00:00:00Z",
1415                    "2025-02-01T00:00:00Z",
1416                    subproblem::StudyPeriodKind::Study,
1417                    system::System::default(),
1418                    "expectation",
1419                    "naive",
1420                    "naive",
1421                    "storage",
1422                )
1423                .unwrap(),
1424            )
1425            .unwrap();
1426        node_data_graph.add_edge(pre_study_id, prev_id).unwrap();
1427        let mut scenario_generator = scenario::NoiseGenerator::new();
1428        scenario_generator.add_node_generator(
1429            vec![Normal::new(75.0, 0.0).unwrap()],
1430            vec![LogNormal::new(3.6, 0.6928).unwrap()],
1431            3,
1432        );
1433        scenario_generator.add_node_generator(
1434            vec![Normal::new(75.0, 0.0).unwrap()],
1435            vec![LogNormal::new(3.6, 0.6928).unwrap()],
1436            3,
1437        );
1438        for new_id_isize in 1..4 {
1439            let new_id = node_data_graph
1440                .add_node(
1441                    NodeData::new(
1442                        new_id_isize,
1443                        new_id_isize.try_into().unwrap(),
1444                        new_id_isize.try_into().unwrap(),
1445                        "2025-01-01T00:00:00Z",
1446                        "2025-02-01T00:00:00Z",
1447                        subproblem::StudyPeriodKind::Study,
1448                        system::System::default(),
1449                        "expectation",
1450                        "naive",
1451                        "naive",
1452                        "storage",
1453                    )
1454                    .unwrap(),
1455                )
1456                .unwrap();
1457            node_data_graph.add_edge(prev_id, new_id).unwrap();
1458            scenario_generator.add_node_generator(
1459                vec![Normal::new(75.0, 0.0).unwrap()],
1460                vec![LogNormal::new(3.6, 0.6928).unwrap()],
1461                3,
1462            );
1463        }
1464        let storage = vec![83.222];
1465
1466        let initial_condition =
1467            initial_condition::InitialCondition::new(storage, vec![]);
1468
1469        let saa = scenario_generator.generate(0);
1470
1471        let mut sddp_algo =
1472            SddpAlgorithm::new(node_data_graph, initial_condition, 0).unwrap();
1473
1474        sddp_algo.train(24, 1, &saa).unwrap();
1475
1476        sddp_algo.simulate(100, &saa).unwrap();
1477    }
1478}