ndarray_conv/conv_fft/processor/mod.rs
1//! Provides FFT processor implementations for convolution operations.
2//!
3//! This module contains traits and implementations for performing forward and backward FFT transforms
4//! on real and complex-valued arrays. These processors are used internally by the FFT-accelerated
5//! convolution methods.
6
7use std::marker::PhantomData;
8
9use ndarray::{Array, ArrayBase, DataMut, Dim, IntoDimension, Ix, RemoveAxis};
10use num::Complex;
11use rustfft::FftNum;
12
13pub mod complex;
14pub mod real;
15
16/// Marker trait for numeric types that can be used with ConvFftNum.
17///
18/// This trait is implemented for both integer and floating-point types that implement `FftNum`.
19///
20/// # Important Note
21///
22/// While this trait is implemented for integer types (i8, i16, i32, i64, i128, isize),
23/// **integer FFT operations have known accuracy issues** and should NOT be used in production.
24/// Only lengths of 2 or 4 work correctly for 1D arrays; other lengths produce incorrect results.
25///
26/// **Always use f32 or f64 for FFT operations.**
27pub trait ConvFftNum: FftNum {}
28
29macro_rules! impl_conv_fft_num {
30 ($($t:ty),*) => {
31 $(impl ConvFftNum for $t {})*
32 };
33}
34
35impl_conv_fft_num!(i8, i16, i32, i64, i128, isize, f32, f64);
36
37/// Returns a processor instance for the given input element type.
38///
39/// This function is a convenience wrapper around `GetProcessor::get_processor()`.
40///
41/// # Type Parameters
42///
43/// * `T`: The FFT numeric type (typically f32 or f64)
44/// * `InElem`: The input element type (`T` for real, `Complex<T>` for complex)
45///
46/// # Example
47///
48/// ```rust
49/// use ndarray_conv::conv_fft::processor::get as get_processor;
50///
51/// // Get a processor for f32 real values
52/// let mut proc = get_processor::<f32, f32>();
53///
54/// // Get a processor for Complex<f32> values
55/// use num::Complex;
56/// let mut proc_complex = get_processor::<f32, Complex<f32>>();
57/// ```
58pub fn get<T: FftNum, InElem: GetProcessor<T, InElem>>() -> impl Processor<T, InElem> {
59 InElem::get_processor()
60}
61
62/// Trait for FFT processors that can perform forward and backward transforms.
63///
64/// This trait defines the interface for performing FFT operations on N-dimensional arrays.
65/// Implementations exist for both real-valued and complex-valued inputs.
66pub trait Processor<T: FftNum, InElem: GetProcessor<T, InElem>> {
67 /// Performs a forward FFT transform.
68 ///
69 /// Converts the input array from the spatial/time domain to the frequency domain.
70 ///
71 /// # Arguments
72 ///
73 /// * `input`: A mutable reference to the input array
74 ///
75 /// # Returns
76 ///
77 /// An array of complex values representing the frequency domain.
78 fn forward<S: DataMut<Elem = InElem>, const N: usize>(
79 &mut self,
80 input: &mut ArrayBase<S, Dim<[Ix; N]>>,
81 ) -> Array<Complex<T>, Dim<[Ix; N]>>
82 where
83 Dim<[Ix; N]>: RemoveAxis,
84 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
85
86 /// Performs a backward (inverse) FFT transform.
87 ///
88 /// Converts the input array from the frequency domain back to the spatial/time domain.
89 ///
90 /// # Arguments
91 ///
92 /// * `input`: A mutable reference to the frequency domain array
93 ///
94 /// # Returns
95 ///
96 /// An array in the spatial/time domain with the same element type as the original input.
97 fn backward<const N: usize>(
98 &mut self,
99 input: &mut Array<Complex<T>, Dim<[Ix; N]>>,
100 ) -> Array<InElem, Dim<[Ix; N]>>
101 where
102 Dim<[Ix; N]>: RemoveAxis,
103 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
104}
105
106/// Trait for types that can provide a processor instance.
107///
108/// This trait is implemented for real and complex numeric types, allowing them to
109/// automatically select the appropriate FFT processor implementation.
110pub trait GetProcessor<T: FftNum, InElem>
111where
112 InElem: GetProcessor<T, InElem>,
113{
114 /// Returns a processor instance appropriate for this type.
115 fn get_processor() -> impl Processor<T, InElem>;
116}
117
118impl<T: ConvFftNum> GetProcessor<T, T> for T {
119 fn get_processor() -> impl Processor<T, T> {
120 real::Processor::<T>::default()
121 }
122}
123
124impl<T: FftNum> GetProcessor<T, Complex<T>> for Complex<T> {
125 fn get_processor() -> impl Processor<T, Complex<T>> {
126 complex::Processor::<T>::default()
127 }
128}