use alloc::vec;
use alloc::vec::Vec;
use crate::dynmatrix::DynMatrix;
use crate::traits::{FloatScalar, MatrixMut, MatrixRef};
use super::border::{fetch_border, BorderMode};
pub fn max_filter<T: FloatScalar>(
src: &DynMatrix<T>,
radius: usize,
border: BorderMode<T>,
) -> DynMatrix<T> {
filter_1d_then_1d(src, radius, border, max_of)
}
pub fn min_filter<T: FloatScalar>(
src: &DynMatrix<T>,
radius: usize,
border: BorderMode<T>,
) -> DynMatrix<T> {
filter_1d_then_1d(src, radius, border, min_of)
}
pub fn dilate<T: FloatScalar>(
src: &DynMatrix<T>,
radius: usize,
border: BorderMode<T>,
) -> DynMatrix<T> {
max_filter(src, radius, border)
}
pub fn erode<T: FloatScalar>(
src: &DynMatrix<T>,
radius: usize,
border: BorderMode<T>,
) -> DynMatrix<T> {
min_filter(src, radius, border)
}
pub fn opening<T: FloatScalar>(
src: &DynMatrix<T>,
radius: usize,
border: BorderMode<T>,
) -> DynMatrix<T> {
let eroded = erode(src, radius, border);
dilate(&eroded, radius, border)
}
pub fn closing<T: FloatScalar>(
src: &DynMatrix<T>,
radius: usize,
border: BorderMode<T>,
) -> DynMatrix<T> {
let dilated = dilate(src, radius, border);
erode(&dilated, radius, border)
}
pub fn morphology_gradient<T: FloatScalar>(
src: &DynMatrix<T>,
radius: usize,
border: BorderMode<T>,
) -> DynMatrix<T> {
let d = dilate(src, radius, border);
let e = erode(src, radius, border);
let nrows = src.nrows();
let ncols = src.ncols();
let mut out = DynMatrix::<T>::zeros(nrows, ncols);
for j in 0..ncols {
for i in 0..nrows {
out[(i, j)] = d[(i, j)] - e[(i, j)];
}
}
out
}
pub fn top_hat<T: FloatScalar>(
src: &DynMatrix<T>,
radius: usize,
border: BorderMode<T>,
) -> DynMatrix<T> {
let op = opening(src, radius, border);
let nrows = src.nrows();
let ncols = src.ncols();
let mut out = DynMatrix::<T>::zeros(nrows, ncols);
for j in 0..ncols {
for i in 0..nrows {
out[(i, j)] = src[(i, j)] - op[(i, j)];
}
}
out
}
pub fn black_hat<T: FloatScalar>(
src: &DynMatrix<T>,
radius: usize,
border: BorderMode<T>,
) -> DynMatrix<T> {
let cl = closing(src, radius, border);
let nrows = src.nrows();
let ncols = src.ncols();
let mut out = DynMatrix::<T>::zeros(nrows, ncols);
for j in 0..ncols {
for i in 0..nrows {
out[(i, j)] = cl[(i, j)] - src[(i, j)];
}
}
out
}
#[inline]
fn max_of<T: FloatScalar>(a: T, b: T) -> T {
if a >= b {
a
} else {
b
}
}
#[inline]
fn min_of<T: FloatScalar>(a: T, b: T) -> T {
if a <= b {
a
} else {
b
}
}
fn filter_1d_then_1d<T: FloatScalar>(
src: &DynMatrix<T>,
radius: usize,
border: BorderMode<T>,
combine: fn(T, T) -> T,
) -> DynMatrix<T> {
let nrows = src.nrows();
let ncols = src.ncols();
let mut out = DynMatrix::<T>::zeros(nrows, ncols);
if nrows == 0 || ncols == 0 {
return out;
}
if radius == 0 {
return src.clone();
}
let k = 2 * radius + 1;
let mut padded: Vec<T> = Vec::with_capacity(nrows.max(ncols) + 2 * radius);
let mut g: Vec<T> = vec![T::zero(); nrows.max(ncols) + 2 * radius];
let mut h: Vec<T> = vec![T::zero(); nrows.max(ncols) + 2 * radius];
let mut tmp = DynMatrix::<T>::zeros(nrows, ncols);
for j in 0..ncols {
let src_col = src.col_as_slice(j, 0);
let tmp_col = tmp.col_as_mut_slice(j, 0);
van_herk_1d(src_col, tmp_col, radius, k, border, combine, &mut padded, &mut g, &mut h);
}
let mut row_in: Vec<T> = vec![T::zero(); ncols];
let mut row_out: Vec<T> = vec![T::zero(); ncols];
for i in 0..nrows {
for j in 0..ncols {
row_in[j] = tmp[(i, j)];
}
van_herk_1d(&row_in, &mut row_out, radius, k, border, combine, &mut padded, &mut g, &mut h);
for j in 0..ncols {
out[(i, j)] = row_out[j];
}
}
out
}
#[allow(clippy::too_many_arguments)]
fn van_herk_1d<T: FloatScalar>(
src: &[T],
dst: &mut [T],
radius: usize,
k: usize,
border: BorderMode<T>,
combine: fn(T, T) -> T,
padded: &mut Vec<T>,
g: &mut Vec<T>,
h: &mut Vec<T>,
) {
let n = src.len();
if n == 0 {
return;
}
let total = n + 2 * radius;
padded.clear();
for i in 0..total {
let idx = i as isize - radius as isize;
padded.push(fetch_border(src, idx, border));
}
if g.len() < total {
g.resize(total, T::zero());
}
if h.len() < total {
h.resize(total, T::zero());
}
for i in 0..total {
if i % k == 0 {
g[i] = padded[i];
} else {
g[i] = combine(g[i - 1], padded[i]);
}
}
for i in (0..total).rev() {
let is_block_end = (i + 1) % k == 0 || i == total - 1;
if is_block_end {
h[i] = padded[i];
} else {
h[i] = combine(h[i + 1], padded[i]);
}
}
for p in 0..n {
dst[p] = combine(h[p], g[p + 2 * radius]);
}
}