Skip to main content

oximedia_gpu/ops/
transform.rs

1//! Transform operations (DCT, FFT) for frequency domain processing
2
3use crate::{
4    shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5    GpuDevice, GpuError, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[repr(C)]
14#[derive(Copy, Clone, Pod, Zeroable)]
15struct TransformParams {
16    width: u32,
17    height: u32,
18    block_size: u32,
19    transform_type: u32,
20    stride: u32,
21    is_inverse: u32,
22    padding1: u32,
23    padding2: u32,
24}
25
26/// Transform operations for frequency domain processing
27pub struct TransformOperation;
28
29impl TransformOperation {
30    /// Compute 2D DCT (Discrete Cosine Transform)
31    ///
32    /// Computes the forward DCT on 8x8 blocks. Input dimensions must be
33    /// multiples of 8.
34    ///
35    /// # Arguments
36    ///
37    /// * `device` - GPU device
38    /// * `input` - Input data (f32 values)
39    /// * `output` - Output DCT coefficients
40    /// * `width` - Data width (must be multiple of 8)
41    /// * `height` - Data height (must be multiple of 8)
42    ///
43    /// # Errors
44    ///
45    /// Returns an error if dimensions are invalid or if the GPU operation fails.
46    pub fn dct_2d(
47        device: &GpuDevice,
48        input: &[f32],
49        output: &mut [f32],
50        width: u32,
51        height: u32,
52    ) -> Result<()> {
53        if width % 8 != 0 || height % 8 != 0 {
54            return Err(GpuError::InvalidDimensions { width, height });
55        }
56
57        utils::validate_dimensions(width, height)?;
58
59        let expected_size = (width * height) as usize;
60        if input.len() < expected_size || output.len() < expected_size {
61            return Err(GpuError::InvalidBufferSize {
62                expected: expected_size,
63                actual: input.len().min(output.len()),
64            });
65        }
66
67        let pipeline = Self::get_dct_8x8_pipeline(device)?;
68        let layout = Self::get_bind_group_layout(device)?;
69
70        Self::execute_transform(
71            device, pipeline, layout, input, output, width, height, 8, 0, // DCT
72        )
73    }
74
75    /// Compute 2D IDCT (Inverse Discrete Cosine Transform)
76    ///
77    /// Computes the inverse DCT on 8x8 blocks. Input dimensions must be
78    /// multiples of 8.
79    ///
80    /// # Arguments
81    ///
82    /// * `device` - GPU device
83    /// * `input` - Input DCT coefficients
84    /// * `output` - Output reconstructed data
85    /// * `width` - Data width (must be multiple of 8)
86    /// * `height` - Data height (must be multiple of 8)
87    ///
88    /// # Errors
89    ///
90    /// Returns an error if dimensions are invalid or if the GPU operation fails.
91    pub fn idct_2d(
92        device: &GpuDevice,
93        input: &[f32],
94        output: &mut [f32],
95        width: u32,
96        height: u32,
97    ) -> Result<()> {
98        if width % 8 != 0 || height % 8 != 0 {
99            return Err(GpuError::InvalidDimensions { width, height });
100        }
101
102        utils::validate_dimensions(width, height)?;
103
104        let expected_size = (width * height) as usize;
105        if input.len() < expected_size || output.len() < expected_size {
106            return Err(GpuError::InvalidBufferSize {
107                expected: expected_size,
108                actual: input.len().min(output.len()),
109            });
110        }
111
112        let pipeline = Self::get_idct_8x8_pipeline(device)?;
113        let layout = Self::get_bind_group_layout(device)?;
114
115        Self::execute_transform(
116            device, pipeline, layout, input, output, width, height, 8, 1, // IDCT
117        )
118    }
119
120    /// Compute general 2D DCT using row-column decomposition
121    ///
122    /// This method works for any dimensions, not just multiples of 8.
123    ///
124    /// # Arguments
125    ///
126    /// * `device` - GPU device
127    /// * `input` - Input data (f32 values)
128    /// * `output` - Output DCT coefficients
129    /// * `width` - Data width
130    /// * `height` - Data height
131    ///
132    /// # Errors
133    ///
134    /// Returns an error if dimensions are invalid or if the GPU operation fails.
135    pub fn dct_2d_general(
136        device: &GpuDevice,
137        input: &[f32],
138        output: &mut [f32],
139        width: u32,
140        height: u32,
141    ) -> Result<()> {
142        utils::validate_dimensions(width, height)?;
143
144        let expected_size = (width * height) as usize;
145        if input.len() < expected_size || output.len() < expected_size {
146            return Err(GpuError::InvalidBufferSize {
147                expected: expected_size,
148                actual: input.len().min(output.len()),
149            });
150        }
151
152        // Two-pass DCT: row then column
153        let mut temp = vec![0.0f32; expected_size];
154
155        // Row DCT
156        let row_pipeline = Self::get_dct_row_pipeline(device)?;
157        let layout = Self::get_bind_group_layout(device)?;
158
159        Self::execute_transform(
160            device,
161            row_pipeline,
162            layout,
163            input,
164            &mut temp,
165            width,
166            height,
167            width,
168            0,
169        )?;
170
171        // Column DCT
172        let col_pipeline = Self::get_dct_col_pipeline(device)?;
173
174        Self::execute_transform(
175            device,
176            col_pipeline,
177            layout,
178            &temp,
179            output,
180            width,
181            height,
182            height,
183            0,
184        )
185    }
186
187    #[allow(clippy::too_many_arguments)]
188    fn execute_transform(
189        device: &GpuDevice,
190        pipeline: &ComputePipeline,
191        layout: &BindGroupLayout,
192        input: &[f32],
193        output: &mut [f32],
194        width: u32,
195        height: u32,
196        block_size: u32,
197        transform_type: u32,
198    ) -> Result<()> {
199        let input_bytes = bytemuck::cast_slice(input);
200        let output_size = std::mem::size_of_val(output);
201
202        // Create buffers
203        let input_buffer = utils::create_storage_buffer(device, input_bytes.len() as u64)?;
204        let output_buffer = utils::create_storage_buffer(device, output_size as u64)?;
205
206        // Upload input data
207        device
208            .queue()
209            .write_buffer(input_buffer.buffer(), 0, input_bytes);
210
211        // Create uniform buffer for parameters
212        let params = TransformParams {
213            width,
214            height,
215            block_size,
216            transform_type,
217            stride: width,
218            is_inverse: 0,
219            padding1: 0,
220            padding2: 0,
221        };
222        let params_bytes = bytemuck::bytes_of(&params);
223        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
224
225        // Create bind group
226        let compiler = ShaderCompiler::new(device);
227        let bind_group = compiler.create_bind_group(
228            "Transform Bind Group",
229            layout,
230            &[
231                wgpu::BindGroupEntry {
232                    binding: 0,
233                    resource: input_buffer.buffer().as_entire_binding(),
234                },
235                wgpu::BindGroupEntry {
236                    binding: 1,
237                    resource: output_buffer.buffer().as_entire_binding(),
238                },
239                wgpu::BindGroupEntry {
240                    binding: 2,
241                    resource: params_buffer.buffer().as_entire_binding(),
242                },
243            ],
244        );
245
246        // Execute compute pass
247        Self::dispatch_compute(device, pipeline, &bind_group, width, height, block_size)?;
248
249        // Read back results
250        let readback_buffer = utils::create_readback_buffer(device, output_size as u64)?;
251        let mut encoder = device
252            .device()
253            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
254                label: Some("Transform Copy Encoder"),
255            });
256
257        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output_size as u64)?;
258
259        device.queue().submit(Some(encoder.finish()));
260        device.wait();
261
262        let result = readback_buffer.read(device, 0, output_size as u64)?;
263        let result_f32: &[f32] = bytemuck::cast_slice(&result);
264        output.copy_from_slice(result_f32);
265
266        Ok(())
267    }
268
269    fn dispatch_compute(
270        device: &GpuDevice,
271        pipeline: &ComputePipeline,
272        bind_group: &BindGroup,
273        width: u32,
274        height: u32,
275        block_size: u32,
276    ) -> Result<()> {
277        let mut encoder = device
278            .device()
279            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
280                label: Some("Transform Compute Encoder"),
281            });
282
283        {
284            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
285                label: Some("Transform Compute Pass"),
286                timestamp_writes: None,
287            });
288
289            compute_pass.set_pipeline(pipeline);
290            compute_pass.set_bind_group(0, bind_group, &[]);
291
292            if block_size == 8 {
293                // For 8x8 DCT, dispatch one workgroup per block
294                let dispatch_x = width / 8;
295                let dispatch_y = height / 8;
296                compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
297            } else {
298                // For row/column transforms
299                let total_elements = width * height;
300                let dispatch = total_elements.div_ceil(256);
301                compute_pass.dispatch_workgroups(dispatch, 1, 1);
302            }
303        }
304
305        device.queue().submit(Some(encoder.finish()));
306        Ok(())
307    }
308
309    fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
310        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
311
312        Ok(LAYOUT.get_or_init(|| {
313            let compiler = ShaderCompiler::new(device);
314            let entries = BindGroupLayoutBuilder::new()
315                .add_storage_buffer_read_only(0) // input
316                .add_storage_buffer(1) // output
317                .add_uniform_buffer(2) // params
318                .build();
319
320            compiler.create_bind_group_layout("Transform Bind Group Layout", &entries)
321        }))
322    }
323
324    fn init_pipeline(
325        device: &GpuDevice,
326        name: &str,
327        entry_point: &str,
328    ) -> std::result::Result<ComputePipeline, String> {
329        let compiler = ShaderCompiler::new(device);
330        let shader = compiler
331            .compile(
332                "Transform Shader",
333                ShaderSource::Embedded(crate::shader::embedded::TRANSFORM_SHADER),
334            )
335            .map_err(|e| format!("Failed to compile transform shader: {e}"))?;
336
337        let layout = Self::get_bind_group_layout(device)
338            .map_err(|e| format!("Failed to create bind group layout: {e}"))?;
339
340        compiler
341            .create_pipeline(name, &shader, entry_point, layout)
342            .map_err(|e| format!("Failed to create pipeline: {e}"))
343    }
344
345    fn get_dct_8x8_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
346        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
347
348        PIPELINE
349            .get_or_init(|| {
350                TransformOperation::init_pipeline(device, "DCT 8x8 Pipeline", "dct_8x8")
351            })
352            .as_ref()
353            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
354    }
355
356    fn get_idct_8x8_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
357        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
358
359        PIPELINE
360            .get_or_init(|| {
361                TransformOperation::init_pipeline(device, "IDCT 8x8 Pipeline", "idct_8x8")
362            })
363            .as_ref()
364            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
365    }
366
367    fn get_dct_row_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
368        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
369
370        PIPELINE
371            .get_or_init(|| {
372                TransformOperation::init_pipeline(device, "DCT Row Pipeline", "dct_row")
373            })
374            .as_ref()
375            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
376    }
377
378    fn get_dct_col_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
379        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
380
381        PIPELINE
382            .get_or_init(|| {
383                TransformOperation::init_pipeline(device, "DCT Column Pipeline", "dct_col")
384            })
385            .as_ref()
386            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
387    }
388}
389
390// =============================================================================
391// CPU-side perspective transform and lens distortion correction (Task 8)
392// =============================================================================
393
394/// A 3×3 homogeneous perspective (projective) transform matrix stored in
395/// row-major order.
396///
397/// The matrix maps homogeneous image coordinates `(x, y, 1)ᵀ` to new
398/// coordinates via `(x', y', w')ᵀ = M · (x, y, 1)ᵀ`.  The Cartesian result
399/// is `(x'/w', y'/w')`.
400#[derive(Debug, Clone, Copy)]
401pub struct PerspectiveMatrix {
402    /// Row-major 3×3 elements: `[[a,b,c],[d,e,f],[g,h,i]]`.
403    pub data: [[f64; 3]; 3],
404}
405
406impl PerspectiveMatrix {
407    /// Create from a flat row-major array of 9 elements.
408    #[must_use]
409    pub fn from_array(m: [f64; 9]) -> Self {
410        Self {
411            data: [[m[0], m[1], m[2]], [m[3], m[4], m[5]], [m[6], m[7], m[8]]],
412        }
413    }
414
415    /// Identity perspective matrix (no transform).
416    #[must_use]
417    pub fn identity() -> Self {
418        Self::from_array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0])
419    }
420
421    /// Apply this matrix to a point `(x, y)` and return the projected result.
422    ///
423    /// Returns `None` if the homogeneous weight `w` is too close to zero
424    /// (the point maps to infinity).
425    #[must_use]
426    pub fn project(&self, x: f64, y: f64) -> Option<(f64, f64)> {
427        let m = &self.data;
428        let x_h = m[0][0] * x + m[0][1] * y + m[0][2];
429        let y_h = m[1][0] * x + m[1][1] * y + m[1][2];
430        let w = m[2][0] * x + m[2][1] * y + m[2][2];
431        if w.abs() < 1e-12 {
432            return None;
433        }
434        Some((x_h / w, y_h / w))
435    }
436
437    /// Compute the inverse of this matrix using Cramer's rule.
438    ///
439    /// Returns `None` if the matrix is singular.
440    #[must_use]
441    pub fn inverse(&self) -> Option<Self> {
442        let m = &self.data;
443        let det = m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
444            - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
445            + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
446        if det.abs() < 1e-15 {
447            return None;
448        }
449        let inv_det = 1.0 / det;
450        let inv = [
451            [
452                (m[1][1] * m[2][2] - m[1][2] * m[2][1]) * inv_det,
453                (m[0][2] * m[2][1] - m[0][1] * m[2][2]) * inv_det,
454                (m[0][1] * m[1][2] - m[0][2] * m[1][1]) * inv_det,
455            ],
456            [
457                (m[1][2] * m[2][0] - m[1][0] * m[2][2]) * inv_det,
458                (m[0][0] * m[2][2] - m[0][2] * m[2][0]) * inv_det,
459                (m[0][2] * m[1][0] - m[0][0] * m[1][2]) * inv_det,
460            ],
461            [
462                (m[1][0] * m[2][1] - m[1][1] * m[2][0]) * inv_det,
463                (m[0][1] * m[2][0] - m[0][0] * m[2][1]) * inv_det,
464                (m[0][0] * m[1][1] - m[0][1] * m[1][0]) * inv_det,
465            ],
466        ];
467        Some(Self { data: inv })
468    }
469}
470
471impl Default for PerspectiveMatrix {
472    fn default() -> Self {
473        Self::identity()
474    }
475}
476
477/// Parameters for radial + tangential lens distortion (Brown-Conrady model).
478///
479/// This is the same model used by OpenCV and is compatible with camera
480/// calibration output from standard photogrammetry tools.
481#[derive(Debug, Clone, Copy)]
482pub struct LensDistortionParams {
483    /// Radial distortion coefficient k₁ (typically small, e.g. -0.3 to +0.5).
484    pub k1: f64,
485    /// Radial distortion coefficient k₂.
486    pub k2: f64,
487    /// Radial distortion coefficient k₃.
488    pub k3: f64,
489    /// Tangential distortion coefficient p₁.
490    pub p1: f64,
491    /// Tangential distortion coefficient p₂.
492    pub p2: f64,
493    /// Focal length in pixels along the X axis.
494    pub fx: f64,
495    /// Focal length in pixels along the Y axis.
496    pub fy: f64,
497    /// Principal point X coordinate (typically `width / 2`).
498    pub cx: f64,
499    /// Principal point Y coordinate (typically `height / 2`).
500    pub cy: f64,
501}
502
503impl LensDistortionParams {
504    /// Create a default (no distortion) parameter set for an image of the
505    /// given `width × height`.
506    #[must_use]
507    pub fn no_distortion(width: u32, height: u32) -> Self {
508        Self {
509            k1: 0.0,
510            k2: 0.0,
511            k3: 0.0,
512            p1: 0.0,
513            p2: 0.0,
514            fx: f64::from(width),
515            fy: f64::from(height),
516            cx: f64::from(width) / 2.0,
517            cy: f64::from(height) / 2.0,
518        }
519    }
520}
521
522/// CPU-parallel perspective warp of a packed RGBA image.
523///
524/// Uses inverse mapping with bilinear interpolation: for each destination
525/// pixel `(dx, dy)` the inverse homography maps it back to the source
526/// coordinates `(sx, sy)`, which are bilinearly sampled.
527///
528/// Pixels that map outside the source image are filled with `fill_rgba`.
529///
530/// # Errors
531///
532/// Returns [`crate::GpuError::InvalidDimensions`] for zero dimensions or
533/// [`crate::GpuError::InvalidBufferSize`] for buffer/dimension mismatches.
534pub fn perspective_warp(
535    input: &[u8],
536    src_width: u32,
537    src_height: u32,
538    output: &mut [u8],
539    dst_width: u32,
540    dst_height: u32,
541    matrix: &PerspectiveMatrix,
542    fill_rgba: [u8; 4],
543) -> crate::Result<()> {
544    use super::utils;
545    use crate::GpuError;
546
547    if src_width == 0 || src_height == 0 {
548        return Err(GpuError::InvalidDimensions {
549            width: src_width,
550            height: src_height,
551        });
552    }
553    if dst_width == 0 || dst_height == 0 {
554        return Err(GpuError::InvalidDimensions {
555            width: dst_width,
556            height: dst_height,
557        });
558    }
559    utils::validate_buffer_size(input, src_width, src_height, 4)?;
560    utils::validate_buffer_size(output, dst_width, dst_height, 4)?;
561
562    let inv = matrix
563        .inverse()
564        .ok_or_else(|| GpuError::Internal("Perspective matrix is singular".to_string()))?;
565
566    let sw = src_width as usize;
567    let sh = src_height as usize;
568    let dw = dst_width as usize;
569    let dh = dst_height as usize;
570
571    for dy in 0..dh {
572        for dx in 0..dw {
573            let dst_idx = (dy * dw + dx) * 4;
574            let Some((sx_f, sy_f)) = inv.project(dx as f64, dy as f64) else {
575                output[dst_idx..dst_idx + 4].copy_from_slice(&fill_rgba);
576                continue;
577            };
578
579            // Bilinear interpolation
580            let x0 = sx_f.floor() as isize;
581            let y0 = sy_f.floor() as isize;
582            let x1 = x0 + 1;
583            let y1 = y0 + 1;
584            let fx = sx_f - sx_f.floor();
585            let fy = sy_f - sy_f.floor();
586
587            let sample = |cx: isize, cy: isize| -> [f64; 4] {
588                if cx < 0 || cy < 0 || cx >= sw as isize || cy >= sh as isize {
589                    [
590                        fill_rgba[0] as f64,
591                        fill_rgba[1] as f64,
592                        fill_rgba[2] as f64,
593                        fill_rgba[3] as f64,
594                    ]
595                } else {
596                    let idx = (cy as usize * sw + cx as usize) * 4;
597                    [
598                        input[idx] as f64,
599                        input[idx + 1] as f64,
600                        input[idx + 2] as f64,
601                        input[idx + 3] as f64,
602                    ]
603                }
604            };
605
606            let p00 = sample(x0, y0);
607            let p10 = sample(x1, y0);
608            let p01 = sample(x0, y1);
609            let p11 = sample(x1, y1);
610
611            for c in 0..4 {
612                let v = p00[c] * (1.0 - fx) * (1.0 - fy)
613                    + p10[c] * fx * (1.0 - fy)
614                    + p01[c] * (1.0 - fx) * fy
615                    + p11[c] * fx * fy;
616                output[dst_idx + c] = v.round().clamp(0.0, 255.0) as u8;
617            }
618        }
619    }
620
621    Ok(())
622}
623
624/// CPU-side lens distortion correction using the Brown-Conrady model.
625///
626/// For each destination pixel `(x, y)` the distortion model computes the
627/// corresponding distorted source coordinate and bilinearly samples the input.
628/// Pixels that map outside the source image are filled with `fill_rgba`.
629///
630/// # Errors
631///
632/// Returns an error if dimensions are zero or buffers are the wrong size.
633pub fn lens_undistort(
634    input: &[u8],
635    width: u32,
636    height: u32,
637    output: &mut [u8],
638    params: &LensDistortionParams,
639    fill_rgba: [u8; 4],
640) -> crate::Result<()> {
641    use super::utils;
642    use crate::GpuError;
643
644    if width == 0 || height == 0 {
645        return Err(GpuError::InvalidDimensions { width, height });
646    }
647    utils::validate_buffer_size(input, width, height, 4)?;
648    utils::validate_buffer_size(output, width, height, 4)?;
649
650    let w = width as usize;
651    let h = height as usize;
652    let inv_fx = 1.0 / params.fx;
653    let inv_fy = 1.0 / params.fy;
654
655    for dy in 0..h {
656        for dx in 0..w {
657            // Normalised coordinates (undistorted space).
658            let x_u = (dx as f64 - params.cx) * inv_fx;
659            let y_u = (dy as f64 - params.cy) * inv_fy;
660
661            // Apply Brown-Conrady radial + tangential distortion to map from
662            // undistorted → distorted (where the actual sensor data lives).
663            let r2 = x_u * x_u + y_u * y_u;
664            let r4 = r2 * r2;
665            let r6 = r4 * r2;
666            let radial = 1.0 + params.k1 * r2 + params.k2 * r4 + params.k3 * r6;
667            let x_d =
668                x_u * radial + 2.0 * params.p1 * x_u * y_u + params.p2 * (r2 + 2.0 * x_u * x_u);
669            let y_d =
670                y_u * radial + params.p1 * (r2 + 2.0 * y_u * y_u) + 2.0 * params.p2 * x_u * y_u;
671
672            // Back to pixel coordinates in the distorted (source) image.
673            let sx_f = x_d * params.fx + params.cx;
674            let sy_f = y_d * params.fy + params.cy;
675
676            let dst_idx = (dy * w + dx) * 4;
677
678            let x0 = sx_f.floor() as isize;
679            let y0 = sy_f.floor() as isize;
680            let x1 = x0 + 1;
681            let y1 = y0 + 1;
682            let fx = sx_f - sx_f.floor();
683            let fy = sy_f - sy_f.floor();
684
685            let sample = |cx: isize, cy: isize| -> [f64; 4] {
686                if cx < 0 || cy < 0 || cx >= w as isize || cy >= h as isize {
687                    [
688                        fill_rgba[0] as f64,
689                        fill_rgba[1] as f64,
690                        fill_rgba[2] as f64,
691                        fill_rgba[3] as f64,
692                    ]
693                } else {
694                    let idx = (cy as usize * w + cx as usize) * 4;
695                    [
696                        input[idx] as f64,
697                        input[idx + 1] as f64,
698                        input[idx + 2] as f64,
699                        input[idx + 3] as f64,
700                    ]
701                }
702            };
703
704            let p00 = sample(x0, y0);
705            let p10 = sample(x1, y0);
706            let p01 = sample(x0, y1);
707            let p11 = sample(x1, y1);
708
709            for c in 0..4 {
710                let v = p00[c] * (1.0 - fx) * (1.0 - fy)
711                    + p10[c] * fx * (1.0 - fy)
712                    + p01[c] * (1.0 - fx) * fy
713                    + p11[c] * fx * fy;
714                output[dst_idx + c] = v.round().clamp(0.0, 255.0) as u8;
715            }
716        }
717    }
718
719    Ok(())
720}
721
722// =============================================================================
723// Tests for perspective transform and lens distortion (Task 8)
724// =============================================================================
725
726#[cfg(test)]
727mod tests {
728    use super::*;
729
730    fn solid_rgba(w: u32, h: u32, r: u8, g: u8, b: u8, a: u8) -> Vec<u8> {
731        let n = (w * h * 4) as usize;
732        let mut v = vec![0u8; n];
733        for px in v.chunks_exact_mut(4) {
734            px[0] = r;
735            px[1] = g;
736            px[2] = b;
737            px[3] = a;
738        }
739        v
740    }
741
742    // ── PerspectiveMatrix ─────────────────────────────────────────────────────
743
744    #[test]
745    fn test_perspective_identity_project() {
746        let m = PerspectiveMatrix::identity();
747        let (x, y) = m
748            .project(100.0, 200.0)
749            .expect("identity must not return None");
750        assert!((x - 100.0).abs() < 1e-10, "x={x}");
751        assert!((y - 200.0).abs() < 1e-10, "y={y}");
752    }
753
754    #[test]
755    fn test_perspective_translation() {
756        // Pure translation: shift by (10, 20).
757        let m = PerspectiveMatrix::from_array([1.0, 0.0, 10.0, 0.0, 1.0, 20.0, 0.0, 0.0, 1.0]);
758        let (x, y) = m.project(5.0, 5.0).expect("no infinity");
759        assert!((x - 15.0).abs() < 1e-10, "x={x}");
760        assert!((y - 25.0).abs() < 1e-10, "y={y}");
761    }
762
763    #[test]
764    fn test_perspective_inverse_is_correct() {
765        let m = PerspectiveMatrix::from_array([1.0, 0.5, 10.0, -0.2, 1.0, 5.0, 0.001, 0.0, 1.0]);
766        let inv = m.inverse().expect("non-singular matrix must have inverse");
767        // m · inv(m) ≈ identity
768        let (x_orig, y_orig) = (50.0_f64, 30.0_f64);
769        let (x_proj, y_proj) = m.project(x_orig, y_orig).expect("forward project");
770        let (x_back, y_back) = inv.project(x_proj, y_proj).expect("inverse project");
771        assert!(
772            (x_back - x_orig).abs() < 1e-6,
773            "x roundtrip: {x_back} ≠ {x_orig}"
774        );
775        assert!(
776            (y_back - y_orig).abs() < 1e-6,
777            "y roundtrip: {y_back} ≠ {y_orig}"
778        );
779    }
780
781    #[test]
782    fn test_perspective_singular_returns_none_inverse() {
783        // All-zero matrix is singular.
784        let m = PerspectiveMatrix::from_array([0.0; 9]);
785        assert!(m.inverse().is_none(), "singular matrix must return None");
786    }
787
788    // ── perspective_warp ──────────────────────────────────────────────────────
789
790    #[test]
791    fn test_perspective_warp_identity_preserves_image() {
792        let w = 8u32;
793        let h = 8u32;
794        let src = solid_rgba(w, h, 100, 150, 200, 255);
795        let mut dst = vec![0u8; (w * h * 4) as usize];
796        perspective_warp(
797            &src,
798            w,
799            h,
800            &mut dst,
801            w,
802            h,
803            &PerspectiveMatrix::identity(),
804            [0, 0, 0, 0],
805        )
806        .expect("identity warp must succeed");
807        // Every pixel must match the source (within bilinear rounding).
808        for (s, d) in src.iter().zip(dst.iter()) {
809            assert!(
810                (*s as i32 - *d as i32).unsigned_abs() <= 1,
811                "identity warp mismatch"
812            );
813        }
814    }
815
816    #[test]
817    fn test_perspective_warp_out_of_bounds_uses_fill() {
818        let w = 4u32;
819        let h = 4u32;
820        let src = solid_rgba(w, h, 255, 0, 0, 255);
821        let mut dst = vec![0u8; (w * h * 4) as usize];
822        // Large translation sends all destination pixels outside the source.
823        let m =
824            PerspectiveMatrix::from_array([1.0, 0.0, 10000.0, 0.0, 1.0, 10000.0, 0.0, 0.0, 1.0]);
825        perspective_warp(&src, w, h, &mut dst, w, h, &m, [0, 255, 0, 255])
826            .expect("warp must succeed");
827        // All pixels should be fill colour (green).
828        for i in 0..(w * h) as usize {
829            assert_eq!(dst[i * 4 + 1], 255, "fill green channel mismatch");
830        }
831    }
832
833    #[test]
834    fn test_perspective_warp_invalid_dims_return_error() {
835        let src = solid_rgba(4, 4, 0, 0, 0, 255);
836        let mut dst = vec![0u8; 16 * 4];
837        let result = perspective_warp(
838            &src,
839            0,
840            4,
841            &mut dst,
842            4,
843            4,
844            &PerspectiveMatrix::identity(),
845            [0; 4],
846        );
847        assert!(result.is_err());
848    }
849
850    // ── lens_undistort ────────────────────────────────────────────────────────
851
852    #[test]
853    fn test_lens_undistort_no_distortion_identity() {
854        let w = 8u32;
855        let h = 8u32;
856        let src = solid_rgba(w, h, 50, 100, 150, 255);
857        let mut dst = vec![0u8; (w * h * 4) as usize];
858        let params = LensDistortionParams::no_distortion(w, h);
859        lens_undistort(&src, w, h, &mut dst, &params, [0; 4]).expect("no distortion must succeed");
860        // Interior pixels should be close to the source colour.
861        for px in dst.chunks_exact(4).take(4) {
862            assert!((px[0] as i32 - 50).unsigned_abs() <= 2, "R mismatch");
863            assert!((px[1] as i32 - 100).unsigned_abs() <= 2, "G mismatch");
864            assert!((px[2] as i32 - 150).unsigned_abs() <= 2, "B mismatch");
865        }
866    }
867
868    #[test]
869    fn test_lens_undistort_preserves_centre_pixel() {
870        // Centre pixel should be unaffected by distortion.
871        let w = 9u32; // odd size so centre is exact
872        let h = 9u32;
873        let mut src = vec![0u8; (w * h * 4) as usize];
874        // Mark the centre pixel distinctively.
875        let cx = (w / 2) as usize;
876        let cy = (h / 2) as usize;
877        let center_idx = (cy * w as usize + cx) * 4;
878        src[center_idx] = 255;
879        src[center_idx + 1] = 128;
880        src[center_idx + 2] = 64;
881        src[center_idx + 3] = 255;
882        let mut dst = vec![0u8; (w * h * 4) as usize];
883        let params = LensDistortionParams {
884            k1: 0.1,
885            k2: 0.0,
886            k3: 0.0,
887            p1: 0.0,
888            p2: 0.0,
889            fx: f64::from(w),
890            fy: f64::from(h),
891            cx: f64::from(w) / 2.0,
892            cy: f64::from(h) / 2.0,
893        };
894        lens_undistort(&src, w, h, &mut dst, &params, [0; 4]).expect("undistort must succeed");
895        // Centre pixel at (cx, cy): r2 = 0, so it maps back to itself.
896        let out_r = dst[center_idx];
897        assert!(
898            out_r > 128,
899            "centre R should reflect the marked pixel, got {out_r}"
900        );
901    }
902
903    #[test]
904    fn test_lens_undistort_invalid_dims_return_error() {
905        let src = vec![0u8; 64];
906        let mut dst = vec![0u8; 64];
907        let params = LensDistortionParams::no_distortion(4, 4);
908        let result = lens_undistort(&src, 0, 4, &mut dst, &params, [0; 4]);
909        assert!(result.is_err());
910    }
911}