Skip to main content

oximedia_gpu/kernels/
transform.rs

1//! Transform operations (DCT, FFT, geometric transforms)
2
3use crate::{GpuDevice, Result};
4
5/// Transform operation type
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum TransformType {
8    /// Discrete Cosine Transform (DCT)
9    DCT,
10    /// Inverse DCT
11    IDCT,
12    /// Fast Fourier Transform (FFT)
13    FFT,
14    /// Inverse FFT
15    IFFT,
16    /// Rotate 90 degrees
17    Rotate90,
18    /// Rotate 180 degrees
19    Rotate180,
20    /// Rotate 270 degrees
21    Rotate270,
22    /// Flip horizontal
23    FlipHorizontal,
24    /// Flip vertical
25    FlipVertical,
26    /// Transpose
27    Transpose,
28    /// Affine transform
29    Affine,
30    /// Perspective transform
31    Perspective,
32}
33
34/// Transform kernel for frequency domain and geometric operations
35pub struct TransformKernel {
36    transform_type: TransformType,
37}
38
39impl TransformKernel {
40    /// Create a new transform kernel
41    #[must_use]
42    pub fn new(transform_type: TransformType) -> Self {
43        Self { transform_type }
44    }
45
46    /// Create a DCT transform kernel
47    #[must_use]
48    pub fn dct() -> Self {
49        Self::new(TransformType::DCT)
50    }
51
52    /// Create an IDCT transform kernel
53    #[must_use]
54    pub fn idct() -> Self {
55        Self::new(TransformType::IDCT)
56    }
57
58    /// Create a rotate kernel
59    #[must_use]
60    pub fn rotate(degrees: i32) -> Self {
61        let transform_type = match degrees % 360 {
62            90 | -270 => TransformType::Rotate90,
63            180 | -180 => TransformType::Rotate180,
64            270 | -90 => TransformType::Rotate270,
65            _ => TransformType::Rotate90, // Default
66        };
67        Self::new(transform_type)
68    }
69
70    /// Create a flip kernel
71    #[must_use]
72    pub fn flip(horizontal: bool) -> Self {
73        let transform_type = if horizontal {
74            TransformType::FlipHorizontal
75        } else {
76            TransformType::FlipVertical
77        };
78        Self::new(transform_type)
79    }
80
81    /// Execute the transform operation
82    ///
83    /// # Arguments
84    ///
85    /// * `device` - GPU device
86    /// * `input` - Input data buffer
87    /// * `output` - Output data buffer
88    /// * `width` - Data width
89    /// * `height` - Data height
90    ///
91    /// # Errors
92    ///
93    /// Returns an error if the operation fails.
94    pub fn execute(
95        &self,
96        device: &GpuDevice,
97        input: &[f32],
98        output: &mut [f32],
99        width: u32,
100        height: u32,
101    ) -> Result<()> {
102        match self.transform_type {
103            TransformType::DCT => {
104                crate::ops::TransformOperation::dct_2d(device, input, output, width, height)
105            }
106            TransformType::IDCT => {
107                crate::ops::TransformOperation::idct_2d(device, input, output, width, height)
108            }
109            _ => Err(crate::GpuError::NotSupported(format!(
110                "Transform type {:?} not yet implemented",
111                self.transform_type
112            ))),
113        }
114    }
115
116    /// Get the transform type
117    #[must_use]
118    pub fn transform_type(&self) -> TransformType {
119        self.transform_type
120    }
121
122    /// Check if this is a frequency domain transform
123    #[must_use]
124    pub fn is_frequency_domain(&self) -> bool {
125        matches!(
126            self.transform_type,
127            TransformType::DCT | TransformType::IDCT | TransformType::FFT | TransformType::IFFT
128        )
129    }
130
131    /// Check if this is a geometric transform
132    #[must_use]
133    pub fn is_geometric(&self) -> bool {
134        matches!(
135            self.transform_type,
136            TransformType::Rotate90
137                | TransformType::Rotate180
138                | TransformType::Rotate270
139                | TransformType::FlipHorizontal
140                | TransformType::FlipVertical
141                | TransformType::Transpose
142                | TransformType::Affine
143                | TransformType::Perspective
144        )
145    }
146
147    /// Estimate FLOPS for the transform operation
148    #[must_use]
149    pub fn estimate_flops(width: u32, height: u32, transform_type: TransformType) -> u64 {
150        let n = u64::from(width) * u64::from(height);
151
152        match transform_type {
153            TransformType::DCT | TransformType::IDCT => {
154                // DCT complexity: O(N^2 log N) for 2D
155                let log_n = (n as f64).log2().ceil() as u64;
156                n * n * log_n
157            }
158            TransformType::FFT | TransformType::IFFT => {
159                // FFT complexity: O(N log N)
160                let log_n = (n as f64).log2().ceil() as u64;
161                n * log_n * 5 // 5 ops per butterfly
162            }
163            _ => {
164                // Geometric transforms: O(N)
165                n
166            }
167        }
168    }
169}
170
171/// Affine transformation matrix
172#[derive(Debug, Clone, Copy)]
173pub struct AffineMatrix {
174    /// Matrix elements [a, b, c, d, tx, ty]
175    /// [ a  b  tx ]
176    /// [ c  d  ty ]
177    /// [ 0  0  1  ]
178    pub elements: [f32; 6],
179}
180
181impl AffineMatrix {
182    /// Create an identity matrix
183    #[must_use]
184    pub fn identity() -> Self {
185        Self {
186            elements: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
187        }
188    }
189
190    /// Create a translation matrix
191    #[must_use]
192    pub fn translation(tx: f32, ty: f32) -> Self {
193        Self {
194            elements: [1.0, 0.0, tx, 0.0, 1.0, ty],
195        }
196    }
197
198    /// Create a rotation matrix
199    #[must_use]
200    pub fn rotation(angle_radians: f32) -> Self {
201        let cos = angle_radians.cos();
202        let sin = angle_radians.sin();
203        Self {
204            elements: [cos, -sin, 0.0, sin, cos, 0.0],
205        }
206    }
207
208    /// Create a scaling matrix
209    #[must_use]
210    pub fn scaling(sx: f32, sy: f32) -> Self {
211        Self {
212            elements: [sx, 0.0, 0.0, 0.0, sy, 0.0],
213        }
214    }
215
216    /// Combine two affine transformations
217    #[must_use]
218    pub fn combine(&self, other: &Self) -> Self {
219        let a1 = self.elements;
220        let a2 = other.elements;
221
222        Self {
223            elements: [
224                a1[0] * a2[0] + a1[1] * a2[3],
225                a1[0] * a2[1] + a1[1] * a2[4],
226                a1[0] * a2[2] + a1[1] * a2[5] + a1[2],
227                a1[3] * a2[0] + a1[4] * a2[3],
228                a1[3] * a2[1] + a1[4] * a2[4],
229                a1[3] * a2[2] + a1[4] * a2[5] + a1[5],
230            ],
231        }
232    }
233
234    /// Get matrix elements
235    #[must_use]
236    pub fn as_array(&self) -> [f32; 6] {
237        self.elements
238    }
239}
240
241impl Default for AffineMatrix {
242    fn default() -> Self {
243        Self::identity()
244    }
245}
246
247/// Warp kernel for geometric transformations
248pub struct WarpKernel {
249    matrix: AffineMatrix,
250}
251
252impl WarpKernel {
253    /// Create a new warp kernel
254    #[must_use]
255    pub fn new(matrix: AffineMatrix) -> Self {
256        Self { matrix }
257    }
258
259    /// Create a rotation warp
260    #[must_use]
261    pub fn rotation(angle_degrees: f32, center_x: f32, center_y: f32) -> Self {
262        let angle_radians = angle_degrees.to_radians();
263
264        // Translate to origin, rotate, translate back
265        let t1 = AffineMatrix::translation(-center_x, -center_y);
266        let r = AffineMatrix::rotation(angle_radians);
267        let t2 = AffineMatrix::translation(center_x, center_y);
268
269        let matrix = t1.combine(&r).combine(&t2);
270
271        Self::new(matrix)
272    }
273
274    /// Create a scaling warp
275    #[must_use]
276    pub fn scaling(sx: f32, sy: f32, center_x: f32, center_y: f32) -> Self {
277        let t1 = AffineMatrix::translation(-center_x, -center_y);
278        let s = AffineMatrix::scaling(sx, sy);
279        let t2 = AffineMatrix::translation(center_x, center_y);
280
281        let matrix = t1.combine(&s).combine(&t2);
282
283        Self::new(matrix)
284    }
285
286    /// Get the transformation matrix
287    #[must_use]
288    pub fn matrix(&self) -> &AffineMatrix {
289        &self.matrix
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296
297    #[test]
298    fn test_transform_kernel_creation() {
299        let kernel = TransformKernel::dct();
300        assert_eq!(kernel.transform_type(), TransformType::DCT);
301        assert!(kernel.is_frequency_domain());
302        assert!(!kernel.is_geometric());
303
304        let kernel = TransformKernel::rotate(90);
305        assert_eq!(kernel.transform_type(), TransformType::Rotate90);
306        assert!(!kernel.is_frequency_domain());
307        assert!(kernel.is_geometric());
308    }
309
310    #[test]
311    fn test_affine_matrix_identity() {
312        let identity = AffineMatrix::identity();
313        let elements = identity.as_array();
314        assert_eq!(elements, [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
315    }
316
317    #[test]
318    fn test_affine_matrix_translation() {
319        let trans = AffineMatrix::translation(10.0, 20.0);
320        let elements = trans.as_array();
321        assert_eq!(elements[2], 10.0);
322        assert_eq!(elements[5], 20.0);
323    }
324
325    #[test]
326    fn test_affine_matrix_scaling() {
327        let scale = AffineMatrix::scaling(2.0, 3.0);
328        let elements = scale.as_array();
329        assert_eq!(elements[0], 2.0);
330        assert_eq!(elements[4], 3.0);
331    }
332
333    #[test]
334    fn test_affine_matrix_combination() {
335        let t1 = AffineMatrix::translation(10.0, 20.0);
336        let s = AffineMatrix::scaling(2.0, 2.0);
337        let combined = t1.combine(&s);
338
339        // The result should be a combined transformation
340        assert!(combined.elements[0] > 0.0);
341    }
342
343    #[test]
344    fn test_flops_estimation() {
345        let flops_dct = TransformKernel::estimate_flops(64, 64, TransformType::DCT);
346        let flops_rotate = TransformKernel::estimate_flops(64, 64, TransformType::Rotate90);
347
348        assert!(flops_dct > 0);
349        assert!(flops_rotate > 0);
350        assert!(flops_dct > flops_rotate); // DCT should be more expensive
351    }
352}