Skip to main content

oximedia_gpu/kernels/
transform.rs

1//! Transform operations (DCT, FFT, geometric transforms)
2
3use crate::{GpuDevice, GpuError, Result};
4use oxifft::Complex;
5
6/// Transform operation type
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum TransformType {
9    /// Discrete Cosine Transform (DCT)
10    DCT,
11    /// Inverse DCT
12    IDCT,
13    /// Fast Fourier Transform (FFT)
14    FFT,
15    /// Inverse FFT
16    IFFT,
17    /// Rotate 90 degrees
18    Rotate90,
19    /// Rotate 180 degrees
20    Rotate180,
21    /// Rotate 270 degrees
22    Rotate270,
23    /// Flip horizontal
24    FlipHorizontal,
25    /// Flip vertical
26    FlipVertical,
27    /// Transpose
28    Transpose,
29    /// Affine transform
30    Affine,
31    /// Perspective transform
32    Perspective,
33}
34
35/// Transform kernel for frequency domain and geometric operations
36pub struct TransformKernel {
37    transform_type: TransformType,
38}
39
40impl TransformKernel {
41    /// Create a new transform kernel
42    #[must_use]
43    pub fn new(transform_type: TransformType) -> Self {
44        Self { transform_type }
45    }
46
47    /// Create a DCT transform kernel
48    #[must_use]
49    pub fn dct() -> Self {
50        Self::new(TransformType::DCT)
51    }
52
53    /// Create an IDCT transform kernel
54    #[must_use]
55    pub fn idct() -> Self {
56        Self::new(TransformType::IDCT)
57    }
58
59    /// Create a rotate kernel
60    #[must_use]
61    pub fn rotate(degrees: i32) -> Self {
62        let transform_type = match degrees % 360 {
63            90 | -270 => TransformType::Rotate90,
64            180 | -180 => TransformType::Rotate180,
65            270 | -90 => TransformType::Rotate270,
66            _ => TransformType::Rotate90, // Default
67        };
68        Self::new(transform_type)
69    }
70
71    /// Create a flip kernel
72    #[must_use]
73    pub fn flip(horizontal: bool) -> Self {
74        let transform_type = if horizontal {
75            TransformType::FlipHorizontal
76        } else {
77            TransformType::FlipVertical
78        };
79        Self::new(transform_type)
80    }
81
82    /// Execute the transform operation (frequency-domain, f32 data).
83    ///
84    /// Handles DCT and IDCT which operate on `f32` frequency-domain data.
85    /// For pixel-level geometric transforms (rotate, flip, transpose) use
86    /// [`TransformKernel::execute_u8`] instead.
87    ///
88    /// # Arguments
89    ///
90    /// * `device` - GPU device
91    /// * `input` - Input data buffer
92    /// * `output` - Output data buffer
93    /// * `width` - Data width
94    /// * `height` - Data height
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if the operation fails or is not supported for f32 data.
99    pub fn execute(
100        &self,
101        device: &GpuDevice,
102        input: &[f32],
103        output: &mut [f32],
104        width: u32,
105        height: u32,
106    ) -> Result<()> {
107        match self.transform_type {
108            TransformType::DCT => {
109                crate::ops::TransformOperation::dct_2d(device, input, output, width, height)
110            }
111            TransformType::IDCT => {
112                crate::ops::TransformOperation::idct_2d(device, input, output, width, height)
113            }
114            // FFT/IFFT: use execute_fft_f32 / execute_ifft_f32 directly.
115            // Affine/Perspective: matrix parameters cannot be passed through the
116            // unit enum variant — use execute_affine_f32 / execute_perspective_f32
117            // directly. These arms return NotSupported to preserve the API contract.
118            TransformType::FFT => self.execute_fft_f32(input, output, width, height),
119            TransformType::IFFT => self.execute_ifft_f32(input, output, width, height),
120            TransformType::Affine => Err(crate::GpuError::NotSupported(
121                "Affine requires a matrix — call execute_affine_f32() directly".to_string(),
122            )),
123            TransformType::Perspective => Err(crate::GpuError::NotSupported(
124                "Perspective requires a matrix — call execute_perspective_f32() directly"
125                    .to_string(),
126            )),
127            _ => Err(crate::GpuError::NotSupported(format!(
128                "Transform type {:?} requires u8 pixel data — use execute_u8()",
129                self.transform_type
130            ))),
131        }
132    }
133
134    /// Execute a geometric pixel transform on an interleaved `u8` image buffer.
135    ///
136    /// Handles `Rotate90`, `Rotate180`, `Rotate270`, `FlipHorizontal`,
137    /// `FlipVertical`, and `Transpose`.  `FFT`, `IFFT`, `Affine`, and
138    /// `Perspective` are deliberately left as `NotSupported`.
139    ///
140    /// The `_device` parameter is accepted for API symmetry but is not used
141    /// by the CPU-side implementations (the geometric ops are fully pure-Rust).
142    ///
143    /// # Arguments
144    ///
145    /// * `_device` - GPU device (unused; present for API consistency)
146    /// * `input` - Input pixel buffer (`width * height * channels` bytes)
147    /// * `width` - Image width in pixels
148    /// * `height` - Image height in pixels
149    /// * `channels` - Bytes per pixel (e.g. 3 for RGB, 4 for RGBA)
150    ///
151    /// # Errors
152    ///
153    /// Returns [`crate::GpuError::NotSupported`] for frequency-domain,
154    /// `Affine`, and `Perspective` transform types.
155    pub fn execute_u8(
156        &self,
157        _device: &GpuDevice,
158        input: &[u8],
159        width: u32,
160        height: u32,
161        channels: u32,
162    ) -> Result<Vec<u8>> {
163        match self.transform_type {
164            TransformType::Rotate90 => Ok(crate::ops::TransformOperation::rotate90(
165                input, width, height, channels,
166            )),
167            TransformType::Rotate180 => Ok(crate::ops::TransformOperation::rotate180(
168                input, width, height, channels,
169            )),
170            TransformType::Rotate270 => Ok(crate::ops::TransformOperation::rotate270(
171                input, width, height, channels,
172            )),
173            TransformType::FlipHorizontal => Ok(crate::ops::TransformOperation::flip_horizontal(
174                input, width, height, channels,
175            )),
176            TransformType::FlipVertical => Ok(crate::ops::TransformOperation::flip_vertical(
177                input, width, height, channels,
178            )),
179            TransformType::Transpose => Ok(crate::ops::TransformOperation::transpose(
180                input, width, height, channels,
181            )),
182            TransformType::FFT | TransformType::IFFT => Err(crate::GpuError::NotSupported(
183                "FFT/IFFT operates on f32 data — use execute()".to_string(),
184            )),
185            // Affine/Perspective u8: matrix parameters cannot be passed through the
186            // unit enum variant — use execute_affine_u8 / execute_perspective_u8
187            // directly. These arms return NotSupported to preserve the API contract.
188            TransformType::Affine => Err(crate::GpuError::NotSupported(
189                "Affine requires a matrix — call execute_affine_u8() directly".to_string(),
190            )),
191            TransformType::Perspective => Err(crate::GpuError::NotSupported(
192                "Perspective requires a matrix — call execute_perspective_u8() directly"
193                    .to_string(),
194            )),
195            TransformType::DCT | TransformType::IDCT => {
196                Err(crate::GpuError::NotSupported(format!(
197                    "Transform type {:?} operates on f32 data — use execute()",
198                    self.transform_type
199                )))
200            }
201        }
202    }
203
204    /// Get the transform type
205    #[must_use]
206    pub fn transform_type(&self) -> TransformType {
207        self.transform_type
208    }
209
210    /// Check if this is a frequency domain transform
211    #[must_use]
212    pub fn is_frequency_domain(&self) -> bool {
213        matches!(
214            self.transform_type,
215            TransformType::DCT | TransformType::IDCT | TransformType::FFT | TransformType::IFFT
216        )
217    }
218
219    /// Check if this is a geometric transform
220    #[must_use]
221    pub fn is_geometric(&self) -> bool {
222        matches!(
223            self.transform_type,
224            TransformType::Rotate90
225                | TransformType::Rotate180
226                | TransformType::Rotate270
227                | TransformType::FlipHorizontal
228                | TransformType::FlipVertical
229                | TransformType::Transpose
230                | TransformType::Affine
231                | TransformType::Perspective
232        )
233    }
234
235    /// Estimate FLOPS for the transform operation
236    #[must_use]
237    pub fn estimate_flops(width: u32, height: u32, transform_type: TransformType) -> u64 {
238        let n = u64::from(width) * u64::from(height);
239
240        match transform_type {
241            TransformType::DCT | TransformType::IDCT => {
242                // DCT complexity: O(N^2 log N) for 2D
243                let log_n = (n as f64).log2().ceil() as u64;
244                n * n * log_n
245            }
246            TransformType::FFT | TransformType::IFFT => {
247                // FFT complexity: O(N log N)
248                let log_n = (n as f64).log2().ceil() as u64;
249                n * log_n * 5 // 5 ops per butterfly
250            }
251            _ => {
252                // Geometric transforms: O(N)
253                n
254            }
255        }
256    }
257
258    // -------------------------------------------------------------------------
259    // Affine transform — f32 path (CPU fallback, inverse-mapped nearest-neighbour)
260    // -------------------------------------------------------------------------
261
262    /// Apply a 2D affine transform to a packed f32 scalar image.
263    ///
264    /// Matrix layout: `[a, b, c, d, tx, ty]` such that the *forward* mapping
265    /// `(x', y') = ([a b; c d] · [x; y]) + [tx; ty]` is inverted before use.
266    /// For each output pixel `(ox, oy)` the inverse transform finds the source
267    /// coordinate `(sx, sy)` and performs nearest-neighbour sampling (clamped to
268    /// border).
269    ///
270    /// # Arguments
271    ///
272    /// * `input`  – f32 buffer of `width * height` samples
273    /// * `output` – f32 buffer of `width * height` samples (same size as input)
274    /// * `width`  – image width in pixels
275    /// * `height` – image height in pixels
276    /// * `matrix` – forward affine matrix `[a, b, c, d, tx, ty]`
277    ///
278    /// # Errors
279    ///
280    /// Returns [`GpuError::InvalidBufferSize`] if buffers are too small, or
281    /// [`GpuError::Internal`] if the matrix is singular (det ≈ 0).
282    pub fn execute_affine_f32(
283        &self,
284        input: &[f32],
285        output: &mut [f32],
286        width: u32,
287        height: u32,
288        matrix: [f32; 6],
289    ) -> Result<()> {
290        let expected = (width * height) as usize;
291        if input.len() < expected {
292            return Err(GpuError::InvalidBufferSize {
293                expected,
294                actual: input.len(),
295            });
296        }
297        if output.len() < expected {
298            return Err(GpuError::InvalidBufferSize {
299                expected,
300                actual: output.len(),
301            });
302        }
303
304        // Forward matrix [a, b, c, d, tx, ty]
305        let a = matrix[0];
306        let b = matrix[1];
307        let c = matrix[2];
308        let d = matrix[3];
309        let tx = matrix[4];
310        let ty = matrix[5];
311
312        let det = a * d - b * c;
313        if det.abs() < f32::EPSILON {
314            return Err(GpuError::Internal("Affine matrix is singular".to_string()));
315        }
316
317        // Inverse 2×2 part
318        let inv_det = 1.0 / det;
319        let ia = d * inv_det;
320        let ib = -b * inv_det;
321        let ic = -c * inv_det;
322        let id = a * inv_det;
323        // Inverse translation: inv_t = -M_inv * t
324        let itx = -(ia * tx + ib * ty);
325        let ity = -(ic * tx + id * ty);
326
327        let w = width as i32;
328        let h = height as i32;
329
330        for oy in 0..height {
331            for ox in 0..width {
332                let fx = ox as f32;
333                let fy = oy as f32;
334                let sx = ia * fx + ib * fy + itx;
335                let sy = ic * fx + id * fy + ity;
336                let ix = (sx.floor() as i32).clamp(0, w - 1) as u32;
337                let iy = (sy.floor() as i32).clamp(0, h - 1) as u32;
338                let out_idx = (oy * width + ox) as usize;
339                let in_idx = (iy * width + ix) as usize;
340                output[out_idx] = input[in_idx];
341            }
342        }
343
344        Ok(())
345    }
346
347    // -------------------------------------------------------------------------
348    // Affine transform — u8 path (delegates to f32)
349    // -------------------------------------------------------------------------
350
351    /// Apply a 2D affine transform to a packed `u8` image (any number of
352    /// channels per pixel).
353    ///
354    /// Each pixel is treated as `channels` consecutive bytes. The geometric
355    /// mapping is computed in f32 (inverse affine, nearest-neighbour sampling).
356    ///
357    /// # Arguments
358    ///
359    /// * `input`    – u8 buffer of `width * height * channels` bytes
360    /// * `output`   – u8 buffer of `width * height * channels` bytes
361    /// * `width`    – image width in pixels
362    /// * `height`   – image height in pixels
363    /// * `channels` – bytes per pixel (e.g. 3 for RGB, 4 for RGBA)
364    /// * `matrix`   – forward affine matrix `[a, b, c, d, tx, ty]`
365    ///
366    /// # Errors
367    ///
368    /// Returns an error if buffers are too small or the matrix is singular.
369    pub fn execute_affine_u8(
370        &self,
371        input: &[u8],
372        output: &mut [u8],
373        width: u32,
374        height: u32,
375        channels: u32,
376        matrix: [f32; 6],
377    ) -> Result<()> {
378        let expected = (width * height * channels) as usize;
379        if input.len() < expected {
380            return Err(GpuError::InvalidBufferSize {
381                expected,
382                actual: input.len(),
383            });
384        }
385        if output.len() < expected {
386            return Err(GpuError::InvalidBufferSize {
387                expected,
388                actual: output.len(),
389            });
390        }
391
392        let a = matrix[0];
393        let b = matrix[1];
394        let c = matrix[2];
395        let d = matrix[3];
396        let tx = matrix[4];
397        let ty = matrix[5];
398
399        let det = a * d - b * c;
400        if det.abs() < f32::EPSILON {
401            return Err(GpuError::Internal("Affine matrix is singular".to_string()));
402        }
403
404        let inv_det = 1.0 / det;
405        let ia = d * inv_det;
406        let ib = -b * inv_det;
407        let ic = -c * inv_det;
408        let id = a * inv_det;
409        let itx = -(ia * tx + ib * ty);
410        let ity = -(ic * tx + id * ty);
411
412        let w = width as i32;
413        let h = height as i32;
414        let ch = channels as usize;
415
416        for oy in 0..height {
417            for ox in 0..width {
418                let fx = ox as f32;
419                let fy = oy as f32;
420                let sx = ia * fx + ib * fy + itx;
421                let sy = ic * fx + id * fy + ity;
422                let ix = (sx.floor() as i32).clamp(0, w - 1) as u32;
423                let iy = (sy.floor() as i32).clamp(0, h - 1) as u32;
424                let out_off = ((oy * width + ox) as usize) * ch;
425                let in_off = ((iy * width + ix) as usize) * ch;
426                output[out_off..out_off + ch].copy_from_slice(&input[in_off..in_off + ch]);
427            }
428        }
429
430        Ok(())
431    }
432
433    // -------------------------------------------------------------------------
434    // Perspective transform — f32 path (CPU, inverse-mapped nearest-neighbour)
435    // -------------------------------------------------------------------------
436
437    /// Apply a 2D perspective (homography) transform to a packed f32 scalar image.
438    ///
439    /// `matrix` is a row-major 3×3 homography `H` stored as 9 f32 values.  For
440    /// each output pixel `(ox, oy)` the inverse `H⁻¹` maps it back to the source
441    /// coordinate `(sx/w, sy/w)`, which is sampled with clamped nearest-neighbour.
442    ///
443    /// # Arguments
444    ///
445    /// * `input`  – f32 buffer of `width * height` samples
446    /// * `output` – f32 buffer of `width * height` samples
447    /// * `width`  – image width
448    /// * `height` – image height
449    /// * `matrix` – 3×3 row-major homography `[h00,h01,h02, h10,h11,h12, h20,h21,h22]`
450    ///
451    /// # Errors
452    ///
453    /// Returns [`GpuError::Internal`] if the matrix is singular.
454    pub fn execute_perspective_f32(
455        &self,
456        input: &[f32],
457        output: &mut [f32],
458        width: u32,
459        height: u32,
460        matrix: [f32; 9],
461    ) -> Result<()> {
462        let expected = (width * height) as usize;
463        if input.len() < expected {
464            return Err(GpuError::InvalidBufferSize {
465                expected,
466                actual: input.len(),
467            });
468        }
469        if output.len() < expected {
470            return Err(GpuError::InvalidBufferSize {
471                expected,
472                actual: output.len(),
473            });
474        }
475
476        // Compute 3×3 inverse using cofactor expansion (f64 for precision).
477        let m = matrix.map(|v| v as f64);
478        let det = m[0] * (m[4] * m[8] - m[5] * m[7]) - m[1] * (m[3] * m[8] - m[5] * m[6])
479            + m[2] * (m[3] * m[7] - m[4] * m[6]);
480
481        if det.abs() < 1e-12 {
482            return Err(GpuError::Internal(
483                "Perspective matrix is singular".to_string(),
484            ));
485        }
486
487        let inv_det = 1.0 / det;
488        let inv: [f64; 9] = [
489            (m[4] * m[8] - m[5] * m[7]) * inv_det,
490            (m[2] * m[7] - m[1] * m[8]) * inv_det,
491            (m[1] * m[5] - m[2] * m[4]) * inv_det,
492            (m[5] * m[6] - m[3] * m[8]) * inv_det,
493            (m[0] * m[8] - m[2] * m[6]) * inv_det,
494            (m[2] * m[3] - m[0] * m[5]) * inv_det,
495            (m[3] * m[7] - m[4] * m[6]) * inv_det,
496            (m[1] * m[6] - m[0] * m[7]) * inv_det,
497            (m[0] * m[4] - m[1] * m[3]) * inv_det,
498        ];
499
500        let w = width as i32;
501        let h = height as i32;
502
503        for oy in 0..height {
504            for ox in 0..width {
505                let x = ox as f64;
506                let y = oy as f64;
507                let xh = inv[0] * x + inv[1] * y + inv[2];
508                let yh = inv[3] * x + inv[4] * y + inv[5];
509                let wh = inv[6] * x + inv[7] * y + inv[8];
510                if wh.abs() < 1e-12 {
511                    // Maps to infinity; use border pixel.
512                    output[(oy * width + ox) as usize] = input[0];
513                    continue;
514                }
515                let sx = (xh / wh).round() as i32;
516                let sy = (yh / wh).round() as i32;
517                let ix = sx.clamp(0, w - 1) as u32;
518                let iy = sy.clamp(0, h - 1) as u32;
519                let out_idx = (oy * width + ox) as usize;
520                let in_idx = (iy * width + ix) as usize;
521                output[out_idx] = input[in_idx];
522            }
523        }
524
525        Ok(())
526    }
527
528    // -------------------------------------------------------------------------
529    // Perspective transform — u8 path
530    // -------------------------------------------------------------------------
531
532    /// Apply a 2D perspective transform to a packed u8 image.
533    ///
534    /// Same geometry as [`Self::execute_perspective_f32`] but operates on multi-channel
535    /// u8 pixel data. `channels` is the number of bytes per pixel.
536    ///
537    /// # Errors
538    ///
539    /// Returns an error if buffers are too small or the matrix is singular.
540    pub fn execute_perspective_u8(
541        &self,
542        input: &[u8],
543        output: &mut [u8],
544        width: u32,
545        height: u32,
546        channels: u32,
547        matrix: [f32; 9],
548    ) -> Result<()> {
549        let expected = (width * height * channels) as usize;
550        if input.len() < expected {
551            return Err(GpuError::InvalidBufferSize {
552                expected,
553                actual: input.len(),
554            });
555        }
556        if output.len() < expected {
557            return Err(GpuError::InvalidBufferSize {
558                expected,
559                actual: output.len(),
560            });
561        }
562
563        let m = matrix.map(|v| v as f64);
564        let det = m[0] * (m[4] * m[8] - m[5] * m[7]) - m[1] * (m[3] * m[8] - m[5] * m[6])
565            + m[2] * (m[3] * m[7] - m[4] * m[6]);
566
567        if det.abs() < 1e-12 {
568            return Err(GpuError::Internal(
569                "Perspective matrix is singular".to_string(),
570            ));
571        }
572
573        let inv_det = 1.0 / det;
574        let inv: [f64; 9] = [
575            (m[4] * m[8] - m[5] * m[7]) * inv_det,
576            (m[2] * m[7] - m[1] * m[8]) * inv_det,
577            (m[1] * m[5] - m[2] * m[4]) * inv_det,
578            (m[5] * m[6] - m[3] * m[8]) * inv_det,
579            (m[0] * m[8] - m[2] * m[6]) * inv_det,
580            (m[2] * m[3] - m[0] * m[5]) * inv_det,
581            (m[3] * m[7] - m[4] * m[6]) * inv_det,
582            (m[1] * m[6] - m[0] * m[7]) * inv_det,
583            (m[0] * m[4] - m[1] * m[3]) * inv_det,
584        ];
585
586        let iw = width as i32;
587        let ih = height as i32;
588        let ch = channels as usize;
589
590        for oy in 0..height {
591            for ox in 0..width {
592                let x = ox as f64;
593                let y = oy as f64;
594                let xh = inv[0] * x + inv[1] * y + inv[2];
595                let yh = inv[3] * x + inv[4] * y + inv[5];
596                let wh = inv[6] * x + inv[7] * y + inv[8];
597                let (ix, iy) = if wh.abs() < 1e-12 {
598                    (0u32, 0u32)
599                } else {
600                    let sx = (xh / wh).round() as i32;
601                    let sy = (yh / wh).round() as i32;
602                    (sx.clamp(0, iw - 1) as u32, sy.clamp(0, ih - 1) as u32)
603                };
604                let out_off = ((oy * width + ox) as usize) * ch;
605                let in_off = ((iy * width + ix) as usize) * ch;
606                output[out_off..out_off + ch].copy_from_slice(&input[in_off..in_off + ch]);
607            }
608        }
609
610        Ok(())
611    }
612
613    // -------------------------------------------------------------------------
614    // FFT / IFFT — f32 path (CPU, 2D separable via OxiFFT)
615    // -------------------------------------------------------------------------
616
617    /// Compute a 2D forward FFT of an f32 scalar image via row-column separation.
618    ///
619    /// Input samples are treated as real values; adjacent pairs `(input[2k],
620    /// input[2k+1])` are **not** used as complex pairs — instead each f32 sample
621    /// is promoted to a complex number with imaginary part 0 before the 1D FFT.
622    ///
623    /// The result is stored interleaved: `output[2*k] = re`, `output[2*k+1] = im`
624    /// for each complex output coefficient.  Therefore `output` must have at least
625    /// `2 * width * height` elements.
626    ///
627    /// The 2D FFT is computed as 1D FFT of every row followed by 1D FFT of every
628    /// column (separability property).
629    ///
630    /// # Errors
631    ///
632    /// Returns [`GpuError::InvalidBufferSize`] if buffers are undersized.
633    pub fn execute_fft_f32(
634        &self,
635        input: &[f32],
636        output: &mut [f32],
637        width: u32,
638        height: u32,
639    ) -> Result<()> {
640        let n = (width * height) as usize;
641        if input.len() < n {
642            return Err(GpuError::InvalidBufferSize {
643                expected: n,
644                actual: input.len(),
645            });
646        }
647        // Output is interleaved complex: need 2*n f32 slots.
648        let out_needed = 2 * n;
649        if output.len() < out_needed {
650            return Err(GpuError::InvalidBufferSize {
651                expected: out_needed,
652                actual: output.len(),
653            });
654        }
655
656        let w = width as usize;
657        let h = height as usize;
658
659        // Build complex working buffer: real input → complex with im=0.
660        let mut work: Vec<Complex<f64>> = input[..n]
661            .iter()
662            .map(|&v| Complex::new(v as f64, 0.0))
663            .collect();
664
665        // 1D FFT of each row.
666        for row in 0..h {
667            let start = row * w;
668            let row_slice: Vec<Complex<f64>> = work[start..start + w].to_vec();
669            let row_fft = oxifft::fft(&row_slice);
670            work[start..start + w].copy_from_slice(&row_fft);
671        }
672
673        // 1D FFT of each column.
674        let mut col_buf = vec![Complex::new(0.0f64, 0.0); h];
675        for col in 0..w {
676            for row in 0..h {
677                col_buf[row] = work[row * w + col];
678            }
679            let col_fft = oxifft::fft(&col_buf);
680            for row in 0..h {
681                work[row * w + col] = col_fft[row];
682            }
683        }
684
685        // Store interleaved (re, im) into output.
686        for (k, c) in work.iter().enumerate() {
687            output[2 * k] = c.re as f32;
688            output[2 * k + 1] = c.im as f32;
689        }
690
691        Ok(())
692    }
693
694    /// Compute a 2D inverse FFT of a packed complex f32 buffer.
695    ///
696    /// Input format: interleaved complex `input[2*k] = re`, `input[2*k+1] = im`.
697    /// Output format: interleaved complex (same layout as [`Self::execute_fft_f32`]).
698    ///
699    /// The IFFT is computed as: conjugate → forward FFT → conjugate → divide by N.
700    ///
701    /// # Errors
702    ///
703    /// Returns [`GpuError::InvalidBufferSize`] if buffers are undersized.
704    pub fn execute_ifft_f32(
705        &self,
706        input: &[f32],
707        output: &mut [f32],
708        width: u32,
709        height: u32,
710    ) -> Result<()> {
711        let n = (width * height) as usize;
712        let in_needed = 2 * n;
713        if input.len() < in_needed {
714            return Err(GpuError::InvalidBufferSize {
715                expected: in_needed,
716                actual: input.len(),
717            });
718        }
719        let out_needed = 2 * n;
720        if output.len() < out_needed {
721            return Err(GpuError::InvalidBufferSize {
722                expected: out_needed,
723                actual: output.len(),
724            });
725        }
726
727        let w = width as usize;
728        let h = height as usize;
729
730        // Read interleaved complex and conjugate.
731        let mut work: Vec<Complex<f64>> = (0..n)
732            .map(|k| Complex::new(input[2 * k] as f64, -(input[2 * k + 1] as f64)))
733            .collect();
734
735        // Row FFTs.
736        for row in 0..h {
737            let start = row * w;
738            let row_slice: Vec<Complex<f64>> = work[start..start + w].to_vec();
739            let row_fft = oxifft::fft(&row_slice);
740            work[start..start + w].copy_from_slice(&row_fft);
741        }
742
743        // Column FFTs.
744        let mut col_buf = vec![Complex::new(0.0f64, 0.0); h];
745        for col in 0..w {
746            for row in 0..h {
747                col_buf[row] = work[row * w + col];
748            }
749            let col_fft = oxifft::fft(&col_buf);
750            for row in 0..h {
751                work[row * w + col] = col_fft[row];
752            }
753        }
754
755        // Conjugate and divide by N.
756        let scale = 1.0 / n as f64;
757        for (k, c) in work.iter().enumerate() {
758            output[2 * k] = (c.re * scale) as f32;
759            output[2 * k + 1] = (-c.im * scale) as f32;
760        }
761
762        Ok(())
763    }
764}
765
766/// Affine transformation matrix
767#[derive(Debug, Clone, Copy)]
768pub struct AffineMatrix {
769    /// Matrix elements [a, b, c, d, tx, ty]
770    /// [ a  b  tx ]
771    /// [ c  d  ty ]
772    /// [ 0  0  1  ]
773    pub elements: [f32; 6],
774}
775
776impl AffineMatrix {
777    /// Create an identity matrix
778    #[must_use]
779    pub fn identity() -> Self {
780        Self {
781            elements: [1.0, 0.0, 0.0, 0.0, 1.0, 0.0],
782        }
783    }
784
785    /// Create a translation matrix
786    #[must_use]
787    pub fn translation(tx: f32, ty: f32) -> Self {
788        Self {
789            elements: [1.0, 0.0, tx, 0.0, 1.0, ty],
790        }
791    }
792
793    /// Create a rotation matrix
794    #[must_use]
795    pub fn rotation(angle_radians: f32) -> Self {
796        let cos = angle_radians.cos();
797        let sin = angle_radians.sin();
798        Self {
799            elements: [cos, -sin, 0.0, sin, cos, 0.0],
800        }
801    }
802
803    /// Create a scaling matrix
804    #[must_use]
805    pub fn scaling(sx: f32, sy: f32) -> Self {
806        Self {
807            elements: [sx, 0.0, 0.0, 0.0, sy, 0.0],
808        }
809    }
810
811    /// Combine two affine transformations
812    #[must_use]
813    pub fn combine(&self, other: &Self) -> Self {
814        let a1 = self.elements;
815        let a2 = other.elements;
816
817        Self {
818            elements: [
819                a1[0] * a2[0] + a1[1] * a2[3],
820                a1[0] * a2[1] + a1[1] * a2[4],
821                a1[0] * a2[2] + a1[1] * a2[5] + a1[2],
822                a1[3] * a2[0] + a1[4] * a2[3],
823                a1[3] * a2[1] + a1[4] * a2[4],
824                a1[3] * a2[2] + a1[4] * a2[5] + a1[5],
825            ],
826        }
827    }
828
829    /// Get matrix elements
830    #[must_use]
831    pub fn as_array(&self) -> [f32; 6] {
832        self.elements
833    }
834}
835
836impl Default for AffineMatrix {
837    fn default() -> Self {
838        Self::identity()
839    }
840}
841
842/// Warp kernel for geometric transformations
843pub struct WarpKernel {
844    matrix: AffineMatrix,
845}
846
847impl WarpKernel {
848    /// Create a new warp kernel
849    #[must_use]
850    pub fn new(matrix: AffineMatrix) -> Self {
851        Self { matrix }
852    }
853
854    /// Create a rotation warp
855    #[must_use]
856    pub fn rotation(angle_degrees: f32, center_x: f32, center_y: f32) -> Self {
857        let angle_radians = angle_degrees.to_radians();
858
859        // Translate to origin, rotate, translate back
860        let t1 = AffineMatrix::translation(-center_x, -center_y);
861        let r = AffineMatrix::rotation(angle_radians);
862        let t2 = AffineMatrix::translation(center_x, center_y);
863
864        let matrix = t1.combine(&r).combine(&t2);
865
866        Self::new(matrix)
867    }
868
869    /// Create a scaling warp
870    #[must_use]
871    pub fn scaling(sx: f32, sy: f32, center_x: f32, center_y: f32) -> Self {
872        let t1 = AffineMatrix::translation(-center_x, -center_y);
873        let s = AffineMatrix::scaling(sx, sy);
874        let t2 = AffineMatrix::translation(center_x, center_y);
875
876        let matrix = t1.combine(&s).combine(&t2);
877
878        Self::new(matrix)
879    }
880
881    /// Get the transformation matrix
882    #[must_use]
883    pub fn matrix(&self) -> &AffineMatrix {
884        &self.matrix
885    }
886}
887
888#[cfg(test)]
889mod tests {
890    use super::*;
891
892    #[test]
893    fn test_transform_kernel_creation() {
894        let kernel = TransformKernel::dct();
895        assert_eq!(kernel.transform_type(), TransformType::DCT);
896        assert!(kernel.is_frequency_domain());
897        assert!(!kernel.is_geometric());
898
899        let kernel = TransformKernel::rotate(90);
900        assert_eq!(kernel.transform_type(), TransformType::Rotate90);
901        assert!(!kernel.is_frequency_domain());
902        assert!(kernel.is_geometric());
903    }
904
905    #[test]
906    fn test_affine_matrix_identity() {
907        let identity = AffineMatrix::identity();
908        let elements = identity.as_array();
909        assert_eq!(elements, [1.0, 0.0, 0.0, 0.0, 1.0, 0.0]);
910    }
911
912    #[test]
913    fn test_affine_matrix_translation() {
914        let trans = AffineMatrix::translation(10.0, 20.0);
915        let elements = trans.as_array();
916        assert_eq!(elements[2], 10.0);
917        assert_eq!(elements[5], 20.0);
918    }
919
920    #[test]
921    fn test_affine_matrix_scaling() {
922        let scale = AffineMatrix::scaling(2.0, 3.0);
923        let elements = scale.as_array();
924        assert_eq!(elements[0], 2.0);
925        assert_eq!(elements[4], 3.0);
926    }
927
928    #[test]
929    fn test_affine_matrix_combination() {
930        let t1 = AffineMatrix::translation(10.0, 20.0);
931        let s = AffineMatrix::scaling(2.0, 2.0);
932        let combined = t1.combine(&s);
933
934        // The result should be a combined transformation
935        assert!(combined.elements[0] > 0.0);
936    }
937
938    #[test]
939    fn test_flops_estimation() {
940        let flops_dct = TransformKernel::estimate_flops(64, 64, TransformType::DCT);
941        let flops_rotate = TransformKernel::estimate_flops(64, 64, TransformType::Rotate90);
942
943        assert!(flops_dct > 0);
944        assert!(flops_rotate > 0);
945        assert!(flops_dct > flops_rotate); // DCT should be more expensive
946    }
947
948    // -------------------------------------------------------------------------
949    // New tests for affine, perspective, FFT, IFFT
950    // -------------------------------------------------------------------------
951
952    /// Identity affine: output must equal input.
953    #[test]
954    fn test_affine_identity() {
955        let kernel = TransformKernel::new(TransformType::Affine);
956        let width = 4u32;
957        let height = 4u32;
958        let input: Vec<f32> = (0..(width * height)).map(|i| i as f32).collect();
959        let mut output = vec![0.0f32; (width * height) as usize];
960        // Identity forward matrix [a=1, b=0, c=0, d=1, tx=0, ty=0]
961        let identity = [1.0f32, 0.0, 0.0, 1.0, 0.0, 0.0];
962        kernel
963            .execute_affine_f32(&input, &mut output, width, height, identity)
964            .expect("affine identity must succeed");
965        assert_eq!(input, output, "identity affine must preserve all values");
966    }
967
968    /// Affine with 2× uniform scale: inverse maps output → input by dividing
969    /// coordinates by 2, so `output[oy][ox]` comes from `input[oy/2][ox/2]`.
970    #[test]
971    fn test_affine_scale() {
972        let kernel = TransformKernel::new(TransformType::Affine);
973        let width = 8u32;
974        let height = 8u32;
975        // Assign unique values: pixel (x, y) = y * width + x as f32.
976        let input: Vec<f32> = (0..(width * height)).map(|i| i as f32).collect();
977        let mut output = vec![0.0f32; (width * height) as usize];
978        // Forward 2× scale: [a=2, b=0, c=0, d=2, tx=0, ty=0]
979        // Inverse: maps (ox, oy) → (ox/2, oy/2)
980        let mat = [2.0f32, 0.0, 0.0, 2.0, 0.0, 0.0];
981        kernel
982            .execute_affine_f32(&input, &mut output, width, height, mat)
983            .expect("affine 2x scale must succeed");
984        // Check several output pixels: output[oy*w+ox] == input[(oy/2)*w + (ox/2)]
985        for oy in 0..height {
986            for ox in 0..width {
987                let expected = input[((oy / 2) * width + (ox / 2)) as usize];
988                let got = output[(oy * width + ox) as usize];
989                assert!(
990                    (got - expected).abs() < 1e-5,
991                    "output[{oy}][{ox}]={got}, expected {expected}"
992                );
993            }
994        }
995    }
996
997    /// Singular affine matrix must return an error.
998    #[test]
999    fn test_affine_singular_returns_error() {
1000        let kernel = TransformKernel::new(TransformType::Affine);
1001        let width = 4u32;
1002        let height = 4u32;
1003        let input = vec![1.0f32; (width * height) as usize];
1004        let mut output = vec![0.0f32; (width * height) as usize];
1005        // Singular: det = 0*0 - 0*0 = 0
1006        let singular = [0.0f32, 0.0, 0.0, 0.0, 0.0, 0.0];
1007        let result = kernel.execute_affine_f32(&input, &mut output, width, height, singular);
1008        assert!(result.is_err(), "singular affine must return error");
1009    }
1010
1011    /// FFT of a unit impulse (1, 0, 0, …) should have all output magnitudes = 1.
1012    #[test]
1013    fn test_fft_impulse() {
1014        let kernel = TransformKernel::new(TransformType::FFT);
1015        let width = 4u32;
1016        let height = 4u32;
1017        let n = (width * height) as usize;
1018        let mut input = vec![0.0f32; n];
1019        input[0] = 1.0; // unit impulse
1020                        // Output needs 2*n slots (interleaved complex)
1021        let mut output = vec![0.0f32; 2 * n];
1022        kernel
1023            .execute_fft_f32(&input, &mut output, width, height)
1024            .expect("FFT of impulse must succeed");
1025        // All magnitudes should be 1.0 (within floating-point tolerance)
1026        for k in 0..n {
1027            let re = output[2 * k] as f64;
1028            let im = output[2 * k + 1] as f64;
1029            let mag = re.hypot(im);
1030            assert!(
1031                (mag - 1.0).abs() < 1e-4,
1032                "FFT[{k}] magnitude={mag:.6}, expected 1.0"
1033            );
1034        }
1035    }
1036
1037    /// FFT followed by IFFT must recover the original signal.
1038    #[test]
1039    fn test_fft_ifft_roundtrip() {
1040        let kernel = TransformKernel::new(TransformType::FFT);
1041        let width = 4u32;
1042        let height = 4u32;
1043        let n = (width * height) as usize;
1044        let input: Vec<f32> = (0..n).map(|i| i as f32 * 0.1).collect();
1045        let mut freq = vec![0.0f32; 2 * n];
1046        kernel
1047            .execute_fft_f32(&input, &mut freq, width, height)
1048            .expect("FFT must succeed");
1049        let mut recovered = vec![0.0f32; 2 * n];
1050        kernel
1051            .execute_ifft_f32(&freq, &mut recovered, width, height)
1052            .expect("IFFT must succeed");
1053        // Real parts of recovered should match input; imaginary parts ≈ 0.
1054        for k in 0..n {
1055            let diff = (recovered[2 * k] - input[k]).abs();
1056            assert!(
1057                diff < 1e-4,
1058                "IFFT roundtrip: idx={k} expected={:.4}, got={:.4}",
1059                input[k],
1060                recovered[2 * k]
1061            );
1062        }
1063    }
1064
1065    /// Perspective with identity homography should be a no-op.
1066    #[test]
1067    fn test_perspective_identity_f32() {
1068        let kernel = TransformKernel::new(TransformType::Perspective);
1069        let width = 4u32;
1070        let height = 4u32;
1071        let input: Vec<f32> = (0..(width * height)).map(|i| i as f32).collect();
1072        let mut output = vec![0.0f32; (width * height) as usize];
1073        // Identity 3×3 homography
1074        let identity = [1.0f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
1075        kernel
1076            .execute_perspective_f32(&input, &mut output, width, height, identity)
1077            .expect("perspective identity must succeed");
1078        assert_eq!(input, output, "identity perspective must preserve values");
1079    }
1080
1081    /// Singular perspective matrix must return an error.
1082    #[test]
1083    fn test_perspective_singular() {
1084        let kernel = TransformKernel::new(TransformType::Perspective);
1085        let width = 4u32;
1086        let height = 4u32;
1087        let input = vec![1.0f32; (width * height) as usize];
1088        let mut output = vec![0.0f32; (width * height) as usize];
1089        // All-zero matrix is singular.
1090        let singular = [0.0f32; 9];
1091        let result = kernel.execute_perspective_f32(&input, &mut output, width, height, singular);
1092        assert!(result.is_err(), "singular perspective must return error");
1093    }
1094
1095    /// Affine u8 identity: output must equal input.
1096    #[test]
1097    fn test_affine_u8_identity() {
1098        let kernel = TransformKernel::new(TransformType::Affine);
1099        let width = 4u32;
1100        let height = 4u32;
1101        let channels = 3u32;
1102        let input: Vec<u8> = (0..(width * height * channels) as usize)
1103            .map(|i| (i % 256) as u8)
1104            .collect();
1105        let mut output = vec![0u8; input.len()];
1106        let identity = [1.0f32, 0.0, 0.0, 1.0, 0.0, 0.0];
1107        kernel
1108            .execute_affine_u8(&input, &mut output, width, height, channels, identity)
1109            .expect("affine u8 identity must succeed");
1110        assert_eq!(input, output, "identity affine u8 must preserve all bytes");
1111    }
1112}