omega_brain/
attention_system.rs1use crate::{BrainConfig, Result};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
8pub enum AttentionType { ScaledDotProduct, Flash, Linear, Sparse, Hyperbolic, Graph }
9
10#[derive(Debug, Clone)]
12pub struct WorkingMemory {
13 items: Vec<(Vec<f64>, f64)>,
14 capacity: usize,
15 dim: usize,
16}
17
18impl WorkingMemory {
19 pub fn new(capacity: usize, dim: usize) -> Self {
20 Self { items: Vec::with_capacity(capacity), capacity, dim }
21 }
22 pub fn dim(&self) -> usize { self.dim }
23 pub fn store(&mut self, content: &[f64]) {
24 if self.items.len() >= self.capacity {
25 self.items.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
26 self.items.remove(0);
27 }
28 self.items.push((content.to_vec(), 1.0));
29 }
30 pub fn gate(&self, input: &[f64]) -> Vec<f64> {
31 let mut result = input.to_vec();
32 for (item, strength) in &self.items {
33 for (i, &v) in item.iter().enumerate() {
34 if i < result.len() { result[i] += v * strength * 0.1; }
35 }
36 }
37 result
38 }
39 pub fn contents(&self) -> Vec<Vec<f64>> { self.items.iter().map(|(v, _)| v.clone()).collect() }
40 pub fn clear(&mut self) { self.items.clear(); }
41}
42
43#[derive(Debug, Clone)]
45pub struct AttentionController {
46 num_heads: usize,
47 dim: usize,
48 top_down: f64,
49 bottom_up: f64,
50 bias: Vec<f64>,
51 mechanism: AttentionType,
52}
53
54impl AttentionController {
55 pub fn new(num_heads: usize, dim: usize, top_down: f64, bottom_up: f64) -> Self {
56 Self { num_heads, dim, top_down, bottom_up, bias: vec![0.0; dim], mechanism: AttentionType::ScaledDotProduct }
57 }
58 pub fn attend(&self, input: &[f64]) -> Vec<f64> {
59 (0..self.dim).map(|i| {
60 let v = input.get(i).copied().unwrap_or(0.0);
61 let b = self.bias.get(i).copied().unwrap_or(0.0);
62 (v * self.bottom_up + b * self.top_down).tanh()
63 }).collect()
64 }
65 pub fn set_top_down_bias(&mut self, bias: &[f64]) {
66 for (i, &b) in bias.iter().enumerate() { if i < self.dim { self.bias[i] = b; } }
67 }
68 pub fn set_mechanism(&mut self, m: AttentionType) { self.mechanism = m; }
69 pub fn num_heads(&self) -> usize { self.num_heads }
70 pub fn top_down_strength(&self) -> f64 { self.top_down }
71 pub fn bottom_up_strength(&self) -> f64 { self.bottom_up }
72}
73
74pub struct AttentionSystem {
76 controller: AttentionController,
77 working_memory: WorkingMemory,
78 current_focus: Vec<f64>,
79 dim: usize,
80}
81
82impl AttentionSystem {
83 pub fn new(config: &BrainConfig) -> Self {
84 Self {
85 controller: AttentionController::new(config.attention_heads, config.attention_dim, config.top_down_strength, config.bottom_up_strength),
86 working_memory: WorkingMemory::new(config.workspace_capacity, config.attention_dim),
87 current_focus: vec![0.0; config.attention_dim],
88 dim: config.attention_dim,
89 }
90 }
91 pub fn attend(&mut self, input: &[f64]) -> Result<Vec<f64>> {
92 let normalized: Vec<f64> = (0..self.dim).map(|i| input.get(i).copied().unwrap_or(0.0)).collect();
93 let attended = self.controller.attend(&normalized);
94 let gated = self.working_memory.gate(&attended);
95 for (i, &v) in gated.iter().enumerate() {
96 if i < self.current_focus.len() { self.current_focus[i] = 0.8 * self.current_focus[i] + 0.2 * v; }
97 }
98 Ok(gated)
99 }
100 pub fn focus_on(&mut self, target: &[f64]) -> Result<()> {
101 let normalized: Vec<f64> = (0..self.dim).map(|i| target.get(i).copied().unwrap_or(0.0)).collect();
102 self.controller.set_top_down_bias(&normalized);
103 self.current_focus = normalized;
104 Ok(())
105 }
106 pub fn current_focus(&self) -> Vec<f64> { self.current_focus.clone() }
107 pub fn remember(&mut self, item: &[f64]) { self.working_memory.store(item); }
108 pub fn working_memory_contents(&self) -> Vec<Vec<f64>> { self.working_memory.contents() }
109 pub fn clear_working_memory(&mut self) { self.working_memory.clear(); }
110 pub fn attention_strength(&self) -> f64 {
111 let max = self.current_focus.iter().map(|x| x.abs()).fold(0.0, f64::max);
112 let mean = self.current_focus.iter().map(|x| x.abs()).sum::<f64>() / self.current_focus.len().max(1) as f64;
113 if mean > 0.0 { (max / mean).min(2.0) / 2.0 } else { 0.0 }
114 }
115 pub fn switch_mechanism(&mut self, m: AttentionType) { self.controller.set_mechanism(m); }
116 pub fn reset(&mut self) {
117 self.controller = AttentionController::new(self.controller.num_heads(), self.dim, self.controller.top_down_strength(), self.controller.bottom_up_strength());
118 self.working_memory.clear();
119 self.current_focus = vec![0.0; self.dim];
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 #[test]
127 fn test_working_memory() {
128 let mut wm = WorkingMemory::new(5, 8);
129 wm.store(&vec![0.5; 8]);
130 assert_eq!(wm.contents().len(), 1);
131 }
132}