#![doc = include_str!("dehaze.md")]
use std::collections::VecDeque;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
pub const DEHAZE_AMOUNT_MIN: f32 = -100.0;
pub const DEHAZE_AMOUNT_MAX: f32 = 100.0;
#[cfg_attr(feature = "docgen", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DehazeParams {
#[serde(default)]
#[cfg_attr(feature = "docgen", schemars(range(min = -100.0, max = 100.0)))]
pub amount: f32,
}
impl Default for DehazeParams {
fn default() -> Self {
Self { amount: 0.0 }
}
}
impl DehazeParams {
pub fn is_neutral(&self) -> bool {
self.amount == 0.0
}
}
const PATCH_SIZE: usize = 15;
const AIRLIGHT_PERCENTILE: f64 = 0.001;
#[derive(Clone, Copy)]
struct UnsafeSlicePtr(*mut f32);
unsafe impl Send for UnsafeSlicePtr {}
unsafe impl Sync for UnsafeSlicePtr {}
impl UnsafeSlicePtr {
fn ptr(self) -> *mut f32 {
self.0
}
}
fn min_filter_1d(data: &[f32], window_size: usize) -> Vec<f32> {
let n = data.len();
if n == 0 {
return Vec::new();
}
let half = window_size / 2;
let mut result = vec![0.0_f32; n];
let mut deque: VecDeque<usize> = VecDeque::new();
for right in 0..(n + half) {
if right < n {
while let Some(&back) = deque.back() {
if data[back] >= data[right] {
deque.pop_back();
} else {
break;
}
}
deque.push_back(right);
}
let j = right as isize - half as isize;
if j >= 0 && (j as usize) < n {
let j = j as usize;
let left = j.saturating_sub(half);
while let Some(&front) = deque.front() {
if front < left {
deque.pop_front();
} else {
break;
}
}
result[j] = data[deque[0]];
}
}
result
}
fn dark_channel(buf: &[[f32; 3]], width: usize, height: usize) -> Vec<f32> {
let n = width * height;
let mut pixel_min = vec![0.0_f32; n];
pixel_min
.par_chunks_mut(1024)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let base = chunk_idx * 1024;
for (i, val) in chunk.iter_mut().enumerate() {
let [r, g, b] = buf[base + i];
*val = r.min(g).min(b);
}
});
let mut h_filtered = vec![0.0_f32; n];
h_filtered
.par_chunks_mut(width)
.enumerate()
.for_each(|(y, row_out)| {
let row_start = y * width;
let row = &pixel_min[row_start..row_start + width];
let filtered = min_filter_1d(row, PATCH_SIZE);
row_out.copy_from_slice(&filtered);
});
let mut result = vec![0.0_f32; n];
let result_send = UnsafeSlicePtr(result.as_mut_ptr());
(0..width).into_par_iter().for_each(|x| {
let mut col_buf = vec![0.0_f32; height];
for y in 0..height {
col_buf[y] = h_filtered[y * width + x];
}
let filtered = min_filter_1d(&col_buf, PATCH_SIZE);
let ptr = result_send.ptr();
for (y, &val) in filtered.iter().enumerate() {
unsafe { *ptr.add(y * width + x) = val };
}
});
result
}
fn estimate_airlight(buf: &[[f32; 3]], dark_ch: &[f32]) -> [f32; 3] {
let n = buf.len();
if n == 0 {
return [1.0, 1.0, 1.0];
}
let top_count = ((n as f64 * AIRLIGHT_PERCENTILE).ceil() as usize).max(1);
let mut indices: Vec<usize> = (0..n).collect();
let pivot = top_count.min(n) - 1;
indices.select_nth_unstable_by(pivot, |&a, &b| dark_ch[b].partial_cmp(&dark_ch[a]).unwrap());
let mut best_idx = indices[0];
let mut best_intensity = 0.0_f32;
for &idx in indices.iter().take(top_count) {
let [r, g, b] = buf[idx];
let intensity = r + g + b;
if intensity > best_intensity {
best_intensity = intensity;
best_idx = idx;
}
}
buf[best_idx]
}
const GUIDED_FILTER_RADIUS: usize = 40;
const GUIDED_FILTER_EPSILON: f32 = 0.001;
fn box_filter_1d(data: &[f32], radius: usize) -> Vec<f32> {
let n = data.len();
if n == 0 {
return Vec::new();
}
let mut prefix = vec![0.0_f32; n + 1];
for i in 0..n {
prefix[i + 1] = prefix[i] + data[i];
}
let mut result = vec![0.0_f32; n];
for (i, val) in result.iter_mut().enumerate() {
let left = i.saturating_sub(radius);
let right = (i + radius).min(n - 1);
let count = (right - left + 1) as f32;
*val = (prefix[right + 1] - prefix[left]) / count;
}
result
}
fn box_filter_2d(data: &[f32], width: usize, height: usize, radius: usize) -> Vec<f32> {
let n = width * height;
let mut h_filtered = vec![0.0_f32; n];
h_filtered
.par_chunks_mut(width)
.enumerate()
.for_each(|(y, row_out)| {
let row_start = y * width;
let row = &data[row_start..row_start + width];
let filtered = box_filter_1d(row, radius);
row_out.copy_from_slice(&filtered);
});
let mut result = vec![0.0_f32; n];
let result_send = UnsafeSlicePtr(result.as_mut_ptr());
(0..width).into_par_iter().for_each(|x| {
let mut col = vec![0.0_f32; height];
for y in 0..height {
col[y] = h_filtered[y * width + x];
}
let filtered = box_filter_1d(&col, radius);
let ptr = result_send.ptr();
for (y, &val) in filtered.iter().enumerate() {
unsafe { *ptr.add(y * width + x) = val };
}
});
result
}
fn guided_filter(guide: &[f32], input: &[f32], width: usize, height: usize) -> Vec<f32> {
let r = GUIDED_FILTER_RADIUS;
let eps = GUIDED_FILTER_EPSILON;
let n = width * height;
let mean_g = box_filter_2d(guide, width, height, r);
let mean_p = box_filter_2d(input, width, height, r);
let mut gp = vec![0.0_f32; n];
gp.par_chunks_mut(1024)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let base = chunk_idx * 1024;
for (i, val) in chunk.iter_mut().enumerate() {
*val = guide[base + i] * input[base + i];
}
});
let mut gg = vec![0.0_f32; n];
gg.par_chunks_mut(1024)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let base = chunk_idx * 1024;
for (i, val) in chunk.iter_mut().enumerate() {
*val = guide[base + i] * guide[base + i];
}
});
let mean_gp = box_filter_2d(&gp, width, height, r);
let mean_gg = box_filter_2d(&gg, width, height, r);
let mut a = vec![0.0_f32; n];
let mut b = vec![0.0_f32; n];
a.par_chunks_mut(1024)
.zip(b.par_chunks_mut(1024))
.enumerate()
.for_each(|(chunk_idx, (a_chunk, b_chunk))| {
let base = chunk_idx * 1024;
for i in 0..a_chunk.len() {
let idx = base + i;
let cov_gp = mean_gp[idx] - mean_g[idx] * mean_p[idx];
let var_g = mean_gg[idx] - mean_g[idx] * mean_g[idx];
a_chunk[i] = cov_gp / (var_g + eps);
b_chunk[i] = mean_p[idx] - a_chunk[i] * mean_g[idx];
}
});
let mean_a = box_filter_2d(&a, width, height, r);
let mean_b = box_filter_2d(&b, width, height, r);
let mut result = vec![0.0_f32; n];
result
.par_chunks_mut(1024)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let base = chunk_idx * 1024;
for (i, val) in chunk.iter_mut().enumerate() {
let idx = base + i;
*val = mean_a[idx] * guide[idx] + mean_b[idx];
}
});
result
}
const T_MIN: f32 = 0.1;
pub fn apply_dehaze(
buf: &[[f32; 3]],
width: usize,
height: usize,
params: &DehazeParams,
) -> Vec<[f32; 3]> {
if params.is_neutral() {
return buf.to_vec();
}
let n = width * height;
let amount = params.amount;
let dc = dark_channel(buf, width, height);
let a = estimate_airlight(buf, &dc);
if amount < 0.0 {
let strength = (-amount / 100.0).min(1.0);
let mut result = vec![[0.0_f32; 3]; n];
result
.par_chunks_mut(1024)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let base = chunk_idx * 1024;
for (i, px) in chunk.iter_mut().enumerate() {
let src = buf[base + i];
for c in 0..3 {
px[c] = (src[c] * (1.0 - strength) + a[c] * strength).clamp(0.0, 1.0);
}
}
});
return result;
}
let omega = (amount / 100.0).min(1.0);
let a_safe = [a[0].max(0.01), a[1].max(0.01), a[2].max(0.01)];
let mut normalized = vec![[0.0_f32; 3]; n];
normalized
.par_chunks_mut(1024)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let base = chunk_idx * 1024;
for (i, px) in chunk.iter_mut().enumerate() {
let src = buf[base + i];
*px = [src[0] / a_safe[0], src[1] / a_safe[1], src[2] / a_safe[2]];
}
});
let dc_norm = dark_channel(&normalized, width, height);
let mut t_raw = vec![0.0_f32; n];
t_raw
.par_chunks_mut(1024)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let base = chunk_idx * 1024;
for (i, val) in chunk.iter_mut().enumerate() {
*val = 1.0 - omega * dc_norm[base + i];
}
});
let mut guide = vec![0.0_f32; n];
guide
.par_chunks_mut(1024)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let base = chunk_idx * 1024;
for (i, val) in chunk.iter_mut().enumerate() {
let [r, g, b] = buf[base + i];
*val = super::LUMA_R * r + super::LUMA_G * g + super::LUMA_B * b;
}
});
let t_refined = guided_filter(&guide, &t_raw, width, height);
let mut result = vec![[0.0_f32; 3]; n];
result
.par_chunks_mut(1024)
.enumerate()
.for_each(|(chunk_idx, chunk)| {
let base = chunk_idx * 1024;
for (i, px) in chunk.iter_mut().enumerate() {
let idx = base + i;
let t = t_refined[idx].max(T_MIN);
for c in 0..3 {
let recovered = (buf[idx][c] - a[c]) / t + a[c];
px[c] = recovered.clamp(0.0, 1.0);
}
}
});
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_params_are_neutral() {
let p = DehazeParams::default();
assert_eq!(p.amount, 0.0);
assert!(p.is_neutral());
}
#[test]
fn non_zero_amount_is_not_neutral() {
let p = DehazeParams { amount: 50.0 };
assert!(!p.is_neutral());
}
#[test]
fn negative_amount_is_not_neutral() {
let p = DehazeParams { amount: -30.0 };
assert!(!p.is_neutral());
}
#[test]
fn dark_channel_uniform_buffer() {
let buf = vec![[0.5_f32, 0.5, 0.5]; 4];
let dc = dark_channel(&buf, 2, 2);
for &v in &dc {
assert!((v - 0.5).abs() < 1e-6);
}
}
#[test]
fn dark_channel_picks_min_rgb() {
let buf = vec![[0.8, 0.3, 0.6]; 1];
let dc = dark_channel(&buf, 1, 1);
assert!((dc[0] - 0.3).abs() < 1e-6);
}
#[test]
fn dark_channel_spreads_minimum_across_patch() {
let mut buf = vec![[0.9, 0.9, 0.9]; 9]; buf[4] = [0.1, 0.1, 0.1]; let dc = dark_channel(&buf, 3, 3);
for &v in &dc {
assert!((v - 0.1).abs() < 1e-6, "Expected 0.1, got {v}");
}
}
#[test]
fn airlight_selects_brightest_in_haziest_region() {
let mut buf = vec![[0.1, 0.1, 0.1]; 16]; buf[0] = [0.9, 0.85, 0.8]; let dc = dark_channel(&buf, 4, 4);
let a = estimate_airlight(&buf, &dc);
assert!(a[0] > 0.5, "Expected bright airlight R, got {}", a[0]);
}
#[test]
fn guided_filter_uniform_input_is_identity() {
let guide = vec![0.5_f32; 9];
let input = vec![0.7_f32; 9];
let result = guided_filter(&guide, &input, 3, 3);
for &v in &result {
assert!((v - 0.7).abs() < 1e-4, "Expected ~0.7, got {v}");
}
}
#[test]
fn guided_filter_preserves_step_edge() {
let width = 20;
let height = 1;
let mut guide = vec![0.0_f32; width];
let mut input = vec![0.0_f32; width];
for i in width / 2..width {
guide[i] = 1.0;
input[i] = 1.0;
}
let result = guided_filter(&guide, &input, width, height);
assert!(result[0] < 0.3, "Left should be dark, got {}", result[0]);
assert!(
result[width - 1] > 0.7,
"Right should be bright, got {}",
result[width - 1]
);
}
#[test]
fn apply_dehaze_zero_amount_is_identity() {
let buf = vec![[0.5, 0.3, 0.7]; 4];
let params = DehazeParams { amount: 0.0 };
let result = apply_dehaze(&buf, 2, 2, ¶ms);
for (i, px) in result.iter().enumerate() {
for c in 0..3 {
assert!(
(px[c] - buf[i][c]).abs() < 1e-6,
"Pixel {i} channel {c}: expected {}, got {}",
buf[i][c],
px[c]
);
}
}
}
#[test]
fn apply_dehaze_positive_changes_output() {
let mut buf = Vec::with_capacity(100);
for i in 0..100 {
let base = 0.5 + 0.3 * (i as f32 / 100.0);
buf.push([base, base * 0.9, base * 0.85]);
}
let params = DehazeParams { amount: 50.0 };
let result = apply_dehaze(&buf, 10, 10, ¶ms);
let differs = result
.iter()
.zip(buf.iter())
.any(|(r, b)| (r[0] - b[0]).abs() > 1e-4);
assert!(differs, "Dehaze should change hazy image");
}
#[test]
fn apply_dehaze_negative_adds_haze() {
let mut buf = Vec::with_capacity(100);
for i in 0..100 {
let t = i as f32 / 100.0;
buf.push([0.2 + 0.5 * t, 0.3 + 0.3 * t, 0.1 + 0.4 * t]);
}
let params = DehazeParams { amount: -30.0 };
let result = apply_dehaze(&buf, 10, 10, ¶ms);
let differs = result
.iter()
.zip(buf.iter())
.any(|(r, b)| (r[0] - b[0]).abs() > 1e-4);
assert!(differs, "Negative dehaze should add haze");
}
#[test]
fn apply_dehaze_output_clamped_to_0_1() {
let buf = vec![[0.95, 0.95, 0.95]; 100]; let params = DehazeParams { amount: 100.0 };
let result = apply_dehaze(&buf, 10, 10, ¶ms);
for px in &result {
for &v in px {
assert!((0.0..=1.0).contains(&v), "Output {v:.4} out of [0,1]");
}
}
}
#[test]
fn apply_dehaze_t_min_prevents_extreme_values() {
let buf = vec![[0.8, 0.8, 0.8]; 100]; let params = DehazeParams { amount: 100.0 };
let result = apply_dehaze(&buf, 10, 10, ¶ms);
for px in &result {
for &v in px {
assert!(
(0.0..=1.0).contains(&v),
"T_MIN should prevent extreme values, got {v:.4}"
);
}
}
}
}