use argmin::core::observers::{ObserverMode, SlogLogger};
use argmin::core::{CostFunction, Error, Executor, Gradient};
use argmin::solver::linesearch::MoreThuenteLineSearch;
use argmin::solver::quasinewton::SR1;
use argmin_testfunctions::styblinski_tang;
use finitediff::FiniteDiff;
use ndarray::{array, Array1, Array2};
struct StyblinskiTang {}
impl CostFunction for StyblinskiTang {
type Param = Array1<f64>;
type Output = f64;
fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
Ok(styblinski_tang(&p.to_vec()))
}
}
impl Gradient for StyblinskiTang {
type Param = Array1<f64>;
type Gradient = Array1<f64>;
fn gradient(&self, p: &Self::Param) -> Result<Self::Gradient, Error> {
Ok((*p).forward_diff(&|x| styblinski_tang(&x.to_vec())))
}
}
fn run() -> Result<(), Error> {
let cost = StyblinskiTang {};
let init_param: Array1<f64> = array![5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0];
let init_hessian: Array2<f64> = Array2::eye(8);
let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9)?;
let solver = SR1::new(linesearch);
let res = Executor::new(cost, solver)
.configure(|state| {
state
.param(init_param)
.inv_hessian(init_hessian)
.max_iters(1000)
})
.add_observer(SlogLogger::term(), ObserverMode::Always)
.run()?;
std::thread::sleep(std::time::Duration::from_secs(1));
println!("{res}");
Ok(())
}
fn main() {
if let Err(ref e) = run() {
println!("{e}");
std::process::exit(1);
}
}