eryon_surface/model/
attention.rs

1/*
2    Appellation: attention <module>
3    Contrib: @FL03
4*/
5
6use cnc::Forward;
7use ndarray::{Array1, Array2, ArrayBase, Data, ScalarOperand};
8use num_traits::{Float, FromPrimitive};
9use rustfft::num_complex::Complex;
10use rustfft::{FftNum, FftPlanner};
11
12/// FFT-based attention mechanism for temporal pattern recognition.
13///
14/// This implementation is based on "The FFT Strikes Back: Fast and Accurate
15/// Spectral-Pruning Neural Networks" (https://arxiv.org/pdf/2502.18394).
16///
17/// The mechanism:
18///
19/// 1. Transforms input to frequency domain using FFT
20/// 2. Applies soft thresholding to frequency components based on energy distribution
21/// 3. Enhances important frequency patterns
22/// 4. Returns to time domain with inverse FFT
23///
24/// The attention mechanism is parameterized by `steepness` and `threshold`, which control the
25/// sensitivity of the attention to frequency components.
26#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
27#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
28pub struct FftAttention<A = f32> {
29    pub(crate) steepness: A,
30    pub(crate) threshold: A,
31}
32
33impl<A> FftAttention<A> {
34    /// Create a new attention module with the given parameters
35    pub fn new() -> Self
36    where
37        A: FromPrimitive,
38    {
39        Self {
40            steepness: A::from_f32(10.0).unwrap(),
41            threshold: A::from_f32(0.1).unwrap(),
42        }
43    }
44    /// returns an immutable reference to the steepness of the attention module
45    pub const fn steepness(&self) -> &A {
46        &self.steepness
47    }
48    /// returns a mutable reference to the steepness of the attention module to allow for
49    /// gradient descent
50    #[inline]
51    pub fn steepness_mut(&mut self) -> &mut A {
52        &mut self.steepness
53    }
54    /// returns an immutable reference to the threshold of the attention module
55    pub const fn threshold(&self) -> &A {
56        &self.threshold
57    }
58    /// returns a mutable reference to the threshold of the attention module to allow for
59    /// gradient descent
60    #[inline]
61    pub fn threshold_mut(&mut self) -> &mut A {
62        &mut self.threshold
63    }
64    /// set the steepness of the attention mechanism
65    #[inline]
66    pub fn set_steepness(&mut self, steepness: A) {
67        self.steepness = steepness;
68    }
69    /// set the threshold of the attention mechanism
70    #[inline]
71    pub fn set_threshold(&mut self, threshold: A) {
72        self.threshold = threshold;
73    }
74    /// consumes the current instance and returns another with the given steepness
75    pub fn with_steepness(self, steepness: A) -> Self {
76        Self { steepness, ..self }
77    }
78    /// consumes the current instance and returns another with the given threshold
79    pub fn with_threshold(self, threshold: A) -> Self {
80        Self { threshold, ..self }
81    }
82    #[cfg_attr(
83        feature = "tracing",
84        tracing::instrument(
85            skip(self, input),
86            name = "forward",
87            target = "attention",
88            level = "trace",
89        )
90    )]
91    pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
92    where
93        Self: Forward<X, Output = Y>,
94    {
95        <Self as Forward<X>>::forward(self, input)
96    }
97}
98
99impl<A> Default for FftAttention<A>
100where
101    A: FromPrimitive,
102{
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108impl<A, S> Forward<ArrayBase<S, ndarray::Ix1>> for FftAttention<A>
109where
110    A: FftNum + Float + FromPrimitive + ScalarOperand,
111    S: Data<Elem = A>,
112{
113    type Output = ndarray::Array1<A>;
114
115    fn forward(&self, input: &ArrayBase<S, ndarray::Ix1>) -> cnc::Result<Self::Output> {
116        let seq_len = input.dim();
117        let n = A::from_usize(seq_len).unwrap();
118
119        if seq_len == 0 {
120            return Err(cnc::params::error::ParamsError::InvalidInputShape.into());
121        }
122
123        // Create FFT planner
124        let mut planner = FftPlanner::new();
125        let fft = planner.plan_fft_forward(seq_len);
126
127        // Simplified: directly use a 1D Vec for frequency domain
128        let mut windowed_input: Vec<Complex<A>> = Vec::with_capacity(seq_len);
129
130        // declare constants for windowing
131        let n_minus_1 = A::from_usize(seq_len - 1).unwrap();
132        let pi2 = A::from_f32(2.0 * core::f32::consts::PI).unwrap();
133
134        // Apply windowing function while extracting data
135        for time_idx in 0..seq_len {
136            // Apply Hann window: 0.5 * (1 - cos(2π*i/(N-1)))
137            let window_factor = if seq_len > 1 {
138                let i_f = A::from_usize(time_idx).unwrap();
139                A::from_f32(0.5).unwrap() * (A::one() - (pi2 * i_f / n_minus_1).cos())
140            } else {
141                A::one()
142            };
143
144            // Get value and apply window
145            let val = input[time_idx] * window_factor;
146            windowed_input.push(Complex::new(val, A::zero()));
147        }
148
149        // Perform FFT in-place
150        fft.process(&mut windowed_input);
151
152        // Calculate energy and total energy directly
153        let mut freq_energy = Array1::<A>::zeros(seq_len);
154        let mut total_energy = A::zero();
155        for (time_idx, &val) in windowed_input.iter().enumerate() {
156            let energy = (val.re * val.re + val.im * val.im).sqrt();
157            freq_energy[time_idx] = energy;
158            total_energy = total_energy + energy;
159        }
160
161        // Add epsilon to prevent division by zero
162        let epsilon = A::from_f32(1e-10).unwrap();
163        total_energy = total_energy.max(epsilon);
164
165        // Clip normalized energy values to prevent sigmoid explosion
166        for time_idx in 0..seq_len {
167            // normalize energy
168            let normalized_energy = freq_energy[time_idx] / total_energy;
169
170            // Use a more stable sigmoid implementation
171            let exp_term = (-(normalized_energy - self.threshold) * self.steepness).exp();
172            let attention_weight = if exp_term.is_finite() {
173                A::one() / (A::one() + exp_term)
174            } else if (normalized_energy - self.threshold) > A::zero() {
175                A::one() // Sigmoid approaches 1 for large positive inputs
176            } else {
177                A::zero() // Sigmoid approaches 0 for large negative inputs
178            };
179
180            // Apply weight
181            windowed_input[time_idx] = Complex::new(
182                windowed_input[time_idx].re * attention_weight,
183                windowed_input[time_idx].im * attention_weight,
184            );
185        }
186
187        // Inverse FFT in-place
188        let ifft = planner.plan_fft_inverse(seq_len);
189        ifft.process(&mut windowed_input);
190
191        // Create a result array with same dimensions as input
192        let mut result = Array1::zeros(seq_len);
193        if windowed_input
194            .iter()
195            .any(|&c| c.re.is_nan() || c.im.is_nan())
196        {
197            #[cfg(feature = "tracing")]
198            tracing::warn!("The FFT/IFFT process produced NaN values.");
199        }
200        // Transfer back the processed values from frequency domain
201        for (idx, &complex) in windowed_input.iter().enumerate() {
202            // Normalize by sequence length (standard for IFFT)
203            let res = complex.re / n;
204            if res.is_nan() {
205                result[idx] = A::zero(); // Replace NaN with zero
206            } else {
207                result[idx] = res;
208            }
209        }
210
211        Ok(result)
212    }
213}
214
215impl<A, S> Forward<ArrayBase<S, ndarray::Ix2>> for FftAttention<A>
216where
217    A: FftNum + Float + FromPrimitive + ScalarOperand,
218    S: Data<Elem = A>,
219{
220    type Output = ndarray::Array2<A>;
221
222    fn forward(&self, input: &ArrayBase<S, ndarray::Ix2>) -> cnc::Result<Self::Output> {
223        use rustfft::FftPlanner;
224        use rustfft::num_complex::Complex;
225
226        let (seq_len, feature_dim) = input.dim();
227
228        if seq_len == 0 {
229            return Err(cnc::params::error::ParamsError::InvalidInputShape.into());
230        }
231
232        // Create FFT planner
233        let mut planner = FftPlanner::new();
234        let fft = planner.plan_fft_forward(seq_len);
235        let mut frequency_domain = Array2::<Complex<A>>::zeros((feature_dim, seq_len));
236
237        // declare constants for windowing
238        let n_minus_1 = A::from_usize(seq_len - 1).unwrap();
239        let pi2 = A::from_f32(2.0 * core::f32::consts::PI).unwrap();
240        // Process each feature dimension
241        for feature_idx in 0..feature_dim {
242            // Extract this feature across all timesteps
243            let mut windowed_input: Vec<Complex<A>> = Vec::with_capacity(seq_len);
244
245            // Apply windowing function while extracting data
246            for time_idx in 0..seq_len {
247                // Apply Hann window: 0.5 * (1 - cos(2π*i/(N-1)))
248                let window_factor = if seq_len > 1 {
249                    let i_f = A::from_usize(time_idx).unwrap();
250                    A::from_f32(0.5).unwrap() * (A::one() - (pi2 * i_f / n_minus_1).cos())
251                } else {
252                    A::one()
253                };
254
255                // Get value and apply window
256                let val = input[[time_idx, feature_idx]] * window_factor;
257                windowed_input.push(Complex::new(val, A::zero()));
258            }
259
260            // Perform FFT
261            fft.process(&mut windowed_input);
262
263            // Store in frequency domain
264            for (time_idx, &val) in windowed_input.iter().enumerate() {
265                frequency_domain[[feature_idx, time_idx]] = val;
266            }
267        }
268
269        // Calculate frequency domain attention weights
270        let mut attention_weights = Array2::<A>::zeros((feature_dim, seq_len));
271
272        for fdx in 0..feature_dim {
273            // Calculate energy at each frequency
274            let mut total_energy = A::zero();
275            let mut freq_energy = Array1::<A>::zeros(seq_len);
276
277            for time_idx in 0..seq_len {
278                let val = frequency_domain[[fdx, time_idx]];
279                let energy = (val.re * val.re + val.im * val.im).sqrt();
280                freq_energy[time_idx] = energy;
281                total_energy = total_energy + energy;
282            }
283
284            // Normalize to create attention weights
285            if total_energy.is_positive() {
286                for time_idx in 0..seq_len {
287                    // Apply soft-thresholding as described in paper
288                    let normalized_energy = freq_energy[time_idx] / total_energy;
289
290                    // Using sigmoid to create attention weight with soft threshold
291                    attention_weights[[fdx, time_idx]] = A::one()
292                        / (A::one()
293                            + (-(normalized_energy - self.threshold) * self.steepness).exp());
294                }
295            }
296        }
297
298        // Apply attention weights in frequency domain
299        frequency_domain
300            .iter_mut()
301            .zip(attention_weights.iter())
302            .for_each(|(val, &w)| {
303                *val = Complex::new(val.re * w, val.im * w);
304            });
305
306        // Create a result array with same dimensions as input
307        let mut result = Array2::zeros((seq_len, feature_dim));
308
309        // Inverse FFT to get back to time domain with enhanced patterns
310        let ifft = planner.plan_fft_inverse(seq_len);
311
312        for fdx in 0..feature_dim {
313            let mut row = frequency_domain.row_mut(fdx);
314            let feature_slice = row.as_slice_mut().unwrap();
315
316            // Perform inverse FFT
317            ifft.process(feature_slice);
318
319            // Transfer back to result array, preserving time dimension
320            for (t_idx, complex) in feature_slice.iter().enumerate() {
321                // Normalize by sequence length (standard for IFFT)
322                result[[t_idx, fdx]] = complex.re / A::from_f32(seq_len as f32).unwrap();
323            }
324        }
325
326        Ok(result)
327    }
328}