1pub mod ca1;
14pub mod ca3;
15pub mod dentate_gyrus;
16pub mod entorhinal;
17pub mod place_cells;
18pub mod replay;
19
20pub use ca1::{CA1Layer, CA1Neuron, CA1Output};
21pub use ca3::{CA3Network, CA3Neuron, PatternCompletion};
22pub use dentate_gyrus::{DentateGyrus, GranuleCell, MossyFiber};
23pub use entorhinal::{EntorhinalCortex, GridCell, PerforantPath};
24pub use place_cells::{PlaceCell, PlaceField, SpatialMap};
25pub use replay::{ReplayBuffer, ReplayEvent, SharpWaveRipple};
26
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29use thiserror::Error;
30
31#[derive(Debug, Error)]
33pub enum HippocampusError {
34 #[error("Pattern size mismatch: expected {expected}, got {got}")]
35 PatternSizeMismatch { expected: usize, got: usize },
36
37 #[error("Memory capacity exceeded: {0}")]
38 CapacityExceeded(String),
39
40 #[error("Pattern not found: {0}")]
41 PatternNotFound(String),
42
43 #[error("Encoding failed: {0}")]
44 EncodingFailed(String),
45
46 #[error("Replay failed: {0}")]
47 ReplayFailed(String),
48}
49
50pub type Result<T> = std::result::Result<T, HippocampusError>;
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct HippocampusConfig {
55 pub input_dim: usize,
57 pub dg_size: usize,
59 pub ca3_size: usize,
61 pub ca1_size: usize,
63 pub dg_sparsity: f64,
65 pub ca3_recurrence: f64,
67 pub learning_rate: f64,
69 pub replay_buffer_size: usize,
71 pub ripple_threshold: f64,
73}
74
75impl Default for HippocampusConfig {
76 fn default() -> Self {
77 Self {
78 input_dim: 256,
79 dg_size: 2560, ca3_size: 512,
81 ca1_size: 256,
82 dg_sparsity: 0.02, ca3_recurrence: 0.04,
84 learning_rate: 0.01,
85 replay_buffer_size: 1000,
86 ripple_threshold: 0.7,
87 }
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct MemoryTrace {
94 pub id: String,
96 pub input: Vec<f64>,
98 pub dg_code: Vec<f64>,
100 pub ca3_code: Vec<f64>,
102 pub ca1_output: Vec<f64>,
104 pub timestamp: u64,
106 pub strength: f64,
108 pub replay_count: u32,
110 pub context: Option<Vec<f64>>,
112}
113
114impl MemoryTrace {
115 pub fn new(id: String, input: Vec<f64>) -> Self {
116 Self {
117 id,
118 input,
119 dg_code: Vec::new(),
120 ca3_code: Vec::new(),
121 ca1_output: Vec::new(),
122 timestamp: std::time::SystemTime::now()
123 .duration_since(std::time::UNIX_EPOCH)
124 .unwrap_or_default()
125 .as_millis() as u64,
126 strength: 1.0,
127 replay_count: 0,
128 context: None,
129 }
130 }
131}
132
133pub struct Hippocampus {
135 config: HippocampusConfig,
136 entorhinal: EntorhinalCortex,
138 dentate_gyrus: DentateGyrus,
140 ca3: CA3Network,
142 ca1: CA1Layer,
144 place_cells: SpatialMap,
146 replay: ReplayBuffer,
148 memories: HashMap<String, MemoryTrace>,
150 theta_phase: f64,
152 theta_frequency: f64,
154}
155
156impl Hippocampus {
157 pub fn new() -> Self {
159 Self::with_config(HippocampusConfig::default())
160 }
161
162 pub fn with_config(config: HippocampusConfig) -> Self {
164 let entorhinal = EntorhinalCortex::new(config.input_dim, config.dg_size);
165 let dentate_gyrus = DentateGyrus::new(config.dg_size, config.dg_sparsity);
166 let ca3 = CA3Network::new(config.ca3_size, config.dg_size, config.ca3_recurrence);
167 let ca1 = CA1Layer::new(config.ca1_size, config.ca3_size);
168 let place_cells = SpatialMap::new(64.0, 64.0, 100); let replay = ReplayBuffer::new(config.replay_buffer_size);
170
171 Self {
172 config,
173 entorhinal,
174 dentate_gyrus,
175 ca3,
176 ca1,
177 place_cells,
178 replay,
179 memories: HashMap::new(),
180 theta_phase: 0.0,
181 theta_frequency: 8.0, }
183 }
184
185 pub fn encode(&mut self, input: &[f64], context: Option<&[f64]>) -> Result<String> {
187 if input.len() != self.config.input_dim {
188 return Err(HippocampusError::PatternSizeMismatch {
189 expected: self.config.input_dim,
190 got: input.len(),
191 });
192 }
193
194 let id = uuid::Uuid::now_v7().to_string();
196 let mut trace = MemoryTrace::new(id.clone(), input.to_vec());
197
198 let ec_output = self.entorhinal.process(input);
200
201 let dg_output = self.dentate_gyrus.separate(&ec_output);
203 trace.dg_code = dg_output.clone();
204
205 let ca3_output = self.ca3.encode(&dg_output);
207 trace.ca3_code = ca3_output.clone();
208
209 let ca1_output = self.ca1.process(&ca3_output);
211 trace.ca1_output = ca1_output;
212
213 if let Some(ctx) = context {
215 trace.context = Some(ctx.to_vec());
216 }
217
218 self.replay.add(ReplayEvent {
220 memory_id: id.clone(),
221 pattern: trace.ca3_code.clone(),
222 timestamp: trace.timestamp,
223 priority: 1.0,
224 });
225
226 self.memories.insert(id.clone(), trace);
228
229 Ok(id)
230 }
231
232 pub fn retrieve(&mut self, cue: &[f64]) -> Result<Vec<f64>> {
234 if cue.len() != self.config.input_dim {
235 return Err(HippocampusError::PatternSizeMismatch {
236 expected: self.config.input_dim,
237 got: cue.len(),
238 });
239 }
240
241 let ec_output = self.entorhinal.process(cue);
243
244 let dg_output = self.dentate_gyrus.separate(&ec_output);
246
247 let completed = self.ca3.complete(&dg_output);
249
250 let output = self.ca1.process(&completed);
252
253 let retrieved = self.entorhinal.decode(&output);
255
256 Ok(retrieved)
257 }
258
259 pub fn retrieve_by_id(&self, id: &str) -> Option<&MemoryTrace> {
261 self.memories.get(id)
262 }
263
264 pub fn replay(&mut self, num_events: usize) -> Vec<String> {
266 let events = self.replay.sample(num_events);
267 let mut replayed_ids = Vec::new();
268
269 for event in events {
270 if let Some(trace) = self.memories.get_mut(&event.memory_id) {
271 trace.strength *= 1.1;
273 trace.strength = trace.strength.min(10.0);
274 trace.replay_count += 1;
275
276 self.ca3.reactivate(&event.pattern);
278
279 let ca1_out = self.ca1.process(&event.pattern);
281 trace.ca1_output = ca1_out;
282
283 replayed_ids.push(event.memory_id.clone());
284 }
285 }
286
287 replayed_ids
288 }
289
290 pub fn sharp_wave_ripple(&mut self) -> Option<SharpWaveRipple> {
292 let activity_level = self.ca3.get_activity_level();
294
295 if activity_level > self.config.ripple_threshold {
296 let events = self.replay.sample_prioritized(5);
298
299 if !events.is_empty() {
300 let patterns: Vec<Vec<f64>> = events.iter().map(|e| e.pattern.clone()).collect();
301
302 let ripple = SharpWaveRipple {
303 timestamp: std::time::SystemTime::now()
304 .duration_since(std::time::UNIX_EPOCH)
305 .unwrap_or_default()
306 .as_millis() as u64,
307 patterns,
308 duration_ms: 100,
309 frequency_hz: 150.0, };
311
312 for event in events {
314 if let Some(trace) = self.memories.get_mut(&event.memory_id) {
315 trace.strength *= 1.2;
316 trace.replay_count += 1;
317 }
318 }
319
320 return Some(ripple);
321 }
322 }
323
324 None
325 }
326
327 pub fn update_location(&mut self, x: f64, y: f64) {
329 self.place_cells.update_position(x, y);
330 }
331
332 pub fn get_place_activity(&self) -> Vec<f64> {
334 self.place_cells.get_activity()
335 }
336
337 pub fn step_theta(&mut self, dt: f64) {
339 self.theta_phase += 2.0 * std::f64::consts::PI * self.theta_frequency * dt;
340 if self.theta_phase > 2.0 * std::f64::consts::PI {
341 self.theta_phase -= 2.0 * std::f64::consts::PI;
342 }
343 }
344
345 pub fn theta_phase(&self) -> f64 {
347 self.theta_phase
348 }
349
350 pub fn memory_count(&self) -> usize {
352 self.memories.len()
353 }
354
355 pub fn memory_ids(&self) -> Vec<String> {
357 self.memories.keys().cloned().collect()
358 }
359
360 pub fn decay(&mut self, factor: f64) {
362 for trace in self.memories.values_mut() {
363 trace.strength *= factor;
364 }
365
366 self.memories.retain(|_, trace| trace.strength > 0.01);
368 }
369
370 pub fn stats(&self) -> HippocampusStats {
372 HippocampusStats {
373 memory_count: self.memories.len(),
374 replay_buffer_size: self.replay.len(),
375 average_strength: if self.memories.is_empty() {
376 0.0
377 } else {
378 self.memories.values().map(|m| m.strength).sum::<f64>()
379 / self.memories.len() as f64
380 },
381 ca3_activity: self.ca3.get_activity_level(),
382 theta_phase: self.theta_phase,
383 }
384 }
385
386 pub fn clear(&mut self) {
388 self.memories.clear();
389 self.replay.clear();
390 self.ca3.reset();
391 self.ca1.reset();
392 }
393}
394
395impl Default for Hippocampus {
396 fn default() -> Self {
397 Self::new()
398 }
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct HippocampusStats {
404 pub memory_count: usize,
405 pub replay_buffer_size: usize,
406 pub average_strength: f64,
407 pub ca3_activity: f64,
408 pub theta_phase: f64,
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_hippocampus_creation() {
417 let hippo = Hippocampus::new();
418 assert_eq!(hippo.memory_count(), 0);
419 }
420
421 #[test]
422 fn test_encode_retrieve() {
423 let mut hippo = Hippocampus::with_config(HippocampusConfig {
424 input_dim: 64,
425 dg_size: 256,
426 ca3_size: 128,
427 ca1_size: 64,
428 ..Default::default()
429 });
430
431 let input: Vec<f64> = (0..64).map(|i| (i as f64 / 64.0)).collect();
433 let id = hippo.encode(&input, None).unwrap();
434
435 assert_eq!(hippo.memory_count(), 1);
436
437 let mut cue = input.clone();
439 for i in 32..64 {
440 cue[i] = 0.0; }
442
443 let retrieved = hippo.retrieve(&cue).unwrap();
444 assert_eq!(retrieved.len(), 64);
445 }
446
447 #[test]
448 fn test_replay() {
449 let mut hippo = Hippocampus::with_config(HippocampusConfig {
450 input_dim: 32,
451 dg_size: 128,
452 ca3_size: 64,
453 ca1_size: 32,
454 ..Default::default()
455 });
456
457 for i in 0..5 {
459 let input: Vec<f64> = (0..32).map(|j| ((i + j) as f64 / 32.0)).collect();
460 hippo.encode(&input, None).unwrap();
461 }
462
463 let replayed = hippo.replay(3);
465 assert!(replayed.len() <= 3);
466 }
467
468 #[test]
469 fn test_theta_oscillation() {
470 let mut hippo = Hippocampus::new();
471
472 hippo.step_theta(0.01);
473 assert!(hippo.theta_phase() > 0.0);
474
475 for _ in 0..1000 {
477 hippo.step_theta(0.001);
478 }
479 assert!(hippo.theta_phase() < 2.0 * std::f64::consts::PI);
480 }
481
482 #[test]
483 fn test_decay() {
484 let mut hippo = Hippocampus::with_config(HippocampusConfig {
485 input_dim: 16,
486 dg_size: 64,
487 ca3_size: 32,
488 ca1_size: 16,
489 ..Default::default()
490 });
491
492 let input = vec![0.5; 16];
493 hippo.encode(&input, None).unwrap();
494
495 let initial_strength = hippo.memories.values().next().unwrap().strength;
496
497 hippo.decay(0.9);
498
499 let new_strength = hippo.memories.values().next().unwrap().strength;
500 assert!(new_strength < initial_strength);
501 }
502}