use crate::{Filter, PaddingType};
use photon_rs::transform::padding_uniform as uniform;
use photon_rs::PhotonImage;
use photon_rs::Rgba;
use rayon::prelude::*;
fn convolve(img_padded: &PhotonImage, filter: &Filter, width_conv: u32, height_conv: u32, stride: u32) -> PhotonImage {
let raw = img_padded.get_raw_pixels();
let wp = img_padded.get_width() as usize;
let fw = filter.width;
let fh = filter.height;
let kernel = &filter.kernel;
let wc = width_conv as usize;
let hc = height_conv as usize;
let stride = stride as usize;
let out_size = wc * hc * 4;
let mut out = vec![0u8; out_size];
out.par_chunks_mut(wc * 4).enumerate().for_each(|(yc, row_out)| {
let row_base = yc * stride;
for xc in 0..wc {
let col_base = xc * stride;
let mut r: f32 = 0.0;
let mut g: f32 = 0.0;
let mut b: f32 = 0.0;
for fy in 0..fh {
let row_offset = (row_base + fy) * wp;
let k_row = fy * fw;
for fx in 0..fw {
let px = (row_offset + col_base + fx) * 4;
let k = kernel[k_row + fx];
r += raw[px] as f32 * k;
g += raw[px + 1] as f32 * k;
b += raw[px + 2] as f32 * k;
}
}
let i = xc * 4;
row_out[i] = r.clamp(0.0, 255.0) as u8;
row_out[i + 1] = g.clamp(0.0, 255.0) as u8;
row_out[i + 2] = b.clamp(0.0, 255.0) as u8;
row_out[i + 3] = 255_u8;
}
});
debug_assert_eq!(out.len(), out_size);
#[cfg(debug_assertions)]
println!("Convolution done (rayon)...");
PhotonImage::new(out, width_conv, height_conv)
}
fn separable_convolve(
img_padded: &PhotonImage,
row_vec: &[f32],
col_vec: &[f32],
width_conv: u32,
height_conv: u32,
stride: u32,
) -> PhotonImage {
let raw = img_padded.get_raw_pixels();
let wp = img_padded.get_width() as usize;
let hp = img_padded.get_height() as usize;
let fw = row_vec.len();
let fh = col_vec.len();
let wc = width_conv as usize;
let hc = height_conv as usize;
let stride = stride as usize;
let temp_w = wc;
let temp_size = hp * temp_w * 3;
let mut temp: Vec<f32> = vec![0.0; temp_size];
temp.par_chunks_mut(temp_w * 3).enumerate().for_each(|(y, row_temp)| {
let row_input = y * wp;
for x in 0..temp_w {
let col_input = x * stride;
let mut r: f32 = 0.0;
let mut g: f32 = 0.0;
let mut b: f32 = 0.0;
for fx in 0..fw {
let px = (row_input + col_input + fx) * 4;
let k = row_vec[fx];
r += raw[px] as f32 * k;
g += raw[px + 1] as f32 * k;
b += raw[px + 2] as f32 * k;
}
let t = x * 3;
row_temp[t] = r;
row_temp[t + 1] = g;
row_temp[t + 2] = b;
}
});
let out_size = wc * hc * 4;
let mut out = vec![0u8; out_size];
out.par_chunks_mut(wc * 4).enumerate().for_each(|(yc, row_out)| {
let row_base = yc * stride;
for xc in 0..wc {
let mut r: f32 = 0.0;
let mut g: f32 = 0.0;
let mut b: f32 = 0.0;
for fy in 0..fh {
let t = ((row_base + fy) * temp_w + xc) * 3;
let k = col_vec[fy];
r += temp[t] * k;
g += temp[t + 1] * k;
b += temp[t + 2] * k;
}
let i = xc * 4;
row_out[i] = r.clamp(0.0, 255.0) as u8;
row_out[i + 1] = g.clamp(0.0, 255.0) as u8;
row_out[i + 2] = b.clamp(0.0, 255.0) as u8;
row_out[i + 3] = 255_u8;
}
});
debug_assert_eq!(out.len(), out_size);
#[cfg(debug_assertions)]
println!("Separable convolution done (rayon)...");
PhotonImage::new(out, width_conv, height_conv)
}
#[inline]
fn output_dim(input_size: u32, filter_size: u32, pad: u32, stride: u32) -> u32 {
let dim = input_size - filter_size + 2 * pad;
if dim % stride != 0 {
eprintln!("[WARNING]: stride value not suitable. Convolution may fail.");
}
dim / stride + 1
}
pub fn convolution(img: &PhotonImage, filter: Filter, stride: u32, padding: PaddingType) -> PhotonImage {
if stride == 0 {
eprintln!("[ERROR]: Stride provided = 0");
std::process::exit(1);
}
let separable = filter.try_separable();
match &padding {
PaddingType::UNIFORM(pad_amt) => {
let img_padded = uniform(img, *pad_amt, Rgba::new(0, 0, 0, 255));
let wc = output_dim(img.get_width(), filter.width as u32, *pad_amt, stride);
let hc = output_dim(img.get_height(), filter.height as u32, *pad_amt, stride);
if let Some((col, row)) = separable {
separable_convolve(&img_padded, &row, &col, wc, hc, stride)
} else {
convolve(&img_padded, &filter, wc, hc, stride)
}
}
PaddingType::NONE => {
let wc = output_dim(img.get_width(), filter.width as u32, 0, stride);
let hc = output_dim(img.get_height(), filter.height as u32, 0, stride);
if let Some((col, row)) = separable {
separable_convolve(img, &row, &col, wc, hc, stride)
} else {
convolve(img, &filter, wc, hc, stride)
}
}
}
}