Skip to main content

oximedia_align/
rigid_transform.rs

1#![allow(dead_code)]
2//! Rigid body (rotation + translation) transformations for alignment.
3//!
4//! A rigid transform preserves distances and angles. It is characterised by a
5//! 2-D rotation angle `theta` and a translation vector `(tx, ty)`. This is the
6//! simplest geometric model for aligning images that differ only by camera
7//! rotation and shift (no scaling or shearing).
8//!
9//! # Features
10//!
11//! - [`RigidTransform`] representation (angle + translation)
12//! - Application of the transform to 2-D points
13//! - Inverse transform
14//! - Estimation from matched point pairs via least-squares
15//! - Composition / chaining of transforms
16//! - Residual error computation
17
18use std::f64::consts::PI;
19
20/// A rigid 2-D transform: rotation by `theta` followed by translation `(tx, ty)`.
21///
22/// The transformation of a point `(x, y)` is:
23///
24/// ```text
25/// x' = cos(theta) * x - sin(theta) * y + tx
26/// y' = sin(theta) * x + cos(theta) * y + ty
27/// ```
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub struct RigidTransform {
30    /// Rotation angle in radians
31    pub theta: f64,
32    /// Translation along the X axis
33    pub tx: f64,
34    /// Translation along the Y axis
35    pub ty: f64,
36}
37
38impl RigidTransform {
39    /// Create a new rigid transform.
40    pub fn new(theta: f64, tx: f64, ty: f64) -> Self {
41        Self { theta, tx, ty }
42    }
43
44    /// The identity transform (no rotation, no translation).
45    pub fn identity() -> Self {
46        Self {
47            theta: 0.0,
48            tx: 0.0,
49            ty: 0.0,
50        }
51    }
52
53    /// Create a pure translation (no rotation).
54    pub fn translation(tx: f64, ty: f64) -> Self {
55        Self { theta: 0.0, tx, ty }
56    }
57
58    /// Create a pure rotation about the origin (no translation).
59    pub fn rotation(theta: f64) -> Self {
60        Self {
61            theta,
62            tx: 0.0,
63            ty: 0.0,
64        }
65    }
66
67    /// Apply this transform to a point.
68    pub fn apply(&self, x: f64, y: f64) -> (f64, f64) {
69        let (sin_t, cos_t) = self.theta.sin_cos();
70        let xp = cos_t * x - sin_t * y + self.tx;
71        let yp = sin_t * x + cos_t * y + self.ty;
72        (xp, yp)
73    }
74
75    /// Compute the inverse transform.
76    pub fn inverse(&self) -> Self {
77        let (sin_t, cos_t) = self.theta.sin_cos();
78        // Inverse rotation is -theta, inverse translation rotated back
79        let tx_inv = -(cos_t * self.tx + sin_t * self.ty);
80        let ty_inv = -(-sin_t * self.tx + cos_t * self.ty);
81        Self {
82            theta: -self.theta,
83            tx: tx_inv,
84            ty: ty_inv,
85        }
86    }
87
88    /// Compose two rigid transforms: first `self`, then `other`.
89    /// Equivalent to applying `self` followed by `other`.
90    pub fn compose(&self, other: &Self) -> Self {
91        let theta = self.theta + other.theta;
92        let (sin_o, cos_o) = other.theta.sin_cos();
93        let tx = cos_o * self.tx - sin_o * self.ty + other.tx;
94        let ty = sin_o * self.tx + cos_o * self.ty + other.ty;
95        Self { theta, tx, ty }
96    }
97
98    /// Normalise the angle to `[-PI, PI)`.
99    pub fn normalize_angle(&mut self) {
100        self.theta = (self.theta + PI).rem_euclid(2.0 * PI) - PI;
101    }
102
103    /// Return the rotation angle in degrees.
104    #[allow(clippy::cast_precision_loss)]
105    pub fn angle_degrees(&self) -> f64 {
106        self.theta.to_degrees()
107    }
108
109    /// Compute the translation magnitude.
110    pub fn translation_magnitude(&self) -> f64 {
111        (self.tx * self.tx + self.ty * self.ty).sqrt()
112    }
113}
114
115/// A pair of corresponding 2-D points.
116#[derive(Debug, Clone, Copy)]
117pub struct PointPair {
118    /// Source point X
119    pub src_x: f64,
120    /// Source point Y
121    pub src_y: f64,
122    /// Destination point X
123    pub dst_x: f64,
124    /// Destination point Y
125    pub dst_y: f64,
126}
127
128impl PointPair {
129    /// Create a new point pair.
130    pub fn new(src_x: f64, src_y: f64, dst_x: f64, dst_y: f64) -> Self {
131        Self {
132            src_x,
133            src_y,
134            dst_x,
135            dst_y,
136        }
137    }
138}
139
140/// Estimate a rigid transform from 2 or more point correspondences using
141/// least-squares (the Procrustes solution without scaling).
142///
143/// Returns `None` if fewer than 2 correspondences are provided.
144#[allow(clippy::cast_precision_loss)]
145pub fn estimate_rigid(pairs: &[PointPair]) -> Option<RigidTransform> {
146    if pairs.len() < 2 {
147        return None;
148    }
149    let n = pairs.len() as f64;
150
151    // Compute centroids
152    let (cx_s, cy_s) = pairs
153        .iter()
154        .fold((0.0, 0.0), |(sx, sy), p| (sx + p.src_x, sy + p.src_y));
155    let (cx_d, cy_d) = pairs
156        .iter()
157        .fold((0.0, 0.0), |(sx, sy), p| (sx + p.dst_x, sy + p.dst_y));
158    let (cx_s, cy_s) = (cx_s / n, cy_s / n);
159    let (cx_d, cy_d) = (cx_d / n, cy_d / n);
160
161    // Compute cross-covariance sums for rotation
162    let mut sum_sin = 0.0;
163    let mut sum_cos = 0.0;
164    for p in pairs {
165        let sx = p.src_x - cx_s;
166        let sy = p.src_y - cy_s;
167        let dx = p.dst_x - cx_d;
168        let dy = p.dst_y - cy_d;
169        sum_cos += sx * dx + sy * dy;
170        sum_sin += sx * dy - sy * dx;
171    }
172
173    let theta = sum_sin.atan2(sum_cos);
174    let (sin_t, cos_t) = theta.sin_cos();
175    let tx = cx_d - (cos_t * cx_s - sin_t * cy_s);
176    let ty = cy_d - (sin_t * cx_s + cos_t * cy_s);
177
178    Some(RigidTransform { theta, tx, ty })
179}
180
181/// Compute the root mean square error of a rigid transform given correspondences.
182#[allow(clippy::cast_precision_loss)]
183pub fn rmse(transform: &RigidTransform, pairs: &[PointPair]) -> f64 {
184    if pairs.is_empty() {
185        return 0.0;
186    }
187    let sum: f64 = pairs
188        .iter()
189        .map(|p| {
190            let (xp, yp) = transform.apply(p.src_x, p.src_y);
191            let dx = xp - p.dst_x;
192            let dy = yp - p.dst_y;
193            dx * dx + dy * dy
194        })
195        .sum();
196    (sum / pairs.len() as f64).sqrt()
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    const EPS: f64 = 1e-9;
204
205    #[test]
206    fn test_identity_apply() {
207        let t = RigidTransform::identity();
208        let (x, y) = t.apply(3.0, 4.0);
209        assert!((x - 3.0).abs() < EPS);
210        assert!((y - 4.0).abs() < EPS);
211    }
212
213    #[test]
214    fn test_translation_apply() {
215        let t = RigidTransform::translation(5.0, -3.0);
216        let (x, y) = t.apply(1.0, 2.0);
217        assert!((x - 6.0).abs() < EPS);
218        assert!((y + 1.0).abs() < EPS);
219    }
220
221    #[test]
222    fn test_rotation_90_degrees() {
223        let t = RigidTransform::rotation(PI / 2.0);
224        let (x, y) = t.apply(1.0, 0.0);
225        assert!(x.abs() < EPS);
226        assert!((y - 1.0).abs() < EPS);
227    }
228
229    #[test]
230    fn test_inverse_roundtrip() {
231        let t = RigidTransform::new(0.3, 5.0, -2.0);
232        let inv = t.inverse();
233        let (x, y) = t.apply(7.0, 3.0);
234        let (xb, yb) = inv.apply(x, y);
235        assert!((xb - 7.0).abs() < EPS);
236        assert!((yb - 3.0).abs() < EPS);
237    }
238
239    #[test]
240    fn test_compose_with_identity() {
241        let t = RigidTransform::new(0.5, 1.0, 2.0);
242        let id = RigidTransform::identity();
243        let c = t.compose(&id);
244        assert!((c.theta - t.theta).abs() < EPS);
245        assert!((c.tx - t.tx).abs() < EPS);
246        assert!((c.ty - t.ty).abs() < EPS);
247    }
248
249    #[test]
250    fn test_compose_two_translations() {
251        let t1 = RigidTransform::translation(1.0, 2.0);
252        let t2 = RigidTransform::translation(3.0, 4.0);
253        let c = t1.compose(&t2);
254        assert!(c.theta.abs() < EPS);
255        assert!((c.tx - 4.0).abs() < EPS);
256        assert!((c.ty - 6.0).abs() < EPS);
257    }
258
259    #[test]
260    fn test_angle_degrees() {
261        let t = RigidTransform::rotation(PI / 4.0);
262        assert!((t.angle_degrees() - 45.0).abs() < 1e-6);
263    }
264
265    #[test]
266    fn test_translation_magnitude() {
267        let t = RigidTransform::translation(3.0, 4.0);
268        assert!((t.translation_magnitude() - 5.0).abs() < EPS);
269    }
270
271    #[test]
272    fn test_normalize_angle() {
273        let mut t = RigidTransform::rotation(3.0 * PI);
274        t.normalize_angle();
275        // 3*PI normalises to either PI or -PI (both represent the same angle).
276        assert!(
277            (t.theta - PI).abs() < 1e-6 || (t.theta + PI).abs() < 1e-6,
278            "expected ±PI, got {}",
279            t.theta
280        );
281    }
282
283    #[test]
284    fn test_estimate_pure_translation() {
285        let pairs = vec![
286            PointPair::new(0.0, 0.0, 1.0, 2.0),
287            PointPair::new(1.0, 0.0, 2.0, 2.0),
288            PointPair::new(0.0, 1.0, 1.0, 3.0),
289        ];
290        let t = estimate_rigid(&pairs).expect("t should be valid");
291        assert!(t.theta.abs() < 1e-6);
292        assert!((t.tx - 1.0).abs() < 1e-6);
293        assert!((t.ty - 2.0).abs() < 1e-6);
294    }
295
296    #[test]
297    fn test_estimate_insufficient_points() {
298        let pairs = vec![PointPair::new(0.0, 0.0, 1.0, 1.0)];
299        assert!(estimate_rigid(&pairs).is_none());
300    }
301
302    #[test]
303    fn test_rmse_perfect() {
304        let t = RigidTransform::translation(1.0, 0.0);
305        let pairs = vec![
306            PointPair::new(0.0, 0.0, 1.0, 0.0),
307            PointPair::new(1.0, 0.0, 2.0, 0.0),
308        ];
309        assert!(rmse(&t, &pairs) < EPS);
310    }
311
312    #[test]
313    fn test_rmse_empty() {
314        let t = RigidTransform::identity();
315        assert!(rmse(&t, &[]).abs() < EPS);
316    }
317}