use crate::cut::pool::CutPool;
#[derive(Debug, Clone, Default)]
pub struct AngularPruningResult {
pub deactivate: Vec<u32>,
pub clusters_formed: usize,
pub dominance_checks: usize,
}
const DOMINANCE_EPSILON: f64 = 1e-8;
#[must_use]
#[allow(clippy::cast_possible_truncation)]
pub fn select_angular_dominated(
pool: &CutPool,
visited_states: &[f64],
cosine_threshold: f64,
current_iteration: u64,
) -> AngularPruningResult {
let n_state = pool.state_dimension;
debug_assert!(
visited_states.len() % n_state.max(1) == 0,
"visited_states length {} is not a multiple of state_dimension {}",
visited_states.len(),
n_state
);
if pool.active_count() < 2 || visited_states.is_empty() {
return AngularPruningResult::default();
}
let (clusters, zero_norm_cluster) =
cluster_by_angular_similarity(pool, cosine_threshold, current_iteration);
let clusters_formed = clusters.len() + usize::from(!zero_norm_cluster.is_empty());
let mut deactivate: Vec<u32> = Vec::new();
let mut dominance_checks: usize = 0;
for cluster in &clusters {
if cluster.len() < 2 {
continue;
}
dominance_checks += cluster.len() * (cluster.len() - 1);
let dominated = dominated_within_cluster(pool, cluster, visited_states, DOMINANCE_EPSILON);
deactivate.extend(dominated.iter().map(|&idx| idx as u32));
}
if zero_norm_cluster.len() >= 2 {
dominance_checks += zero_norm_cluster.len() * (zero_norm_cluster.len() - 1);
let dominated = dominated_zero_norm_cluster(
pool,
&zero_norm_cluster,
visited_states,
DOMINANCE_EPSILON,
);
deactivate.extend(dominated.iter().map(|&idx| idx as u32));
}
deactivate.sort_unstable();
AngularPruningResult {
deactivate,
clusters_formed,
dominance_checks,
}
}
fn cluster_by_angular_similarity(
pool: &CutPool,
cosine_threshold: f64,
current_iteration: u64,
) -> (Vec<Vec<usize>>, Vec<usize>) {
let populated = pool.populated_count;
let n_state = pool.state_dimension;
let eligible: Vec<usize> = (0..populated)
.filter(|&k| pool.active[k] && pool.metadata[k].iteration_generated < current_iteration)
.collect();
let unit_normals: Vec<Option<Vec<f64>>> = eligible
.iter()
.map(|&k| {
let start = k * n_state;
let coeffs = &pool.coefficients[start..start + n_state];
let norm_sq: f64 = coeffs.iter().map(|&c| c * c).sum();
let norm = norm_sq.sqrt();
if norm < 1e-12 {
None
} else {
let inv = 1.0 / norm;
Some(coeffs.iter().map(|&c| c * inv).collect())
}
})
.collect();
let zero_norm_cluster: Vec<usize> = eligible
.iter()
.zip(unit_normals.iter())
.filter_map(|(&slot, n)| if n.is_none() { Some(slot) } else { None })
.collect();
let mut clusters: Vec<Vec<usize>> = Vec::new();
let mut cluster_reps: Vec<Vec<f64>> = Vec::new();
for (&slot, normal_opt) in eligible.iter().zip(unit_normals.iter()) {
let Some(normal) = normal_opt else {
continue;
};
let mut assigned = false;
for (c_idx, rep) in cluster_reps.iter().enumerate() {
let cosine: f64 = normal.iter().zip(rep.iter()).map(|(a, b)| a * b).sum();
if cosine > cosine_threshold {
clusters[c_idx].push(slot);
assigned = true;
break;
}
}
if !assigned {
cluster_reps.push(normal.clone());
clusters.push(vec![slot]);
}
}
(clusters, zero_norm_cluster)
}
fn dominated_within_cluster(
pool: &CutPool,
cluster: &[usize],
visited_states: &[f64],
epsilon: f64,
) -> Vec<usize> {
let n_state = pool.state_dimension;
let n_cuts = cluster.len();
let mut is_candidate = vec![true; n_cuts];
let mut n_candidates = n_cuts;
let mut values = vec![0.0_f64; n_cuts];
for x_hat in visited_states.chunks_exact(n_state) {
for (i, &slot) in cluster.iter().enumerate() {
let start = slot * n_state;
values[i] = pool.intercepts[slot]
+ pool.coefficients[start..start + n_state]
.iter()
.zip(x_hat.iter())
.map(|(c, x)| c * x)
.sum::<f64>();
}
let max_val = values[..n_cuts]
.iter()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
let cutoff = max_val - epsilon;
for i in 0..n_cuts {
if is_candidate[i] && values[i] >= cutoff {
is_candidate[i] = false;
n_candidates -= 1;
}
}
if n_candidates == 0 {
break;
}
}
cluster
.iter()
.enumerate()
.filter_map(|(i, &slot)| if is_candidate[i] { Some(slot) } else { None })
.collect()
}
fn dominated_zero_norm_cluster(
pool: &CutPool,
cluster: &[usize],
_visited_states: &[f64],
epsilon: f64,
) -> Vec<usize> {
let max_intercept = cluster
.iter()
.map(|&s| pool.intercepts[s])
.fold(f64::NEG_INFINITY, f64::max);
let cutoff = max_intercept - epsilon;
cluster
.iter()
.filter(|&&slot_i| pool.intercepts[slot_i] < cutoff)
.copied()
.collect()
}
#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
pub struct AngularPruningParams {
pub cosine_threshold: f64,
pub check_frequency: u64,
}
impl AngularPruningParams {
#[must_use]
pub fn should_run(&self, iteration: u64) -> bool {
iteration > 0 && iteration % self.check_frequency == 0
}
}
pub fn parse_angular_pruning_config(
config: &cobre_io::config::AngularPruningConfig,
parent_check_frequency: Option<u32>,
) -> Result<Option<AngularPruningParams>, String> {
let enabled = config.enabled.unwrap_or(false);
if !enabled {
return Ok(None);
}
let cosine_threshold = config.cosine_threshold.unwrap_or(0.999);
if cosine_threshold <= 0.0 || cosine_threshold > 1.0 {
return Err(format!(
"angular_pruning.cosine_threshold must be in (0.0, 1.0], got {cosine_threshold}"
));
}
let check_frequency_u32 = config
.check_frequency
.unwrap_or_else(|| parent_check_frequency.unwrap_or(5));
if check_frequency_u32 == 0 {
return Err("angular_pruning.check_frequency must be > 0".to_string());
}
Ok(Some(AngularPruningParams {
cosine_threshold,
check_frequency: u64::from(check_frequency_u32),
}))
}
#[cfg(test)]
mod tests {
use super::{AngularPruningParams, parse_angular_pruning_config};
use cobre_io::config::AngularPruningConfig;
#[test]
fn parse_disabled_returns_none() {
let config = AngularPruningConfig {
enabled: None,
cosine_threshold: None,
check_frequency: None,
};
let result = parse_angular_pruning_config(&config, None).unwrap();
assert!(result.is_none());
}
#[test]
fn parse_enabled_false_returns_none() {
let config = AngularPruningConfig {
enabled: Some(false),
cosine_threshold: None,
check_frequency: None,
};
let result = parse_angular_pruning_config(&config, None).unwrap();
assert!(result.is_none());
}
#[test]
fn parse_enabled_default_threshold() {
let config = AngularPruningConfig {
enabled: Some(true),
cosine_threshold: None,
check_frequency: None,
};
let params = parse_angular_pruning_config(&config, None)
.unwrap()
.unwrap();
assert!(
(params.cosine_threshold - 0.999).abs() < f64::EPSILON,
"expected default threshold 0.999, got {}",
params.cosine_threshold
);
}
#[test]
fn parse_enabled_explicit_threshold() {
let config = AngularPruningConfig {
enabled: Some(true),
cosine_threshold: Some(0.99),
check_frequency: None,
};
let params = parse_angular_pruning_config(&config, None)
.unwrap()
.unwrap();
assert!(
(params.cosine_threshold - 0.99).abs() < f64::EPSILON,
"expected 0.99, got {}",
params.cosine_threshold
);
}
#[test]
fn parse_enabled_inherits_parent_frequency() {
let config = AngularPruningConfig {
enabled: Some(true),
cosine_threshold: None,
check_frequency: None,
};
let params = parse_angular_pruning_config(&config, Some(3))
.unwrap()
.unwrap();
assert_eq!(params.check_frequency, 3);
}
#[test]
fn parse_enabled_explicit_frequency_overrides_parent() {
let config = AngularPruningConfig {
enabled: Some(true),
cosine_threshold: None,
check_frequency: Some(2),
};
let params = parse_angular_pruning_config(&config, Some(3))
.unwrap()
.unwrap();
assert_eq!(params.check_frequency, 2);
}
#[test]
fn parse_enabled_no_parent_frequency_defaults_to_five() {
let config = AngularPruningConfig {
enabled: Some(true),
cosine_threshold: None,
check_frequency: None,
};
let params = parse_angular_pruning_config(&config, None)
.unwrap()
.unwrap();
assert_eq!(params.check_frequency, 5);
}
#[test]
fn parse_invalid_threshold_zero() {
let config = AngularPruningConfig {
enabled: Some(true),
cosine_threshold: Some(0.0),
check_frequency: None,
};
let err = parse_angular_pruning_config(&config, None).unwrap_err();
assert!(
err.contains("cosine_threshold must be in (0.0, 1.0]"),
"unexpected error: {err}"
);
}
#[test]
fn parse_invalid_threshold_negative() {
let config = AngularPruningConfig {
enabled: Some(true),
cosine_threshold: Some(-0.5),
check_frequency: None,
};
let err = parse_angular_pruning_config(&config, None).unwrap_err();
assert!(
err.contains("cosine_threshold must be in (0.0, 1.0]"),
"unexpected error: {err}"
);
}
#[test]
fn parse_invalid_threshold_above_one() {
let config = AngularPruningConfig {
enabled: Some(true),
cosine_threshold: Some(1.1),
check_frequency: None,
};
let err = parse_angular_pruning_config(&config, None).unwrap_err();
assert!(
err.contains("cosine_threshold must be in (0.0, 1.0]"),
"unexpected error: {err}"
);
}
#[test]
fn parse_invalid_frequency_zero() {
let config = AngularPruningConfig {
enabled: Some(true),
cosine_threshold: None,
check_frequency: Some(0),
};
let err = parse_angular_pruning_config(&config, None).unwrap_err();
assert!(
err.contains("check_frequency must be > 0"),
"unexpected error: {err}"
);
}
#[test]
fn should_run_false_at_zero() {
let params = AngularPruningParams {
cosine_threshold: 0.999,
check_frequency: 5,
};
assert!(!params.should_run(0));
}
#[test]
fn should_run_true_at_multiples() {
let params = AngularPruningParams {
cosine_threshold: 0.999,
check_frequency: 5,
};
assert!(params.should_run(5));
assert!(params.should_run(10));
assert!(params.should_run(15));
}
#[test]
fn should_run_false_between_multiples() {
let params = AngularPruningParams {
cosine_threshold: 0.999,
check_frequency: 5,
};
assert!(!params.should_run(1));
assert!(!params.should_run(2));
assert!(!params.should_run(3));
assert!(!params.should_run(4));
assert!(!params.should_run(6));
}
#[test]
fn parse_threshold_exactly_one_is_valid() {
let config = AngularPruningConfig {
enabled: Some(true),
cosine_threshold: Some(1.0),
check_frequency: None,
};
let params = parse_angular_pruning_config(&config, None)
.unwrap()
.unwrap();
assert!(
(params.cosine_threshold - 1.0).abs() < f64::EPSILON,
"threshold 1.0 should be accepted"
);
}
use super::select_angular_dominated;
use crate::cut::pool::CutPool;
fn make_pool(n_state: usize) -> CutPool {
CutPool::new(64, n_state, 32, 0)
}
fn add(pool: &mut CutPool, fp_idx: u32, intercept: f64, coefficients: &[f64]) {
pool.add_cut(1, fp_idx, intercept, coefficients);
}
#[test]
fn h2_preservation_crossing_cuts_both_survive() {
let mut pool = make_pool(1);
add(&mut pool, 0, 0.0, &[2.0]);
add(&mut pool, 1, 1.0, &[0.5]);
let visited = vec![0.0_f64, 1.0_f64];
let result = select_angular_dominated(&pool, &visited, 0.999, 2);
assert!(
result.deactivate.is_empty(),
"neither cut should be deactivated; got {:?}",
result.deactivate
);
}
#[test]
fn dominated_cut_deactivated() {
let mut pool = make_pool(1);
add(&mut pool, 0, 5.0, &[1.0]);
add(&mut pool, 1, 2.0, &[1.0]);
let visited = vec![0.0_f64, 10.0_f64, 100.0_f64];
let result = select_angular_dominated(&pool, &visited, 0.999, 2);
assert_eq!(
result.deactivate,
vec![33],
"B (slot 33) should be deactivated"
);
}
#[test]
fn equal_valued_cuts_one_survives() {
let mut pool = make_pool(1);
add(&mut pool, 0, 5.0, &[1.0]);
add(&mut pool, 1, 5.0, &[1.0]);
let visited = vec![0.0_f64, 10.0_f64];
let result = select_angular_dominated(&pool, &visited, 0.999, 2);
assert!(
result.deactivate.is_empty(),
"equal-valued cuts must not be deactivated: got {:?}",
result.deactivate
);
}
#[test]
fn equal_intercept_zero_norm_both_survive() {
let mut pool = make_pool(1);
add(&mut pool, 0, 5.0, &[0.0]);
add(&mut pool, 1, 5.0, &[0.0]);
let visited = vec![0.0_f64];
let result = select_angular_dominated(&pool, &visited, 0.999, 2);
assert!(
result.deactivate.is_empty(),
"equal-intercept zero-norm cuts must not be deactivated"
);
}
#[test]
fn orthogonal_cuts_no_deactivation() {
let mut pool = make_pool(2);
add(&mut pool, 0, 0.0, &[1.0, 0.0]);
add(&mut pool, 1, 0.0, &[0.0, 1.0]);
let visited = vec![1.0_f64, 0.0, 0.0, 1.0]; let result = select_angular_dominated(&pool, &visited, 0.999, 2);
assert!(
result.deactivate.is_empty(),
"orthogonal cuts must not be deactivated; got {:?}",
result.deactivate
);
}
#[test]
fn zero_norm_dominated() {
let mut pool = make_pool(2);
add(&mut pool, 0, 5.0, &[0.0, 0.0]);
add(&mut pool, 1, 10.0, &[0.0, 0.0]);
let visited = vec![1.0_f64, 2.0, 3.0, 4.0]; let result = select_angular_dominated(&pool, &visited, 0.999, 2);
assert_eq!(
result.deactivate,
vec![32],
"cut with intercept 5.0 (slot 32) should be deactivated"
);
}
#[test]
fn current_iteration_excluded() {
let mut pool = make_pool(1);
add(&mut pool, 0, 5.0, &[1.0]);
add(&mut pool, 1, 2.0, &[1.0]);
let visited = vec![0.0_f64, 1.0];
let result = select_angular_dominated(&pool, &visited, 0.999, 1);
assert!(
result.deactivate.is_empty(),
"cuts from current iteration must not be deactivated; got {:?}",
result.deactivate
);
}
#[test]
fn empty_visited_states_returns_empty() {
let mut pool = make_pool(1);
add(&mut pool, 0, 5.0, &[1.0]);
add(&mut pool, 1, 2.0, &[1.0]);
let result = select_angular_dominated(&pool, &[], 0.999, 2);
assert!(
result.deactivate.is_empty(),
"empty visited states must yield empty result"
);
}
#[test]
fn single_active_cut_returns_empty() {
let mut pool = make_pool(1);
add(&mut pool, 0, 5.0, &[1.0]);
let visited = vec![0.0_f64, 1.0, 2.0];
let result = select_angular_dominated(&pool, &visited, 0.999, 2);
assert!(
result.deactivate.is_empty(),
"single cut cannot be dominated"
);
}
#[test]
fn three_cut_cluster_partial_dominance() {
let mut pool = make_pool(1);
add(&mut pool, 0, 10.0, &[1.0]);
add(&mut pool, 1, 7.0, &[1.0]);
add(&mut pool, 2, 3.0, &[1.0]);
let visited = vec![0.0_f64, 5.0, 10.0];
let result = select_angular_dominated(&pool, &visited, 0.999, 2);
let mut got = result.deactivate.clone();
got.sort_unstable();
assert_eq!(got, vec![33, 34], "B and C should be deactivated");
}
#[test]
fn determinism() {
let mut pool = make_pool(2);
add(&mut pool, 0, 10.0, &[1.0, 0.0]);
add(&mut pool, 1, 7.0, &[1.0, 0.0]);
add(&mut pool, 2, 0.0, &[0.0, 1.0]);
let visited = vec![1.0_f64, 0.0, 0.0, 1.0, 2.0, 3.0];
let r1 = select_angular_dominated(&pool, &visited, 0.999, 2);
let r2 = select_angular_dominated(&pool, &visited, 0.999, 2);
assert_eq!(
r1.deactivate, r2.deactivate,
"results must be deterministic"
);
assert_eq!(r1.clusters_formed, r2.clusters_formed);
assert_eq!(r1.dominance_checks, r2.dominance_checks);
}
#[test]
fn high_dimensional_correctness() {
let mut pool = make_pool(10);
let coeffs = vec![1.0_f64; 10];
add(&mut pool, 0, 100.0, &coeffs);
add(&mut pool, 1, 50.0, &coeffs);
let visited: Vec<f64> = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
0.5, 0.5, 0.5, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0,
];
let result = select_angular_dominated(&pool, &visited, 0.999, 2);
assert_eq!(
result.deactivate,
vec![33], "B should be deactivated in 10-d case"
);
}
}