use super::pool::CutPool;
#[derive(Debug, Clone)]
pub struct FutureCostFunction {
pub pools: Vec<CutPool>,
pub state_dimension: usize,
pub forward_passes: u32,
}
impl FutureCostFunction {
#[must_use]
pub fn new(
num_stages: usize,
state_dimension: usize,
forward_passes: u32,
max_iterations: u64,
warm_start_count: u32,
) -> Self {
#[allow(clippy::cast_possible_truncation)]
let capacity: usize =
(u64::from(warm_start_count) + max_iterations * u64::from(forward_passes)) as usize;
let pools = (0..num_stages)
.map(|_| CutPool::new(capacity, state_dimension, forward_passes, warm_start_count))
.collect();
Self {
pools,
state_dimension,
forward_passes,
}
}
pub fn add_cut(
&mut self,
stage: usize,
iteration: u64,
forward_pass_index: u32,
intercept: f64,
coefficients: &[f64],
) {
debug_assert!(
stage < self.pools.len(),
"stage index {stage} is out of bounds (num_stages = {})",
self.pools.len()
);
self.pools[stage].add_cut(iteration, forward_pass_index, intercept, coefficients);
}
pub fn active_cuts(&self, stage: usize) -> impl Iterator<Item = (usize, f64, &[f64])> {
debug_assert!(
stage < self.pools.len(),
"stage index {stage} is out of bounds (num_stages = {})",
self.pools.len()
);
self.pools[stage].active_cuts()
}
#[must_use]
pub fn evaluate_at_state(&self, stage: usize, values: &[f64]) -> f64 {
debug_assert!(
stage < self.pools.len(),
"stage index {stage} is out of bounds (num_stages = {})",
self.pools.len()
);
self.pools[stage].evaluate_at_state(values)
}
#[must_use]
pub fn total_active_cuts(&self) -> usize {
self.pools.iter().map(CutPool::active_count).sum()
}
pub fn deactivate(&mut self, stage: usize, indices: &[u32]) {
debug_assert!(
stage < self.pools.len(),
"stage index {stage} is out of bounds (num_stages = {})",
self.pools.len()
);
self.pools[stage].deactivate(indices);
}
}
#[cfg(test)]
mod tests {
use super::FutureCostFunction;
#[test]
fn new_creates_correct_number_of_pools() {
let fcf = FutureCostFunction::new(5, 9, 10, 100, 0);
assert_eq!(fcf.pools.len(), 5);
}
#[test]
fn new_each_pool_has_correct_capacity_no_warmstart() {
let fcf = FutureCostFunction::new(5, 9, 10, 100, 0);
for pool in &fcf.pools {
assert_eq!(pool.capacity, 1000);
assert_eq!(pool.state_dimension, 9);
assert_eq!(pool.forward_passes, 10);
assert_eq!(pool.warm_start_count, 0);
}
}
#[test]
fn new_each_pool_has_correct_capacity_with_warmstart() {
let fcf = FutureCostFunction::new(3, 4, 10, 100, 5);
for pool in &fcf.pools {
assert_eq!(pool.capacity, 1005);
assert_eq!(pool.warm_start_count, 5);
}
}
#[test]
fn new_all_pools_start_with_zero_active_cuts() {
let fcf = FutureCostFunction::new(4, 3, 5, 20, 0);
assert_eq!(fcf.total_active_cuts(), 0);
}
#[test]
fn new_zero_stages_is_valid() {
let fcf = FutureCostFunction::new(0, 4, 5, 10, 0);
assert_eq!(fcf.pools.len(), 0);
assert_eq!(fcf.total_active_cuts(), 0);
}
#[test]
fn add_cut_and_active_cuts_round_trip_at_specific_stage() {
let mut fcf = FutureCostFunction::new(5, 2, 1, 10, 0);
let coeffs = [3.0, 7.0];
fcf.add_cut(2, 0, 0, 42.0, &coeffs);
let active: Vec<_> = fcf.active_cuts(2).collect();
assert_eq!(active.len(), 1);
let (_, intercept, c) = active[0];
assert_eq!(intercept, 42.0);
assert_eq!(c, &[3.0, 7.0]);
}
#[test]
fn active_cuts_at_other_stage_returns_empty() {
let mut fcf = FutureCostFunction::new(5, 2, 1, 10, 0);
fcf.add_cut(2, 0, 0, 42.0, &[1.0, 2.0]);
let active: Vec<_> = fcf.active_cuts(3).collect();
assert!(active.is_empty());
}
#[test]
fn add_cut_multiple_stages_are_independent() {
let mut fcf = FutureCostFunction::new(4, 1, 1, 10, 0);
fcf.add_cut(0, 0, 0, 1.0, &[1.0]);
fcf.add_cut(1, 0, 0, 2.0, &[2.0]);
fcf.add_cut(3, 0, 0, 4.0, &[4.0]);
assert_eq!(fcf.active_cuts(0).count(), 1);
assert_eq!(fcf.active_cuts(1).count(), 1);
assert_eq!(fcf.active_cuts(2).count(), 0);
assert_eq!(fcf.active_cuts(3).count(), 1);
}
#[test]
fn evaluate_at_state_delegates_to_correct_pool() {
let mut fcf = FutureCostFunction::new(3, 2, 1, 10, 0);
fcf.add_cut(1, 0, 0, 10.0, &[1.0, 0.0]);
fcf.add_cut(2, 0, 0, 5.0, &[0.0, 2.0]);
assert_eq!(fcf.evaluate_at_state(1, &[3.0, 4.0]), 13.0);
assert_eq!(fcf.evaluate_at_state(2, &[3.0, 4.0]), 13.0);
assert_eq!(fcf.evaluate_at_state(0, &[3.0, 4.0]), f64::NEG_INFINITY);
}
#[test]
fn total_active_cuts_sums_across_stages() {
let mut fcf = FutureCostFunction::new(4, 1, 1, 20, 0);
fcf.add_cut(0, 0, 0, 1.0, &[1.0]);
fcf.add_cut(1, 0, 0, 2.0, &[2.0]);
fcf.add_cut(1, 1, 0, 3.0, &[3.0]);
fcf.add_cut(3, 0, 0, 4.0, &[4.0]);
assert_eq!(fcf.total_active_cuts(), 4);
}
#[test]
fn total_active_cuts_reflects_deactivation() {
let mut fcf = FutureCostFunction::new(2, 1, 1, 10, 0);
fcf.add_cut(0, 0, 0, 1.0, &[1.0]); fcf.add_cut(0, 1, 0, 2.0, &[2.0]); fcf.add_cut(1, 0, 0, 3.0, &[3.0]);
assert_eq!(fcf.total_active_cuts(), 3);
fcf.deactivate(0, &[0]);
assert_eq!(fcf.total_active_cuts(), 2);
}
#[test]
fn deactivate_delegates_to_correct_pool() {
let mut fcf = FutureCostFunction::new(3, 1, 1, 10, 0);
fcf.add_cut(1, 0, 0, 10.0, &[1.0]); fcf.add_cut(1, 1, 0, 20.0, &[2.0]); fcf.add_cut(2, 0, 0, 30.0, &[3.0]);
fcf.deactivate(1, &[0]);
assert_eq!(fcf.active_cuts(1).count(), 1);
assert_eq!(fcf.active_cuts(2).count(), 1);
}
#[test]
fn ac_new_5_stages_pools_len_is_5() {
let fcf = FutureCostFunction::new(5, 9, 10, 100, 0);
assert_eq!(fcf.pools.len(), 5);
}
#[test]
fn ac_active_cuts_at_stage_with_cut_yields_it() {
let mut fcf = FutureCostFunction::new(5, 3, 1, 10, 0);
let coeffs = [1.0, 2.0, 3.0];
fcf.add_cut(2, 0, 0, 99.0, &coeffs);
let active: Vec<_> = fcf.active_cuts(2).collect();
assert_eq!(active.len(), 1);
}
#[test]
fn ac_active_cuts_at_different_stage_yields_none() {
let mut fcf = FutureCostFunction::new(5, 3, 1, 10, 0);
fcf.add_cut(2, 0, 0, 99.0, &[1.0, 2.0, 3.0]);
let active: Vec<_> = fcf.active_cuts(3).collect();
assert!(active.is_empty());
}
#[test]
fn ac_total_active_cuts_is_sum_across_stages() {
let mut fcf = FutureCostFunction::new(5, 1, 1, 10, 0);
fcf.add_cut(0, 0, 0, 1.0, &[1.0]);
fcf.add_cut(1, 0, 0, 2.0, &[2.0]);
fcf.add_cut(1, 1, 0, 3.0, &[3.0]);
fcf.add_cut(4, 0, 0, 4.0, &[4.0]);
assert_eq!(fcf.total_active_cuts(), 4);
}
#[test]
fn fcf_derives_debug_and_clone() {
let mut fcf = FutureCostFunction::new(2, 2, 1, 5, 0);
fcf.add_cut(0, 0, 0, 7.0, &[1.0, 2.0]);
let cloned = fcf.clone();
assert_eq!(cloned.total_active_cuts(), 1);
assert_eq!(cloned.evaluate_at_state(0, &[0.0, 0.0]), 7.0);
let debug_str = format!("{fcf:?}");
assert!(!debug_str.is_empty());
}
}