use std::time::Instant;
use cobre_solver::{
Basis, ProfiledSolver, RowBatch, SolutionView, SolverError, SolverInterface, StageTemplate,
};
use crate::basis_reconstruct::{
ReconstructionTarget, enforce_basic_count_invariant, reconstruct_basis_uniform_basic,
};
use crate::cut::row::append_slots_to_lp;
use crate::cut::{CutPool, CutRowMap};
use crate::cut_selection::CutSelectionStrategy;
use crate::error::SddpError;
use crate::gemm::gemm_block;
use crate::indexer::StageIndexer;
use crate::workspace::CapturedBasis;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct DcsParams {
pub k1: Option<u32>,
pub k2: u32,
pub nadic: u32,
pub epsilon_viol: f64,
pub start_iteration: u64,
pub max_inner_iterations: u32,
}
impl Default for DcsParams {
fn default() -> Self {
Self {
k1: None,
k2: 5,
nadic: 10,
epsilon_viol: 1e-10,
start_iteration: 2,
max_inner_iterations: 50,
}
}
}
impl DcsParams {
#[must_use]
pub fn from_strategy(strategy: &CutSelectionStrategy) -> Option<Self> {
match strategy {
CutSelectionStrategy::Dynamic {
k1,
k2,
nadic,
epsilon_viol,
start_iteration,
} => Some(Self {
k1: *k1,
k2: *k2,
nadic: *nadic,
epsilon_viol: *epsilon_viol,
start_iteration: *start_iteration,
max_inner_iterations: Self::default().max_inner_iterations,
}),
CutSelectionStrategy::Level1 { .. }
| CutSelectionStrategy::Lml1 { .. }
| CutSelectionStrategy::Dominated { .. } => None,
}
}
#[must_use]
pub fn is_active(&self, iteration: u64) -> bool {
iteration >= self.start_iteration
}
}
#[derive(Default)]
pub struct DcsScoringScratch {
pub unscaled_state: Vec<f64>,
pub cand_coef_block: Vec<f64>,
pub alpha: Vec<f64>,
pub cand_slots: Vec<u32>,
pub violations: Vec<(f64, u32)>,
}
impl DcsScoringScratch {
pub fn reserve(&mut self, n_state: usize, pool_capacity: usize) {
if self.unscaled_state.capacity() < n_state {
self.unscaled_state
.reserve(n_state - self.unscaled_state.capacity());
}
let coef_capacity = pool_capacity * n_state;
if self.cand_coef_block.capacity() < coef_capacity {
self.cand_coef_block
.reserve(coef_capacity - self.cand_coef_block.capacity());
}
if self.alpha.capacity() < pool_capacity {
self.alpha.reserve(pool_capacity - self.alpha.capacity());
}
if self.cand_slots.capacity() < pool_capacity {
self.cand_slots
.reserve(pool_capacity - self.cand_slots.capacity());
}
if self.violations.capacity() < pool_capacity {
self.violations
.reserve(pool_capacity - self.violations.capacity());
}
}
}
pub fn score_violated_candidates(
pool: &CutPool,
indexer: &StageIndexer,
primal: &[f64],
col_scale: &[f64],
resident: &CutRowMap,
params: &DcsParams,
current_iteration: u64,
scratch: &mut DcsScoringScratch,
out_selected: &mut Vec<u32>,
) -> usize {
let n_state = indexer.n_state;
let theta = indexer.theta;
debug_assert!(
primal.len() > theta,
"score_violated_candidates: primal.len() {} <= theta ({theta})",
primal.len(),
);
debug_assert_eq!(
pool.state_dimension, n_state,
"score_violated_candidates: pool.state_dimension {} != indexer.n_state {}",
pool.state_dimension, n_state,
);
debug_assert!(
col_scale.is_empty() || col_scale.len() == primal.len(),
"score_violated_candidates: col_scale.len() {} != primal.len() {} (non-empty col_scale \
must be per-column)",
col_scale.len(),
primal.len(),
);
out_selected.clear();
scratch.violations.clear();
scratch.cand_coef_block.clear();
scratch.cand_slots.clear();
scratch.alpha.clear();
let theta_raw = if col_scale.is_empty() {
primal[theta]
} else {
col_scale[theta] * primal[theta]
};
scratch.unscaled_state.clear();
for j in 0..n_state {
let c = indexer.state_to_lp_column(j);
let x_raw = if col_scale.is_empty() {
primal[c]
} else {
col_scale[c] * primal[c]
};
scratch.unscaled_state.push(x_raw);
}
for (slot, _intercept, coefficients) in pool.active_cuts() {
if let Some(k1) = params.k1 {
let age = current_iteration.saturating_sub(pool.metadata[slot].iteration_generated);
if age >= u64::from(k1) {
continue;
}
}
if resident.lp_row_for_slot(slot).is_some() {
continue;
}
scratch.cand_coef_block.extend_from_slice(coefficients);
#[allow(clippy::cast_possible_truncation)]
scratch.cand_slots.push(slot as u32);
}
let k_rows = scratch.cand_slots.len();
scratch.alpha.clear();
scratch.alpha.resize(k_rows, 0.0);
gemm_block(
&scratch.cand_coef_block,
&scratch.unscaled_state,
k_rows,
n_state,
1,
&mut scratch.alpha,
);
for (i, &slot) in scratch.cand_slots.iter().enumerate() {
let alpha = pool.intercepts[slot as usize] + scratch.alpha[i];
let v = alpha - theta_raw;
if v > params.epsilon_viol {
scratch.violations.push((v, slot));
}
}
let violated_count = scratch.violations.len();
scratch
.violations
.sort_unstable_by(|a, b| b.0.total_cmp(&a.0).then(a.1.cmp(&b.1)));
let take = (params.nadic as usize).min(violated_count);
out_selected.extend(scratch.violations[..take].iter().map(|&(_, slot)| slot));
violated_count
}
#[derive(Clone, Copy, Debug)]
pub struct DcsSolveContext {
pub stage_index: usize,
pub scenario_index: usize,
pub iteration: Option<u64>,
pub continue_carry: bool,
}
pub struct DcsSolveScratch {
pub batch: RowBatch,
pub scoring: DcsScoringScratch,
pub out_selected: Vec<u32>,
pub recon_basis: Basis,
pub row_map: CutRowMap,
pub res_primal: Vec<f64>,
pub res_dual: Vec<f64>,
pub res_reduced_costs: Vec<f64>,
pub res_objective: f64,
pub res_iterations: u64,
pub res_solve_time_seconds: f64,
pub scoring_time_seconds: f64,
pub rows_in_lp_sum: u64,
pub rows_in_lp_count: u64,
pub rows_in_lp_max: u64,
}
impl Default for DcsSolveScratch {
fn default() -> Self {
Self {
batch: RowBatch {
num_rows: 0,
row_starts: Vec::new(),
col_indices: Vec::new(),
values: Vec::new(),
row_lower: Vec::new(),
row_upper: Vec::new(),
},
scoring: DcsScoringScratch::default(),
out_selected: Vec::new(),
recon_basis: Basis::new(0, 0),
res_primal: Vec::new(),
res_dual: Vec::new(),
res_reduced_costs: Vec::new(),
res_objective: 0.0,
res_iterations: 0,
res_solve_time_seconds: 0.0,
scoring_time_seconds: 0.0,
rows_in_lp_sum: 0,
rows_in_lp_count: 0,
rows_in_lp_max: 0,
row_map: CutRowMap::new(0, 0),
}
}
}
impl DcsSolveScratch {
pub fn reserve(&mut self, n_state: usize, pool_capacity: usize) {
self.scoring.reserve(n_state, pool_capacity);
if self.out_selected.capacity() < pool_capacity {
self.out_selected
.reserve(pool_capacity - self.out_selected.capacity());
}
self.row_map.reset(pool_capacity, 0);
for buf in [
&mut self.res_primal,
&mut self.res_dual,
&mut self.res_reduced_costs,
] {
if buf.capacity() < pool_capacity {
buf.reserve(pool_capacity - buf.capacity());
}
}
}
#[must_use]
pub fn result_view(&self) -> SolutionView<'_> {
SolutionView {
objective: self.res_objective,
primal: &self.res_primal,
dual: &self.res_dual,
reduced_costs: &self.res_reduced_costs,
iterations: self.res_iterations,
solve_time_seconds: self.res_solve_time_seconds,
}
}
fn store_result(&mut self, view: &SolutionView<'_>) {
self.res_objective = view.objective;
self.res_iterations = view.iterations;
self.res_solve_time_seconds = view.solve_time_seconds;
self.res_primal.clear();
self.res_primal.extend_from_slice(view.primal);
self.res_dual.clear();
self.res_dual.extend_from_slice(view.dual);
self.res_reduced_costs.clear();
self.res_reduced_costs.extend_from_slice(view.reduced_costs);
let rows_in_lp = self.row_map.total_cut_rows() as u64;
self.rows_in_lp_sum += rows_in_lp;
self.rows_in_lp_count += 1;
self.rows_in_lp_max = self.rows_in_lp_max.max(rows_in_lp);
}
}
fn map_solver_error(e: SolverError, ctx: DcsSolveContext) -> SddpError {
match e {
SolverError::Infeasible => SddpError::Infeasible {
stage: ctx.stage_index,
iteration: ctx.iteration.unwrap_or(0),
scenario: ctx.scenario_index,
},
other => SddpError::Solver(other),
}
}
#[allow(clippy::too_many_arguments)]
pub fn lazy_solve_preloaded<S: SolverInterface>(
solver: &mut ProfiledSolver<S>,
core: &StageTemplate,
pool: &CutPool,
indexer: &StageIndexer,
col_scale: &[f64],
stored_basis: Option<&CapturedBasis>,
initial_resident: &[u32],
params: &DcsParams,
scratch: &mut DcsSolveScratch,
ctx: DcsSolveContext,
) -> Result<(), SddpError> {
let current_iteration = ctx.iteration.unwrap_or(0);
let mut view = if ctx.continue_carry {
solver.solve(None).map_err(|e| map_solver_error(e, ctx))?
} else {
scratch.row_map.reset(pool.populated_count, core.num_rows);
append_slots_to_lp(
solver,
pool,
initial_resident,
indexer,
col_scale,
&mut scratch.row_map,
&mut scratch.batch,
);
let cut_rows = scratch.row_map.total_cut_rows();
if let Some(stored) = stored_basis {
let target = ReconstructionTarget {
base_row_count: core.num_rows,
num_cols: core.num_cols,
};
reconstruct_basis_uniform_basic(stored, target, cut_rows, &mut scratch.recon_basis);
enforce_basic_count_invariant(
&mut scratch.recon_basis,
core.num_rows + cut_rows,
core.num_rows,
);
solver
.solve(Some(&scratch.recon_basis))
.map_err(|e| map_solver_error(e, ctx))?
} else {
solver.solve(None).map_err(|e| map_solver_error(e, ctx))?
}
};
for _ in 0..params.max_inner_iterations {
let t0 = Instant::now();
let violated = score_violated_candidates(
pool,
indexer,
view.primal,
col_scale,
&scratch.row_map,
params,
current_iteration,
&mut scratch.scoring,
&mut scratch.out_selected,
);
scratch.scoring_time_seconds += t0.elapsed().as_secs_f64();
if violated == 0 {
scratch.store_result(&view);
return Ok(());
}
let _ = view;
append_slots_to_lp(
solver,
pool,
&scratch.out_selected,
indexer,
col_scale,
&mut scratch.row_map,
&mut scratch.batch,
);
view = solver.solve(None).map_err(|e| map_solver_error(e, ctx))?;
}
let mut all_params = *params;
all_params.nadic = u32::MAX;
let t0 = Instant::now();
let remaining = score_violated_candidates(
pool,
indexer,
view.primal,
col_scale,
&scratch.row_map,
&all_params,
current_iteration,
&mut scratch.scoring,
&mut scratch.out_selected,
);
scratch.scoring_time_seconds += t0.elapsed().as_secs_f64();
let _ = view;
if remaining > 0 {
append_slots_to_lp(
solver,
pool,
&scratch.out_selected,
indexer,
col_scale,
&mut scratch.row_map,
&mut scratch.batch,
);
}
let view = solver.solve(None).map_err(|e| map_solver_error(e, ctx))?;
scratch.store_result(&view);
Ok(())
}
pub fn build_initial_resident_set(
pool: &CutPool,
current_iteration: u64,
k2: u32,
out: &mut Vec<u32>,
) {
debug_assert!(
pool.metadata.len() >= pool.populated_count,
"build_initial_resident_set: metadata.len() {} < populated_count {}",
pool.metadata.len(),
pool.populated_count,
);
out.clear();
let window = u64::from(k2);
#[allow(clippy::cast_possible_truncation)]
for s in 0..pool.populated_count {
if !pool.active[s] {
continue;
}
let meta = &pool.metadata[s];
let within_window = current_iteration.saturating_sub(meta.last_active_iter) <= window;
let is_current_iter = meta.iteration_generated == current_iteration;
if within_window || is_current_iter {
out.push(s as u32);
}
}
}
#[cfg(test)]
#[allow(clippy::doc_markdown)]
mod tests {
use cobre_solver::{
ActiveProfile, ActiveSolver, Basis, ProfiledSolver, RowBatch, SolverError, SolverInterface,
SolverStatistics, StageTemplate,
};
use super::{
DcsParams, DcsScoringScratch, DcsSolveContext, DcsSolveScratch, build_initial_resident_set,
lazy_solve_preloaded, score_violated_candidates,
};
use crate::cut::{CutPool, CutRowMap};
use crate::cut_selection::{CutMetadata, CutSelectionStrategy};
use crate::indexer::StageIndexer;
#[test]
fn default_matches_spec() {
let params = DcsParams::default();
assert_eq!(
params,
DcsParams {
k1: None,
k2: 5,
nadic: 10,
epsilon_viol: 1e-10,
start_iteration: 2,
max_inner_iterations: 50,
}
);
assert_eq!(params.k1, None);
}
#[test]
fn from_strategy_dynamic_copies_fields() {
let strategy = CutSelectionStrategy::Dynamic {
k1: Some(20),
k2: 7,
nadic: 3,
epsilon_viol: 1e-9,
start_iteration: 4,
};
let params = DcsParams::from_strategy(&strategy)
.expect("from_strategy must return Some for the Dynamic variant");
assert_eq!(
params,
DcsParams {
k1: Some(20),
k2: 7,
nadic: 3,
epsilon_viol: 1e-9,
start_iteration: 4,
max_inner_iterations: 50,
}
);
let strategy_inf = CutSelectionStrategy::Dynamic {
k1: None,
k2: 5,
nadic: 10,
epsilon_viol: 1e-10,
start_iteration: 2,
};
let params_inf = DcsParams::from_strategy(&strategy_inf)
.expect("from_strategy must return Some for the Dynamic variant");
assert_eq!(params_inf.k1, None);
}
#[test]
fn from_strategy_non_dynamic_is_none() {
let level1 = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
assert!(DcsParams::from_strategy(&level1).is_none());
let dominated = CutSelectionStrategy::Dominated {
threshold: 1e-6,
check_frequency: 10,
};
assert!(DcsParams::from_strategy(&dominated).is_none());
}
#[test]
fn is_active_threshold() {
let params = DcsParams {
start_iteration: 2,
..DcsParams::default()
};
assert_eq!(
[
params.is_active(0),
params.is_active(1),
params.is_active(2),
params.is_active(3),
],
[false, false, true, true]
);
}
#[test]
fn dcs_params_is_copy() {
fn assert_copy<T: Copy>() {}
assert_copy::<DcsParams>();
}
const N_STATE: usize = 2;
const THETA_COL: usize = 6;
const PRIMAL_LEN: usize = THETA_COL + 1;
fn indexer() -> StageIndexer {
StageIndexer::new(2, 0)
}
fn empty_pool() -> CutPool {
CutPool::new(16, N_STATE, 16, 0)
}
fn add(pool: &mut CutPool, slot: u32, intercept: f64, coeffs: &[f64], iter_generated: u64) {
pool.add_cut(0, slot, intercept, coeffs);
pool.metadata[slot as usize].iteration_generated = iter_generated;
}
fn empty_resident() -> CutRowMap {
CutRowMap::new(16, 0)
}
fn params(nadic: u32, epsilon_viol: f64, k1: Option<u32>) -> DcsParams {
DcsParams {
k1,
nadic,
epsilon_viol,
..DcsParams::default()
}
}
#[test]
fn scores_and_orders_two_violated_descending() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 0, 5.0, &[0.0, 0.0], 1); add(&mut pool, 1, 2.0, &[0.0, 0.0], 1); let primal = vec![0.0_f64; PRIMAL_LEN]; let resident = empty_resident();
let p = params(10, 1e-10, None);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let count = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(count, 2);
assert_eq!(
out,
vec![0, 1],
"descending violation: slot 0 (5) then slot 1 (2)"
);
}
#[test]
fn respects_nadic_cap_returns_full_violated_count() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 0, 5.0, &[0.0, 0.0], 1);
add(&mut pool, 1, 2.0, &[0.0, 0.0], 1);
let primal = vec![0.0_f64; PRIMAL_LEN];
let resident = empty_resident();
let p = params(1, 1e-10, None);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let count = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(count, 2, "full violated count regardless of nadic");
assert_eq!(out, vec![0], "top-1 only");
}
#[test]
fn tie_break_ascending_slot_id() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 7, 3.0, &[0.0, 0.0], 1);
add(&mut pool, 4, 3.0, &[0.0, 0.0], 1);
let primal = vec![0.0_f64; PRIMAL_LEN];
let resident = empty_resident();
let p = params(1, 1e-10, None);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let count = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(count, 2);
assert_eq!(out, vec![4], "equal violation → ascending slot id wins");
}
#[test]
fn applies_col_scale_unscaling() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 0, 0.5, &[1.0, 0.0], 1);
let mut primal = vec![0.0_f64; PRIMAL_LEN];
primal[0] = 2.0; primal[THETA_COL] = 1.0; let mut col_scale = vec![1.0_f64; PRIMAL_LEN];
col_scale[0] = 0.5; col_scale[THETA_COL] = 2.0; let resident = empty_resident();
let p = params(10, 1e-10, None);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let count = score_violated_candidates(
&pool,
&idx,
&primal,
&col_scale,
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(
count, 0,
"with correct col_scale unscaling, v = -0.5 → not violated"
);
assert!(out.is_empty());
}
#[test]
fn skips_resident_slots() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 0, 5.0, &[0.0, 0.0], 1); add(&mut pool, 1, 2.0, &[0.0, 0.0], 1); let primal = vec![0.0_f64; PRIMAL_LEN];
let mut resident = empty_resident();
resident.insert(0); let p = params(10, 1e-10, None);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let count = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(count, 1, "resident slot 0 is excluded from scoring");
assert_eq!(out, vec![1]);
}
#[test]
fn k1_window_filters_old_cuts() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 0, 5.0, &[0.0, 0.0], 8); add(&mut pool, 1, 2.0, &[0.0, 0.0], 2); let primal = vec![0.0_f64; PRIMAL_LEN];
let resident = empty_resident();
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let p_finite = params(10, 1e-10, Some(5));
let count_finite = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p_finite,
10,
&mut scratch,
&mut out,
);
assert_eq!(count_finite, 1, "only the in-window cut is counted");
assert_eq!(out, vec![0]);
let p_inf = params(10, 1e-10, None);
let count_inf = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p_inf,
10,
&mut scratch,
&mut out,
);
assert_eq!(count_inf, 2, "k1 = None ⇒ all cuts eligible");
assert_eq!(out, vec![0, 1]);
}
#[test]
fn k1_window_keeps_warm_start_cuts() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 0, 5.0, &[0.0, 0.0], u64::MAX); let primal = vec![0.0_f64; PRIMAL_LEN];
let resident = empty_resident();
let p = params(10, 1e-10, Some(1)); let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let count = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(count, 1, "warm-start cut (age 0) is always in window");
assert_eq!(out, vec![0]);
}
#[test]
fn epsilon_viol_is_strict() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 0, 1.0, &[0.0, 0.0], 1);
let primal = vec![0.0_f64; PRIMAL_LEN];
let resident = empty_resident();
let p = params(10, 1.0, None);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let count = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(count, 0, "v == epsilon_viol is NOT strictly greater");
assert!(out.is_empty());
}
#[test]
fn no_violation_returns_empty() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 0, -1.0, &[0.0, 0.0], 1);
let primal = vec![0.0_f64; PRIMAL_LEN];
let resident = empty_resident();
let p = params(10, 1e-10, None);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let count = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(count, 0);
assert!(out.is_empty());
}
use crate::gemm::gemm_block;
fn splitmix64(state: &mut u64) -> u64 {
*state = state.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = *state;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn draw_f64(state: &mut u64) -> f64 {
let r = splitmix64(state);
let bits = (r >> 12) & ((1_u64 << 52) - 1);
f64::from_bits((1023_u64 << 52) | bits) - 1.5
}
#[allow(clippy::cast_possible_truncation)]
fn per_row_reference(
pool: &CutPool,
idx: &StageIndexer,
primal: &[f64],
col_scale: &[f64],
resident: &CutRowMap,
p: &DcsParams,
current_iteration: u64,
) -> (Vec<u64>, Vec<u32>) {
let n_state = idx.n_state;
let theta = idx.theta;
let theta_raw = if col_scale.is_empty() {
primal[theta]
} else {
col_scale[theta] * primal[theta]
};
let mut unscaled_state = Vec::with_capacity(n_state);
for j in 0..n_state {
let c = idx.state_to_lp_column(j);
let x_raw = if col_scale.is_empty() {
primal[c]
} else {
col_scale[c] * primal[c]
};
unscaled_state.push(x_raw);
}
let mut alpha_bits = Vec::new();
let mut violations: Vec<(f64, u32)> = Vec::new();
for (slot, intercept, coefficients) in pool.active_cuts() {
if let Some(k1) = p.k1 {
let age = current_iteration.saturating_sub(pool.metadata[slot].iteration_generated);
if age >= u64::from(k1) {
continue;
}
}
if resident.lp_row_for_slot(slot).is_some() {
continue;
}
let mut dot = [0.0_f64; 1];
gemm_block(coefficients, &unscaled_state, 1, n_state, 1, &mut dot);
alpha_bits.push(dot[0].to_bits());
let alpha = intercept + dot[0];
let v = alpha - theta_raw;
if v > p.epsilon_viol {
violations.push((v, slot as u32));
}
}
violations.sort_unstable_by(|a, b| b.0.total_cmp(&a.0).then(a.1.cmp(&b.1)));
let take = (p.nadic as usize).min(violations.len());
let out_selected = violations[..take].iter().map(|&(_, s)| s).collect();
(alpha_bits, out_selected)
}
#[allow(clippy::cast_possible_truncation)]
fn random_pool(cap: usize, seed: u64) -> CutPool {
let mut pool = CutPool::new(cap, N_STATE, cap as u32, 0);
let mut state = seed;
for slot in 0..cap {
let intercept = draw_f64(&mut state);
let coeffs = [draw_f64(&mut state), draw_f64(&mut state)];
pool.add_cut(0, slot as u32, intercept, &coeffs);
let gen_iter = 1 + (splitmix64(&mut state) % 12);
pool.metadata[slot].iteration_generated = gen_iter;
}
pool
}
#[test]
fn batched_scoring_bit_identical_to_per_row_reference() {
let idx = indexer();
let cap = 64;
let pool = random_pool(cap, 0x0BAD_F00D_C0FF_EE11);
let mut prng = 0xDEAD_BEEF_1234_5678_u64;
let mut primal = vec![0.0_f64; PRIMAL_LEN];
for v in &mut primal {
*v = draw_f64(&mut prng);
}
let mut col_scale = vec![1.0_f64; PRIMAL_LEN];
for v in &mut col_scale {
*v = 0.5 + (draw_f64(&mut prng) + 1.5) * 0.5; }
primal[THETA_COL] = -50.0;
let mut resident = CutRowMap::new(cap, 0);
for slot in (0..cap).step_by(3) {
resident.insert(slot);
}
let current_iteration = 10;
let p = params(7, 1e-9, Some(5));
let (ref_alpha_bits, ref_selected) = per_row_reference(
&pool,
&idx,
&primal,
&col_scale,
&resident,
&p,
current_iteration,
);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let count = score_violated_candidates(
&pool,
&idx,
&primal,
&col_scale,
&resident,
&p,
current_iteration,
&mut scratch,
&mut out,
);
assert_eq!(
scratch.alpha.len(),
ref_alpha_bits.len(),
"batched and per-row reference must gather the same candidate set"
);
let batched_alpha_bits: Vec<u64> = scratch.alpha.iter().map(|a| a.to_bits()).collect();
assert_eq!(
batched_alpha_bits, ref_alpha_bits,
"batched alpha must be bit-identical to the per-row reference"
);
assert_eq!(
out, ref_selected,
"batched out_selected must match the per-row reference exactly"
);
assert!(count >= out.len());
assert!(
scratch.alpha.len() > 1,
"fixture must gather >1 candidate (got {})",
scratch.alpha.len()
);
assert!(
!out.is_empty(),
"fixture must select at least one violation"
);
assert!(
count > out.len(),
"fixture must have more violations ({count}) than the nadic cap ({})",
out.len()
);
}
#[test]
fn batched_scoring_single_gemm_per_pass() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 0, 5.0, &[1.0, 0.0], 10);
add(&mut pool, 1, 2.0, &[0.0, 1.0], 10);
add(&mut pool, 2, 9.0, &[0.5, 0.5], 10);
let primal = vec![0.0_f64; PRIMAL_LEN];
let resident = empty_resident();
let p = params(10, 1e-10, None);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let _ = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
let k = 3;
assert_eq!(scratch.cand_slots.len(), k, "all k candidates gathered");
assert_eq!(
scratch.cand_slots,
vec![0, 1, 2],
"ascending-slot gather order"
);
assert_eq!(
scratch.cand_coef_block.len(),
k * N_STATE,
"coef block is exactly k_rows × n_state — one batched GEMM input"
);
assert_eq!(
scratch.alpha.len(),
k,
"alpha holds exactly k activities from the single batched GEMM"
);
}
#[test]
fn batched_scoring_gathers_only_eligible() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 0, 5.0, &[0.0, 0.0], 8); add(&mut pool, 1, 2.0, &[0.0, 0.0], 2); add(&mut pool, 2, 9.0, &[0.0, 0.0], 9); add(&mut pool, 3, 4.0, &[0.0, 0.0], 9); let primal = vec![0.0_f64; PRIMAL_LEN];
let mut resident = empty_resident();
resident.insert(2); let p = params(10, 1e-10, Some(5)); let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let _ = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(
scratch.cand_slots,
vec![0, 3],
"out-of-window (1) and resident (2) slots excluded from the gather"
);
assert_eq!(scratch.alpha.len(), 2);
}
#[test]
fn batched_scoring_empty_candidates_no_op() {
let idx = indexer();
let pool = empty_pool(); let primal = vec![0.0_f64; PRIMAL_LEN];
let resident = empty_resident();
let p = params(10, 1e-10, None);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let count = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(count, 0);
assert!(scratch.cand_slots.is_empty());
assert!(scratch.cand_coef_block.is_empty());
assert!(scratch.alpha.is_empty());
assert!(out.is_empty());
}
#[test]
fn batched_scoring_all_resident_no_candidates() {
let idx = indexer();
let mut pool = empty_pool();
add(&mut pool, 0, 5.0, &[1.0, 0.0], 10);
add(&mut pool, 1, 2.0, &[0.0, 1.0], 10);
let primal = vec![0.0_f64; PRIMAL_LEN];
let mut resident = empty_resident();
resident.insert(0);
resident.insert(1);
let p = params(10, 1e-10, None);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
let count = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(count, 0, "all candidates resident → none scored");
assert!(scratch.cand_slots.is_empty(), "all-resident → empty gather");
assert!(scratch.alpha.is_empty());
assert!(out.is_empty());
}
#[test]
fn batched_scoring_scratch_is_growth_only() {
let idx = indexer();
let cap = 32;
let pool = random_pool(cap, 0x00C0_FFEE_0BAD_CAFE);
let primal = vec![0.0_f64; PRIMAL_LEN];
let resident = empty_resident();
let p = params(10, 1e-10, None);
let mut scratch = DcsScoringScratch::default();
let mut out = Vec::new();
scratch.reserve(N_STATE, cap);
let coef_cap = scratch.cand_coef_block.capacity();
let alpha_cap = scratch.alpha.capacity();
let slots_cap = scratch.cand_slots.capacity();
let viol_cap = scratch.violations.capacity();
let state_cap = scratch.unscaled_state.capacity();
assert!(
coef_cap >= cap * N_STATE,
"coef block reserved to cap*n_state"
);
assert!(alpha_cap >= cap);
assert!(slots_cap >= cap);
for _ in 0..5 {
let _ = score_violated_candidates(
&pool,
&idx,
&primal,
&[],
&resident,
&p,
10,
&mut scratch,
&mut out,
);
assert_eq!(scratch.cand_coef_block.capacity(), coef_cap);
assert_eq!(scratch.alpha.capacity(), alpha_cap);
assert_eq!(scratch.cand_slots.capacity(), slots_cap);
assert_eq!(scratch.violations.capacity(), viol_cap);
assert_eq!(scratch.unscaled_state.capacity(), state_cap);
}
}
const STATE_X0: f64 = 2.0;
const LAZY_THETA_COL: usize = 3;
fn lazy_indexer() -> StageIndexer {
StageIndexer::new(1, 0)
}
fn core_template() -> StageTemplate {
StageTemplate {
num_cols: 4,
num_rows: 0,
num_nz: 0,
col_starts: vec![0_i32; 5], row_indices: Vec::new(),
values: Vec::new(),
col_lower: vec![STATE_X0, 0.0, 0.0, -1.0e6],
col_upper: vec![STATE_X0, 0.0, 0.0, 1.0e6],
objective: vec![0.0, 0.0, 0.0, 1.0], row_lower: Vec::new(),
row_upper: Vec::new(),
n_state: 1,
n_transfer: 0,
n_dual_relevant: 1,
n_hydro: 1,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
}
}
fn make_three_cut_pool() -> CutPool {
let mut pool = CutPool::new(16, 1, 16, 0);
pool.add_cut(0, 0, 1.0, &[0.0]);
pool.add_cut(0, 1, 0.0, &[1.0]);
pool.add_cut(0, 2, 5.0, &[0.0]);
for slot in 0..3 {
pool.metadata[slot].iteration_generated = 1;
}
pool
}
fn make_four_cut_pool() -> CutPool {
let mut pool = CutPool::new(16, 1, 16, 0);
pool.add_cut(0, 0, 1.0, &[0.0]);
pool.add_cut(0, 1, 0.0, &[1.0]);
pool.add_cut(0, 2, 5.0, &[0.0]);
pool.add_cut(0, 3, 0.0, &[2.0]);
for slot in 0..4 {
pool.metadata[slot].iteration_generated = 1;
}
pool
}
fn active_profiled() -> ProfiledSolver<ActiveSolver> {
ProfiledSolver::new(ActiveSolver::new().expect("ActiveSolver::new()"))
}
fn ctx() -> DcsSolveContext {
DcsSolveContext {
stage_index: 0,
scenario_index: 0,
iteration: Some(10),
continue_carry: false,
}
}
fn lazy_params(nadic: u32, max_inner: u32) -> DcsParams {
DcsParams {
k1: None,
nadic,
max_inner_iterations: max_inner,
..DcsParams::default()
}
}
fn solve_all_cuts(pool: &CutPool, indexer: &StageIndexer) -> (f64, f64) {
let mut solver = active_profiled();
let core = core_template();
solver.load_model(&core);
let mut row_map = CutRowMap::new(pool.populated_count, core.num_rows);
let mut batch = RowBatch {
num_rows: 0,
row_starts: Vec::new(),
col_indices: Vec::new(),
values: Vec::new(),
row_lower: Vec::new(),
row_upper: Vec::new(),
};
let all_slots: Vec<u32> = (0..pool.populated_count as u32).collect();
crate::cut::row::append_slots_to_lp(
&mut solver,
pool,
&all_slots,
indexer,
&[],
&mut row_map,
&mut batch,
);
let view = solver.solve(None).expect("all-cuts solve must succeed");
(view.objective, view.primal[LAZY_THETA_COL])
}
#[test]
fn lazy_solve_exact_matches_all_cuts() {
let indexer = lazy_indexer();
let pool = make_three_cut_pool();
let (all_obj, all_theta) = solve_all_cuts(&pool, &indexer);
let mut solver = active_profiled();
let core = core_template();
let mut scratch = DcsSolveScratch::default();
let params = lazy_params(10, 50);
solver.load_model(&core);
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[0, 1], ¶ms,
&mut scratch,
ctx(),
)
.expect("lazy_solve_preloaded must succeed");
let view = scratch.result_view();
assert!(
(view.objective - all_obj).abs() < 1e-9,
"objective {} != all-cuts {all_obj}",
view.objective
);
assert!(
(view.primal[LAZY_THETA_COL] - all_theta).abs() < 1e-9,
"theta {} != all-cuts {all_theta}",
view.primal[LAZY_THETA_COL]
);
assert_eq!(view.objective, all_obj);
assert_eq!(view.primal[LAZY_THETA_COL], all_theta);
}
#[test]
fn lazy_solve_no_violation_stops_immediately() {
let indexer = lazy_indexer();
let pool = make_three_cut_pool();
let mut solver = active_profiled();
let core = core_template();
let mut scratch = DcsSolveScratch::default();
let params = lazy_params(10, 50);
solver.load_model(&core);
let solves_before = solver.statistics().solve_count;
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[0, 1, 2],
¶ms,
&mut scratch,
ctx(),
)
.expect("lazy_solve_preloaded must succeed");
let solve_delta = solver.statistics().solve_count - solves_before;
let view = scratch.result_view();
assert_eq!(view.primal[LAZY_THETA_COL], 5.0);
assert!(scratch.out_selected.is_empty());
assert_eq!(
solve_delta, 1,
"no-violation path must issue exactly 1 solve (no redundant exit re-solve)"
);
}
#[test]
fn lazy_solve_grows_to_include_binding_cut() {
let indexer = lazy_indexer();
let pool = make_three_cut_pool();
let mut solver = active_profiled();
let core = core_template();
let mut scratch = DcsSolveScratch::default();
let params = lazy_params(10, 50);
solver.load_model(&core);
let solves_before = solver.statistics().solve_count;
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[0, 1], ¶ms,
&mut scratch,
ctx(),
)
.expect("lazy_solve_preloaded must succeed");
let solve_delta = solver.statistics().solve_count - solves_before;
let view = scratch.result_view();
assert_eq!(
view.primal[LAZY_THETA_COL], 5.0,
"final theta must reflect the added binding cut"
);
assert_eq!(
solve_delta, 2,
"one-addition path must issue exactly 2 solves (no redundant exit re-solve)"
);
}
#[test]
fn lazy_solve_continue_carry_exact_across_openings() {
let indexer = lazy_indexer();
let pool = make_four_cut_pool();
let core = core_template();
let params = lazy_params(10, 50);
let mut solver = active_profiled();
let mut scratch = DcsSolveScratch::default();
solver.load_model(&core);
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[0, 1],
¶ms,
&mut scratch,
DcsSolveContext {
continue_carry: false,
..ctx()
},
)
.expect("opening A (fresh) must solve");
assert_eq!(
scratch.result_view().primal[LAZY_THETA_COL],
5.0,
"opening A theta (x0=2)"
);
solver.set_col_bounds(&[0], &[10.0], &[10.0]);
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[], ¶ms,
&mut scratch,
DcsSolveContext {
continue_carry: true,
..ctx()
},
)
.expect("opening B (continue) must solve");
assert_eq!(
scratch.result_view().primal[LAZY_THETA_COL],
20.0,
"opening B theta (x0=10) must reach the all-cuts optimum via the \
carried-LP continue path (slot 3 added on top of the carried set)"
);
solver.set_col_bounds(&[0], &[2.0], &[2.0]);
let solves_before = solver.statistics().solve_count;
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[], ¶ms,
&mut scratch,
DcsSolveContext {
continue_carry: true,
..ctx()
},
)
.expect("opening C (continue, no add) must solve");
let solve_delta = solver.statistics().solve_count - solves_before;
assert_eq!(
scratch.result_view().primal[LAZY_THETA_COL],
5.0,
"opening C theta (x0=2) reverts to the carried binding cut (slot 2)"
);
assert_eq!(
solve_delta, 1,
"continue-carry no-add opening must issue exactly 1 solve (no redundant exit re-solve)"
);
}
#[test]
fn scoring_time_default_is_zero() {
let scratch = DcsSolveScratch::default();
assert_eq!(scratch.scoring_time_seconds, 0.0);
}
#[test]
fn scoring_time_increases_when_inner_loop_runs() {
let indexer = lazy_indexer();
let pool = make_three_cut_pool();
let mut solver = active_profiled();
let core = core_template();
let mut scratch = DcsSolveScratch::default();
let params = lazy_params(10, 50);
let before = scratch.scoring_time_seconds;
solver.load_model(&core);
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[0, 1], ¶ms,
&mut scratch,
ctx(),
)
.expect("lazy_solve_preloaded must succeed");
assert!(
scratch.scoring_time_seconds > before,
"scoring accumulator {} must exceed its pre-call value {before}",
scratch.scoring_time_seconds
);
}
#[test]
fn rows_in_lp_tracks_resident_set_size_per_solve() {
let indexer = lazy_indexer();
let pool = make_three_cut_pool();
let mut solver = active_profiled();
let core = core_template();
let mut scratch = DcsSolveScratch::default();
let params = lazy_params(10, 50);
assert_eq!(scratch.rows_in_lp_sum, 0);
assert_eq!(scratch.rows_in_lp_count, 0);
assert_eq!(scratch.rows_in_lp_max, 0);
for _ in 0..2 {
solver.load_model(&core);
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[0, 1, 2], ¶ms,
&mut scratch,
ctx(),
)
.expect("lazy_solve_preloaded must succeed");
}
assert_eq!(scratch.rows_in_lp_count, 2, "one term folded per solve");
assert_eq!(
scratch.rows_in_lp_sum, 6,
"3 resident cut rows per solve, summed over 2 solves"
);
assert_eq!(
scratch.rows_in_lp_max, 3,
"peak resident set is the 3 seeded cuts"
);
}
#[test]
fn scoring_time_single_pass_result_unchanged() {
let indexer = lazy_indexer();
let pool = make_three_cut_pool();
let mut solver = active_profiled();
let core = core_template();
let mut scratch = DcsSolveScratch::default();
let params = lazy_params(10, 50);
let before = scratch.scoring_time_seconds;
solver.load_model(&core);
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[0, 1, 2], ¶ms,
&mut scratch,
ctx(),
)
.expect("lazy_solve_preloaded must succeed");
let view = scratch.result_view();
assert!(
scratch.scoring_time_seconds >= before,
"scoring accumulator must not decrease"
);
assert_eq!(
view.primal[LAZY_THETA_COL], 5.0,
"single-pass optimum must be the binding-cut floor"
);
assert_eq!(view.objective, 5.0);
assert!(
scratch.out_selected.is_empty(),
"no violation → no slots selected for growth"
);
}
#[test]
fn lazy_solve_tc_fallback_terminates() {
let indexer = lazy_indexer();
let pool = make_three_cut_pool();
let (all_obj, all_theta) = solve_all_cuts(&pool, &indexer);
let mut solver = active_profiled();
let core = core_template();
let mut scratch = DcsSolveScratch::default();
let params = lazy_params(1, 1);
solver.load_model(&core);
let solves_before = solver.statistics().solve_count;
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[], ¶ms,
&mut scratch,
ctx(),
)
.expect("lazy_solve_preloaded TC fallback must succeed");
let solve_delta = solver.statistics().solve_count - solves_before;
let view = scratch.result_view();
assert!((view.objective - all_obj).abs() < 1e-9);
assert!((view.primal[LAZY_THETA_COL] - all_theta).abs() < 1e-9);
assert_eq!(
solve_delta, 3,
"TC fallback issues initial + one capped add/resolve + final TC solve = 3"
);
}
#[test]
fn lazy_solve_is_deterministic() {
let indexer = lazy_indexer();
let pool = make_three_cut_pool();
let core = core_template();
let params = lazy_params(10, 50);
let run = || {
let mut solver = active_profiled();
let mut scratch = DcsSolveScratch::default();
solver.load_model(&core);
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[0],
¶ms,
&mut scratch,
ctx(),
)
.expect("lazy_solve_preloaded must succeed");
scratch.result_view().objective
};
assert_eq!(run(), run(), "objective must be deterministic");
}
#[test]
fn lazy_solve_result_buffers_growth_only() {
let indexer = lazy_indexer();
let pool = make_three_cut_pool();
let mut solver = active_profiled();
let core = core_template();
let mut scratch = DcsSolveScratch::default();
let params = lazy_params(10, 50);
let call = |solver: &mut ProfiledSolver<ActiveSolver>, scratch: &mut DcsSolveScratch| {
solver.load_model(&core);
lazy_solve_preloaded(
solver,
&core,
&pool,
&indexer,
&[],
None,
&[0, 1], ¶ms,
scratch,
ctx(),
)
.expect("lazy_solve_preloaded must succeed");
};
call(&mut solver, &mut scratch);
call(&mut solver, &mut scratch);
let primal_cap = scratch.res_primal.capacity();
let dual_cap = scratch.res_dual.capacity();
let rc_cap = scratch.res_reduced_costs.capacity();
assert!(
!scratch.res_primal.is_empty(),
"result primal must be filled"
);
assert_eq!(
scratch.result_view().primal[LAZY_THETA_COL],
5.0,
"warmed result must be the binding optimum"
);
for _ in 0..3 {
call(&mut solver, &mut scratch);
assert_eq!(
scratch.res_primal.capacity(),
primal_cap,
"res_primal capacity must be stable (growth-only)"
);
assert_eq!(
scratch.res_dual.capacity(),
dual_cap,
"res_dual capacity must be stable (growth-only)"
);
assert_eq!(
scratch.res_reduced_costs.capacity(),
rc_cap,
"res_reduced_costs capacity must be stable (growth-only)"
);
assert_eq!(scratch.result_view().primal[LAZY_THETA_COL], 5.0);
}
}
struct TwoPhaseMock {
first: Vec<f64>,
rest: Vec<f64>,
call_count: usize,
buf: Vec<f64>,
empty: Vec<f64>,
}
impl TwoPhaseMock {
fn new(first: Vec<f64>, rest: Vec<f64>) -> Self {
Self {
first,
rest,
call_count: 0,
buf: Vec::new(),
empty: Vec::new(),
}
}
}
impl SolverInterface for TwoPhaseMock {
type Profile = ActiveProfile;
fn apply_profile(&mut self, _profile: &ActiveProfile) {}
fn solver_name_version(&self) -> String {
"TwoPhaseMock 0.0.0".to_string()
}
fn load_model(&mut self, _template: &StageTemplate) {}
fn add_rows(&mut self, _rows: &RowBatch) {}
fn set_row_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn set_col_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn solve(
&mut self,
_basis: Option<&Basis>,
) -> Result<cobre_solver::SolutionView<'_>, SolverError> {
let src = if self.call_count == 0 {
&self.first
} else {
&self.rest
};
self.call_count += 1;
self.buf.clone_from(src);
Ok(cobre_solver::SolutionView {
objective: self.buf[LAZY_THETA_COL],
primal: &self.buf,
dual: &self.empty,
reduced_costs: &self.empty,
iterations: 0,
solve_time_seconds: 0.0,
})
}
fn get_basis(&mut self, _out: &mut Basis) {}
fn statistics(&self) -> SolverStatistics {
SolverStatistics {
solve_count: self.call_count as u64,
..SolverStatistics::default()
}
}
fn statistics_into(&self, out: &mut SolverStatistics) {
out.copy_from(&self.statistics());
}
fn name(&self) -> &'static str {
"TwoPhaseMock"
}
}
#[test]
fn lazy_solve_tolerates_cold_mid_loop_resolve() {
let indexer = lazy_indexer();
let mut pool = CutPool::new(16, 1, 16, 0);
pool.add_cut(0, 0, 5.0, &[0.0]);
pool.metadata[0].iteration_generated = 1;
let first = vec![STATE_X0, 0.0, 0.0, 0.0];
let rest = vec![STATE_X0, 0.0, 0.0, 5.0];
let mut solver = ProfiledSolver::new(TwoPhaseMock::new(first, rest));
let core = core_template();
let mut scratch = DcsSolveScratch::default();
let params = lazy_params(10, 50);
solver.load_model(&core);
lazy_solve_preloaded(
&mut solver,
&core,
&pool,
&indexer,
&[],
None,
&[], ¶ms,
&mut scratch,
ctx(),
)
.expect("lazy_solve_preloaded must tolerate a cold mid-loop re-solve");
let view = scratch.result_view();
assert_eq!(
view.primal[LAZY_THETA_COL], 5.0,
"loop reaches the no-violation stop on the second (cold) solve"
);
}
#[allow(clippy::cast_possible_truncation)]
fn seed_pool(specs: &[(bool, u64, u64)]) -> CutPool {
let n = specs.len();
let mut pool = CutPool::new(n.max(1), 1, n.max(1) as u32, 0);
for (i, &(active, iteration_generated, last_active_iter)) in specs.iter().enumerate() {
pool.add_cut(0, i as u32, 0.0, &[0.0]);
pool.metadata[i] = CutMetadata {
iteration_generated,
forward_pass_index: i as u32,
active_count: 0,
last_active_iter,
};
pool.active[i] = active;
}
pool.cached_active_count = specs.iter().filter(|&&(a, _, _)| a).count();
pool
}
#[test]
fn seeds_within_k2_window() {
let pool = seed_pool(&[
(true, 1, 10),
(true, 1, 8),
(true, 1, 3),
(true, 1, 10),
(true, 1, 6),
]);
let mut out = Vec::new();
build_initial_resident_set(&pool, 10, 5, &mut out);
assert_eq!(out, vec![0, 1, 3, 4]);
}
#[test]
fn always_seeds_current_iteration_cuts() {
let pool = seed_pool(&[(true, 10, 1), (true, 1, 9)]);
let mut out = Vec::new();
build_initial_resident_set(&pool, 10, 2, &mut out);
assert_eq!(
out,
vec![0, 1],
"current-iteration slot 0 seeded despite stale last_active_iter"
);
}
#[test]
fn excludes_inactive_slots() {
let pool = seed_pool(&[(true, 1, 10), (false, 1, 10), (true, 1, 10)]);
let mut out = Vec::new();
build_initial_resident_set(&pool, 10, 5, &mut out);
assert_eq!(out, vec![0, 2], "inactive slot 1 excluded");
}
#[test]
fn result_is_ascending_and_deterministic() {
let pool = seed_pool(&[
(true, 1, 10),
(true, 1, 9),
(false, 1, 10),
(true, 1, 8),
(true, 1, 10),
]);
let mut a = Vec::new();
let mut b = Vec::new();
build_initial_resident_set(&pool, 10, 5, &mut a);
build_initial_resident_set(&pool, 10, 5, &mut b);
assert_eq!(a, b, "two calls on identical metadata must match");
assert!(
a.windows(2).all(|w| w[0] < w[1]),
"result must be strictly ascending, got {a:?}"
);
assert_eq!(a, vec![0, 1, 3, 4]);
}
#[test]
fn seeds_old_generation_cut_that_bound_recently() {
let pool = seed_pool(&[(true, 1, 11), (true, 1, 1), (true, 1, 9)]);
let mut out = Vec::new();
build_initial_resident_set(&pool, 12, 3, &mut out);
assert_eq!(
out,
vec![0, 2],
"old-generation cuts that bound within the last k2 iterations are \
seeded by binding recency; the never-re-bound cut is excluded"
);
}
#[test]
fn k2_zero_window_boundary() {
let pool = seed_pool(&[(true, 1, 10), (true, 1, 9), (true, 1, 11), (true, 10, 2)]);
let mut out = Vec::new();
build_initial_resident_set(&pool, 10, 0, &mut out);
assert_eq!(out, vec![0, 2, 3]);
}
}