1#![allow(dead_code)]
2use std::f64::consts::PI;
19
20#[derive(Debug, Clone, Copy, PartialEq)]
29pub struct RigidTransform {
30 pub theta: f64,
32 pub tx: f64,
34 pub ty: f64,
36}
37
38impl RigidTransform {
39 pub fn new(theta: f64, tx: f64, ty: f64) -> Self {
41 Self { theta, tx, ty }
42 }
43
44 pub fn identity() -> Self {
46 Self {
47 theta: 0.0,
48 tx: 0.0,
49 ty: 0.0,
50 }
51 }
52
53 pub fn translation(tx: f64, ty: f64) -> Self {
55 Self { theta: 0.0, tx, ty }
56 }
57
58 pub fn rotation(theta: f64) -> Self {
60 Self {
61 theta,
62 tx: 0.0,
63 ty: 0.0,
64 }
65 }
66
67 pub fn apply(&self, x: f64, y: f64) -> (f64, f64) {
69 let (sin_t, cos_t) = self.theta.sin_cos();
70 let xp = cos_t * x - sin_t * y + self.tx;
71 let yp = sin_t * x + cos_t * y + self.ty;
72 (xp, yp)
73 }
74
75 pub fn inverse(&self) -> Self {
77 let (sin_t, cos_t) = self.theta.sin_cos();
78 let tx_inv = -(cos_t * self.tx + sin_t * self.ty);
80 let ty_inv = -(-sin_t * self.tx + cos_t * self.ty);
81 Self {
82 theta: -self.theta,
83 tx: tx_inv,
84 ty: ty_inv,
85 }
86 }
87
88 pub fn compose(&self, other: &Self) -> Self {
91 let theta = self.theta + other.theta;
92 let (sin_o, cos_o) = other.theta.sin_cos();
93 let tx = cos_o * self.tx - sin_o * self.ty + other.tx;
94 let ty = sin_o * self.tx + cos_o * self.ty + other.ty;
95 Self { theta, tx, ty }
96 }
97
98 pub fn normalize_angle(&mut self) {
100 self.theta = (self.theta + PI).rem_euclid(2.0 * PI) - PI;
101 }
102
103 #[allow(clippy::cast_precision_loss)]
105 pub fn angle_degrees(&self) -> f64 {
106 self.theta.to_degrees()
107 }
108
109 pub fn translation_magnitude(&self) -> f64 {
111 (self.tx * self.tx + self.ty * self.ty).sqrt()
112 }
113}
114
115#[derive(Debug, Clone, Copy)]
117pub struct PointPair {
118 pub src_x: f64,
120 pub src_y: f64,
122 pub dst_x: f64,
124 pub dst_y: f64,
126}
127
128impl PointPair {
129 pub fn new(src_x: f64, src_y: f64, dst_x: f64, dst_y: f64) -> Self {
131 Self {
132 src_x,
133 src_y,
134 dst_x,
135 dst_y,
136 }
137 }
138}
139
140#[allow(clippy::cast_precision_loss)]
145pub fn estimate_rigid(pairs: &[PointPair]) -> Option<RigidTransform> {
146 if pairs.len() < 2 {
147 return None;
148 }
149 let n = pairs.len() as f64;
150
151 let (cx_s, cy_s) = pairs
153 .iter()
154 .fold((0.0, 0.0), |(sx, sy), p| (sx + p.src_x, sy + p.src_y));
155 let (cx_d, cy_d) = pairs
156 .iter()
157 .fold((0.0, 0.0), |(sx, sy), p| (sx + p.dst_x, sy + p.dst_y));
158 let (cx_s, cy_s) = (cx_s / n, cy_s / n);
159 let (cx_d, cy_d) = (cx_d / n, cy_d / n);
160
161 let mut sum_sin = 0.0;
163 let mut sum_cos = 0.0;
164 for p in pairs {
165 let sx = p.src_x - cx_s;
166 let sy = p.src_y - cy_s;
167 let dx = p.dst_x - cx_d;
168 let dy = p.dst_y - cy_d;
169 sum_cos += sx * dx + sy * dy;
170 sum_sin += sx * dy - sy * dx;
171 }
172
173 let theta = sum_sin.atan2(sum_cos);
174 let (sin_t, cos_t) = theta.sin_cos();
175 let tx = cx_d - (cos_t * cx_s - sin_t * cy_s);
176 let ty = cy_d - (sin_t * cx_s + cos_t * cy_s);
177
178 Some(RigidTransform { theta, tx, ty })
179}
180
181#[allow(clippy::cast_precision_loss)]
183pub fn rmse(transform: &RigidTransform, pairs: &[PointPair]) -> f64 {
184 if pairs.is_empty() {
185 return 0.0;
186 }
187 let sum: f64 = pairs
188 .iter()
189 .map(|p| {
190 let (xp, yp) = transform.apply(p.src_x, p.src_y);
191 let dx = xp - p.dst_x;
192 let dy = yp - p.dst_y;
193 dx * dx + dy * dy
194 })
195 .sum();
196 (sum / pairs.len() as f64).sqrt()
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 const EPS: f64 = 1e-9;
204
205 #[test]
206 fn test_identity_apply() {
207 let t = RigidTransform::identity();
208 let (x, y) = t.apply(3.0, 4.0);
209 assert!((x - 3.0).abs() < EPS);
210 assert!((y - 4.0).abs() < EPS);
211 }
212
213 #[test]
214 fn test_translation_apply() {
215 let t = RigidTransform::translation(5.0, -3.0);
216 let (x, y) = t.apply(1.0, 2.0);
217 assert!((x - 6.0).abs() < EPS);
218 assert!((y + 1.0).abs() < EPS);
219 }
220
221 #[test]
222 fn test_rotation_90_degrees() {
223 let t = RigidTransform::rotation(PI / 2.0);
224 let (x, y) = t.apply(1.0, 0.0);
225 assert!(x.abs() < EPS);
226 assert!((y - 1.0).abs() < EPS);
227 }
228
229 #[test]
230 fn test_inverse_roundtrip() {
231 let t = RigidTransform::new(0.3, 5.0, -2.0);
232 let inv = t.inverse();
233 let (x, y) = t.apply(7.0, 3.0);
234 let (xb, yb) = inv.apply(x, y);
235 assert!((xb - 7.0).abs() < EPS);
236 assert!((yb - 3.0).abs() < EPS);
237 }
238
239 #[test]
240 fn test_compose_with_identity() {
241 let t = RigidTransform::new(0.5, 1.0, 2.0);
242 let id = RigidTransform::identity();
243 let c = t.compose(&id);
244 assert!((c.theta - t.theta).abs() < EPS);
245 assert!((c.tx - t.tx).abs() < EPS);
246 assert!((c.ty - t.ty).abs() < EPS);
247 }
248
249 #[test]
250 fn test_compose_two_translations() {
251 let t1 = RigidTransform::translation(1.0, 2.0);
252 let t2 = RigidTransform::translation(3.0, 4.0);
253 let c = t1.compose(&t2);
254 assert!(c.theta.abs() < EPS);
255 assert!((c.tx - 4.0).abs() < EPS);
256 assert!((c.ty - 6.0).abs() < EPS);
257 }
258
259 #[test]
260 fn test_angle_degrees() {
261 let t = RigidTransform::rotation(PI / 4.0);
262 assert!((t.angle_degrees() - 45.0).abs() < 1e-6);
263 }
264
265 #[test]
266 fn test_translation_magnitude() {
267 let t = RigidTransform::translation(3.0, 4.0);
268 assert!((t.translation_magnitude() - 5.0).abs() < EPS);
269 }
270
271 #[test]
272 fn test_normalize_angle() {
273 let mut t = RigidTransform::rotation(3.0 * PI);
274 t.normalize_angle();
275 assert!(
277 (t.theta - PI).abs() < 1e-6 || (t.theta + PI).abs() < 1e-6,
278 "expected ±PI, got {}",
279 t.theta
280 );
281 }
282
283 #[test]
284 fn test_estimate_pure_translation() {
285 let pairs = vec![
286 PointPair::new(0.0, 0.0, 1.0, 2.0),
287 PointPair::new(1.0, 0.0, 2.0, 2.0),
288 PointPair::new(0.0, 1.0, 1.0, 3.0),
289 ];
290 let t = estimate_rigid(&pairs).expect("t should be valid");
291 assert!(t.theta.abs() < 1e-6);
292 assert!((t.tx - 1.0).abs() < 1e-6);
293 assert!((t.ty - 2.0).abs() < 1e-6);
294 }
295
296 #[test]
297 fn test_estimate_insufficient_points() {
298 let pairs = vec![PointPair::new(0.0, 0.0, 1.0, 1.0)];
299 assert!(estimate_rigid(&pairs).is_none());
300 }
301
302 #[test]
303 fn test_rmse_perfect() {
304 let t = RigidTransform::translation(1.0, 0.0);
305 let pairs = vec![
306 PointPair::new(0.0, 0.0, 1.0, 0.0),
307 PointPair::new(1.0, 0.0, 2.0, 0.0),
308 ];
309 assert!(rmse(&t, &pairs) < EPS);
310 }
311
312 #[test]
313 fn test_rmse_empty() {
314 let t = RigidTransform::identity();
315 assert!(rmse(&t, &[]).abs() < EPS);
316 }
317}