omega_brain/
attention_system.rs

1//! Attention System - Self-contained attention mechanisms
2
3use crate::{BrainConfig, Result};
4use serde::{Deserialize, Serialize};
5
6/// Attention mechanism types
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum AttentionType { ScaledDotProduct, Flash, Linear, Sparse, Hyperbolic, Graph }
9
10/// Working memory
11#[derive(Debug, Clone)]
12pub struct WorkingMemory {
13    items: Vec<(Vec<f64>, f64)>,
14    capacity: usize,
15    dim: usize,
16}
17
18impl WorkingMemory {
19    pub fn new(capacity: usize, dim: usize) -> Self {
20        Self { items: Vec::with_capacity(capacity), capacity, dim }
21    }
22    pub fn dim(&self) -> usize { self.dim }
23    pub fn store(&mut self, content: &[f64]) {
24        if self.items.len() >= self.capacity {
25            self.items.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
26            self.items.remove(0);
27        }
28        self.items.push((content.to_vec(), 1.0));
29    }
30    pub fn gate(&self, input: &[f64]) -> Vec<f64> {
31        let mut result = input.to_vec();
32        for (item, strength) in &self.items {
33            for (i, &v) in item.iter().enumerate() {
34                if i < result.len() { result[i] += v * strength * 0.1; }
35            }
36        }
37        result
38    }
39    pub fn contents(&self) -> Vec<Vec<f64>> { self.items.iter().map(|(v, _)| v.clone()).collect() }
40    pub fn clear(&mut self) { self.items.clear(); }
41}
42
43/// Attention controller
44#[derive(Debug, Clone)]
45pub struct AttentionController {
46    num_heads: usize,
47    dim: usize,
48    top_down: f64,
49    bottom_up: f64,
50    bias: Vec<f64>,
51    mechanism: AttentionType,
52}
53
54impl AttentionController {
55    pub fn new(num_heads: usize, dim: usize, top_down: f64, bottom_up: f64) -> Self {
56        Self { num_heads, dim, top_down, bottom_up, bias: vec![0.0; dim], mechanism: AttentionType::ScaledDotProduct }
57    }
58    pub fn attend(&self, input: &[f64]) -> Vec<f64> {
59        (0..self.dim).map(|i| {
60            let v = input.get(i).copied().unwrap_or(0.0);
61            let b = self.bias.get(i).copied().unwrap_or(0.0);
62            (v * self.bottom_up + b * self.top_down).tanh()
63        }).collect()
64    }
65    pub fn set_top_down_bias(&mut self, bias: &[f64]) {
66        for (i, &b) in bias.iter().enumerate() { if i < self.dim { self.bias[i] = b; } }
67    }
68    pub fn set_mechanism(&mut self, m: AttentionType) { self.mechanism = m; }
69    pub fn num_heads(&self) -> usize { self.num_heads }
70    pub fn top_down_strength(&self) -> f64 { self.top_down }
71    pub fn bottom_up_strength(&self) -> f64 { self.bottom_up }
72}
73
74/// Attention system
75pub struct AttentionSystem {
76    controller: AttentionController,
77    working_memory: WorkingMemory,
78    current_focus: Vec<f64>,
79    dim: usize,
80}
81
82impl AttentionSystem {
83    pub fn new(config: &BrainConfig) -> Self {
84        Self {
85            controller: AttentionController::new(config.attention_heads, config.attention_dim, config.top_down_strength, config.bottom_up_strength),
86            working_memory: WorkingMemory::new(config.workspace_capacity, config.attention_dim),
87            current_focus: vec![0.0; config.attention_dim],
88            dim: config.attention_dim,
89        }
90    }
91    pub fn attend(&mut self, input: &[f64]) -> Result<Vec<f64>> {
92        let normalized: Vec<f64> = (0..self.dim).map(|i| input.get(i).copied().unwrap_or(0.0)).collect();
93        let attended = self.controller.attend(&normalized);
94        let gated = self.working_memory.gate(&attended);
95        for (i, &v) in gated.iter().enumerate() {
96            if i < self.current_focus.len() { self.current_focus[i] = 0.8 * self.current_focus[i] + 0.2 * v; }
97        }
98        Ok(gated)
99    }
100    pub fn focus_on(&mut self, target: &[f64]) -> Result<()> {
101        let normalized: Vec<f64> = (0..self.dim).map(|i| target.get(i).copied().unwrap_or(0.0)).collect();
102        self.controller.set_top_down_bias(&normalized);
103        self.current_focus = normalized;
104        Ok(())
105    }
106    pub fn current_focus(&self) -> Vec<f64> { self.current_focus.clone() }
107    pub fn remember(&mut self, item: &[f64]) { self.working_memory.store(item); }
108    pub fn working_memory_contents(&self) -> Vec<Vec<f64>> { self.working_memory.contents() }
109    pub fn clear_working_memory(&mut self) { self.working_memory.clear(); }
110    pub fn attention_strength(&self) -> f64 {
111        let max = self.current_focus.iter().map(|x| x.abs()).fold(0.0, f64::max);
112        let mean = self.current_focus.iter().map(|x| x.abs()).sum::<f64>() / self.current_focus.len().max(1) as f64;
113        if mean > 0.0 { (max / mean).min(2.0) / 2.0 } else { 0.0 }
114    }
115    pub fn switch_mechanism(&mut self, m: AttentionType) { self.controller.set_mechanism(m); }
116    pub fn reset(&mut self) {
117        self.controller = AttentionController::new(self.controller.num_heads(), self.dim, self.controller.top_down_strength(), self.controller.bottom_up_strength());
118        self.working_memory.clear();
119        self.current_focus = vec![0.0; self.dim];
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    #[test]
127    fn test_working_memory() {
128        let mut wm = WorkingMemory::new(5, 8);
129        wm.store(&vec![0.5; 8]);
130        assert_eq!(wm.contents().len(), 1);
131    }
132}