use cobre_solver::{RowBatch, StageTemplate};
use crate::{
context::StageContext, cut::CutRowMap, lower_bound::LbEvalScratch, lp_builder::PatchBuffer,
trajectory::TrajectoryRecord,
};
pub(crate) struct IterationScratch {
pub patch_buf: PatchBuffer,
pub records: Vec<TrajectoryRecord>,
pub cut_batches: Vec<RowBatch>,
pub lb_cut_batch: RowBatch,
pub baked_templates: Vec<StageTemplate>,
pub bake_row_batches: Vec<RowBatch>,
pub lb_cut_row_map: CutRowMap,
pub lb_scratch: LbEvalScratch,
pub(crate) baking_scratch: cobre_solver::BakingScratch,
}
impl IterationScratch {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
max_local_fwd: usize,
num_stages: usize,
n_state: usize,
fcf_pool_0_capacity: usize,
template_0_num_rows: usize,
hydro_count: usize,
max_par_order: usize,
n_anticipated: usize,
k_max: usize,
stage_ctx: &StageContext<'_>,
) -> Self {
let records: Vec<TrajectoryRecord> = (0..max_local_fwd * num_stages)
.map(|_| TrajectoryRecord {
primal: Vec::new(),
dual: Vec::new(),
stage_cost: 0.0,
state: vec![0.0; n_state],
})
.collect();
let patch_buf = PatchBuffer::new(hydro_count, max_par_order, 0, 0, n_anticipated, k_max);
let cut_batches: Vec<RowBatch> = (0..num_stages)
.map(|_| RowBatch {
num_rows: 0,
row_starts: Vec::new(),
col_indices: Vec::new(),
values: Vec::new(),
row_lower: Vec::new(),
row_upper: Vec::new(),
})
.collect();
let lb_cut_batch = RowBatch {
num_rows: 0,
row_starts: Vec::new(),
col_indices: Vec::new(),
values: Vec::new(),
row_lower: Vec::new(),
row_upper: Vec::new(),
};
let mut baked_templates: Vec<StageTemplate> =
(0..num_stages).map(|_| StageTemplate::empty()).collect();
let bake_row_batches: Vec<RowBatch> = (0..num_stages)
.map(|_| RowBatch {
num_rows: 0,
row_starts: Vec::new(),
col_indices: Vec::new(),
values: Vec::new(),
row_lower: Vec::new(),
row_upper: Vec::new(),
})
.collect();
let mut baking_scratch = cobre_solver::BakingScratch::new();
for t in 0..num_stages {
cobre_solver::bake_rows_into_template(
&stage_ctx.templates[t],
&bake_row_batches[t],
&mut baked_templates[t],
&mut baking_scratch,
);
}
let lb_cut_row_map = CutRowMap::new(fcf_pool_0_capacity, template_0_num_rows);
let lb_scratch = LbEvalScratch::new();
Self {
patch_buf,
records,
cut_batches,
lb_cut_batch,
baked_templates,
bake_row_batches,
lb_cut_row_map,
lb_scratch,
baking_scratch,
}
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::float_cmp,
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::too_many_lines,
clippy::needless_range_loop
)]
mod tests {
use cobre_solver::StageTemplate;
use super::IterationScratch;
use crate::context::StageContext;
fn minimal_template() -> StageTemplate {
StageTemplate {
num_cols: 4,
num_rows: 2,
num_nz: 1,
col_starts: vec![0_i32, 0, 0, 1, 1],
row_indices: vec![0_i32],
values: vec![1.0],
col_lower: vec![0.0, f64::NEG_INFINITY, 0.0, 0.0],
col_upper: vec![f64::INFINITY; 4],
objective: vec![0.0, 0.0, 0.0, 1.0],
row_lower: vec![0.0, 0.0],
row_upper: vec![0.0, 0.0],
n_state: 1,
n_transfer: 0,
n_dual_relevant: 1,
n_hydro: 1,
max_par_order: 0,
col_scale: Vec::new(),
row_scale: Vec::new(),
}
}
fn make_stage_ctx(templates: &[StageTemplate]) -> StageContext<'_> {
StageContext {
templates,
base_rows: &[],
noise_scale: &[],
n_hydros: 0,
n_load_buses: 0,
load_balance_row_starts: &[],
load_bus_indices: &[],
block_counts_per_stage: &[],
ncs_max_gen: &[],
ncs_allow_curtailment: &[],
discount_factors: &[],
cumulative_discount_factors: &[],
stage_lag_transitions: &[],
noise_group_ids: &[],
downstream_par_order: 0,
}
}
#[test]
fn iteration_scratch_new_sizes_vecs_correctly() {
let max_local_fwd = 2;
let num_stages = 3;
let n_state = 4;
let fcf_pool_0_capacity = 10;
let template_0_num_rows = 5;
let hydro_count = 1;
let max_par_order = 1;
let templates = vec![minimal_template(); num_stages];
let stage_ctx = make_stage_ctx(&templates);
let scratch = IterationScratch::new(
max_local_fwd,
num_stages,
n_state,
fcf_pool_0_capacity,
template_0_num_rows,
hydro_count,
max_par_order,
0,
0,
&stage_ctx,
);
assert_eq!(
scratch.records.len(),
max_local_fwd * num_stages,
"records must be pre-sized to max_local_fwd * num_stages"
);
assert_eq!(
scratch.records[0].state.len(),
n_state,
"each record state must have n_state elements"
);
assert_eq!(
scratch.cut_batches.len(),
num_stages,
"cut_batches must have one RowBatch per stage"
);
assert_eq!(
scratch.bake_row_batches.len(),
num_stages,
"bake_row_batches must have one RowBatch per stage"
);
assert_eq!(
scratch.baked_templates.len(),
num_stages,
"baked_templates must have one StageTemplate per stage"
);
}
#[test]
fn iteration_scratch_new_pre_bakes_templates() {
let max_local_fwd = 2;
let num_stages = 3;
let n_state = 4;
let fcf_pool_0_capacity = 10;
let template_0_num_rows = 5;
let hydro_count = 1;
let max_par_order = 1;
let templates = vec![minimal_template(); num_stages];
let stage_ctx = make_stage_ctx(&templates);
let scratch = IterationScratch::new(
max_local_fwd,
num_stages,
n_state,
fcf_pool_0_capacity,
template_0_num_rows,
hydro_count,
max_par_order,
0,
0,
&stage_ctx,
);
for t in 0..num_stages {
assert_eq!(
scratch.baked_templates[t].num_rows, stage_ctx.templates[t].num_rows,
"baked_templates[{t}].num_rows must match stage_ctx template"
);
assert_eq!(
scratch.baked_templates[t].num_cols, stage_ctx.templates[t].num_cols,
"baked_templates[{t}].num_cols must match stage_ctx template"
);
}
}
#[test]
fn iteration_scratch_new_sizes_patch_buffer_for_anticipated_thermals() {
let max_local_fwd = 1;
let num_stages = 2;
let n_state = 4;
let fcf_pool_0_capacity = 4;
let template_0_num_rows = 4;
let hydro_count = 2;
let max_par_order = 1;
let n_anticipated = 3;
let k_max = 2;
let templates = vec![minimal_template(); num_stages];
let stage_ctx = make_stage_ctx(&templates);
let scratch = IterationScratch::new(
max_local_fwd,
num_stages,
n_state,
fcf_pool_0_capacity,
template_0_num_rows,
hydro_count,
max_par_order,
n_anticipated,
k_max,
&stage_ctx,
);
let expected_capacity = hydro_count + hydro_count;
assert_eq!(
scratch.patch_buf.indices.len(),
expected_capacity,
"patch_buf indices length must include A*K slots for Category 6",
);
assert_eq!(
scratch.patch_buf.lower.len(),
expected_capacity,
"patch_buf lower length must include A*K slots for Category 6",
);
assert_eq!(
scratch.patch_buf.upper.len(),
expected_capacity,
"patch_buf upper length must include A*K slots for Category 6",
);
let expected_pre_fill_count = hydro_count;
assert_eq!(
scratch.patch_buf.forward_patch_count(),
expected_pre_fill_count,
"forward_patch_count must include the A*K Category 6 slots",
);
}
#[test]
fn iteration_scratch_new_patch_buffer_zero_anticipated_unchanged() {
let max_local_fwd = 1;
let num_stages = 2;
let n_state = 2;
let fcf_pool_0_capacity = 4;
let template_0_num_rows = 4;
let hydro_count = 2;
let max_par_order = 1;
let templates = vec![minimal_template(); num_stages];
let stage_ctx = make_stage_ctx(&templates);
let scratch = IterationScratch::new(
max_local_fwd,
num_stages,
n_state,
fcf_pool_0_capacity,
template_0_num_rows,
hydro_count,
max_par_order,
0,
0,
&stage_ctx,
);
let expected_capacity = hydro_count + hydro_count;
assert_eq!(
scratch.patch_buf.indices.len(),
expected_capacity,
"zero-anticipated patch_buf must match the pre-anticipated layout",
);
}
}