Skip to main content

scirs2_ndimage/
registration.rs

1//! Image Registration Module
2//!
3//! Provides algorithms for aligning images and point sets:
4//!
5//! - **Phase correlation**: FFT-based sub-pixel translation estimation
6//! - **ICP (Iterative Closest Point)**: Point set registration
7//! - **Affine registration**: Full affine transform fitting (6 DOF in 2D)
8//! - **Rigid registration**: Rotation + translation (3 DOF in 2D)
9//! - **Multi-resolution pyramid**: Coarse-to-fine registration
10//! - **Quality metrics**: TRE, mutual information estimate
11//! - **Deformable registration**: diffeomorphic demons, fluid model, B-spline FFD
12
13/// Deformable image registration (demons, fluid, B-spline FFD).
14pub mod deformable;
15
16use scirs2_core::ndarray::{Array1, Array2, Axis};
17use scirs2_core::numeric::Complex64;
18use scirs2_fft::{fft2, fftfreq, ifft2};
19use std::f64::consts::PI;
20
21use crate::error::{NdimageError, NdimageResult};
22
23// ---------------------------------------------------------------------------
24// Data types
25// ---------------------------------------------------------------------------
26
27/// Result of a translation registration (phase correlation)
28#[derive(Debug, Clone)]
29pub struct TranslationResult {
30    /// Estimated shift along row (y) axis
31    pub shift_y: f64,
32    /// Estimated shift along column (x) axis
33    pub shift_x: f64,
34    /// Peak correlation value (confidence indicator, 0..1 range)
35    pub peak_value: f64,
36}
37
38/// A 2-D affine transform represented as a 3x3 homogeneous matrix.
39///
40/// The matrix maps source coordinates to target coordinates:
41///   \[x'\]   \[a00 a01 a02\] \[x\]
42///   \[y'\] = \[a10 a11 a12\] \[y\]
43///   \[ 1\]   \[ 0   0   1 \] \[1\]
44#[derive(Debug, Clone)]
45pub struct AffineTransform2D {
46    /// 3x3 homogeneous matrix (last row is \[0,0,1\])
47    pub matrix: Array2<f64>,
48    /// Residual (mean squared error) of the fit
49    pub residual: f64,
50}
51
52/// Result of rigid registration (rotation + translation)
53#[derive(Debug, Clone)]
54pub struct RigidTransform2D {
55    /// Rotation angle in radians (counter-clockwise)
56    pub angle: f64,
57    /// Translation along x
58    pub tx: f64,
59    /// Translation along y
60    pub ty: f64,
61    /// Residual (mean squared error)
62    pub residual: f64,
63}
64
65/// Result of ICP registration
66#[derive(Debug, Clone)]
67pub struct IcpResult {
68    /// Final rigid transform
69    pub transform: RigidTransform2D,
70    /// Number of iterations performed
71    pub iterations: usize,
72    /// History of mean squared errors per iteration
73    pub mse_history: Vec<f64>,
74    /// Whether the algorithm converged
75    pub converged: bool,
76}
77
78/// Configuration for ICP
79#[derive(Debug, Clone)]
80pub struct IcpConfig {
81    /// Maximum iterations
82    pub max_iterations: usize,
83    /// Convergence tolerance on MSE change
84    pub tolerance: f64,
85    /// Maximum correspondence distance (points farther than this are rejected)
86    pub max_distance: Option<f64>,
87}
88
89impl Default for IcpConfig {
90    fn default() -> Self {
91        Self {
92            max_iterations: 100,
93            tolerance: 1e-8,
94            max_distance: None,
95        }
96    }
97}
98
99/// Configuration for multi-resolution pyramid registration
100#[derive(Debug, Clone)]
101pub struct PyramidConfig {
102    /// Number of pyramid levels (including the original resolution)
103    pub levels: usize,
104    /// Down-sampling factor between successive levels
105    pub scale_factor: f64,
106}
107
108impl Default for PyramidConfig {
109    fn default() -> Self {
110        Self {
111            levels: 3,
112            scale_factor: 2.0,
113        }
114    }
115}
116
117/// Registration quality metrics
118#[derive(Debug, Clone)]
119pub struct RegistrationMetrics {
120    /// Target Registration Error -- RMS distance between transformed source
121    /// landmarks and corresponding target landmarks
122    pub tre: f64,
123    /// Estimated mutual information (histogram-based, discrete)
124    pub mutual_information: f64,
125    /// Normalized Cross-Correlation
126    pub ncc: f64,
127}
128
129// ---------------------------------------------------------------------------
130// Phase correlation
131// ---------------------------------------------------------------------------
132
133/// Estimate translation between two images using phase correlation.
134///
135/// Computes the cross-power spectrum of the two images, then finds the peak
136/// of its inverse FFT.  The location of the peak gives the integer shift;
137/// sub-pixel refinement is performed via parabolic interpolation.
138///
139/// Both images must have the same shape.
140///
141/// # Arguments
142/// * `reference` - Reference (fixed) image
143/// * `moving`    - Moving (to-be-registered) image
144///
145/// # Returns
146/// A `TranslationResult` with the estimated shift and confidence.
147pub fn phase_correlation(
148    reference: &Array2<f64>,
149    moving: &Array2<f64>,
150) -> NdimageResult<TranslationResult> {
151    let (ny, nx) = reference.dim();
152    if moving.dim() != (ny, nx) {
153        return Err(NdimageError::DimensionError(format!(
154            "Image shapes must match: reference ({},{}) vs moving ({},{})",
155            ny,
156            nx,
157            moving.nrows(),
158            moving.ncols()
159        )));
160    }
161    if ny == 0 || nx == 0 {
162        return Err(NdimageError::InvalidInput(
163            "Images must be non-empty".into(),
164        ));
165    }
166
167    // Forward FFT of both images
168    let spec_ref = fft2(reference, None, None, None)
169        .map_err(|e| NdimageError::ComputationError(format!("FFT of reference failed: {}", e)))?;
170    let spec_mov = fft2(moving, None, None, None).map_err(|e| {
171        NdimageError::ComputationError(format!("FFT of moving image failed: {}", e))
172    })?;
173
174    // Cross-power spectrum:  R = F1* . F2 / |F1* . F2|
175    let mut cross_power = Array2::<Complex64>::zeros((ny, nx));
176    for i in 0..ny {
177        for j in 0..nx {
178            let prod = spec_ref[[i, j]].conj() * spec_mov[[i, j]];
179            let mag = prod.norm();
180            cross_power[[i, j]] = if mag > 1e-15 {
181                prod / mag
182            } else {
183                Complex64::new(0.0, 0.0)
184            };
185        }
186    }
187
188    // Inverse FFT to get the correlation surface
189    let corr_complex = ifft2(&cross_power, None, None, None).map_err(|e| {
190        NdimageError::ComputationError(format!("IFFT of cross-power failed: {}", e))
191    })?;
192
193    // Find peak in the real part
194    let mut best_val = f64::NEG_INFINITY;
195    let mut best_i = 0usize;
196    let mut best_j = 0usize;
197    for i in 0..ny {
198        for j in 0..nx {
199            let v = corr_complex[[i, j]].re;
200            if v > best_val {
201                best_val = v;
202                best_i = i;
203                best_j = j;
204            }
205        }
206    }
207
208    // Sub-pixel refinement via parabolic interpolation along each axis
209    let sub_y = subpixel_1d(
210        corr_complex[[(best_i + ny - 1) % ny, best_j]].re,
211        best_val,
212        corr_complex[[(best_i + 1) % ny, best_j]].re,
213    );
214    let sub_x = subpixel_1d(
215        corr_complex[[best_i, (best_j + nx - 1) % nx]].re,
216        best_val,
217        corr_complex[[best_i, (best_j + 1) % nx]].re,
218    );
219
220    // Convert from FFT index to shift (wrap around center)
221    let shift_y = if best_i as f64 + sub_y > ny as f64 / 2.0 {
222        best_i as f64 + sub_y - ny as f64
223    } else {
224        best_i as f64 + sub_y
225    };
226    let shift_x = if best_j as f64 + sub_x > nx as f64 / 2.0 {
227        best_j as f64 + sub_x - nx as f64
228    } else {
229        best_j as f64 + sub_x
230    };
231
232    Ok(TranslationResult {
233        shift_y,
234        shift_x,
235        peak_value: best_val,
236    })
237}
238
239/// Parabolic sub-pixel refinement: given three consecutive samples
240/// `(y_minus, y_center, y_plus)` around a peak, returns the fractional offset.
241fn subpixel_1d(y_minus: f64, y_center: f64, y_plus: f64) -> f64 {
242    let denom = 2.0 * (2.0 * y_center - y_minus - y_plus);
243    if denom.abs() < 1e-15 {
244        0.0
245    } else {
246        (y_minus - y_plus) / denom
247    }
248}
249
250// ---------------------------------------------------------------------------
251// Affine registration (least-squares)
252// ---------------------------------------------------------------------------
253
254/// Compute a 2-D affine transform that maps `source` points to `target` points
255/// in the least-squares sense.
256///
257/// Each row of `source` / `target` is a point `[x, y]`.
258/// At least 3 non-collinear point pairs are required.
259///
260/// The affine transform is:
261///   x' = a00*x + a01*y + a02
262///   y' = a10*x + a11*y + a12
263pub fn affine_registration(
264    source: &Array2<f64>,
265    target: &Array2<f64>,
266) -> NdimageResult<AffineTransform2D> {
267    let n = source.nrows();
268    if n < 3 {
269        return Err(NdimageError::InvalidInput(
270            "Need at least 3 point pairs for affine registration".into(),
271        ));
272    }
273    if source.ncols() != 2 || target.ncols() != 2 {
274        return Err(NdimageError::InvalidInput(
275            "Point arrays must have 2 columns (x, y)".into(),
276        ));
277    }
278    if target.nrows() != n {
279        return Err(NdimageError::DimensionError(
280            "source and target must have the same number of rows".into(),
281        ));
282    }
283
284    // Build the design matrix A (n*2  x  6) and observation vector b (n*2)
285    // For each point pair (sx, sy) -> (tx, ty):
286    //   tx = a00*sx + a01*sy + a02
287    //   ty = a10*sx + a11*sy + a12
288    //
289    // We solve  A * p = b  with p = [a00 a01 a02 a10 a11 a12]^T
290    let m = 2 * n;
291    let mut a_mat = Array2::<f64>::zeros((m, 6));
292    let mut b_vec = Array1::<f64>::zeros(m);
293
294    for k in 0..n {
295        let sx = source[[k, 0]];
296        let sy = source[[k, 1]];
297        // row for x'
298        let r0 = 2 * k;
299        a_mat[[r0, 0]] = sx;
300        a_mat[[r0, 1]] = sy;
301        a_mat[[r0, 2]] = 1.0;
302        b_vec[r0] = target[[k, 0]];
303        // row for y'
304        let r1 = 2 * k + 1;
305        a_mat[[r1, 3]] = sx;
306        a_mat[[r1, 4]] = sy;
307        a_mat[[r1, 5]] = 1.0;
308        b_vec[r1] = target[[k, 1]];
309    }
310
311    // Solve via normal equations:  A^T A p = A^T b
312    let ata = a_mat.t().dot(&a_mat);
313    let atb = a_mat.t().dot(&b_vec);
314
315    let params = solve_6x6(&ata, &atb)?;
316
317    // Build homogeneous 3x3 matrix
318    let mut matrix = Array2::<f64>::zeros((3, 3));
319    matrix[[0, 0]] = params[0];
320    matrix[[0, 1]] = params[1];
321    matrix[[0, 2]] = params[2];
322    matrix[[1, 0]] = params[3];
323    matrix[[1, 1]] = params[4];
324    matrix[[1, 2]] = params[5];
325    matrix[[2, 2]] = 1.0;
326
327    // Compute residual
328    let predicted = a_mat.dot(&params);
329    let diff = &predicted - &b_vec;
330    let residual = diff.dot(&diff) / n as f64;
331
332    Ok(AffineTransform2D { matrix, residual })
333}
334
335/// Solve a 6x6 symmetric positive-definite system via Cholesky decomposition.
336fn solve_6x6(ata: &Array2<f64>, atb: &Array1<f64>) -> NdimageResult<Array1<f64>> {
337    let n = 6;
338    // Cholesky L such that ata = L * L^T
339    let mut l_mat = Array2::<f64>::zeros((n, n));
340    for i in 0..n {
341        for j in 0..=i {
342            let mut s = 0.0;
343            for k in 0..j {
344                s += l_mat[[i, k]] * l_mat[[j, k]];
345            }
346            if i == j {
347                let diag = ata[[i, i]] - s;
348                if diag <= 0.0 {
349                    return Err(NdimageError::ComputationError(
350                        "Matrix is not positive-definite (collinear points?)".into(),
351                    ));
352                }
353                l_mat[[i, j]] = diag.sqrt();
354            } else {
355                l_mat[[i, j]] = (ata[[i, j]] - s) / l_mat[[j, j]];
356            }
357        }
358    }
359
360    // Forward substitution: L y = atb
361    let mut y = Array1::<f64>::zeros(n);
362    for i in 0..n {
363        let mut s = 0.0;
364        for k in 0..i {
365            s += l_mat[[i, k]] * y[k];
366        }
367        y[i] = (atb[i] - s) / l_mat[[i, i]];
368    }
369
370    // Back substitution: L^T x = y
371    let mut x = Array1::<f64>::zeros(n);
372    for i in (0..n).rev() {
373        let mut s = 0.0;
374        for k in (i + 1)..n {
375            s += l_mat[[k, i]] * x[k];
376        }
377        x[i] = (y[i] - s) / l_mat[[i, i]];
378    }
379
380    Ok(x)
381}
382
383// ---------------------------------------------------------------------------
384// Rigid registration (SVD-based, Umeyama / Procrustes)
385// ---------------------------------------------------------------------------
386
387/// Compute the rigid (rotation + translation) transform that best maps `source`
388/// to `target` in the least-squares sense.
389///
390/// Uses the SVD-based method (Umeyama 1991).
391/// Each row is a 2-D point `[x, y]`.  At least 2 non-coincident point pairs
392/// are needed.
393pub fn rigid_registration(
394    source: &Array2<f64>,
395    target: &Array2<f64>,
396) -> NdimageResult<RigidTransform2D> {
397    let n = source.nrows();
398    if n < 2 {
399        return Err(NdimageError::InvalidInput(
400            "Need at least 2 point pairs for rigid registration".into(),
401        ));
402    }
403    if source.ncols() != 2 || target.ncols() != 2 {
404        return Err(NdimageError::InvalidInput(
405            "Point arrays must have 2 columns (x, y)".into(),
406        ));
407    }
408    if target.nrows() != n {
409        return Err(NdimageError::DimensionError(
410            "source and target must have the same number of rows".into(),
411        ));
412    }
413
414    // Centroids
415    let src_mean = source.mean_axis(Axis(0)).ok_or_else(|| {
416        NdimageError::ComputationError("Failed to compute source centroid".into())
417    })?;
418    let tgt_mean = target.mean_axis(Axis(0)).ok_or_else(|| {
419        NdimageError::ComputationError("Failed to compute target centroid".into())
420    })?;
421
422    // Center the points
423    let src_centered = source - &src_mean.view().insert_axis(Axis(0));
424    let tgt_centered = target - &tgt_mean.view().insert_axis(Axis(0));
425
426    // Cross-covariance matrix H = src_centered^T * tgt_centered  (2x2)
427    let h = src_centered.t().dot(&tgt_centered);
428
429    // SVD of H via closed-form for 2x2
430    let (u, _s, vt) = svd_2x2(h[[0, 0]], h[[0, 1]], h[[1, 0]], h[[1, 1]]);
431
432    // Rotation matrix R = V * U^T
433    // Ensure proper rotation (det > 0)
434    let det = (u[[0, 0]] * u[[1, 1]] - u[[0, 1]] * u[[1, 0]])
435        * (vt[[0, 0]] * vt[[1, 1]] - vt[[0, 1]] * vt[[1, 0]]);
436    let sign = if det < 0.0 { -1.0 } else { 1.0 };
437
438    let mut d_mat = Array2::<f64>::zeros((2, 2));
439    d_mat[[0, 0]] = 1.0;
440    d_mat[[1, 1]] = sign;
441
442    let rot = vt.t().dot(&d_mat).dot(&u.t());
443    let angle = rot[[1, 0]].atan2(rot[[0, 0]]);
444
445    // Translation  t = tgt_mean - R * src_mean
446    let rotated_mean = rot.dot(&src_mean);
447    let tx = tgt_mean[0] - rotated_mean[0];
448    let ty = tgt_mean[1] - rotated_mean[1];
449
450    // Residual
451    let transformed = src_centered.dot(&rot.t());
452    let diff = &transformed - &tgt_centered;
453    let mse = diff.mapv(|v| v * v).sum() / n as f64;
454
455    Ok(RigidTransform2D {
456        angle,
457        tx,
458        ty,
459        residual: mse,
460    })
461}
462
463/// Closed-form 2x2 SVD.
464/// Returns (U, [s1, s2], V^T) such that A = U diag(s) V^T.
465fn svd_2x2(a: f64, b: f64, c: f64, d: f64) -> (Array2<f64>, [f64; 2], Array2<f64>) {
466    // Using the analytical formula for 2x2 SVD
467    let s1_sq = (a * a + b * b + c * c + d * d) / 2.0;
468    let det = a * d - b * c;
469    let tmp =
470        ((a * a + b * b - c * c - d * d).powi(2) + 4.0 * (a * c + b * d).powi(2)).sqrt() / 2.0;
471
472    let sigma1 = (s1_sq + tmp).sqrt();
473    let sigma2 = (s1_sq - tmp).max(0.0).sqrt();
474
475    // A^T A eigenvalues are sigma^2
476    let ata_00 = a * a + c * c;
477    let ata_01 = a * b + c * d;
478    let ata_11 = b * b + d * d;
479
480    // Eigenvectors of A^T A -> columns of V
481    let theta_v = if ata_01.abs() < 1e-15 {
482        0.0
483    } else {
484        0.5 * (2.0 * ata_01).atan2(ata_00 - ata_11)
485    };
486
487    let mut vt = Array2::<f64>::zeros((2, 2));
488    vt[[0, 0]] = theta_v.cos();
489    vt[[0, 1]] = theta_v.sin();
490    vt[[1, 0]] = -theta_v.sin();
491    vt[[1, 1]] = theta_v.cos();
492
493    // U columns from A V / sigma
494    let mut u = Array2::<f64>::zeros((2, 2));
495    if sigma1 > 1e-15 {
496        u[[0, 0]] = (a * vt[[0, 0]] + b * vt[[0, 1]]) / sigma1;
497        u[[1, 0]] = (c * vt[[0, 0]] + d * vt[[0, 1]]) / sigma1;
498    } else {
499        u[[0, 0]] = 1.0;
500    }
501    if sigma2 > 1e-15 {
502        u[[0, 1]] = (a * vt[[1, 0]] + b * vt[[1, 1]]) / sigma2;
503        u[[1, 1]] = (c * vt[[1, 0]] + d * vt[[1, 1]]) / sigma2;
504    } else {
505        // Choose orthogonal column
506        u[[0, 1]] = -u[[1, 0]];
507        u[[1, 1]] = u[[0, 0]];
508    }
509
510    (u, [sigma1, sigma2], vt)
511}
512
513// ---------------------------------------------------------------------------
514// Iterative Closest Point (ICP)
515// ---------------------------------------------------------------------------
516
517/// Register `source` point set to `target` point set using ICP.
518///
519/// Both arrays have shape (N, 2) where each row is `[x, y]`.
520/// The algorithm iteratively:
521///   1. Finds closest target point for each source point
522///   2. Computes the best rigid transform
523///   3. Applies the transform
524///   4. Checks convergence
525pub fn icp_registration(
526    source: &Array2<f64>,
527    target: &Array2<f64>,
528    config: Option<IcpConfig>,
529) -> NdimageResult<IcpResult> {
530    let cfg = config.unwrap_or_default();
531
532    if source.ncols() != 2 || target.ncols() != 2 {
533        return Err(NdimageError::InvalidInput(
534            "Point arrays must have 2 columns".into(),
535        ));
536    }
537    if source.nrows() < 2 || target.nrows() < 2 {
538        return Err(NdimageError::InvalidInput(
539            "Need at least 2 points in each set".into(),
540        ));
541    }
542
543    let n_src = source.nrows();
544    let mut current = source.to_owned();
545    let mut cum_angle: f64 = 0.0;
546    let mut cum_tx: f64 = 0.0;
547    let mut cum_ty: f64 = 0.0;
548    let mut mse_history = Vec::new();
549    let mut converged = false;
550
551    for iter in 0..cfg.max_iterations {
552        // 1. Find correspondences (nearest target for each source)
553        let (correspondences, mse) = find_correspondences(&current, target, cfg.max_distance)?;
554
555        mse_history.push(mse);
556
557        // Check convergence
558        if iter > 0 {
559            let prev = mse_history[iter - 1];
560            if (prev - mse).abs() < cfg.tolerance {
561                converged = true;
562                break;
563            }
564        }
565
566        if correspondences.is_empty() {
567            return Err(NdimageError::ComputationError(
568                "No valid correspondences found".into(),
569            ));
570        }
571
572        // 2. Build matched point sets
573        let n_match = correspondences.len();
574        let mut src_matched = Array2::<f64>::zeros((n_match, 2));
575        let mut tgt_matched = Array2::<f64>::zeros((n_match, 2));
576        for (k, &(si, ti)) in correspondences.iter().enumerate() {
577            src_matched[[k, 0]] = current[[si, 0]];
578            src_matched[[k, 1]] = current[[si, 1]];
579            tgt_matched[[k, 0]] = target[[ti, 0]];
580            tgt_matched[[k, 1]] = target[[ti, 1]];
581        }
582
583        // 3. Compute best rigid transform
584        let rigid = rigid_registration(&src_matched, &tgt_matched)?;
585
586        // 4. Apply transform to all source points
587        let cos_a = rigid.angle.cos();
588        let sin_a = rigid.angle.sin();
589        for k in 0..n_src {
590            let x = current[[k, 0]];
591            let y = current[[k, 1]];
592            current[[k, 0]] = cos_a * x - sin_a * y + rigid.tx;
593            current[[k, 1]] = sin_a * x + cos_a * y + rigid.ty;
594        }
595
596        // Accumulate transform
597        let old_tx = cum_tx;
598        let old_ty = cum_ty;
599        let old_cos = cum_angle.cos();
600        let old_sin = cum_angle.sin();
601        cum_tx = cos_a * old_tx - sin_a * old_ty + rigid.tx;
602        cum_ty = sin_a * old_tx + cos_a * old_ty + rigid.ty;
603        cum_angle += rigid.angle;
604    }
605
606    let final_iters = mse_history.len();
607
608    Ok(IcpResult {
609        transform: RigidTransform2D {
610            angle: cum_angle,
611            tx: cum_tx,
612            ty: cum_ty,
613            residual: mse_history.last().copied().unwrap_or(f64::INFINITY),
614        },
615        iterations: final_iters,
616        mse_history,
617        converged,
618    })
619}
620
621/// Find nearest-neighbor correspondences from `source` to `target`.
622/// Returns pairs of (source_idx, target_idx) and the mean squared distance.
623fn find_correspondences(
624    source: &Array2<f64>,
625    target: &Array2<f64>,
626    max_dist: Option<f64>,
627) -> NdimageResult<(Vec<(usize, usize)>, f64)> {
628    let n_src = source.nrows();
629    let n_tgt = target.nrows();
630    let max_dist_sq = max_dist.map(|d| d * d);
631
632    let mut pairs = Vec::with_capacity(n_src);
633    let mut total_dist_sq = 0.0;
634
635    for si in 0..n_src {
636        let sx = source[[si, 0]];
637        let sy = source[[si, 1]];
638
639        let mut best_dist_sq = f64::INFINITY;
640        let mut best_ti = 0usize;
641
642        for ti in 0..n_tgt {
643            let dx = sx - target[[ti, 0]];
644            let dy = sy - target[[ti, 1]];
645            let d2 = dx * dx + dy * dy;
646            if d2 < best_dist_sq {
647                best_dist_sq = d2;
648                best_ti = ti;
649            }
650        }
651
652        let accept = match max_dist_sq {
653            Some(md2) => best_dist_sq <= md2,
654            None => true,
655        };
656
657        if accept {
658            pairs.push((si, best_ti));
659            total_dist_sq += best_dist_sq;
660        }
661    }
662
663    let mse = if pairs.is_empty() {
664        f64::INFINITY
665    } else {
666        total_dist_sq / pairs.len() as f64
667    };
668
669    Ok((pairs, mse))
670}
671
672// ---------------------------------------------------------------------------
673// Multi-resolution pyramid registration
674// ---------------------------------------------------------------------------
675
676/// Perform multi-resolution pyramid registration using phase correlation at
677/// each level, refining from coarse to fine.
678///
679/// At the coarsest level the shift is estimated on heavily down-sampled images;
680/// that estimate is propagated to the next finer level as an initial guess.
681///
682/// Returns the final sub-pixel translation estimate.
683pub fn pyramid_registration(
684    reference: &Array2<f64>,
685    moving: &Array2<f64>,
686    config: Option<PyramidConfig>,
687) -> NdimageResult<TranslationResult> {
688    let cfg = config.unwrap_or_default();
689    let (ny, nx) = reference.dim();
690    if moving.dim() != (ny, nx) {
691        return Err(NdimageError::DimensionError(
692            "Images must have the same shape for pyramid registration".into(),
693        ));
694    }
695    if cfg.levels == 0 {
696        return Err(NdimageError::InvalidInput(
697            "Number of pyramid levels must be >= 1".into(),
698        ));
699    }
700    if cfg.scale_factor <= 1.0 {
701        return Err(NdimageError::InvalidInput(
702            "Scale factor must be > 1.0".into(),
703        ));
704    }
705
706    // Build pyramid by successive down-sampling
707    let mut ref_pyramid = vec![reference.clone()];
708    let mut mov_pyramid = vec![moving.clone()];
709    for _ in 1..cfg.levels {
710        let ref_prev = ref_pyramid
711            .last()
712            .ok_or_else(|| NdimageError::ComputationError("Empty pyramid".into()))?;
713        let mov_prev = mov_pyramid
714            .last()
715            .ok_or_else(|| NdimageError::ComputationError("Empty pyramid".into()))?;
716        ref_pyramid.push(downsample_2x(ref_prev));
717        mov_pyramid.push(downsample_2x(mov_prev));
718    }
719
720    // Register coarse-to-fine (last element = coarsest)
721    let mut cum_shift_y = 0.0;
722    let mut cum_shift_x = 0.0;
723    let mut best_peak = 0.0;
724
725    for level in (0..cfg.levels).rev() {
726        let ref_level = &ref_pyramid[level];
727        let mov_level = &mov_pyramid[level];
728
729        // If the image is too small, skip
730        if ref_level.nrows() < 4 || ref_level.ncols() < 4 {
731            continue;
732        }
733
734        let result = phase_correlation(ref_level, mov_level)?;
735
736        if level == cfg.levels - 1 {
737            // Coarsest level: use directly
738            cum_shift_y = result.shift_y;
739            cum_shift_x = result.shift_x;
740        } else {
741            // Refine: the coarser estimate is scaled up by 2
742            cum_shift_y = cum_shift_y * 2.0 + result.shift_y;
743            cum_shift_x = cum_shift_x * 2.0 + result.shift_x;
744        }
745        best_peak = result.peak_value;
746    }
747
748    Ok(TranslationResult {
749        shift_y: cum_shift_y,
750        shift_x: cum_shift_x,
751        peak_value: best_peak,
752    })
753}
754
755/// Simple 2x down-sampling by averaging 2x2 blocks.
756fn downsample_2x(image: &Array2<f64>) -> Array2<f64> {
757    let (ny, nx) = image.dim();
758    let out_ny = ny / 2;
759    let out_nx = nx / 2;
760    if out_ny == 0 || out_nx == 0 {
761        return Array2::zeros((1.max(out_ny), 1.max(out_nx)));
762    }
763
764    let mut out = Array2::zeros((out_ny, out_nx));
765    for i in 0..out_ny {
766        for j in 0..out_nx {
767            let ii = 2 * i;
768            let jj = 2 * j;
769            out[[i, j]] = (image[[ii, jj]]
770                + image[[ii + 1, jj]]
771                + image[[ii, jj + 1]]
772                + image[[ii + 1, jj + 1]])
773                / 4.0;
774        }
775    }
776    out
777}
778
779// ---------------------------------------------------------------------------
780// Registration quality metrics
781// ---------------------------------------------------------------------------
782
783/// Compute registration quality metrics.
784///
785/// * `source_landmarks` / `target_landmarks` are Nx2 arrays of corresponding
786///   landmark points *before* and *after* registration of the source image.
787/// * `reference` / `registered` are the reference and the source-after-
788///   registration images (used for NCC and MI).
789///
790/// If landmark arrays are empty, TRE is returned as 0.
791/// If image arrays are empty, NCC and MI are returned as 0.
792pub fn registration_metrics(
793    source_landmarks: Option<&Array2<f64>>,
794    target_landmarks: Option<&Array2<f64>>,
795    reference: Option<&Array2<f64>>,
796    registered: Option<&Array2<f64>>,
797) -> NdimageResult<RegistrationMetrics> {
798    // TRE
799    let tre = match (source_landmarks, target_landmarks) {
800        (Some(src), Some(tgt)) => {
801            if src.nrows() != tgt.nrows() {
802                return Err(NdimageError::DimensionError(
803                    "Landmark arrays must have the same number of rows".into(),
804                ));
805            }
806            compute_tre(src, tgt)
807        }
808        _ => 0.0,
809    };
810
811    // NCC and MI
812    let (ncc, mi) = match (reference, registered) {
813        (Some(ref_img), Some(reg_img)) => {
814            if ref_img.dim() != reg_img.dim() {
815                return Err(NdimageError::DimensionError(
816                    "Images must have the same shape for metric computation".into(),
817                ));
818            }
819            let n = compute_ncc(ref_img, reg_img);
820            let m = compute_mutual_information(ref_img, reg_img);
821            (n, m)
822        }
823        _ => (0.0, 0.0),
824    };
825
826    Ok(RegistrationMetrics {
827        tre,
828        mutual_information: mi,
829        ncc,
830    })
831}
832
833/// Target Registration Error: RMS distance between corresponding landmarks.
834fn compute_tre(transformed_src: &Array2<f64>, target: &Array2<f64>) -> f64 {
835    let n = transformed_src.nrows();
836    if n == 0 {
837        return 0.0;
838    }
839    let mut sum_sq = 0.0;
840    for i in 0..n {
841        let dx = transformed_src[[i, 0]] - target[[i, 0]];
842        let dy = transformed_src[[i, 1]] - target[[i, 1]];
843        sum_sq += dx * dx + dy * dy;
844    }
845    (sum_sq / n as f64).sqrt()
846}
847
848/// Normalized Cross-Correlation between two images.
849fn compute_ncc(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
850    let n = a.len() as f64;
851    if n < 1.0 {
852        return 0.0;
853    }
854    let mean_a = a.sum() / n;
855    let mean_b = b.sum() / n;
856
857    let mut num = 0.0;
858    let mut denom_a = 0.0;
859    let mut denom_b = 0.0;
860
861    for (va, vb) in a.iter().zip(b.iter()) {
862        let da = va - mean_a;
863        let db = vb - mean_b;
864        num += da * db;
865        denom_a += da * da;
866        denom_b += db * db;
867    }
868
869    let denom = (denom_a * denom_b).sqrt();
870    if denom < 1e-15 {
871        0.0
872    } else {
873        num / denom
874    }
875}
876
877/// Estimate mutual information using a joint histogram with 64 bins.
878fn compute_mutual_information(a: &Array2<f64>, b: &Array2<f64>) -> f64 {
879    let n_bins = 64usize;
880
881    // Find intensity ranges
882    let (mut a_min, mut a_max) = (f64::INFINITY, f64::NEG_INFINITY);
883    let (mut b_min, mut b_max) = (f64::INFINITY, f64::NEG_INFINITY);
884    for (&va, &vb) in a.iter().zip(b.iter()) {
885        if va < a_min {
886            a_min = va;
887        }
888        if va > a_max {
889            a_max = va;
890        }
891        if vb < b_min {
892            b_min = vb;
893        }
894        if vb > b_max {
895            b_max = vb;
896        }
897    }
898
899    let a_range = a_max - a_min;
900    let b_range = b_max - b_min;
901    if a_range < 1e-15 || b_range < 1e-15 {
902        return 0.0;
903    }
904
905    // Build joint histogram
906    let mut joint = vec![0usize; n_bins * n_bins];
907    let n_total = a.len();
908    let a_scale = (n_bins as f64 - 1e-10) / a_range;
909    let b_scale = (n_bins as f64 - 1e-10) / b_range;
910
911    for (&va, &vb) in a.iter().zip(b.iter()) {
912        let ai = ((va - a_min) * a_scale) as usize;
913        let bi = ((vb - b_min) * b_scale) as usize;
914        let ai = ai.min(n_bins - 1);
915        let bi = bi.min(n_bins - 1);
916        joint[ai * n_bins + bi] += 1;
917    }
918
919    // Marginal histograms
920    let mut hist_a = vec![0usize; n_bins];
921    let mut hist_b = vec![0usize; n_bins];
922    for ai in 0..n_bins {
923        for bi in 0..n_bins {
924            let c = joint[ai * n_bins + bi];
925            hist_a[ai] += c;
926            hist_b[bi] += c;
927        }
928    }
929
930    // MI = sum p(a,b) * log(p(a,b) / (p(a)*p(b)))
931    let n_f = n_total as f64;
932    let mut mi = 0.0;
933    for ai in 0..n_bins {
934        for bi in 0..n_bins {
935            let pab = joint[ai * n_bins + bi] as f64 / n_f;
936            let pa = hist_a[ai] as f64 / n_f;
937            let pb = hist_b[bi] as f64 / n_f;
938            if pab > 1e-15 && pa > 1e-15 && pb > 1e-15 {
939                mi += pab * (pab / (pa * pb)).ln();
940            }
941        }
942    }
943    mi
944}
945
946// ---------------------------------------------------------------------------
947// Apply transform helpers
948// ---------------------------------------------------------------------------
949
950/// Apply an affine transform to a set of 2-D points.
951/// Each row of `points` is [x, y].
952pub fn apply_affine_to_points(
953    points: &Array2<f64>,
954    transform: &AffineTransform2D,
955) -> NdimageResult<Array2<f64>> {
956    if points.ncols() != 2 {
957        return Err(NdimageError::InvalidInput(
958            "Points must have 2 columns".into(),
959        ));
960    }
961    let n = points.nrows();
962    let m = &transform.matrix;
963    let mut out = Array2::<f64>::zeros((n, 2));
964    for i in 0..n {
965        let x = points[[i, 0]];
966        let y = points[[i, 1]];
967        out[[i, 0]] = m[[0, 0]] * x + m[[0, 1]] * y + m[[0, 2]];
968        out[[i, 1]] = m[[1, 0]] * x + m[[1, 1]] * y + m[[1, 2]];
969    }
970    Ok(out)
971}
972
973/// Apply a rigid transform to a set of 2-D points.
974pub fn apply_rigid_to_points(
975    points: &Array2<f64>,
976    transform: &RigidTransform2D,
977) -> NdimageResult<Array2<f64>> {
978    if points.ncols() != 2 {
979        return Err(NdimageError::InvalidInput(
980            "Points must have 2 columns".into(),
981        ));
982    }
983    let n = points.nrows();
984    let cos_a = transform.angle.cos();
985    let sin_a = transform.angle.sin();
986    let mut out = Array2::<f64>::zeros((n, 2));
987    for i in 0..n {
988        let x = points[[i, 0]];
989        let y = points[[i, 1]];
990        out[[i, 0]] = cos_a * x - sin_a * y + transform.tx;
991        out[[i, 1]] = sin_a * x + cos_a * y + transform.ty;
992    }
993    Ok(out)
994}
995
996// ---------------------------------------------------------------------------
997// Tests
998// ---------------------------------------------------------------------------
999#[cfg(test)]
1000mod tests {
1001    use super::*;
1002    use scirs2_core::ndarray::Array2;
1003
1004    #[test]
1005    fn test_phase_correlation_no_shift() {
1006        let img = Array2::from_shape_fn((32, 32), |(i, j)| {
1007            ((i as f64 * 0.3).sin() + (j as f64 * 0.5).cos()) * 10.0
1008        });
1009        let result = phase_correlation(&img, &img).expect("phase_correlation failed");
1010        assert!(
1011            result.shift_y.abs() < 1.0,
1012            "shift_y should be ~0, got {}",
1013            result.shift_y
1014        );
1015        assert!(
1016            result.shift_x.abs() < 1.0,
1017            "shift_x should be ~0, got {}",
1018            result.shift_x
1019        );
1020    }
1021
1022    #[test]
1023    fn test_phase_correlation_known_shift() {
1024        // Create reference and shifted version
1025        let ny = 64;
1026        let nx = 64;
1027        let reference = Array2::from_shape_fn((ny, nx), |(i, j)| {
1028            ((i as f64 / 8.0).sin() * (j as f64 / 8.0).cos()) * 100.0
1029        });
1030        // Shift by (3, 5) via circular shift
1031        let mut moved = Array2::zeros((ny, nx));
1032        for i in 0..ny {
1033            for j in 0..nx {
1034                moved[[(i + 3) % ny, (j + 5) % nx]] = reference[[i, j]];
1035            }
1036        }
1037        let result = phase_correlation(&reference, &moved).expect("phase_correlation failed");
1038        assert!(
1039            (result.shift_y - 3.0).abs() < 1.5,
1040            "shift_y ~ 3, got {}",
1041            result.shift_y
1042        );
1043        assert!(
1044            (result.shift_x - 5.0).abs() < 1.5,
1045            "shift_x ~ 5, got {}",
1046            result.shift_x
1047        );
1048    }
1049
1050    #[test]
1051    fn test_affine_registration_identity() {
1052        let pts = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
1053            .expect("shape error");
1054        let result = affine_registration(&pts, &pts).expect("affine_registration failed");
1055        // Should be close to identity
1056        assert!((result.matrix[[0, 0]] - 1.0).abs() < 1e-10);
1057        assert!((result.matrix[[1, 1]] - 1.0).abs() < 1e-10);
1058        assert!(result.residual < 1e-10);
1059    }
1060
1061    #[test]
1062    fn test_affine_registration_translation() {
1063        let src = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
1064            .expect("shape error");
1065        let tgt = Array2::from_shape_vec((4, 2), vec![3.0, 2.0, 4.0, 2.0, 3.0, 3.0, 4.0, 3.0])
1066            .expect("shape error");
1067        let result = affine_registration(&src, &tgt).expect("affine_registration failed");
1068        assert!((result.matrix[[0, 2]] - 3.0).abs() < 1e-8, "tx ~ 3");
1069        assert!((result.matrix[[1, 2]] - 2.0).abs() < 1e-8, "ty ~ 2");
1070    }
1071
1072    #[test]
1073    fn test_rigid_registration_identity() {
1074        let pts = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
1075            .expect("shape error");
1076        let result = rigid_registration(&pts, &pts).expect("rigid_registration failed");
1077        assert!(result.angle.abs() < 1e-8);
1078        assert!(result.tx.abs() < 1e-8);
1079        assert!(result.ty.abs() < 1e-8);
1080    }
1081
1082    #[test]
1083    fn test_rigid_registration_translation() {
1084        let src = Array2::from_shape_vec((4, 2), vec![0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 1.0])
1085            .expect("shape error");
1086        let tgt = Array2::from_shape_vec((4, 2), vec![5.0, 3.0, 6.0, 3.0, 5.0, 4.0, 6.0, 4.0])
1087            .expect("shape error");
1088        let result = rigid_registration(&src, &tgt).expect("rigid_registration failed");
1089        assert!(
1090            result.angle.abs() < 1e-8,
1091            "no rotation expected, got {}",
1092            result.angle
1093        );
1094        assert!((result.tx - 5.0).abs() < 1e-6, "tx ~ 5, got {}", result.tx);
1095        assert!((result.ty - 3.0).abs() < 1e-6, "ty ~ 3, got {}", result.ty);
1096    }
1097
1098    #[test]
1099    fn test_rigid_registration_rotation() {
1100        let angle = PI / 6.0; // 30 degrees
1101        let cos_a = angle.cos();
1102        let sin_a = angle.sin();
1103        let src = Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, -1.0])
1104            .expect("shape error");
1105        // Rotate source by 30 degrees around origin
1106        let mut tgt = Array2::zeros((4, 2));
1107        for i in 0..4 {
1108            let x = src[[i, 0]];
1109            let y = src[[i, 1]];
1110            tgt[[i, 0]] = cos_a * x - sin_a * y;
1111            tgt[[i, 1]] = sin_a * x + cos_a * y;
1112        }
1113        let result = rigid_registration(&src, &tgt).expect("rigid_registration failed");
1114        assert!(
1115            (result.angle - angle).abs() < 1e-6,
1116            "angle ~ pi/6, got {}",
1117            result.angle
1118        );
1119    }
1120
1121    #[test]
1122    fn test_icp_registration() {
1123        // Use well-spaced points with a SMALL shift relative to inter-point distance
1124        // so that nearest-neighbor correspondences are correct from the start.
1125        let src = Array2::from_shape_vec(
1126            (9, 2),
1127            vec![
1128                0.0, 0.0, 10.0, 0.0, 20.0, 0.0, 0.0, 10.0, 10.0, 10.0, 20.0, 10.0, 0.0, 20.0, 10.0,
1129                20.0, 20.0, 20.0,
1130            ],
1131        )
1132        .expect("shape error");
1133        let mut tgt = src.clone();
1134        // Small translation (well below half the inter-point distance of 10)
1135        let shift_x = 1.5;
1136        let shift_y = 2.0;
1137        for i in 0..tgt.nrows() {
1138            tgt[[i, 0]] += shift_x;
1139            tgt[[i, 1]] += shift_y;
1140        }
1141
1142        let result = icp_registration(&src, &tgt, None).expect("icp failed");
1143        assert!(
1144            (result.transform.tx - shift_x).abs() < 0.5,
1145            "tx ~ {}, got {}",
1146            shift_x,
1147            result.transform.tx
1148        );
1149        assert!(
1150            (result.transform.ty - shift_y).abs() < 0.5,
1151            "ty ~ {}, got {}",
1152            shift_y,
1153            result.transform.ty
1154        );
1155        assert!(result.converged, "ICP should converge");
1156    }
1157
1158    #[test]
1159    fn test_pyramid_registration_no_shift() {
1160        let img = Array2::from_shape_fn((64, 64), |(i, j)| {
1161            ((i as f64 / 10.0).sin() + (j as f64 / 10.0).cos()) * 50.0
1162        });
1163        let result = pyramid_registration(&img, &img, None).expect("pyramid failed");
1164        assert!(
1165            result.shift_y.abs() < 2.0,
1166            "shift_y ~ 0, got {}",
1167            result.shift_y
1168        );
1169        assert!(
1170            result.shift_x.abs() < 2.0,
1171            "shift_x ~ 0, got {}",
1172            result.shift_x
1173        );
1174    }
1175
1176    #[test]
1177    fn test_registration_metrics_perfect() {
1178        let pts = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
1179            .expect("shape error");
1180        let metrics =
1181            registration_metrics(Some(&pts), Some(&pts), None, None).expect("metrics failed");
1182        assert!(
1183            metrics.tre < 1e-10,
1184            "TRE should be 0 for identical landmarks"
1185        );
1186    }
1187
1188    #[test]
1189    fn test_registration_metrics_ncc() {
1190        let img = Array2::from_shape_fn((16, 16), |(i, j)| (i + j) as f64);
1191        let metrics =
1192            registration_metrics(None, None, Some(&img), Some(&img)).expect("metrics failed");
1193        assert!(
1194            (metrics.ncc - 1.0).abs() < 1e-10,
1195            "NCC should be 1 for identical images"
1196        );
1197    }
1198
1199    #[test]
1200    fn test_registration_metrics_mi() {
1201        let img = Array2::from_shape_fn((32, 32), |(i, j)| (i * j) as f64);
1202        let metrics =
1203            registration_metrics(None, None, Some(&img), Some(&img)).expect("metrics failed");
1204        // MI should be positive for identical images
1205        assert!(metrics.mutual_information > 0.0, "MI should be positive");
1206    }
1207
1208    #[test]
1209    fn test_apply_affine_to_points() {
1210        let pts = Array2::from_shape_vec((2, 2), vec![1.0, 0.0, 0.0, 1.0]).expect("shape error");
1211        let mut mat = Array2::<f64>::zeros((3, 3));
1212        mat[[0, 0]] = 1.0;
1213        mat[[1, 1]] = 1.0;
1214        mat[[0, 2]] = 10.0; // translate x by 10
1215        mat[[1, 2]] = 20.0; // translate y by 20
1216        mat[[2, 2]] = 1.0;
1217        let tf = AffineTransform2D {
1218            matrix: mat,
1219            residual: 0.0,
1220        };
1221        let result = apply_affine_to_points(&pts, &tf).expect("apply affine failed");
1222        assert!((result[[0, 0]] - 11.0).abs() < 1e-10);
1223        assert!((result[[0, 1]] - 20.0).abs() < 1e-10);
1224        assert!((result[[1, 0]] - 10.0).abs() < 1e-10);
1225        assert!((result[[1, 1]] - 21.0).abs() < 1e-10);
1226    }
1227
1228    #[test]
1229    fn test_apply_rigid_to_points() {
1230        let pts = Array2::from_shape_vec((1, 2), vec![1.0, 0.0]).expect("shape error");
1231        let tf = RigidTransform2D {
1232            angle: PI / 2.0,
1233            tx: 0.0,
1234            ty: 0.0,
1235            residual: 0.0,
1236        };
1237        let result = apply_rigid_to_points(&pts, &tf).expect("apply rigid failed");
1238        assert!(result[[0, 0]].abs() < 1e-10, "x ~ 0 after 90-deg rotation");
1239        assert!(
1240            (result[[0, 1]] - 1.0).abs() < 1e-10,
1241            "y ~ 1 after 90-deg rotation"
1242        );
1243    }
1244
1245    #[test]
1246    fn test_downsample_2x() {
1247        let img = Array2::from_shape_fn((8, 8), |(i, j)| (i * 8 + j) as f64);
1248        let ds = downsample_2x(&img);
1249        assert_eq!(ds.dim(), (4, 4));
1250        // Top-left 2x2 block: 0, 1, 8, 9 -> avg = 4.5
1251        assert!((ds[[0, 0]] - 4.5).abs() < 1e-10);
1252    }
1253
1254    #[test]
1255    fn test_phase_correlation_dimension_mismatch() {
1256        let a = Array2::zeros((10, 10));
1257        let b = Array2::zeros((10, 12));
1258        assert!(phase_correlation(&a, &b).is_err());
1259    }
1260
1261    #[test]
1262    fn test_affine_too_few_points() {
1263        let src = Array2::from_shape_vec((2, 2), vec![0.0, 0.0, 1.0, 1.0]).expect("shape");
1264        let tgt = src.clone();
1265        assert!(affine_registration(&src, &tgt).is_err());
1266    }
1267
1268    #[test]
1269    fn test_rigid_too_few_points() {
1270        let src = Array2::from_shape_vec((1, 2), vec![0.0, 0.0]).expect("shape");
1271        let tgt = src.clone();
1272        assert!(rigid_registration(&src, &tgt).is_err());
1273    }
1274}