use crate::cut::WARM_START_ITERATION;
use crate::cut_selection::CutMetadata;
#[derive(Debug, Clone)]
pub struct CutPool {
pub coefficients: Vec<f64>,
pub intercepts: Vec<f64>,
pub metadata: Vec<CutMetadata>,
pub active: Vec<bool>,
pub populated_count: usize,
pub capacity: usize,
pub state_dimension: usize,
pub forward_passes: u32,
pub warm_start_count: u32,
pub cached_active_count: usize,
pub(crate) candidates_buf: Vec<u32>,
}
impl CutPool {
#[must_use]
pub fn new(
capacity: usize,
state_dimension: usize,
forward_passes: u32,
warm_start_count: u32,
) -> Self {
let default_meta = CutMetadata {
iteration_generated: 0,
forward_pass_index: 0,
active_count: 0,
last_active_iter: 0,
active_window: 0,
};
Self {
coefficients: vec![0.0; capacity * state_dimension],
intercepts: vec![0.0; capacity],
metadata: vec![default_meta; capacity],
active: vec![false; capacity],
populated_count: 0,
capacity,
state_dimension,
forward_passes,
warm_start_count,
cached_active_count: 0,
candidates_buf: Vec::new(),
}
}
#[inline]
fn slot_index(&self, iteration: u64, forward_pass_index: u32) -> usize {
#[allow(clippy::cast_possible_truncation)]
let iter_usize = iteration as usize;
self.warm_start_count as usize
+ iter_usize * self.forward_passes as usize
+ forward_pass_index as usize
}
pub fn add_cut(
&mut self,
iteration: u64,
forward_pass_index: u32,
intercept: f64,
coefficients: &[f64],
) {
let slot = self.slot_index(iteration, forward_pass_index);
debug_assert!(
slot < self.capacity,
"cut slot {slot} is out of bounds (capacity = {})",
self.capacity
);
debug_assert!(
coefficients.len() == self.state_dimension,
"coefficients length {} != state_dimension {}",
coefficients.len(),
self.state_dimension
);
self.intercepts[slot] = intercept;
let start = slot * self.state_dimension;
self.coefficients[start..start + self.state_dimension].copy_from_slice(coefficients);
debug_assert!(
!self.active[slot],
"add_cut: slot {slot} is already active (double-insert)"
);
self.active[slot] = true;
self.cached_active_count += 1;
self.metadata[slot] = CutMetadata {
iteration_generated: iteration,
forward_pass_index,
active_count: 0,
last_active_iter: iteration,
active_window: crate::basis_reconstruct::SEED_BIT,
};
if slot >= self.populated_count {
self.populated_count = slot + 1;
}
}
pub fn active_cuts(&self) -> impl Iterator<Item = (usize, f64, &[f64])> {
let mut remaining = self.cached_active_count;
self.active[..self.populated_count]
.iter()
.enumerate()
.scan((), move |(), (i, &is_active)| {
if remaining == 0 {
return None;
}
if is_active {
remaining -= 1;
let start = i * self.state_dimension;
Some(Some((
i,
self.intercepts[i],
&self.coefficients[start..start + self.state_dimension],
)))
} else {
Some(None)
}
})
.flatten()
}
pub(crate) fn active_delta_cuts(
&self,
current_iteration: u64,
) -> impl Iterator<Item = (usize, f64, &[f64])> {
let mut remaining = self.cached_active_count;
self.active[..self.populated_count]
.iter()
.enumerate()
.scan((), move |(), (slot, &is_active)| {
if remaining == 0 {
return None;
}
if is_active {
remaining -= 1;
Some(Some(slot))
} else {
Some(None)
}
})
.flatten()
.filter(move |&slot| {
self.metadata[slot].iteration_generated == current_iteration
&& self.metadata[slot].iteration_generated != WARM_START_ITERATION
})
.map(|i| {
let start = i * self.state_dimension;
(
i,
self.intercepts[i],
&self.coefficients[start..start + self.state_dimension],
)
})
}
#[must_use]
pub fn active_count(&self) -> usize {
debug_assert_eq!(
self.cached_active_count,
self.active[..self.populated_count]
.iter()
.filter(|&&a| a)
.count(),
"cached active_count {} != computed {}",
self.cached_active_count,
self.active[..self.populated_count]
.iter()
.filter(|&&a| a)
.count(),
);
self.cached_active_count
}
pub fn deactivate(&mut self, indices: &[u32]) {
for &idx in indices {
let i = idx as usize;
debug_assert!(i < self.capacity, "deactivate index {i} out of bounds");
debug_assert!(
self.active[i],
"deactivate called with index {i} that is already inactive"
);
if i < self.capacity && self.active[i] {
self.active[i] = false;
self.cached_active_count -= 1;
}
}
}
#[must_use]
pub fn evaluate_at_state(&self, state: &[f64]) -> f64 {
debug_assert!(
state.len() == self.state_dimension,
"state length {} != state_dimension {}",
state.len(),
self.state_dimension
);
self.active_cuts()
.map(|(_, intercept, coeffs)| {
let dot: f64 = coeffs.iter().zip(state).map(|(a, b)| a * b).sum();
intercept + dot
})
.fold(f64::NEG_INFINITY, f64::max)
}
#[must_use]
pub fn sparsity_report(&self) -> SparsityReport {
let active_count = self.active_count();
let mut exact_zero_count = 0usize;
let mut per_dimension_zeros = vec![0usize; self.state_dimension];
for (_slot, _intercept, coeffs) in self.active_cuts() {
for (j, &c) in coeffs.iter().enumerate() {
if c == 0.0 {
exact_zero_count += 1;
per_dimension_zeros[j] += 1;
}
}
}
let total = active_count * self.state_dimension;
#[allow(clippy::cast_precision_loss)]
let fraction = if total > 0 {
exact_zero_count as f64 / total as f64
} else {
0.0
};
SparsityReport {
total_coefficients: total,
exact_zero_count,
sparsity_fraction: fraction,
per_dimension_zeros,
}
}
#[must_use]
pub fn from_deserialized(
state_dimension: usize,
records: &[cobre_io::OwnedPolicyCutRecord],
) -> Self {
let capacity = records.len();
let mut coefficients = Vec::with_capacity(capacity * state_dimension);
let mut intercepts = Vec::with_capacity(capacity);
let mut active = Vec::with_capacity(capacity);
let mut metadata = Vec::with_capacity(capacity);
let mut cached_active_count = 0usize;
for record in records {
debug_assert!(
record.coefficients.len() == state_dimension,
"from_deserialized: coefficients length {} != state_dimension {}",
record.coefficients.len(),
state_dimension
);
coefficients.extend_from_slice(&record.coefficients);
intercepts.push(record.intercept);
active.push(record.is_active);
if record.is_active {
cached_active_count += 1;
}
metadata.push(CutMetadata {
iteration_generated: u64::from(record.iteration),
forward_pass_index: record.forward_pass_index,
active_count: 0,
last_active_iter: u64::from(record.iteration),
active_window: 0,
});
}
#[allow(clippy::cast_possible_truncation)]
Self {
coefficients,
intercepts,
metadata,
active,
populated_count: capacity,
capacity,
state_dimension,
forward_passes: 0,
warm_start_count: capacity as u32,
cached_active_count,
candidates_buf: Vec::new(),
}
}
#[must_use]
pub fn new_with_warm_start(
state_dimension: usize,
forward_passes: u32,
max_iterations: u64,
records: &[cobre_io::OwnedPolicyCutRecord],
) -> Self {
let warm_start_count = records.len();
#[allow(clippy::cast_possible_truncation)]
let capacity = warm_start_count + (max_iterations as usize) * (forward_passes as usize);
let default_meta = CutMetadata {
iteration_generated: 0,
forward_pass_index: 0,
active_count: 0,
last_active_iter: 0,
active_window: 0,
};
let mut coefficients = vec![0.0_f64; capacity * state_dimension];
let mut intercepts = vec![0.0; capacity];
let mut active = vec![false; capacity];
let mut metadata = vec![default_meta; capacity];
let mut cached_active_count = 0usize;
for (i, record) in records.iter().enumerate() {
debug_assert!(
record.coefficients.len() == state_dimension,
"new_with_warm_start: coefficients length {} != state_dimension {}",
record.coefficients.len(),
state_dimension
);
let start = i * state_dimension;
coefficients[start..start + state_dimension].copy_from_slice(&record.coefficients);
intercepts[i] = record.intercept;
active[i] = record.is_active;
if record.is_active {
cached_active_count += 1;
}
metadata[i] = CutMetadata {
iteration_generated: WARM_START_ITERATION,
forward_pass_index: record.forward_pass_index,
active_count: 0,
last_active_iter: u64::from(record.iteration),
active_window: 0,
};
}
#[allow(clippy::cast_possible_truncation)]
Self {
coefficients,
intercepts,
metadata,
active,
populated_count: warm_start_count,
capacity,
state_dimension,
forward_passes,
warm_start_count: warm_start_count as u32,
cached_active_count,
candidates_buf: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct SparsityReport {
pub total_coefficients: usize,
pub exact_zero_count: usize,
pub sparsity_fraction: f64,
pub per_dimension_zeros: Vec<usize>,
}
#[derive(Debug, Clone)]
pub struct BudgetEnforcementResult {
pub evicted_count: u32,
pub active_before: u32,
pub active_after: u32,
}
impl CutPool {
pub fn enforce_budget(
&mut self,
budget: u32,
current_iteration: u64,
_forward_passes: u32,
) -> BudgetEnforcementResult {
#[allow(clippy::cast_possible_truncation)]
let active_before = self.active_count() as u32;
let budget_usize = budget as usize;
if self.cached_active_count <= budget_usize {
return BudgetEnforcementResult {
evicted_count: 0,
active_before,
active_after: active_before,
};
}
let excess = self.cached_active_count - budget_usize;
self.candidates_buf.clear();
#[allow(clippy::cast_possible_truncation)]
self.candidates_buf.extend(
self.active[..self.populated_count]
.iter()
.enumerate()
.filter(|&(slot, &is_active)| {
is_active && self.metadata[slot].iteration_generated != current_iteration
})
.map(|(slot, _)| slot as u32),
);
if self.candidates_buf.is_empty() {
return BudgetEnforcementResult {
evicted_count: 0,
active_before,
active_after: active_before,
};
}
let evict_count = excess.min(self.candidates_buf.len());
let key = |&slot: &u32| {
let meta = &self.metadata[slot as usize];
(meta.last_active_iter, meta.active_count)
};
if evict_count < self.candidates_buf.len() / 2 {
self.candidates_buf
.select_nth_unstable_by(evict_count, |a, b| key(a).cmp(&key(b)));
} else {
self.candidates_buf.sort_unstable_by_key(|a| key(a));
}
let to_evict: Vec<u32> = self.candidates_buf[..evict_count].to_vec();
self.deactivate(&to_evict);
#[allow(clippy::cast_possible_truncation)]
let evicted_count = evict_count as u32;
#[allow(clippy::cast_possible_truncation)]
let active_after = self.active_count() as u32;
BudgetEnforcementResult {
evicted_count,
active_before,
active_after,
}
}
}
#[cfg(test)]
mod tests {
use super::CutPool;
#[test]
fn new_creates_pool_with_correct_capacity_and_all_inactive() {
let pool = CutPool::new(100, 9, 10, 0);
assert_eq!(pool.capacity, 100);
assert_eq!(pool.state_dimension, 9);
assert_eq!(pool.forward_passes, 10);
assert_eq!(pool.warm_start_count, 0);
assert_eq!(pool.populated_count, 0);
assert_eq!(pool.active_count(), 0);
assert!(pool.active.iter().all(|&a| !a));
assert_eq!(pool.coefficients.len(), 100 * 9);
assert!(pool.coefficients.iter().all(|&v| v == 0.0));
assert!(pool.intercepts.iter().all(|&v| v == 0.0));
}
#[test]
fn new_zero_capacity_is_valid() {
let pool = CutPool::new(0, 4, 5, 0);
assert_eq!(pool.capacity, 0);
assert_eq!(pool.active_count(), 0);
}
#[test]
fn add_cut_at_slot_zero_stores_intercept_coefficients_and_active_flag() {
let mut pool = CutPool::new(100, 9, 10, 0);
let coeffs = vec![1.0; 9];
pool.add_cut(0, 0, 5.0, &coeffs);
assert_eq!(pool.active_count(), 1);
assert!(pool.active[0]);
assert_eq!(pool.intercepts[0], 5.0);
assert_eq!(&pool.coefficients[0..9], vec![1.0; 9].as_slice());
assert_eq!(pool.metadata[0].iteration_generated, 0);
assert_eq!(pool.metadata[0].forward_pass_index, 0);
assert_eq!(pool.populated_count, 1);
}
#[test]
fn add_cut_deterministic_slot_formula_no_warmstart() {
let mut pool = CutPool::new(200, 2, 10, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 2.0]); pool.add_cut(0, 3, 2.0, &[3.0, 4.0]); pool.add_cut(1, 0, 3.0, &[5.0, 6.0]); pool.add_cut(2, 5, 4.0, &[7.0, 8.0]);
assert!(pool.active[0]);
assert_eq!(pool.intercepts[0], 1.0);
assert!(pool.active[3]);
assert_eq!(pool.intercepts[3], 2.0);
assert!(pool.active[10]);
assert_eq!(pool.intercepts[10], 3.0);
assert!(pool.active[25]);
assert_eq!(pool.intercepts[25], 4.0);
}
#[test]
fn add_cut_warm_start_count_offsets_slot() {
let mut pool = CutPool::new(100, 9, 10, 5);
let coeffs = vec![0.0; 9];
pool.add_cut(0, 0, 42.0, &coeffs);
assert!(pool.active[5]);
assert_eq!(pool.intercepts[5], 42.0);
assert_eq!(pool.populated_count, 6);
}
#[test]
fn add_cut_metadata_initialized_correctly() {
let mut pool = CutPool::new(50, 3, 5, 0);
pool.add_cut(3, 2, 7.0, &[1.0, 2.0, 3.0]);
let meta = &pool.metadata[17];
assert_eq!(meta.iteration_generated, 3);
assert_eq!(meta.forward_pass_index, 2);
assert_eq!(meta.active_count, 0);
assert_eq!(meta.last_active_iter, 3);
}
#[test]
fn populated_count_tracks_high_water_mark() {
let mut pool = CutPool::new(50, 1, 5, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); assert_eq!(pool.populated_count, 1);
pool.add_cut(1, 0, 2.0, &[2.0]); assert_eq!(pool.populated_count, 6);
pool.add_cut(0, 2, 3.0, &[3.0]); assert_eq!(pool.populated_count, 6);
}
#[test]
fn active_cuts_returns_only_active_cuts() {
let mut pool = CutPool::new(20, 2, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 2.0]); pool.add_cut(1, 0, 2.0, &[3.0, 4.0]); pool.add_cut(2, 0, 3.0, &[5.0, 6.0]);
pool.deactivate(&[1]);
let active: Vec<_> = pool.active_cuts().collect();
assert_eq!(active.len(), 2);
let slots: Vec<usize> = active.iter().map(|(s, _, _)| *s).collect();
assert!(slots.contains(&0));
assert!(slots.contains(&2));
assert!(!slots.contains(&1));
}
#[test]
fn active_cuts_empty_pool_returns_empty_iterator() {
let pool = CutPool::new(10, 3, 5, 0);
let active: Vec<_> = pool.active_cuts().collect();
assert!(active.is_empty());
}
#[test]
fn active_count_is_correct_after_add_and_deactivate() {
let mut pool = CutPool::new(20, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); pool.add_cut(1, 0, 2.0, &[2.0]); pool.add_cut(2, 0, 3.0, &[3.0]);
assert_eq!(pool.active_count(), 3);
pool.deactivate(&[1]);
assert_eq!(pool.active_count(), 2);
}
#[test]
fn deactivate_sets_flags_correctly() {
let mut pool = CutPool::new(20, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); pool.add_cut(1, 0, 2.0, &[2.0]); pool.add_cut(2, 0, 3.0, &[3.0]);
pool.deactivate(&[1]);
assert!(pool.active[0]);
assert!(!pool.active[1]);
assert!(pool.active[2]);
assert_eq!(pool.active_count(), 2);
}
#[test]
fn deactivate_multiple_indices() {
let mut pool = CutPool::new(20, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]); pool.add_cut(1, 0, 2.0, &[2.0]); pool.add_cut(2, 0, 3.0, &[3.0]);
pool.deactivate(&[0, 2]);
assert!(!pool.active[0]);
assert!(pool.active[1]);
assert!(!pool.active[2]);
assert_eq!(pool.active_count(), 1);
}
#[test]
fn deactivate_empty_slice_is_noop() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.deactivate(&[]);
assert_eq!(pool.active_count(), 1);
}
#[test]
#[cfg(not(debug_assertions))]
fn deactivate_duplicate_index_is_silently_skipped() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.add_cut(1, 0, 2.0, &[2.0]);
assert_eq!(pool.active_count(), 2);
pool.deactivate(&[0, 0, 0]);
assert_eq!(
pool.active_count(),
1,
"duplicate indices must not double-decrement"
);
assert!(!pool.active[0]);
assert!(pool.active[1]);
}
#[test]
fn evaluate_at_state_returns_max_cut_value() {
let mut pool = CutPool::new(10, 2, 1, 0);
pool.add_cut(0, 0, 10.0, &[1.0, 0.0]);
pool.add_cut(1, 0, 5.0, &[0.0, 2.0]);
let result = pool.evaluate_at_state(&[3.0, 4.0]);
assert_eq!(result, 13.0);
}
#[test]
fn evaluate_at_state_selects_correct_max() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 2.0, &[1.0]);
pool.add_cut(1, 0, 5.0, &[2.0]);
let result = pool.evaluate_at_state(&[10.0]);
assert_eq!(result, 25.0);
}
#[test]
fn evaluate_at_state_empty_pool_returns_neg_infinity() {
let pool = CutPool::new(10, 3, 5, 0);
assert_eq!(pool.evaluate_at_state(&[1.0, 2.0, 3.0]), f64::NEG_INFINITY);
}
#[test]
fn evaluate_at_state_all_deactivated_returns_neg_infinity() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 100.0, &[1.0]);
pool.deactivate(&[0]);
assert_eq!(pool.evaluate_at_state(&[5.0]), f64::NEG_INFINITY);
}
#[test]
fn evaluate_at_state_ignores_deactivated_cuts() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 10.0, &[1.0]);
pool.add_cut(1, 0, 100.0, &[1.0]);
pool.deactivate(&[1]);
assert_eq!(pool.evaluate_at_state(&[3.0]), 13.0);
}
#[test]
fn ac_add_cut_stores_at_slot_zero_and_active_count_is_one() {
let mut pool = CutPool::new(100, 9, 10, 0);
let coeffs = vec![0.0; 9];
pool.add_cut(0, 0, 5.0, &coeffs);
assert!(pool.active[0]);
assert_eq!(pool.active_count(), 1);
}
#[test]
fn ac_deactivate_reduces_active_count_correctly() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0]);
pool.add_cut(1, 0, 2.0, &[2.0]);
pool.add_cut(2, 0, 3.0, &[3.0]);
pool.deactivate(&[1]);
assert_eq!(pool.active_count(), 2);
assert!(!pool.active[1]);
}
#[test]
fn ac_evaluate_at_state_returns_correct_max() {
let mut pool = CutPool::new(10, 2, 1, 0);
pool.add_cut(0, 0, 10.0, &[1.0, 0.0]);
pool.add_cut(1, 0, 5.0, &[0.0, 2.0]);
assert_eq!(pool.evaluate_at_state(&[3.0, 4.0]), 13.0);
}
#[test]
fn ac_warm_start_count_offsets_slot() {
let mut pool = CutPool::new(100, 9, 10, 5);
let coeffs = vec![0.0; 9];
pool.add_cut(0, 0, 1.0, &coeffs);
assert!(pool.active[5]);
assert!(!pool.active[0]);
}
#[test]
fn ac_empty_pool_evaluate_returns_neg_infinity() {
let pool = CutPool::new(10, 2, 1, 0);
assert_eq!(pool.evaluate_at_state(&[1.0, 2.0]), f64::NEG_INFINITY);
}
#[test]
fn cut_pool_derives_debug_and_clone() {
let mut pool = CutPool::new(5, 2, 1, 0);
pool.add_cut(0, 0, 3.0, &[1.0, 2.0]);
let cloned = pool.clone();
assert_eq!(cloned.active_count(), 1);
assert_eq!(cloned.intercepts[0], 3.0);
let debug_str = format!("{pool:?}");
assert!(!debug_str.is_empty());
}
#[test]
fn sparsity_report_empty_pool() {
let pool = CutPool::new(10, 3, 1, 0);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 0);
assert_eq!(report.exact_zero_count, 0);
assert!((report.sparsity_fraction - 0.0).abs() < f64::EPSILON);
assert_eq!(report.per_dimension_zeros, vec![0, 0, 0]);
}
#[test]
fn sparsity_report_all_nonzero() {
let mut pool = CutPool::new(10, 3, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 2.0, 3.0]);
pool.add_cut(1, 0, 2.0, &[4.0, 5.0, 6.0]);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 6);
assert_eq!(report.exact_zero_count, 0);
assert!((report.sparsity_fraction - 0.0).abs() < f64::EPSILON);
assert_eq!(report.per_dimension_zeros, vec![0, 0, 0]);
}
#[test]
fn sparsity_report_all_zero() {
let mut pool = CutPool::new(10, 3, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0, 0.0, 0.0]);
pool.add_cut(1, 0, 2.0, &[0.0, 0.0, 0.0]);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 6);
assert_eq!(report.exact_zero_count, 6);
assert!((report.sparsity_fraction - 1.0).abs() < f64::EPSILON);
assert_eq!(report.per_dimension_zeros, vec![2, 2, 2]);
}
#[test]
fn sparsity_report_mixed() {
let mut pool = CutPool::new(10, 3, 1, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 0.0, 2.0]);
pool.add_cut(1, 0, 2.0, &[0.0, 0.0, 3.0]);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 6);
assert_eq!(report.exact_zero_count, 3);
assert!((report.sparsity_fraction - 0.5).abs() < 1e-10);
assert_eq!(report.per_dimension_zeros, vec![1, 2, 0]);
}
#[test]
fn sparsity_report_excludes_inactive_cuts() {
let mut pool = CutPool::new(10, 2, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0, 0.0]); pool.add_cut(1, 0, 2.0, &[1.0, 2.0]); pool.deactivate(&[0]);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 2);
assert_eq!(report.exact_zero_count, 0);
assert!((report.sparsity_fraction - 0.0).abs() < f64::EPSILON);
}
#[test]
fn sparsity_report_per_dimension_zeros_correct() {
let mut pool = CutPool::new(10, 4, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0, 1.0, 0.0, 3.0]);
pool.add_cut(1, 0, 2.0, &[0.0, 2.0, 4.0, 0.0]);
pool.add_cut(2, 0, 3.0, &[5.0, 6.0, 7.0, 8.0]);
let report = pool.sparsity_report();
assert_eq!(report.total_coefficients, 12);
assert_eq!(report.exact_zero_count, 4);
assert_eq!(report.per_dimension_zeros, vec![2, 0, 1, 1]);
}
#[test]
fn warm_start_cuts_have_sentinel_iteration() {
use crate::cut::WARM_START_ITERATION;
use cobre_io::OwnedPolicyCutRecord;
let records = vec![
OwnedPolicyCutRecord {
cut_id: 0,
slot_index: 0,
coefficients: vec![1.0, 2.0],
intercept: 10.0,
is_active: true,
iteration: 5,
forward_pass_index: 0,
},
OwnedPolicyCutRecord {
cut_id: 1,
slot_index: 1,
coefficients: vec![3.0, 4.0],
intercept: 20.0,
is_active: true,
iteration: 7,
forward_pass_index: 1,
},
];
let pool = CutPool::new_with_warm_start(2, 4, 100, &records);
assert_eq!(pool.warm_start_count, 2);
assert_eq!(pool.populated_count, 2);
assert_eq!(pool.metadata[0].iteration_generated, WARM_START_ITERATION);
assert_eq!(pool.metadata[1].iteration_generated, WARM_START_ITERATION);
assert_eq!(pool.metadata[0].last_active_iter, 5);
assert_eq!(pool.metadata[1].last_active_iter, 7);
}
#[test]
fn terminal_has_boundary_cuts_when_warm_start_count_positive() {
use cobre_io::OwnedPolicyCutRecord;
let records = vec![OwnedPolicyCutRecord {
cut_id: 0,
slot_index: 0,
coefficients: vec![1.0],
intercept: 5.0,
is_active: true,
iteration: 0,
forward_pass_index: 0,
}];
let pool = CutPool::new_with_warm_start(1, 4, 100, &records);
assert!(pool.warm_start_count > 0, "terminal pool has boundary cuts");
}
#[test]
fn no_boundary_cuts_when_warm_start_count_zero() {
let pool = CutPool::new(100, 2, 10, 0);
assert_eq!(pool.warm_start_count, 0, "no boundary cuts");
}
#[test]
fn enforce_budget_noop_when_under_budget() {
let mut pool = CutPool::new(100, 2, 10, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 2.0]);
pool.add_cut(0, 1, 2.0, &[3.0, 4.0]);
assert_eq!(pool.active_count(), 2);
let result = pool.enforce_budget(5, 1, 10);
assert_eq!(result.evicted_count, 0);
assert_eq!(result.active_after, 2);
assert_eq!(pool.active_count(), 2);
}
#[test]
fn enforce_budget_evicts_oldest_last_active_iter() {
let mut pool = CutPool::new(100, 2, 10, 0);
for iter in 0..5_u64 {
pool.add_cut(iter, 0, 1.0, &[1.0, 0.0]);
pool.metadata[pool.populated_count - 1].last_active_iter = iter;
}
assert_eq!(pool.active_count(), 5);
let result = pool.enforce_budget(3, 5, 10);
assert_eq!(result.evicted_count, 2);
assert_eq!(result.active_after, 3);
assert_eq!(pool.active_count(), 3);
assert!(!pool.active[0], "oldest cut should be evicted");
assert!(!pool.active[10], "second oldest should be evicted");
}
#[test]
fn enforce_budget_tiebreaks_by_active_count() {
let mut pool = CutPool::new(100, 2, 10, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 0.0]);
pool.metadata[0].last_active_iter = 1;
pool.metadata[0].active_count = 5;
pool.add_cut(0, 1, 2.0, &[0.0, 1.0]);
pool.metadata[1].last_active_iter = 1;
pool.metadata[1].active_count = 2;
assert_eq!(pool.active_count(), 2);
let result = pool.enforce_budget(1, 1, 10);
assert_eq!(result.evicted_count, 1);
assert!(pool.active[0], "higher active_count survives");
assert!(!pool.active[1], "lower active_count evicted");
}
#[test]
fn enforce_budget_protects_current_iteration() {
let mut pool = CutPool::new(100, 2, 10, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 0.0]);
pool.metadata[0].last_active_iter = 0;
pool.add_cut(0, 1, 2.0, &[0.0, 1.0]);
pool.metadata[1].last_active_iter = 0;
pool.add_cut(1, 0, 3.0, &[1.0, 1.0]);
pool.metadata[10].last_active_iter = 1;
assert_eq!(pool.active_count(), 3);
let result = pool.enforce_budget(1, 1, 10);
assert_eq!(result.evicted_count, 2);
assert!(pool.active[10], "current iteration cut preserved");
}
#[test]
fn enforce_budget_all_current_iteration_no_eviction() {
let mut pool = CutPool::new(100, 2, 10, 0);
pool.add_cut(5, 0, 1.0, &[1.0, 0.0]);
pool.add_cut(5, 1, 2.0, &[0.0, 1.0]);
pool.add_cut(5, 2, 3.0, &[1.0, 1.0]);
assert_eq!(pool.active_count(), 3);
let result = pool.enforce_budget(1, 5, 10);
assert_eq!(result.evicted_count, 0);
assert_eq!(pool.active_count(), 3);
}
#[test]
fn enforce_budget_result_fields() {
let mut pool = CutPool::new(100, 2, 10, 0);
pool.add_cut(0, 0, 1.0, &[1.0, 0.0]);
pool.add_cut(1, 0, 2.0, &[0.0, 1.0]);
pool.add_cut(2, 0, 3.0, &[1.0, 1.0]);
assert_eq!(pool.active_count(), 3);
let result = pool.enforce_budget(1, 3, 10);
assert_eq!(result.active_before, 3);
assert_eq!(result.evicted_count, 2);
assert_eq!(result.active_after, 1);
}
#[test]
fn active_cuts_early_exit_stops_at_cached_count() {
let mut pool = CutPool::new(100, 2, 1, 0);
pool.add_cut(0, 0, 5.0, &[1.0, 2.0]);
pool.populated_count = 100;
assert_eq!(pool.cached_active_count, 1);
assert_eq!(pool.populated_count, 100);
let result: Vec<_> = pool.active_cuts().collect();
assert_eq!(result.len(), 1);
assert_eq!(result[0].0, 0, "yielded slot must be 0");
assert_eq!(result[0].1, 5.0, "intercept must match");
assert_eq!(result[0].2, &[1.0, 2.0], "coefficients must match");
}
#[test]
fn enforce_budget_candidates_buf_is_reused() {
let mut pool = CutPool::new(100, 2, 10, 0);
for iter in 0..5_u64 {
pool.add_cut(iter, 0, 1.0, &[1.0, 0.0]);
}
assert_eq!(pool.active_count(), 5);
pool.enforce_budget(3, 5, 10);
let cap_after_first = pool.candidates_buf.capacity();
assert!(
cap_after_first >= 1,
"candidates_buf must have acquired capacity after first enforce_budget"
);
let mut pool2 = CutPool::new(100, 2, 10, 0);
for iter in 0..5_u64 {
pool2.add_cut(iter, 0, 1.0, &[1.0, 0.0]);
}
pool2.enforce_budget(3, 5, 10);
let cap_after_first2 = pool2.candidates_buf.capacity();
pool2.add_cut(6, 0, 2.0, &[0.0, 1.0]);
pool2.add_cut(7, 0, 2.0, &[0.0, 1.0]);
pool2.add_cut(8, 0, 2.0, &[0.0, 1.0]);
pool2.enforce_budget(2, 9, 10);
let cap_after_second2 = pool2.candidates_buf.capacity();
assert!(
cap_after_second2 >= cap_after_first2,
"candidates_buf capacity must not shrink across calls (was {cap_after_first2}, now {cap_after_second2})"
);
}
}