use cobre_solver::{Basis, ProfiledSolver, SolverInterface};
use crate::backward::StagedCut;
use crate::dcs::DcsSolveScratch;
use crate::lp_builder::PatchBuffer;
use crate::risk_measure::{BackwardOutcome, RiskMeasureScratch};
#[derive(Clone, Debug)]
pub struct CapturedBasis {
pub basis: Basis,
pub base_row_count: usize,
pub cut_row_slots: Vec<u32>,
pub state_at_capture: Vec<f64>,
}
pub const BASIS_BROADCAST_WIRE_VERSION: i32 = 1;
impl CapturedBasis {
#[must_use]
pub fn new(
num_cols: usize,
num_rows: usize,
base_row_count: usize,
cut_slot_capacity: usize,
n_state: usize,
) -> Self {
Self {
basis: Basis::new(num_cols, num_rows),
base_row_count,
cut_row_slots: Vec::with_capacity(cut_slot_capacity),
state_at_capture: Vec::with_capacity(n_state),
}
}
pub fn clear(&mut self) {
self.cut_row_slots.clear();
self.state_at_capture.clear();
}
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
pub fn to_broadcast_payload(&self, i32_buf: &mut Vec<i32>, f64_buf: &mut Vec<f64>) {
i32_buf.push(1_i32);
i32_buf.push(BASIS_BROADCAST_WIRE_VERSION);
i32_buf.push(self.basis.col_status.len() as i32);
i32_buf.push(self.basis.row_status.len() as i32);
i32_buf.push(self.base_row_count as i32);
i32_buf.push(self.cut_row_slots.len() as i32);
i32_buf.push(self.state_at_capture.len() as i32);
i32_buf.extend_from_slice(&self.basis.col_status);
i32_buf.extend_from_slice(&self.basis.row_status);
for &slot in &self.cut_row_slots {
i32_buf.push(slot as i32);
}
f64_buf.extend_from_slice(&self.state_at_capture);
}
#[allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss
)]
pub fn try_from_broadcast_payload(
stage: usize,
i32_buf: &[i32],
i32_cursor: &mut usize,
f64_buf: &[f64],
f64_cursor: &mut usize,
) -> Result<Option<Self>, crate::SddpError> {
if *i32_cursor >= i32_buf.len() {
return Err(crate::SddpError::Validation(format!(
"try_from_broadcast_payload: buffer truncated at stage {stage} \
(pos={}, len={})",
*i32_cursor,
i32_buf.len()
)));
}
let sentinel = i32_buf[*i32_cursor];
*i32_cursor += 1;
if sentinel == 0 {
return Ok(None);
}
if *i32_cursor >= i32_buf.len() {
return Err(crate::SddpError::Validation(format!(
"try_from_broadcast_payload: buffer truncated reading version at stage {stage}"
)));
}
let version = i32_buf[*i32_cursor];
*i32_cursor += 1;
if version != BASIS_BROADCAST_WIRE_VERSION {
return Err(crate::SddpError::Validation(format!(
"try_from_broadcast_payload: unsupported wire version {version} at stage \
{stage} (expected {BASIS_BROADCAST_WIRE_VERSION})"
)));
}
if *i32_cursor + 5 > i32_buf.len() {
return Err(crate::SddpError::Validation(format!(
"try_from_broadcast_payload: buffer truncated reading lengths at stage {stage}"
)));
}
let col_len = i32_buf[*i32_cursor] as usize;
*i32_cursor += 1;
let row_len = i32_buf[*i32_cursor] as usize;
*i32_cursor += 1;
let base_row_count = i32_buf[*i32_cursor] as usize;
*i32_cursor += 1;
let cut_slot_count = i32_buf[*i32_cursor] as usize;
*i32_cursor += 1;
let state_len = i32_buf[*i32_cursor] as usize;
*i32_cursor += 1;
if *i32_cursor + col_len > i32_buf.len() {
return Err(crate::SddpError::Validation(format!(
"try_from_broadcast_payload: buffer truncated reading col_status at stage \
{stage} (need {col_len}, have {})",
i32_buf.len() - *i32_cursor
)));
}
let col_status = i32_buf[*i32_cursor..*i32_cursor + col_len].to_vec();
*i32_cursor += col_len;
if *i32_cursor + row_len > i32_buf.len() {
return Err(crate::SddpError::Validation(format!(
"try_from_broadcast_payload: buffer truncated reading row_status at stage \
{stage} (need {row_len}, have {})",
i32_buf.len() - *i32_cursor
)));
}
let row_status = i32_buf[*i32_cursor..*i32_cursor + row_len].to_vec();
*i32_cursor += row_len;
if *i32_cursor + cut_slot_count > i32_buf.len() {
return Err(crate::SddpError::Validation(format!(
"try_from_broadcast_payload: buffer truncated reading cut_row_slots at stage \
{stage} (need {cut_slot_count}, have {})",
i32_buf.len() - *i32_cursor
)));
}
let cut_row_slots: Vec<u32> = i32_buf[*i32_cursor..*i32_cursor + cut_slot_count]
.iter()
.map(|&v| v as u32)
.collect();
*i32_cursor += cut_slot_count;
if *f64_cursor + state_len > f64_buf.len() {
return Err(crate::SddpError::Validation(format!(
"try_from_broadcast_payload: f64 buffer truncated reading state_at_capture at \
stage {stage} (need {state_len}, have {})",
f64_buf.len() - *f64_cursor
)));
}
let state_at_capture = f64_buf[*f64_cursor..*f64_cursor + state_len].to_vec();
*f64_cursor += state_len;
Ok(Some(Self {
basis: cobre_solver::Basis {
col_status,
row_status,
},
base_row_count,
cut_row_slots,
state_at_capture,
}))
}
}
#[derive(Clone, Copy, Debug, Default)]
pub struct WorkspaceSizing {
pub hydro_count: usize,
pub max_par_order: usize,
pub n_load_buses: usize,
pub max_blocks: usize,
pub downstream_par_order: usize,
pub max_openings: usize,
pub initial_pool_capacity: usize,
pub n_state: usize,
pub max_local_fwd: usize,
pub total_forward_passes: usize,
pub noise_dim: usize,
pub n_anticipated: usize,
pub k_max: usize,
}
#[derive(Default)]
pub(crate) struct BackwardAccumulators {
pub(crate) outcomes: Vec<BackwardOutcome>,
pub(crate) slot_increments: Vec<u64>,
pub(crate) agg_coefficients: Vec<f64>,
pub(crate) agg_arena: Vec<f64>,
pub(crate) metadata_sync_contribution: Vec<u64>,
pub(crate) per_opening_stats: Vec<crate::solver_stats::SolverStatsDelta>,
pub(crate) state_duals_buf: Vec<f64>,
pub(crate) cut_duals_buf: Vec<f64>,
pub(crate) staged_cuts_buf: Vec<StagedCut>,
pub(crate) risk_scratch: RiskMeasureScratch,
pub(crate) dcs_solve: DcsSolveScratch,
pub(crate) dcs_initial_resident: Vec<u32>,
pub(crate) stats_before_buf: cobre_solver::SolverStatistics,
pub(crate) stats_after_buf: cobre_solver::SolverStatistics,
}
impl BackwardAccumulators {
pub(crate) fn new(max_openings: usize, initial_pool_capacity: usize, n_state: usize) -> Self {
let outcomes = (0..max_openings)
.map(|_| BackwardOutcome {
intercept: 0.0,
coefficients: vec![0.0_f64; n_state],
objective_value: 0.0,
})
.collect();
let mut dcs_solve = DcsSolveScratch::default();
dcs_solve.reserve(n_state, initial_pool_capacity);
Self {
outcomes,
slot_increments: vec![0u64; initial_pool_capacity],
agg_coefficients: vec![0.0_f64; n_state],
agg_arena: Vec::new(),
metadata_sync_contribution: vec![0u64; initial_pool_capacity],
per_opening_stats: Vec::new(),
state_duals_buf: Vec::new(),
cut_duals_buf: Vec::new(),
staged_cuts_buf: Vec::new(),
risk_scratch: RiskMeasureScratch::new(),
dcs_solve,
dcs_initial_resident: Vec::with_capacity(initial_pool_capacity),
stats_before_buf: cobre_solver::SolverStatistics::default(),
stats_after_buf: cobre_solver::SolverStatistics::default(),
}
}
}
#[allow(clippy::struct_field_names)]
pub(crate) struct ScratchBuffers {
pub(crate) noise_buf: Vec<f64>,
pub(crate) inflow_m3s_buf: Vec<f64>,
pub(crate) lag_matrix_buf: Vec<f64>,
pub(crate) par_inflow_buf: Vec<f64>,
pub(crate) eta_floor_buf: Vec<f64>,
pub(crate) zero_targets_buf: Vec<f64>,
pub(crate) ncs_col_upper_buf: Vec<f64>,
pub(crate) ncs_col_lower_buf: Vec<f64>,
pub(crate) ncs_col_indices_buf: Vec<usize>,
pub(crate) load_rhs_buf: Vec<f64>,
pub(crate) row_lower_buf: Vec<f64>,
pub(crate) z_inflow_rhs_buf: Vec<f64>,
pub(crate) effective_eta_buf: Vec<f64>,
pub(crate) unscaled_primal: Vec<f64>,
pub(crate) unscaled_dual: Vec<f64>,
pub(crate) lag_accumulator: Vec<f64>,
pub(crate) lag_weight_accum: f64,
pub(crate) downstream_accumulator: Vec<f64>,
pub(crate) downstream_weight_accum: f64,
pub(crate) downstream_completed_lags: Vec<f64>,
pub(crate) downstream_n_completed: usize,
pub(crate) recon_slot_lookup: Vec<Option<u32>>,
pub(crate) trajectory_costs_buf: Vec<f64>,
pub(crate) raw_noise_buf: Vec<f64>,
pub(crate) perm_scratch: Vec<usize>,
pub(crate) anticipated_state_buf: Vec<f64>,
}
pub struct SolverWorkspace<S: SolverInterface> {
pub rank: i32,
pub worker_id: i32,
pub solver: ProfiledSolver<S>,
pub patch_buf: PatchBuffer,
pub current_state: Vec<f64>,
pub(crate) scratch: ScratchBuffers,
pub(crate) scratch_basis: Basis,
pub(crate) backward_accum: BackwardAccumulators,
pub worker_timing_buf: cobre_core::WorkerPhaseTimings,
}
impl<S: SolverInterface> SolverWorkspace<S> {
#[must_use]
pub fn new(
rank: i32,
worker_id: i32,
solver: S,
patch_buf: PatchBuffer,
n_state: usize,
sizing: WorkspaceSizing,
) -> Self {
Self {
rank,
worker_id,
solver: ProfiledSolver::new(solver),
patch_buf,
current_state: Vec::with_capacity(n_state),
scratch: ScratchBuffers::new(sizing),
scratch_basis: Basis::new(0, 0),
backward_accum: BackwardAccumulators::new(
sizing.max_openings,
sizing.initial_pool_capacity,
sizing.n_state,
),
worker_timing_buf: cobre_core::WorkerPhaseTimings::default(),
}
}
}
impl ScratchBuffers {
pub(crate) fn new(s: WorkspaceSizing) -> Self {
let WorkspaceSizing {
hydro_count,
max_par_order,
n_load_buses,
max_blocks,
downstream_par_order,
initial_pool_capacity,
max_local_fwd,
total_forward_passes,
noise_dim,
n_anticipated,
k_max,
..
} = s;
Self {
noise_buf: Vec::with_capacity(hydro_count),
inflow_m3s_buf: Vec::with_capacity(hydro_count),
lag_matrix_buf: Vec::with_capacity(max_par_order * hydro_count),
par_inflow_buf: Vec::with_capacity(hydro_count),
eta_floor_buf: Vec::with_capacity(hydro_count),
zero_targets_buf: vec![0.0_f64; hydro_count],
ncs_col_upper_buf: Vec::new(),
ncs_col_lower_buf: Vec::new(),
ncs_col_indices_buf: Vec::new(),
load_rhs_buf: Vec::with_capacity(n_load_buses * max_blocks),
row_lower_buf: Vec::new(),
z_inflow_rhs_buf: Vec::with_capacity(hydro_count),
effective_eta_buf: Vec::with_capacity(hydro_count),
unscaled_primal: Vec::new(),
unscaled_dual: Vec::new(),
lag_accumulator: vec![0.0_f64; hydro_count],
lag_weight_accum: 0.0,
downstream_accumulator: if downstream_par_order > 0 {
vec![0.0_f64; hydro_count]
} else {
Vec::new()
},
downstream_weight_accum: 0.0,
downstream_completed_lags: if downstream_par_order > 0 {
vec![0.0_f64; hydro_count * downstream_par_order]
} else {
Vec::new()
},
downstream_n_completed: 0,
recon_slot_lookup: vec![None; initial_pool_capacity],
trajectory_costs_buf: Vec::with_capacity(max_local_fwd),
raw_noise_buf: Vec::with_capacity(noise_dim),
perm_scratch: Vec::with_capacity(total_forward_passes.max(1)),
anticipated_state_buf: Vec::with_capacity(n_anticipated * k_max),
}
}
}
pub struct WorkspacePool<S: SolverInterface> {
pub workspaces: Vec<SolverWorkspace<S>>,
}
impl<S: SolverInterface> WorkspacePool<S> {
#[must_use]
#[allow(clippy::expect_used)]
pub fn new(
rank: i32,
n_threads: usize,
n_state: usize,
sizing: WorkspaceSizing,
solver_factory: impl Fn() -> S,
) -> Self {
let workspaces = (0..n_threads)
.map(|idx| {
let worker_id =
i32::try_from(idx).expect("worker_id fits in i32 (rayon pools are small)");
SolverWorkspace {
rank,
worker_id,
solver: ProfiledSolver::new(solver_factory()),
patch_buf: PatchBuffer::new(
sizing.hydro_count,
sizing.max_par_order,
sizing.n_load_buses,
sizing.max_blocks,
sizing.n_anticipated,
sizing.k_max,
),
current_state: Vec::with_capacity(n_state),
scratch: ScratchBuffers::new(sizing),
scratch_basis: Basis::new(0, 0),
backward_accum: BackwardAccumulators::new(
sizing.max_openings,
sizing.initial_pool_capacity,
sizing.n_state,
),
worker_timing_buf: cobre_core::WorkerPhaseTimings::default(),
}
})
.collect();
Self { workspaces }
}
#[allow(clippy::expect_used)]
pub fn try_new<E>(
rank: i32,
n_threads: usize,
n_state: usize,
sizing: WorkspaceSizing,
solver_factory: impl Fn() -> Result<S, E>,
) -> Result<Self, E> {
let mut workspaces = Vec::with_capacity(n_threads);
for idx in 0..n_threads {
let worker_id =
i32::try_from(idx).expect("worker_id fits in i32 (rayon pools are small)");
workspaces.push(SolverWorkspace {
rank,
worker_id,
solver: ProfiledSolver::new(solver_factory()?),
patch_buf: PatchBuffer::new(
sizing.hydro_count,
sizing.max_par_order,
sizing.n_load_buses,
sizing.max_blocks,
sizing.n_anticipated,
sizing.k_max,
),
current_state: Vec::with_capacity(n_state),
scratch: ScratchBuffers::new(sizing),
scratch_basis: Basis::new(0, 0),
backward_accum: BackwardAccumulators::new(
sizing.max_openings,
sizing.initial_pool_capacity,
sizing.n_state,
),
worker_timing_buf: cobre_core::WorkerPhaseTimings::default(),
});
}
Ok(Self { workspaces })
}
pub(crate) fn resize_scratch_bases(&mut self, max_cols: usize, max_rows: usize) {
for ws in &mut self.workspaces {
ws.scratch_basis = Basis::new(max_cols, max_rows);
}
}
}
pub struct BasisStore {
bases: Vec<Option<CapturedBasis>>,
num_stages: usize,
}
impl BasisStore {
#[must_use]
pub fn new(num_scenarios: usize, num_stages: usize) -> Self {
let len = num_scenarios * num_stages;
Self {
bases: vec![None; len],
num_stages,
}
}
#[must_use]
pub fn num_scenarios(&self) -> usize {
self.bases.len().checked_div(self.num_stages).unwrap_or(0)
}
#[must_use]
pub fn num_stages(&self) -> usize {
self.num_stages
}
#[must_use]
pub fn get(&self, scenario: usize, stage: usize) -> Option<&CapturedBasis> {
self.bases[scenario * self.num_stages + stage].as_ref()
}
pub fn get_mut(&mut self, scenario: usize, stage: usize) -> &mut Option<CapturedBasis> {
&mut self.bases[scenario * self.num_stages + stage]
}
#[must_use]
pub fn split_workers_mut(&mut self, n_workers: usize) -> Vec<BasisStoreSliceMut<'_>> {
debug_assert!(n_workers > 0, "n_workers must be > 0");
let total_scenarios = self.num_scenarios();
let mut slices = Vec::with_capacity(n_workers);
let mut bases_rem = self.bases.as_mut_slice();
let mut offset = 0usize;
for w in 0..n_workers {
let (start, end) = crate::solve::partition(total_scenarios, n_workers, w);
let count = end - start;
let chunk = count * self.num_stages;
let (bases_left, bases_rest) = bases_rem.split_at_mut(chunk);
bases_rem = bases_rest;
slices.push(BasisStoreSliceMut {
bases: bases_left,
scenario_offset: offset,
num_stages: self.num_stages,
});
offset += count;
}
slices
}
}
pub struct BasisStoreSliceMut<'a> {
bases: &'a mut [Option<CapturedBasis>],
scenario_offset: usize,
num_stages: usize,
}
impl BasisStoreSliceMut<'_> {
#[must_use]
pub fn get(&self, scenario: usize, stage: usize) -> Option<&CapturedBasis> {
let local = scenario - self.scenario_offset;
self.bases[local * self.num_stages + stage].as_ref()
}
pub fn get_mut(&mut self, scenario: usize, stage: usize) -> &mut Option<CapturedBasis> {
let local = scenario - self.scenario_offset;
&mut self.bases[local * self.num_stages + stage]
}
}
#[cfg(test)]
mod tests {
use super::{
BasisStore, CapturedBasis, ScratchBuffers, SolverWorkspace, WorkspacePool, WorkspaceSizing,
};
use cobre_solver::{
Basis, SolutionView, SolverError, SolverInterface, SolverStatistics,
types::{RowBatch, StageTemplate},
};
struct MockSolver;
impl SolverInterface for MockSolver {
type Profile = cobre_solver::ActiveProfile;
fn apply_profile(&mut self, _profile: &cobre_solver::ActiveProfile) {}
fn solver_name_version(&self) -> String {
"MockSolver 0.0.0".to_string()
}
fn load_model(&mut self, _t: &StageTemplate) {}
fn add_rows(&mut self, _r: &RowBatch) {}
fn set_row_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn set_col_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn solve(&mut self, _basis: Option<&Basis>) -> Result<SolutionView<'_>, SolverError> {
Err(SolverError::InternalError {
message: "mock".into(),
error_code: None,
})
}
fn get_basis(&mut self, _out: &mut Basis) {}
fn statistics(&self) -> SolverStatistics {
SolverStatistics::default()
}
fn statistics_into(&self, out: &mut SolverStatistics) {
out.copy_from(&SolverStatistics::default());
}
fn name(&self) -> &'static str {
"Mock"
}
}
fn assert_send<T: Send>() {}
#[test]
fn test_workspace_send_bound() {
assert_send::<SolverWorkspace<MockSolver>>();
}
fn sizing(
hydro_count: usize,
max_par_order: usize,
downstream_par_order: usize,
) -> WorkspaceSizing {
WorkspaceSizing {
hydro_count,
max_par_order,
n_load_buses: 0,
max_blocks: 0,
downstream_par_order,
..WorkspaceSizing::default()
}
}
#[test]
fn test_workspace_pool_size() {
let pool = WorkspacePool::new(0, 4, 9, sizing(3, 2, 0), || MockSolver);
assert_eq!(pool.workspaces.len(), 4);
}
#[test]
fn test_workspace_buffer_dimensions() {
let pool = WorkspacePool::new(0, 4, 9, sizing(3, 2, 0), || MockSolver);
for ws in &pool.workspaces {
assert_eq!(ws.patch_buf.indices.len(), 6, "patch_buf length");
assert_eq!(ws.current_state.capacity(), 9, "current_state capacity");
assert_eq!(ws.current_state.len(), 0, "current_state starts empty");
}
}
#[test]
fn test_workspace_pool_zero_threads() {
let pool = WorkspacePool::new(0, 0, 9, sizing(3, 2, 0), || MockSolver);
assert_eq!(pool.workspaces.len(), 0);
}
#[test]
fn test_workspace_pool_single_thread() {
let pool = WorkspacePool::new(0, 1, 0, sizing(0, 0, 0), || MockSolver);
assert_eq!(pool.workspaces.len(), 1);
assert_eq!(pool.workspaces[0].patch_buf.indices.len(), 0);
}
#[test]
fn test_workspace_pool_each_solver_independent() {
let n = 6;
let pool = WorkspacePool::new(0, n, 1, sizing(1, 0, 0), || MockSolver);
assert_eq!(pool.workspaces.len(), n);
}
#[test]
fn test_scratch_buffers_zero_downstream_par_order_empty_buffers() {
let scratch = ScratchBuffers::new(WorkspaceSizing {
hydro_count: 5,
max_par_order: 2,
n_load_buses: 0,
max_blocks: 1,
downstream_par_order: 0,
..WorkspaceSizing::default()
});
assert!(
scratch.downstream_accumulator.is_empty(),
"downstream_accumulator must be empty when downstream_par_order=0"
);
assert!(
scratch.downstream_completed_lags.is_empty(),
"downstream_completed_lags must be empty when downstream_par_order=0"
);
assert_eq!(
scratch.downstream_weight_accum, 0.0,
"downstream_weight_accum must be 0.0"
);
assert_eq!(
scratch.downstream_n_completed, 0,
"downstream_n_completed must be 0"
);
}
#[test]
fn test_scratch_buffers_nonzero_downstream_par_order_allocates_correctly() {
let scratch = ScratchBuffers::new(WorkspaceSizing {
hydro_count: 3,
max_par_order: 2,
n_load_buses: 0,
max_blocks: 1,
downstream_par_order: 2,
..WorkspaceSizing::default()
});
assert_eq!(
scratch.downstream_accumulator.len(),
3,
"downstream_accumulator.len() must equal hydro_count"
);
assert_eq!(
scratch.downstream_completed_lags.len(),
6,
"downstream_completed_lags.len() must equal hydro_count * downstream_par_order"
);
assert!(
scratch.downstream_accumulator.iter().all(|&v| v == 0.0),
"downstream_accumulator must be initialized to 0.0"
);
assert!(
scratch.downstream_completed_lags.iter().all(|&v| v == 0.0),
"downstream_completed_lags must be initialized to 0.0"
);
assert_eq!(scratch.downstream_weight_accum, 0.0);
assert_eq!(scratch.downstream_n_completed, 0);
}
#[test]
fn test_workspace_pool_propagates_downstream_par_order() {
let pool = WorkspacePool::new(
0,
2,
6,
WorkspaceSizing {
hydro_count: 3,
max_par_order: 2,
n_load_buses: 0,
max_blocks: 1,
downstream_par_order: 2,
..WorkspaceSizing::default()
},
|| MockSolver,
);
for ws in &pool.workspaces {
assert_eq!(
ws.scratch.downstream_accumulator.len(),
3,
"downstream_accumulator.len() per workspace"
);
assert_eq!(
ws.scratch.downstream_completed_lags.len(),
6,
"downstream_completed_lags.len() per workspace"
);
assert_eq!(ws.scratch.downstream_weight_accum, 0.0);
assert_eq!(ws.scratch.downstream_n_completed, 0);
}
}
#[test]
fn basis_store_new_all_none() {
let store = BasisStore::new(3, 5);
assert_eq!(store.num_scenarios(), 3);
assert_eq!(store.num_stages(), 5);
for s in 0..3 {
for t in 0..5 {
assert!(
store.get(s, t).is_none(),
"slot [{s}][{t}] must start as None"
);
}
}
}
#[test]
fn basis_store_get_mut_set_and_retrieve() {
let mut store = BasisStore::new(2, 3);
*store.get_mut(1, 2) = Some(CapturedBasis::new(4, 2, 0, 0, 0));
assert!(store.get(1, 2).is_some());
assert!(store.get(0, 0).is_none());
assert!(store.get(1, 0).is_none());
}
#[test]
fn basis_store_zero_scenarios() {
let store = BasisStore::new(0, 5);
assert_eq!(store.num_scenarios(), 0);
assert_eq!(store.num_stages(), 5);
}
#[test]
fn basis_store_zero_stages() {
let store = BasisStore::new(3, 0);
assert_eq!(store.num_scenarios(), 0);
assert_eq!(store.num_stages(), 0);
}
#[test]
fn basis_store_split_workers_mut_disjoint_writes() {
let mut store = BasisStore::new(4, 3);
let mut slices = store.split_workers_mut(2);
*slices[0].get_mut(0, 1) = Some(CapturedBasis::new(2, 1, 0, 0, 0));
*slices[1].get_mut(3, 2) = Some(CapturedBasis::new(2, 1, 0, 0, 0));
drop(slices);
assert!(
store.get(0, 1).is_some(),
"scenario 0 stage 1 must be populated"
);
assert!(
store.get(3, 2).is_some(),
"scenario 3 stage 2 must be populated"
);
assert!(store.get(0, 0).is_none());
assert!(store.get(3, 0).is_none());
}
#[test]
fn basis_store_split_single_worker() {
let mut store = BasisStore::new(3, 2);
let mut slices = store.split_workers_mut(1);
*slices[0].get_mut(2, 1) = Some(CapturedBasis::new(1, 0, 0, 0, 0));
drop(slices);
assert!(store.get(2, 1).is_some());
}
#[test]
fn basis_store_split_more_workers_than_scenarios() {
let mut store = BasisStore::new(2, 3);
let slices = store.split_workers_mut(4);
assert_eq!(slices.len(), 4);
assert_eq!(slices[0].bases.len(), 3); assert_eq!(slices[1].bases.len(), 3);
assert_eq!(slices[2].bases.len(), 0);
assert_eq!(slices[3].bases.len(), 0);
}
#[test]
fn basis_store_slice_offset_correct() {
let mut store = BasisStore::new(6, 2);
let mut slices = store.split_workers_mut(3);
*slices[1].get_mut(2, 0) = Some(CapturedBasis::new(1, 0, 0, 0, 0));
*slices[1].get_mut(3, 1) = Some(CapturedBasis::new(1, 0, 0, 0, 0));
drop(slices);
assert!(store.get(2, 0).is_some());
assert!(store.get(3, 1).is_some());
assert!(store.get(0, 0).is_none());
assert!(store.get(4, 0).is_none());
}
#[test]
fn test_captured_basis_new_capacities() {
let cb = CapturedBasis::new(4, 6, 3, 10, 2);
assert_eq!(cb.basis.row_status.len(), 6, "row_status length");
assert_eq!(cb.base_row_count, 3, "base_row_count");
assert!(
cb.cut_row_slots.capacity() >= 10,
"cut_row_slots capacity must be >= 10 (got {})",
cb.cut_row_slots.capacity()
);
assert_eq!(cb.cut_row_slots.len(), 0, "cut_row_slots starts empty");
assert!(
cb.state_at_capture.capacity() >= 2,
"state_at_capture capacity must be >= 2 (got {})",
cb.state_at_capture.capacity()
);
assert_eq!(
cb.state_at_capture.len(),
0,
"state_at_capture starts empty"
);
}
#[test]
fn test_basis_store_holds_captured_basis() {
let mut store = BasisStore::new(3, 5);
for s in 0..3 {
for t in 0..5 {
assert!(
store.get(s, t).is_none(),
"slot [{s}][{t}] must be None before any write"
);
}
}
*store.get_mut(1, 3) = Some(CapturedBasis::new(4, 6, 3, 10, 2));
let retrieved = store.get(1, 3);
assert!(retrieved.is_some(), "slot [1][3] must be Some after write");
let cb = retrieved.expect("just checked is_some");
assert_eq!(cb.base_row_count, 3);
for s in 0..3 {
for t in 0..5 {
if s == 1 && t == 3 {
continue;
}
assert!(
store.get(s, t).is_none(),
"slot [{s}][{t}] must remain None"
);
}
}
}
#[test]
fn test_recon_slot_lookup_presized() {
let pool = WorkspacePool::new(
0,
4,
0,
WorkspaceSizing {
initial_pool_capacity: 50,
..WorkspaceSizing::default()
},
|| MockSolver,
);
for (i, ws) in pool.workspaces.iter().enumerate() {
assert_eq!(
ws.scratch.recon_slot_lookup.len(),
50,
"workspace {i}: recon_slot_lookup.len() must equal initial_pool_capacity (50)"
);
assert!(
ws.scratch.recon_slot_lookup.iter().all(Option::is_none),
"workspace {i}: all recon_slot_lookup entries must be None"
);
}
let pool_empty = WorkspacePool::new(0, 1, 0, WorkspaceSizing::default(), || MockSolver);
assert_eq!(
pool_empty.workspaces[0].scratch.recon_slot_lookup.len(),
0,
"initial_pool_capacity=0 must produce empty recon_slot_lookup"
);
}
#[test]
fn test_workspace_pool_assigns_sequential_worker_ids() {
let pool = WorkspacePool::new(
3,
5,
0,
WorkspaceSizing::default(),
|| MockSolver,
);
let ws_slice = &pool.workspaces;
assert_eq!(ws_slice.len(), 5);
let mut seen: std::collections::HashSet<i32> = std::collections::HashSet::new();
for ws in ws_slice {
assert_eq!(ws.rank, 3);
assert!(ws.worker_id >= 0 && ws.worker_id < 5);
assert!(seen.insert(ws.worker_id), "worker_id duplicated");
}
}
#[test]
fn test_captured_basis_round_trip_populated() {
let original = CapturedBasis {
basis: Basis {
col_status: vec![1_i32, 2, 3],
row_status: vec![4_i32, 5],
},
base_row_count: 1,
cut_row_slots: vec![10_u32, 20],
state_at_capture: vec![1.5_f64, 2.5, 3.5],
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
original.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let result = CapturedBasis::try_from_broadcast_payload(
0,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect("round-trip must not fail");
let recovered = result.expect("sentinel is 1; must return Some");
assert_eq!(
recovered.basis.col_status, original.basis.col_status,
"col_status"
);
assert_eq!(
recovered.basis.row_status, original.basis.row_status,
"row_status"
);
assert_eq!(
recovered.base_row_count, original.base_row_count,
"base_row_count"
);
assert_eq!(
recovered.cut_row_slots, original.cut_row_slots,
"cut_row_slots"
);
assert_eq!(
recovered.state_at_capture, original.state_at_capture,
"state_at_capture"
);
assert_eq!(i32_cursor, i32_buf.len(), "i32_cursor must be at end");
assert_eq!(f64_cursor, f64_buf.len(), "f64_cursor must be at end");
}
#[test]
fn test_captured_basis_round_trip_empty_metadata() {
let original = CapturedBasis {
basis: Basis {
col_status: vec![7_i32, 8],
row_status: vec![9_i32],
},
base_row_count: 1,
cut_row_slots: vec![],
state_at_capture: vec![],
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
original.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let result = CapturedBasis::try_from_broadcast_payload(
0,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect("round-trip must not fail");
let recovered = result.expect("sentinel is 1; must return Some");
assert_eq!(recovered.basis.col_status, original.basis.col_status);
assert_eq!(recovered.basis.row_status, original.basis.row_status);
assert_eq!(recovered.base_row_count, original.base_row_count);
assert!(
recovered.cut_row_slots.is_empty(),
"cut_row_slots must be empty"
);
assert!(
recovered.state_at_capture.is_empty(),
"state_at_capture must be empty"
);
assert_eq!(i32_cursor, i32_buf.len(), "i32_cursor must be at end");
assert_eq!(f64_cursor, f64_buf.len(), "f64_cursor must be at end");
}
#[test]
fn test_captured_basis_round_trip_multi_stage() {
let populated = CapturedBasis {
basis: Basis {
col_status: vec![11_i32, 22, 33],
row_status: vec![44_i32, 55, 66],
},
base_row_count: 2,
cut_row_slots: vec![100_u32, 200, 300],
state_at_capture: vec![0.1_f64, 0.2],
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
populated.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
i32_buf.push(0_i32);
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let stage0 = CapturedBasis::try_from_broadcast_payload(
0,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect("stage 0 must not fail")
.expect("stage 0 sentinel is 1; must return Some");
assert_eq!(stage0.basis.col_status, populated.basis.col_status);
assert_eq!(stage0.basis.row_status, populated.basis.row_status);
assert_eq!(stage0.base_row_count, populated.base_row_count);
assert_eq!(stage0.cut_row_slots, populated.cut_row_slots);
assert_eq!(stage0.state_at_capture, populated.state_at_capture);
let stage1 = CapturedBasis::try_from_broadcast_payload(
1,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect("stage 1 must not fail");
assert!(stage1.is_none(), "stage 1 sentinel is 0; must return None");
assert_eq!(
i32_cursor,
i32_buf.len(),
"i32_cursor must be at end after both stages"
);
assert_eq!(
f64_cursor,
f64_buf.len(),
"f64_cursor must be at end after both stages"
);
}
#[test]
fn test_captured_basis_truncated_i32_buffer() {
use crate::SddpError;
let cb = CapturedBasis {
basis: Basis {
col_status: vec![1_i32, 2],
row_status: vec![3_i32],
},
base_row_count: 1,
cut_row_slots: vec![5_u32],
state_at_capture: vec![9.9_f64],
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
cb.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
i32_buf.pop();
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let err = CapturedBasis::try_from_broadcast_payload(
7,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect_err("truncated buffer must return Err");
match err {
SddpError::Validation(ref msg) => {
assert!(
msg.contains("truncated"),
"error message must contain 'truncated', got: {msg}"
);
assert!(
msg.contains('7'),
"error message must contain stage index 7, got: {msg}"
);
}
other => panic!("expected SddpError::Validation, got {other:?}"),
}
}
#[test]
fn test_captured_basis_truncated_f64_buffer() {
use crate::SddpError;
let cb = CapturedBasis {
basis: Basis {
col_status: vec![1_i32],
row_status: vec![2_i32],
},
base_row_count: 1,
cut_row_slots: vec![],
state_at_capture: vec![1.0_f64, 2.0, 3.0],
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
cb.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
f64_buf.pop();
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let err = CapturedBasis::try_from_broadcast_payload(
3,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect_err("truncated f64 buffer must return Err");
match err {
SddpError::Validation(ref msg) => {
assert!(
msg.contains("state_at_capture"),
"error message must contain 'state_at_capture', got: {msg}"
);
assert!(
msg.contains('3'),
"error message must contain stage index 3, got: {msg}"
);
}
other => panic!("expected SddpError::Validation, got {other:?}"),
}
}
#[test]
fn to_broadcast_payload_emits_version_byte() {
use super::BASIS_BROADCAST_WIRE_VERSION;
let original = CapturedBasis {
basis: Basis {
col_status: vec![1_i32, 2, 3, 4],
row_status: vec![5_i32, 6, 7, 8],
},
base_row_count: 2,
cut_row_slots: vec![10_u32, 20],
state_at_capture: vec![0.5_f64, 1.5],
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
original.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
assert_eq!(i32_buf[0], 1_i32, "offset 0 must be the presence sentinel");
assert_eq!(
i32_buf[1], BASIS_BROADCAST_WIRE_VERSION,
"offset 1 must be BASIS_BROADCAST_WIRE_VERSION"
);
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let recovered = CapturedBasis::try_from_broadcast_payload(
0,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect("round-trip must not fail")
.expect("sentinel is 1; must return Some");
assert_eq!(recovered.basis.col_status, original.basis.col_status);
assert_eq!(recovered.basis.row_status, original.basis.row_status);
assert_eq!(recovered.base_row_count, original.base_row_count);
assert_eq!(recovered.cut_row_slots, original.cut_row_slots);
assert_eq!(recovered.state_at_capture, original.state_at_capture);
assert_eq!(i32_cursor, i32_buf.len(), "i32_cursor must be at end");
assert_eq!(f64_cursor, f64_buf.len(), "f64_cursor must be at end");
}
#[test]
fn try_from_broadcast_payload_rejects_wrong_version() {
use crate::SddpError;
let cb = CapturedBasis {
basis: Basis {
col_status: vec![1_i32, 2, 3, 4],
row_status: vec![5_i32, 6, 7, 8],
},
base_row_count: 2,
cut_row_slots: vec![10_u32, 20],
state_at_capture: vec![0.5_f64, 1.5],
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
cb.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
i32_buf[1] = 2_i32;
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let err = CapturedBasis::try_from_broadcast_payload(
0,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect_err("mismatched version must return Err");
match err {
SddpError::Validation(ref msg) => {
assert!(
msg.contains("unsupported wire version 2"),
"error must contain 'unsupported wire version 2', got: {msg}"
);
}
other => panic!("expected SddpError::Validation, got {other:?}"),
}
}
#[test]
fn try_from_broadcast_payload_none_does_not_consume_version_byte() {
let populated = CapturedBasis {
basis: Basis {
col_status: vec![7_i32],
row_status: vec![8_i32],
},
base_row_count: 1,
cut_row_slots: vec![],
state_at_capture: vec![],
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
i32_buf.push(0_i32);
populated.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let stage0 = CapturedBasis::try_from_broadcast_payload(
0,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect("stage 0 must not fail");
assert!(stage0.is_none(), "stage 0 sentinel is 0; must return None");
assert_eq!(
i32_cursor, 1,
"None path must advance cursor by exactly 1 (only the sentinel)"
);
let stage1 = CapturedBasis::try_from_broadcast_payload(
1,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect("stage 1 must not fail")
.expect("stage 1 sentinel is 1; must return Some");
assert_eq!(stage1.basis.col_status, populated.basis.col_status);
assert_eq!(stage1.basis.row_status, populated.basis.row_status);
assert_eq!(i32_cursor, i32_buf.len(), "i32_cursor must be at end");
assert_eq!(f64_cursor, f64_buf.len(), "f64_cursor must be at end");
}
#[test]
fn test_captured_basis_round_trip_includes_anticipated_state() {
let state_at_capture = vec![1.0_f64, 2.0, 100.0, 200.0, 1000.0, 2000.0];
let original = CapturedBasis {
basis: Basis {
col_status: vec![1_i32, 2, 3, 4],
row_status: vec![5_i32, 6, 7],
},
base_row_count: 2,
cut_row_slots: vec![10_u32, 20],
state_at_capture: state_at_capture.clone(),
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
original.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let recovered = CapturedBasis::try_from_broadcast_payload(
0,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect("round-trip must not fail")
.expect("sentinel is 1; must return Some");
assert_eq!(
recovered.state_at_capture, state_at_capture,
"full state must roundtrip bit-exactly"
);
assert_eq!(
&recovered.state_at_capture[4..6],
&[1000.0, 2000.0],
"anticipated slice must roundtrip bit-exactly"
);
assert_eq!(
&recovered.state_at_capture[2..4],
&[100.0, 200.0],
"lag slice must roundtrip bit-exactly"
);
assert_eq!(
&recovered.state_at_capture[0..2],
&[1.0, 2.0],
"storage slice must roundtrip bit-exactly"
);
assert_eq!(i32_cursor, i32_buf.len(), "i32_cursor at end");
assert_eq!(f64_cursor, f64_buf.len(), "f64_cursor at end");
}
#[test]
fn test_captured_basis_state_at_capture_length_is_recorded_correctly() {
let small = CapturedBasis {
basis: Basis {
col_status: vec![1_i32; 4],
row_status: vec![1_i32; 3],
},
base_row_count: 2,
cut_row_slots: vec![],
state_at_capture: vec![0.0_f64; 6],
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
small.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
assert_eq!(
i32_buf[6], 6_i32,
"state_at_capture length field must be 6 for N=2 L=1 A=1 K=2"
);
let empty_state = CapturedBasis {
basis: Basis {
col_status: vec![1_i32],
row_status: vec![1_i32],
},
base_row_count: 1,
cut_row_slots: vec![],
state_at_capture: vec![],
};
let mut i32_buf2: Vec<i32> = Vec::new();
let mut f64_buf2: Vec<f64> = Vec::new();
empty_state.to_broadcast_payload(&mut i32_buf2, &mut f64_buf2);
assert_eq!(
i32_buf2[6], 0_i32,
"state_at_capture length field must be 0 for empty state"
);
let large_state: Vec<f64> = (0..15).map(|i| f64::from(i) * 10.0).collect();
let large = CapturedBasis {
basis: Basis {
col_status: vec![1_i32; 5],
row_status: vec![1_i32; 5],
},
base_row_count: 3,
cut_row_slots: vec![1_u32, 2],
state_at_capture: large_state.clone(),
};
let mut i32_buf3: Vec<i32> = Vec::new();
let mut f64_buf3: Vec<f64> = Vec::new();
large.to_broadcast_payload(&mut i32_buf3, &mut f64_buf3);
assert_eq!(
i32_buf3[6], 15_i32,
"state_at_capture length field must be 15 for N=3 L=2 A=2 K=3"
);
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let recovered = CapturedBasis::try_from_broadcast_payload(
0,
&i32_buf3,
&mut i32_cursor,
&f64_buf3,
&mut f64_cursor,
)
.expect("round-trip must not fail")
.expect("sentinel is 1; must return Some");
assert_eq!(
&recovered.state_at_capture[9..15],
&large_state[9..15],
"anticipated slice (last 6 entries) must roundtrip bit-exactly"
);
}
#[test]
fn test_captured_basis_round_trip_with_pre_horizon_seed_in_slot_zero() {
let state_at_capture = vec![1.0_f64, 2.0, 100.0, 200.0, 12345.5, 0.0];
let original = CapturedBasis {
basis: Basis {
col_status: vec![1_i32, 2, 3, 4],
row_status: vec![5_i32, 6, 7, 8],
},
base_row_count: 3,
cut_row_slots: vec![10_u32, 20],
state_at_capture: state_at_capture.clone(),
};
let mut i32_buf = Vec::new();
let mut f64_buf = Vec::new();
original.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
let mut i32_cursor = 0;
let mut f64_cursor = 0;
let recovered = CapturedBasis::try_from_broadcast_payload(
0,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect("round-trip must not fail")
.expect("sentinel is 1; must return Some");
assert_eq!(recovered.state_at_capture[4], 12345.5);
assert_eq!(recovered.basis.row_status.len(), 4);
assert_eq!(recovered.state_at_capture, state_at_capture);
assert_eq!(i32_cursor, i32_buf.len());
assert_eq!(f64_cursor, f64_buf.len());
}
#[test]
fn test_state_at_capture_length_equals_n_state_after_layout_change() {
let state_at_capture = vec![1.0_f64, 2.0, 100.0, 200.0, 1000.0, 2000.0];
let original = CapturedBasis {
basis: Basis {
col_status: vec![1_i32; 12],
row_status: vec![5_i32; 8],
},
base_row_count: 6,
cut_row_slots: vec![10_u32, 20],
state_at_capture: state_at_capture.clone(),
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
original.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let recovered = CapturedBasis::try_from_broadcast_payload(
0,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect("round-trip must not fail")
.expect("sentinel is 1; must return Some");
assert_eq!(
recovered.state_at_capture.len(),
6,
"state_at_capture.len() must equal n_state regardless of new LP columns"
);
assert_eq!(
recovered.state_at_capture, state_at_capture,
"state_at_capture round-trips bit-identically"
);
assert_eq!(
recovered.basis.col_status.len(),
12,
"col_status length round-trips bit-identically (including the new column slot)"
);
assert_eq!(
recovered.cut_row_slots, original.cut_row_slots,
"cut_row_slots round-trips bit-identically"
);
}
#[test]
fn test_basis_broadcast_wire_version_stays_one_with_state_out_column() {
use super::BASIS_BROADCAST_WIRE_VERSION;
let original = CapturedBasis {
basis: Basis {
col_status: vec![1_i32; 16], row_status: vec![1_i32; 10],
},
base_row_count: 8,
cut_row_slots: vec![],
state_at_capture: vec![0.0_f64; 6],
};
let mut i32_buf: Vec<i32> = Vec::new();
let mut f64_buf: Vec<f64> = Vec::new();
original.to_broadcast_payload(&mut i32_buf, &mut f64_buf);
assert_eq!(i32_buf[0], 1, "sentinel must be 1");
assert_eq!(
i32_buf[1], BASIS_BROADCAST_WIRE_VERSION,
"wire version field must equal BASIS_BROADCAST_WIRE_VERSION (= 1)"
);
assert_eq!(
BASIS_BROADCAST_WIRE_VERSION, 1,
"broadcast wire-format version constant must remain stable across releases"
);
let mut i32_cursor = 0_usize;
let mut f64_cursor = 0_usize;
let recovered = CapturedBasis::try_from_broadcast_payload(
0,
&i32_buf,
&mut i32_cursor,
&f64_buf,
&mut f64_cursor,
)
.expect("round-trip must not fail")
.expect("sentinel is 1; must return Some");
assert_eq!(recovered.basis.col_status, original.basis.col_status);
assert_eq!(recovered.basis.row_status, original.basis.row_status);
assert_eq!(recovered.base_row_count, original.base_row_count);
assert_eq!(recovered.cut_row_slots, original.cut_row_slots);
assert_eq!(recovered.state_at_capture, original.state_at_capture);
}
#[cfg(all(test, feature = "highs"))]
#[test]
fn workspace_solver_initialised_with_default_profile() {
use cobre_solver::HighsProfile;
let pool = WorkspacePool::new(0, 2, 0, WorkspaceSizing::default(), || MockSolver);
for ws in &pool.workspaces {
assert_eq!(
ws.solver.current_profile(),
&HighsProfile::default(),
"solver.current_profile() must equal HighsProfile::default() after construction"
);
}
let ws = SolverWorkspace::new(
0,
0,
MockSolver,
crate::lp_builder::PatchBuffer::new(0, 0, 0, 0, 0, 0),
0,
WorkspaceSizing::default(),
);
assert_eq!(
ws.solver.current_profile(),
&HighsProfile::default(),
"SolverWorkspace::new must initialise solver with default profile"
);
}
}