omega_hippocampus/
lib.rs

1//! Omega Hippocampus
2//!
3//! Biologically-inspired hippocampal memory system implementing:
4//! - Dentate Gyrus (DG): Pattern separation via sparse coding
5//! - CA3: Autoassociative memory for pattern completion
6//! - CA1: Output layer for memory consolidation
7//! - Entorhinal Cortex: Input/output interface
8//! - Place cells: Spatial memory and navigation
9//! - Sharp-wave ripples: Memory replay and consolidation
10//!
11//! Based on computational neuroscience models of hippocampal function.
12
13pub mod ca1;
14pub mod ca3;
15pub mod dentate_gyrus;
16pub mod entorhinal;
17pub mod place_cells;
18pub mod replay;
19
20pub use ca1::{CA1Layer, CA1Neuron, CA1Output};
21pub use ca3::{CA3Network, CA3Neuron, PatternCompletion};
22pub use dentate_gyrus::{DentateGyrus, GranuleCell, MossyFiber};
23pub use entorhinal::{EntorhinalCortex, GridCell, PerforantPath};
24pub use place_cells::{PlaceCell, PlaceField, SpatialMap};
25pub use replay::{ReplayBuffer, ReplayEvent, SharpWaveRipple};
26
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29use thiserror::Error;
30
31/// Hippocampus errors
32#[derive(Debug, Error)]
33pub enum HippocampusError {
34    #[error("Pattern size mismatch: expected {expected}, got {got}")]
35    PatternSizeMismatch { expected: usize, got: usize },
36
37    #[error("Memory capacity exceeded: {0}")]
38    CapacityExceeded(String),
39
40    #[error("Pattern not found: {0}")]
41    PatternNotFound(String),
42
43    #[error("Encoding failed: {0}")]
44    EncodingFailed(String),
45
46    #[error("Replay failed: {0}")]
47    ReplayFailed(String),
48}
49
50pub type Result<T> = std::result::Result<T, HippocampusError>;
51
52/// Configuration for the hippocampal system
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct HippocampusConfig {
55    /// Dimension of input patterns
56    pub input_dim: usize,
57    /// Number of granule cells in DG (typically 10x input)
58    pub dg_size: usize,
59    /// Number of CA3 neurons
60    pub ca3_size: usize,
61    /// Number of CA1 neurons
62    pub ca1_size: usize,
63    /// Sparsity of DG representation (fraction active)
64    pub dg_sparsity: f64,
65    /// CA3 recurrent connection probability
66    pub ca3_recurrence: f64,
67    /// Learning rate for synaptic plasticity
68    pub learning_rate: f64,
69    /// Replay buffer size
70    pub replay_buffer_size: usize,
71    /// Sharp-wave ripple threshold
72    pub ripple_threshold: f64,
73}
74
75impl Default for HippocampusConfig {
76    fn default() -> Self {
77        Self {
78            input_dim: 256,
79            dg_size: 2560, // 10x expansion
80            ca3_size: 512,
81            ca1_size: 256,
82            dg_sparsity: 0.02, // 2% active (sparse coding)
83            ca3_recurrence: 0.04,
84            learning_rate: 0.01,
85            replay_buffer_size: 1000,
86            ripple_threshold: 0.7,
87        }
88    }
89}
90
91/// Memory trace stored in hippocampus
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct MemoryTrace {
94    /// Unique identifier
95    pub id: String,
96    /// Original input pattern
97    pub input: Vec<f64>,
98    /// DG sparse representation
99    pub dg_code: Vec<f64>,
100    /// CA3 representation
101    pub ca3_code: Vec<f64>,
102    /// CA1 output
103    pub ca1_output: Vec<f64>,
104    /// Timestamp of encoding
105    pub timestamp: u64,
106    /// Strength of memory (consolidation level)
107    pub strength: f64,
108    /// Number of times replayed
109    pub replay_count: u32,
110    /// Associated context/location
111    pub context: Option<Vec<f64>>,
112}
113
114impl MemoryTrace {
115    pub fn new(id: String, input: Vec<f64>) -> Self {
116        Self {
117            id,
118            input,
119            dg_code: Vec::new(),
120            ca3_code: Vec::new(),
121            ca1_output: Vec::new(),
122            timestamp: std::time::SystemTime::now()
123                .duration_since(std::time::UNIX_EPOCH)
124                .unwrap_or_default()
125                .as_millis() as u64,
126            strength: 1.0,
127            replay_count: 0,
128            context: None,
129        }
130    }
131}
132
133/// The complete hippocampal formation
134pub struct Hippocampus {
135    config: HippocampusConfig,
136    /// Entorhinal cortex (input/output interface)
137    entorhinal: EntorhinalCortex,
138    /// Dentate gyrus (pattern separation)
139    dentate_gyrus: DentateGyrus,
140    /// CA3 (autoassociative memory)
141    ca3: CA3Network,
142    /// CA1 (output layer)
143    ca1: CA1Layer,
144    /// Place cell system
145    place_cells: SpatialMap,
146    /// Replay buffer
147    replay: ReplayBuffer,
148    /// Stored memory traces
149    memories: HashMap<String, MemoryTrace>,
150    /// Current theta phase (0 to 2π)
151    theta_phase: f64,
152    /// Theta frequency (Hz)
153    theta_frequency: f64,
154}
155
156impl Hippocampus {
157    /// Create new hippocampus with default configuration
158    pub fn new() -> Self {
159        Self::with_config(HippocampusConfig::default())
160    }
161
162    /// Create hippocampus with custom configuration
163    pub fn with_config(config: HippocampusConfig) -> Self {
164        let entorhinal = EntorhinalCortex::new(config.input_dim, config.dg_size);
165        let dentate_gyrus = DentateGyrus::new(config.dg_size, config.dg_sparsity);
166        let ca3 = CA3Network::new(config.ca3_size, config.dg_size, config.ca3_recurrence);
167        let ca1 = CA1Layer::new(config.ca1_size, config.ca3_size);
168        let place_cells = SpatialMap::new(64.0, 64.0, 100); // 64x64 environment, 100 place cells
169        let replay = ReplayBuffer::new(config.replay_buffer_size);
170
171        Self {
172            config,
173            entorhinal,
174            dentate_gyrus,
175            ca3,
176            ca1,
177            place_cells,
178            replay,
179            memories: HashMap::new(),
180            theta_phase: 0.0,
181            theta_frequency: 8.0, // 8 Hz theta
182        }
183    }
184
185    /// Encode a new memory
186    pub fn encode(&mut self, input: &[f64], context: Option<&[f64]>) -> Result<String> {
187        if input.len() != self.config.input_dim {
188            return Err(HippocampusError::PatternSizeMismatch {
189                expected: self.config.input_dim,
190                got: input.len(),
191            });
192        }
193
194        // Generate memory ID
195        let id = uuid::Uuid::now_v7().to_string();
196        let mut trace = MemoryTrace::new(id.clone(), input.to_vec());
197
198        // Step 1: Entorhinal cortex preprocessing
199        let ec_output = self.entorhinal.process(input);
200
201        // Step 2: Dentate gyrus pattern separation
202        let dg_output = self.dentate_gyrus.separate(&ec_output);
203        trace.dg_code = dg_output.clone();
204
205        // Step 3: CA3 encoding (sparse DG → CA3)
206        let ca3_output = self.ca3.encode(&dg_output);
207        trace.ca3_code = ca3_output.clone();
208
209        // Step 4: CA3 → CA1 transfer
210        let ca1_output = self.ca1.process(&ca3_output);
211        trace.ca1_output = ca1_output;
212
213        // Store context if provided
214        if let Some(ctx) = context {
215            trace.context = Some(ctx.to_vec());
216        }
217
218        // Add to replay buffer
219        self.replay.add(ReplayEvent {
220            memory_id: id.clone(),
221            pattern: trace.ca3_code.clone(),
222            timestamp: trace.timestamp,
223            priority: 1.0,
224        });
225
226        // Store memory trace
227        self.memories.insert(id.clone(), trace);
228
229        Ok(id)
230    }
231
232    /// Retrieve/complete a memory pattern
233    pub fn retrieve(&mut self, cue: &[f64]) -> Result<Vec<f64>> {
234        if cue.len() != self.config.input_dim {
235            return Err(HippocampusError::PatternSizeMismatch {
236                expected: self.config.input_dim,
237                got: cue.len(),
238            });
239        }
240
241        // Entorhinal preprocessing
242        let ec_output = self.entorhinal.process(cue);
243
244        // DG separation (creates retrieval cue)
245        let dg_output = self.dentate_gyrus.separate(&ec_output);
246
247        // CA3 pattern completion
248        let completed = self.ca3.complete(&dg_output);
249
250        // CA1 output
251        let output = self.ca1.process(&completed);
252
253        // Entorhinal decoding (back to input space)
254        let retrieved = self.entorhinal.decode(&output);
255
256        Ok(retrieved)
257    }
258
259    /// Retrieve by memory ID
260    pub fn retrieve_by_id(&self, id: &str) -> Option<&MemoryTrace> {
261        self.memories.get(id)
262    }
263
264    /// Perform replay during "offline" periods (sleep)
265    pub fn replay(&mut self, num_events: usize) -> Vec<String> {
266        let events = self.replay.sample(num_events);
267        let mut replayed_ids = Vec::new();
268
269        for event in events {
270            if let Some(trace) = self.memories.get_mut(&event.memory_id) {
271                // Replay strengthens memory
272                trace.strength *= 1.1;
273                trace.strength = trace.strength.min(10.0);
274                trace.replay_count += 1;
275
276                // Reactivate CA3 pattern
277                self.ca3.reactivate(&event.pattern);
278
279                // Update CA1
280                let ca1_out = self.ca1.process(&event.pattern);
281                trace.ca1_output = ca1_out;
282
283                replayed_ids.push(event.memory_id.clone());
284            }
285        }
286
287        replayed_ids
288    }
289
290    /// Generate sharp-wave ripple replay
291    pub fn sharp_wave_ripple(&mut self) -> Option<SharpWaveRipple> {
292        // Check if conditions are right for SWR
293        let activity_level = self.ca3.get_activity_level();
294
295        if activity_level > self.config.ripple_threshold {
296            // Sample memories for replay
297            let events = self.replay.sample_prioritized(5);
298
299            if !events.is_empty() {
300                let patterns: Vec<Vec<f64>> = events.iter().map(|e| e.pattern.clone()).collect();
301
302                let ripple = SharpWaveRipple {
303                    timestamp: std::time::SystemTime::now()
304                        .duration_since(std::time::UNIX_EPOCH)
305                        .unwrap_or_default()
306                        .as_millis() as u64,
307                    patterns,
308                    duration_ms: 100,
309                    frequency_hz: 150.0, // Ripple frequency
310                };
311
312                // Process replay
313                for event in events {
314                    if let Some(trace) = self.memories.get_mut(&event.memory_id) {
315                        trace.strength *= 1.2;
316                        trace.replay_count += 1;
317                    }
318                }
319
320                return Some(ripple);
321            }
322        }
323
324        None
325    }
326
327    /// Update spatial representation (place cells)
328    pub fn update_location(&mut self, x: f64, y: f64) {
329        self.place_cells.update_position(x, y);
330    }
331
332    /// Get current place cell activity
333    pub fn get_place_activity(&self) -> Vec<f64> {
334        self.place_cells.get_activity()
335    }
336
337    /// Advance theta oscillation
338    pub fn step_theta(&mut self, dt: f64) {
339        self.theta_phase += 2.0 * std::f64::consts::PI * self.theta_frequency * dt;
340        if self.theta_phase > 2.0 * std::f64::consts::PI {
341            self.theta_phase -= 2.0 * std::f64::consts::PI;
342        }
343    }
344
345    /// Get current theta phase
346    pub fn theta_phase(&self) -> f64 {
347        self.theta_phase
348    }
349
350    /// Get memory count
351    pub fn memory_count(&self) -> usize {
352        self.memories.len()
353    }
354
355    /// Get all memory IDs
356    pub fn memory_ids(&self) -> Vec<String> {
357        self.memories.keys().cloned().collect()
358    }
359
360    /// Decay old memories
361    pub fn decay(&mut self, factor: f64) {
362        for trace in self.memories.values_mut() {
363            trace.strength *= factor;
364        }
365
366        // Remove very weak memories
367        self.memories.retain(|_, trace| trace.strength > 0.01);
368    }
369
370    /// Get statistics
371    pub fn stats(&self) -> HippocampusStats {
372        HippocampusStats {
373            memory_count: self.memories.len(),
374            replay_buffer_size: self.replay.len(),
375            average_strength: if self.memories.is_empty() {
376                0.0
377            } else {
378                self.memories.values().map(|m| m.strength).sum::<f64>()
379                    / self.memories.len() as f64
380            },
381            ca3_activity: self.ca3.get_activity_level(),
382            theta_phase: self.theta_phase,
383        }
384    }
385
386    /// Clear all memories
387    pub fn clear(&mut self) {
388        self.memories.clear();
389        self.replay.clear();
390        self.ca3.reset();
391        self.ca1.reset();
392    }
393}
394
395impl Default for Hippocampus {
396    fn default() -> Self {
397        Self::new()
398    }
399}
400
401/// Statistics about hippocampus state
402#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct HippocampusStats {
404    pub memory_count: usize,
405    pub replay_buffer_size: usize,
406    pub average_strength: f64,
407    pub ca3_activity: f64,
408    pub theta_phase: f64,
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_hippocampus_creation() {
417        let hippo = Hippocampus::new();
418        assert_eq!(hippo.memory_count(), 0);
419    }
420
421    #[test]
422    fn test_encode_retrieve() {
423        let mut hippo = Hippocampus::with_config(HippocampusConfig {
424            input_dim: 64,
425            dg_size: 256,
426            ca3_size: 128,
427            ca1_size: 64,
428            ..Default::default()
429        });
430
431        // Encode a pattern
432        let input: Vec<f64> = (0..64).map(|i| (i as f64 / 64.0)).collect();
433        let id = hippo.encode(&input, None).unwrap();
434
435        assert_eq!(hippo.memory_count(), 1);
436
437        // Retrieve with partial cue
438        let mut cue = input.clone();
439        for i in 32..64 {
440            cue[i] = 0.0; // Zero out half
441        }
442
443        let retrieved = hippo.retrieve(&cue).unwrap();
444        assert_eq!(retrieved.len(), 64);
445    }
446
447    #[test]
448    fn test_replay() {
449        let mut hippo = Hippocampus::with_config(HippocampusConfig {
450            input_dim: 32,
451            dg_size: 128,
452            ca3_size: 64,
453            ca1_size: 32,
454            ..Default::default()
455        });
456
457        // Encode multiple patterns
458        for i in 0..5 {
459            let input: Vec<f64> = (0..32).map(|j| ((i + j) as f64 / 32.0)).collect();
460            hippo.encode(&input, None).unwrap();
461        }
462
463        // Replay
464        let replayed = hippo.replay(3);
465        assert!(replayed.len() <= 3);
466    }
467
468    #[test]
469    fn test_theta_oscillation() {
470        let mut hippo = Hippocampus::new();
471
472        hippo.step_theta(0.01);
473        assert!(hippo.theta_phase() > 0.0);
474
475        // Full cycle
476        for _ in 0..1000 {
477            hippo.step_theta(0.001);
478        }
479        assert!(hippo.theta_phase() < 2.0 * std::f64::consts::PI);
480    }
481
482    #[test]
483    fn test_decay() {
484        let mut hippo = Hippocampus::with_config(HippocampusConfig {
485            input_dim: 16,
486            dg_size: 64,
487            ca3_size: 32,
488            ca1_size: 16,
489            ..Default::default()
490        });
491
492        let input = vec![0.5; 16];
493        hippo.encode(&input, None).unwrap();
494
495        let initial_strength = hippo.memories.values().next().unwrap().strength;
496
497        hippo.decay(0.9);
498
499        let new_strength = hippo.memories.values().next().unwrap().strength;
500        assert!(new_strength < initial_strength);
501    }
502}