use super::pool::CutPool;
#[derive(Debug, Clone)]
pub struct FutureCostFunction {
pub pools: Vec<CutPool>,
pub state_dimension: usize,
pub forward_passes: u32,
}
impl FutureCostFunction {
#[must_use]
pub fn new(
num_stages: usize,
state_dimension: usize,
forward_passes: u32,
max_iterations: u64,
warm_start_counts: &[u32],
) -> Self {
debug_assert_eq!(
warm_start_counts.len(),
num_stages,
"warm_start_counts.len() ({}) != num_stages ({})",
warm_start_counts.len(),
num_stages
);
#[allow(clippy::cast_possible_truncation)]
let pools = warm_start_counts
.iter()
.map(|&wsc| {
let capacity =
(u64::from(wsc) + max_iterations * u64::from(forward_passes)) as usize;
CutPool::new(capacity, state_dimension, forward_passes, wsc)
})
.collect();
Self {
pools,
state_dimension,
forward_passes,
}
}
pub fn from_deserialized(
stage_results: &[cobre_io::StageCutsReadResult],
) -> Result<Self, crate::SddpError> {
if stage_results.is_empty() {
return Err(crate::SddpError::Validation(
"from_deserialized: stage_results is empty".to_string(),
));
}
let state_dimension = stage_results[0].state_dimension as usize;
for sr in &stage_results[1..] {
if sr.state_dimension as usize != state_dimension {
return Err(crate::SddpError::Validation(format!(
"from_deserialized: inconsistent state_dimension: stage {} has {}, \
expected {} (from stage {})",
sr.stage_id, sr.state_dimension, state_dimension, stage_results[0].stage_id
)));
}
}
let pools = stage_results
.iter()
.map(|sr| CutPool::from_deserialized(state_dimension, &sr.cuts))
.collect();
Ok(Self {
pools,
state_dimension,
forward_passes: 0,
})
}
pub fn new_with_warm_start(
stage_results: &[cobre_io::StageCutsReadResult],
forward_passes: u32,
max_iterations: u64,
) -> Result<Self, crate::SddpError> {
if stage_results.is_empty() {
return Err(crate::SddpError::Validation(
"new_with_warm_start: stage_results is empty".to_string(),
));
}
let state_dimension = stage_results[0].state_dimension as usize;
for sr in &stage_results[1..] {
if sr.state_dimension as usize != state_dimension {
return Err(crate::SddpError::Validation(format!(
"new_with_warm_start: inconsistent state_dimension: stage {} has {}, \
expected {} (from stage {})",
sr.stage_id, sr.state_dimension, state_dimension, stage_results[0].stage_id
)));
}
}
let pools = stage_results
.iter()
.map(|sr| {
CutPool::new_with_warm_start(
state_dimension,
forward_passes,
max_iterations,
&sr.cuts,
)
})
.collect();
Ok(Self {
pools,
state_dimension,
forward_passes,
})
}
pub fn add_cut(
&mut self,
stage: usize,
iteration: u64,
forward_pass_index: u32,
intercept: f64,
coefficients: &[f64],
) {
debug_assert!(
stage < self.pools.len(),
"stage index {stage} is out of bounds (num_stages = {})",
self.pools.len()
);
self.pools[stage].add_cut(iteration, forward_pass_index, intercept, coefficients);
}
pub fn active_cuts(&self, stage: usize) -> impl Iterator<Item = (usize, f64, &[f64])> {
debug_assert!(
stage < self.pools.len(),
"stage index {stage} is out of bounds (num_stages = {})",
self.pools.len()
);
self.pools[stage].active_cuts()
}
#[must_use]
pub fn evaluate_at_state(&self, stage: usize, values: &[f64]) -> f64 {
debug_assert!(
stage < self.pools.len(),
"stage index {stage} is out of bounds (num_stages = {})",
self.pools.len()
);
self.pools[stage].evaluate_at_state(values)
}
#[must_use]
pub fn total_active_cuts(&self) -> usize {
self.pools.iter().map(CutPool::active_count).sum()
}
pub fn deactivate(&mut self, stage: usize, indices: &[u32]) {
debug_assert!(
stage < self.pools.len(),
"stage index {stage} is out of bounds (num_stages = {})",
self.pools.len()
);
self.pools[stage].deactivate(indices);
}
#[must_use]
pub fn sparsity_reports(&self) -> Vec<super::pool::SparsityReport> {
self.pools.iter().map(CutPool::sparsity_report).collect()
}
}
#[cfg(test)]
mod tests {
use super::FutureCostFunction;
#[test]
fn new_creates_correct_number_of_pools() {
let fcf = FutureCostFunction::new(5, 9, 10, 100, &[0; 5]);
assert_eq!(fcf.pools.len(), 5);
}
#[test]
fn new_each_pool_has_correct_capacity_no_warmstart() {
let fcf = FutureCostFunction::new(5, 9, 10, 100, &[0; 5]);
for pool in &fcf.pools {
assert_eq!(pool.capacity, 1000);
assert_eq!(pool.state_dimension, 9);
assert_eq!(pool.forward_passes, 10);
assert_eq!(pool.warm_start_count, 0);
}
}
#[test]
fn new_each_pool_has_correct_capacity_with_warmstart() {
let fcf = FutureCostFunction::new(3, 4, 10, 100, &[5; 3]);
for pool in &fcf.pools {
assert_eq!(pool.capacity, 1005);
assert_eq!(pool.warm_start_count, 5);
}
}
#[test]
fn new_all_pools_start_with_zero_active_cuts() {
let fcf = FutureCostFunction::new(4, 3, 5, 20, &[0; 4]);
assert_eq!(fcf.total_active_cuts(), 0);
}
#[test]
fn new_zero_stages_is_valid() {
let fcf = FutureCostFunction::new(0, 4, 5, 10, &[]);
assert_eq!(fcf.pools.len(), 0);
assert_eq!(fcf.total_active_cuts(), 0);
}
#[test]
fn new_non_uniform_warm_start_counts_per_stage_capacity() {
let fcf = FutureCostFunction::new(3, 4, 2, 10, &[5, 3, 0]);
assert_eq!(fcf.pools[0].capacity, 25);
assert_eq!(fcf.pools[0].warm_start_count, 5);
assert_eq!(fcf.pools[1].capacity, 23);
assert_eq!(fcf.pools[1].warm_start_count, 3);
assert_eq!(fcf.pools[2].capacity, 20);
assert_eq!(fcf.pools[2].warm_start_count, 0);
}
#[test]
fn new_uniform_zero_counts_matches_old_scalar_zero_behavior() {
let fcf = FutureCostFunction::new(3, 4, 2, 10, &[0, 0, 0]);
for pool in &fcf.pools {
assert_eq!(pool.capacity, 20);
assert_eq!(pool.warm_start_count, 0);
}
}
#[cfg(debug_assertions)]
#[test]
#[should_panic(expected = "warm_start_counts.len()")]
fn new_mismatched_length_panics_in_debug() {
let _ = FutureCostFunction::new(3, 4, 2, 10, &[0, 0]);
}
#[test]
fn add_cut_and_active_cuts_round_trip_at_specific_stage() {
let mut fcf = FutureCostFunction::new(5, 2, 1, 10, &[0; 5]);
let coeffs = [3.0, 7.0];
fcf.add_cut(2, 0, 0, 42.0, &coeffs);
let active: Vec<_> = fcf.active_cuts(2).collect();
assert_eq!(active.len(), 1);
let (_, intercept, c) = active[0];
assert_eq!(intercept, 42.0);
assert_eq!(c, &[3.0, 7.0]);
}
#[test]
fn active_cuts_at_other_stage_returns_empty() {
let mut fcf = FutureCostFunction::new(5, 2, 1, 10, &[0; 5]);
fcf.add_cut(2, 0, 0, 42.0, &[1.0, 2.0]);
let active: Vec<_> = fcf.active_cuts(3).collect();
assert!(active.is_empty());
}
#[test]
fn add_cut_multiple_stages_are_independent() {
let mut fcf = FutureCostFunction::new(4, 1, 1, 10, &[0; 4]);
fcf.add_cut(0, 0, 0, 1.0, &[1.0]);
fcf.add_cut(1, 0, 0, 2.0, &[2.0]);
fcf.add_cut(3, 0, 0, 4.0, &[4.0]);
assert_eq!(fcf.active_cuts(0).count(), 1);
assert_eq!(fcf.active_cuts(1).count(), 1);
assert_eq!(fcf.active_cuts(2).count(), 0);
assert_eq!(fcf.active_cuts(3).count(), 1);
}
#[test]
fn evaluate_at_state_delegates_to_correct_pool() {
let mut fcf = FutureCostFunction::new(3, 2, 1, 10, &[0; 3]);
fcf.add_cut(1, 0, 0, 10.0, &[1.0, 0.0]);
fcf.add_cut(2, 0, 0, 5.0, &[0.0, 2.0]);
assert_eq!(fcf.evaluate_at_state(1, &[3.0, 4.0]), 13.0);
assert_eq!(fcf.evaluate_at_state(2, &[3.0, 4.0]), 13.0);
assert_eq!(fcf.evaluate_at_state(0, &[3.0, 4.0]), f64::NEG_INFINITY);
}
#[test]
fn total_active_cuts_sums_across_stages() {
let mut fcf = FutureCostFunction::new(4, 1, 1, 20, &[0; 4]);
fcf.add_cut(0, 0, 0, 1.0, &[1.0]);
fcf.add_cut(1, 0, 0, 2.0, &[2.0]);
fcf.add_cut(1, 1, 0, 3.0, &[3.0]);
fcf.add_cut(3, 0, 0, 4.0, &[4.0]);
assert_eq!(fcf.total_active_cuts(), 4);
}
#[test]
fn total_active_cuts_reflects_deactivation() {
let mut fcf = FutureCostFunction::new(2, 1, 1, 10, &[0; 2]);
fcf.add_cut(0, 0, 0, 1.0, &[1.0]); fcf.add_cut(0, 1, 0, 2.0, &[2.0]); fcf.add_cut(1, 0, 0, 3.0, &[3.0]);
assert_eq!(fcf.total_active_cuts(), 3);
fcf.deactivate(0, &[0]);
assert_eq!(fcf.total_active_cuts(), 2);
}
#[test]
fn deactivate_delegates_to_correct_pool() {
let mut fcf = FutureCostFunction::new(3, 1, 1, 10, &[0; 3]);
fcf.add_cut(1, 0, 0, 10.0, &[1.0]); fcf.add_cut(1, 1, 0, 20.0, &[2.0]); fcf.add_cut(2, 0, 0, 30.0, &[3.0]);
fcf.deactivate(1, &[0]);
assert_eq!(fcf.active_cuts(1).count(), 1);
assert_eq!(fcf.active_cuts(2).count(), 1);
}
#[test]
fn ac_new_5_stages_pools_len_is_5() {
let fcf = FutureCostFunction::new(5, 9, 10, 100, &[0; 5]);
assert_eq!(fcf.pools.len(), 5);
}
#[test]
fn ac_active_cuts_at_stage_with_cut_yields_it() {
let mut fcf = FutureCostFunction::new(5, 3, 1, 10, &[0; 5]);
let coeffs = [1.0, 2.0, 3.0];
fcf.add_cut(2, 0, 0, 99.0, &coeffs);
let active: Vec<_> = fcf.active_cuts(2).collect();
assert_eq!(active.len(), 1);
}
#[test]
fn ac_active_cuts_at_different_stage_yields_none() {
let mut fcf = FutureCostFunction::new(5, 3, 1, 10, &[0; 5]);
fcf.add_cut(2, 0, 0, 99.0, &[1.0, 2.0, 3.0]);
let active: Vec<_> = fcf.active_cuts(3).collect();
assert!(active.is_empty());
}
#[test]
fn ac_total_active_cuts_is_sum_across_stages() {
let mut fcf = FutureCostFunction::new(5, 1, 1, 10, &[0; 5]);
fcf.add_cut(0, 0, 0, 1.0, &[1.0]);
fcf.add_cut(1, 0, 0, 2.0, &[2.0]);
fcf.add_cut(1, 1, 0, 3.0, &[3.0]);
fcf.add_cut(4, 0, 0, 4.0, &[4.0]);
assert_eq!(fcf.total_active_cuts(), 4);
}
#[test]
fn fcf_derives_debug_and_clone() {
let mut fcf = FutureCostFunction::new(2, 2, 1, 5, &[0; 2]);
fcf.add_cut(0, 0, 0, 7.0, &[1.0, 2.0]);
let cloned = fcf.clone();
assert_eq!(cloned.total_active_cuts(), 1);
assert_eq!(cloned.evaluate_at_state(0, &[0.0, 0.0]), 7.0);
let debug_str = format!("{fcf:?}");
assert!(!debug_str.is_empty());
}
fn make_record(
intercept: f64,
coefficients: Vec<f64>,
is_active: bool,
) -> cobre_io::OwnedPolicyCutRecord {
cobre_io::OwnedPolicyCutRecord {
cut_id: 0,
slot_index: 0,
iteration: 0,
forward_pass_index: 0,
intercept,
coefficients,
is_active,
}
}
fn make_stage(
stage_id: u32,
state_dimension: u32,
cuts: Vec<cobre_io::OwnedPolicyCutRecord>,
) -> cobre_io::StageCutsReadResult {
let populated_count = u32::try_from(cuts.len()).expect("cuts count fits in u32");
cobre_io::StageCutsReadResult {
stage_id,
state_dimension,
capacity: populated_count,
warm_start_count: 0,
populated_count,
cuts,
}
}
#[test]
fn from_deserialized_empty_input_returns_err() {
let result = FutureCostFunction::from_deserialized(&[]);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("empty"), "{msg}");
}
#[test]
fn from_deserialized_inconsistent_dimensions_returns_err() {
let stages = vec![
make_stage(0, 2, vec![make_record(1.0, vec![1.0, 0.0], true)]),
make_stage(1, 3, vec![make_record(2.0, vec![1.0, 0.0, 0.0], true)]),
];
let result = FutureCostFunction::from_deserialized(&stages);
assert!(result.is_err());
let msg = result.unwrap_err().to_string();
assert!(msg.contains("inconsistent"), "{msg}");
}
#[test]
fn from_deserialized_preserves_active_flags() {
let stages = vec![make_stage(
0,
2,
vec![
make_record(1.0, vec![1.0, 0.0], true),
make_record(2.0, vec![0.0, 1.0], false), make_record(3.0, vec![1.0, 1.0], true),
],
)];
let fcf = FutureCostFunction::from_deserialized(&stages).unwrap();
assert_eq!(fcf.pools.len(), 1);
assert_eq!(fcf.total_active_cuts(), 2);
assert_eq!(fcf.pools[0].populated_count, 3);
}
#[test]
fn from_deserialized_evaluate_at_state_matches_original() {
let mut original = FutureCostFunction::new(2, 2, 1, 10, &[0; 2]);
original.add_cut(0, 0, 0, 10.0, &[1.0, 0.0]);
original.add_cut(0, 1, 0, 5.0, &[0.0, 2.0]);
original.add_cut(1, 0, 0, 3.0, &[1.0, 1.0]);
let state = [3.0, 4.0];
let orig_val_s0 = original.evaluate_at_state(0, &state);
let orig_val_s1 = original.evaluate_at_state(1, &state);
let stages = vec![
make_stage(
0,
2,
vec![
make_record(10.0, vec![1.0, 0.0], true),
make_record(5.0, vec![0.0, 2.0], true),
],
),
make_stage(1, 2, vec![make_record(3.0, vec![1.0, 1.0], true)]),
];
let reconstructed = FutureCostFunction::from_deserialized(&stages).unwrap();
assert_eq!(reconstructed.evaluate_at_state(0, &state), orig_val_s0);
assert_eq!(reconstructed.evaluate_at_state(1, &state), orig_val_s1);
}
#[test]
fn from_deserialized_empty_stage_is_valid() {
let stages = vec![
make_stage(0, 2, vec![make_record(1.0, vec![1.0, 0.0], true)]),
make_stage(1, 2, vec![]), ];
let fcf = FutureCostFunction::from_deserialized(&stages).unwrap();
assert_eq!(fcf.pools.len(), 2);
assert_eq!(fcf.pools[1].capacity, 0);
assert_eq!(fcf.pools[1].active_count(), 0);
assert_eq!(fcf.evaluate_at_state(1, &[1.0, 1.0]), f64::NEG_INFINITY);
}
#[test]
fn from_deserialized_single_cut_stage() {
let stages = vec![make_stage(
0,
3,
vec![make_record(7.0, vec![1.0, 2.0, 3.0], true)],
)];
let fcf = FutureCostFunction::from_deserialized(&stages).unwrap();
assert_eq!(fcf.state_dimension, 3);
assert_eq!(fcf.total_active_cuts(), 1);
assert_eq!(fcf.evaluate_at_state(0, &[1.0, 2.0, 3.0]), 21.0);
}
#[test]
fn warm_start_capacity_includes_training_slots() {
let stages = vec![make_stage(
0,
2,
vec![
make_record(1.0, vec![1.0, 0.0], true),
make_record(2.0, vec![0.0, 1.0], true),
],
)];
let fcf = FutureCostFunction::new_with_warm_start(&stages, 4, 10).unwrap();
assert_eq!(fcf.pools.len(), 1);
assert_eq!(fcf.pools[0].capacity, 42);
assert_eq!(fcf.pools[0].warm_start_count, 2);
assert_eq!(fcf.pools[0].forward_passes, 4);
assert_eq!(fcf.pools[0].populated_count, 2);
assert_eq!(fcf.total_active_cuts(), 2);
}
#[test]
fn warm_start_training_cuts_at_correct_offset() {
let stages = vec![make_stage(0, 1, vec![make_record(10.0, vec![1.0], true)])];
let mut fcf = FutureCostFunction::new_with_warm_start(&stages, 2, 5).unwrap();
fcf.add_cut(0, 0, 0, 20.0, &[2.0]);
fcf.add_cut(0, 0, 1, 30.0, &[3.0]);
assert_eq!(fcf.total_active_cuts(), 3);
assert_eq!(fcf.pools[0].populated_count, 3);
assert_eq!(fcf.pools[0].intercepts[0], 10.0);
assert_eq!(fcf.pools[0].intercepts[1], 20.0);
assert_eq!(fcf.pools[0].intercepts[2], 30.0);
}
#[test]
fn warm_start_empty_stage_has_training_capacity() {
let stages = vec![
make_stage(0, 2, vec![make_record(1.0, vec![1.0, 0.0], true)]),
make_stage(1, 2, vec![]),
];
let fcf = FutureCostFunction::new_with_warm_start(&stages, 3, 5).unwrap();
assert_eq!(fcf.pools[0].capacity, 16);
assert_eq!(fcf.pools[0].warm_start_count, 1);
assert_eq!(fcf.pools[1].capacity, 15);
assert_eq!(fcf.pools[1].warm_start_count, 0);
}
#[test]
fn warm_start_preserves_inactive_flags() {
let stages = vec![make_stage(
0,
2,
vec![
make_record(1.0, vec![1.0, 0.0], true),
make_record(2.0, vec![0.0, 1.0], false), ],
)];
let fcf = FutureCostFunction::new_with_warm_start(&stages, 1, 5).unwrap();
assert_eq!(fcf.pools[0].warm_start_count, 2);
assert_eq!(fcf.total_active_cuts(), 1); }
}