ceres_solver_sys/
lib.rs

1pub use cxx;
2
3#[cxx::bridge(namespace = "ceres")]
4pub mod ffi {
5    // The explicit lifetimes make some signatures more verbose.
6    #![allow(clippy::needless_lifetimes)]
7    // False positive https://github.com/rust-lang/rust-clippy/issues/13360
8    #![allow(clippy::needless_maybe_sized)]
9    // False positive, I believe
10    #![allow(clippy::missing_safety_doc)]
11
12    #[repr(u32)]
13    enum MinimizerType {
14        LINE_SEARCH,
15        TRUST_REGION,
16    }
17
18    #[repr(u32)]
19    enum LineSearchDirectionType {
20        STEEPEST_DESCENT,
21        NONLINEAR_CONJUGATE_GRADIENT,
22        LBFGS,
23        BFGS,
24    }
25
26    #[repr(u32)]
27    enum LineSearchType {
28        ARMIJO,
29        WOLFE,
30    }
31
32    #[repr(u32)]
33    enum NonlinearConjugateGradientType {
34        FLETCHER_REEVES,
35        POLAK_RIBIERE,
36        HESTENES_STIEFEL,
37    }
38
39    #[repr(u32)]
40    enum LineSearchInterpolationType {
41        BISECTION,
42        QUADRATIC,
43        CUBIC,
44    }
45
46    #[repr(u32)]
47    enum TrustRegionStrategyType {
48        LEVENBERG_MARQUARDT,
49        DOGLEG,
50    }
51
52    #[repr(u32)]
53    enum DoglegType {
54        TRADITIONAL_DOGLEG,
55        SUBSPACE_DOGLEG,
56    }
57
58    #[repr(u32)]
59    enum LinearSolverType {
60        DENSE_NORMAL_CHOLESKY,
61        DENSE_QR,
62        SPARSE_NORMAL_CHOLESKY,
63        DENSE_SCHUR,
64        SPARSE_SCHUR,
65        ITERATIVE_SCHUR,
66        CGNR,
67    }
68
69    #[repr(u32)]
70    enum PreconditionerType {
71        IDENTITY,
72        JACOBI,
73        SCHUR_JACOBI,
74        SCHUR_POWER_SERIES_EXPANSION,
75        CLUSTER_JACOBI,
76        CLUSTER_TRIDIAGONAL,
77        SUBSET,
78    }
79
80    #[repr(u32)]
81    enum VisibilityClusteringType {
82        CANONICAL_VIEWS,
83        SINGLE_LINKAGE,
84    }
85
86    #[repr(u32)]
87    enum DenseLinearAlgebraLibraryType {
88        EIGEN,
89        LAPACK,
90        CUDA,
91    }
92
93    #[repr(u32)]
94    enum SparseLinearAlgebraLibraryType {
95        SUITE_SPARSE,
96        EIGEN_SPARSE,
97        ACCELERATE_SPARSE,
98        CUDA_SPARSE,
99        NO_SPARSE,
100    }
101
102    #[repr(u32)]
103    enum LoggingType {
104        SILENT,
105        PER_MINIMIZER_ITERATION,
106    }
107
108    #[repr(u32)]
109    enum DumpFormatType {
110        CONSOLE,
111        TEXTFILE,
112    }
113
114    extern "Rust" {
115        type RustCostFunction<'cost>;
116        unsafe fn evaluate(
117            self: &RustCostFunction,
118            parameters: *const *const f64,
119            residuals: *mut f64,
120            jacobians: *mut *mut f64,
121        ) -> bool;
122
123        type RustLossFunction;
124        unsafe fn evaluate(self: &RustLossFunction, sq_norm: f64, out: *mut f64);
125    }
126
127    unsafe extern "C++" {
128        include!("ceres-solver-sys/src/lib.h");
129
130        type MinimizerType;
131        type LineSearchDirectionType;
132        type LineSearchType;
133        type NonlinearConjugateGradientType;
134        type LineSearchInterpolationType;
135        type TrustRegionStrategyType;
136        type DoglegType;
137        type LinearSolverType;
138        type PreconditionerType;
139        type VisibilityClusteringType;
140        type DenseLinearAlgebraLibraryType;
141        type SparseLinearAlgebraLibraryType;
142        type LoggingType;
143        type DumpFormatType;
144
145        type CallbackCostFunction<'cost>;
146        /// Creates new C++ cost function from Rust cost function;
147        fn new_callback_cost_function<'cost>(
148            inner: Box<RustCostFunction<'cost>>,
149            num_residuals: i32,
150            parameter_block_sizes: &[i32],
151        ) -> UniquePtr<CallbackCostFunction<'cost>>;
152
153        type LossFunction;
154        /// Creates new C++ loss function from Rust loss function;
155        fn new_callback_loss_function(inner: Box<RustLossFunction>) -> UniquePtr<LossFunction>;
156        /// Creates stock TrivialLoss.
157        fn new_trivial_loss() -> UniquePtr<LossFunction>;
158        /// Creates stock HuberLoss.
159        fn new_huber_loss(a: f64) -> UniquePtr<LossFunction>;
160        /// Creates stock SoftLOneLoss.
161        fn new_soft_l_one_loss(a: f64) -> UniquePtr<LossFunction>;
162        /// Creates stock CauchyLoss.
163        fn new_cauchy_loss(a: f64) -> UniquePtr<LossFunction>;
164        /// Creates stock ArctanLoss.
165        fn new_arctan_loss(a: f64) -> UniquePtr<LossFunction>;
166        /// Creates stock TolerantLoss.
167        fn new_tolerant_loss(a: f64, b: f64) -> UniquePtr<LossFunction>;
168        /// Creates stock TukeyLoss.
169        fn new_tukey_loss(a: f64) -> UniquePtr<LossFunction>;
170
171        type ResidualBlockId;
172
173        type Problem<'cost>;
174        /// Set parameter to be constant.
175        ///
176        /// # Safety
177        /// `values` must point to already added parameter block.
178        unsafe fn SetParameterBlockConstant(self: Pin<&mut Problem>, values: *const f64);
179        /// Set parameter to vary.
180        ///
181        /// # Safety
182        /// `values` must point to already added parameter block.
183        unsafe fn SetParameterBlockVariable(self: Pin<&mut Problem>, values: *mut f64);
184        /// Check if parameter is constant.
185        ///
186        /// # Safety
187        /// `values` must point to already added parameter block.
188        unsafe fn IsParameterBlockConstant(self: &Problem, values: *const f64) -> bool;
189        /// Set lower bound for a component of a parameter block.
190        ///
191        /// # Safety
192        /// `values` must point to already added parameter block.
193        unsafe fn SetParameterLowerBound(
194            self: Pin<&mut Problem>,
195            values: *mut f64,
196            index: i32,
197            lower_bound: f64,
198        );
199        /// Set upper bound for a component of a parameter block.
200        ///
201        /// # Safety
202        /// `values` must point to already added parameter block.
203        unsafe fn SetParameterUpperBound(
204            self: Pin<&mut Problem>,
205            values: *mut f64,
206            index: i32,
207            upper_bound: f64,
208        );
209        fn NumParameterBlocks(self: &Problem) -> i32;
210        fn NumParameters(self: &Problem) -> i32;
211        fn NumResidualBlocks(self: &Problem) -> i32;
212        fn NumResiduals(self: &Problem) -> i32;
213        /// Number of components of the parameter.
214        ///
215        /// # Safety
216        /// `values` must point to already added parameter block.
217        unsafe fn ParameterBlockSize(self: &Problem, values: *const f64) -> i32;
218        /// Checks if problem has a given parameter.
219        ///
220        /// # Safety
221        /// It should be safe to call this function with any pointer.
222        unsafe fn HasParameterBlock(self: &Problem, values: *const f64) -> bool;
223        /// Creates new Problem.
224        fn new_problem<'cost>() -> UniquePtr<Problem<'cost>>;
225        /// Adds a residual block to the problem.
226        ///
227        /// # Safety
228        /// `parameter_blocks` must outlive `problem`.
229        unsafe fn add_residual_block<'cost>(
230            problem: Pin<&mut Problem<'cost>>,
231            cost_function: UniquePtr<CallbackCostFunction<'cost>>,
232            loss_function: UniquePtr<LossFunction>,
233            parameter_blocks: *const *mut f64,
234            num_parameter_blocks: i32,
235        ) -> SharedPtr<ResidualBlockId>;
236
237        type SolverOptions;
238        fn is_valid(self: &SolverOptions, error: Pin<&mut CxxString>) -> bool;
239        fn set_minimizer_type(self: Pin<&mut SolverOptions>, minimizer_type: MinimizerType);
240        fn set_line_search_direction_type(
241            self: Pin<&mut SolverOptions>,
242            line_search_direction_type: LineSearchDirectionType,
243        );
244        fn set_line_search_type(self: Pin<&mut SolverOptions>, line_search_type: LineSearchType);
245        fn set_nonlinear_conjugate_gradient_type(
246            self: Pin<&mut SolverOptions>,
247            nonlinear_conjugate_gradient_type: NonlinearConjugateGradientType,
248        );
249        fn set_max_lbfgs_rank(self: Pin<&mut SolverOptions>, max_rank: i32);
250        fn set_use_approximate_eigenvalue_bfgs_scaling(self: Pin<&mut SolverOptions>, yes: bool);
251        fn set_line_search_interpolation_type(
252            self: Pin<&mut SolverOptions>,
253            line_search_interpolation_type: LineSearchInterpolationType,
254        );
255        fn set_min_line_search_step_size(self: Pin<&mut SolverOptions>, step_size: f64);
256        fn set_line_search_sufficient_function_decrease(
257            self: Pin<&mut SolverOptions>,
258            sufficient_decrease: f64,
259        );
260        fn set_max_line_search_step_contraction(
261            self: Pin<&mut SolverOptions>,
262            max_step_contraction: f64,
263        );
264        fn set_min_line_search_step_contraction(
265            self: Pin<&mut SolverOptions>,
266            min_step_contraction: f64,
267        );
268        fn set_max_num_line_search_direction_restarts(
269            self: Pin<&mut SolverOptions>,
270            max_num_restarts: i32,
271        );
272        fn set_line_search_sufficient_curvature_decrease(
273            self: Pin<&mut SolverOptions>,
274            sufficient_curvature_decrease: f64,
275        );
276        fn set_max_line_search_step_expansion(
277            self: Pin<&mut SolverOptions>,
278            max_step_expansion: f64,
279        );
280        fn set_trust_region_strategy_type(
281            self: Pin<&mut SolverOptions>,
282            trust_region_strategy_type: TrustRegionStrategyType,
283        );
284        fn set_dogleg_type(self: Pin<&mut SolverOptions>, dogleg_type: DoglegType);
285        fn set_use_nonmonotonic_steps(self: Pin<&mut SolverOptions>, yes: bool);
286        fn set_max_consecutive_nonmonotonic_steps(
287            self: Pin<&mut SolverOptions>,
288            max_consecutive_nonmonotonic_steps: i32,
289        );
290        fn set_max_num_iterations(self: Pin<&mut SolverOptions>, max_num_iterations: i32);
291        fn set_max_solver_time_in_seconds(
292            self: Pin<&mut SolverOptions>,
293            max_solver_time_in_seconds: f64,
294        );
295        fn set_num_threads(self: Pin<&mut SolverOptions>, num_threads: i32);
296        fn set_initial_trust_region_radius(
297            self: Pin<&mut SolverOptions>,
298            initial_trust_region_radius: f64,
299        );
300        fn set_max_trust_region_radius(self: Pin<&mut SolverOptions>, max_trust_region_radius: f64);
301        fn set_min_trust_region_radius(self: Pin<&mut SolverOptions>, min_trust_region_radius: f64);
302        fn set_min_relative_decrease(self: Pin<&mut SolverOptions>, min_relative_decrease: f64);
303        fn set_min_lm_diagonal(self: Pin<&mut SolverOptions>, min_lm_diagonal: f64);
304        fn set_max_lm_diagonal(self: Pin<&mut SolverOptions>, max_lm_diagonal: f64);
305        fn set_max_num_consecutive_invalid_steps(
306            self: Pin<&mut SolverOptions>,
307            max_num_consecutive_invalid_steps: i32,
308        );
309        fn set_function_tolerance(self: Pin<&mut SolverOptions>, function_tolerance: f64);
310        fn set_gradient_tolerance(self: Pin<&mut SolverOptions>, gradient_tolerance: f64);
311        fn set_parameter_tolerance(self: Pin<&mut SolverOptions>, parameter_tolerance: f64);
312        fn set_linear_solver_type(
313            self: Pin<&mut SolverOptions>,
314            linear_solver_type: LinearSolverType,
315        );
316        fn set_preconditioner_type(
317            self: Pin<&mut SolverOptions>,
318            preconditioner_type: PreconditionerType,
319        );
320        fn set_visibility_clustering_type(
321            self: Pin<&mut SolverOptions>,
322            visibility_clustering_type: VisibilityClusteringType,
323        );
324        fn set_residual_blocks_for_subset_preconditioner(
325            self: Pin<&mut SolverOptions>,
326            residual_blocks: &[SharedPtr<ResidualBlockId>],
327        );
328        fn set_dense_linear_algebra_library_type(
329            self: Pin<&mut SolverOptions>,
330            dense_linear_algebra_library_type: DenseLinearAlgebraLibraryType,
331        );
332        fn set_sparse_linear_algebra_library_type(
333            self: Pin<&mut SolverOptions>,
334            sparse_linear_algebra_library_type: SparseLinearAlgebraLibraryType,
335        );
336        fn set_logging_type(self: Pin<&mut SolverOptions>, logging_type: LoggingType);
337        fn set_minimizer_progress_to_stdout(self: Pin<&mut SolverOptions>, yes: bool);
338        fn set_trust_region_minimizer_iterations_to_dump(
339            self: Pin<&mut SolverOptions>,
340            iterations_to_dump: &[i32],
341        );
342        fn set_trust_region_problem_dump_directory(
343            self: Pin<&mut SolverOptions>,
344            directory: Pin<&CxxString>,
345        );
346        fn set_trust_region_problem_dump_format_type(
347            self: Pin<&mut SolverOptions>,
348            trust_region_problem_dump_format_type: DumpFormatType,
349        );
350        fn set_check_gradients(self: Pin<&mut SolverOptions>, yes: bool);
351        fn set_gradient_check_relative_precision(
352            self: Pin<&mut SolverOptions>,
353            gradient_check_relative_precision: f64,
354        );
355        fn set_gradient_check_numeric_derivative_relative_step_size(
356            self: Pin<&mut SolverOptions>,
357            gradient_check_numeric_derivative_relative_step_size: f64,
358        );
359        fn set_update_state_every_iteration(self: Pin<&mut SolverOptions>, yes: bool);
360
361        /// Create an instance wrapping Solver::Options.
362        fn new_solver_options() -> UniquePtr<SolverOptions>;
363
364        type SolverSummary;
365        fn brief_report(self: &SolverSummary) -> UniquePtr<CxxString>;
366        fn full_report(self: &SolverSummary) -> UniquePtr<CxxString>;
367        fn is_solution_usable(self: &SolverSummary) -> bool;
368        fn initial_cost(self: &SolverSummary) -> f64;
369        fn final_cost(self: &SolverSummary) -> f64;
370        fn fixed_cost(self: &SolverSummary) -> f64;
371        fn num_successful_steps(self: &SolverSummary) -> i32;
372        fn num_unsuccessful_steps(self: &SolverSummary) -> i32;
373        fn num_inner_iteration_steps(self: &SolverSummary) -> i32;
374        fn num_line_search_steps(self: &SolverSummary) -> i32;
375        /// Create an instance wrapping Solver::Summary.
376        fn new_solver_summary() -> UniquePtr<SolverSummary>;
377
378        /// Wrapper for Solve() function.
379        fn solve(
380            options: &SolverOptions,
381            problem: Pin<&mut Problem>,
382            summary: Pin<&mut SolverSummary>,
383        );
384    }
385}
386
387pub struct RustCostFunction<'cost>(
388    pub Box<dyn Fn(*const *const f64, *mut f64, *mut *mut f64) -> bool + 'cost>,
389);
390
391impl RustCostFunction<'_> {
392    pub fn evaluate(
393        &self,
394        parameters: *const *const f64,
395        residuals: *mut f64,
396        jacobians: *mut *mut f64,
397    ) -> bool {
398        (self.0)(parameters, residuals, jacobians)
399    }
400}
401
402impl<'cost> From<Box<dyn Fn(*const *const f64, *mut f64, *mut *mut f64) -> bool + 'cost>>
403    for RustCostFunction<'cost>
404{
405    fn from(
406        value: Box<dyn Fn(*const *const f64, *mut f64, *mut *mut f64) -> bool + 'cost>,
407    ) -> Self {
408        Self(value)
409    }
410}
411
412pub struct RustLossFunction(pub Box<dyn Fn(f64, *mut f64)>);
413
414impl RustLossFunction {
415    pub fn evaluate(&self, sq_norm: f64, out: *mut f64) {
416        (self.0)(sq_norm, out)
417    }
418}
419
420impl From<Box<dyn Fn(f64, *mut f64)>> for RustLossFunction {
421    fn from(value: Box<dyn Fn(f64, *mut f64)>) -> Self {
422        Self(value)
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use std::ptr::slice_from_raw_parts_mut;
430
431    use approx::assert_abs_diff_eq;
432    use cxx::UniquePtr;
433
434    // y = (x - 3), J = 1
435    fn cost_evaluate(
436        parameters: *const *const f64,
437        residuals: *mut f64,
438        jacobians: *mut *mut f64,
439    ) -> bool {
440        let x = unsafe { **parameters };
441        unsafe {
442            *residuals = x - 3.0;
443        }
444        if !jacobians.is_null() {
445            let d_dx = unsafe { *jacobians };
446            if !d_dx.is_null() {
447                unsafe {
448                    *d_dx = 1.0;
449                }
450            }
451        }
452        true
453    }
454
455    // Just the trivial loss
456    fn loss_evaluate(sq_norm: f64, out: *mut f64) {
457        let out = slice_from_raw_parts_mut(out, 3);
458        unsafe {
459            (*out)[0] = sq_norm;
460            (*out)[1] = 1.0;
461            (*out)[2] = 0.0;
462        }
463    }
464
465    fn end_to_end(loss: UniquePtr<ffi::LossFunction>) {
466        let parameter_block_sizes = [1];
467        let mut x_init = [0.0];
468        let parameter_blocks = [&mut x_init as *mut f64];
469
470        let rust_cost_function = RustCostFunction(Box::new(cost_evaluate));
471        let cost_function = ffi::new_callback_cost_function(
472            Box::new(rust_cost_function),
473            1,
474            &parameter_block_sizes,
475        );
476
477        let mut problem = ffi::new_problem();
478        unsafe {
479            ffi::add_residual_block(
480                problem.as_mut().unwrap(),
481                cost_function,
482                loss,
483                parameter_blocks.as_ptr(),
484                parameter_blocks.len() as i32,
485            );
486        }
487
488        let mut options = ffi::new_solver_options();
489        options
490            .as_mut()
491            .unwrap()
492            .set_logging_type(ffi::LoggingType::SILENT);
493
494        let mut summary = ffi::new_solver_summary();
495        ffi::solve(
496            options.as_ref().unwrap(),
497            problem.as_mut().unwrap(),
498            summary.as_mut().unwrap(),
499        );
500
501        assert_abs_diff_eq!(x_init[0], 3.0, epsilon = 1e-8);
502    }
503
504    #[test]
505    fn end_to_end_no_loss() {
506        end_to_end(UniquePtr::null());
507    }
508
509    #[test]
510    fn end_to_end_custom_loss() {
511        let rust_loss_function = RustLossFunction(Box::new(loss_evaluate));
512        let loss_function = ffi::new_callback_loss_function(Box::new(rust_loss_function));
513        end_to_end(loss_function);
514    }
515
516    #[test]
517    fn end_to_end_stock_loss() {
518        end_to_end(ffi::new_arctan_loss(1.0));
519    }
520}