#![allow(unused)]
use std::fmt::Debug;
use statrs::distribution::{ChiSquared, ContinuousCDF};
use super::{
BaseOptParams, LevenMarquardt, OptError, OptObserverVec, OptParams, OptResult, Optimizer,
};
use crate::{
containers::{GraphOrder, ValuesOrder},
core::{Graph, L2, Values},
dtype,
linalg::VectorViewX,
linear::{CholeskySolver, LinearSolver},
robust::RobustCost,
};
pub trait ConvexableKernel: RobustCost + Clone {
fn init_mu(d2: &[dtype], thresh: &[dtype]) -> dtype;
fn new(mu: dtype, thresh: dtype) -> Self;
fn step_mu(&mut self, step_size: dtype);
fn upcast(&self) -> Box<dyn RobustCost>
where
Self: Sized + 'static,
{
dyn_clone::clone_box(self)
}
fn mu(&self) -> dtype;
}
#[derive(Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct GncGemanMcClure {
mu: dtype,
c2: dtype,
}
#[factrs::mark]
impl RobustCost for GncGemanMcClure {
fn loss(&self, d2: dtype) -> dtype {
let p = self.mu * self.c2;
0.5 * p * d2 / (p + d2)
}
fn weight(&self, d2: dtype) -> dtype {
let p = self.mu * self.c2;
let frac = p / (p + d2);
frac * frac
}
}
impl ConvexableKernel for GncGemanMcClure {
fn init_mu(d2: &[dtype], thresh: &[dtype]) -> dtype {
2.0 * d2
.iter()
.zip(thresh)
.fold(0.0, |mu, (d, t)| dtype::max(mu, d / t))
}
fn new(mu: dtype, thresh: dtype) -> Self {
Self { mu, c2: thresh }
}
fn step_mu(&mut self, step_size: dtype) {
self.mu = dtype::max(1.0, self.mu / step_size);
}
fn mu(&self) -> dtype {
self.mu
}
}
impl Debug for GncGemanMcClure {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"GncGemanMcClure {{ mu: {}, c: {} }}",
self.mu,
self.c2.sqrt()
)
}
}
#[derive(Debug)]
pub struct GncParams<O: Optimizer = LevenMarquardt>
where
O::Params: Clone,
{
pub base: BaseOptParams,
pub inner: O::Params,
pub mu_step_size: dtype,
pub percentile: dtype,
}
impl<O: Optimizer> Clone for GncParams<O> {
fn clone(&self) -> Self {
Self {
base: self.base.clone(),
inner: self.inner.clone(),
mu_step_size: self.mu_step_size,
percentile: self.percentile,
}
}
}
impl<O: Optimizer> Default for GncParams<O> {
fn default() -> Self {
Self {
base: Default::default(),
inner: Default::default(),
mu_step_size: 1.4,
percentile: 0.95,
}
}
}
impl<O: Optimizer> OptParams for GncParams<O> {
fn base_params(&self) -> &BaseOptParams {
self.base.base_params()
}
}
pub struct GraduatedNonConvexity<K = GncGemanMcClure, O: Optimizer = LevenMarquardt> {
kernels: Vec<Option<K>>,
params: GncParams<O>,
graph: Graph,
observers: OptObserverVec,
}
impl<K: ConvexableKernel + 'static, O: Optimizer> Optimizer for GraduatedNonConvexity<K, O> {
type Params = GncParams<O>;
fn new(params: Self::Params, graph: Graph) -> Self {
Self {
observers: OptObserverVec::default(),
kernels: Vec::new(),
graph,
params,
}
}
fn observers(&self) -> &OptObserverVec {
&self.observers
}
fn observers_mut(&mut self) -> &mut OptObserverVec {
&mut self.observers
}
fn graph(&self) -> &Graph {
&self.graph
}
fn graph_mut(&mut self) -> &mut Graph {
&mut self.graph
}
fn params(&self) -> &BaseOptParams {
&self.params.base
}
fn error(&self, values: &Values) -> dtype {
self.graph.error(values)
}
fn init(&mut self, values: &Values) -> Vec<&'static str> {
let e: Vec<_> = self.graph().iter().map(|f| f.error(values)).collect();
#[allow(clippy::unnecessary_cast)]
let thresholds: Vec<_> = self
.graph()
.iter()
.map(|f| {
ChiSquared::new(f.dim_out() as f64)
.expect("")
.inverse_cdf(self.params.percentile as f64) as dtype
})
.collect();
let mu = K::init_mu(&e, &thresholds);
let is_odometry = self
.graph()
.iter()
.enumerate()
.map(|(i, f)| f.keys().len() == 2 && f.keys()[0].0 + 1 == f.keys()[1].0)
.collect::<Vec<_>>();
if is_odometry.iter().all(|&x| x) {
log::warn!("All factors are odometry, no kernels will be created");
}
self.kernels = thresholds
.iter()
.zip(is_odometry)
.map(|(t, inlier)| if (inlier) { None } else { Some(K::new(mu, *t)) })
.collect();
vec![" Mu "]
}
fn step(&mut self, mut values: Values, idx: usize) -> OptResult<(Values, String)> {
self.kernels
.iter_mut()
.filter_map(|k| k.as_mut())
.for_each(|k| k.step_mu(self.params.mu_step_size));
let mut mu = 0.0;
for (i, k) in self.kernels.iter().enumerate() {
if let Some(k) = k {
mu = k.mu();
}
}
#[allow(clippy::unwrap_used)]
self.graph
.iter_mut()
.zip(self.kernels.clone())
.filter(|(f, k)| k.is_some())
.for_each(|(f, k)| f.robust = k.unwrap().upcast());
let error = self.error(&values);
let mut info = String::new();
let mut opt = O::new(self.params.inner.clone(), self.graph().clone());
let result = opt.optimize(values.clone());
match result {
Ok(v) => values = v,
Err(OptError::MaxIterations(v)) => {
values = v;
}
Err(e) => {
log::warn!("Inner optimizer failed");
return Err(e);
}
}
info.push_str(&format!(" {mu:^12.4e} |"));
Ok((values, info))
}
fn optimize(&mut self, mut values: Values) -> OptResult<Values> {
let append = self.init(&values);
let mut error_old = self.error(&values);
if error_old <= self.params().error_tol {
log::info!("Error is already below tolerance, skipping optimization");
return Ok(values);
}
let extra = if append.is_empty() { "" } else { " |" };
log::info!(
"{:^5} | {:^12} | {:^12} | {:^12} | {}",
"Iter",
"Error",
"ErrorAbs",
"ErrorRel",
append.join(" | ") + extra,
);
log::info!(
"{:^5} | {:^12} | {:^12} | {:^12} | {}",
"-----",
"------------",
"------------",
"------------",
append
.iter()
.map(|s| "-".repeat(s.len()))
.collect::<Vec<_>>()
.join(" | ")
+ extra
);
log::info!(
"{:^5} | {:^12.4e} | {:^12} | {:^12} | {}",
0,
error_old,
"-",
"-",
append
.iter()
.map(|s| format!("{:^width$}", "-", width = s.len()))
.collect::<Vec<_>>()
.join(" | ")
+ extra
);
let mut error_new = error_old;
for i in 1..self.params().max_iterations + 1 {
error_old = error_new;
let (temp, info) = self.step(values, i)?;
values = temp;
self.observers().notify(&values, i);
error_new = self.error(&values);
let error_decrease_abs = dtype::abs(error_old - error_new);
let error_decrease_rel = error_decrease_abs / error_old;
log::info!(
"{i:^5} | {error_new:^12.4e} | {error_decrease_abs:^12.4e} | {error_decrease_rel:^12.4e} | {info}"
);
if error_new <= self.params().error_tol {
log::info!("Error is below tolerance, stopping optimization");
return Ok(values);
}
if error_decrease_abs <= self.params().error_tol_absolute {
log::info!("Error decrease is below absolute tolerance, stopping optimization");
return Ok(values);
}
if error_decrease_rel <= self.params().error_tol_relative {
log::info!("Error decrease is below relative tolerance, stopping optimization");
return Ok(values);
}
}
Err(OptError::MaxIterations(values))
}
}