ceres_solver/
loss.rs

1//! Loss functions for [NllsProblem](crate::nlls_problem::NllsProblem) and
2//! [CurveFitProblem1D](crate::curve_fit::CurveFitProblem1D).
3//!
4//! Loss function is a function applied to a squared norm of the problem, it could help in reducing
5//! outliers data for better convergence. There are two types of them: ones built from custom
6//! functions boxed into [LossFunctionType] and Ceres stock functions having one or two
7//! scale parameters.
8
9use ceres_solver_sys::cxx::UniquePtr;
10use ceres_solver_sys::ffi;
11
12pub type LossFunctionType = Box<dyn Fn(f64, &mut [f64; 3])>;
13
14/// Loss function for [NllsProblem](crate::nlls_problem::NllsProblem) and
15/// [CurveFitProblem1D](crate::curve_fit::CurveFitProblem1D), it is a transformation of the squared
16/// residuals which is generally used to make the solver less sensitive to outliers. This enum has
17/// two flavours: user specified function and Ceres stock function.
18pub struct LossFunction(UniquePtr<ffi::LossFunction>);
19
20impl LossFunction {
21    /// Create a [LossFunction] to handle a custom loss function.
22    ///
23    /// # Arguments
24    /// - func - a boxed function which accepts two arguments: non-negative squared residual, and
25    ///   an array of 0) loss function value, 1) its first, and 2) its second derivatives. See
26    ///   details at
27    ///   <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres12LossFunctionE>.
28    pub fn custom(func: impl Into<LossFunctionType>) -> Self {
29        let safe_func = func.into();
30        let rust_func: Box<dyn Fn(f64, *mut f64)> = Box::new(move |sq_norm, out_ptr| {
31            let out = unsafe { &mut *(out_ptr as *mut [f64; 3]) };
32            safe_func(sq_norm, out);
33        });
34        let inner = ffi::new_callback_loss_function(Box::new(rust_func.into()));
35        Self(inner)
36    }
37
38    /// Huber loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres9HuberLossE>.
39    pub fn huber(a: f64) -> Self {
40        Self(ffi::new_huber_loss(a))
41    }
42
43    /// Soft L1 loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres12SoftLOneLossE>.
44    pub fn soft_l1(a: f64) -> Self {
45        Self(ffi::new_soft_l_one_loss(a))
46    }
47
48    /// log(1+s) loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres10CauchyLossE>.
49    pub fn cauchy(a: f64) -> Self {
50        Self(ffi::new_cauchy_loss(a))
51    }
52
53    /// Arctangent loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres10ArctanLossE>.
54    pub fn arctan(a: f64) -> Self {
55        Self(ffi::new_arctan_loss(a))
56    }
57
58    /// Tolerant loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres12TolerantLossE>.
59    pub fn tolerant(a: f64, b: f64) -> Self {
60        Self(ffi::new_tolerant_loss(a, b))
61    }
62
63    /// Tukey loss function
64    pub fn tukey(a: f64) -> Self {
65        Self(ffi::new_tukey_loss(a))
66    }
67
68    pub fn into_inner(self) -> UniquePtr<ffi::LossFunction> {
69        self.0
70    }
71}