ndarray_conv/conv_fft/processor/
mod.rs

1//! Provides FFT processor implementations for convolution operations.
2//!
3//! This module contains traits and implementations for performing forward and backward FFT transforms
4//! on real and complex-valued arrays. These processors are used internally by the FFT-accelerated
5//! convolution methods.
6
7use std::marker::PhantomData;
8
9use ndarray::{Array, ArrayBase, DataMut, Dim, IntoDimension, Ix, RemoveAxis};
10use num::Complex;
11use rustfft::FftNum;
12
13pub mod complex;
14pub mod real;
15
16/// Marker trait for numeric types that can be used with ConvFftNum.
17///
18/// This trait is implemented for both integer and floating-point types that implement `FftNum`.
19///
20/// # Important Note
21///
22/// While this trait is implemented for integer types (i8, i16, i32, i64, i128, isize),
23/// **integer FFT operations have known accuracy issues** and should NOT be used in production.
24/// Only lengths of 2 or 4 work correctly for 1D arrays; other lengths produce incorrect results.
25///
26/// **Always use f32 or f64 for FFT operations.**
27pub trait ConvFftNum: FftNum {}
28
29macro_rules! impl_conv_fft_num {
30    ($($t:ty),*) => {
31        $(impl ConvFftNum for $t {})*
32    };
33}
34
35impl_conv_fft_num!(i8, i16, i32, i64, i128, isize, f32, f64);
36
37/// Returns a processor instance for the given input element type.
38///
39/// This function is a convenience wrapper around `GetProcessor::get_processor()`.
40///
41/// # Type Parameters
42///
43/// * `T`: The FFT numeric type (typically f32 or f64)
44/// * `InElem`: The input element type (`T` for real, `Complex<T>` for complex)
45pub fn get<T: FftNum, InElem: GetProcessor<T, InElem>>() -> impl Processor<T, InElem> {
46    InElem::get_processor()
47}
48
49/// Trait for FFT processors that can perform forward and backward transforms.
50///
51/// This trait defines the interface for performing FFT operations on N-dimensional arrays.
52/// Implementations exist for both real-valued and complex-valued inputs.
53pub trait Processor<T: FftNum, InElem: GetProcessor<T, InElem>> {
54    /// Performs a forward FFT transform.
55    ///
56    /// Converts the input array from the spatial/time domain to the frequency domain.
57    ///
58    /// # Arguments
59    ///
60    /// * `input`: A mutable reference to the input array
61    ///
62    /// # Returns
63    ///
64    /// An array of complex values representing the frequency domain.
65    fn forward<S: DataMut<Elem = InElem>, const N: usize>(
66        &mut self,
67        input: &mut ArrayBase<S, Dim<[Ix; N]>>,
68    ) -> Array<Complex<T>, Dim<[Ix; N]>>
69    where
70        Dim<[Ix; N]>: RemoveAxis,
71        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
72
73    /// Performs a backward (inverse) FFT transform.
74    ///
75    /// Converts the input array from the frequency domain back to the spatial/time domain.
76    ///
77    /// # Arguments
78    ///
79    /// * `input`: A mutable reference to the frequency domain array
80    ///
81    /// # Returns
82    ///
83    /// An array in the spatial/time domain with the same element type as the original input.
84    fn backward<const N: usize>(
85        &mut self,
86        input: &mut Array<Complex<T>, Dim<[Ix; N]>>,
87    ) -> Array<InElem, Dim<[Ix; N]>>
88    where
89        Dim<[Ix; N]>: RemoveAxis,
90        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
91}
92
93/// Trait for types that can provide a processor instance.
94///
95/// This trait is implemented for real and complex numeric types, allowing them to
96/// automatically select the appropriate FFT processor implementation.
97pub trait GetProcessor<T: FftNum, InElem>
98where
99    InElem: GetProcessor<T, InElem>,
100{
101    /// Returns a processor instance appropriate for this type.
102    fn get_processor() -> impl Processor<T, InElem>;
103}
104
105impl<T: ConvFftNum> GetProcessor<T, T> for T {
106    fn get_processor() -> impl Processor<T, T> {
107        real::Processor::<T>::default()
108    }
109}
110
111impl<T: FftNum> GetProcessor<T, Complex<T>> for Complex<T> {
112    fn get_processor() -> impl Processor<T, Complex<T>> {
113        complex::Processor::<T>::default()
114    }
115}