math_audio_differential_evolution/
differential_evolution.rs

1use crate::{DEConfig, DEError, DEReport, DifferentialEvolution, Result};
2use ndarray::Array1;
3
4/// Runs Differential Evolution optimization on a function.
5///
6/// This is a convenience function that mirrors SciPy's `differential_evolution` API.
7/// It creates a DE optimizer with the given bounds and configuration, then runs
8/// the optimization to find the global minimum.
9///
10/// # Arguments
11///
12/// * `func` - The objective function to minimize, mapping `&Array1<f64>` to `f64`
13/// * `bounds` - Vector of (lower, upper) bound pairs for each dimension
14/// * `config` - DE configuration (use `DEConfigBuilder` to construct)
15///
16/// # Returns
17///
18/// Returns `Ok(DEReport)` containing the optimization result on success.
19///
20/// # Errors
21///
22/// Returns `DEError::InvalidBounds` if any bound pair has upper < lower.
23///
24/// # Example
25///
26/// ```rust
27/// use math_audio_differential_evolution::{differential_evolution, DEConfigBuilder};
28///
29/// let result = differential_evolution(
30///     &|x| x[0].powi(2) + x[1].powi(2),
31///     &[(-5.0, 5.0), (-5.0, 5.0)],
32///     DEConfigBuilder::new().maxiter(50).seed(42).build(),
33/// ).expect("optimization failed");
34///
35/// assert!(result.fun < 0.01);
36/// ```
37pub fn differential_evolution<F>(
38    func: &F,
39    bounds: &[(f64, f64)],
40    config: DEConfig,
41) -> Result<DEReport>
42where
43    F: Fn(&Array1<f64>) -> f64 + Sync,
44{
45    let n = bounds.len();
46    let mut lower = Array1::<f64>::zeros(n);
47    let mut upper = Array1::<f64>::zeros(n);
48    for (i, (lo, hi)) in bounds.iter().enumerate() {
49        lower[i] = *lo;
50        upper[i] = *hi;
51        if hi < lo {
52            return Err(DEError::InvalidBounds {
53                index: i,
54                lower: *lo,
55                upper: *hi,
56            });
57        }
58    }
59    let mut de = DifferentialEvolution::new(func, lower, upper)?;
60    *de.config_mut() = config;
61    Ok(de.solve())
62}