use std::ops::Mul;
use faer::sparse::{SparseColMat, Triplet};
use faer_ext::IntoNalgebra;
use super::{BaseOptParams, OptError, OptObserverVec, OptParams, OptResult, Optimizer};
use crate::{
containers::{Graph, GraphOrder, Values, ValuesOrder},
dtype,
linalg::DiffResult,
linear::{LinearSolver, LinearValues},
};
#[derive(Clone, Debug)]
pub struct LevenParams {
pub lambda_min: dtype,
pub lambda_max: dtype,
pub lambda_factor: dtype,
pub min_model_fidelity: dtype,
pub diagonal_damping: bool,
pub base: BaseOptParams,
}
impl Default for LevenParams {
fn default() -> Self {
Self {
lambda_min: 1e-20,
lambda_max: 1e5,
min_model_fidelity: 1e-3,
lambda_factor: 10.0,
diagonal_damping: true,
base: Default::default(),
}
}
}
impl OptParams for LevenParams {
fn base_params(&self) -> &BaseOptParams {
&self.base
}
}
pub struct LevenMarquardt {
graph: Graph,
solver: Box<dyn LinearSolver>,
params: LevenParams,
observers: OptObserverVec,
lambda: dtype,
graph_order: Option<GraphOrder>,
}
impl LevenMarquardt {
pub fn set_solver(&mut self, solver: impl LinearSolver + 'static) {
self.solver = Box::new(solver);
}
}
impl Optimizer for LevenMarquardt {
type Params = LevenParams;
fn new(params: Self::Params, graph: Graph) -> Self {
Self {
graph,
solver: Default::default(),
observers: OptObserverVec::default(),
params,
lambda: 1e-10,
graph_order: None,
}
}
fn observers(&self) -> &OptObserverVec {
&self.observers
}
fn observers_mut(&mut self) -> &mut OptObserverVec {
&mut self.observers
}
fn graph(&self) -> &Graph {
&self.graph
}
fn graph_mut(&mut self) -> &mut Graph {
&mut self.graph
}
fn params(&self) -> &BaseOptParams {
&self.params.base
}
fn error(&self, values: &Values) -> crate::dtype {
self.graph.error(values)
}
fn init(&mut self, values: &Values) -> Vec<&'static str> {
self.graph_order = Some(
self.graph
.sparsity_pattern(ValuesOrder::from_values(values)),
);
vec![" Lambda ", " Fidelity "]
}
fn step(&mut self, mut values: Values, _idx: usize) -> OptResult<(Values, String)> {
let order = ValuesOrder::from_values(&values);
let linear_graph = self.graph.linearize(&values);
let DiffResult { value: r, diff: j } =
linear_graph.residual_jacobian(self.graph_order.as_ref().expect("Missing graph order"));
let jtj = j
.as_ref()
.transpose()
.to_col_major()
.expect("J failed to transpose")
.mul(j.as_ref());
let triplets_i = if self.params.diagonal_damping {
(0..jtj.ncols())
.map(|i| Triplet::new(i as isize, i as isize, jtj[(i, i)]))
.collect::<Vec<_>>()
} else {
(0..jtj.ncols())
.map(|i| Triplet::new(i as isize, i as isize, 1.0))
.collect::<Vec<_>>()
};
let i = SparseColMat::<usize, dtype>::try_new_from_nonnegative_triplets(
jtj.ncols(),
jtj.ncols(),
&triplets_i,
)
.expect("Failed to make damping terms");
let b = j.as_ref().transpose().mul(&r);
let mut dx = LinearValues::zero_from_order(order.clone());
let old_lin_error = linear_graph.error(&dx);
let old_error = self.graph.error(&values);
let mut model_fidelity;
loop {
#[allow(clippy::unnecessary_cast)]
let a = &jtj + (&i * self.lambda as f64);
let delta = self
.solver
.solve_symmetric(a.as_ref(), b.as_ref())
.as_ref()
.into_nalgebra()
.column(0)
.clone_owned();
dx = LinearValues::from_order_and_vector(
self.graph_order
.as_ref()
.expect("Missing graph order")
.order
.clone(),
delta,
);
let curr_lin_error = linear_graph.error(&dx);
if curr_lin_error < old_lin_error {
let mut new_values = values.clone();
new_values.oplus_mut(&dx);
let curr_error = self.graph.error(&new_values);
model_fidelity = (curr_error - old_error) / (curr_lin_error - old_lin_error);
if model_fidelity > self.params.min_model_fidelity {
break;
}
}
self.lambda *= self.params.lambda_factor;
if self.lambda > self.params.lambda_max {
return Err(OptError::FailedToStep);
}
}
values.oplus_mut(&dx);
self.lambda /= self.params.lambda_factor;
if self.lambda < self.params.lambda_min {
self.lambda = self.params.lambda_min;
}
Ok((
values,
format!("{:^12.4e} | {:^12.4e} |", self.lambda, model_fidelity),
))
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::test_optimizer;
test_optimizer!(LevenMarquardt);
}