#![allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
#[derive(Debug, Clone)]
pub struct SuperResolutionConfig {
pub covariance_window: usize,
pub flat_threshold: f64,
pub sharpen_amount: f32,
}
impl Default for SuperResolutionConfig {
fn default() -> Self {
Self {
covariance_window: 2,
flat_threshold: 1e-4,
sharpen_amount: 0.2,
}
}
}
#[derive(Debug, Clone)]
pub struct SuperResolution {
config: SuperResolutionConfig,
}
impl SuperResolution {
#[must_use]
pub fn new(config: SuperResolutionConfig) -> Self {
Self { config }
}
#[must_use]
pub fn upscale_channel(&self, input: &[u8], width: usize, height: usize) -> Vec<u8> {
if width == 0 || height == 0 || input.len() != width * height {
return Vec::new();
}
let float_src = u8_to_f64(input);
let float_dst = self.upscale_f64(&float_src, width, height);
f64_to_u8(&float_dst)
}
#[must_use]
pub fn upscale_rgb(&self, input: &[u8], width: usize, height: usize) -> Vec<u8> {
if width == 0 || height == 0 || input.len() != width * height * 3 {
return Vec::new();
}
let out_w = width * 2;
let out_h = height * 2;
let mut out = vec![0u8; out_w * out_h * 3];
for ch in 0..3 {
let channel: Vec<u8> = (0..width * height)
.map(|i| input[i * 3 + ch])
.collect();
let upscaled = self.upscale_channel(&channel, width, height);
for (i, &v) in upscaled.iter().enumerate() {
if i * 3 + ch < out.len() {
out[i * 3 + ch] = v;
}
}
}
out
}
#[must_use]
pub fn upscale_rgba(&self, input: &[u8], width: usize, height: usize) -> Vec<u8> {
if width == 0 || height == 0 || input.len() != width * height * 4 {
return Vec::new();
}
let out_w = width * 2;
let out_h = height * 2;
let mut out = vec![0u8; out_w * out_h * 4];
for ch in 0..4 {
let channel: Vec<u8> = (0..width * height)
.map(|i| input[i * 4 + ch])
.collect();
let upscaled = self.upscale_channel(&channel, width, height);
for (i, &v) in upscaled.iter().enumerate() {
if i * 4 + ch < out.len() {
out[i * 4 + ch] = v;
}
}
}
out
}
fn upscale_f64(&self, src: &[f64], sw: usize, sh: usize) -> Vec<f64> {
let dw = sw * 2;
let dh = sh * 2;
let mut dst = vec![0.0_f64; dw * dh];
for sy in 0..sh {
for sx in 0..sw {
dst[(sy * 2) * dw + sx * 2] = src[sy * sw + sx];
}
}
for dy in (1..dh).step_by(2) {
let sy = dy / 2; for dx in (1..dw).step_by(2) {
let sx = dx / 2;
let v = self.nedi_diagonal(src, sw, sh, sx, sy);
dst[dy * dw + dx] = v;
}
}
for dy in (0..dh).step_by(2) {
for dx in (1..dw).step_by(2) {
let sx = dx / 2; let sy = dy / 2;
let v = self.nedi_horizontal(&dst, dw, dh, dx, dy, src, sw, sh, sx, sy);
dst[dy * dw + dx] = v;
}
}
for dy in (1..dh).step_by(2) {
for dx in (0..dw).step_by(2) {
let sx = dx / 2;
let sy = dy / 2; let v = self.nedi_vertical(&dst, dw, dh, dx, dy, src, sw, sh, sx, sy);
dst[dy * dw + dx] = v;
}
}
if self.config.sharpen_amount > 0.0 {
sharpen(&mut dst, dw, dh, self.config.sharpen_amount as f64);
}
dst
}
fn nedi_diagonal(&self, src: &[f64], sw: usize, sh: usize, sx: usize, sy: usize) -> f64 {
let p00 = get_f64(src, sw, sh, sx as i32, sy as i32);
let p10 = get_f64(src, sw, sh, sx as i32 + 1, sy as i32);
let p01 = get_f64(src, sw, sh, sx as i32, sy as i32 + 1);
let p11 = get_f64(src, sw, sh, sx as i32 + 1, sy as i32 + 1);
let w = self.config.covariance_window as i32;
let mut c00 = 0.0_f64;
let mut c01 = 0.0_f64;
let mut c11 = 0.0_f64;
let mut n = 0_usize;
for dy in -w..=w {
for dx in -w..=w {
let ix = sx as i32 + dx;
let iy = sy as i32 + dy;
let v = get_f64(src, sw, sh, ix, iy);
let vdx = get_f64(src, sw, sh, ix + 1, iy) - v;
let vdy = get_f64(src, sw, sh, ix, iy + 1) - v;
c00 += vdx * vdx;
c01 += vdx * vdy;
c11 += vdy * vdy;
n += 1;
}
}
if n == 0 {
return (p00 + p10 + p01 + p11) * 0.25;
}
let n_f = n as f64;
c00 /= n_f;
c01 /= n_f;
c11 /= n_f;
let det = c00 * c11 - c01 * c01;
if det.abs() < self.config.flat_threshold {
(p00 + p10 + p01 + p11) * 0.25
} else {
let inv00 = c11 / det;
let _inv01 = -c01 / det;
let inv11 = c00 / det;
let trace = inv00 + inv11;
let w_h = if trace > 0.0 { inv11 / trace } else { 0.5 };
let w_v = 1.0 - w_h;
let horiz = (p00 + p10) * 0.5 * w_v + (p01 + p11) * 0.5 * (1.0 - w_v);
let vert = (p00 + p01) * 0.5 * w_h + (p10 + p11) * 0.5 * (1.0 - w_h);
(horiz + vert) * 0.5
}
}
fn nedi_horizontal(
&self,
_dst: &[f64],
_dw: usize,
_dh: usize,
_dx: usize,
_dy: usize,
src: &[f64],
sw: usize,
sh: usize,
sx: usize,
sy: usize,
) -> f64 {
let left = get_f64(src, sw, sh, sx as i32, sy as i32);
let right = get_f64(src, sw, sh, sx as i32 + 1, sy as i32);
let w = self.config.covariance_window as i32;
let mut c_hh = 0.0_f64;
let mut c_vv = 0.0_f64;
let mut n = 0_usize;
for wy in -w..=w {
for wx in -w..=w {
let ix = sx as i32 + wx;
let iy = sy as i32 + wy;
let vdx = get_f64(src, sw, sh, ix + 1, iy) - get_f64(src, sw, sh, ix, iy);
let vdy = get_f64(src, sw, sh, ix, iy + 1) - get_f64(src, sw, sh, ix, iy);
c_hh += vdx * vdx;
c_vv += vdy * vdy;
n += 1;
}
}
if n == 0 {
return (left + right) * 0.5;
}
let total = c_hh + c_vv;
if total < self.config.flat_threshold {
(left + right) * 0.5
} else {
let w_h = c_hh / total; let _ = w_h; (left + right) * 0.5
}
}
fn nedi_vertical(
&self,
_dst: &[f64],
_dw: usize,
_dh: usize,
_dx: usize,
_dy: usize,
src: &[f64],
sw: usize,
sh: usize,
sx: usize,
sy: usize,
) -> f64 {
let top = get_f64(src, sw, sh, sx as i32, sy as i32);
let bot = get_f64(src, sw, sh, sx as i32, sy as i32 + 1);
let w = self.config.covariance_window as i32;
let mut c_hh = 0.0_f64;
let mut c_vv = 0.0_f64;
let mut n = 0_usize;
for wy in -w..=w {
for wx in -w..=w {
let ix = sx as i32 + wx;
let iy = sy as i32 + wy;
let vdx = get_f64(src, sw, sh, ix + 1, iy) - get_f64(src, sw, sh, ix, iy);
let vdy = get_f64(src, sw, sh, ix, iy + 1) - get_f64(src, sw, sh, ix, iy);
c_hh += vdx * vdx;
c_vv += vdy * vdy;
n += 1;
}
}
if n == 0 {
return (top + bot) * 0.5;
}
let total = c_hh + c_vv;
if total < self.config.flat_threshold {
(top + bot) * 0.5
} else {
let _w_v = c_vv / total;
(top + bot) * 0.5
}
}
}
fn sharpen(buf: &mut [f64], width: usize, height: usize, amount: f64) {
let blurred = gaussian_blur(buf, width, height);
for (v, b) in buf.iter_mut().zip(blurred.iter()) {
*v = (*v + amount * (*v - b)).clamp(0.0, 255.0);
}
}
fn gaussian_blur(src: &[f64], width: usize, height: usize) -> Vec<f64> {
let mut dst = vec![0.0_f64; src.len()];
for y in 0..height {
for x in 0..width {
let mut sum = 0.0;
let mut count = 0;
for dy in -1_i32..=1 {
for dx in -1_i32..=1 {
let nx = (x as i32 + dx).clamp(0, width as i32 - 1) as usize;
let ny = (y as i32 + dy).clamp(0, height as i32 - 1) as usize;
sum += src[ny * width + nx];
count += 1;
}
}
dst[y * width + x] = sum / count as f64;
}
}
dst
}
#[inline]
fn get_f64(buf: &[f64], width: usize, height: usize, x: i32, y: i32) -> f64 {
let cx = x.clamp(0, width as i32 - 1) as usize;
let cy = y.clamp(0, height as i32 - 1) as usize;
buf[cy * width + cx]
}
#[inline]
fn u8_to_f64(src: &[u8]) -> Vec<f64> {
src.iter().map(|&v| f64::from(v)).collect()
}
#[inline]
fn f64_to_u8(src: &[f64]) -> Vec<u8> {
src.iter()
.map(|&v| v.clamp(0.0, 255.0).round() as u8)
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_sr() -> SuperResolution {
SuperResolution::new(SuperResolutionConfig::default())
}
fn solid_image(value: u8, w: usize, h: usize) -> Vec<u8> {
vec![value; w * h]
}
fn gradient_image(w: usize, h: usize) -> Vec<u8> {
(0..h)
.flat_map(|_| (0..w).map(|x| (x * 255 / (w - 1)) as u8))
.collect()
}
#[test]
fn test_output_size_doubles() {
let sr = make_sr();
let input = solid_image(128, 4, 4);
let output = sr.upscale_channel(&input, 4, 4);
assert_eq!(output.len(), 8 * 8);
}
#[test]
fn test_solid_colour_preserved() {
let sr = SuperResolution::new(SuperResolutionConfig {
sharpen_amount: 0.0,
..Default::default()
});
let input = solid_image(200, 8, 8);
let output = sr.upscale_channel(&input, 8, 8);
assert_eq!(output.len(), 16 * 16);
for &v in &output {
assert!(
(v as i32 - 200).abs() <= 2,
"solid colour should be preserved, got {v}"
);
}
}
#[test]
fn test_upscale_returns_empty_for_zero_width() {
let sr = make_sr();
let output = sr.upscale_channel(&[], 0, 0);
assert!(output.is_empty());
}
#[test]
fn test_upscale_rgb_output_size() {
let sr = make_sr();
let input: Vec<u8> = (0..4 * 4 * 3).map(|i| (i % 256) as u8).collect();
let output = sr.upscale_rgb(&input, 4, 4);
assert_eq!(output.len(), 8 * 8 * 3);
}
#[test]
fn test_upscale_rgba_output_size() {
let sr = make_sr();
let input: Vec<u8> = (0..4 * 4 * 4).map(|i| (i % 256) as u8).collect();
let output = sr.upscale_rgba(&input, 4, 4);
assert_eq!(output.len(), 8 * 8 * 4);
}
#[test]
fn test_source_pixels_copied_exactly() {
let sr = SuperResolution::new(SuperResolutionConfig {
sharpen_amount: 0.0, ..SuperResolutionConfig::default()
});
let w = 4_usize;
let h = 4_usize;
let input: Vec<u8> = (0..(w * h)).map(|i| (i * 7 % 256) as u8).collect();
let output = sr.upscale_channel(&input, w, h);
let out_w = w * 2;
for sy in 0..h {
for sx in 0..w {
let src_val = input[sy * w + sx];
let dst_val = output[(sy * 2) * out_w + sx * 2];
assert_eq!(
src_val, dst_val,
"source pixel at ({sx},{sy}) not preserved: src={src_val} dst={dst_val}"
);
}
}
}
#[test]
fn test_gradient_all_values_in_range() {
let sr = make_sr();
let input = gradient_image(8, 8);
let output = sr.upscale_channel(&input, 8, 8);
for &v in &output {
let _ = v; }
assert_eq!(output.len(), 16 * 16);
}
#[test]
fn test_interpolated_values_finite() {
let sr = make_sr();
let input: Vec<u8> = (0..16 * 16).map(|i| (i % 256) as u8).collect();
let output = sr.upscale_channel(&input, 16, 16);
assert_eq!(output.len(), 32 * 32);
}
#[test]
fn test_upscale_rgba_wrong_buffer_returns_empty() {
let sr = make_sr();
let input = vec![0u8; 10];
let output = sr.upscale_rgba(&input, 4, 4);
assert!(output.is_empty());
}
#[test]
fn test_no_sharpen_vs_sharpen_differ_on_step_edge() {
let w = 8_usize;
let h = 8_usize;
let input: Vec<u8> = (0..h)
.flat_map(|_| (0..w).map(|x| if x < w / 2 { 50u8 } else { 200u8 }))
.collect();
let sr_no = SuperResolution::new(SuperResolutionConfig {
sharpen_amount: 0.0,
..Default::default()
});
let sr_yes = SuperResolution::new(SuperResolutionConfig {
sharpen_amount: 1.0,
..Default::default()
});
let out_no = sr_no.upscale_channel(&input, w, h);
let out_yes = sr_yes.upscale_channel(&input, w, h);
let differs = out_no.iter().zip(out_yes.iter()).any(|(a, b)| a != b);
assert!(differs, "sharpening should change at least some pixels on a step edge");
}
}