1#![allow(clippy::cast_precision_loss)]
27
28use crate::{AlignError, AlignResult};
29
30#[derive(Debug, Clone)]
32pub struct BundleAdjustConfig {
33 pub max_iterations: usize,
35 pub initial_lambda: f64,
37 pub lambda_up_factor: f64,
39 pub lambda_down_factor: f64,
41 pub param_tolerance: f64,
43 pub error_tolerance: f64,
45}
46
47impl Default for BundleAdjustConfig {
48 fn default() -> Self {
49 Self {
50 max_iterations: 50,
51 initial_lambda: 1e-3,
52 lambda_up_factor: 10.0,
53 lambda_down_factor: 0.1,
54 param_tolerance: 1e-8,
55 error_tolerance: 1e-10,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
62pub struct Observation {
63 pub camera_idx: usize,
65 pub point_idx: usize,
67 pub pixel_x: f64,
69 pub pixel_y: f64,
71}
72
73impl Observation {
74 #[must_use]
76 pub fn new(camera_idx: usize, point_idx: usize, pixel_x: f64, pixel_y: f64) -> Self {
77 Self {
78 camera_idx,
79 point_idx,
80 pixel_x,
81 pixel_y,
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
89pub struct CameraParams {
90 pub rotation: [f64; 3],
92 pub translation: [f64; 3],
94 pub focal_length: f64,
96}
97
98impl CameraParams {
99 #[must_use]
101 pub fn new(rotation: [f64; 3], translation: [f64; 3], focal_length: f64) -> Self {
102 Self {
103 rotation,
104 translation,
105 focal_length,
106 }
107 }
108
109 #[must_use]
111 pub fn identity() -> Self {
112 Self {
113 rotation: [0.0; 3],
114 translation: [0.0; 3],
115 focal_length: 1.0,
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct Point3D {
123 pub x: f64,
125 pub y: f64,
127 pub z: f64,
129}
130
131impl Point3D {
132 #[must_use]
134 pub fn new(x: f64, y: f64, z: f64) -> Self {
135 Self { x, y, z }
136 }
137}
138
139#[derive(Debug, Clone)]
141pub struct BundleAdjustResult {
142 pub cameras: Vec<CameraParams>,
144 pub points: Vec<Point3D>,
146 pub final_error: f64,
148 pub iterations: usize,
150 pub converged: bool,
152}
153
154pub struct BundleAdjuster {
156 pub config: BundleAdjustConfig,
158}
159
160impl Default for BundleAdjuster {
161 fn default() -> Self {
162 Self {
163 config: BundleAdjustConfig::default(),
164 }
165 }
166}
167
168impl BundleAdjuster {
169 #[must_use]
171 pub fn new(config: BundleAdjustConfig) -> Self {
172 Self { config }
173 }
174
175 pub fn optimize(
188 &self,
189 cameras: &[CameraParams],
190 points: &[Point3D],
191 observations: &[Observation],
192 ) -> AlignResult<BundleAdjustResult> {
193 if cameras.is_empty() {
194 return Err(AlignError::InsufficientData(
195 "Need at least one camera".to_string(),
196 ));
197 }
198 if points.is_empty() {
199 return Err(AlignError::InsufficientData(
200 "Need at least one 3D point".to_string(),
201 ));
202 }
203 if observations.is_empty() {
204 return Err(AlignError::InsufficientData(
205 "Need at least one observation".to_string(),
206 ));
207 }
208
209 let num_cam_params = cameras.len() * 7;
213 let num_point_params = points.len() * 3;
214 let total_params = num_cam_params + num_point_params;
215 let num_residuals = observations.len() * 2;
216
217 let mut params = vec![0.0_f64; total_params];
218
219 for (i, cam) in cameras.iter().enumerate() {
221 let base = i * 7;
222 params[base] = cam.rotation[0];
223 params[base + 1] = cam.rotation[1];
224 params[base + 2] = cam.rotation[2];
225 params[base + 3] = cam.translation[0];
226 params[base + 4] = cam.translation[1];
227 params[base + 5] = cam.translation[2];
228 params[base + 6] = cam.focal_length;
229 }
230
231 for (i, pt) in points.iter().enumerate() {
233 let base = num_cam_params + i * 3;
234 params[base] = pt.x;
235 params[base + 1] = pt.y;
236 params[base + 2] = pt.z;
237 }
238
239 let mut lambda = self.config.initial_lambda;
240 let mut current_error = self.compute_total_error(¶ms, cameras.len(), observations)?;
241 let mut converged = false;
242 let mut iter = 0;
243
244 for iteration in 0..self.config.max_iterations {
245 iter = iteration + 1;
246
247 let (jacobian, residuals) = self.compute_jacobian_and_residuals(
249 ¶ms,
250 cameras.len(),
251 points.len(),
252 observations,
253 )?;
254
255 let jtj = self.compute_jtj(&jacobian, total_params, num_residuals);
257 let jtr = self.compute_jtr(&jacobian, &residuals, total_params, num_residuals);
258
259 let delta = self.solve_normal_equations(&jtj, &jtr, total_params, lambda)?;
261
262 let delta_norm: f64 = delta.iter().map(|d| d * d).sum::<f64>().sqrt();
264 let param_norm: f64 = params.iter().map(|p| p * p).sum::<f64>().sqrt().max(1.0);
265
266 if delta_norm / param_norm < self.config.param_tolerance {
267 converged = true;
268 break;
269 }
270
271 let trial_params: Vec<f64> = params.iter().zip(&delta).map(|(p, d)| p + d).collect();
273
274 let trial_error =
275 self.compute_total_error(&trial_params, cameras.len(), observations)?;
276
277 if trial_error < current_error {
278 let error_reduction = (current_error - trial_error) / current_error.max(1e-15);
280 params = trial_params;
281 current_error = trial_error;
282 lambda *= self.config.lambda_down_factor;
283 lambda = lambda.max(1e-12);
284
285 if error_reduction < self.config.error_tolerance {
286 converged = true;
287 break;
288 }
289 } else {
290 lambda *= self.config.lambda_up_factor;
292 lambda = lambda.min(1e10);
293 }
294 }
295
296 let mut opt_cameras = Vec::with_capacity(cameras.len());
298 for i in 0..cameras.len() {
299 let base = i * 7;
300 opt_cameras.push(CameraParams::new(
301 [params[base], params[base + 1], params[base + 2]],
302 [params[base + 3], params[base + 4], params[base + 5]],
303 params[base + 6],
304 ));
305 }
306
307 let mut opt_points = Vec::with_capacity(points.len());
308 for i in 0..points.len() {
309 let base = num_cam_params + i * 3;
310 opt_points.push(Point3D::new(
311 params[base],
312 params[base + 1],
313 params[base + 2],
314 ));
315 }
316
317 Ok(BundleAdjustResult {
318 cameras: opt_cameras,
319 points: opt_points,
320 final_error: current_error,
321 iterations: iter,
322 converged,
323 })
324 }
325
326 fn project(params: &[f64], cam_idx: usize, point_params: &[f64; 3]) -> (f64, f64) {
328 let base = cam_idx * 7;
329 let rx = params[base];
330 let ry = params[base + 1];
331 let rz = params[base + 2];
332 let tx = params[base + 3];
333 let ty = params[base + 4];
334 let tz = params[base + 5];
335 let f = params[base + 6];
336
337 let px = point_params[0];
338 let py = point_params[1];
339 let pz = point_params[2];
340
341 let theta = (rx * rx + ry * ry + rz * rz).sqrt();
343 let (r00, r01, r02, r10, r11, r12, r20, r21, r22) = if theta < 1e-10 {
344 (1.0, -rz, ry, rz, 1.0, -rx, -ry, rx, 1.0)
346 } else {
347 let c = theta.cos();
348 let s = theta.sin();
349 let t = 1.0 - c;
350 let kx = rx / theta;
351 let ky = ry / theta;
352 let kz = rz / theta;
353 (
354 t * kx * kx + c,
355 t * kx * ky - s * kz,
356 t * kx * kz + s * ky,
357 t * kx * ky + s * kz,
358 t * ky * ky + c,
359 t * ky * kz - s * kx,
360 t * kx * kz - s * ky,
361 t * ky * kz + s * kx,
362 t * kz * kz + c,
363 )
364 };
365
366 let cx = r00 * px + r01 * py + r02 * pz + tx;
368 let cy = r10 * px + r11 * py + r12 * pz + ty;
369 let cz = r20 * px + r21 * py + r22 * pz + tz;
370
371 if cz.abs() < 1e-10 {
373 return (0.0, 0.0);
374 }
375
376 let proj_x = f * cx / cz;
377 let proj_y = f * cy / cz;
378
379 (proj_x, proj_y)
380 }
381
382 fn compute_total_error(
384 &self,
385 params: &[f64],
386 num_cameras: usize,
387 observations: &[Observation],
388 ) -> AlignResult<f64> {
389 let num_cam_params = num_cameras * 7;
390 let mut total = 0.0_f64;
391
392 for obs in observations {
393 let pt_base = num_cam_params + obs.point_idx * 3;
394 if pt_base + 2 >= params.len() {
395 return Err(AlignError::InvalidConfig(
396 "Point index out of range".to_string(),
397 ));
398 }
399
400 let point = [params[pt_base], params[pt_base + 1], params[pt_base + 2]];
401 let (px, py) = Self::project(params, obs.camera_idx, &point);
402
403 let rx = px - obs.pixel_x;
404 let ry = py - obs.pixel_y;
405 total += rx * rx + ry * ry;
406 }
407
408 Ok(total)
409 }
410
411 fn compute_jacobian_and_residuals(
413 &self,
414 params: &[f64],
415 num_cameras: usize,
416 num_points: usize,
417 observations: &[Observation],
418 ) -> AlignResult<(Vec<f64>, Vec<f64>)> {
419 let num_cam_params = num_cameras * 7;
420 let total_params = num_cam_params + num_points * 3;
421 let num_residuals = observations.len() * 2;
422
423 let mut jacobian = vec![0.0_f64; num_residuals * total_params];
424 let mut residuals = vec![0.0_f64; num_residuals];
425
426 let epsilon = 1e-7;
427
428 for (obs_idx, obs) in observations.iter().enumerate() {
429 let pt_base = num_cam_params + obs.point_idx * 3;
430 let point = [params[pt_base], params[pt_base + 1], params[pt_base + 2]];
431 let (px, py) = Self::project(params, obs.camera_idx, &point);
432
433 let res_base = obs_idx * 2;
434 residuals[res_base] = px - obs.pixel_x;
435 residuals[res_base + 1] = py - obs.pixel_y;
436
437 let cam_base = obs.camera_idx * 7;
439 for p in 0..7 {
440 let param_idx = cam_base + p;
441 let mut params_plus = params.to_vec();
442 params_plus[param_idx] += epsilon;
443
444 let pt_p = [
445 params_plus[pt_base],
446 params_plus[pt_base + 1],
447 params_plus[pt_base + 2],
448 ];
449 let (px_plus, py_plus) = Self::project(¶ms_plus, obs.camera_idx, &pt_p);
450
451 jacobian[res_base * total_params + param_idx] = (px_plus - px) / epsilon;
452 jacobian[(res_base + 1) * total_params + param_idx] = (py_plus - py) / epsilon;
453 }
454
455 for p in 0..3 {
457 let param_idx = pt_base + p;
458 let mut point_plus = point;
459 point_plus[p] += epsilon;
460
461 let (px_plus, py_plus) = Self::project(params, obs.camera_idx, &point_plus);
462
463 jacobian[res_base * total_params + param_idx] = (px_plus - px) / epsilon;
464 jacobian[(res_base + 1) * total_params + param_idx] = (py_plus - py) / epsilon;
465 }
466 }
467
468 Ok((jacobian, residuals))
469 }
470
471 fn compute_jtj(&self, j: &[f64], n_params: usize, n_residuals: usize) -> Vec<f64> {
473 let mut jtj = vec![0.0_f64; n_params * n_params];
474
475 for r in 0..n_residuals {
476 for i in 0..n_params {
477 let ji = j[r * n_params + i];
478 if ji.abs() < 1e-15 {
479 continue;
480 }
481 for k in i..n_params {
482 let jk = j[r * n_params + k];
483 if jk.abs() < 1e-15 {
484 continue;
485 }
486 let val = ji * jk;
487 jtj[i * n_params + k] += val;
488 if i != k {
489 jtj[k * n_params + i] += val;
490 }
491 }
492 }
493 }
494
495 jtj
496 }
497
498 fn compute_jtr(&self, j: &[f64], r: &[f64], n_params: usize, n_residuals: usize) -> Vec<f64> {
500 let mut jtr = vec![0.0_f64; n_params];
501
502 for res in 0..n_residuals {
503 let rv = r[res];
504 if rv.abs() < 1e-15 {
505 continue;
506 }
507 for p in 0..n_params {
508 jtr[p] -= j[res * n_params + p] * rv;
509 }
510 }
511
512 jtr
513 }
514
515 fn solve_normal_equations(
517 &self,
518 jtj: &[f64],
519 jtr: &[f64],
520 n: usize,
521 lambda: f64,
522 ) -> AlignResult<Vec<f64>> {
523 let mut a = jtj.to_vec();
525 for i in 0..n {
526 a[i * n + i] += lambda * a[i * n + i].max(1e-6);
527 }
528
529 let mut b = jtr.to_vec();
531
532 for col in 0..n {
533 let mut max_row = col;
535 let mut max_val = a[col * n + col].abs();
536 for row in (col + 1)..n {
537 let val = a[row * n + col].abs();
538 if val > max_val {
539 max_val = val;
540 max_row = row;
541 }
542 }
543
544 if max_val < 1e-14 {
545 continue;
547 }
548
549 if max_row != col {
551 for j in 0..n {
552 a.swap(col * n + j, max_row * n + j);
553 }
554 b.swap(col, max_row);
555 }
556
557 let pivot = a[col * n + col];
559 for row in (col + 1)..n {
560 let factor = a[row * n + col] / pivot;
561 for j in col..n {
562 a[row * n + j] -= factor * a[col * n + j];
563 }
564 b[row] -= factor * b[col];
565 }
566 }
567
568 let mut x = vec![0.0_f64; n];
570 for col in (0..n).rev() {
571 if a[col * n + col].abs() < 1e-14 {
572 continue;
573 }
574 let mut sum = b[col];
575 for j in (col + 1)..n {
576 sum -= a[col * n + j] * x[j];
577 }
578 x[col] = sum / a[col * n + col];
579 }
580
581 Ok(x)
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_config_default() {
591 let config = BundleAdjustConfig::default();
592 assert_eq!(config.max_iterations, 50);
593 assert!((config.initial_lambda - 1e-3).abs() < 1e-10);
594 }
595
596 #[test]
597 fn test_camera_params_identity() {
598 let cam = CameraParams::identity();
599 assert_eq!(cam.rotation, [0.0; 3]);
600 assert_eq!(cam.translation, [0.0; 3]);
601 assert!((cam.focal_length - 1.0).abs() < 1e-10);
602 }
603
604 #[test]
605 fn test_point3d_creation() {
606 let pt = Point3D::new(1.0, 2.0, 3.0);
607 assert!((pt.x - 1.0).abs() < 1e-10);
608 assert!((pt.y - 2.0).abs() < 1e-10);
609 assert!((pt.z - 3.0).abs() < 1e-10);
610 }
611
612 #[test]
613 fn test_observation_creation() {
614 let obs = Observation::new(0, 1, 100.0, 200.0);
615 assert_eq!(obs.camera_idx, 0);
616 assert_eq!(obs.point_idx, 1);
617 }
618
619 #[test]
620 fn test_empty_cameras_error() {
621 let ba = BundleAdjuster::default();
622 let result = ba.optimize(&[], &[Point3D::new(0.0, 0.0, 1.0)], &[]);
623 assert!(result.is_err());
624 }
625
626 #[test]
627 fn test_empty_points_error() {
628 let ba = BundleAdjuster::default();
629 let result = ba.optimize(&[CameraParams::identity()], &[], &[]);
630 assert!(result.is_err());
631 }
632
633 #[test]
634 fn test_empty_observations_error() {
635 let ba = BundleAdjuster::default();
636 let result = ba.optimize(
637 &[CameraParams::identity()],
638 &[Point3D::new(0.0, 0.0, 1.0)],
639 &[],
640 );
641 assert!(result.is_err());
642 }
643
644 #[test]
645 fn test_projection_identity_camera() {
646 let params = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0];
648 let point = [1.0, 2.0, 5.0];
649 let (px, py) = BundleAdjuster::project(¶ms, 0, &point);
650 assert!((px - 0.2).abs() < 1e-6, "px={px}");
652 assert!((py - 0.4).abs() < 1e-6, "py={py}");
653 }
654
655 #[test]
656 fn test_projection_with_translation() {
657 let params = vec![0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0];
659 let point = [1.0, 0.0, 5.0];
660 let (px, _py) = BundleAdjuster::project(¶ms, 0, &point);
661 assert!((px - 0.4).abs() < 1e-6, "px={px}");
663 }
664
665 #[test]
666 fn test_simple_optimization() {
667 let cameras = vec![CameraParams::new([0.0; 3], [0.0; 3], 100.0)];
671
672 let points = vec![
673 Point3D::new(0.5, 0.5, 5.0),
674 Point3D::new(-0.5, 0.3, 4.0),
675 Point3D::new(0.0, -0.5, 6.0),
676 ];
677
678 let params: Vec<f64> = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 100.0];
680 let mut observations = Vec::new();
681 for (i, pt) in points.iter().enumerate() {
682 let point = [pt.x, pt.y, pt.z];
683 let (px, py) = BundleAdjuster::project(¶ms, 0, &point);
684 observations.push(Observation::new(0, i, px, py));
685 }
686
687 let ba = BundleAdjuster::new(BundleAdjustConfig {
688 max_iterations: 20,
689 ..BundleAdjustConfig::default()
690 });
691
692 let result = ba
693 .optimize(&cameras, &points, &observations)
694 .expect("should succeed");
695
696 assert!(
698 result.final_error < 1.0,
699 "final_error={}",
700 result.final_error
701 );
702 }
703
704 #[test]
705 fn test_optimization_with_perturbation() {
706 let true_cameras = vec![CameraParams::new([0.0; 3], [0.0; 3], 100.0)];
708 let true_points = vec![
709 Point3D::new(1.0, 0.0, 5.0),
710 Point3D::new(-1.0, 0.0, 5.0),
711 Point3D::new(0.0, 1.0, 5.0),
712 Point3D::new(0.0, -1.0, 5.0),
713 ];
714
715 let true_params: Vec<f64> = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 100.0];
717 let mut observations = Vec::new();
718 for (i, pt) in true_points.iter().enumerate() {
719 let point = [pt.x, pt.y, pt.z];
720 let (px, py) = BundleAdjuster::project(&true_params, 0, &point);
721 observations.push(Observation::new(0, i, px, py));
722 }
723
724 let perturbed_points = vec![
726 Point3D::new(1.1, 0.1, 5.1),
727 Point3D::new(-0.9, 0.1, 4.9),
728 Point3D::new(0.1, 1.1, 5.1),
729 Point3D::new(0.1, -0.9, 4.9),
730 ];
731
732 let ba = BundleAdjuster::new(BundleAdjustConfig {
733 max_iterations: 30,
734 ..BundleAdjustConfig::default()
735 });
736
737 let result = ba
738 .optimize(&true_cameras, &perturbed_points, &observations)
739 .expect("should succeed");
740
741 assert!(
743 result.final_error < 10.0,
744 "final_error={}",
745 result.final_error
746 );
747 }
748
749 #[test]
750 fn test_two_camera_optimization() {
751 let cameras = vec![
753 CameraParams::new([0.0; 3], [0.0, 0.0, 0.0], 100.0),
754 CameraParams::new([0.0; 3], [2.0, 0.0, 0.0], 100.0),
755 ];
756 let points = vec![Point3D::new(1.0, 0.0, 5.0)];
757
758 let params1: Vec<f64> = vec![
759 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 100.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 100.0,
762 ];
763 let point = [1.0, 0.0, 5.0];
764
765 let (px0, py0) = BundleAdjuster::project(¶ms1, 0, &point);
766 let (px1, py1) = BundleAdjuster::project(¶ms1, 1, &point);
767
768 let observations = vec![
769 Observation::new(0, 0, px0, py0),
770 Observation::new(1, 0, px1, py1),
771 ];
772
773 let ba = BundleAdjuster::default();
774 let result = ba
775 .optimize(&cameras, &points, &observations)
776 .expect("should succeed");
777
778 assert!(
779 result.final_error < 1.0,
780 "final_error={}",
781 result.final_error
782 );
783 assert_eq!(result.cameras.len(), 2);
784 assert_eq!(result.points.len(), 1);
785 }
786
787 #[test]
788 fn test_bundle_adjust_result_fields() {
789 let result = BundleAdjustResult {
790 cameras: vec![CameraParams::identity()],
791 points: vec![Point3D::new(0.0, 0.0, 1.0)],
792 final_error: 0.1,
793 iterations: 5,
794 converged: true,
795 };
796 assert!(result.converged);
797 assert_eq!(result.iterations, 5);
798 }
799}