use ferray_core::Array;
use ferray_core::dimension::Ix1;
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConvolveMode {
Full,
Same,
Valid,
}
pub fn convolve<T>(
a: &Array<T, Ix1>,
v: &Array<T, Ix1>,
mode: ConvolveMode,
) -> FerrayResult<Array<T, Ix1>>
where
T: Element + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + Copy,
{
let a_data: Vec<T> = a.iter().copied().collect();
let v_data: Vec<T> = v.iter().copied().collect();
let n = a_data.len();
let m = v_data.len();
if n == 0 || m == 0 {
return Err(FerrayError::invalid_value(
"convolve: input arrays must be non-empty",
));
}
let full_len = n + m - 1;
let mut full = vec![<T as Element>::zero(); full_len];
for k in 0..full_len {
let i_lo = (k + 1).saturating_sub(m);
let i_hi = (k + 1).min(n);
let mut acc = <T as Element>::zero();
for i in i_lo..i_hi {
acc = acc + a_data[i] * v_data[k - i];
}
full[k] = acc;
}
match mode {
ConvolveMode::Full => Array::from_vec(Ix1::new([full_len]), full),
ConvolveMode::Same => {
let out_len = n.max(m);
let start = (full_len - out_len) / 2;
let result = full[start..start + out_len].to_vec();
Array::from_vec(Ix1::new([out_len]), result)
}
ConvolveMode::Valid => {
let out_len = if n >= m { n - m + 1 } else { m - n + 1 };
let start = m.min(n) - 1;
let result = full[start..start + out_len].to_vec();
Array::from_vec(Ix1::new([out_len]), result)
}
}
}
#[cfg(feature = "fft-convolve")]
pub fn fftconvolve(
a: &Array<f64, Ix1>,
v: &Array<f64, Ix1>,
mode: ConvolveMode,
) -> FerrayResult<Array<f64, Ix1>> {
use ferray_fft::{FftNorm, irfft, rfft};
let n = a.size();
let m = v.size();
if n == 0 || m == 0 {
return Err(FerrayError::invalid_value(
"fftconvolve: input arrays must be non-empty",
));
}
let full_len = n + m - 1;
let mut a_pad = vec![0.0f64; full_len];
let mut v_pad = vec![0.0f64; full_len];
for (dst, &src) in a_pad.iter_mut().zip(a.iter()) {
*dst = src;
}
for (dst, &src) in v_pad.iter_mut().zip(v.iter()) {
*dst = src;
}
let a_padded = Array::<f64, Ix1>::from_vec(Ix1::new([full_len]), a_pad)?;
let v_padded = Array::<f64, Ix1>::from_vec(Ix1::new([full_len]), v_pad)?;
let a_fft = rfft(&a_padded, None, None, FftNorm::Backward)?;
let v_fft = rfft(&v_padded, None, None, FftNorm::Backward)?;
let a_spec: Vec<num_complex::Complex<f64>> = a_fft.iter().copied().collect();
let v_spec: Vec<num_complex::Complex<f64>> = v_fft.iter().copied().collect();
let prod: Vec<num_complex::Complex<f64>> = a_spec
.iter()
.zip(v_spec.iter())
.map(|(a, b)| a * b)
.collect();
let prod_arr = Array::<num_complex::Complex<f64>, Ix1>::from_vec(Ix1::new([prod.len()]), prod)?;
let inv = irfft(&prod_arr, Some(full_len), None, FftNorm::Backward)?;
let inv_data: Vec<f64> = inv.iter().copied().collect();
match mode {
ConvolveMode::Full => Array::from_vec(Ix1::new([full_len]), inv_data),
ConvolveMode::Same => {
let out_len = n.max(m);
let start = (full_len - out_len) / 2;
let slice = inv_data[start..start + out_len].to_vec();
Array::from_vec(Ix1::new([out_len]), slice)
}
ConvolveMode::Valid => {
let out_len = if n >= m { n - m + 1 } else { m - n + 1 };
let start = m.min(n) - 1;
let slice = inv_data[start..start + out_len].to_vec();
Array::from_vec(Ix1::new([out_len]), slice)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_util::arr1;
#[test]
fn test_convolve_full() {
let a = arr1(vec![1.0, 2.0, 3.0]);
let v = arr1(vec![0.0, 1.0, 0.5]);
let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s.len(), 5);
assert!((s[0] - 0.0).abs() < 1e-12);
assert!((s[1] - 1.0).abs() < 1e-12);
assert!((s[2] - 2.5).abs() < 1e-12);
assert!((s[3] - 4.0).abs() < 1e-12);
assert!((s[4] - 1.5).abs() < 1e-12);
}
#[test]
fn test_convolve_same() {
let a = arr1(vec![1.0, 2.0, 3.0]);
let v = arr1(vec![0.0, 1.0, 0.5]);
let r = convolve(&a, &v, ConvolveMode::Same).unwrap();
assert_eq!(r.size(), 3);
let s = r.as_slice().unwrap();
assert!((s[0] - 1.0).abs() < 1e-12);
assert!((s[1] - 2.5).abs() < 1e-12);
assert!((s[2] - 4.0).abs() < 1e-12);
}
#[test]
fn test_convolve_valid() {
let a = arr1(vec![1.0, 2.0, 3.0]);
let v = arr1(vec![0.0, 1.0, 0.5]);
let r = convolve(&a, &v, ConvolveMode::Valid).unwrap();
assert_eq!(r.size(), 1);
let s = r.as_slice().unwrap();
assert!((s[0] - 2.5).abs() < 1e-12);
}
#[test]
fn test_convolve_simple() {
let a = arr1(vec![1.0, 1.0, 1.0]);
let v = arr1(vec![1.0, 1.0, 1.0]);
let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s.len(), 5);
assert!((s[0] - 1.0).abs() < 1e-12);
assert!((s[1] - 2.0).abs() < 1e-12);
assert!((s[2] - 3.0).abs() < 1e-12);
assert!((s[3] - 2.0).abs() < 1e-12);
assert!((s[4] - 1.0).abs() < 1e-12);
}
#[test]
fn test_convolve_i32() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
let v = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 1]).unwrap();
let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
assert_eq!(r.as_slice().unwrap(), &[1, 3, 5, 3]);
}
}