eryon_surface/model/
attention.rs1use cnc::Forward;
7use ndarray::{Array1, Array2, ArrayBase, Data, ScalarOperand};
8use num_traits::{Float, FromPrimitive};
9use rustfft::num_complex::Complex;
10use rustfft::{FftNum, FftPlanner};
11
12#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
27#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
28pub struct FftAttention<A = f32> {
29 pub(crate) steepness: A,
30 pub(crate) threshold: A,
31}
32
33impl<A> FftAttention<A> {
34 pub fn new() -> Self
36 where
37 A: FromPrimitive,
38 {
39 Self {
40 steepness: A::from_f32(10.0).unwrap(),
41 threshold: A::from_f32(0.1).unwrap(),
42 }
43 }
44 pub const fn steepness(&self) -> &A {
46 &self.steepness
47 }
48 #[inline]
51 pub fn steepness_mut(&mut self) -> &mut A {
52 &mut self.steepness
53 }
54 pub const fn threshold(&self) -> &A {
56 &self.threshold
57 }
58 #[inline]
61 pub fn threshold_mut(&mut self) -> &mut A {
62 &mut self.threshold
63 }
64 #[inline]
66 pub fn set_steepness(&mut self, steepness: A) {
67 self.steepness = steepness;
68 }
69 #[inline]
71 pub fn set_threshold(&mut self, threshold: A) {
72 self.threshold = threshold;
73 }
74 pub fn with_steepness(self, steepness: A) -> Self {
76 Self { steepness, ..self }
77 }
78 pub fn with_threshold(self, threshold: A) -> Self {
80 Self { threshold, ..self }
81 }
82 #[cfg_attr(
83 feature = "tracing",
84 tracing::instrument(
85 skip(self, input),
86 name = "forward",
87 target = "attention",
88 level = "trace",
89 )
90 )]
91 pub fn forward<X, Y>(&self, input: &X) -> cnc::Result<Y>
92 where
93 Self: Forward<X, Output = Y>,
94 {
95 <Self as Forward<X>>::forward(self, input)
96 }
97}
98
99impl<A> Default for FftAttention<A>
100where
101 A: FromPrimitive,
102{
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108impl<A, S> Forward<ArrayBase<S, ndarray::Ix1>> for FftAttention<A>
109where
110 A: FftNum + Float + FromPrimitive + ScalarOperand,
111 S: Data<Elem = A>,
112{
113 type Output = ndarray::Array1<A>;
114
115 fn forward(&self, input: &ArrayBase<S, ndarray::Ix1>) -> cnc::Result<Self::Output> {
116 let seq_len = input.dim();
117 let n = A::from_usize(seq_len).unwrap();
118
119 if seq_len == 0 {
120 return Err(cnc::params::error::ParamsError::InvalidInputShape.into());
121 }
122
123 let mut planner = FftPlanner::new();
125 let fft = planner.plan_fft_forward(seq_len);
126
127 let mut windowed_input: Vec<Complex<A>> = Vec::with_capacity(seq_len);
129
130 let n_minus_1 = A::from_usize(seq_len - 1).unwrap();
132 let pi2 = A::from_f32(2.0 * core::f32::consts::PI).unwrap();
133
134 for time_idx in 0..seq_len {
136 let window_factor = if seq_len > 1 {
138 let i_f = A::from_usize(time_idx).unwrap();
139 A::from_f32(0.5).unwrap() * (A::one() - (pi2 * i_f / n_minus_1).cos())
140 } else {
141 A::one()
142 };
143
144 let val = input[time_idx] * window_factor;
146 windowed_input.push(Complex::new(val, A::zero()));
147 }
148
149 fft.process(&mut windowed_input);
151
152 let mut freq_energy = Array1::<A>::zeros(seq_len);
154 let mut total_energy = A::zero();
155 for (time_idx, &val) in windowed_input.iter().enumerate() {
156 let energy = (val.re * val.re + val.im * val.im).sqrt();
157 freq_energy[time_idx] = energy;
158 total_energy = total_energy + energy;
159 }
160
161 let epsilon = A::from_f32(1e-10).unwrap();
163 total_energy = total_energy.max(epsilon);
164
165 for time_idx in 0..seq_len {
167 let normalized_energy = freq_energy[time_idx] / total_energy;
169
170 let exp_term = (-(normalized_energy - self.threshold) * self.steepness).exp();
172 let attention_weight = if exp_term.is_finite() {
173 A::one() / (A::one() + exp_term)
174 } else if (normalized_energy - self.threshold) > A::zero() {
175 A::one() } else {
177 A::zero() };
179
180 windowed_input[time_idx] = Complex::new(
182 windowed_input[time_idx].re * attention_weight,
183 windowed_input[time_idx].im * attention_weight,
184 );
185 }
186
187 let ifft = planner.plan_fft_inverse(seq_len);
189 ifft.process(&mut windowed_input);
190
191 let mut result = Array1::zeros(seq_len);
193 if windowed_input
194 .iter()
195 .any(|&c| c.re.is_nan() || c.im.is_nan())
196 {
197 #[cfg(feature = "tracing")]
198 tracing::warn!("The FFT/IFFT process produced NaN values.");
199 }
200 for (idx, &complex) in windowed_input.iter().enumerate() {
202 let res = complex.re / n;
204 if res.is_nan() {
205 result[idx] = A::zero(); } else {
207 result[idx] = res;
208 }
209 }
210
211 Ok(result)
212 }
213}
214
215impl<A, S> Forward<ArrayBase<S, ndarray::Ix2>> for FftAttention<A>
216where
217 A: FftNum + Float + FromPrimitive + ScalarOperand,
218 S: Data<Elem = A>,
219{
220 type Output = ndarray::Array2<A>;
221
222 fn forward(&self, input: &ArrayBase<S, ndarray::Ix2>) -> cnc::Result<Self::Output> {
223 use rustfft::FftPlanner;
224 use rustfft::num_complex::Complex;
225
226 let (seq_len, feature_dim) = input.dim();
227
228 if seq_len == 0 {
229 return Err(cnc::params::error::ParamsError::InvalidInputShape.into());
230 }
231
232 let mut planner = FftPlanner::new();
234 let fft = planner.plan_fft_forward(seq_len);
235 let mut frequency_domain = Array2::<Complex<A>>::zeros((feature_dim, seq_len));
236
237 let n_minus_1 = A::from_usize(seq_len - 1).unwrap();
239 let pi2 = A::from_f32(2.0 * core::f32::consts::PI).unwrap();
240 for feature_idx in 0..feature_dim {
242 let mut windowed_input: Vec<Complex<A>> = Vec::with_capacity(seq_len);
244
245 for time_idx in 0..seq_len {
247 let window_factor = if seq_len > 1 {
249 let i_f = A::from_usize(time_idx).unwrap();
250 A::from_f32(0.5).unwrap() * (A::one() - (pi2 * i_f / n_minus_1).cos())
251 } else {
252 A::one()
253 };
254
255 let val = input[[time_idx, feature_idx]] * window_factor;
257 windowed_input.push(Complex::new(val, A::zero()));
258 }
259
260 fft.process(&mut windowed_input);
262
263 for (time_idx, &val) in windowed_input.iter().enumerate() {
265 frequency_domain[[feature_idx, time_idx]] = val;
266 }
267 }
268
269 let mut attention_weights = Array2::<A>::zeros((feature_dim, seq_len));
271
272 for fdx in 0..feature_dim {
273 let mut total_energy = A::zero();
275 let mut freq_energy = Array1::<A>::zeros(seq_len);
276
277 for time_idx in 0..seq_len {
278 let val = frequency_domain[[fdx, time_idx]];
279 let energy = (val.re * val.re + val.im * val.im).sqrt();
280 freq_energy[time_idx] = energy;
281 total_energy = total_energy + energy;
282 }
283
284 if total_energy.is_positive() {
286 for time_idx in 0..seq_len {
287 let normalized_energy = freq_energy[time_idx] / total_energy;
289
290 attention_weights[[fdx, time_idx]] = A::one()
292 / (A::one()
293 + (-(normalized_energy - self.threshold) * self.steepness).exp());
294 }
295 }
296 }
297
298 frequency_domain
300 .iter_mut()
301 .zip(attention_weights.iter())
302 .for_each(|(val, &w)| {
303 *val = Complex::new(val.re * w, val.im * w);
304 });
305
306 let mut result = Array2::zeros((seq_len, feature_dim));
308
309 let ifft = planner.plan_fft_inverse(seq_len);
311
312 for fdx in 0..feature_dim {
313 let mut row = frequency_domain.row_mut(fdx);
314 let feature_slice = row.as_slice_mut().unwrap();
315
316 ifft.process(feature_slice);
318
319 for (t_idx, complex) in feature_slice.iter().enumerate() {
321 result[[t_idx, fdx]] = complex.re / A::from_f32(seq_len as f32).unwrap();
323 }
324 }
325
326 Ok(result)
327 }
328}