ceres_solver/
cost.rs

1//! Cost function wrapper for [NllsProblem](crate::nlls_problem::NllsProblem).
2//!
3//! Box your cost function into [CostFunctionType] to add it to problem using
4//! [crate::nlls_problem::ResidualBlockBuilder::set_cost]
5
6use crate::types::JacobianType;
7
8use ceres_solver_sys::cxx;
9use ceres_solver_sys::ffi;
10use std::slice;
11
12pub type CostFunctionType<'a> = Box<dyn Fn(&[&[f64]], &mut [f64], JacobianType<'_>) -> bool + 'a>;
13
14/// A cost function for [NllsProblem](crate::nlls_problem::NllsProblem).
15pub(crate) struct CostFunction<'cost>(cxx::UniquePtr<ffi::CallbackCostFunction<'cost>>);
16
17impl<'cost> CostFunction<'cost> {
18    /// Create a new cost function from a Rust function.
19    ///
20    /// # Arguments
21    /// - func - function to find residuals and Jacobian for the problem block. The function itself
22    ///   must return [false] if it cannot compute Jacobian, [true] otherwise, and accept following
23    ///   arguments:
24    ///   - parameters - slice of [f64] slices representing the current values of the parameters.
25    ///     Each parameter is represented as a slice, the slice sizes are specified by
26    ///     `parameter_sizes`.
27    ///   - residuals - mutable slice of [f64] for residuals outputs, the size is specified by
28    ///     `num_residuals`.
29    ///   - jacobians: [JacobianType](crate::types::JacobianType) - represents a mutable
30    ///     structure to output the Jacobian. Sometimes the solver doesn't need the Jacobian or
31    ///     some of its components, in this case the corresponding value is [None]. For the required
32    ///     components it has a 3-D shape: top index is for the parameter index, middle index is for
33    ///     the residual index, and the most inner dimension is for the given parameter component
34    ///     index. So the size of top-level [Some] is defined by `parameter_sizes.len()`,
35    ///     second-level [Some]'s slice length is `num_residuals`, and the bottom-level slice has
36    ///     length of `parameter_sizes[i]`, where `i` is the top-level index.
37    /// - parameter_sizes - sizes of the parameter vectors.
38    /// - num_residuals - length of the residual vector, usually corresponds to the number of
39    ///   data points.
40    pub fn new(
41        func: impl Into<CostFunctionType<'cost>>,
42        parameter_sizes: impl Into<Vec<usize>>,
43        num_residuals: usize,
44    ) -> Self {
45        let parameter_sizes = parameter_sizes.into();
46        let parameter_block_sizes: Vec<_> =
47            parameter_sizes.iter().map(|&size| size as i32).collect();
48
49        let safe_func = func.into();
50        let rust_func: Box<dyn Fn(*const *const f64, *mut f64, *mut *mut f64) -> bool + 'cost> =
51            Box::new(move |parameters_ptr, residuals_ptr, jacobians_ptr| {
52                let parameter_pointers =
53                    unsafe { slice::from_raw_parts(parameters_ptr, parameter_sizes.len()) };
54                let parameters = parameter_pointers
55                    .iter()
56                    .zip(parameter_sizes.iter())
57                    .map(|(&p, &size)| unsafe { slice::from_raw_parts(p, size) })
58                    .collect::<Vec<_>>();
59                let residuals = unsafe { slice::from_raw_parts_mut(residuals_ptr, num_residuals) };
60                let mut jacobians_owned =
61                    OwnedJacobian::from_pointer(jacobians_ptr, &parameter_sizes, num_residuals);
62                let mut jacobian_references = jacobians_owned.references();
63                safe_func(
64                    &parameters,
65                    residuals,
66                    jacobian_references.as_mut().map(|v| &mut v[..]),
67                )
68            });
69        let inner = ffi::new_callback_cost_function(
70            Box::new(rust_func.into()),
71            num_residuals as i32,
72            &parameter_block_sizes,
73        );
74        Self(inner)
75    }
76
77    pub fn into_inner(self) -> cxx::UniquePtr<ffi::CallbackCostFunction<'cost>> {
78        self.0
79    }
80}
81
82struct OwnedJacobian<'a>(Option<Vec<Option<Vec<&'a mut [f64]>>>>);
83
84impl<'a> OwnedJacobian<'a> {
85    fn from_pointer(
86        pointer: *mut *mut f64,
87        parameter_sizes: &[usize],
88        num_residuals: usize,
89    ) -> Self {
90        if pointer.is_null() {
91            return Self(None);
92        }
93        let per_parameter = unsafe { slice::from_raw_parts_mut(pointer, parameter_sizes.len()) };
94        let vec = per_parameter
95            .iter()
96            .zip(parameter_sizes)
97            .map(|(&p, &size)| OwnedDerivative::from_pointer(p, size, num_residuals).0)
98            .collect();
99        Self(Some(vec))
100    }
101
102    fn references(&'a mut self) -> Option<Vec<Option<&'a mut [&'a mut [f64]]>>> {
103        let v = self
104            .0
105            .as_mut()?
106            .iter_mut()
107            .map(|der| der.as_mut().map(|v| &mut v[..]))
108            .collect();
109        Some(v)
110    }
111}
112
113struct OwnedDerivative<'a>(Option<Vec<&'a mut [f64]>>);
114
115impl OwnedDerivative<'_> {
116    fn from_pointer(pointer: *mut f64, parameter_size: usize, num_residuals: usize) -> Self {
117        if pointer.is_null() {
118            return Self(None);
119        }
120        let per_residual_per_param_component =
121            { unsafe { slice::from_raw_parts_mut(pointer, parameter_size * num_residuals) } };
122        let v = per_residual_per_param_component
123            .chunks_exact_mut(parameter_size)
124            .collect();
125        Self(Some(v))
126    }
127}