1use crate::embedding::bridge::EmbeddingBridge;
2use crate::memory::persistence::MemorySnapshot;
3use crate::memory::training::TrainingOrchestrator;
4use crate::reasoning::basins::BasinAnalyzer;
5use crate::reasoning::composition::Compositor;
6use crate::reasoning::confidence::ConfidenceGate;
7use crate::reasoning::interference::InterferenceDetector;
8use crate::reasoning::surprise::SurpriseDetector;
9use crate::types::*;
10use crate::RaiError;
11use rem_nra::nra::{NRAConfig, NonlinearResonanceMemory};
12use rem_nra::rem::{REMConfig, ResidualEquilibriumMemory};
13use rem_nra::Vec64;
14use std::path::Path;
15use std::sync::Arc;
16use tokio::sync::Mutex;
17
18pub struct MemoryManager {
20 nra: Arc<Mutex<NonlinearResonanceMemory>>,
22 rem: Arc<Mutex<ResidualEquilibriumMemory>>,
24 bridge: Arc<EmbeddingBridge>,
26 confidence_gate: ConfidenceGate,
28 interference_detector: InterferenceDetector,
30 basin_analyzer: BasinAnalyzer,
32 surprise_detector: SurpriseDetector,
34 training: Arc<Mutex<TrainingOrchestrator>>,
36 texts: Arc<Mutex<Vec<String>>>,
38 next_id: Arc<Mutex<usize>>,
40}
41
42impl MemoryManager {
43 pub fn new(bridge: Arc<EmbeddingBridge>) -> Self {
45 let nra_config = NRAConfig {
46 dim_state: 64,
47 dim_omega: 32,
48 dim_value: 64,
49 num_units: 512,
50 train_epochs: 300,
51 ..Default::default()
52 };
53 let rem_config = REMConfig {
54 dim_memory: 256,
55 dim_key: 32,
56 dim_value: 64,
57 ..Default::default()
58 };
59
60 let mut rng = rand::thread_rng();
61 let nra = NonlinearResonanceMemory::new(nra_config, &mut rng);
62 let rem = ResidualEquilibriumMemory::new(rem_config, &mut rng);
63
64 Self {
65 nra: Arc::new(Mutex::new(nra)),
66 rem: Arc::new(Mutex::new(rem)),
67 bridge,
68 confidence_gate: ConfidenceGate::default(),
69 interference_detector: InterferenceDetector::default(),
70 basin_analyzer: BasinAnalyzer::default(),
71 surprise_detector: SurpriseDetector::default(),
72 training: Arc::new(Mutex::new(TrainingOrchestrator::default())),
73 texts: Arc::new(Mutex::new(Vec::new())),
74 next_id: Arc::new(Mutex::new(0)),
75 }
76 }
77
78 pub fn with_configs(
80 bridge: Arc<EmbeddingBridge>,
81 nra_config: NRAConfig,
82 rem_config: REMConfig,
83 ) -> Self {
84 let mut rng = rand::thread_rng();
85 let nra = NonlinearResonanceMemory::new(nra_config, &mut rng);
86 let rem = ResidualEquilibriumMemory::new(rem_config, &mut rng);
87
88 Self {
89 nra: Arc::new(Mutex::new(nra)),
90 rem: Arc::new(Mutex::new(rem)),
91 bridge,
92 confidence_gate: ConfidenceGate::default(),
93 interference_detector: InterferenceDetector::default(),
94 basin_analyzer: BasinAnalyzer::default(),
95 surprise_detector: SurpriseDetector::default(),
96 training: Arc::new(Mutex::new(TrainingOrchestrator::default())),
97 texts: Arc::new(Mutex::new(Vec::new())),
98 next_id: Arc::new(Mutex::new(0)),
99 }
100 }
101
102 pub async fn store(&self, content: &str) -> Result<InterferenceReport, RaiError> {
104 let (omega, key, value) = self.bridge.embed_text(content).await?;
105
106 let energy_before = {
108 let nra = self.nra.lock().await;
109 nra.energy_snapshot()
110 };
111
112 {
114 let mut nra = self.nra.lock().await;
115 nra.store(&omega, &value)
116 .map_err(|e| RaiError::MemoryError(format!("NRA store: {e}")))?;
117 }
118
119 {
121 let mut rem = self.rem.lock().await;
122 rem.store(&key, &value)
123 .map_err(|e| RaiError::MemoryError(format!("REM store: {e}")))?;
124 }
125
126 {
128 let mut texts = self.texts.lock().await;
129 texts.push(content.to_string());
130 }
131
132 {
134 let mut id = self.next_id.lock().await;
135 *id += 1;
136 }
137
138 {
140 let mut training = self.training.lock().await;
141 training.item_stored();
142 }
143
144 let energy_after = {
146 let nra = self.nra.lock().await;
147 nra.energy_snapshot()
148 };
149
150 let texts = self.texts.lock().await;
152 let len = energy_before.len();
154 let report =
155 self.interference_detector
156 .detect(&energy_before, &energy_after[..len], &texts[..len]);
157
158 Ok(report)
159 }
160
161 pub async fn recall(&self, query: &str) -> Result<RetrievalResult, RaiError> {
163 let omega = self.bridge.text_to_omega(query).await?;
164
165 let diagnostics = {
166 let nra = self.nra.lock().await;
167 nra.retrieve_with_diagnostics(&omega)
168 .map_err(|e| RaiError::MemoryError(format!("NRA retrieve: {e}")))?
169 };
170
171 let confidence = self
172 .confidence_gate
173 .classify(diagnostics.energy, diagnostics.grad_norm);
174 let explanation =
175 self.confidence_gate
176 .explain(diagnostics.energy, diagnostics.grad_norm, confidence);
177
178 let content = self
180 .bridge
181 .nearest_text(&diagnostics.value)
182 .await
183 .unwrap_or_else(|| "(no matching text found)".to_string());
184
185 Ok(RetrievalResult {
186 content,
187 confidence,
188 energy: diagnostics.energy,
189 steps: diagnostics.steps,
190 grad_norm: diagnostics.grad_norm,
191 explanation,
192 })
193 }
194
195 pub async fn intersect(&self, concepts: &[String]) -> Result<IntersectionResult, RaiError> {
197 if concepts.is_empty() {
198 return Err(RaiError::InvalidInput("no concepts provided".into()));
199 }
200
201 let mut omegas = Vec::with_capacity(concepts.len());
203 for concept in concepts {
204 let omega = self.bridge.text_to_omega(concept).await?;
205 omegas.push(omega);
206 }
207
208 let combined = Compositor::intersect(&omegas);
210
211 let diagnostics = {
213 let nra = self.nra.lock().await;
214 nra.retrieve_with_diagnostics(&combined)
215 .map_err(|e| RaiError::MemoryError(format!("NRA intersect: {e}")))?
216 };
217
218 let confidence = self
219 .confidence_gate
220 .classify(diagnostics.energy, diagnostics.grad_norm);
221
222 let content = self
223 .bridge
224 .nearest_text(&diagnostics.value)
225 .await
226 .unwrap_or_else(|| "(no matching text at intersection)".to_string());
227
228 Ok(IntersectionResult {
229 content,
230 confidence,
231 energy: diagnostics.energy,
232 concepts: concepts.to_vec(),
233 })
234 }
235
236 pub async fn check_contradiction(&self, fact: &str) -> Result<InterferenceReport, RaiError> {
238 let (omega, _key, value) = self.bridge.embed_text(fact).await?;
239
240 let energy_before = {
242 let nra = self.nra.lock().await;
243 nra.energy_snapshot()
244 };
245
246 {
248 let mut nra = self.nra.lock().await;
249 nra.store(&omega, &value)
250 .map_err(|e| RaiError::MemoryError(format!("NRA contradict: {e}")))?;
251 }
252
253 let energy_after = {
255 let nra = self.nra.lock().await;
256 nra.energy_snapshot()
257 };
258
259 let texts = self.texts.lock().await;
260 let len = energy_before.len();
261 let report =
262 self.interference_detector
263 .detect(&energy_before, &energy_after[..len], &texts[..len]);
264
265 Ok(report)
266 }
267
268 pub async fn measure_surprise(&self, content: &str) -> Result<SurpriseResult, RaiError> {
270 let (_omega, _key, _value) = self.bridge.embed_text(content).await?;
271
272 let rem = self.rem.lock().await;
274 let residual_norm = rem.mean_residual_norm();
275
276 Ok(self.surprise_detector.score(residual_norm))
277 }
278
279 pub async fn explain_confidence(&self, query: &str) -> Result<ConfidenceExplanation, RaiError> {
281 let omega = self.bridge.text_to_omega(query).await?;
282
283 let nra = self.nra.lock().await;
284 let diagnostics = nra
285 .retrieve_with_diagnostics(&omega)
286 .map_err(|e| RaiError::MemoryError(format!("NRA explain: {e}")))?;
287
288 let mut confidence = self
289 .confidence_gate
290 .classify(diagnostics.energy, diagnostics.grad_norm);
291
292 let mut rng = rand::thread_rng();
294 let basin_result = self
295 .basin_analyzer
296 .analyze(&nra.params, &omega, &nra.config, &mut rng);
297
298 if basin_result.is_ambiguous {
299 confidence = ConfidenceLevel::Ambiguous;
300 }
301
302 let explanation =
303 self.confidence_gate
304 .explain(diagnostics.energy, diagnostics.grad_norm, confidence);
305
306 Ok(ConfidenceExplanation {
307 confidence,
308 energy: diagnostics.energy,
309 grad_norm: diagnostics.grad_norm,
310 num_attractors: basin_result.attractors.len(),
311 basin_spread: basin_result.energy_spread,
312 explanation,
313 })
314 }
315
316 pub async fn health(&self) -> Result<HealthReport, RaiError> {
318 let nra = self.nra.lock().await;
319 let rem = self.rem.lock().await;
320
321 let nra_mse = nra.mse().ok();
322 let rem_mse = rem.mse().ok();
323 let rem_residual_norm = if rem.is_empty() {
324 None
325 } else {
326 Some(rem.mean_residual_norm())
327 };
328
329 let num_memories = nra.len();
330 let nra_capacity_ratio = num_memories as f64 / nra.config.num_units as f64;
331
332 let training = self.training.lock().await;
333 let needs_training = training.needs_nra_retrain() || training.needs_rem_retrain();
334
335 Ok(HealthReport {
336 num_memories,
337 nra_mse,
338 rem_mse,
339 rem_residual_norm,
340 nra_capacity_ratio,
341 needs_training,
342 })
343 }
344
345 pub async fn train_nra(&self) -> Result<Vec<f64>, RaiError> {
347 let nra = self.nra.clone();
348 let handle = TrainingOrchestrator::spawn_nra_retrain(nra);
349 let result = handle
350 .await
351 .map_err(|e| RaiError::TrainingError(format!("join: {e}")))?;
352 {
353 let mut training = self.training.lock().await;
354 training.nra_trained();
355 }
356 result
357 }
358
359 pub async fn train_rem(&self) -> Result<Vec<f64>, RaiError> {
361 let rem = self.rem.clone();
362 let handle = TrainingOrchestrator::spawn_rem_retrain(rem);
363 let result = handle
364 .await
365 .map_err(|e| RaiError::TrainingError(format!("join: {e}")))?;
366 {
367 let mut training = self.training.lock().await;
368 training.rem_trained();
369 }
370 result
371 }
372
373 pub async fn save(&self, path: &Path) -> Result<(), RaiError> {
375 let nra = self.nra.lock().await;
376 let rem = self.rem.lock().await;
377
378 let snapshot = MemorySnapshot {
379 nra_params: nra.params.clone(),
380 nra_config: nra.config.clone(),
381 nra_items: nra.items().to_vec(),
382 rem_config: rem.config.clone(),
383 rem_encoder: rem.encoder.clone(),
384 rem_decoder: rem.decoder.clone(),
385 rem_memory_state: rem.memory_state.clone(),
386 rem_items: rem.items().to_vec(),
387 text_index: self.bridge.text_index().await,
388 omega_proj: self.bridge.omega_proj.clone(),
389 key_proj: self.bridge.key_proj.clone(),
390 value_proj: self.bridge.value_proj.clone(),
391 total_items: nra.len(),
392 };
393
394 snapshot.save(path)
395 }
396
397 pub async fn load(path: &Path, bridge: Arc<EmbeddingBridge>) -> Result<Self, RaiError> {
399 let snapshot = MemorySnapshot::load(path)?;
400
401 let mut nra =
403 NonlinearResonanceMemory::from_params(snapshot.nra_params, snapshot.nra_config);
404 for (omega, value) in &snapshot.nra_items {
405 nra.store(omega, value)
406 .map_err(|e| RaiError::MemoryError(format!("restore NRA: {e}")))?;
407 }
408
409 let mut rng = rand::thread_rng();
411 let mut rem = ResidualEquilibriumMemory::new(snapshot.rem_config, &mut rng);
412 rem.encoder = snapshot.rem_encoder;
414 rem.decoder = snapshot.rem_decoder;
415 rem.memory_state = snapshot.rem_memory_state;
416 for (key, value) in &snapshot.rem_items {
417 rem.store(key, value)
419 .map_err(|e| RaiError::MemoryError(format!("restore REM: {e}")))?;
420 }
421
422 bridge.restore_text_index(snapshot.text_index.clone()).await;
424
425 let texts: Vec<String> = snapshot
426 .text_index
427 .entries
428 .iter()
429 .map(|e| e.text.clone())
430 .collect();
431
432 Ok(Self {
433 nra: Arc::new(Mutex::new(nra)),
434 rem: Arc::new(Mutex::new(rem)),
435 bridge,
436 confidence_gate: ConfidenceGate::default(),
437 interference_detector: InterferenceDetector::default(),
438 basin_analyzer: BasinAnalyzer::default(),
439 surprise_detector: SurpriseDetector::default(),
440 training: Arc::new(Mutex::new(TrainingOrchestrator::default())),
441 texts: Arc::new(Mutex::new(texts)),
442 next_id: Arc::new(Mutex::new(snapshot.total_items)),
443 })
444 }
445
446 pub async fn len(&self) -> usize {
448 let nra = self.nra.lock().await;
449 nra.len()
450 }
451
452 pub async fn energy_snapshot(&self) -> Vec<(Vec64, f64)> {
454 let nra = self.nra.lock().await;
455 nra.energy_snapshot()
456 }
457}