1pub mod deformable;
15
16use scirs2_core::ndarray::{Array1, Array2, Axis};
17use scirs2_core::numeric::Complex64;
18use scirs2_fft::{fft2, fftfreq, ifft2};
19use std::f64::consts::PI;
20
21use crate::error::{NdimageError, NdimageResult};
22
23#[derive(Debug, Clone)]
29pub struct TranslationResult {
30 pub shift_y: f64,
32 pub shift_x: f64,
34 pub peak_value: f64,
36}
37
38#[derive(Debug, Clone)]
45pub struct AffineTransform2D {
46 pub matrix: Array2<f64>,
48 pub residual: f64,
50}
51
52#[derive(Debug, Clone)]
54pub struct RigidTransform2D {
55 pub angle: f64,
57 pub tx: f64,
59 pub ty: f64,
61 pub residual: f64,
63}
64
65#[derive(Debug, Clone)]
67pub struct IcpResult {
68 pub transform: RigidTransform2D,
70 pub iterations: usize,
72 pub mse_history: Vec<f64>,
74 pub converged: bool,
76}
77
78#[derive(Debug, Clone)]
80pub struct IcpConfig {
81 pub max_iterations: usize,
83 pub tolerance: f64,
85 pub max_distance: Option<f64>,
87}
88
89impl Default for IcpConfig {
90 fn default() -> Self {
91 Self {
92 max_iterations: 100,
93 tolerance: 1e-8,
94 max_distance: None,
95 }
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct PyramidConfig {
102 pub levels: usize,
104 pub scale_factor: f64,
106}
107
108impl Default for PyramidConfig {
109 fn default() -> Self {
110 Self {
111 levels: 3,
112 scale_factor: 2.0,
113 }
114 }
115}
116
117#[derive(Debug, Clone)]
119pub struct RegistrationMetrics {
120 pub tre: f64,
123 pub mutual_information: f64,
125 pub ncc: f64,
127}
128
129pub fn phase_correlation(
148 reference: &Array2<f64>,
149 moving: &Array2<f64>,
150) -> NdimageResult<TranslationResult> {
151 let (ny, nx) = reference.dim();
152 if moving.dim() != (ny, nx) {
153 return Err(NdimageError::DimensionError(format!(
154 "Image shapes must match: reference ({},{}) vs moving ({},{})",
155 ny,
156 nx,
157 moving.nrows(),
158 moving.ncols()
159 )));
160 }
161 if ny == 0 || nx == 0 {
162 return Err(NdimageError::InvalidInput(
163 "Images must be non-empty".into(),
164 ));
165 }
166
167 let spec_ref = fft2(reference, None, None, None)
169 .map_err(|e| NdimageError::ComputationError(format!("FFT of reference failed: {}", e)))?;
170 let spec_mov = fft2(moving, None, None, None).map_err(|e| {
171 NdimageError::ComputationError(format!("FFT of moving image failed: {}", e))
172 })?;
173
174 let mut cross_power = Array2::<Complex64>::zeros((ny, nx));
176 for i in 0..ny {
177 for j in 0..nx {
178 let prod = spec_ref[[i, j]].conj() * spec_mov[[i, j]];
179 let mag = prod.norm();
180 cross_power[[i, j]] = if mag > 1e-15 {
181 prod / mag
182 } else {
183 Complex64::new(0.0, 0.0)
184 };
185 }
186 }
187
188 let corr_complex = ifft2(&cross_power, None, None, None).map_err(|e| {
190 NdimageError::ComputationError(format!("IFFT of cross-power failed: {}", e))
191 })?;
192
193 let mut best_val = f64::NEG_INFINITY;
195 let mut best_i = 0usize;
196 let mut best_j = 0usize;
197 for i in 0..ny {
198 for j in 0..nx {
199 let v = corr_complex[[i, j]].re;
200 if v > best_val {
201 best_val = v;
202 best_i = i;
203 best_j = j;
204 }
205 }
206 }
207
208 let sub_y = subpixel_1d(
210 corr_complex[[(best_i + ny - 1) % ny, best_j]].re,
211 best_val,
212 corr_complex[[(best_i + 1) % ny, best_j]].re,
213 );
214 let sub_x = subpixel_1d(
215 corr_complex[[best_i, (best_j + nx - 1) % nx]].re,
216 best_val,
217 corr_complex[[best_i, (best_j + 1) % nx]].re,
218 );
219
220 let shift_y = if best_i as f64 + sub_y > ny as f64 / 2.0 {
222 best_i as f64 + sub_y - ny as f64
223 } else {
224 best_i as f64 + sub_y
225 };
226 let shift_x = if best_j as f64 + sub_x > nx as f64 / 2.0 {
227 best_j as f64 + sub_x - nx as f64
228 } else {
229 best_j as f64 + sub_x
230 };
231
232 Ok(TranslationResult {
233 shift_y,
234 shift_x,
235 peak_value: best_val,
236 })
237}
238
239fn subpixel_1d(y_minus: f64, y_center: f64, y_plus: f64) -> f64 {
242 let denom = 2.0 * (2.0 * y_center - y_minus - y_plus);
243 if denom.abs() < 1e-15 {
244 0.0
245 } else {
246 (y_minus - y_plus) / denom
247 }
248}
249
250pub fn affine_registration(
264 source: &Array2<f64>,
265 target: &Array2<f64>,
266) -> NdimageResult<AffineTransform2D> {
267 let n = source.nrows();
268 if n < 3 {
269 return Err(NdimageError::InvalidInput(
270 "Need at least 3 point pairs for affine registration".into(),
271 ));
272 }
273 if source.ncols() != 2 || target.ncols() != 2 {
274 return Err(NdimageError::InvalidInput(
275 "Point arrays must have 2 columns (x, y)".into(),
276 ));
277 }
278 if target.nrows() != n {
279 return Err(NdimageError::DimensionError(
280 "source and target must have the same number of rows".into(),
281 ));
282 }
283
284 let m = 2 * n;
291 let mut a_mat = Array2::<f64>::zeros((m, 6));
292 let mut b_vec = Array1::<f64>::zeros(m);
293
294 for k in 0..n {
295 let sx = source[[k, 0]];
296 let sy = source[[k, 1]];
297 let r0 = 2 * k;
299 a_mat[[r0, 0]] = sx;
300 a_mat[[r0, 1]] = sy;
301 a_mat[[r0, 2]] = 1.0;
302 b_vec[r0] = target[[k, 0]];
303 let r1 = 2 * k + 1;
305 a_mat[[r1, 3]] = sx;
306 a_mat[[r1, 4]] = sy;
307 a_mat[[r1, 5]] = 1.0;
308 b_vec[r1] = target[[k, 1]];
309 }
310
311 let ata = a_mat.t().dot(&a_mat);
313 let atb = a_mat.t().dot(&b_vec);
314
315 let params = solve_6x6(&ata, &atb)?;
316
317 let mut matrix = Array2::<f64>::zeros((3, 3));
319 matrix[[0, 0]] = params[0];
320 matrix[[0, 1]] = params[1];
321 matrix[[0, 2]] = params[2];
322 matrix[[1, 0]] = params[3];
323 matrix[[1, 1]] = params[4];
324 matrix[[1, 2]] = params[5];
325 matrix[[2, 2]] = 1.0;
326
327 let predicted = a_mat.dot(¶ms);
329 let diff = &predicted - &b_vec;
330 let residual = diff.dot(&diff) / n as f64;
331
332 Ok(AffineTransform2D { matrix, residual })
333}
334
335fn solve_6x6(ata: &Array2<f64>, atb: &Array1<f64>) -> NdimageResult<Array1<f64>> {
337 let n = 6;
338 let mut l_mat = Array2::<f64>::zeros((n, n));
340 for i in 0..n {
341 for j in 0..=i {
342 let mut s = 0.0;
343 for k in 0..j {
344 s += l_mat[[i, k]] * l_mat[[j, k]];
345 }
346 if i == j {
347 let diag = ata[[i, i]] - s;
348 if diag <= 0.0 {
349 return Err(NdimageError::ComputationError(
350 "Matrix is not positive-definite (collinear points?)".into(),
351 ));
352 }
353 l_mat[[i, j]] = diag.sqrt();
354 } else {
355 l_mat[[i, j]] = (ata[[i, j]] - s) / l_mat[[j, j]];
356 }
357 }
358 }
359
360 let mut y = Array1::<f64>::zeros(n);
362 for i in 0..n {
363 let mut s = 0.0;
364 for k in 0..i {
365 s += l_mat[[i, k]] * y[k];
366 }
367 y[i] = (atb[i] - s) / l_mat[[i, i]];
368 }
369
370 let mut x = Array1::<f64>::zeros(n);
372 for i in (0..n).rev() {
373 let mut s = 0.0;
374 for k in (i + 1)..n {
375 s += l_mat[[k, i]] * x[k];
376 }
377 x[i] = (y[i] - s) / l_mat[[i, i]];
378 }
379
380 Ok(x)
381}
382
383pub fn rigid_registration(
394 source: &Array2<f64>,
395 target: &Array2<f64>,
396) -> NdimageResult<RigidTransform2D> {
397 let n = source.nrows();
398 if n < 2 {
399 return Err(NdimageError::InvalidInput(
400 "Need at least 2 point pairs for rigid registration".into(),
401 ));
402 }
403 if source.ncols() != 2 || target.ncols() != 2 {
404 return Err(NdimageError::InvalidInput(
405 "Point arrays must have 2 columns (x, y)".into(),
406 ));
407 }
408 if target.nrows() != n {
409 return Err(NdimageError::DimensionError(
410 "source and target must have the same number of rows".into(),
411 ));
412 }
413
414 let src_mean = source.mean_axis(Axis(0)).ok_or_else(|| {
416 NdimageError::ComputationError("Failed to compute source centroid".into())
417 })?;
418 let tgt_mean = target.mean_axis(Axis(0)).ok_or_else(|| {
419 NdimageError::ComputationError("Failed to compute target centroid".into())
420 })?;
421
422 let src_centered = source - &src_mean.view().insert_axis(Axis(0));
424 let tgt_centered = target - &tgt_mean.view().insert_axis(Axis(0));
425
426 let h = src_centered.t().dot(&tgt_centered);
428
429 let (u, _s, vt) = svd_2x2(h[[0, 0]], h[[0, 1]], h[[1, 0]], h[[1, 1]]);
431
432 let det = (u[[0, 0]] * u[[1, 1]] - u[[0, 1]] * u[[1, 0]])
435 * (vt[[0, 0]] * vt[[1, 1]] - vt[[0, 1]] * vt[[1, 0]]);
436 let sign = if det < 0.0 { -1.0 } else { 1.0 };
437
438 let mut d_mat = Array2::<f64>::zeros((2, 2));
439 d_mat[[0, 0]] = 1.0;
440 d_mat[[1, 1]] = sign;
441
442 let rot = vt.t().dot(&d_mat).dot(&u.t());
443 let angle = rot[[1, 0]].atan2(rot[[0, 0]]);
444
445 let rotated_mean = rot.dot(&src_mean);
447 let tx = tgt_mean[0] - rotated_mean[0];
448 let ty = tgt_mean[1] - rotated_mean[1];
449
450 let transformed = src_centered.dot(&rot.t());
452 let diff = &transformed - &tgt_centered;
453 let mse = diff.mapv(|v| v * v).sum() / n as f64;
454
455 Ok(RigidTransform2D {
456 angle,
457 tx,
458 ty,
459 residual: mse,
460 })
461}
462
463fn svd_2x2(a: f64, b: f64, c: f64, d: f64) -> (Array2<f64>, [f64; 2], Array2<f64>) {
466 let s1_sq = (a * a + b * b + c * c + d * d) / 2.0;
468 let det = a * d - b * c;
469 let tmp =
470 ((a * a + b * b - c * c - d * d).powi(2) + 4.0 * (a * c + b * d).powi(2)).sqrt() / 2.0;
471
472 let sigma1 = (s1_sq + tmp).sqrt();
473 let sigma2 = (s1_sq - tmp).max(0.0).sqrt();
474
475 let ata_00 = a * a + c * c;
477 let ata_01 = a * b + c * d;
478 let ata_11 = b * b + d * d;
479
480 let theta_v = if ata_01.abs() < 1e-15 {
482 0.0
483 } else {
484 0.5 * (2.0 * ata_01).atan2(ata_00 - ata_11)
485 };
486
487 let mut vt = Array2::<f64>::zeros((2, 2));
488 vt[[0, 0]] = theta_v.cos();
489 vt[[0, 1]] = theta_v.sin();
490 vt[[1, 0]] = -theta_v.sin();
491 vt[[1, 1]] = theta_v.cos();
492
493 let mut u = Array2::<f64>::zeros((2, 2));
495 if sigma1 > 1e-15 {
496 u[[0, 0]] = (a * vt[[0, 0]] + b * vt[[0, 1]]) / sigma1;
497 u[[1, 0]] = (c * vt[[0, 0]] + d * vt[[0, 1]]) / sigma1;
498 } else {
499 u[[0, 0]] = 1.0;
500 }
501 if sigma2 > 1e-15 {
502 u[[0, 1]] = (a * vt[[1, 0]] + b * vt[[1, 1]]) / sigma2;
503 u[[1, 1]] = (c * vt[[1, 0]] + d * vt[[1, 1]]) / sigma2;
504 } else {
505 u[[0, 1]] = -u[[1, 0]];
507 u[[1, 1]] = u[[0, 0]];
508 }
509
510 (u, [sigma1, sigma2], vt)
511}
512
513pub fn icp_registration(
526 source: &Array2<f64>,
527 target: &Array2<f64>,
528 config: Option<IcpConfig>,
529) -> NdimageResult<IcpResult> {
530 let cfg = config.unwrap_or_default();
531
532 if source.ncols() != 2 || target.ncols() != 2 {
533 return Err(NdimageError::InvalidInput(
534 "Point arrays must have 2 columns".into(),
535 ));
536 }
537 if source.nrows() < 2 || target.nrows() < 2 {
538 return Err(NdimageError::InvalidInput(
539 "Need at least 2 points in each set".into(),
540 ));
541 }
542
543 let n_src = source.nrows();
544 let mut current = source.to_owned();
545 let mut cum_angle: f64 = 0.0;
546 let mut cum_tx: f64 = 0.0;
547 let mut cum_ty: f64 = 0.0;
548 let mut mse_history = Vec::new();
549 let mut converged = false;
550
551 for iter in 0..cfg.max_iterations {
552 let (correspondences, mse) = find_correspondences(¤t, target, cfg.max_distance)?;
554
555 mse_history.push(mse);
556
557 if iter > 0 {
559 let prev = mse_history[iter - 1];
560 if (prev - mse).abs() < cfg.tolerance {
561 converged = true;
562 break;
563 }
564 }
565
566 if correspondences.is_empty() {
567 return Err(NdimageError::ComputationError(
568 "No valid correspondences found".into(),
569 ));
570 }
571
572 let n_match = correspondences.len();
574 let mut src_matched = Array2::<f64>::zeros((n_match, 2));
575 let mut tgt_matched = Array2::<f64>::zeros((n_match, 2));
576 for (k, &(si, ti)) in correspondences.iter().enumerate() {
577 src_matched[[k, 0]] = current[[si, 0]];
578 src_matched[[k, 1]] = current[[si, 1]];
579 tgt_matched[[k, 0]] = target[[ti, 0]];
580 tgt_matched[[k, 1]] = target[[ti, 1]];
581 }
582
583 let rigid = rigid_registration(&src_matched, &tgt_matched)?;
585
586 let cos_a = rigid.angle.cos();
588 let sin_a = rigid.angle.sin();
589 for k in 0..n_src {
590 let x = current[[k, 0]];
591 let y = current[[k, 1]];
592 current[[k, 0]] = cos_a * x - sin_a * y + rigid.tx;
593 current[[k, 1]] = sin_a * x + cos_a * y + rigid.ty;
594 }
595
596 let old_tx = cum_tx;
598 let old_ty = cum_ty;
599 let old_cos = cum_angle.cos();
600 let old_sin = cum_angle.sin();
601 cum_tx = cos_a * old_tx - sin_a * old_ty + rigid.tx;
602 cum_ty = sin_a * old_tx + cos_a * old_ty + rigid.ty;
603 cum_angle += rigid.angle;
604 }
605
606 let final_iters = mse_history.len();
607
608 Ok(IcpResult {
609 transform: RigidTransform2D {
610 angle: cum_angle,
611 tx: cum_tx,
612 ty: cum_ty,
613 residual: mse_history.last().copied().unwrap_or(f64::INFINITY),
614 },
615 iterations: final_iters,
616 mse_history,
617 converged,
618 })
619}
620
621fn find_correspondences(
624 source: &Array2<f64>,
625 target: &Array2<f64>,
626 max_dist: Option<f64>,
627) -> NdimageResult<(Vec<(usize, usize)>, f64)> {
628 let n_src = source.nrows();
629 let n_tgt = target.nrows();
630 let max_dist_sq = max_dist.map(|d| d * d);
631
632 let mut pairs = Vec::with_capacity(n_src);
633 let mut total_dist_sq = 0.0;
634
635 for si in 0..n_src {
636 let sx = source[[si, 0]];
637 let sy = source[[si, 1]];
638
639 let mut best_dist_sq = f64::INFINITY;
640 let mut best_ti = 0usize;
641
642 for ti in 0..n_tgt {
643 let dx = sx - target[[ti, 0]];
644 let dy = sy - target[[ti, 1]];
645 let d2 = dx * dx + dy * dy;
646 if d2 < best_dist_sq {
647 best_dist_sq = d2;
648 best_ti = ti;
649 }
650 }
651
652 let accept = match max_dist_sq {
653 Some(md2) => best_dist_sq <= md2,
654 None => true,
655 };
656
657 if accept {
658 pairs.push((si, best_ti));
659 total_dist_sq += best_dist_sq;
660 }
661 }
662
663 let mse = if pairs.is_empty() {
664 f64::INFINITY
665 } else {
666 total_dist_sq / pairs.len() as f64
667 };
668
669 Ok((pairs, mse))
670}
671
672pub fn pyramid_registration(
684 reference: &Array2<f64>,
685 moving: &Array2<f64>,
686 config: Option<PyramidConfig>,
687) -> NdimageResult<TranslationResult> {
688 let cfg = config.unwrap_or_default();
689 let (ny, nx) = reference.dim();
690 if moving.dim() != (ny, nx) {
691 return Err(NdimageError::DimensionError(
692 "Images must have the same shape for pyramid registration".into(),
693 ));
694 }
695 if cfg.levels == 0 {
696 return Err(NdimageError::InvalidInput(
697 "Number of pyramid levels must be >= 1".into(),
698 ));
699 }
700 if cfg.scale_factor <= 1.0 {
701 return Err(NdimageError::InvalidInput(
702 "Scale factor must be > 1.0".into(),
703 ));
704 }
705
706 let mut ref_pyramid = vec![reference.clone()];
708 let mut mov_pyramid = vec![moving.clone()];
709 for _ in 1..cfg.levels {
710 let ref_prev = ref_pyramid
711 .last()
712 .ok_or_else(|| NdimageError::ComputationError("Empty pyramid".into()))?;
713 let mov_prev = mov_pyramid
714 .last()
715 .ok_or_else(|| NdimageError::ComputationError("Empty pyramid".into()))?;
716 ref_pyramid.push(downsample_2x(ref_prev));
717 mov_pyramid.push(downsample_2x(mov_prev));
718 }
719
720 let mut cum_shift_y = 0.0;
722 let mut cum_shift_x = 0.0;
723 let mut best_peak = 0.0;
724
725 for level in (0..cfg.levels).rev() {
726 let ref_level = &ref_pyramid[level];
727 let mov_level = &mov_pyramid[level];
728
729 if ref_level.nrows() < 4 || ref_level.ncols() < 4 {
731 continue;
732 }
733
734 let result = phase_correlation(ref_level, mov_level)?;
735
736 if level == cfg.levels - 1 {
737 cum_shift_y = result.shift_y;
739 cum_shift_x = result.shift_x;
740 } else {
741 cum_shift_y = cum_shift_y * 2.0 + result.shift_y;
743 cum_shift_x = cum_shift_x * 2.0 + result.shift_x;
744 }
745 best_peak = result.peak_value;
746 }
747
748 Ok(TranslationResult {
749 shift_y: cum_shift_y,
750 shift_x: cum_shift_x,
751 peak_value: best_peak,
752 })
753}
754
755fn downsample_2x(image: &Array2<f64>) -> Array2<f64> {
757 let (ny, nx) = image.dim();
758 let out_ny = ny / 2;
759 let out_nx = nx / 2;
760 if out_ny == 0 || out_nx == 0 {
761 return Array2::zeros((1.max(out_ny), 1.max(out_nx)));
762 }
763
764 let mut out = Array2::zeros((out_ny, out_nx));
765 for i in 0..out_ny {
766 for j in 0..out_nx {
767 let ii = 2 * i;
768 let jj = 2 * j;
769 out[[i, j]] = (image[[ii, jj]]
770 + image[[ii + 1, jj]]
771 + image[[ii, jj + 1]]
772 + image[[ii + 1, jj + 1]])
773 / 4.0;
774 }
775 }
776 out
777}
778
779pub fn registration_metrics(
793 source_landmarks: Option<&Array2<f64>>,
794 target_landmarks: Option<&Array2<f64>>,
795 reference: Option<&Array2<f64>>,
796 registered: Option<&Array2<f64>>,
797) -> NdimageResult<RegistrationMetrics> {
798 let tre = match (source_landmarks, target_landmarks) {
800 (Some(src), Some(tgt)) => {
801 if src.nrows() != tgt.nrows() {
802 return Err(NdimageError::DimensionError(
803 "Landmark arrays must have the same number of rows".into(),
804 ));
805 }
806 compute_tre(src, tgt)
807 }
808 _ => 0.0,
809 };
810
811 let (ncc, mi) = match (reference, registered) {
813 (Some(ref_img), Some(reg_img)) => {
814 if ref_img.dim() != reg_img.dim() {
815 return Err(NdimageError::DimensionError(
816 "Images must have the same shape for metric computation".into(),
817 ));
818 }
819 let n = compute_ncc(ref_img, reg_img);
820 let m = compute_mutual_information(ref_img, reg_img);
821 (n, m)
822 }
823 _ => (0.0, 0.0),
824 };
825
826 Ok(RegistrationMetrics {
827 tre,
828 mutual_information: mi,
829 ncc,
830 })
831}
832
833fn compute_tre(transformed_src: &Array2<f64>, target: &Array2<f64>) -> f64 {
835 let n = transformed_src.nrows();
836 if n == 0 {
837 return 0.0;
838 }
839 let mut sum_sq = 0.0;
840 for i in 0..n {
841 let dx = transformed_src[[i, 0]] - target[[i, 0]];
842 let dy = transformed_src[[i, 1]] - target[[i, 1]];
843 sum_sq += dx * dx + dy * dy;
844 }
845 (sum_sq / n as f64).sqrt()
846}
847
848fn compute_ncc(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
850 let n = a.len() as f64;
851 if n < 1.0 {
852 return 0.0;
853 }
854 let mean_a = a.sum() / n;
855 let mean_b = b.sum() / n;
856
857 let mut num = 0.0;
858 let mut denom_a = 0.0;
859 let mut denom_b = 0.0;
860
861 for (va, vb) in a.iter().zip(b.iter()) {
862 let da = va - mean_a;
863 let db = vb - mean_b;
864 num += da * db;
865 denom_a += da * da;
866 denom_b += db * db;
867 }
868
869 let denom = (denom_a * denom_b).sqrt();
870 if denom < 1e-15 {
871 0.0
872 } else {
873 num / denom
874 }
875}
876
877fn compute_mutual_information(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
879 let n_bins = 64usize;
880
881 let (mut a_min, mut a_max) = (f64::INFINITY, f64::NEG_INFINITY);
883 let (mut b_min, mut b_max) = (f64::INFINITY, f64::NEG_INFINITY);
884 for (&va, &vb) in a.iter().zip(b.iter()) {
885 if va < a_min {
886 a_min = va;
887 }
888 if va > a_max {
889 a_max = va;
890 }
891 if vb < b_min {
892 b_min = vb;
893 }
894 if vb > b_max {
895 b_max = vb;
896 }
897 }
898
899 let a_range = a_max - a_min;
900 let b_range = b_max - b_min;
901 if a_range < 1e-15 || b_range < 1e-15 {
902 return 0.0;
903 }
904
905 let mut joint = vec![0usize; n_bins * n_bins];
907 let n_total = a.len();
908 let a_scale = (n_bins as f64 - 1e-10) / a_range;
909 let b_scale = (n_bins as f64 - 1e-10) / b_range;
910
911 for (&va, &vb) in a.iter().zip(b.iter()) {
912 let ai = ((va - a_min) * a_scale) as usize;
913 let bi = ((vb - b_min) * b_scale) as usize;
914 let ai = ai.min(n_bins - 1);
915 let bi = bi.min(n_bins - 1);
916 joint[ai * n_bins + bi] += 1;
917 }
918
919 let mut hist_a = vec![0usize; n_bins];
921 let mut hist_b = vec![0usize; n_bins];
922 for ai in 0..n_bins {
923 for bi in 0..n_bins {
924 let c = joint[ai * n_bins + bi];
925 hist_a[ai] += c;
926 hist_b[bi] += c;
927 }
928 }
929
930 let n_f = n_total as f64;
932 let mut mi = 0.0;
933 for ai in 0..n_bins {
934 for bi in 0..n_bins {
935 let pab = joint[ai * n_bins + bi] as f64 / n_f;
936 let pa = hist_a[ai] as f64 / n_f;
937 let pb = hist_b[bi] as f64 / n_f;
938 if pab > 1e-15 && pa > 1e-15 && pb > 1e-15 {
939 mi += pab * (pab / (pa * pb)).ln();
940 }
941 }
942 }
943 mi
944}
945
946pub fn apply_affine_to_points(
953 points: &Array2<f64>,
954 transform: &AffineTransform2D,
955) -> NdimageResult<Array2<f64>> {
956 if points.ncols() != 2 {
957 return Err(NdimageError::InvalidInput(
958 "Points must have 2 columns".into(),
959 ));
960 }
961 let n = points.nrows();
962 let m = &transform.matrix;
963 let mut out = Array2::<f64>::zeros((n, 2));
964 for i in 0..n {
965 let x = points[[i, 0]];
966 let y = points[[i, 1]];
967 out[[i, 0]] = m[[0, 0]] * x + m[[0, 1]] * y + m[[0, 2]];
968 out[[i, 1]] = m[[1, 0]] * x + m[[1, 1]] * y + m[[1, 2]];
969 }
970 Ok(out)
971}
972
973pub fn apply_rigid_to_points(
975 points: &Array2<f64>,
976 transform: &RigidTransform2D,
977) -> NdimageResult<Array2<f64>> {
978 if points.ncols() != 2 {
979 return Err(NdimageError::InvalidInput(
980 "Points must have 2 columns".into(),
981 ));
982 }
983 let n = points.nrows();
984 let cos_a = transform.angle.cos();
985 let sin_a = transform.angle.sin();
986 let mut out = Array2::<f64>::zeros((n, 2));
987 for i in 0..n {
988 let x = points[[i, 0]];
989 let y = points[[i, 1]];
990 out[[i, 0]] = cos_a * x - sin_a * y + transform.tx;
991 out[[i, 1]] = sin_a * x + cos_a * y + transform.ty;
992 }
993 Ok(out)
994}
995
996#[cfg(test)]
1000mod tests {
1001 use super::*;
1002 use scirs2_core::ndarray::Array2;
1003
1004 #[test]
1005 fn test_phase_correlation_no_shift() {
1006 let img = Array2::from_shape_fn((32, 32), |(i, j)| {
1007 ((i as f64 * 0.3).sin() + (j as f64 * 0.5).cos()) * 10.0
1008 });
1009 let result = phase_correlation(&img, &img).expect("phase_correlation failed");
1010 assert!(
1011 result.shift_y.abs() < 1.0,
1012 "shift_y should be ~0, got {}",
1013 result.shift_y
1014 );
1015 assert!(
1016 result.shift_x.abs() < 1.0,
1017 "shift_x should be ~0, got {}",
1018 result.shift_x
1019 );
1020 }
1021
1022 #[test]
1023 fn test_phase_correlation_known_shift() {
1024 let ny = 64;
1026 let nx = 64;
1027 let reference = Array2::from_shape_fn((ny, nx), |(i, j)| {
1028 ((i as f64 / 8.0).sin() * (j as f64 / 8.0).cos()) * 100.0
1029 });
1030 let mut moved = Array2::zeros((ny, nx));
1032 for i in 0..ny {
1033 for j in 0..nx {
1034 moved[[(i + 3) % ny, (j + 5) % nx]] = reference[[i, j]];
1035 }
1036 }
1037 let result = phase_correlation(&reference, &moved).expect("phase_correlation failed");
1038 assert!(
1039 (result.shift_y - 3.0).abs() < 1.5,
1040 "shift_y ~ 3, got {}",
1041 result.shift_y
1042 );
1043 assert!(
1044 (result.shift_x - 5.0).abs() < 1.5,
1045 "shift_x ~ 5, got {}",
1046 result.shift_x
1047 );
1048 }
1049
1050 #[test]
1051 fn test_affine_registration_identity() {
1052 let pts = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
1053 .expect("shape error");
1054 let result = affine_registration(&pts, &pts).expect("affine_registration failed");
1055 assert!((result.matrix[[0, 0]] - 1.0).abs() < 1e-10);
1057 assert!((result.matrix[[1, 1]] - 1.0).abs() < 1e-10);
1058 assert!(result.residual < 1e-10);
1059 }
1060
1061 #[test]
1062 fn test_affine_registration_translation() {
1063 let src = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
1064 .expect("shape error");
1065 let tgt = Array2::from_shape_vec((4, 2), vec![3.0, 2.0, 4.0, 2.0, 3.0, 3.0, 4.0, 3.0])
1066 .expect("shape error");
1067 let result = affine_registration(&src, &tgt).expect("affine_registration failed");
1068 assert!((result.matrix[[0, 2]] - 3.0).abs() < 1e-8, "tx ~ 3");
1069 assert!((result.matrix[[1, 2]] - 2.0).abs() < 1e-8, "ty ~ 2");
1070 }
1071
1072 #[test]
1073 fn test_rigid_registration_identity() {
1074 let pts = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
1075 .expect("shape error");
1076 let result = rigid_registration(&pts, &pts).expect("rigid_registration failed");
1077 assert!(result.angle.abs() < 1e-8);
1078 assert!(result.tx.abs() < 1e-8);
1079 assert!(result.ty.abs() < 1e-8);
1080 }
1081
1082 #[test]
1083 fn test_rigid_registration_translation() {
1084 let src = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
1085 .expect("shape error");
1086 let tgt = Array2::from_shape_vec((4, 2), vec![5.0, 3.0, 6.0, 3.0, 5.0, 4.0, 6.0, 4.0])
1087 .expect("shape error");
1088 let result = rigid_registration(&src, &tgt).expect("rigid_registration failed");
1089 assert!(
1090 result.angle.abs() < 1e-8,
1091 "no rotation expected, got {}",
1092 result.angle
1093 );
1094 assert!((result.tx - 5.0).abs() < 1e-6, "tx ~ 5, got {}", result.tx);
1095 assert!((result.ty - 3.0).abs() < 1e-6, "ty ~ 3, got {}", result.ty);
1096 }
1097
1098 #[test]
1099 fn test_rigid_registration_rotation() {
1100 let angle = PI / 6.0; let cos_a = angle.cos();
1102 let sin_a = angle.sin();
1103 let src = Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, -1.0])
1104 .expect("shape error");
1105 let mut tgt = Array2::zeros((4, 2));
1107 for i in 0..4 {
1108 let x = src[[i, 0]];
1109 let y = src[[i, 1]];
1110 tgt[[i, 0]] = cos_a * x - sin_a * y;
1111 tgt[[i, 1]] = sin_a * x + cos_a * y;
1112 }
1113 let result = rigid_registration(&src, &tgt).expect("rigid_registration failed");
1114 assert!(
1115 (result.angle - angle).abs() < 1e-6,
1116 "angle ~ pi/6, got {}",
1117 result.angle
1118 );
1119 }
1120
1121 #[test]
1122 fn test_icp_registration() {
1123 let src = Array2::from_shape_vec(
1126 (9, 2),
1127 vec![
1128 0.0, 0.0, 10.0, 0.0, 20.0, 0.0, 0.0, 10.0, 10.0, 10.0, 20.0, 10.0, 0.0, 20.0, 10.0,
1129 20.0, 20.0, 20.0,
1130 ],
1131 )
1132 .expect("shape error");
1133 let mut tgt = src.clone();
1134 let shift_x = 1.5;
1136 let shift_y = 2.0;
1137 for i in 0..tgt.nrows() {
1138 tgt[[i, 0]] += shift_x;
1139 tgt[[i, 1]] += shift_y;
1140 }
1141
1142 let result = icp_registration(&src, &tgt, None).expect("icp failed");
1143 assert!(
1144 (result.transform.tx - shift_x).abs() < 0.5,
1145 "tx ~ {}, got {}",
1146 shift_x,
1147 result.transform.tx
1148 );
1149 assert!(
1150 (result.transform.ty - shift_y).abs() < 0.5,
1151 "ty ~ {}, got {}",
1152 shift_y,
1153 result.transform.ty
1154 );
1155 assert!(result.converged, "ICP should converge");
1156 }
1157
1158 #[test]
1159 fn test_pyramid_registration_no_shift() {
1160 let img = Array2::from_shape_fn((64, 64), |(i, j)| {
1161 ((i as f64 / 10.0).sin() + (j as f64 / 10.0).cos()) * 50.0
1162 });
1163 let result = pyramid_registration(&img, &img, None).expect("pyramid failed");
1164 assert!(
1165 result.shift_y.abs() < 2.0,
1166 "shift_y ~ 0, got {}",
1167 result.shift_y
1168 );
1169 assert!(
1170 result.shift_x.abs() < 2.0,
1171 "shift_x ~ 0, got {}",
1172 result.shift_x
1173 );
1174 }
1175
1176 #[test]
1177 fn test_registration_metrics_perfect() {
1178 let pts = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1179 .expect("shape error");
1180 let metrics =
1181 registration_metrics(Some(&pts), Some(&pts), None, None).expect("metrics failed");
1182 assert!(
1183 metrics.tre < 1e-10,
1184 "TRE should be 0 for identical landmarks"
1185 );
1186 }
1187
1188 #[test]
1189 fn test_registration_metrics_ncc() {
1190 let img = Array2::from_shape_fn((16, 16), |(i, j)| (i + j) as f64);
1191 let metrics =
1192 registration_metrics(None, None, Some(&img), Some(&img)).expect("metrics failed");
1193 assert!(
1194 (metrics.ncc - 1.0).abs() < 1e-10,
1195 "NCC should be 1 for identical images"
1196 );
1197 }
1198
1199 #[test]
1200 fn test_registration_metrics_mi() {
1201 let img = Array2::from_shape_fn((32, 32), |(i, j)| (i * j) as f64);
1202 let metrics =
1203 registration_metrics(None, None, Some(&img), Some(&img)).expect("metrics failed");
1204 assert!(metrics.mutual_information > 0.0, "MI should be positive");
1206 }
1207
1208 #[test]
1209 fn test_apply_affine_to_points() {
1210 let pts = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).expect("shape error");
1211 let mut mat = Array2::<f64>::zeros((3, 3));
1212 mat[[0, 0]] = 1.0;
1213 mat[[1, 1]] = 1.0;
1214 mat[[0, 2]] = 10.0; mat[[1, 2]] = 20.0; mat[[2, 2]] = 1.0;
1217 let tf = AffineTransform2D {
1218 matrix: mat,
1219 residual: 0.0,
1220 };
1221 let result = apply_affine_to_points(&pts, &tf).expect("apply affine failed");
1222 assert!((result[[0, 0]] - 11.0).abs() < 1e-10);
1223 assert!((result[[0, 1]] - 20.0).abs() < 1e-10);
1224 assert!((result[[1, 0]] - 10.0).abs() < 1e-10);
1225 assert!((result[[1, 1]] - 21.0).abs() < 1e-10);
1226 }
1227
1228 #[test]
1229 fn test_apply_rigid_to_points() {
1230 let pts = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).expect("shape error");
1231 let tf = RigidTransform2D {
1232 angle: PI / 2.0,
1233 tx: 0.0,
1234 ty: 0.0,
1235 residual: 0.0,
1236 };
1237 let result = apply_rigid_to_points(&pts, &tf).expect("apply rigid failed");
1238 assert!(result[[0, 0]].abs() < 1e-10, "x ~ 0 after 90-deg rotation");
1239 assert!(
1240 (result[[0, 1]] - 1.0).abs() < 1e-10,
1241 "y ~ 1 after 90-deg rotation"
1242 );
1243 }
1244
1245 #[test]
1246 fn test_downsample_2x() {
1247 let img = Array2::from_shape_fn((8, 8), |(i, j)| (i * 8 + j) as f64);
1248 let ds = downsample_2x(&img);
1249 assert_eq!(ds.dim(), (4, 4));
1250 assert!((ds[[0, 0]] - 4.5).abs() < 1e-10);
1252 }
1253
1254 #[test]
1255 fn test_phase_correlation_dimension_mismatch() {
1256 let a = Array2::zeros((10, 10));
1257 let b = Array2::zeros((10, 12));
1258 assert!(phase_correlation(&a, &b).is_err());
1259 }
1260
1261 #[test]
1262 fn test_affine_too_few_points() {
1263 let src = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).expect("shape");
1264 let tgt = src.clone();
1265 assert!(affine_registration(&src, &tgt).is_err());
1266 }
1267
1268 #[test]
1269 fn test_rigid_too_few_points() {
1270 let src = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).expect("shape");
1271 let tgt = src.clone();
1272 assert!(rigid_registration(&src, &tgt).is_err());
1273 }
1274}