Skip to main content

oximedia_gpu/kernels/
transform.rs

1//! Transform operations (DCT, FFT, geometric transforms)
2
3use crate::{GpuDevice, Result};
4
5/// Transform operation type
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum TransformType {
8    /// Discrete Cosine Transform (DCT)
9    DCT,
10    /// Inverse DCT
11    IDCT,
12    /// Fast Fourier Transform (FFT)
13    FFT,
14    /// Inverse FFT
15    IFFT,
16    /// Rotate 90 degrees
17    Rotate90,
18    /// Rotate 180 degrees
19    Rotate180,
20    /// Rotate 270 degrees
21    Rotate270,
22    /// Flip horizontal
23    FlipHorizontal,
24    /// Flip vertical
25    FlipVertical,
26    /// Transpose
27    Transpose,
28    /// Affine transform
29    Affine,
30    /// Perspective transform
31    Perspective,
32}
33
34/// Transform kernel for frequency domain and geometric operations
35pub struct TransformKernel {
36    transform_type: TransformType,
37}
38
39impl TransformKernel {
40    /// Create a new transform kernel
41    #[must_use]
42    pub fn new(transform_type: TransformType) -> Self {
43        Self { transform_type }
44    }
45
46    /// Create a DCT transform kernel
47    #[must_use]
48    pub fn dct() -> Self {
49        Self::new(TransformType::DCT)
50    }
51
52    /// Create an IDCT transform kernel
53    #[must_use]
54    pub fn idct() -> Self {
55        Self::new(TransformType::IDCT)
56    }
57
58    /// Create a rotate kernel
59    #[must_use]
60    pub fn rotate(degrees: i32) -> Self {
61        let transform_type = match degrees % 360 {
62            90 | -270 => TransformType::Rotate90,
63            180 | -180 => TransformType::Rotate180,
64            270 | -90 => TransformType::Rotate270,
65            _ => TransformType::Rotate90, // Default
66        };
67        Self::new(transform_type)
68    }
69
70    /// Create a flip kernel
71    #[must_use]
72    pub fn flip(horizontal: bool) -> Self {
73        let transform_type = if horizontal {
74            TransformType::FlipHorizontal
75        } else {
76            TransformType::FlipVertical
77        };
78        Self::new(transform_type)
79    }
80
81    /// Execute the transform operation (frequency-domain, f32 data).
82    ///
83    /// Handles DCT and IDCT which operate on `f32` frequency-domain data.
84    /// For pixel-level geometric transforms (rotate, flip, transpose) use
85    /// [`TransformKernel::execute_u8`] instead.
86    ///
87    /// # Arguments
88    ///
89    /// * `device` - GPU device
90    /// * `input` - Input data buffer
91    /// * `output` - Output data buffer
92    /// * `width` - Data width
93    /// * `height` - Data height
94    ///
95    /// # Errors
96    ///
97    /// Returns an error if the operation fails or is not supported for f32 data.
98    pub fn execute(
99        &self,
100        device: &GpuDevice,
101        input: &[f32],
102        output: &mut [f32],
103        width: u32,
104        height: u32,
105    ) -> Result<()> {
106        match self.transform_type {
107            TransformType::DCT => {
108                crate::ops::TransformOperation::dct_2d(device, input, output, width, height)
109            }
110            TransformType::IDCT => {
111                crate::ops::TransformOperation::idct_2d(device, input, output, width, height)
112            }
113            TransformType::FFT
114            | TransformType::IFFT
115            | TransformType::Affine
116            | TransformType::Perspective => Err(crate::GpuError::NotSupported(format!(
117                "Transform type {:?} not yet implemented",
118                self.transform_type
119            ))),
120            _ => Err(crate::GpuError::NotSupported(format!(
121                "Transform type {:?} requires u8 pixel data — use execute_u8()",
122                self.transform_type
123            ))),
124        }
125    }
126
127    /// Execute a geometric pixel transform on an interleaved `u8` image buffer.
128    ///
129    /// Handles `Rotate90`, `Rotate180`, `Rotate270`, `FlipHorizontal`,
130    /// `FlipVertical`, and `Transpose`.  `FFT`, `IFFT`, `Affine`, and
131    /// `Perspective` are deliberately left as `NotSupported`.
132    ///
133    /// The `_device` parameter is accepted for API symmetry but is not used
134    /// by the CPU-side implementations (the geometric ops are fully pure-Rust).
135    ///
136    /// # Arguments
137    ///
138    /// * `_device` - GPU device (unused; present for API consistency)
139    /// * `input` - Input pixel buffer (`width * height * channels` bytes)
140    /// * `width` - Image width in pixels
141    /// * `height` - Image height in pixels
142    /// * `channels` - Bytes per pixel (e.g. 3 for RGB, 4 for RGBA)
143    ///
144    /// # Errors
145    ///
146    /// Returns [`crate::GpuError::NotSupported`] for frequency-domain,
147    /// `Affine`, and `Perspective` transform types.
148    pub fn execute_u8(
149        &self,
150        _device: &GpuDevice,
151        input: &[u8],
152        width: u32,
153        height: u32,
154        channels: u32,
155    ) -> Result<Vec<u8>> {
156        match self.transform_type {
157            TransformType::Rotate90 => Ok(crate::ops::TransformOperation::rotate90(
158                input, width, height, channels,
159            )),
160            TransformType::Rotate180 => Ok(crate::ops::TransformOperation::rotate180(
161                input, width, height, channels,
162            )),
163            TransformType::Rotate270 => Ok(crate::ops::TransformOperation::rotate270(
164                input, width, height, channels,
165            )),
166            TransformType::FlipHorizontal => Ok(crate::ops::TransformOperation::flip_horizontal(
167                input, width, height, channels,
168            )),
169            TransformType::FlipVertical => Ok(crate::ops::TransformOperation::flip_vertical(
170                input, width, height, channels,
171            )),
172            TransformType::Transpose => Ok(crate::ops::TransformOperation::transpose(
173                input, width, height, channels,
174            )),
175            TransformType::FFT
176            | TransformType::IFFT
177            | TransformType::Affine
178            | TransformType::Perspective => Err(crate::GpuError::NotSupported(format!(
179                "Transform type {:?} not yet implemented",
180                self.transform_type
181            ))),
182            TransformType::DCT | TransformType::IDCT => {
183                Err(crate::GpuError::NotSupported(format!(
184                    "Transform type {:?} operates on f32 data — use execute()",
185                    self.transform_type
186                )))
187            }
188        }
189    }
190
191    /// Get the transform type
192    #[must_use]
193    pub fn transform_type(&self) -> TransformType {
194        self.transform_type
195    }
196
197    /// Check if this is a frequency domain transform
198    #[must_use]
199    pub fn is_frequency_domain(&self) -> bool {
200        matches!(
201            self.transform_type,
202            TransformType::DCT | TransformType::IDCT | TransformType::FFT | TransformType::IFFT
203        )
204    }
205
206    /// Check if this is a geometric transform
207    #[must_use]
208    pub fn is_geometric(&self) -> bool {
209        matches!(
210            self.transform_type,
211            TransformType::Rotate90
212                | TransformType::Rotate180
213                | TransformType::Rotate270
214                | TransformType::FlipHorizontal
215                | TransformType::FlipVertical
216                | TransformType::Transpose
217                | TransformType::Affine
218                | TransformType::Perspective
219        )
220    }
221
222    /// Estimate FLOPS for the transform operation
223    #[must_use]
224    pub fn estimate_flops(width: u32, height: u32, transform_type: TransformType) -> u64 {
225        let n = u64::from(width) * u64::from(height);
226
227        match transform_type {
228            TransformType::DCT | TransformType::IDCT => {
229                // DCT complexity: O(N^2 log N) for 2D
230                let log_n = (n as f64).log2().ceil() as u64;
231                n * n * log_n
232            }
233            TransformType::FFT | TransformType::IFFT => {
234                // FFT complexity: O(N log N)
235                let log_n = (n as f64).log2().ceil() as u64;
236                n * log_n * 5 // 5 ops per butterfly
237            }
238            _ => {
239                // Geometric transforms: O(N)
240                n
241            }
242        }
243    }
244}
245
246/// Affine transformation matrix
247#[derive(Debug, Clone, Copy)]
248pub struct AffineMatrix {
249    /// Matrix elements [a, b, c, d, tx, ty]
250    /// [ a  b  tx ]
251    /// [ c  d  ty ]
252    /// [ 0  0  1  ]
253    pub elements: [f32; 6],
254}
255
256impl AffineMatrix {
257    /// Create an identity matrix
258    #[must_use]
259    pub fn identity() -> Self {
260        Self {
261            elements: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
262        }
263    }
264
265    /// Create a translation matrix
266    #[must_use]
267    pub fn translation(tx: f32, ty: f32) -> Self {
268        Self {
269            elements: [1.0, 0.0, tx, 0.0, 1.0, ty],
270        }
271    }
272
273    /// Create a rotation matrix
274    #[must_use]
275    pub fn rotation(angle_radians: f32) -> Self {
276        let cos = angle_radians.cos();
277        let sin = angle_radians.sin();
278        Self {
279            elements: [cos, -sin, 0.0, sin, cos, 0.0],
280        }
281    }
282
283    /// Create a scaling matrix
284    #[must_use]
285    pub fn scaling(sx: f32, sy: f32) -> Self {
286        Self {
287            elements: [sx, 0.0, 0.0, 0.0, sy, 0.0],
288        }
289    }
290
291    /// Combine two affine transformations
292    #[must_use]
293    pub fn combine(&self, other: &Self) -> Self {
294        let a1 = self.elements;
295        let a2 = other.elements;
296
297        Self {
298            elements: [
299                a1[0] * a2[0] + a1[1] * a2[3],
300                a1[0] * a2[1] + a1[1] * a2[4],
301                a1[0] * a2[2] + a1[1] * a2[5] + a1[2],
302                a1[3] * a2[0] + a1[4] * a2[3],
303                a1[3] * a2[1] + a1[4] * a2[4],
304                a1[3] * a2[2] + a1[4] * a2[5] + a1[5],
305            ],
306        }
307    }
308
309    /// Get matrix elements
310    #[must_use]
311    pub fn as_array(&self) -> [f32; 6] {
312        self.elements
313    }
314}
315
316impl Default for AffineMatrix {
317    fn default() -> Self {
318        Self::identity()
319    }
320}
321
322/// Warp kernel for geometric transformations
323pub struct WarpKernel {
324    matrix: AffineMatrix,
325}
326
327impl WarpKernel {
328    /// Create a new warp kernel
329    #[must_use]
330    pub fn new(matrix: AffineMatrix) -> Self {
331        Self { matrix }
332    }
333
334    /// Create a rotation warp
335    #[must_use]
336    pub fn rotation(angle_degrees: f32, center_x: f32, center_y: f32) -> Self {
337        let angle_radians = angle_degrees.to_radians();
338
339        // Translate to origin, rotate, translate back
340        let t1 = AffineMatrix::translation(-center_x, -center_y);
341        let r = AffineMatrix::rotation(angle_radians);
342        let t2 = AffineMatrix::translation(center_x, center_y);
343
344        let matrix = t1.combine(&r).combine(&t2);
345
346        Self::new(matrix)
347    }
348
349    /// Create a scaling warp
350    #[must_use]
351    pub fn scaling(sx: f32, sy: f32, center_x: f32, center_y: f32) -> Self {
352        let t1 = AffineMatrix::translation(-center_x, -center_y);
353        let s = AffineMatrix::scaling(sx, sy);
354        let t2 = AffineMatrix::translation(center_x, center_y);
355
356        let matrix = t1.combine(&s).combine(&t2);
357
358        Self::new(matrix)
359    }
360
361    /// Get the transformation matrix
362    #[must_use]
363    pub fn matrix(&self) -> &AffineMatrix {
364        &self.matrix
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_transform_kernel_creation() {
374        let kernel = TransformKernel::dct();
375        assert_eq!(kernel.transform_type(), TransformType::DCT);
376        assert!(kernel.is_frequency_domain());
377        assert!(!kernel.is_geometric());
378
379        let kernel = TransformKernel::rotate(90);
380        assert_eq!(kernel.transform_type(), TransformType::Rotate90);
381        assert!(!kernel.is_frequency_domain());
382        assert!(kernel.is_geometric());
383    }
384
385    #[test]
386    fn test_affine_matrix_identity() {
387        let identity = AffineMatrix::identity();
388        let elements = identity.as_array();
389        assert_eq!(elements, [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
390    }
391
392    #[test]
393    fn test_affine_matrix_translation() {
394        let trans = AffineMatrix::translation(10.0, 20.0);
395        let elements = trans.as_array();
396        assert_eq!(elements[2], 10.0);
397        assert_eq!(elements[5], 20.0);
398    }
399
400    #[test]
401    fn test_affine_matrix_scaling() {
402        let scale = AffineMatrix::scaling(2.0, 3.0);
403        let elements = scale.as_array();
404        assert_eq!(elements[0], 2.0);
405        assert_eq!(elements[4], 3.0);
406    }
407
408    #[test]
409    fn test_affine_matrix_combination() {
410        let t1 = AffineMatrix::translation(10.0, 20.0);
411        let s = AffineMatrix::scaling(2.0, 2.0);
412        let combined = t1.combine(&s);
413
414        // The result should be a combined transformation
415        assert!(combined.elements[0] > 0.0);
416    }
417
418    #[test]
419    fn test_flops_estimation() {
420        let flops_dct = TransformKernel::estimate_flops(64, 64, TransformType::DCT);
421        let flops_rotate = TransformKernel::estimate_flops(64, 64, TransformType::Rotate90);
422
423        assert!(flops_dct > 0);
424        assert!(flops_rotate > 0);
425        assert!(flops_dct > flops_rotate); // DCT should be more expensive
426    }
427}