burn_vision/transform/
transform2d.rs

1use burn_tensor::{Tensor, backend::Backend, grid::affine_grid_2d, ops::InterpolateMode};
2
3/// 2D point transformation
4///
5/// Useful for resampling: rotating, scaling, translating, etc image tensors
6pub struct Transform2D {
7    // 2x3 transformation matrix, to be used with column vectors:
8    // T(x) = Ax
9    transform: [[f32; 3]; 2],
10}
11
12impl Transform2D {
13    /// Transforms an image
14    ///
15    /// * `img` - Images tensor with shape (batch_size, channels, height, width)
16    ///
17    /// # Returns
18    ///
19    /// A tensor with the same as the input
20    pub fn transform<B: Backend>(self, img: Tensor<B, 4>) -> Tensor<B, 4> {
21        let [batch_size, channels, height, width] = img.shape().dims();
22        let transform = Tensor::<B, 2>::from(self.transform);
23        let transform = transform.reshape([1, 2, 3]).expand([batch_size, 2, 3]);
24        let grid = affine_grid_2d(transform, [batch_size, channels, height, width]);
25
26        img.grid_sample_2d(grid, InterpolateMode::Bilinear)
27    }
28
29    /// Makes a 2d transformation composed of other transformations
30    pub fn composed<I: IntoIterator<Item = Self>>(transforms: I) -> Self {
31        let mut result = Self::identity();
32        for t in transforms.into_iter() {
33            result = result.mul(t);
34        }
35        result
36    }
37
38    /// Multiply two affine transforms represented as 2x3 matrices
39    fn mul(self, other: Transform2D) -> Transform2D {
40        let mut result = [[0.0f32; 3]; 2];
41
42        // Row 0
43        result[0][0] = self.transform[0][0] * other.transform[0][0]
44            + self.transform[0][1] * other.transform[1][0];
45        result[0][1] = self.transform[0][0] * other.transform[0][1]
46            + self.transform[0][1] * other.transform[1][1];
47        result[0][2] = self.transform[0][0] * other.transform[0][2]
48            + self.transform[0][1] * other.transform[1][2]
49            + self.transform[0][2];
50
51        // Row 1
52        result[1][0] = self.transform[1][0] * other.transform[0][0]
53            + self.transform[1][1] * other.transform[1][0];
54        result[1][1] = self.transform[1][0] * other.transform[0][1]
55            + self.transform[1][1] * other.transform[1][1];
56        result[1][2] = self.transform[1][0] * other.transform[0][2]
57            + self.transform[1][1] * other.transform[1][2]
58            + self.transform[1][2];
59
60        Transform2D { transform: result }
61    }
62
63    /// Makes an identity transform (x = Ax)
64    pub fn identity() -> Self {
65        Self {
66            transform: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
67        }
68    }
69
70    /// Makes a [`Transform2D`] for rotating a tensor
71    ///
72    /// * `theta` - In radians, the rotation
73    /// * `cx` - Center of rotation, x
74    /// * `cy` - Center of rotation, y
75    pub fn rotation(theta: f32, cx: f32, cy: f32) -> Self {
76        let cos_theta = theta.cos();
77        let sin_theta = theta.sin();
78
79        let transform = [
80            [cos_theta, -sin_theta, cx - cos_theta * cx + sin_theta * cy],
81            [sin_theta, cos_theta, cy - sin_theta * cx - cos_theta * cy],
82        ];
83
84        Self { transform }
85    }
86
87    /// Makes a [`Transform2D`] for scaling an image tensor
88    ///
89    /// * `sx` - Scale factor in the x direction
90    /// * `sy` - Scale factor in the y direction
91    /// * `cx` - Center of scaling, x
92    /// * `cy` - Center of scaling, y
93    pub fn scale(sx: f32, sy: f32, cx: f32, cy: f32) -> Self {
94        let transform = [[sx, 0.0, cx - sx * cx], [0.0, sy, cy - sy * cy]];
95
96        Self { transform }
97    }
98
99    /// Makes a [`Transform2D`] for translating an image tensor
100    ///
101    /// * `tx` - Translation in the x direction
102    /// * `ty` - Translation in the y direction
103    pub fn translation(tx: f32, ty: f32) -> Self {
104        let transform = [[1.0, 0.0, tx], [0.0, 1.0, ty]];
105
106        Self { transform }
107    }
108
109    /// Applies a general shear transformation around the image center,
110    /// combining both X and Y shear.
111    ///
112    /// # Arguments
113    /// * `shx` - Shear factor along the X-axis.
114    /// * `shy` - Shear factor along the Y-axis.
115    /// * `cx`, `cy` - Coordinates of the image center.
116    ///
117    /// # Returns
118    /// * `Self` with a combined shear transform matrix.
119    pub fn shear(shx: f32, shy: f32, cx: f32, cy: f32) -> Self {
120        let transform = [[1.0, shx, -shx * cy], [shy, 1.0, -shy * cx]];
121
122        Self { transform }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use burn_ndarray::NdArray;
130    use burn_tensor::Tolerance;
131    type B = NdArray;
132
133    #[test]
134    fn transform_identity_translation() {
135        let t = Transform2D::translation(0.0, 0.0);
136        let image_original = Tensor::<B, 4>::from([[[[1., 0.], [0., 2.]]]]);
137        let image_transformed = t.transform(image_original.clone());
138        image_original
139            .to_data()
140            .assert_approx_eq(&image_transformed.to_data(), Tolerance::<f32>::balanced());
141    }
142
143    #[test]
144    fn transform_translation() {
145        let t = Transform2D::translation(1., 1.);
146        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
147        // This result would change if the padding method is different
148        let image_expected = Tensor::<B, 4>::from([[[[2.5, 3.], [3.5, 4.]]]]);
149        let image = t.transform(image);
150        image_expected
151            .to_data()
152            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
153    }
154
155    #[test]
156    fn transform_rotation_90_degrees() {
157        let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, 0.0, 0.0);
158        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
159        let image_expected = Tensor::<B, 4>::from([[[[2., 4.], [1., 3.]]]]);
160        let image = t.transform(image);
161        image_expected
162            .to_data()
163            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
164    }
165
166    #[test]
167    fn transform_rotation_around_corner() {
168        let cx = 1.;
169        let cy = -1.;
170        let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, cx, cy);
171        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
172        // This result would change if the padding method is different
173        let image_expected = Tensor::<B, 4>::from([[[[2., 2.], [1., 1.]]]]);
174        let image = t.transform(image);
175        image_expected
176            .to_data()
177            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
178    }
179
180    #[test]
181    fn transform_scale() {
182        let cx = 0.0;
183        let cy = 0.0;
184        let t = Transform2D::scale(0.5, 0.5, cx, cy);
185        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
186        let image_expected = Tensor::<B, 4>::from([[[[1.75, 2.25], [2.75, 3.25]]]]);
187        let image = t.transform(image);
188        image_expected
189            .to_data()
190            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
191    }
192
193    #[test]
194    fn transform_scale_around_corner() {
195        let cx = 1.;
196        let cy = -1.;
197        let t = Transform2D::scale(0.5, 0.5, cx, cy);
198        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
199        let image_expected = Tensor::<B, 4>::from([[[[1.5, 2.], [2.5, 3.]]]]);
200        let image = t.transform(image);
201        image_expected
202            .to_data()
203            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
204    }
205
206    #[test]
207    fn transform_combined() {
208        let t1 = Transform2D::translation(0.2, -0.5);
209        let t2 = Transform2D::rotation(std::f32::consts::FRAC_PI_3, 0., 0.);
210        let t = Transform2D::composed([t1, t2]);
211
212        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
213        // This result would change if the padding method is different
214        let image_expected =
215            Tensor::<B, 4>::from([[[[1.7830127, 2.8660254], [1.1339746, 3.2830124]]]]);
216        let image = t.transform(image);
217        image_expected
218            .to_data()
219            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
220    }
221}