1use burn_tensor::{
2 Tensor,
3 backend::Backend,
4 grid::affine_grid_2d,
5 ops::{GridSampleOptions, GridSamplePaddingMode, InterpolateMode},
6};
7
8pub struct Transform2D {
12 transform: [[f32; 3]; 2],
15}
16
17impl Transform2D {
18 pub fn transform<B: Backend>(self, img: Tensor<B, 4>) -> Tensor<B, 4> {
26 let [batch_size, channels, height, width] = img.shape().dims();
27 let transform = Tensor::<B, 2>::from(self.transform);
28 let transform = transform.reshape([1, 2, 3]).expand([batch_size, 2, 3]);
29 let grid = affine_grid_2d(transform, [batch_size, channels, height, width]);
30
31 let options = GridSampleOptions::new(InterpolateMode::Bilinear)
32 .with_padding_mode(GridSamplePaddingMode::Border)
33 .with_align_corners(true);
34 img.grid_sample_2d(grid, options)
35 }
36
37 pub fn composed<I: IntoIterator<Item = Self>>(transforms: I) -> Self {
39 let mut result = Self::identity();
40 for t in transforms.into_iter() {
41 result = result.mul(t);
42 }
43 result
44 }
45
46 fn mul(self, other: Transform2D) -> Transform2D {
48 let mut result = [[0.0f32; 3]; 2];
49
50 result[0][0] = self.transform[0][0] * other.transform[0][0]
52 + self.transform[0][1] * other.transform[1][0];
53 result[0][1] = self.transform[0][0] * other.transform[0][1]
54 + self.transform[0][1] * other.transform[1][1];
55 result[0][2] = self.transform[0][0] * other.transform[0][2]
56 + self.transform[0][1] * other.transform[1][2]
57 + self.transform[0][2];
58
59 result[1][0] = self.transform[1][0] * other.transform[0][0]
61 + self.transform[1][1] * other.transform[1][0];
62 result[1][1] = self.transform[1][0] * other.transform[0][1]
63 + self.transform[1][1] * other.transform[1][1];
64 result[1][2] = self.transform[1][0] * other.transform[0][2]
65 + self.transform[1][1] * other.transform[1][2]
66 + self.transform[1][2];
67
68 Transform2D { transform: result }
69 }
70
71 pub fn identity() -> Self {
73 Self {
74 transform: [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
75 }
76 }
77
78 pub fn rotation(theta: f32, cx: f32, cy: f32) -> Self {
84 let cos_theta = theta.cos();
85 let sin_theta = theta.sin();
86
87 let transform = [
88 [cos_theta, -sin_theta, cx - cos_theta * cx + sin_theta * cy],
89 [sin_theta, cos_theta, cy - sin_theta * cx - cos_theta * cy],
90 ];
91
92 Self { transform }
93 }
94
95 pub fn scale(sx: f32, sy: f32, cx: f32, cy: f32) -> Self {
102 let transform = [[sx, 0.0, cx - sx * cx], [0.0, sy, cy - sy * cy]];
103
104 Self { transform }
105 }
106
107 pub fn translation(tx: f32, ty: f32) -> Self {
112 let transform = [[1.0, 0.0, tx], [0.0, 1.0, ty]];
113
114 Self { transform }
115 }
116
117 pub fn shear(shx: f32, shy: f32, cx: f32, cy: f32) -> Self {
128 let transform = [[1.0, shx, -shx * cy], [shy, 1.0, -shy * cx]];
129
130 Self { transform }
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137 use burn_ndarray::NdArray;
138 use burn_tensor::Tolerance;
139 type B = NdArray;
140
141 #[test]
142 fn transform_identity_translation() {
143 let t = Transform2D::translation(0.0, 0.0);
144 let image_original = Tensor::<B, 4>::from([[[[1., 0.], [0., 2.]]]]);
145 let image_transformed = t.transform(image_original.clone());
146 image_original
147 .to_data()
148 .assert_approx_eq(&image_transformed.to_data(), Tolerance::<f32>::balanced());
149 }
150
151 #[test]
152 fn transform_translation() {
153 let t = Transform2D::translation(1., 1.);
154 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
155 let image_expected = Tensor::<B, 4>::from([[[[2.5, 3.], [3.5, 4.]]]]);
157 let image = t.transform(image);
158 image_expected
159 .to_data()
160 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
161 }
162
163 #[test]
164 fn transform_rotation_90_degrees() {
165 let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, 0.0, 0.0);
166 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
167 let image_expected = Tensor::<B, 4>::from([[[[2., 4.], [1., 3.]]]]);
168 let image = t.transform(image);
169 image_expected
170 .to_data()
171 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
172 }
173
174 #[test]
175 fn transform_rotation_around_corner() {
176 let cx = 1.;
177 let cy = -1.;
178 let t = Transform2D::rotation(std::f32::consts::FRAC_PI_2, cx, cy);
179 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
180 let image_expected = Tensor::<B, 4>::from([[[[2., 2.], [1., 1.]]]]);
182 let image = t.transform(image);
183 image_expected
184 .to_data()
185 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
186 }
187
188 #[test]
189 fn transform_scale() {
190 let cx = 0.0;
191 let cy = 0.0;
192 let t = Transform2D::scale(0.5, 0.5, cx, cy);
193 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
194 let image_expected = Tensor::<B, 4>::from([[[[1.75, 2.25], [2.75, 3.25]]]]);
195 let image = t.transform(image);
196 image_expected
197 .to_data()
198 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
199 }
200
201 #[test]
202 fn transform_scale_around_corner() {
203 let cx = 1.;
204 let cy = -1.;
205 let t = Transform2D::scale(0.5, 0.5, cx, cy);
206 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
207 let image_expected = Tensor::<B, 4>::from([[[[1.5, 2.], [2.5, 3.]]]]);
208 let image = t.transform(image);
209 image_expected
210 .to_data()
211 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
212 }
213
214 #[test]
215 fn transform_combined() {
216 let t1 = Transform2D::translation(0.2, -0.5);
217 let t2 = Transform2D::rotation(std::f32::consts::FRAC_PI_3, 0., 0.);
218 let t = Transform2D::composed([t1, t2]);
219
220 let image = Tensor::<B, 4>::from([[[[1., 2.], [3., 4.]]]]);
221 let image_expected =
223 Tensor::<B, 4>::from([[[[1.7830127, 2.8660254], [1.1339746, 3.2830124]]]]);
224 let image = t.transform(image);
225 image_expected
226 .to_data()
227 .assert_approx_eq(&image.to_data(), Tolerance::<f32>::balanced());
228 }
229}