use cobre_solver::{SolutionView, SolverError, SolverInterface};
use crate::{
basis_reconstruct::{ReconstructionTarget, enforce_basic_count_invariant, reconstruct_basis},
context::StageContext,
cut::pool::CutPool,
error::SddpError,
workspace::{CapturedBasis, SolverWorkspace},
};
pub struct StageInputs<'a> {
pub stage_context: &'a StageContext<'a>,
pub pool: &'a CutPool,
pub stored_basis: Option<&'a CapturedBasis>,
pub stage_index: usize,
pub scenario_index: usize,
pub iteration: Option<u64>,
}
pub fn run_stage_solve<'ws, S: SolverInterface>(
ws: &'ws mut SolverWorkspace<S>,
inputs: &StageInputs<'_>,
) -> Result<SolutionView<'ws>, SddpError> {
if ws.scratch.recon_slot_lookup.len() < inputs.pool.populated_count {
ws.scratch
.recon_slot_lookup
.resize(inputs.pool.populated_count, None);
}
let view = if let Some(captured) = inputs.stored_basis {
let target = ReconstructionTarget {
base_row_count: inputs.stage_context.templates[inputs.stage_index].num_rows,
num_cols: inputs.stage_context.templates[inputs.stage_index].num_cols,
};
let _ = reconstruct_basis(
captured,
target,
inputs.pool.active_cuts(),
&mut ws.scratch_basis,
&mut ws.scratch.recon_slot_lookup,
);
let num_row_for_invariant = ws.scratch_basis.row_status.len();
let base_row_for_invariant = 0;
enforce_basic_count_invariant(
&mut ws.scratch_basis,
num_row_for_invariant,
base_row_for_invariant,
);
ws.solver.record_reconstruction_stats();
ws.solver.solve(Some(&ws.scratch_basis)).map_err(|e| {
map_solver_error(
e,
inputs.stage_index,
inputs.scenario_index,
inputs.iteration,
)
})?
} else {
ws.solver.solve(None).map_err(|e| {
map_solver_error(
e,
inputs.stage_index,
inputs.scenario_index,
inputs.iteration,
)
})?
};
Ok(view)
}
fn map_solver_error(
e: SolverError,
stage: usize,
scenario: usize,
iteration: Option<u64>,
) -> SddpError {
match e {
SolverError::Infeasible => SddpError::Infeasible {
stage,
iteration: iteration.unwrap_or(0),
scenario,
},
other => SddpError::Solver(other),
}
}
pub(crate) fn fill_unscaled(out: &mut Vec<f64>, scaled: &[f64], col_scale: &[f64]) {
debug_assert!(
col_scale.is_empty() || col_scale.len() == scaled.len(),
"col_scale length {} != primal length {}",
col_scale.len(),
scaled.len()
);
out.clear();
if col_scale.is_empty() {
out.extend_from_slice(scaled);
} else {
out.extend(scaled.iter().zip(col_scale.iter()).map(|(&xp, &d)| d * xp));
}
}
pub(crate) fn fill_unscaled_dual(out: &mut Vec<f64>, scaled: &[f64], row_scale: &[f64]) {
debug_assert!(
row_scale.len() <= scaled.len(),
"row_scale length {} exceeds dual length {}",
row_scale.len(),
scaled.len()
);
out.clear();
if row_scale.is_empty() {
out.extend_from_slice(scaled);
} else {
out.extend(scaled.iter().enumerate().map(|(i, &d)| {
let scale = if i < row_scale.len() {
row_scale[i]
} else {
1.0
};
d * scale
}));
}
}
#[cfg(test)]
mod tests {
use cobre_solver::{ActiveSolver, SolverInterface, StageTemplate};
#[cfg(feature = "highs")]
use cobre_solver::SolverError;
use super::{StageInputs, run_stage_solve};
use crate::{
SddpError,
basis_reconstruct::{HIGHS_BASIS_STATUS_BASIC as B, HIGHS_BASIS_STATUS_LOWER as L},
context::StageContext,
cut::pool::CutPool,
lp_builder::PatchBuffer,
workspace::{CapturedBasis, SolverWorkspace, WorkspaceSizing},
};
fn make_template() -> StageTemplate {
StageTemplate {
num_cols: 3,
num_rows: 2,
num_nz: 3,
col_starts: vec![0_i32, 2, 2, 3],
row_indices: vec![0_i32, 1, 1],
values: vec![1.0, 2.0, 1.0],
col_lower: vec![0.0, 0.0, 0.0],
col_upper: vec![10.0, f64::INFINITY, 8.0],
objective: vec![0.0, 1.0, 50.0],
row_lower: vec![6.0, 14.0],
row_upper: vec![6.0, 14.0],
n_state: 1,
n_transfer: 0,
n_dual_relevant: 1,
n_hydro: 1,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
}
}
fn make_infeasible_template() -> StageTemplate {
StageTemplate {
num_cols: 1,
num_rows: 0,
num_nz: 0,
col_starts: vec![0_i32, 0],
row_indices: vec![],
values: vec![],
col_lower: vec![5.0],
col_upper: vec![2.0], objective: vec![1.0],
row_lower: vec![],
row_upper: vec![],
n_state: 0,
n_transfer: 0,
n_dual_relevant: 0,
n_hydro: 0,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
}
}
fn make_workspace(template: &StageTemplate) -> SolverWorkspace<ActiveSolver> {
let mut solver = ActiveSolver::new().expect("ActiveSolver::new()");
solver.load_model(template);
SolverWorkspace::new(
0,
0,
solver,
PatchBuffer::new(0, 0, 0, 0, 0, 0),
0,
WorkspaceSizing::default(),
)
}
fn make_context(templates: &[StageTemplate]) -> StageContext<'_> {
StageContext {
templates,
base_rows: &[],
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
}
}
fn make_empty_pool() -> CutPool {
CutPool::new(16, 1, 1, 0)
}
#[test]
fn run_stage_solve_cold_start_returns_view() {
let template = make_template();
let templates = std::slice::from_ref(&template);
let ctx = make_context(templates);
let pool = make_empty_pool();
let mut ws = make_workspace(&template);
let inputs = StageInputs {
stage_context: &ctx,
pool: &pool,
stored_basis: None,
stage_index: 0,
scenario_index: 0,
iteration: Some(1),
};
let result = run_stage_solve(&mut ws, &inputs);
let view = result.expect("cold start should succeed");
assert!(view.objective.is_finite());
}
#[test]
fn run_stage_solve_warm_start_baked_path_succeeds() {
let template = make_template();
let templates = std::slice::from_ref(&template);
let ctx = make_context(templates);
let pool = make_empty_pool();
let mut ws = make_workspace(&template);
ws.scratch.recon_slot_lookup = vec![None; 16];
let mut captured = CapturedBasis::new(
template.num_cols,
template.num_rows,
template.num_rows,
0,
1,
);
captured.basis.col_status.clear();
captured.basis.col_status.push(B); captured.basis.col_status.push(B); captured.basis.col_status.push(L); captured.basis.row_status.clear();
captured.basis.row_status.push(L); captured.basis.row_status.push(L); captured.state_at_capture.push(6.0);
let inputs = StageInputs {
stage_context: &ctx,
pool: &pool,
stored_basis: Some(&captured),
stage_index: 0,
scenario_index: 0,
iteration: None,
};
let result = run_stage_solve(&mut ws, &inputs);
assert!(
result.is_ok(),
"warm start on baked path should succeed: {result:?}"
);
let lower_count = ws
.scratch_basis
.row_status
.iter()
.filter(|&&s| s == L)
.count();
assert_eq!(
lower_count, 2,
"both rows should be LOWER after reconstruction"
);
}
#[test]
fn run_stage_solve_propagates_infeasible() {
let template = make_infeasible_template();
let templates = std::slice::from_ref(&template);
let ctx = make_context(templates);
let pool = CutPool::new(16, 0, 1, 0);
let mut ws = make_workspace(&template);
let inputs = StageInputs {
stage_context: &ctx,
pool: &pool,
stored_basis: None,
stage_index: 0,
scenario_index: 7,
iteration: Some(42),
};
let result = run_stage_solve(&mut ws, &inputs);
match result {
Err(SddpError::Infeasible {
stage,
scenario,
iteration,
}) => {
assert_eq!(stage, 0, "stage must match inputs.stage_index");
assert_eq!(scenario, 7, "scenario must match inputs.scenario_index");
assert_eq!(iteration, 42, "iteration must match inputs.iteration");
}
other => panic!("expected SddpError::Infeasible, got {other:?}"),
}
}
#[cfg(feature = "highs")]
#[test]
fn basis_inconsistent_propagates_as_sddp_solver_error() {
let template = make_template();
let templates = std::slice::from_ref(&template);
let ctx = make_context(templates);
let pool = make_empty_pool();
let mut ws = make_workspace(&template);
ws.scratch.recon_slot_lookup = vec![None; 16];
let all_lower = CapturedBasis::new(
template.num_cols,
template.num_rows,
template.num_rows,
0,
1,
);
let inputs = StageInputs {
stage_context: &ctx,
pool: &pool,
stored_basis: Some(&all_lower),
stage_index: 0,
scenario_index: 3,
iteration: Some(5),
};
let result = run_stage_solve(&mut ws, &inputs);
match result {
Err(SddpError::Solver(SolverError::BasisInconsistent { .. })) => {
}
Err(SddpError::Infeasible { .. }) => {
panic!("BasisInconsistent must not map to SddpError::Infeasible")
}
other => panic!(
"expected Err(SddpError::Solver(SolverError::BasisInconsistent {{ .. }})), \
got {other:?}"
),
}
}
}