omega_attention/
controller.rs

1//! Attention Controller
2//!
3//! Implements brain-like attention control with:
4//! - Top-down (goal-driven) attention
5//! - Bottom-up (salience-driven) attention
6//! - Priority map combining both signals
7
8use serde::{Deserialize, Serialize};
9use std::collections::VecDeque;
10
11use crate::mechanisms::{
12    AttentionMechanism, AttentionOutput, AttentionType,
13    ScaledDotProductAttention, LinearAttention, SparseAttention,
14};
15use crate::Result;
16
17/// Configuration for attention controller
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct AttentionConfig {
20    /// Dimension of attention vectors
21    pub dim: usize,
22    /// Weight for top-down attention (0.0 to 1.0)
23    pub top_down_weight: f64,
24    /// Weight for bottom-up attention (0.0 to 1.0)
25    pub bottom_up_weight: f64,
26    /// Number of attention heads
27    pub num_heads: usize,
28    /// Default attention mechanism type
29    pub default_mechanism: AttentionType,
30    /// Attention focus decay rate
31    pub focus_decay: f64,
32}
33
34impl Default for AttentionConfig {
35    fn default() -> Self {
36        Self {
37            dim: 64,
38            top_down_weight: 0.6,
39            bottom_up_weight: 0.4,
40            num_heads: 4,
41            default_mechanism: AttentionType::ScaledDotProduct,
42            focus_decay: 0.1,
43        }
44    }
45}
46
47/// Priority map combining top-down and bottom-up signals
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct PriorityMap {
50    /// Priority values
51    pub priorities: Vec<f64>,
52    /// Top-down contribution
53    pub top_down: Vec<f64>,
54    /// Bottom-up contribution
55    pub bottom_up: Vec<f64>,
56    /// Combined priority
57    pub combined: Vec<f64>,
58}
59
60impl PriorityMap {
61    pub fn new(size: usize) -> Self {
62        Self {
63            priorities: vec![0.0; size],
64            top_down: vec![0.0; size],
65            bottom_up: vec![0.0; size],
66            combined: vec![0.0; size],
67        }
68    }
69
70    /// Get highest priority index
71    pub fn argmax(&self) -> usize {
72        self.combined
73            .iter()
74            .enumerate()
75            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
76            .map(|(i, _)| i)
77            .unwrap_or(0)
78    }
79
80    /// Get top-k priority indices
81    pub fn top_k(&self, k: usize) -> Vec<usize> {
82        let mut indexed: Vec<_> = self.combined.iter().enumerate().collect();
83        indexed.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
84        indexed.into_iter().take(k).map(|(i, _)| i).collect()
85    }
86
87    /// Normalize priorities
88    pub fn normalize(&mut self) {
89        let sum: f64 = self.combined.iter().sum();
90        if sum > 0.0 {
91            for p in &mut self.combined {
92                *p /= sum;
93            }
94        }
95    }
96}
97
98/// The attention controller
99pub struct AttentionController {
100    config: AttentionConfig,
101    /// Current focus target (for top-down)
102    focus: Option<Vec<f64>>,
103    /// Focus history
104    focus_history: VecDeque<Vec<f64>>,
105    /// Attention mechanism
106    mechanism: Box<dyn AttentionMechanism>,
107    /// Recent attention outputs for inhibition of return
108    recent_attended: VecDeque<usize>,
109}
110
111impl AttentionController {
112    pub fn new(config: AttentionConfig) -> Self {
113        let mechanism: Box<dyn AttentionMechanism> = match config.default_mechanism {
114            AttentionType::Linear => Box::new(LinearAttention::new(config.dim)),
115            AttentionType::Sparse => Box::new(SparseAttention::new(config.dim, 16, 4)),
116            _ => Box::new(ScaledDotProductAttention::new(config.dim)),
117        };
118
119        Self {
120            config,
121            focus: None,
122            focus_history: VecDeque::with_capacity(10),
123            mechanism,
124            recent_attended: VecDeque::with_capacity(5),
125        }
126    }
127
128    /// Set top-down focus target
129    pub fn set_focus(&mut self, target: &[f64]) {
130        if let Some(old_focus) = self.focus.take() {
131            self.focus_history.push_back(old_focus);
132            if self.focus_history.len() > 10 {
133                self.focus_history.pop_front();
134            }
135        }
136        self.focus = Some(target.to_vec());
137    }
138
139    /// Clear focus
140    pub fn clear_focus(&mut self) {
141        self.focus = None;
142    }
143
144    /// Get current focus
145    pub fn current_focus(&self) -> Option<Vec<f64>> {
146        self.focus.clone()
147    }
148
149    /// Compute top-down relevance based on current goals
150    pub fn compute_relevance(&self, input: &[f64], goals: &[f64]) -> Vec<f64> {
151        let n = input.len() / self.config.dim;
152        let mut relevance = vec![0.0; n];
153
154        // Compute similarity to goals
155        for (i, rel) in relevance.iter_mut().enumerate().take(n) {
156            let start = i * self.config.dim;
157            let end = (start + self.config.dim).min(input.len());
158            let item = &input[start..end];
159
160            // Cosine similarity to goals
161            let mut dot = 0.0;
162            let mut norm_item = 0.0;
163            let mut norm_goals = 0.0;
164
165            for (j, &x) in item.iter().enumerate() {
166                if let Some(&g) = goals.get(j) {
167                    dot += x * g;
168                }
169                norm_item += x * x;
170            }
171            for &g in goals.iter().take(self.config.dim) {
172                norm_goals += g * g;
173            }
174
175            norm_item = norm_item.sqrt();
176            norm_goals = norm_goals.sqrt();
177
178            if norm_item > 0.0 && norm_goals > 0.0 {
179                *rel = (dot / (norm_item * norm_goals) + 1.0) / 2.0; // Normalize to [0,1]
180            }
181
182            // Boost if matches current focus
183            if let Some(ref focus) = self.focus {
184                let focus_sim = Self::cosine_similarity(item, focus);
185                *rel = (*rel + focus_sim) / 2.0;
186            }
187        }
188
189        relevance
190    }
191
192    /// Combine top-down and bottom-up priorities
193    pub fn combine_priorities(&self, salience: &[f64], relevance: &[f64]) -> PriorityMap {
194        let n = salience.len().min(relevance.len());
195        let mut map = PriorityMap::new(n);
196
197        map.bottom_up = salience[..n].to_vec();
198        map.top_down = relevance[..n].to_vec();
199
200        // Weighted combination
201        for i in 0..n {
202            let td = self.config.top_down_weight * relevance[i];
203            let bu = self.config.bottom_up_weight * salience[i];
204            map.combined[i] = td + bu;
205
206            // Inhibition of return: reduce priority of recently attended
207            if self.recent_attended.contains(&i) {
208                map.combined[i] *= 0.5;
209            }
210        }
211
212        map.normalize();
213        map
214    }
215
216    /// Apply attention mechanism with priorities
217    pub fn apply_attention(
218        &mut self,
219        input: &[f64],
220        priority: &PriorityMap,
221        context: &[f64],
222    ) -> Result<AttentionOutput> {
223        // Use priorities as soft mask
224        let n = priority.combined.len();
225        let threshold = 0.1 / n as f64; // Only attend to above-average priorities
226        let mask: Vec<bool> = priority.combined.iter().map(|&p| p > threshold).collect();
227
228        // Query is combination of focus and context
229        let query = if let Some(ref focus) = self.focus {
230            // Blend focus with context
231            let mut q = vec![0.0; self.config.dim];
232            for (i, q_val) in q.iter_mut().enumerate().take(self.config.dim) {
233                let f = focus.get(i).copied().unwrap_or(0.0);
234                let c = context.get(i).copied().unwrap_or(0.0);
235                *q_val = 0.7 * f + 0.3 * c; // Focus-weighted query
236            }
237            q
238        } else {
239            context.to_vec()
240        };
241
242        // Apply attention
243        let output = self.mechanism.compute(&query, input, input, Some(&mask));
244
245        // Update inhibition of return
246        self.recent_attended.push_back(output.max_index);
247        if self.recent_attended.len() > 5 {
248            self.recent_attended.pop_front();
249        }
250
251        Ok(output)
252    }
253
254    /// Decay focus over time
255    pub fn decay_focus(&mut self) {
256        if let Some(ref mut focus) = self.focus {
257            for f in focus.iter_mut() {
258                *f *= 1.0 - self.config.focus_decay;
259            }
260
261            // Clear if decayed to near-zero
262            let norm: f64 = focus.iter().map(|x| x * x).sum::<f64>().sqrt();
263            if norm < 0.01 {
264                self.focus = None;
265            }
266        }
267    }
268
269    /// Switch attention mechanism
270    pub fn set_mechanism(&mut self, mechanism_type: AttentionType) {
271        self.mechanism = match mechanism_type {
272            AttentionType::Linear => Box::new(LinearAttention::new(self.config.dim)),
273            AttentionType::Sparse => Box::new(SparseAttention::new(self.config.dim, 16, 4)),
274            _ => Box::new(ScaledDotProductAttention::new(self.config.dim)),
275        };
276    }
277
278    /// Helper: cosine similarity
279    fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
280        let mut dot = 0.0;
281        let mut norm_a = 0.0;
282        let mut norm_b = 0.0;
283
284        for (&x, &y) in a.iter().zip(b.iter()) {
285            dot += x * y;
286            norm_a += x * x;
287            norm_b += y * y;
288        }
289
290        norm_a = norm_a.sqrt();
291        norm_b = norm_b.sqrt();
292
293        if norm_a > 0.0 && norm_b > 0.0 {
294            (dot / (norm_a * norm_b) + 1.0) / 2.0
295        } else {
296            0.5
297        }
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_attention_controller_creation() {
307        let config = AttentionConfig::default();
308        let controller = AttentionController::new(config);
309
310        assert!(controller.current_focus().is_none());
311    }
312
313    #[test]
314    fn test_set_focus() {
315        let config = AttentionConfig::default();
316        let mut controller = AttentionController::new(config);
317
318        let target = vec![1.0, 0.0, 0.0, 0.0];
319        controller.set_focus(&target);
320
321        assert!(controller.current_focus().is_some());
322    }
323
324    #[test]
325    fn test_priority_map() {
326        let mut map = PriorityMap::new(5);
327        map.combined = vec![0.1, 0.3, 0.2, 0.1, 0.3];
328
329        assert_eq!(map.argmax(), 4); // max_by returns last max in case of ties
330
331        let top2 = map.top_k(2);
332        assert_eq!(top2.len(), 2);
333    }
334
335    #[test]
336    fn test_combine_priorities() {
337        let config = AttentionConfig::default();
338        let controller = AttentionController::new(config);
339
340        let salience = vec![0.5, 0.3, 0.2];
341        let relevance = vec![0.2, 0.5, 0.3];
342
343        let map = controller.combine_priorities(&salience, &relevance);
344
345        assert_eq!(map.combined.len(), 3);
346        assert!((map.combined.iter().sum::<f64>() - 1.0).abs() < 0.01);
347    }
348
349    #[test]
350    fn test_focus_decay() {
351        let mut config = AttentionConfig::default();
352        config.focus_decay = 0.5;
353
354        let mut controller = AttentionController::new(config);
355        controller.set_focus(&vec![1.0, 1.0, 1.0, 1.0]);
356
357        controller.decay_focus();
358
359        let focus = controller.current_focus().unwrap();
360        assert!(focus[0] < 1.0);
361    }
362}