use cobre_solver::{SolverInterface, SolverStatistics};
use crate::{
cut::{CutRowMap, pool::CutPool},
forward::write_capture_metadata,
solver_stats::SolverStatsDelta,
workspace::{BasisStoreSliceMut, CapturedBasis, SolverWorkspace},
};
use super::SuccessorSpec;
pub(crate) fn accumulate_opening_outcome<S: SolverInterface + Send>(
ws: &mut SolverWorkspace<S>,
succ: &SuccessorSpec<'_>,
omega: usize,
objective: f64,
x_hat: &[f64],
stats_before: &SolverStatistics,
stats_after: &SolverStatistics,
) {
write_opening_outcome(ws, omega, objective, x_hat, stats_before, stats_after);
for (cut_idx, &slot) in succ.successor_active_slots.iter().enumerate() {
if ws
.backward_accum
.cut_duals_buf
.get(cut_idx)
.is_some_and(|&d| d > succ.cut_activity_tolerance)
{
ws.backward_accum.slot_increments[slot] += 1;
}
}
}
pub(crate) fn accumulate_dcs_binding_counts(
dual: &[f64],
row_map: &CutRowMap,
pool: &CutPool,
cut_activity_tolerance: f64,
slot_increments: &mut [u64],
) {
for (slot, increment) in slot_increments
.iter_mut()
.enumerate()
.take(pool.populated_count)
{
let Some(lp_row) = row_map.lp_row_for_slot(slot) else {
continue;
};
if dual
.get(lp_row)
.is_some_and(|&d| d > cut_activity_tolerance)
{
*increment += 1;
}
}
}
pub(crate) fn write_opening_outcome<S: SolverInterface + Send>(
ws: &mut SolverWorkspace<S>,
omega: usize,
objective: f64,
x_hat: &[f64],
stats_before: &SolverStatistics,
stats_after: &SolverStatistics,
) {
let opening_delta = SolverStatsDelta::from_snapshots(stats_before, stats_after);
SolverStatsDelta::accumulate_into(
&mut ws.backward_accum.per_opening_stats[omega],
&opening_delta,
);
let out = &mut ws.backward_accum.outcomes[omega];
out.coefficients
.copy_from_slice(&ws.backward_accum.state_duals_buf);
out.objective_value = objective;
out.intercept = objective
- out
.coefficients
.iter()
.zip(x_hat)
.map(|(pi, x)| pi * x)
.sum::<f64>();
}
pub(crate) fn save_basis_at_omega_zero<S: SolverInterface + Send>(
ws: &mut SolverWorkspace<S>,
succ: &SuccessorSpec<'_>,
basis_slice: &mut BasisStoreSliceMut<'_>,
m: usize,
x_hat: &[f64],
) {
let s = succ.successor;
let num_cols = succ.baked_template.num_cols;
let base_row_count = succ.template_num_rows;
let cut_row_count = succ.num_cuts_at_successor;
let basis_row_capacity = base_row_count + cut_row_count;
if let Some(captured) = basis_slice.get_mut(m, s).as_mut() {
ws.solver.get_basis(&mut captured.basis);
write_capture_metadata(
captured,
succ.successor_pool,
base_row_count,
cut_row_count,
x_hat,
);
} else {
let mut captured = CapturedBasis::new(
num_cols,
basis_row_capacity,
base_row_count,
cut_row_count,
x_hat.len(),
);
ws.solver.get_basis(&mut captured.basis);
write_capture_metadata(
&mut captured,
succ.successor_pool,
base_row_count,
cut_row_count,
x_hat,
);
*basis_slice.get_mut(m, s) = Some(captured);
}
}