1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
//! Loss functions for [NllsProblem](crate::nlls_problem::NllsProblem) and
//! [CurveFitProblem1D](crate::curve_fit::CurveFitProblem1D).
//!
//! Loss function is a function applied to a squared norm of the problem, it could help in reducing
//! outliers data for better convergence. There are two types of them: ones built from custom
//! functions boxed into [LossFunctionType] and Ceres stock functions having one or two
//! scale parameters.
use ceres_solver_sys::cxx::UniquePtr;
use ceres_solver_sys::ffi;
pub type LossFunctionType = Box<dyn Fn(f64, &mut [f64; 3])>;
/// Loss function for [NllsProblem](crate::nlls_problem::NllsProblem) and
/// [CurveFitProblem1D](crate::curve_fit::CurveFitProblem1D), it is a transformation of the squared
/// residuals which is generally used to make the solver less sensitive to outliers. This enum has
/// two flavours: user specified function and Ceres stock function.
pub struct LossFunction(UniquePtr<ffi::LossFunction>);
impl LossFunction {
/// Create a [LossFunction] to handle a custom loss function.
///
/// # Arguments
/// - func - a boxed function which accepts two arguments: non-negative squared residual, and
/// an array of 0) loss function value, 1) its first, and 2) its second derivatives. See
/// details at
/// <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres12LossFunctionE>.
pub fn custom(func: impl Into<LossFunctionType>) -> Self {
let safe_func = func.into();
let rust_func: Box<dyn Fn(f64, *mut f64)> = Box::new(move |sq_norm, out_ptr| {
let out = unsafe { &mut *(out_ptr as *mut [f64; 3]) };
safe_func(sq_norm, out);
});
let inner = ffi::new_callback_loss_function(Box::new(rust_func.into()));
Self(inner)
}
/// Huber loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres9HuberLossE>.
pub fn huber(a: f64) -> Self {
Self(ffi::new_huber_loss(a))
}
/// Soft L1 loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres12SoftLOneLossE>.
pub fn soft_l1(a: f64) -> Self {
Self(ffi::new_soft_l_one_loss(a))
}
/// log(1+s) loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres10CauchyLossE>.
pub fn cauchy(a: f64) -> Self {
Self(ffi::new_cauchy_loss(a))
}
/// Arctangent loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres10ArctanLossE>.
pub fn arctan(a: f64) -> Self {
Self(ffi::new_arctan_loss(a))
}
/// Tolerant loss function, see details at <http://ceres-solver.org/nnls_modeling.html#_CPPv4N5ceres12TolerantLossE>.
pub fn tolerant(a: f64, b: f64) -> Self {
Self(ffi::new_tolerant_loss(a, b))
}
/// Tukey loss function
pub fn tukey(a: f64) -> Self {
Self(ffi::new_tukey_loss(a))
}
pub fn into_inner(self) -> UniquePtr<ffi::LossFunction> {
self.0
}
}