use alloc::vec;
use crate::cfa::{CfaPattern, Channel};
pub fn demosaic(input: &[f32], width: usize, height: usize, cfa: &CfaPattern, output: &mut [f32]) {
let npix = width * height;
let w = width as isize;
let h = height as isize;
let mut green_h = vec![0.0f32; npix];
let mut green_v = vec![0.0f32; npix];
for y in 0..height {
for x in 0..width {
let idx = y * width + x;
let val = input[idx];
if cfa.color_at(y, x) == Channel::Green {
green_h[idx] = val;
green_v[idx] = val;
continue;
}
let mut sum_h = 0.0f32;
let mut wt_h = 0.0f32;
for kx in -3isize..=3 {
let nx = x as isize + kx;
if nx >= 0 && nx < w {
let nx = nx as usize;
if cfa.color_at(y, nx) == Channel::Green {
let d = kx.unsigned_abs() as f32;
let weight = 1.0 / (1.0 + d * d);
sum_h += input[y * width + nx] * weight;
wt_h += weight;
}
}
}
green_h[idx] = if wt_h > 0.0 { sum_h / wt_h } else { val };
let mut sum_v = 0.0f32;
let mut wt_v = 0.0f32;
for ky in -3isize..=3 {
let ny = y as isize + ky;
if ny >= 0 && ny < h {
let ny = ny as usize;
if cfa.color_at(ny, x) == Channel::Green {
let d = ky.unsigned_abs() as f32;
let weight = 1.0 / (1.0 + d * d);
sum_v += input[ny * width + x] * weight;
wt_v += weight;
}
}
}
green_v[idx] = if wt_v > 0.0 { sum_v / wt_v } else { val };
}
}
for y in 0..height {
for x in 0..width {
let idx = y * width + x;
let ch = cfa.color_at(y, x);
let gh = green_h[idx];
let gv = green_v[idx];
let mut hom_h = 0.0f32;
let mut hom_v = 0.0f32;
for ky in -1isize..=1 {
for kx in -1isize..=1 {
if ky == 0 && kx == 0 {
continue;
}
let ny = (y as isize + ky).clamp(0, h - 1) as usize;
let nx = (x as isize + kx).clamp(0, w - 1) as usize;
let nidx = ny * width + nx;
hom_h += (green_h[nidx] - gh).abs();
hom_v += (green_v[nidx] - gv).abs();
}
}
let alpha = hom_v / (hom_h + hom_v + 1e-6);
let green = alpha * gh + (1.0 - alpha) * gv;
let mut rgb = [0.0f32; 3];
if ch == Channel::Green {
rgb[1] = input[idx];
rgb[0] = interp_color_diff(input, &green_h, &green_v, width, height, cfa, y, x, Channel::Red);
rgb[2] = interp_color_diff(input, &green_h, &green_v, width, height, cfa, y, x, Channel::Blue);
} else {
rgb[ch as usize] = input[idx];
rgb[1] = green;
let other = if ch == Channel::Red { Channel::Blue } else { Channel::Red };
rgb[other as usize] = interp_color_diff(input, &green_h, &green_v, width, height, cfa, y, x, other);
}
output[idx] = rgb[0];
output[npix + idx] = rgb[1];
output[2 * npix + idx] = rgb[2];
}
}
}
#[inline]
fn interp_color_diff(
input: &[f32],
green_h: &[f32],
green_v: &[f32],
width: usize,
height: usize,
cfa: &CfaPattern,
y: usize,
x: usize,
target_ch: Channel,
) -> f32 {
let w = width as isize;
let h = height as isize;
let here = y * width + x;
let green_here = (green_h[here] + green_v[here]) * 0.5;
let mut cd_sum = 0.0f32;
let mut cd_wt = 0.0f32;
for ky in -2isize..=2 {
let ny = y as isize + ky;
if ny < 0 || ny >= h {
continue;
}
let ny = ny as usize;
for kx in -2isize..=2 {
let nx = x as isize + kx;
if nx < 0 || nx >= w {
continue;
}
let nx = nx as usize;
if cfa.color_at(ny, nx) == target_ch {
let nidx = ny * width + nx;
let green_neighbor = (green_h[nidx] + green_v[nidx]) * 0.5;
let cd = input[nidx] - green_neighbor;
let d = (ky.unsigned_abs() + kx.unsigned_abs()) as f32;
let weight = 1.0 / (1.0 + d);
cd_sum += cd * weight;
cd_wt += weight;
}
}
}
if cd_wt > 0.0 {
green_here + cd_sum / cd_wt
} else {
green_here
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn uniform_color_preserved() {
let cfa = CfaPattern::xtrans_default();
let w = 24;
let h = 24;
let mut input = vec![0.0f32; w * h];
for y in 0..h {
for x in 0..w {
input[y * w + x] = match cfa.color_at(y, x) {
Channel::Red => 0.8,
Channel::Green => 0.5,
Channel::Blue => 0.3,
};
}
}
let mut output = vec![0.0f32; 3 * w * h];
demosaic(&input, w, h, &cfa, &mut output);
for y in 3..h - 3 {
for x in 3..w - 3 {
let idx = y * w + x;
let r = output[idx];
let g = output[w * h + idx];
let b = output[2 * w * h + idx];
assert!((r - 0.8).abs() < 0.1, "R at ({y},{x}) = {r}");
assert!((g - 0.5).abs() < 0.1, "G at ({y},{x}) = {g}");
assert!((b - 0.3).abs() < 0.1, "B at ({y},{x}) = {b}");
}
}
}
#[test]
fn known_channel_exact() {
let cfa = CfaPattern::xtrans_default();
let w = 18;
let h = 18;
let mut input = vec![0.0f32; w * h];
for i in 0..w * h {
input[i] = (i as f32) / (w * h) as f32;
}
let mut output = vec![0.0f32; 3 * w * h];
demosaic(&input, w, h, &cfa, &mut output);
for y in 0..h {
for x in 0..w {
let idx = y * w + x;
let ch = cfa.color_at(y, x) as usize;
let out_val = output[ch * w * h + idx];
let in_val = input[idx];
assert_eq!(out_val, in_val, "known channel at ({y},{x}) ch={ch}");
}
}
}
}