use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::Sender;
use std::time::Instant;
use cobre_comm::Communicator;
use cobre_core::{StageSelectionRecord, TrainingEvent};
use cobre_solver::Basis;
use cobre_solver::RowBatch;
use cobre_solver::SolverInterface;
use cobre_stochastic::OpeningTree;
use crate::{
SddpError, StoppingRuleSet, TrainingConfig, TrajectoryRecord,
backward::run_backward_pass,
context::{StageContext, TrainingContext},
convergence::ConvergenceMonitor,
cut::CutRowMap,
cut::fcf::FutureCostFunction,
cut_selection::CutSelectionStrategy,
cut_sync::CutSyncBuffers,
evaluate_lower_bound,
forward::{ForwardPassBatch, run_forward_pass, sync_forward},
lower_bound::LbEvalSpec,
lp_builder::PatchBuffer,
risk_measure::RiskMeasure,
solver_stats::{SolverStatsDelta, SolverStatsEntry, aggregate_solver_statistics},
state_exchange::ExchangeBuffers,
stopping_rule::RULE_ITERATION_LIMIT,
workspace::{BasisStore, WorkspacePool},
};
#[derive(Debug)]
pub struct TrainingOutcome {
pub result: TrainingResult,
pub error: Option<SddpError>,
}
#[derive(Debug, Clone)]
pub struct TrainingResult {
pub final_lb: f64,
pub final_ub: f64,
pub final_ub_std: f64,
pub final_gap: f64,
pub iterations: u64,
pub reason: String,
pub total_time_ms: u64,
pub basis_cache: Vec<Option<Basis>>,
pub solver_stats_log: Vec<SolverStatsEntry>,
}
#[inline]
fn emit(sender: Option<&Sender<TrainingEvent>>, event: TrainingEvent) {
if let Some(s) = sender {
let _ = s.send(event);
}
}
fn needs_periodic_rebuild(row_map: &CutRowMap, iterations_since_rebuild: u64) -> bool {
let total = row_map.total_cut_rows();
let active = row_map.active_count();
let phantom = total - active;
(total > 0 && phantom * 5 > total) || iterations_since_rebuild >= 50
}
#[allow(
clippy::too_many_arguments,
clippy::too_many_lines,
clippy::similar_names
)]
pub fn train<S: SolverInterface + Send, C: Communicator>(
solver: &mut S,
config: TrainingConfig,
fcf: &mut FutureCostFunction,
stage_ctx: &StageContext<'_>,
training_ctx: &TrainingContext<'_>,
opening_tree: &OpeningTree,
risk_measures: &[RiskMeasure],
stopping_rules: StoppingRuleSet,
cut_selection: Option<&CutSelectionStrategy>,
cut_activity_tolerance: f64,
shutdown_flag: Option<&Arc<AtomicBool>>,
comm: &C,
n_fwd_threads: usize,
solver_factory: impl Fn() -> Result<S, cobre_solver::SolverError>,
max_blocks: usize,
) -> Result<TrainingOutcome, SddpError> {
let horizon = training_ctx.horizon;
let indexer = training_ctx.indexer;
let initial_state = training_ctx.initial_state;
let num_stages = horizon.num_stages();
let num_ranks = comm.size();
let my_rank = comm.rank();
let total_forward_passes = config.forward_passes as usize;
let n_state = indexer.n_state;
let base_fwd = total_forward_passes / num_ranks;
let remainder_fwd = total_forward_passes % num_ranks;
let my_actual_fwd = base_fwd + usize::from(my_rank < remainder_fwd);
let my_fwd_offset = base_fwd * my_rank + my_rank.min(remainder_fwd);
let max_local_fwd = base_fwd + usize::from(remainder_fwd > 0);
let empty_record = TrajectoryRecord {
primal: vec![],
dual: vec![],
stage_cost: 0.0,
state: vec![0.0; n_state],
};
let mut records = vec![empty_record; max_local_fwd * num_stages];
let n_threads = n_fwd_threads.max(1);
let mut fwd_pool = WorkspacePool::try_new(
n_threads,
indexer.hydro_count,
indexer.max_par_order,
n_state,
stage_ctx.n_load_buses,
max_blocks,
solver_factory,
)
.map_err(SddpError::Solver)?;
let mut basis_store = BasisStore::new(max_local_fwd, num_stages);
let mut patch_buf = PatchBuffer::new(indexer.hydro_count, indexer.max_par_order, 0, 0);
let mut convergence_monitor = ConvergenceMonitor::new(stopping_rules);
let mut exchange_bufs = ExchangeBuffers::new(n_state, max_local_fwd, num_ranks);
let mut cut_sync_bufs =
CutSyncBuffers::with_distribution(n_state, max_local_fwd, num_ranks, total_forward_passes);
let start_time = Instant::now();
let TrainingConfig {
forward_passes: config_forward_passes,
max_iterations,
event_sender,
..
} = config;
#[allow(clippy::cast_possible_truncation)]
emit(
event_sender.as_ref(),
TrainingEvent::TrainingStarted {
case_name: String::new(),
stages: num_stages as u32,
hydros: indexer.hydro_count as u32,
thermals: 0,
ranks: num_ranks as u32,
#[allow(clippy::cast_possible_truncation)]
threads_per_rank: n_threads as u32,
timestamp: String::new(),
},
);
let mut final_lb = 0.0;
let mut final_ub = 0.0;
let mut final_ub_std = 0.0;
let mut final_gap = 0.0;
let mut completed_iterations = 0u64;
let mut termination_reason = RULE_ITERATION_LIMIT.to_string();
let mut solver_stats_log: Vec<SolverStatsEntry> = Vec::new();
let mut cut_batches: Vec<RowBatch> = (0..num_stages)
.map(|_| RowBatch {
num_rows: 0,
row_starts: Vec::new(),
col_indices: Vec::new(),
values: Vec::new(),
row_lower: Vec::new(),
row_upper: Vec::new(),
})
.collect();
let mut lb_cut_batch = RowBatch {
num_rows: 0,
row_starts: Vec::new(),
col_indices: Vec::new(),
values: Vec::new(),
row_lower: Vec::new(),
row_upper: Vec::new(),
};
let mut lb_cut_row_map = CutRowMap::new(fcf.pools[0].capacity, stage_ctx.templates[0].num_rows);
let mut lb_iterations_since_rebuild: u64 = 0;
macro_rules! on_error {
($e:expr) => {{
#[allow(clippy::cast_possible_truncation)]
emit(
event_sender.as_ref(),
TrainingEvent::TrainingFinished {
reason: "error".to_string(),
iterations: completed_iterations,
final_lb,
final_ub,
total_time_ms: (start_time.elapsed().as_millis() as u64).max(1),
total_cuts: fcf.total_active_cuts() as u64,
},
);
let last_scenario = my_actual_fwd.saturating_sub(1);
#[allow(clippy::cast_possible_truncation)]
let total_time_ms = (start_time.elapsed().as_millis() as u64).max(1);
let basis_cache = (0..num_stages)
.map(|t| basis_store.get(last_scenario, t).cloned())
.collect();
return Ok(TrainingOutcome {
result: TrainingResult {
final_lb,
final_ub,
final_ub_std,
final_gap,
iterations: completed_iterations,
reason: "error".to_string(),
total_time_ms,
basis_cache,
solver_stats_log,
},
error: Some($e),
});
}};
}
for iteration in 1..=max_iterations {
if let Some(flag) = shutdown_flag {
if flag.load(Ordering::Relaxed) {
convergence_monitor.set_shutdown();
}
}
let iter_start = Instant::now();
let fwd_record_len = my_actual_fwd * num_stages;
let fwd_batch = ForwardPassBatch {
local_forward_passes: my_actual_fwd,
iteration,
fwd_offset: my_fwd_offset,
};
let fwd_stats_before = {
let pool_stats: Vec<_> = fwd_pool
.workspaces
.iter()
.map(|w| w.solver.statistics())
.collect();
aggregate_solver_statistics(&pool_stats)
};
let forward_result = match run_forward_pass(
&mut fwd_pool.workspaces,
&mut basis_store,
stage_ctx,
fcf,
&mut cut_batches,
training_ctx,
&fwd_batch,
&mut records[..fwd_record_len],
) {
Ok(r) => r,
Err(e) => on_error!(e),
};
let fwd_delta = {
let pool_stats: Vec<_> = fwd_pool
.workspaces
.iter()
.map(|w| w.solver.statistics())
.collect();
let fwd_stats_after = aggregate_solver_statistics(&pool_stats);
SolverStatsDelta::from_snapshots(&fwd_stats_before, &fwd_stats_after)
};
let fwd_solve_time_ms = fwd_delta.solve_time_ms;
solver_stats_log.push((iteration, "forward".to_string(), -1, fwd_delta));
let forward_elapsed_ms = forward_result.elapsed_ms;
let local_n = forward_result.scenario_costs.len();
let local_cost_sum: f64 = forward_result.scenario_costs.iter().sum();
emit(
event_sender.as_ref(),
TrainingEvent::ForwardPassComplete {
iteration,
scenarios: config_forward_passes,
#[allow(clippy::cast_precision_loss)]
ub_mean: if local_n > 0 {
local_cost_sum / local_n as f64
} else {
0.0
},
ub_std: 0.0,
elapsed_ms: forward_elapsed_ms,
},
);
let sync_result = match sync_forward(&forward_result, comm, total_forward_passes) {
Ok(r) => r,
Err(e) => on_error!(e),
};
emit(
event_sender.as_ref(),
TrainingEvent::ForwardSyncComplete {
iteration,
global_ub_mean: sync_result.global_ub_mean,
global_ub_std: sync_result.global_ub_std,
sync_time_ms: sync_result.sync_time_ms,
},
);
let mut bwd_spec = crate::backward::BackwardPassSpec {
exchange: &mut exchange_bufs,
records: &records,
iteration,
local_work: my_actual_fwd,
fwd_offset: my_fwd_offset,
risk_measures,
cut_activity_tolerance,
cut_sync_bufs: &mut cut_sync_bufs,
};
let backward_result = match run_backward_pass(
&mut fwd_pool.workspaces,
&basis_store,
stage_ctx,
fcf,
&mut cut_batches,
training_ctx,
&mut bwd_spec,
comm,
) {
Ok(r) => r,
Err(e) => on_error!(e),
};
let bwd_solve_time_ms = {
let deltas: Vec<_> = backward_result
.stage_stats
.iter()
.map(|(_, d)| d.clone())
.collect();
let agg = SolverStatsDelta::aggregate(&deltas);
let total_ms = agg.solve_time_ms;
#[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)]
for (stage_idx, delta) in &backward_result.stage_stats {
solver_stats_log.push((
iteration,
"backward".to_string(),
*stage_idx as i32,
delta.clone(),
));
}
total_ms
};
let backward_elapsed_ms = backward_result.elapsed_ms;
#[allow(clippy::cast_possible_truncation)]
emit(
event_sender.as_ref(),
TrainingEvent::BackwardPassComplete {
iteration,
cuts_generated: backward_result.cuts_generated as u32,
stages_processed: num_stages.saturating_sub(1) as u32,
elapsed_ms: backward_elapsed_ms,
state_exchange_time_ms: backward_result.state_exchange_time_ms,
cut_batch_build_time_ms: backward_result.cut_batch_build_time_ms,
rayon_overhead_time_ms: backward_result.rayon_overhead_time_ms,
},
);
#[allow(clippy::cast_possible_truncation)]
emit(
event_sender.as_ref(),
TrainingEvent::CutSyncComplete {
iteration,
cuts_distributed: backward_result.cuts_generated as u32,
cuts_active: fcf.total_active_cuts() as u32,
cuts_removed: 0,
sync_time_ms: backward_result.cut_sync_time_ms,
},
);
if let Some(strategy) = cut_selection {
if strategy.should_run(iteration) {
let sel_start = Instant::now();
let num_sel_stages = num_stages.saturating_sub(1);
let mut cuts_deactivated = 0u32;
let mut per_stage = Vec::with_capacity(num_sel_stages);
#[allow(clippy::cast_possible_truncation)]
{
let pool0 = &fcf.pools[0];
let active_0 = pool0.active_count() as u32;
per_stage.push(StageSelectionRecord {
stage: 0,
cuts_populated: pool0.populated_count as u32,
cuts_active_before: active_0,
cuts_deactivated: 0,
cuts_active_after: active_0,
});
}
#[allow(clippy::cast_possible_truncation)]
for stage in 1..num_sel_stages {
let pool = &fcf.pools[stage];
let populated = pool.populated_count as u32;
let active_before = pool.active_count() as u32;
let stage_u32 = stage as u32;
let deact = strategy.select_for_stage(
&pool.metadata[..pool.populated_count],
&pool.active[..pool.populated_count],
iteration,
stage_u32,
);
let n_deact = deact.indices.len() as u32;
cuts_deactivated += n_deact;
fcf.pools[stage].deactivate(&deact.indices);
let active_after = fcf.pools[stage].active_count() as u32;
per_stage.push(StageSelectionRecord {
stage: stage_u32,
cuts_populated: populated,
cuts_active_before: active_before,
cuts_deactivated: n_deact,
cuts_active_after: active_after,
});
}
#[allow(clippy::cast_possible_truncation)]
let selection_time_ms = sel_start.elapsed().as_millis() as u64;
#[allow(clippy::cast_possible_truncation)]
let stages_processed = num_sel_stages as u32;
emit(
event_sender.as_ref(),
TrainingEvent::CutSelectionComplete {
iteration,
cuts_deactivated,
stages_processed,
selection_time_ms,
allgatherv_time_ms: 0,
per_stage,
},
);
}
}
lb_iterations_since_rebuild += 1;
if comm.rank() == 0 && needs_periodic_rebuild(&lb_cut_row_map, lb_iterations_since_rebuild)
{
lb_cut_row_map.reset(stage_ctx.templates[0].num_rows);
lb_iterations_since_rebuild = 0;
}
let lb_stats_before = solver.statistics();
let lb_spec = LbEvalSpec {
template: &stage_ctx.templates[0],
base_row: stage_ctx.base_rows[0],
noise_scale: stage_ctx.noise_scale,
n_hydros: stage_ctx.n_hydros,
opening_tree,
risk_measure: &risk_measures[0],
stochastic: Some(training_ctx.stochastic),
n_load_buses: stage_ctx.n_load_buses,
ncs_max_gen: stage_ctx.ncs_max_gen,
block_count: stage_ctx.block_counts_per_stage[0],
ncs_generation: indexer.ncs_generation.clone(),
};
let lb = match evaluate_lower_bound(
solver,
fcf,
initial_state,
indexer,
&mut patch_buf,
&mut lb_cut_batch,
&lb_spec,
comm,
Some(&mut lb_cut_row_map),
) {
Ok(r) => r,
Err(e) => on_error!(e),
};
let lb_stats_after = solver.statistics();
let lb_lp_solves = lb_stats_after.solve_count - lb_stats_before.solve_count;
let lb_delta = SolverStatsDelta::from_snapshots(&lb_stats_before, &lb_stats_after);
let lb_solve_time_ms = lb_delta.solve_time_ms;
solver_stats_log.push((iteration, "lower_bound".to_string(), -1, lb_delta));
let (should_stop, rule_results) = convergence_monitor.update(lb, &sync_result);
final_lb = convergence_monitor.lower_bound();
final_ub = convergence_monitor.upper_bound();
final_ub_std = convergence_monitor.upper_bound_std();
final_gap = convergence_monitor.gap();
emit(
event_sender.as_ref(),
TrainingEvent::ConvergenceUpdate {
iteration,
lower_bound: final_lb,
upper_bound: final_ub,
upper_bound_std: convergence_monitor.upper_bound_std(),
gap: final_gap,
rules_evaluated: rule_results.clone(),
},
);
#[allow(clippy::cast_possible_truncation)]
let wall_time_ms = start_time.elapsed().as_millis() as u64;
#[allow(clippy::cast_possible_truncation)]
let iteration_time_ms = iter_start.elapsed().as_millis() as u64;
emit(
event_sender.as_ref(),
TrainingEvent::IterationSummary {
iteration,
lower_bound: final_lb,
upper_bound: final_ub,
gap: final_gap,
wall_time_ms,
iteration_time_ms,
forward_ms: forward_elapsed_ms,
backward_ms: backward_elapsed_ms,
lp_solves: forward_result.lp_solves + backward_result.lp_solves + lb_lp_solves,
solve_time_ms: fwd_solve_time_ms + bwd_solve_time_ms + lb_solve_time_ms,
},
);
completed_iterations = iteration;
if should_stop {
termination_reason = rule_results
.iter()
.find(|r| r.triggered)
.map_or_else(|| "unknown".to_string(), |r| r.rule_name.clone());
break;
}
}
#[allow(clippy::cast_possible_truncation)]
let total_time_ms = (start_time.elapsed().as_millis() as u64).max(1);
#[allow(clippy::cast_possible_truncation)]
emit(
event_sender.as_ref(),
TrainingEvent::TrainingFinished {
reason: termination_reason.clone(),
iterations: completed_iterations,
final_lb,
final_ub,
total_time_ms,
total_cuts: fcf.total_active_cuts() as u64,
},
);
let last_scenario = my_actual_fwd.saturating_sub(1);
let basis_cache = (0..num_stages)
.map(|t| basis_store.get(last_scenario, t).cloned())
.collect();
Ok(TrainingOutcome {
result: TrainingResult {
final_lb,
final_ub,
final_ub_std,
final_gap,
iterations: completed_iterations,
reason: termination_reason,
total_time_ms,
basis_cache,
solver_stats_log,
},
error: None,
})
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::float_cmp,
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::too_many_lines
)]
mod tests {
use std::collections::BTreeMap;
use std::sync::mpsc;
use chrono::NaiveDate;
use cobre_comm::{CommData, CommError, Communicator, ReduceOp};
use cobre_core::{
Bus, EntityId, SystemBuilder, TrainingEvent,
scenario::{CorrelationEntity, CorrelationGroup, CorrelationModel, CorrelationProfile},
temporal::{
Block, BlockMode, NoiseMethod, ScenarioSourceConfig, Stage, StageRiskConfig,
StageStateConfig,
},
};
use cobre_solver::{
Basis, LpSolution, RowBatch, SolverError, SolverInterface, SolverStatistics, StageTemplate,
};
use cobre_stochastic::{
StochasticContext, build_stochastic_context, tree::opening_tree::OpeningTree,
};
use super::train;
use crate::{
HorizonMode, InflowNonNegativityMethod, RiskMeasure, SddpError, StageIndexer, StoppingMode,
StoppingRule, StoppingRuleSet, TrainingConfig,
context::{StageContext, TrainingContext},
cut::fcf::FutureCostFunction,
};
fn minimal_template(n_state: usize) -> StageTemplate {
let _ = n_state;
StageTemplate {
num_cols: 4,
num_rows: 2,
num_nz: 1,
col_starts: vec![0_i32, 0, 0, 1, 1],
row_indices: vec![0_i32],
values: vec![1.0],
col_lower: vec![0.0, f64::NEG_INFINITY, 0.0, 0.0],
col_upper: vec![f64::INFINITY; 4],
objective: vec![0.0, 0.0, 0.0, 1.0],
row_lower: vec![0.0, 0.0],
row_upper: vec![0.0, 0.0],
n_state: 1,
n_transfer: 0,
n_dual_relevant: 1,
n_hydro: 1,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
}
}
fn fixed_solution(objective: f64) -> LpSolution {
LpSolution {
objective,
primal: vec![0.0; 4],
dual: vec![0.0; 2],
reduced_costs: vec![0.0; 4],
iterations: 0,
solve_time_seconds: 0.0,
}
}
struct MockSolver {
objectives: Vec<f64>,
call_count: usize,
infeasible_on_first: bool,
}
impl MockSolver {
fn with_fixed(objective: f64) -> Self {
Self {
objectives: vec![objective],
call_count: 0,
infeasible_on_first: false,
}
}
fn infeasible() -> Self {
Self {
objectives: vec![0.0],
call_count: 0,
infeasible_on_first: true,
}
}
}
impl SolverInterface for MockSolver {
fn load_model(&mut self, _t: &StageTemplate) {}
fn add_rows(&mut self, _r: &RowBatch) {}
fn set_row_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn set_col_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn solve(&mut self) -> Result<cobre_solver::SolutionView<'_>, SolverError> {
let call = self.call_count;
self.call_count += 1;
if self.infeasible_on_first && call == 0 {
return Err(SolverError::Infeasible);
}
let obj = self.objectives[call % self.objectives.len()];
let sol = fixed_solution(obj);
let _ = sol;
Ok(cobre_solver::SolutionView {
objective: obj,
primal: &[0.0, 0.0, 0.0, 0.0],
dual: &[0.0, 0.0],
reduced_costs: &[0.0, 0.0, 0.0, 0.0],
iterations: 0,
solve_time_seconds: 0.0,
})
}
fn reset(&mut self) {
self.call_count = 0;
}
fn get_basis(&mut self, _out: &mut Basis) {}
fn solve_with_basis(
&mut self,
_basis: &Basis,
) -> Result<cobre_solver::SolutionView<'_>, SolverError> {
self.solve()
}
fn statistics(&self) -> SolverStatistics {
SolverStatistics::default()
}
fn name(&self) -> &'static str {
"Mock"
}
}
struct StubComm;
impl Communicator for StubComm {
fn allgatherv<T: CommData>(
&self,
send: &[T],
recv: &mut [T],
_counts: &[usize],
_displs: &[usize],
) -> Result<(), CommError> {
recv[..send.len()].clone_from_slice(send);
Ok(())
}
fn allreduce<T: CommData>(
&self,
send: &[T],
recv: &mut [T],
_op: ReduceOp,
) -> Result<(), CommError> {
recv.clone_from_slice(send);
Ok(())
}
fn broadcast<T: CommData>(&self, _buf: &mut [T], _root: usize) -> Result<(), CommError> {
Ok(())
}
fn barrier(&self) -> Result<(), CommError> {
Ok(())
}
fn rank(&self) -> usize {
0
}
fn size(&self) -> usize {
1
}
}
fn make_opening_tree(n_openings: usize) -> OpeningTree {
use chrono::NaiveDate;
use cobre_core::{
EntityId,
scenario::{CorrelationEntity, CorrelationGroup, CorrelationModel, CorrelationProfile},
temporal::{
Block, BlockMode, NoiseMethod, ScenarioSourceConfig, Stage, StageRiskConfig,
StageStateConfig,
},
};
use cobre_stochastic::correlation::resolve::DecomposedCorrelation;
use std::collections::BTreeMap;
let stage = Stage {
index: 0,
id: 0,
start_date: NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
end_date: NaiveDate::from_ymd_opt(2024, 2, 1).unwrap(),
season_id: Some(0),
blocks: vec![Block {
index: 0,
name: "S".to_string(),
duration_hours: 744.0,
}],
block_mode: BlockMode::Parallel,
state_config: StageStateConfig {
storage: true,
inflow_lags: false,
},
risk_config: StageRiskConfig::Expectation,
scenario_config: ScenarioSourceConfig {
branching_factor: n_openings,
noise_method: NoiseMethod::Saa,
},
};
let entity_id = EntityId(1);
let mut profiles = BTreeMap::new();
profiles.insert(
"default".to_string(),
CorrelationProfile {
groups: vec![CorrelationGroup {
name: "g1".to_string(),
entities: vec![CorrelationEntity {
entity_type: "inflow".to_string(),
id: entity_id,
}],
matrix: vec![vec![1.0]],
}],
},
);
let corr_model = CorrelationModel {
method: "cholesky".to_string(),
profiles,
schedule: vec![],
};
let mut decomposed = DecomposedCorrelation::build(&corr_model).unwrap();
let entity_order = vec![entity_id];
cobre_stochastic::tree::generate::generate_opening_tree(
42,
&[stage],
1,
&mut decomposed,
&entity_order,
)
}
#[allow(clippy::cast_possible_wrap)]
fn make_stochastic_context(n_stages: usize, n_openings: usize) -> StochasticContext {
use cobre_core::entities::hydro::{Hydro, HydroGenerationModel, HydroPenalties};
use cobre_core::scenario::InflowModel;
let bus = Bus {
id: EntityId(0),
name: "B0".to_string(),
deficit_segments: vec![cobre_core::DeficitSegment {
depth_mw: None,
cost_per_mwh: 1000.0,
}],
excess_cost: 0.0,
};
let hydro = Hydro {
id: EntityId(1),
name: "H1".to_string(),
bus_id: EntityId(0),
downstream_id: None,
entry_stage_id: None,
exit_stage_id: None,
min_storage_hm3: 0.0,
max_storage_hm3: 100.0,
min_outflow_m3s: 0.0,
max_outflow_m3s: None,
generation_model: HydroGenerationModel::ConstantProductivity {
productivity_mw_per_m3s: 1.0,
},
min_turbined_m3s: 0.0,
max_turbined_m3s: 100.0,
min_generation_mw: 0.0,
max_generation_mw: 100.0,
tailrace: None,
hydraulic_losses: None,
efficiency: None,
evaporation_coefficients_mm: None,
evaporation_reference_volumes_hm3: None,
diversion: None,
filling: None,
penalties: HydroPenalties {
spillage_cost: 0.0,
diversion_cost: 0.0,
fpha_turbined_cost: 0.0,
storage_violation_below_cost: 0.0,
filling_target_violation_cost: 0.0,
turbined_violation_below_cost: 0.0,
outflow_violation_below_cost: 0.0,
outflow_violation_above_cost: 0.0,
generation_violation_below_cost: 0.0,
evaporation_violation_cost: 0.0,
water_withdrawal_violation_cost: 0.0,
},
};
let make_stage = |idx: usize| Stage {
index: idx,
id: idx as i32,
start_date: NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
end_date: NaiveDate::from_ymd_opt(2024, 2, 1).unwrap(),
season_id: Some(0),
blocks: vec![Block {
index: 0,
name: "S".to_string(),
duration_hours: 744.0,
}],
block_mode: BlockMode::Parallel,
state_config: StageStateConfig {
storage: true,
inflow_lags: false,
},
risk_config: StageRiskConfig::Expectation,
scenario_config: ScenarioSourceConfig {
branching_factor: n_openings,
noise_method: NoiseMethod::Saa,
},
};
let stages: Vec<Stage> = (0..n_stages).map(make_stage).collect();
let inflow_models: Vec<InflowModel> = (0..n_stages)
.map(|i| InflowModel {
hydro_id: EntityId(1),
stage_id: i as i32,
mean_m3s: 100.0,
std_m3s: 30.0,
ar_coefficients: vec![],
residual_std_ratio: 1.0,
})
.collect();
let mut profiles = BTreeMap::new();
profiles.insert(
"default".to_string(),
CorrelationProfile {
groups: vec![CorrelationGroup {
name: "g1".to_string(),
entities: vec![CorrelationEntity {
entity_type: "inflow".to_string(),
id: EntityId(1),
}],
matrix: vec![vec![1.0]],
}],
},
);
let correlation = CorrelationModel {
method: "cholesky".to_string(),
profiles,
schedule: vec![],
};
let system = SystemBuilder::new()
.buses(vec![bus])
.hydros(vec![hydro])
.stages(stages)
.inflow_models(inflow_models)
.correlation(correlation)
.build()
.unwrap();
build_stochastic_context(&system, 42, &[], &[], None).unwrap()
}
fn make_fcf(
n_stages: usize,
n_state: usize,
forward_passes: u32,
max_iter: u64,
) -> FutureCostFunction {
FutureCostFunction::new(n_stages, n_state, forward_passes, max_iter, 0)
}
fn iteration_limit_rules(limit: u64) -> StoppingRuleSet {
StoppingRuleSet {
rules: vec![StoppingRule::IterationLimit { limit }],
mode: StoppingMode::Any,
}
}
#[test]
fn ac_train_completes_with_iteration_limit() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0); let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 5,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(5),
None,
0.0,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
1,
)
.unwrap();
assert!(result.error.is_none(), "expected no error");
assert_eq!(result.result.iterations, 5, "expected 5 iterations");
assert_eq!(result.result.reason, "iteration_limit");
}
#[test]
fn ac_train_returns_partial_on_infeasible() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 5,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::infeasible();
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(5),
None,
0.0,
None,
&comm,
1,
|| Ok(MockSolver::infeasible()),
1,
);
let outcome = result.unwrap();
assert!(
outcome.error.is_some(),
"expected error in TrainingOutcome, got: {outcome:?}"
);
assert!(
matches!(outcome.error, Some(SddpError::Infeasible { stage: 0, .. })),
"expected SddpError::Infeasible at stage 0, got: {:?}",
outcome.error
);
assert_eq!(
outcome.result.iterations, 0,
"no iterations should have completed"
);
assert_eq!(outcome.result.reason, "error");
}
#[test]
fn ac_train_emits_correct_event_sequence() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 10,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: Some(tx),
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
};
train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(2),
None,
0.0,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
1,
)
.unwrap();
drop(fcf); let events: Vec<TrainingEvent> = rx.try_iter().collect();
assert_eq!(
events.len(),
14,
"expected 14 events, got {} ({events:?})",
events.len()
);
assert!(
matches!(events[0], TrainingEvent::TrainingStarted { .. }),
"first event must be TrainingStarted"
);
assert!(
matches!(events.last(), Some(TrainingEvent::TrainingFinished { .. })),
"last event must be TrainingFinished"
);
assert!(matches!(
events[1],
TrainingEvent::ForwardPassComplete { .. }
));
assert!(matches!(
events[2],
TrainingEvent::ForwardSyncComplete { .. }
));
assert!(matches!(
events[3],
TrainingEvent::BackwardPassComplete { .. }
));
assert!(matches!(events[4], TrainingEvent::CutSyncComplete { .. }));
assert!(matches!(events[5], TrainingEvent::ConvergenceUpdate { .. }));
assert!(matches!(events[6], TrainingEvent::IterationSummary { .. }));
assert!(matches!(
events[7],
TrainingEvent::ForwardPassComplete { .. }
));
assert!(matches!(
events[8],
TrainingEvent::ForwardSyncComplete { .. }
));
assert!(matches!(
events[9],
TrainingEvent::BackwardPassComplete { .. }
));
assert!(matches!(events[10], TrainingEvent::CutSyncComplete { .. }));
assert!(matches!(
events[11],
TrainingEvent::ConvergenceUpdate { .. }
));
assert!(matches!(events[12], TrainingEvent::IterationSummary { .. }));
}
#[test]
fn ac_train_result_fields_populated() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 5,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(5),
None,
0.0,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
1,
)
.unwrap();
assert!(result.error.is_none(), "expected no error");
assert_eq!(result.result.iterations, 5);
assert!(!result.result.reason.is_empty(), "reason must not be empty");
}
#[test]
fn ac_train_with_no_event_sender() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 2,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(2),
None,
0.0,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
1,
);
assert!(result.is_ok(), "train with no event_sender must not panic");
}
#[test]
fn ac_total_time_ms_is_non_negative() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 1,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(1),
None,
0.0,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
1,
)
.unwrap();
assert!(result.error.is_none(), "expected no error");
assert!(
result.result.total_time_ms > 0,
"total_time_ms must be > 0, got {}",
result.result.total_time_ms,
);
}
#[test]
fn cut_selection_none_skips_step() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 10,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: Some(tx),
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
};
train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(5),
None,
0.0,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
1,
)
.unwrap();
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let cut_sel_count = events
.iter()
.filter(|e| matches!(e, TrainingEvent::CutSelectionComplete { .. }))
.count();
assert_eq!(
cut_sel_count, 0,
"expected no CutSelectionComplete events with cut_selection: None"
);
}
#[test]
fn cut_selection_level1_runs_at_frequency() {
use crate::cut_selection::CutSelectionStrategy;
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 10,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: Some(tx),
};
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 3,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
};
train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(5),
Some(&strategy),
0.0,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
1,
)
.unwrap();
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let sel_events: Vec<&TrainingEvent> = events
.iter()
.filter(|e| matches!(e, TrainingEvent::CutSelectionComplete { .. }))
.collect();
assert_eq!(
sel_events.len(),
1,
"expected exactly 1 CutSelectionComplete event for check_frequency=3 over 5 iterations"
);
let TrainingEvent::CutSelectionComplete { iteration, .. } = sel_events[0] else {
panic!("wrong variant");
};
assert_eq!(
*iteration, 3,
"CutSelectionComplete must fire at iteration 3"
);
}
#[test]
fn cut_selection_stage0_exempt_preserves_cuts() {
use crate::cut_selection::CutSelectionStrategy;
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 10,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: Some(tx),
};
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 2,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
};
train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(2),
Some(&strategy),
0.0,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
1,
)
.unwrap();
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let sel_events: Vec<&TrainingEvent> = events
.iter()
.filter(|e| matches!(e, TrainingEvent::CutSelectionComplete { .. }))
.collect();
assert_eq!(
sel_events.len(),
1,
"expected exactly 1 CutSelectionComplete event at iteration 2"
);
let TrainingEvent::CutSelectionComplete {
iteration,
cuts_deactivated,
per_stage,
..
} = sel_events[0]
else {
panic!("wrong variant");
};
assert_eq!(*iteration, 2, "selection must fire at iteration 2");
assert_eq!(
*cuts_deactivated, 0,
"stage 0 is exempt from cut selection, so no cuts should be deactivated"
);
assert!(
!per_stage.is_empty(),
"per_stage must contain at least the stage 0 record"
);
assert_eq!(per_stage[0].stage, 0, "first record must be stage 0");
assert_eq!(
per_stage[0].cuts_deactivated, 0,
"stage 0 must have zero deactivations"
);
}
#[test]
fn existing_train_tests_pass_with_none() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 3,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(3),
None,
0.0,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
1,
)
.unwrap();
assert!(result.error.is_none(), "expected no error");
assert_eq!(result.result.iterations, 3);
assert_eq!(result.result.reason, "iteration_limit");
}
#[test]
fn ac_train_partial_result_on_mid_iteration_failure() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 5,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: Some(tx),
};
let mut solver = MockSolver::infeasible();
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
};
let outcome = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(5),
None,
0.0,
None,
&comm,
1,
|| Ok(MockSolver::infeasible()),
1,
)
.unwrap();
assert!(outcome.error.is_some(), "expected error in TrainingOutcome");
assert_eq!(
outcome.result.iterations, 0,
"no iterations should have completed (failure in iteration 1)"
);
assert_eq!(outcome.result.reason, "error");
assert!(
outcome.result.total_time_ms > 0,
"total_time_ms must be > 0"
);
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let finished = events
.iter()
.find(|e| matches!(e, TrainingEvent::TrainingFinished { .. }));
assert!(
finished.is_some(),
"TrainingFinished event must be emitted even on error"
);
if let Some(TrainingEvent::TrainingFinished { reason, .. }) = finished {
assert_eq!(reason, "error", "TrainingFinished reason must be 'error'");
}
}
}