concision_neural/layers/attention/
fft.rs1use cnc::Forward;
6use ndarray::{Array1, Array2, ArrayBase, Data, Ix1, Ix2, ScalarOperand};
7use num_traits::{Float, FromPrimitive};
8use rustfft::num_complex::Complex;
9use rustfft::{FftNum, FftPlanner};
10
11#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
26#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
27pub struct FftAttention<A = f32> {
28 pub(crate) steepness: A,
29 pub(crate) threshold: A,
30}
31
32impl<A> FftAttention<A> {
33 pub fn new() -> Self
35 where
36 A: FromPrimitive,
37 {
38 Self {
39 steepness: A::from_f32(10.0).unwrap(),
40 threshold: A::from_f32(0.1).unwrap(),
41 }
42 }
43 pub const fn steepness(&self) -> &A {
45 &self.steepness
46 }
47 #[inline]
50 pub fn steepness_mut(&mut self) -> &mut A {
51 &mut self.steepness
52 }
53 pub const fn threshold(&self) -> &A {
55 &self.threshold
56 }
57 #[inline]
60 pub fn threshold_mut(&mut self) -> &mut A {
61 &mut self.threshold
62 }
63 #[inline]
65 pub fn set_steepness(&mut self, steepness: A) {
66 self.steepness = steepness;
67 }
68 #[inline]
70 pub fn set_threshold(&mut self, threshold: A) {
71 self.threshold = threshold;
72 }
73 pub fn with_steepness(self, steepness: A) -> Self {
75 Self { steepness, ..self }
76 }
77 pub fn with_threshold(self, threshold: A) -> Self {
79 Self { threshold, ..self }
80 }
81 #[cfg_attr(
82 feature = "tracing",
83 tracing::instrument(
84 skip(self, input),
85 name = "forward",
86 target = "attention",
87 level = "trace",
88 )
89 )]
90 pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
91 where
92 Self: Forward<X, Output = Y>,
93 {
94 <Self as Forward<X>>::forward(self, input)
95 }
96}
97
98impl<A> Default for FftAttention<A>
99where
100 A: FromPrimitive,
101{
102 fn default() -> Self {
103 Self::new()
104 }
105}
106
107impl<A, S> Forward<ArrayBase<S, Ix1>> for FftAttention<A>
108where
109 A: FftNum + Float + FromPrimitive + ScalarOperand,
110 S: Data<Elem = A>,
111{
112 type Output = Array1<A>;
113
114 fn forward(&self, input: &ArrayBase<S, Ix1>) -> cnc::Result<Self::Output> {
115 let seq_len = input.dim();
116 let n = A::from_usize(seq_len).unwrap();
117
118 if seq_len == 0 {
119 return Err(cnc::params::ParamsError::InvalidInputShape.into());
120 }
121
122 let mut planner = FftPlanner::new();
124 let fft = planner.plan_fft_forward(seq_len);
125
126 let mut windowed_input: Vec<Complex<A>> = Vec::with_capacity(seq_len);
128
129 let n_minus_1 = A::from_usize(seq_len - 1).unwrap();
131 let pi2 = A::from_f32(2.0 * core::f32::consts::PI).unwrap();
132
133 for time_idx in 0..seq_len {
135 let window_factor = if seq_len > 1 {
137 let i_f = A::from_usize(time_idx).unwrap();
138 A::from_f32(0.5).unwrap() * (A::one() - (pi2 * i_f / n_minus_1).cos())
139 } else {
140 A::one()
141 };
142
143 let val = input[time_idx] * window_factor;
145 windowed_input.push(Complex::new(val, A::zero()));
146 }
147
148 fft.process(&mut windowed_input);
150
151 let mut freq_energy = Array1::<A>::zeros(seq_len);
153 let mut total_energy = A::zero();
154 for (time_idx, &val) in windowed_input.iter().enumerate() {
155 let energy = (val.re * val.re + val.im * val.im).sqrt();
156 freq_energy[time_idx] = energy;
157 total_energy = total_energy + energy;
158 }
159
160 let epsilon = A::from_f32(1e-10).unwrap();
162 total_energy = total_energy.max(epsilon);
163
164 for time_idx in 0..seq_len {
166 let normalized_energy = freq_energy[time_idx] / total_energy;
168
169 let exp_term = (-(normalized_energy - self.threshold) * self.steepness).exp();
171 let attention_weight = if exp_term.is_finite() {
172 A::one() / (A::one() + exp_term)
173 } else if (normalized_energy - self.threshold) > A::zero() {
174 A::one() } else {
176 A::zero() };
178
179 windowed_input[time_idx] = Complex::new(
181 windowed_input[time_idx].re * attention_weight,
182 windowed_input[time_idx].im * attention_weight,
183 );
184 }
185
186 let ifft = planner.plan_fft_inverse(seq_len);
188 ifft.process(&mut windowed_input);
189
190 let mut result = Array1::zeros(seq_len);
192 if windowed_input
193 .iter()
194 .any(|&c| c.re.is_nan() || c.im.is_nan())
195 {
196 #[cfg(feature = "tracing")]
197 tracing::warn!("The FFT/IFFT process produced NaN values.");
198 }
199 for (idx, &complex) in windowed_input.iter().enumerate() {
201 let res = complex.re / n;
203 if res.is_nan() {
204 result[idx] = A::zero(); } else {
206 result[idx] = res;
207 }
208 }
209
210 Ok(result)
211 }
212}
213
214impl<A, S> Forward<ArrayBase<S, Ix2>> for FftAttention<A>
215where
216 A: FftNum + Float + FromPrimitive + ScalarOperand,
217 S: Data<Elem = A>,
218{
219 type Output = Array2<A>;
220
221 fn forward(&self, input: &ArrayBase<S, Ix2>) -> cnc::Result<Self::Output> {
222 use rustfft::FftPlanner;
223 use rustfft::num_complex::Complex;
224
225 let (seq_len, feature_dim) = input.dim();
226
227 if seq_len == 0 {
228 return Err(cnc::params::ParamsError::InvalidInputShape.into());
229 }
230
231 let mut planner = FftPlanner::new();
233 let fft = planner.plan_fft_forward(seq_len);
234 let mut frequency_domain = Array2::<Complex<A>>::zeros((feature_dim, seq_len));
235
236 let n_minus_1 = A::from_usize(seq_len - 1).unwrap();
238 let pi2 = A::from_f32(2.0 * core::f32::consts::PI).unwrap();
239 for feature_idx in 0..feature_dim {
241 let mut windowed_input: Vec<Complex<A>> = Vec::with_capacity(seq_len);
243
244 for time_idx in 0..seq_len {
246 let window_factor = if seq_len > 1 {
248 let i_f = A::from_usize(time_idx).unwrap();
249 A::from_f32(0.5).unwrap() * (A::one() - (pi2 * i_f / n_minus_1).cos())
250 } else {
251 A::one()
252 };
253
254 let val = input[[time_idx, feature_idx]] * window_factor;
256 windowed_input.push(Complex::new(val, A::zero()));
257 }
258
259 fft.process(&mut windowed_input);
261
262 for (time_idx, &val) in windowed_input.iter().enumerate() {
264 frequency_domain[[feature_idx, time_idx]] = val;
265 }
266 }
267
268 let mut attention_weights = Array2::<A>::zeros((feature_dim, seq_len));
270
271 for fdx in 0..feature_dim {
272 let mut total_energy = A::zero();
274 let mut freq_energy = Array1::<A>::zeros(seq_len);
275
276 for time_idx in 0..seq_len {
277 let val = frequency_domain[[fdx, time_idx]];
278 let energy = (val.re * val.re + val.im * val.im).sqrt();
279 freq_energy[time_idx] = energy;
280 total_energy = total_energy + energy;
281 }
282
283 if total_energy.is_positive() {
285 for time_idx in 0..seq_len {
286 let normalized_energy = freq_energy[time_idx] / total_energy;
288
289 attention_weights[[fdx, time_idx]] = A::one()
291 / (A::one()
292 + (-(normalized_energy - self.threshold) * self.steepness).exp());
293 }
294 }
295 }
296
297 frequency_domain
299 .iter_mut()
300 .zip(attention_weights.iter())
301 .for_each(|(val, &w)| {
302 *val = Complex::new(val.re * w, val.im * w);
303 });
304
305 let mut result = Array2::zeros((seq_len, feature_dim));
307
308 let ifft = planner.plan_fft_inverse(seq_len);
310
311 for fdx in 0..feature_dim {
312 let mut row = frequency_domain.row_mut(fdx);
313 let mut feature_slice = row.as_slice_mut().unwrap();
314
315 ifft.process(&mut feature_slice);
317
318 for (t_idx, complex) in feature_slice.iter().enumerate() {
320 result[[t_idx, fdx]] = complex.re / A::from_f32(seq_len as f32).unwrap();
322 }
323 }
324
325 Ok(result)
326 }
327}