use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::Sender;
use std::time::Instant;
use cobre_comm::Communicator;
use cobre_core::TrainingEvent;
use cobre_solver::Basis;
use cobre_solver::SolverInterface;
use cobre_solver::StageTemplate;
use cobre_stochastic::OpeningTree;
use crate::{
SddpError, StoppingRuleSet, TrainingConfig, TrajectoryRecord,
backward::run_backward_pass,
context::{StageContext, TrainingContext},
convergence::ConvergenceMonitor,
cut::fcf::FutureCostFunction,
cut_selection::CutSelectionStrategy,
cut_sync::CutSyncBuffers,
evaluate_lower_bound,
forward::{ForwardPassBatch, run_forward_pass, sync_forward},
lower_bound::LbEvalSpec,
lp_builder::PatchBuffer,
risk_measure::RiskMeasure,
state_exchange::ExchangeBuffers,
stopping_rule::RULE_ITERATION_LIMIT,
workspace::{BasisStore, WorkspacePool},
};
#[derive(Debug, Clone)]
pub struct TrainingResult {
pub final_lb: f64,
pub final_ub: f64,
pub final_gap: f64,
pub iterations: u64,
pub reason: String,
pub total_time_ms: u64,
pub basis_cache: Vec<Option<Basis>>,
}
#[inline]
fn emit(sender: Option<&Sender<TrainingEvent>>, event: TrainingEvent) {
if let Some(s) = sender {
let _ = s.send(event);
}
}
fn collect_local_cuts_for_stage(
fcf: &FutureCostFunction,
stage: usize,
iteration: u64,
) -> Vec<(u32, u32, u32, f64, Vec<f64>)> {
let pool = &fcf.pools[stage];
let mut result = Vec::new();
for slot in 0..pool.populated_count {
if !pool.active[slot] {
continue;
}
let meta = &pool.metadata[slot];
if meta.iteration_generated != iteration {
continue;
}
let intercept = pool.intercepts[slot];
let coefficients = pool.coefficients[slot].clone();
#[allow(clippy::cast_possible_truncation)]
let slot_u32 = slot as u32;
#[allow(clippy::cast_possible_truncation)]
let iter_u32 = iteration as u32;
result.push((
slot_u32,
iter_u32,
meta.forward_pass_index,
intercept,
coefficients,
));
}
result
}
#[allow(
clippy::too_many_arguments,
clippy::too_many_lines,
clippy::similar_names
)]
pub fn train<S: SolverInterface + Send, C: Communicator>(
solver: &mut S,
config: TrainingConfig,
fcf: &mut FutureCostFunction,
templates: &[StageTemplate],
base_rows: &[usize],
training_ctx: &TrainingContext<'_>,
opening_tree: &OpeningTree,
risk_measures: &[RiskMeasure],
stopping_rules: StoppingRuleSet,
cut_selection: Option<&CutSelectionStrategy>,
shutdown_flag: Option<&Arc<AtomicBool>>,
comm: &C,
n_fwd_threads: usize,
solver_factory: impl Fn() -> Result<S, cobre_solver::SolverError>,
noise_scale: &[f64],
n_hydros: usize,
n_load_buses: usize,
max_blocks: usize,
load_balance_row_starts: &[usize],
load_bus_indices: &[usize],
block_counts_per_stage: &[usize],
) -> Result<TrainingResult, SddpError> {
let horizon = training_ctx.horizon;
let indexer = training_ctx.indexer;
let initial_state = training_ctx.initial_state;
let num_stages = horizon.num_stages();
let num_ranks = comm.size();
let my_rank = comm.rank();
let total_forward_passes = config.forward_passes as usize;
let n_state = indexer.n_state;
let base_fwd = total_forward_passes / num_ranks;
let remainder_fwd = total_forward_passes % num_ranks;
let my_actual_fwd = base_fwd + usize::from(my_rank < remainder_fwd);
let my_fwd_offset = base_fwd * my_rank + my_rank.min(remainder_fwd);
let max_local_fwd = base_fwd + usize::from(remainder_fwd > 0);
let empty_record = TrajectoryRecord {
primal: vec![],
dual: vec![],
stage_cost: 0.0,
state: vec![0.0; n_state],
};
let mut records = vec![empty_record; max_local_fwd * num_stages];
let n_threads = n_fwd_threads.max(1);
let mut fwd_pool = WorkspacePool::try_new(
n_threads,
indexer.hydro_count,
indexer.max_par_order,
n_state,
n_load_buses,
max_blocks,
solver_factory,
)
.map_err(SddpError::Solver)?;
let mut basis_store = BasisStore::new(max_local_fwd, num_stages);
let mut patch_buf = PatchBuffer::new(indexer.hydro_count, indexer.max_par_order, 0, 0);
let mut convergence_monitor = ConvergenceMonitor::new(stopping_rules);
let mut exchange_bufs = ExchangeBuffers::new(n_state, max_local_fwd, num_ranks);
let mut cut_sync_bufs =
CutSyncBuffers::with_distribution(n_state, max_local_fwd, num_ranks, total_forward_passes);
let start_time = Instant::now();
let stage_ctx = StageContext {
templates,
base_rows,
noise_scale,
n_hydros,
n_load_buses,
load_balance_row_starts,
load_bus_indices,
block_counts_per_stage,
};
let TrainingConfig {
forward_passes: config_forward_passes,
max_iterations,
event_sender,
..
} = config;
#[allow(clippy::cast_possible_truncation)]
emit(
event_sender.as_ref(),
TrainingEvent::TrainingStarted {
case_name: String::new(),
stages: num_stages as u32,
hydros: indexer.hydro_count as u32,
thermals: 0,
ranks: num_ranks as u32,
#[allow(clippy::cast_possible_truncation)]
threads_per_rank: n_threads as u32,
timestamp: String::new(),
},
);
let mut final_lb = 0.0;
let mut final_ub = 0.0;
let mut final_gap = 0.0;
let mut completed_iterations = 0u64;
let mut termination_reason = RULE_ITERATION_LIMIT.to_string();
for iteration in 1..=max_iterations {
if let Some(flag) = shutdown_flag {
if flag.load(Ordering::Relaxed) {
convergence_monitor.set_shutdown();
}
}
let iter_start = Instant::now();
let fwd_record_len = my_actual_fwd * num_stages;
let fwd_batch = ForwardPassBatch {
local_forward_passes: my_actual_fwd,
iteration,
fwd_offset: my_fwd_offset,
};
let forward_result = run_forward_pass(
&mut fwd_pool.workspaces,
&mut basis_store,
&stage_ctx,
fcf,
training_ctx,
&fwd_batch,
&mut records[..fwd_record_len],
)?;
let forward_elapsed_ms = forward_result.elapsed_ms;
let local_n = forward_result.scenario_costs.len();
let local_cost_sum: f64 = forward_result.scenario_costs.iter().sum();
emit(
event_sender.as_ref(),
TrainingEvent::ForwardPassComplete {
iteration,
scenarios: config_forward_passes,
#[allow(clippy::cast_precision_loss)]
ub_mean: if local_n > 0 {
local_cost_sum / local_n as f64
} else {
0.0
},
ub_std: 0.0,
elapsed_ms: forward_elapsed_ms,
},
);
let sync_result = sync_forward(&forward_result, comm, total_forward_passes)?;
emit(
event_sender.as_ref(),
TrainingEvent::ForwardSyncComplete {
iteration,
global_ub_mean: sync_result.global_ub_mean,
global_ub_std: sync_result.global_ub_std,
sync_time_ms: sync_result.sync_time_ms,
},
);
for stage in 0..num_stages {
exchange_bufs.exchange(&records, stage, num_stages, comm)?;
}
let bwd_spec = crate::backward::BackwardPassSpec {
exchange: &exchange_bufs,
iteration,
local_work: my_actual_fwd,
fwd_offset: my_fwd_offset,
risk_measures,
};
let backward_result = run_backward_pass(
&mut fwd_pool.workspaces,
&basis_store,
&stage_ctx,
fcf,
training_ctx,
&bwd_spec,
comm,
)?;
let backward_elapsed_ms = backward_result.elapsed_ms;
#[allow(clippy::cast_possible_truncation)]
emit(
event_sender.as_ref(),
TrainingEvent::BackwardPassComplete {
iteration,
cuts_generated: backward_result.cuts_generated as u32,
stages_processed: num_stages.saturating_sub(1) as u32,
elapsed_ms: backward_elapsed_ms,
},
);
for stage in 0..num_stages.saturating_sub(1) {
let owned_cuts = collect_local_cuts_for_stage(fcf, stage, iteration);
let local_cuts: Vec<(u32, u32, u32, f64, &[f64])> = owned_cuts
.iter()
.map(|(slot, iter, fp, intercept, coeffs)| {
(*slot, *iter, *fp, *intercept, coeffs.as_slice())
})
.collect();
cut_sync_bufs.sync_cuts(stage, &local_cuts, fcf, comm)?;
}
#[allow(clippy::cast_possible_truncation)]
emit(
event_sender.as_ref(),
TrainingEvent::CutSyncComplete {
iteration,
cuts_distributed: backward_result.cuts_generated as u32,
cuts_active: fcf.total_active_cuts() as u32,
cuts_removed: 0,
sync_time_ms: 0,
},
);
if let Some(strategy) = cut_selection {
if strategy.should_run(iteration) {
let sel_start = Instant::now();
let num_sel_stages = num_stages.saturating_sub(1);
let mut cuts_deactivated = 0u32;
#[allow(clippy::cast_possible_truncation)]
for stage in 0..num_sel_stages {
let stage_u32 = stage as u32;
let deact =
strategy.select_for_stage(&fcf.pools[stage].metadata, iteration, stage_u32);
cuts_deactivated += deact.indices.len() as u32;
fcf.pools[stage].deactivate(&deact.indices);
}
#[allow(clippy::cast_possible_truncation)]
let selection_time_ms = sel_start.elapsed().as_millis() as u64;
#[allow(clippy::cast_possible_truncation)]
let stages_processed = num_sel_stages as u32;
emit(
event_sender.as_ref(),
TrainingEvent::CutSelectionComplete {
iteration,
cuts_deactivated,
stages_processed,
selection_time_ms,
allgatherv_time_ms: 0,
},
);
}
}
let lb_solves_before = solver.statistics().solve_count;
let lb_spec = LbEvalSpec {
template: &templates[0],
base_row: base_rows[0],
noise_scale,
n_hydros,
opening_tree,
risk_measure: &risk_measures[0],
};
let lb = evaluate_lower_bound(
solver,
fcf,
initial_state,
indexer,
&mut patch_buf,
&lb_spec,
comm,
)?;
let lb_lp_solves = solver.statistics().solve_count - lb_solves_before;
let (should_stop, rule_results) = convergence_monitor.update(lb, &sync_result);
final_lb = convergence_monitor.lower_bound();
final_ub = convergence_monitor.upper_bound();
final_gap = convergence_monitor.gap();
emit(
event_sender.as_ref(),
TrainingEvent::ConvergenceUpdate {
iteration,
lower_bound: final_lb,
upper_bound: final_ub,
upper_bound_std: convergence_monitor.upper_bound_std(),
gap: final_gap,
rules_evaluated: rule_results.clone(),
},
);
#[allow(clippy::cast_possible_truncation)]
let wall_time_ms = start_time.elapsed().as_millis() as u64;
#[allow(clippy::cast_possible_truncation)]
let iteration_time_ms = iter_start.elapsed().as_millis() as u64;
emit(
event_sender.as_ref(),
TrainingEvent::IterationSummary {
iteration,
lower_bound: final_lb,
upper_bound: final_ub,
gap: final_gap,
wall_time_ms,
iteration_time_ms,
forward_ms: forward_elapsed_ms,
backward_ms: backward_elapsed_ms,
lp_solves: forward_result.lp_solves + backward_result.lp_solves + lb_lp_solves,
},
);
completed_iterations = iteration;
if should_stop {
termination_reason = rule_results
.iter()
.find(|r| r.triggered)
.map_or_else(|| "unknown".to_string(), |r| r.rule_name.clone());
break;
}
}
#[allow(clippy::cast_possible_truncation)]
let total_time_ms = (start_time.elapsed().as_millis() as u64).max(1);
#[allow(clippy::cast_possible_truncation)]
emit(
event_sender.as_ref(),
TrainingEvent::TrainingFinished {
reason: termination_reason.clone(),
iterations: completed_iterations,
final_lb,
final_ub,
total_time_ms,
total_cuts: fcf.total_active_cuts() as u64,
},
);
let last_scenario = my_actual_fwd.saturating_sub(1);
let basis_cache = (0..num_stages)
.map(|t| basis_store.get(last_scenario, t).cloned())
.collect();
Ok(TrainingResult {
final_lb,
final_ub,
final_gap,
iterations: completed_iterations,
reason: termination_reason,
total_time_ms,
basis_cache,
})
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::float_cmp,
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::too_many_lines
)]
mod tests {
use std::collections::BTreeMap;
use std::sync::mpsc;
use chrono::NaiveDate;
use cobre_comm::{CommData, CommError, Communicator, ReduceOp};
use cobre_core::{
Bus, EntityId, SystemBuilder, TrainingEvent,
scenario::{CorrelationEntity, CorrelationGroup, CorrelationModel, CorrelationProfile},
temporal::{
Block, BlockMode, NoiseMethod, ScenarioSourceConfig, Stage, StageRiskConfig,
StageStateConfig,
},
};
use cobre_solver::{
Basis, LpSolution, RowBatch, SolverError, SolverInterface, SolverStatistics, StageTemplate,
};
use cobre_stochastic::{
StochasticContext, build_stochastic_context, tree::opening_tree::OpeningTree,
};
use super::train;
use crate::{
HorizonMode, InflowNonNegativityMethod, RiskMeasure, SddpError, StageIndexer, StoppingMode,
StoppingRule, StoppingRuleSet, TrainingConfig, context::TrainingContext,
cut::fcf::FutureCostFunction,
};
fn minimal_template(n_state: usize) -> StageTemplate {
let _ = n_state;
StageTemplate {
num_cols: 3,
num_rows: 1,
num_nz: 1,
col_starts: vec![0_i32, 0, 1, 1],
row_indices: vec![0_i32],
values: vec![1.0],
col_lower: vec![0.0, 0.0, 0.0],
col_upper: vec![f64::INFINITY, f64::INFINITY, f64::INFINITY],
objective: vec![0.0, 0.0, 1.0],
row_lower: vec![0.0],
row_upper: vec![0.0],
n_state: 1,
n_transfer: 0,
n_dual_relevant: 1,
n_hydro: 1,
max_par_order: 0,
}
}
fn fixed_solution(objective: f64) -> LpSolution {
LpSolution {
objective,
primal: vec![0.0; 3],
dual: vec![0.0; 1],
reduced_costs: vec![0.0; 3],
iterations: 0,
solve_time_seconds: 0.0,
}
}
struct MockSolver {
objectives: Vec<f64>,
call_count: usize,
infeasible_on_first: bool,
}
impl MockSolver {
fn with_fixed(objective: f64) -> Self {
Self {
objectives: vec![objective],
call_count: 0,
infeasible_on_first: false,
}
}
fn infeasible() -> Self {
Self {
objectives: vec![0.0],
call_count: 0,
infeasible_on_first: true,
}
}
}
impl SolverInterface for MockSolver {
fn load_model(&mut self, _t: &StageTemplate) {}
fn add_rows(&mut self, _r: &RowBatch) {}
fn set_row_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn set_col_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn solve(&mut self) -> Result<cobre_solver::SolutionView<'_>, SolverError> {
let call = self.call_count;
self.call_count += 1;
if self.infeasible_on_first && call == 0 {
return Err(SolverError::Infeasible);
}
let obj = self.objectives[call % self.objectives.len()];
let sol = fixed_solution(obj);
let _ = sol;
Ok(cobre_solver::SolutionView {
objective: obj,
primal: &[0.0, 0.0, 0.0],
dual: &[0.0],
reduced_costs: &[0.0, 0.0, 0.0],
iterations: 0,
solve_time_seconds: 0.0,
})
}
fn reset(&mut self) {
self.call_count = 0;
}
fn get_basis(&mut self, _out: &mut Basis) {}
fn solve_with_basis(
&mut self,
_basis: &Basis,
) -> Result<cobre_solver::SolutionView<'_>, SolverError> {
self.solve()
}
fn statistics(&self) -> SolverStatistics {
SolverStatistics::default()
}
fn name(&self) -> &'static str {
"Mock"
}
}
struct StubComm;
impl Communicator for StubComm {
fn allgatherv<T: CommData>(
&self,
send: &[T],
recv: &mut [T],
_counts: &[usize],
_displs: &[usize],
) -> Result<(), CommError> {
recv[..send.len()].clone_from_slice(send);
Ok(())
}
fn allreduce<T: CommData>(
&self,
send: &[T],
recv: &mut [T],
_op: ReduceOp,
) -> Result<(), CommError> {
recv.clone_from_slice(send);
Ok(())
}
fn broadcast<T: CommData>(&self, _buf: &mut [T], _root: usize) -> Result<(), CommError> {
Ok(())
}
fn barrier(&self) -> Result<(), CommError> {
Ok(())
}
fn rank(&self) -> usize {
0
}
fn size(&self) -> usize {
1
}
}
fn make_opening_tree(n_openings: usize) -> OpeningTree {
use chrono::NaiveDate;
use cobre_core::{
EntityId,
scenario::{CorrelationEntity, CorrelationGroup, CorrelationModel, CorrelationProfile},
temporal::{
Block, BlockMode, NoiseMethod, ScenarioSourceConfig, Stage, StageRiskConfig,
StageStateConfig,
},
};
use cobre_stochastic::correlation::resolve::DecomposedCorrelation;
use std::collections::BTreeMap;
let stage = Stage {
index: 0,
id: 0,
start_date: NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
end_date: NaiveDate::from_ymd_opt(2024, 2, 1).unwrap(),
season_id: Some(0),
blocks: vec![Block {
index: 0,
name: "S".to_string(),
duration_hours: 744.0,
}],
block_mode: BlockMode::Parallel,
state_config: StageStateConfig {
storage: true,
inflow_lags: false,
},
risk_config: StageRiskConfig::Expectation,
scenario_config: ScenarioSourceConfig {
branching_factor: n_openings,
noise_method: NoiseMethod::Saa,
},
};
let entity_id = EntityId(1);
let mut profiles = BTreeMap::new();
profiles.insert(
"default".to_string(),
CorrelationProfile {
groups: vec![CorrelationGroup {
name: "g1".to_string(),
entities: vec![CorrelationEntity {
entity_type: "inflow".to_string(),
id: entity_id,
}],
matrix: vec![vec![1.0]],
}],
},
);
let corr_model = CorrelationModel {
method: "cholesky".to_string(),
profiles,
schedule: vec![],
};
let mut decomposed = DecomposedCorrelation::build(&corr_model).unwrap();
let entity_order = vec![entity_id];
cobre_stochastic::tree::generate::generate_opening_tree(
42,
&[stage],
1,
&mut decomposed,
&entity_order,
)
}
#[allow(clippy::cast_possible_wrap)]
fn make_stochastic_context(n_stages: usize, n_openings: usize) -> StochasticContext {
use cobre_core::entities::hydro::{Hydro, HydroGenerationModel, HydroPenalties};
use cobre_core::scenario::InflowModel;
let bus = Bus {
id: EntityId(0),
name: "B0".to_string(),
deficit_segments: vec![cobre_core::DeficitSegment {
depth_mw: None,
cost_per_mwh: 1000.0,
}],
excess_cost: 0.0,
};
let hydro = Hydro {
id: EntityId(1),
name: "H1".to_string(),
bus_id: EntityId(0),
downstream_id: None,
entry_stage_id: None,
exit_stage_id: None,
min_storage_hm3: 0.0,
max_storage_hm3: 100.0,
min_outflow_m3s: 0.0,
max_outflow_m3s: None,
generation_model: HydroGenerationModel::ConstantProductivity {
productivity_mw_per_m3s: 1.0,
},
min_turbined_m3s: 0.0,
max_turbined_m3s: 100.0,
min_generation_mw: 0.0,
max_generation_mw: 100.0,
tailrace: None,
hydraulic_losses: None,
efficiency: None,
evaporation_coefficients_mm: None,
diversion: None,
filling: None,
penalties: HydroPenalties {
spillage_cost: 0.0,
diversion_cost: 0.0,
fpha_turbined_cost: 0.0,
storage_violation_below_cost: 0.0,
filling_target_violation_cost: 0.0,
turbined_violation_below_cost: 0.0,
outflow_violation_below_cost: 0.0,
outflow_violation_above_cost: 0.0,
generation_violation_below_cost: 0.0,
evaporation_violation_cost: 0.0,
water_withdrawal_violation_cost: 0.0,
},
};
let make_stage = |idx: usize| Stage {
index: idx,
id: idx as i32,
start_date: NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
end_date: NaiveDate::from_ymd_opt(2024, 2, 1).unwrap(),
season_id: Some(0),
blocks: vec![Block {
index: 0,
name: "S".to_string(),
duration_hours: 744.0,
}],
block_mode: BlockMode::Parallel,
state_config: StageStateConfig {
storage: true,
inflow_lags: false,
},
risk_config: StageRiskConfig::Expectation,
scenario_config: ScenarioSourceConfig {
branching_factor: n_openings,
noise_method: NoiseMethod::Saa,
},
};
let stages: Vec<Stage> = (0..n_stages).map(make_stage).collect();
let inflow_models: Vec<InflowModel> = (0..n_stages)
.map(|i| InflowModel {
hydro_id: EntityId(1),
stage_id: i as i32,
mean_m3s: 100.0,
std_m3s: 30.0,
ar_coefficients: vec![],
residual_std_ratio: 1.0,
})
.collect();
let mut profiles = BTreeMap::new();
profiles.insert(
"default".to_string(),
CorrelationProfile {
groups: vec![CorrelationGroup {
name: "g1".to_string(),
entities: vec![CorrelationEntity {
entity_type: "inflow".to_string(),
id: EntityId(1),
}],
matrix: vec![vec![1.0]],
}],
},
);
let correlation = CorrelationModel {
method: "cholesky".to_string(),
profiles,
schedule: vec![],
};
let system = SystemBuilder::new()
.buses(vec![bus])
.hydros(vec![hydro])
.stages(stages)
.inflow_models(inflow_models)
.correlation(correlation)
.build()
.unwrap();
build_stochastic_context(&system, 42, &[]).unwrap()
}
fn make_fcf(
n_stages: usize,
n_state: usize,
forward_passes: u32,
max_iter: u64,
) -> FutureCostFunction {
FutureCostFunction::new(n_stages, n_state, forward_passes, max_iter, 0)
}
fn iteration_limit_rules(limit: u64) -> StoppingRuleSet {
StoppingRuleSet {
rules: vec![StoppingRule::IterationLimit { limit }],
mode: StoppingMode::Any,
}
}
#[test]
fn ac_train_completes_with_iteration_limit() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0); let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![1usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 5,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let result = train(
&mut solver,
config,
&mut fcf,
&templates,
&base_rows,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(5),
None,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
&[],
0,
0,
1,
&[],
&[],
&[1usize, 1],
)
.unwrap();
assert_eq!(result.iterations, 5, "expected 5 iterations");
assert_eq!(result.reason, "iteration_limit");
}
#[test]
fn ac_train_returns_error_on_infeasible() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![1usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 5,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::infeasible();
let comm = StubComm;
let result = train(
&mut solver,
config,
&mut fcf,
&templates,
&base_rows,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(5),
None,
None,
&comm,
1,
|| Ok(MockSolver::infeasible()),
&[],
0,
0,
1,
&[],
&[],
&[1usize, 1],
);
assert!(
matches!(result, Err(SddpError::Infeasible { stage: 0, .. })),
"expected SddpError::Infeasible at stage 0, got: {result:?}"
);
}
#[test]
fn ac_train_emits_correct_event_sequence() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![1usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 10,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: Some(tx),
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
train(
&mut solver,
config,
&mut fcf,
&templates,
&base_rows,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(2),
None,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
&[],
0,
0,
1,
&[],
&[],
&[1usize, 1],
)
.unwrap();
drop(fcf); let events: Vec<TrainingEvent> = rx.try_iter().collect();
assert_eq!(
events.len(),
14,
"expected 14 events, got {} ({events:?})",
events.len()
);
assert!(
matches!(events[0], TrainingEvent::TrainingStarted { .. }),
"first event must be TrainingStarted"
);
assert!(
matches!(events.last(), Some(TrainingEvent::TrainingFinished { .. })),
"last event must be TrainingFinished"
);
assert!(matches!(
events[1],
TrainingEvent::ForwardPassComplete { .. }
));
assert!(matches!(
events[2],
TrainingEvent::ForwardSyncComplete { .. }
));
assert!(matches!(
events[3],
TrainingEvent::BackwardPassComplete { .. }
));
assert!(matches!(events[4], TrainingEvent::CutSyncComplete { .. }));
assert!(matches!(events[5], TrainingEvent::ConvergenceUpdate { .. }));
assert!(matches!(events[6], TrainingEvent::IterationSummary { .. }));
assert!(matches!(
events[7],
TrainingEvent::ForwardPassComplete { .. }
));
assert!(matches!(
events[8],
TrainingEvent::ForwardSyncComplete { .. }
));
assert!(matches!(
events[9],
TrainingEvent::BackwardPassComplete { .. }
));
assert!(matches!(events[10], TrainingEvent::CutSyncComplete { .. }));
assert!(matches!(
events[11],
TrainingEvent::ConvergenceUpdate { .. }
));
assert!(matches!(events[12], TrainingEvent::IterationSummary { .. }));
}
#[test]
fn ac_train_result_fields_populated() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![1usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 5,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let result = train(
&mut solver,
config,
&mut fcf,
&templates,
&base_rows,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(5),
None,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
&[],
0,
0,
1,
&[],
&[],
&[1usize, 1],
)
.unwrap();
assert_eq!(result.iterations, 5);
assert!(!result.reason.is_empty(), "reason must not be empty");
}
#[test]
fn ac_train_with_no_event_sender() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![1usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 2,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let result = train(
&mut solver,
config,
&mut fcf,
&templates,
&base_rows,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(2),
None,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
&[],
0,
0,
1,
&[],
&[],
&[1usize, 1],
);
assert!(result.is_ok(), "train with no event_sender must not panic");
}
#[test]
fn ac_total_time_ms_is_non_negative() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![1usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 1,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let result = train(
&mut solver,
config,
&mut fcf,
&templates,
&base_rows,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(1),
None,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
&[],
0,
0,
1,
&[],
&[],
&[1usize, 1],
)
.unwrap();
assert!(
result.total_time_ms > 0,
"total_time_ms must be > 0, got {}",
result.total_time_ms,
);
}
#[test]
fn cut_selection_none_skips_step() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![1usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 10,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: Some(tx),
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
train(
&mut solver,
config,
&mut fcf,
&templates,
&base_rows,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(5),
None,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
&[],
0,
0,
1,
&[],
&[],
&[1usize, 1],
)
.unwrap();
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let cut_sel_count = events
.iter()
.filter(|e| matches!(e, TrainingEvent::CutSelectionComplete { .. }))
.count();
assert_eq!(
cut_sel_count, 0,
"expected no CutSelectionComplete events with cut_selection: None"
);
}
#[test]
fn cut_selection_level1_runs_at_frequency() {
use crate::cut_selection::CutSelectionStrategy;
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![1usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 10,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: Some(tx),
};
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 3,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
train(
&mut solver,
config,
&mut fcf,
&templates,
&base_rows,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(5),
Some(&strategy),
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
&[],
0,
0,
1,
&[],
&[],
&[1usize, 1],
)
.unwrap();
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let sel_events: Vec<&TrainingEvent> = events
.iter()
.filter(|e| matches!(e, TrainingEvent::CutSelectionComplete { .. }))
.collect();
assert_eq!(
sel_events.len(),
1,
"expected exactly 1 CutSelectionComplete event for check_frequency=3 over 5 iterations"
);
let TrainingEvent::CutSelectionComplete { iteration, .. } = sel_events[0] else {
panic!("wrong variant");
};
assert_eq!(
*iteration, 3,
"CutSelectionComplete must fire at iteration 3"
);
}
#[test]
fn cut_selection_deactivates_inactive_cuts() {
use crate::cut_selection::CutSelectionStrategy;
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![1usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 10,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: Some(tx),
};
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 2,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
train(
&mut solver,
config,
&mut fcf,
&templates,
&base_rows,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(2),
Some(&strategy),
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
&[],
0,
0,
1,
&[],
&[],
&[1usize, 1],
)
.unwrap();
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let sel_events: Vec<&TrainingEvent> = events
.iter()
.filter(|e| matches!(e, TrainingEvent::CutSelectionComplete { .. }))
.collect();
assert_eq!(
sel_events.len(),
1,
"expected exactly 1 CutSelectionComplete event at iteration 2"
);
let TrainingEvent::CutSelectionComplete {
iteration,
cuts_deactivated,
..
} = sel_events[0]
else {
panic!("wrong variant");
};
assert_eq!(*iteration, 2, "selection must fire at iteration 2");
assert!(
*cuts_deactivated > 0,
"expected cuts_deactivated > 0 (mock solver produces zero-activity cuts), got 0"
);
}
#[test]
fn existing_train_tests_pass_with_none() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![1usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let opening_tree = make_opening_tree(1);
let stochastic = make_stochastic_context(n_stages, 1);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let risk_measures = vec![RiskMeasure::Expectation; n_stages];
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
forward_passes: 1,
max_iterations: 3,
checkpoint_interval: None,
warm_start_cuts: 0,
event_sender: None,
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let result = train(
&mut solver,
config,
&mut fcf,
&templates,
&base_rows,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
},
&opening_tree,
&risk_measures,
iteration_limit_rules(3),
None,
None,
&comm,
1,
|| Ok(MockSolver::with_fixed(100.0)),
&[],
0,
0,
1,
&[],
&[],
&[1usize, 1],
)
.unwrap();
assert_eq!(result.iterations, 3);
assert_eq!(result.reason, "iteration_limit");
}
}