use cobre_comm::Communicator;
use cobre_solver::{SolverInterface, StageTemplate};
use crate::{
SddpError, TrainingConfig,
context::{StageContext, TrainingContext},
cut::fcf::FutureCostFunction,
solver_stats::SolverStatsEntry,
training_session::{IterationOutcome, TrainingSession},
workspace::CapturedBasis,
};
#[derive(Debug)]
pub struct TrainingOutcome {
pub result: TrainingResult,
pub error: Option<SddpError>,
}
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct TrainingResult {
pub final_lb: f64,
pub final_ub: f64,
pub final_ub_std: f64,
pub final_gap: f64,
pub iterations: u64,
pub reason: String,
pub total_time_ms: u64,
pub basis_cache: Vec<Option<CapturedBasis>>,
pub solver_stats_log: Vec<SolverStatsEntry>,
pub visited_archive: Option<crate::visited_states::VisitedStatesArchive>,
pub baked_templates: Option<Vec<StageTemplate>>,
}
impl TrainingResult {
#[must_use]
#[allow(clippy::too_many_arguments, clippy::similar_names)]
pub fn new(
final_lb: f64,
final_ub: f64,
final_ub_std: f64,
final_gap: f64,
iterations: u64,
reason: String,
total_time_ms: u64,
basis_cache: Vec<Option<CapturedBasis>>,
solver_stats_log: Vec<SolverStatsEntry>,
visited_archive: Option<crate::visited_states::VisitedStatesArchive>,
baked_templates: Option<Vec<StageTemplate>>,
) -> Self {
Self {
final_lb,
final_ub,
final_ub_std,
final_gap,
iterations,
reason,
total_time_ms,
basis_cache,
solver_stats_log,
visited_archive,
baked_templates,
}
}
}
fn checked_broadcast_len(len: usize, operation: &'static str) -> Result<i32, SddpError> {
i32::try_from(len).map_err(|_| {
SddpError::Communication(cobre_comm::CommError::InvalidBufferSize {
operation,
expected: i32::MAX as usize,
actual: len,
})
})
}
pub(crate) fn broadcast_basis_cache<C: Communicator>(
basis_store: &crate::workspace::BasisStore,
num_stages: usize,
comm: &C,
) -> Result<Vec<Option<CapturedBasis>>, SddpError> {
if comm.size() == 1 {
let cache = (0..num_stages)
.map(|t| basis_store.get(0, t).cloned())
.collect();
return Ok(cache);
}
let mut buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
if comm.rank() == 0 {
for t in 0..num_stages {
match basis_store.get(0, t) {
None => buf.push(0_i32),
Some(captured) => captured.to_broadcast_payload(&mut buf, &mut f64_buf),
}
}
}
let mut len_buf = [checked_broadcast_len(
buf.len(),
"broadcast_basis_cache_i32",
)?];
comm.broadcast(&mut len_buf, 0).map_err(SddpError::from)?;
let total_len = usize::try_from(len_buf[0]).map_err(|_| {
SddpError::Validation(format!(
"broadcast_basis_cache_i32: received negative length {}",
len_buf[0]
))
})?;
buf.resize(total_len, 0_i32);
comm.broadcast(&mut buf, 0).map_err(SddpError::from)?;
let mut f64_len_buf = [checked_broadcast_len(
f64_buf.len(),
"broadcast_basis_cache_f64",
)?];
comm.broadcast(&mut f64_len_buf, 0)
.map_err(SddpError::from)?;
let f64_total_len = usize::try_from(f64_len_buf[0]).map_err(|_| {
SddpError::Validation(format!(
"broadcast_basis_cache_f64: received negative length {}",
f64_len_buf[0]
))
})?;
f64_buf.resize(f64_total_len, 0.0_f64);
comm.broadcast(&mut f64_buf, 0).map_err(SddpError::from)?;
let mut cache: Vec<Option<CapturedBasis>> = Vec::with_capacity(num_stages);
let mut pos = 0_usize;
let mut f64_pos = 0_usize;
for stage in 0..num_stages {
let captured = CapturedBasis::try_from_broadcast_payload(
stage,
&buf,
&mut pos,
&f64_buf,
&mut f64_pos,
)?;
cache.push(captured);
}
Ok(cache)
}
pub fn train<S: SolverInterface + Send, C: Communicator>(
solver: &mut S,
config: TrainingConfig,
fcf: &mut FutureCostFunction,
stage_ctx: &StageContext<'_>,
training_ctx: &TrainingContext<'_>,
comm: &C,
solver_factory: impl Fn() -> Result<S, cobre_solver::SolverError>,
) -> Result<TrainingOutcome, SddpError> {
let mut session = TrainingSession::new(
solver,
config,
fcf,
stage_ctx,
training_ctx,
comm,
solver_factory,
)?;
for iteration in session.iteration_range() {
match session.run_iteration(iteration) {
Ok(IterationOutcome::Continue) => {}
Ok(IterationOutcome::Converged | IterationOutcome::Shutdown) => break,
Err(e) => return session.finalize_with_error(e),
}
}
session.finalize()
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::float_cmp,
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::too_many_lines,
clippy::doc_markdown,
clippy::needless_range_loop
)]
mod tests {
use std::collections::BTreeMap;
use std::sync::mpsc;
use chrono::NaiveDate;
use cobre_comm::{CommData, CommError, Communicator, ReduceOp};
use cobre_core::{
Bus, EntityId, SystemBuilder, TrainingEvent, WorkerTimingPhase,
scenario::{
CorrelationEntity, CorrelationGroup, CorrelationModel, CorrelationProfile,
SamplingScheme,
},
temporal::{
Block, BlockMode, NoiseMethod, ScenarioSourceConfig, Stage, StageRiskConfig,
StageStateConfig,
},
};
use cobre_solver::{
Basis, LpSolution, RowBatch, SolverError, SolverInterface, SolverStatistics, StageTemplate,
};
use cobre_stochastic::{
ClassSchemes, OpeningTreeInputs, StochasticContext, build_stochastic_context,
};
use super::train;
use crate::{
StoppingMode, StoppingRule, StoppingRuleSet, TrainingConfig,
config::{CutManagementConfig, EventConfig, LoopConfig},
context::{StageContext, TrainingContext},
cut::fcf::FutureCostFunction,
error::SddpError,
horizon_mode::HorizonMode,
indexer::StageIndexer,
inflow_method::InflowNonNegativityMethod,
risk_measure::RiskMeasure,
solver_stats::SolverStatsDelta,
};
fn minimal_template(n_state: usize) -> StageTemplate {
let _ = n_state;
StageTemplate {
num_cols: 4,
num_rows: 2,
num_nz: 1,
col_starts: vec![0_i32, 0, 0, 1, 1],
row_indices: vec![0_i32],
values: vec![1.0],
col_lower: vec![0.0, f64::NEG_INFINITY, 0.0, 0.0],
col_upper: vec![f64::INFINITY; 4],
objective: vec![0.0, 0.0, 0.0, 1.0],
row_lower: vec![0.0, 0.0],
row_upper: vec![0.0, 0.0],
n_state: 1,
n_transfer: 0,
n_dual_relevant: 1,
n_hydro: 1,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
}
}
fn fixed_solution(objective: f64) -> LpSolution {
LpSolution {
objective,
primal: vec![0.0; 4],
dual: vec![0.0; 2],
reduced_costs: vec![0.0; 4],
iterations: 0,
solve_time_seconds: 0.0,
}
}
struct MockSolver {
objectives: Vec<f64>,
call_count: usize,
infeasible_on_first: bool,
}
impl MockSolver {
fn with_fixed(objective: f64) -> Self {
Self {
objectives: vec![objective],
call_count: 0,
infeasible_on_first: false,
}
}
fn infeasible() -> Self {
Self {
objectives: vec![0.0],
call_count: 0,
infeasible_on_first: true,
}
}
}
impl SolverInterface for MockSolver {
fn solver_name_version(&self) -> String {
"MockSolver 0.0.0".to_string()
}
fn load_model(&mut self, _t: &StageTemplate) {}
fn add_rows(&mut self, _r: &RowBatch) {}
fn set_row_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn set_col_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn solve(
&mut self,
_basis: Option<&Basis>,
) -> Result<cobre_solver::SolutionView<'_>, SolverError> {
let call = self.call_count;
self.call_count += 1;
if self.infeasible_on_first && call == 0 {
return Err(SolverError::Infeasible);
}
let obj = self.objectives[call % self.objectives.len()];
let sol = fixed_solution(obj);
let _ = sol;
Ok(cobre_solver::SolutionView {
objective: obj,
primal: &[0.0, 0.0, 0.0, 0.0],
dual: &[0.0, 0.0],
reduced_costs: &[0.0, 0.0, 0.0, 0.0],
iterations: 0,
solve_time_seconds: 0.0,
})
}
fn get_basis(&mut self, _out: &mut Basis) {}
fn statistics(&self) -> SolverStatistics {
SolverStatistics::default()
}
fn name(&self) -> &'static str {
"Mock"
}
}
struct StubComm;
impl Communicator for StubComm {
fn allgatherv<T: CommData>(
&self,
send: &[T],
recv: &mut [T],
_counts: &[usize],
_displs: &[usize],
) -> Result<(), CommError> {
recv[..send.len()].clone_from_slice(send);
Ok(())
}
fn allreduce<T: CommData>(
&self,
send: &[T],
recv: &mut [T],
_op: ReduceOp,
) -> Result<(), CommError> {
recv.clone_from_slice(send);
Ok(())
}
fn broadcast<T: CommData>(&self, _buf: &mut [T], _root: usize) -> Result<(), CommError> {
Ok(())
}
fn barrier(&self) -> Result<(), CommError> {
Ok(())
}
fn rank(&self) -> usize {
0
}
fn size(&self) -> usize {
1
}
fn abort(&self, error_code: i32) -> ! {
std::process::exit(error_code)
}
}
#[allow(clippy::cast_possible_wrap)]
fn make_stochastic_context(n_stages: usize, n_openings: usize) -> StochasticContext {
use cobre_core::entities::hydro::{Hydro, HydroGenerationModel, HydroPenalties};
use cobre_core::scenario::InflowModel;
let bus = Bus {
id: EntityId(0),
name: "B0".to_string(),
deficit_segments: vec![cobre_core::DeficitSegment {
depth_mw: None,
cost_per_mwh: 1000.0,
}],
excess_cost: 0.0,
};
let hydro = Hydro {
id: EntityId(1),
name: "H1".to_string(),
bus_id: EntityId(0),
downstream_id: None,
entry_stage_id: None,
exit_stage_id: None,
min_storage_hm3: 0.0,
max_storage_hm3: 100.0,
min_outflow_m3s: 0.0,
max_outflow_m3s: None,
generation_model: HydroGenerationModel::ConstantProductivity,
min_turbined_m3s: 0.0,
max_turbined_m3s: 100.0,
specific_productivity_mw_per_m3s_per_m: None,
min_generation_mw: 0.0,
max_generation_mw: 100.0,
tailrace: None,
hydraulic_losses: None,
efficiency: None,
evaporation_coefficients_mm: None,
evaporation_reference_volumes_hm3: None,
diversion: None,
filling: None,
penalties: HydroPenalties {
spillage_cost: 0.0,
diversion_cost: 0.0,
turbined_cost: 0.0,
storage_violation_below_cost: 0.0,
filling_target_violation_cost: 0.0,
turbined_violation_below_cost: 0.0,
outflow_violation_below_cost: 0.0,
outflow_violation_above_cost: 0.0,
generation_violation_below_cost: 0.0,
evaporation_violation_cost: 0.0,
water_withdrawal_violation_cost: 0.0,
water_withdrawal_violation_pos_cost: 0.0,
water_withdrawal_violation_neg_cost: 0.0,
evaporation_violation_pos_cost: 0.0,
evaporation_violation_neg_cost: 0.0,
inflow_nonnegativity_cost: 1000.0,
},
};
let make_stage = |idx: usize| Stage {
index: idx,
id: idx as i32,
start_date: NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
end_date: NaiveDate::from_ymd_opt(2024, 2, 1).unwrap(),
season_id: Some(0),
blocks: vec![Block {
index: 0,
name: "S".to_string(),
duration_hours: 744.0,
}],
block_mode: BlockMode::Parallel,
state_config: StageStateConfig {
storage: true,
inflow_lags: false,
},
risk_config: StageRiskConfig::Expectation,
scenario_config: ScenarioSourceConfig {
branching_factor: n_openings,
noise_method: NoiseMethod::Saa,
},
};
let stages: Vec<Stage> = (0..n_stages).map(make_stage).collect();
let inflow_models: Vec<InflowModel> = (0..n_stages)
.map(|i| InflowModel {
hydro_id: EntityId(1),
stage_id: i as i32,
mean_m3s: 100.0,
std_m3s: 30.0,
ar_coefficients: vec![],
residual_std_ratio: 1.0,
annual: None,
})
.collect();
let mut profiles = BTreeMap::new();
profiles.insert(
"default".to_string(),
CorrelationProfile {
groups: vec![CorrelationGroup {
name: "g1".to_string(),
entities: vec![CorrelationEntity {
entity_type: "inflow".to_string(),
id: EntityId(1),
}],
matrix: vec![vec![1.0]],
}],
},
);
let correlation = CorrelationModel {
method: "spectral".to_string(),
profiles,
schedule: vec![],
};
let system = SystemBuilder::new()
.buses(vec![bus])
.hydros(vec![hydro])
.stages(stages)
.inflow_models(inflow_models)
.correlation(correlation)
.build()
.unwrap();
build_stochastic_context(
&system,
42,
None,
&[],
&[],
OpeningTreeInputs::default(),
ClassSchemes {
inflow: Some(SamplingScheme::InSample),
load: Some(SamplingScheme::InSample),
ncs: Some(SamplingScheme::InSample),
},
)
.unwrap()
}
fn make_stages(n_stages: usize) -> Vec<Stage> {
(0..n_stages)
.map(|i| Stage {
index: i,
id: i as i32,
start_date: chrono::NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
end_date: chrono::NaiveDate::from_ymd_opt(2024, 2, 1).unwrap(),
season_id: Some(0),
blocks: vec![Block {
index: 0,
name: "S".to_string(),
duration_hours: 744.0,
}],
block_mode: BlockMode::Parallel,
state_config: cobre_core::temporal::StageStateConfig {
storage: true,
inflow_lags: false,
},
risk_config: cobre_core::temporal::StageRiskConfig::Expectation,
scenario_config: ScenarioSourceConfig {
branching_factor: 1,
noise_method: NoiseMethod::Saa,
},
})
.collect()
}
fn make_fcf(
n_stages: usize,
n_state: usize,
forward_passes: u32,
max_iter: u64,
) -> FutureCostFunction {
FutureCostFunction::new(
n_stages,
n_state,
forward_passes,
max_iter,
&vec![0; n_stages],
)
}
fn iteration_limit_rules(limit: u64) -> StoppingRuleSet {
StoppingRuleSet {
rules: vec![StoppingRule::IterationLimit { limit }],
mode: StoppingMode::Any,
}
}
#[test]
fn ac_train_completes_with_iteration_limit() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0); let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 5,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(5),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: None,
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
assert!(result.error.is_none(), "expected no error");
assert_eq!(result.result.iterations, 5, "expected 5 iterations");
assert_eq!(result.result.reason, "iteration_limit");
}
#[test]
fn ac_train_returns_partial_on_infeasible() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 5,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(5),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: None,
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::infeasible();
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::infeasible()),
);
let outcome = result.unwrap();
assert!(
outcome.error.is_some(),
"expected error in TrainingOutcome, got: {outcome:?}"
);
assert!(
matches!(outcome.error, Some(SddpError::Infeasible { stage: 0, .. })),
"expected SddpError::Infeasible at stage 0, got: {:?}",
outcome.error
);
assert_eq!(
outcome.result.iterations, 0,
"no iterations should have completed"
);
assert_eq!(outcome.result.reason, "error");
}
#[test]
fn ac_train_emits_correct_event_sequence() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 10,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(2),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: Some(tx),
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
drop(fcf); let events: Vec<TrainingEvent> = rx.try_iter().collect();
assert_eq!(
events.len(),
20,
"expected 20 events, got {} ({events:?})",
events.len()
);
assert!(
matches!(events[0], TrainingEvent::TrainingStarted { .. }),
"first event must be TrainingStarted"
);
assert!(
matches!(events.last(), Some(TrainingEvent::TrainingFinished { .. })),
"last event must be TrainingFinished"
);
assert!(matches!(
events[1],
TrainingEvent::WorkerTiming {
phase: WorkerTimingPhase::Forward,
..
}
));
assert!(matches!(
events[2],
TrainingEvent::ForwardPassComplete { .. }
));
assert!(matches!(
events[3],
TrainingEvent::ForwardSyncComplete { .. }
));
assert!(matches!(
events[4],
TrainingEvent::WorkerTiming {
phase: WorkerTimingPhase::Backward,
..
}
));
assert!(matches!(
events[5],
TrainingEvent::BackwardPassComplete { .. }
));
assert!(matches!(
events[6],
TrainingEvent::PolicySyncComplete { .. }
));
assert!(matches!(
events[7],
TrainingEvent::PolicyTemplateBakeComplete { .. }
));
assert!(matches!(events[8], TrainingEvent::ConvergenceUpdate { .. }));
assert!(matches!(events[9], TrainingEvent::IterationSummary { .. }));
assert!(matches!(
events[10],
TrainingEvent::WorkerTiming {
phase: WorkerTimingPhase::Forward,
..
}
));
assert!(matches!(
events[11],
TrainingEvent::ForwardPassComplete { .. }
));
assert!(matches!(
events[12],
TrainingEvent::ForwardSyncComplete { .. }
));
assert!(matches!(
events[13],
TrainingEvent::WorkerTiming {
phase: WorkerTimingPhase::Backward,
..
}
));
assert!(matches!(
events[14],
TrainingEvent::BackwardPassComplete { .. }
));
assert!(matches!(
events[15],
TrainingEvent::PolicySyncComplete { .. }
));
assert!(matches!(
events[16],
TrainingEvent::PolicyTemplateBakeComplete { .. }
));
assert!(matches!(
events[17],
TrainingEvent::ConvergenceUpdate { .. }
));
assert!(matches!(events[18], TrainingEvent::IterationSummary { .. }));
}
#[test]
fn ac_worker_timing_per_worker_event_count_and_setup_invariant() {
use cobre_core::WorkerTimingPhase;
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 10,
start_iteration: 0,
n_fwd_threads: 4,
max_blocks: 1,
stopping_rules: iteration_limit_rules(1),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: Some(tx),
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let worker_events: Vec<&TrainingEvent> = events
.iter()
.filter(|e| matches!(e, TrainingEvent::WorkerTiming { .. }))
.collect();
assert_eq!(
worker_events.len(),
8,
"expected 8 WorkerTiming events (4 workers × 2 phases × 1 iter), got {}",
worker_events.len()
);
let mut fwd_workers = std::collections::BTreeSet::new();
let mut bwd_workers = std::collections::BTreeSet::new();
let mut bwd_setup_sum_ms = 0.0_f64;
for ev in &worker_events {
let TrainingEvent::WorkerTiming {
rank,
worker_id,
iteration,
phase,
timings,
} = ev
else {
unreachable!()
};
assert_eq!(*rank, 0, "expected rank=0 in single-rank stub");
assert!(
(0..4).contains(worker_id),
"worker_id {worker_id} out of [0,4)"
);
assert_eq!(*iteration, 1, "expected iteration=1 (max_iterations=1)");
match phase {
WorkerTimingPhase::Forward => {
assert!(
fwd_workers.insert(*worker_id),
"worker_id {worker_id} duplicated in Forward emissions"
);
assert_eq!(timings.bwd_setup_ms, 0.0);
}
WorkerTimingPhase::Backward => {
assert!(
bwd_workers.insert(*worker_id),
"worker_id {worker_id} duplicated in Backward emissions"
);
bwd_setup_sum_ms += timings.bwd_setup_ms;
assert_eq!(timings.fwd_setup_ms, 0.0);
}
}
}
assert_eq!(fwd_workers.len(), 4, "expected 4 distinct forward workers");
assert_eq!(bwd_workers.len(), 4, "expected 4 distinct backward workers");
let bwd_setup_total_ms_u64 = events
.iter()
.find_map(|e| match e {
TrainingEvent::BackwardPassComplete { setup_time_ms, .. } => Some(*setup_time_ms),
_ => None,
})
.expect("BackwardPassComplete event must exist");
#[allow(clippy::cast_precision_loss)]
let bwd_setup_total_ms = bwd_setup_total_ms_u64 as f64;
assert!(
(bwd_setup_sum_ms - bwd_setup_total_ms).abs() < 1.0,
"sum of per-worker BWD_SETUP ({bwd_setup_sum_ms} ms) must match \
BackwardPassComplete.setup_time_ms ({bwd_setup_total_ms} ms) within ±1 ms"
);
}
#[test]
fn ac_train_result_fields_populated() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 5,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(5),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: None,
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
assert!(result.error.is_none(), "expected no error");
assert_eq!(result.result.iterations, 5);
assert!(!result.result.reason.is_empty(), "reason must not be empty");
}
#[test]
fn ac_train_with_no_event_sender() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 2,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(2),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: None,
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
);
assert!(result.is_ok(), "train with no event_sender must not panic");
}
#[test]
fn ac_total_time_ms_is_non_negative() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 1,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(1),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: None,
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
assert!(result.error.is_none(), "expected no error");
assert!(
result.result.total_time_ms > 0,
"total_time_ms must be > 0, got {}",
result.result.total_time_ms,
);
}
#[test]
fn cut_selection_none_skips_step() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 10,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(5),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: Some(tx),
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let cut_sel_count = events
.iter()
.filter(|e| matches!(e, TrainingEvent::PolicySelectionComplete { .. }))
.count();
assert_eq!(
cut_sel_count, 0,
"expected no PolicySelectionComplete events with cut_selection: None"
);
}
#[test]
fn cut_selection_level1_runs_at_frequency() {
use crate::cut_selection::CutSelectionStrategy;
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 10,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(5),
},
cut_management: CutManagementConfig {
cut_selection: Some(CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 3,
}),
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: Some(tx),
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let sel_events: Vec<&TrainingEvent> = events
.iter()
.filter(|e| matches!(e, TrainingEvent::PolicySelectionComplete { .. }))
.collect();
assert_eq!(
sel_events.len(),
1,
"expected exactly 1 PolicySelectionComplete event for check_frequency=3 over 5 iterations"
);
let TrainingEvent::PolicySelectionComplete { iteration, .. } = sel_events[0] else {
panic!("wrong variant");
};
assert_eq!(
*iteration, 3,
"PolicySelectionComplete must fire at iteration 3"
);
}
#[test]
fn cut_selection_stage0_exempt_preserves_cuts() {
use crate::cut_selection::CutSelectionStrategy;
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 10,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(2),
},
cut_management: CutManagementConfig {
cut_selection: Some(CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 2,
}),
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: Some(tx),
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let sel_events: Vec<&TrainingEvent> = events
.iter()
.filter(|e| matches!(e, TrainingEvent::PolicySelectionComplete { .. }))
.collect();
assert_eq!(
sel_events.len(),
1,
"expected exactly 1 PolicySelectionComplete event at iteration 2"
);
let TrainingEvent::PolicySelectionComplete {
iteration,
rows_deactivated,
per_stage,
..
} = sel_events[0]
else {
panic!("wrong variant");
};
assert_eq!(*iteration, 2, "selection must fire at iteration 2");
assert_eq!(
*rows_deactivated, 0,
"stage 0 is exempt from cut selection, so no cuts should be deactivated"
);
assert!(
!per_stage.is_empty(),
"per_stage must contain at least the stage 0 record"
);
assert_eq!(per_stage[0].stage, 0, "first record must be stage 0");
assert_eq!(
per_stage[0].rows_deactivated, 0,
"stage 0 must have zero deactivations"
);
}
#[test]
fn existing_train_tests_pass_with_none() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 3,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(3),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: None,
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
let result = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
assert!(result.error.is_none(), "expected no error");
assert_eq!(result.result.iterations, 3);
assert_eq!(result.result.reason, "iteration_limit");
}
#[test]
fn ac_train_partial_result_on_mid_iteration_failure() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 5,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(5),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: Some(tx),
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::infeasible();
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
let outcome = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::infeasible()),
)
.unwrap();
assert!(outcome.error.is_some(), "expected error in TrainingOutcome");
assert_eq!(
outcome.result.iterations, 0,
"no iterations should have completed (failure in iteration 1)"
);
assert_eq!(outcome.result.reason, "error");
assert!(
outcome.result.total_time_ms > 0,
"total_time_ms must be > 0"
);
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let finished = events
.iter()
.find(|e| matches!(e, TrainingEvent::TrainingFinished { .. }));
assert!(
finished.is_some(),
"TrainingFinished event must be emitted even on error"
);
if let Some(TrainingEvent::TrainingFinished { reason, .. }) = finished {
assert_eq!(reason, "error", "TrainingFinished reason must be 'error'");
}
}
#[test]
fn start_iteration_resumes_from_offset() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 5,
start_iteration: 3,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(5),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: None,
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
let outcome = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
assert_eq!(
outcome.result.iterations, 5,
"iterations must report the absolute iteration number (5), not the delta (2)"
);
assert_eq!(outcome.result.reason, "iteration_limit");
}
#[test]
fn start_iteration_at_or_beyond_max_runs_zero_iterations() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 5,
start_iteration: 5,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(5),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: None,
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
let outcome = train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
assert_eq!(
outcome.result.iterations, 5,
"iterations must equal start_iteration when no loop iterations execute"
);
assert_eq!(
outcome.result.reason, "iteration_limit",
"reason should be iteration_limit when loop range is empty"
);
}
#[test]
fn ac_broadcast_basis_cache_uses_scenario_0_not_last() {
use super::broadcast_basis_cache;
use crate::workspace::BasisStore;
let num_scenarios = 4; let num_stages = 3;
let mut store = BasisStore::new(num_scenarios, num_stages);
for t in 0..num_stages {
*store.get_mut(0, t) = Some(crate::workspace::CapturedBasis {
basis: Basis {
col_status: vec![10_i32 + t as i32, 20_i32 + t as i32],
row_status: vec![30_i32 + t as i32],
},
base_row_count: 0,
cut_row_slots: Vec::new(),
state_at_capture: Vec::new(),
});
}
for t in 0..num_stages {
*store.get_mut(3, t) = Some(crate::workspace::CapturedBasis {
basis: Basis {
col_status: vec![99_i32, 88_i32],
row_status: vec![77_i32],
},
base_row_count: 0,
cut_row_slots: Vec::new(),
state_at_capture: Vec::new(),
});
}
let comm = StubComm; let cache = broadcast_basis_cache(&store, num_stages, &comm).unwrap();
assert_eq!(cache.len(), num_stages);
for (t, entry) in cache.iter().enumerate() {
let captured = entry
.as_ref()
.expect("stage {t} must have a captured basis");
assert_eq!(
captured.basis.col_status,
vec![10_i32 + t as i32, 20_i32 + t as i32],
"stage {t} col_status must come from scenario 0, not scenario 3"
);
assert_eq!(
captured.basis.row_status,
vec![30_i32 + t as i32],
"stage {t} row_status must come from scenario 0, not scenario 3"
);
}
}
#[test]
fn ac_broadcast_basis_cache_none_slots_preserved() {
use super::broadcast_basis_cache;
use crate::workspace::BasisStore;
let num_stages = 2;
let store = BasisStore::new(1, num_stages);
let comm = StubComm;
let cache = broadcast_basis_cache(&store, num_stages, &comm).unwrap();
assert_eq!(cache.len(), num_stages);
for t in 0..num_stages {
assert!(
cache[t].is_none(),
"stage {t} must be None when basis store has no entry for scenario 0"
);
}
}
#[test]
fn broadcast_basis_cache_single_rank_preserves_metadata() {
use super::broadcast_basis_cache;
use crate::workspace::{BasisStore, CapturedBasis};
let num_stages = 2;
let mut store = BasisStore::new(1, num_stages);
*store.get_mut(0, 0) = Some(CapturedBasis {
basis: Basis {
col_status: vec![1_i32, 2_i32],
row_status: vec![3_i32, 4_i32, 5_i32],
},
base_row_count: 2,
cut_row_slots: vec![10_u32, 11_u32, 12_u32],
state_at_capture: vec![1.5_f64, 2.5_f64],
});
let comm = StubComm; let cache = broadcast_basis_cache(&store, num_stages, &comm).unwrap();
assert_eq!(cache.len(), num_stages);
let cb = cache[0].as_ref().expect("stage 0 must have captured basis");
assert_eq!(
cb.cut_row_slots.len(),
3,
"single-rank path must preserve cut_row_slots"
);
assert_eq!(cb.base_row_count, 2, "base_row_count must be preserved");
assert_eq!(
cb.state_at_capture,
vec![1.5_f64, 2.5_f64],
"state_at_capture must be preserved"
);
assert!(cache[1].is_none(), "stage 1 must remain None");
}
#[derive(Clone)]
enum MockPayload {
Ints(Vec<i32>),
Floats(Vec<f64>),
}
struct MultiRankMockComm {
rank: usize,
queue: std::sync::Mutex<std::collections::VecDeque<MockPayload>>,
}
impl MultiRankMockComm {
fn new_root() -> Self {
Self {
rank: 0,
queue: std::sync::Mutex::new(std::collections::VecDeque::new()),
}
}
fn new_peer(root: &MultiRankMockComm) -> Self {
Self {
rank: 1,
queue: std::sync::Mutex::new(root.queue.lock().unwrap().clone()),
}
}
fn new_peer_from_queue(queue: std::collections::VecDeque<MockPayload>) -> Self {
Self {
rank: 1,
queue: std::sync::Mutex::new(queue),
}
}
fn snapshot(&self) -> std::collections::VecDeque<MockPayload> {
self.queue.lock().unwrap().clone()
}
}
impl Communicator for MultiRankMockComm {
fn allgatherv<T: CommData>(
&self,
_send: &[T],
_recv: &mut [T],
_counts: &[usize],
_displs: &[usize],
) -> Result<(), CommError> {
unreachable!("broadcast_basis_cache does not call allgatherv")
}
fn allreduce<T: CommData>(
&self,
_send: &[T],
_recv: &mut [T],
_op: ReduceOp,
) -> Result<(), CommError> {
unreachable!("broadcast_basis_cache does not call allreduce")
}
fn broadcast<T: CommData>(&self, buf: &mut [T], root: usize) -> Result<(), CommError> {
self.broadcast_typed(buf, root)
}
fn barrier(&self) -> Result<(), CommError> {
Ok(())
}
fn rank(&self) -> usize {
self.rank
}
fn size(&self) -> usize {
2
}
fn abort(&self, code: i32) -> ! {
std::process::exit(code)
}
}
impl MultiRankMockComm {
#[allow(clippy::unnecessary_wraps)]
fn broadcast_typed<T: CommData>(
&self,
buf: &mut [T],
_root: usize,
) -> Result<(), CommError> {
use std::any::Any;
let probe: Box<dyn Any> = Box::new(T::default());
if probe.downcast_ref::<i32>().is_some() {
if self.rank == 0 {
let ints: Vec<i32> = buf
.iter()
.map(|v| {
*Box::<dyn Any>::from(Box::new(*v))
.downcast::<i32>()
.expect("T proved i32 above")
})
.collect();
self.queue
.lock()
.unwrap()
.push_back(MockPayload::Ints(ints));
} else {
let payload = self
.queue
.lock()
.unwrap()
.pop_front()
.expect("MultiRankMockComm: no payload to replay for i32 broadcast");
let MockPayload::Ints(src) = payload else {
panic!("MultiRankMockComm: expected Ints payload for i32 broadcast");
};
assert_eq!(src.len(), buf.len(), "i32 replay length mismatch");
for (dst, v) in buf.iter_mut().zip(src.iter()) {
let boxed: Box<dyn Any> = Box::new(*v);
*dst = *boxed.downcast::<T>().expect("T proved i32 above");
}
}
} else if probe.downcast_ref::<f64>().is_some() {
if self.rank == 0 {
let floats: Vec<f64> = buf
.iter()
.map(|v| {
*Box::<dyn Any>::from(Box::new(*v))
.downcast::<f64>()
.expect("T proved f64 above")
})
.collect();
self.queue
.lock()
.unwrap()
.push_back(MockPayload::Floats(floats));
} else {
let payload = self
.queue
.lock()
.unwrap()
.pop_front()
.expect("MultiRankMockComm: no payload to replay for f64 broadcast");
let MockPayload::Floats(src) = payload else {
panic!("MultiRankMockComm: expected Floats payload for f64 broadcast");
};
assert_eq!(src.len(), buf.len(), "f64 replay length mismatch");
for (dst, v) in buf.iter_mut().zip(src.iter()) {
let boxed: Box<dyn Any> = Box::new(*v);
*dst = *boxed.downcast::<T>().expect("T proved f64 above");
}
}
} else {
panic!("MultiRankMockComm: unsupported broadcast type (expected i32 or f64)");
}
Ok(())
}
}
#[test]
fn broadcast_basis_cache_multi_rank_round_trips_full_metadata() {
use super::broadcast_basis_cache;
use crate::workspace::{BasisStore, CapturedBasis};
let mut store = BasisStore::new(1, 2);
*store.get_mut(0, 0) = Some(CapturedBasis {
basis: Basis {
col_status: vec![1_i32, 2_i32, 3_i32],
row_status: vec![10_i32, 20_i32],
},
base_row_count: 4,
cut_row_slots: vec![10_u32, 11_u32, 12_u32],
state_at_capture: vec![1.5_f64, 2.5_f64],
});
let root_comm = MultiRankMockComm::new_root();
let _cache_rank0 = broadcast_basis_cache(&store, 2, &root_comm).unwrap();
let peer_comm = MultiRankMockComm::new_peer(&root_comm);
let empty_store = BasisStore::new(1, 2);
let cache = broadcast_basis_cache(&empty_store, 2, &peer_comm).unwrap();
assert_eq!(cache.len(), 2);
let cb0 = cache[0]
.as_ref()
.expect("stage 0 must deserialise into CapturedBasis on rank 1");
assert_eq!(
cb0.basis.col_status,
vec![1_i32, 2_i32, 3_i32],
"col_status must round-trip"
);
assert_eq!(
cb0.basis.row_status,
vec![10_i32, 20_i32],
"row_status must round-trip"
);
assert_eq!(
cb0.cut_row_slots,
vec![10_u32, 11_u32, 12_u32],
"cut_row_slots must round-trip on non-root rank"
);
assert_eq!(
cb0.state_at_capture,
vec![1.5_f64, 2.5_f64],
"state_at_capture must round-trip on non-root rank"
);
assert_eq!(
cb0.base_row_count, 4,
"base_row_count must round-trip on non-root rank"
);
assert!(cache[1].is_none(), "stage 1 had no basis → None");
}
#[test]
fn broadcast_basis_cache_empty_cut_slots_round_trips_ok() {
use super::broadcast_basis_cache;
use crate::workspace::{BasisStore, CapturedBasis};
let mut store = BasisStore::new(1, 1);
*store.get_mut(0, 0) = Some(CapturedBasis {
basis: Basis {
col_status: vec![5_i32, 6_i32],
row_status: vec![7_i32],
},
base_row_count: 1,
cut_row_slots: vec![], state_at_capture: vec![3.75_f64],
});
let root_comm = MultiRankMockComm::new_root();
let _ = broadcast_basis_cache(&store, 1, &root_comm).unwrap();
let peer_comm = MultiRankMockComm::new_peer(&root_comm);
let empty_store = BasisStore::new(1, 1);
let cache = broadcast_basis_cache(&empty_store, 1, &peer_comm).unwrap();
assert_eq!(cache.len(), 1);
let cb = cache[0]
.as_ref()
.expect("stage 0 must be Some after broadcast");
assert!(
cb.cut_row_slots.is_empty(),
"empty cut_row_slots must round-trip without error or panic"
);
assert_eq!(
cb.state_at_capture,
vec![3.75_f64],
"state_at_capture must still round-trip when cut_row_slots is empty"
);
assert_eq!(cb.base_row_count, 1, "base_row_count must round-trip");
}
#[test]
fn broadcast_basis_cache_truncated_cut_slots_returns_validation() {
use super::broadcast_basis_cache;
use crate::workspace::{BasisStore, CapturedBasis};
let mut store = BasisStore::new(1, 1);
*store.get_mut(0, 0) = Some(CapturedBasis {
basis: Basis {
col_status: vec![1_i32],
row_status: vec![2_i32],
},
base_row_count: 1,
cut_row_slots: vec![10_u32, 11_u32, 12_u32],
state_at_capture: vec![0.0_f64],
});
let root_comm = MultiRankMockComm::new_root();
let _ = broadcast_basis_cache(&store, 1, &root_comm).unwrap();
let mut snapshot = root_comm.snapshot();
let truncated_len = {
let entry = snapshot.get_mut(1).expect("i32 payload entry must exist");
let MockPayload::Ints(ref mut ints) = *entry else {
panic!("entry [1] must be Ints");
};
ints.pop(); ints.len() as i32
};
let len_entry = snapshot.get_mut(0).expect("i32 length entry must exist");
let MockPayload::Ints(ref mut len_vec) = *len_entry else {
panic!("entry [0] must be Ints");
};
assert_eq!(len_vec.len(), 1, "length entry must hold a single scalar");
len_vec[0] = truncated_len;
let peer_comm = MultiRankMockComm::new_peer_from_queue(snapshot);
let empty_store = BasisStore::new(1, 1);
let result = broadcast_basis_cache(&empty_store, 1, &peer_comm);
match result {
Err(SddpError::Validation(msg)) => {
assert!(
msg.contains("cut_row_slots"),
"error message must mention 'cut_row_slots', got: {msg}"
);
assert!(
msg.contains('0'),
"error message must contain stage index 0, got: {msg}"
);
}
other => panic!("expected SddpError::Validation, got: {other:?}"),
}
}
#[test]
fn broadcast_basis_cache_truncated_state_returns_validation() {
use super::broadcast_basis_cache;
use crate::workspace::{BasisStore, CapturedBasis};
let mut store = BasisStore::new(1, 1);
*store.get_mut(0, 0) = Some(CapturedBasis {
basis: Basis {
col_status: vec![1_i32],
row_status: vec![2_i32],
},
base_row_count: 1,
cut_row_slots: vec![],
state_at_capture: vec![1.0_f64, 2.0_f64, 3.0_f64],
});
let root_comm = MultiRankMockComm::new_root();
let _ = broadcast_basis_cache(&store, 1, &root_comm).unwrap();
let mut snapshot = root_comm.snapshot();
let truncated_f64_len = {
let entry = snapshot.get_mut(3).expect("f64 payload entry must exist");
let MockPayload::Floats(ref mut floats) = *entry else {
panic!("entry [3] must be Floats");
};
floats.truncate(1); floats.len() as i32
};
let f64_len_entry = snapshot.get_mut(2).expect("f64 length entry must exist");
let MockPayload::Ints(ref mut f64_len_vec) = *f64_len_entry else {
panic!("entry [2] must be Ints (f64 length is broadcast as i32)");
};
assert_eq!(
f64_len_vec.len(),
1,
"f64 length entry must hold a single scalar"
);
f64_len_vec[0] = truncated_f64_len;
let peer_comm = MultiRankMockComm::new_peer_from_queue(snapshot);
let empty_store = BasisStore::new(1, 1);
let result = broadcast_basis_cache(&empty_store, 1, &peer_comm);
match result {
Err(SddpError::Validation(msg)) => {
assert!(
msg.contains("state_at_capture"),
"error message must mention 'state_at_capture', got: {msg}"
);
assert!(
msg.contains('0'),
"error message must contain stage index 0, got: {msg}"
);
}
other => panic!("expected SddpError::Validation, got: {other:?}"),
}
}
#[test]
fn broadcast_basis_cache_rejects_oversized_i32_payload() {
use super::checked_broadcast_len;
let oversized: usize = (i32::MAX as usize) + 1;
let result = checked_broadcast_len(oversized, "broadcast_basis_cache_i32");
match result {
Err(SddpError::Communication(CommError::InvalidBufferSize {
operation,
expected,
actual,
})) => {
assert_eq!(actual, oversized, "actual must equal the oversized length");
assert_eq!(
expected,
i32::MAX as usize,
"expected must equal i32::MAX as usize"
);
assert_eq!(
operation, "broadcast_basis_cache_i32",
"operation string must be 'broadcast_basis_cache_i32'"
);
}
other => panic!(
"expected SddpError::Communication(CommError::InvalidBufferSize {{ .. }}), got: {other:?}"
),
}
}
#[test]
fn template_bake_event_emitted() {
let n_stages = 2;
let indexer = StageIndexer::new(1, 0);
let templates = vec![minimal_template(indexer.n_state); n_stages];
let base_rows = vec![2usize; n_stages];
let initial_state = vec![0.0_f64; indexer.n_state];
let stochastic = make_stochastic_context(n_stages, 1);
let stages = make_stages(n_stages);
let horizon = HorizonMode::Finite {
num_stages: n_stages,
};
let mut fcf = make_fcf(n_stages, indexer.n_state, 1, 10);
let (tx, rx) = mpsc::channel::<TrainingEvent>();
let config = TrainingConfig {
loop_config: LoopConfig {
forward_passes: 1,
max_iterations: 10,
start_iteration: 0,
n_fwd_threads: 1,
max_blocks: 1,
stopping_rules: iteration_limit_rules(2),
},
cut_management: CutManagementConfig {
cut_selection: None,
budget: None,
cut_activity_tolerance: 0.0,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
warm_start_cuts: 0,
risk_measures: vec![RiskMeasure::Expectation; n_stages],
},
events: EventConfig {
event_sender: Some(tx),
checkpoint_interval: None,
shutdown_flag: None,
export_states: false,
},
};
let mut solver = MockSolver::with_fixed(100.0);
let comm = StubComm;
let stage_ctx = StageContext {
templates: &templates,
base_rows: &base_rows,
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[1usize, 1],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
};
train(
&mut solver,
config,
&mut fcf,
&stage_ctx,
&TrainingContext {
horizon: &horizon,
indexer: &indexer,
inflow_method: &InflowNonNegativityMethod::None,
stochastic: &stochastic,
initial_state: &initial_state,
inflow_scheme: SamplingScheme::InSample,
load_scheme: SamplingScheme::InSample,
ncs_scheme: SamplingScheme::InSample,
stages: &stages,
historical_library: None,
external_inflow_library: None,
external_load_library: None,
external_ncs_library: None,
recent_accum_seed: &[],
recent_weight_seed: 0.0,
},
&comm,
|| Ok(MockSolver::with_fixed(100.0)),
)
.unwrap();
let events: Vec<TrainingEvent> = rx.try_iter().collect();
let bake_events: Vec<&TrainingEvent> = events
.iter()
.filter(|e| matches!(e, TrainingEvent::PolicyTemplateBakeComplete { .. }))
.collect();
assert_eq!(
bake_events.len(),
2,
"expected exactly 2 PolicyTemplateBakeComplete events, got {}",
bake_events.len()
);
for event in &bake_events {
let TrainingEvent::PolicyTemplateBakeComplete {
stages_processed, ..
} = event
else {
panic!("wrong variant")
};
assert_eq!(
*stages_processed, n_stages as u32,
"stages_processed must equal num_stages"
);
}
let second_bake = bake_events[1];
let TrainingEvent::PolicyTemplateBakeComplete {
total_rows_baked, ..
} = second_bake
else {
panic!("wrong variant")
};
assert!(
*total_rows_baked > 0,
"iteration 2 bake must have baked at least one cut row (backward pass \
generated cuts on iteration 1)"
);
}
#[test]
fn ac_training_result_new_assigns_all_fields() {
use crate::workspace::CapturedBasis;
let basis_cache = vec![Some(CapturedBasis {
basis: Basis {
col_status: vec![1_i32],
row_status: vec![2_i32],
},
base_row_count: 3,
cut_row_slots: vec![4_u32],
state_at_capture: vec![5.0_f64],
})];
let solver_stats_log = vec![(
7_u64,
"forward",
-1_i32,
-1_i32,
0_i32,
-1_i32,
SolverStatsDelta::default(),
)];
let result = super::TrainingResult::new(
1.5_f64, 2.5_f64, 0.25_f64, 0.1_f64, 42_u64, "iteration_limit".to_string(), 9_999_u64, basis_cache,
solver_stats_log,
None, None, );
assert_eq!(result.final_lb, 1.5_f64, "final_lb");
assert_eq!(result.final_ub, 2.5_f64, "final_ub");
assert_eq!(result.final_ub_std, 0.25_f64, "final_ub_std");
assert_eq!(result.final_gap, 0.1_f64, "final_gap");
assert_eq!(result.iterations, 42_u64, "iterations");
assert_eq!(result.reason, "iteration_limit", "reason");
assert_eq!(result.total_time_ms, 9_999_u64, "total_time_ms");
assert_eq!(result.basis_cache.len(), 1, "basis_cache length");
let captured = result.basis_cache[0].as_ref().expect("basis_cache[0]");
assert_eq!(captured.base_row_count, 3, "basis_cache[0].base_row_count");
assert_eq!(result.solver_stats_log.len(), 1, "solver_stats_log length");
assert_eq!(
result.solver_stats_log[0].0, 7_u64,
"solver_stats_log[0].0 (iteration)"
);
assert!(result.visited_archive.is_none(), "visited_archive");
assert!(result.baked_templates.is_none(), "baked_templates");
}
}