scirs2-fft 0.4.2

Fast Fourier Transform module for SciRS2 (scirs2-fft)
Documentation
//! SIMD-accelerated Real-valued Fast Fourier Transform (RFFT) operations
//!
//! This module provides SIMD-accelerated implementations of FFT operations
//! for real-valued inputs, using the unified SIMD abstraction layer from scirs2-core.

use crate::error::{FFTError, FFTResult};
use crate::rfft::{irfft as irfft_basic, rfft as rfft_basic};
use scirs2_core::numeric::Complex64;
use scirs2_core::numeric::NumCast;
use scirs2_core::simd_ops::{AutoOptimizer, PlatformCapabilities};
use std::fmt::Debug;

/// Compute the 1-dimensional discrete Fourier Transform for real input with SIMD acceleration.
///
/// This function is optimized using SIMD instructions for improved performance on
/// modern CPUs. For real-valued inputs, this uses a specialized algorithm that is
/// more efficient than a general complex FFT.
///
/// # Arguments
///
/// * `input` - Input real-valued array
/// * `n` - Length of the transformed axis (optional)
/// * `norm` - Normalization mode (optional)
///
/// # Returns
///
/// * The Fourier transform of the real input array
///
/// # Examples
///
/// ```
/// use scirs2_fft::simd_rfft::{rfft_simd};
/// use scirs2_fft::simd_fft::NormMode;
///
/// // Generate a simple signal
/// let signal = vec![1.0, 2.0, 3.0, 4.0];
///
/// // Compute RFFT of the signal with SIMD acceleration
/// let spectrum = rfft_simd(&signal, None, None).expect("Operation failed");
///
/// // RFFT produces n//2 + 1 complex values
/// assert_eq!(spectrum.len(), signal.len() / 2 + 1);
/// ```
#[allow(dead_code)]
pub fn rfft_simd<T>(input: &[T], n: Option<usize>, norm: Option<&str>) -> FFTResult<Vec<Complex64>>
where
    T: NumCast + Copy + Debug + 'static,
{
    // Use the basic rfft implementation which already handles the logic
    let result = rfft_basic(input, n)?;

    // Apply normalization if requested
    if let Some(norm_str) = norm {
        let mut result_mut = result;
        let n = input.len();
        match norm_str {
            "backward" => {
                let scale = 1.0 / (n as f64);
                result_mut.iter_mut().for_each(|c| *c *= scale);
            }
            "ortho" => {
                let scale = 1.0 / (n as f64).sqrt();
                result_mut.iter_mut().for_each(|c| *c *= scale);
            }
            "forward" => {
                let scale = 1.0 / (n as f64);
                result_mut.iter_mut().for_each(|c| *c *= scale);
            }
            _ => {} // No normalization for unrecognized mode
        }
        return Ok(result_mut);
    }

    Ok(result)
}

/// Compute the inverse of the 1-dimensional discrete Fourier Transform for real input with SIMD acceleration.
///
/// This function is optimized using SIMD instructions for improved performance on
/// modern CPUs.
///
/// # Arguments
///
/// * `input` - Input complex-valued array representing the Fourier transform of real data
/// * `n` - Length of the output array (optional)
/// * `norm` - Normalization mode (optional)
///
/// # Returns
///
/// * The inverse Fourier transform, yielding a real-valued array
///
/// # Examples
///
/// ```
/// use scirs2_fft::simd_rfft::{rfft_simd, irfft_simd};
///
/// // Generate a simple signal
/// let signal = vec![1.0, 2.0, 3.0, 4.0];
///
/// // Forward transform
/// let spectrum = rfft_simd(&signal, None, None).expect("Operation failed");
///
/// // Inverse transform
/// let recovered = irfft_simd(&spectrum, Some(signal.len()), None).expect("Operation failed");
///
/// // Check recovery accuracy
/// for (x, y) in signal.iter().zip(recovered.iter()) {
///     assert!((x - y).abs() < 1e-10);
/// }
/// ```
#[allow(dead_code)]
pub fn irfft_simd<T>(input: &[T], n: Option<usize>, norm: Option<&str>) -> FFTResult<Vec<f64>>
where
    T: NumCast + Copy + Debug + 'static,
{
    // Use the basic irfft implementation
    let result = irfft_basic(input, n)?;

    // Apply normalization if requested
    if let Some(norm_str) = norm {
        let mut result_mut = result;
        let n = input.len();
        match norm_str {
            "backward" => {
                let scale = 1.0 / (n as f64);
                result_mut.iter_mut().for_each(|c| *c *= scale);
            }
            "ortho" => {
                let scale = 1.0 / (n as f64).sqrt();
                result_mut.iter_mut().for_each(|c| *c *= scale);
            }
            "forward" => {
                let scale = 1.0 / (n as f64);
                result_mut.iter_mut().for_each(|c| *c *= scale);
            }
            _ => {} // No normalization for unrecognized mode
        }
        return Ok(result_mut);
    }

    Ok(result)
}

