use super::{clip, DfOptResult, DerivativeFreeOptimizer};
use crate::error::{OptimizeError, OptimizeResult};
use scirs2_core::ndarray::Array1;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PatternType {
Axes,
CompassRose,
Simplex,
}
#[derive(Debug, Clone)]
pub struct PatternSearchOptions {
pub lower: Option<Vec<f64>>,
pub upper: Option<Vec<f64>>,
pub delta_init: f64,
pub delta_min: f64,
pub expand_factor: f64,
pub contract_factor: f64,
pub max_fev: usize,
pub max_iter: usize,
pub f_tol: f64,
pub pattern: PatternType,
pub opportunistic: bool,
}
impl Default for PatternSearchOptions {
fn default() -> Self {
PatternSearchOptions {
lower: None,
upper: None,
delta_init: 1.0,
delta_min: 1e-7,
expand_factor: 2.0,
contract_factor: 0.5,
max_fev: 50000,
max_iter: 10000,
f_tol: 1e-10,
pattern: PatternType::Axes,
opportunistic: true,
}
}
}
pub struct PatternSearchSolver {
pub options: PatternSearchOptions,
}
impl PatternSearchSolver {
pub fn new() -> Self {
PatternSearchSolver {
options: PatternSearchOptions::default(),
}
}
pub fn with_options(options: PatternSearchOptions) -> Self {
PatternSearchSolver { options }
}
fn get_bounds(&self, n: usize) -> (Vec<f64>, Vec<f64>) {
let lo = match &self.options.lower {
Some(l) => l.clone(),
None => vec![f64::NEG_INFINITY; n],
};
let hi = match &self.options.upper {
Some(u) => u.clone(),
None => vec![f64::INFINITY; n],
};
(lo, hi)
}
fn project(&self, x: &[f64], lo: &[f64], hi: &[f64]) -> Vec<f64> {
x.iter()
.zip(lo.iter().zip(hi.iter()))
.map(|(&xi, (&l, &h))| clip(xi, l, h))
.collect()
}
fn generate_pattern(&self, n: usize) -> Vec<Vec<f64>> {
match self.options.pattern {
PatternType::Axes => {
let mut dirs = Vec::with_capacity(2 * n);
for i in 0..n {
let mut d = vec![0.0; n];
d[i] = 1.0;
dirs.push(d.clone());
d[i] = -1.0;
dirs.push(d);
}
dirs
}
PatternType::CompassRose => {
let mut dirs = Vec::new();
for i in 0..n {
let mut d = vec![0.0; n];
d[i] = 1.0;
dirs.push(d.clone());
d[i] = -1.0;
dirs.push(d);
}
if n <= 8 {
for i in 0..n {
for j in (i + 1)..n {
let mut d = vec![0.0; n];
d[i] = 1.0;
d[j] = -1.0;
let scale = 1.0 / (2.0_f64).sqrt();
let ds: Vec<f64> = d.iter().map(|v| v * scale).collect();
dirs.push(ds.clone());
let ds2: Vec<f64> = ds.iter().map(|v| -v).collect();
dirs.push(ds2);
}
}
}
dirs
}
PatternType::Simplex => {
let mut dirs = Vec::with_capacity(n + 1);
let neg_sum_scale = 1.0 / (n as f64).sqrt();
let neg_d = vec![-neg_sum_scale; n];
for i in 0..n {
let mut d = vec![0.0; n];
d[i] = 1.0;
dirs.push(d);
}
dirs.push(neg_d);
dirs
}
}
}
}
impl Default for PatternSearchSolver {
fn default() -> Self {
PatternSearchSolver::new()
}
}
impl DerivativeFreeOptimizer for PatternSearchSolver {
fn minimize<F>(&self, func: F, x0: &[f64]) -> OptimizeResult<DfOptResult>
where
F: Fn(&[f64]) -> f64,
{
let n = x0.len();
if n == 0 {
return Err(OptimizeError::InvalidInput(
"x0 must be non-empty".to_string(),
));
}
let (lo, hi) = self.get_bounds(n);
let mut x = self.project(x0, &lo, &hi);
let mut delta = self.options.delta_init;
let mut nfev = 0usize;
let mut nit = 0usize;
let mut fx = {
nfev += 1;
func(&x)
};
let dirs = self.generate_pattern(n);
loop {
if nit >= self.options.max_iter || nfev >= self.options.max_fev {
break;
}
if delta < self.options.delta_min {
return Ok(DfOptResult {
x: Array1::from_vec(x),
fun: fx,
nfev,
nit,
success: true,
message: "Converged: mesh size below tolerance".to_string(),
});
}
let mut improved = false;
let mut best_x = x.clone();
let mut best_f = fx;
'poll: for dir in &dirs {
if nfev >= self.options.max_fev {
break 'poll;
}
let xtrial: Vec<f64> = x
.iter()
.zip(dir.iter())
.map(|(&xi, &di)| xi + delta * di)
.collect();
let xtrial = self.project(&xtrial, &lo, &hi);
nfev += 1;
let ftrial = func(&xtrial);
if ftrial < best_f - self.options.f_tol {
best_f = ftrial;
best_x = xtrial;
improved = true;
if self.options.opportunistic {
break 'poll;
}
}
}
if improved {
x = best_x;
fx = best_f;
delta = (delta * self.options.expand_factor).min(self.options.delta_init * 100.0);
} else {
delta *= self.options.contract_factor;
}
nit += 1;
}
let success = delta < self.options.delta_min * 10.0;
Ok(DfOptResult {
x: Array1::from_vec(x),
fun: fx,
nfev,
nit,
success,
message: if success {
"Converged".to_string()
} else {
"Maximum iterations/evaluations reached".to_string()
},
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_pattern_search_quadratic() {
let solver = PatternSearchSolver::new();
let result = solver
.minimize(
|x: &[f64]| (x[0] - 2.0).powi(2) + (x[1] - 3.0).powi(2),
&[0.0, 0.0],
)
.expect("optimization failed");
assert_abs_diff_eq!(result.x[0], 2.0, epsilon = 1e-3);
assert_abs_diff_eq!(result.x[1], 3.0, epsilon = 1e-3);
assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-4);
}
#[test]
fn test_pattern_search_bounded() {
let opts = PatternSearchOptions {
lower: Some(vec![0.0, 0.0]),
upper: Some(vec![5.0, 5.0]),
delta_min: 1e-8,
..Default::default()
};
let solver = PatternSearchSolver::with_options(opts);
let result = solver
.minimize(
|x: &[f64]| (x[0] + 2.0).powi(2) + (x[1] + 2.0).powi(2),
&[1.0, 1.0],
)
.expect("optimization failed");
assert_abs_diff_eq!(result.x[0], 0.0, epsilon = 1e-3);
assert_abs_diff_eq!(result.x[1], 0.0, epsilon = 1e-3);
}
#[test]
fn test_pattern_search_compass_rose() {
let opts = PatternSearchOptions {
pattern: PatternType::CompassRose,
delta_min: 1e-7,
max_fev: 100000,
..Default::default()
};
let solver = PatternSearchSolver::with_options(opts);
let result = solver
.minimize(
|x: &[f64]| (x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2),
&[0.0, 0.0],
)
.expect("optimization failed");
assert_abs_diff_eq!(result.x[0], 1.0, epsilon = 1e-3);
assert_abs_diff_eq!(result.x[1], 2.0, epsilon = 1e-3);
}
#[test]
fn test_pattern_search_non_opportunistic() {
let opts = PatternSearchOptions {
opportunistic: false,
delta_min: 1e-6,
max_fev: 200000,
..Default::default()
};
let solver = PatternSearchSolver::with_options(opts);
let result = solver
.minimize(|x: &[f64]| x[0].powi(2) + x[1].powi(2), &[5.0, 5.0])
.expect("optimization failed");
assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-4);
}
#[test]
fn test_pattern_search_1d() {
let solver = PatternSearchSolver::new();
let result = solver
.minimize(|x: &[f64]| (x[0] - 7.0).powi(2), &[0.0])
.expect("optimization failed");
assert_abs_diff_eq!(result.x[0], 7.0, epsilon = 1e-3);
assert_abs_diff_eq!(result.fun, 0.0, epsilon = 1e-4);
}
#[test]
fn test_pattern_search_simplex_pattern() {
let opts = PatternSearchOptions {
pattern: PatternType::Simplex,
delta_min: 1e-7,
..Default::default()
};
let solver = PatternSearchSolver::with_options(opts);
let result = solver
.minimize(
|x: &[f64]| (x[0] - 1.5).powi(2) + (x[1] + 0.5).powi(2),
&[0.0, 0.0],
)
.expect("optimization failed");
assert_abs_diff_eq!(result.x[0], 1.5, epsilon = 1e-3);
assert_abs_diff_eq!(result.x[1], -0.5, epsilon = 1e-3);
}
}