1use anyhow::{anyhow, Result};
5use scirs2_core::ndarray_ext::{s, Array1, Array2};
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use std::time::{Duration, Instant};
9use tracing::info;
10use uuid::Uuid;
11
12use crate::memory_nets_controller::{
13 DNCConfig, DifferentiableNeuralComputer, NTMConfig, NeuralTuringMachine,
14};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MemoryNetworksConfig {
19 pub memory_capacity: usize,
20 pub embedding_dim: usize,
21 pub num_hops: usize,
22 pub learning_rate: f32,
23}
24
25impl Default for MemoryNetworksConfig {
26 fn default() -> Self {
27 Self {
28 memory_capacity: 1000,
29 embedding_dim: 128,
30 num_hops: 3,
31 learning_rate: 0.01,
32 }
33 }
34}
35
36pub struct MemoryNetworks {
38 pub(crate) config: MemoryNetworksConfig,
39 pub(crate) memory_embeddings: Array2<f32>,
40 pub(crate) memory_content: Vec<String>,
41 pub(crate) input_encoder: Array2<f32>,
42 pub(crate) output_encoder: Array2<f32>,
43 pub(crate) query_encoder: Array2<f32>,
44}
45
46impl MemoryNetworks {
47 pub fn new(config: MemoryNetworksConfig) -> Self {
48 use scirs2_core::random::Random;
49 let mut rng = Random::default();
50
51 let memory_embeddings = Array2::zeros((config.memory_capacity, config.embedding_dim));
52 let memory_content = Vec::new();
53
54 let input_encoder =
55 Array2::from_shape_fn((config.embedding_dim, config.embedding_dim), |_| {
56 rng.random_range(-0.1..0.1)
57 });
58 let output_encoder =
59 Array2::from_shape_fn((config.embedding_dim, config.embedding_dim), |_| {
60 rng.random_range(-0.1..0.1)
61 });
62 let query_encoder =
63 Array2::from_shape_fn((config.embedding_dim, config.embedding_dim), |_| {
64 rng.random_range(-0.1..0.1)
65 });
66
67 Self {
68 config,
69 memory_embeddings,
70 memory_content,
71 input_encoder,
72 output_encoder,
73 query_encoder,
74 }
75 }
76
77 pub fn store_memory(&mut self, content: String, embedding: Array1<f32>) -> Result<()> {
79 if self.memory_content.len() < self.config.memory_capacity {
80 let index = self.memory_content.len();
81 self.memory_content.push(content);
82 if embedding.len() == self.config.embedding_dim {
83 self.memory_embeddings.row_mut(index).assign(&embedding);
84 } else {
85 return Err(anyhow!("Embedding dimension mismatch"));
86 }
87 } else {
88 let index = 0;
89 self.memory_content[index] = content;
90 self.memory_embeddings.row_mut(index).assign(&embedding);
91 for i in 1..self.memory_content.len() {
92 self.memory_content.swap(i - 1, i);
93 let row1 = self.memory_embeddings.row(i - 1).to_owned();
94 let row2 = self.memory_embeddings.row(i).to_owned();
95 self.memory_embeddings.row_mut(i - 1).assign(&row2);
96 self.memory_embeddings.row_mut(i).assign(&row1);
97 }
98 }
99 Ok(())
100 }
101
102 pub fn query(&self, query_embedding: &Array1<f32>) -> Result<Array1<f32>> {
104 let num_memories = self.memory_content.len();
105 if num_memories == 0 {
106 return Ok(Array1::zeros(self.config.embedding_dim));
107 }
108
109 let mut response = Array1::zeros(self.config.embedding_dim);
110 let mut current_query = query_embedding.clone();
111
112 for _hop in 0..self.config.num_hops {
113 let attention_weights = self.compute_attention(¤t_query)?;
114 let active_embeddings = self
116 .memory_embeddings
117 .slice(scirs2_core::ndarray_ext::s![..num_memories, ..]);
118 let memory_response = active_embeddings.t().dot(&attention_weights);
119 current_query = self.output_encoder.dot(&memory_response);
120 response = memory_response;
121 }
122
123 Ok(response)
124 }
125
126 fn compute_attention(&self, query: &Array1<f32>) -> Result<Array1<f32>> {
127 let num_memories = self.memory_content.len();
128 if num_memories == 0 {
129 return Ok(Array1::zeros(0));
130 }
131
132 let mut attention_scores = Array1::zeros(num_memories);
133 for i in 0..num_memories {
134 let memory_embedding = self.memory_embeddings.row(i);
135 attention_scores[i] = query.dot(&memory_embedding);
136 }
137
138 let max_score = attention_scores.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
139 let exp_scores = attention_scores.map(|&x| (x - max_score).exp());
140 let sum_exp = exp_scores.sum();
141
142 if sum_exp > 0.0 {
143 Ok(exp_scores / sum_exp)
144 } else {
145 Ok(Array1::from_elem(num_memories, 1.0 / num_memories as f32))
146 }
147 }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct EpisodicConfig {
153 pub episode_capacity: usize,
154 pub episode_length: usize,
155 pub embedding_dim: usize,
156 pub decay_factor: f32,
157}
158
159impl Default for EpisodicConfig {
160 fn default() -> Self {
161 Self {
162 episode_capacity: 100,
163 episode_length: 50,
164 embedding_dim: 128,
165 decay_factor: 0.95,
166 }
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct EpisodeMetadata {
173 pub episode_type: String,
174 pub success: bool,
175 pub length: usize,
176 pub average_reward: f32,
177 pub tags: Vec<String>,
178}
179
180#[derive(Debug, Clone)]
182pub struct Episode {
183 pub id: Uuid,
184 pub states: Vec<Array1<f32>>,
185 pub rewards: Vec<f32>,
186 pub metadata: EpisodeMetadata,
187 pub timestamp: chrono::DateTime<chrono::Utc>,
188}
189
190pub struct EpisodicMemory {
192 pub(crate) config: EpisodicConfig,
193 pub(crate) episodes: VecDeque<Episode>,
194 pub(crate) current_episode: Option<Episode>,
195}
196
197impl EpisodicMemory {
198 pub fn new(config: EpisodicConfig) -> Self {
199 Self {
200 config,
201 episodes: VecDeque::new(),
202 current_episode: None,
203 }
204 }
205
206 pub fn start_episode(&mut self, episode_type: String) {
207 let episode = Episode {
208 id: Uuid::new_v4(),
209 states: Vec::new(),
210 rewards: Vec::new(),
211 metadata: EpisodeMetadata {
212 episode_type,
213 success: false,
214 length: 0,
215 average_reward: 0.0,
216 tags: Vec::new(),
217 },
218 timestamp: chrono::Utc::now(),
219 };
220 self.current_episode = Some(episode);
221 }
222
223 pub fn add_state(&mut self, state: Array1<f32>, reward: f32) -> Result<()> {
224 if let Some(ref mut episode) = self.current_episode {
225 episode.states.push(state);
226 episode.rewards.push(reward);
227 Ok(())
228 } else {
229 Err(anyhow!("No active episode"))
230 }
231 }
232
233 pub fn end_episode(&mut self, success: bool) -> Result<()> {
234 if let Some(mut episode) = self.current_episode.take() {
235 episode.metadata.success = success;
236 episode.metadata.length = episode.states.len();
237 episode.metadata.average_reward = if episode.rewards.is_empty() {
238 0.0
239 } else {
240 episode.rewards.iter().sum::<f32>() / episode.rewards.len() as f32
241 };
242
243 if self.episodes.len() >= self.config.episode_capacity {
244 self.episodes.pop_front();
245 }
246 self.episodes.push_back(episode);
247 Ok(())
248 } else {
249 Err(anyhow!("No active episode"))
250 }
251 }
252
253 pub fn retrieve_similar_episodes(&self, query_state: &Array1<f32>, k: usize) -> Vec<&Episode> {
254 let mut similarities: Vec<(f32, &Episode)> = self
255 .episodes
256 .iter()
257 .map(|episode| {
258 let similarity = self.compute_episode_similarity(episode, query_state);
259 (similarity, episode)
260 })
261 .collect();
262
263 similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
264 similarities
265 .into_iter()
266 .take(k)
267 .map(|(_, episode)| episode)
268 .collect()
269 }
270
271 fn compute_episode_similarity(&self, episode: &Episode, query_state: &Array1<f32>) -> f32 {
272 if episode.states.is_empty() {
273 return 0.0;
274 }
275 let mut total_similarity = 0.0;
276 for state in &episode.states {
277 total_similarity += cosine_sim(query_state, state);
278 }
279 total_similarity / episode.states.len() as f32
280 }
281}
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct RelationalConfig {
286 pub memory_size: usize,
287 pub embedding_dim: usize,
288 pub num_heads: usize,
289 pub num_relation_types: usize,
290}
291
292impl Default for RelationalConfig {
293 fn default() -> Self {
294 Self {
295 memory_size: 512,
296 embedding_dim: 256,
297 num_heads: 8,
298 num_relation_types: 10,
299 }
300 }
301}
302
303pub struct RelationalAttention {
305 pub(crate) query_weights: Array2<f32>,
306 pub(crate) key_weights: Array2<f32>,
307 pub(crate) value_weights: Array2<f32>,
308 pub(crate) num_heads: usize,
309 pub(crate) embed_dim: usize,
310}
311
312impl RelationalAttention {
313 pub fn new(embed_dim: usize, num_heads: usize) -> Self {
314 use scirs2_core::random::Random;
315 let mut rng = Random::default();
316
317 let query_weights =
318 Array2::from_shape_fn((embed_dim, embed_dim), |_| rng.random_range(-0.1..0.1));
319 let key_weights =
320 Array2::from_shape_fn((embed_dim, embed_dim), |_| rng.random_range(-0.1..0.1));
321 let value_weights =
322 Array2::from_shape_fn((embed_dim, embed_dim), |_| rng.random_range(-0.1..0.1));
323
324 Self {
325 query_weights,
326 key_weights,
327 value_weights,
328 num_heads,
329 embed_dim,
330 }
331 }
332
333 pub fn forward(&self, memory: &Array2<f32>, query: &Array1<f32>) -> Array1<f32> {
334 let head_dim = self.embed_dim / self.num_heads;
335 let mut output = Array1::zeros(self.embed_dim);
336
337 for head in 0..self.num_heads {
338 let start_idx = head * head_dim;
339 let end_idx = (head + 1) * head_dim;
340
341 let q_head = self.query_weights.slice(s![start_idx..end_idx, ..]);
342 let k_head = self.key_weights.slice(s![start_idx..end_idx, ..]);
343 let v_head = self.value_weights.slice(s![start_idx..end_idx, ..]);
344
345 let q = q_head.dot(query);
346 let keys = memory.dot(&k_head.t());
347 let values = memory.dot(&v_head.t());
348
349 let mut scores = Array1::zeros(memory.nrows());
350 for i in 0..memory.nrows() {
351 scores[i] = q.dot(&keys.row(i));
352 }
353
354 let max_score = scores.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
355 let exp_scores = scores.map(|&x| (x - max_score).exp());
356 let sum_exp = exp_scores.sum();
357 let attention_weights = if sum_exp > 0.0 {
358 exp_scores / sum_exp
359 } else {
360 Array1::from_elem(memory.nrows(), 1.0 / memory.nrows() as f32)
361 };
362
363 let head_output = values.t().dot(&attention_weights);
364 output
365 .slice_mut(s![start_idx..end_idx])
366 .assign(&head_output);
367 }
368
369 output
370 }
371}
372
373pub struct RelationalMemoryCore {
375 pub(crate) config: RelationalConfig,
376 pub(crate) memory: Array2<f32>,
377 pub(crate) relation_matrices: Vec<Array2<f32>>,
378 pub(crate) attention_mechanism: RelationalAttention,
379}
380
381impl RelationalMemoryCore {
382 pub fn new(config: RelationalConfig) -> Self {
383 use scirs2_core::random::Random;
384 let mut rng = Random::default();
385
386 let memory = Array2::zeros((config.memory_size, config.embedding_dim));
387 let mut relation_matrices = Vec::new();
388
389 for _ in 0..config.num_relation_types {
390 let relation_matrix =
391 Array2::from_shape_fn((config.embedding_dim, config.embedding_dim), |_| {
392 rng.random_range(-0.1..0.1)
393 });
394 relation_matrices.push(relation_matrix);
395 }
396
397 let attention_mechanism = RelationalAttention::new(config.embedding_dim, config.num_heads);
398 Self {
399 config,
400 memory,
401 relation_matrices,
402 attention_mechanism,
403 }
404 }
405
406 pub fn store_relation(
407 &mut self,
408 subject: &Array1<f32>,
409 relation_type: usize,
410 object: &Array1<f32>,
411 ) -> Result<()> {
412 if relation_type >= self.config.num_relation_types {
413 return Err(anyhow!("Invalid relation type"));
414 }
415 let relation_matrix = &self.relation_matrices[relation_type];
416 let transformed_subject = relation_matrix.dot(subject);
417 let transformed_object = relation_matrix.dot(object);
418
419 if let Some(slot) = self.find_empty_slot() {
420 let combined = &transformed_subject + &transformed_object;
421 self.memory.row_mut(slot).assign(&combined);
422 }
423 Ok(())
424 }
425
426 fn find_empty_slot(&self) -> Option<usize> {
427 (0..self.memory.nrows()).find(|&i| self.memory.row(i).sum() == 0.0)
428 }
429
430 pub fn query_relations(&self, query: &Array1<f32>) -> Array1<f32> {
431 self.attention_mechanism.forward(&self.memory, query)
432 }
433}
434
435#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct SparseConfig {
438 pub memory_capacity: usize,
439 pub embedding_dim: usize,
440 pub sparsity_factor: f32,
441 pub update_threshold: f32,
442}
443
444impl Default for SparseConfig {
445 fn default() -> Self {
446 Self {
447 memory_capacity: 10000,
448 embedding_dim: 512,
449 sparsity_factor: 0.1,
450 update_threshold: 0.01,
451 }
452 }
453}
454
455pub struct SparseAccessMemory {
457 pub(crate) config: SparseConfig,
458 pub(crate) memory: HashMap<usize, Array1<f32>>,
459 pub(crate) access_counts: HashMap<usize, usize>,
460 pub(crate) last_access: HashMap<usize, Instant>,
461}
462
463impl SparseAccessMemory {
464 pub fn new(config: SparseConfig) -> Self {
465 Self {
466 config,
467 memory: HashMap::new(),
468 access_counts: HashMap::new(),
469 last_access: HashMap::new(),
470 }
471 }
472
473 pub fn store(&mut self, key: usize, value: Array1<f32>) -> Result<()> {
474 if self.memory.len() >= self.config.memory_capacity {
475 self.evict_least_used()?;
476 }
477 self.memory.insert(key, value);
478 self.access_counts.insert(key, 1);
479 self.last_access.insert(key, Instant::now());
480 Ok(())
481 }
482
483 pub fn retrieve(&mut self, key: usize) -> Option<&Array1<f32>> {
484 if let Some(value) = self.memory.get(&key) {
485 *self.access_counts.entry(key).or_insert(0) += 1;
486 self.last_access.insert(key, Instant::now());
487 Some(value)
488 } else {
489 None
490 }
491 }
492
493 pub fn find_similar(&self, query: &Array1<f32>, k: usize) -> Vec<(usize, f32)> {
494 let mut similarities: Vec<(usize, f32)> = self
495 .memory
496 .iter()
497 .map(|(&key, value)| (key, cosine_sim(query, value)))
498 .collect();
499 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
500 similarities.into_iter().take(k).collect()
501 }
502
503 fn evict_least_used(&mut self) -> Result<()> {
504 let mut candidates: Vec<(usize, usize, Instant)> = self
505 .access_counts
506 .iter()
507 .map(|(&key, &count)| {
508 let last_access = self
509 .last_access
510 .get(&key)
511 .copied()
512 .unwrap_or(Instant::now());
513 (key, count, last_access)
514 })
515 .collect();
516 candidates.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.2.cmp(&b.2)));
517
518 if let Some((key_to_evict, _, _)) = candidates.first() {
519 let key = *key_to_evict;
520 self.memory.remove(&key);
521 self.access_counts.remove(&key);
522 self.last_access.remove(&key);
523 }
524 Ok(())
525 }
526
527 pub fn cleanup(&mut self, max_age: Duration) -> Result<usize> {
528 let now = Instant::now();
529 let mut keys_to_remove = Vec::new();
530
531 for (&key, &last_access) in &self.last_access {
532 if now.duration_since(last_access) > max_age {
533 keys_to_remove.push(key);
534 }
535 }
536
537 let removed_count = keys_to_remove.len();
538 for key in keys_to_remove {
539 self.memory.remove(&key);
540 self.access_counts.remove(&key);
541 self.last_access.remove(&key);
542 }
543 Ok(removed_count)
544 }
545}
546
547#[derive(Debug, Clone, Serialize, Deserialize)]
549pub enum CoordinationStrategy {
550 RoundRobin,
551 PerformanceBased,
552 ContentBased,
553 Adaptive,
554}
555
556#[derive(Debug, Clone, Serialize, Deserialize)]
558pub struct MemoryUsageStats {
559 pub dnc_utilization: f32,
560 pub ntm_utilization: f32,
561 pub memory_networks_utilization: f32,
562 pub episodic_utilization: f32,
563 pub relational_utilization: f32,
564 pub sparse_utilization: f32,
565 pub total_memory_mb: f32,
566}
567
568#[derive(Default)]
570pub struct MemoryPerformanceTracker {
571 pub(crate) access_latencies: HashMap<String, VecDeque<f32>>,
572 pub(crate) hit_rates: HashMap<String, f32>,
573 pub(crate) throughput_metrics: HashMap<String, f32>,
574}
575
576impl MemoryPerformanceTracker {
577 pub fn new() -> Self {
578 Self::default()
579 }
580
581 pub fn record_access(&mut self, memory_type: &str, latency_ms: f32) {
582 let latencies = self
583 .access_latencies
584 .entry(memory_type.to_string())
585 .or_default();
586 latencies.push_back(latency_ms);
587 while latencies.len() > 100 {
588 latencies.pop_front();
589 }
590 }
591
592 pub fn get_average_latency(&self, memory_type: &str) -> f32 {
593 if let Some(latencies) = self.access_latencies.get(memory_type) {
594 if !latencies.is_empty() {
595 return latencies.iter().sum::<f32>() / latencies.len() as f32;
596 }
597 }
598 0.0
599 }
600}
601
602pub struct MemoryCoordinator {
604 pub(crate) strategy: CoordinationStrategy,
605 pub(crate) usage_stats: MemoryUsageStats,
606 pub(crate) performance_tracker: MemoryPerformanceTracker,
607}
608
609#[derive(Debug, Clone, Serialize, Deserialize)]
611pub struct MemoryPerformanceMetrics {
612 pub total_operations: u64,
613 pub average_latency_ms: f32,
614 pub hit_rate: f32,
615 pub utilization: f32,
616 pub ops_per_second: f32,
617 pub error_rate: f32,
618}
619
620impl Default for MemoryPerformanceMetrics {
621 fn default() -> Self {
622 Self {
623 total_operations: 0,
624 average_latency_ms: 0.0,
625 hit_rate: 0.0,
626 utilization: 0.0,
627 ops_per_second: 0.0,
628 error_rate: 0.0,
629 }
630 }
631}
632
633#[derive(Debug, Clone, Serialize, Deserialize)]
635pub struct GlobalMemorySettings {
636 pub enable_compression: bool,
637 pub memory_capacity_mb: f32,
638 pub cleanup_threshold: f32,
639 pub enable_persistence: bool,
640 pub update_frequency_ms: u64,
641 pub enable_coordination: bool,
642}
643
644impl Default for GlobalMemorySettings {
645 fn default() -> Self {
646 Self {
647 enable_compression: true,
648 memory_capacity_mb: 1024.0,
649 cleanup_threshold: 0.85,
650 enable_persistence: true,
651 update_frequency_ms: 100,
652 enable_coordination: true,
653 }
654 }
655}
656
657#[derive(Debug, Clone, Serialize, Deserialize, Default)]
659pub struct MemoryConfig {
660 pub dnc_config: DNCConfig,
661 pub ntm_config: NTMConfig,
662 pub memory_networks_config: MemoryNetworksConfig,
663 pub episodic_config: EpisodicConfig,
664 pub relational_config: RelationalConfig,
665 pub sparse_config: SparseConfig,
666 pub global_settings: GlobalMemorySettings,
667}
668
669pub struct MemoryAugmentedNetwork {
671 pub(crate) config: MemoryConfig,
672 pub(crate) dnc: DifferentiableNeuralComputer,
673 pub(crate) ntm: NeuralTuringMachine,
674 pub(crate) memory_networks: MemoryNetworks,
675 pub(crate) episodic_memory: EpisodicMemory,
676 pub(crate) relational_memory: RelationalMemoryCore,
677 pub(crate) sparse_memory: SparseAccessMemory,
678 pub(crate) memory_coordinator: MemoryCoordinator,
679 pub(crate) performance_metrics: MemoryPerformanceMetrics,
680}
681
682impl MemoryAugmentedNetwork {
683 pub fn new(config: MemoryConfig) -> Result<Self> {
684 let dnc = DifferentiableNeuralComputer::new(config.dnc_config.clone());
685 let ntm = NeuralTuringMachine::new(config.ntm_config.clone());
686 let memory_networks = MemoryNetworks::new(config.memory_networks_config.clone());
687 let episodic_memory = EpisodicMemory::new(config.episodic_config.clone());
688 let relational_memory = RelationalMemoryCore::new(config.relational_config.clone());
689 let sparse_memory = SparseAccessMemory::new(config.sparse_config.clone());
690
691 let memory_coordinator = MemoryCoordinator {
692 strategy: CoordinationStrategy::Adaptive,
693 usage_stats: MemoryUsageStats {
694 dnc_utilization: 0.0,
695 ntm_utilization: 0.0,
696 memory_networks_utilization: 0.0,
697 episodic_utilization: 0.0,
698 relational_utilization: 0.0,
699 sparse_utilization: 0.0,
700 total_memory_mb: 0.0,
701 },
702 performance_tracker: MemoryPerformanceTracker::new(),
703 };
704
705 Ok(Self {
706 config,
707 dnc,
708 ntm,
709 memory_networks,
710 episodic_memory,
711 relational_memory,
712 sparse_memory,
713 memory_coordinator,
714 performance_metrics: MemoryPerformanceMetrics::default(),
715 })
716 }
717
718 pub async fn process(
719 &mut self,
720 input: &Array1<f32>,
721 memory_type: Option<&str>,
722 ) -> Result<Array1<f32>> {
723 let start_time = Instant::now();
724
725 let result = match memory_type {
726 Some("dnc") => self.dnc.forward(input),
727 Some("ntm") => self.ntm.forward(input),
728 Some("memory_networks") => Ok(self.memory_networks.query(input)?),
729 Some("relational") => Ok(self.relational_memory.query_relations(input)),
730 Some("sparse") => {
731 let similar = self.sparse_memory.find_similar(input, 1);
732 if let Some((key, _)) = similar.first() {
733 Ok(self.sparse_memory.retrieve(*key).unwrap_or(input).clone())
734 } else {
735 Ok(input.clone())
736 }
737 }
738 _ => self.adaptive_routing(input).await,
739 };
740
741 let latency = start_time.elapsed().as_millis() as f32;
742 if let Some(mem_type) = memory_type {
743 self.memory_coordinator
744 .performance_tracker
745 .record_access(mem_type, latency);
746 }
747
748 self.performance_metrics.total_operations += 1;
749 self.update_performance_metrics(latency);
750
751 result
752 }
753
754 async fn adaptive_routing(&mut self, input: &Array1<f32>) -> Result<Array1<f32>> {
755 let input_norm = input.mapv(|x| x * x).sum().sqrt();
756 let input_sparsity =
757 input.iter().filter(|&&x| x.abs() < 0.01).count() as f32 / input.len() as f32;
758
759 match (input_norm, input_sparsity) {
760 (norm, sparsity) if norm > 10.0 && sparsity < 0.3 => self.dnc.forward(input),
761 (norm, sparsity) if norm < 5.0 && sparsity > 0.7 => {
762 let similar = self.sparse_memory.find_similar(input, 1);
763 if let Some((key, _)) = similar.first() {
764 Ok(self.sparse_memory.retrieve(*key).unwrap_or(input).clone())
765 } else {
766 Ok(input.clone())
767 }
768 }
769 _ => Ok(self.memory_networks.query(input)?),
770 }
771 }
772
773 pub async fn store(
774 &mut self,
775 content: String,
776 embedding: Array1<f32>,
777 memory_type: Option<&str>,
778 ) -> Result<()> {
779 match memory_type {
780 Some("memory_networks") => {
781 self.memory_networks.store_memory(content, embedding)?;
782 }
783 Some("sparse") => {
784 let key = self.hash_content(&content);
785 self.sparse_memory.store(key, embedding)?;
786 }
787 Some("relational") => {
788 let zero_vector = Array1::zeros(embedding.len());
789 self.relational_memory
790 .store_relation(&embedding, 0, &zero_vector)?;
791 }
792 _ => {
793 self.memory_networks.store_memory(content, embedding)?;
794 }
795 }
796 Ok(())
797 }
798
799 fn hash_content(&self, content: &str) -> usize {
800 use std::collections::hash_map::DefaultHasher;
801 use std::hash::{Hash, Hasher};
802 let mut hasher = DefaultHasher::new();
803 content.hash(&mut hasher);
804 hasher.finish() as usize
805 }
806
807 pub fn start_episode(&mut self, episode_type: String) {
808 self.episodic_memory.start_episode(episode_type);
809 }
810
811 pub fn add_episode_state(&mut self, state: Array1<f32>, reward: f32) -> Result<()> {
812 self.episodic_memory.add_state(state, reward)
813 }
814
815 pub fn end_episode(&mut self, success: bool) -> Result<()> {
816 self.episodic_memory.end_episode(success)
817 }
818
819 pub fn get_memory_stats(&self) -> MemoryUsageStats {
820 self.memory_coordinator.usage_stats.clone()
821 }
822
823 pub fn get_performance_metrics(&self) -> &MemoryPerformanceMetrics {
824 &self.performance_metrics
825 }
826
827 fn update_performance_metrics(&mut self, latency: f32) {
828 let alpha = 0.1;
829 self.performance_metrics.average_latency_ms =
830 alpha * latency + (1.0 - alpha) * self.performance_metrics.average_latency_ms;
831 }
832
833 pub async fn cleanup(&mut self) -> Result<()> {
834 if self.dnc.get_memory_utilization() > 0.9 {
835 self.dnc.reset();
836 }
837 let cleanup_duration = Duration::from_secs(3600);
838 let removed = self.sparse_memory.cleanup(cleanup_duration)?;
839 if removed > 0 {
840 info!("Cleaned up {} entries from sparse memory", removed);
841 }
842 Ok(())
843 }
844}
845
846fn cosine_sim(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
847 let dot_product = a.dot(b);
848 let norm_a = a.mapv(|x| x * x).sum().sqrt();
849 let norm_b = b.mapv(|x| x * x).sum().sqrt();
850 if norm_a > 0.0 && norm_b > 0.0 {
851 dot_product / (norm_a * norm_b)
852 } else {
853 0.0
854 }
855}