use crate::error::{CvError, CvResult};
#[derive(Debug, Clone, Copy)]
pub struct DistortionModel {
pub k1: f64,
pub k2: f64,
pub k3: f64,
pub p1: f64,
pub p2: f64,
pub focal_length: f64,
}
impl Default for DistortionModel {
fn default() -> Self {
Self {
k1: 0.0,
k2: 0.0,
k3: 0.0,
p1: 0.0,
p2: 0.0,
focal_length: 0.0,
}
}
}
impl DistortionModel {
#[must_use]
pub fn new_radial(k1: f64, k2: f64) -> Self {
Self {
k1,
k2,
..Default::default()
}
}
#[must_use]
pub const fn new(k1: f64, k2: f64, k3: f64, p1: f64, p2: f64) -> Self {
Self {
k1,
k2,
k3,
p1,
p2,
focal_length: 0.0,
}
}
#[must_use]
pub const fn with_focal_length(mut self, fl: f64) -> Self {
self.focal_length = fl;
self
}
#[must_use]
pub fn is_identity(&self) -> bool {
self.k1.abs() < 1e-10
&& self.k2.abs() < 1e-10
&& self.k3.abs() < 1e-10
&& self.p1.abs() < 1e-10
&& self.p2.abs() < 1e-10
}
}
pub struct LensDistortionCorrector {
model: DistortionModel,
width: u32,
height: u32,
cx: f64,
cy: f64,
focal: f64,
remap: Vec<(f64, f64)>,
}
impl LensDistortionCorrector {
#[must_use]
pub fn new(model: DistortionModel, width: u32, height: u32) -> Self {
let cx = width as f64 / 2.0;
let cy = height as f64 / 2.0;
let focal = if model.focal_length > 0.0 {
model.focal_length
} else {
((cx * cx + cy * cy) as f64).sqrt()
};
let mut corrector = Self {
model,
width,
height,
cx,
cy,
focal,
remap: Vec::new(),
};
corrector.build_remap();
corrector
}
fn build_remap(&mut self) {
let n = (self.width as usize) * (self.height as usize);
self.remap = Vec::with_capacity(n);
for y in 0..self.height {
for x in 0..self.width {
let (sx, sy) = self.apply_distortion(x as f64, y as f64);
self.remap.push((sx, sy));
}
}
}
fn apply_distortion(&self, dst_x: f64, dst_y: f64) -> (f64, f64) {
let xn = (dst_x - self.cx) / self.focal;
let yn = (dst_y - self.cy) / self.focal;
let r2 = xn * xn + yn * yn;
let r4 = r2 * r2;
let r6 = r4 * r2;
let radial = 1.0 + self.model.k1 * r2 + self.model.k2 * r4 + self.model.k3 * r6;
let tx = 2.0 * self.model.p1 * xn * yn + self.model.p2 * (r2 + 2.0 * xn * xn);
let ty = self.model.p1 * (r2 + 2.0 * yn * yn) + 2.0 * self.model.p2 * xn * yn;
let xd = xn * radial + tx;
let yd = yn * radial + ty;
(xd * self.focal + self.cx, yd * self.focal + self.cy)
}
pub fn correct(&self, image: &[u8], width: u32, height: u32) -> CvResult<Vec<u8>> {
if width != self.width || height != self.height {
return Err(CvError::invalid_dimensions(width, height));
}
let expected = (width as usize) * (height as usize);
if image.len() < expected {
return Err(CvError::insufficient_data(expected, image.len()));
}
if self.model.is_identity() {
return Ok(image[..expected].to_vec());
}
let w = width as usize;
let h = height as usize;
let mut output = vec![0u8; expected];
for (dst_idx, &(src_x, src_y)) in self.remap.iter().enumerate() {
if src_x >= 0.0 && src_x < (w - 1) as f64 && src_y >= 0.0 && src_y < (h - 1) as f64 {
let x0 = src_x.floor() as usize;
let y0 = src_y.floor() as usize;
let x1 = x0 + 1;
let y1 = y0 + 1;
let fx = src_x - x0 as f64;
let fy = src_y - y0 as f64;
let v00 = image[y0 * w + x0] as f64;
let v01 = image[y0 * w + x1] as f64;
let v10 = image[y1 * w + x0] as f64;
let v11 = image[y1 * w + x1] as f64;
let v = v00 * (1.0 - fx) * (1.0 - fy)
+ v01 * fx * (1.0 - fy)
+ v10 * (1.0 - fx) * fy
+ v11 * fx * fy;
output[dst_idx] = v.round().clamp(0.0, 255.0) as u8;
}
}
Ok(output)
}
pub fn correct_rgb(&self, image: &[u8], width: u32, height: u32) -> CvResult<Vec<u8>> {
if width != self.width || height != self.height {
return Err(CvError::invalid_dimensions(width, height));
}
let expected = (width as usize) * (height as usize) * 3;
if image.len() < expected {
return Err(CvError::insufficient_data(expected, image.len()));
}
if self.model.is_identity() {
return Ok(image[..expected].to_vec());
}
let w = width as usize;
let h = height as usize;
let mut output = vec![0u8; expected];
for (dst_idx, &(src_x, src_y)) in self.remap.iter().enumerate() {
if src_x >= 0.0 && src_x < (w - 1) as f64 && src_y >= 0.0 && src_y < (h - 1) as f64 {
let x0 = src_x.floor() as usize;
let y0 = src_y.floor() as usize;
let x1 = x0 + 1;
let y1 = y0 + 1;
let fx = src_x - x0 as f64;
let fy = src_y - y0 as f64;
for c in 0..3 {
let v00 = image[(y0 * w + x0) * 3 + c] as f64;
let v01 = image[(y0 * w + x1) * 3 + c] as f64;
let v10 = image[(y1 * w + x0) * 3 + c] as f64;
let v11 = image[(y1 * w + x1) * 3 + c] as f64;
let v = v00 * (1.0 - fx) * (1.0 - fy)
+ v01 * fx * (1.0 - fy)
+ v10 * (1.0 - fx) * fy
+ v11 * fx * fy;
output[dst_idx * 3 + c] = v.round().clamp(0.0, 255.0) as u8;
}
}
}
Ok(output)
}
#[must_use]
pub const fn model(&self) -> &DistortionModel {
&self.model
}
#[must_use]
pub const fn dimensions(&self) -> (u32, u32) {
(self.width, self.height)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distortion_model_default_is_identity() {
let model = DistortionModel::default();
assert!(model.is_identity());
}
#[test]
fn test_distortion_model_radial() {
let model = DistortionModel::new_radial(-0.2, 0.05);
assert!(!model.is_identity());
assert_eq!(model.k1, -0.2);
assert_eq!(model.k2, 0.05);
}
#[test]
fn test_corrector_identity() {
let model = DistortionModel::default();
let corrector = LensDistortionCorrector::new(model, 100, 100);
let image = vec![128u8; 100 * 100];
let result = corrector.correct(&image, 100, 100).expect("should succeed");
assert_eq!(result, image);
}
#[test]
fn test_corrector_barrel() {
let model = DistortionModel::new_radial(0.1, 0.0);
let corrector = LensDistortionCorrector::new(model, 100, 100);
let image = vec![128u8; 100 * 100];
let result = corrector.correct(&image, 100, 100).expect("should succeed");
assert_eq!(result.len(), 100 * 100);
}
#[test]
fn test_corrector_pincushion() {
let model = DistortionModel::new_radial(-0.1, 0.0);
let corrector = LensDistortionCorrector::new(model, 100, 100);
let image = vec![128u8; 100 * 100];
let result = corrector.correct(&image, 100, 100).expect("should succeed");
assert_eq!(result.len(), 100 * 100);
}
#[test]
fn test_corrector_dimension_mismatch() {
let model = DistortionModel::default();
let corrector = LensDistortionCorrector::new(model, 100, 100);
let image = vec![128u8; 200 * 200];
let result = corrector.correct(&image, 200, 200);
assert!(result.is_err());
}
#[test]
fn test_corrector_rgb() {
let model = DistortionModel::new_radial(0.05, 0.0);
let corrector = LensDistortionCorrector::new(model, 50, 50);
let image = vec![128u8; 50 * 50 * 3];
let result = corrector
.correct_rgb(&image, 50, 50)
.expect("should succeed");
assert_eq!(result.len(), 50 * 50 * 3);
}
#[test]
fn test_center_pixel_unchanged() {
let model = DistortionModel::new_radial(0.1, 0.0);
let corrector = LensDistortionCorrector::new(model, 100, 100);
let (sx, sy) = corrector.apply_distortion(50.0, 50.0);
assert!((sx - 50.0).abs() < 1.0);
assert!((sy - 50.0).abs() < 1.0);
}
#[test]
fn test_with_focal_length() {
let model = DistortionModel::new_radial(0.1, 0.0).with_focal_length(500.0);
assert_eq!(model.focal_length, 500.0);
}
#[test]
fn test_corrector_dimensions() {
let model = DistortionModel::default();
let corrector = LensDistortionCorrector::new(model, 640, 480);
assert_eq!(corrector.dimensions(), (640, 480));
}
}