use alloc::vec;
use alloc::vec::Vec;
use crate::dynmatrix::DynMatrix;
use crate::traits::{FloatScalar, MatrixMut, MatrixRef};
pub fn resize_bilinear<T: FloatScalar>(
src: &DynMatrix<T>,
new_rows: usize,
new_cols: usize,
) -> DynMatrix<T> {
let h_in = src.nrows();
let w_in = src.ncols();
let mut dst = DynMatrix::<T>::zeros(new_rows, new_cols);
if h_in == 0 || w_in == 0 || new_rows == 0 || new_cols == 0 {
return dst;
}
let half = T::from(0.5_f64).unwrap();
let h_in_t = T::from(h_in).unwrap();
let w_in_t = T::from(w_in).unwrap();
let h_out_t = T::from(new_rows).unwrap();
let w_out_t = T::from(new_cols).unwrap();
let sy = h_in_t / h_out_t;
let sx = w_in_t / w_out_t;
let max_r = h_in - 1;
let max_c = w_in - 1;
let mut i0s: Vec<usize> = vec![0; new_rows];
let mut i1s: Vec<usize> = vec![0; new_rows];
let mut tys: Vec<T> = vec![T::zero(); new_rows];
for i_out in 0..new_rows {
let y = (T::from(i_out).unwrap() + half) * sy - half;
let (i0, i1, ty) = map_axis(y, max_r);
i0s[i_out] = i0;
i1s[i_out] = i1;
tys[i_out] = ty;
}
let mut j0s: Vec<usize> = vec![0; new_cols];
let mut j1s: Vec<usize> = vec![0; new_cols];
let mut txs: Vec<T> = vec![T::zero(); new_cols];
for j_out in 0..new_cols {
let x = (T::from(j_out).unwrap() + half) * sx - half;
let (j0, j1, tx) = map_axis(x, max_c);
j0s[j_out] = j0;
j1s[j_out] = j1;
txs[j_out] = tx;
}
for j_out in 0..new_cols {
let j0 = j0s[j_out];
let j1 = j1s[j_out];
let tx = txs[j_out];
let src_j0 = src.col_as_slice(j0, 0);
let src_j1 = src.col_as_slice(j1, 0);
let dst_col = dst.col_as_mut_slice(j_out, 0);
for i_out in 0..new_rows {
let i0 = i0s[i_out];
let i1 = i1s[i_out];
let ty = tys[i_out];
let a = src_j0[i0];
let b = src_j1[i0];
let c = src_j0[i1];
let d = src_j1[i1];
let top = a + (b - a) * tx;
let bot = c + (d - c) * tx;
dst_col[i_out] = top + (bot - top) * ty;
}
}
dst
}
#[inline]
fn map_axis<T: FloatScalar>(u: T, max: usize) -> (usize, usize, T) {
let zero = T::zero();
let max_t = T::from(max).unwrap();
let u_clamped = if u < zero {
zero
} else if u > max_t {
max_t
} else {
u
};
let i0_t = u_clamped.floor();
let i0 = i0_t.to_usize().unwrap_or(0).min(max);
let i1 = (i0 + 1).min(max);
let t = u_clamped - i0_t;
(i0, i1, t)
}