Skip to main content

oximedia_gpu/ops/
transform.rs

1//! Transform operations (DCT, FFT) for frequency domain processing
2
3use crate::{
4    shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5    GpuDevice, GpuError, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[repr(C)]
14#[derive(Copy, Clone, Pod, Zeroable)]
15struct TransformParams {
16    width: u32,
17    height: u32,
18    block_size: u32,
19    transform_type: u32,
20    stride: u32,
21    is_inverse: u32,
22    padding1: u32,
23    padding2: u32,
24}
25
26/// Transform operations for frequency domain processing
27pub struct TransformOperation;
28
29impl TransformOperation {
30    /// Compute 2D DCT (Discrete Cosine Transform)
31    ///
32    /// Computes the forward DCT on 8x8 blocks. Input dimensions must be
33    /// multiples of 8.
34    ///
35    /// # Arguments
36    ///
37    /// * `device` - GPU device
38    /// * `input` - Input data (f32 values)
39    /// * `output` - Output DCT coefficients
40    /// * `width` - Data width (must be multiple of 8)
41    /// * `height` - Data height (must be multiple of 8)
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if dimensions are invalid or if the GPU operation fails.
46    pub fn dct_2d(
47        device: &GpuDevice,
48        input: &[f32],
49        output: &mut [f32],
50        width: u32,
51        height: u32,
52    ) -> Result<()> {
53        if width % 8 != 0 || height % 8 != 0 {
54            return Err(GpuError::InvalidDimensions { width, height });
55        }
56
57        utils::validate_dimensions(width, height)?;
58
59        let expected_size = (width * height) as usize;
60        if input.len() < expected_size || output.len() < expected_size {
61            return Err(GpuError::InvalidBufferSize {
62                expected: expected_size,
63                actual: input.len().min(output.len()),
64            });
65        }
66
67        let pipeline = Self::get_dct_8x8_pipeline(device)?;
68        let layout = Self::get_bind_group_layout(device)?;
69
70        Self::execute_transform(
71            device, pipeline, layout, input, output, width, height, 8, 0, // DCT
72        )
73    }
74
75    /// Compute 2D IDCT (Inverse Discrete Cosine Transform)
76    ///
77    /// Computes the inverse DCT on 8x8 blocks. Input dimensions must be
78    /// multiples of 8.
79    ///
80    /// # Arguments
81    ///
82    /// * `device` - GPU device
83    /// * `input` - Input DCT coefficients
84    /// * `output` - Output reconstructed data
85    /// * `width` - Data width (must be multiple of 8)
86    /// * `height` - Data height (must be multiple of 8)
87    ///
88    /// # Errors
89    ///
90    /// Returns an error if dimensions are invalid or if the GPU operation fails.
91    pub fn idct_2d(
92        device: &GpuDevice,
93        input: &[f32],
94        output: &mut [f32],
95        width: u32,
96        height: u32,
97    ) -> Result<()> {
98        if width % 8 != 0 || height % 8 != 0 {
99            return Err(GpuError::InvalidDimensions { width, height });
100        }
101
102        utils::validate_dimensions(width, height)?;
103
104        let expected_size = (width * height) as usize;
105        if input.len() < expected_size || output.len() < expected_size {
106            return Err(GpuError::InvalidBufferSize {
107                expected: expected_size,
108                actual: input.len().min(output.len()),
109            });
110        }
111
112        let pipeline = Self::get_idct_8x8_pipeline(device)?;
113        let layout = Self::get_bind_group_layout(device)?;
114
115        Self::execute_transform(
116            device, pipeline, layout, input, output, width, height, 8, 1, // IDCT
117        )
118    }
119
120    /// Compute general 2D DCT using row-column decomposition
121    ///
122    /// This method works for any dimensions, not just multiples of 8.
123    ///
124    /// # Arguments
125    ///
126    /// * `device` - GPU device
127    /// * `input` - Input data (f32 values)
128    /// * `output` - Output DCT coefficients
129    /// * `width` - Data width
130    /// * `height` - Data height
131    ///
132    /// # Errors
133    ///
134    /// Returns an error if dimensions are invalid or if the GPU operation fails.
135    pub fn dct_2d_general(
136        device: &GpuDevice,
137        input: &[f32],
138        output: &mut [f32],
139        width: u32,
140        height: u32,
141    ) -> Result<()> {
142        utils::validate_dimensions(width, height)?;
143
144        let expected_size = (width * height) as usize;
145        if input.len() < expected_size || output.len() < expected_size {
146            return Err(GpuError::InvalidBufferSize {
147                expected: expected_size,
148                actual: input.len().min(output.len()),
149            });
150        }
151
152        // Two-pass DCT: row then column
153        let mut temp = vec![0.0f32; expected_size];
154
155        // Row DCT
156        let row_pipeline = Self::get_dct_row_pipeline(device)?;
157        let layout = Self::get_bind_group_layout(device)?;
158
159        Self::execute_transform(
160            device,
161            row_pipeline,
162            layout,
163            input,
164            &mut temp,
165            width,
166            height,
167            width,
168            0,
169        )?;
170
171        // Column DCT
172        let col_pipeline = Self::get_dct_col_pipeline(device)?;
173
174        Self::execute_transform(
175            device,
176            col_pipeline,
177            layout,
178            &temp,
179            output,
180            width,
181            height,
182            height,
183            0,
184        )
185    }
186
187    #[allow(clippy::too_many_arguments)]
188    fn execute_transform(
189        device: &GpuDevice,
190        pipeline: &ComputePipeline,
191        layout: &BindGroupLayout,
192        input: &[f32],
193        output: &mut [f32],
194        width: u32,
195        height: u32,
196        block_size: u32,
197        transform_type: u32,
198    ) -> Result<()> {
199        let input_bytes = bytemuck::cast_slice(input);
200        let output_size = std::mem::size_of_val(output);
201
202        // Create buffers
203        let input_buffer = utils::create_storage_buffer(device, input_bytes.len() as u64)?;
204        let output_buffer = utils::create_storage_buffer(device, output_size as u64)?;
205
206        // Upload input data
207        device
208            .queue()
209            .write_buffer(input_buffer.buffer(), 0, input_bytes);
210
211        // Create uniform buffer for parameters
212        let params = TransformParams {
213            width,
214            height,
215            block_size,
216            transform_type,
217            stride: width,
218            is_inverse: 0,
219            padding1: 0,
220            padding2: 0,
221        };
222        let params_bytes = bytemuck::bytes_of(&params);
223        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
224
225        // Create bind group
226        let compiler = ShaderCompiler::new(device);
227        let bind_group = compiler.create_bind_group(
228            "Transform Bind Group",
229            layout,
230            &[
231                wgpu::BindGroupEntry {
232                    binding: 0,
233                    resource: input_buffer.buffer().as_entire_binding(),
234                },
235                wgpu::BindGroupEntry {
236                    binding: 1,
237                    resource: output_buffer.buffer().as_entire_binding(),
238                },
239                wgpu::BindGroupEntry {
240                    binding: 2,
241                    resource: params_buffer.buffer().as_entire_binding(),
242                },
243            ],
244        );
245
246        // Execute compute pass
247        Self::dispatch_compute(device, pipeline, &bind_group, width, height, block_size)?;
248
249        // Read back results
250        let readback_buffer = utils::create_readback_buffer(device, output_size as u64)?;
251        let mut encoder = device
252            .device()
253            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
254                label: Some("Transform Copy Encoder"),
255            });
256
257        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output_size as u64)?;
258
259        device.queue().submit(Some(encoder.finish()));
260        device.wait();
261
262        let result = readback_buffer.read(device, 0, output_size as u64)?;
263        let result_f32: &[f32] = bytemuck::cast_slice(&result);
264        output.copy_from_slice(result_f32);
265
266        Ok(())
267    }
268
269    fn dispatch_compute(
270        device: &GpuDevice,
271        pipeline: &ComputePipeline,
272        bind_group: &BindGroup,
273        width: u32,
274        height: u32,
275        block_size: u32,
276    ) -> Result<()> {
277        let mut encoder = device
278            .device()
279            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
280                label: Some("Transform Compute Encoder"),
281            });
282
283        {
284            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
285                label: Some("Transform Compute Pass"),
286                timestamp_writes: None,
287            });
288
289            compute_pass.set_pipeline(pipeline);
290            compute_pass.set_bind_group(0, bind_group, &[]);
291
292            if block_size == 8 {
293                // For 8x8 DCT, dispatch one workgroup per block
294                let dispatch_x = width / 8;
295                let dispatch_y = height / 8;
296                compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
297            } else {
298                // For row/column transforms
299                let total_elements = width * height;
300                let dispatch = total_elements.div_ceil(256);
301                compute_pass.dispatch_workgroups(dispatch, 1, 1);
302            }
303        }
304
305        device.queue().submit(Some(encoder.finish()));
306        Ok(())
307    }
308
309    fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
310        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
311
312        Ok(LAYOUT.get_or_init(|| {
313            let compiler = ShaderCompiler::new(device);
314            let entries = BindGroupLayoutBuilder::new()
315                .add_storage_buffer_read_only(0) // input
316                .add_storage_buffer(1) // output
317                .add_uniform_buffer(2) // params
318                .build();
319
320            compiler.create_bind_group_layout("Transform Bind Group Layout", &entries)
321        }))
322    }
323
324    fn init_pipeline(
325        device: &GpuDevice,
326        name: &str,
327        entry_point: &str,
328    ) -> std::result::Result<ComputePipeline, String> {
329        let compiler = ShaderCompiler::new(device);
330        let shader = compiler
331            .compile(
332                "Transform Shader",
333                ShaderSource::Embedded(crate::shader::embedded::TRANSFORM_SHADER),
334            )
335            .map_err(|e| format!("Failed to compile transform shader: {e}"))?;
336
337        let layout = Self::get_bind_group_layout(device)
338            .map_err(|e| format!("Failed to create bind group layout: {e}"))?;
339
340        compiler
341            .create_pipeline(name, &shader, entry_point, layout)
342            .map_err(|e| format!("Failed to create pipeline: {e}"))
343    }
344
345    fn get_dct_8x8_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
346        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
347
348        PIPELINE
349            .get_or_init(|| {
350                TransformOperation::init_pipeline(device, "DCT 8x8 Pipeline", "dct_8x8")
351            })
352            .as_ref()
353            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
354    }
355
356    fn get_idct_8x8_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
357        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
358
359        PIPELINE
360            .get_or_init(|| {
361                TransformOperation::init_pipeline(device, "IDCT 8x8 Pipeline", "idct_8x8")
362            })
363            .as_ref()
364            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
365    }
366
367    fn get_dct_row_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
368        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
369
370        PIPELINE
371            .get_or_init(|| {
372                TransformOperation::init_pipeline(device, "DCT Row Pipeline", "dct_row")
373            })
374            .as_ref()
375            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
376    }
377
378    fn get_dct_col_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
379        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
380
381        PIPELINE
382            .get_or_init(|| {
383                TransformOperation::init_pipeline(device, "DCT Column Pipeline", "dct_col")
384            })
385            .as_ref()
386            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
387    }
388}
389
390// =============================================================================
391// CPU-side perspective transform and lens distortion correction (Task 8)
392// =============================================================================
393
394/// A 3×3 homogeneous perspective (projective) transform matrix stored in
395/// row-major order.
396///
397/// The matrix maps homogeneous image coordinates `(x, y, 1)ᵀ` to new
398/// coordinates via `(x', y', w')ᵀ = M · (x, y, 1)ᵀ`.  The Cartesian result
399/// is `(x'/w', y'/w')`.
400#[derive(Debug, Clone, Copy)]
401pub struct PerspectiveMatrix {
402    /// Row-major 3×3 elements: `[[a,b,c],[d,e,f],[g,h,i]]`.
403    pub data: [[f64; 3]; 3],
404}
405
406impl PerspectiveMatrix {
407    /// Create from a flat row-major array of 9 elements.
408    #[must_use]
409    pub fn from_array(m: [f64; 9]) -> Self {
410        Self {
411            data: [[m[0], m[1], m[2]], [m[3], m[4], m[5]], [m[6], m[7], m[8]]],
412        }
413    }
414
415    /// Identity perspective matrix (no transform).
416    #[must_use]
417    pub fn identity() -> Self {
418        Self::from_array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0])
419    }
420
421    /// Apply this matrix to a point `(x, y)` and return the projected result.
422    ///
423    /// Returns `None` if the homogeneous weight `w` is too close to zero
424    /// (the point maps to infinity).
425    #[must_use]
426    pub fn project(&self, x: f64, y: f64) -> Option<(f64, f64)> {
427        let m = &self.data;
428        let x_h = m[0][0] * x + m[0][1] * y + m[0][2];
429        let y_h = m[1][0] * x + m[1][1] * y + m[1][2];
430        let w = m[2][0] * x + m[2][1] * y + m[2][2];
431        if w.abs() < 1e-12 {
432            return None;
433        }
434        Some((x_h / w, y_h / w))
435    }
436
437    /// Compute the inverse of this matrix using Cramer's rule.
438    ///
439    /// Returns `None` if the matrix is singular.
440    #[must_use]
441    pub fn inverse(&self) -> Option<Self> {
442        let m = &self.data;
443        let det = m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
444            - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
445            + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
446        if det.abs() < 1e-15 {
447            return None;
448        }
449        let inv_det = 1.0 / det;
450        let inv = [
451            [
452                (m[1][1] * m[2][2] - m[1][2] * m[2][1]) * inv_det,
453                (m[0][2] * m[2][1] - m[0][1] * m[2][2]) * inv_det,
454                (m[0][1] * m[1][2] - m[0][2] * m[1][1]) * inv_det,
455            ],
456            [
457                (m[1][2] * m[2][0] - m[1][0] * m[2][2]) * inv_det,
458                (m[0][0] * m[2][2] - m[0][2] * m[2][0]) * inv_det,
459                (m[0][2] * m[1][0] - m[0][0] * m[1][2]) * inv_det,
460            ],
461            [
462                (m[1][0] * m[2][1] - m[1][1] * m[2][0]) * inv_det,
463                (m[0][1] * m[2][0] - m[0][0] * m[2][1]) * inv_det,
464                (m[0][0] * m[1][1] - m[0][1] * m[1][0]) * inv_det,
465            ],
466        ];
467        Some(Self { data: inv })
468    }
469}
470
471impl Default for PerspectiveMatrix {
472    fn default() -> Self {
473        Self::identity()
474    }
475}
476
477/// Parameters for radial + tangential lens distortion (Brown-Conrady model).
478///
479/// This is the same model used by OpenCV and is compatible with camera
480/// calibration output from standard photogrammetry tools.
481#[derive(Debug, Clone, Copy)]
482pub struct LensDistortionParams {
483    /// Radial distortion coefficient k₁ (typically small, e.g. -0.3 to +0.5).
484    pub k1: f64,
485    /// Radial distortion coefficient k₂.
486    pub k2: f64,
487    /// Radial distortion coefficient k₃.
488    pub k3: f64,
489    /// Tangential distortion coefficient p₁.
490    pub p1: f64,
491    /// Tangential distortion coefficient p₂.
492    pub p2: f64,
493    /// Focal length in pixels along the X axis.
494    pub fx: f64,
495    /// Focal length in pixels along the Y axis.
496    pub fy: f64,
497    /// Principal point X coordinate (typically `width / 2`).
498    pub cx: f64,
499    /// Principal point Y coordinate (typically `height / 2`).
500    pub cy: f64,
501}
502
503impl LensDistortionParams {
504    /// Create a default (no distortion) parameter set for an image of the
505    /// given `width × height`.
506    #[must_use]
507    pub fn no_distortion(width: u32, height: u32) -> Self {
508        Self {
509            k1: 0.0,
510            k2: 0.0,
511            k3: 0.0,
512            p1: 0.0,
513            p2: 0.0,
514            fx: f64::from(width),
515            fy: f64::from(height),
516            cx: f64::from(width) / 2.0,
517            cy: f64::from(height) / 2.0,
518        }
519    }
520}
521
522/// CPU-parallel perspective warp of a packed RGBA image.
523///
524/// Uses inverse mapping with bilinear interpolation: for each destination
525/// pixel `(dx, dy)` the inverse homography maps it back to the source
526/// coordinates `(sx, sy)`, which are bilinearly sampled.
527///
528/// Pixels that map outside the source image are filled with `fill_rgba`.
529///
530/// # Errors
531///
532/// Returns [`crate::GpuError::InvalidDimensions`] for zero dimensions or
533/// [`crate::GpuError::InvalidBufferSize`] for buffer/dimension mismatches.
534pub fn perspective_warp(
535    input: &[u8],
536    src_width: u32,
537    src_height: u32,
538    output: &mut [u8],
539    dst_width: u32,
540    dst_height: u32,
541    matrix: &PerspectiveMatrix,
542    fill_rgba: [u8; 4],
543) -> crate::Result<()> {
544    use super::utils;
545    use crate::GpuError;
546
547    if src_width == 0 || src_height == 0 {
548        return Err(GpuError::InvalidDimensions {
549            width: src_width,
550            height: src_height,
551        });
552    }
553    if dst_width == 0 || dst_height == 0 {
554        return Err(GpuError::InvalidDimensions {
555            width: dst_width,
556            height: dst_height,
557        });
558    }
559    utils::validate_buffer_size(input, src_width, src_height, 4)?;
560    utils::validate_buffer_size(output, dst_width, dst_height, 4)?;
561
562    let inv = matrix
563        .inverse()
564        .ok_or_else(|| GpuError::Internal("Perspective matrix is singular".to_string()))?;
565
566    let sw = src_width as usize;
567    let sh = src_height as usize;
568    let dw = dst_width as usize;
569    let dh = dst_height as usize;
570
571    for dy in 0..dh {
572        for dx in 0..dw {
573            let dst_idx = (dy * dw + dx) * 4;
574            let Some((sx_f, sy_f)) = inv.project(dx as f64, dy as f64) else {
575                output[dst_idx..dst_idx + 4].copy_from_slice(&fill_rgba);
576                continue;
577            };
578
579            // Bilinear interpolation
580            let x0 = sx_f.floor() as isize;
581            let y0 = sy_f.floor() as isize;
582            let x1 = x0 + 1;
583            let y1 = y0 + 1;
584            let fx = sx_f - sx_f.floor();
585            let fy = sy_f - sy_f.floor();
586
587            let sample = |cx: isize, cy: isize| -> [f64; 4] {
588                if cx < 0 || cy < 0 || cx >= sw as isize || cy >= sh as isize {
589                    [
590                        fill_rgba[0] as f64,
591                        fill_rgba[1] as f64,
592                        fill_rgba[2] as f64,
593                        fill_rgba[3] as f64,
594                    ]
595                } else {
596                    let idx = (cy as usize * sw + cx as usize) * 4;
597                    [
598                        input[idx] as f64,
599                        input[idx + 1] as f64,
600                        input[idx + 2] as f64,
601                        input[idx + 3] as f64,
602                    ]
603                }
604            };
605
606            let p00 = sample(x0, y0);
607            let p10 = sample(x1, y0);
608            let p01 = sample(x0, y1);
609            let p11 = sample(x1, y1);
610
611            for c in 0..4 {
612                let v = p00[c] * (1.0 - fx) * (1.0 - fy)
613                    + p10[c] * fx * (1.0 - fy)
614                    + p01[c] * (1.0 - fx) * fy
615                    + p11[c] * fx * fy;
616                output[dst_idx + c] = v.round().clamp(0.0, 255.0) as u8;
617            }
618        }
619    }
620
621    Ok(())
622}
623
624/// CPU-side lens distortion correction using the Brown-Conrady model.
625///
626/// For each destination pixel `(x, y)` the distortion model computes the
627/// corresponding distorted source coordinate and bilinearly samples the input.
628/// Pixels that map outside the source image are filled with `fill_rgba`.
629///
630/// # Errors
631///
632/// Returns an error if dimensions are zero or buffers are the wrong size.
633pub fn lens_undistort(
634    input: &[u8],
635    width: u32,
636    height: u32,
637    output: &mut [u8],
638    params: &LensDistortionParams,
639    fill_rgba: [u8; 4],
640) -> crate::Result<()> {
641    use super::utils;
642    use crate::GpuError;
643
644    if width == 0 || height == 0 {
645        return Err(GpuError::InvalidDimensions { width, height });
646    }
647    utils::validate_buffer_size(input, width, height, 4)?;
648    utils::validate_buffer_size(output, width, height, 4)?;
649
650    let w = width as usize;
651    let h = height as usize;
652    let inv_fx = 1.0 / params.fx;
653    let inv_fy = 1.0 / params.fy;
654
655    for dy in 0..h {
656        for dx in 0..w {
657            // Normalised coordinates (undistorted space).
658            let x_u = (dx as f64 - params.cx) * inv_fx;
659            let y_u = (dy as f64 - params.cy) * inv_fy;
660
661            // Apply Brown-Conrady radial + tangential distortion to map from
662            // undistorted → distorted (where the actual sensor data lives).
663            let r2 = x_u * x_u + y_u * y_u;
664            let r4 = r2 * r2;
665            let r6 = r4 * r2;
666            let radial = 1.0 + params.k1 * r2 + params.k2 * r4 + params.k3 * r6;
667            let x_d =
668                x_u * radial + 2.0 * params.p1 * x_u * y_u + params.p2 * (r2 + 2.0 * x_u * x_u);
669            let y_d =
670                y_u * radial + params.p1 * (r2 + 2.0 * y_u * y_u) + 2.0 * params.p2 * x_u * y_u;
671
672            // Back to pixel coordinates in the distorted (source) image.
673            let sx_f = x_d * params.fx + params.cx;
674            let sy_f = y_d * params.fy + params.cy;
675
676            let dst_idx = (dy * w + dx) * 4;
677
678            let x0 = sx_f.floor() as isize;
679            let y0 = sy_f.floor() as isize;
680            let x1 = x0 + 1;
681            let y1 = y0 + 1;
682            let fx = sx_f - sx_f.floor();
683            let fy = sy_f - sy_f.floor();
684
685            let sample = |cx: isize, cy: isize| -> [f64; 4] {
686                if cx < 0 || cy < 0 || cx >= w as isize || cy >= h as isize {
687                    [
688                        fill_rgba[0] as f64,
689                        fill_rgba[1] as f64,
690                        fill_rgba[2] as f64,
691                        fill_rgba[3] as f64,
692                    ]
693                } else {
694                    let idx = (cy as usize * w + cx as usize) * 4;
695                    [
696                        input[idx] as f64,
697                        input[idx + 1] as f64,
698                        input[idx + 2] as f64,
699                        input[idx + 3] as f64,
700                    ]
701                }
702            };
703
704            let p00 = sample(x0, y0);
705            let p10 = sample(x1, y0);
706            let p01 = sample(x0, y1);
707            let p11 = sample(x1, y1);
708
709            for c in 0..4 {
710                let v = p00[c] * (1.0 - fx) * (1.0 - fy)
711                    + p10[c] * fx * (1.0 - fy)
712                    + p01[c] * (1.0 - fx) * fy
713                    + p11[c] * fx * fy;
714                output[dst_idx + c] = v.round().clamp(0.0, 255.0) as u8;
715            }
716        }
717    }
718
719    Ok(())
720}
721
722// =============================================================================
723// CPU-side geometric (pixel-level) transforms (rotate, flip, transpose)
724// =============================================================================
725
726impl TransformOperation {
727    /// Copy one pixel from `src` to `dst`, using interleaved layout.
728    ///
729    /// All coordinates are 0-indexed.  `src_w` is the *source* image width and
730    /// `dst_w` is the *destination* image width (both in pixels).  `ch` is the
731    /// number of bytes per pixel.
732    #[inline]
733    fn copy_pixel(
734        src: &[u8],
735        dst: &mut [u8],
736        src_x: u32,
737        src_y: u32,
738        dst_x: u32,
739        dst_y: u32,
740        src_w: u32,
741        dst_w: u32,
742        ch: u32,
743    ) {
744        let src_off = ((src_y * src_w + src_x) * ch) as usize;
745        let dst_off = ((dst_y * dst_w + dst_x) * ch) as usize;
746        dst[dst_off..dst_off + ch as usize].copy_from_slice(&src[src_off..src_off + ch as usize]);
747    }
748
749    /// Rotate an interleaved pixel image 90° clockwise.
750    ///
751    /// Output dimensions are swapped: `out_width = height`, `out_height = width`.
752    ///
753    /// Pixel mapping: `output(x, y) = input(y, width_out - 1 - x)` where
754    /// `width_out = height`.
755    ///
756    /// # Arguments
757    ///
758    /// * `data` – packed pixel buffer (interleaved, `channels` bytes per pixel)
759    /// * `width` – source image width in pixels
760    /// * `height` – source image height in pixels
761    /// * `channels` – bytes per pixel (e.g. 3 for RGB, 4 for RGBA)
762    ///
763    /// # Panics
764    ///
765    /// Panics in debug mode if `data.len() != width * height * channels`.
766    #[must_use]
767    pub fn rotate90(data: &[u8], width: u32, height: u32, channels: u32) -> Vec<u8> {
768        // After 90° CW: out_width = in_height, out_height = in_width
769        let out_width = height;
770        let out_height = width;
771        let mut out = vec![0u8; (out_width * out_height * channels) as usize];
772
773        for src_y in 0..height {
774            for src_x in 0..width {
775                // 90° CW: dst_x = height - 1 - src_y ... wait, let's derive carefully.
776                // Clockwise 90°: new_x = (in_height - 1 - src_y) is wrong.
777                // The standard derivation for CW 90°:
778                //   src (col=x, row=y) → dst (col=height-1-y, row=x)
779                //   i.e. dst_x = src_y, dst_y = (width - 1 - src_x)
780                // Verify: src(0,0) → dst_x=0, dst_y=width-1  ← top-left goes to bottom-left of output
781                // That matches CW rotation where (0,0) ends at bottom-left of the output.
782                // Actually let's verify with a 3x1 image rotated 90° CW:
783                //   Input (width=3, height=1):  [A B C] (row 0)
784                //   Output (width=1, height=3):  col 0: row0=C, row1=B, row2=A
785                //   So output pixel at (x=0, y=0) = input(width-1-0, 0) = input(2,0)=C ✓
786                //   using: dst_x=src_y, dst_y=(in_width-1-src_x):
787                //     src(0,0): dst_x=0, dst_y=2 → output(0,2)=A ✓ (A at row2)
788                //     src(1,0): dst_x=0, dst_y=1 → output(0,1)=B ✓
789                //     src(2,0): dst_x=0, dst_y=0 → output(0,0)=C ✓
790                let dst_x = src_y;
791                let dst_y = width - 1 - src_x;
792                Self::copy_pixel(
793                    data, &mut out, src_x, src_y, dst_x, dst_y, width, out_width, channels,
794                );
795            }
796        }
797
798        out
799    }
800
801    /// Rotate an interleaved pixel image 180°.
802    ///
803    /// Output dimensions are the same as input.
804    ///
805    /// Pixel mapping: `output(x, y) = input(width-1-x, height-1-y)`.
806    #[must_use]
807    pub fn rotate180(data: &[u8], width: u32, height: u32, channels: u32) -> Vec<u8> {
808        let mut out = vec![0u8; (width * height * channels) as usize];
809
810        for src_y in 0..height {
811            for src_x in 0..width {
812                let dst_x = width - 1 - src_x;
813                let dst_y = height - 1 - src_y;
814                Self::copy_pixel(
815                    data, &mut out, src_x, src_y, dst_x, dst_y, width, width, channels,
816                );
817            }
818        }
819
820        out
821    }
822
823    /// Rotate an interleaved pixel image 270° clockwise (= 90° counter-clockwise).
824    ///
825    /// Output dimensions are swapped: `out_width = height`, `out_height = width`.
826    ///
827    /// Pixel mapping: `output(x, y) = input(height-1-y, x)`.
828    #[must_use]
829    pub fn rotate270(data: &[u8], width: u32, height: u32, channels: u32) -> Vec<u8> {
830        // After 270° CW (= 90° CCW): out_width = in_height, out_height = in_width
831        let out_width = height;
832        let out_height = width;
833        let mut out = vec![0u8; (out_width * out_height * channels) as usize];
834
835        for src_y in 0..height {
836            for src_x in 0..width {
837                // 270° CW derivation:
838                //   src (col=x, row=y) → dst (col=in_height-1-src_y, row=src_x)
839                //   i.e. dst_x = height-1-src_y, dst_y = src_x
840                // Verify with 3x1 (width=3, height=1) rotated 270° CW:
841                //   Output (width=1, height=3): row0=A, row1=B, row2=C
842                //   src(0,0): dst_x=0, dst_y=0 → output(0,0)=A ✓
843                //   src(1,0): dst_x=0, dst_y=1 → output(0,1)=B ✓
844                //   src(2,0): dst_x=0, dst_y=2 → output(0,2)=C ✓
845                let dst_x = height - 1 - src_y;
846                let dst_y = src_x;
847                Self::copy_pixel(
848                    data, &mut out, src_x, src_y, dst_x, dst_y, width, out_width, channels,
849                );
850            }
851        }
852
853        out
854    }
855
856    /// Flip an interleaved pixel image horizontally (mirror left-right).
857    ///
858    /// Output dimensions are the same as input.
859    ///
860    /// Pixel mapping: `output(x, y) = input(width-1-x, y)`.
861    #[must_use]
862    pub fn flip_horizontal(data: &[u8], width: u32, height: u32, channels: u32) -> Vec<u8> {
863        let mut out = vec![0u8; (width * height * channels) as usize];
864
865        for src_y in 0..height {
866            for src_x in 0..width {
867                let dst_x = width - 1 - src_x;
868                Self::copy_pixel(
869                    data, &mut out, src_x, src_y, dst_x, src_y, width, width, channels,
870                );
871            }
872        }
873
874        out
875    }
876
877    /// Flip an interleaved pixel image vertically (mirror top-bottom).
878    ///
879    /// Output dimensions are the same as input.
880    ///
881    /// Pixel mapping: `output(x, y) = input(x, height-1-y)`.
882    #[must_use]
883    pub fn flip_vertical(data: &[u8], width: u32, height: u32, channels: u32) -> Vec<u8> {
884        let mut out = vec![0u8; (width * height * channels) as usize];
885
886        for src_y in 0..height {
887            for src_x in 0..width {
888                let dst_y = height - 1 - src_y;
889                Self::copy_pixel(
890                    data, &mut out, src_x, src_y, src_x, dst_y, width, width, channels,
891                );
892            }
893        }
894
895        out
896    }
897
898    /// Transpose an interleaved pixel image (swap x and y axes).
899    ///
900    /// Output dimensions are swapped: `out_width = height`, `out_height = width`.
901    ///
902    /// Pixel mapping: `output(x, y) = input(y, x)`.
903    #[must_use]
904    pub fn transpose(data: &[u8], width: u32, height: u32, channels: u32) -> Vec<u8> {
905        // After transpose: out_width = in_height, out_height = in_width
906        let out_width = height;
907        let out_height = width;
908        let mut out = vec![0u8; (out_width * out_height * channels) as usize];
909
910        for src_y in 0..height {
911            for src_x in 0..width {
912                // output(x, y) = input(y, x)
913                // dst_x = src_y, dst_y = src_x
914                let dst_x = src_y;
915                let dst_y = src_x;
916                Self::copy_pixel(
917                    data, &mut out, src_x, src_y, dst_x, dst_y, width, out_width, channels,
918                );
919            }
920        }
921
922        out
923    }
924}
925
926// =============================================================================
927// Tests for perspective transform and lens distortion (Task 8)
928// =============================================================================
929
930#[cfg(test)]
931mod tests {
932    use super::*;
933
934    fn solid_rgba(w: u32, h: u32, r: u8, g: u8, b: u8, a: u8) -> Vec<u8> {
935        let n = (w * h * 4) as usize;
936        let mut v = vec![0u8; n];
937        for px in v.chunks_exact_mut(4) {
938            px[0] = r;
939            px[1] = g;
940            px[2] = b;
941            px[3] = a;
942        }
943        v
944    }
945
946    // ── PerspectiveMatrix ─────────────────────────────────────────────────────
947
948    #[test]
949    fn test_perspective_identity_project() {
950        let m = PerspectiveMatrix::identity();
951        let (x, y) = m
952            .project(100.0, 200.0)
953            .expect("identity must not return None");
954        assert!((x - 100.0).abs() < 1e-10, "x={x}");
955        assert!((y - 200.0).abs() < 1e-10, "y={y}");
956    }
957
958    #[test]
959    fn test_perspective_translation() {
960        // Pure translation: shift by (10, 20).
961        let m = PerspectiveMatrix::from_array([1.0, 0.0, 10.0, 0.0, 1.0, 20.0, 0.0, 0.0, 1.0]);
962        let (x, y) = m.project(5.0, 5.0).expect("no infinity");
963        assert!((x - 15.0).abs() < 1e-10, "x={x}");
964        assert!((y - 25.0).abs() < 1e-10, "y={y}");
965    }
966
967    #[test]
968    fn test_perspective_inverse_is_correct() {
969        let m = PerspectiveMatrix::from_array([1.0, 0.5, 10.0, -0.2, 1.0, 5.0, 0.001, 0.0, 1.0]);
970        let inv = m.inverse().expect("non-singular matrix must have inverse");
971        // m · inv(m) ≈ identity
972        let (x_orig, y_orig) = (50.0_f64, 30.0_f64);
973        let (x_proj, y_proj) = m.project(x_orig, y_orig).expect("forward project");
974        let (x_back, y_back) = inv.project(x_proj, y_proj).expect("inverse project");
975        assert!(
976            (x_back - x_orig).abs() < 1e-6,
977            "x roundtrip: {x_back} ≠ {x_orig}"
978        );
979        assert!(
980            (y_back - y_orig).abs() < 1e-6,
981            "y roundtrip: {y_back} ≠ {y_orig}"
982        );
983    }
984
985    #[test]
986    fn test_perspective_singular_returns_none_inverse() {
987        // All-zero matrix is singular.
988        let m = PerspectiveMatrix::from_array([0.0; 9]);
989        assert!(m.inverse().is_none(), "singular matrix must return None");
990    }
991
992    // ── perspective_warp ──────────────────────────────────────────────────────
993
994    #[test]
995    fn test_perspective_warp_identity_preserves_image() {
996        let w = 8u32;
997        let h = 8u32;
998        let src = solid_rgba(w, h, 100, 150, 200, 255);
999        let mut dst = vec![0u8; (w * h * 4) as usize];
1000        perspective_warp(
1001            &src,
1002            w,
1003            h,
1004            &mut dst,
1005            w,
1006            h,
1007            &PerspectiveMatrix::identity(),
1008            [0, 0, 0, 0],
1009        )
1010        .expect("identity warp must succeed");
1011        // Every pixel must match the source (within bilinear rounding).
1012        for (s, d) in src.iter().zip(dst.iter()) {
1013            assert!(
1014                (*s as i32 - *d as i32).unsigned_abs() <= 1,
1015                "identity warp mismatch"
1016            );
1017        }
1018    }
1019
1020    #[test]
1021    fn test_perspective_warp_out_of_bounds_uses_fill() {
1022        let w = 4u32;
1023        let h = 4u32;
1024        let src = solid_rgba(w, h, 255, 0, 0, 255);
1025        let mut dst = vec![0u8; (w * h * 4) as usize];
1026        // Large translation sends all destination pixels outside the source.
1027        let m =
1028            PerspectiveMatrix::from_array([1.0, 0.0, 10000.0, 0.0, 1.0, 10000.0, 0.0, 0.0, 1.0]);
1029        perspective_warp(&src, w, h, &mut dst, w, h, &m, [0, 255, 0, 255])
1030            .expect("warp must succeed");
1031        // All pixels should be fill colour (green).
1032        for i in 0..(w * h) as usize {
1033            assert_eq!(dst[i * 4 + 1], 255, "fill green channel mismatch");
1034        }
1035    }
1036
1037    #[test]
1038    fn test_perspective_warp_invalid_dims_return_error() {
1039        let src = solid_rgba(4, 4, 0, 0, 0, 255);
1040        let mut dst = vec![0u8; 16 * 4];
1041        let result = perspective_warp(
1042            &src,
1043            0,
1044            4,
1045            &mut dst,
1046            4,
1047            4,
1048            &PerspectiveMatrix::identity(),
1049            [0; 4],
1050        );
1051        assert!(result.is_err());
1052    }
1053
1054    // ── lens_undistort ────────────────────────────────────────────────────────
1055
1056    #[test]
1057    fn test_lens_undistort_no_distortion_identity() {
1058        let w = 8u32;
1059        let h = 8u32;
1060        let src = solid_rgba(w, h, 50, 100, 150, 255);
1061        let mut dst = vec![0u8; (w * h * 4) as usize];
1062        let params = LensDistortionParams::no_distortion(w, h);
1063        lens_undistort(&src, w, h, &mut dst, &params, [0; 4]).expect("no distortion must succeed");
1064        // Interior pixels should be close to the source colour.
1065        for px in dst.chunks_exact(4).take(4) {
1066            assert!((px[0] as i32 - 50).unsigned_abs() <= 2, "R mismatch");
1067            assert!((px[1] as i32 - 100).unsigned_abs() <= 2, "G mismatch");
1068            assert!((px[2] as i32 - 150).unsigned_abs() <= 2, "B mismatch");
1069        }
1070    }
1071
1072    #[test]
1073    fn test_lens_undistort_preserves_centre_pixel() {
1074        // Centre pixel should be unaffected by distortion.
1075        let w = 9u32; // odd size so centre is exact
1076        let h = 9u32;
1077        let mut src = vec![0u8; (w * h * 4) as usize];
1078        // Mark the centre pixel distinctively.
1079        let cx = (w / 2) as usize;
1080        let cy = (h / 2) as usize;
1081        let center_idx = (cy * w as usize + cx) * 4;
1082        src[center_idx] = 255;
1083        src[center_idx + 1] = 128;
1084        src[center_idx + 2] = 64;
1085        src[center_idx + 3] = 255;
1086        let mut dst = vec![0u8; (w * h * 4) as usize];
1087        let params = LensDistortionParams {
1088            k1: 0.1,
1089            k2: 0.0,
1090            k3: 0.0,
1091            p1: 0.0,
1092            p2: 0.0,
1093            fx: f64::from(w),
1094            fy: f64::from(h),
1095            cx: f64::from(w) / 2.0,
1096            cy: f64::from(h) / 2.0,
1097        };
1098        lens_undistort(&src, w, h, &mut dst, &params, [0; 4]).expect("undistort must succeed");
1099        // Centre pixel at (cx, cy): r2 = 0, so it maps back to itself.
1100        let out_r = dst[center_idx];
1101        assert!(
1102            out_r > 128,
1103            "centre R should reflect the marked pixel, got {out_r}"
1104        );
1105    }
1106
1107    #[test]
1108    fn test_lens_undistort_invalid_dims_return_error() {
1109        let src = vec![0u8; 64];
1110        let mut dst = vec![0u8; 64];
1111        let params = LensDistortionParams::no_distortion(4, 4);
1112        let result = lens_undistort(&src, 0, 4, &mut dst, &params, [0; 4]);
1113        assert!(result.is_err());
1114    }
1115
1116    // ── Geometric transforms ───────────────────────────────────────────────────
1117
1118    /// Build a test image where every pixel has a unique value based on its
1119    /// (x, y) coordinates.  Pixel at (x, y) in an image of width `w` gets
1120    /// value `[y as u8, x as u8, (y*w+x) as u8]` using 3 channels.
1121    fn make_test_image_3ch(w: u32, h: u32) -> Vec<u8> {
1122        let mut buf = vec![0u8; (w * h * 3) as usize];
1123        for y in 0..h {
1124            for x in 0..w {
1125                let off = ((y * w + x) * 3) as usize;
1126                buf[off] = y as u8;
1127                buf[off + 1] = x as u8;
1128                buf[off + 2] = (y * w + x) as u8;
1129            }
1130        }
1131        buf
1132    }
1133
1134    #[test]
1135    fn test_rotate90_dimensions() {
1136        // A 3×5 image rotated 90° CW should produce a 5×3 image.
1137        let img = make_test_image_3ch(3, 5);
1138        let out = TransformOperation::rotate90(&img, 3, 5, 3);
1139        // out_width = in_height = 5, out_height = in_width = 3
1140        assert_eq!(
1141            out.len(),
1142            (5 * 3 * 3) as usize,
1143            "output buffer size mismatch"
1144        );
1145    }
1146
1147    #[test]
1148    fn test_rotate90_corner() {
1149        // Source image 4×2 (width=4, height=2), 3-channel.
1150        // After 90° CW: out_width=2, out_height=4.
1151        // src(0,0) → dst_x=src_y=0, dst_y=width-1-src_x=3 → dst(0,3)
1152        let w: u32 = 4;
1153        let h: u32 = 2;
1154        let ch: u32 = 3;
1155        let mut img = vec![0u8; (w * h * ch) as usize];
1156        // Mark src pixel (0,0) distinctively.
1157        img[0] = 1;
1158        img[1] = 2;
1159        img[2] = 3;
1160
1161        let out = TransformOperation::rotate90(&img, w, h, ch);
1162        let out_width = h; // 2
1163                           // Expected: dst(dst_x=0, dst_y=3) holds [1,2,3]
1164        let dst_off = ((3 * out_width + 0) * ch) as usize;
1165        assert_eq!(
1166            &out[dst_off..dst_off + 3],
1167            &[1, 2, 3],
1168            "rotate90 corner pixel wrong"
1169        );
1170    }
1171
1172    #[test]
1173    fn test_rotate180_roundtrip() {
1174        // Rotating 180° twice must reproduce the original image.
1175        let w: u32 = 4;
1176        let h: u32 = 3;
1177        let img = make_test_image_3ch(w, h);
1178        let once = TransformOperation::rotate180(&img, w, h, 3);
1179        let twice = TransformOperation::rotate180(&once, w, h, 3);
1180        assert_eq!(img, twice, "rotate180 twice must equal original");
1181    }
1182
1183    #[test]
1184    fn test_flip_horizontal_reverses_row() {
1185        // Flip a 4×2 image horizontally; the first row of the output should be
1186        // the reverse of the first row of the input.
1187        let w: u32 = 4;
1188        let h: u32 = 2;
1189        let ch: u32 = 3;
1190        let img = make_test_image_3ch(w, h);
1191        let out = TransformOperation::flip_horizontal(&img, w, h, ch);
1192
1193        // Row 0 of input: pixels at x=0,1,2,3
1194        // Row 0 of output: pixels at dst_x=3,2,1,0 (reversed)
1195        for x in 0..w {
1196            let src_off = (x * ch) as usize;
1197            let dst_off = ((w - 1 - x) * ch) as usize;
1198            assert_eq!(
1199                &img[src_off..src_off + ch as usize],
1200                &out[dst_off..dst_off + ch as usize],
1201                "flip_horizontal row-reversal wrong at x={x}"
1202            );
1203        }
1204    }
1205
1206    #[test]
1207    fn test_transpose_swaps_dimensions() {
1208        // A 2×4 image (width=2, height=4) transposed should be 4×2.
1209        let w: u32 = 2;
1210        let h: u32 = 4;
1211        let ch: u32 = 3;
1212        let img = make_test_image_3ch(w, h);
1213        let out = TransformOperation::transpose(&img, w, h, ch);
1214        // out_width = in_height = 4, out_height = in_width = 2
1215        assert_eq!(
1216            out.len(),
1217            (4 * 2 * ch) as usize,
1218            "transpose buffer size mismatch"
1219        );
1220        // Verify that output(x=src_y, y=src_x) == input(src_x, src_y)
1221        let out_width: u32 = h; // 4
1222        for src_y in 0..h {
1223            for src_x in 0..w {
1224                let src_off = ((src_y * w + src_x) * ch) as usize;
1225                let dst_off = ((src_x * out_width + src_y) * ch) as usize;
1226                assert_eq!(
1227                    &img[src_off..src_off + ch as usize],
1228                    &out[dst_off..dst_off + ch as usize],
1229                    "transpose pixel mismatch at ({src_x},{src_y})"
1230                );
1231            }
1232        }
1233    }
1234}