1use burn_tensor::{Tensor, backend::Backend, grid::affine_grid_2d, ops::InterpolateMode};
2
3pub struct Transform2D {
7 transform: [[f32; 3]; 2],
10}
11
12impl Transform2D {
13 pub fn transform<B: Backend>(self, img: Tensor<B, 4>) -> Tensor<B, 4> {
21 let [batch_size, channels, height, width] = img.shape().dims();
22 let transform = Tensor::<B, 2>::from(self.transform);
23 let transform = transform.reshape([1, 2, 3]).expand([batch_size, 2, 3]);
24 let grid = affine_grid_2d(transform, [batch_size, channels, height, width]);
25
26 img.grid_sample_2d(grid, InterpolateMode::Bilinear)
27 }
28
29 pub fn composed<I: IntoIterator<Item = Self>>(transforms: I) -> Self {
31 let mut result = Self::identity();
32 for t in transforms.into_iter() {
33 result = result.mul(t);
34 }
35 result
36 }
37
38 fn mul(self, other: Transform2D) -> Transform2D {
40 let mut result = [[0.0f32; 3]; 2];
41
42 result[0][0] = self.transform[0][0] * other.transform[0][0]
44 + self.transform[0][1] * other.transform[1][0];
45 result[0][1] = self.transform[0][0] * other.transform[0][1]
46 + self.transform[0][1] * other.transform[1][1];
47 result[0][2] = self.transform[0][0] * other.transform[0][2]
48 + self.transform[0][1] * other.transform[1][2]
49 + self.transform[0][2];
50
51 result[1][0] = self.transform[1][0] * other.transform[0][0]
53 + self.transform[1][1] * other.transform[1][0];
54 result[1][1] = self.transform[1][0] * other.transform[0][1]
55 + self.transform[1][1] * other.transform[1][1];
56 result[1][2] = self.transform[1][0] * other.transform[0][2]
57 + self.transform[1][1] * other.transform[1][2]
58 + self.transform[1][2];
59
60 Transform2D { transform: result }
61 }
62
63 pub fn identity() -> Self {
65 Self {
66 transform: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
67 }
68 }
69
70 pub fn rotation(theta: f32, cx: f32, cy: f32) -> Self {
76 let cos_theta = theta.cos();
77 let sin_theta = theta.sin();
78
79 let transform = [
80 [cos_theta, -sin_theta, cx - cos_theta * cx + sin_theta * cy],
81 [sin_theta, cos_theta, cy - sin_theta * cx - cos_theta * cy],
82 ];
83
84 Self { transform }
85 }
86
87 pub fn scale(sx: f32, sy: f32, cx: f32, cy: f32) -> Self {
94 let transform = [[sx, 0.0, cx - sx * cx], [0.0, sy, cy - sy * cy]];
95
96 Self { transform }
97 }
98
99 pub fn translation(tx: f32, ty: f32) -> Self {
104 let transform = [[1.0, 0.0, tx], [0.0, 1.0, ty]];
105
106 Self { transform }
107 }
108
109 pub fn shear(shx: f32, shy: f32, cx: f32, cy: f32) -> Self {
120 let transform = [[1.0, shx, -shx * cy], [shy, 1.0, -shy * cx]];
121
122 Self { transform }
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use burn_ndarray::NdArray;
130 use burn_tensor::Tolerance;
131 type B = NdArray;
132
133 #[test]
134 fn transform_identity_translation() {
135 let t = Transform2D::translation(0.0, 0.0);
136 let image_original = Tensor::<B, 4>::from([[[[1., 0.], [0., 2.]]]]);
137 let image_transformed = t.transform(image_original.clone());
138 image_original
139 .to_data()
140 .assert_approx_eq(&image_transformed.to_data(), Tolerance::<f32>::balanced());
141 }
142
143 #[test]
144 fn transform_translation() {
145 let t = Transform2D::translation(1., 1.);
146 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
147 let image_expected = Tensor::<B, 4>::from([[[[2.5, 3.], [3.5, 4.]]]]);
149 let image = t.transform(image);
150 image_expected
151 .to_data()
152 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
153 }
154
155 #[test]
156 fn transform_rotation_90_degrees() {
157 let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, 0.0, 0.0);
158 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
159 let image_expected = Tensor::<B, 4>::from([[[[2., 4.], [1., 3.]]]]);
160 let image = t.transform(image);
161 image_expected
162 .to_data()
163 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
164 }
165
166 #[test]
167 fn transform_rotation_around_corner() {
168 let cx = 1.;
169 let cy = -1.;
170 let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, cx, cy);
171 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
172 let image_expected = Tensor::<B, 4>::from([[[[2., 2.], [1., 1.]]]]);
174 let image = t.transform(image);
175 image_expected
176 .to_data()
177 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
178 }
179
180 #[test]
181 fn transform_scale() {
182 let cx = 0.0;
183 let cy = 0.0;
184 let t = Transform2D::scale(0.5, 0.5, cx, cy);
185 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
186 let image_expected = Tensor::<B, 4>::from([[[[1.75, 2.25], [2.75, 3.25]]]]);
187 let image = t.transform(image);
188 image_expected
189 .to_data()
190 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
191 }
192
193 #[test]
194 fn transform_scale_around_corner() {
195 let cx = 1.;
196 let cy = -1.;
197 let t = Transform2D::scale(0.5, 0.5, cx, cy);
198 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
199 let image_expected = Tensor::<B, 4>::from([[[[1.5, 2.], [2.5, 3.]]]]);
200 let image = t.transform(image);
201 image_expected
202 .to_data()
203 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
204 }
205
206 #[test]
207 fn transform_combined() {
208 let t1 = Transform2D::translation(0.2, -0.5);
209 let t2 = Transform2D::rotation(std::f32::consts::FRAC_PI_3, 0., 0.);
210 let t = Transform2D::composed([t1, t2]);
211
212 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
213 let image_expected =
215 Tensor::<B, 4>::from([[[[1.7830127, 2.8660254], [1.1339746, 3.2830124]]]]);
216 let image = t.transform(image);
217 image_expected
218 .to_data()
219 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
220 }
221}