Skip to main content

oximedia_align/
transform.rs

1//! Geometric transformations for image alignment.
2//!
3//! This module provides tools for applying geometric transformations:
4//!
5//! - Image warping
6//! - Bilinear and bicubic interpolation
7//! - Transformation composition
8//! - Region-based transformations
9
10use crate::spatial::{AffineTransform, Homography};
11use crate::{AlignError, AlignResult, Point2D};
12use nalgebra::Matrix3;
13
14/// Interpolation method for image warping
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum InterpolationMethod {
17    /// Nearest neighbor (fastest)
18    Nearest,
19    /// Bilinear interpolation
20    Bilinear,
21    /// Bicubic interpolation (highest quality)
22    Bicubic,
23}
24
25/// Image warper for applying transformations
26pub struct ImageWarper {
27    /// Interpolation method
28    pub interpolation: InterpolationMethod,
29    /// Border handling mode
30    pub border_mode: BorderMode,
31}
32
33/// Border handling mode
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35pub enum BorderMode {
36    /// Constant border (fill with specified value)
37    Constant(u8),
38    /// Replicate edge pixels
39    Replicate,
40    /// Reflect border
41    Reflect,
42    /// Wrap around
43    Wrap,
44}
45
46impl Default for ImageWarper {
47    fn default() -> Self {
48        Self {
49            interpolation: InterpolationMethod::Bilinear,
50            border_mode: BorderMode::Constant(0),
51        }
52    }
53}
54
55impl ImageWarper {
56    /// Create a new image warper
57    #[must_use]
58    pub fn new(interpolation: InterpolationMethod, border_mode: BorderMode) -> Self {
59        Self {
60            interpolation,
61            border_mode,
62        }
63    }
64
65    /// Warp image using homography
66    ///
67    /// # Errors
68    /// Returns error if warping fails
69    pub fn warp_homography(
70        &self,
71        input: &[u8],
72        width: usize,
73        height: usize,
74        homography: &Homography,
75        output_width: usize,
76        output_height: usize,
77    ) -> AlignResult<Vec<u8>> {
78        if input.len() != width * height * 3 {
79            return Err(AlignError::InvalidConfig("Input size mismatch".to_string()));
80        }
81
82        let mut output = vec![0u8; output_width * output_height * 3];
83
84        // Invert homography for backward mapping
85        let inv_h = homography.inverse()?;
86
87        for y in 0..output_height {
88            for x in 0..output_width {
89                let dst = Point2D::new(x as f64, y as f64);
90                let src = inv_h.transform(&dst);
91
92                let pixel = self.sample_pixel(input, width, height, src.x as f32, src.y as f32);
93
94                let idx = (y * output_width + x) * 3;
95                output[idx..idx + 3].copy_from_slice(&pixel);
96            }
97        }
98
99        Ok(output)
100    }
101
102    /// Warp image using affine transform
103    ///
104    /// # Errors
105    /// Returns error if warping fails
106    pub fn warp_affine(
107        &self,
108        input: &[u8],
109        width: usize,
110        height: usize,
111        transform: &AffineTransform,
112        output_width: usize,
113        output_height: usize,
114    ) -> AlignResult<Vec<u8>> {
115        if input.len() != width * height * 3 {
116            return Err(AlignError::InvalidConfig("Input size mismatch".to_string()));
117        }
118
119        let mut output = vec![0u8; output_width * output_height * 3];
120
121        // Compute inverse transform for backward mapping
122        let inv = self.invert_affine(transform)?;
123
124        for y in 0..output_height {
125            for x in 0..output_width {
126                let dst = Point2D::new(x as f64, y as f64);
127                let src = inv.transform(&dst);
128
129                let pixel = self.sample_pixel(input, width, height, src.x as f32, src.y as f32);
130
131                let idx = (y * output_width + x) * 3;
132                output[idx..idx + 3].copy_from_slice(&pixel);
133            }
134        }
135
136        Ok(output)
137    }
138
139    /// Sample pixel with interpolation
140    fn sample_pixel(&self, image: &[u8], width: usize, height: usize, x: f32, y: f32) -> [u8; 3] {
141        match self.interpolation {
142            InterpolationMethod::Nearest => self.sample_nearest(image, width, height, x, y),
143            InterpolationMethod::Bilinear => self.sample_bilinear(image, width, height, x, y),
144            InterpolationMethod::Bicubic => self.sample_bicubic(image, width, height, x, y),
145        }
146    }
147
148    /// Nearest neighbor sampling
149    fn sample_nearest(&self, image: &[u8], width: usize, height: usize, x: f32, y: f32) -> [u8; 3] {
150        let xi = x.round() as isize;
151        let yi = y.round() as isize;
152
153        if xi >= 0 && xi < width as isize && yi >= 0 && yi < height as isize {
154            let idx = (yi as usize * width + xi as usize) * 3;
155            if idx + 2 < image.len() {
156                return [image[idx], image[idx + 1], image[idx + 2]];
157            }
158        }
159
160        self.get_border_value()
161    }
162
163    /// Bilinear interpolation
164    fn sample_bilinear(
165        &self,
166        image: &[u8],
167        width: usize,
168        height: usize,
169        x: f32,
170        y: f32,
171    ) -> [u8; 3] {
172        let x0 = x.floor() as isize;
173        let y0 = y.floor() as isize;
174        let x1 = x0 + 1;
175        let y1 = y0 + 1;
176
177        let dx = x - x0 as f32;
178        let dy = y - y0 as f32;
179
180        let p00 = self.get_pixel(image, width, height, x0, y0);
181        let p10 = self.get_pixel(image, width, height, x1, y0);
182        let p01 = self.get_pixel(image, width, height, x0, y1);
183        let p11 = self.get_pixel(image, width, height, x1, y1);
184
185        let mut result = [0u8; 3];
186        for c in 0..3 {
187            let v0 = f32::from(p00[c]) * (1.0 - dx) + f32::from(p10[c]) * dx;
188            let v1 = f32::from(p01[c]) * (1.0 - dx) + f32::from(p11[c]) * dx;
189            let v = v0 * (1.0 - dy) + v1 * dy;
190            result[c] = v.round().clamp(0.0, 255.0) as u8;
191        }
192
193        result
194    }
195
196    /// Bicubic interpolation
197    fn sample_bicubic(&self, image: &[u8], width: usize, height: usize, x: f32, y: f32) -> [u8; 3] {
198        let x0 = x.floor() as isize;
199        let y0 = y.floor() as isize;
200
201        let dx = x - x0 as f32;
202        let dy = y - y0 as f32;
203
204        let mut result = [0u8; 3];
205
206        for c in 0..3 {
207            let mut value = 0.0f32;
208
209            // Bicubic kernel is 4x4
210            for j in -1..=2 {
211                for i in -1..=2 {
212                    let pixel = self.get_pixel(image, width, height, x0 + i, y0 + j);
213                    let wx = Self::cubic_weight(i as f32 - dx);
214                    let wy = Self::cubic_weight(j as f32 - dy);
215                    value += f32::from(pixel[c]) * wx * wy;
216                }
217            }
218
219            result[c] = value.round().clamp(0.0, 255.0) as u8;
220        }
221
222        result
223    }
224
225    /// Cubic interpolation weight (Mitchell-Netravali filter)
226    fn cubic_weight(x: f32) -> f32 {
227        let x = x.abs();
228        if x < 1.0 {
229            (1.5 * x - 2.5) * x * x + 1.0
230        } else if x < 2.0 {
231            ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0
232        } else {
233            0.0
234        }
235    }
236
237    /// Get pixel with border handling
238    fn get_pixel(&self, image: &[u8], width: usize, height: usize, x: isize, y: isize) -> [u8; 3] {
239        let (x_clamped, y_clamped) = self.apply_border_mode(x, y, width, height);
240
241        if x_clamped >= 0
242            && x_clamped < width as isize
243            && y_clamped >= 0
244            && y_clamped < height as isize
245        {
246            let idx = (y_clamped as usize * width + x_clamped as usize) * 3;
247            if idx + 2 < image.len() {
248                return [image[idx], image[idx + 1], image[idx + 2]];
249            }
250        }
251
252        self.get_border_value()
253    }
254
255    /// Apply border mode
256    fn apply_border_mode(&self, x: isize, y: isize, width: usize, height: usize) -> (isize, isize) {
257        match self.border_mode {
258            BorderMode::Constant(_) => (x, y),
259            BorderMode::Replicate => (
260                x.clamp(0, width as isize - 1),
261                y.clamp(0, height as isize - 1),
262            ),
263            BorderMode::Reflect => (
264                Self::reflect_coord(x, width),
265                Self::reflect_coord(y, height),
266            ),
267            BorderMode::Wrap => (
268                ((x % width as isize + width as isize) % width as isize),
269                ((y % height as isize + height as isize) % height as isize),
270            ),
271        }
272    }
273
274    /// Reflect coordinate
275    fn reflect_coord(x: isize, size: usize) -> isize {
276        let size = size as isize;
277        if x < 0 {
278            -x - 1
279        } else if x >= size {
280            2 * size - x - 1
281        } else {
282            x
283        }
284    }
285
286    /// Get border value
287    fn get_border_value(&self) -> [u8; 3] {
288        match self.border_mode {
289            BorderMode::Constant(v) => [v, v, v],
290            _ => [0, 0, 0],
291        }
292    }
293
294    /// Invert affine transform
295    fn invert_affine(&self, transform: &AffineTransform) -> AlignResult<AffineTransform> {
296        let a = transform.matrix[(0, 0)];
297        let b = transform.matrix[(0, 1)];
298        let c = transform.matrix[(1, 0)];
299        let d = transform.matrix[(1, 1)];
300        let tx = transform.matrix[(0, 2)];
301        let ty = transform.matrix[(1, 2)];
302
303        let det = a * d - b * c;
304
305        if det.abs() < 1e-10 {
306            return Err(AlignError::NumericalError("Singular matrix".to_string()));
307        }
308
309        let inv_det = 1.0 / det;
310
311        let inv_matrix = nalgebra::Matrix2x3::new(
312            d * inv_det,
313            -b * inv_det,
314            (b * ty - d * tx) * inv_det,
315            -c * inv_det,
316            a * inv_det,
317            (c * tx - a * ty) * inv_det,
318        );
319
320        Ok(AffineTransform::new(inv_matrix))
321    }
322}
323
324/// Transformation builder for composing multiple transforms
325pub struct TransformBuilder {
326    /// Accumulated transformation matrix
327    matrix: Matrix3<f64>,
328}
329
330impl Default for TransformBuilder {
331    fn default() -> Self {
332        Self::new()
333    }
334}
335
336impl TransformBuilder {
337    /// Create a new transform builder
338    #[must_use]
339    pub fn new() -> Self {
340        Self {
341            matrix: Matrix3::identity(),
342        }
343    }
344
345    /// Add translation
346    #[must_use]
347    pub fn translate(mut self, tx: f64, ty: f64) -> Self {
348        let t = Matrix3::new(1.0, 0.0, tx, 0.0, 1.0, ty, 0.0, 0.0, 1.0);
349        self.matrix = t * self.matrix;
350        self
351    }
352
353    /// Add rotation (angle in radians)
354    #[must_use]
355    pub fn rotate(mut self, angle: f64) -> Self {
356        let c = angle.cos();
357        let s = angle.sin();
358        let r = Matrix3::new(c, -s, 0.0, s, c, 0.0, 0.0, 0.0, 1.0);
359        self.matrix = r * self.matrix;
360        self
361    }
362
363    /// Add scale
364    #[must_use]
365    pub fn scale(mut self, sx: f64, sy: f64) -> Self {
366        let s = Matrix3::new(sx, 0.0, 0.0, 0.0, sy, 0.0, 0.0, 0.0, 1.0);
367        self.matrix = s * self.matrix;
368        self
369    }
370
371    /// Add shear
372    #[must_use]
373    pub fn shear(mut self, shx: f64, shy: f64) -> Self {
374        let sh = Matrix3::new(1.0, shx, 0.0, shy, 1.0, 0.0, 0.0, 0.0, 1.0);
375        self.matrix = sh * self.matrix;
376        self
377    }
378
379    /// Build homography
380    #[must_use]
381    pub fn build(self) -> Homography {
382        Homography::new(self.matrix)
383    }
384}
385
386/// Mesh warper for non-rigid transformations
387pub struct MeshWarper {
388    /// Grid width
389    pub grid_width: usize,
390    /// Grid height
391    pub grid_height: usize,
392    /// Control points
393    control_points: Vec<Vec<Point2D>>,
394}
395
396impl MeshWarper {
397    /// Create a new mesh warper
398    #[must_use]
399    pub fn new(grid_width: usize, grid_height: usize) -> Self {
400        let mut control_points = Vec::new();
401
402        for _y in 0..=grid_height {
403            let mut row = Vec::new();
404            for _x in 0..=grid_width {
405                row.push(Point2D::new(0.0, 0.0));
406            }
407            control_points.push(row);
408        }
409
410        Self {
411            grid_width,
412            grid_height,
413            control_points,
414        }
415    }
416
417    /// Set control point
418    pub fn set_control_point(&mut self, x: usize, y: usize, point: Point2D) {
419        if y < self.control_points.len() && x < self.control_points[y].len() {
420            self.control_points[y][x] = point;
421        }
422    }
423
424    /// Initialize regular grid
425    pub fn init_regular_grid(&mut self, width: usize, height: usize) {
426        let dx = width as f64 / self.grid_width as f64;
427        let dy = height as f64 / self.grid_height as f64;
428
429        for y in 0..=self.grid_height {
430            for x in 0..=self.grid_width {
431                self.control_points[y][x] = Point2D::new(x as f64 * dx, y as f64 * dy);
432            }
433        }
434    }
435
436    /// Warp image using mesh
437    ///
438    /// # Errors
439    /// Returns error if warping fails
440    pub fn warp(&self, input: &[u8], width: usize, height: usize) -> AlignResult<Vec<u8>> {
441        if input.len() != width * height * 3 {
442            return Err(AlignError::InvalidConfig("Input size mismatch".to_string()));
443        }
444
445        let mut output = vec![0u8; width * height * 3];
446        let warper = ImageWarper::default();
447
448        let dx = width as f64 / self.grid_width as f64;
449        let dy = height as f64 / self.grid_height as f64;
450
451        for y in 0..height {
452            for x in 0..width {
453                // Find grid cell
454                let gx = (x as f64 / dx).floor() as usize;
455                let gy = (y as f64 / dy).floor() as usize;
456
457                if gx < self.grid_width && gy < self.grid_height {
458                    // Bilinear interpolation within cell
459                    let tx = (x as f64 - gx as f64 * dx) / dx;
460                    let ty = (y as f64 - gy as f64 * dy) / dy;
461
462                    let p00 = &self.control_points[gy][gx];
463                    let p10 = &self.control_points[gy][gx + 1];
464                    let p01 = &self.control_points[gy + 1][gx];
465                    let p11 = &self.control_points[gy + 1][gx + 1];
466
467                    let src_x = p00.x * (1.0 - tx) * (1.0 - ty)
468                        + p10.x * tx * (1.0 - ty)
469                        + p01.x * (1.0 - tx) * ty
470                        + p11.x * tx * ty;
471
472                    let src_y = p00.y * (1.0 - tx) * (1.0 - ty)
473                        + p10.y * tx * (1.0 - ty)
474                        + p01.y * (1.0 - tx) * ty
475                        + p11.y * tx * ty;
476
477                    let pixel =
478                        warper.sample_pixel(input, width, height, src_x as f32, src_y as f32);
479
480                    let idx = (y * width + x) * 3;
481                    output[idx..idx + 3].copy_from_slice(&pixel);
482                }
483            }
484        }
485
486        Ok(output)
487    }
488}
489
490/// Perspective quad warper
491pub struct QuadWarper;
492
493impl QuadWarper {
494    /// Warp a quadrilateral region to rectangle
495    ///
496    /// # Errors
497    /// Returns error if warping fails
498    pub fn warp_quad(
499        input: &[u8],
500        width: usize,
501        height: usize,
502        src_quad: &[Point2D; 4],
503        dst_width: usize,
504        dst_height: usize,
505    ) -> AlignResult<Vec<u8>> {
506        // Build homography from quad to rectangle
507        let dst_quad = [
508            Point2D::new(0.0, 0.0),
509            Point2D::new(dst_width as f64, 0.0),
510            Point2D::new(dst_width as f64, dst_height as f64),
511            Point2D::new(0.0, dst_height as f64),
512        ];
513
514        let homography = Self::compute_quad_to_quad_homography(src_quad, &dst_quad)?;
515
516        let warper = ImageWarper::default();
517        warper.warp_homography(input, width, height, &homography, dst_width, dst_height)
518    }
519
520    /// Compute homography from quad to quad
521    fn compute_quad_to_quad_homography(
522        src: &[Point2D; 4],
523        dst: &[Point2D; 4],
524    ) -> AlignResult<Homography> {
525        // Build system of equations for DLT
526        let mut a = nalgebra::DMatrix::zeros(8, 9);
527
528        for i in 0..4 {
529            let x = src[i].x;
530            let y = src[i].y;
531            let xp = dst[i].x;
532            let yp = dst[i].y;
533
534            a[(i * 2, 0)] = -x;
535            a[(i * 2, 1)] = -y;
536            a[(i * 2, 2)] = -1.0;
537            a[(i * 2, 6)] = xp * x;
538            a[(i * 2, 7)] = xp * y;
539            a[(i * 2, 8)] = xp;
540
541            a[(i * 2 + 1, 3)] = -x;
542            a[(i * 2 + 1, 4)] = -y;
543            a[(i * 2 + 1, 5)] = -1.0;
544            a[(i * 2 + 1, 6)] = yp * x;
545            a[(i * 2 + 1, 7)] = yp * y;
546            a[(i * 2 + 1, 8)] = yp;
547        }
548
549        let svd = a.svd(true, true);
550        let v = svd
551            .v_t
552            .ok_or_else(|| AlignError::NumericalError("SVD failed".to_string()))?;
553
554        let h_vec = v.row(8);
555
556        if h_vec[8].abs() < 1e-10 {
557            return Err(AlignError::NumericalError(
558                "Degenerate solution".to_string(),
559            ));
560        }
561
562        let scale = h_vec[8];
563        let matrix = Matrix3::new(
564            h_vec[0] / scale,
565            h_vec[1] / scale,
566            h_vec[2] / scale,
567            h_vec[3] / scale,
568            h_vec[4] / scale,
569            h_vec[5] / scale,
570            h_vec[6] / scale,
571            h_vec[7] / scale,
572            1.0,
573        );
574
575        Ok(Homography::new(matrix))
576    }
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582
583    #[test]
584    fn test_interpolation_method() {
585        assert_eq!(InterpolationMethod::Nearest, InterpolationMethod::Nearest);
586        assert_ne!(InterpolationMethod::Nearest, InterpolationMethod::Bilinear);
587    }
588
589    #[test]
590    fn test_border_mode() {
591        let mode = BorderMode::Constant(128);
592        match mode {
593            BorderMode::Constant(v) => assert_eq!(v, 128),
594            _ => panic!("Wrong border mode"),
595        }
596    }
597
598    #[test]
599    fn test_image_warper_creation() {
600        let warper = ImageWarper::default();
601        assert_eq!(warper.interpolation, InterpolationMethod::Bilinear);
602    }
603
604    #[test]
605    fn test_cubic_weight() {
606        let w = ImageWarper::cubic_weight(0.0);
607        assert!((w - 1.0).abs() < 1e-6);
608
609        let w = ImageWarper::cubic_weight(2.0);
610        assert!(w.abs() < 1e-6);
611    }
612
613    #[test]
614    fn test_transform_builder() {
615        let transform = TransformBuilder::new()
616            .translate(10.0, 20.0)
617            .rotate(std::f64::consts::PI / 4.0)
618            .scale(2.0, 2.0)
619            .build();
620
621        let point = Point2D::new(0.0, 0.0);
622        let transformed = transform.transform(&point);
623        assert!(transformed.x.is_finite());
624        assert!(transformed.y.is_finite());
625    }
626
627    #[test]
628    fn test_mesh_warper_creation() {
629        let warper = MeshWarper::new(10, 10);
630        assert_eq!(warper.grid_width, 10);
631        assert_eq!(warper.grid_height, 10);
632    }
633
634    #[test]
635    fn test_mesh_warper_control_points() {
636        let mut warper = MeshWarper::new(2, 2);
637        warper.set_control_point(1, 1, Point2D::new(100.0, 100.0));
638        assert_eq!(warper.control_points[1][1].x, 100.0);
639        assert_eq!(warper.control_points[1][1].y, 100.0);
640    }
641
642    #[test]
643    fn test_mesh_warper_regular_grid() {
644        let mut warper = MeshWarper::new(4, 4);
645        warper.init_regular_grid(400, 400);
646        assert_eq!(warper.control_points[0][0].x, 0.0);
647        assert_eq!(warper.control_points[4][4].x, 400.0);
648        assert_eq!(warper.control_points[4][4].y, 400.0);
649    }
650}