mod cpu;
pub mod filter;
pub mod impl_generic;
pub mod traits;
pub mod wavelet;
#[cfg(feature = "cuda")]
mod cuda;
#[cfg(feature = "wgpu")]
mod wgpu;
use numr::dtype::DType;
use numr::error::{Error, Result};
pub use traits::analysis::{
DecimateFilterImpl, DecimateParams, HilbertResult, PeakParams, PeakResult,
SignalAnalysisAlgorithms,
};
pub use traits::convolution::ConvMode;
pub use traits::edge::EdgeDetectionAlgorithms;
pub use traits::filter_apply::{
FilterApplicationAlgorithms, LfilterResult, PadType, SosfiltResult,
};
pub use traits::frequency_response::{FrequencyResponseAlgorithms, FreqzResult, FreqzSpec};
pub use traits::nd_filters::{BoundaryMode, NdFilterAlgorithms};
pub use traits::spectral::{
CoherenceResult, CsdResult, Detrend, PeriodogramParams, PeriodogramResult, PsdScaling,
SpectralAnalysisAlgorithms, SpectralWindow, WelchParams, WelchResult,
};
pub use traits::{ConvolutionAlgorithms, SpectrogramAlgorithms, StftAlgorithms};
pub use wavelet::{
CwtAlgorithms, CwtResult, DwtAlgorithms, DwtResult, WavedecResult, Wavelet, WaveletFamily,
};
pub use filter::{
FilterConversions, FilterOutput, FilterType, FirDesignAlgorithms, FirWindow,
IirDesignAlgorithms, IirDesignResult, SosFilter, SosPairing, TransferFunction, ZpkFilter,
};
pub fn validate_signal_dtype(dtype: DType, op: &'static str) -> Result<()> {
match dtype {
DType::F32 | DType::F64 => Ok(()),
_ => Err(Error::UnsupportedDType { dtype, op }),
}
}
pub fn validate_kernel_1d(kernel: &[usize], op: &'static str) -> Result<()> {
if kernel.len() != 1 {
return Err(Error::InvalidArgument {
arg: "kernel",
reason: format!("{op} requires 1D kernel, got {}-D", kernel.len()),
});
}
Ok(())
}
pub fn validate_kernel_2d(kernel: &[usize], op: &'static str) -> Result<()> {
if kernel.len() != 2 {
return Err(Error::InvalidArgument {
arg: "kernel",
reason: format!("{op} requires 2D kernel, got {}-D", kernel.len()),
});
}
Ok(())
}
pub fn validate_stft_params(n_fft: usize, hop_length: usize, op: &'static str) -> Result<()> {
if n_fft == 0 || !n_fft.is_power_of_two() {
return Err(Error::InvalidArgument {
arg: "n_fft",
reason: format!("{op} requires n_fft to be a positive power of 2, got {n_fft}"),
});
}
if hop_length == 0 {
return Err(Error::InvalidArgument {
arg: "hop_length",
reason: format!("{op} requires hop_length > 0, got {hop_length}"),
});
}
Ok(())
}
#[inline]
pub fn next_power_of_two(n: usize) -> usize {
n.next_power_of_two()
}
pub fn stft_num_frames(signal_len: usize, n_fft: usize, hop_length: usize, center: bool) -> usize {
let padded_len = if center {
signal_len + n_fft
} else {
signal_len
};
if padded_len < n_fft {
0
} else {
(padded_len - n_fft) / hop_length + 1
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stft_num_frames() {
let frames = stft_num_frames(1000, 256, 64, true);
let expected = (1000 + 256 - 256) / 64 + 1;
assert_eq!(frames, expected);
}
}