omega_brain/
memory_system.rs1use crate::{BrainConfig, Result};
4use crate::sleep_system::SleepOutput;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Memory {
11 pub pattern: Vec<f64>,
12 pub importance: f64,
13}
14
15#[derive(Debug, Clone)]
17pub struct ReplayBuffer {
18 memories: Vec<Memory>,
19 capacity: usize,
20}
21
22impl ReplayBuffer {
23 pub fn new(capacity: usize) -> Self { Self { memories: Vec::with_capacity(capacity), capacity } }
24 pub fn add(&mut self, pattern: Vec<f64>, importance: f64) {
25 if self.memories.len() >= self.capacity {
26 self.memories.sort_by(|a, b| a.importance.partial_cmp(&b.importance).unwrap());
27 self.memories.remove(0);
28 }
29 self.memories.push(Memory { pattern, importance });
30 }
31 pub fn sample(&self, n: usize) -> Vec<Memory> {
32 self.memories.iter().rev().take(n).cloned().collect()
33 }
34 pub fn clear(&mut self) { self.memories.clear(); }
35}
36
37#[derive(Debug, Clone)]
39pub struct Hippocampus {
40 patterns: HashMap<String, Memory>,
41 dim: usize,
42 threshold: f64,
43 next_id: usize,
44}
45
46impl Hippocampus {
47 pub fn new(dim: usize, _ca3_size: usize, threshold: f64) -> Self {
48 Self { patterns: HashMap::new(), dim, threshold, next_id: 0 }
49 }
50 pub fn dim(&self) -> usize { self.dim }
51 pub fn encode(&mut self, pattern: &[f64], importance: f64) -> String {
52 let id = format!("mem_{}", self.next_id);
53 self.next_id += 1;
54 self.patterns.insert(id.clone(), Memory { pattern: pattern.to_vec(), importance });
55 id
56 }
57 pub fn retrieve(&self, cue: &[f64]) -> Option<Memory> {
58 let mut best: Option<(&Memory, f64)> = None;
59 for mem in self.patterns.values() {
60 let sim = cosine_similarity(cue, &mem.pattern);
61 if sim > self.threshold && (best.is_none() || sim > best.unwrap().1) {
62 best = Some((mem, sim));
63 }
64 }
65 best.map(|(m, _)| m.clone())
66 }
67 pub fn strengthen(&mut self, pattern: &[f64], amount: f64) {
68 for mem in self.patterns.values_mut() {
69 let sim = cosine_similarity(pattern, &mem.pattern);
70 if sim > 0.8 { mem.importance = (mem.importance + amount).min(1.0); }
71 }
72 }
73 pub fn strengthen_association(&mut self, a: &[f64], b: &[f64], amount: f64) {
74 self.strengthen(a, amount * 0.5);
75 self.strengthen(b, amount * 0.5);
76 }
77 pub fn consolidate_all(&mut self) -> usize {
78 let count = self.patterns.len();
79 for mem in self.patterns.values_mut() { mem.importance = (mem.importance * 1.1).min(1.0); }
80 count
81 }
82}
83
84fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
86 let mut dot = 0.0; let mut na = 0.0; let mut nb = 0.0;
87 for (&x, &y) in a.iter().zip(b.iter()) { dot += x * y; na += x * x; nb += y * y; }
88 let denom = (na * nb).sqrt();
89 if denom > 0.0 { dot / denom } else { 0.0 }
90}
91
92pub struct MemorySystem {
94 hippocampus: Hippocampus,
95 replay_buffer: ReplayBuffer,
96 consolidation_count: usize,
97 memory_count: usize,
98 dim: usize,
99}
100
101impl MemorySystem {
102 pub fn new(config: &BrainConfig) -> Self {
103 Self {
104 hippocampus: Hippocampus::new(config.pattern_dim, config.ca3_size, config.consolidation_threshold),
105 replay_buffer: ReplayBuffer::new(config.replay_buffer_size),
106 consolidation_count: 0,
107 memory_count: 0,
108 dim: config.pattern_dim,
109 }
110 }
111 pub fn process(&mut self, content: &[f64]) -> Result<Vec<f64>> {
112 let normalized: Vec<f64> = (0..self.dim).map(|i| content.get(i).copied().unwrap_or(0.0)).collect();
113 if let Some(retrieved) = self.hippocampus.retrieve(&normalized) {
114 Ok(normalized.iter().zip(retrieved.pattern.iter()).map(|(&a, &b)| 0.6 * a + 0.4 * b).collect())
115 } else {
116 let sig = normalized.iter().map(|x| x.abs()).sum::<f64>() / self.dim.max(1) as f64;
117 if sig > 0.3 { self.encode(&normalized, sig)?; }
118 Ok(normalized)
119 }
120 }
121 pub fn encode(&mut self, content: &[f64], importance: f64) -> Result<()> {
122 let normalized: Vec<f64> = (0..self.dim).map(|i| content.get(i).copied().unwrap_or(0.0)).collect();
123 self.hippocampus.encode(&normalized, importance);
124 self.memory_count += 1;
125 self.replay_buffer.add(normalized, importance);
126 Ok(())
127 }
128 pub fn retrieve(&self, cue: &[f64]) -> Result<Option<Vec<f64>>> {
129 let normalized: Vec<f64> = (0..self.dim).map(|i| cue.get(i).copied().unwrap_or(0.0)).collect();
130 Ok(self.hippocampus.retrieve(&normalized).map(|m| m.pattern))
131 }
132 pub fn consolidate_slow_wave(&mut self, output: &SleepOutput) -> Result<()> {
133 for mem in self.replay_buffer.sample(output.replay_count) {
134 self.hippocampus.strengthen(&mem.pattern, 0.1);
135 }
136 self.consolidation_count += 1;
137 Ok(())
138 }
139 pub fn consolidate_rem(&mut self, output: &SleepOutput) -> Result<()> {
140 let mems = self.replay_buffer.sample(output.replay_count);
141 for i in 0..mems.len().saturating_sub(1) {
142 self.hippocampus.strengthen_association(&mems[i].pattern, &mems[i + 1].pattern, 0.05);
143 }
144 Ok(())
145 }
146 pub fn force_consolidation(&mut self) -> Result<usize> {
147 let count = self.hippocampus.consolidate_all();
148 self.consolidation_count += 1;
149 Ok(count)
150 }
151 pub fn consolidation_ratio(&self) -> f64 {
152 if self.memory_count == 0 { 0.0 } else { self.consolidation_count as f64 / self.memory_count as f64 }
153 }
154 pub fn memory_count(&self) -> usize { self.memory_count }
155 pub fn reset(&mut self) {
156 self.hippocampus = Hippocampus::new(self.dim, 500, 0.7);
157 self.replay_buffer.clear();
158 self.consolidation_count = 0;
159 self.memory_count = 0;
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 #[test]
167 fn test_hippocampus() {
168 let mut hc = Hippocampus::new(8, 100, 0.5);
169 hc.encode(&vec![0.5; 8], 1.0);
170 assert!(hc.retrieve(&vec![0.5; 8]).is_some());
171 }
172}