omega_brain/
memory_system.rs

1//! Memory System - Self-contained hippocampal memory implementation
2
3use crate::{BrainConfig, Result};
4use crate::sleep_system::SleepOutput;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Memory entry
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Memory {
11    pub pattern: Vec<f64>,
12    pub importance: f64,
13}
14
15/// Replay buffer
16#[derive(Debug, Clone)]
17pub struct ReplayBuffer {
18    memories: Vec<Memory>,
19    capacity: usize,
20}
21
22impl ReplayBuffer {
23    pub fn new(capacity: usize) -> Self { Self { memories: Vec::with_capacity(capacity), capacity } }
24    pub fn add(&mut self, pattern: Vec<f64>, importance: f64) {
25        if self.memories.len() >= self.capacity {
26            self.memories.sort_by(|a, b| a.importance.partial_cmp(&b.importance).unwrap());
27            self.memories.remove(0);
28        }
29        self.memories.push(Memory { pattern, importance });
30    }
31    pub fn sample(&self, n: usize) -> Vec<Memory> {
32        self.memories.iter().rev().take(n).cloned().collect()
33    }
34    pub fn clear(&mut self) { self.memories.clear(); }
35}
36
37/// Hippocampus - pattern separation and completion
38#[derive(Debug, Clone)]
39pub struct Hippocampus {
40    patterns: HashMap<String, Memory>,
41    dim: usize,
42    threshold: f64,
43    next_id: usize,
44}
45
46impl Hippocampus {
47    pub fn new(dim: usize, _ca3_size: usize, threshold: f64) -> Self {
48        Self { patterns: HashMap::new(), dim, threshold, next_id: 0 }
49    }
50    pub fn dim(&self) -> usize { self.dim }
51    pub fn encode(&mut self, pattern: &[f64], importance: f64) -> String {
52        let id = format!("mem_{}", self.next_id);
53        self.next_id += 1;
54        self.patterns.insert(id.clone(), Memory { pattern: pattern.to_vec(), importance });
55        id
56    }
57    pub fn retrieve(&self, cue: &[f64]) -> Option<Memory> {
58        let mut best: Option<(&Memory, f64)> = None;
59        for mem in self.patterns.values() {
60            let sim = cosine_similarity(cue, &mem.pattern);
61            if sim > self.threshold && (best.is_none() || sim > best.unwrap().1) {
62                best = Some((mem, sim));
63            }
64        }
65        best.map(|(m, _)| m.clone())
66    }
67    pub fn strengthen(&mut self, pattern: &[f64], amount: f64) {
68        for mem in self.patterns.values_mut() {
69            let sim = cosine_similarity(pattern, &mem.pattern);
70            if sim > 0.8 { mem.importance = (mem.importance + amount).min(1.0); }
71        }
72    }
73    pub fn strengthen_association(&mut self, a: &[f64], b: &[f64], amount: f64) {
74        self.strengthen(a, amount * 0.5);
75        self.strengthen(b, amount * 0.5);
76    }
77    pub fn consolidate_all(&mut self) -> usize {
78        let count = self.patterns.len();
79        for mem in self.patterns.values_mut() { mem.importance = (mem.importance * 1.1).min(1.0); }
80        count
81    }
82}
83
84/// Cosine similarity (free function to avoid borrow conflicts)
85fn cosine_similarity(a: &[f64], b: &[f64]) -> f64 {
86    let mut dot = 0.0; let mut na = 0.0; let mut nb = 0.0;
87    for (&x, &y) in a.iter().zip(b.iter()) { dot += x * y; na += x * x; nb += y * y; }
88    let denom = (na * nb).sqrt();
89    if denom > 0.0 { dot / denom } else { 0.0 }
90}
91
92/// Memory system
93pub struct MemorySystem {
94    hippocampus: Hippocampus,
95    replay_buffer: ReplayBuffer,
96    consolidation_count: usize,
97    memory_count: usize,
98    dim: usize,
99}
100
101impl MemorySystem {
102    pub fn new(config: &BrainConfig) -> Self {
103        Self {
104            hippocampus: Hippocampus::new(config.pattern_dim, config.ca3_size, config.consolidation_threshold),
105            replay_buffer: ReplayBuffer::new(config.replay_buffer_size),
106            consolidation_count: 0,
107            memory_count: 0,
108            dim: config.pattern_dim,
109        }
110    }
111    pub fn process(&mut self, content: &[f64]) -> Result<Vec<f64>> {
112        let normalized: Vec<f64> = (0..self.dim).map(|i| content.get(i).copied().unwrap_or(0.0)).collect();
113        if let Some(retrieved) = self.hippocampus.retrieve(&normalized) {
114            Ok(normalized.iter().zip(retrieved.pattern.iter()).map(|(&a, &b)| 0.6 * a + 0.4 * b).collect())
115        } else {
116            let sig = normalized.iter().map(|x| x.abs()).sum::<f64>() / self.dim.max(1) as f64;
117            if sig > 0.3 { self.encode(&normalized, sig)?; }
118            Ok(normalized)
119        }
120    }
121    pub fn encode(&mut self, content: &[f64], importance: f64) -> Result<()> {
122        let normalized: Vec<f64> = (0..self.dim).map(|i| content.get(i).copied().unwrap_or(0.0)).collect();
123        self.hippocampus.encode(&normalized, importance);
124        self.memory_count += 1;
125        self.replay_buffer.add(normalized, importance);
126        Ok(())
127    }
128    pub fn retrieve(&self, cue: &[f64]) -> Result<Option<Vec<f64>>> {
129        let normalized: Vec<f64> = (0..self.dim).map(|i| cue.get(i).copied().unwrap_or(0.0)).collect();
130        Ok(self.hippocampus.retrieve(&normalized).map(|m| m.pattern))
131    }
132    pub fn consolidate_slow_wave(&mut self, output: &SleepOutput) -> Result<()> {
133        for mem in self.replay_buffer.sample(output.replay_count) {
134            self.hippocampus.strengthen(&mem.pattern, 0.1);
135        }
136        self.consolidation_count += 1;
137        Ok(())
138    }
139    pub fn consolidate_rem(&mut self, output: &SleepOutput) -> Result<()> {
140        let mems = self.replay_buffer.sample(output.replay_count);
141        for i in 0..mems.len().saturating_sub(1) {
142            self.hippocampus.strengthen_association(&mems[i].pattern, &mems[i + 1].pattern, 0.05);
143        }
144        Ok(())
145    }
146    pub fn force_consolidation(&mut self) -> Result<usize> {
147        let count = self.hippocampus.consolidate_all();
148        self.consolidation_count += 1;
149        Ok(count)
150    }
151    pub fn consolidation_ratio(&self) -> f64 {
152        if self.memory_count == 0 { 0.0 } else { self.consolidation_count as f64 / self.memory_count as f64 }
153    }
154    pub fn memory_count(&self) -> usize { self.memory_count }
155    pub fn reset(&mut self) {
156        self.hippocampus = Hippocampus::new(self.dim, 500, 0.7);
157        self.replay_buffer.clear();
158        self.consolidation_count = 0;
159        self.memory_count = 0;
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    #[test]
167    fn test_hippocampus() {
168        let mut hc = Hippocampus::new(8, 100, 0.5);
169        hc.encode(&vec![0.5; 8], 1.0);
170        assert!(hc.retrieve(&vec![0.5; 8]).is_some());
171    }
172}