ndarray_conv/conv_fft/processor/mod.rs
1//! Provides FFT processor implementations for convolution operations.
2//!
3//! This module contains traits and implementations for performing forward and backward FFT transforms
4//! on real and complex-valued arrays. These processors are used internally by the FFT-accelerated
5//! convolution methods.
6
7use std::marker::PhantomData;
8
9use ndarray::{Array, ArrayBase, DataMut, Dim, IntoDimension, Ix, RemoveAxis};
10use num::Complex;
11use rustfft::FftNum;
12
13/// When the `rayon` feature is enabled, processors must be `Send` so they can
14/// be used across thread boundaries. Without rayon there is no such requirement.
15/// This trait provides the conditional bound without duplicating any code.
16#[cfg(feature = "rayon")]
17pub trait MaybeSend: Send {}
18#[cfg(feature = "rayon")]
19impl<T: Send> MaybeSend for T {}
20
21#[cfg(not(feature = "rayon"))]
22pub trait MaybeSend {}
23#[cfg(not(feature = "rayon"))]
24impl<T> MaybeSend for T {}
25
26/// Same pattern as [`MaybeSend`] but for `Sync`.
27/// Required so that shared references to data/kernel arrays can cross thread
28/// boundaries inside `rayon::join` closures.
29#[cfg(feature = "rayon")]
30pub trait MaybeSync: Sync {}
31#[cfg(feature = "rayon")]
32impl<T: Sync> MaybeSync for T {}
33
34#[cfg(not(feature = "rayon"))]
35pub trait MaybeSync {}
36#[cfg(not(feature = "rayon"))]
37impl<T> MaybeSync for T {}
38
39pub mod complex;
40pub mod real;
41
42/// Marker trait for numeric types that can be used with ConvFftNum.
43///
44/// This trait is implemented for both integer and floating-point types that implement `FftNum`.
45///
46/// # Important Note
47///
48/// While this trait is implemented for integer types (i8, i16, i32, i64, i128, isize),
49/// **integer FFT operations have known accuracy issues** and should NOT be used in production.
50/// Only lengths of 2 or 4 work correctly for 1D arrays; other lengths produce incorrect results.
51///
52/// **Always use f32 or f64 for FFT operations.**
53pub trait ConvFftNum: FftNum {}
54
55macro_rules! impl_conv_fft_num {
56 ($($t:ty),*) => {
57 $(impl ConvFftNum for $t {})*
58 };
59}
60
61impl_conv_fft_num!(i8, i16, i32, i64, i128, isize, f32, f64);
62
63/// Returns a processor instance for the given input element type.
64///
65/// This function is a convenience wrapper around `GetProcessor::get_processor()`.
66///
67/// # Type Parameters
68///
69/// * `T`: The FFT numeric type (typically f32 or f64)
70/// * `InElem`: The input element type (`T` for real, `Complex<T>` for complex)
71pub fn get<T: FftNum, InElem: GetProcessor<T, InElem>>() -> impl Processor<T, InElem> {
72 InElem::get_processor()
73}
74
75/// Trait for FFT processors that can perform forward and backward transforms.
76///
77/// This trait defines the interface for performing FFT operations on N-dimensional arrays.
78/// Implementations exist for both real-valued and complex-valued inputs.
79pub trait Processor<T: FftNum, InElem: GetProcessor<T, InElem>>: MaybeSend {
80 /// Performs a forward FFT transform.
81 ///
82 /// Converts the input array from the spatial/time domain to the frequency domain.
83 ///
84 /// # Arguments
85 ///
86 /// * `input`: A mutable reference to the input array
87 ///
88 /// # Returns
89 ///
90 /// An array of complex values representing the frequency domain.
91 fn forward<S: DataMut<Elem = InElem>, const N: usize>(
92 &mut self,
93 input: &mut ArrayBase<S, Dim<[Ix; N]>>,
94 parallel: bool,
95 ) -> Array<Complex<T>, Dim<[Ix; N]>>
96 where
97 Dim<[Ix; N]>: RemoveAxis,
98 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
99
100 /// Performs a backward (inverse) FFT transform.
101 ///
102 /// Converts the input array from the frequency domain back to the spatial/time domain.
103 ///
104 /// # Arguments
105 ///
106 /// * `input`: A mutable reference to the frequency domain array
107 ///
108 /// # Returns
109 ///
110 /// An array in the spatial/time domain with the same element type as the original input.
111 fn backward<const N: usize>(
112 &mut self,
113 input: &mut Array<Complex<T>, Dim<[Ix; N]>>,
114 parallel: bool,
115 ) -> Array<InElem, Dim<[Ix; N]>>
116 where
117 Dim<[Ix; N]>: RemoveAxis,
118 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
119}
120
121/// Trait for types that can provide a processor instance.
122///
123/// This trait is implemented for real and complex numeric types, allowing them to
124/// automatically select the appropriate FFT processor implementation.
125pub trait GetProcessor<T: FftNum, InElem>: MaybeSend
126where
127 InElem: GetProcessor<T, InElem>,
128{
129 /// Returns a serial processor instance appropriate for this type.
130 fn get_processor() -> impl Processor<T, InElem>;
131}
132
133impl<T: ConvFftNum> GetProcessor<T, T> for T {
134 fn get_processor() -> impl Processor<T, T> {
135 real::Processor::<T>::default()
136 }
137}
138
139impl<T: FftNum> GetProcessor<T, Complex<T>> for Complex<T> {
140 fn get_processor() -> impl Processor<T, Complex<T>> {
141 complex::Processor::<T>::default()
142 }
143}