use crate::error::{OptimizeError, OptimizeResult};
#[derive(Debug, Clone, PartialEq)]
pub struct ViolationSummary {
pub total_violation: f64,
pub max_violation: f64,
pub n_violated: usize,
pub violations: Vec<f64>,
}
impl ViolationSummary {
pub fn new(violations: Vec<f64>) -> Self {
let total: f64 = violations.iter().sum();
let max = violations.iter().cloned().fold(0.0_f64, f64::max);
let n_violated = violations.iter().filter(|&&v| v > 0.0).count();
Self {
total_violation: total,
max_violation: max,
n_violated,
violations,
}
}
pub fn is_feasible(&self) -> bool {
self.total_violation == 0.0
}
pub fn is_approximately_feasible(&self, tol: f64) -> bool {
self.total_violation <= tol
}
pub fn l2_norm(&self) -> f64 {
self.violations.iter().map(|v| v.powi(2)).sum::<f64>().sqrt()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct FeasibilityRule {
pub feasibility_tol: f64,
}
impl FeasibilityRule {
pub fn new(feasibility_tol: f64) -> Self {
Self { feasibility_tol }
}
pub fn compare(
&self,
f_a: f64,
viol_a: &ViolationSummary,
f_b: f64,
viol_b: &ViolationSummary,
) -> std::cmp::Ordering {
let feasible_a = viol_a.total_violation <= self.feasibility_tol;
let feasible_b = viol_b.total_violation <= self.feasibility_tol;
match (feasible_a, feasible_b) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
(true, true) => f_a
.partial_cmp(&f_b)
.unwrap_or(std::cmp::Ordering::Equal),
(false, false) => viol_a
.total_violation
.partial_cmp(&viol_b.total_violation)
.unwrap_or(std::cmp::Ordering::Equal),
}
}
pub fn sort_population(
&self,
population: &[(f64, ViolationSummary)],
) -> Vec<usize> {
let mut indices: Vec<usize> = (0..population.len()).collect();
indices.sort_by(|&a, &b| {
self.compare(
population[a].0,
&population[a].1,
population[b].0,
&population[b].1,
)
});
indices
}
}
#[derive(Debug, Clone)]
pub struct EpsilonFeasibility {
epsilon: f64,
epsilon_0: f64,
t_c: usize,
cp: f64,
schedule: EpsilonSchedule,
generation: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum EpsilonSchedule {
Linear,
Exponential,
PowerLaw,
}
impl EpsilonFeasibility {
pub fn new(
epsilon_0: f64,
t_c: usize,
cp: f64,
schedule: EpsilonSchedule,
) -> OptimizeResult<Self> {
if epsilon_0 < 0.0 {
return Err(OptimizeError::InvalidInput(
"epsilon_0 must be >= 0".to_string(),
));
}
if t_c == 0 {
return Err(OptimizeError::InvalidInput("t_c must be > 0".to_string()));
}
Ok(Self {
epsilon: epsilon_0,
epsilon_0,
t_c,
cp,
schedule,
generation: 0,
})
}
pub fn epsilon(&self) -> f64 {
self.epsilon
}
pub fn step(&mut self) {
self.generation += 1;
self.epsilon = self.compute_epsilon(self.generation);
}
fn compute_epsilon(&self, t: usize) -> f64 {
if t >= self.t_c {
return 0.0;
}
let ratio = t as f64 / self.t_c as f64;
match self.schedule {
EpsilonSchedule::Linear => self.epsilon_0 * (1.0 - ratio),
EpsilonSchedule::Exponential => self.epsilon_0 * (-self.cp * t as f64).exp(),
EpsilonSchedule::PowerLaw => self.epsilon_0 * (1.0 - ratio).powf(self.cp),
}
}
pub fn compare(
&self,
f_a: f64,
viol_a: &ViolationSummary,
f_b: f64,
viol_b: &ViolationSummary,
) -> std::cmp::Ordering {
let eps_feasible_a = viol_a.total_violation <= self.epsilon;
let eps_feasible_b = viol_b.total_violation <= self.epsilon;
match (eps_feasible_a, eps_feasible_b) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
(true, true) => f_a.partial_cmp(&f_b).unwrap_or(std::cmp::Ordering::Equal),
(false, false) => viol_a
.total_violation
.partial_cmp(&viol_b.total_violation)
.unwrap_or(std::cmp::Ordering::Equal),
}
}
}
#[derive(Debug, Clone)]
pub struct StochasticRanking {
pub p_f: f64,
seed: u64,
}
impl StochasticRanking {
pub fn new(p_f: f64, seed: u64) -> OptimizeResult<Self> {
if !(0.0..=1.0).contains(&p_f) {
return Err(OptimizeError::InvalidInput(
"p_f must be in [0, 1]".to_string(),
));
}
Ok(Self { p_f, seed })
}
pub fn rank(
&self,
objectives: &[f64],
violations: &[ViolationSummary],
n_passes: usize,
) -> OptimizeResult<Vec<usize>> {
let n = objectives.len();
if n != violations.len() {
return Err(OptimizeError::InvalidInput(format!(
"objectives.len()={n} must equal violations.len()={}",
violations.len()
)));
}
if n == 0 {
return Ok(vec![]);
}
let mut indices: Vec<usize> = (0..n).collect();
let mut rng_state = self.seed;
let n_bubble = n_passes.max(1);
for _ in 0..n_bubble {
for j in 0..(n - 1) {
let a = indices[j];
let b = indices[j + 1];
let fa = objectives[a];
let fb = objectives[b];
let va = &violations[a];
let vb = &violations[b];
rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
let u = (rng_state >> 33) as f64 / (u32::MAX as f64);
let should_swap = if va.is_feasible() && vb.is_feasible() {
if u < self.p_f {
fa > fb } else {
false
}
} else if !va.is_feasible() && !vb.is_feasible() {
va.total_violation > vb.total_violation
} else {
!va.is_feasible() && vb.is_feasible()
};
if should_swap {
indices.swap(j, j + 1);
}
}
}
Ok(indices)
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveFeasibility {
feasibility_ratio: f64,
target_ratio: f64,
alpha: f64,
feasibility_tol: f64,
}
impl AdaptiveFeasibility {
pub fn new(target_ratio: f64, alpha: f64, feasibility_tol: f64) -> OptimizeResult<Self> {
if !(0.0..=1.0).contains(&target_ratio) {
return Err(OptimizeError::InvalidInput(
"target_ratio must be in (0, 1]".to_string(),
));
}
if !(0.0..=1.0).contains(&alpha) {
return Err(OptimizeError::InvalidInput(
"alpha must be in [0, 1]".to_string(),
));
}
Ok(Self {
feasibility_ratio: target_ratio,
target_ratio,
alpha,
feasibility_tol,
})
}
pub fn update(&mut self, violations: &[ViolationSummary]) {
if violations.is_empty() {
return;
}
let n_feasible = violations
.iter()
.filter(|v| v.total_violation <= self.feasibility_tol)
.count();
let observed_ratio = n_feasible as f64 / violations.len() as f64;
self.feasibility_ratio =
(1.0 - self.alpha) * self.feasibility_ratio + self.alpha * observed_ratio;
}
pub fn feasibility_ratio(&self) -> f64 {
self.feasibility_ratio
}
pub fn effective_penalty_weight(&self, base_penalty: f64) -> f64 {
if self.feasibility_ratio <= 0.0 {
return base_penalty * 10.0;
}
let ratio = self.target_ratio / self.feasibility_ratio.max(1e-10);
if self.feasibility_ratio < self.target_ratio {
base_penalty * ratio.powi(2)
} else {
base_penalty / ratio
}
}
pub fn compare(
&self,
f_a: f64,
viol_a: &ViolationSummary,
f_b: f64,
viol_b: &ViolationSummary,
) -> std::cmp::Ordering {
let feasible_a = viol_a.total_violation <= self.feasibility_tol;
let feasible_b = viol_b.total_violation <= self.feasibility_tol;
match (feasible_a, feasible_b) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
(true, true) => f_a.partial_cmp(&f_b).unwrap_or(std::cmp::Ordering::Equal),
(false, false) => {
let w = self.effective_penalty_weight(1.0);
let score_a = f_a + w * viol_a.total_violation;
let score_b = f_b + w * viol_b.total_violation;
score_a.partial_cmp(&score_b).unwrap_or(std::cmp::Ordering::Equal)
}
}
}
}
pub fn ineq_violations<F>(x: &[f64], g_fns: &[F]) -> ViolationSummary
where
F: Fn(&[f64]) -> f64,
{
let violations: Vec<f64> = g_fns.iter().map(|g| g(x).max(0.0)).collect();
ViolationSummary::new(violations)
}
pub fn eq_violations<F>(x: &[f64], h_fns: &[F]) -> ViolationSummary
where
F: Fn(&[f64]) -> f64,
{
let violations: Vec<f64> = h_fns.iter().map(|h| h(x).abs()).collect();
ViolationSummary::new(violations)
}
pub fn combined_violations<G, H>(
x: &[f64],
g_fns: &[G],
h_fns: &[H],
) -> ViolationSummary
where
G: Fn(&[f64]) -> f64,
H: Fn(&[f64]) -> f64,
{
let mut violations: Vec<f64> = g_fns.iter().map(|g| g(x).max(0.0)).collect();
violations.extend(h_fns.iter().map(|h| h(x).abs()));
ViolationSummary::new(violations)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_violation_summary_feasible() {
let vs = ViolationSummary::new(vec![0.0, 0.0]);
assert!(vs.is_feasible());
assert_eq!(vs.total_violation, 0.0);
assert_eq!(vs.n_violated, 0);
}
#[test]
fn test_violation_summary_infeasible() {
let vs = ViolationSummary::new(vec![0.3, 0.0, 0.5]);
assert!(!vs.is_feasible());
assert!((vs.total_violation - 0.8).abs() < 1e-10);
assert!((vs.max_violation - 0.5).abs() < 1e-10);
assert_eq!(vs.n_violated, 2);
}
#[test]
fn test_violation_summary_l2_norm() {
let vs = ViolationSummary::new(vec![3.0, 4.0]);
assert!((vs.l2_norm() - 5.0).abs() < 1e-10);
}
#[test]
fn test_violation_summary_approximately_feasible() {
let vs = ViolationSummary::new(vec![0.05]);
assert!(vs.is_approximately_feasible(0.1));
assert!(!vs.is_approximately_feasible(0.01));
}
#[test]
fn test_feasibility_rule_feasible_beats_infeasible() {
let rule = FeasibilityRule::new(1e-8);
let feas = ViolationSummary::new(vec![0.0]);
let infeas = ViolationSummary::new(vec![0.5]);
assert_eq!(rule.compare(100.0, &feas, 0.0, &infeas), std::cmp::Ordering::Less);
}
#[test]
fn test_feasibility_rule_both_feasible_compare_objective() {
let rule = FeasibilityRule::new(1e-8);
let feas = ViolationSummary::new(vec![0.0]);
assert_eq!(rule.compare(1.0, &feas, 2.0, &feas), std::cmp::Ordering::Less);
assert_eq!(rule.compare(2.0, &feas, 1.0, &feas), std::cmp::Ordering::Greater);
assert_eq!(rule.compare(1.5, &feas, 1.5, &feas), std::cmp::Ordering::Equal);
}
#[test]
fn test_feasibility_rule_both_infeasible_compare_violation() {
let rule = FeasibilityRule::new(1e-8);
let v1 = ViolationSummary::new(vec![0.3]);
let v2 = ViolationSummary::new(vec![0.8]);
assert_eq!(rule.compare(0.0, &v1, 0.0, &v2), std::cmp::Ordering::Less);
}
#[test]
fn test_feasibility_rule_sort_population() {
let rule = FeasibilityRule::new(1e-8);
let pop = vec![
(5.0, ViolationSummary::new(vec![0.0])), (1.0, ViolationSummary::new(vec![1.0])), (2.0, ViolationSummary::new(vec![0.0])), ];
let sorted = rule.sort_population(&pop);
assert_eq!(sorted[0], 2);
assert_eq!(sorted[1], 0);
assert_eq!(sorted[2], 1);
}
#[test]
fn test_epsilon_feasibility_linear_decay() {
let mut ef = EpsilonFeasibility::new(1.0, 10, 1.0, EpsilonSchedule::Linear).expect("failed to create ef");
assert!((ef.epsilon() - 1.0).abs() < 1e-10);
ef.step(); assert!((ef.epsilon() - 0.9).abs() < 1e-10);
for _ in 0..9 {
ef.step();
}
assert_eq!(ef.epsilon(), 0.0);
}
#[test]
fn test_epsilon_feasibility_power_law_decay() {
let mut ef = EpsilonFeasibility::new(1.0, 100, 2.0, EpsilonSchedule::PowerLaw).expect("failed to create ef");
ef.step(); assert!(ef.epsilon() > 0.97 && ef.epsilon() < 1.0);
}
#[test]
fn test_epsilon_feasibility_invalid_epsilon() {
let result = EpsilonFeasibility::new(-1.0, 10, 1.0, EpsilonSchedule::Linear);
assert!(result.is_err());
}
#[test]
fn test_epsilon_feasibility_invalid_tc() {
let result = EpsilonFeasibility::new(1.0, 0, 1.0, EpsilonSchedule::Linear);
assert!(result.is_err());
}
#[test]
fn test_epsilon_feasibility_compare_relaxed_feasibility() {
let ef = EpsilonFeasibility::new(0.5, 10, 1.0, EpsilonSchedule::Linear).expect("failed to create ef");
let v1 = ViolationSummary::new(vec![0.3]);
let v2 = ViolationSummary::new(vec![0.4]);
assert_eq!(ef.compare(1.0, &v1, 2.0, &v2), std::cmp::Ordering::Less);
}
#[test]
fn test_stochastic_ranking_correct_length() {
let sr = StochasticRanking::new(0.45, 42).expect("failed to create sr");
let objectives = vec![1.0, 2.0, 3.0, 4.0];
let violations = vec![
ViolationSummary::new(vec![0.0]),
ViolationSummary::new(vec![0.5]),
ViolationSummary::new(vec![0.0]),
ViolationSummary::new(vec![0.2]),
];
let ranked = sr.rank(&objectives, &violations, 5).expect("failed to create ranked");
assert_eq!(ranked.len(), 4);
let mut sorted_ranked = ranked.clone();
sorted_ranked.sort_unstable();
assert_eq!(sorted_ranked, vec![0, 1, 2, 3]);
}
#[test]
fn test_stochastic_ranking_feasible_prefers_better_obj() {
let sr = StochasticRanking::new(0.45, 42).expect("failed to create sr");
let objectives = vec![2.0, 1.0]; let violations = vec![
ViolationSummary::new(vec![0.0]), ViolationSummary::new(vec![0.0]), ];
let ranked = sr.rank(&objectives, &violations, 50).expect("failed to create ranked");
assert_eq!(ranked[0], 1, "better objective should rank first");
}
#[test]
fn test_stochastic_ranking_invalid_p_f() {
let result = StochasticRanking::new(1.5, 42);
assert!(result.is_err());
}
#[test]
fn test_stochastic_ranking_mismatch_lengths() {
let sr = StochasticRanking::new(0.45, 42).expect("failed to create sr");
let result = sr.rank(&[1.0, 2.0], &[ViolationSummary::new(vec![0.0])], 5);
assert!(result.is_err());
}
#[test]
fn test_adaptive_feasibility_initial_ratio() {
let af = AdaptiveFeasibility::new(0.5, 0.1, 1e-8).expect("failed to create af");
assert!((af.feasibility_ratio() - 0.5).abs() < 1e-10);
}
#[test]
fn test_adaptive_feasibility_update_all_feasible() {
let mut af = AdaptiveFeasibility::new(0.5, 0.5, 1e-8).expect("failed to create af");
let violations = vec![
ViolationSummary::new(vec![0.0]),
ViolationSummary::new(vec![0.0]),
];
af.update(&violations);
assert!(af.feasibility_ratio() > 0.5);
}
#[test]
fn test_adaptive_feasibility_penalty_amplified_when_low() {
let mut af = AdaptiveFeasibility::new(0.5, 0.9, 1e-8).expect("failed to create af");
let violations: Vec<ViolationSummary> = vec![
ViolationSummary::new(vec![1.0]),
ViolationSummary::new(vec![2.0]),
];
af.update(&violations);
let weight = af.effective_penalty_weight(1.0);
assert!(weight > 1.0, "penalty should be amplified when feasibility is low");
}
#[test]
fn test_adaptive_feasibility_compare() {
let af = AdaptiveFeasibility::new(0.5, 0.1, 1e-8).expect("failed to create af");
let feas = ViolationSummary::new(vec![0.0]);
let infeas = ViolationSummary::new(vec![1.0]);
assert_eq!(
af.compare(100.0, &feas, 0.0, &infeas),
std::cmp::Ordering::Less
);
}
#[test]
fn test_ineq_violations_satisfied() {
let x = &[0.5_f64];
let g = |x: &[f64]| x[0] - 1.0; let v = ineq_violations(x, &[g]);
assert_eq!(v.total_violation, 0.0);
}
#[test]
fn test_ineq_violations_violated() {
let x = &[1.5_f64];
let g = |x: &[f64]| x[0] - 1.0; let v = ineq_violations(x, &[g]);
assert!((v.total_violation - 0.5).abs() < 1e-10);
}
#[test]
fn test_eq_violations() {
let x = &[2.0_f64];
let h = |x: &[f64]| x[0] - 3.0; let v = eq_violations(x, &[h]);
assert!((v.total_violation - 1.0).abs() < 1e-10);
}
#[test]
fn test_combined_violations() {
let x = &[1.5_f64, 2.0_f64];
let g = |x: &[f64]| x[0] - 1.0; let h = |x: &[f64]| x[1] - 2.0; let v = combined_violations(x, &[g], &[h]);
assert!((v.total_violation - 0.5).abs() < 1e-10);
assert_eq!(v.n_violated, 1);
}
}