ferray_ufunc/ops/
convolution.rs1use ferray_core::Array;
6use ferray_core::dimension::Ix1;
7use ferray_core::dtype::Element;
8use ferray_core::error::{FerrayError, FerrayResult};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ConvolveMode {
13 Full,
15 Same,
17 Valid,
19}
20
21pub fn convolve<T>(
25 a: &Array<T, Ix1>,
26 v: &Array<T, Ix1>,
27 mode: ConvolveMode,
28) -> FerrayResult<Array<T, Ix1>>
29where
30 T: Element + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + Copy,
31{
32 let a_data: Vec<T> = a.iter().copied().collect();
33 let v_data: Vec<T> = v.iter().copied().collect();
34 let n = a_data.len();
35 let m = v_data.len();
36
37 if n == 0 || m == 0 {
38 return Err(FerrayError::invalid_value(
39 "convolve: input arrays must be non-empty",
40 ));
41 }
42
43 let full_len = n + m - 1;
45 let mut full = vec![<T as Element>::zero(); full_len];
46
47 for i in 0..n {
48 for j in 0..m {
49 full[i + j] = full[i + j] + a_data[i] * v_data[j];
50 }
51 }
52
53 match mode {
54 ConvolveMode::Full => Array::from_vec(Ix1::new([full_len]), full),
55 ConvolveMode::Same => {
56 let out_len = n.max(m);
57 let start = (full_len - out_len) / 2;
58 let result = full[start..start + out_len].to_vec();
59 Array::from_vec(Ix1::new([out_len]), result)
60 }
61 ConvolveMode::Valid => {
62 let out_len = if n >= m { n - m + 1 } else { m - n + 1 };
63 let start = m.min(n) - 1;
64 let result = full[start..start + out_len].to_vec();
65 Array::from_vec(Ix1::new([out_len]), result)
66 }
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 fn arr1(data: Vec<f64>) -> Array<f64, Ix1> {
75 let n = data.len();
76 Array::from_vec(Ix1::new([n]), data).unwrap()
77 }
78
79 #[test]
80 fn test_convolve_full() {
81 let a = arr1(vec![1.0, 2.0, 3.0]);
82 let v = arr1(vec![0.0, 1.0, 0.5]);
83 let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
84 let s = r.as_slice().unwrap();
85 assert_eq!(s.len(), 5);
88 assert!((s[0] - 0.0).abs() < 1e-12);
89 assert!((s[1] - 1.0).abs() < 1e-12);
90 assert!((s[2] - 2.5).abs() < 1e-12);
91 assert!((s[3] - 4.0).abs() < 1e-12);
92 assert!((s[4] - 1.5).abs() < 1e-12);
93 }
94
95 #[test]
96 fn test_convolve_same() {
97 let a = arr1(vec![1.0, 2.0, 3.0]);
98 let v = arr1(vec![0.0, 1.0, 0.5]);
99 let r = convolve(&a, &v, ConvolveMode::Same).unwrap();
100 assert_eq!(r.size(), 3);
101 let s = r.as_slice().unwrap();
102 assert!((s[0] - 1.0).abs() < 1e-12);
104 assert!((s[1] - 2.5).abs() < 1e-12);
105 assert!((s[2] - 4.0).abs() < 1e-12);
106 }
107
108 #[test]
109 fn test_convolve_valid() {
110 let a = arr1(vec![1.0, 2.0, 3.0]);
111 let v = arr1(vec![0.0, 1.0, 0.5]);
112 let r = convolve(&a, &v, ConvolveMode::Valid).unwrap();
113 assert_eq!(r.size(), 1);
114 let s = r.as_slice().unwrap();
115 assert!((s[0] - 2.5).abs() < 1e-12);
116 }
117
118 #[test]
119 fn test_convolve_simple() {
120 let a = arr1(vec![1.0, 1.0, 1.0]);
121 let v = arr1(vec![1.0, 1.0, 1.0]);
122 let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
123 let s = r.as_slice().unwrap();
124 assert_eq!(s.len(), 5);
125 assert!((s[0] - 1.0).abs() < 1e-12);
126 assert!((s[1] - 2.0).abs() < 1e-12);
127 assert!((s[2] - 3.0).abs() < 1e-12);
128 assert!((s[3] - 2.0).abs() < 1e-12);
129 assert!((s[4] - 1.0).abs() < 1e-12);
130 }
131
132 #[test]
133 fn test_convolve_i32() {
134 let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
135 let v = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 1]).unwrap();
136 let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
137 assert_eq!(r.as_slice().unwrap(), &[1, 3, 5, 3]);
138 }
139}