#[derive(Debug, Clone)]
pub struct CutMetadata {
pub iteration_generated: u64,
pub forward_pass_index: u32,
pub active_count: u64,
pub last_active_iter: u64,
pub domination_count: u64,
}
#[derive(Debug, Clone)]
pub struct DeactivationSet {
pub stage_index: u32,
pub indices: Vec<u32>,
}
#[derive(Debug, Clone)]
pub enum CutSelectionStrategy {
Level1 {
threshold: u64,
check_frequency: u64,
},
Lml1 {
memory_window: u64,
check_frequency: u64,
},
Dominated {
threshold: f64,
check_frequency: u64,
},
}
impl CutSelectionStrategy {
#[must_use]
pub fn should_run(&self, iteration: u64) -> bool {
let freq = match self {
Self::Level1 {
check_frequency, ..
}
| Self::Lml1 {
check_frequency, ..
}
| Self::Dominated {
check_frequency, ..
} => *check_frequency,
};
iteration > 0 && iteration % freq == 0
}
#[must_use]
pub fn select(&self, metadata: &[CutMetadata], current_iteration: u64) -> DeactivationSet {
self.select_for_stage(metadata, current_iteration, 0)
}
#[must_use]
pub fn select_for_stage(
&self,
metadata: &[CutMetadata],
current_iteration: u64,
stage_index: u32,
) -> DeactivationSet {
#[allow(clippy::cast_possible_truncation)]
let indices = match self {
Self::Level1 { threshold, .. } => metadata
.iter()
.enumerate()
.filter(|(_, m)| m.active_count <= *threshold)
.map(|(i, _)| i as u32)
.collect(),
Self::Lml1 { memory_window, .. } => metadata
.iter()
.enumerate()
.filter(|(_, m)| {
current_iteration.saturating_sub(m.last_active_iter) > *memory_window
})
.map(|(i, _)| i as u32)
.collect(),
Self::Dominated { .. } => vec![],
};
DeactivationSet {
stage_index,
indices,
}
}
pub fn update_activity(
&self,
metadata: &mut CutMetadata,
is_binding: bool,
current_iteration: u64,
) {
if !is_binding {
return;
}
match self {
Self::Level1 { .. } => {
metadata.active_count += 1;
}
Self::Lml1 { .. } => {
metadata.last_active_iter = current_iteration;
}
Self::Dominated { .. } => {
metadata.domination_count = 0;
}
}
}
}
pub fn parse_cut_selection_config(
config: &cobre_io::config::CutSelectionConfig,
) -> Result<Option<CutSelectionStrategy>, String> {
let enabled = config.enabled.unwrap_or(false);
if !enabled {
return Ok(None);
}
let method = config
.method
.as_deref()
.ok_or_else(|| "cut_selection.enabled is true but method is not specified".to_string())?;
let threshold = config.threshold.unwrap_or(0);
let check_frequency = config.check_frequency.unwrap_or(5);
if check_frequency == 0 {
return Err("cut_selection.check_frequency must be > 0".to_string());
}
match method {
"level1" => Ok(Some(CutSelectionStrategy::Level1 {
threshold: u64::from(threshold),
check_frequency: u64::from(check_frequency),
})),
"lml1" => Ok(Some(CutSelectionStrategy::Lml1 {
memory_window: u64::from(threshold),
check_frequency: u64::from(check_frequency),
})),
"domination" => Ok(Some(CutSelectionStrategy::Dominated {
threshold: f64::from(threshold),
check_frequency: u64::from(check_frequency),
})),
other => Err(format!("unknown cut_selection.method: \"{other}\"")),
}
}
#[cfg(test)]
mod tests {
use super::parse_cut_selection_config;
use super::{CutMetadata, CutSelectionStrategy, DeactivationSet};
use cobre_io::config::CutSelectionConfig;
fn make_meta(active_count: u64, last_active_iter: u64, domination_count: u64) -> CutMetadata {
CutMetadata {
iteration_generated: 1,
forward_pass_index: 0,
active_count,
last_active_iter,
domination_count,
}
}
#[test]
fn should_run_false_at_zero() {
let s = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
};
assert!(!s.should_run(0));
}
#[test]
fn should_run_false_between_multiples() {
let s = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
};
assert!(!s.should_run(3));
assert!(!s.should_run(7));
}
#[test]
fn should_run_true_at_multiples() {
let s = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
};
assert!(s.should_run(5));
assert!(s.should_run(10));
assert!(s.should_run(15));
}
#[test]
fn should_run_lml1_respects_check_frequency() {
let s = CutSelectionStrategy::Lml1 {
memory_window: 10,
check_frequency: 5,
};
assert!(!s.should_run(0));
assert!(!s.should_run(3));
assert!(s.should_run(5));
assert!(s.should_run(10));
}
#[test]
fn should_run_dominated_respects_check_frequency() {
let s = CutSelectionStrategy::Dominated {
threshold: 0.0,
check_frequency: 10,
};
assert!(!s.should_run(5));
assert!(s.should_run(10));
}
#[test]
fn level1_deactivates_zero_activity_cuts() {
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
};
let metadata = vec![make_meta(0, 1, 0), make_meta(1, 5, 0)];
let deact = strategy.select(&metadata, 10);
assert_eq!(
deact.indices,
vec![0],
"only the inactive cut is deactivated"
);
}
#[test]
fn level1_retains_positive_activity_cuts() {
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
};
let metadata = vec![make_meta(3, 1, 0), make_meta(7, 5, 0)];
let deact = strategy.select(&metadata, 10);
assert!(
deact.indices.is_empty(),
"no cuts should be deactivated when all have activity"
);
}
#[test]
fn level1_threshold_1_deactivates_cuts_with_count_at_most_1() {
let strategy = CutSelectionStrategy::Level1 {
threshold: 1,
check_frequency: 5,
};
let metadata = vec![make_meta(0, 1, 0), make_meta(1, 5, 0), make_meta(2, 8, 0)];
let deact = strategy.select(&metadata, 10);
assert_eq!(deact.indices, vec![0, 1]);
}
#[test]
fn level1_empty_metadata_returns_empty_set() {
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
};
let deact = strategy.select(&[], 10);
assert!(deact.indices.is_empty());
}
#[test]
fn lml1_deactivates_cuts_outside_memory_window() {
let strategy = CutSelectionStrategy::Lml1 {
memory_window: 10,
check_frequency: 5,
};
let metadata = vec![make_meta(0, 5, 0)]; let deact = strategy.select(&metadata, 20);
assert_eq!(deact.indices, vec![0]);
}
#[test]
fn lml1_retains_cuts_within_memory_window() {
let strategy = CutSelectionStrategy::Lml1 {
memory_window: 10,
check_frequency: 5,
};
let metadata = vec![make_meta(0, 12, 0)];
let deact = strategy.select(&metadata, 20);
assert!(deact.indices.is_empty());
}
#[test]
fn lml1_retains_cuts_exactly_at_boundary() {
let strategy = CutSelectionStrategy::Lml1 {
memory_window: 10,
check_frequency: 5,
};
let metadata = vec![make_meta(0, 10, 0)];
let deact = strategy.select(&metadata, 20);
assert!(
deact.indices.is_empty(),
"boundary case: exactly at window edge, retained"
);
}
#[test]
fn lml1_mixed_cuts_deactivates_correct_indices() {
let strategy = CutSelectionStrategy::Lml1 {
memory_window: 10,
check_frequency: 5,
};
let metadata = vec![make_meta(0, 5, 0), make_meta(0, 12, 0), make_meta(0, 1, 0)];
let deact = strategy.select(&metadata, 20);
assert_eq!(deact.indices, vec![0, 2]);
}
#[test]
fn dominated_select_always_returns_empty_set() {
let strategy = CutSelectionStrategy::Dominated {
threshold: 0.001,
check_frequency: 10,
};
let metadata = vec![make_meta(0, 1, 5), make_meta(0, 1, 10)];
let deact = strategy.select(&metadata, 20);
assert!(deact.indices.is_empty());
}
#[test]
fn level1_update_activity_increments_active_count_when_binding() {
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
};
let mut meta = make_meta(0, 1, 0);
strategy.update_activity(&mut meta, true, 5);
assert_eq!(meta.active_count, 1);
strategy.update_activity(&mut meta, true, 6);
assert_eq!(meta.active_count, 2);
}
#[test]
fn level1_update_activity_does_nothing_when_not_binding() {
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
};
let mut meta = make_meta(3, 1, 0);
strategy.update_activity(&mut meta, false, 5);
assert_eq!(meta.active_count, 3, "must not modify when not binding");
}
#[test]
fn lml1_update_activity_sets_last_active_iter_when_binding() {
let strategy = CutSelectionStrategy::Lml1 {
memory_window: 10,
check_frequency: 5,
};
let mut meta = make_meta(0, 1, 0);
strategy.update_activity(&mut meta, true, 15);
assert_eq!(meta.last_active_iter, 15);
}
#[test]
fn lml1_update_activity_does_nothing_when_not_binding() {
let strategy = CutSelectionStrategy::Lml1 {
memory_window: 10,
check_frequency: 5,
};
let mut meta = make_meta(0, 7, 0);
strategy.update_activity(&mut meta, false, 15);
assert_eq!(meta.last_active_iter, 7, "must not modify when not binding");
}
#[test]
fn dominated_update_activity_resets_domination_count_when_binding() {
let strategy = CutSelectionStrategy::Dominated {
threshold: 0.001,
check_frequency: 10,
};
let mut meta = make_meta(0, 1, 42);
strategy.update_activity(&mut meta, true, 10);
assert_eq!(
meta.domination_count, 0,
"domination_count must be reset when cut is binding"
);
}
#[test]
fn dominated_update_activity_does_nothing_when_not_binding() {
let strategy = CutSelectionStrategy::Dominated {
threshold: 0.001,
check_frequency: 10,
};
let mut meta = make_meta(0, 1, 42);
strategy.update_activity(&mut meta, false, 10);
assert_eq!(
meta.domination_count, 42,
"must not modify when not binding"
);
}
#[test]
fn ac_level1_threshold_0_deactivates_zero_activity_cut() {
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
};
let metadata = vec![CutMetadata {
iteration_generated: 1,
forward_pass_index: 0,
active_count: 0,
last_active_iter: 1,
domination_count: 0,
}];
let deact = strategy.select(&metadata, 10);
assert!(deact.indices.contains(&0));
}
#[test]
fn ac_lml1_deactivates_cut_outside_memory_window() {
let strategy = CutSelectionStrategy::Lml1 {
memory_window: 10,
check_frequency: 5,
};
let metadata = vec![CutMetadata {
iteration_generated: 1,
forward_pass_index: 0,
active_count: 0,
last_active_iter: 5,
domination_count: 0,
}];
let deact = strategy.select(&metadata, 20);
assert!(deact.indices.contains(&0));
}
#[test]
fn select_for_stage_sets_stage_index() {
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
};
let metadata = vec![make_meta(0, 1, 0)];
let deact = strategy.select_for_stage(&metadata, 10, 7);
assert_eq!(deact.stage_index, 7);
}
#[test]
fn select_sets_stage_index_to_zero() {
let strategy = CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
};
let metadata: Vec<CutMetadata> = vec![];
let deact = strategy.select(&metadata, 10);
assert_eq!(deact.stage_index, 0);
}
#[test]
fn deactivation_set_derives_debug_and_clone() {
let deact = DeactivationSet {
stage_index: 2,
indices: vec![0, 3, 7],
};
let cloned = deact.clone();
assert_eq!(cloned.stage_index, 2);
assert_eq!(cloned.indices, vec![0, 3, 7]);
assert!(!format!("{deact:?}").is_empty());
}
#[test]
fn cut_metadata_derives_debug_and_clone() {
let meta = make_meta(5, 10, 2);
let cloned = meta.clone();
assert_eq!(cloned.active_count, 5);
assert!(!format!("{meta:?}").is_empty());
}
#[test]
fn test_parse_disabled_default() {
let cfg = CutSelectionConfig::default();
let result = parse_cut_selection_config(&cfg);
assert!(result.is_ok());
assert!(
result.unwrap().is_none(),
"default config must produce None (disabled)"
);
}
#[test]
fn test_parse_level1() {
let cfg = CutSelectionConfig {
enabled: Some(true),
method: Some("level1".to_string()),
threshold: Some(0),
check_frequency: Some(5),
cut_activity_tolerance: None,
};
let result = parse_cut_selection_config(&cfg);
assert!(result.is_ok());
let strategy = result
.unwrap()
.expect("must produce Some for enabled level1");
assert!(
matches!(
strategy,
CutSelectionStrategy::Level1 {
threshold: 0,
check_frequency: 5,
}
),
"unexpected variant: {strategy:?}"
);
}
#[test]
fn test_parse_lml1() {
let cfg = CutSelectionConfig {
enabled: Some(true),
method: Some("lml1".to_string()),
threshold: None,
check_frequency: None,
cut_activity_tolerance: None,
};
let result = parse_cut_selection_config(&cfg);
assert!(result.is_ok());
let strategy = result.unwrap().expect("must produce Some for enabled lml1");
assert!(
matches!(
strategy,
CutSelectionStrategy::Lml1 {
memory_window: 0,
check_frequency: 5,
}
),
"unexpected variant: {strategy:?}"
);
}
#[test]
fn test_parse_domination() {
let cfg = CutSelectionConfig {
enabled: Some(true),
method: Some("domination".to_string()),
threshold: Some(0),
check_frequency: Some(10),
cut_activity_tolerance: None,
};
let result = parse_cut_selection_config(&cfg);
assert!(result.is_ok());
let strategy = result
.unwrap()
.expect("must produce Some for enabled domination");
assert!(
matches!(
strategy,
CutSelectionStrategy::Dominated {
threshold,
check_frequency: 10,
} if threshold == 0.0
),
"unexpected variant: {strategy:?}"
);
}
#[test]
fn test_parse_unknown_method() {
let cfg = CutSelectionConfig {
enabled: Some(true),
method: Some("bogus".to_string()),
threshold: None,
check_frequency: None,
cut_activity_tolerance: None,
};
let result = parse_cut_selection_config(&cfg);
assert!(result.is_err());
let msg = result.unwrap_err();
assert!(
msg.contains("bogus"),
"error message must contain the unrecognized method name, got: {msg}"
);
}
#[test]
fn test_parse_enabled_without_method() {
let cfg = CutSelectionConfig {
enabled: Some(true),
method: None,
threshold: None,
check_frequency: None,
cut_activity_tolerance: None,
};
let result = parse_cut_selection_config(&cfg);
assert!(result.is_err());
}
#[test]
fn test_parse_enabled_false_with_method_returns_none() {
let cfg = CutSelectionConfig {
enabled: Some(false),
method: Some("level1".to_string()),
threshold: None,
check_frequency: None,
cut_activity_tolerance: None,
};
let result = parse_cut_selection_config(&cfg).unwrap();
assert!(
result.is_none(),
"enabled=false must return None even when method is set"
);
}
#[test]
fn test_parse_zero_check_frequency() {
let cfg = CutSelectionConfig {
enabled: Some(true),
method: Some("level1".to_string()),
threshold: None,
check_frequency: Some(0),
cut_activity_tolerance: None,
};
let result = parse_cut_selection_config(&cfg);
assert!(result.is_err());
let msg = result.unwrap_err();
assert!(
msg.contains("check_frequency"),
"error message must mention check_frequency, got: {msg}"
);
}
}