1use crate::{
4 shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5 GpuDevice, GpuError, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[repr(C)]
14#[derive(Copy, Clone, Pod, Zeroable)]
15struct TransformParams {
16 width: u32,
17 height: u32,
18 block_size: u32,
19 transform_type: u32,
20 stride: u32,
21 is_inverse: u32,
22 padding1: u32,
23 padding2: u32,
24}
25
26pub struct TransformOperation;
28
29impl TransformOperation {
30 pub fn dct_2d(
47 device: &GpuDevice,
48 input: &[f32],
49 output: &mut [f32],
50 width: u32,
51 height: u32,
52 ) -> Result<()> {
53 if width % 8 != 0 || height % 8 != 0 {
54 return Err(GpuError::InvalidDimensions { width, height });
55 }
56
57 utils::validate_dimensions(width, height)?;
58
59 let expected_size = (width * height) as usize;
60 if input.len() < expected_size || output.len() < expected_size {
61 return Err(GpuError::InvalidBufferSize {
62 expected: expected_size,
63 actual: input.len().min(output.len()),
64 });
65 }
66
67 let pipeline = Self::get_dct_8x8_pipeline(device)?;
68 let layout = Self::get_bind_group_layout(device)?;
69
70 Self::execute_transform(
71 device, pipeline, layout, input, output, width, height, 8, 0, )
73 }
74
75 pub fn idct_2d(
92 device: &GpuDevice,
93 input: &[f32],
94 output: &mut [f32],
95 width: u32,
96 height: u32,
97 ) -> Result<()> {
98 if width % 8 != 0 || height % 8 != 0 {
99 return Err(GpuError::InvalidDimensions { width, height });
100 }
101
102 utils::validate_dimensions(width, height)?;
103
104 let expected_size = (width * height) as usize;
105 if input.len() < expected_size || output.len() < expected_size {
106 return Err(GpuError::InvalidBufferSize {
107 expected: expected_size,
108 actual: input.len().min(output.len()),
109 });
110 }
111
112 let pipeline = Self::get_idct_8x8_pipeline(device)?;
113 let layout = Self::get_bind_group_layout(device)?;
114
115 Self::execute_transform(
116 device, pipeline, layout, input, output, width, height, 8, 1, )
118 }
119
120 pub fn dct_2d_general(
136 device: &GpuDevice,
137 input: &[f32],
138 output: &mut [f32],
139 width: u32,
140 height: u32,
141 ) -> Result<()> {
142 utils::validate_dimensions(width, height)?;
143
144 let expected_size = (width * height) as usize;
145 if input.len() < expected_size || output.len() < expected_size {
146 return Err(GpuError::InvalidBufferSize {
147 expected: expected_size,
148 actual: input.len().min(output.len()),
149 });
150 }
151
152 let mut temp = vec![0.0f32; expected_size];
154
155 let row_pipeline = Self::get_dct_row_pipeline(device)?;
157 let layout = Self::get_bind_group_layout(device)?;
158
159 Self::execute_transform(
160 device,
161 row_pipeline,
162 layout,
163 input,
164 &mut temp,
165 width,
166 height,
167 width,
168 0,
169 )?;
170
171 let col_pipeline = Self::get_dct_col_pipeline(device)?;
173
174 Self::execute_transform(
175 device,
176 col_pipeline,
177 layout,
178 &temp,
179 output,
180 width,
181 height,
182 height,
183 0,
184 )
185 }
186
187 #[allow(clippy::too_many_arguments)]
188 fn execute_transform(
189 device: &GpuDevice,
190 pipeline: &ComputePipeline,
191 layout: &BindGroupLayout,
192 input: &[f32],
193 output: &mut [f32],
194 width: u32,
195 height: u32,
196 block_size: u32,
197 transform_type: u32,
198 ) -> Result<()> {
199 let input_bytes = bytemuck::cast_slice(input);
200 let output_size = std::mem::size_of_val(output);
201
202 let input_buffer = utils::create_storage_buffer(device, input_bytes.len() as u64)?;
204 let output_buffer = utils::create_storage_buffer(device, output_size as u64)?;
205
206 device
208 .queue()
209 .write_buffer(input_buffer.buffer(), 0, input_bytes);
210
211 let params = TransformParams {
213 width,
214 height,
215 block_size,
216 transform_type,
217 stride: width,
218 is_inverse: 0,
219 padding1: 0,
220 padding2: 0,
221 };
222 let params_bytes = bytemuck::bytes_of(¶ms);
223 let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
224
225 let compiler = ShaderCompiler::new(device);
227 let bind_group = compiler.create_bind_group(
228 "Transform Bind Group",
229 layout,
230 &[
231 wgpu::BindGroupEntry {
232 binding: 0,
233 resource: input_buffer.buffer().as_entire_binding(),
234 },
235 wgpu::BindGroupEntry {
236 binding: 1,
237 resource: output_buffer.buffer().as_entire_binding(),
238 },
239 wgpu::BindGroupEntry {
240 binding: 2,
241 resource: params_buffer.buffer().as_entire_binding(),
242 },
243 ],
244 );
245
246 Self::dispatch_compute(device, pipeline, &bind_group, width, height, block_size)?;
248
249 let readback_buffer = utils::create_readback_buffer(device, output_size as u64)?;
251 let mut encoder = device
252 .device()
253 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
254 label: Some("Transform Copy Encoder"),
255 });
256
257 output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output_size as u64)?;
258
259 device.queue().submit(Some(encoder.finish()));
260 device.wait();
261
262 let result = readback_buffer.read(device, 0, output_size as u64)?;
263 let result_f32: &[f32] = bytemuck::cast_slice(&result);
264 output.copy_from_slice(result_f32);
265
266 Ok(())
267 }
268
269 fn dispatch_compute(
270 device: &GpuDevice,
271 pipeline: &ComputePipeline,
272 bind_group: &BindGroup,
273 width: u32,
274 height: u32,
275 block_size: u32,
276 ) -> Result<()> {
277 let mut encoder = device
278 .device()
279 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
280 label: Some("Transform Compute Encoder"),
281 });
282
283 {
284 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
285 label: Some("Transform Compute Pass"),
286 timestamp_writes: None,
287 });
288
289 compute_pass.set_pipeline(pipeline);
290 compute_pass.set_bind_group(0, bind_group, &[]);
291
292 if block_size == 8 {
293 let dispatch_x = width / 8;
295 let dispatch_y = height / 8;
296 compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
297 } else {
298 let total_elements = width * height;
300 let dispatch = total_elements.div_ceil(256);
301 compute_pass.dispatch_workgroups(dispatch, 1, 1);
302 }
303 }
304
305 device.queue().submit(Some(encoder.finish()));
306 Ok(())
307 }
308
309 fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
310 static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
311
312 Ok(LAYOUT.get_or_init(|| {
313 let compiler = ShaderCompiler::new(device);
314 let entries = BindGroupLayoutBuilder::new()
315 .add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .build();
319
320 compiler.create_bind_group_layout("Transform Bind Group Layout", &entries)
321 }))
322 }
323
324 fn init_pipeline(
325 device: &GpuDevice,
326 name: &str,
327 entry_point: &str,
328 ) -> std::result::Result<ComputePipeline, String> {
329 let compiler = ShaderCompiler::new(device);
330 let shader = compiler
331 .compile(
332 "Transform Shader",
333 ShaderSource::Embedded(crate::shader::embedded::TRANSFORM_SHADER),
334 )
335 .map_err(|e| format!("Failed to compile transform shader: {e}"))?;
336
337 let layout = Self::get_bind_group_layout(device)
338 .map_err(|e| format!("Failed to create bind group layout: {e}"))?;
339
340 compiler
341 .create_pipeline(name, &shader, entry_point, layout)
342 .map_err(|e| format!("Failed to create pipeline: {e}"))
343 }
344
345 fn get_dct_8x8_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
346 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
347
348 PIPELINE
349 .get_or_init(|| {
350 TransformOperation::init_pipeline(device, "DCT 8x8 Pipeline", "dct_8x8")
351 })
352 .as_ref()
353 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
354 }
355
356 fn get_idct_8x8_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
357 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
358
359 PIPELINE
360 .get_or_init(|| {
361 TransformOperation::init_pipeline(device, "IDCT 8x8 Pipeline", "idct_8x8")
362 })
363 .as_ref()
364 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
365 }
366
367 fn get_dct_row_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
368 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
369
370 PIPELINE
371 .get_or_init(|| {
372 TransformOperation::init_pipeline(device, "DCT Row Pipeline", "dct_row")
373 })
374 .as_ref()
375 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
376 }
377
378 fn get_dct_col_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
379 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
380
381 PIPELINE
382 .get_or_init(|| {
383 TransformOperation::init_pipeline(device, "DCT Column Pipeline", "dct_col")
384 })
385 .as_ref()
386 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
387 }
388}
389
390#[derive(Debug, Clone, Copy)]
401pub struct PerspectiveMatrix {
402 pub data: [[f64; 3]; 3],
404}
405
406impl PerspectiveMatrix {
407 #[must_use]
409 pub fn from_array(m: [f64; 9]) -> Self {
410 Self {
411 data: [[m[0], m[1], m[2]], [m[3], m[4], m[5]], [m[6], m[7], m[8]]],
412 }
413 }
414
415 #[must_use]
417 pub fn identity() -> Self {
418 Self::from_array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0])
419 }
420
421 #[must_use]
426 pub fn project(&self, x: f64, y: f64) -> Option<(f64, f64)> {
427 let m = &self.data;
428 let x_h = m[0][0] * x + m[0][1] * y + m[0][2];
429 let y_h = m[1][0] * x + m[1][1] * y + m[1][2];
430 let w = m[2][0] * x + m[2][1] * y + m[2][2];
431 if w.abs() < 1e-12 {
432 return None;
433 }
434 Some((x_h / w, y_h / w))
435 }
436
437 #[must_use]
441 pub fn inverse(&self) -> Option<Self> {
442 let m = &self.data;
443 let det = m[0][0] * (m[1][1] * m[2][2] - m[1][2] * m[2][1])
444 - m[0][1] * (m[1][0] * m[2][2] - m[1][2] * m[2][0])
445 + m[0][2] * (m[1][0] * m[2][1] - m[1][1] * m[2][0]);
446 if det.abs() < 1e-15 {
447 return None;
448 }
449 let inv_det = 1.0 / det;
450 let inv = [
451 [
452 (m[1][1] * m[2][2] - m[1][2] * m[2][1]) * inv_det,
453 (m[0][2] * m[2][1] - m[0][1] * m[2][2]) * inv_det,
454 (m[0][1] * m[1][2] - m[0][2] * m[1][1]) * inv_det,
455 ],
456 [
457 (m[1][2] * m[2][0] - m[1][0] * m[2][2]) * inv_det,
458 (m[0][0] * m[2][2] - m[0][2] * m[2][0]) * inv_det,
459 (m[0][2] * m[1][0] - m[0][0] * m[1][2]) * inv_det,
460 ],
461 [
462 (m[1][0] * m[2][1] - m[1][1] * m[2][0]) * inv_det,
463 (m[0][1] * m[2][0] - m[0][0] * m[2][1]) * inv_det,
464 (m[0][0] * m[1][1] - m[0][1] * m[1][0]) * inv_det,
465 ],
466 ];
467 Some(Self { data: inv })
468 }
469}
470
471impl Default for PerspectiveMatrix {
472 fn default() -> Self {
473 Self::identity()
474 }
475}
476
477#[derive(Debug, Clone, Copy)]
482pub struct LensDistortionParams {
483 pub k1: f64,
485 pub k2: f64,
487 pub k3: f64,
489 pub p1: f64,
491 pub p2: f64,
493 pub fx: f64,
495 pub fy: f64,
497 pub cx: f64,
499 pub cy: f64,
501}
502
503impl LensDistortionParams {
504 #[must_use]
507 pub fn no_distortion(width: u32, height: u32) -> Self {
508 Self {
509 k1: 0.0,
510 k2: 0.0,
511 k3: 0.0,
512 p1: 0.0,
513 p2: 0.0,
514 fx: f64::from(width),
515 fy: f64::from(height),
516 cx: f64::from(width) / 2.0,
517 cy: f64::from(height) / 2.0,
518 }
519 }
520}
521
522pub fn perspective_warp(
535 input: &[u8],
536 src_width: u32,
537 src_height: u32,
538 output: &mut [u8],
539 dst_width: u32,
540 dst_height: u32,
541 matrix: &PerspectiveMatrix,
542 fill_rgba: [u8; 4],
543) -> crate::Result<()> {
544 use super::utils;
545 use crate::GpuError;
546
547 if src_width == 0 || src_height == 0 {
548 return Err(GpuError::InvalidDimensions {
549 width: src_width,
550 height: src_height,
551 });
552 }
553 if dst_width == 0 || dst_height == 0 {
554 return Err(GpuError::InvalidDimensions {
555 width: dst_width,
556 height: dst_height,
557 });
558 }
559 utils::validate_buffer_size(input, src_width, src_height, 4)?;
560 utils::validate_buffer_size(output, dst_width, dst_height, 4)?;
561
562 let inv = matrix
563 .inverse()
564 .ok_or_else(|| GpuError::Internal("Perspective matrix is singular".to_string()))?;
565
566 let sw = src_width as usize;
567 let sh = src_height as usize;
568 let dw = dst_width as usize;
569 let dh = dst_height as usize;
570
571 for dy in 0..dh {
572 for dx in 0..dw {
573 let dst_idx = (dy * dw + dx) * 4;
574 let Some((sx_f, sy_f)) = inv.project(dx as f64, dy as f64) else {
575 output[dst_idx..dst_idx + 4].copy_from_slice(&fill_rgba);
576 continue;
577 };
578
579 let x0 = sx_f.floor() as isize;
581 let y0 = sy_f.floor() as isize;
582 let x1 = x0 + 1;
583 let y1 = y0 + 1;
584 let fx = sx_f - sx_f.floor();
585 let fy = sy_f - sy_f.floor();
586
587 let sample = |cx: isize, cy: isize| -> [f64; 4] {
588 if cx < 0 || cy < 0 || cx >= sw as isize || cy >= sh as isize {
589 [
590 fill_rgba[0] as f64,
591 fill_rgba[1] as f64,
592 fill_rgba[2] as f64,
593 fill_rgba[3] as f64,
594 ]
595 } else {
596 let idx = (cy as usize * sw + cx as usize) * 4;
597 [
598 input[idx] as f64,
599 input[idx + 1] as f64,
600 input[idx + 2] as f64,
601 input[idx + 3] as f64,
602 ]
603 }
604 };
605
606 let p00 = sample(x0, y0);
607 let p10 = sample(x1, y0);
608 let p01 = sample(x0, y1);
609 let p11 = sample(x1, y1);
610
611 for c in 0..4 {
612 let v = p00[c] * (1.0 - fx) * (1.0 - fy)
613 + p10[c] * fx * (1.0 - fy)
614 + p01[c] * (1.0 - fx) * fy
615 + p11[c] * fx * fy;
616 output[dst_idx + c] = v.round().clamp(0.0, 255.0) as u8;
617 }
618 }
619 }
620
621 Ok(())
622}
623
624pub fn lens_undistort(
634 input: &[u8],
635 width: u32,
636 height: u32,
637 output: &mut [u8],
638 params: &LensDistortionParams,
639 fill_rgba: [u8; 4],
640) -> crate::Result<()> {
641 use super::utils;
642 use crate::GpuError;
643
644 if width == 0 || height == 0 {
645 return Err(GpuError::InvalidDimensions { width, height });
646 }
647 utils::validate_buffer_size(input, width, height, 4)?;
648 utils::validate_buffer_size(output, width, height, 4)?;
649
650 let w = width as usize;
651 let h = height as usize;
652 let inv_fx = 1.0 / params.fx;
653 let inv_fy = 1.0 / params.fy;
654
655 for dy in 0..h {
656 for dx in 0..w {
657 let x_u = (dx as f64 - params.cx) * inv_fx;
659 let y_u = (dy as f64 - params.cy) * inv_fy;
660
661 let r2 = x_u * x_u + y_u * y_u;
664 let r4 = r2 * r2;
665 let r6 = r4 * r2;
666 let radial = 1.0 + params.k1 * r2 + params.k2 * r4 + params.k3 * r6;
667 let x_d =
668 x_u * radial + 2.0 * params.p1 * x_u * y_u + params.p2 * (r2 + 2.0 * x_u * x_u);
669 let y_d =
670 y_u * radial + params.p1 * (r2 + 2.0 * y_u * y_u) + 2.0 * params.p2 * x_u * y_u;
671
672 let sx_f = x_d * params.fx + params.cx;
674 let sy_f = y_d * params.fy + params.cy;
675
676 let dst_idx = (dy * w + dx) * 4;
677
678 let x0 = sx_f.floor() as isize;
679 let y0 = sy_f.floor() as isize;
680 let x1 = x0 + 1;
681 let y1 = y0 + 1;
682 let fx = sx_f - sx_f.floor();
683 let fy = sy_f - sy_f.floor();
684
685 let sample = |cx: isize, cy: isize| -> [f64; 4] {
686 if cx < 0 || cy < 0 || cx >= w as isize || cy >= h as isize {
687 [
688 fill_rgba[0] as f64,
689 fill_rgba[1] as f64,
690 fill_rgba[2] as f64,
691 fill_rgba[3] as f64,
692 ]
693 } else {
694 let idx = (cy as usize * w + cx as usize) * 4;
695 [
696 input[idx] as f64,
697 input[idx + 1] as f64,
698 input[idx + 2] as f64,
699 input[idx + 3] as f64,
700 ]
701 }
702 };
703
704 let p00 = sample(x0, y0);
705 let p10 = sample(x1, y0);
706 let p01 = sample(x0, y1);
707 let p11 = sample(x1, y1);
708
709 for c in 0..4 {
710 let v = p00[c] * (1.0 - fx) * (1.0 - fy)
711 + p10[c] * fx * (1.0 - fy)
712 + p01[c] * (1.0 - fx) * fy
713 + p11[c] * fx * fy;
714 output[dst_idx + c] = v.round().clamp(0.0, 255.0) as u8;
715 }
716 }
717 }
718
719 Ok(())
720}
721
722#[cfg(test)]
727mod tests {
728 use super::*;
729
730 fn solid_rgba(w: u32, h: u32, r: u8, g: u8, b: u8, a: u8) -> Vec<u8> {
731 let n = (w * h * 4) as usize;
732 let mut v = vec![0u8; n];
733 for px in v.chunks_exact_mut(4) {
734 px[0] = r;
735 px[1] = g;
736 px[2] = b;
737 px[3] = a;
738 }
739 v
740 }
741
742 #[test]
745 fn test_perspective_identity_project() {
746 let m = PerspectiveMatrix::identity();
747 let (x, y) = m
748 .project(100.0, 200.0)
749 .expect("identity must not return None");
750 assert!((x - 100.0).abs() < 1e-10, "x={x}");
751 assert!((y - 200.0).abs() < 1e-10, "y={y}");
752 }
753
754 #[test]
755 fn test_perspective_translation() {
756 let m = PerspectiveMatrix::from_array([1.0, 0.0, 10.0, 0.0, 1.0, 20.0, 0.0, 0.0, 1.0]);
758 let (x, y) = m.project(5.0, 5.0).expect("no infinity");
759 assert!((x - 15.0).abs() < 1e-10, "x={x}");
760 assert!((y - 25.0).abs() < 1e-10, "y={y}");
761 }
762
763 #[test]
764 fn test_perspective_inverse_is_correct() {
765 let m = PerspectiveMatrix::from_array([1.0, 0.5, 10.0, -0.2, 1.0, 5.0, 0.001, 0.0, 1.0]);
766 let inv = m.inverse().expect("non-singular matrix must have inverse");
767 let (x_orig, y_orig) = (50.0_f64, 30.0_f64);
769 let (x_proj, y_proj) = m.project(x_orig, y_orig).expect("forward project");
770 let (x_back, y_back) = inv.project(x_proj, y_proj).expect("inverse project");
771 assert!(
772 (x_back - x_orig).abs() < 1e-6,
773 "x roundtrip: {x_back} ≠ {x_orig}"
774 );
775 assert!(
776 (y_back - y_orig).abs() < 1e-6,
777 "y roundtrip: {y_back} ≠ {y_orig}"
778 );
779 }
780
781 #[test]
782 fn test_perspective_singular_returns_none_inverse() {
783 let m = PerspectiveMatrix::from_array([0.0; 9]);
785 assert!(m.inverse().is_none(), "singular matrix must return None");
786 }
787
788 #[test]
791 fn test_perspective_warp_identity_preserves_image() {
792 let w = 8u32;
793 let h = 8u32;
794 let src = solid_rgba(w, h, 100, 150, 200, 255);
795 let mut dst = vec![0u8; (w * h * 4) as usize];
796 perspective_warp(
797 &src,
798 w,
799 h,
800 &mut dst,
801 w,
802 h,
803 &PerspectiveMatrix::identity(),
804 [0, 0, 0, 0],
805 )
806 .expect("identity warp must succeed");
807 for (s, d) in src.iter().zip(dst.iter()) {
809 assert!(
810 (*s as i32 - *d as i32).unsigned_abs() <= 1,
811 "identity warp mismatch"
812 );
813 }
814 }
815
816 #[test]
817 fn test_perspective_warp_out_of_bounds_uses_fill() {
818 let w = 4u32;
819 let h = 4u32;
820 let src = solid_rgba(w, h, 255, 0, 0, 255);
821 let mut dst = vec![0u8; (w * h * 4) as usize];
822 let m =
824 PerspectiveMatrix::from_array([1.0, 0.0, 10000.0, 0.0, 1.0, 10000.0, 0.0, 0.0, 1.0]);
825 perspective_warp(&src, w, h, &mut dst, w, h, &m, [0, 255, 0, 255])
826 .expect("warp must succeed");
827 for i in 0..(w * h) as usize {
829 assert_eq!(dst[i * 4 + 1], 255, "fill green channel mismatch");
830 }
831 }
832
833 #[test]
834 fn test_perspective_warp_invalid_dims_return_error() {
835 let src = solid_rgba(4, 4, 0, 0, 0, 255);
836 let mut dst = vec![0u8; 16 * 4];
837 let result = perspective_warp(
838 &src,
839 0,
840 4,
841 &mut dst,
842 4,
843 4,
844 &PerspectiveMatrix::identity(),
845 [0; 4],
846 );
847 assert!(result.is_err());
848 }
849
850 #[test]
853 fn test_lens_undistort_no_distortion_identity() {
854 let w = 8u32;
855 let h = 8u32;
856 let src = solid_rgba(w, h, 50, 100, 150, 255);
857 let mut dst = vec![0u8; (w * h * 4) as usize];
858 let params = LensDistortionParams::no_distortion(w, h);
859 lens_undistort(&src, w, h, &mut dst, ¶ms, [0; 4]).expect("no distortion must succeed");
860 for px in dst.chunks_exact(4).take(4) {
862 assert!((px[0] as i32 - 50).unsigned_abs() <= 2, "R mismatch");
863 assert!((px[1] as i32 - 100).unsigned_abs() <= 2, "G mismatch");
864 assert!((px[2] as i32 - 150).unsigned_abs() <= 2, "B mismatch");
865 }
866 }
867
868 #[test]
869 fn test_lens_undistort_preserves_centre_pixel() {
870 let w = 9u32; let h = 9u32;
873 let mut src = vec![0u8; (w * h * 4) as usize];
874 let cx = (w / 2) as usize;
876 let cy = (h / 2) as usize;
877 let center_idx = (cy * w as usize + cx) * 4;
878 src[center_idx] = 255;
879 src[center_idx + 1] = 128;
880 src[center_idx + 2] = 64;
881 src[center_idx + 3] = 255;
882 let mut dst = vec![0u8; (w * h * 4) as usize];
883 let params = LensDistortionParams {
884 k1: 0.1,
885 k2: 0.0,
886 k3: 0.0,
887 p1: 0.0,
888 p2: 0.0,
889 fx: f64::from(w),
890 fy: f64::from(h),
891 cx: f64::from(w) / 2.0,
892 cy: f64::from(h) / 2.0,
893 };
894 lens_undistort(&src, w, h, &mut dst, ¶ms, [0; 4]).expect("undistort must succeed");
895 let out_r = dst[center_idx];
897 assert!(
898 out_r > 128,
899 "centre R should reflect the marked pixel, got {out_r}"
900 );
901 }
902
903 #[test]
904 fn test_lens_undistort_invalid_dims_return_error() {
905 let src = vec![0u8; 64];
906 let mut dst = vec![0u8; 64];
907 let params = LensDistortionParams::no_distortion(4, 4);
908 let result = lens_undistort(&src, 0, 4, &mut dst, ¶ms, [0; 4]);
909 assert!(result.is_err());
910 }
911}