1use pyo3::prelude::*;
6use pyo3::types::PyDict;
7use scirs2_core::python::numpy_compat::{scirs_to_numpy_array1, Array1};
8use scirs2_numpy::{PyArray1, PyReadonlyArray1};
9
10use scirs2_signal::hilbert::hilbert;
12
13use scirs2_signal::filter::fir::firwin;
15use scirs2_signal::filter::iir::{butter, cheby1};
16use scirs2_signal::filter::FilterType;
17
18#[pyfunction]
29#[pyo3(signature = (a, v, mode="full"))]
30fn convolve_py(
31 py: Python,
32 a: PyReadonlyArray1<f64>,
33 v: PyReadonlyArray1<f64>,
34 mode: &str,
35) -> PyResult<Py<PyArray1<f64>>> {
36 let a_arr = a.as_array();
37 let v_arr = v.as_array();
38 let a_slice = a_arr.as_slice().ok_or_else(|| {
39 pyo3::exceptions::PyValueError::new_err("Array 'a' is not contiguous in memory")
40 })?;
41 let v_slice = v_arr.as_slice().ok_or_else(|| {
42 pyo3::exceptions::PyValueError::new_err("Array 'v' is not contiguous in memory")
43 })?;
44 let n = a_slice.len();
45 let m = v_slice.len();
46
47 if n == 0 || m == 0 {
48 return Err(pyo3::exceptions::PyValueError::new_err(
49 "Arrays must not be empty",
50 ));
51 }
52
53 let (out_len, offset) = match mode {
55 "full" => (n + m - 1, 0),
56 "same" => (n, (m - 1) / 2),
57 "valid" => {
58 if n < m {
59 return Err(pyo3::exceptions::PyValueError::new_err(
60 "For 'valid' mode, first array must be at least as long as second",
61 ));
62 }
63 (n - m + 1, m - 1)
64 }
65 _ => {
66 return Err(pyo3::exceptions::PyValueError::new_err(
67 "mode must be 'full', 'same', or 'valid'",
68 ))
69 }
70 };
71
72 let mut result = vec![0.0f64; out_len];
74
75 for (i, res) in result.iter_mut().enumerate() {
76 let full_idx = i + offset;
77 let mut sum = 0.0f64;
78 for (j, &vj) in v_slice.iter().enumerate() {
79 let ai = full_idx as isize - j as isize;
80 if ai >= 0 && (ai as usize) < n {
81 sum += a_slice[ai as usize] * vj;
82 }
83 }
84 *res = sum;
85 }
86
87 scirs_to_numpy_array1(Array1::from_vec(result), py)
88}
89
90#[pyfunction]
97#[pyo3(signature = (a, v, mode="full"))]
98fn correlate_py(
99 py: Python,
100 a: PyReadonlyArray1<f64>,
101 v: PyReadonlyArray1<f64>,
102 mode: &str,
103) -> PyResult<Py<PyArray1<f64>>> {
104 let a_arr = a.as_array();
105 let v_arr = v.as_array();
106 let a_slice = a_arr.as_slice().ok_or_else(|| {
107 pyo3::exceptions::PyValueError::new_err("Array 'a' is not contiguous in memory")
108 })?;
109 let v_slice = v_arr.as_slice().ok_or_else(|| {
110 pyo3::exceptions::PyValueError::new_err("Array 'v' is not contiguous in memory")
111 })?;
112 let n = a_slice.len();
113 let m = v_slice.len();
114
115 if n == 0 || m == 0 {
116 return Err(pyo3::exceptions::PyValueError::new_err(
117 "Arrays must not be empty",
118 ));
119 }
120
121 let (out_len, offset) = match mode {
124 "full" => (n + m - 1, 0),
125 "same" => (n, (m - 1) / 2),
126 "valid" => {
127 if n < m {
128 return Err(pyo3::exceptions::PyValueError::new_err(
129 "For 'valid' mode, first array must be at least as long as second",
130 ));
131 }
132 (n - m + 1, m - 1)
133 }
134 _ => {
135 return Err(pyo3::exceptions::PyValueError::new_err(
136 "mode must be 'full', 'same', or 'valid'",
137 ))
138 }
139 };
140
141 let mut result = vec![0.0f64; out_len];
143
144 for (i, res) in result.iter_mut().enumerate() {
145 let full_idx = i + offset;
146 let mut sum = 0.0f64;
147 for (j, &vj) in v_slice.iter().rev().enumerate() {
148 let ai = full_idx as isize - j as isize;
149 if ai >= 0 && (ai as usize) < n {
150 sum += a_slice[ai as usize] * vj;
151 }
152 }
153 *res = sum;
154 }
155
156 scirs_to_numpy_array1(Array1::from_vec(result), py)
157}
158
159#[pyfunction]
167fn hilbert_py(py: Python, x: PyReadonlyArray1<f64>) -> PyResult<Py<PyAny>> {
168 let x_slice = x.as_array().to_vec();
169
170 let result =
171 hilbert(&x_slice).map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
172
173 let real: Vec<f64> = result.iter().map(|c| c.re).collect();
175 let imag: Vec<f64> = result.iter().map(|c| c.im).collect();
176
177 let dict = PyDict::new(py);
178 dict.set_item("real", scirs_to_numpy_array1(Array1::from_vec(real), py)?)?;
179 dict.set_item("imag", scirs_to_numpy_array1(Array1::from_vec(imag), py)?)?;
180
181 Ok(dict.into())
182}
183
184#[pyfunction]
190fn hann_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
191 let mut window = Vec::with_capacity(n);
192 for i in 0..n {
193 let val = 0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos());
194 window.push(val);
195 }
196 scirs_to_numpy_array1(Array1::from_vec(window), py)
197}
198
199#[pyfunction]
201fn hamming_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
202 let mut window = Vec::with_capacity(n);
203 for i in 0..n {
204 let val = 0.54 - 0.46 * (2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64).cos();
205 window.push(val);
206 }
207 scirs_to_numpy_array1(Array1::from_vec(window), py)
208}
209
210#[pyfunction]
212fn blackman_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
213 let mut window = Vec::with_capacity(n);
214 for i in 0..n {
215 let t = 2.0 * std::f64::consts::PI * i as f64 / (n - 1) as f64;
216 let val = 0.42 - 0.5 * t.cos() + 0.08 * (2.0 * t).cos();
217 window.push(val);
218 }
219 scirs_to_numpy_array1(Array1::from_vec(window), py)
220}
221
222#[pyfunction]
224fn bartlett_py(py: Python, n: usize) -> PyResult<Py<PyArray1<f64>>> {
225 let mut window = Vec::with_capacity(n);
226 let half = (n - 1) as f64 / 2.0;
227 for i in 0..n {
228 let val = 1.0 - ((i as f64 - half) / half).abs();
229 window.push(val);
230 }
231 scirs_to_numpy_array1(Array1::from_vec(window), py)
232}
233
234#[pyfunction]
236fn kaiser_py(py: Python, n: usize, beta: f64) -> PyResult<Py<PyArray1<f64>>> {
237 let mut window = Vec::with_capacity(n);
238
239 fn bessel_i0(x: f64) -> f64 {
241 let mut sum = 1.0;
242 let mut term = 1.0;
243 for k in 1..50 {
244 term *= (x / 2.0).powi(2) / (k as f64).powi(2);
245 sum += term;
246 if term < 1e-12 {
247 break;
248 }
249 }
250 sum
251 }
252
253 let denom = bessel_i0(beta);
254 for i in 0..n {
255 let t = 2.0 * i as f64 / (n - 1) as f64 - 1.0;
256 let arg = beta * (1.0 - t * t).sqrt();
257 let val = bessel_i0(arg) / denom;
258 window.push(val);
259 }
260
261 scirs_to_numpy_array1(Array1::from_vec(window), py)
262}
263
264#[pyfunction]
278#[pyo3(signature = (order, cutoff, filter_type="lowpass"))]
279fn butter_py(py: Python, order: usize, cutoff: f64, filter_type: &str) -> PyResult<Py<PyAny>> {
280 let ftype = match filter_type.to_lowercase().as_str() {
281 "lowpass" | "low" => FilterType::Lowpass,
282 "highpass" | "high" => FilterType::Highpass,
283 "bandpass" | "band" => FilterType::Bandpass,
284 "bandstop" | "stop" => FilterType::Bandstop,
285 _ => {
286 return Err(pyo3::exceptions::PyValueError::new_err(
287 "Invalid filter type. Use 'lowpass', 'highpass', 'bandpass', or 'bandstop'",
288 ));
289 }
290 };
291
292 let (b, a) = butter(order, cutoff, ftype)
293 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
294
295 let dict = PyDict::new(py);
296 dict.set_item("b", scirs_to_numpy_array1(Array1::from_vec(b), py)?)?;
297 dict.set_item("a", scirs_to_numpy_array1(Array1::from_vec(a), py)?)?;
298
299 Ok(dict.into())
300}
301
302#[pyfunction]
313#[pyo3(signature = (order, ripple, cutoff, filter_type="lowpass"))]
314fn cheby1_py(
315 py: Python,
316 order: usize,
317 ripple: f64,
318 cutoff: f64,
319 filter_type: &str,
320) -> PyResult<Py<PyAny>> {
321 let ftype = match filter_type.to_lowercase().as_str() {
322 "lowpass" | "low" => FilterType::Lowpass,
323 "highpass" | "high" => FilterType::Highpass,
324 _ => {
325 return Err(pyo3::exceptions::PyValueError::new_err(
326 "Invalid filter type for cheby1. Use 'lowpass' or 'highpass'",
327 ));
328 }
329 };
330
331 let (b, a) = cheby1(order, ripple, cutoff, ftype)
332 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
333
334 let dict = PyDict::new(py);
335 dict.set_item("b", scirs_to_numpy_array1(Array1::from_vec(b), py)?)?;
336 dict.set_item("a", scirs_to_numpy_array1(Array1::from_vec(a), py)?)?;
337
338 Ok(dict.into())
339}
340
341#[pyfunction]
352#[pyo3(signature = (numtaps, cutoff, window="hamming", pass_zero=true))]
353fn firwin_py(
354 py: Python,
355 numtaps: usize,
356 cutoff: f64,
357 window: &str,
358 pass_zero: bool,
359) -> PyResult<Py<PyArray1<f64>>> {
360 let coeffs = firwin(numtaps, cutoff, window, pass_zero)
361 .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(format!("{e}")))?;
362
363 scirs_to_numpy_array1(Array1::from_vec(coeffs), py)
364}
365
366#[pyfunction]
374#[pyo3(signature = (x, height=None, distance=None))]
375fn find_peaks_py(
376 py: Python,
377 x: PyReadonlyArray1<f64>,
378 height: Option<f64>,
379 distance: Option<usize>,
380) -> PyResult<Py<PyArray1<i64>>> {
381 let x_arr = x.as_array();
382 let n = x_arr.len();
383
384 if n < 3 {
385 return scirs_to_numpy_array1(Array1::from_vec(vec![]), py);
386 }
387
388 let mut peaks: Vec<i64> = Vec::new();
389
390 for i in 1..n - 1 {
392 if x_arr[i] > x_arr[i - 1] && x_arr[i] > x_arr[i + 1] {
393 if let Some(h) = height {
395 if x_arr[i] < h {
396 continue;
397 }
398 }
399 peaks.push(i as i64);
400 }
401 }
402
403 if let Some(dist) = distance {
405 let mut filtered = Vec::new();
406 for &peak in &peaks {
407 let keep = filtered
408 .iter()
409 .all(|&p: &i64| (peak - p).unsigned_abs() >= dist as u64);
410 if keep {
411 filtered.push(peak);
412 }
413 }
414 peaks = filtered;
415 }
416
417 scirs_to_numpy_array1(Array1::from_vec(peaks), py)
418}
419
420pub fn register_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
422 m.add_function(wrap_pyfunction!(convolve_py, m)?)?;
424 m.add_function(wrap_pyfunction!(correlate_py, m)?)?;
425
426 m.add_function(wrap_pyfunction!(hilbert_py, m)?)?;
428
429 m.add_function(wrap_pyfunction!(hann_py, m)?)?;
431 m.add_function(wrap_pyfunction!(hamming_py, m)?)?;
432 m.add_function(wrap_pyfunction!(blackman_py, m)?)?;
433 m.add_function(wrap_pyfunction!(bartlett_py, m)?)?;
434 m.add_function(wrap_pyfunction!(kaiser_py, m)?)?;
435
436 m.add_function(wrap_pyfunction!(butter_py, m)?)?;
438 m.add_function(wrap_pyfunction!(cheby1_py, m)?)?;
439 m.add_function(wrap_pyfunction!(firwin_py, m)?)?;
440
441 m.add_function(wrap_pyfunction!(find_peaks_py, m)?)?;
443
444 Ok(())
445}