Skip to main content

ndarray_conv/conv_fft/processor/
mod.rs

1//! Provides FFT processor implementations for convolution operations.
2//!
3//! This module contains traits and implementations for performing forward and backward FFT transforms
4//! on real and complex-valued arrays. These processors are used internally by the FFT-accelerated
5//! convolution methods.
6
7use std::marker::PhantomData;
8
9use ndarray::{Array, ArrayBase, DataMut, Dim, IntoDimension, Ix, RemoveAxis};
10use num::Complex;
11use rustfft::FftNum;
12
13/// When the `rayon` feature is enabled, processors must be `Send` so they can
14/// be used across thread boundaries.  Without rayon there is no such requirement.
15/// This trait provides the conditional bound without duplicating any code.
16#[cfg(feature = "rayon")]
17pub trait MaybeSend: Send {}
18#[cfg(feature = "rayon")]
19impl<T: Send> MaybeSend for T {}
20
21#[cfg(not(feature = "rayon"))]
22pub trait MaybeSend {}
23#[cfg(not(feature = "rayon"))]
24impl<T> MaybeSend for T {}
25
26/// Same pattern as [`MaybeSend`] but for `Sync`.
27/// Required so that shared references to data/kernel arrays can cross thread
28/// boundaries inside `rayon::join` closures.
29#[cfg(feature = "rayon")]
30pub trait MaybeSync: Sync {}
31#[cfg(feature = "rayon")]
32impl<T: Sync> MaybeSync for T {}
33
34#[cfg(not(feature = "rayon"))]
35pub trait MaybeSync {}
36#[cfg(not(feature = "rayon"))]
37impl<T> MaybeSync for T {}
38
39pub mod complex;
40pub mod real;
41
42/// Marker trait for numeric types that can be used with ConvFftNum.
43///
44/// This trait is implemented for both integer and floating-point types that implement `FftNum`.
45///
46/// # Important Note
47///
48/// While this trait is implemented for integer types (i8, i16, i32, i64, i128, isize),
49/// **integer FFT operations have known accuracy issues** and should NOT be used in production.
50/// Only lengths of 2 or 4 work correctly for 1D arrays; other lengths produce incorrect results.
51///
52/// **Always use f32 or f64 for FFT operations.**
53pub trait ConvFftNum: FftNum {}
54
55macro_rules! impl_conv_fft_num {
56    ($($t:ty),*) => {
57        $(impl ConvFftNum for $t {})*
58    };
59}
60
61impl_conv_fft_num!(i8, i16, i32, i64, i128, isize, f32, f64);
62
63/// Returns a processor instance for the given input element type.
64///
65/// This function is a convenience wrapper around `GetProcessor::get_processor()`.
66///
67/// # Type Parameters
68///
69/// * `T`: The FFT numeric type (typically f32 or f64)
70/// * `InElem`: The input element type (`T` for real, `Complex<T>` for complex)
71pub fn get<T: FftNum, InElem: GetProcessor<T, InElem>>() -> impl Processor<T, InElem> {
72    InElem::get_processor()
73}
74
75/// Trait for FFT processors that can perform forward and backward transforms.
76///
77/// This trait defines the interface for performing FFT operations on N-dimensional arrays.
78/// Implementations exist for both real-valued and complex-valued inputs.
79pub trait Processor<T: FftNum, InElem: GetProcessor<T, InElem>>: MaybeSend {
80    /// Performs a forward FFT transform.
81    ///
82    /// Converts the input array from the spatial/time domain to the frequency domain.
83    ///
84    /// # Arguments
85    ///
86    /// * `input`: A mutable reference to the input array
87    ///
88    /// # Returns
89    ///
90    /// An array of complex values representing the frequency domain.
91    fn forward<S: DataMut<Elem = InElem>, const N: usize>(
92        &mut self,
93        input: &mut ArrayBase<S, Dim<[Ix; N]>>,
94        parallel: bool,
95    ) -> Array<Complex<T>, Dim<[Ix; N]>>
96    where
97        Dim<[Ix; N]>: RemoveAxis,
98        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
99
100    /// Performs a backward (inverse) FFT transform.
101    ///
102    /// Converts the input array from the frequency domain back to the spatial/time domain.
103    ///
104    /// # Arguments
105    ///
106    /// * `input`: A mutable reference to the frequency domain array
107    ///
108    /// # Returns
109    ///
110    /// An array in the spatial/time domain with the same element type as the original input.
111    fn backward<const N: usize>(
112        &mut self,
113        input: &mut Array<Complex<T>, Dim<[Ix; N]>>,
114        parallel: bool,
115    ) -> Array<InElem, Dim<[Ix; N]>>
116    where
117        Dim<[Ix; N]>: RemoveAxis,
118        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
119}
120
121/// Trait for types that can provide a processor instance.
122///
123/// This trait is implemented for real and complex numeric types, allowing them to
124/// automatically select the appropriate FFT processor implementation.
125pub trait GetProcessor<T: FftNum, InElem>: MaybeSend
126where
127    InElem: GetProcessor<T, InElem>,
128{
129    /// Returns a serial processor instance appropriate for this type.
130    fn get_processor() -> impl Processor<T, InElem>;
131}
132
133impl<T: ConvFftNum> GetProcessor<T, T> for T {
134    fn get_processor() -> impl Processor<T, T> {
135        real::Processor::<T>::default()
136    }
137}
138
139impl<T: FftNum> GetProcessor<T, Complex<T>> for Complex<T> {
140    fn get_processor() -> impl Processor<T, Complex<T>> {
141        complex::Processor::<T>::default()
142    }
143}