Skip to main content

burn_vision/transform/
transform2d.rs

1use burn_tensor::{
2    Tensor,
3    backend::Backend,
4    grid::affine_grid_2d,
5    ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},
6};
7
8/// 2D point transformation
9///
10/// Useful for resampling: rotating, scaling, translating, etc image tensors
11pub struct Transform2D {
12    // 2x3 transformation matrix, to be used with column vectors:
13    // T(x) = Ax
14    transform: [[f32; 3]; 2],
15}
16
17impl Transform2D {
18    /// Transforms an image
19    ///
20    /// * `img` - Images tensor with shape (batch_size, channels, height, width)
21    ///
22    /// # Returns
23    ///
24    /// A tensor with the same as the input
25    pub fn transform<B: Backend>(self, img: Tensor<B, 4>) -> Tensor<B, 4> {
26        let [batch_size, channels, height, width] = img.shape().dims();
27        let transform = Tensor::<B, 2>::from(self.transform);
28        let transform = transform.reshape([1, 2, 3]).expand([batch_size, 2, 3]);
29        let grid = affine_grid_2d(transform, [batch_size, channels, height, width]);
30
31        let options = GridSampleOptions::new(InterpolateMode::Bilinear)
32            .with_padding_mode(GridSamplePaddingMode::Border)
33            .with_align_corners(true);
34        img.grid_sample_2d(grid, options)
35    }
36
37    /// Makes a 2d transformation composed of other transformations
38    pub fn composed<I: IntoIterator<Item = Self>>(transforms: I) -> Self {
39        let mut result = Self::identity();
40        for t in transforms.into_iter() {
41            result = result.mul(t);
42        }
43        result
44    }
45
46    /// Multiply two affine transforms represented as 2x3 matrices
47    fn mul(self, other: Transform2D) -> Transform2D {
48        let mut result = [[0.0f32; 3]; 2];
49
50        // Row 0
51        result[0][0] = self.transform[0][0] * other.transform[0][0]
52            + self.transform[0][1] * other.transform[1][0];
53        result[0][1] = self.transform[0][0] * other.transform[0][1]
54            + self.transform[0][1] * other.transform[1][1];
55        result[0][2] = self.transform[0][0] * other.transform[0][2]
56            + self.transform[0][1] * other.transform[1][2]
57            + self.transform[0][2];
58
59        // Row 1
60        result[1][0] = self.transform[1][0] * other.transform[0][0]
61            + self.transform[1][1] * other.transform[1][0];
62        result[1][1] = self.transform[1][0] * other.transform[0][1]
63            + self.transform[1][1] * other.transform[1][1];
64        result[1][2] = self.transform[1][0] * other.transform[0][2]
65            + self.transform[1][1] * other.transform[1][2]
66            + self.transform[1][2];
67
68        Transform2D { transform: result }
69    }
70
71    /// Makes an identity transform (x = Ax)
72    pub fn identity() -> Self {
73        Self {
74            transform: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
75        }
76    }
77
78    /// Makes a [`Transform2D`] for rotating a tensor
79    ///
80    /// * `theta` - In radians, the rotation
81    /// * `cx` - Center of rotation, x
82    /// * `cy` - Center of rotation, y
83    pub fn rotation(theta: f32, cx: f32, cy: f32) -> Self {
84        let cos_theta = theta.cos();
85        let sin_theta = theta.sin();
86
87        let transform = [
88            [cos_theta, -sin_theta, cx - cos_theta * cx + sin_theta * cy],
89            [sin_theta, cos_theta, cy - sin_theta * cx - cos_theta * cy],
90        ];
91
92        Self { transform }
93    }
94
95    /// Makes a [`Transform2D`] for scaling an image tensor
96    ///
97    /// * `sx` - Scale factor in the x direction
98    /// * `sy` - Scale factor in the y direction
99    /// * `cx` - Center of scaling, x
100    /// * `cy` - Center of scaling, y
101    pub fn scale(sx: f32, sy: f32, cx: f32, cy: f32) -> Self {
102        let transform = [[sx, 0.0, cx - sx * cx], [0.0, sy, cy - sy * cy]];
103
104        Self { transform }
105    }
106
107    /// Makes a [`Transform2D`] for translating an image tensor
108    ///
109    /// * `tx` - Translation in the x direction
110    /// * `ty` - Translation in the y direction
111    pub fn translation(tx: f32, ty: f32) -> Self {
112        let transform = [[1.0, 0.0, tx], [0.0, 1.0, ty]];
113
114        Self { transform }
115    }
116
117    /// Applies a general shear transformation around the image center,
118    /// combining both X and Y shear.
119    ///
120    /// # Arguments
121    /// * `shx` - Shear factor along the X-axis.
122    /// * `shy` - Shear factor along the Y-axis.
123    /// * `cx`, `cy` - Coordinates of the image center.
124    ///
125    /// # Returns
126    /// * `Self` with a combined shear transform matrix.
127    pub fn shear(shx: f32, shy: f32, cx: f32, cy: f32) -> Self {
128        let transform = [[1.0, shx, -shx * cy], [shy, 1.0, -shy * cx]];
129
130        Self { transform }
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137    use burn_ndarray::NdArray;
138    use burn_tensor::Tolerance;
139    type B = NdArray;
140
141    #[test]
142    fn transform_identity_translation() {
143        let t = Transform2D::translation(0.0, 0.0);
144        let image_original = Tensor::<B, 4>::from([[[[1., 0.], [0., 2.]]]]);
145        let image_transformed = t.transform(image_original.clone());
146        image_original
147            .to_data()
148            .assert_approx_eq(&image_transformed.to_data(), Tolerance::<f32>::balanced());
149    }
150
151    #[test]
152    fn transform_translation() {
153        let t = Transform2D::translation(1., 1.);
154        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
155        // This result would change if the padding method is different
156        let image_expected = Tensor::<B, 4>::from([[[[2.5, 3.], [3.5, 4.]]]]);
157        let image = t.transform(image);
158        image_expected
159            .to_data()
160            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
161    }
162
163    #[test]
164    fn transform_rotation_90_degrees() {
165        let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, 0.0, 0.0);
166        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
167        let image_expected = Tensor::<B, 4>::from([[[[2., 4.], [1., 3.]]]]);
168        let image = t.transform(image);
169        image_expected
170            .to_data()
171            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
172    }
173
174    #[test]
175    fn transform_rotation_around_corner() {
176        let cx = 1.;
177        let cy = -1.;
178        let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, cx, cy);
179        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
180        // This result would change if the padding method is different
181        let image_expected = Tensor::<B, 4>::from([[[[2., 2.], [1., 1.]]]]);
182        let image = t.transform(image);
183        image_expected
184            .to_data()
185            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
186    }
187
188    #[test]
189    fn transform_scale() {
190        let cx = 0.0;
191        let cy = 0.0;
192        let t = Transform2D::scale(0.5, 0.5, cx, cy);
193        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
194        let image_expected = Tensor::<B, 4>::from([[[[1.75, 2.25], [2.75, 3.25]]]]);
195        let image = t.transform(image);
196        image_expected
197            .to_data()
198            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
199    }
200
201    #[test]
202    fn transform_scale_around_corner() {
203        let cx = 1.;
204        let cy = -1.;
205        let t = Transform2D::scale(0.5, 0.5, cx, cy);
206        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
207        let image_expected = Tensor::<B, 4>::from([[[[1.5, 2.], [2.5, 3.]]]]);
208        let image = t.transform(image);
209        image_expected
210            .to_data()
211            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
212    }
213
214    #[test]
215    fn transform_combined() {
216        let t1 = Transform2D::translation(0.2, -0.5);
217        let t2 = Transform2D::rotation(std::f32::consts::FRAC_PI_3, 0., 0.);
218        let t = Transform2D::composed([t1, t2]);
219
220        let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
221        // This result would change if the padding method is different
222        let image_expected =
223            Tensor::<B, 4>::from([[[[1.7830127, 2.8660254], [1.1339746, 3.2830124]]]]);
224        let image = t.transform(image);
225        image_expected
226            .to_data()
227            .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
228    }
229}