use scirs2_core::ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
use scirs2_core::numeric::Complex;
use scirs2_core::numeric::{Float, FloatConst};
use std::f64::consts::PI;
use crate::error::{LinalgError, LinalgResult};
pub type Complex64 = Complex<f64>;
pub type Complex32 = Complex<f32>;
#[derive(Debug, Clone, Copy)]
pub enum FFTAlgorithm {
CooleyTukey,
MixedRadix,
PrimeFactor,
Auto,
}
#[derive(Debug, Clone, Copy)]
pub enum WindowFunction {
Rectangular,
Hann,
Hamming,
Blackman,
Kaiser(f64),
Tukey(f64),
Gaussian(f64),
}
#[derive(Debug)]
pub struct FFTPlan<F> {
pub size: usize,
pub algorithm: FFTAlgorithm,
pub twiddle_factors: Vec<Complex<F>>,
pub bit_reversal: Vec<usize>,
pub real_input: bool,
}
impl<F> FFTPlan<F>
where
F: Float + FloatConst,
{
pub fn new(size: usize, algorithm: FFTAlgorithm, realinput: bool) -> LinalgResult<Self> {
if size == 0 {
return Err(LinalgError::ShapeError(
"FFT size must be positive".to_string(),
));
}
let selected_algorithm = match algorithm {
FFTAlgorithm::Auto => Self::select_algorithm(size),
_ => algorithm,
};
if let FFTAlgorithm::CooleyTukey = selected_algorithm {
if !size.is_power_of_two() {
return Err(LinalgError::ShapeError(
"Cooley-Tukey algorithm requires power-of-2 size".to_string(),
));
}
}
let twiddle_factors = Self::compute_twiddle_factors(size);
let bit_reversal = Self::compute_bit_reversal(size);
Ok(FFTPlan {
size,
algorithm: selected_algorithm,
twiddle_factors,
bit_reversal,
real_input: realinput,
})
}
fn select_algorithm(size: usize) -> FFTAlgorithm {
if size.is_power_of_two() {
FFTAlgorithm::CooleyTukey
} else {
FFTAlgorithm::MixedRadix
}
}
fn compute_twiddle_factors(size: usize) -> Vec<Complex<F>> {
let mut twiddles = Vec::with_capacity(size);
let two_pi = F::from(2.0).expect("Failed to convert constant to float") * F::PI();
for k in 0..size {
let angle = -two_pi * F::from(k).expect("Failed to convert to float")
/ F::from(size).expect("Failed to convert to float");
twiddles.push(Complex::new(angle.cos(), angle.sin()));
}
twiddles
}
fn compute_bit_reversal(size: usize) -> Vec<usize> {
let mut reversal = vec![0; size];
let logsize = (size as f64).log2() as usize;
for (i, item) in reversal.iter_mut().enumerate().take(size) {
*item = Self::reverse_bits(i, logsize);
}
reversal
}
fn reverse_bits(mut num: usize, bits: usize) -> usize {
let mut result = 0;
for _ in 0..bits {
result = (result << 1) | (num & 1);
num >>= 1;
}
result
}
}
#[allow(dead_code)]
pub fn fft_1d(input: &ArrayView1<Complex64>, inverse: bool) -> LinalgResult<Array1<Complex64>> {
let size = input.len();
if !size.is_power_of_two() {
return Err(LinalgError::ShapeError(
"Input size must be power of 2 for Cooley-Tukey FFT".to_string(),
));
}
let plan = FFTPlan::new(size, FFTAlgorithm::CooleyTukey, false)?;
fft_1d_with_plan(input, &plan, inverse)
}
#[allow(dead_code)]
pub fn fft_1d_with_plan(
input: &ArrayView1<Complex64>,
plan: &FFTPlan<f64>,
inverse: bool,
) -> LinalgResult<Array1<Complex64>> {
if input.len() != plan.size {
return Err(LinalgError::ShapeError(format!(
"Input size {} doesn't match plan size {}",
input.len(),
plan.size
)));
}
match plan.algorithm {
FFTAlgorithm::CooleyTukey => cooley_tukey_fft(input, plan, inverse),
FFTAlgorithm::MixedRadix => mixed_radix_fft(input, plan, inverse),
_ => Err(LinalgError::ComputationError(
"Unsupported FFT algorithm".to_string(),
)),
}
}
#[allow(dead_code)]
fn cooley_tukey_fft(
input: &ArrayView1<Complex64>,
plan: &FFTPlan<f64>,
inverse: bool,
) -> LinalgResult<Array1<Complex64>> {
let size = input.len();
let mut data = input.to_owned();
for i in 0..size {
let j = plan.bit_reversal[i];
if i < j {
data.swap(i, j);
}
}
let mut length = 2;
while length <= size {
let half_length = length / 2;
let step = size / length;
for start in (0..size).step_by(length) {
for i in 0..half_length {
let u = data[start + i];
let twiddle_index = i * step;
let mut twiddle = plan.twiddle_factors[twiddle_index];
if inverse {
twiddle = twiddle.conj();
}
let v = data[start + i + half_length] * twiddle;
data[start + i] = u + v;
data[start + i + half_length] = u - v;
}
}
length *= 2;
}
if inverse {
let scale = Complex64::new(1.0 / size as f64, 0.0);
for elem in data.iter_mut() {
*elem *= scale;
}
}
Ok(data)
}
#[allow(dead_code)]
fn mixed_radix_fft(
input: &ArrayView1<Complex64>,
_plan: &FFTPlan<f64>,
inverse: bool,
) -> LinalgResult<Array1<Complex64>> {
let size = input.len();
bluestein_fft(input, inverse)
}
#[allow(dead_code)]
pub fn bluestein_fft(
input: &ArrayView1<Complex64>,
inverse: bool,
) -> LinalgResult<Array1<Complex64>> {
let n = input.len();
if n <= 1 {
return Ok(input.to_owned());
}
let m = (2 * n - 1).next_power_of_two();
let mut chirp = Array1::zeros(m);
let pi_over_n = if inverse {
PI / n as f64
} else {
-PI / n as f64
};
for k in 0..n {
let arg = pi_over_n * (k * k) as f64;
chirp[k] = Complex64::new(arg.cos(), arg.sin());
if k > 0 && k < n {
chirp[m - k] = chirp[k];
}
}
let mut a = Array1::zeros(m);
for k in 0..n {
a[k] = input[k] * chirp[k];
}
let a_fft = fft_power_of_2(&a.view())?;
let chirp_fft = fft_power_of_2(&chirp.view())?;
let mut product = Array1::zeros(m);
for k in 0..m {
product[k] = a_fft[k] * chirp_fft[k];
}
let ifft_result = ifft_power_of_2(&product.view())?;
let mut result = Array1::zeros(n);
for k in 0..n {
result[k] = ifft_result[k] * chirp[k];
}
if inverse {
let scale = Complex64::new(1.0 / n as f64, 0.0);
for elem in result.iter_mut() {
*elem *= scale;
}
}
Ok(result)
}
#[allow(dead_code)]
fn fft_power_of_2(input: &ArrayView1<Complex64>) -> LinalgResult<Array1<Complex64>> {
if input.len().is_power_of_two() {
fft_1d(input, false)
} else {
Err(LinalgError::ShapeError(
"Input size must be power of 2".to_string(),
))
}
}
#[allow(dead_code)]
fn ifft_power_of_2(input: &ArrayView1<Complex64>) -> LinalgResult<Array1<Complex64>> {
if input.len().is_power_of_two() {
fft_1d(input, true)
} else {
Err(LinalgError::ShapeError(
"Input size must be power of 2".to_string(),
))
}
}
#[allow(dead_code)]
pub fn rfft_1d(input: &ArrayView1<f64>) -> LinalgResult<Array1<Complex64>> {
let n = input.len();
let mut complex_input = Array1::zeros(n);
for i in 0..n {
complex_input[i] = Complex64::new(input[i], 0.0);
}
let full_fft = if n.is_power_of_two() {
fft_1d(&complex_input.view(), false)?
} else {
bluestein_fft(&complex_input.view(), false)?
};
let outputsize = n / 2 + 1;
let mut result = Array1::zeros(outputsize);
for i in 0..outputsize {
result[i] = full_fft[i];
}
Ok(result)
}
#[allow(dead_code)]
pub fn irfft_1d(input: &ArrayView1<Complex64>, outputsize: usize) -> LinalgResult<Array1<f64>> {
if !outputsize.is_multiple_of(2) {
return Err(LinalgError::ShapeError(
"Output size must be even for IRFFT".to_string(),
));
}
let expected_inputsize = outputsize / 2 + 1;
if input.len() != expected_inputsize {
return Err(LinalgError::ShapeError(format!(
"Input size {} doesn't match expected size {} for output size {}",
input.len(),
expected_inputsize,
outputsize
)));
}
let mut full_spectrum = Array1::zeros(outputsize);
for i in 0..input.len() {
full_spectrum[i] = input[i];
}
for i in 1..outputsize / 2 {
full_spectrum[outputsize - i] = input[i].conj();
}
let ifft_result = if outputsize.is_power_of_two() {
fft_1d(&full_spectrum.view(), true)?
} else {
bluestein_fft(&full_spectrum.view(), true)?
};
let mut result = Array1::zeros(outputsize);
for i in 0..outputsize {
result[i] = ifft_result[i].re;
}
Ok(result)
}
#[allow(dead_code)]
pub fn fft_2d(input: &ArrayView2<Complex64>, inverse: bool) -> LinalgResult<Array2<Complex64>> {
let (rows, cols) = input.dim();
let mut result = input.to_owned();
for i in 0..rows {
let row = result.row(i).to_owned();
let row_fft = if cols.is_power_of_two() {
fft_1d(&row.view(), inverse)?
} else {
bluestein_fft(&row.view(), inverse)?
};
for j in 0..cols {
result[[i, j]] = row_fft[j];
}
}
for j in 0..cols {
let col = result.column(j).to_owned();
let col_fft = if rows.is_power_of_two() {
fft_1d(&col.view(), inverse)?
} else {
bluestein_fft(&col.view(), inverse)?
};
for i in 0..rows {
result[[i, j]] = col_fft[i];
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn fft_3d(input: &ArrayView3<Complex64>, inverse: bool) -> LinalgResult<Array3<Complex64>> {
let (depth, rows, cols) = input.dim();
let mut result = input.to_owned();
for i in 0..rows {
for j in 0..cols {
let mut line = Array1::zeros(depth);
for k in 0..depth {
line[k] = result[[k, i, j]];
}
let line_fft = if depth.is_power_of_two() {
fft_1d(&line.view(), inverse)?
} else {
bluestein_fft(&line.view(), inverse)?
};
for k in 0..depth {
result[[k, i, j]] = line_fft[k];
}
}
}
for k in 0..depth {
for j in 0..cols {
let mut line = Array1::zeros(rows);
for i in 0..rows {
line[i] = result[[k, i, j]];
}
let line_fft = if rows.is_power_of_two() {
fft_1d(&line.view(), inverse)?
} else {
bluestein_fft(&line.view(), inverse)?
};
for i in 0..rows {
result[[k, i, j]] = line_fft[i];
}
}
}
for k in 0..depth {
for i in 0..rows {
let mut line = Array1::zeros(cols);
for j in 0..cols {
line[j] = result[[k, i, j]];
}
let line_fft = if cols.is_power_of_two() {
fft_1d(&line.view(), inverse)?
} else {
bluestein_fft(&line.view(), inverse)?
};
for j in 0..cols {
result[[k, i, j]] = line_fft[j];
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn apply_window(signal: &ArrayView1<f64>, window: WindowFunction) -> LinalgResult<Array1<f64>> {
let n = signal.len();
let mut windowed = signal.to_owned();
match window {
WindowFunction::Rectangular => {
}
WindowFunction::Hann => {
for i in 0..n {
let factor = 0.5 * (1.0 - (2.0 * PI * i as f64 / (n - 1) as f64).cos());
windowed[i] *= factor;
}
}
WindowFunction::Hamming => {
for i in 0..n {
let factor = 0.54 - 0.46 * (2.0 * PI * i as f64 / (n - 1) as f64).cos();
windowed[i] *= factor;
}
}
WindowFunction::Blackman => {
for i in 0..n {
let x = 2.0 * PI * i as f64 / (n - 1) as f64;
let factor = 0.42 - 0.5 * x.cos() + 0.08 * (2.0 * x).cos();
windowed[i] *= factor;
}
}
WindowFunction::Kaiser(beta) => {
let i0_beta = modified_bessel_i0(beta);
for i in 0..n {
let x = 2.0 * i as f64 / (n - 1) as f64 - 1.0;
let factor = modified_bessel_i0(beta * (1.0 - x * x).sqrt()) / i0_beta;
windowed[i] *= factor;
}
}
WindowFunction::Tukey(alpha) => {
let taper_len = ((alpha * n as f64) / 2.0) as usize;
for i in 0..n {
let factor = if i < taper_len {
0.5 * (1.0 + (PI * i as f64 / taper_len as f64 - PI).cos())
} else if i >= n - taper_len {
0.5 * (1.0 + (PI * (n - 1 - i) as f64 / taper_len as f64 - PI).cos())
} else {
1.0
};
windowed[i] *= factor;
}
}
WindowFunction::Gaussian(sigma) => {
let center = (n - 1) as f64 / 2.0;
for i in 0..n {
let x = (i as f64 - center) / sigma;
let factor = (-0.5 * x * x).exp();
windowed[i] *= factor;
}
}
}
Ok(windowed)
}
#[allow(dead_code)]
fn modified_bessel_i0(x: f64) -> f64 {
let mut result = 1.0;
let mut term = 1.0;
let mut k = 1.0;
while term.abs() > 1e-12 * result.abs() {
term *= (x / 2.0) * (x / 2.0) / (k * k);
result += term;
k += 1.0;
}
result
}
#[allow(dead_code)]
pub fn dct_1d(input: &ArrayView1<f64>) -> LinalgResult<Array1<f64>> {
let n = input.len();
let mut result = Array1::zeros(n);
for k in 0..n {
let mut sum = 0.0;
for i in 0..n {
let angle = PI * k as f64 * (2.0 * i as f64 + 1.0) / (2.0 * n as f64);
sum += input[i] * angle.cos();
}
let normalization = if k == 0 {
(1.0 / n as f64).sqrt()
} else {
(2.0 / n as f64).sqrt()
};
result[k] = sum * normalization;
}
Ok(result)
}
#[allow(dead_code)]
pub fn idct_1d(input: &ArrayView1<f64>) -> LinalgResult<Array1<f64>> {
let n = input.len();
let mut result = Array1::zeros(n);
for i in 0..n {
let mut sum = 0.0;
sum += input[0] * (1.0 / n as f64).sqrt();
for k in 1..n {
let angle = PI * k as f64 * (2.0 * i as f64 + 1.0) / (2.0 * n as f64);
sum += input[k] * (2.0 / n as f64).sqrt() * angle.cos();
}
result[i] = sum;
}
Ok(result)
}
#[allow(dead_code)]
pub fn dst_1d(input: &ArrayView1<f64>) -> LinalgResult<Array1<f64>> {
let n = input.len();
let mut result = Array1::zeros(n);
for k in 0..n {
let mut sum = 0.0;
for i in 0..n {
let angle = PI * (k + 1) as f64 * (i + 1) as f64 / (n + 1) as f64;
sum += input[i] * angle.sin();
}
result[k] = sum * (2.0 / (n + 1) as f64).sqrt();
}
Ok(result)
}
#[allow(dead_code)]
pub fn fft_convolve(
signal1: &ArrayView1<f64>,
signal2: &ArrayView1<f64>,
) -> LinalgResult<Array1<f64>> {
let n1 = signal1.len();
let n2 = signal2.len();
let outputsize = n1 + n2 - 1;
let fftsize = outputsize.next_power_of_two();
let mut padded1 = Array1::zeros(fftsize);
let mut padded2 = Array1::zeros(fftsize);
for i in 0..n1 {
padded1[i] = signal1[i];
}
for i in 0..n2 {
padded2[i] = signal2[i];
}
let complex1 = rfft_1d(&padded1.view())?;
let complex2 = rfft_1d(&padded2.view())?;
let mut product = Array1::zeros(complex1.len());
for i in 0..complex1.len() {
product[i] = complex1[i] * complex2[i];
}
let result_full = irfft_1d(&product.view(), fftsize)?;
let mut result = Array1::zeros(outputsize);
for i in 0..outputsize {
result[i] = result_full[i];
}
Ok(result)
}
#[allow(dead_code)]
pub fn periodogram_psd(
signal: &ArrayView1<f64>,
window: WindowFunction,
nfft: Option<usize>,
) -> LinalgResult<Array1<f64>> {
let n = signal.len();
let fftsize = nfft.unwrap_or(n);
let windowed = apply_window(signal, window)?;
let mut padded = Array1::zeros(fftsize);
for i in 0..n.min(fftsize) {
padded[i] = windowed[i];
}
let fft_result = rfft_1d(&padded.view())?;
let mut psd = Array1::zeros(fft_result.len());
let normalization = 1.0 / (fftsize as f64);
for i in 0..fft_result.len() {
psd[i] = fft_result[i].norm_sqr() * normalization;
}
if fftsize.is_multiple_of(2) && fft_result.len() > 1 {
for i in 1..fft_result.len() - 1 {
psd[i] *= 2.0;
}
} else {
for i in 1..fft_result.len() {
psd[i] *= 2.0;
}
}
Ok(psd)
}
#[allow(dead_code)]
pub fn welch_psd(
signal: &ArrayView1<f64>,
nperseg: usize,
overlap: f64,
window: WindowFunction,
) -> LinalgResult<Array1<f64>> {
if !(0.0..1.0).contains(&overlap) {
return Err(LinalgError::ShapeError(
"Overlap must be between 0.0 and 1.0".to_string(),
));
}
let n = signal.len();
let step = ((1.0 - overlap) * nperseg as f64) as usize;
let num_segments = if n >= nperseg {
(n - nperseg) / step + 1
} else {
0
};
if num_segments == 0 {
return Err(LinalgError::ShapeError(
"Signal too short for given segment length".to_string(),
));
}
let fftsize = nperseg.next_power_of_two();
let outputsize = fftsize / 2 + 1;
let mut psd_sum = Array1::zeros(outputsize);
for seg in 0..num_segments {
let start = seg * step;
let end = (start + nperseg).min(n);
let mut segment = Array1::zeros(nperseg);
for i in 0..(end - start) {
segment[i] = signal[start + i];
}
let segment_psd = periodogram_psd(&segment.view(), window, Some(fftsize))?;
for i in 0..outputsize {
psd_sum[i] += segment_psd[i];
}
}
for i in 0..outputsize {
psd_sum[i] /= num_segments as f64;
}
Ok(psd_sum)
}
#[allow(dead_code)]
pub fn hadamard_transform(input: &ArrayView1<f64>, inverse: bool) -> LinalgResult<Array1<f64>> {
let n = input.len();
if !n.is_power_of_two() {
return Err(LinalgError::ShapeError(
"Input length must be a power of 2 for Hadamard transform".to_string(),
));
}
if n == 0 {
return Ok(Array1::zeros(0));
}
let mut data = input.to_owned();
let mut size = 1;
while size < n {
for start in (0..n).step_by(size * 2) {
for i in 0..size {
let left = start + i;
let right = start + i + size;
if right < n {
let a = data[left];
let b = data[right];
data[left] = a + b;
data[right] = a - b;
}
}
}
size *= 2;
}
if inverse {
let scale = 1.0 / n as f64;
data.mapv_inplace(|x| x * scale);
}
Ok(data)
}
#[allow(dead_code)]
pub fn walsh_hadamard_transform(
input: &ArrayView1<f64>,
inverse: bool,
) -> LinalgResult<Array1<f64>> {
let n = input.len();
if !n.is_power_of_two() {
return Err(LinalgError::ShapeError(
"Input length must be a power of 2 for Walsh-Hadamard transform".to_string(),
));
}
let mut data = input.to_owned();
let log_n = (n as f64).log2() as usize;
for i in 0..n {
let j = bit_reverse(i, log_n);
if i < j {
data.swap(i, j);
}
}
hadamard_transform(&data.view(), inverse)
}
#[allow(dead_code)]
fn bit_reverse(mut n: usize, bits: usize) -> usize {
let mut result = 0;
for _ in 0..bits {
result = (result << 1) | (n & 1);
n >>= 1;
}
result
}
#[allow(dead_code)]
pub fn fast_walsh_transform(input: &ArrayView1<f64>, inverse: bool) -> LinalgResult<Array1<f64>> {
let n = input.len();
if !n.is_power_of_two() {
return Err(LinalgError::ShapeError(
"Input length must be a power of 2 for Fast Walsh Transform".to_string(),
));
}
let mut data = input.to_owned();
let mut h = 1;
while h < n {
for i in (0..n).step_by(h * 2) {
for j in i..i + h {
let u = data[j];
let v = data[j + h];
data[j] = u + v;
data[j + h] = u - v;
}
}
h *= 2;
}
if inverse {
let scale = 1.0 / n as f64;
data.mapv_inplace(|x| x * scale);
}
Ok(data)
}
#[allow(dead_code)]
pub fn fft_frequencies(n: usize, sample_rate: f64, realfft: bool) -> Array1<f64> {
let outputsize = if realfft { n / 2 + 1 } else { n };
let mut freqs = Array1::zeros(outputsize);
let df = sample_rate / n as f64;
if realfft {
for i in 0..outputsize {
freqs[i] = i as f64 * df;
}
} else {
for i in 0..outputsize {
freqs[i] = if i <= n / 2 {
i as f64 * df
} else {
(i as i64 - n as i64) as f64 * df
};
}
}
freqs
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_fft_plan_creation() {
let plan =
FFTPlan::<f64>::new(8, FFTAlgorithm::CooleyTukey, false).expect("Operation failed");
assert_eq!(plan.size, 8);
assert_eq!(plan.twiddle_factors.len(), 8);
assert_eq!(plan.bit_reversal.len(), 8);
}
#[test]
fn test_fft_1d_basic() {
let input = array![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
];
let result = fft_1d(&input.view(), false).expect("Operation failed");
assert_eq!(result.len(), 4);
assert_relative_eq!(result[0].re, 1.0, epsilon = 1e-10);
assert_relative_eq!(result[0].im, 0.0, epsilon = 1e-10);
}
#[test]
fn test_fft_inverse_property() {
let input = array![
Complex64::new(1.0, 0.5),
Complex64::new(0.2, -0.3),
Complex64::new(-0.1, 0.8),
Complex64::new(0.7, -0.2),
];
let fft_result = fft_1d(&input.view(), false).expect("Operation failed");
let ifft_result = fft_1d(&fft_result.view(), true).expect("Operation failed");
for i in 0..input.len() {
assert_relative_eq!(input[i].re, ifft_result[i].re, epsilon = 1e-12);
assert_relative_eq!(input[i].im, ifft_result[i].im, epsilon = 1e-12);
}
}
#[test]
fn test_rfft_1d() {
let input = array![1.0, 0.0, 0.0, 0.0];
let result = rfft_1d(&input.view()).expect("Operation failed");
assert_eq!(result.len(), 3); assert_relative_eq!(result[0].re, 1.0, epsilon = 1e-10);
assert_relative_eq!(result[0].im, 0.0, epsilon = 1e-10);
}
#[test]
fn test_irfft_1d() {
let input = array![1.0, 0.0, 0.0, 0.0];
let fft_result = rfft_1d(&input.view()).expect("Operation failed");
let reconstructed = irfft_1d(&fft_result.view(), 4).expect("Operation failed");
for i in 0..input.len() {
assert_relative_eq!(input[i], reconstructed[i], epsilon = 1e-12);
}
}
#[test]
fn test_fft_2d() {
let input = Array2::from_shape_fn((4, 4), |(i, j)| Complex64::new((i + j) as f64, 0.0));
let result = fft_2d(&input.view(), false).expect("Operation failed");
assert_eq!(result.shape(), &[4, 4]);
let reconstructed = fft_2d(&result.view(), true).expect("Operation failed");
for i in 0..4 {
for j in 0..4 {
assert_relative_eq!(input[[i, j]].re, reconstructed[[i, j]].re, epsilon = 1e-12);
assert_relative_eq!(input[[i, j]].im, reconstructed[[i, j]].im, epsilon = 1e-12);
}
}
}
#[test]
fn test_window_functions() {
let signal = array![1.0, 1.0, 1.0, 1.0];
let rect =
apply_window(&signal.view(), WindowFunction::Rectangular).expect("Operation failed");
for i in 0..signal.len() {
assert_relative_eq!(signal[i], rect[i], epsilon = 1e-10);
}
let hann = apply_window(&signal.view(), WindowFunction::Hann).expect("Operation failed");
assert_relative_eq!(hann[0], 0.0, epsilon = 1e-10);
assert_relative_eq!(hann[3], 0.0, epsilon = 1e-10);
assert!(hann[1] > 0.0);
assert!(hann[2] > 0.0);
}
#[test]
fn test_dct_1d() {
let input = array![1.0, 0.0, 0.0, 0.0];
let dct_result = dct_1d(&input.view()).expect("Operation failed");
let idct_result = idct_1d(&dct_result.view()).expect("Operation failed");
for i in 0..input.len() {
assert_relative_eq!(input[i], idct_result[i], epsilon = 1e-12);
}
}
#[test]
fn test_dst_1d() {
let input = array![1.0, 2.0, 3.0, 4.0];
let dst_result = dst_1d(&input.view()).expect("Operation failed");
assert_eq!(dst_result.len(), 4);
assert!(!dst_result.iter().all(|&x| x == 0.0));
}
#[test]
fn test_fft_convolve() {
let signal1 = array![1.0, 2.0, 3.0];
let signal2 = array![0.5, 1.5];
let result = fft_convolve(&signal1.view(), &signal2.view()).expect("Operation failed");
assert_eq!(result.len(), 4);
assert_relative_eq!(result[0], 1.0 * 0.5, epsilon = 1e-10);
}
#[test]
fn test_periodogram_psd() {
let signal = Array1::from_shape_fn(16, |i| (2.0 * PI * i as f64 / 16.0).sin());
let psd = periodogram_psd(&signal.view(), WindowFunction::Rectangular, None)
.expect("Operation failed");
assert_eq!(psd.len(), 9); assert!(psd.iter().all(|&x| x >= 0.0));
}
#[test]
fn test_welch_psd() {
let signal = Array1::from_shape_fn(64, |i| (2.0 * PI * i as f64 / 8.0).sin());
let psd =
welch_psd(&signal.view(), 16, 0.5, WindowFunction::Hann).expect("Operation failed");
assert!(!psd.is_empty());
assert!(psd.iter().all(|&x| x >= 0.0));
}
#[test]
fn test_fft_frequencies() {
let freqs = fft_frequencies(8, 1000.0, true);
assert_eq!(freqs.len(), 5);
assert_relative_eq!(freqs[0], 0.0, epsilon = 1e-10);
assert_relative_eq!(freqs[1], 125.0, epsilon = 1e-10); assert_relative_eq!(freqs[4], 500.0, epsilon = 1e-10); }
#[test]
fn test_bluestein_arbitrarysize() {
let input = array![
Complex64::new(1.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
Complex64::new(0.0, 0.0),
];
let result = bluestein_fft(&input.view(), false).expect("Operation failed");
let reconstructed = bluestein_fft(&result.view(), true).expect("Operation failed");
for i in 0..input.len() {
assert_relative_eq!(input[i].re, reconstructed[i].re, epsilon = 1e-10);
assert_relative_eq!(input[i].im, reconstructed[i].im, epsilon = 1e-10);
}
}
#[test]
fn test_fft_3d() {
let input = Array3::from_shape_fn((2, 2, 2), |(i, j, k)| {
Complex64::new((i + j + k) as f64, 0.0)
});
let result = fft_3d(&input.view(), false).expect("Operation failed");
assert_eq!(result.shape(), &[2, 2, 2]);
let reconstructed = fft_3d(&result.view(), true).expect("Operation failed");
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
assert_relative_eq!(
input[[i, j, k]].re,
reconstructed[[i, j, k]].re,
epsilon = 1e-12
);
assert_relative_eq!(
input[[i, j, k]].im,
reconstructed[[i, j, k]].im,
epsilon = 1e-12
);
}
}
}
}
#[test]
fn test_kaiser_window() {
let signal = array![1.0, 1.0, 1.0, 1.0, 1.0];
let windowed =
apply_window(&signal.view(), WindowFunction::Kaiser(2.0)).expect("Operation failed");
assert!(windowed[0] < windowed[2]); assert!(windowed[4] < windowed[2]); assert!(windowed[2] > 0.0); }
#[test]
fn test_tukey_window() {
let signal = Array1::ones(10);
let windowed =
apply_window(&signal.view(), WindowFunction::Tukey(0.5)).expect("Operation failed");
assert!(windowed[0] < windowed[5]); assert!(windowed[9] < windowed[5]); }
#[test]
fn test_hadamard_transform() {
let input = array![1.0, 1.0, 1.0, 1.0];
let result = hadamard_transform(&input.view(), false).expect("Operation failed");
assert_relative_eq!(result[0], 4.0, epsilon = 1e-12);
assert_relative_eq!(result[1], 0.0, epsilon = 1e-12);
assert_relative_eq!(result[2], 0.0, epsilon = 1e-12);
assert_relative_eq!(result[3], 0.0, epsilon = 1e-12);
let reconstructed = hadamard_transform(&result.view(), true).expect("Operation failed");
for i in 0..4 {
assert_relative_eq!(input[i], reconstructed[i], epsilon = 1e-12);
}
}
#[test]
fn test_walsh_hadamard_transform() {
let input = array![1.0, 0.0, 1.0, 0.0];
let result = walsh_hadamard_transform(&input.view(), false).expect("Operation failed");
assert_eq!(result.len(), 4);
let reconstructed =
walsh_hadamard_transform(&result.view(), true).expect("Operation failed");
for i in 0..4 {
assert_relative_eq!(input[i], reconstructed[i], epsilon = 1e-12);
}
}
#[test]
fn test_fast_walsh_transform() {
let input = array![1.0, -1.0, 1.0, -1.0];
let result = fast_walsh_transform(&input.view(), false).expect("Operation failed");
assert_eq!(result.len(), 4);
let reconstructed = fast_walsh_transform(&result.view(), true).expect("Operation failed");
for i in 0..4 {
assert_relative_eq!(input[i], reconstructed[i], epsilon = 1e-12);
}
}
#[test]
fn test_hadamard_transform_properties() {
let input = array![1.0, 2.0, 3.0, 4.0];
let transformed = hadamard_transform(&input.view(), false).expect("Operation failed");
let twice_transformed =
hadamard_transform(&transformed.view(), false).expect("Operation failed");
let n = input.len() as f64;
for i in 0..4 {
assert_relative_eq!(twice_transformed[i], n * input[i], epsilon = 1e-12);
}
}
#[test]
fn test_bit_reverse() {
assert_eq!(bit_reverse(0, 3), 0); assert_eq!(bit_reverse(1, 3), 4); assert_eq!(bit_reverse(2, 3), 2); assert_eq!(bit_reverse(3, 3), 6); assert_eq!(bit_reverse(4, 3), 1); }
}