1pub use cxx;
2
3#[cxx::bridge(namespace = "ceres")]
4pub mod ffi {
5 #![allow(clippy::needless_lifetimes)]
7 #![allow(clippy::needless_maybe_sized)]
9 #![allow(clippy::missing_safety_doc)]
11
12 #[repr(u32)]
13 enum MinimizerType {
14 LINE_SEARCH,
15 TRUST_REGION,
16 }
17
18 #[repr(u32)]
19 enum LineSearchDirectionType {
20 STEEPEST_DESCENT,
21 NONLINEAR_CONJUGATE_GRADIENT,
22 LBFGS,
23 BFGS,
24 }
25
26 #[repr(u32)]
27 enum LineSearchType {
28 ARMIJO,
29 WOLFE,
30 }
31
32 #[repr(u32)]
33 enum NonlinearConjugateGradientType {
34 FLETCHER_REEVES,
35 POLAK_RIBIERE,
36 HESTENES_STIEFEL,
37 }
38
39 #[repr(u32)]
40 enum LineSearchInterpolationType {
41 BISECTION,
42 QUADRATIC,
43 CUBIC,
44 }
45
46 #[repr(u32)]
47 enum TrustRegionStrategyType {
48 LEVENBERG_MARQUARDT,
49 DOGLEG,
50 }
51
52 #[repr(u32)]
53 enum DoglegType {
54 TRADITIONAL_DOGLEG,
55 SUBSPACE_DOGLEG,
56 }
57
58 #[repr(u32)]
59 enum LinearSolverType {
60 DENSE_NORMAL_CHOLESKY,
61 DENSE_QR,
62 SPARSE_NORMAL_CHOLESKY,
63 DENSE_SCHUR,
64 SPARSE_SCHUR,
65 ITERATIVE_SCHUR,
66 CGNR,
67 }
68
69 #[repr(u32)]
70 enum PreconditionerType {
71 IDENTITY,
72 JACOBI,
73 SCHUR_JACOBI,
74 SCHUR_POWER_SERIES_EXPANSION,
75 CLUSTER_JACOBI,
76 CLUSTER_TRIDIAGONAL,
77 SUBSET,
78 }
79
80 #[repr(u32)]
81 enum VisibilityClusteringType {
82 CANONICAL_VIEWS,
83 SINGLE_LINKAGE,
84 }
85
86 #[repr(u32)]
87 enum DenseLinearAlgebraLibraryType {
88 EIGEN,
89 LAPACK,
90 CUDA,
91 }
92
93 #[repr(u32)]
94 enum SparseLinearAlgebraLibraryType {
95 SUITE_SPARSE,
96 EIGEN_SPARSE,
97 ACCELERATE_SPARSE,
98 CUDA_SPARSE,
99 NO_SPARSE,
100 }
101
102 #[repr(u32)]
103 enum LoggingType {
104 SILENT,
105 PER_MINIMIZER_ITERATION,
106 }
107
108 #[repr(u32)]
109 enum DumpFormatType {
110 CONSOLE,
111 TEXTFILE,
112 }
113
114 extern "Rust" {
115 type RustCostFunction<'cost>;
116 unsafe fn evaluate(
117 self: &RustCostFunction,
118 parameters: *const *const f64,
119 residuals: *mut f64,
120 jacobians: *mut *mut f64,
121 ) -> bool;
122
123 type RustLossFunction;
124 unsafe fn evaluate(self: &RustLossFunction, sq_norm: f64, out: *mut f64);
125 }
126
127 unsafe extern "C++" {
128 include!("ceres-solver-sys/src/lib.h");
129
130 type MinimizerType;
131 type LineSearchDirectionType;
132 type LineSearchType;
133 type NonlinearConjugateGradientType;
134 type LineSearchInterpolationType;
135 type TrustRegionStrategyType;
136 type DoglegType;
137 type LinearSolverType;
138 type PreconditionerType;
139 type VisibilityClusteringType;
140 type DenseLinearAlgebraLibraryType;
141 type SparseLinearAlgebraLibraryType;
142 type LoggingType;
143 type DumpFormatType;
144
145 type CallbackCostFunction<'cost>;
146 fn new_callback_cost_function<'cost>(
148 inner: Box<RustCostFunction<'cost>>,
149 num_residuals: i32,
150 parameter_block_sizes: &[i32],
151 ) -> UniquePtr<CallbackCostFunction<'cost>>;
152
153 type LossFunction;
154 fn new_callback_loss_function(inner: Box<RustLossFunction>) -> UniquePtr<LossFunction>;
156 fn new_trivial_loss() -> UniquePtr<LossFunction>;
158 fn new_huber_loss(a: f64) -> UniquePtr<LossFunction>;
160 fn new_soft_l_one_loss(a: f64) -> UniquePtr<LossFunction>;
162 fn new_cauchy_loss(a: f64) -> UniquePtr<LossFunction>;
164 fn new_arctan_loss(a: f64) -> UniquePtr<LossFunction>;
166 fn new_tolerant_loss(a: f64, b: f64) -> UniquePtr<LossFunction>;
168 fn new_tukey_loss(a: f64) -> UniquePtr<LossFunction>;
170
171 type ResidualBlockId;
172
173 type Problem<'cost>;
174 unsafe fn SetParameterBlockConstant(self: Pin<&mut Problem>, values: *const f64);
179 unsafe fn SetParameterBlockVariable(self: Pin<&mut Problem>, values: *mut f64);
184 unsafe fn IsParameterBlockConstant(self: &Problem, values: *const f64) -> bool;
189 unsafe fn SetParameterLowerBound(
194 self: Pin<&mut Problem>,
195 values: *mut f64,
196 index: i32,
197 lower_bound: f64,
198 );
199 unsafe fn SetParameterUpperBound(
204 self: Pin<&mut Problem>,
205 values: *mut f64,
206 index: i32,
207 upper_bound: f64,
208 );
209 fn NumParameterBlocks(self: &Problem) -> i32;
210 fn NumParameters(self: &Problem) -> i32;
211 fn NumResidualBlocks(self: &Problem) -> i32;
212 fn NumResiduals(self: &Problem) -> i32;
213 unsafe fn ParameterBlockSize(self: &Problem, values: *const f64) -> i32;
218 unsafe fn HasParameterBlock(self: &Problem, values: *const f64) -> bool;
223 fn new_problem<'cost>() -> UniquePtr<Problem<'cost>>;
225 unsafe fn add_residual_block<'cost>(
230 problem: Pin<&mut Problem<'cost>>,
231 cost_function: UniquePtr<CallbackCostFunction<'cost>>,
232 loss_function: UniquePtr<LossFunction>,
233 parameter_blocks: *const *mut f64,
234 num_parameter_blocks: i32,
235 ) -> SharedPtr<ResidualBlockId>;
236
237 type SolverOptions;
238 fn is_valid(self: &SolverOptions, error: Pin<&mut CxxString>) -> bool;
239 fn set_minimizer_type(self: Pin<&mut SolverOptions>, minimizer_type: MinimizerType);
240 fn set_line_search_direction_type(
241 self: Pin<&mut SolverOptions>,
242 line_search_direction_type: LineSearchDirectionType,
243 );
244 fn set_line_search_type(self: Pin<&mut SolverOptions>, line_search_type: LineSearchType);
245 fn set_nonlinear_conjugate_gradient_type(
246 self: Pin<&mut SolverOptions>,
247 nonlinear_conjugate_gradient_type: NonlinearConjugateGradientType,
248 );
249 fn set_max_lbfgs_rank(self: Pin<&mut SolverOptions>, max_rank: i32);
250 fn set_use_approximate_eigenvalue_bfgs_scaling(self: Pin<&mut SolverOptions>, yes: bool);
251 fn set_line_search_interpolation_type(
252 self: Pin<&mut SolverOptions>,
253 line_search_interpolation_type: LineSearchInterpolationType,
254 );
255 fn set_min_line_search_step_size(self: Pin<&mut SolverOptions>, step_size: f64);
256 fn set_line_search_sufficient_function_decrease(
257 self: Pin<&mut SolverOptions>,
258 sufficient_decrease: f64,
259 );
260 fn set_max_line_search_step_contraction(
261 self: Pin<&mut SolverOptions>,
262 max_step_contraction: f64,
263 );
264 fn set_min_line_search_step_contraction(
265 self: Pin<&mut SolverOptions>,
266 min_step_contraction: f64,
267 );
268 fn set_max_num_line_search_direction_restarts(
269 self: Pin<&mut SolverOptions>,
270 max_num_restarts: i32,
271 );
272 fn set_line_search_sufficient_curvature_decrease(
273 self: Pin<&mut SolverOptions>,
274 sufficient_curvature_decrease: f64,
275 );
276 fn set_max_line_search_step_expansion(
277 self: Pin<&mut SolverOptions>,
278 max_step_expansion: f64,
279 );
280 fn set_trust_region_strategy_type(
281 self: Pin<&mut SolverOptions>,
282 trust_region_strategy_type: TrustRegionStrategyType,
283 );
284 fn set_dogleg_type(self: Pin<&mut SolverOptions>, dogleg_type: DoglegType);
285 fn set_use_nonmonotonic_steps(self: Pin<&mut SolverOptions>, yes: bool);
286 fn set_max_consecutive_nonmonotonic_steps(
287 self: Pin<&mut SolverOptions>,
288 max_consecutive_nonmonotonic_steps: i32,
289 );
290 fn set_max_num_iterations(self: Pin<&mut SolverOptions>, max_num_iterations: i32);
291 fn set_max_solver_time_in_seconds(
292 self: Pin<&mut SolverOptions>,
293 max_solver_time_in_seconds: f64,
294 );
295 fn set_num_threads(self: Pin<&mut SolverOptions>, num_threads: i32);
296 fn set_initial_trust_region_radius(
297 self: Pin<&mut SolverOptions>,
298 initial_trust_region_radius: f64,
299 );
300 fn set_max_trust_region_radius(self: Pin<&mut SolverOptions>, max_trust_region_radius: f64);
301 fn set_min_trust_region_radius(self: Pin<&mut SolverOptions>, min_trust_region_radius: f64);
302 fn set_min_relative_decrease(self: Pin<&mut SolverOptions>, min_relative_decrease: f64);
303 fn set_min_lm_diagonal(self: Pin<&mut SolverOptions>, min_lm_diagonal: f64);
304 fn set_max_lm_diagonal(self: Pin<&mut SolverOptions>, max_lm_diagonal: f64);
305 fn set_max_num_consecutive_invalid_steps(
306 self: Pin<&mut SolverOptions>,
307 max_num_consecutive_invalid_steps: i32,
308 );
309 fn set_function_tolerance(self: Pin<&mut SolverOptions>, function_tolerance: f64);
310 fn set_gradient_tolerance(self: Pin<&mut SolverOptions>, gradient_tolerance: f64);
311 fn set_parameter_tolerance(self: Pin<&mut SolverOptions>, parameter_tolerance: f64);
312 fn set_linear_solver_type(
313 self: Pin<&mut SolverOptions>,
314 linear_solver_type: LinearSolverType,
315 );
316 fn set_preconditioner_type(
317 self: Pin<&mut SolverOptions>,
318 preconditioner_type: PreconditionerType,
319 );
320 fn set_visibility_clustering_type(
321 self: Pin<&mut SolverOptions>,
322 visibility_clustering_type: VisibilityClusteringType,
323 );
324 fn set_residual_blocks_for_subset_preconditioner(
325 self: Pin<&mut SolverOptions>,
326 residual_blocks: &[SharedPtr<ResidualBlockId>],
327 );
328 fn set_dense_linear_algebra_library_type(
329 self: Pin<&mut SolverOptions>,
330 dense_linear_algebra_library_type: DenseLinearAlgebraLibraryType,
331 );
332 fn set_sparse_linear_algebra_library_type(
333 self: Pin<&mut SolverOptions>,
334 sparse_linear_algebra_library_type: SparseLinearAlgebraLibraryType,
335 );
336 fn set_logging_type(self: Pin<&mut SolverOptions>, logging_type: LoggingType);
337 fn set_minimizer_progress_to_stdout(self: Pin<&mut SolverOptions>, yes: bool);
338 fn set_trust_region_minimizer_iterations_to_dump(
339 self: Pin<&mut SolverOptions>,
340 iterations_to_dump: &[i32],
341 );
342 fn set_trust_region_problem_dump_directory(
343 self: Pin<&mut SolverOptions>,
344 directory: Pin<&CxxString>,
345 );
346 fn set_trust_region_problem_dump_format_type(
347 self: Pin<&mut SolverOptions>,
348 trust_region_problem_dump_format_type: DumpFormatType,
349 );
350 fn set_check_gradients(self: Pin<&mut SolverOptions>, yes: bool);
351 fn set_gradient_check_relative_precision(
352 self: Pin<&mut SolverOptions>,
353 gradient_check_relative_precision: f64,
354 );
355 fn set_gradient_check_numeric_derivative_relative_step_size(
356 self: Pin<&mut SolverOptions>,
357 gradient_check_numeric_derivative_relative_step_size: f64,
358 );
359 fn set_update_state_every_iteration(self: Pin<&mut SolverOptions>, yes: bool);
360
361 fn new_solver_options() -> UniquePtr<SolverOptions>;
363
364 type SolverSummary;
365 fn brief_report(self: &SolverSummary) -> UniquePtr<CxxString>;
366 fn full_report(self: &SolverSummary) -> UniquePtr<CxxString>;
367 fn is_solution_usable(self: &SolverSummary) -> bool;
368 fn initial_cost(self: &SolverSummary) -> f64;
369 fn final_cost(self: &SolverSummary) -> f64;
370 fn fixed_cost(self: &SolverSummary) -> f64;
371 fn num_successful_steps(self: &SolverSummary) -> i32;
372 fn num_unsuccessful_steps(self: &SolverSummary) -> i32;
373 fn num_inner_iteration_steps(self: &SolverSummary) -> i32;
374 fn num_line_search_steps(self: &SolverSummary) -> i32;
375 fn new_solver_summary() -> UniquePtr<SolverSummary>;
377
378 fn solve(
380 options: &SolverOptions,
381 problem: Pin<&mut Problem>,
382 summary: Pin<&mut SolverSummary>,
383 );
384 }
385}
386
387pub struct RustCostFunction<'cost>(
388 pub Box<dyn Fn(*const *const f64, *mut f64, *mut *mut f64) -> bool + 'cost>,
389);
390
391impl RustCostFunction<'_> {
392 pub fn evaluate(
393 &self,
394 parameters: *const *const f64,
395 residuals: *mut f64,
396 jacobians: *mut *mut f64,
397 ) -> bool {
398 (self.0)(parameters, residuals, jacobians)
399 }
400}
401
402impl<'cost> From<Box<dyn Fn(*const *const f64, *mut f64, *mut *mut f64) -> bool + 'cost>>
403 for RustCostFunction<'cost>
404{
405 fn from(
406 value: Box<dyn Fn(*const *const f64, *mut f64, *mut *mut f64) -> bool + 'cost>,
407 ) -> Self {
408 Self(value)
409 }
410}
411
412pub struct RustLossFunction(pub Box<dyn Fn(f64, *mut f64)>);
413
414impl RustLossFunction {
415 pub fn evaluate(&self, sq_norm: f64, out: *mut f64) {
416 (self.0)(sq_norm, out)
417 }
418}
419
420impl From<Box<dyn Fn(f64, *mut f64)>> for RustLossFunction {
421 fn from(value: Box<dyn Fn(f64, *mut f64)>) -> Self {
422 Self(value)
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use std::ptr::slice_from_raw_parts_mut;
430
431 use approx::assert_abs_diff_eq;
432 use cxx::UniquePtr;
433
434 fn cost_evaluate(
436 parameters: *const *const f64,
437 residuals: *mut f64,
438 jacobians: *mut *mut f64,
439 ) -> bool {
440 let x = unsafe { **parameters };
441 unsafe {
442 *residuals = x - 3.0;
443 }
444 if !jacobians.is_null() {
445 let d_dx = unsafe { *jacobians };
446 if !d_dx.is_null() {
447 unsafe {
448 *d_dx = 1.0;
449 }
450 }
451 }
452 true
453 }
454
455 fn loss_evaluate(sq_norm: f64, out: *mut f64) {
457 let out = slice_from_raw_parts_mut(out, 3);
458 unsafe {
459 (*out)[0] = sq_norm;
460 (*out)[1] = 1.0;
461 (*out)[2] = 0.0;
462 }
463 }
464
465 fn end_to_end(loss: UniquePtr<ffi::LossFunction>) {
466 let parameter_block_sizes = [1];
467 let mut x_init = [0.0];
468 let parameter_blocks = [&mut x_init as *mut f64];
469
470 let rust_cost_function = RustCostFunction(Box::new(cost_evaluate));
471 let cost_function = ffi::new_callback_cost_function(
472 Box::new(rust_cost_function),
473 1,
474 ¶meter_block_sizes,
475 );
476
477 let mut problem = ffi::new_problem();
478 unsafe {
479 ffi::add_residual_block(
480 problem.as_mut().unwrap(),
481 cost_function,
482 loss,
483 parameter_blocks.as_ptr(),
484 parameter_blocks.len() as i32,
485 );
486 }
487
488 let mut options = ffi::new_solver_options();
489 options
490 .as_mut()
491 .unwrap()
492 .set_logging_type(ffi::LoggingType::SILENT);
493
494 let mut summary = ffi::new_solver_summary();
495 ffi::solve(
496 options.as_ref().unwrap(),
497 problem.as_mut().unwrap(),
498 summary.as_mut().unwrap(),
499 );
500
501 assert_abs_diff_eq!(x_init[0], 3.0, epsilon = 1e-8);
502 }
503
504 #[test]
505 fn end_to_end_no_loss() {
506 end_to_end(UniquePtr::null());
507 }
508
509 #[test]
510 fn end_to_end_custom_loss() {
511 let rust_loss_function = RustLossFunction(Box::new(loss_evaluate));
512 let loss_function = ffi::new_callback_loss_function(Box::new(rust_loss_function));
513 end_to_end(loss_function);
514 }
515
516 #[test]
517 fn end_to_end_stock_loss() {
518 end_to_end(ffi::new_arctan_loss(1.0));
519 }
520}