Skip to main content

oximedia_align/
elastic_align.rs

1#![allow(dead_code)]
2//! Elastic (non-rigid) alignment for deformable media registration.
3//!
4//! This module implements non-rigid alignment techniques that can handle local deformations
5//! such as lens distortion residuals, rolling shutter wobble, and object-level motion.
6//!
7//! # Features
8//!
9//! - **Thin-plate spline (TPS) warping** for smooth non-rigid transforms
10//! - **Control point management** with automatic correspondence
11//! - **Regularized alignment** to prevent overfitting
12//! - **Deformation field** representation and analysis
13
14use crate::{AlignError, AlignResult, Point2D};
15
16/// A control point pair used as a landmark for elastic alignment.
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub struct ControlPoint {
19    /// Source position.
20    pub source: Point2D,
21    /// Target position (where the source should map to).
22    pub target: Point2D,
23    /// Weight for this control point (higher means stronger influence).
24    pub weight: f64,
25}
26
27impl ControlPoint {
28    /// Create a new control point pair.
29    #[must_use]
30    pub fn new(source: Point2D, target: Point2D, weight: f64) -> Self {
31        Self {
32            source,
33            target,
34            weight,
35        }
36    }
37
38    /// Create with default weight 1.0.
39    #[must_use]
40    pub fn with_unit_weight(source: Point2D, target: Point2D) -> Self {
41        Self {
42            source,
43            target,
44            weight: 1.0,
45        }
46    }
47
48    /// Compute the displacement vector (target - source).
49    #[must_use]
50    pub fn displacement(&self) -> (f64, f64) {
51        (self.target.x - self.source.x, self.target.y - self.source.y)
52    }
53
54    /// Compute the displacement magnitude.
55    #[must_use]
56    pub fn displacement_magnitude(&self) -> f64 {
57        let (dx, dy) = self.displacement();
58        (dx * dx + dy * dy).sqrt()
59    }
60}
61
62/// Configuration for elastic alignment.
63#[derive(Debug, Clone)]
64pub struct ElasticAlignConfig {
65    /// Regularization parameter (lambda). Higher values produce smoother warps.
66    pub regularization: f64,
67    /// Minimum number of control points required.
68    pub min_control_points: usize,
69    /// Maximum allowed displacement in pixels.
70    pub max_displacement: f64,
71    /// Grid resolution for deformation field output (pixels per cell).
72    pub grid_resolution: u32,
73}
74
75impl Default for ElasticAlignConfig {
76    fn default() -> Self {
77        Self {
78            regularization: 0.01,
79            min_control_points: 4,
80            max_displacement: 100.0,
81            grid_resolution: 16,
82        }
83    }
84}
85
86/// Thin-Plate Spline coefficients for one coordinate dimension.
87#[derive(Debug, Clone)]
88pub struct TpsCoefficients {
89    /// Weights for each control point (non-linear part).
90    pub weights: Vec<f64>,
91    /// Affine part: a0 + a1*x + a2*y.
92    pub affine: [f64; 3],
93}
94
95/// Result of elastic alignment computation.
96#[derive(Debug, Clone)]
97pub struct ElasticAlignResult {
98    /// TPS coefficients for x-coordinate mapping.
99    pub tps_x: TpsCoefficients,
100    /// TPS coefficients for y-coordinate mapping.
101    pub tps_y: TpsCoefficients,
102    /// The control points used.
103    pub control_points: Vec<ControlPoint>,
104    /// Root-mean-square alignment error in pixels.
105    pub rms_error: f64,
106    /// Maximum alignment error in pixels.
107    pub max_error: f64,
108    /// Bending energy of the deformation (lower is smoother).
109    pub bending_energy: f64,
110}
111
112/// A sampled deformation field on a regular grid.
113#[derive(Debug, Clone)]
114pub struct DeformationField {
115    /// Horizontal displacement for each grid cell.
116    pub dx: Vec<f64>,
117    /// Vertical displacement for each grid cell.
118    pub dy: Vec<f64>,
119    /// Number of grid columns.
120    pub cols: u32,
121    /// Number of grid rows.
122    pub rows: u32,
123    /// Cell size in pixels.
124    pub cell_size: u32,
125}
126
127impl DeformationField {
128    /// Create a zero deformation field.
129    #[must_use]
130    pub fn new(width: u32, height: u32, cell_size: u32) -> Self {
131        let cols = width.div_ceil(cell_size);
132        let rows = height.div_ceil(cell_size);
133        let count = (cols * rows) as usize;
134        Self {
135            dx: vec![0.0; count],
136            dy: vec![0.0; count],
137            cols,
138            rows,
139            cell_size,
140        }
141    }
142
143    /// Get displacement at grid cell (cx, cy).
144    #[must_use]
145    pub fn get(&self, cx: u32, cy: u32) -> Option<(f64, f64)> {
146        if cx < self.cols && cy < self.rows {
147            let idx = (cy * self.cols + cx) as usize;
148            Some((self.dx[idx], self.dy[idx]))
149        } else {
150            None
151        }
152    }
153
154    /// Set displacement at grid cell (cx, cy).
155    pub fn set(&mut self, cx: u32, cy: u32, dx: f64, dy: f64) {
156        if cx < self.cols && cy < self.rows {
157            let idx = (cy * self.cols + cx) as usize;
158            self.dx[idx] = dx;
159            self.dy[idx] = dy;
160        }
161    }
162
163    /// Compute the average displacement magnitude across the field.
164    #[must_use]
165    #[allow(clippy::cast_precision_loss)]
166    pub fn average_displacement(&self) -> f64 {
167        if self.dx.is_empty() {
168            return 0.0;
169        }
170        let total: f64 = self
171            .dx
172            .iter()
173            .zip(self.dy.iter())
174            .map(|(x, y)| (x * x + y * y).sqrt())
175            .sum();
176        total / self.dx.len() as f64
177    }
178
179    /// Compute maximum displacement magnitude in the field.
180    #[must_use]
181    pub fn max_displacement(&self) -> f64 {
182        self.dx
183            .iter()
184            .zip(self.dy.iter())
185            .map(|(x, y)| (x * x + y * y).sqrt())
186            .fold(0.0_f64, f64::max)
187    }
188}
189
190/// The Thin-Plate Spline radial basis function: r^2 * ln(r).
191#[allow(clippy::cast_precision_loss)]
192fn tps_kernel(r: f64) -> f64 {
193    if r < 1e-15 {
194        0.0
195    } else {
196        r * r * r.ln()
197    }
198}
199
200/// Elastic aligner using thin-plate spline interpolation.
201#[derive(Debug, Clone)]
202pub struct ElasticAligner {
203    /// Configuration.
204    config: ElasticAlignConfig,
205}
206
207impl ElasticAligner {
208    /// Create a new elastic aligner with the given configuration.
209    #[must_use]
210    pub fn new(config: ElasticAlignConfig) -> Self {
211        Self { config }
212    }
213
214    /// Create with default configuration.
215    #[must_use]
216    pub fn with_defaults() -> Self {
217        Self {
218            config: ElasticAlignConfig::default(),
219        }
220    }
221
222    /// Compute TPS alignment from control point correspondences.
223    pub fn align(&self, control_points: &[ControlPoint]) -> AlignResult<ElasticAlignResult> {
224        let n = control_points.len();
225        if n < self.config.min_control_points {
226            return Err(AlignError::InsufficientData(format!(
227                "Need at least {} control points, got {}",
228                self.config.min_control_points, n
229            )));
230        }
231
232        // Check max displacement constraint
233        for cp in control_points {
234            if cp.displacement_magnitude() > self.config.max_displacement {
235                return Err(AlignError::InvalidConfig(format!(
236                    "Control point displacement {:.1} exceeds max {:.1}",
237                    cp.displacement_magnitude(),
238                    self.config.max_displacement
239                )));
240            }
241        }
242
243        // Solve TPS for x and y independently
244        let tps_x = self.solve_tps(control_points, true)?;
245        let tps_y = self.solve_tps(control_points, false)?;
246
247        // Compute errors
248        let (rms_error, max_error) = self.compute_errors(control_points, &tps_x, &tps_y);
249
250        // Compute bending energy
251        let bending_energy = self.compute_bending_energy(control_points, &tps_x, &tps_y);
252
253        Ok(ElasticAlignResult {
254            tps_x,
255            tps_y,
256            control_points: control_points.to_vec(),
257            rms_error,
258            max_error,
259            bending_energy,
260        })
261    }
262
263    /// Transform a point using TPS coefficients.
264    #[must_use]
265    pub fn transform_point(&self, point: &Point2D, result: &ElasticAlignResult) -> Point2D {
266        let new_x = self.evaluate_tps(point, &result.tps_x, &result.control_points);
267        let new_y = self.evaluate_tps(point, &result.tps_y, &result.control_points);
268        Point2D::new(new_x, new_y)
269    }
270
271    /// Generate a sampled deformation field from an alignment result.
272    #[must_use]
273    #[allow(clippy::cast_precision_loss)]
274    pub fn generate_deformation_field(
275        &self,
276        result: &ElasticAlignResult,
277        width: u32,
278        height: u32,
279    ) -> DeformationField {
280        let cell_size = self.config.grid_resolution;
281        let mut field = DeformationField::new(width, height, cell_size);
282
283        for cy in 0..field.rows {
284            for cx in 0..field.cols {
285                let px = f64::from(cx * cell_size + cell_size / 2);
286                let py = f64::from(cy * cell_size + cell_size / 2);
287                let src = Point2D::new(px, py);
288                let dst = self.transform_point(&src, result);
289                field.set(cx, cy, dst.x - px, dst.y - py);
290            }
291        }
292
293        field
294    }
295
296    /// Solve TPS for one coordinate (x if `for_x` is true, y otherwise).
297    #[allow(clippy::cast_precision_loss)]
298    fn solve_tps(&self, points: &[ControlPoint], for_x: bool) -> AlignResult<TpsCoefficients> {
299        let n = points.len();
300        // System size: n + 3 (n weights + 3 affine params)
301        let size = n + 3;
302
303        // Build the system matrix L and right-hand side v
304        // L = | K  P |   v = | target_coords |
305        //     | P' 0 |       |       0       |
306        let mut l_matrix = vec![0.0f64; size * size];
307        let mut rhs = vec![0.0f64; size];
308
309        // Fill K (n x n kernel matrix) + regularization on diagonal
310        for i in 0..n {
311            for j in 0..n {
312                let r = points[i].source.distance(&points[j].source);
313                l_matrix[i * size + j] = tps_kernel(r);
314            }
315            // Regularization
316            l_matrix[i * size + i] += self.config.regularization / points[i].weight;
317        }
318
319        // Fill P (n x 3) and P^T (3 x n)
320        for i in 0..n {
321            l_matrix[i * size + n] = 1.0;
322            l_matrix[i * size + n + 1] = points[i].source.x;
323            l_matrix[i * size + n + 2] = points[i].source.y;
324
325            l_matrix[(n) * size + i] = 1.0;
326            l_matrix[(n + 1) * size + i] = points[i].source.x;
327            l_matrix[(n + 2) * size + i] = points[i].source.y;
328        }
329
330        // Fill rhs
331        for i in 0..n {
332            rhs[i] = if for_x {
333                points[i].target.x
334            } else {
335                points[i].target.y
336            };
337        }
338
339        // Solve using Gauss elimination with partial pivoting
340        let solution = Self::gauss_solve(&mut l_matrix, &mut rhs, size)?;
341
342        let weights = solution[..n].to_vec();
343        let affine = [solution[n], solution[n + 1], solution[n + 2]];
344
345        Ok(TpsCoefficients { weights, affine })
346    }
347
348    /// Evaluate TPS at a point.
349    fn evaluate_tps(
350        &self,
351        point: &Point2D,
352        tps: &TpsCoefficients,
353        control_points: &[ControlPoint],
354    ) -> f64 {
355        let mut val = tps.affine[0] + tps.affine[1] * point.x + tps.affine[2] * point.y;
356
357        for (i, cp) in control_points.iter().enumerate() {
358            let r = point.distance(&cp.source);
359            val += tps.weights[i] * tps_kernel(r);
360        }
361
362        val
363    }
364
365    /// Compute RMS and max error.
366    #[allow(clippy::cast_precision_loss)]
367    fn compute_errors(
368        &self,
369        points: &[ControlPoint],
370        tps_x: &TpsCoefficients,
371        tps_y: &TpsCoefficients,
372    ) -> (f64, f64) {
373        let mut sum_sq = 0.0;
374        let mut max_e = 0.0_f64;
375
376        for cp in points {
377            let px = self.evaluate_tps(&cp.source, tps_x, points);
378            let py = self.evaluate_tps(&cp.source, tps_y, points);
379            let err = ((px - cp.target.x).powi(2) + (py - cp.target.y).powi(2)).sqrt();
380            sum_sq += err * err;
381            max_e = max_e.max(err);
382        }
383
384        let rms = (sum_sq / points.len() as f64).sqrt();
385        (rms, max_e)
386    }
387
388    /// Compute bending energy.
389    fn compute_bending_energy(
390        &self,
391        points: &[ControlPoint],
392        tps_x: &TpsCoefficients,
393        tps_y: &TpsCoefficients,
394    ) -> f64 {
395        let n = points.len();
396        let mut energy = 0.0;
397
398        for i in 0..n {
399            for j in 0..n {
400                let r = points[i].source.distance(&points[j].source);
401                let k = tps_kernel(r);
402                energy += tps_x.weights[i] * tps_x.weights[j] * k;
403                energy += tps_y.weights[i] * tps_y.weights[j] * k;
404            }
405        }
406
407        energy.abs()
408    }
409
410    /// Gaussian elimination with partial pivoting.
411    fn gauss_solve(a: &mut [f64], b: &mut [f64], n: usize) -> AlignResult<Vec<f64>> {
412        // Forward elimination
413        for col in 0..n {
414            // Partial pivoting
415            let mut max_val = a[col * n + col].abs();
416            let mut max_row = col;
417            for row in (col + 1)..n {
418                let val = a[row * n + col].abs();
419                if val > max_val {
420                    max_val = val;
421                    max_row = row;
422                }
423            }
424
425            if max_val < 1e-15 {
426                return Err(AlignError::NumericalError(
427                    "Singular matrix in TPS solve".to_string(),
428                ));
429            }
430
431            // Swap rows
432            if max_row != col {
433                for k in 0..n {
434                    a.swap(col * n + k, max_row * n + k);
435                }
436                b.swap(col, max_row);
437            }
438
439            // Eliminate below
440            let pivot = a[col * n + col];
441            for row in (col + 1)..n {
442                let factor = a[row * n + col] / pivot;
443                for k in col..n {
444                    a[row * n + k] -= factor * a[col * n + k];
445                }
446                b[row] -= factor * b[col];
447            }
448        }
449
450        // Back substitution
451        let mut x = vec![0.0f64; n];
452        for col in (0..n).rev() {
453            let mut sum = b[col];
454            for k in (col + 1)..n {
455                sum -= a[col * n + k] * x[k];
456            }
457            x[col] = sum / a[col * n + col];
458        }
459
460        Ok(x)
461    }
462
463    /// Get the current configuration.
464    #[must_use]
465    pub fn config(&self) -> &ElasticAlignConfig {
466        &self.config
467    }
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473
474    #[test]
475    fn test_control_point_creation() {
476        let cp = ControlPoint::new(Point2D::new(10.0, 20.0), Point2D::new(12.0, 22.0), 1.0);
477        assert!((cp.source.x - 10.0).abs() < f64::EPSILON);
478        assert!((cp.target.x - 12.0).abs() < f64::EPSILON);
479    }
480
481    #[test]
482    fn test_control_point_displacement() {
483        let cp = ControlPoint::new(Point2D::new(0.0, 0.0), Point2D::new(3.0, 4.0), 1.0);
484        let (dx, dy) = cp.displacement();
485        assert!((dx - 3.0).abs() < f64::EPSILON);
486        assert!((dy - 4.0).abs() < f64::EPSILON);
487        assert!((cp.displacement_magnitude() - 5.0).abs() < 1e-10);
488    }
489
490    #[test]
491    fn test_control_point_unit_weight() {
492        let cp = ControlPoint::with_unit_weight(Point2D::new(0.0, 0.0), Point2D::new(1.0, 1.0));
493        assert!((cp.weight - 1.0).abs() < f64::EPSILON);
494    }
495
496    #[test]
497    fn test_config_default() {
498        let config = ElasticAlignConfig::default();
499        assert!((config.regularization - 0.01).abs() < f64::EPSILON);
500        assert_eq!(config.min_control_points, 4);
501    }
502
503    #[test]
504    fn test_deformation_field_creation() {
505        let field = DeformationField::new(320, 240, 16);
506        assert_eq!(field.cols, 20);
507        assert_eq!(field.rows, 15);
508        assert_eq!(field.dx.len(), 300);
509    }
510
511    #[test]
512    fn test_deformation_field_get_set() {
513        let mut field = DeformationField::new(64, 64, 16);
514        field.set(1, 2, 3.5, -1.5);
515        let (dx, dy) = field.get(1, 2).expect("get should succeed");
516        assert!((dx - 3.5).abs() < f64::EPSILON);
517        assert!((dy - (-1.5)).abs() < f64::EPSILON);
518    }
519
520    #[test]
521    fn test_deformation_field_average() {
522        let mut field = DeformationField::new(32, 32, 16);
523        field.set(0, 0, 3.0, 4.0); // magnitude 5
524        field.set(1, 0, 0.0, 0.0);
525        field.set(0, 1, 0.0, 0.0);
526        field.set(1, 1, 0.0, 0.0);
527        assert!((field.average_displacement() - 1.25).abs() < 1e-10);
528    }
529
530    #[test]
531    fn test_deformation_field_max() {
532        let mut field = DeformationField::new(32, 32, 16);
533        field.set(0, 0, 3.0, 4.0);
534        field.set(1, 0, 1.0, 0.0);
535        assert!((field.max_displacement() - 5.0).abs() < 1e-10);
536    }
537
538    #[test]
539    fn test_tps_kernel() {
540        assert!((tps_kernel(0.0)).abs() < f64::EPSILON);
541        // tps_kernel(1.0) = 1^2 * ln(1) = 0
542        assert!((tps_kernel(1.0)).abs() < f64::EPSILON);
543        // tps_kernel(e) = e^2 * ln(e) = e^2
544        let e = std::f64::consts::E;
545        assert!((tps_kernel(e) - e * e).abs() < 1e-10);
546    }
547
548    #[test]
549    fn test_elastic_align_insufficient_points() {
550        let aligner = ElasticAligner::with_defaults();
551        let points = vec![ControlPoint::with_unit_weight(
552            Point2D::new(0.0, 0.0),
553            Point2D::new(1.0, 1.0),
554        )];
555        let result = aligner.align(&points);
556        assert!(result.is_err());
557    }
558
559    #[test]
560    fn test_elastic_align_identity() {
561        let aligner = ElasticAligner::new(ElasticAlignConfig {
562            regularization: 0.001,
563            min_control_points: 4,
564            max_displacement: 100.0,
565            grid_resolution: 16,
566        });
567
568        // Identity mapping: source == target
569        let points = vec![
570            ControlPoint::with_unit_weight(Point2D::new(0.0, 0.0), Point2D::new(0.0, 0.0)),
571            ControlPoint::with_unit_weight(Point2D::new(100.0, 0.0), Point2D::new(100.0, 0.0)),
572            ControlPoint::with_unit_weight(Point2D::new(0.0, 100.0), Point2D::new(0.0, 100.0)),
573            ControlPoint::with_unit_weight(Point2D::new(100.0, 100.0), Point2D::new(100.0, 100.0)),
574        ];
575
576        let result = aligner.align(&points).expect("result should be valid");
577        // RMS error should be very small for identity
578        assert!(result.rms_error < 1.0);
579    }
580
581    #[test]
582    fn test_elastic_align_translation() {
583        let aligner = ElasticAligner::new(ElasticAlignConfig {
584            regularization: 0.001,
585            min_control_points: 4,
586            max_displacement: 100.0,
587            grid_resolution: 16,
588        });
589
590        // Translation of (5, 3)
591        let points = vec![
592            ControlPoint::with_unit_weight(Point2D::new(0.0, 0.0), Point2D::new(5.0, 3.0)),
593            ControlPoint::with_unit_weight(Point2D::new(100.0, 0.0), Point2D::new(105.0, 3.0)),
594            ControlPoint::with_unit_weight(Point2D::new(0.0, 100.0), Point2D::new(5.0, 103.0)),
595            ControlPoint::with_unit_weight(Point2D::new(100.0, 100.0), Point2D::new(105.0, 103.0)),
596        ];
597
598        let result = aligner.align(&points).expect("result should be valid");
599        // Transform a test point
600        let transformed = aligner.transform_point(&Point2D::new(50.0, 50.0), &result);
601        // Should be approximately (55, 53)
602        assert!((transformed.x - 55.0).abs() < 2.0);
603        assert!((transformed.y - 53.0).abs() < 2.0);
604    }
605
606    #[test]
607    fn test_elastic_align_max_displacement_exceeded() {
608        let aligner = ElasticAligner::new(ElasticAlignConfig {
609            max_displacement: 5.0,
610            ..ElasticAlignConfig::default()
611        });
612
613        let points = vec![
614            ControlPoint::with_unit_weight(Point2D::new(0.0, 0.0), Point2D::new(100.0, 100.0)),
615            ControlPoint::with_unit_weight(Point2D::new(10.0, 0.0), Point2D::new(110.0, 100.0)),
616            ControlPoint::with_unit_weight(Point2D::new(0.0, 10.0), Point2D::new(100.0, 110.0)),
617            ControlPoint::with_unit_weight(Point2D::new(10.0, 10.0), Point2D::new(110.0, 110.0)),
618        ];
619
620        let result = aligner.align(&points);
621        assert!(result.is_err());
622    }
623}