use ferray_core::Array;
use ferray_core::dimension::Ix1;
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConvolveMode {
Full,
Same,
Valid,
}
pub fn convolve<T>(
a: &Array<T, Ix1>,
v: &Array<T, Ix1>,
mode: ConvolveMode,
) -> FerrayResult<Array<T, Ix1>>
where
T: Element + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + Copy,
{
let a_data: Vec<T> = a.iter().copied().collect();
let v_data: Vec<T> = v.iter().copied().collect();
let n = a_data.len();
let m = v_data.len();
if n == 0 || m == 0 {
return Err(FerrayError::invalid_value(
"convolve: input arrays must be non-empty",
));
}
let full_len = n + m - 1;
let mut full = vec![<T as Element>::zero(); full_len];
for i in 0..n {
for j in 0..m {
full[i + j] = full[i + j] + a_data[i] * v_data[j];
}
}
match mode {
ConvolveMode::Full => Array::from_vec(Ix1::new([full_len]), full),
ConvolveMode::Same => {
let out_len = n.max(m);
let start = (full_len - out_len) / 2;
let result = full[start..start + out_len].to_vec();
Array::from_vec(Ix1::new([out_len]), result)
}
ConvolveMode::Valid => {
let out_len = if n >= m { n - m + 1 } else { m - n + 1 };
let start = m.min(n) - 1;
let result = full[start..start + out_len].to_vec();
Array::from_vec(Ix1::new([out_len]), result)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
let n = data.len();
Array::from_vec(Ix1::new([n]), data).unwrap()
}
#[test]
fn test_convolve_full() {
let a = arr1(vec![1.0, 2.0, 3.0]);
let v = arr1(vec![0.0, 1.0, 0.5]);
let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s.len(), 5);
assert!((s[0] - 0.0).abs() < 1e-12);
assert!((s[1] - 1.0).abs() < 1e-12);
assert!((s[2] - 2.5).abs() < 1e-12);
assert!((s[3] - 4.0).abs() < 1e-12);
assert!((s[4] - 1.5).abs() < 1e-12);
}
#[test]
fn test_convolve_same() {
let a = arr1(vec![1.0, 2.0, 3.0]);
let v = arr1(vec![0.0, 1.0, 0.5]);
let r = convolve(&a, &v, ConvolveMode::Same).unwrap();
assert_eq!(r.size(), 3);
let s = r.as_slice().unwrap();
assert!((s[0] - 1.0).abs() < 1e-12);
assert!((s[1] - 2.5).abs() < 1e-12);
assert!((s[2] - 4.0).abs() < 1e-12);
}
#[test]
fn test_convolve_valid() {
let a = arr1(vec![1.0, 2.0, 3.0]);
let v = arr1(vec![0.0, 1.0, 0.5]);
let r = convolve(&a, &v, ConvolveMode::Valid).unwrap();
assert_eq!(r.size(), 1);
let s = r.as_slice().unwrap();
assert!((s[0] - 2.5).abs() < 1e-12);
}
#[test]
fn test_convolve_simple() {
let a = arr1(vec![1.0, 1.0, 1.0]);
let v = arr1(vec![1.0, 1.0, 1.0]);
let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
let s = r.as_slice().unwrap();
assert_eq!(s.len(), 5);
assert!((s[0] - 1.0).abs() < 1e-12);
assert!((s[1] - 2.0).abs() < 1e-12);
assert!((s[2] - 3.0).abs() < 1e-12);
assert!((s[3] - 2.0).abs() < 1e-12);
assert!((s[4] - 1.0).abs() < 1e-12);
}
#[test]
fn test_convolve_i32() {
let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
let v = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 1]).unwrap();
let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
assert_eq!(r.as_slice().unwrap(), &[1, 3, 5, 3]);
}
}