1#![allow(dead_code)]
2use crate::{AlignError, AlignResult, Point2D};
15
16#[derive(Debug, Clone, Copy, PartialEq)]
18pub struct ControlPoint {
19 pub source: Point2D,
21 pub target: Point2D,
23 pub weight: f64,
25}
26
27impl ControlPoint {
28 #[must_use]
30 pub fn new(source: Point2D, target: Point2D, weight: f64) -> Self {
31 Self {
32 source,
33 target,
34 weight,
35 }
36 }
37
38 #[must_use]
40 pub fn with_unit_weight(source: Point2D, target: Point2D) -> Self {
41 Self {
42 source,
43 target,
44 weight: 1.0,
45 }
46 }
47
48 #[must_use]
50 pub fn displacement(&self) -> (f64, f64) {
51 (self.target.x - self.source.x, self.target.y - self.source.y)
52 }
53
54 #[must_use]
56 pub fn displacement_magnitude(&self) -> f64 {
57 let (dx, dy) = self.displacement();
58 (dx * dx + dy * dy).sqrt()
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct ElasticAlignConfig {
65 pub regularization: f64,
67 pub min_control_points: usize,
69 pub max_displacement: f64,
71 pub grid_resolution: u32,
73}
74
75impl Default for ElasticAlignConfig {
76 fn default() -> Self {
77 Self {
78 regularization: 0.01,
79 min_control_points: 4,
80 max_displacement: 100.0,
81 grid_resolution: 16,
82 }
83 }
84}
85
86#[derive(Debug, Clone)]
88pub struct TpsCoefficients {
89 pub weights: Vec<f64>,
91 pub affine: [f64; 3],
93}
94
95#[derive(Debug, Clone)]
97pub struct ElasticAlignResult {
98 pub tps_x: TpsCoefficients,
100 pub tps_y: TpsCoefficients,
102 pub control_points: Vec<ControlPoint>,
104 pub rms_error: f64,
106 pub max_error: f64,
108 pub bending_energy: f64,
110}
111
112#[derive(Debug, Clone)]
114pub struct DeformationField {
115 pub dx: Vec<f64>,
117 pub dy: Vec<f64>,
119 pub cols: u32,
121 pub rows: u32,
123 pub cell_size: u32,
125}
126
127impl DeformationField {
128 #[must_use]
130 pub fn new(width: u32, height: u32, cell_size: u32) -> Self {
131 let cols = width.div_ceil(cell_size);
132 let rows = height.div_ceil(cell_size);
133 let count = (cols * rows) as usize;
134 Self {
135 dx: vec![0.0; count],
136 dy: vec![0.0; count],
137 cols,
138 rows,
139 cell_size,
140 }
141 }
142
143 #[must_use]
145 pub fn get(&self, cx: u32, cy: u32) -> Option<(f64, f64)> {
146 if cx < self.cols && cy < self.rows {
147 let idx = (cy * self.cols + cx) as usize;
148 Some((self.dx[idx], self.dy[idx]))
149 } else {
150 None
151 }
152 }
153
154 pub fn set(&mut self, cx: u32, cy: u32, dx: f64, dy: f64) {
156 if cx < self.cols && cy < self.rows {
157 let idx = (cy * self.cols + cx) as usize;
158 self.dx[idx] = dx;
159 self.dy[idx] = dy;
160 }
161 }
162
163 #[must_use]
165 #[allow(clippy::cast_precision_loss)]
166 pub fn average_displacement(&self) -> f64 {
167 if self.dx.is_empty() {
168 return 0.0;
169 }
170 let total: f64 = self
171 .dx
172 .iter()
173 .zip(self.dy.iter())
174 .map(|(x, y)| (x * x + y * y).sqrt())
175 .sum();
176 total / self.dx.len() as f64
177 }
178
179 #[must_use]
181 pub fn max_displacement(&self) -> f64 {
182 self.dx
183 .iter()
184 .zip(self.dy.iter())
185 .map(|(x, y)| (x * x + y * y).sqrt())
186 .fold(0.0_f64, f64::max)
187 }
188}
189
190#[allow(clippy::cast_precision_loss)]
192fn tps_kernel(r: f64) -> f64 {
193 if r < 1e-15 {
194 0.0
195 } else {
196 r * r * r.ln()
197 }
198}
199
200#[derive(Debug, Clone)]
202pub struct ElasticAligner {
203 config: ElasticAlignConfig,
205}
206
207impl ElasticAligner {
208 #[must_use]
210 pub fn new(config: ElasticAlignConfig) -> Self {
211 Self { config }
212 }
213
214 #[must_use]
216 pub fn with_defaults() -> Self {
217 Self {
218 config: ElasticAlignConfig::default(),
219 }
220 }
221
222 pub fn align(&self, control_points: &[ControlPoint]) -> AlignResult<ElasticAlignResult> {
224 let n = control_points.len();
225 if n < self.config.min_control_points {
226 return Err(AlignError::InsufficientData(format!(
227 "Need at least {} control points, got {}",
228 self.config.min_control_points, n
229 )));
230 }
231
232 for cp in control_points {
234 if cp.displacement_magnitude() > self.config.max_displacement {
235 return Err(AlignError::InvalidConfig(format!(
236 "Control point displacement {:.1} exceeds max {:.1}",
237 cp.displacement_magnitude(),
238 self.config.max_displacement
239 )));
240 }
241 }
242
243 let tps_x = self.solve_tps(control_points, true)?;
245 let tps_y = self.solve_tps(control_points, false)?;
246
247 let (rms_error, max_error) = self.compute_errors(control_points, &tps_x, &tps_y);
249
250 let bending_energy = self.compute_bending_energy(control_points, &tps_x, &tps_y);
252
253 Ok(ElasticAlignResult {
254 tps_x,
255 tps_y,
256 control_points: control_points.to_vec(),
257 rms_error,
258 max_error,
259 bending_energy,
260 })
261 }
262
263 #[must_use]
265 pub fn transform_point(&self, point: &Point2D, result: &ElasticAlignResult) -> Point2D {
266 let new_x = self.evaluate_tps(point, &result.tps_x, &result.control_points);
267 let new_y = self.evaluate_tps(point, &result.tps_y, &result.control_points);
268 Point2D::new(new_x, new_y)
269 }
270
271 #[must_use]
273 #[allow(clippy::cast_precision_loss)]
274 pub fn generate_deformation_field(
275 &self,
276 result: &ElasticAlignResult,
277 width: u32,
278 height: u32,
279 ) -> DeformationField {
280 let cell_size = self.config.grid_resolution;
281 let mut field = DeformationField::new(width, height, cell_size);
282
283 for cy in 0..field.rows {
284 for cx in 0..field.cols {
285 let px = f64::from(cx * cell_size + cell_size / 2);
286 let py = f64::from(cy * cell_size + cell_size / 2);
287 let src = Point2D::new(px, py);
288 let dst = self.transform_point(&src, result);
289 field.set(cx, cy, dst.x - px, dst.y - py);
290 }
291 }
292
293 field
294 }
295
296 #[allow(clippy::cast_precision_loss)]
298 fn solve_tps(&self, points: &[ControlPoint], for_x: bool) -> AlignResult<TpsCoefficients> {
299 let n = points.len();
300 let size = n + 3;
302
303 let mut l_matrix = vec![0.0f64; size * size];
307 let mut rhs = vec![0.0f64; size];
308
309 for i in 0..n {
311 for j in 0..n {
312 let r = points[i].source.distance(&points[j].source);
313 l_matrix[i * size + j] = tps_kernel(r);
314 }
315 l_matrix[i * size + i] += self.config.regularization / points[i].weight;
317 }
318
319 for i in 0..n {
321 l_matrix[i * size + n] = 1.0;
322 l_matrix[i * size + n + 1] = points[i].source.x;
323 l_matrix[i * size + n + 2] = points[i].source.y;
324
325 l_matrix[(n) * size + i] = 1.0;
326 l_matrix[(n + 1) * size + i] = points[i].source.x;
327 l_matrix[(n + 2) * size + i] = points[i].source.y;
328 }
329
330 for i in 0..n {
332 rhs[i] = if for_x {
333 points[i].target.x
334 } else {
335 points[i].target.y
336 };
337 }
338
339 let solution = Self::gauss_solve(&mut l_matrix, &mut rhs, size)?;
341
342 let weights = solution[..n].to_vec();
343 let affine = [solution[n], solution[n + 1], solution[n + 2]];
344
345 Ok(TpsCoefficients { weights, affine })
346 }
347
348 fn evaluate_tps(
350 &self,
351 point: &Point2D,
352 tps: &TpsCoefficients,
353 control_points: &[ControlPoint],
354 ) -> f64 {
355 let mut val = tps.affine[0] + tps.affine[1] * point.x + tps.affine[2] * point.y;
356
357 for (i, cp) in control_points.iter().enumerate() {
358 let r = point.distance(&cp.source);
359 val += tps.weights[i] * tps_kernel(r);
360 }
361
362 val
363 }
364
365 #[allow(clippy::cast_precision_loss)]
367 fn compute_errors(
368 &self,
369 points: &[ControlPoint],
370 tps_x: &TpsCoefficients,
371 tps_y: &TpsCoefficients,
372 ) -> (f64, f64) {
373 let mut sum_sq = 0.0;
374 let mut max_e = 0.0_f64;
375
376 for cp in points {
377 let px = self.evaluate_tps(&cp.source, tps_x, points);
378 let py = self.evaluate_tps(&cp.source, tps_y, points);
379 let err = ((px - cp.target.x).powi(2) + (py - cp.target.y).powi(2)).sqrt();
380 sum_sq += err * err;
381 max_e = max_e.max(err);
382 }
383
384 let rms = (sum_sq / points.len() as f64).sqrt();
385 (rms, max_e)
386 }
387
388 fn compute_bending_energy(
390 &self,
391 points: &[ControlPoint],
392 tps_x: &TpsCoefficients,
393 tps_y: &TpsCoefficients,
394 ) -> f64 {
395 let n = points.len();
396 let mut energy = 0.0;
397
398 for i in 0..n {
399 for j in 0..n {
400 let r = points[i].source.distance(&points[j].source);
401 let k = tps_kernel(r);
402 energy += tps_x.weights[i] * tps_x.weights[j] * k;
403 energy += tps_y.weights[i] * tps_y.weights[j] * k;
404 }
405 }
406
407 energy.abs()
408 }
409
410 fn gauss_solve(a: &mut [f64], b: &mut [f64], n: usize) -> AlignResult<Vec<f64>> {
412 for col in 0..n {
414 let mut max_val = a[col * n + col].abs();
416 let mut max_row = col;
417 for row in (col + 1)..n {
418 let val = a[row * n + col].abs();
419 if val > max_val {
420 max_val = val;
421 max_row = row;
422 }
423 }
424
425 if max_val < 1e-15 {
426 return Err(AlignError::NumericalError(
427 "Singular matrix in TPS solve".to_string(),
428 ));
429 }
430
431 if max_row != col {
433 for k in 0..n {
434 a.swap(col * n + k, max_row * n + k);
435 }
436 b.swap(col, max_row);
437 }
438
439 let pivot = a[col * n + col];
441 for row in (col + 1)..n {
442 let factor = a[row * n + col] / pivot;
443 for k in col..n {
444 a[row * n + k] -= factor * a[col * n + k];
445 }
446 b[row] -= factor * b[col];
447 }
448 }
449
450 let mut x = vec![0.0f64; n];
452 for col in (0..n).rev() {
453 let mut sum = b[col];
454 for k in (col + 1)..n {
455 sum -= a[col * n + k] * x[k];
456 }
457 x[col] = sum / a[col * n + col];
458 }
459
460 Ok(x)
461 }
462
463 #[must_use]
465 pub fn config(&self) -> &ElasticAlignConfig {
466 &self.config
467 }
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473
474 #[test]
475 fn test_control_point_creation() {
476 let cp = ControlPoint::new(Point2D::new(10.0, 20.0), Point2D::new(12.0, 22.0), 1.0);
477 assert!((cp.source.x - 10.0).abs() < f64::EPSILON);
478 assert!((cp.target.x - 12.0).abs() < f64::EPSILON);
479 }
480
481 #[test]
482 fn test_control_point_displacement() {
483 let cp = ControlPoint::new(Point2D::new(0.0, 0.0), Point2D::new(3.0, 4.0), 1.0);
484 let (dx, dy) = cp.displacement();
485 assert!((dx - 3.0).abs() < f64::EPSILON);
486 assert!((dy - 4.0).abs() < f64::EPSILON);
487 assert!((cp.displacement_magnitude() - 5.0).abs() < 1e-10);
488 }
489
490 #[test]
491 fn test_control_point_unit_weight() {
492 let cp = ControlPoint::with_unit_weight(Point2D::new(0.0, 0.0), Point2D::new(1.0, 1.0));
493 assert!((cp.weight - 1.0).abs() < f64::EPSILON);
494 }
495
496 #[test]
497 fn test_config_default() {
498 let config = ElasticAlignConfig::default();
499 assert!((config.regularization - 0.01).abs() < f64::EPSILON);
500 assert_eq!(config.min_control_points, 4);
501 }
502
503 #[test]
504 fn test_deformation_field_creation() {
505 let field = DeformationField::new(320, 240, 16);
506 assert_eq!(field.cols, 20);
507 assert_eq!(field.rows, 15);
508 assert_eq!(field.dx.len(), 300);
509 }
510
511 #[test]
512 fn test_deformation_field_get_set() {
513 let mut field = DeformationField::new(64, 64, 16);
514 field.set(1, 2, 3.5, -1.5);
515 let (dx, dy) = field.get(1, 2).expect("get should succeed");
516 assert!((dx - 3.5).abs() < f64::EPSILON);
517 assert!((dy - (-1.5)).abs() < f64::EPSILON);
518 }
519
520 #[test]
521 fn test_deformation_field_average() {
522 let mut field = DeformationField::new(32, 32, 16);
523 field.set(0, 0, 3.0, 4.0); field.set(1, 0, 0.0, 0.0);
525 field.set(0, 1, 0.0, 0.0);
526 field.set(1, 1, 0.0, 0.0);
527 assert!((field.average_displacement() - 1.25).abs() < 1e-10);
528 }
529
530 #[test]
531 fn test_deformation_field_max() {
532 let mut field = DeformationField::new(32, 32, 16);
533 field.set(0, 0, 3.0, 4.0);
534 field.set(1, 0, 1.0, 0.0);
535 assert!((field.max_displacement() - 5.0).abs() < 1e-10);
536 }
537
538 #[test]
539 fn test_tps_kernel() {
540 assert!((tps_kernel(0.0)).abs() < f64::EPSILON);
541 assert!((tps_kernel(1.0)).abs() < f64::EPSILON);
543 let e = std::f64::consts::E;
545 assert!((tps_kernel(e) - e * e).abs() < 1e-10);
546 }
547
548 #[test]
549 fn test_elastic_align_insufficient_points() {
550 let aligner = ElasticAligner::with_defaults();
551 let points = vec![ControlPoint::with_unit_weight(
552 Point2D::new(0.0, 0.0),
553 Point2D::new(1.0, 1.0),
554 )];
555 let result = aligner.align(&points);
556 assert!(result.is_err());
557 }
558
559 #[test]
560 fn test_elastic_align_identity() {
561 let aligner = ElasticAligner::new(ElasticAlignConfig {
562 regularization: 0.001,
563 min_control_points: 4,
564 max_displacement: 100.0,
565 grid_resolution: 16,
566 });
567
568 let points = vec![
570 ControlPoint::with_unit_weight(Point2D::new(0.0, 0.0), Point2D::new(0.0, 0.0)),
571 ControlPoint::with_unit_weight(Point2D::new(100.0, 0.0), Point2D::new(100.0, 0.0)),
572 ControlPoint::with_unit_weight(Point2D::new(0.0, 100.0), Point2D::new(0.0, 100.0)),
573 ControlPoint::with_unit_weight(Point2D::new(100.0, 100.0), Point2D::new(100.0, 100.0)),
574 ];
575
576 let result = aligner.align(&points).expect("result should be valid");
577 assert!(result.rms_error < 1.0);
579 }
580
581 #[test]
582 fn test_elastic_align_translation() {
583 let aligner = ElasticAligner::new(ElasticAlignConfig {
584 regularization: 0.001,
585 min_control_points: 4,
586 max_displacement: 100.0,
587 grid_resolution: 16,
588 });
589
590 let points = vec![
592 ControlPoint::with_unit_weight(Point2D::new(0.0, 0.0), Point2D::new(5.0, 3.0)),
593 ControlPoint::with_unit_weight(Point2D::new(100.0, 0.0), Point2D::new(105.0, 3.0)),
594 ControlPoint::with_unit_weight(Point2D::new(0.0, 100.0), Point2D::new(5.0, 103.0)),
595 ControlPoint::with_unit_weight(Point2D::new(100.0, 100.0), Point2D::new(105.0, 103.0)),
596 ];
597
598 let result = aligner.align(&points).expect("result should be valid");
599 let transformed = aligner.transform_point(&Point2D::new(50.0, 50.0), &result);
601 assert!((transformed.x - 55.0).abs() < 2.0);
603 assert!((transformed.y - 53.0).abs() < 2.0);
604 }
605
606 #[test]
607 fn test_elastic_align_max_displacement_exceeded() {
608 let aligner = ElasticAligner::new(ElasticAlignConfig {
609 max_displacement: 5.0,
610 ..ElasticAlignConfig::default()
611 });
612
613 let points = vec![
614 ControlPoint::with_unit_weight(Point2D::new(0.0, 0.0), Point2D::new(100.0, 100.0)),
615 ControlPoint::with_unit_weight(Point2D::new(10.0, 0.0), Point2D::new(110.0, 100.0)),
616 ControlPoint::with_unit_weight(Point2D::new(0.0, 10.0), Point2D::new(100.0, 110.0)),
617 ControlPoint::with_unit_weight(Point2D::new(10.0, 10.0), Point2D::new(110.0, 110.0)),
618 ];
619
620 let result = aligner.align(&points);
621 assert!(result.is_err());
622 }
623}