use minuit2::{MnMigrad, MnSimplex};
#[test]
fn nan_resilience() {
let result = MnMigrad::new()
.add("x", 4.0, 1.0) .minimize(&|p: &[f64]| {
if p[0] > 5.0 { f64::NAN } else { p[0] * p[0] }
});
assert!(result.is_valid());
assert!(result.fval() < 1e-4);
}
#[test]
fn inf_resilience() {
let result = MnMigrad::new().add("x", 4.0, 1.0).minimize(&|p: &[f64]| {
if p[0] > 5.0 {
f64::INFINITY
} else {
p[0] * p[0]
}
});
assert!(result.is_valid());
assert!(result.fval() < 1e-4);
}
#[test]
fn high_dim_stress() {
let n = 50;
let mut builder = MnMigrad::new();
for i in 0..n {
builder = builder.add(format!("x{}", i), i as f64, 0.1);
}
let result = builder
.max_fcn(10000) .minimize(&|p: &[f64]| p.iter().map(|x| x * x).sum());
assert!(result.is_valid());
assert!(result.fval() < 1e-4);
for i in 0..n {
assert!(result.user_state().value(&format!("x{}", i)).unwrap().abs() < 1e-2);
}
}
#[test]
fn goldstein_price() {
let gp = |p: &[f64]| {
let x = p[0];
let y = p[1];
let part1 = 1.0
+ (x + y + 1.0).powi(2)
* (19.0 - 14.0 * x + 3.0 * x * x - 14.0 * y + 6.0 * x * y + 3.0 * y * y);
let part2 = 30.0
+ (2.0 * x - 3.0 * y).powi(2)
* (18.0 - 32.0 * x + 12.0 * x * x + 48.0 * y - 36.0 * x * y + 27.0 * y * y);
part1 * part2
};
let result = MnSimplex::new()
.add("x", 0.5, 0.5)
.add("y", -0.5, 0.5)
.tolerance(0.0001) .max_fcn(5000)
.minimize(&gp);
assert!(result.is_valid());
assert!((result.fval() - 3.0).abs() < 1e-4);
let params = result.params();
assert!((params[0] - 0.0).abs() < 0.1);
assert!((params[1] - (-1.0)).abs() < 0.1);
}
#[test]
fn boundary_edge_case() {
let result_at_bound = MnMigrad::new()
.add_limited("x", 5.0, 0.1, 0.0, 5.0)
.minimize(&|p: &[f64]| (p[0] - 5.0).powi(2));
assert!(result_at_bound.is_valid());
assert!((result_at_bound.params()[0] - 5.0).abs() < 1e-4);
let result_near_bound = MnMigrad::new()
.add_limited("x", 4.999999, 0.1, 0.0, 5.0)
.minimize(&|p: &[f64]| (p[0] - 5.0).powi(2));
assert!(result_near_bound.is_valid());
assert!((result_near_bound.params()[0] - 5.0).abs() < 1e-4);
}
#[test]
fn rosenbrock_hard_start() {
let result = MnMigrad::new()
.add("x", -1.2, 0.1)
.add("y", 1.0, 0.1)
.tolerance(0.1)
.minimize(&|p: &[f64]| (1.0 - p[0]).powi(2) + 100.0 * (p[1] - p[0] * p[0]).powi(2));
assert!(result.is_valid());
assert!(result.fval() < 1e-4);
assert!((result.params()[0] - 1.0).abs() < 1e-2);
assert!((result.params()[1] - 1.0).abs() < 1e-2);
}