use cobre_solver::{SolutionView, SolverInterface, StageTemplate};
use crate::{
basis_reconstruct::{
PaddingContext, ReconstructionStats, ReconstructionTarget, enforce_basic_count_invariant,
reconstruct_basis,
},
context::StageContext,
cut::pool::CutPool,
error::SddpError,
indexer::StageIndexer,
workspace::{CapturedBasis, SolverWorkspace},
};
use cobre_solver::SolverError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Phase {
Forward,
Backward,
Simulation,
}
#[allow(dead_code)]
pub struct StageInputs<'a> {
pub stage_context: &'a StageContext<'a>,
pub indexer: &'a StageIndexer,
pub pool: &'a CutPool,
pub current_state: &'a [f64],
pub stored_basis: Option<&'a CapturedBasis>,
pub baked_template: &'a StageTemplate,
pub stage_index: usize,
pub scenario_index: usize,
pub horizon_is_terminal: bool,
pub terminal_has_boundary_cuts: bool,
pub iteration: Option<u64>,
pub basis_activity_window: u32,
}
#[allow(dead_code)]
#[derive(Debug)]
pub enum StageOutcome<'solver> {
Forward {
view: SolutionView<'solver>,
recon_stats: ReconstructionStats,
},
Backward {
view: SolutionView<'solver>,
recon_stats: ReconstructionStats,
},
Simulation {
view: SolutionView<'solver>,
recon_stats: ReconstructionStats,
},
}
pub fn run_stage_solve<'ws, S: SolverInterface>(
ws: &'ws mut SolverWorkspace<S>,
phase: Phase,
inputs: &StageInputs<'_>,
) -> Result<StageOutcome<'ws>, SddpError> {
if ws.scratch.recon_slot_lookup.len() < inputs.pool.populated_count {
ws.scratch
.recon_slot_lookup
.resize(inputs.pool.populated_count, None);
}
let (view, recon_stats) = if let Some(captured) = inputs.stored_basis {
let theta_value = inputs
.pool
.evaluate_at_state(&inputs.current_state[..inputs.indexer.n_state]);
let padding = PaddingContext {
state: &inputs.current_state[..inputs.indexer.n_state],
theta: theta_value,
tolerance: 1e-7,
};
let source = crate::basis_reconstruct::ReconstructionSource {
target: ReconstructionTarget {
base_row_count: inputs.stage_context.templates[inputs.stage_index].num_rows,
num_cols: inputs.stage_context.templates[inputs.stage_index].num_cols,
},
cut_metadata: &inputs.pool.metadata,
basis_activity_window: inputs.basis_activity_window,
};
let recon_stats = reconstruct_basis(
captured,
source,
inputs.pool.active_cuts(),
padding,
&mut ws.scratch_basis,
&mut ws.scratch.recon_slot_lookup,
&mut ws.scratch.promotion_scratch,
);
let num_row_for_invariant = ws.scratch_basis.row_status.len();
let base_row_for_invariant = inputs.indexer.n_state;
enforce_basic_count_invariant(
&mut ws.scratch_basis,
num_row_for_invariant,
base_row_for_invariant,
);
ws.solver.record_reconstruction_stats();
let view = ws.solver.solve(Some(&ws.scratch_basis)).map_err(|e| {
map_solver_error(
e,
inputs.stage_index,
inputs.scenario_index,
inputs.iteration,
)
})?;
(view, recon_stats)
} else {
let view = ws.solver.solve(None).map_err(|e| {
map_solver_error(
e,
inputs.stage_index,
inputs.scenario_index,
inputs.iteration,
)
})?;
(view, ReconstructionStats::default())
};
let outcome = match phase {
Phase::Forward => StageOutcome::Forward { view, recon_stats },
Phase::Backward => StageOutcome::Backward { view, recon_stats },
Phase::Simulation => StageOutcome::Simulation { view, recon_stats },
};
Ok(outcome)
}
fn map_solver_error(
e: SolverError,
stage: usize,
scenario: usize,
iteration: Option<u64>,
) -> SddpError {
match e {
SolverError::Infeasible => SddpError::Infeasible {
stage,
iteration: iteration.unwrap_or(0),
scenario,
},
other => SddpError::Solver(other),
}
}
#[cfg(test)]
mod tests {
use cobre_solver::{HighsSolver, SolverError, SolverInterface, StageTemplate};
use super::{Phase, StageInputs, run_stage_solve};
use crate::{
SddpError,
basis_reconstruct::{
HIGHS_BASIS_STATUS_BASIC as B, HIGHS_BASIS_STATUS_LOWER as L, ReconstructionStats,
},
context::StageContext,
cut::pool::CutPool,
indexer::StageIndexer,
lp_builder::PatchBuffer,
workspace::{CapturedBasis, SolverWorkspace, WorkspaceSizing},
};
fn make_template() -> StageTemplate {
StageTemplate {
num_cols: 3,
num_rows: 2,
num_nz: 3,
col_starts: vec![0_i32, 2, 2, 3],
row_indices: vec![0_i32, 1, 1],
values: vec![1.0, 2.0, 1.0],
col_lower: vec![0.0, 0.0, 0.0],
col_upper: vec![10.0, f64::INFINITY, 8.0],
objective: vec![0.0, 1.0, 50.0],
row_lower: vec![6.0, 14.0],
row_upper: vec![6.0, 14.0],
n_state: 1,
n_transfer: 0,
n_dual_relevant: 1,
n_hydro: 1,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
}
}
fn make_infeasible_template() -> StageTemplate {
StageTemplate {
num_cols: 1,
num_rows: 0,
num_nz: 0,
col_starts: vec![0_i32, 0],
row_indices: vec![],
values: vec![],
col_lower: vec![5.0],
col_upper: vec![2.0], objective: vec![1.0],
row_lower: vec![],
row_upper: vec![],
n_state: 0,
n_transfer: 0,
n_dual_relevant: 0,
n_hydro: 0,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
}
}
fn make_workspace(template: &StageTemplate) -> SolverWorkspace<HighsSolver> {
let mut solver = HighsSolver::new().expect("HighsSolver::new()");
solver.load_model(template);
SolverWorkspace::new(
0,
0,
solver,
PatchBuffer::new(0, 0, 0, 0),
0,
WorkspaceSizing::default(),
)
}
fn make_context(templates: &[StageTemplate]) -> StageContext<'_> {
StageContext {
templates,
base_rows: &[],
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
}
}
fn make_empty_pool() -> CutPool {
CutPool::new(16, 1, 1, 0)
}
fn make_indexer() -> StageIndexer {
StageIndexer::new(1, 0)
}
#[test]
fn run_stage_solve_cold_start_returns_outcome() {
let template = make_template();
let templates = std::slice::from_ref(&template);
let ctx = make_context(templates);
let pool = make_empty_pool();
let indexer = make_indexer();
let mut ws = make_workspace(&template);
let inputs = StageInputs {
stage_context: &ctx,
indexer: &indexer,
pool: &pool,
current_state: &[0.0],
stored_basis: None,
baked_template: &template,
stage_index: 0,
scenario_index: 0,
horizon_is_terminal: false,
terminal_has_boundary_cuts: false,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
iteration: Some(1),
};
let result = run_stage_solve(&mut ws, Phase::Forward, &inputs);
let outcome = result.expect("cold start should succeed");
match outcome {
crate::stage_solve::StageOutcome::Forward { recon_stats, .. } => {
assert_eq!(recon_stats, ReconstructionStats::default());
}
_ => panic!("expected Forward variant"),
}
}
#[test]
fn run_stage_solve_warm_start_baked_path_succeeds() {
let template = make_template();
let templates = std::slice::from_ref(&template);
let ctx = make_context(templates);
let pool = make_empty_pool();
let indexer = make_indexer();
let mut ws = make_workspace(&template);
ws.scratch.recon_slot_lookup = vec![None; 16];
let mut captured = CapturedBasis::new(
template.num_cols,
template.num_rows,
template.num_rows,
0,
1,
);
captured.basis.col_status.clear();
captured.basis.col_status.push(B); captured.basis.col_status.push(B); captured.basis.col_status.push(L); captured.basis.row_status.clear();
captured.basis.row_status.push(L); captured.basis.row_status.push(L); captured.state_at_capture.push(6.0);
let inputs = StageInputs {
stage_context: &ctx,
indexer: &indexer,
pool: &pool,
current_state: &[6.0],
stored_basis: Some(&captured),
baked_template: &template,
stage_index: 0,
scenario_index: 0,
horizon_is_terminal: false,
terminal_has_boundary_cuts: false,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
iteration: None,
};
let result = run_stage_solve(&mut ws, Phase::Simulation, &inputs);
assert!(
result.is_ok(),
"warm start on baked path should succeed: {result:?}"
);
let lower_count = ws
.scratch_basis
.row_status
.iter()
.filter(|&&s| s == L)
.count();
assert_eq!(
lower_count, 2,
"both rows should be LOWER after reconstruction"
);
}
#[test]
fn run_stage_solve_propagates_infeasible() {
let template = make_infeasible_template();
let templates = std::slice::from_ref(&template);
let ctx = make_context(templates);
let pool = CutPool::new(16, 0, 1, 0);
let indexer = StageIndexer::new(0, 0);
let mut ws = make_workspace(&template);
let inputs = StageInputs {
stage_context: &ctx,
indexer: &indexer,
pool: &pool,
current_state: &[],
stored_basis: None,
baked_template: &template,
stage_index: 0,
scenario_index: 7,
horizon_is_terminal: false,
terminal_has_boundary_cuts: false,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
iteration: Some(42),
};
let result = run_stage_solve(&mut ws, Phase::Forward, &inputs);
match result {
Err(SddpError::Infeasible {
stage,
scenario,
iteration,
}) => {
assert_eq!(stage, 0, "stage must match inputs.stage_index");
assert_eq!(scenario, 7, "scenario must match inputs.scenario_index");
assert_eq!(iteration, 42, "iteration must match inputs.iteration");
}
other => panic!("expected SddpError::Infeasible, got {other:?}"),
}
}
#[test]
fn basis_inconsistent_propagates_as_sddp_solver_error() {
let template = make_template();
let templates = std::slice::from_ref(&template);
let ctx = make_context(templates);
let pool = make_empty_pool();
let indexer = make_indexer();
let mut ws = make_workspace(&template);
ws.scratch.recon_slot_lookup = vec![None; 16];
let all_lower = CapturedBasis::new(
template.num_cols,
template.num_rows,
template.num_rows,
0,
1,
);
let inputs = StageInputs {
stage_context: &ctx,
indexer: &indexer,
pool: &pool,
current_state: &[0.0],
stored_basis: Some(&all_lower),
baked_template: &template,
stage_index: 0,
scenario_index: 3,
horizon_is_terminal: false,
terminal_has_boundary_cuts: false,
basis_activity_window: crate::basis_reconstruct::DEFAULT_BASIS_ACTIVITY_WINDOW,
iteration: Some(5),
};
let result = run_stage_solve(&mut ws, Phase::Forward, &inputs);
match result {
Err(SddpError::Solver(SolverError::BasisInconsistent { .. })) => {
}
Err(SddpError::Infeasible { .. }) => {
panic!("BasisInconsistent must not map to SddpError::Infeasible")
}
other => panic!(
"expected Err(SddpError::Solver(SolverError::BasisInconsistent {{ .. }})), \
got {other:?}"
),
}
}
}