Skip to main content

math_audio_optimisation/
impl_helpers.rs

1use crate::{DEReport, DifferentialEvolution};
2use ndarray::{Array1, Array2, Zip};
3use oxiblas_ndarray::blas::matvec;
4
5// ------------------------------ Internal helpers ------------------------------
6
7impl<'a, F> DifferentialEvolution<'a, F>
8where
9    F: Fn(&Array1<f64>) -> f64 + Sync,
10{
11    pub(crate) fn energy(&self, x: &Array1<f64>) -> f64 {
12        let base = (self.func)(x);
13        let energy = base + self.penalty(x);
14        if energy.is_finite() {
15            energy
16        } else {
17            f64::INFINITY
18        }
19    }
20
21    pub(crate) fn penalty(&self, x: &Array1<f64>) -> f64 {
22        let mut p = 0.0;
23        // Nonlinear ineq: fc(x) <= 0 feasible
24        for (f, w) in &self.config.penalty_ineq {
25            let v = f(x);
26            let viol = v.max(0.0);
27            p += w * viol * viol;
28        }
29        // Nonlinear eq: h(x) = 0
30        for (h, w) in &self.config.penalty_eq {
31            let v = h(x);
32            p += w * v * v;
33        }
34        // Linear penalties: lb <= A x <= ub
35        if let Some(lp) = &self.config.linear_penalty {
36            let ax = matvec(&lp.a, &x.to_owned());
37            Zip::from(&ax)
38                .and(&lp.lb)
39                .and(&lp.ub)
40                .for_each(|&v, &lo, &hi| {
41                    if v < lo {
42                        let d = lo - v;
43                        p += lp.weight * d * d;
44                    } else if v > hi {
45                        let d = v - hi;
46                        p += lp.weight * d * d;
47                    }
48                });
49        }
50        p
51    }
52
53    #[allow(clippy::too_many_arguments)]
54    pub(crate) fn finish_report(
55        &self,
56        pop: Array2<f64>,
57        energies: Array1<f64>,
58        x: Array1<f64>,
59        fun: f64,
60        success: bool,
61        message: String,
62        nit: usize,
63        nfev: usize,
64    ) -> DEReport {
65        DEReport {
66            x,
67            fun,
68            success,
69            message,
70            nit,
71            nfev,
72            population: pop,
73            population_energies: energies,
74        }
75    }
76
77    pub(crate) fn polish(&self, x0: &Array1<f64>) -> (Array1<f64>, f64, usize) {
78        let polish_cfg = match &self.config.polish {
79            Some(cfg) if cfg.enabled => cfg,
80            _ => {
81                let f = self.energy(x0);
82                return (x0.clone(), f, 1);
83            }
84        };
85
86        let n = x0.len();
87        let mut x = x0.clone();
88        let mut best_f = self.energy(&x);
89        let mut nfev = 1;
90
91        let initial_step = 0.1;
92        let min_step = 1e-8;
93        let mut step = initial_step;
94
95        let max_eval = polish_cfg.maxeval.min(200 * n);
96
97        while nfev < max_eval && step > min_step {
98            let mut improved = false;
99
100            for i in 0..n {
101                if nfev >= max_eval {
102                    break;
103                }
104
105                let original = x[i];
106                let bounds_span = self.upper[i] - self.lower[i];
107                let dim_step = step * bounds_span.max(1.0);
108
109                for delta in [dim_step, -dim_step] {
110                    if nfev >= max_eval {
111                        break;
112                    }
113                    x[i] = (original + delta).clamp(self.lower[i], self.upper[i]);
114                    let f = self.energy(&x);
115                    nfev += 1;
116
117                    if f < best_f {
118                        best_f = f;
119                        improved = true;
120                        break;
121                    }
122                    x[i] = original;
123                }
124            }
125
126            if !improved {
127                step *= 0.5;
128            }
129        }
130
131        (x, best_f, nfev)
132    }
133}