Skip to main content

ferray_ufunc/ops/
convolution.rs

1// ferray-ufunc: Convolution
2//
3// convolve with modes: Full, Same, Valid
4
5use ferray_core::Array;
6use ferray_core::dimension::Ix1;
7use ferray_core::dtype::Element;
8use ferray_core::error::{FerrayError, FerrayResult};
9
10/// Convolution mode.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ConvolveMode {
13    /// Full convolution output (length = N + M - 1).
14    Full,
15    /// Output has length max(N, M).
16    Same,
17    /// Output only where signals fully overlap (length = max(N, M) - min(N, M) + 1).
18    Valid,
19}
20
21/// Discrete, linear convolution of two 1-D arrays.
22///
23/// Computes `convolve(a, v, mode)` following NumPy semantics.
24pub fn convolve<T>(
25    a: &Array<T, Ix1>,
26    v: &Array<T, Ix1>,
27    mode: ConvolveMode,
28) -> FerrayResult<Array<T, Ix1>>
29where
30    T: Element + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + Copy,
31{
32    let a_data: Vec<T> = a.iter().copied().collect();
33    let v_data: Vec<T> = v.iter().copied().collect();
34    let n = a_data.len();
35    let m = v_data.len();
36
37    if n == 0 || m == 0 {
38        return Err(FerrayError::invalid_value(
39            "convolve: input arrays must be non-empty",
40        ));
41    }
42
43    // Full convolution
44    let full_len = n + m - 1;
45    let mut full = vec![<T as Element>::zero(); full_len];
46
47    for i in 0..n {
48        for j in 0..m {
49            full[i + j] = full[i + j] + a_data[i] * v_data[j];
50        }
51    }
52
53    match mode {
54        ConvolveMode::Full => Array::from_vec(Ix1::new([full_len]), full),
55        ConvolveMode::Same => {
56            let out_len = n.max(m);
57            let start = (full_len - out_len) / 2;
58            let result = full[start..start + out_len].to_vec();
59            Array::from_vec(Ix1::new([out_len]), result)
60        }
61        ConvolveMode::Valid => {
62            let out_len = if n >= m { n - m + 1 } else { m - n + 1 };
63            let start = m.min(n) - 1;
64            let result = full[start..start + out_len].to_vec();
65            Array::from_vec(Ix1::new([out_len]), result)
66        }
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
75        let n = data.len();
76        Array::from_vec(Ix1::new([n]), data).unwrap()
77    }
78
79    #[test]
80    fn test_convolve_full() {
81        let a = arr1(vec![1.0, 2.0, 3.0]);
82        let v = arr1(vec![0.0, 1.0, 0.5]);
83        let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
84        let s = r.as_slice().unwrap();
85        // [0*1, 1*1+0*2, 0.5*1+1*2+0*3, 0.5*2+1*3, 0.5*3]
86        // = [0, 1, 2.5, 4, 1.5]
87        assert_eq!(s.len(), 5);
88        assert!((s[0] - 0.0).abs() < 1e-12);
89        assert!((s[1] - 1.0).abs() < 1e-12);
90        assert!((s[2] - 2.5).abs() < 1e-12);
91        assert!((s[3] - 4.0).abs() < 1e-12);
92        assert!((s[4] - 1.5).abs() < 1e-12);
93    }
94
95    #[test]
96    fn test_convolve_same() {
97        let a = arr1(vec![1.0, 2.0, 3.0]);
98        let v = arr1(vec![0.0, 1.0, 0.5]);
99        let r = convolve(&a, &v, ConvolveMode::Same).unwrap();
100        assert_eq!(r.size(), 3);
101        let s = r.as_slice().unwrap();
102        // Full = [0, 1, 2.5, 4, 1.5], same takes middle 3 = [1, 2.5, 4]
103        assert!((s[0] - 1.0).abs() < 1e-12);
104        assert!((s[1] - 2.5).abs() < 1e-12);
105        assert!((s[2] - 4.0).abs() < 1e-12);
106    }
107
108    #[test]
109    fn test_convolve_valid() {
110        let a = arr1(vec![1.0, 2.0, 3.0]);
111        let v = arr1(vec![0.0, 1.0, 0.5]);
112        let r = convolve(&a, &v, ConvolveMode::Valid).unwrap();
113        assert_eq!(r.size(), 1);
114        let s = r.as_slice().unwrap();
115        assert!((s[0] - 2.5).abs() < 1e-12);
116    }
117
118    #[test]
119    fn test_convolve_simple() {
120        let a = arr1(vec![1.0, 1.0, 1.0]);
121        let v = arr1(vec![1.0, 1.0, 1.0]);
122        let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
123        let s = r.as_slice().unwrap();
124        assert_eq!(s.len(), 5);
125        assert!((s[0] - 1.0).abs() < 1e-12);
126        assert!((s[1] - 2.0).abs() < 1e-12);
127        assert!((s[2] - 3.0).abs() < 1e-12);
128        assert!((s[3] - 2.0).abs() < 1e-12);
129        assert!((s[4] - 1.0).abs() < 1e-12);
130    }
131
132    #[test]
133    fn test_convolve_i32() {
134        let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
135        let v = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 1]).unwrap();
136        let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
137        assert_eq!(r.as_slice().unwrap(), &[1, 3, 5, 3]);
138    }
139}