use crate::cut_selection::CutMetadata;
#[derive(Debug, Clone)]
pub struct CutPool {
pub coefficients: Vec<Vec<f64>>,
pub intercepts: Vec<f64>,
pub metadata: Vec<CutMetadata>,
pub active: Vec<bool>,
pub populated_count: usize,
pub capacity: usize,
pub state_dimension: usize,
pub forward_passes: u32,
pub warm_start_count: u32,
}
impl CutPool {
#[must_use]
pub fn new(
capacity: usize,
state_dimension: usize,
forward_passes: u32,
warm_start_count: u32,
) -> Self {
let default_meta = CutMetadata {
iteration_generated: 0,
forward_pass_index: 0,
active_count: 0,
last_active_iter: 0,
domination_count: 0,
};
Self {
coefficients: vec![vec![0.0; state_dimension]; capacity],
intercepts: vec![0.0; capacity],
metadata: vec![default_meta; capacity],
active: vec![false; capacity],
populated_count: 0,
capacity,
state_dimension,
forward_passes,
warm_start_count,
}
}
#[inline]
fn slot_index(&self, iteration: u64, forward_pass_index: u32) -> usize {
#[allow(clippy::cast_possible_truncation)]
let iter_usize = iteration as usize;
self.warm_start_count as usize
+ iter_usize * self.forward_passes as usize
+ forward_pass_index as usize
}
pub fn add_cut(
&mut self,
iteration: u64,
forward_pass_index: u32,
intercept: f64,
coefficients: &[f64],
) {
let slot = self.slot_index(iteration, forward_pass_index);
debug_assert!(
slot < self.capacity,
"cut slot {slot} is out of bounds (capacity = {})",
self.capacity
);
debug_assert!(
coefficients.len() == self.state_dimension,
"coefficients length {} != state_dimension {}",
coefficients.len(),
self.state_dimension
);
self.intercepts[slot] = intercept;
self.coefficients[slot].copy_from_slice(coefficients);
self.active[slot] = true;
self.metadata[slot] = CutMetadata {
iteration_generated: iteration,
forward_pass_index,
active_count: 0,
last_active_iter: iteration,
domination_count: 0,
};
if slot >= self.populated_count {
self.populated_count = slot + 1;
}
}
pub fn active_cuts(&self) -> impl Iterator<Item = (usize, f64, &[f64])> {
self.active[..self.populated_count]
.iter()
.enumerate()
.filter(|&(_, &is_active)| is_active)
.map(|(i, _)| (i, self.intercepts[i], self.coefficients[i].as_slice()))
}
#[must_use]
pub fn active_count(&self) -> usize {
self.active[..self.populated_count]
.iter()
.filter(|&&a| a)
.count()
}
pub fn deactivate(&mut self, indices: &[u32]) {
for &idx in indices {
let i = idx as usize;
debug_assert!(i < self.capacity, "deactivate index {i} out of bounds");
if i < self.capacity {
self.active[i] = false;
}
}
}
#[must_use]
pub fn evaluate_at_state(&self, state: &[f64]) -> f64 {
debug_assert!(
state.len() == self.state_dimension,
"state length {} != state_dimension {}",
state.len(),
self.state_dimension
);
self.active_cuts()
.map(|(_, intercept, coeffs)| {
let dot: f64 = coeffs.iter().zip(state).map(|(a, b)| a * b).sum();
intercept + dot
})
.fold(f64::NEG_INFINITY, f64::max)
}
}
#[cfg(test)]
mod tests {
use super::CutPool;
#[test]
fn new_creates_pool_with_correct_capacity_and_all_inactive() {
let pool = CutPool::new(100, 9, 10, 0);
assert_eq!(pool.capacity, 100);
assert_eq!(pool.state_dimension, 9);
assert_eq!(pool.forward_passes, 10);
assert_eq!(pool.warm_start_count, 0);
assert_eq!(pool.populated_count, 0);
assert_eq!(pool.active_count(), 0);
assert!(pool.active.iter().all(|&a| !a));
assert_eq!(pool.coefficients.len(), 100);
assert!(
pool.coefficients
.iter()
.all(|c| c.iter().all(|&v| v == 0.0))
);
assert!(pool.intercepts.iter().all(|&v| v == 0.0));
}
#[test]
fn new_zero_capacity_is_valid() {
let pool = CutPool::new(0, 4, 5, 0);
assert_eq!(pool.capacity, 0);
assert_eq!(pool.active_count(), 0);
}
#[test]
fn add_cut_at_slot_zero_stores_intercept_coefficients_and_active_flag() {
let mut pool = CutPool::new(100, 9, 10, 0);
let coeffs = vec![1.0; 9];
pool.add_cut(0, 0, 5.0, &coeffs);
assert_eq!(pool.active_count(), 1);
assert!(pool.active[0]);
assert_eq!(pool.intercepts[0], 5.0);
assert_eq!(pool.coefficients[0], vec![1.0; 9]);
assert_eq!(pool.metadata[0].iteration_generated, 0);
assert_eq!(pool.metadata[0].forward_pass_index, 0);
assert_eq!(pool.populated_count, 1);
}
#[test]
fn add_cut_deterministic_slot_formula_no_warmstart() {
let mut pool = CutPool::new(200, 2, 10, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 2.0]); pool.add_cut(0, 3, 2.0, &[3.0, 4.0]); pool.add_cut(1, 0, 3.0, &[5.0, 6.0]); pool.add_cut(2, 5, 4.0, &[7.0, 8.0]);
assert!(pool.active[0]);
assert_eq!(pool.intercepts[0], 1.0);
assert!(pool.active[3]);
assert_eq!(pool.intercepts[3], 2.0);
assert!(pool.active[10]);
assert_eq!(pool.intercepts[10], 3.0);
assert!(pool.active[25]);
assert_eq!(pool.intercepts[25], 4.0);
}
#[test]
fn add_cut_warm_start_count_offsets_slot() {
let mut pool = CutPool::new(100, 9, 10, 5);
let coeffs = vec![0.0; 9];
pool.add_cut(0, 0, 42.0, &coeffs);
assert!(pool.active[5]);
assert_eq!(pool.intercepts[5], 42.0);
assert_eq!(pool.populated_count, 6);
}
#[test]
fn add_cut_metadata_initialized_correctly() {
let mut pool = CutPool::new(50, 3, 5, 0);
pool.add_cut(3, 2, 7.0, &[1.0, 2.0, 3.0]);
let meta = &pool.metadata[17];
assert_eq!(meta.iteration_generated, 3);
assert_eq!(meta.forward_pass_index, 2);
assert_eq!(meta.active_count, 0);
assert_eq!(meta.last_active_iter, 3);
assert_eq!(meta.domination_count, 0);
}
#[test]
fn populated_count_tracks_high_water_mark() {
let mut pool = CutPool::new(50, 1, 5, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); assert_eq!(pool.populated_count, 1);
pool.add_cut(1, 0, 2.0, &[2.0]); assert_eq!(pool.populated_count, 6);
pool.add_cut(0, 2, 3.0, &[3.0]); assert_eq!(pool.populated_count, 6);
}
#[test]
fn active_cuts_returns_only_active_cuts() {
let mut pool = CutPool::new(20, 2, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 2.0]); pool.add_cut(1, 0, 2.0, &[3.0, 4.0]); pool.add_cut(2, 0, 3.0, &[5.0, 6.0]);
pool.deactivate(&[1]);
let active: Vec<_> = pool.active_cuts().collect();
assert_eq!(active.len(), 2);
let slots: Vec<usize> = active.iter().map(|(s, _, _)| *s).collect();
assert!(slots.contains(&0));
assert!(slots.contains(&2));
assert!(!slots.contains(&1));
}
#[test]
fn active_cuts_empty_pool_returns_empty_iterator() {
let pool = CutPool::new(10, 3, 5, 0);
let active: Vec<_> = pool.active_cuts().collect();
assert!(active.is_empty());
}
#[test]
fn active_count_is_correct_after_add_and_deactivate() {
let mut pool = CutPool::new(20, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); pool.add_cut(1, 0, 2.0, &[2.0]); pool.add_cut(2, 0, 3.0, &[3.0]);
assert_eq!(pool.active_count(), 3);
pool.deactivate(&[1]);
assert_eq!(pool.active_count(), 2);
}
#[test]
fn deactivate_sets_flags_correctly() {
let mut pool = CutPool::new(20, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); pool.add_cut(1, 0, 2.0, &[2.0]); pool.add_cut(2, 0, 3.0, &[3.0]);
pool.deactivate(&[1]);
assert!(pool.active[0]);
assert!(!pool.active[1]);
assert!(pool.active[2]);
assert_eq!(pool.active_count(), 2);
}
#[test]
fn deactivate_multiple_indices() {
let mut pool = CutPool::new(20, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); pool.add_cut(1, 0, 2.0, &[2.0]); pool.add_cut(2, 0, 3.0, &[3.0]);
pool.deactivate(&[0, 2]);
assert!(!pool.active[0]);
assert!(pool.active[1]);
assert!(!pool.active[2]);
assert_eq!(pool.active_count(), 1);
}
#[test]
fn deactivate_empty_slice_is_noop() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.deactivate(&[]);
assert_eq!(pool.active_count(), 1);
}
#[test]
fn evaluate_at_state_returns_max_cut_value() {
let mut pool = CutPool::new(10, 2, 1, 0);
pool.add_cut(0, 0, 10.0, &[1.0, 0.0]);
pool.add_cut(1, 0, 5.0, &[0.0, 2.0]);
let result = pool.evaluate_at_state(&[3.0, 4.0]);
assert_eq!(result, 13.0);
}
#[test]
fn evaluate_at_state_selects_correct_max() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 2.0, &[1.0]);
pool.add_cut(1, 0, 5.0, &[2.0]);
let result = pool.evaluate_at_state(&[10.0]);
assert_eq!(result, 25.0);
}
#[test]
fn evaluate_at_state_empty_pool_returns_neg_infinity() {
let pool = CutPool::new(10, 3, 5, 0);
assert_eq!(pool.evaluate_at_state(&[1.0, 2.0, 3.0]), f64::NEG_INFINITY);
}
#[test]
fn evaluate_at_state_all_deactivated_returns_neg_infinity() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 100.0, &[1.0]);
pool.deactivate(&[0]);
assert_eq!(pool.evaluate_at_state(&[5.0]), f64::NEG_INFINITY);
}
#[test]
fn evaluate_at_state_ignores_deactivated_cuts() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 10.0, &[1.0]);
pool.add_cut(1, 0, 100.0, &[1.0]);
pool.deactivate(&[1]);
assert_eq!(pool.evaluate_at_state(&[3.0]), 13.0);
}
#[test]
fn ac_add_cut_stores_at_slot_zero_and_active_count_is_one() {
let mut pool = CutPool::new(100, 9, 10, 0);
let coeffs = vec![0.0; 9];
pool.add_cut(0, 0, 5.0, &coeffs);
assert!(pool.active[0]);
assert_eq!(pool.active_count(), 1);
}
#[test]
fn ac_deactivate_reduces_active_count_correctly() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.add_cut(1, 0, 2.0, &[2.0]);
pool.add_cut(2, 0, 3.0, &[3.0]);
pool.deactivate(&[1]);
assert_eq!(pool.active_count(), 2);
assert!(!pool.active[1]);
}
#[test]
fn ac_evaluate_at_state_returns_correct_max() {
let mut pool = CutPool::new(10, 2, 1, 0);
pool.add_cut(0, 0, 10.0, &[1.0, 0.0]);
pool.add_cut(1, 0, 5.0, &[0.0, 2.0]);
assert_eq!(pool.evaluate_at_state(&[3.0, 4.0]), 13.0);
}
#[test]
fn ac_warm_start_count_offsets_slot() {
let mut pool = CutPool::new(100, 9, 10, 5);
let coeffs = vec![0.0; 9];
pool.add_cut(0, 0, 1.0, &coeffs);
assert!(pool.active[5]);
assert!(!pool.active[0]);
}
#[test]
fn ac_empty_pool_evaluate_returns_neg_infinity() {
let pool = CutPool::new(10, 2, 1, 0);
assert_eq!(pool.evaluate_at_state(&[1.0, 2.0]), f64::NEG_INFINITY);
}
#[test]
fn cut_pool_derives_debug_and_clone() {
let mut pool = CutPool::new(5, 2, 1, 0);
pool.add_cut(0, 0, 3.0, &[1.0, 2.0]);
let cloned = pool.clone();
assert_eq!(cloned.active_count(), 1);
assert_eq!(cloned.intercepts[0], 3.0);
let debug_str = format!("{pool:?}");
assert!(!debug_str.is_empty());
}
}