use crate::traits::Repair;
#[derive(Debug, Clone)]
pub struct ClampToBounds {
pub bounds: Vec<(f64, f64)>,
}
impl ClampToBounds {
pub fn new(bounds: Vec<(f64, f64)>) -> Self {
for (i, &(lo, hi)) in bounds.iter().enumerate() {
assert!(
lo <= hi,
"ClampToBounds bound at index {i} has lo > hi: ({lo}, {hi})",
);
}
Self { bounds }
}
}
impl Repair<Vec<f64>> for ClampToBounds {
fn repair(&mut self, decision: &mut Vec<f64>) {
for (j, x) in decision.iter_mut().enumerate() {
if let Some(&(lo, hi)) = self.bounds.get(j) {
*x = x.clamp(lo, hi);
}
}
}
}
#[derive(Debug, Clone)]
pub struct ProjectToSimplex {
pub total: f64,
}
impl ProjectToSimplex {
pub fn new(total: f64) -> Self {
assert!(total > 0.0, "ProjectToSimplex total must be > 0");
Self { total }
}
}
impl Repair<Vec<f64>> for ProjectToSimplex {
fn repair(&mut self, decision: &mut Vec<f64>) {
let n = decision.len();
if n == 0 {
return;
}
let max_abs = decision
.iter()
.copied()
.fold(0.0_f64, |a, b| a.max(b.abs()));
if max_abs > self.total * 1e15 {
let mut argmax = 0;
for (i, &v) in decision.iter().enumerate().skip(1) {
if v > decision[argmax] {
argmax = i;
}
}
for (i, x) in decision.iter_mut().enumerate() {
*x = if i == argmax { self.total } else { 0.0 };
}
return;
}
let mut sorted: Vec<f64> = decision.clone();
sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut tau_at_rho = sorted[0] - self.total;
for (j, &val) in sorted.iter().enumerate() {
cumsum += val;
let tau = (cumsum - self.total) / (j as f64 + 1.0);
if val - tau > 0.0 {
tau_at_rho = tau;
}
}
for x in decision.iter_mut() {
*x = (*x - tau_at_rho).max(0.0);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn clamp_to_bounds_clips() {
let mut r = ClampToBounds::new(vec![(-1.0, 1.0); 3]);
let mut x = vec![-2.5, 0.5, 5.0];
r.repair(&mut x);
assert_eq!(x, vec![-1.0, 0.5, 1.0]);
}
#[test]
fn clamp_passthrough_when_already_in_bounds() {
let mut r = ClampToBounds::new(vec![(-1.0, 1.0); 3]);
let mut x = vec![-0.3, 0.0, 0.7];
let original = x.clone();
r.repair(&mut x);
assert_eq!(x, original);
}
#[test]
#[should_panic(expected = "lo > hi")]
fn clamp_invalid_bounds_panics() {
let _ = ClampToBounds::new(vec![(1.0, -1.0)]);
}
#[test]
fn project_to_unit_simplex_sums_to_total() {
let mut r = ProjectToSimplex::new(1.0);
let mut x = vec![0.5, 0.3, 0.2, -0.5];
r.repair(&mut x);
let s: f64 = x.iter().sum();
assert!(approx_eq(s, 1.0, 1e-12));
for &v in &x {
assert!(v >= 0.0);
}
}
#[test]
fn project_already_on_simplex_unchanged() {
let mut r = ProjectToSimplex::new(1.0);
let mut x = vec![0.5, 0.3, 0.2];
r.repair(&mut x);
let s: f64 = x.iter().sum();
assert!(approx_eq(s, 1.0, 1e-12));
assert!(approx_eq(x[0], 0.5, 1e-12));
assert!(approx_eq(x[1], 0.3, 1e-12));
assert!(approx_eq(x[2], 0.2, 1e-12));
}
#[test]
fn project_arbitrary_total() {
let mut r = ProjectToSimplex::new(10.0);
let mut x = vec![100.0, 50.0, -20.0, 30.0];
r.repair(&mut x);
let s: f64 = x.iter().sum();
assert!(approx_eq(s, 10.0, 1e-9));
for &v in &x {
assert!(v >= 0.0);
}
}
#[test]
#[should_panic(expected = "total must be > 0")]
fn project_non_positive_total_panics() {
let _ = ProjectToSimplex::new(0.0);
}
#[test]
fn project_extreme_magnitudes_concentrates_on_argmax() {
let mut r = ProjectToSimplex::new(1.0);
let mut x = vec![1e20, 5e19, -1e20];
r.repair(&mut x);
let s: f64 = x.iter().sum();
assert!(approx_eq(s, 1.0, 1e-12));
assert!(approx_eq(x[0], 1.0, 1e-12));
assert_eq!(x[1], 0.0);
assert_eq!(x[2], 0.0);
}
#[test]
fn project_all_zeros_distributes_evenly() {
let mut r = ProjectToSimplex::new(1.0);
let mut x = vec![0.0, 0.0, 0.0, 0.0];
r.repair(&mut x);
let s: f64 = x.iter().sum();
assert!(approx_eq(s, 1.0, 1e-12));
for &v in &x {
assert!(approx_eq(v, 0.25, 1e-12));
}
}
#[test]
fn project_extreme_magnitudes_keeps_first_index_on_tie() {
let mut r = ProjectToSimplex::new(1.0);
let mut x = vec![1e20, 1e20, -1e20];
r.repair(&mut x);
assert_eq!(x, vec![1.0, 0.0, 0.0]);
}
#[test]
fn project_extreme_magnitudes_finds_argmax_at_non_zero_index() {
let mut r = ProjectToSimplex::new(1.0);
let mut x = vec![-1e20, 1e20, 5e19];
r.repair(&mut x);
assert_eq!(x, vec![0.0, 1.0, 0.0]);
}
}