use faer_ext::IntoNalgebra;
use super::{BaseOptParams, OptObserverVec, OptResult, Optimizer};
use crate::{
containers::{Graph, GraphOrder, Values, ValuesOrder},
dtype,
linalg::DiffResult,
linear::{LinearSolver, LinearValues},
};
pub struct GaussNewton {
graph: Graph,
solver: Box<dyn LinearSolver>,
params: BaseOptParams,
observers: OptObserverVec,
graph_order: Option<GraphOrder>,
}
impl GaussNewton {
pub fn set_solver(&mut self, solver: impl LinearSolver + 'static) {
self.solver = Box::new(solver);
}
}
impl Optimizer for GaussNewton {
type Params = BaseOptParams;
fn new(params: Self::Params, graph: Graph) -> Self {
Self {
graph,
solver: Default::default(),
observers: OptObserverVec::default(),
params,
graph_order: None,
}
}
fn observers(&self) -> &OptObserverVec {
&self.observers
}
fn observers_mut(&mut self) -> &mut OptObserverVec {
&mut self.observers
}
fn graph(&self) -> &Graph {
&self.graph
}
fn graph_mut(&mut self) -> &mut Graph {
&mut self.graph
}
fn error(&self, values: &Values) -> dtype {
self.graph.error(values)
}
fn params(&self) -> &BaseOptParams {
&self.params
}
fn init(&mut self, _values: &Values) -> Vec<&'static str> {
self.graph_order = Some(
self.graph
.sparsity_pattern(ValuesOrder::from_values(_values)),
);
Vec::new()
}
fn step(&mut self, mut values: Values, _idx: usize) -> OptResult<(Values, String)> {
let linear_graph = self.graph.linearize(&values);
let DiffResult { value: r, diff: j } =
linear_graph.residual_jacobian(self.graph_order.as_ref().expect("Missing graph order"));
let delta = self
.solver
.solve_lst_sq(j.as_ref(), r.as_ref())
.as_ref()
.into_nalgebra()
.column(0)
.clone_owned();
let dx = LinearValues::from_order_and_vector(
self.graph_order
.as_ref()
.expect("Missing graph order")
.order
.clone(),
delta,
);
values.oplus_mut(&dx);
Ok((values, String::new()))
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::test_optimizer;
test_optimizer!(GaussNewton);
}