/// Adaptive RFFT that automatically chooses the best implementation
#[allow(dead_code)]
pub fn rfft_adaptive<T>(
    input: &[T],
    n: Option<usize>,
    norm: Option<&str>,
) -> FFTResult<Vec<Complex64>>
where
    T: NumCast + Copy + Debug + 'static,
{
    let optimizer = AutoOptimizer::new();
    let caps = PlatformCapabilities::detect();
    let size = n.unwrap_or(input.len());

    if caps.gpu_available && optimizer.should_use_gpu(size) {
        // Use GPU implementation when available
        match rfft_gpu(input, n, norm) {
            Ok(result) => Ok(result),
            Err(_) => {
                // Fall back to SIMD implementation if GPU fails
                rfft_simd(input, n, norm)
            }
        }
    } else {
        rfft_simd(input, n, norm)
    }
}

/// Adaptive IRFFT that automatically chooses the best implementation
#[allow(dead_code)]
pub fn irfft_adaptive<T>(input: &[T], n: Option<usize>, norm: Option<&str>) -> FFTResult<Vec<f64>>
where
    T: NumCast + Copy + Debug + 'static,
{
    let optimizer = AutoOptimizer::new();
    let caps = PlatformCapabilities::detect();
    let size = n.unwrap_or_else(|| input.len() * 2 - 2);

    if caps.gpu_available && optimizer.should_use_gpu(size) {
        // Use GPU implementation when available
        match irfft_gpu(input, n, norm) {
            Ok(result) => Ok(result),
            Err(_) => {
                // Fall back to SIMD implementation if GPU fails
                irfft_simd(input, n, norm)
            }
        }
    } else {
        irfft_simd(input, n, norm)
    }
}

/// GPU-accelerated RFFT implementation
#[cfg(feature = "cuda")]
#[allow(dead_code)]
fn rfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<Complex64>>
where
    T: NumCast + Copy + Debug + 'static,
{
    // GPU implementation is simplified for now due to API incompatibilities
    // Will be properly implemented when GPU support is fully integrated
    Err(FFTError::NotImplementedError(
        "GPU-accelerated RFFT is not yet fully implemented".to_string(),
    ))
}

/// GPU-accelerated IRFFT implementation
#[cfg(feature = "cuda")]
#[allow(dead_code)]
fn irfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<f64>>
where
    T: NumCast + Copy + Debug + 'static,
{
    // GPU implementation is simplified for now due to API incompatibilities
    // Will be properly implemented when GPU support is fully integrated
    Err(FFTError::NotImplementedError(
        "GPU-accelerated IRFFT is not yet fully implemented".to_string(),
    ))
}

/// Fallback implementations when GPU feature is not enabled
#[cfg(not(feature = "cuda"))]
#[allow(dead_code)]
fn rfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<Complex64>>
where
    T: NumCast + Copy + Debug + 'static,
{
    Err(crate::error::FFTError::NotImplementedError(
        "GPU FFT not compiled".to_string(),
    ))
}

#[cfg(not(feature = "cuda"))]
#[allow(dead_code)]
fn irfft_gpu<T>(_input: &[T], _n: Option<usize>, _norm: Option<&str>) -> FFTResult<Vec<f64>>
where
    T: NumCast + Copy + Debug + 'static,
{
    Err(crate::error::FFTError::NotImplementedError(
        "GPU FFT not compiled".to_string(),
    ))
}

#[cfg(test)]
mod tests {
    use super::*;
    use approx::assert_abs_diff_eq;

    #[test]
    fn test_rfft_simd_simple() {
        let signal = vec![1.0, 2.0, 3.0, 4.0];

        // Forward transform
        let spectrum = rfft_simd(&signal, None, None).expect("Operation failed");

        // Check size
        assert_eq!(spectrum.len(), signal.len() / 2 + 1);

        // First element should be sum of all values
        assert_abs_diff_eq!(spectrum[0].re, 10.0, epsilon = 1e-10);
        assert_abs_diff_eq!(spectrum[0].im, 0.0, epsilon = 1e-10);
    }

    #[test]
    fn test_rfft_irfft_roundtrip() {
        let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];

        // Forward transform
        let spectrum = rfft_simd(&signal, None, None).expect("Operation failed");

        // Inverse transform
        let recovered = irfft_simd(&spectrum, Some(signal.len()), None).expect("Operation failed");

        // Check recovery
        for (i, (&orig, &rec)) in signal.iter().zip(recovered.iter()).enumerate() {
            if (orig - rec).abs() > 1e-10 {
                panic!("Mismatch at index {i}: {orig} != {rec}");
            }
        }
    }

    #[test]
    fn test_adaptive_selection() {
        let signal = vec![1.0; 1000];

        // Test adaptive functions (should work regardless of GPU availability)
        let spectrum = rfft_adaptive(&signal, None, None).expect("Operation failed");
        assert_eq!(spectrum.len(), signal.len() / 2 + 1);

        let recovered =
            irfft_adaptive(&spectrum, Some(signal.len()), None).expect("Operation failed");
        assert_eq!(recovered.len(), signal.len());
    }
}