Skip to main content

rai_core/memory/
manager.rs

1use crate::embedding::bridge::EmbeddingBridge;
2use crate::memory::persistence::MemorySnapshot;
3use crate::memory::training::TrainingOrchestrator;
4use crate::reasoning::basins::BasinAnalyzer;
5use crate::reasoning::composition::Compositor;
6use crate::reasoning::confidence::ConfidenceGate;
7use crate::reasoning::interference::InterferenceDetector;
8use crate::reasoning::surprise::SurpriseDetector;
9use crate::types::*;
10use crate::RaiError;
11use rem_nra::nra::{NRAConfig, NonlinearResonanceMemory};
12use rem_nra::rem::{REMConfig, ResidualEquilibriumMemory};
13use rem_nra::Vec64;
14use std::path::Path;
15use std::sync::Arc;
16use tokio::sync::Mutex;
17
18/// Central memory manager that orchestrates NRA, REM, embedding, and reasoning.
19pub struct MemoryManager {
20    /// NRA memory for nonlinear addressing.
21    nra: Arc<Mutex<NonlinearResonanceMemory>>,
22    /// REM memory for structure-aware storage.
23    rem: Arc<Mutex<ResidualEquilibriumMemory>>,
24    /// Embedding bridge for text <-> vector conversion.
25    bridge: Arc<EmbeddingBridge>,
26    /// Confidence gating module.
27    confidence_gate: ConfidenceGate,
28    /// Interference detector.
29    interference_detector: InterferenceDetector,
30    /// Basin analyzer.
31    basin_analyzer: BasinAnalyzer,
32    /// Surprise detector.
33    surprise_detector: SurpriseDetector,
34    /// Training orchestrator.
35    training: Arc<Mutex<TrainingOrchestrator>>,
36    /// Text labels for stored memories (parallel to NRA items).
37    texts: Arc<Mutex<Vec<String>>>,
38    /// Next memory ID.
39    next_id: Arc<Mutex<usize>>,
40}
41
42impl MemoryManager {
43    /// Create a new MemoryManager with default configurations.
44    pub fn new(bridge: Arc<EmbeddingBridge>) -> Self {
45        let nra_config = NRAConfig {
46            dim_state: 64,
47            dim_omega: 32,
48            dim_value: 64,
49            num_units: 512,
50            train_epochs: 300,
51            ..Default::default()
52        };
53        let rem_config = REMConfig {
54            dim_memory: 256,
55            dim_key: 32,
56            dim_value: 64,
57            ..Default::default()
58        };
59
60        let mut rng = rand::thread_rng();
61        let nra = NonlinearResonanceMemory::new(nra_config, &mut rng);
62        let rem = ResidualEquilibriumMemory::new(rem_config, &mut rng);
63
64        Self {
65            nra: Arc::new(Mutex::new(nra)),
66            rem: Arc::new(Mutex::new(rem)),
67            bridge,
68            confidence_gate: ConfidenceGate::default(),
69            interference_detector: InterferenceDetector::default(),
70            basin_analyzer: BasinAnalyzer::default(),
71            surprise_detector: SurpriseDetector::default(),
72            training: Arc::new(Mutex::new(TrainingOrchestrator::default())),
73            texts: Arc::new(Mutex::new(Vec::new())),
74            next_id: Arc::new(Mutex::new(0)),
75        }
76    }
77
78    /// Create with custom NRA/REM configs.
79    pub fn with_configs(
80        bridge: Arc<EmbeddingBridge>,
81        nra_config: NRAConfig,
82        rem_config: REMConfig,
83    ) -> Self {
84        let mut rng = rand::thread_rng();
85        let nra = NonlinearResonanceMemory::new(nra_config, &mut rng);
86        let rem = ResidualEquilibriumMemory::new(rem_config, &mut rng);
87
88        Self {
89            nra: Arc::new(Mutex::new(nra)),
90            rem: Arc::new(Mutex::new(rem)),
91            bridge,
92            confidence_gate: ConfidenceGate::default(),
93            interference_detector: InterferenceDetector::default(),
94            basin_analyzer: BasinAnalyzer::default(),
95            surprise_detector: SurpriseDetector::default(),
96            training: Arc::new(Mutex::new(TrainingOrchestrator::default())),
97            texts: Arc::new(Mutex::new(Vec::new())),
98            next_id: Arc::new(Mutex::new(0)),
99        }
100    }
101
102    /// Store a fact and return an interference report.
103    pub async fn store(&self, content: &str) -> Result<InterferenceReport, RaiError> {
104        let (omega, key, value) = self.bridge.embed_text(content).await?;
105
106        // Take energy snapshot before
107        let energy_before = {
108            let nra = self.nra.lock().await;
109            nra.energy_snapshot()
110        };
111
112        // Store in NRA
113        {
114            let mut nra = self.nra.lock().await;
115            nra.store(&omega, &value)
116                .map_err(|e| RaiError::MemoryError(format!("NRA store: {e}")))?;
117        }
118
119        // Store in REM
120        {
121            let mut rem = self.rem.lock().await;
122            rem.store(&key, &value)
123                .map_err(|e| RaiError::MemoryError(format!("REM store: {e}")))?;
124        }
125
126        // Record text
127        {
128            let mut texts = self.texts.lock().await;
129            texts.push(content.to_string());
130        }
131
132        // Increment ID
133        {
134            let mut id = self.next_id.lock().await;
135            *id += 1;
136        }
137
138        // Notify training orchestrator
139        {
140            let mut training = self.training.lock().await;
141            training.item_stored();
142        }
143
144        // Take energy snapshot after
145        let energy_after = {
146            let nra = self.nra.lock().await;
147            nra.energy_snapshot()
148        };
149
150        // Detect interference
151        let texts = self.texts.lock().await;
152        // The before snapshot has one fewer item, so only compare overlapping items
153        let len = energy_before.len();
154        let report =
155            self.interference_detector
156                .detect(&energy_before, &energy_after[..len], &texts[..len]);
157
158        Ok(report)
159    }
160
161    /// Recall a memory with confidence diagnostics.
162    pub async fn recall(&self, query: &str) -> Result<RetrievalResult, RaiError> {
163        let omega = self.bridge.text_to_omega(query).await?;
164
165        let diagnostics = {
166            let nra = self.nra.lock().await;
167            nra.retrieve_with_diagnostics(&omega)
168                .map_err(|e| RaiError::MemoryError(format!("NRA retrieve: {e}")))?
169        };
170
171        let confidence = self
172            .confidence_gate
173            .classify(diagnostics.energy, diagnostics.grad_norm);
174        let explanation =
175            self.confidence_gate
176                .explain(diagnostics.energy, diagnostics.grad_norm, confidence);
177
178        // Find nearest text
179        let content = self
180            .bridge
181            .nearest_text(&diagnostics.value)
182            .await
183            .unwrap_or_else(|| "(no matching text found)".to_string());
184
185        Ok(RetrievalResult {
186            content,
187            confidence,
188            energy: diagnostics.energy,
189            steps: diagnostics.steps,
190            grad_norm: diagnostics.grad_norm,
191            explanation,
192        })
193    }
194
195    /// Query at concept intersection using compositional addressing.
196    pub async fn intersect(&self, concepts: &[String]) -> Result<IntersectionResult, RaiError> {
197        if concepts.is_empty() {
198            return Err(RaiError::InvalidInput("no concepts provided".into()));
199        }
200
201        // Embed each concept to omega
202        let mut omegas = Vec::with_capacity(concepts.len());
203        for concept in concepts {
204            let omega = self.bridge.text_to_omega(concept).await?;
205            omegas.push(omega);
206        }
207
208        // Compose omegas
209        let combined = Compositor::intersect(&omegas);
210
211        // Retrieve at intersection
212        let diagnostics = {
213            let nra = self.nra.lock().await;
214            nra.retrieve_with_diagnostics(&combined)
215                .map_err(|e| RaiError::MemoryError(format!("NRA intersect: {e}")))?
216        };
217
218        let confidence = self
219            .confidence_gate
220            .classify(diagnostics.energy, diagnostics.grad_norm);
221
222        let content = self
223            .bridge
224            .nearest_text(&diagnostics.value)
225            .await
226            .unwrap_or_else(|| "(no matching text at intersection)".to_string());
227
228        Ok(IntersectionResult {
229            content,
230            confidence,
231            energy: diagnostics.energy,
232            concepts: concepts.to_vec(),
233        })
234    }
235
236    /// Check if a new fact contradicts existing memory.
237    pub async fn check_contradiction(&self, fact: &str) -> Result<InterferenceReport, RaiError> {
238        let (omega, _key, value) = self.bridge.embed_text(fact).await?;
239
240        // Snapshot before
241        let energy_before = {
242            let nra = self.nra.lock().await;
243            nra.energy_snapshot()
244        };
245
246        // Temporarily store (we'll revert if needed)
247        {
248            let mut nra = self.nra.lock().await;
249            nra.store(&omega, &value)
250                .map_err(|e| RaiError::MemoryError(format!("NRA contradict: {e}")))?;
251        }
252
253        // Snapshot after
254        let energy_after = {
255            let nra = self.nra.lock().await;
256            nra.energy_snapshot()
257        };
258
259        let texts = self.texts.lock().await;
260        let len = energy_before.len();
261        let report =
262            self.interference_detector
263                .detect(&energy_before, &energy_after[..len], &texts[..len]);
264
265        Ok(report)
266    }
267
268    /// Measure novelty/surprise of a fact using REM prior.
269    pub async fn measure_surprise(&self, content: &str) -> Result<SurpriseResult, RaiError> {
270        let (_omega, _key, _value) = self.bridge.embed_text(content).await?;
271
272        // Get REM prior prediction
273        let rem = self.rem.lock().await;
274        let residual_norm = rem.mean_residual_norm();
275
276        Ok(self.surprise_detector.score(residual_norm))
277    }
278
279    /// Explain the confidence of a retrieval in detail.
280    pub async fn explain_confidence(&self, query: &str) -> Result<ConfidenceExplanation, RaiError> {
281        let omega = self.bridge.text_to_omega(query).await?;
282
283        let nra = self.nra.lock().await;
284        let diagnostics = nra
285            .retrieve_with_diagnostics(&omega)
286            .map_err(|e| RaiError::MemoryError(format!("NRA explain: {e}")))?;
287
288        let mut confidence = self
289            .confidence_gate
290            .classify(diagnostics.energy, diagnostics.grad_norm);
291
292        // Basin analysis
293        let mut rng = rand::thread_rng();
294        let basin_result = self
295            .basin_analyzer
296            .analyze(&nra.params, &omega, &nra.config, &mut rng);
297
298        if basin_result.is_ambiguous {
299            confidence = ConfidenceLevel::Ambiguous;
300        }
301
302        let explanation =
303            self.confidence_gate
304                .explain(diagnostics.energy, diagnostics.grad_norm, confidence);
305
306        Ok(ConfidenceExplanation {
307            confidence,
308            energy: diagnostics.energy,
309            grad_norm: diagnostics.grad_norm,
310            num_attractors: basin_result.attractors.len(),
311            basin_spread: basin_result.energy_spread,
312            explanation,
313        })
314    }
315
316    /// Get system health diagnostics.
317    pub async fn health(&self) -> Result<HealthReport, RaiError> {
318        let nra = self.nra.lock().await;
319        let rem = self.rem.lock().await;
320
321        let nra_mse = nra.mse().ok();
322        let rem_mse = rem.mse().ok();
323        let rem_residual_norm = if rem.is_empty() {
324            None
325        } else {
326            Some(rem.mean_residual_norm())
327        };
328
329        let num_memories = nra.len();
330        let nra_capacity_ratio = num_memories as f64 / nra.config.num_units as f64;
331
332        let training = self.training.lock().await;
333        let needs_training = training.needs_nra_retrain() || training.needs_rem_retrain();
334
335        Ok(HealthReport {
336            num_memories,
337            nra_mse,
338            rem_mse,
339            rem_residual_norm,
340            nra_capacity_ratio,
341            needs_training,
342        })
343    }
344
345    /// Trigger NRA retraining.
346    pub async fn train_nra(&self) -> Result<Vec<f64>, RaiError> {
347        let nra = self.nra.clone();
348        let handle = TrainingOrchestrator::spawn_nra_retrain(nra);
349        let result = handle
350            .await
351            .map_err(|e| RaiError::TrainingError(format!("join: {e}")))?;
352        {
353            let mut training = self.training.lock().await;
354            training.nra_trained();
355        }
356        result
357    }
358
359    /// Trigger REM retraining.
360    pub async fn train_rem(&self) -> Result<Vec<f64>, RaiError> {
361        let rem = self.rem.clone();
362        let handle = TrainingOrchestrator::spawn_rem_retrain(rem);
363        let result = handle
364            .await
365            .map_err(|e| RaiError::TrainingError(format!("join: {e}")))?;
366        {
367            let mut training = self.training.lock().await;
368            training.rem_trained();
369        }
370        result
371    }
372
373    /// Save full state to disk.
374    pub async fn save(&self, path: &Path) -> Result<(), RaiError> {
375        let nra = self.nra.lock().await;
376        let rem = self.rem.lock().await;
377
378        let snapshot = MemorySnapshot {
379            nra_params: nra.params.clone(),
380            nra_config: nra.config.clone(),
381            nra_items: nra.items().to_vec(),
382            rem_config: rem.config.clone(),
383            rem_encoder: rem.encoder.clone(),
384            rem_decoder: rem.decoder.clone(),
385            rem_memory_state: rem.memory_state.clone(),
386            rem_items: rem.items().to_vec(),
387            text_index: self.bridge.text_index().await,
388            omega_proj: self.bridge.omega_proj.clone(),
389            key_proj: self.bridge.key_proj.clone(),
390            value_proj: self.bridge.value_proj.clone(),
391            total_items: nra.len(),
392        };
393
394        snapshot.save(path)
395    }
396
397    /// Load state from disk. Returns a new MemoryManager.
398    pub async fn load(path: &Path, bridge: Arc<EmbeddingBridge>) -> Result<Self, RaiError> {
399        let snapshot = MemorySnapshot::load(path)?;
400
401        // Reconstruct NRA
402        let mut nra =
403            NonlinearResonanceMemory::from_params(snapshot.nra_params, snapshot.nra_config);
404        for (omega, value) in &snapshot.nra_items {
405            nra.store(omega, value)
406                .map_err(|e| RaiError::MemoryError(format!("restore NRA: {e}")))?;
407        }
408
409        // Reconstruct REM
410        let mut rng = rand::thread_rng();
411        let mut rem = ResidualEquilibriumMemory::new(snapshot.rem_config, &mut rng);
412        // Restore encoder/decoder params
413        rem.encoder = snapshot.rem_encoder;
414        rem.decoder = snapshot.rem_decoder;
415        rem.memory_state = snapshot.rem_memory_state;
416        for (key, value) in &snapshot.rem_items {
417            // We only want to restore items, not re-encode (since we have the memory state)
418            rem.store(key, value)
419                .map_err(|e| RaiError::MemoryError(format!("restore REM: {e}")))?;
420        }
421
422        // Restore text index
423        bridge.restore_text_index(snapshot.text_index.clone()).await;
424
425        let texts: Vec<String> = snapshot
426            .text_index
427            .entries
428            .iter()
429            .map(|e| e.text.clone())
430            .collect();
431
432        Ok(Self {
433            nra: Arc::new(Mutex::new(nra)),
434            rem: Arc::new(Mutex::new(rem)),
435            bridge,
436            confidence_gate: ConfidenceGate::default(),
437            interference_detector: InterferenceDetector::default(),
438            basin_analyzer: BasinAnalyzer::default(),
439            surprise_detector: SurpriseDetector::default(),
440            training: Arc::new(Mutex::new(TrainingOrchestrator::default())),
441            texts: Arc::new(Mutex::new(texts)),
442            next_id: Arc::new(Mutex::new(snapshot.total_items)),
443        })
444    }
445
446    /// Get number of stored memories.
447    pub async fn len(&self) -> usize {
448        let nra = self.nra.lock().await;
449        nra.len()
450    }
451
452    /// Take an NRA energy snapshot for external use.
453    pub async fn energy_snapshot(&self) -> Vec<(Vec64, f64)> {
454        let nra = self.nra.lock().await;
455        nra.energy_snapshot()
456    }
457}