use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyAny, PyDict};
use scirs2_numpy::{IntoPyArray, PyArray1, PyArray2};
use scirs2_core::ndarray::{Array1, Array2};
#[pyfunction]
#[pyo3(signature = (x, fs=1.0, window="hann", nperseg=None, noverlap=None, nfft=None))]
pub fn stft_py(
py: Python,
x: Vec<f64>,
fs: f64,
window: &str,
nperseg: Option<usize>,
noverlap: Option<usize>,
nfft: Option<usize>,
) -> PyResult<Py<PyAny>> {
if x.is_empty() {
return Err(PyValueError::new_err("x must not be empty"));
}
let result = scirs2_signal::stft(
&x,
Some(fs),
Some(window),
nperseg,
noverlap,
nfft,
None,
None,
None,
)
.map_err(|e| PyRuntimeError::new_err(format!("STFT failed: {}", e)))?;
let (freqs, times, stft_complex) = result;
let n_freqs = freqs.len();
let n_times = times.len();
let mut stft_real: Vec<f64> = Vec::with_capacity(n_freqs * n_times);
let mut stft_imag: Vec<f64> = Vec::with_capacity(n_freqs * n_times);
for row in &stft_complex {
for c in row {
stft_real.push(c.re);
stft_imag.push(c.im);
}
}
let freqs_arr = Array1::from_vec(freqs);
let times_arr = Array1::from_vec(times);
let stft_real_arr = Array2::from_shape_vec((n_freqs, n_times), stft_real)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to reshape STFT real: {}", e)))?;
let stft_imag_arr = Array2::from_shape_vec((n_freqs, n_times), stft_imag)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to reshape STFT imag: {}", e)))?;
let dict = PyDict::new(py);
dict.set_item("freqs", freqs_arr.into_pyarray(py))?;
dict.set_item("times", times_arr.into_pyarray(py))?;
dict.set_item("stft_real", stft_real_arr.into_pyarray(py))?;
dict.set_item("stft_imag", stft_imag_arr.into_pyarray(py))?;
Ok(dict.into())
}
#[pyfunction]
#[pyo3(signature = (x, fs=1.0, window="hann", nperseg=None, noverlap=None, nfft=None, scaling="density"))]
pub fn welch_py(
py: Python,
x: Vec<f64>,
fs: f64,
window: &str,
nperseg: Option<usize>,
noverlap: Option<usize>,
nfft: Option<usize>,
scaling: &str,
) -> PyResult<Py<PyAny>> {
if x.is_empty() {
return Err(PyValueError::new_err("x must not be empty"));
}
let result = scirs2_signal::welch(
&x,
Some(fs),
Some(window),
nperseg,
noverlap,
nfft,
None,
Some(scaling),
)
.map_err(|e| PyRuntimeError::new_err(format!("Welch PSD failed: {}", e)))?;
let (freqs, psd) = result;
let dict = PyDict::new(py);
dict.set_item("freqs", Array1::from_vec(freqs).into_pyarray(py))?;
dict.set_item("psd", Array1::from_vec(psd).into_pyarray(py))?;
Ok(dict.into())
}
#[pyfunction]
#[pyo3(signature = (x, fs=1.0, window=None, nfft=None, scaling="density"))]
pub fn periodogram_py(
py: Python,
x: Vec<f64>,
fs: f64,
window: Option<&str>,
nfft: Option<usize>,
scaling: &str,
) -> PyResult<Py<PyAny>> {
if x.is_empty() {
return Err(PyValueError::new_err("x must not be empty"));
}
let result = scirs2_signal::periodogram(
&x,
Some(fs),
window,
nfft,
None,
Some(scaling),
)
.map_err(|e| PyRuntimeError::new_err(format!("Periodogram failed: {}", e)))?;
let (freqs, psd) = result;
let dict = PyDict::new(py);
dict.set_item("freqs", Array1::from_vec(freqs).into_pyarray(py))?;
dict.set_item("psd", Array1::from_vec(psd).into_pyarray(py))?;
Ok(dict.into())
}
#[pyfunction]
#[pyo3(signature = (x, fs=1.0, window="hann", nperseg=None, noverlap=None, nfft=None))]
pub fn spectrogram_py(
py: Python,
x: Vec<f64>,
fs: f64,
window: &str,
nperseg: Option<usize>,
noverlap: Option<usize>,
nfft: Option<usize>,
) -> PyResult<Py<PyAny>> {
if x.is_empty() {
return Err(PyValueError::new_err("x must not be empty"));
}
let result = scirs2_signal::spectrogram(
&x,
Some(fs),
Some(window),
nperseg,
noverlap,
nfft,
None,
None,
None,
)
.map_err(|e| PyRuntimeError::new_err(format!("Spectrogram failed: {}", e)))?;
let (freqs, times, sxx) = result;
let n_freqs = freqs.len();
let n_times = times.len();
let sxx_flat: Vec<f64> = sxx.into_iter().flatten().collect();
let sxx_arr = Array2::from_shape_vec((n_freqs, n_times), sxx_flat)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to reshape Sxx: {}", e)))?;
let dict = PyDict::new(py);
dict.set_item("freqs", Array1::from_vec(freqs).into_pyarray(py))?;
dict.set_item("times", Array1::from_vec(times).into_pyarray(py))?;
dict.set_item("Sxx", sxx_arr.into_pyarray(py))?;
Ok(dict.into())
}
#[pyfunction]
pub fn ricker_py(
py: Python,
points: usize,
a: f64,
) -> PyResult<Py<PyArray1<f64>>> {
if points == 0 {
return Err(PyValueError::new_err("points must be > 0"));
}
if a <= 0.0 {
return Err(PyValueError::new_err("a must be positive"));
}
let wav = scirs2_signal::ricker(points, a)
.map_err(|e| PyRuntimeError::new_err(format!("Ricker wavelet failed: {}", e)))?;
Ok(Array1::from_vec(wav).into_pyarray(py).unbind())
}
#[pyfunction]
#[pyo3(signature = (t, f0, t1, f1, method="linear", phi=0.0))]
pub fn chirp_py(
py: Python,
t: Vec<f64>,
f0: f64,
t1: f64,
f1: f64,
method: &str,
phi: f64,
) -> PyResult<Py<PyArray1<f64>>> {
if t.is_empty() {
return Err(PyValueError::new_err("t must not be empty"));
}
let result = scirs2_signal::chirp(&t, f0, t1, f1, method, phi)
.map_err(|e| PyRuntimeError::new_err(format!("Chirp generation failed: {}", e)))?;
Ok(Array1::from_vec(result).into_pyarray(py).unbind())
}
#[pyfunction]
#[pyo3(signature = (t, duty=0.5))]
pub fn square_py(
py: Python,
t: Vec<f64>,
duty: f64,
) -> PyResult<Py<PyArray1<f64>>> {
if t.is_empty() {
return Err(PyValueError::new_err("t must not be empty"));
}
if !(0.0..=1.0).contains(&duty) {
return Err(PyValueError::new_err("duty must be between 0 and 1"));
}
let result = scirs2_signal::square(&t, duty)
.map_err(|e| PyRuntimeError::new_err(format!("Square wave generation failed: {}", e)))?;
Ok(Array1::from_vec(result).into_pyarray(py).unbind())
}
#[pyfunction]
#[pyo3(signature = (t, width=1.0))]
pub fn sawtooth_py(
py: Python,
t: Vec<f64>,
width: f64,
) -> PyResult<Py<PyArray1<f64>>> {
if t.is_empty() {
return Err(PyValueError::new_err("t must not be empty"));
}
let result = scirs2_signal::sawtooth(&t, width)
.map_err(|e| PyRuntimeError::new_err(format!("Sawtooth wave generation failed: {}", e)))?;
Ok(Array1::from_vec(result).into_pyarray(py).unbind())
}
pub fn register_signal_ext_module(m: &Bound<'_, pyo3::PyModule>) -> pyo3::PyResult<()> {
m.add_function(wrap_pyfunction!(stft_py, m)?)?;
m.add_function(wrap_pyfunction!(welch_py, m)?)?;
m.add_function(wrap_pyfunction!(periodogram_py, m)?)?;
m.add_function(wrap_pyfunction!(spectrogram_py, m)?)?;
m.add_function(wrap_pyfunction!(ricker_py, m)?)?;
m.add_function(wrap_pyfunction!(chirp_py, m)?)?;
m.add_function(wrap_pyfunction!(square_py, m)?)?;
m.add_function(wrap_pyfunction!(sawtooth_py, m)?)?;
Ok(())
}