quadratic_with_plots/
quadratic_with_plots.rs

1use nalgebra::{DMatrix, DVector};
2use optimization_solvers::{
3    BackTracking, FuncEvalMultivariate, GradientDescent, LineSearchSolver, Plotter3d, Tracer,
4};
5
6fn main() {
7    // Setting up log verbosity and _.
8    std::env::set_var("RUST_LOG", "debug");
9    let _ = Tracer::default().with_normal_stdout_layer().build();
10    // Setting up the oracle
11    let matrix = DMatrix::from_vec(2, 2, vec![100., 0., 0., 100.]);
12    let mut f_and_g = |x: &DVector<f64>| -> FuncEvalMultivariate {
13        let f = x.dot(&(&matrix * x));
14        let g = 2. * &matrix * x;
15        FuncEvalMultivariate::new(f, g)
16    };
17    // Setting up the line search
18    let armijo_factr = 1e-4;
19    let beta = 0.5; // (beta in (0, 1), ntice that beta = 0.5 corresponds to bisection)
20    let mut ls = BackTracking::new(armijo_factr, beta);
21    // Setting up the main solver, with its parameters and the initial guess
22    let tol = 1e-6;
23    let x0 = DVector::from_vec(vec![10., 10.]);
24    let mut solver = GradientDescent::new(tol, x0);
25    // We define a callback to store iterates and function evaluations
26    let mut iterates = vec![];
27    let mut solver_callback = |s: &GradientDescent| {
28        iterates.push(s.x().clone());
29    };
30    // Running the solver
31    let max_iter_solver = 100;
32    let max_iter_line_search = 10;
33
34    solver
35        .minimize(
36            &mut ls,
37            f_and_g,
38            max_iter_solver,
39            max_iter_line_search,
40            Some(&mut solver_callback),
41        )
42        .unwrap();
43    // Printing the result
44    let x = solver.x();
45    let eval = f_and_g(x);
46    println!("x: {:?}", x);
47    println!("f(x): {}", eval.f());
48    println!("g(x): {:?}", eval.g());
49
50    // Plotting the iterates
51    let n = 50;
52    let start = -5.0;
53    let end = 5.0;
54    let plotter = Plotter3d::new(start, end, start, end, n)
55        .append_plot(&mut f_and_g, "Objective function", 0.5)
56        .append_scatter_points(&mut f_and_g, &iterates, "Iterates")
57        .set_layout_size(1600, 1000);
58    plotter.build("quadratic.html");
59}