use std::cell::Cell;
use std::rc::Rc;
use basin::{
BasicState, CostFunction, Executor, FiniteDiff, Gradient, GradientDescent, GradientTolerance,
MaxIter,
};
struct Sphere;
impl CostFunction for Sphere {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, std::convert::Infallible> {
Ok(x.iter().map(|xi| xi * xi).sum())
}
}
impl Gradient for Sphere {
type Gradient = Vec<f64>;
fn gradient(&self, x: &Vec<f64>) -> Result<Vec<f64>, std::convert::Infallible> {
Ok(x.iter().map(|xi| 2.0 * xi).collect())
}
}
#[test]
fn default_fused_matches_separate_calls() {
let p = Sphere;
let x = vec![1.5, -2.25, 0.5];
let (fused_cost, fused_grad) = p.cost_and_gradient(&x).unwrap();
assert_eq!(fused_cost, p.cost(&x).unwrap());
assert_eq!(fused_grad, p.gradient(&x).unwrap());
}
#[test]
fn gradient_descent_runs_with_no_opt_in() {
let solver = GradientDescent::new(0.1);
let state = BasicState::new(vec![1.0, 1.0]);
let result = Executor::new(Sphere, solver, state)
.terminate_on(MaxIter(200))
.terminate_on(GradientTolerance(1e-10))
.run()
.unwrap();
assert!(result.cost() < 1e-8, "got cost {}", result.cost());
}
struct Counted {
fused_calls: Rc<Cell<usize>>,
}
impl CostFunction for Counted {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, std::convert::Infallible> {
Ok(x.iter().map(|xi| xi * xi).sum())
}
}
impl Gradient for Counted {
type Gradient = Vec<f64>;
fn gradient(&self, x: &Vec<f64>) -> Result<Vec<f64>, std::convert::Infallible> {
Ok(x.iter().map(|xi| 2.0 * xi).collect())
}
fn cost_and_gradient(&self, x: &Vec<f64>) -> Result<(f64, Vec<f64>), std::convert::Infallible> {
Ok({
self.fused_calls.set(self.fused_calls.get() + 1);
let c = x.iter().map(|xi| xi * xi).sum();
let g = x.iter().map(|xi| 2.0 * xi).collect();
(c, g)
})
}
}
#[test]
fn solver_calls_fused_override() {
let counter = Rc::new(Cell::new(0_usize));
let problem = Counted {
fused_calls: counter.clone(),
};
let solver = GradientDescent::new(0.1);
let state = BasicState::new(vec![1.0, 1.0]);
let result = Executor::new(problem, solver, state)
.terminate_on(MaxIter(200))
.terminate_on(GradientTolerance(1e-10))
.run()
.unwrap();
let expected_min = result.iter() as usize + 1;
let calls = counter.get();
assert!(
calls >= expected_min,
"expected at least {expected_min} fused calls, got {calls}"
);
assert!(result.cost() < 1e-8, "got cost {}", result.cost());
}
#[cfg(feature = "nalgebra")]
mod hessian {
use super::*;
use basin::Hessian;
use nalgebra::{DMatrix, DVector};
struct SphereN;
impl CostFunction for SphereN {
type Param = DVector<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &DVector<f64>) -> Result<f64, std::convert::Infallible> {
Ok(x.dot(x))
}
}
impl Gradient for SphereN {
type Gradient = DVector<f64>;
fn gradient(&self, x: &DVector<f64>) -> Result<DVector<f64>, std::convert::Infallible> {
Ok(2.0 * x)
}
}
impl Hessian for SphereN {
type Hessian = DMatrix<f64>;
fn hessian(&self, x: &DVector<f64>) -> Result<DMatrix<f64>, std::convert::Infallible> {
Ok(2.0 * DMatrix::identity(x.len(), x.len()))
}
}
#[test]
fn default_triple_matches_separate_calls() {
let p = SphereN;
let x = DVector::from_vec(vec![1.0, -2.0, 3.0]);
let (c, g, h) = p.cost_and_gradient_and_hessian(&x).unwrap();
assert_eq!(c, p.cost(&x).unwrap());
assert_eq!(g, p.gradient(&x).unwrap());
assert_eq!(h, p.hessian(&x).unwrap());
}
struct CountedTriple {
fused_calls: Rc<Cell<usize>>,
}
impl CostFunction for CountedTriple {
type Param = DVector<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &DVector<f64>) -> Result<f64, std::convert::Infallible> {
Ok(x.dot(x))
}
}
impl Gradient for CountedTriple {
type Gradient = DVector<f64>;
fn gradient(&self, x: &DVector<f64>) -> Result<DVector<f64>, std::convert::Infallible> {
Ok(2.0 * x)
}
}
impl Hessian for CountedTriple {
type Hessian = DMatrix<f64>;
fn hessian(&self, x: &DVector<f64>) -> Result<DMatrix<f64>, std::convert::Infallible> {
Ok(2.0 * DMatrix::identity(x.len(), x.len()))
}
fn cost_and_gradient_and_hessian(
&self,
x: &DVector<f64>,
) -> Result<(f64, DVector<f64>, DMatrix<f64>), std::convert::Infallible> {
Ok({
self.fused_calls.set(self.fused_calls.get() + 1);
let c = x.dot(x);
let g = 2.0 * x;
let h = 2.0 * DMatrix::identity(x.len(), x.len());
(c, g, h)
})
}
}
#[test]
fn override_triple_is_called() {
let counter = Rc::new(Cell::new(0_usize));
let p = CountedTriple {
fused_calls: counter.clone(),
};
let x = DVector::from_vec(vec![1.0, 2.0, 3.0]);
let _ = p.cost_and_gradient_and_hessian(&x).unwrap();
let _ = p.cost_and_gradient_and_hessian(&x).unwrap();
assert_eq!(counter.get(), 2);
}
}
#[cfg(feature = "nalgebra")]
mod lsq {
use super::*;
use basin::{Jacobian, LevenbergMarquardt, NllsState, Residual};
use nalgebra::{DMatrix, DVector};
struct Affine {
fused_calls: Rc<Cell<usize>>,
}
impl CostFunction for Affine {
type Param = DVector<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &DVector<f64>) -> Result<f64, std::convert::Infallible> {
Ok(0.5 * ((x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2)))
}
}
impl Residual for Affine {
type Param = DVector<f64>;
type Output = DVector<f64>;
type Error = std::convert::Infallible;
fn residual(&self, x: &DVector<f64>) -> Result<DVector<f64>, std::convert::Infallible> {
Ok(DVector::from_vec(vec![x[0] - 1.0, x[1] - 2.0]))
}
}
impl Jacobian for Affine {
type Jacobian = DMatrix<f64>;
fn jacobian(&self, _x: &DVector<f64>) -> Result<DMatrix<f64>, std::convert::Infallible> {
Ok(DMatrix::identity(2, 2))
}
fn residual_and_jacobian(
&self,
x: &DVector<f64>,
) -> Result<(DVector<f64>, DMatrix<f64>), std::convert::Infallible> {
Ok({
self.fused_calls.set(self.fused_calls.get() + 1);
(
DVector::from_vec(vec![x[0] - 1.0, x[1] - 2.0]),
DMatrix::identity(2, 2),
)
})
}
}
#[test]
fn lm_calls_fused_residual_and_jacobian() {
let counter = Rc::new(Cell::new(0_usize));
let problem = Affine {
fused_calls: counter.clone(),
};
let state = NllsState::new(DVector::from_vec(vec![0.0, 0.0]));
let solver: LevenbergMarquardt<DVector<f64>, DMatrix<f64>> = LevenbergMarquardt::new();
let result = Executor::new(problem, solver, state)
.terminate_on(MaxIter(10))
.run()
.unwrap();
assert!(counter.get() >= 1);
let x = result.param();
assert!((x[0] - 1.0).abs() < 1e-8);
assert!((x[1] - 2.0).abs() < 1e-8);
}
}
struct CostOnly;
impl CostFunction for CostOnly {
type Param = Vec<f64>;
type Output = f64;
type Error = std::convert::Infallible;
fn cost(&self, x: &Vec<f64>) -> Result<f64, std::convert::Infallible> {
Ok(x.iter().map(|xi| (xi - 1.0).powi(2)).sum())
}
}
#[test]
fn finite_diff_runs_through_solver() {
let problem = FiniteDiff::new(CostOnly);
let solver = GradientDescent::new(0.1);
let state = BasicState::new(vec![0.0, 0.0]);
let result = Executor::new(problem, solver, state)
.terminate_on(MaxIter(200))
.terminate_on(GradientTolerance(1e-8))
.run()
.unwrap();
let x = result.param();
assert!((x[0] - 1.0).abs() < 1e-3);
assert!((x[1] - 1.0).abs() < 1e-3);
}