use cobre_solver::{Basis, SolverInterface};
use crate::lp_builder::PatchBuffer;
#[allow(clippy::struct_field_names)]
pub(crate) struct ScratchBuffers {
pub(crate) noise_buf: Vec<f64>,
pub(crate) inflow_m3s_buf: Vec<f64>,
pub(crate) lag_matrix_buf: Vec<f64>,
pub(crate) par_inflow_buf: Vec<f64>,
pub(crate) eta_floor_buf: Vec<f64>,
pub(crate) zero_targets_buf: Vec<f64>,
pub(crate) ncs_col_upper_buf: Vec<f64>,
pub(crate) ncs_col_lower_buf: Vec<f64>,
pub(crate) ncs_col_indices_buf: Vec<usize>,
pub(crate) load_rhs_buf: Vec<f64>,
pub(crate) row_lower_buf: Vec<f64>,
pub(crate) z_inflow_rhs_buf: Vec<f64>,
pub(crate) effective_eta_buf: Vec<f64>,
pub(crate) unscaled_primal: Vec<f64>,
pub(crate) unscaled_dual: Vec<f64>,
}
pub struct SolverWorkspace<S: SolverInterface> {
pub solver: S,
pub patch_buf: PatchBuffer,
pub current_state: Vec<f64>,
pub(crate) scratch: ScratchBuffers,
}
impl<S: SolverInterface> SolverWorkspace<S> {
#[must_use]
pub fn new(
solver: S,
patch_buf: PatchBuffer,
n_state: usize,
hydro_count: usize,
max_par_order: usize,
n_load_buses: usize,
max_blocks: usize,
) -> Self {
Self {
solver,
patch_buf,
current_state: Vec::with_capacity(n_state),
scratch: ScratchBuffers {
noise_buf: Vec::with_capacity(hydro_count),
inflow_m3s_buf: Vec::with_capacity(hydro_count),
lag_matrix_buf: Vec::with_capacity(max_par_order * hydro_count),
par_inflow_buf: Vec::with_capacity(hydro_count),
eta_floor_buf: Vec::with_capacity(hydro_count),
zero_targets_buf: vec![0.0_f64; hydro_count],
ncs_col_upper_buf: Vec::new(),
ncs_col_lower_buf: Vec::new(),
ncs_col_indices_buf: Vec::new(),
load_rhs_buf: Vec::with_capacity(n_load_buses * max_blocks),
row_lower_buf: Vec::new(),
z_inflow_rhs_buf: Vec::with_capacity(hydro_count),
effective_eta_buf: Vec::with_capacity(hydro_count),
unscaled_primal: Vec::new(),
unscaled_dual: Vec::new(),
},
}
}
}
pub struct WorkspacePool<S: SolverInterface> {
pub workspaces: Vec<SolverWorkspace<S>>,
}
impl<S: SolverInterface> WorkspacePool<S> {
#[must_use]
pub fn new(
n_threads: usize,
hydro_count: usize,
max_par_order: usize,
n_state: usize,
n_load_buses: usize,
max_blocks: usize,
solver_factory: impl Fn() -> S,
) -> Self {
let workspaces = (0..n_threads)
.map(|_| SolverWorkspace {
solver: solver_factory(),
patch_buf: PatchBuffer::new(hydro_count, max_par_order, n_load_buses, max_blocks),
current_state: Vec::with_capacity(n_state),
scratch: ScratchBuffers {
noise_buf: Vec::with_capacity(hydro_count),
inflow_m3s_buf: Vec::with_capacity(hydro_count),
lag_matrix_buf: Vec::with_capacity(max_par_order * hydro_count),
par_inflow_buf: Vec::with_capacity(hydro_count),
eta_floor_buf: Vec::with_capacity(hydro_count),
zero_targets_buf: vec![0.0_f64; hydro_count],
ncs_col_upper_buf: Vec::new(),
ncs_col_lower_buf: Vec::new(),
ncs_col_indices_buf: Vec::new(),
load_rhs_buf: Vec::with_capacity(n_load_buses * max_blocks),
row_lower_buf: Vec::new(),
z_inflow_rhs_buf: Vec::with_capacity(hydro_count),
effective_eta_buf: Vec::with_capacity(hydro_count),
unscaled_primal: Vec::new(),
unscaled_dual: Vec::new(),
},
})
.collect();
Self { workspaces }
}
pub fn try_new<E>(
n_threads: usize,
hydro_count: usize,
max_par_order: usize,
n_state: usize,
n_load_buses: usize,
max_blocks: usize,
solver_factory: impl Fn() -> Result<S, E>,
) -> Result<Self, E> {
let mut workspaces = Vec::with_capacity(n_threads);
for _ in 0..n_threads {
workspaces.push(SolverWorkspace {
solver: solver_factory()?,
patch_buf: PatchBuffer::new(hydro_count, max_par_order, n_load_buses, max_blocks),
current_state: Vec::with_capacity(n_state),
scratch: ScratchBuffers {
noise_buf: Vec::with_capacity(hydro_count),
inflow_m3s_buf: Vec::with_capacity(hydro_count),
lag_matrix_buf: Vec::with_capacity(max_par_order * hydro_count),
par_inflow_buf: Vec::with_capacity(hydro_count),
eta_floor_buf: Vec::with_capacity(hydro_count),
zero_targets_buf: vec![0.0_f64; hydro_count],
ncs_col_upper_buf: Vec::new(),
ncs_col_lower_buf: Vec::new(),
ncs_col_indices_buf: Vec::new(),
load_rhs_buf: Vec::with_capacity(n_load_buses * max_blocks),
row_lower_buf: Vec::new(),
z_inflow_rhs_buf: Vec::with_capacity(hydro_count),
effective_eta_buf: Vec::with_capacity(hydro_count),
unscaled_primal: Vec::new(),
unscaled_dual: Vec::new(),
},
});
}
Ok(Self { workspaces })
}
}
pub struct BasisStore {
bases: Vec<Option<Basis>>,
num_stages: usize,
}
impl BasisStore {
#[must_use]
pub fn new(num_scenarios: usize, num_stages: usize) -> Self {
Self {
bases: vec![None; num_scenarios * num_stages],
num_stages,
}
}
#[must_use]
pub fn num_scenarios(&self) -> usize {
if self.num_stages == 0 {
0
} else {
self.bases.len() / self.num_stages
}
}
#[must_use]
pub fn num_stages(&self) -> usize {
self.num_stages
}
#[must_use]
pub fn get(&self, scenario: usize, stage: usize) -> Option<&Basis> {
self.bases[scenario * self.num_stages + stage].as_ref()
}
pub fn get_mut(&mut self, scenario: usize, stage: usize) -> &mut Option<Basis> {
&mut self.bases[scenario * self.num_stages + stage]
}
#[must_use]
pub fn split_workers_mut(&mut self, n_workers: usize) -> Vec<BasisStoreSliceMut<'_>> {
debug_assert!(n_workers > 0, "n_workers must be > 0");
let total_scenarios = self.num_scenarios();
let mut slices = Vec::with_capacity(n_workers);
let mut remainder = self.bases.as_mut_slice();
let mut offset = 0usize;
for w in 0..n_workers {
let (start, end) = crate::forward::partition(total_scenarios, n_workers, w);
let count = end - start;
let (left, rest) = remainder.split_at_mut(count * self.num_stages);
remainder = rest;
slices.push(BasisStoreSliceMut {
bases: left,
scenario_offset: offset,
num_stages: self.num_stages,
});
offset += count;
}
slices
}
}
pub struct BasisStoreSliceMut<'a> {
bases: &'a mut [Option<Basis>],
scenario_offset: usize,
num_stages: usize,
}
impl BasisStoreSliceMut<'_> {
#[must_use]
pub fn get(&self, scenario: usize, stage: usize) -> Option<&Basis> {
let local = scenario - self.scenario_offset;
self.bases[local * self.num_stages + stage].as_ref()
}
pub fn get_mut(&mut self, scenario: usize, stage: usize) -> &mut Option<Basis> {
let local = scenario - self.scenario_offset;
&mut self.bases[local * self.num_stages + stage]
}
}
#[cfg(test)]
mod tests {
use super::{BasisStore, SolverWorkspace, WorkspacePool};
use cobre_solver::{
Basis, SolutionView, SolverError, SolverInterface, SolverStatistics,
types::{RowBatch, StageTemplate},
};
struct MockSolver;
impl SolverInterface for MockSolver {
fn load_model(&mut self, _t: &StageTemplate) {}
fn add_rows(&mut self, _r: &RowBatch) {}
fn set_row_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn set_col_bounds(&mut self, _i: &[usize], _l: &[f64], _u: &[f64]) {}
fn solve(&mut self) -> Result<SolutionView<'_>, SolverError> {
Err(SolverError::InternalError {
message: "mock".into(),
error_code: None,
})
}
fn reset(&mut self) {}
fn get_basis(&mut self, _out: &mut Basis) {}
fn solve_with_basis(&mut self, _b: &Basis) -> Result<SolutionView<'_>, SolverError> {
Err(SolverError::InternalError {
message: "mock".into(),
error_code: None,
})
}
fn statistics(&self) -> SolverStatistics {
SolverStatistics::default()
}
fn name(&self) -> &'static str {
"Mock"
}
}
fn assert_send<T: Send>() {}
#[test]
fn test_workspace_send_bound() {
assert_send::<SolverWorkspace<MockSolver>>();
}
#[test]
fn test_workspace_pool_size() {
let pool = WorkspacePool::new(4, 3, 2, 9, 0, 0, || MockSolver);
assert_eq!(pool.workspaces.len(), 4);
}
#[test]
fn test_workspace_buffer_dimensions() {
let pool = WorkspacePool::new(4, 3, 2, 9, 0, 0, || MockSolver);
for ws in &pool.workspaces {
assert_eq!(ws.patch_buf.indices.len(), 15, "patch_buf length");
assert_eq!(ws.current_state.capacity(), 9, "current_state capacity");
assert_eq!(ws.current_state.len(), 0, "current_state starts empty");
}
}
#[test]
fn test_workspace_pool_zero_threads() {
let pool = WorkspacePool::new(0, 3, 2, 9, 0, 0, || MockSolver);
assert_eq!(pool.workspaces.len(), 0);
}
#[test]
fn test_workspace_pool_single_thread() {
let pool = WorkspacePool::new(1, 0, 0, 0, 0, 0, || MockSolver);
assert_eq!(pool.workspaces.len(), 1);
assert_eq!(pool.workspaces[0].patch_buf.indices.len(), 0);
}
#[test]
fn test_workspace_pool_each_solver_independent() {
let n = 6;
let pool = WorkspacePool::new(n, 1, 0, 1, 0, 0, || MockSolver);
assert_eq!(pool.workspaces.len(), n);
}
#[test]
fn basis_store_new_all_none() {
let store = BasisStore::new(3, 5);
assert_eq!(store.num_scenarios(), 3);
assert_eq!(store.num_stages(), 5);
for s in 0..3 {
for t in 0..5 {
assert!(
store.get(s, t).is_none(),
"slot [{s}][{t}] must start as None"
);
}
}
}
#[test]
fn basis_store_get_mut_set_and_retrieve() {
let mut store = BasisStore::new(2, 3);
let basis = Basis::new(4, 2);
*store.get_mut(1, 2) = Some(basis);
assert!(store.get(1, 2).is_some());
assert!(store.get(0, 0).is_none());
assert!(store.get(1, 0).is_none());
}
#[test]
fn basis_store_zero_scenarios() {
let store = BasisStore::new(0, 5);
assert_eq!(store.num_scenarios(), 0);
assert_eq!(store.num_stages(), 5);
}
#[test]
fn basis_store_zero_stages() {
let store = BasisStore::new(3, 0);
assert_eq!(store.num_scenarios(), 0);
assert_eq!(store.num_stages(), 0);
}
#[test]
fn basis_store_split_workers_mut_disjoint_writes() {
let mut store = BasisStore::new(4, 3);
let mut slices = store.split_workers_mut(2);
*slices[0].get_mut(0, 1) = Some(Basis::new(2, 1));
*slices[1].get_mut(3, 2) = Some(Basis::new(2, 1));
drop(slices);
assert!(
store.get(0, 1).is_some(),
"scenario 0 stage 1 must be populated"
);
assert!(
store.get(3, 2).is_some(),
"scenario 3 stage 2 must be populated"
);
assert!(store.get(0, 0).is_none());
assert!(store.get(3, 0).is_none());
}
#[test]
fn basis_store_split_single_worker() {
let mut store = BasisStore::new(3, 2);
let mut slices = store.split_workers_mut(1);
*slices[0].get_mut(2, 1) = Some(Basis::new(1, 0));
drop(slices);
assert!(store.get(2, 1).is_some());
}
#[test]
fn basis_store_split_more_workers_than_scenarios() {
let mut store = BasisStore::new(2, 3);
let slices = store.split_workers_mut(4);
assert_eq!(slices.len(), 4);
assert_eq!(slices[0].bases.len(), 3); assert_eq!(slices[1].bases.len(), 3);
assert_eq!(slices[2].bases.len(), 0);
assert_eq!(slices[3].bases.len(), 0);
}
#[test]
fn basis_store_slice_offset_correct() {
let mut store = BasisStore::new(6, 2);
let mut slices = store.split_workers_mut(3);
*slices[1].get_mut(2, 0) = Some(Basis::new(1, 0));
*slices[1].get_mut(3, 1) = Some(Basis::new(1, 0));
drop(slices);
assert!(store.get(2, 0).is_some());
assert!(store.get(3, 1).is_some());
assert!(store.get(0, 0).is_none());
assert!(store.get(4, 0).is_none());
}
}