use cobre_solver::SolverInterface;
use crate::{
SddpError,
context::{StageContext, TrainingContext},
dcs::{DcsParams, DcsSolveContext, build_initial_resident_set, lazy_solve_preloaded},
risk_measure::RiskMeasure,
state_exchange::ExchangeBuffers,
workspace::{BasisStoreSliceMut, SolverWorkspace},
};
use super::{
StagedCut, SuccessorSpec,
duals_extraction::{extract_duals_from_view, extract_state_duals_only},
lp_setup::{load_backward_lp, patch_opening_bounds, resolve_backward_basis},
outcome_aggregation::{
accumulate_dcs_binding_counts, accumulate_opening_outcome, save_basis_at_omega_zero,
write_opening_outcome,
},
};
pub(crate) enum StageOpeningSolver {
Baked,
Lazy(DcsParams),
}
impl StageOpeningSolver {
pub(crate) fn from_dcs_params(dcs_params: Option<DcsParams>) -> Self {
match dcs_params {
Some(params) => StageOpeningSolver::Lazy(params),
None => StageOpeningSolver::Baked,
}
}
pub(crate) fn prepare<S: SolverInterface + Send>(
&self,
ws: &mut SolverWorkspace<S>,
ctx: &StageContext<'_>,
succ: &SuccessorSpec<'_>,
iteration: u64,
) {
match self {
StageOpeningSolver::Baked => {
load_backward_lp(ws, succ);
}
StageOpeningSolver::Lazy(params) => {
ws.solver.load_model(&ctx.templates[succ.successor]);
build_initial_resident_set(
succ.successor_pool,
iteration,
params.k2,
&mut ws.backward_accum.dcs_initial_resident,
);
}
}
}
#[allow(clippy::too_many_arguments)]
fn solve_opening<S: SolverInterface + Send>(
&self,
ws: &mut SolverWorkspace<S>,
ctx: &StageContext<'_>,
training_ctx: &TrainingContext<'_>,
succ: &SuccessorSpec<'_>,
basis_slice: &mut BasisStoreSliceMut<'_>,
raw_noise: &[f64],
x_hat: &[f64],
s: usize,
scenario: usize,
iteration: u64,
m: usize,
omega: usize,
is_first: bool,
) -> Result<(), SddpError> {
match self {
StageOpeningSolver::Baked => Self::solve_baked(
ws,
ctx,
training_ctx,
succ,
basis_slice,
raw_noise,
x_hat,
s,
scenario,
iteration,
m,
omega,
is_first,
),
StageOpeningSolver::Lazy(params) => Self::solve_lazy(
ws,
ctx,
training_ctx,
succ,
*params,
raw_noise,
x_hat,
s,
scenario,
iteration,
omega,
!is_first,
),
}
}
#[allow(clippy::too_many_arguments)]
fn solve_baked<S: SolverInterface + Send>(
ws: &mut SolverWorkspace<S>,
ctx: &StageContext<'_>,
training_ctx: &TrainingContext<'_>,
succ: &SuccessorSpec<'_>,
basis_slice: &mut BasisStoreSliceMut<'_>,
raw_noise: &[f64],
x_hat: &[f64],
s: usize,
scenario: usize,
iteration: u64,
m: usize,
omega: usize,
is_first: bool,
) -> Result<(), SddpError> {
let indexer = training_ctx.indexer;
patch_opening_bounds(ws, ctx, training_ctx, raw_noise, x_hat, s);
let mut state_duals = std::mem::take(&mut ws.backward_accum.state_duals_buf);
let mut cut_duals = std::mem::take(&mut ws.backward_accum.cut_duals_buf);
let mut stats_before_omega = std::mem::take(&mut ws.backward_accum.stats_before_buf);
ws.solver.statistics_into(&mut stats_before_omega);
let stored_basis = if is_first {
resolve_backward_basis(basis_slice, m, s)
} else {
None
};
let inputs = crate::stage_solve::StageInputs {
stage_context: ctx,
pool: succ.successor_pool,
stored_basis,
stage_index: s,
scenario_index: scenario,
iteration: Some(iteration),
};
let view = crate::stage_solve::run_stage_solve(ws, &inputs)?;
let objective = extract_duals_from_view(
&view,
indexer.n_state,
indexer,
&ctx.templates[s].col_scale,
succ,
&mut state_duals,
&mut cut_duals,
);
let _ = view;
ws.backward_accum.state_duals_buf = state_duals;
ws.backward_accum.cut_duals_buf = cut_duals;
let mut stats_after_omega = std::mem::take(&mut ws.backward_accum.stats_after_buf);
ws.solver.statistics_into(&mut stats_after_omega);
accumulate_opening_outcome(
ws,
succ,
omega,
objective,
x_hat,
&stats_before_omega,
&stats_after_omega,
);
ws.backward_accum.stats_before_buf = stats_before_omega;
ws.backward_accum.stats_after_buf = stats_after_omega;
if is_first {
save_basis_at_omega_zero(ws, succ, basis_slice, m, x_hat);
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn solve_lazy<S: SolverInterface + Send>(
ws: &mut SolverWorkspace<S>,
ctx: &StageContext<'_>,
training_ctx: &TrainingContext<'_>,
succ: &SuccessorSpec<'_>,
params: DcsParams,
raw_noise: &[f64],
x_hat: &[f64],
s: usize,
scenario: usize,
iteration: u64,
omega: usize,
continue_carry: bool,
) -> Result<(), SddpError> {
let indexer = training_ctx.indexer;
let core = &ctx.templates[s];
let col_scale = &ctx.templates[s].col_scale;
patch_opening_bounds(ws, ctx, training_ctx, raw_noise, x_hat, s);
let mut stats_before_omega = std::mem::take(&mut ws.backward_accum.stats_before_buf);
ws.solver.statistics_into(&mut stats_before_omega);
let mut state_duals = std::mem::take(&mut ws.backward_accum.state_duals_buf);
let dcs_ctx = DcsSolveContext {
stage_index: s,
scenario_index: scenario,
iteration: Some(iteration),
continue_carry,
};
lazy_solve_preloaded(
&mut ws.solver,
core,
succ.successor_pool,
indexer,
col_scale,
None,
&ws.backward_accum.dcs_initial_resident,
¶ms,
&mut ws.backward_accum.dcs_solve,
dcs_ctx,
)?;
let view = ws.backward_accum.dcs_solve.result_view();
let objective =
extract_state_duals_only(&view, indexer.n_state, indexer, col_scale, &mut state_duals);
accumulate_dcs_binding_counts(
view.dual,
&ws.backward_accum.dcs_solve.row_map,
succ.successor_pool,
succ.cut_activity_tolerance,
&mut ws.backward_accum.slot_increments,
);
let _ = view;
ws.backward_accum.state_duals_buf = state_duals;
let mut stats_after_omega = std::mem::take(&mut ws.backward_accum.stats_after_buf);
ws.solver.statistics_into(&mut stats_after_omega);
write_opening_outcome(
ws,
omega,
objective,
x_hat,
&stats_before_omega,
&stats_after_omega,
);
ws.backward_accum.stats_before_buf = stats_before_omega;
ws.backward_accum.stats_after_buf = stats_after_omega;
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn process_trial_point_backward<S: SolverInterface + Send>(
ws: &mut SolverWorkspace<S>,
ctx: &StageContext<'_>,
training_ctx: &TrainingContext<'_>,
exchange: &ExchangeBuffers,
fwd_offset: usize,
iteration: u64,
risk_measures: &[RiskMeasure],
succ: &SuccessorSpec<'_>,
basis_slice: &mut BasisStoreSliceMut<'_>,
opening_solver: &StageOpeningSolver,
m: usize,
arena_offset: usize,
) -> Result<StagedCut, SddpError> {
let tree_view = training_ctx.stochastic.tree_view();
let x_hat = exchange.state_at(succ.my_rank, m);
let scenario = fwd_offset + m;
let s = succ.successor;
debug_assert_eq!(
ws.backward_accum.per_opening_stats.len(),
succ.probabilities.len(),
"per_opening_stats must be initialised to n_openings before each stage's trial-point loop"
);
let noise_key_diag = training_ctx.noise_key_diag;
let solve_order = tree_view.solve_order_data(s);
debug_assert_eq!(
solve_order.len(),
succ.probabilities.len(),
"solve_order(s) must be a permutation of 0..n_openings"
);
let first = solve_order[0] as usize;
let mut omega_position = 0usize;
while omega_position < succ.probabilities.len() {
let omega = solve_order[omega_position] as usize;
omega_position += 1;
let raw_noise = tree_view.opening(s, omega);
let is_first = omega == first;
opening_solver.solve_opening(
ws,
ctx,
training_ctx,
succ,
basis_slice,
raw_noise,
x_hat,
s,
scenario,
iteration,
m,
omega,
is_first,
)?;
if let Some(diag) = noise_key_diag {
let simplex_iterations = ws.backward_accum.per_opening_stats[omega].simplex_iterations;
let noise_key = diag.key(s, omega).unwrap_or(f64::NAN);
eprintln!(
"COBRE_W1_DIAG\tstage={s}\ttrial={scenario}\tomega={omega}\t\
noise_key={noise_key:.17e}\tsimplex_iterations={simplex_iterations}"
);
}
}
let n_openings = succ.probabilities.len();
let mut agg_intercept = 0.0_f64;
risk_measures[succ.t].aggregate_cut_into(
&ws.backward_accum.outcomes[..n_openings],
succ.probabilities,
&mut agg_intercept,
&mut ws.backward_accum.agg_coefficients,
&mut ws.backward_accum.risk_scratch,
);
let n_state = ws.backward_accum.agg_coefficients.len();
let coefficients_range = arena_offset..arena_offset + n_state;
debug_assert!(
coefficients_range.end <= ws.backward_accum.agg_arena.len(),
"agg_arena must be sized to cover this trial point's slot before the solve"
);
ws.backward_accum.agg_arena[coefficients_range.clone()]
.copy_from_slice(&ws.backward_accum.agg_coefficients[..n_state]);
debug_assert!(
u32::try_from(scenario).is_ok(),
"global scenario index overflows u32"
);
#[allow(clippy::cast_possible_truncation)]
let forward_pass_index = scenario as u32;
let pop = ws.backward_accum.slot_increments.len();
for slot in 0..pop {
let count = ws.backward_accum.slot_increments[slot];
if count > 0 {
ws.backward_accum.metadata_sync_contribution[slot] += count;
}
}
Ok(StagedCut {
trial_point_idx: m,
intercept: agg_intercept,
coefficients_range,
forward_pass_index,
})
}