ceres_solver/loss.rs
1//! Loss functions for [NllsProblem](crate::nlls_problem::NllsProblem) and
2//! [CurveFitProblem1D](crate::curve_fit::CurveFitProblem1D).
3//!
4//! Loss function is a function applied to a squared norm of the problem, it could help in reducing
5//! outliers data for better convergence. There are two types of them: ones built from custom
6//! functions boxed into [LossFunctionType] and Ceres stock functions having one or two
7//! scale parameters.
8
9use ceres_solver_sys::cxx::UniquePtr;
10use ceres_solver_sys::ffi;
11
12pub type LossFunctionType = Box<dyn Fn(f64, &mut [f64; 3])>;
13
14/// Loss function for [NllsProblem](crate::nlls_problem::NllsProblem) and
15/// [CurveFitProblem1D](crate::curve_fit::CurveFitProblem1D), it is a transformation of the squared
16/// residuals which is generally used to make the solver less sensitive to outliers. This enum has
17/// two flavours: user specified function and Ceres stock function.
18pub struct LossFunction(UniquePtr<ffi::LossFunction>);
19
20impl LossFunction {
21 /// Create a [LossFunction] to handle a custom loss function.
22 ///
23 /// # Arguments
24 /// - func - a boxed function which accepts two arguments: non-negative squared residual, and
25 /// an array of 0) loss function value, 1) its first, and 2) its second derivatives. See
26 /// details at
27 /// <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres12LossFunctionE>.
28 pub fn custom(func: impl Into<LossFunctionType>) -> Self {
29 let safe_func = func.into();
30 let rust_func: Box<dyn Fn(f64, *mut f64)> = Box::new(move |sq_norm, out_ptr| {
31 let out = unsafe { &mut *(out_ptr as *mut [f64; 3]) };
32 safe_func(sq_norm, out);
33 });
34 let inner = ffi::new_callback_loss_function(Box::new(rust_func.into()));
35 Self(inner)
36 }
37
38 /// Huber loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres9HuberLossE>.
39 pub fn huber(a: f64) -> Self {
40 Self(ffi::new_huber_loss(a))
41 }
42
43 /// Soft L1 loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres12SoftLOneLossE>.
44 pub fn soft_l1(a: f64) -> Self {
45 Self(ffi::new_soft_l_one_loss(a))
46 }
47
48 /// log(1+s) loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres10CauchyLossE>.
49 pub fn cauchy(a: f64) -> Self {
50 Self(ffi::new_cauchy_loss(a))
51 }
52
53 /// Arctangent loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres10ArctanLossE>.
54 pub fn arctan(a: f64) -> Self {
55 Self(ffi::new_arctan_loss(a))
56 }
57
58 /// Tolerant loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres12TolerantLossE>.
59 pub fn tolerant(a: f64, b: f64) -> Self {
60 Self(ffi::new_tolerant_loss(a, b))
61 }
62
63 /// Tukey loss function
64 pub fn tukey(a: f64) -> Self {
65 Self(ffi::new_tukey_loss(a))
66 }
67
68 pub fn into_inner(self) -> UniquePtr<ffi::LossFunction> {
69 self.0
70 }
71}