use rayon::prelude::*;
use crate::gemm::gemm_block;
pub(crate) const M_BLOCK: usize = 8;
#[derive(Debug, Clone)]
pub struct CutMetadata {
pub iteration_generated: u64,
pub forward_pass_index: u32,
pub active_count: u64,
pub last_active_iter: u64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct CutActivityUpdates {
pub stage_index: u32,
pub updates: Vec<u32>,
pub reactivations: Vec<u32>,
}
impl CutActivityUpdates {
#[must_use]
pub fn deactivations_only(stage_index: u32, indices: Vec<u32>) -> Self {
Self {
stage_index,
updates: indices,
reactivations: vec![],
}
}
#[must_use]
pub fn deactivation_indices(&self) -> Vec<u32> {
self.updates.clone()
}
#[must_use]
pub fn reactivation_indices(&self) -> Vec<u32> {
self.reactivations.clone()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum CutSelectionStrategy {
Level1 {
check_frequency: u64,
tie_tolerance: f64,
},
Lml1 {
check_frequency: u64,
tie_tolerance: f64,
},
Dominated {
threshold: f64,
check_frequency: u64,
},
Dynamic {
k1: Option<u32>,
k2: u32,
nadic: u32,
epsilon_viol: f64,
start_iteration: u64,
},
}
impl CutSelectionStrategy {
#[must_use]
pub fn should_run(&self, iteration: u64) -> bool {
let freq = match self {
Self::Level1 {
check_frequency, ..
}
| Self::Lml1 {
check_frequency, ..
}
| Self::Dominated {
check_frequency, ..
} => *check_frequency,
Self::Dynamic { .. } => return false,
};
iteration > 0 && iteration.is_multiple_of(freq)
}
#[must_use]
pub fn select(
&self,
pool: &crate::cut::CutPool,
visited_states: &[f64],
current_iteration: u64,
) -> CutActivityUpdates {
self.select_for_stage(pool, visited_states, current_iteration, 0)
}
#[must_use]
pub fn select_for_stage(
&self,
pool: &crate::cut::CutPool,
visited_states: &[f64],
current_iteration: u64,
stage_index: u32,
) -> CutActivityUpdates {
if let CutSelectionStrategy::Dynamic { .. } = self {
return CutActivityUpdates {
stage_index,
updates: vec![],
reactivations: vec![],
};
}
let populated = pool.populated_count;
let n_state = pool.state_dimension;
let warm_start = pool.warm_start_count as usize;
if populated == 0 || visited_states.is_empty() || n_state == 0 {
return CutActivityUpdates {
stage_index,
updates: vec![],
reactivations: vec![],
};
}
let eligible: Vec<bool> = (0..populated)
.map(|k| k >= warm_start && pool.metadata[k].iteration_generated < current_iteration)
.collect();
let n_eligible = eligible.iter().filter(|&&e| e).count();
if n_eligible < 2 {
return CutActivityUpdates {
stage_index,
updates: vec![],
reactivations: vec![],
};
}
let n_states = visited_states.len() / n_state;
let n_blocks = n_states.div_ceil(M_BLOCK);
let m_block_starts: Vec<usize> = (0..n_blocks).map(|i| i * M_BLOCK).collect();
let coef_slice = &pool.coefficients[..populated * n_state];
let intercepts: &[f64] = &pool.intercepts[..populated];
let is_selected: Vec<bool> = m_block_starts
.par_iter()
.fold(
|| {
(
vec![0.0_f64; populated * M_BLOCK],
vec![false; populated],
)
},
|(mut v_block_local, mut bitmap_local), &m_start| {
let m_end = (m_start + M_BLOCK).min(n_states);
let m_len = m_end - m_start;
let state_block = &visited_states[m_start * n_state..m_end * n_state];
let v_block_active = &mut v_block_local[..populated * m_len];
gemm_block(
coef_slice,
state_block,
populated,
n_state,
m_len,
v_block_active,
);
for (k, &intercept) in intercepts.iter().enumerate().take(populated) {
let row = k * m_len;
for col in 0..m_len {
v_block_active[row + col] += intercept;
}
}
for col in 0..m_len {
apply_column_rule(
self,
v_block_active,
populated,
m_len,
col,
warm_start,
&eligible,
&mut bitmap_local,
);
}
(v_block_local, bitmap_local)
},
)
.map(|(_, bitmap)| bitmap)
.reduce(
|| vec![false; populated],
|mut a, b| {
for (ai, bi) in a.iter_mut().zip(b.iter()) {
*ai |= *bi;
}
a
},
);
let mut deactivations: Vec<u32> = Vec::new();
let mut reactivations: Vec<u32> = Vec::new();
#[allow(clippy::cast_possible_truncation)]
for k in warm_start..populated {
if eligible[k] {
let currently_active = pool.active[k];
if is_selected[k] && !currently_active {
reactivations.push(k as u32);
} else if !is_selected[k] && currently_active {
deactivations.push(k as u32);
}
}
}
CutActivityUpdates {
stage_index,
updates: deactivations,
reactivations,
}
}
}
#[allow(clippy::too_many_arguments)]
#[inline]
fn apply_column_rule(
method: &CutSelectionStrategy,
v_block: &[f64],
populated: usize,
m_len: usize,
col: usize,
warm_start: usize,
eligible: &[bool],
bitmap: &mut [bool],
) {
let mut max_val = f64::NEG_INFINITY;
for k in 0..populated {
let v = v_block[k * m_len + col];
if v > max_val {
max_val = v;
}
}
match method {
CutSelectionStrategy::Level1 { tie_tolerance, .. }
| CutSelectionStrategy::Dominated {
threshold: tie_tolerance,
..
} => {
let cutoff = max_val - tie_tolerance;
for k in warm_start..populated {
if eligible[k] && v_block[k * m_len + col] >= cutoff {
bitmap[k] = true;
}
}
}
CutSelectionStrategy::Lml1 { tie_tolerance, .. } => {
let cutoff = max_val - tie_tolerance;
for k in warm_start..populated {
if eligible[k] && v_block[k * m_len + col] >= cutoff {
bitmap[k] = true;
break;
}
}
}
CutSelectionStrategy::Dynamic { .. } => {
unreachable!("Dynamic cut selection does not run the value-evaluation kernel")
}
}
}
fn validate_check_frequency(check_frequency: u32) -> Result<u32, String> {
if check_frequency == 0 {
return Err("cut_selection.check_frequency must be > 0".to_string());
}
Ok(check_frequency)
}
pub fn parse_cut_selection_config(
config: &cobre_io::config::RowSelectionConfig,
) -> Result<Option<CutSelectionStrategy>, String> {
use cobre_io::config::SelectionMethod;
let Some(selection) = config.selection.as_ref() else {
return Ok(None);
};
match *selection {
SelectionMethod::Level1 {
tie_tolerance,
check_frequency,
} => Ok(Some(CutSelectionStrategy::Level1 {
check_frequency: u64::from(validate_check_frequency(check_frequency)?),
tie_tolerance,
})),
SelectionMethod::Lml1 {
tie_tolerance,
check_frequency,
} => Ok(Some(CutSelectionStrategy::Lml1 {
check_frequency: u64::from(validate_check_frequency(check_frequency)?),
tie_tolerance,
})),
SelectionMethod::Domination {
domination_tolerance,
check_frequency,
} => Ok(Some(CutSelectionStrategy::Dominated {
threshold: domination_tolerance,
check_frequency: u64::from(validate_check_frequency(check_frequency)?),
})),
SelectionMethod::Dynamic {
start_iteration,
seed_window,
candidate_recency,
max_added_per_round,
violation_tolerance,
} => {
if start_iteration == 0 {
return Err(
"cut_selection.start_iteration must be >= 1 for method='dynamic'".to_string(),
);
}
if candidate_recency == Some(0) {
return Err(
"cut_selection.candidate_recency must be >= 1 for method='dynamic' \
(omit it for the unbounded default)"
.to_string(),
);
}
if max_added_per_round == 0 {
return Err(
"cut_selection.max_added_per_round must be >= 1 for method='dynamic'"
.to_string(),
);
}
if violation_tolerance <= 0.0 {
return Err(
"cut_selection.violation_tolerance must be > 0 for method='dynamic'"
.to_string(),
);
}
Ok(Some(CutSelectionStrategy::Dynamic {
k1: candidate_recency,
k2: seed_window,
nadic: max_added_per_round,
epsilon_viol: violation_tolerance,
start_iteration: u64::from(start_iteration),
}))
}
}
}
#[cfg(test)]
mod tests {
use super::parse_cut_selection_config;
use super::{CutActivityUpdates, CutMetadata, CutSelectionStrategy};
use crate::cut::CutPool;
use cobre_io::config::{RowSelectionConfig, SelectionMethod};
fn make_meta(active_count: u64, last_active_iter: u64) -> CutMetadata {
CutMetadata {
iteration_generated: 1,
forward_pass_index: 0,
active_count,
last_active_iter,
}
}
#[allow(clippy::cast_possible_truncation)]
fn make_pool(metadata: &[CutMetadata], active: &[bool]) -> CutPool {
let n = metadata.len();
let mut pool = CutPool::new(n, 1, 1, 0);
for i in 0..n {
pool.add_cut(0, i as u32, 0.0, &[0.0]);
}
pool.metadata[..n].clone_from_slice(metadata);
pool.active[..n].clone_from_slice(active);
pool.cached_active_count = active.iter().filter(|&&a| a).count();
pool
}
#[test]
fn should_run_false_at_zero() {
let s = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
assert!(!s.should_run(0));
}
#[test]
fn should_run_false_between_multiples() {
let s = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
assert!(!s.should_run(3));
assert!(!s.should_run(7));
}
#[test]
fn should_run_true_at_multiples() {
let s = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
assert!(s.should_run(5));
assert!(s.should_run(10));
assert!(s.should_run(15));
}
#[test]
fn should_run_lml1_respects_check_frequency() {
let s = CutSelectionStrategy::Lml1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
assert!(!s.should_run(0));
assert!(!s.should_run(3));
assert!(s.should_run(5));
assert!(s.should_run(10));
}
#[test]
fn should_run_dominated_respects_check_frequency() {
let s = CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 10,
};
assert!(!s.should_run(5));
assert!(s.should_run(10));
}
#[test]
fn level1_deactivates_dominated_cuts_at_state() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 0.0,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0]);
pool.add_cut(1, 0, 5.0, &[0.0]);
pool.add_cut(2, 0, 3.0, &[0.0]);
let deact = strategy.select(&pool, &[0.0], 10);
let mut deact_idx = deact.deactivation_indices();
deact_idx.sort_unstable();
assert_eq!(deact_idx, vec![0, 2], "cuts 0 and 2 must be deactivated");
assert!(
deact.reactivation_indices().is_empty(),
"no reactivations expected"
);
}
#[test]
fn level1_retains_tied_cuts() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 5.0, &[0.0]);
pool.add_cut(1, 0, 5.0, &[0.0]);
pool.add_cut(2, 0, 3.0, &[0.0]);
let deact = strategy.select(&pool, &[0.0], 10);
assert_eq!(
deact.deactivation_indices(),
vec![2],
"only cut 2 (value 3) is deactivated; ties 0 and 1 kept"
);
assert!(deact.reactivation_indices().is_empty());
}
#[test]
fn level1_retains_positive_activity_cuts() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 0.0,
};
let mut pool = CutPool::new(2, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[2.0]);
pool.add_cut(1, 0, 3.0, &[0.0]);
let deact = strategy.select(&pool, &[1.0], 10);
assert!(
deact.deactivation_indices().is_empty(),
"no cuts deactivated when all tied at max"
);
}
#[test]
fn level1_threshold_1_deactivates_cuts_with_count_at_most_1() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 0.5,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0]);
pool.add_cut(1, 0, 3.0, &[0.0]);
pool.add_cut(2, 0, 2.0, &[0.0]);
let deact = strategy.select(&pool, &[0.0], 10);
let mut deact_idx = deact.deactivation_indices();
deact_idx.sort_unstable();
assert_eq!(deact_idx, vec![0, 2]);
}
#[test]
fn level1_empty_metadata_returns_empty_set() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let pool = CutPool::new(0, 1, 1, 0);
let deact = strategy.select(&pool, &[], 10);
assert!(deact.deactivation_indices().is_empty());
}
#[test]
fn level1_empty_states_returns_empty() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0]);
pool.add_cut(1, 0, 5.0, &[0.0]);
pool.add_cut(2, 0, 3.0, &[0.0]);
let deact = strategy.select(&pool, &[], 10);
assert!(deact.deactivation_indices().is_empty());
assert!(deact.reactivation_indices().is_empty());
}
#[test]
fn lml1_only_oldest_survives_at_each_state() {
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 5.0, &[0.0]);
pool.add_cut(1, 0, 5.0, &[0.0]);
pool.add_cut(2, 0, 3.0, &[0.0]);
let deact = strategy.select(&pool, &[0.0], 10);
let mut deact_idx = deact.deactivation_indices();
deact_idx.sort_unstable();
assert_eq!(
deact_idx,
vec![1, 2],
"cuts 1 and 2 deactivated; only oldest (cut 0) at max survives"
);
assert!(deact.reactivation_indices().is_empty());
}
#[test]
fn lml1_union_across_trial_points() {
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 2.0, &[0.0]); pool.add_cut(1, 0, 0.0, &[3.0]); pool.add_cut(2, 0, 0.5, &[0.0]); let deact = strategy.select(&pool, &[0.0, 1.0], 10);
assert_eq!(
deact.deactivation_indices(),
vec![2],
"only cut 2 (never at max) deactivated"
);
assert!(deact.reactivation_indices().is_empty());
}
#[test]
fn lml1_deactivates_cuts_outside_memory_window() {
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let pool = make_pool(&[make_meta(0, 5)], &[true]);
let deact = strategy.select(&pool, &[0.0], 20);
assert!(deact.deactivation_indices().is_empty());
}
#[test]
fn lml1_retains_cuts_within_memory_window() {
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(2, 1, 1, 0);
pool.add_cut(0, 0, 3.0, &[0.0]);
pool.add_cut(1, 0, 3.0, &[0.0]);
let deact = strategy.select(&pool, &[0.0], 10);
assert_eq!(
deact.deactivation_indices(),
vec![1],
"cut 1 deactivated; cut 0 (oldest) retained"
);
}
#[test]
fn lml1_retains_cuts_exactly_at_boundary() {
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 5.0, &[0.0]);
pool.add_cut(1, 0, 4.0, &[0.0]);
pool.add_cut(2, 0, 5.0, &[0.0]);
let deact = strategy.select(&pool, &[0.0], 10);
let mut deact_idx = deact.deactivation_indices();
deact_idx.sort_unstable();
assert_eq!(
deact_idx,
vec![1, 2],
"cuts 1 and 2 deactivated; cut 0 (oldest at max) retained"
);
}
#[test]
fn lml1_mixed_cuts_deactivates_correct_indices() {
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0]);
pool.add_cut(1, 0, 0.0, &[2.0]);
pool.add_cut(2, 0, 3.0, &[0.0]);
let deact = strategy.select(&pool, &[0.0, 2.0], 10);
assert_eq!(
deact.deactivation_indices(),
vec![0],
"only cut 0 (never at max) deactivated"
);
}
#[test]
fn lml1_empty_states_returns_empty() {
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(2, 1, 1, 0);
pool.add_cut(0, 0, 5.0, &[0.0]);
pool.add_cut(1, 0, 3.0, &[0.0]);
let deact = strategy.select(&pool, &[], 10);
assert!(deact.deactivation_indices().is_empty());
assert!(deact.reactivation_indices().is_empty());
}
#[test]
fn ac_level1_threshold_0_deactivates_zero_activity_cut() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 0.0,
};
let mut pool = CutPool::new(2, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0]);
pool.add_cut(1, 0, 3.0, &[0.0]);
let deact = strategy.select(&pool, &[0.0], 10);
assert!(deact.deactivation_indices().contains(&0));
}
#[test]
fn ac_lml1_deactivates_cut_outside_memory_window() {
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(2, 1, 1, 0);
pool.add_cut(0, 0, 5.0, &[0.0]);
pool.add_cut(1, 0, 3.0, &[0.0]);
let deact = strategy.select(&pool, &[0.0], 20);
assert!(deact.deactivation_indices().contains(&1));
}
#[test]
fn select_for_stage_sets_stage_index() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let pool = make_pool(&[make_meta(0, 1)], &[true]);
let deact = strategy.select_for_stage(&pool, &[], 10, 7);
assert_eq!(deact.stage_index, 7);
}
#[test]
fn select_sets_stage_index_to_zero() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance: 1e-10,
};
let pool = CutPool::new(0, 1, 1, 0);
let deact = strategy.select(&pool, &[], 10);
assert_eq!(deact.stage_index, 0);
}
#[test]
fn deactivation_set_derives_debug_and_clone() {
let deact = CutActivityUpdates::deactivations_only(2, vec![0, 3, 7]);
let cloned = deact.clone();
assert_eq!(cloned.stage_index, 2);
assert_eq!(cloned.deactivation_indices(), vec![0, 3, 7]);
assert!(!format!("{deact:?}").is_empty());
}
#[test]
fn cut_metadata_derives_debug_and_clone() {
let meta = make_meta(5, 10);
let cloned = meta.clone();
assert_eq!(cloned.active_count, 5);
assert!(!format!("{meta:?}").is_empty());
}
#[test]
fn level1_spares_cuts_from_current_iteration() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0]);
pool.add_cut(1, 0, 1.0, &[0.0]);
pool.add_cut(2, 0, 5.0, &[0.0]);
pool.metadata[0].iteration_generated = 10; pool.metadata[1].iteration_generated = 5;
pool.metadata[2].iteration_generated = 5;
let deact = strategy.select(&pool, &[0.0], 10);
assert_eq!(
deact.deactivation_indices(),
vec![1],
"only the older cut (slot 1) should be deactivated; \
the current-iteration cut (slot 0) must be spared"
);
}
#[test]
fn lml1_spares_cuts_from_current_iteration() {
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let pool = make_pool(
&[CutMetadata {
iteration_generated: 10,
forward_pass_index: 0,
active_count: 0,
last_active_iter: 10,
}],
&[true],
);
let deact = strategy.select(&pool, &[0.0], 10);
assert!(
deact.deactivation_indices().is_empty(),
"current-iteration cut must not be deactivated"
);
}
#[test]
fn lml1_memory_window_boundary_behavior() {
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(5, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0]);
pool.add_cut(1, 0, 2.0, &[0.0]);
pool.add_cut(2, 0, 0.0, &[2.0]);
pool.add_cut(3, 0, 3.0, &[0.0]);
pool.add_cut(4, 0, 1.0, &[1.0]);
let deact = strategy.select_for_stage(&pool, &[0.0, 1.0], 10, 0);
let mut deact_idx = deact.deactivation_indices();
deact_idx.sort_unstable();
assert_eq!(
deact_idx,
vec![0, 1, 2, 4],
"only cut 3 (oldest at max at both states) survives"
);
}
#[test]
fn level1_reactivates_inactive_cut_at_max() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 5.0, &[0.0]);
pool.add_cut(1, 0, 3.0, &[0.0]);
pool.add_cut(2, 0, 1.0, &[0.0]);
pool.set_active(0, false);
assert_eq!(pool.active_count(), 2);
let result = strategy.select(&pool, &[0.0], 10);
assert_eq!(
result.reactivation_indices(),
vec![0],
"inactive cut 0 (at max) must be reactivated"
);
let mut deact_idx = result.deactivation_indices();
deact_idx.sort_unstable();
assert_eq!(
deact_idx,
vec![1, 2],
"active cuts 1 and 2 (below max) must be deactivated"
);
}
#[test]
fn lml1_reactivates_inactive_oldest_at_max() {
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(2, 1, 1, 0);
pool.add_cut(0, 0, 5.0, &[0.0]);
pool.add_cut(1, 0, 5.0, &[0.0]);
pool.set_active(0, false);
let result = strategy.select(&pool, &[0.0], 10);
assert_eq!(
result.reactivation_indices(),
vec![0],
"inactive cut 0 (oldest at max) must be reactivated"
);
assert_eq!(
result.deactivation_indices(),
vec![1],
"active cut 1 (younger at max) must be deactivated"
);
}
#[test]
fn select_returns_empty_when_all_cuts_from_current_iteration() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0]);
pool.add_cut(1, 0, 5.0, &[0.0]);
pool.add_cut(2, 0, 3.0, &[0.0]);
pool.metadata[0].iteration_generated = 10;
pool.metadata[1].iteration_generated = 10;
pool.metadata[2].iteration_generated = 10;
let result = strategy.select(&pool, &[0.0], 10);
assert!(
result.deactivation_indices().is_empty(),
"no activity changes when all cuts from current iteration"
);
assert!(result.reactivation_indices().is_empty());
}
#[allow(clippy::cast_possible_truncation)]
#[test]
fn level1_warm_start_cuts_not_deactivated() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 1,
tie_tolerance: 0.0,
};
let n = 3usize;
let mut pool = CutPool::new(n, 1, 1, 1); let intercepts = [10.0f64, 1.0, 3.0];
#[allow(clippy::needless_range_loop)]
for i in 0..n {
pool.intercepts[i] = intercepts[i];
pool.coefficients[i] = 0.0;
pool.active[i] = true;
pool.metadata[i] = CutMetadata {
iteration_generated: 1,
forward_pass_index: i as u32,
active_count: 0,
last_active_iter: 1,
};
}
pool.populated_count = n;
pool.cached_active_count = n;
let result = strategy.select(&pool, &[0.0], 10);
assert!(
!result.deactivation_indices().contains(&0),
"warm-start slot 0 must not be deactivated"
);
let mut deact_idx = result.deactivation_indices();
deact_idx.sort_unstable();
assert_eq!(deact_idx, vec![1, 2]);
}
#[test]
fn select_skips_already_inactive_slots() {
let mut pool = CutPool::new(10, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0]); pool.add_cut(1, 0, 5.0, &[0.0]); pool.add_cut(2, 0, 3.0, &[0.0]); assert_eq!(pool.active_count(), 3);
pool.set_active(0, false);
assert_eq!(pool.active_count(), 2);
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 1,
tie_tolerance: 0.0,
};
let deact = strategy.select_for_stage(&pool, &[0.0], 5, 0);
assert_eq!(
deact.deactivation_indices(),
vec![2],
"only slot 2 (active, below max) deactivated"
);
assert!(
deact.reactivation_indices().is_empty(),
"slot 0 (inactive, below max) must not be reactivated"
);
}
#[test]
fn select_for_stage_returns_cut_activity_updates_with_deactivations() {
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0]); pool.add_cut(1, 0, 2.0, &[0.0]);
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 1,
tie_tolerance: 0.0,
};
let result = strategy.select_for_stage(&pool, &[0.0], 10, 0);
assert!(
!result.updates.is_empty(),
"deactivations must be non-empty"
);
assert!(
result.updates.contains(&0),
"slot 0 (below max) must be deactivated"
);
assert!(
!result.updates.contains(&1),
"slot 1 (at max) must not be deactivated"
);
}
#[test]
fn aggressiveness_ordering_level1_leq_lml1_leq_dominated() {
let meta = [
CutMetadata {
iteration_generated: 1,
forward_pass_index: 0,
active_count: 0,
last_active_iter: 1,
},
CutMetadata {
iteration_generated: 1,
forward_pass_index: 1,
active_count: 0,
last_active_iter: 2,
},
CutMetadata {
iteration_generated: 1,
forward_pass_index: 2,
active_count: 3,
last_active_iter: 3,
},
CutMetadata {
iteration_generated: 1,
forward_pass_index: 3,
active_count: 5,
last_active_iter: 10,
},
CutMetadata {
iteration_generated: 1,
forward_pass_index: 4,
active_count: 5,
last_active_iter: 10,
},
];
let pool = make_dominated_pool(
&[0.0, 0.0, 1.0, 0.0, 5.0],
&[vec![0.0], vec![0.1], vec![0.0], vec![2.0], vec![-1.0]],
&[true; 5],
&meta,
);
let states: Vec<f64> = vec![0.0, 1.0, 3.0, 5.0];
let l1 = CutSelectionStrategy::Level1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let deact_l1 = l1.select(&pool, &states, 11);
let lml1 = CutSelectionStrategy::Lml1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let deact_lml1 = lml1.select(&pool, &states, 11);
let dom = CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 1,
};
let deact_dom = dom.select(&pool, &states, 11);
assert!(
deact_l1.deactivation_indices().len() <= deact_lml1.deactivation_indices().len(),
"Level1 ({}) should deactivate <= LML1 ({})",
deact_l1.deactivation_indices().len(),
deact_lml1.deactivation_indices().len()
);
assert!(
deact_lml1.deactivation_indices().len() <= deact_dom.deactivation_indices().len(),
"LML1 ({}) should deactivate <= Dominated ({})",
deact_lml1.deactivation_indices().len(),
deact_dom.deactivation_indices().len()
);
}
#[test]
fn cut_activity_updates_deactivations_only_constructor() {
let updates = CutActivityUpdates::deactivations_only(7, vec![0, 1, 2]);
assert_eq!(updates.stage_index, 7);
assert_eq!(updates.updates.len(), 3);
assert_eq!(updates.updates, vec![0, 1, 2]);
assert!(updates.reactivations.is_empty());
}
#[test]
fn cut_activity_updates_deactivation_indices_returns_updates() {
let updates = CutActivityUpdates {
stage_index: 0,
updates: vec![0, 2],
reactivations: vec![],
};
assert_eq!(updates.deactivation_indices(), vec![0, 2]);
assert!(updates.reactivation_indices().is_empty());
}
#[test]
fn test_parse_disabled_default() {
let cfg = RowSelectionConfig::default();
let result = parse_cut_selection_config(&cfg);
assert!(result.is_ok());
assert!(
result.unwrap().is_none(),
"default config (no selection) must produce None (disabled)"
);
}
#[test]
fn test_parse_level1() {
let cfg = RowSelectionConfig {
selection: Some(SelectionMethod::Level1 {
tie_tolerance: 1e-8,
check_frequency: 5,
}),
..RowSelectionConfig::default()
};
let strategy = parse_cut_selection_config(&cfg)
.expect("level1 must parse")
.expect("must produce Some for level1");
assert!(
matches!(
strategy,
CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance,
} if (tie_tolerance - 1e-8).abs() < f64::EPSILON
),
"unexpected variant: {strategy:?}"
);
}
#[test]
fn test_parse_level1_default_tie_tolerance() {
let json = r#"{"selection": {"method": "level1"}}"#;
let cfg: RowSelectionConfig = serde_json::from_str(json).expect("level1 defaults parse");
let strategy = parse_cut_selection_config(&cfg)
.expect("level1 must parse")
.expect("must produce Some for level1");
assert!(
matches!(
strategy,
CutSelectionStrategy::Level1 {
check_frequency: 5,
tie_tolerance,
} if (tie_tolerance - 1e-10).abs() < 1e-20
),
"unexpected variant or wrong default tie_tolerance: {strategy:?}"
);
}
#[test]
fn test_parse_lml1() {
let cfg = RowSelectionConfig {
selection: Some(SelectionMethod::Lml1 {
tie_tolerance: 1e-8,
check_frequency: 5,
}),
..RowSelectionConfig::default()
};
let strategy = parse_cut_selection_config(&cfg)
.expect("lml1 must parse")
.expect("must produce Some for lml1");
assert!(
matches!(
strategy,
CutSelectionStrategy::Lml1 {
check_frequency: 5,
tie_tolerance,
} if (tie_tolerance - 1e-8).abs() < f64::EPSILON
),
"unexpected variant: {strategy:?}"
);
}
#[test]
fn test_parse_lml1_defaults() {
let json = r#"{"selection": {"method": "lml1"}}"#;
let cfg: RowSelectionConfig = serde_json::from_str(json).expect("lml1 defaults parse");
let strategy = parse_cut_selection_config(&cfg)
.expect("lml1 must parse")
.expect("must produce Some for lml1");
assert!(
matches!(
strategy,
CutSelectionStrategy::Lml1 {
check_frequency: 5,
tie_tolerance,
} if (tie_tolerance - 1e-10).abs() < 1e-20
),
"unexpected variant or wrong default tie_tolerance: {strategy:?}"
);
}
#[test]
fn test_parse_domination() {
let cfg = RowSelectionConfig {
selection: Some(SelectionMethod::Domination {
domination_tolerance: 1e-6,
check_frequency: 10,
}),
..RowSelectionConfig::default()
};
let strategy = parse_cut_selection_config(&cfg)
.expect("domination must parse")
.expect("must produce Some for domination");
assert!(
matches!(
strategy,
CutSelectionStrategy::Dominated {
threshold,
check_frequency: 10,
} if (threshold - 1e-6).abs() < f64::EPSILON
),
"unexpected variant: {strategy:?}"
);
}
#[test]
fn test_parse_zero_check_frequency() {
let cfg = RowSelectionConfig {
selection: Some(SelectionMethod::Level1 {
tie_tolerance: 1e-10,
check_frequency: 0,
}),
..RowSelectionConfig::default()
};
let result = parse_cut_selection_config(&cfg);
assert!(result.is_err());
let msg = result.unwrap_err();
assert!(
msg.contains("check_frequency"),
"error message must mention check_frequency, got: {msg}"
);
}
#[allow(clippy::cast_possible_truncation)]
fn make_dominated_pool(
intercepts: &[f64],
coefficients: &[Vec<f64>],
active: &[bool],
metadata: &[CutMetadata],
) -> CutPool {
let n = intercepts.len();
let state_dim = coefficients[0].len();
let mut pool = CutPool::new(n, state_dim, 1, 0);
for i in 0..n {
pool.add_cut(0, i as u32, intercepts[i], &coefficients[i]);
pool.metadata[i] = metadata[i].clone();
pool.active[i] = active[i];
}
pool.cached_active_count = active.iter().filter(|&&a| a).count();
pool
}
fn default_meta_at(iter: u64) -> CutMetadata {
CutMetadata {
iteration_generated: iter,
forward_pass_index: 0,
active_count: 0,
last_active_iter: iter,
}
}
fn default_meta_vec(n: usize, iter: u64) -> Vec<CutMetadata> {
(0..n).map(|_| default_meta_at(iter)).collect()
}
#[test]
fn dominated_select_deactivate_dominated() {
let strategy = CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 1,
};
let pool = make_dominated_pool(
&[1.0, 0.0, 3.0, 0.5, 0.0],
&[
vec![0.0], vec![2.0], vec![-1.0], vec![0.0], vec![0.5], ],
&[true; 5],
&default_meta_vec(5, 1),
);
let states: Vec<f64> = vec![0.0, 1.0, 3.0];
let deact = strategy.select(&pool, &states, 10);
assert_eq!(deact.deactivation_indices(), vec![0, 3, 4]);
}
#[test]
fn dominated_select_partial_domination_retained() {
let strategy = CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 1,
};
let pool = make_dominated_pool(
&[2.0, 0.0],
&[vec![0.0], vec![2.0]],
&[true, true],
&default_meta_vec(2, 1),
);
let states: Vec<f64> = vec![0.0, 1.0, 3.0];
let deact = strategy.select(&pool, &states, 10);
assert!(
deact.deactivation_indices().is_empty(),
"cut 0 achieves max at x=0 and x=1, must not be deactivated"
);
}
#[test]
fn dominated_select_none_dominated_when_all_achieve_max() {
let strategy = CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 1,
};
let pool = make_dominated_pool(
&[5.0, 0.0, 2.0],
&[vec![-2.0], vec![3.0], vec![0.0]],
&[true; 3],
&default_meta_vec(3, 1),
);
let states: Vec<f64> = vec![0.0, 1.0, 3.0];
let deact = strategy.select(&pool, &states, 10);
assert_eq!(
deact.deactivation_indices(),
vec![2],
"only cut 2 (constant 2) should be dominated"
);
}
#[test]
fn dominated_select_empty_states() {
let strategy = CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 1,
};
let pool = make_dominated_pool(
&[1.0, 2.0],
&[vec![0.0], vec![0.0]],
&[true, true],
&default_meta_vec(2, 1),
);
let deact = strategy.select(&pool, &[], 10);
assert!(
deact.deactivation_indices().is_empty(),
"empty visited_states must produce empty deactivation set"
);
}
#[test]
fn dominated_select_single_active_cut() {
let strategy = CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 1,
};
let pool = make_dominated_pool(
&[1.0, 2.0, 3.0],
&[vec![0.0], vec![0.0], vec![0.0]],
&[true, false, false],
&default_meta_vec(3, 1),
);
let states: Vec<f64> = vec![0.0, 1.0];
let deact = strategy.select(&pool, &states, 10);
assert_eq!(
deact.deactivation_indices(),
vec![0],
"active cut 0 (below max) must be deactivated"
);
assert_eq!(
deact.reactivation_indices(),
vec![2],
"inactive cut 2 (at max) must be reactivated"
);
}
#[test]
fn dominated_select_current_iteration_excluded() {
let strategy = CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 1,
};
let pool = make_dominated_pool(
&[1.0, 5.0],
&[vec![0.0], vec![0.0]],
&[true, true],
&[default_meta_at(10), default_meta_at(1)],
);
let states: Vec<f64> = vec![0.0, 1.0];
let deact = strategy.select(&pool, &states, 10);
assert!(
deact.deactivation_indices().is_empty(),
"cut from current iteration must not be deactivated even if dominated"
);
}
#[test]
fn level1_selected_is_superset_of_lml1_selected() {
let meta = default_meta_vec(4, 1);
let pool = make_dominated_pool(
&[1.0, 3.0, 0.0, 0.0],
&[vec![0.0], vec![0.0], vec![2.0], vec![0.0]],
&[true; 4],
&meta,
);
let states: Vec<f64> = vec![0.0, 2.0];
let l1 = CutSelectionStrategy::Level1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let lml1 = CutSelectionStrategy::Lml1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let deact_l1 = l1.select(&pool, &states, 10);
let deact_lml1 = lml1.select(&pool, &states, 10);
for slot in deact_lml1.deactivation_indices() {
assert!(
deact_l1.deactivation_indices().contains(&slot),
"slot {slot} deactivated by Lml1 but not by Level1; \
Level1_selected must be a superset of Lml1_selected"
);
}
assert!(
deact_l1.deactivation_indices().len() <= deact_lml1.deactivation_indices().len(),
"Level1 must deactivate <= Lml1"
);
}
#[test]
fn dominated_epsilon_tolerance_cut_barely_below_max() {
let meta = default_meta_vec(2, 1);
let pool = make_dominated_pool(
&[5.0, 4.999_999_9],
&[vec![0.0], vec![0.0]],
&[true; 2],
&meta,
);
let states: Vec<f64> = vec![0.0];
let dom_loose = CutSelectionStrategy::Dominated {
threshold: 1e-6,
check_frequency: 1,
};
let deact_loose = dom_loose.select(&pool, &states, 10);
assert!(
deact_loose.deactivation_indices().is_empty(),
"cut1 (1e-7 below max) must survive when epsilon=1e-6"
);
let dom_strict = CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 1,
};
let deact_strict = dom_strict.select(&pool, &states, 10);
assert_eq!(
deact_strict.deactivation_indices(),
vec![1],
"cut1 (1e-7 below max) must be deactivated when epsilon=0"
);
}
#[test]
fn level1_single_eligible_cut_returns_empty() {
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let mut pool = CutPool::new(2, 1, 1, 0);
pool.add_cut(0, 0, 10.0, &[0.0]); pool.add_cut(1, 0, 1.0, &[0.0]); pool.metadata[0].iteration_generated = 10; pool.metadata[1].iteration_generated = 5;
let result = strategy.select(&pool, &[0.0], 10);
assert!(
result.deactivation_indices().is_empty(),
"single eligible cut must not trigger any deactivations"
);
assert!(result.reactivation_indices().is_empty());
}
#[allow(clippy::cast_precision_loss)]
fn make_determinism_pool() -> CutPool {
const N: usize = 100;
let mut pool = CutPool::new(N, 1, 1, 0);
for i in 0..N {
let intercept = (i % 7) as f64;
let slope = ((i + 3) % 5) as f64 - 2.0;
#[allow(clippy::cast_possible_truncation)]
pool.add_cut(0, i as u32, intercept, &[slope]);
pool.metadata[i].iteration_generated = 1;
}
pool
}
#[allow(clippy::cast_precision_loss)]
fn make_determinism_states(count: usize) -> Vec<f64> {
(0..count).map(|i| (i as f64) * 0.01 - 5.0).collect()
}
fn run_in_pool(
strategy: &CutSelectionStrategy,
pool: &CutPool,
states: &[f64],
current_iteration: u64,
num_threads: usize,
) -> CutActivityUpdates {
let rayon_pool = rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.expect("rayon pool must build for determinism test");
rayon_pool.install(|| strategy.select_for_stage(pool, states, current_iteration, 0))
}
#[test]
fn select_for_stage_deterministic_across_thread_counts_level1() {
let pool = make_determinism_pool();
let states = make_determinism_states(1024);
let strategy = CutSelectionStrategy::Level1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let r1 = run_in_pool(&strategy, &pool, &states, 10, 1);
let r4 = run_in_pool(&strategy, &pool, &states, 10, 4);
let r8 = run_in_pool(&strategy, &pool, &states, 10, 8);
assert_eq!(
r1, r4,
"Level1: 1-thread vs 4-thread results must be bit-identical"
);
assert_eq!(
r4, r8,
"Level1: 4-thread vs 8-thread results must be bit-identical"
);
}
#[test]
fn select_for_stage_deterministic_across_thread_counts_lml1() {
let pool = make_determinism_pool();
let states = make_determinism_states(1024);
let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let r1 = run_in_pool(&strategy, &pool, &states, 10, 1);
let r4 = run_in_pool(&strategy, &pool, &states, 10, 4);
let r8 = run_in_pool(&strategy, &pool, &states, 10, 8);
assert_eq!(
r1, r4,
"Lml1: 1-thread vs 4-thread results must be bit-identical"
);
assert_eq!(
r4, r8,
"Lml1: 4-thread vs 8-thread results must be bit-identical"
);
}
#[test]
fn select_for_stage_deterministic_across_thread_counts_dominated() {
let pool = make_determinism_pool();
let states = make_determinism_states(1024);
let strategy = CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 1,
};
let r1 = run_in_pool(&strategy, &pool, &states, 10, 1);
let r4 = run_in_pool(&strategy, &pool, &states, 10, 4);
let r8 = run_in_pool(&strategy, &pool, &states, 10, 8);
assert_eq!(
r1, r4,
"Dominated: 1-thread vs 4-thread results must be bit-identical"
);
assert_eq!(
r4, r8,
"Dominated: 4-thread vs 8-thread results must be bit-identical"
);
}
#[test]
fn select_for_stage_parallel_matches_sequential_multiple_m_blocks() {
let pool = make_determinism_pool();
let n_states = 263;
let states = make_determinism_states(n_states);
for strategy in [
CutSelectionStrategy::Level1 {
check_frequency: 1,
tie_tolerance: 1e-10,
},
CutSelectionStrategy::Lml1 {
check_frequency: 1,
tie_tolerance: 1e-10,
},
CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 1,
},
] {
let seq = run_in_pool(&strategy, &pool, &states, 10, 1);
let par = run_in_pool(&strategy, &pool, &states, 10, 4);
assert_eq!(
seq, par,
"strategy {strategy:?}: parallel must equal sequential \
across multiple m-blocks with partial last block"
);
}
}
#[test]
fn select_for_stage_deterministic_across_thread_counts() {
let pool = make_determinism_pool();
let states = make_determinism_states(64); let strategy = CutSelectionStrategy::Lml1 {
check_frequency: 1,
tie_tolerance: 1e-10,
};
let r1 = {
let rp = rayon::ThreadPoolBuilder::new()
.num_threads(1)
.build()
.expect("rayon pool");
rp.install(|| strategy.select_for_stage(&pool, &states, 10, 0))
};
let r8 = {
let rp = rayon::ThreadPoolBuilder::new()
.num_threads(8)
.build()
.expect("rayon pool");
rp.install(|| strategy.select_for_stage(&pool, &states, 10, 0))
};
assert_eq!(
r1, r8,
"select_for_stage must be byte-identical across thread counts"
);
}
fn dynamic_from_json(body: &str) -> RowSelectionConfig {
serde_json::from_str(body).expect("dynamic selection block must parse")
}
#[test]
fn parse_dynamic_defaults() {
let cfg = dynamic_from_json(r#"{"selection": {"method": "dynamic"}}"#);
let strategy = parse_cut_selection_config(&cfg)
.expect("dynamic with defaults must parse")
.expect("must produce Some for dynamic");
assert!(
matches!(
strategy,
CutSelectionStrategy::Dynamic {
k1: None,
k2: 5,
nadic: 10,
epsilon_viol,
start_iteration: 2,
} if (epsilon_viol - 1e-10).abs() < 1e-20
),
"unexpected variant or defaults: {strategy:?}"
);
}
#[test]
fn parse_dynamic_overrides() {
let cfg = RowSelectionConfig {
selection: Some(SelectionMethod::Dynamic {
start_iteration: 2,
seed_window: 7,
candidate_recency: Some(20),
max_added_per_round: 3,
violation_tolerance: 1e-9,
}),
..RowSelectionConfig::default()
};
let strategy = parse_cut_selection_config(&cfg)
.expect("dynamic with overrides must parse")
.expect("must produce Some for dynamic");
assert!(
matches!(
strategy,
CutSelectionStrategy::Dynamic {
k1: Some(20),
k2: 7,
nadic: 3,
epsilon_viol,
start_iteration: 2,
} if (epsilon_viol - 1e-9).abs() < f64::EPSILON
),
"unexpected variant or overrides: {strategy:?}"
);
}
#[test]
fn parse_dynamic_rejects_nonpositive_violation_tolerance() {
let cfg = RowSelectionConfig {
selection: Some(SelectionMethod::Dynamic {
start_iteration: 2,
seed_window: 5,
candidate_recency: None,
max_added_per_round: 10,
violation_tolerance: -1.0,
}),
..RowSelectionConfig::default()
};
let msg = parse_cut_selection_config(&cfg)
.expect_err("non-positive violation_tolerance must be rejected for dynamic");
assert!(
msg.contains("violation_tolerance must be > 0"),
"error must mention the violation_tolerance constraint, got: {msg}"
);
}
#[test]
fn parse_dynamic_rejects_zero_candidate_recency() {
let cfg = RowSelectionConfig {
selection: Some(SelectionMethod::Dynamic {
start_iteration: 2,
seed_window: 5,
candidate_recency: Some(0),
max_added_per_round: 10,
violation_tolerance: 1e-10,
}),
..RowSelectionConfig::default()
};
let msg = parse_cut_selection_config(&cfg)
.expect_err("zero candidate_recency must be rejected for dynamic");
assert!(
msg.contains("candidate_recency must be >= 1"),
"error must mention the candidate_recency constraint, got: {msg}"
);
}
#[test]
fn parse_dynamic_explicit_fields() {
let cfg = RowSelectionConfig {
selection: Some(SelectionMethod::Dynamic {
start_iteration: 5,
seed_window: 7,
candidate_recency: Some(20),
max_added_per_round: 3,
violation_tolerance: 1e-9,
}),
..RowSelectionConfig::default()
};
let strategy = parse_cut_selection_config(&cfg)
.expect("dynamic with explicit fields must parse")
.expect("must produce Some for dynamic");
assert!(
matches!(
strategy,
CutSelectionStrategy::Dynamic {
k1: Some(20),
k2: 7,
nadic: 3,
epsilon_viol,
start_iteration: 5,
} if (epsilon_viol - 1e-9).abs() < f64::EPSILON
),
"unexpected variant or explicit fields: {strategy:?}"
);
}
#[test]
fn parse_dynamic_rejects_zero_start_iteration() {
let cfg = RowSelectionConfig {
selection: Some(SelectionMethod::Dynamic {
start_iteration: 0,
seed_window: 5,
candidate_recency: None,
max_added_per_round: 10,
violation_tolerance: 1e-10,
}),
..RowSelectionConfig::default()
};
let msg = parse_cut_selection_config(&cfg)
.expect_err("zero start_iteration must be rejected for dynamic");
assert!(
msg.contains("start_iteration") && msg.contains(">= 1"),
"error must mention start_iteration and '>= 1', got: {msg}"
);
}
#[test]
fn parse_dynamic_rejects_zero_max_added_per_round() {
let cfg = RowSelectionConfig {
selection: Some(SelectionMethod::Dynamic {
start_iteration: 2,
seed_window: 5,
candidate_recency: None,
max_added_per_round: 0,
violation_tolerance: 1e-10,
}),
..RowSelectionConfig::default()
};
let msg = parse_cut_selection_config(&cfg)
.expect_err("zero max_added_per_round must be rejected for dynamic");
assert!(
msg.contains("max_added_per_round") && msg.contains(">= 1"),
"error must name max_added_per_round and '>= 1', got: {msg}"
);
}
#[test]
fn parse_dynamic_seed_window_sets_k2() {
let cfg_zero =
dynamic_from_json(r#"{"selection": {"method": "dynamic", "seed_window": 0}}"#);
let strategy = parse_cut_selection_config(&cfg_zero)
.expect("seed_window = 0 must parse for dynamic (0 is valid)")
.expect("must produce Some for dynamic");
assert!(
matches!(strategy, CutSelectionStrategy::Dynamic { k2: 0, .. }),
"seed_window = 0 must set k2 = 0: {strategy:?}"
);
let cfg_absent = dynamic_from_json(r#"{"selection": {"method": "dynamic"}}"#);
let strategy = parse_cut_selection_config(&cfg_absent)
.expect("absent seed_window must parse for dynamic")
.expect("must produce Some for dynamic");
assert!(
matches!(strategy, CutSelectionStrategy::Dynamic { k2: 5, .. }),
"absent seed_window must default k2 to 5: {strategy:?}"
);
}
#[test]
fn dynamic_should_run_always_false() {
let strategy = CutSelectionStrategy::Dynamic {
k1: None,
k2: 5,
nadic: 10,
epsilon_viol: 1e-10,
start_iteration: 2,
};
for iteration in [0_u64, 1, 2, 5, 100] {
assert!(
!strategy.should_run(iteration),
"Dynamic must never run as a pool-level pass (iteration {iteration})"
);
}
}
#[test]
fn dynamic_select_returns_empty() {
let strategy = CutSelectionStrategy::Dynamic {
k1: None,
k2: 5,
nadic: 10,
epsilon_viol: 1e-10,
start_iteration: 2,
};
let mut pool = CutPool::new(3, 1, 1, 0);
pool.add_cut(0, 0, 1.0, &[0.0]);
pool.add_cut(1, 0, 5.0, &[0.0]);
pool.add_cut(2, 0, 3.0, &[0.0]);
let result = strategy.select_for_stage(&pool, &[0.0], 10, 4);
assert!(
result.updates.is_empty(),
"Dynamic must produce no deactivations"
);
assert!(
result.reactivations.is_empty(),
"Dynamic must produce no reactivations"
);
let via_select = strategy.select(&pool, &[0.0], 10);
assert!(via_select.updates.is_empty());
assert!(via_select.reactivations.is_empty());
}
}