1use crate::{GpuDevice, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum TransformType {
8 DCT,
10 IDCT,
12 FFT,
14 IFFT,
16 Rotate90,
18 Rotate180,
20 Rotate270,
22 FlipHorizontal,
24 FlipVertical,
26 Transpose,
28 Affine,
30 Perspective,
32}
33
34pub struct TransformKernel {
36 transform_type: TransformType,
37}
38
39impl TransformKernel {
40 #[must_use]
42 pub fn new(transform_type: TransformType) -> Self {
43 Self { transform_type }
44 }
45
46 #[must_use]
48 pub fn dct() -> Self {
49 Self::new(TransformType::DCT)
50 }
51
52 #[must_use]
54 pub fn idct() -> Self {
55 Self::new(TransformType::IDCT)
56 }
57
58 #[must_use]
60 pub fn rotate(degrees: i32) -> Self {
61 let transform_type = match degrees % 360 {
62 90 | -270 => TransformType::Rotate90,
63 180 | -180 => TransformType::Rotate180,
64 270 | -90 => TransformType::Rotate270,
65 _ => TransformType::Rotate90, };
67 Self::new(transform_type)
68 }
69
70 #[must_use]
72 pub fn flip(horizontal: bool) -> Self {
73 let transform_type = if horizontal {
74 TransformType::FlipHorizontal
75 } else {
76 TransformType::FlipVertical
77 };
78 Self::new(transform_type)
79 }
80
81 pub fn execute(
95 &self,
96 device: &GpuDevice,
97 input: &[f32],
98 output: &mut [f32],
99 width: u32,
100 height: u32,
101 ) -> Result<()> {
102 match self.transform_type {
103 TransformType::DCT => {
104 crate::ops::TransformOperation::dct_2d(device, input, output, width, height)
105 }
106 TransformType::IDCT => {
107 crate::ops::TransformOperation::idct_2d(device, input, output, width, height)
108 }
109 _ => Err(crate::GpuError::NotSupported(format!(
110 "Transform type {:?} not yet implemented",
111 self.transform_type
112 ))),
113 }
114 }
115
116 #[must_use]
118 pub fn transform_type(&self) -> TransformType {
119 self.transform_type
120 }
121
122 #[must_use]
124 pub fn is_frequency_domain(&self) -> bool {
125 matches!(
126 self.transform_type,
127 TransformType::DCT | TransformType::IDCT | TransformType::FFT | TransformType::IFFT
128 )
129 }
130
131 #[must_use]
133 pub fn is_geometric(&self) -> bool {
134 matches!(
135 self.transform_type,
136 TransformType::Rotate90
137 | TransformType::Rotate180
138 | TransformType::Rotate270
139 | TransformType::FlipHorizontal
140 | TransformType::FlipVertical
141 | TransformType::Transpose
142 | TransformType::Affine
143 | TransformType::Perspective
144 )
145 }
146
147 #[must_use]
149 pub fn estimate_flops(width: u32, height: u32, transform_type: TransformType) -> u64 {
150 let n = u64::from(width) * u64::from(height);
151
152 match transform_type {
153 TransformType::DCT | TransformType::IDCT => {
154 let log_n = (n as f64).log2().ceil() as u64;
156 n * n * log_n
157 }
158 TransformType::FFT | TransformType::IFFT => {
159 let log_n = (n as f64).log2().ceil() as u64;
161 n * log_n * 5 }
163 _ => {
164 n
166 }
167 }
168 }
169}
170
171#[derive(Debug, Clone, Copy)]
173pub struct AffineMatrix {
174 pub elements: [f32; 6],
179}
180
181impl AffineMatrix {
182 #[must_use]
184 pub fn identity() -> Self {
185 Self {
186 elements: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
187 }
188 }
189
190 #[must_use]
192 pub fn translation(tx: f32, ty: f32) -> Self {
193 Self {
194 elements: [1.0, 0.0, tx, 0.0, 1.0, ty],
195 }
196 }
197
198 #[must_use]
200 pub fn rotation(angle_radians: f32) -> Self {
201 let cos = angle_radians.cos();
202 let sin = angle_radians.sin();
203 Self {
204 elements: [cos, -sin, 0.0, sin, cos, 0.0],
205 }
206 }
207
208 #[must_use]
210 pub fn scaling(sx: f32, sy: f32) -> Self {
211 Self {
212 elements: [sx, 0.0, 0.0, 0.0, sy, 0.0],
213 }
214 }
215
216 #[must_use]
218 pub fn combine(&self, other: &Self) -> Self {
219 let a1 = self.elements;
220 let a2 = other.elements;
221
222 Self {
223 elements: [
224 a1[0] * a2[0] + a1[1] * a2[3],
225 a1[0] * a2[1] + a1[1] * a2[4],
226 a1[0] * a2[2] + a1[1] * a2[5] + a1[2],
227 a1[3] * a2[0] + a1[4] * a2[3],
228 a1[3] * a2[1] + a1[4] * a2[4],
229 a1[3] * a2[2] + a1[4] * a2[5] + a1[5],
230 ],
231 }
232 }
233
234 #[must_use]
236 pub fn as_array(&self) -> [f32; 6] {
237 self.elements
238 }
239}
240
241impl Default for AffineMatrix {
242 fn default() -> Self {
243 Self::identity()
244 }
245}
246
247pub struct WarpKernel {
249 matrix: AffineMatrix,
250}
251
252impl WarpKernel {
253 #[must_use]
255 pub fn new(matrix: AffineMatrix) -> Self {
256 Self { matrix }
257 }
258
259 #[must_use]
261 pub fn rotation(angle_degrees: f32, center_x: f32, center_y: f32) -> Self {
262 let angle_radians = angle_degrees.to_radians();
263
264 let t1 = AffineMatrix::translation(-center_x, -center_y);
266 let r = AffineMatrix::rotation(angle_radians);
267 let t2 = AffineMatrix::translation(center_x, center_y);
268
269 let matrix = t1.combine(&r).combine(&t2);
270
271 Self::new(matrix)
272 }
273
274 #[must_use]
276 pub fn scaling(sx: f32, sy: f32, center_x: f32, center_y: f32) -> Self {
277 let t1 = AffineMatrix::translation(-center_x, -center_y);
278 let s = AffineMatrix::scaling(sx, sy);
279 let t2 = AffineMatrix::translation(center_x, center_y);
280
281 let matrix = t1.combine(&s).combine(&t2);
282
283 Self::new(matrix)
284 }
285
286 #[must_use]
288 pub fn matrix(&self) -> &AffineMatrix {
289 &self.matrix
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_transform_kernel_creation() {
299 let kernel = TransformKernel::dct();
300 assert_eq!(kernel.transform_type(), TransformType::DCT);
301 assert!(kernel.is_frequency_domain());
302 assert!(!kernel.is_geometric());
303
304 let kernel = TransformKernel::rotate(90);
305 assert_eq!(kernel.transform_type(), TransformType::Rotate90);
306 assert!(!kernel.is_frequency_domain());
307 assert!(kernel.is_geometric());
308 }
309
310 #[test]
311 fn test_affine_matrix_identity() {
312 let identity = AffineMatrix::identity();
313 let elements = identity.as_array();
314 assert_eq!(elements, [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
315 }
316
317 #[test]
318 fn test_affine_matrix_translation() {
319 let trans = AffineMatrix::translation(10.0, 20.0);
320 let elements = trans.as_array();
321 assert_eq!(elements[2], 10.0);
322 assert_eq!(elements[5], 20.0);
323 }
324
325 #[test]
326 fn test_affine_matrix_scaling() {
327 let scale = AffineMatrix::scaling(2.0, 3.0);
328 let elements = scale.as_array();
329 assert_eq!(elements[0], 2.0);
330 assert_eq!(elements[4], 3.0);
331 }
332
333 #[test]
334 fn test_affine_matrix_combination() {
335 let t1 = AffineMatrix::translation(10.0, 20.0);
336 let s = AffineMatrix::scaling(2.0, 2.0);
337 let combined = t1.combine(&s);
338
339 assert!(combined.elements[0] > 0.0);
341 }
342
343 #[test]
344 fn test_flops_estimation() {
345 let flops_dct = TransformKernel::estimate_flops(64, 64, TransformType::DCT);
346 let flops_rotate = TransformKernel::estimate_flops(64, 64, TransformType::Rotate90);
347
348 assert!(flops_dct > 0);
349 assert!(flops_rotate > 0);
350 assert!(flops_dct > flops_rotate); }
352}