1use crate::{GpuDevice, Result};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum TransformType {
8 DCT,
10 IDCT,
12 FFT,
14 IFFT,
16 Rotate90,
18 Rotate180,
20 Rotate270,
22 FlipHorizontal,
24 FlipVertical,
26 Transpose,
28 Affine,
30 Perspective,
32}
33
34pub struct TransformKernel {
36 transform_type: TransformType,
37}
38
39impl TransformKernel {
40 #[must_use]
42 pub fn new(transform_type: TransformType) -> Self {
43 Self { transform_type }
44 }
45
46 #[must_use]
48 pub fn dct() -> Self {
49 Self::new(TransformType::DCT)
50 }
51
52 #[must_use]
54 pub fn idct() -> Self {
55 Self::new(TransformType::IDCT)
56 }
57
58 #[must_use]
60 pub fn rotate(degrees: i32) -> Self {
61 let transform_type = match degrees % 360 {
62 90 | -270 => TransformType::Rotate90,
63 180 | -180 => TransformType::Rotate180,
64 270 | -90 => TransformType::Rotate270,
65 _ => TransformType::Rotate90, };
67 Self::new(transform_type)
68 }
69
70 #[must_use]
72 pub fn flip(horizontal: bool) -> Self {
73 let transform_type = if horizontal {
74 TransformType::FlipHorizontal
75 } else {
76 TransformType::FlipVertical
77 };
78 Self::new(transform_type)
79 }
80
81 pub fn execute(
99 &self,
100 device: &GpuDevice,
101 input: &[f32],
102 output: &mut [f32],
103 width: u32,
104 height: u32,
105 ) -> Result<()> {
106 match self.transform_type {
107 TransformType::DCT => {
108 crate::ops::TransformOperation::dct_2d(device, input, output, width, height)
109 }
110 TransformType::IDCT => {
111 crate::ops::TransformOperation::idct_2d(device, input, output, width, height)
112 }
113 TransformType::FFT
114 | TransformType::IFFT
115 | TransformType::Affine
116 | TransformType::Perspective => Err(crate::GpuError::NotSupported(format!(
117 "Transform type {:?} not yet implemented",
118 self.transform_type
119 ))),
120 _ => Err(crate::GpuError::NotSupported(format!(
121 "Transform type {:?} requires u8 pixel data — use execute_u8()",
122 self.transform_type
123 ))),
124 }
125 }
126
127 pub fn execute_u8(
149 &self,
150 _device: &GpuDevice,
151 input: &[u8],
152 width: u32,
153 height: u32,
154 channels: u32,
155 ) -> Result<Vec<u8>> {
156 match self.transform_type {
157 TransformType::Rotate90 => Ok(crate::ops::TransformOperation::rotate90(
158 input, width, height, channels,
159 )),
160 TransformType::Rotate180 => Ok(crate::ops::TransformOperation::rotate180(
161 input, width, height, channels,
162 )),
163 TransformType::Rotate270 => Ok(crate::ops::TransformOperation::rotate270(
164 input, width, height, channels,
165 )),
166 TransformType::FlipHorizontal => Ok(crate::ops::TransformOperation::flip_horizontal(
167 input, width, height, channels,
168 )),
169 TransformType::FlipVertical => Ok(crate::ops::TransformOperation::flip_vertical(
170 input, width, height, channels,
171 )),
172 TransformType::Transpose => Ok(crate::ops::TransformOperation::transpose(
173 input, width, height, channels,
174 )),
175 TransformType::FFT
176 | TransformType::IFFT
177 | TransformType::Affine
178 | TransformType::Perspective => Err(crate::GpuError::NotSupported(format!(
179 "Transform type {:?} not yet implemented",
180 self.transform_type
181 ))),
182 TransformType::DCT | TransformType::IDCT => {
183 Err(crate::GpuError::NotSupported(format!(
184 "Transform type {:?} operates on f32 data — use execute()",
185 self.transform_type
186 )))
187 }
188 }
189 }
190
191 #[must_use]
193 pub fn transform_type(&self) -> TransformType {
194 self.transform_type
195 }
196
197 #[must_use]
199 pub fn is_frequency_domain(&self) -> bool {
200 matches!(
201 self.transform_type,
202 TransformType::DCT | TransformType::IDCT | TransformType::FFT | TransformType::IFFT
203 )
204 }
205
206 #[must_use]
208 pub fn is_geometric(&self) -> bool {
209 matches!(
210 self.transform_type,
211 TransformType::Rotate90
212 | TransformType::Rotate180
213 | TransformType::Rotate270
214 | TransformType::FlipHorizontal
215 | TransformType::FlipVertical
216 | TransformType::Transpose
217 | TransformType::Affine
218 | TransformType::Perspective
219 )
220 }
221
222 #[must_use]
224 pub fn estimate_flops(width: u32, height: u32, transform_type: TransformType) -> u64 {
225 let n = u64::from(width) * u64::from(height);
226
227 match transform_type {
228 TransformType::DCT | TransformType::IDCT => {
229 let log_n = (n as f64).log2().ceil() as u64;
231 n * n * log_n
232 }
233 TransformType::FFT | TransformType::IFFT => {
234 let log_n = (n as f64).log2().ceil() as u64;
236 n * log_n * 5 }
238 _ => {
239 n
241 }
242 }
243 }
244}
245
246#[derive(Debug, Clone, Copy)]
248pub struct AffineMatrix {
249 pub elements: [f32; 6],
254}
255
256impl AffineMatrix {
257 #[must_use]
259 pub fn identity() -> Self {
260 Self {
261 elements: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
262 }
263 }
264
265 #[must_use]
267 pub fn translation(tx: f32, ty: f32) -> Self {
268 Self {
269 elements: [1.0, 0.0, tx, 0.0, 1.0, ty],
270 }
271 }
272
273 #[must_use]
275 pub fn rotation(angle_radians: f32) -> Self {
276 let cos = angle_radians.cos();
277 let sin = angle_radians.sin();
278 Self {
279 elements: [cos, -sin, 0.0, sin, cos, 0.0],
280 }
281 }
282
283 #[must_use]
285 pub fn scaling(sx: f32, sy: f32) -> Self {
286 Self {
287 elements: [sx, 0.0, 0.0, 0.0, sy, 0.0],
288 }
289 }
290
291 #[must_use]
293 pub fn combine(&self, other: &Self) -> Self {
294 let a1 = self.elements;
295 let a2 = other.elements;
296
297 Self {
298 elements: [
299 a1[0] * a2[0] + a1[1] * a2[3],
300 a1[0] * a2[1] + a1[1] * a2[4],
301 a1[0] * a2[2] + a1[1] * a2[5] + a1[2],
302 a1[3] * a2[0] + a1[4] * a2[3],
303 a1[3] * a2[1] + a1[4] * a2[4],
304 a1[3] * a2[2] + a1[4] * a2[5] + a1[5],
305 ],
306 }
307 }
308
309 #[must_use]
311 pub fn as_array(&self) -> [f32; 6] {
312 self.elements
313 }
314}
315
316impl Default for AffineMatrix {
317 fn default() -> Self {
318 Self::identity()
319 }
320}
321
322pub struct WarpKernel {
324 matrix: AffineMatrix,
325}
326
327impl WarpKernel {
328 #[must_use]
330 pub fn new(matrix: AffineMatrix) -> Self {
331 Self { matrix }
332 }
333
334 #[must_use]
336 pub fn rotation(angle_degrees: f32, center_x: f32, center_y: f32) -> Self {
337 let angle_radians = angle_degrees.to_radians();
338
339 let t1 = AffineMatrix::translation(-center_x, -center_y);
341 let r = AffineMatrix::rotation(angle_radians);
342 let t2 = AffineMatrix::translation(center_x, center_y);
343
344 let matrix = t1.combine(&r).combine(&t2);
345
346 Self::new(matrix)
347 }
348
349 #[must_use]
351 pub fn scaling(sx: f32, sy: f32, center_x: f32, center_y: f32) -> Self {
352 let t1 = AffineMatrix::translation(-center_x, -center_y);
353 let s = AffineMatrix::scaling(sx, sy);
354 let t2 = AffineMatrix::translation(center_x, center_y);
355
356 let matrix = t1.combine(&s).combine(&t2);
357
358 Self::new(matrix)
359 }
360
361 #[must_use]
363 pub fn matrix(&self) -> &AffineMatrix {
364 &self.matrix
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_transform_kernel_creation() {
374 let kernel = TransformKernel::dct();
375 assert_eq!(kernel.transform_type(), TransformType::DCT);
376 assert!(kernel.is_frequency_domain());
377 assert!(!kernel.is_geometric());
378
379 let kernel = TransformKernel::rotate(90);
380 assert_eq!(kernel.transform_type(), TransformType::Rotate90);
381 assert!(!kernel.is_frequency_domain());
382 assert!(kernel.is_geometric());
383 }
384
385 #[test]
386 fn test_affine_matrix_identity() {
387 let identity = AffineMatrix::identity();
388 let elements = identity.as_array();
389 assert_eq!(elements, [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
390 }
391
392 #[test]
393 fn test_affine_matrix_translation() {
394 let trans = AffineMatrix::translation(10.0, 20.0);
395 let elements = trans.as_array();
396 assert_eq!(elements[2], 10.0);
397 assert_eq!(elements[5], 20.0);
398 }
399
400 #[test]
401 fn test_affine_matrix_scaling() {
402 let scale = AffineMatrix::scaling(2.0, 3.0);
403 let elements = scale.as_array();
404 assert_eq!(elements[0], 2.0);
405 assert_eq!(elements[4], 3.0);
406 }
407
408 #[test]
409 fn test_affine_matrix_combination() {
410 let t1 = AffineMatrix::translation(10.0, 20.0);
411 let s = AffineMatrix::scaling(2.0, 2.0);
412 let combined = t1.combine(&s);
413
414 assert!(combined.elements[0] > 0.0);
416 }
417
418 #[test]
419 fn test_flops_estimation() {
420 let flops_dct = TransformKernel::estimate_flops(64, 64, TransformType::DCT);
421 let flops_rotate = TransformKernel::estimate_flops(64, 64, TransformType::Rotate90);
422
423 assert!(flops_dct > 0);
424 assert!(flops_rotate > 0);
425 assert!(flops_dct > flops_rotate); }
427}