use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FilterType {
LowPass,
HighPass,
BandPass,
BandStop,
Custom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FilterWindow {
Rectangular,
Hamming,
Hanning,
Blackman,
Kaiser,
}
#[derive(Debug, Clone)]
pub struct FilterSpec {
pub filter_type: FilterType,
pub order: usize,
pub cutoff: f64,
pub cutoff_high: Option<f64>,
pub window: FilterWindow,
pub kaiser_beta: Option<f64>,
pub custom_coeffs: Option<Vec<f64>>,
}
impl Default for FilterSpec {
fn default() -> Self {
Self {
filter_type: FilterType::LowPass,
order: 64,
cutoff: 0.25,
cutoff_high: None,
window: FilterWindow::Hamming,
kaiser_beta: None,
custom_coeffs: None,
}
}
}
#[allow(dead_code)]
pub fn frequency_filter<T>(signal: &[T], filterspec: &FilterSpec) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
let max_size = 1024;
let limit = signal.len().min(max_size);
let input: Vec<f64> = signal
.iter()
.take(limit)
.map(|&val| {
NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
})
.collect::<FFTResult<Vec<_>>>()?;
let spectrum = fft(&input, None)?;
let freq_response = design_frequency_response(filterspec, spectrum.len())?;
let filtered_spectrum: Vec<Complex64> = spectrum
.iter()
.zip(&freq_response)
.map(|(&s, &r)| s * r)
.collect();
let result = ifft(&filtered_spectrum, None)?;
let filtered: Vec<f64> = result.iter().map(|c| c.re).collect();
Ok(filtered)
}
#[allow(dead_code)]
fn design_frequency_response(filter_spec: &FilterSpec, size: usize) -> FFTResult<Vec<f64>> {
if let Some(ref coeffs) = filter_spec.custom_coeffs {
if filter_spec.filter_type == FilterType::Custom {
return Ok(coeffs.clone());
}
}
let mut response = vec![0.0; size];
let cutoff_idx = (filter_spec.cutoff * size as f64) as usize;
let cutoff_high_idx = filter_spec
.cutoff_high
.map(|c| (c * size as f64) as usize)
.unwrap_or(cutoff_idx);
match filter_spec.filter_type {
FilterType::LowPass => {
for i in 0..=cutoff_idx.min(size / 2) {
response[i] = 1.0;
if i > 0 && i < size / 2 {
response[size - i] = 1.0;
}
}
}
FilterType::HighPass => {
for i in cutoff_idx..=size / 2 {
response[i] = 1.0;
if i > 0 && i < size / 2 {
response[size - i] = 1.0;
}
}
}
FilterType::BandPass => {
for i in cutoff_idx..=cutoff_high_idx.min(size / 2) {
response[i] = 1.0;
if i > 0 && i < size / 2 {
response[size - i] = 1.0;
}
}
}
FilterType::BandStop => {
for i in 0..=size / 2 {
if i <= cutoff_idx || i >= cutoff_high_idx {
response[i] = 1.0;
if i > 0 && i < size / 2 {
response[size - i] = 1.0;
}
}
}
}
FilterType::Custom => {
return Err(FFTError::ValueError(
"Custom filter type requires custom_coeffs to be provided".to_string(),
));
}
}
apply_window_to_response(&mut response, filter_spec);
Ok(response)
}
#[allow(dead_code)]
fn apply_window_to_response(response: &mut [f64], filterspec: &FilterSpec) {
let size = response.len();
match filterspec.window {
FilterWindow::Rectangular => {
}
FilterWindow::Hamming => {
for i in 0..size {
if response[i] > 0.0 {
let window_val =
0.54 - 0.46 * (2.0 * std::f64::consts::PI * i as f64 / size as f64).cos();
response[i] *= window_val;
}
}
}
FilterWindow::Hanning => {
for i in 0..size {
if response[i] > 0.0 {
let window_val =
0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / size as f64).cos());
response[i] *= window_val;
}
}
}
FilterWindow::Blackman => {
for i in 0..size {
if response[i] > 0.0 {
let x = 2.0 * std::f64::consts::PI * i as f64 / size as f64;
let window_val = 0.42 - 0.5 * x.cos() + 0.08 * (2.0 * x).cos();
response[i] *= window_val;
}
}
}
FilterWindow::Kaiser => {
let beta = filterspec.kaiser_beta.unwrap_or(3.0);
for i in 0..size {
if response[i] > 0.0 {
let x = 2.0 * i as f64 / size as f64 - 1.0;
let window_val = bessel_i0(beta * (1.0 - x * x).sqrt()) / bessel_i0(beta);
response[i] *= window_val;
}
}
}
}
}
#[allow(dead_code)]
fn bessel_i0(x: f64) -> f64 {
let ax = x.abs();
if ax < 3.75 {
let y = (x / 3.75).powi(2);
1.0 + y
* (3.5156229
+ y * (3.0899424
+ y * (1.2067492 + y * (0.2659732 + y * (0.0360768 + y * 0.0045813)))))
} else {
let y = 3.75 / ax;
(ax.exp() / ax.sqrt())
* (0.39894228
+ y * (0.01328592
+ y * (0.00225319
+ y * (-0.00157565
+ y * (0.00916281
+ y * (-0.02057706
+ y * (0.02635537 + y * (-0.01647633 + y * 0.00392377))))))))
}
}
#[allow(dead_code)]
pub fn convolve<T, U>(signal: &[T], kernel: &[U]) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
U: NumCast + Copy + Debug,
{
let max_size = 512;
let signal_len = signal.len().min(max_size);
let kernel_len = kernel.len().min(max_size);
let result_len = signal_len + kernel_len - 1;
let signal_f64: Vec<f64> = signal
.iter()
.take(signal_len)
.map(|&val| {
NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
})
.collect::<FFTResult<Vec<_>>>()?;
let kernel_f64: Vec<f64> = kernel
.iter()
.take(kernel_len)
.map(|&val| {
NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
})
.collect::<FFTResult<Vec<_>>>()?;
let mut signal_padded = signal_f64;
signal_padded.resize(result_len, 0.0);
let mut kernel_padded = kernel_f64;
kernel_padded.resize(result_len, 0.0);
let signal_fft = fft(&signal_padded, None)?;
let kernel_fft = fft(&kernel_padded, None)?;
let mut result_fft = Vec::with_capacity(result_len);
for i in 0..result_len {
result_fft.push(signal_fft[i] * kernel_fft[i]);
}
let result_complex = ifft(&result_fft, None)?;
let result: Vec<f64> = result_complex.iter().map(|c| c.re).collect();
Ok(result)
}
#[allow(dead_code)]
pub fn cross_correlate<T, U>(signal1: &[T], signal2: &[U]) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
U: NumCast + Copy + Debug,
{
let signal1_len = signal1.len();
let signal2_len = signal2.len();
let result_len = signal1_len + signal2_len - 1;
let signal1_f64: Vec<f64> = signal1
.iter()
.map(|&val| {
NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
})
.collect::<FFTResult<Vec<_>>>()?;
let signal2_f64: Vec<f64> = signal2
.iter()
.map(|&val| {
NumCast::from(val)
.ok_or_else(|| FFTError::ValueError(format!("Could not convert {:?} to f64", val)))
})
.collect::<FFTResult<Vec<_>>>()?;
let mut signal1_padded = signal1_f64.clone();
signal1_padded.resize(result_len, 0.0);
let mut signal2_padded = signal2_f64.clone();
signal2_padded.resize(result_len, 0.0);
let signal1_fft = fft(&signal1_padded, None)?;
let signal2_fft = fft(&signal2_padded, None)?;
let mut result_fft = Vec::with_capacity(result_len);
for i in 0..result_len {
result_fft.push(signal1_fft[i] * signal2_fft[i].conj());
}
let result_complex = ifft(&result_fft, None)?;
let result: Vec<f64> = result_complex.iter().map(|c| c.re).collect();
Ok(result)
}
#[allow(dead_code)]
pub fn design_fir_filter(filter_spec: &FilterSpec) -> FFTResult<Vec<f64>> {
let order = filter_spec.order;
let adjusted_order = if order % 2 == 0 { order + 1 } else { order };
let n_freqs = 2048; let freq_response = design_frequency_response(filter_spec, n_freqs)?;
let mut complex_response = vec![Complex64::new(0.0, 0.0); n_freqs];
for i in 0..n_freqs {
complex_response[i] = Complex64::new(freq_response[i], 0.0);
}
let impulse_response = ifft(&complex_response, None)?;
let half_order = adjusted_order / 2;
let mut coeffs = vec![0.0; adjusted_order];
for i in 0..adjusted_order {
let idx = (i + n_freqs - half_order) % n_freqs;
coeffs[i] = impulse_response[idx].re;
}
let mut window = vec![0.0; adjusted_order];
match filter_spec.window {
FilterWindow::Rectangular => {
window.iter_mut().for_each(|w| *w = 1.0);
}
FilterWindow::Hamming => {
for i in 0..adjusted_order {
window[i] = 0.54
- 0.46
* (2.0 * std::f64::consts::PI * i as f64 / (adjusted_order - 1) as f64)
.cos();
}
}
FilterWindow::Hanning => {
for i in 0..adjusted_order {
window[i] = 0.5
* (1.0
- (2.0 * std::f64::consts::PI * i as f64 / (adjusted_order - 1) as f64)
.cos());
}
}
FilterWindow::Blackman => {
for i in 0..adjusted_order {
let x = 2.0 * std::f64::consts::PI * i as f64 / (adjusted_order - 1) as f64;
window[i] = 0.42 - 0.5 * x.cos() + 0.08 * (2.0 * x).cos();
}
}
FilterWindow::Kaiser => {
let beta = filter_spec.kaiser_beta.unwrap_or(3.0);
for i in 0..adjusted_order {
let x = 2.0 * i as f64 / (adjusted_order - 1) as f64 - 1.0;
window[i] = bessel_i0(beta * (1.0 - x * x).sqrt()) / bessel_i0(beta);
}
}
}
for i in 0..adjusted_order {
coeffs[i] *= window[i];
}
let dc_gain: f64 = coeffs.iter().sum();
if dc_gain != 0.0 {
for coeff in &mut coeffs {
*coeff /= dc_gain;
}
}
Ok(coeffs)
}
#[allow(dead_code)]
pub fn fir_filter<T>(signal: &[T], filtercoeffs: &[f64]) -> FFTResult<Vec<f64>>
where
T: NumCast + Copy + Debug,
{
convolve(signal, filtercoeffs)
}
#[cfg(test)]
#[cfg(feature = "never")] mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_frequency_filter_lowpass() {
let n = 128;
let mut signal = vec![0.0; n];
for i in 0..n {
signal[i] += (2.0 * std::f64::consts::PI * 2.0 * i as f64 / n as f64).sin();
}
for i in 0..n {
signal[i] += 0.5 * (2.0 * std::f64::consts::PI * 10.0 * i as f64 / n as f64).sin();
}
let filter_spec = FilterSpec {
filter_type: FilterType::LowPass,
order: 32,
cutoff: 0.25, window: FilterWindow::Hamming,
..Default::default()
};
let filtered = frequency_filter(&signal, &filter_spec).expect("Operation failed");
assert_eq!(filtered.len(), signal.len());
}
#[test]
fn test_convolution() {
let signal = vec![1.0, 2.0, 3.0];
let kernel = vec![0.5, 0.5];
let result = convolve(&signal, &kernel).expect("Operation failed");
assert_eq!(result.len(), signal.len() + kernel.len() - 1);
assert_relative_eq!(result[0], 0.5, epsilon = 1e-10);
assert_relative_eq!(result[1], 1.5, epsilon = 1e-10);
assert_relative_eq!(result[2], 2.5, epsilon = 1e-10);
assert_relative_eq!(result[3], 1.5, epsilon = 1e-10);
}
}