1use ferray_core::Array;
6use ferray_core::dimension::Ix1;
7use ferray_core::dtype::Element;
8use ferray_core::error::{FerrayError, FerrayResult};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum ConvolveMode {
13 Full,
15 Same,
17 Valid,
19}
20
21pub fn convolve<T>(
35 a: &Array<T, Ix1>,
36 v: &Array<T, Ix1>,
37 mode: ConvolveMode,
38) -> FerrayResult<Array<T, Ix1>>
39where
40 T: Element + std::ops::Add<Output = T> + std::ops::Mul<Output = T> + Copy,
41{
42 let a_data: Vec<T> = a.iter().copied().collect();
43 let v_data: Vec<T> = v.iter().copied().collect();
44 let n = a_data.len();
45 let m = v_data.len();
46
47 if n == 0 || m == 0 {
48 return Err(FerrayError::invalid_value(
49 "convolve: input arrays must be non-empty",
50 ));
51 }
52
53 let full_len = n + m - 1;
56 let mut full = vec![<T as Element>::zero(); full_len];
57
58 for k in 0..full_len {
59 let i_lo = (k + 1).saturating_sub(m);
61 let i_hi = (k + 1).min(n);
62 let mut acc = <T as Element>::zero();
63 for i in i_lo..i_hi {
64 acc = acc + a_data[i] * v_data[k - i];
69 }
70 full[k] = acc;
71 }
72
73 match mode {
74 ConvolveMode::Full => Array::from_vec(Ix1::new([full_len]), full),
75 ConvolveMode::Same => {
76 let out_len = n.max(m);
77 let start = (full_len - out_len) / 2;
78 let result = full[start..start + out_len].to_vec();
79 Array::from_vec(Ix1::new([out_len]), result)
80 }
81 ConvolveMode::Valid => {
82 let out_len = if n >= m { n - m + 1 } else { m - n + 1 };
83 let start = m.min(n) - 1;
84 let result = full[start..start + out_len].to_vec();
85 Array::from_vec(Ix1::new([out_len]), result)
86 }
87 }
88}
89
90#[cfg(feature = "fft-convolve")]
104pub fn fftconvolve(
105 a: &Array<f64, Ix1>,
106 v: &Array<f64, Ix1>,
107 mode: ConvolveMode,
108) -> FerrayResult<Array<f64, Ix1>> {
109 use ferray_fft::{FftNorm, irfft, rfft};
110
111 let n = a.size();
112 let m = v.size();
113 if n == 0 || m == 0 {
114 return Err(FerrayError::invalid_value(
115 "fftconvolve: input arrays must be non-empty",
116 ));
117 }
118
119 let full_len = n + m - 1;
121 let mut a_pad = vec![0.0f64; full_len];
122 let mut v_pad = vec![0.0f64; full_len];
123 for (dst, &src) in a_pad.iter_mut().zip(a.iter()) {
124 *dst = src;
125 }
126 for (dst, &src) in v_pad.iter_mut().zip(v.iter()) {
127 *dst = src;
128 }
129 let a_padded = Array::<f64, Ix1>::from_vec(Ix1::new([full_len]), a_pad)?;
130 let v_padded = Array::<f64, Ix1>::from_vec(Ix1::new([full_len]), v_pad)?;
131
132 let a_fft = rfft(&a_padded, None, None, FftNorm::Backward)?;
134 let v_fft = rfft(&v_padded, None, None, FftNorm::Backward)?;
135
136 let a_spec: Vec<num_complex::Complex<f64>> = a_fft.iter().copied().collect();
137 let v_spec: Vec<num_complex::Complex<f64>> = v_fft.iter().copied().collect();
138 let prod: Vec<num_complex::Complex<f64>> = a_spec
139 .iter()
140 .zip(v_spec.iter())
141 .map(|(a, b)| a * b)
142 .collect();
143 let prod_arr = Array::<num_complex::Complex<f64>, Ix1>::from_vec(Ix1::new([prod.len()]), prod)?;
144
145 let inv = irfft(&prod_arr, Some(full_len), None, FftNorm::Backward)?;
146 let inv_data: Vec<f64> = inv.iter().copied().collect();
147
148 match mode {
149 ConvolveMode::Full => Array::from_vec(Ix1::new([full_len]), inv_data),
150 ConvolveMode::Same => {
151 let out_len = n.max(m);
152 let start = (full_len - out_len) / 2;
153 let slice = inv_data[start..start + out_len].to_vec();
154 Array::from_vec(Ix1::new([out_len]), slice)
155 }
156 ConvolveMode::Valid => {
157 let out_len = if n >= m { n - m + 1 } else { m - n + 1 };
158 let start = m.min(n) - 1;
159 let slice = inv_data[start..start + out_len].to_vec();
160 Array::from_vec(Ix1::new([out_len]), slice)
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 use crate::test_util::arr1;
170
171 #[test]
172 fn test_convolve_full() {
173 let a = arr1(vec![1.0, 2.0, 3.0]);
174 let v = arr1(vec![0.0, 1.0, 0.5]);
175 let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
176 let s = r.as_slice().unwrap();
177 assert_eq!(s.len(), 5);
180 assert!((s[0] - 0.0).abs() < 1e-12);
181 assert!((s[1] - 1.0).abs() < 1e-12);
182 assert!((s[2] - 2.5).abs() < 1e-12);
183 assert!((s[3] - 4.0).abs() < 1e-12);
184 assert!((s[4] - 1.5).abs() < 1e-12);
185 }
186
187 #[test]
188 fn test_convolve_same() {
189 let a = arr1(vec![1.0, 2.0, 3.0]);
190 let v = arr1(vec![0.0, 1.0, 0.5]);
191 let r = convolve(&a, &v, ConvolveMode::Same).unwrap();
192 assert_eq!(r.size(), 3);
193 let s = r.as_slice().unwrap();
194 assert!((s[0] - 1.0).abs() < 1e-12);
196 assert!((s[1] - 2.5).abs() < 1e-12);
197 assert!((s[2] - 4.0).abs() < 1e-12);
198 }
199
200 #[test]
201 fn test_convolve_valid() {
202 let a = arr1(vec![1.0, 2.0, 3.0]);
203 let v = arr1(vec![0.0, 1.0, 0.5]);
204 let r = convolve(&a, &v, ConvolveMode::Valid).unwrap();
205 assert_eq!(r.size(), 1);
206 let s = r.as_slice().unwrap();
207 assert!((s[0] - 2.5).abs() < 1e-12);
208 }
209
210 #[test]
211 fn test_convolve_simple() {
212 let a = arr1(vec![1.0, 1.0, 1.0]);
213 let v = arr1(vec![1.0, 1.0, 1.0]);
214 let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
215 let s = r.as_slice().unwrap();
216 assert_eq!(s.len(), 5);
217 assert!((s[0] - 1.0).abs() < 1e-12);
218 assert!((s[1] - 2.0).abs() < 1e-12);
219 assert!((s[2] - 3.0).abs() < 1e-12);
220 assert!((s[3] - 2.0).abs() < 1e-12);
221 assert!((s[4] - 1.0).abs() < 1e-12);
222 }
223
224 #[test]
225 fn test_convolve_i32() {
226 let a = Array::<i32, Ix1>::from_vec(Ix1::new([3]), vec![1, 2, 3]).unwrap();
227 let v = Array::<i32, Ix1>::from_vec(Ix1::new([2]), vec![1, 1]).unwrap();
228 let r = convolve(&a, &v, ConvolveMode::Full).unwrap();
229 assert_eq!(r.as_slice().unwrap(), &[1, 3, 5, 3]);
230 }
231}