1use crate::types::JacobianType;
7
8use ceres_solver_sys::cxx;
9use ceres_solver_sys::ffi;
10use std::slice;
11
12pub type CostFunctionType<'a> = Box<dyn Fn(&[&[f64]], &mut [f64], JacobianType<'_>) -> bool + 'a>;
13
14pub(crate) struct CostFunction<'cost>(cxx::UniquePtr<ffi::CallbackCostFunction<'cost>>);
16
17impl<'cost> CostFunction<'cost> {
18 pub fn new(
41 func: impl Into<CostFunctionType<'cost>>,
42 parameter_sizes: impl Into<Vec<usize>>,
43 num_residuals: usize,
44 ) -> Self {
45 let parameter_sizes = parameter_sizes.into();
46 let parameter_block_sizes: Vec<_> =
47 parameter_sizes.iter().map(|&size| size as i32).collect();
48
49 let safe_func = func.into();
50 let rust_func: Box<dyn Fn(*const *const f64, *mut f64, *mut *mut f64) -> bool + 'cost> =
51 Box::new(move |parameters_ptr, residuals_ptr, jacobians_ptr| {
52 let parameter_pointers =
53 unsafe { slice::from_raw_parts(parameters_ptr, parameter_sizes.len()) };
54 let parameters = parameter_pointers
55 .iter()
56 .zip(parameter_sizes.iter())
57 .map(|(&p, &size)| unsafe { slice::from_raw_parts(p, size) })
58 .collect::<Vec<_>>();
59 let residuals = unsafe { slice::from_raw_parts_mut(residuals_ptr, num_residuals) };
60 let mut jacobians_owned =
61 OwnedJacobian::from_pointer(jacobians_ptr, ¶meter_sizes, num_residuals);
62 let mut jacobian_references = jacobians_owned.references();
63 safe_func(
64 ¶meters,
65 residuals,
66 jacobian_references.as_mut().map(|v| &mut v[..]),
67 )
68 });
69 let inner = ffi::new_callback_cost_function(
70 Box::new(rust_func.into()),
71 num_residuals as i32,
72 ¶meter_block_sizes,
73 );
74 Self(inner)
75 }
76
77 pub fn into_inner(self) -> cxx::UniquePtr<ffi::CallbackCostFunction<'cost>> {
78 self.0
79 }
80}
81
82struct OwnedJacobian<'a>(Option<Vec<Option<Vec<&'a mut [f64]>>>>);
83
84impl<'a> OwnedJacobian<'a> {
85 fn from_pointer(
86 pointer: *mut *mut f64,
87 parameter_sizes: &[usize],
88 num_residuals: usize,
89 ) -> Self {
90 if pointer.is_null() {
91 return Self(None);
92 }
93 let per_parameter = unsafe { slice::from_raw_parts_mut(pointer, parameter_sizes.len()) };
94 let vec = per_parameter
95 .iter()
96 .zip(parameter_sizes)
97 .map(|(&p, &size)| OwnedDerivative::from_pointer(p, size, num_residuals).0)
98 .collect();
99 Self(Some(vec))
100 }
101
102 fn references(&'a mut self) -> Option<Vec<Option<&'a mut [&'a mut [f64]]>>> {
103 let v = self
104 .0
105 .as_mut()?
106 .iter_mut()
107 .map(|der| der.as_mut().map(|v| &mut v[..]))
108 .collect();
109 Some(v)
110 }
111}
112
113struct OwnedDerivative<'a>(Option<Vec<&'a mut [f64]>>);
114
115impl OwnedDerivative<'_> {
116 fn from_pointer(pointer: *mut f64, parameter_size: usize, num_residuals: usize) -> Self {
117 if pointer.is_null() {
118 return Self(None);
119 }
120 let per_residual_per_param_component =
121 { unsafe { slice::from_raw_parts_mut(pointer, parameter_size * num_residuals) } };
122 let v = per_residual_per_param_component
123 .chunks_exact_mut(parameter_size)
124 .collect();
125 Self(Some(v))
126 }
127}