use argmin::core::observers::{ObserverMode, SlogLogger};
use argmin::core::{CostFunction, Error, Executor, Gradient};
use argmin::solver::linesearch::MoreThuenteLineSearch;
use argmin::solver::quasinewton::LBFGS;
use argmin_testfunctions::{rosenbrock_2d, rosenbrock_2d_derivative};
use nalgebra::DVector;
struct Rosenbrock {
a: f64,
b: f64,
}
impl CostFunction for Rosenbrock {
type Param = DVector<f64>;
type Output = f64;
fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
Ok(rosenbrock_2d(p.data.as_vec(), self.a, self.b))
}
}
impl Gradient for Rosenbrock {
type Param = DVector<f64>;
type Gradient = DVector<f64>;
fn gradient(&self, p: &Self::Param) -> Result<Self::Gradient, Error> {
Ok(DVector::from(rosenbrock_2d_derivative(
p.data.as_vec(),
self.a,
self.b,
)))
}
}
fn run() -> Result<(), Error> {
let cost = Rosenbrock { a: 1.0, b: 100.0 };
let init_param: DVector<f64> = DVector::from(vec![-1.2, 1.0]);
let linesearch = MoreThuenteLineSearch::new().with_c(1e-4, 0.9)?;
let solver = LBFGS::new(linesearch, 7);
let res = Executor::new(cost, solver)
.configure(|state| state.param(init_param).max_iters(100))
.add_observer(SlogLogger::term(), ObserverMode::Always)
.run()?;
std::thread::sleep(std::time::Duration::from_secs(1));
println!("{res}");
Ok(())
}
fn main() {
if let Err(ref e) = run() {
println!("{e}");
std::process::exit(1);
}
}