1use serde::{Deserialize, Serialize};
9use std::collections::VecDeque;
10
11use crate::mechanisms::{
12 AttentionMechanism, AttentionOutput, AttentionType,
13 ScaledDotProductAttention, LinearAttention, SparseAttention,
14};
15use crate::Result;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct AttentionConfig {
20 pub dim: usize,
22 pub top_down_weight: f64,
24 pub bottom_up_weight: f64,
26 pub num_heads: usize,
28 pub default_mechanism: AttentionType,
30 pub focus_decay: f64,
32}
33
34impl Default for AttentionConfig {
35 fn default() -> Self {
36 Self {
37 dim: 64,
38 top_down_weight: 0.6,
39 bottom_up_weight: 0.4,
40 num_heads: 4,
41 default_mechanism: AttentionType::ScaledDotProduct,
42 focus_decay: 0.1,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct PriorityMap {
50 pub priorities: Vec<f64>,
52 pub top_down: Vec<f64>,
54 pub bottom_up: Vec<f64>,
56 pub combined: Vec<f64>,
58}
59
60impl PriorityMap {
61 pub fn new(size: usize) -> Self {
62 Self {
63 priorities: vec![0.0; size],
64 top_down: vec![0.0; size],
65 bottom_up: vec![0.0; size],
66 combined: vec![0.0; size],
67 }
68 }
69
70 pub fn argmax(&self) -> usize {
72 self.combined
73 .iter()
74 .enumerate()
75 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
76 .map(|(i, _)| i)
77 .unwrap_or(0)
78 }
79
80 pub fn top_k(&self, k: usize) -> Vec<usize> {
82 let mut indexed: Vec<_> = self.combined.iter().enumerate().collect();
83 indexed.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
84 indexed.into_iter().take(k).map(|(i, _)| i).collect()
85 }
86
87 pub fn normalize(&mut self) {
89 let sum: f64 = self.combined.iter().sum();
90 if sum > 0.0 {
91 for p in &mut self.combined {
92 *p /= sum;
93 }
94 }
95 }
96}
97
98pub struct AttentionController {
100 config: AttentionConfig,
101 focus: Option<Vec<f64>>,
103 focus_history: VecDeque<Vec<f64>>,
105 mechanism: Box<dyn AttentionMechanism>,
107 recent_attended: VecDeque<usize>,
109}
110
111impl AttentionController {
112 pub fn new(config: AttentionConfig) -> Self {
113 let mechanism: Box<dyn AttentionMechanism> = match config.default_mechanism {
114 AttentionType::Linear => Box::new(LinearAttention::new(config.dim)),
115 AttentionType::Sparse => Box::new(SparseAttention::new(config.dim, 16, 4)),
116 _ => Box::new(ScaledDotProductAttention::new(config.dim)),
117 };
118
119 Self {
120 config,
121 focus: None,
122 focus_history: VecDeque::with_capacity(10),
123 mechanism,
124 recent_attended: VecDeque::with_capacity(5),
125 }
126 }
127
128 pub fn set_focus(&mut self, target: &[f64]) {
130 if let Some(old_focus) = self.focus.take() {
131 self.focus_history.push_back(old_focus);
132 if self.focus_history.len() > 10 {
133 self.focus_history.pop_front();
134 }
135 }
136 self.focus = Some(target.to_vec());
137 }
138
139 pub fn clear_focus(&mut self) {
141 self.focus = None;
142 }
143
144 pub fn current_focus(&self) -> Option<Vec<f64>> {
146 self.focus.clone()
147 }
148
149 pub fn compute_relevance(&self, input: &[f64], goals: &[f64]) -> Vec<f64> {
151 let n = input.len() / self.config.dim;
152 let mut relevance = vec![0.0; n];
153
154 for (i, rel) in relevance.iter_mut().enumerate().take(n) {
156 let start = i * self.config.dim;
157 let end = (start + self.config.dim).min(input.len());
158 let item = &input[start..end];
159
160 let mut dot = 0.0;
162 let mut norm_item = 0.0;
163 let mut norm_goals = 0.0;
164
165 for (j, &x) in item.iter().enumerate() {
166 if let Some(&g) = goals.get(j) {
167 dot += x * g;
168 }
169 norm_item += x * x;
170 }
171 for &g in goals.iter().take(self.config.dim) {
172 norm_goals += g * g;
173 }
174
175 norm_item = norm_item.sqrt();
176 norm_goals = norm_goals.sqrt();
177
178 if norm_item > 0.0 && norm_goals > 0.0 {
179 *rel = (dot / (norm_item * norm_goals) + 1.0) / 2.0; }
181
182 if let Some(ref focus) = self.focus {
184 let focus_sim = Self::cosine_similarity(item, focus);
185 *rel = (*rel + focus_sim) / 2.0;
186 }
187 }
188
189 relevance
190 }
191
192 pub fn combine_priorities(&self, salience: &[f64], relevance: &[f64]) -> PriorityMap {
194 let n = salience.len().min(relevance.len());
195 let mut map = PriorityMap::new(n);
196
197 map.bottom_up = salience[..n].to_vec();
198 map.top_down = relevance[..n].to_vec();
199
200 for i in 0..n {
202 let td = self.config.top_down_weight * relevance[i];
203 let bu = self.config.bottom_up_weight * salience[i];
204 map.combined[i] = td + bu;
205
206 if self.recent_attended.contains(&i) {
208 map.combined[i] *= 0.5;
209 }
210 }
211
212 map.normalize();
213 map
214 }
215
216 pub fn apply_attention(
218 &mut self,
219 input: &[f64],
220 priority: &PriorityMap,
221 context: &[f64],
222 ) -> Result<AttentionOutput> {
223 let n = priority.combined.len();
225 let threshold = 0.1 / n as f64; let mask: Vec<bool> = priority.combined.iter().map(|&p| p > threshold).collect();
227
228 let query = if let Some(ref focus) = self.focus {
230 let mut q = vec![0.0; self.config.dim];
232 for (i, q_val) in q.iter_mut().enumerate().take(self.config.dim) {
233 let f = focus.get(i).copied().unwrap_or(0.0);
234 let c = context.get(i).copied().unwrap_or(0.0);
235 *q_val = 0.7 * f + 0.3 * c; }
237 q
238 } else {
239 context.to_vec()
240 };
241
242 let output = self.mechanism.compute(&query, input, input, Some(&mask));
244
245 self.recent_attended.push_back(output.max_index);
247 if self.recent_attended.len() > 5 {
248 self.recent_attended.pop_front();
249 }
250
251 Ok(output)
252 }
253
254 pub fn decay_focus(&mut self) {
256 if let Some(ref mut focus) = self.focus {
257 for f in focus.iter_mut() {
258 *f *= 1.0 - self.config.focus_decay;
259 }
260
261 let norm: f64 = focus.iter().map(|x| x * x).sum::<f64>().sqrt();
263 if norm < 0.01 {
264 self.focus = None;
265 }
266 }
267 }
268
269 pub fn set_mechanism(&mut self, mechanism_type: AttentionType) {
271 self.mechanism = match mechanism_type {
272 AttentionType::Linear => Box::new(LinearAttention::new(self.config.dim)),
273 AttentionType::Sparse => Box::new(SparseAttention::new(self.config.dim, 16, 4)),
274 _ => Box::new(ScaledDotProductAttention::new(self.config.dim)),
275 };
276 }
277
278 fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
280 let mut dot = 0.0;
281 let mut norm_a = 0.0;
282 let mut norm_b = 0.0;
283
284 for (&x, &y) in a.iter().zip(b.iter()) {
285 dot += x * y;
286 norm_a += x * x;
287 norm_b += y * y;
288 }
289
290 norm_a = norm_a.sqrt();
291 norm_b = norm_b.sqrt();
292
293 if norm_a > 0.0 && norm_b > 0.0 {
294 (dot / (norm_a * norm_b) + 1.0) / 2.0
295 } else {
296 0.5
297 }
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_attention_controller_creation() {
307 let config = AttentionConfig::default();
308 let controller = AttentionController::new(config);
309
310 assert!(controller.current_focus().is_none());
311 }
312
313 #[test]
314 fn test_set_focus() {
315 let config = AttentionConfig::default();
316 let mut controller = AttentionController::new(config);
317
318 let target = vec![1.0, 0.0, 0.0, 0.0];
319 controller.set_focus(&target);
320
321 assert!(controller.current_focus().is_some());
322 }
323
324 #[test]
325 fn test_priority_map() {
326 let mut map = PriorityMap::new(5);
327 map.combined = vec![0.1, 0.3, 0.2, 0.1, 0.3];
328
329 assert_eq!(map.argmax(), 4); let top2 = map.top_k(2);
332 assert_eq!(top2.len(), 2);
333 }
334
335 #[test]
336 fn test_combine_priorities() {
337 let config = AttentionConfig::default();
338 let controller = AttentionController::new(config);
339
340 let salience = vec![0.5, 0.3, 0.2];
341 let relevance = vec![0.2, 0.5, 0.3];
342
343 let map = controller.combine_priorities(&salience, &relevance);
344
345 assert_eq!(map.combined.len(), 3);
346 assert!((map.combined.iter().sum::<f64>() - 1.0).abs() < 0.01);
347 }
348
349 #[test]
350 fn test_focus_decay() {
351 let mut config = AttentionConfig::default();
352 config.focus_decay = 0.5;
353
354 let mut controller = AttentionController::new(config);
355 controller.set_focus(&vec![1.0, 1.0, 1.0, 1.0]);
356
357 controller.decay_focus();
358
359 let focus = controller.current_focus().unwrap();
360 assert!(focus[0] < 1.0);
361 }
362}