use crate::cut::WARM_START_ITERATION;
use crate::cut_selection::CutMetadata;
#[derive(Debug, Clone)]
pub struct CutPool {
pub coefficients: Vec<f64>,
pub intercepts: Vec<f64>,
pub metadata: Vec<CutMetadata>,
pub active: Vec<bool>,
pub populated_count: usize,
pub capacity: usize,
pub state_dimension: usize,
pub forward_passes: u32,
pub warm_start_count: u32,
pub iteration_base: u64,
pub cached_active_count: usize,
pub generated_count: usize,
pub(crate) candidates_buf: Vec<u32>,
}
impl CutPool {
#[must_use]
pub fn new(
capacity: usize,
state_dimension: usize,
forward_passes: u32,
warm_start_count: u32,
) -> Self {
let default_meta = CutMetadata {
iteration_generated: 0,
forward_pass_index: 0,
active_count: 0,
last_active_iter: 0,
};
Self {
coefficients: vec![0.0; capacity * state_dimension],
intercepts: vec![0.0; capacity],
metadata: vec![default_meta; capacity],
active: vec![false; capacity],
populated_count: 0,
capacity,
state_dimension,
forward_passes,
warm_start_count,
iteration_base: 0,
cached_active_count: 0,
generated_count: warm_start_count as usize,
candidates_buf: Vec::new(),
}
}
#[inline]
fn slot_index(&self, iteration: u64, forward_pass_index: u32) -> usize {
debug_assert!(
iteration >= self.iteration_base,
"slot_index: iteration {iteration} < iteration_base {}",
self.iteration_base
);
#[allow(clippy::cast_possible_truncation)]
let iter_usize = (iteration - self.iteration_base) as usize;
self.warm_start_count as usize
+ iter_usize * self.forward_passes as usize
+ forward_pass_index as usize
}
pub fn set_iteration_base(&mut self, iteration_base: u64) {
self.iteration_base = iteration_base;
}
pub fn add_cut(
&mut self,
iteration: u64,
forward_pass_index: u32,
intercept: f64,
coefficients: &[f64],
) {
let slot = self.slot_index(iteration, forward_pass_index);
debug_assert!(
slot < self.capacity,
"cut slot {slot} is out of bounds (capacity = {})",
self.capacity
);
debug_assert!(
coefficients.len() == self.state_dimension,
"coefficients length {} != state_dimension {}",
coefficients.len(),
self.state_dimension
);
self.intercepts[slot] = intercept;
let start = slot * self.state_dimension;
self.coefficients[start..start + self.state_dimension].copy_from_slice(coefficients);
debug_assert!(
!self.active[slot],
"add_cut: slot {slot} is already active (double-insert)"
);
self.active[slot] = true;
self.cached_active_count += 1;
self.metadata[slot] = CutMetadata {
iteration_generated: iteration,
forward_pass_index,
active_count: 0,
last_active_iter: iteration,
};
if slot >= self.populated_count {
self.populated_count = slot + 1;
}
self.generated_count += 1;
}
pub fn active_cuts(&self) -> impl Iterator<Item = (usize, f64, &[f64])> {
let mut remaining = self.cached_active_count;
self.active[..self.populated_count]
.iter()
.enumerate()
.scan((), move |(), (i, &is_active)| {
if remaining == 0 {
return None;
}
if is_active {
remaining -= 1;
let start = i * self.state_dimension;
Some(Some((
i,
self.intercepts[i],
&self.coefficients[start..start + self.state_dimension],
)))
} else {
Some(None)
}
})
.flatten()
}
pub(crate) fn active_delta_cuts(
&self,
current_iteration: u64,
) -> impl Iterator<Item = (usize, f64, &[f64])> {
let mut remaining = self.cached_active_count;
self.active[..self.populated_count]
.iter()
.enumerate()
.scan((), move |(), (slot, &is_active)| {
if remaining == 0 {
return None;
}
if is_active {
remaining -= 1;
Some(Some(slot))
} else {
Some(None)
}
})
.flatten()
.filter(move |&slot| {
self.metadata[slot].iteration_generated == current_iteration
&& self.metadata[slot].iteration_generated != WARM_START_ITERATION
})
.map(|i| {
let start = i * self.state_dimension;
(
i,
self.intercepts[i],
&self.coefficients[start..start + self.state_dimension],
)
})
}
#[must_use]
pub fn active_count(&self) -> usize {
debug_assert_eq!(
self.cached_active_count,
self.active[..self.populated_count]
.iter()
.filter(|&&a| a)
.count(),
"cached active_count {} != computed {}",
self.cached_active_count,
self.active[..self.populated_count]
.iter()
.filter(|&&a| a)
.count(),
);
self.cached_active_count
}
#[must_use]
#[inline]
pub fn cuts_in_lp(&self) -> usize {
self.populated_count
}
pub fn deactivate(&mut self, indices: &[u32]) {
for &idx in indices {
self.set_active(idx, false);
}
}
pub fn apply_updates(&mut self, updates: &crate::cut_selection::CutActivityUpdates) {
for &slot in &updates.updates {
self.set_active(slot, false);
}
for &slot in &updates.reactivations {
self.set_active(slot, true);
}
}
pub fn set_active(&mut self, slot: u32, active: bool) {
let i = slot as usize;
debug_assert!(
i < self.populated_count,
"set_active slot {i} out of populated range"
);
if self.active[i] == active {
return;
}
if active {
self.active[i] = true;
self.cached_active_count += 1;
} else {
self.active[i] = false;
self.cached_active_count -= 1;
}
}
#[must_use]
pub fn evaluate_at_state(&self, state: &[f64]) -> f64 {
debug_assert!(
state.len() == self.state_dimension,
"state length {} != state_dimension {}",
state.len(),
self.state_dimension
);
self.active_cuts()
.map(|(_, intercept, coeffs)| {
let dot: f64 = coeffs.iter().zip(state).map(|(a, b)| a * b).sum();
intercept + dot
})
.fold(f64::NEG_INFINITY, f64::max)
}
#[must_use]
pub fn sparsity_report(&self) -> SparsityReport {
let active_count = self.active_count();
let mut exact_zero_count = 0usize;
let mut per_dimension_zeros = vec![0usize; self.state_dimension];
for (_slot, _intercept, coeffs) in self.active_cuts() {
for (j, &c) in coeffs.iter().enumerate() {
if c == 0.0 {
exact_zero_count += 1;
per_dimension_zeros[j] += 1;
}
}
}
let total = active_count * self.state_dimension;
#[allow(clippy::cast_precision_loss)]
let fraction = if total > 0 {
exact_zero_count as f64 / total as f64
} else {
0.0
};
SparsityReport {
total_coefficients: total,
exact_zero_count,
sparsity_fraction: fraction,
per_dimension_zeros,
}
}
#[must_use]
pub fn from_deserialized(
state_dimension: usize,
records: &[cobre_io::OwnedPolicyCutRecord],
) -> Self {
let capacity = records.len();
let mut coefficients = Vec::with_capacity(capacity * state_dimension);
let mut intercepts = Vec::with_capacity(capacity);
let mut active = Vec::with_capacity(capacity);
let mut metadata = Vec::with_capacity(capacity);
let mut cached_active_count = 0usize;
for record in records {
debug_assert!(
record.coefficients.len() == state_dimension,
"from_deserialized: coefficients length {} != state_dimension {}",
record.coefficients.len(),
state_dimension
);
coefficients.extend_from_slice(&record.coefficients);
intercepts.push(record.intercept);
active.push(record.is_active);
if record.is_active {
cached_active_count += 1;
}
metadata.push(CutMetadata {
iteration_generated: u64::from(record.iteration),
forward_pass_index: record.forward_pass_index,
active_count: 0,
last_active_iter: u64::from(record.iteration),
});
}
#[allow(clippy::cast_possible_truncation)]
Self {
coefficients,
intercepts,
metadata,
active,
populated_count: capacity,
capacity,
state_dimension,
forward_passes: 0,
warm_start_count: capacity as u32,
iteration_base: 0,
cached_active_count,
generated_count: capacity,
candidates_buf: Vec::new(),
}
}
#[must_use]
pub fn new_with_warm_start(
state_dimension: usize,
forward_passes: u32,
max_iterations: u64,
records: &[cobre_io::OwnedPolicyCutRecord],
) -> Self {
let warm_start_count = records.len();
#[allow(clippy::cast_possible_truncation)]
let capacity = warm_start_count + (max_iterations as usize) * (forward_passes as usize);
let default_meta = CutMetadata {
iteration_generated: 0,
forward_pass_index: 0,
active_count: 0,
last_active_iter: 0,
};
let mut coefficients = vec![0.0_f64; capacity * state_dimension];
let mut intercepts = vec![0.0; capacity];
let mut active = vec![false; capacity];
let mut metadata = vec![default_meta; capacity];
let mut cached_active_count = 0usize;
for (i, record) in records.iter().enumerate() {
debug_assert!(
record.coefficients.len() == state_dimension,
"new_with_warm_start: coefficients length {} != state_dimension {}",
record.coefficients.len(),
state_dimension
);
let start = i * state_dimension;
coefficients[start..start + state_dimension].copy_from_slice(&record.coefficients);
intercepts[i] = record.intercept;
active[i] = record.is_active;
if record.is_active {
cached_active_count += 1;
}
metadata[i] = CutMetadata {
iteration_generated: WARM_START_ITERATION,
forward_pass_index: record.forward_pass_index,
active_count: 0,
last_active_iter: u64::from(record.iteration),
};
}
#[allow(clippy::cast_possible_truncation)]
Self {
coefficients,
intercepts,
metadata,
active,
populated_count: warm_start_count,
capacity,
state_dimension,
forward_passes,
warm_start_count: warm_start_count as u32,
iteration_base: 0,
cached_active_count,
generated_count: warm_start_count,
candidates_buf: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct SparsityReport {
pub total_coefficients: usize,
pub exact_zero_count: usize,
pub sparsity_fraction: f64,
pub per_dimension_zeros: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct BudgetEnforcementResult {
pub evicted_count: u32,
pub active_before: u32,
pub active_after: u32,
}
impl CutPool {
pub fn enforce_budget(
&mut self,
budget: u32,
current_iteration: u64,
_forward_passes: u32,
) -> BudgetEnforcementResult {
#[allow(clippy::cast_possible_truncation)]
let active_before = self.active_count() as u32;
let budget_usize = budget as usize;
if self.cached_active_count <= budget_usize {
return BudgetEnforcementResult {
evicted_count: 0,
active_before,
active_after: active_before,
};
}
let excess = self.cached_active_count - budget_usize;
self.candidates_buf.clear();
#[allow(clippy::cast_possible_truncation)]
self.candidates_buf.extend(
self.active[..self.populated_count]
.iter()
.enumerate()
.filter(|&(slot, &is_active)| {
is_active && self.metadata[slot].iteration_generated != current_iteration
})
.map(|(slot, _)| slot as u32),
);
if self.candidates_buf.is_empty() {
return BudgetEnforcementResult {
evicted_count: 0,
active_before,
active_after: active_before,
};
}
let evict_count = excess.min(self.candidates_buf.len());
let key = |&slot: &u32| {
let meta = &self.metadata[slot as usize];
(meta.last_active_iter, meta.active_count)
};
if evict_count < self.candidates_buf.len() / 2 {
self.candidates_buf
.select_nth_unstable_by(evict_count, |a, b| key(a).cmp(&key(b)));
} else {
self.candidates_buf.sort_unstable_by_key(|a| key(a));
}
let to_evict: Vec<u32> = self.candidates_buf[..evict_count].to_vec();
self.deactivate(&to_evict);
#[allow(clippy::cast_possible_truncation)]
let evicted_count = evict_count as u32;
#[allow(clippy::cast_possible_truncation)]
let active_after = self.active_count() as u32;
BudgetEnforcementResult {
evicted_count,
active_before,
active_after,
}
}
}
#[cfg(test)]
mod tests {
use super::CutPool;
#[test]
fn new_creates_pool_with_correct_capacity_and_all_inactive() {
let pool = CutPool::new(100, 9, 10, 0);
assert_eq!(pool.capacity, 100);
assert_eq!(pool.state_dimension, 9);
assert_eq!(pool.forward_passes, 10);
assert_eq!(pool.warm_start_count, 0);
assert_eq!(pool.populated_count, 0);
assert_eq!(pool.active_count(), 0);
assert!(pool.active.iter().all(|&a| !a));
assert_eq!(pool.coefficients.len(), 100 * 9);
assert!(pool.coefficients.iter().all(|&v| v == 0.0));
assert!(pool.intercepts.iter().all(|&v| v == 0.0));
}
#[test]
fn new_zero_capacity_is_valid() {
let pool = CutPool::new(0, 4, 5, 0);
assert_eq!(pool.capacity, 0);
assert_eq!(pool.active_count(), 0);
}
#[test]
fn add_cut_at_slot_zero_stores_intercept_coefficients_and_active_flag() {
let mut pool = CutPool::new(100, 9, 10, 0);
let coeffs = vec![1.0; 9];
pool.add_cut(0, 0, 5.0, &coeffs);
assert_eq!(pool.active_count(), 1);
assert!(pool.active[0]);
assert_eq!(pool.intercepts[0], 5.0);
assert_eq!(&pool.coefficients[0..9], vec![1.0; 9].as_slice());
assert_eq!(pool.metadata[0].iteration_generated, 0);
assert_eq!(pool.metadata[0].forward_pass_index, 0);
assert_eq!(pool.populated_count, 1);
}
#[test]
fn add_cut_deterministic_slot_formula_no_warmstart() {
let mut pool = CutPool::new(200, 2, 10, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 2.0]); pool.add_cut(0, 3, 2.0, &[3.0, 4.0]); pool.add_cut(1, 0, 3.0, &[5.0, 6.0]); pool.add_cut(2, 5, 4.0, &[7.0, 8.0]);
assert!(pool.active[0]);
assert_eq!(pool.intercepts[0], 1.0);
assert!(pool.active[3]);
assert_eq!(pool.intercepts[3], 2.0);
assert!(pool.active[10]);
assert_eq!(pool.intercepts[10], 3.0);
assert!(pool.active[25]);
assert_eq!(pool.intercepts[25], 4.0);
}
#[test]
fn add_cut_warm_start_count_offsets_slot() {
let mut pool = CutPool::new(100, 9, 10, 5);
let coeffs = vec![0.0; 9];
pool.add_cut(0, 0, 42.0, &coeffs);
assert!(pool.active[5]);
assert_eq!(pool.intercepts[5], 42.0);
assert_eq!(pool.populated_count, 6);
}
#[test]
fn add_cut_metadata_initialized_correctly() {
let mut pool = CutPool::new(50, 3, 5, 0);
pool.add_cut(3, 2, 7.0, &[1.0, 2.0, 3.0]);
let meta = &pool.metadata[17];
assert_eq!(meta.iteration_generated, 3);
assert_eq!(meta.forward_pass_index, 2);
assert_eq!(meta.active_count, 0);
assert_eq!(meta.last_active_iter, 3);
}
#[test]
fn populated_count_tracks_high_water_mark() {
let mut pool = CutPool::new(50, 1, 5, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); assert_eq!(pool.populated_count, 1);
pool.add_cut(1, 0, 2.0, &[2.0]); assert_eq!(pool.populated_count, 6);
pool.add_cut(0, 2, 3.0, &[3.0]); assert_eq!(pool.populated_count, 6);
}
#[test]
fn set_iteration_base_packs_training_cuts_densely() {
let mut pool = CutPool::new(30, 1, 3, 0);
pool.set_iteration_base(1);
pool.add_cut(1, 0, 1.0, &[1.0]); pool.add_cut(1, 1, 2.0, &[1.0]); pool.add_cut(1, 2, 3.0, &[1.0]); pool.add_cut(2, 0, 4.0, &[1.0]); assert!(pool.active[0] && pool.active[1] && pool.active[2] && pool.active[3]);
assert_eq!(pool.populated_count, 4, "dense packing leaves no gap");
assert_eq!(pool.generated_count, 4);
assert_eq!(pool.metadata[0].iteration_generated, 1);
assert_eq!(pool.metadata[3].iteration_generated, 2);
}
#[test]
fn active_cuts_returns_only_active_cuts() {
let mut pool = CutPool::new(20, 2, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 2.0]); pool.add_cut(1, 0, 2.0, &[3.0, 4.0]); pool.add_cut(2, 0, 3.0, &[5.0, 6.0]);
pool.deactivate(&[1]);
let active: Vec<_> = pool.active_cuts().collect();
assert_eq!(active.len(), 2);
let slots: Vec<usize> = active.iter().map(|(s, _, _)| *s).collect();
assert!(slots.contains(&0));
assert!(slots.contains(&2));
assert!(!slots.contains(&1));
}
#[test]
fn active_cuts_empty_pool_returns_empty_iterator() {
let pool = CutPool::new(10, 3, 5, 0);
let active: Vec<_> = pool.active_cuts().collect();
assert!(active.is_empty());
}
#[test]
fn active_count_is_correct_after_add_and_deactivate() {
let mut pool = CutPool::new(20, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); pool.add_cut(1, 0, 2.0, &[2.0]); pool.add_cut(2, 0, 3.0, &[3.0]);
assert_eq!(pool.active_count(), 3);
pool.deactivate(&[1]);
assert_eq!(pool.active_count(), 2);
}
#[test]
fn deactivate_sets_flags_correctly() {
let mut pool = CutPool::new(20, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); pool.add_cut(1, 0, 2.0, &[2.0]); pool.add_cut(2, 0, 3.0, &[3.0]);
pool.deactivate(&[1]);
assert!(pool.active[0]);
assert!(!pool.active[1]);
assert!(pool.active[2]);
assert_eq!(pool.active_count(), 2);
}
#[test]
fn deactivate_multiple_indices() {
let mut pool = CutPool::new(20, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); pool.add_cut(1, 0, 2.0, &[2.0]); pool.add_cut(2, 0, 3.0, &[3.0]);
pool.deactivate(&[0, 2]);
assert!(!pool.active[0]);
assert!(pool.active[1]);
assert!(!pool.active[2]);
assert_eq!(pool.active_count(), 1);
}
#[test]
fn deactivate_empty_slice_is_noop() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.deactivate(&[]);
assert_eq!(pool.active_count(), 1);
}
#[test]
#[cfg(not(debug_assertions))]
fn deactivate_duplicate_index_is_silently_skipped() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.add_cut(1, 0, 2.0, &[2.0]);
assert_eq!(pool.active_count(), 2);
pool.deactivate(&[0, 0, 0]);
assert_eq!(
pool.active_count(),
1,
"duplicate indices must not double-decrement"
);
assert!(!pool.active[0]);
assert!(pool.active[1]);
}
#[test]
fn evaluate_at_state_returns_max_cut_value() {
let mut pool = CutPool::new(10, 2, 1, 0);
pool.add_cut(0, 0, 10.0, &[1.0, 0.0]);
pool.add_cut(1, 0, 5.0, &[0.0, 2.0]);
let result = pool.evaluate_at_state(&[3.0, 4.0]);
assert_eq!(result, 13.0);
}
#[test]
fn evaluate_at_state_selects_correct_max() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 2.0, &[1.0]);
pool.add_cut(1, 0, 5.0, &[2.0]);
let result = pool.evaluate_at_state(&[10.0]);
assert_eq!(result, 25.0);
}
#[test]
fn evaluate_at_state_empty_pool_returns_neg_infinity() {
let pool = CutPool::new(10, 3, 5, 0);
assert_eq!(pool.evaluate_at_state(&[1.0, 2.0, 3.0]), f64::NEG_INFINITY);
}
#[test]
fn evaluate_at_state_all_deactivated_returns_neg_infinity() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 100.0, &[1.0]);
pool.deactivate(&[0]);
assert_eq!(pool.evaluate_at_state(&[5.0]), f64::NEG_INFINITY);
}
#[test]
fn evaluate_at_state_ignores_deactivated_cuts() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 10.0, &[1.0]);
pool.add_cut(1, 0, 100.0, &[1.0]);
pool.deactivate(&[1]);
assert_eq!(pool.evaluate_at_state(&[3.0]), 13.0);
}
#[test]
fn ac_add_cut_stores_at_slot_zero_and_active_count_is_one() {
let mut pool = CutPool::new(100, 9, 10, 0);
let coeffs = vec![0.0; 9];
pool.add_cut(0, 0, 5.0, &coeffs);
assert!(pool.active[0]);
assert_eq!(pool.active_count(), 1);
}
#[test]
fn ac_deactivate_reduces_active_count_correctly() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.add_cut(1, 0, 2.0, &[2.0]);
pool.add_cut(2, 0, 3.0, &[3.0]);
pool.deactivate(&[1]);
assert_eq!(pool.active_count(), 2);
assert!(!pool.active[1]);
}
#[test]
fn ac_evaluate_at_state_returns_correct_max() {
let mut pool = CutPool::new(10, 2, 1, 0);
pool.add_cut(0, 0, 10.0, &[1.0, 0.0]);
pool.add_cut(1, 0, 5.0, &[0.0, 2.0]);
assert_eq!(pool.evaluate_at_state(&[3.0, 4.0]), 13.0);
}
#[test]
fn ac_warm_start_count_offsets_slot() {
let mut pool = CutPool::new(100, 9, 10, 5);
let coeffs = vec![0.0; 9];
pool.add_cut(0, 0, 1.0, &coeffs);
assert!(pool.active[5]);
assert!(!pool.active[0]);
}
#[test]
fn ac_empty_pool_evaluate_returns_neg_infinity() {
let pool = CutPool::new(10, 2, 1, 0);
assert_eq!(pool.evaluate_at_state(&[1.0, 2.0]), f64::NEG_INFINITY);
}
#[test]
fn cut_pool_derives_debug_and_clone() {
let mut pool = CutPool::new(5, 2, 1, 0);
pool.add_cut(0, 0, 3.0, &[1.0, 2.0]);
let cloned = pool.clone();
assert_eq!(cloned.active_count(), 1);
assert_eq!(cloned.intercepts[0], 3.0);
let debug_str = format!("{pool:?}");
assert!(!debug_str.is_empty());
}
#[test]
fn sparsity_report_empty_pool() {
let pool = CutPool::new(10, 3, 1, 0);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 0);
assert_eq!(report.exact_zero_count, 0);
assert!((report.sparsity_fraction - 0.0).abs() < f64::EPSILON);
assert_eq!(report.per_dimension_zeros, vec![0, 0, 0]);
}
#[test]
fn sparsity_report_all_nonzero() {
let mut pool = CutPool::new(10, 3, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 2.0, 3.0]);
pool.add_cut(1, 0, 2.0, &[4.0, 5.0, 6.0]);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 6);
assert_eq!(report.exact_zero_count, 0);
assert!((report.sparsity_fraction - 0.0).abs() < f64::EPSILON);
assert_eq!(report.per_dimension_zeros, vec![0, 0, 0]);
}
#[test]
fn sparsity_report_all_zero() {
let mut pool = CutPool::new(10, 3, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0, 0.0, 0.0]);
pool.add_cut(1, 0, 2.0, &[0.0, 0.0, 0.0]);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 6);
assert_eq!(report.exact_zero_count, 6);
assert!((report.sparsity_fraction - 1.0).abs() < f64::EPSILON);
assert_eq!(report.per_dimension_zeros, vec![2, 2, 2]);
}
#[test]
fn sparsity_report_mixed() {
let mut pool = CutPool::new(10, 3, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 0.0, 2.0]);
pool.add_cut(1, 0, 2.0, &[0.0, 0.0, 3.0]);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 6);
assert_eq!(report.exact_zero_count, 3);
assert!((report.sparsity_fraction - 0.5).abs() < 1e-10);
assert_eq!(report.per_dimension_zeros, vec![1, 2, 0]);
}
#[test]
fn sparsity_report_excludes_inactive_cuts() {
let mut pool = CutPool::new(10, 2, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0, 0.0]); pool.add_cut(1, 0, 2.0, &[1.0, 2.0]); pool.deactivate(&[0]);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 2);
assert_eq!(report.exact_zero_count, 0);
assert!((report.sparsity_fraction - 0.0).abs() < f64::EPSILON);
}
#[test]
fn sparsity_report_per_dimension_zeros_correct() {
let mut pool = CutPool::new(10, 4, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0, 1.0, 0.0, 3.0]);
pool.add_cut(1, 0, 2.0, &[0.0, 2.0, 4.0, 0.0]);
pool.add_cut(2, 0, 3.0, &[5.0, 6.0, 7.0, 8.0]);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 12);
assert_eq!(report.exact_zero_count, 4);
assert_eq!(report.per_dimension_zeros, vec![2, 0, 1, 1]);
}
#[test]
fn warm_start_cuts_have_sentinel_iteration() {
use crate::cut::WARM_START_ITERATION;
use cobre_io::OwnedPolicyCutRecord;
let records = vec![
OwnedPolicyCutRecord {
cut_id: 0,
slot_index: 0,
coefficients: vec![1.0, 2.0],
intercept: 10.0,
is_active: true,
iteration: 5,
forward_pass_index: 0,
},
OwnedPolicyCutRecord {
cut_id: 1,
slot_index: 1,
coefficients: vec![3.0, 4.0],
intercept: 20.0,
is_active: true,
iteration: 7,
forward_pass_index: 1,
},
];
let pool = CutPool::new_with_warm_start(2, 4, 100, &records);
assert_eq!(pool.warm_start_count, 2);
assert_eq!(pool.populated_count, 2);
assert_eq!(pool.metadata[0].iteration_generated, WARM_START_ITERATION);
assert_eq!(pool.metadata[1].iteration_generated, WARM_START_ITERATION);
assert_eq!(pool.metadata[0].last_active_iter, 5);
assert_eq!(pool.metadata[1].last_active_iter, 7);
}
#[test]
fn terminal_has_boundary_cuts_when_warm_start_count_positive() {
use cobre_io::OwnedPolicyCutRecord;
let records = vec![OwnedPolicyCutRecord {
cut_id: 0,
slot_index: 0,
coefficients: vec![1.0],
intercept: 5.0,
is_active: true,
iteration: 0,
forward_pass_index: 0,
}];
let pool = CutPool::new_with_warm_start(1, 4, 100, &records);
assert!(pool.warm_start_count > 0, "terminal pool has boundary cuts");
}
#[test]
fn no_boundary_cuts_when_warm_start_count_zero() {
let pool = CutPool::new(100, 2, 10, 0);
assert_eq!(pool.warm_start_count, 0, "no boundary cuts");
}
#[test]
fn enforce_budget_noop_when_under_budget() {
let mut pool = CutPool::new(100, 2, 10, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 2.0]);
pool.add_cut(0, 1, 2.0, &[3.0, 4.0]);
assert_eq!(pool.active_count(), 2);
let result = pool.enforce_budget(5, 1, 10);
assert_eq!(result.evicted_count, 0);
assert_eq!(result.active_after, 2);
assert_eq!(pool.active_count(), 2);
}
#[test]
fn enforce_budget_evicts_oldest_last_active_iter() {
let mut pool = CutPool::new(100, 2, 10, 0);
for iter in 0..5_u64 {
pool.add_cut(iter, 0, 1.0, &[1.0, 0.0]);
pool.metadata[pool.populated_count - 1].last_active_iter = iter;
}
assert_eq!(pool.active_count(), 5);
let result = pool.enforce_budget(3, 5, 10);
assert_eq!(result.evicted_count, 2);
assert_eq!(result.active_after, 3);
assert_eq!(pool.active_count(), 3);
assert!(!pool.active[0], "oldest cut should be evicted");
assert!(!pool.active[10], "second oldest should be evicted");
}
#[test]
fn enforce_budget_tiebreaks_by_active_count() {
let mut pool = CutPool::new(100, 2, 10, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 0.0]);
pool.metadata[0].last_active_iter = 1;
pool.metadata[0].active_count = 5;
pool.add_cut(0, 1, 2.0, &[0.0, 1.0]);
pool.metadata[1].last_active_iter = 1;
pool.metadata[1].active_count = 2;
assert_eq!(pool.active_count(), 2);
let result = pool.enforce_budget(1, 1, 10);
assert_eq!(result.evicted_count, 1);
assert!(pool.active[0], "higher active_count survives");
assert!(!pool.active[1], "lower active_count evicted");
}
#[test]
fn enforce_budget_protects_current_iteration() {
let mut pool = CutPool::new(100, 2, 10, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 0.0]);
pool.metadata[0].last_active_iter = 0;
pool.add_cut(0, 1, 2.0, &[0.0, 1.0]);
pool.metadata[1].last_active_iter = 0;
pool.add_cut(1, 0, 3.0, &[1.0, 1.0]);
pool.metadata[10].last_active_iter = 1;
assert_eq!(pool.active_count(), 3);
let result = pool.enforce_budget(1, 1, 10);
assert_eq!(result.evicted_count, 2);
assert!(pool.active[10], "current iteration cut preserved");
}
#[test]
fn enforce_budget_all_current_iteration_no_eviction() {
let mut pool = CutPool::new(100, 2, 10, 0);
pool.add_cut(5, 0, 1.0, &[1.0, 0.0]);
pool.add_cut(5, 1, 2.0, &[0.0, 1.0]);
pool.add_cut(5, 2, 3.0, &[1.0, 1.0]);
assert_eq!(pool.active_count(), 3);
let result = pool.enforce_budget(1, 5, 10);
assert_eq!(result.evicted_count, 0);
assert_eq!(pool.active_count(), 3);
}
#[test]
fn enforce_budget_result_fields() {
let mut pool = CutPool::new(100, 2, 10, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 0.0]);
pool.add_cut(1, 0, 2.0, &[0.0, 1.0]);
pool.add_cut(2, 0, 3.0, &[1.0, 1.0]);
assert_eq!(pool.active_count(), 3);
let result = pool.enforce_budget(1, 3, 10);
assert_eq!(result.active_before, 3);
assert_eq!(result.evicted_count, 2);
assert_eq!(result.active_after, 1);
}
#[test]
fn active_cuts_early_exit_stops_at_cached_count() {
let mut pool = CutPool::new(100, 2, 1, 0);
pool.add_cut(0, 0, 5.0, &[1.0, 2.0]);
pool.populated_count = 100;
assert_eq!(pool.cached_active_count, 1);
assert_eq!(pool.populated_count, 100);
let result: Vec<_> = pool.active_cuts().collect();
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, 0, "yielded slot must be 0");
assert_eq!(result[0].1, 5.0, "intercept must match");
assert_eq!(result[0].2, &[1.0, 2.0], "coefficients must match");
}
#[test]
fn enforce_budget_candidates_buf_is_reused() {
let mut pool = CutPool::new(100, 2, 10, 0);
for iter in 0..5_u64 {
pool.add_cut(iter, 0, 1.0, &[1.0, 0.0]);
}
assert_eq!(pool.active_count(), 5);
pool.enforce_budget(3, 5, 10);
let cap_after_first = pool.candidates_buf.capacity();
assert!(
cap_after_first >= 1,
"candidates_buf must have acquired capacity after first enforce_budget"
);
let mut pool2 = CutPool::new(100, 2, 10, 0);
for iter in 0..5_u64 {
pool2.add_cut(iter, 0, 1.0, &[1.0, 0.0]);
}
pool2.enforce_budget(3, 5, 10);
let cap_after_first2 = pool2.candidates_buf.capacity();
pool2.add_cut(6, 0, 2.0, &[0.0, 1.0]);
pool2.add_cut(7, 0, 2.0, &[0.0, 1.0]);
pool2.add_cut(8, 0, 2.0, &[0.0, 1.0]);
pool2.enforce_budget(2, 9, 10);
let cap_after_second2 = pool2.candidates_buf.capacity();
assert!(
cap_after_second2 >= cap_after_first2,
"candidates_buf capacity must not shrink across calls (was {cap_after_first2}, now {cap_after_second2})"
);
}
#[test]
fn set_active_false_decrements_active_count() {
let mut pool = CutPool::new(10, 4, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 0.0, 0.0, 0.0]);
pool.add_cut(1, 0, 2.0, &[0.0, 1.0, 0.0, 0.0]);
pool.add_cut(2, 0, 3.0, &[0.0, 0.0, 1.0, 0.0]);
pool.set_active(1, false);
assert!(!pool.active[1]);
assert_eq!(pool.active_count(), 2);
assert_eq!(pool.cuts_in_lp(), 3);
}
#[test]
fn set_active_true_reactivates_deactivated_slot() {
let mut pool = CutPool::new(10, 4, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 0.0, 0.0, 0.0]);
pool.add_cut(1, 0, 2.0, &[0.0, 1.0, 0.0, 0.0]);
pool.add_cut(2, 0, 3.0, &[0.0, 0.0, 1.0, 0.0]);
pool.deactivate(&[1]);
pool.set_active(1, true);
assert!(pool.active[1]);
assert_eq!(pool.active_count(), 3);
assert_eq!(pool.cuts_in_lp(), 3);
}
#[test]
fn set_active_idempotent_when_state_unchanged() {
let mut pool = CutPool::new(10, 4, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 0.0, 0.0, 0.0]);
pool.add_cut(1, 0, 2.0, &[0.0, 1.0, 0.0, 0.0]);
pool.add_cut(2, 0, 3.0, &[0.0, 0.0, 1.0, 0.0]);
pool.set_active(1, true);
pool.set_active(1, true);
assert_eq!(pool.active_count(), 3);
}
#[test]
fn deactivate_delegates_to_set_active() {
let mut pool_a = CutPool::new(10, 4, 1, 0);
pool_a.add_cut(0, 0, 1.0, &[1.0, 0.0, 0.0, 0.0]);
pool_a.add_cut(1, 0, 2.0, &[0.0, 1.0, 0.0, 0.0]);
pool_a.add_cut(2, 0, 3.0, &[0.0, 0.0, 1.0, 0.0]);
pool_a.deactivate(&[1, 2]);
let mut pool_b = CutPool::new(10, 4, 1, 0);
pool_b.add_cut(0, 0, 1.0, &[1.0, 0.0, 0.0, 0.0]);
pool_b.add_cut(1, 0, 2.0, &[0.0, 1.0, 0.0, 0.0]);
pool_b.add_cut(2, 0, 3.0, &[0.0, 0.0, 1.0, 0.0]);
pool_b.set_active(1, false);
pool_b.set_active(2, false);
assert_eq!(pool_a.active, pool_b.active);
assert_eq!(pool_a.cached_active_count, pool_b.cached_active_count);
}
#[test]
fn cuts_in_lp_returns_populated_count() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.add_cut(1, 0, 2.0, &[2.0]);
assert_eq!(pool.cuts_in_lp(), 2);
assert_eq!(pool.cuts_in_lp(), pool.populated_count);
pool.deactivate(&[0]);
assert_eq!(pool.cuts_in_lp(), 2);
assert_eq!(pool.active_count(), 1);
}
#[test]
fn apply_updates_applies_mixed_deactivate_and_reactivate() {
use crate::cut_selection::CutActivityUpdates;
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); pool.add_cut(1, 0, 2.0, &[2.0]); pool.add_cut(2, 0, 3.0, &[3.0]); pool.add_cut(3, 0, 4.0, &[4.0]);
pool.deactivate(&[2]);
assert_eq!(pool.active_count(), 3);
assert!(!pool.active[2]);
let updates = CutActivityUpdates {
stage_index: 0,
updates: vec![0, 3], reactivations: vec![2], };
pool.apply_updates(&updates);
assert!(!pool.active[0], "slot 0 must be deactivated");
assert!(pool.active[1], "slot 1 must remain active");
assert!(pool.active[2], "slot 2 must be reactivated");
assert!(!pool.active[3], "slot 3 must be deactivated");
assert_eq!(pool.active_count(), 2);
assert_eq!(pool.cuts_in_lp(), 4);
}
#[test]
fn apply_updates_is_idempotent() {
use crate::cut_selection::CutActivityUpdates;
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.add_cut(1, 0, 2.0, &[2.0]);
pool.add_cut(2, 0, 3.0, &[3.0]);
pool.deactivate(&[1]);
assert_eq!(pool.active_count(), 2);
let updates = CutActivityUpdates {
stage_index: 0,
updates: vec![0],
reactivations: vec![1],
};
pool.apply_updates(&updates);
let active_snapshot = pool.active.clone();
let count_snapshot = pool.active_count();
pool.apply_updates(&updates);
assert_eq!(
pool.active, active_snapshot,
"active bitmap must not change"
);
assert_eq!(
pool.active_count(),
count_snapshot,
"active count must not change"
);
assert_eq!(pool.active_count(), 2);
assert!(!pool.active[0]);
assert!(pool.active[1]);
assert!(pool.active[2]);
}
#[test]
fn apply_updates_empty_is_noop() {
use crate::cut_selection::CutActivityUpdates;
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.add_cut(1, 0, 2.0, &[2.0]);
let active_before = pool.active.clone();
let count_before = pool.active_count();
let updates = CutActivityUpdates {
stage_index: 0,
updates: vec![],
reactivations: vec![],
};
pool.apply_updates(&updates);
assert_eq!(pool.active, active_before);
assert_eq!(pool.active_count(), count_before);
}
#[test]
fn apply_updates_matches_manual_set_active_loop() {
use crate::cut_selection::CutActivityUpdates;
let build_pool = || {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.add_cut(1, 0, 2.0, &[2.0]);
pool.add_cut(2, 0, 3.0, &[3.0]);
pool.add_cut(3, 0, 4.0, &[4.0]);
pool.deactivate(&[2]); pool
};
let updates = CutActivityUpdates {
stage_index: 0,
updates: vec![0, 3],
reactivations: vec![2],
};
let mut pool_a = build_pool();
pool_a.apply_updates(&updates);
let mut pool_b = build_pool();
for &slot in &updates.updates {
pool_b.set_active(slot, false);
}
for &slot in &updates.reactivations {
pool_b.set_active(slot, true);
}
assert_eq!(pool_a.active, pool_b.active);
assert_eq!(pool_a.cached_active_count, pool_b.cached_active_count);
}
}