concision_neural/layers/attention/
fft.rs

1/*
2    Appellation: fft <module>
3    Contrib: @FL03
4*/
5use cnc::Forward;
6use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, ScalarOperand};
7use num_traits::{Float, FromPrimitive};
8use rustfft::num_complex::Complex;
9use rustfft::{FftNum, FftPlanner};
10
11/// FFT-based attention mechanism for temporal pattern recognition.
12///
13/// This implementation is based on "The FFT Strikes Back: Fast and Accurate
14/// Spectral-Pruning Neural Networks" (https://arxiv.org/pdf/2502.18394).
15///
16/// The mechanism:
17///
18/// 1. Transforms input to frequency domain using FFT
19/// 2. Applies soft thresholding to frequency components based on energy distribution
20/// 3. Enhances important frequency patterns
21/// 4. Returns to time domain with inverse FFT
22///
23/// The attention mechanism is parameterized by `steepness` and `threshold`, which control the
24/// sensitivity of the attention to frequency components.
25#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
26#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
27pub struct FftAttention<A = f32> {
28    pub(crate) steepness: A,
29    pub(crate) threshold: A,
30}
31
32impl<A> FftAttention<A> {
33    /// Create a new attention module with the given parameters
34    pub fn new() -> Self
35    where
36        A: FromPrimitive,
37    {
38        Self {
39            steepness: A::from_f32(10.0).unwrap(),
40            threshold: A::from_f32(0.1).unwrap(),
41        }
42    }
43    /// returns an immutable reference to the steepness of the attention module
44    pub const fn steepness(&self) -> &A {
45        &self.steepness
46    }
47    /// returns a mutable reference to the steepness of the attention module to allow for
48    /// gradient descent
49    #[inline]
50    pub fn steepness_mut(&mut self) -> &mut A {
51        &mut self.steepness
52    }
53    /// returns an immutable reference to the threshold of the attention module
54    pub const fn threshold(&self) -> &A {
55        &self.threshold
56    }
57    /// returns a mutable reference to the threshold of the attention module to allow for
58    /// gradient descent
59    #[inline]
60    pub fn threshold_mut(&mut self) -> &mut A {
61        &mut self.threshold
62    }
63    /// set the steepness of the attention mechanism
64    #[inline]
65    pub fn set_steepness(&mut self, steepness: A) {
66        self.steepness = steepness;
67    }
68    /// set the threshold of the attention mechanism
69    #[inline]
70    pub fn set_threshold(&mut self, threshold: A) {
71        self.threshold = threshold;
72    }
73    /// consumes the current instance and returns another with the given steepness
74    pub fn with_steepness(self, steepness: A) -> Self {
75        Self { steepness, ..self }
76    }
77    /// consumes the current instance and returns another with the given threshold
78    pub fn with_threshold(self, threshold: A) -> Self {
79        Self { threshold, ..self }
80    }
81    #[cfg_attr(
82        feature = "tracing",
83        tracing::instrument(
84            skip(self, input),
85            name = "forward",
86            target = "attention",
87            level = "trace",
88        )
89    )]
90    pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
91    where
92        Self: Forward<X, Output = Y>,
93    {
94        <Self as Forward<X>>::forward(self, input)
95    }
96}
97
98impl<A> Default for FftAttention<A>
99where
100    A: FromPrimitive,
101{
102    fn default() -> Self {
103        Self::new()
104    }
105}
106
107impl<A, S> Forward<ArrayBase<S, Ix1>> for FftAttention<A>
108where
109    A: FftNum + Float + FromPrimitive + ScalarOperand,
110    S: Data<Elem = A>,
111{
112    type Output = Array1<A>;
113
114    fn forward(&self, input: &ArrayBase<S, Ix1>) -> cnc::Result<Self::Output> {
115        let seq_len = input.dim();
116        let n = A::from_usize(seq_len).unwrap();
117
118        if seq_len == 0 {
119            return Err(cnc::params::ParamsError::InvalidInputShape.into());
120        }
121
122        // Create FFT planner
123        let mut planner = FftPlanner::new();
124        let fft = planner.plan_fft_forward(seq_len);
125
126        // Simplified: directly use a 1D Vec for frequency domain
127        let mut windowed_input: Vec<Complex<A>> = Vec::with_capacity(seq_len);
128
129        // declare constants for windowing
130        let n_minus_1 = A::from_usize(seq_len - 1).unwrap();
131        let pi2 = A::from_f32(2.0 * core::f32::consts::PI).unwrap();
132
133        // Apply windowing function while extracting data
134        for time_idx in 0..seq_len {
135            // Apply Hann window: 0.5 * (1 - cos(2π*i/(N-1)))
136            let window_factor = if seq_len > 1 {
137                let i_f = A::from_usize(time_idx).unwrap();
138                A::from_f32(0.5).unwrap() * (A::one() - (pi2 * i_f / n_minus_1).cos())
139            } else {
140                A::one()
141            };
142
143            // Get value and apply window
144            let val = input[time_idx] * window_factor;
145            windowed_input.push(Complex::new(val, A::zero()));
146        }
147
148        // Perform FFT in-place
149        fft.process(&mut windowed_input);
150
151        // Calculate energy and total energy directly
152        let mut freq_energy = Array1::<A>::zeros(seq_len);
153        let mut total_energy = A::zero();
154        for (time_idx, &val) in windowed_input.iter().enumerate() {
155            let energy = (val.re * val.re + val.im * val.im).sqrt();
156            freq_energy[time_idx] = energy;
157            total_energy = total_energy + energy;
158        }
159
160        // Add epsilon to prevent division by zero
161        let epsilon = A::from_f32(1e-10).unwrap();
162        total_energy = total_energy.max(epsilon);
163
164        // Clip normalized energy values to prevent sigmoid explosion
165        for time_idx in 0..seq_len {
166            // normalize energy
167            let normalized_energy = freq_energy[time_idx] / total_energy;
168
169            // Use a more stable sigmoid implementation
170            let exp_term = (-(normalized_energy - self.threshold) * self.steepness).exp();
171            let attention_weight = if exp_term.is_finite() {
172                A::one() / (A::one() + exp_term)
173            } else if (normalized_energy - self.threshold) > A::zero() {
174                A::one() // Sigmoid approaches 1 for large positive inputs
175            } else {
176                A::zero() // Sigmoid approaches 0 for large negative inputs
177            };
178
179            // Apply weight
180            windowed_input[time_idx] = Complex::new(
181                windowed_input[time_idx].re * attention_weight,
182                windowed_input[time_idx].im * attention_weight,
183            );
184        }
185
186        // Inverse FFT in-place
187        let ifft = planner.plan_fft_inverse(seq_len);
188        ifft.process(&mut windowed_input);
189
190        // Create a result array with same dimensions as input
191        let mut result = Array1::zeros(seq_len);
192        if windowed_input
193            .iter()
194            .any(|&c| c.re.is_nan() || c.im.is_nan())
195        {
196            #[cfg(feature = "tracing")]
197            tracing::warn!("The FFT/IFFT process produced NaN values.");
198        }
199        // Transfer back the processed values from frequency domain
200        for (idx, &complex) in windowed_input.iter().enumerate() {
201            // Normalize by sequence length (standard for IFFT)
202            let res = complex.re / n;
203            if res.is_nan() {
204                result[idx] = A::zero(); // Replace NaN with zero
205            } else {
206                result[idx] = res;
207            }
208        }
209
210        Ok(result)
211    }
212}
213
214impl<A, S> Forward<ArrayBase<S, Ix2>> for FftAttention<A>
215where
216    A: FftNum + Float + FromPrimitive + ScalarOperand,
217    S: Data<Elem = A>,
218{
219    type Output = Array2<A>;
220
221    fn forward(&self, input: &ArrayBase<S, Ix2>) -> cnc::Result<Self::Output> {
222        use rustfft::FftPlanner;
223        use rustfft::num_complex::Complex;
224
225        let (seq_len, feature_dim) = input.dim();
226
227        if seq_len == 0 {
228            return Err(cnc::params::ParamsError::InvalidInputShape.into());
229        }
230
231        // Create FFT planner
232        let mut planner = FftPlanner::new();
233        let fft = planner.plan_fft_forward(seq_len);
234        let mut frequency_domain = Array2::<Complex<A>>::zeros((feature_dim, seq_len));
235
236        // declare constants for windowing
237        let n_minus_1 = A::from_usize(seq_len - 1).unwrap();
238        let pi2 = A::from_f32(2.0 * core::f32::consts::PI).unwrap();
239        // Process each feature dimension
240        for feature_idx in 0..feature_dim {
241            // Extract this feature across all timesteps
242            let mut windowed_input: Vec<Complex<A>> = Vec::with_capacity(seq_len);
243
244            // Apply windowing function while extracting data
245            for time_idx in 0..seq_len {
246                // Apply Hann window: 0.5 * (1 - cos(2π*i/(N-1)))
247                let window_factor = if seq_len > 1 {
248                    let i_f = A::from_usize(time_idx).unwrap();
249                    A::from_f32(0.5).unwrap() * (A::one() - (pi2 * i_f / n_minus_1).cos())
250                } else {
251                    A::one()
252                };
253
254                // Get value and apply window
255                let val = input[[time_idx, feature_idx]] * window_factor;
256                windowed_input.push(Complex::new(val, A::zero()));
257            }
258
259            // Perform FFT
260            fft.process(&mut windowed_input);
261
262            // Store in frequency domain
263            for (time_idx, &val) in windowed_input.iter().enumerate() {
264                frequency_domain[[feature_idx, time_idx]] = val;
265            }
266        }
267
268        // Calculate frequency domain attention weights
269        let mut attention_weights = Array2::<A>::zeros((feature_dim, seq_len));
270
271        for fdx in 0..feature_dim {
272            // Calculate energy at each frequency
273            let mut total_energy = A::zero();
274            let mut freq_energy = Array1::<A>::zeros(seq_len);
275
276            for time_idx in 0..seq_len {
277                let val = frequency_domain[[fdx, time_idx]];
278                let energy = (val.re * val.re + val.im * val.im).sqrt();
279                freq_energy[time_idx] = energy;
280                total_energy = total_energy + energy;
281            }
282
283            // Normalize to create attention weights
284            if total_energy.is_positive() {
285                for time_idx in 0..seq_len {
286                    // Apply soft-thresholding as described in paper
287                    let normalized_energy = freq_energy[time_idx] / total_energy;
288
289                    // Using sigmoid to create attention weight with soft threshold
290                    attention_weights[[fdx, time_idx]] = A::one()
291                        / (A::one()
292                            + (-(normalized_energy - self.threshold) * self.steepness).exp());
293                }
294            }
295        }
296
297        // Apply attention weights in frequency domain
298        frequency_domain
299            .iter_mut()
300            .zip(attention_weights.iter())
301            .for_each(|(val, &w)| {
302                *val = Complex::new(val.re * w, val.im * w);
303            });
304
305        // Create a result array with same dimensions as input
306        let mut result = Array2::zeros((seq_len, feature_dim));
307
308        // Inverse FFT to get back to time domain with enhanced patterns
309        let ifft = planner.plan_fft_inverse(seq_len);
310
311        for fdx in 0..feature_dim {
312            let mut row = frequency_domain.row_mut(fdx);
313            let mut feature_slice = row.as_slice_mut().unwrap();
314
315            // Perform inverse FFT
316            ifft.process(&mut feature_slice);
317
318            // Transfer back to result array, preserving time dimension
319            for (t_idx, complex) in feature_slice.iter().enumerate() {
320                // Normalize by sequence length (standard for IFFT)
321                result[[t_idx, fdx]] = complex.re / A::from_f32(seq_len as f32).unwrap();
322            }
323        }
324
325        Ok(result)
326    }
327}