Skip to main content

oxirs_embed/
memory_nets_ops.rs

1//! Memory operation systems: MemoryNetworks, EpisodicMemory, RelationalMemoryCore,
2//! SparseAccessMemory, and the top-level MemoryAugmentedNetwork orchestrator.
3
4use anyhow::{anyhow, Result};
5use scirs2_core::ndarray_ext::{s, Array1, Array2};
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, VecDeque};
8use std::time::{Duration, Instant};
9use tracing::info;
10use uuid::Uuid;
11
12use crate::memory_nets_controller::{
13    DNCConfig, DifferentiableNeuralComputer, NTMConfig, NeuralTuringMachine,
14};
15
16/// Memory Networks configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct MemoryNetworksConfig {
19    pub memory_capacity: usize,
20    pub embedding_dim: usize,
21    pub num_hops: usize,
22    pub learning_rate: f32,
23}
24
25impl Default for MemoryNetworksConfig {
26    fn default() -> Self {
27        Self {
28            memory_capacity: 1000,
29            embedding_dim: 128,
30            num_hops: 3,
31            learning_rate: 0.01,
32        }
33    }
34}
35
36/// Memory Networks implementation (multi-hop reasoning)
37pub struct MemoryNetworks {
38    pub(crate) config: MemoryNetworksConfig,
39    pub(crate) memory_embeddings: Array2<f32>,
40    pub(crate) memory_content: Vec<String>,
41    pub(crate) input_encoder: Array2<f32>,
42    pub(crate) output_encoder: Array2<f32>,
43    pub(crate) query_encoder: Array2<f32>,
44}
45
46impl MemoryNetworks {
47    pub fn new(config: MemoryNetworksConfig) -> Self {
48        use scirs2_core::random::Random;
49        let mut rng = Random::default();
50
51        let memory_embeddings = Array2::zeros((config.memory_capacity, config.embedding_dim));
52        let memory_content = Vec::new();
53
54        let input_encoder =
55            Array2::from_shape_fn((config.embedding_dim, config.embedding_dim), |_| {
56                rng.random_range(-0.1..0.1)
57            });
58        let output_encoder =
59            Array2::from_shape_fn((config.embedding_dim, config.embedding_dim), |_| {
60                rng.random_range(-0.1..0.1)
61            });
62        let query_encoder =
63            Array2::from_shape_fn((config.embedding_dim, config.embedding_dim), |_| {
64                rng.random_range(-0.1..0.1)
65            });
66
67        Self {
68            config,
69            memory_embeddings,
70            memory_content,
71            input_encoder,
72            output_encoder,
73            query_encoder,
74        }
75    }
76
77    /// Store memory (FIFO eviction when full)
78    pub fn store_memory(&mut self, content: String, embedding: Array1<f32>) -> Result<()> {
79        if self.memory_content.len() < self.config.memory_capacity {
80            let index = self.memory_content.len();
81            self.memory_content.push(content);
82            if embedding.len() == self.config.embedding_dim {
83                self.memory_embeddings.row_mut(index).assign(&embedding);
84            } else {
85                return Err(anyhow!("Embedding dimension mismatch"));
86            }
87        } else {
88            let index = 0;
89            self.memory_content[index] = content;
90            self.memory_embeddings.row_mut(index).assign(&embedding);
91            for i in 1..self.memory_content.len() {
92                self.memory_content.swap(i - 1, i);
93                let row1 = self.memory_embeddings.row(i - 1).to_owned();
94                let row2 = self.memory_embeddings.row(i).to_owned();
95                self.memory_embeddings.row_mut(i - 1).assign(&row2);
96                self.memory_embeddings.row_mut(i).assign(&row1);
97            }
98        }
99        Ok(())
100    }
101
102    /// Query memory using multi-hop attention
103    pub fn query(&self, query_embedding: &Array1<f32>) -> Result<Array1<f32>> {
104        let num_memories = self.memory_content.len();
105        if num_memories == 0 {
106            return Ok(Array1::zeros(self.config.embedding_dim));
107        }
108
109        let mut response = Array1::zeros(self.config.embedding_dim);
110        let mut current_query = query_embedding.clone();
111
112        for _hop in 0..self.config.num_hops {
113            let attention_weights = self.compute_attention(&current_query)?;
114            // Only use the filled portion of the embeddings matrix
115            let active_embeddings = self
116                .memory_embeddings
117                .slice(scirs2_core::ndarray_ext::s![..num_memories, ..]);
118            let memory_response = active_embeddings.t().dot(&attention_weights);
119            current_query = self.output_encoder.dot(&memory_response);
120            response = memory_response;
121        }
122
123        Ok(response)
124    }
125
126    fn compute_attention(&self, query: &Array1<f32>) -> Result<Array1<f32>> {
127        let num_memories = self.memory_content.len();
128        if num_memories == 0 {
129            return Ok(Array1::zeros(0));
130        }
131
132        let mut attention_scores = Array1::zeros(num_memories);
133        for i in 0..num_memories {
134            let memory_embedding = self.memory_embeddings.row(i);
135            attention_scores[i] = query.dot(&memory_embedding);
136        }
137
138        let max_score = attention_scores.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
139        let exp_scores = attention_scores.map(|&x| (x - max_score).exp());
140        let sum_exp = exp_scores.sum();
141
142        if sum_exp > 0.0 {
143            Ok(exp_scores / sum_exp)
144        } else {
145            Ok(Array1::from_elem(num_memories, 1.0 / num_memories as f32))
146        }
147    }
148}
149
150/// Episodic Memory configuration
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct EpisodicConfig {
153    pub episode_capacity: usize,
154    pub episode_length: usize,
155    pub embedding_dim: usize,
156    pub decay_factor: f32,
157}
158
159impl Default for EpisodicConfig {
160    fn default() -> Self {
161        Self {
162            episode_capacity: 100,
163            episode_length: 50,
164            embedding_dim: 128,
165            decay_factor: 0.95,
166        }
167    }
168}
169
170/// Episode metadata
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct EpisodeMetadata {
173    pub episode_type: String,
174    pub success: bool,
175    pub length: usize,
176    pub average_reward: f32,
177    pub tags: Vec<String>,
178}
179
180/// Episode representation
181#[derive(Debug, Clone)]
182pub struct Episode {
183    pub id: Uuid,
184    pub states: Vec<Array1<f32>>,
185    pub rewards: Vec<f32>,
186    pub metadata: EpisodeMetadata,
187    pub timestamp: chrono::DateTime<chrono::Utc>,
188}
189
190/// Episodic Memory for sequential experiences
191pub struct EpisodicMemory {
192    pub(crate) config: EpisodicConfig,
193    pub(crate) episodes: VecDeque<Episode>,
194    pub(crate) current_episode: Option<Episode>,
195}
196
197impl EpisodicMemory {
198    pub fn new(config: EpisodicConfig) -> Self {
199        Self {
200            config,
201            episodes: VecDeque::new(),
202            current_episode: None,
203        }
204    }
205
206    pub fn start_episode(&mut self, episode_type: String) {
207        let episode = Episode {
208            id: Uuid::new_v4(),
209            states: Vec::new(),
210            rewards: Vec::new(),
211            metadata: EpisodeMetadata {
212                episode_type,
213                success: false,
214                length: 0,
215                average_reward: 0.0,
216                tags: Vec::new(),
217            },
218            timestamp: chrono::Utc::now(),
219        };
220        self.current_episode = Some(episode);
221    }
222
223    pub fn add_state(&mut self, state: Array1<f32>, reward: f32) -> Result<()> {
224        if let Some(ref mut episode) = self.current_episode {
225            episode.states.push(state);
226            episode.rewards.push(reward);
227            Ok(())
228        } else {
229            Err(anyhow!("No active episode"))
230        }
231    }
232
233    pub fn end_episode(&mut self, success: bool) -> Result<()> {
234        if let Some(mut episode) = self.current_episode.take() {
235            episode.metadata.success = success;
236            episode.metadata.length = episode.states.len();
237            episode.metadata.average_reward = if episode.rewards.is_empty() {
238                0.0
239            } else {
240                episode.rewards.iter().sum::<f32>() / episode.rewards.len() as f32
241            };
242
243            if self.episodes.len() >= self.config.episode_capacity {
244                self.episodes.pop_front();
245            }
246            self.episodes.push_back(episode);
247            Ok(())
248        } else {
249            Err(anyhow!("No active episode"))
250        }
251    }
252
253    pub fn retrieve_similar_episodes(&self, query_state: &Array1<f32>, k: usize) -> Vec<&Episode> {
254        let mut similarities: Vec<(f32, &Episode)> = self
255            .episodes
256            .iter()
257            .map(|episode| {
258                let similarity = self.compute_episode_similarity(episode, query_state);
259                (similarity, episode)
260            })
261            .collect();
262
263        similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
264        similarities
265            .into_iter()
266            .take(k)
267            .map(|(_, episode)| episode)
268            .collect()
269    }
270
271    fn compute_episode_similarity(&self, episode: &Episode, query_state: &Array1<f32>) -> f32 {
272        if episode.states.is_empty() {
273            return 0.0;
274        }
275        let mut total_similarity = 0.0;
276        for state in &episode.states {
277            total_similarity += cosine_sim(query_state, state);
278        }
279        total_similarity / episode.states.len() as f32
280    }
281}
282
283/// Relational Memory configuration
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct RelationalConfig {
286    pub memory_size: usize,
287    pub embedding_dim: usize,
288    pub num_heads: usize,
289    pub num_relation_types: usize,
290}
291
292impl Default for RelationalConfig {
293    fn default() -> Self {
294        Self {
295            memory_size: 512,
296            embedding_dim: 256,
297            num_heads: 8,
298            num_relation_types: 10,
299        }
300    }
301}
302
303/// Relational attention mechanism
304pub struct RelationalAttention {
305    pub(crate) query_weights: Array2<f32>,
306    pub(crate) key_weights: Array2<f32>,
307    pub(crate) value_weights: Array2<f32>,
308    pub(crate) num_heads: usize,
309    pub(crate) embed_dim: usize,
310}
311
312impl RelationalAttention {
313    pub fn new(embed_dim: usize, num_heads: usize) -> Self {
314        use scirs2_core::random::Random;
315        let mut rng = Random::default();
316
317        let query_weights =
318            Array2::from_shape_fn((embed_dim, embed_dim), |_| rng.random_range(-0.1..0.1));
319        let key_weights =
320            Array2::from_shape_fn((embed_dim, embed_dim), |_| rng.random_range(-0.1..0.1));
321        let value_weights =
322            Array2::from_shape_fn((embed_dim, embed_dim), |_| rng.random_range(-0.1..0.1));
323
324        Self {
325            query_weights,
326            key_weights,
327            value_weights,
328            num_heads,
329            embed_dim,
330        }
331    }
332
333    pub fn forward(&self, memory: &Array2<f32>, query: &Array1<f32>) -> Array1<f32> {
334        let head_dim = self.embed_dim / self.num_heads;
335        let mut output = Array1::zeros(self.embed_dim);
336
337        for head in 0..self.num_heads {
338            let start_idx = head * head_dim;
339            let end_idx = (head + 1) * head_dim;
340
341            let q_head = self.query_weights.slice(s![start_idx..end_idx, ..]);
342            let k_head = self.key_weights.slice(s![start_idx..end_idx, ..]);
343            let v_head = self.value_weights.slice(s![start_idx..end_idx, ..]);
344
345            let q = q_head.dot(query);
346            let keys = memory.dot(&k_head.t());
347            let values = memory.dot(&v_head.t());
348
349            let mut scores = Array1::zeros(memory.nrows());
350            for i in 0..memory.nrows() {
351                scores[i] = q.dot(&keys.row(i));
352            }
353
354            let max_score = scores.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
355            let exp_scores = scores.map(|&x| (x - max_score).exp());
356            let sum_exp = exp_scores.sum();
357            let attention_weights = if sum_exp > 0.0 {
358                exp_scores / sum_exp
359            } else {
360                Array1::from_elem(memory.nrows(), 1.0 / memory.nrows() as f32)
361            };
362
363            let head_output = values.t().dot(&attention_weights);
364            output
365                .slice_mut(s![start_idx..end_idx])
366                .assign(&head_output);
367        }
368
369        output
370    }
371}
372
373/// Relational Memory Core for structured knowledge
374pub struct RelationalMemoryCore {
375    pub(crate) config: RelationalConfig,
376    pub(crate) memory: Array2<f32>,
377    pub(crate) relation_matrices: Vec<Array2<f32>>,
378    pub(crate) attention_mechanism: RelationalAttention,
379}
380
381impl RelationalMemoryCore {
382    pub fn new(config: RelationalConfig) -> Self {
383        use scirs2_core::random::Random;
384        let mut rng = Random::default();
385
386        let memory = Array2::zeros((config.memory_size, config.embedding_dim));
387        let mut relation_matrices = Vec::new();
388
389        for _ in 0..config.num_relation_types {
390            let relation_matrix =
391                Array2::from_shape_fn((config.embedding_dim, config.embedding_dim), |_| {
392                    rng.random_range(-0.1..0.1)
393                });
394            relation_matrices.push(relation_matrix);
395        }
396
397        let attention_mechanism = RelationalAttention::new(config.embedding_dim, config.num_heads);
398        Self {
399            config,
400            memory,
401            relation_matrices,
402            attention_mechanism,
403        }
404    }
405
406    pub fn store_relation(
407        &mut self,
408        subject: &Array1<f32>,
409        relation_type: usize,
410        object: &Array1<f32>,
411    ) -> Result<()> {
412        if relation_type >= self.config.num_relation_types {
413            return Err(anyhow!("Invalid relation type"));
414        }
415        let relation_matrix = &self.relation_matrices[relation_type];
416        let transformed_subject = relation_matrix.dot(subject);
417        let transformed_object = relation_matrix.dot(object);
418
419        if let Some(slot) = self.find_empty_slot() {
420            let combined = &transformed_subject + &transformed_object;
421            self.memory.row_mut(slot).assign(&combined);
422        }
423        Ok(())
424    }
425
426    fn find_empty_slot(&self) -> Option<usize> {
427        (0..self.memory.nrows()).find(|&i| self.memory.row(i).sum() == 0.0)
428    }
429
430    pub fn query_relations(&self, query: &Array1<f32>) -> Array1<f32> {
431        self.attention_mechanism.forward(&self.memory, query)
432    }
433}
434
435/// Sparse Access Memory configuration
436#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct SparseConfig {
438    pub memory_capacity: usize,
439    pub embedding_dim: usize,
440    pub sparsity_factor: f32,
441    pub update_threshold: f32,
442}
443
444impl Default for SparseConfig {
445    fn default() -> Self {
446        Self {
447            memory_capacity: 10000,
448            embedding_dim: 512,
449            sparsity_factor: 0.1,
450            update_threshold: 0.01,
451        }
452    }
453}
454
455/// Sparse Access Memory for large-scale memory
456pub struct SparseAccessMemory {
457    pub(crate) config: SparseConfig,
458    pub(crate) memory: HashMap<usize, Array1<f32>>,
459    pub(crate) access_counts: HashMap<usize, usize>,
460    pub(crate) last_access: HashMap<usize, Instant>,
461}
462
463impl SparseAccessMemory {
464    pub fn new(config: SparseConfig) -> Self {
465        Self {
466            config,
467            memory: HashMap::new(),
468            access_counts: HashMap::new(),
469            last_access: HashMap::new(),
470        }
471    }
472
473    pub fn store(&mut self, key: usize, value: Array1<f32>) -> Result<()> {
474        if self.memory.len() >= self.config.memory_capacity {
475            self.evict_least_used()?;
476        }
477        self.memory.insert(key, value);
478        self.access_counts.insert(key, 1);
479        self.last_access.insert(key, Instant::now());
480        Ok(())
481    }
482
483    pub fn retrieve(&mut self, key: usize) -> Option<&Array1<f32>> {
484        if let Some(value) = self.memory.get(&key) {
485            *self.access_counts.entry(key).or_insert(0) += 1;
486            self.last_access.insert(key, Instant::now());
487            Some(value)
488        } else {
489            None
490        }
491    }
492
493    pub fn find_similar(&self, query: &Array1<f32>, k: usize) -> Vec<(usize, f32)> {
494        let mut similarities: Vec<(usize, f32)> = self
495            .memory
496            .iter()
497            .map(|(&key, value)| (key, cosine_sim(query, value)))
498            .collect();
499        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
500        similarities.into_iter().take(k).collect()
501    }
502
503    fn evict_least_used(&mut self) -> Result<()> {
504        let mut candidates: Vec<(usize, usize, Instant)> = self
505            .access_counts
506            .iter()
507            .map(|(&key, &count)| {
508                let last_access = self
509                    .last_access
510                    .get(&key)
511                    .copied()
512                    .unwrap_or(Instant::now());
513                (key, count, last_access)
514            })
515            .collect();
516        candidates.sort_by(|a, b| a.1.cmp(&b.1).then_with(|| a.2.cmp(&b.2)));
517
518        if let Some((key_to_evict, _, _)) = candidates.first() {
519            let key = *key_to_evict;
520            self.memory.remove(&key);
521            self.access_counts.remove(&key);
522            self.last_access.remove(&key);
523        }
524        Ok(())
525    }
526
527    pub fn cleanup(&mut self, max_age: Duration) -> Result<usize> {
528        let now = Instant::now();
529        let mut keys_to_remove = Vec::new();
530
531        for (&key, &last_access) in &self.last_access {
532            if now.duration_since(last_access) > max_age {
533                keys_to_remove.push(key);
534            }
535        }
536
537        let removed_count = keys_to_remove.len();
538        for key in keys_to_remove {
539            self.memory.remove(&key);
540            self.access_counts.remove(&key);
541            self.last_access.remove(&key);
542        }
543        Ok(removed_count)
544    }
545}
546
547/// Coordination strategy for multiple memory systems
548#[derive(Debug, Clone, Serialize, Deserialize)]
549pub enum CoordinationStrategy {
550    RoundRobin,
551    PerformanceBased,
552    ContentBased,
553    Adaptive,
554}
555
556/// Memory usage statistics
557#[derive(Debug, Clone, Serialize, Deserialize)]
558pub struct MemoryUsageStats {
559    pub dnc_utilization: f32,
560    pub ntm_utilization: f32,
561    pub memory_networks_utilization: f32,
562    pub episodic_utilization: f32,
563    pub relational_utilization: f32,
564    pub sparse_utilization: f32,
565    pub total_memory_mb: f32,
566}
567
568/// Performance tracker for memory systems
569#[derive(Default)]
570pub struct MemoryPerformanceTracker {
571    pub(crate) access_latencies: HashMap<String, VecDeque<f32>>,
572    pub(crate) hit_rates: HashMap<String, f32>,
573    pub(crate) throughput_metrics: HashMap<String, f32>,
574}
575
576impl MemoryPerformanceTracker {
577    pub fn new() -> Self {
578        Self::default()
579    }
580
581    pub fn record_access(&mut self, memory_type: &str, latency_ms: f32) {
582        let latencies = self
583            .access_latencies
584            .entry(memory_type.to_string())
585            .or_default();
586        latencies.push_back(latency_ms);
587        while latencies.len() > 100 {
588            latencies.pop_front();
589        }
590    }
591
592    pub fn get_average_latency(&self, memory_type: &str) -> f32 {
593        if let Some(latencies) = self.access_latencies.get(memory_type) {
594            if !latencies.is_empty() {
595                return latencies.iter().sum::<f32>() / latencies.len() as f32;
596            }
597        }
598        0.0
599    }
600}
601
602/// Memory coordination system
603pub struct MemoryCoordinator {
604    pub(crate) strategy: CoordinationStrategy,
605    pub(crate) usage_stats: MemoryUsageStats,
606    pub(crate) performance_tracker: MemoryPerformanceTracker,
607}
608
609/// Memory performance metrics
610#[derive(Debug, Clone, Serialize, Deserialize)]
611pub struct MemoryPerformanceMetrics {
612    pub total_operations: u64,
613    pub average_latency_ms: f32,
614    pub hit_rate: f32,
615    pub utilization: f32,
616    pub ops_per_second: f32,
617    pub error_rate: f32,
618}
619
620impl Default for MemoryPerformanceMetrics {
621    fn default() -> Self {
622        Self {
623            total_operations: 0,
624            average_latency_ms: 0.0,
625            hit_rate: 0.0,
626            utilization: 0.0,
627            ops_per_second: 0.0,
628            error_rate: 0.0,
629        }
630    }
631}
632
633/// Global memory system settings
634#[derive(Debug, Clone, Serialize, Deserialize)]
635pub struct GlobalMemorySettings {
636    pub enable_compression: bool,
637    pub memory_capacity_mb: f32,
638    pub cleanup_threshold: f32,
639    pub enable_persistence: bool,
640    pub update_frequency_ms: u64,
641    pub enable_coordination: bool,
642}
643
644impl Default for GlobalMemorySettings {
645    fn default() -> Self {
646        Self {
647            enable_compression: true,
648            memory_capacity_mb: 1024.0,
649            cleanup_threshold: 0.85,
650            enable_persistence: true,
651            update_frequency_ms: 100,
652            enable_coordination: true,
653        }
654    }
655}
656
657/// Configuration for memory-augmented networks
658#[derive(Debug, Clone, Serialize, Deserialize, Default)]
659pub struct MemoryConfig {
660    pub dnc_config: DNCConfig,
661    pub ntm_config: NTMConfig,
662    pub memory_networks_config: MemoryNetworksConfig,
663    pub episodic_config: EpisodicConfig,
664    pub relational_config: RelationalConfig,
665    pub sparse_config: SparseConfig,
666    pub global_settings: GlobalMemorySettings,
667}
668
669/// Memory-Augmented Network Engine
670pub struct MemoryAugmentedNetwork {
671    pub(crate) config: MemoryConfig,
672    pub(crate) dnc: DifferentiableNeuralComputer,
673    pub(crate) ntm: NeuralTuringMachine,
674    pub(crate) memory_networks: MemoryNetworks,
675    pub(crate) episodic_memory: EpisodicMemory,
676    pub(crate) relational_memory: RelationalMemoryCore,
677    pub(crate) sparse_memory: SparseAccessMemory,
678    pub(crate) memory_coordinator: MemoryCoordinator,
679    pub(crate) performance_metrics: MemoryPerformanceMetrics,
680}
681
682impl MemoryAugmentedNetwork {
683    pub fn new(config: MemoryConfig) -> Result<Self> {
684        let dnc = DifferentiableNeuralComputer::new(config.dnc_config.clone());
685        let ntm = NeuralTuringMachine::new(config.ntm_config.clone());
686        let memory_networks = MemoryNetworks::new(config.memory_networks_config.clone());
687        let episodic_memory = EpisodicMemory::new(config.episodic_config.clone());
688        let relational_memory = RelationalMemoryCore::new(config.relational_config.clone());
689        let sparse_memory = SparseAccessMemory::new(config.sparse_config.clone());
690
691        let memory_coordinator = MemoryCoordinator {
692            strategy: CoordinationStrategy::Adaptive,
693            usage_stats: MemoryUsageStats {
694                dnc_utilization: 0.0,
695                ntm_utilization: 0.0,
696                memory_networks_utilization: 0.0,
697                episodic_utilization: 0.0,
698                relational_utilization: 0.0,
699                sparse_utilization: 0.0,
700                total_memory_mb: 0.0,
701            },
702            performance_tracker: MemoryPerformanceTracker::new(),
703        };
704
705        Ok(Self {
706            config,
707            dnc,
708            ntm,
709            memory_networks,
710            episodic_memory,
711            relational_memory,
712            sparse_memory,
713            memory_coordinator,
714            performance_metrics: MemoryPerformanceMetrics::default(),
715        })
716    }
717
718    pub async fn process(
719        &mut self,
720        input: &Array1<f32>,
721        memory_type: Option<&str>,
722    ) -> Result<Array1<f32>> {
723        let start_time = Instant::now();
724
725        let result = match memory_type {
726            Some("dnc") => self.dnc.forward(input),
727            Some("ntm") => self.ntm.forward(input),
728            Some("memory_networks") => Ok(self.memory_networks.query(input)?),
729            Some("relational") => Ok(self.relational_memory.query_relations(input)),
730            Some("sparse") => {
731                let similar = self.sparse_memory.find_similar(input, 1);
732                if let Some((key, _)) = similar.first() {
733                    Ok(self.sparse_memory.retrieve(*key).unwrap_or(input).clone())
734                } else {
735                    Ok(input.clone())
736                }
737            }
738            _ => self.adaptive_routing(input).await,
739        };
740
741        let latency = start_time.elapsed().as_millis() as f32;
742        if let Some(mem_type) = memory_type {
743            self.memory_coordinator
744                .performance_tracker
745                .record_access(mem_type, latency);
746        }
747
748        self.performance_metrics.total_operations += 1;
749        self.update_performance_metrics(latency);
750
751        result
752    }
753
754    async fn adaptive_routing(&mut self, input: &Array1<f32>) -> Result<Array1<f32>> {
755        let input_norm = input.mapv(|x| x * x).sum().sqrt();
756        let input_sparsity =
757            input.iter().filter(|&&x| x.abs() < 0.01).count() as f32 / input.len() as f32;
758
759        match (input_norm, input_sparsity) {
760            (norm, sparsity) if norm > 10.0 && sparsity < 0.3 => self.dnc.forward(input),
761            (norm, sparsity) if norm < 5.0 && sparsity > 0.7 => {
762                let similar = self.sparse_memory.find_similar(input, 1);
763                if let Some((key, _)) = similar.first() {
764                    Ok(self.sparse_memory.retrieve(*key).unwrap_or(input).clone())
765                } else {
766                    Ok(input.clone())
767                }
768            }
769            _ => Ok(self.memory_networks.query(input)?),
770        }
771    }
772
773    pub async fn store(
774        &mut self,
775        content: String,
776        embedding: Array1<f32>,
777        memory_type: Option<&str>,
778    ) -> Result<()> {
779        match memory_type {
780            Some("memory_networks") => {
781                self.memory_networks.store_memory(content, embedding)?;
782            }
783            Some("sparse") => {
784                let key = self.hash_content(&content);
785                self.sparse_memory.store(key, embedding)?;
786            }
787            Some("relational") => {
788                let zero_vector = Array1::zeros(embedding.len());
789                self.relational_memory
790                    .store_relation(&embedding, 0, &zero_vector)?;
791            }
792            _ => {
793                self.memory_networks.store_memory(content, embedding)?;
794            }
795        }
796        Ok(())
797    }
798
799    fn hash_content(&self, content: &str) -> usize {
800        use std::collections::hash_map::DefaultHasher;
801        use std::hash::{Hash, Hasher};
802        let mut hasher = DefaultHasher::new();
803        content.hash(&mut hasher);
804        hasher.finish() as usize
805    }
806
807    pub fn start_episode(&mut self, episode_type: String) {
808        self.episodic_memory.start_episode(episode_type);
809    }
810
811    pub fn add_episode_state(&mut self, state: Array1<f32>, reward: f32) -> Result<()> {
812        self.episodic_memory.add_state(state, reward)
813    }
814
815    pub fn end_episode(&mut self, success: bool) -> Result<()> {
816        self.episodic_memory.end_episode(success)
817    }
818
819    pub fn get_memory_stats(&self) -> MemoryUsageStats {
820        self.memory_coordinator.usage_stats.clone()
821    }
822
823    pub fn get_performance_metrics(&self) -> &MemoryPerformanceMetrics {
824        &self.performance_metrics
825    }
826
827    fn update_performance_metrics(&mut self, latency: f32) {
828        let alpha = 0.1;
829        self.performance_metrics.average_latency_ms =
830            alpha * latency + (1.0 - alpha) * self.performance_metrics.average_latency_ms;
831    }
832
833    pub async fn cleanup(&mut self) -> Result<()> {
834        if self.dnc.get_memory_utilization() > 0.9 {
835            self.dnc.reset();
836        }
837        let cleanup_duration = Duration::from_secs(3600);
838        let removed = self.sparse_memory.cleanup(cleanup_duration)?;
839        if removed > 0 {
840            info!("Cleaned up {} entries from sparse memory", removed);
841        }
842        Ok(())
843    }
844}
845
846fn cosine_sim(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
847    let dot_product = a.dot(b);
848    let norm_a = a.mapv(|x| x * x).sum().sqrt();
849    let norm_b = b.mapv(|x| x * x).sum().sqrt();
850    if norm_a > 0.0 && norm_b > 0.0 {
851        dot_product / (norm_a * norm_b)
852    } else {
853        0.0
854    }
855}