use basin::problems::Rosenbrock;
use basin::{
CostFunction, DenseMatrix, Executor, Gradient, GradientTolerance, QuasiNewtonState,
TerminationReason, BFGS,
};
#[test]
fn bfgs_converges_on_rosenbrock() {
let problem = Rosenbrock::<Vec<f64>>::default();
let initial = vec![-1.2, 1.0];
let result = Executor::new(
problem,
BFGS::new(),
QuasiNewtonState::<Vec<f64>, DenseMatrix>::new(initial),
)
.max_iter(100)
.run();
assert!(
result.cost() < 1e-8,
"expected near-zero cost, got {}",
result.cost()
);
assert!(
(result.param()[0] - 1.0).abs() < 1e-4,
"x[0] = {}",
result.param()[0]
);
assert!(
(result.param()[1] - 1.0).abs() < 1e-4,
"x[1] = {}",
result.param()[1]
);
}
#[test]
fn bfgs_terminates_on_gradient_tolerance() {
let problem = Rosenbrock::<Vec<f64>>::default();
let initial = vec![-1.2, 1.0];
let result = Executor::new(
problem,
BFGS::new(),
QuasiNewtonState::<Vec<f64>, DenseMatrix>::new(initial),
)
.max_iter(200)
.terminate_on(GradientTolerance(1e-6))
.run();
assert_eq!(result.reason, TerminationReason::GradientTolerance);
assert!(result.cost() < 1e-10, "cost = {}", result.cost());
}
struct Quadratic {
diag: Vec<f64>,
}
impl CostFunction for Quadratic {
type Param = Vec<f64>;
type Output = f64;
fn cost(&self, x: &Vec<f64>) -> f64 {
x.iter()
.enumerate()
.map(|(i, xi)| 0.5 * self.diag[i] * xi * xi - xi)
.sum()
}
}
impl Gradient for Quadratic {
type Param = Vec<f64>;
type Gradient = Vec<f64>;
fn gradient(&self, x: &Vec<f64>) -> Vec<f64> {
x.iter()
.enumerate()
.map(|(i, xi)| self.diag[i] * xi - 1.0)
.collect()
}
}
#[test]
fn bfgs_on_5d_quadratic_converges_quickly() {
let problem = Quadratic {
diag: vec![1.0, 2.0, 3.0, 4.0, 5.0],
};
let initial = vec![0.0; 5];
let result = Executor::new(
problem,
BFGS::new(),
QuasiNewtonState::<Vec<f64>, DenseMatrix>::new(initial),
)
.max_iter(50)
.terminate_on(GradientTolerance(1e-8))
.run();
assert_eq!(result.reason, TerminationReason::GradientTolerance);
let expected_cost = -0.5 * (1.0 + 0.5 + 1.0 / 3.0 + 0.25 + 0.2);
assert!(
(result.cost() - expected_cost).abs() < 1e-10,
"cost = {}, expected {}",
result.cost(),
expected_cost
);
assert!(
result.iter() <= 15,
"expected convergence in ≤ 15 iters, got {}",
result.iter()
);
}