Skip to main content

optirs_learned/
episodic_memory_impl.rs

1// Episodic Memory Implementation
2//
3// Implements advanced methods for EpisodicMemoryBank and SupportSetManager
4// types defined in crate::few_shot.
5
6use scirs2_core::ndarray::Array1;
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11use crate::few_shot::{EpisodicMemoryBank, MemoryBankStats, SupportSetManager};
12
13// ---------------------------------------------------------------------------
14// EpisodicMemoryBank additional impl
15// ---------------------------------------------------------------------------
16
17impl<T: Float + Debug + Send + Sync + 'static> EpisodicMemoryBank<T> {
18    /// Store an episode keyed by task id with a representation vector and
19    /// performance score.
20    ///
21    /// If the bank is at capacity, the eviction policy is applied first.
22    pub fn store_lightweight_episode(
23        &mut self,
24        task_id: String,
25        representation: Array1<T>,
26        performance: T,
27    ) -> Result<()> {
28        use crate::few_shot::{
29            AdaptationPerformance, AdaptationResult, AdaptationStep, DifficultyLevel,
30            DomainCharacteristics, DomainInfo, DomainType, EpisodeMetadata, ExampleMetadata,
31            MemoryEpisode, QueryExample, QuerySet, QuerySetStatistics, ResourceUsage,
32            SupportExample, SupportSet, SupportSetStatistics, TaskData, TaskMetadata,
33        };
34        use std::collections::HashMap;
35        use std::time::Duration;
36
37        // Evict if at capacity
38        if self.episodes().len() >= self.capacity() {
39            self.evict()?;
40        }
41
42        let dim = representation.len();
43
44        // Build a minimal TaskData wrapping the representation
45        let support_example = SupportExample {
46            features: representation.clone(),
47            target: performance,
48            weight: T::one(),
49            context: HashMap::new(),
50            metadata: ExampleMetadata {
51                source: task_id.clone(),
52                quality_score: scirs2_core::numeric::NumCast::from(performance).unwrap_or(0.0),
53                created_at: std::time::SystemTime::now(),
54            },
55        };
56
57        let support_set = SupportSet {
58            examples: vec![support_example],
59            task_metadata: TaskMetadata {
60                task_name: task_id.clone(),
61                domain: DomainType::Optimization,
62                difficulty: DifficultyLevel::Medium,
63                created_at: std::time::SystemTime::now(),
64            },
65            statistics: SupportSetStatistics {
66                mean: representation.clone(),
67                variance: Array1::zeros(dim),
68                size: 1,
69                diversity_score: T::zero(),
70            },
71            temporal_order: None,
72        };
73
74        let query_set = QuerySet {
75            examples: Vec::<QueryExample<T>>::new(),
76            statistics: QuerySetStatistics {
77                mean: Array1::zeros(dim),
78                variance: Array1::zeros(dim),
79                size: 0,
80            },
81            eval_metrics: Vec::new(),
82        };
83
84        let task_data = TaskData {
85            task_id: task_id.clone(),
86            support_set,
87            query_set,
88            task_params: HashMap::new(),
89            domain_info: DomainInfo {
90                domain_type: DomainType::Optimization,
91                characteristics: DomainCharacteristics {
92                    input_dim: dim,
93                    output_dim: 1,
94                    temporal: false,
95                    stochasticity: 0.0,
96                    noise_level: 0.0,
97                    sparsity: 0.0,
98                },
99                difficulty_level: DifficultyLevel::Medium,
100                constraints: Vec::new(),
101            },
102        };
103
104        let adaptation_result = AdaptationResult {
105            adapted_state: crate::OptimizerState {
106                parameters: Array1::zeros(1),
107                gradients: Array1::zeros(1),
108                momentum: None,
109                hidden_states: HashMap::new(),
110                memory_buffers: HashMap::new(),
111                step: 0,
112                step_count: 0,
113                loss: None,
114                learning_rate: scirs2_core::numeric::NumCast::from(0.001)
115                    .unwrap_or_else(|| T::one()),
116                metadata: crate::StateMetadata {
117                    task_id: Some(task_id.clone()),
118                    optimizer_type: None,
119                    version: "1.0".to_string(),
120                    timestamp: std::time::SystemTime::now(),
121                    checksum: 0,
122                    compression_level: 0,
123                    custom_data: HashMap::new(),
124                },
125            },
126            performance: AdaptationPerformance {
127                query_performance: performance,
128                support_performance: performance,
129                adaptation_speed: 1,
130                final_loss: T::one() - performance,
131                improvement: performance,
132                stability: T::one(),
133            },
134            task_representation: representation,
135            adaptation_trajectory: Vec::<AdaptationStep<T>>::new(),
136            resource_usage: ResourceUsage {
137                total_time: Duration::from_secs(0),
138                peak_memory_mb: T::zero(),
139                compute_cost: T::zero(),
140                energy_consumption: T::zero(),
141            },
142        };
143
144        let episode = MemoryEpisode {
145            episode_id: format!("ep_{}", self.usage_stats().total_episodes),
146            task_data,
147            adaptation_result,
148            timestamp: std::time::SystemTime::now(),
149            metadata: EpisodeMetadata {
150                difficulty: DifficultyLevel::Medium,
151                domain: DomainType::Optimization,
152                success_rate: scirs2_core::numeric::NumCast::from(performance).unwrap_or(0.0),
153                tags: Vec::new(),
154            },
155            access_count: 0,
156        };
157
158        self.episodes_mut().push_back(episode);
159        self.usage_stats_mut().total_episodes += 1;
160        let len = self.episodes().len();
161        let cap = self.capacity();
162        self.usage_stats_mut().memory_utilization = len as f64 / cap as f64;
163        Ok(())
164    }
165
166    /// Retrieve the k nearest episodes to a query representation vector.
167    ///
168    /// Returns `Vec<(task_id, similarity)>` sorted by descending similarity
169    /// (cosine similarity).
170    pub fn retrieve_by_repr(&self, query: &Array1<T>, k: usize) -> Result<Vec<(String, T)>> {
171        if self.is_empty() {
172            return Ok(Vec::new());
173        }
174
175        let mut scored: Vec<(usize, T)> = Vec::with_capacity(self.len());
176        for (idx, ep) in self.episodes().iter().enumerate() {
177            let repr = &ep.adaptation_result.task_representation;
178            let sim = cosine_similarity(query, repr);
179            scored.push((idx, sim));
180        }
181
182        // Sort descending by similarity
183        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
184
185        let take = k.min(scored.len());
186        let result: Vec<(String, T)> = scored[..take]
187            .iter()
188            .map(|&(idx, sim)| {
189                let ep = &self.episodes()[idx];
190                (ep.episode_id.clone(), sim)
191            })
192            .collect();
193
194        Ok(result)
195    }
196
197    /// Evict the least useful episode according to the eviction policy.
198    ///
199    /// For the Performance policy, removes the episode with the lowest
200    /// query performance. For others, falls back to removing the oldest.
201    pub fn evict(&mut self) -> Result<()> {
202        if self.is_empty() {
203            return Ok(());
204        }
205
206        match self.eviction_policy() {
207            crate::few_shot::EvictionPolicy::Performance => {
208                // Find the episode with the worst performance
209                let mut worst_idx = 0;
210                let mut worst_perf = T::infinity();
211                for (i, ep) in self.episodes().iter().enumerate() {
212                    let perf = ep.adaptation_result.performance.query_performance;
213                    if perf < worst_perf {
214                        worst_perf = perf;
215                        worst_idx = i;
216                    }
217                }
218                self.episodes_mut().remove(worst_idx);
219            }
220            crate::few_shot::EvictionPolicy::LRU => {
221                // Remove least recently used (lowest access_count among oldest)
222                let mut lru_idx = 0;
223                let mut min_access = usize::MAX;
224                for (i, ep) in self.episodes().iter().enumerate() {
225                    if ep.access_count < min_access {
226                        min_access = ep.access_count;
227                        lru_idx = i;
228                    }
229                }
230                self.episodes_mut().remove(lru_idx);
231            }
232            crate::few_shot::EvictionPolicy::LFU => {
233                // Least frequently used
234                let mut lfu_idx = 0;
235                let mut min_access = usize::MAX;
236                for (i, ep) in self.episodes().iter().enumerate() {
237                    if ep.access_count < min_access {
238                        min_access = ep.access_count;
239                        lfu_idx = i;
240                    }
241                }
242                self.episodes_mut().remove(lfu_idx);
243            }
244            _ => {
245                // Default: remove oldest (front of deque)
246                self.episodes_mut().pop_front();
247            }
248        }
249
250        let len = self.episodes().len();
251        let cap = self.capacity();
252        self.usage_stats_mut().memory_utilization = len as f64 / cap as f64;
253        Ok(())
254    }
255
256    /// Get summary statistics about the memory bank.
257    pub fn get_stats(&self) -> MemoryBankStats<T> {
258        let count = self.len();
259        let cap = self.capacity();
260
261        let avg_performance = if count == 0 {
262            T::zero()
263        } else {
264            let mut sum = T::zero();
265            for ep in self.episodes() {
266                sum = sum + ep.adaptation_result.performance.query_performance;
267            }
268            let count_t: T = scirs2_core::numeric::NumCast::from(count).unwrap_or_else(|| T::one());
269            sum / count_t
270        };
271
272        MemoryBankStats {
273            count,
274            avg_performance,
275            capacity_used: if cap > 0 {
276                count as f64 / cap as f64
277            } else {
278                0.0
279            },
280            total_capacity: cap,
281        }
282    }
283
284    /// Remove all episodes.
285    pub fn clear(&mut self) {
286        self.episodes_mut().clear();
287        self.usage_stats_mut().memory_utilization = 0.0;
288    }
289
290    /// Return the number of stored episodes (alias for len).
291    pub fn size(&self) -> usize {
292        self.len()
293    }
294}
295
296// ---------------------------------------------------------------------------
297// SupportSetManager additional impl
298// ---------------------------------------------------------------------------
299
300impl<T: Float + Debug + Send + Sync + 'static> SupportSetManager<T> {
301    /// Select a diverse subset of candidate indices for a support set.
302    ///
303    /// Uses a greedy farthest-point sampling strategy: the first point is
304    /// selected as the one with the largest norm, then each subsequent point
305    /// is the one that is most distant from all already-selected points
306    /// (measured by squared Euclidean distance).
307    pub fn select_support_set(
308        &self,
309        candidates: &[Array1<T>],
310        _labels: &[T],
311        budget: usize,
312    ) -> Result<Vec<usize>> {
313        if candidates.is_empty() {
314            return Err(OptimError::InsufficientData(
315                "No candidates to select from".to_string(),
316            ));
317        }
318        let n = candidates.len();
319        let take = budget.min(n).min(self.max_support_size());
320
321        if take >= n {
322            return Ok((0..n).collect());
323        }
324
325        // Greedy farthest-point sampling
326        let mut selected: Vec<usize> = Vec::with_capacity(take);
327
328        // Pick the candidate with the largest norm as the seed
329        let mut best_seed = 0;
330        let mut best_norm = T::neg_infinity();
331        for (i, c) in candidates.iter().enumerate() {
332            let norm = vec_norm_sq(c);
333            if norm > best_norm {
334                best_norm = norm;
335                best_seed = i;
336            }
337        }
338        selected.push(best_seed);
339
340        // Track min distance from each candidate to any selected point
341        let mut min_dist: Vec<T> = vec![T::infinity(); n];
342
343        while selected.len() < take {
344            // Update min_dist using the last selected point
345            let last = selected[selected.len() - 1];
346            for i in 0..n {
347                let d = squared_euclidean(&candidates[i], &candidates[last]);
348                if d < min_dist[i] {
349                    min_dist[i] = d;
350                }
351            }
352            // Zero out already-selected points
353            for &s in &selected {
354                min_dist[s] = T::neg_infinity();
355            }
356
357            // Pick the point with the maximum min_dist
358            let mut farthest_idx = 0;
359            let mut farthest_dist = T::neg_infinity();
360            for (i, &dist) in min_dist.iter().enumerate().take(n) {
361                if dist > farthest_dist {
362                    farthest_dist = dist;
363                    farthest_idx = i;
364                }
365            }
366            selected.push(farthest_idx);
367        }
368
369        Ok(selected)
370    }
371
372    /// Augment a support set by adding Gaussian noise to each example.
373    ///
374    /// For each input vector, produces a copy with noise ~ N(0, noise_scale^2)
375    /// added to each element (using a simple deterministic hash-based approach
376    /// for reproducibility without requiring rand).
377    pub fn augment_support_set(
378        &self,
379        support: &[Array1<T>],
380        noise_scale: T,
381    ) -> Result<Vec<Array1<T>>> {
382        if support.is_empty() {
383            return Err(OptimError::InsufficientData(
384                "Cannot augment empty support set".to_string(),
385            ));
386        }
387
388        let mut augmented = Vec::with_capacity(support.len() * 2);
389
390        // Keep originals
391        for s in support {
392            augmented.push(s.clone());
393        }
394
395        // Create augmented copies with deterministic pseudo-noise
396        for (ex_idx, s) in support.iter().enumerate() {
397            let mut noisy = s.clone();
398            for (i, val) in noisy.iter_mut().enumerate() {
399                // Simple deterministic hash-based noise for reproducibility
400                let seed = (ex_idx * 7919 + i * 104729 + 31) as f64;
401                let noise_val = ((seed * 0.6180339887).fract() - 0.5) * 2.0; // in [-1, 1]
402                let noise_t: T =
403                    scirs2_core::numeric::NumCast::from(noise_val).unwrap_or_else(|| T::zero());
404                *val = *val + noise_scale * noise_t;
405            }
406            augmented.push(noisy);
407        }
408
409        Ok(augmented)
410    }
411
412    /// Evaluate the quality/diversity of a support set.
413    ///
414    /// Returns the average pairwise squared Euclidean distance between all
415    /// support vectors (higher = more diverse = better quality).
416    pub fn evaluate_quality(&self, support: &[Array1<T>]) -> Result<T> {
417        if support.len() < 2 {
418            return Ok(T::zero());
419        }
420
421        let n = support.len();
422        let mut total_dist = T::zero();
423        let mut pair_count = 0usize;
424
425        for i in 0..n {
426            for j in (i + 1)..n {
427                total_dist = total_dist + squared_euclidean(&support[i], &support[j]);
428                pair_count += 1;
429            }
430        }
431
432        if pair_count == 0 {
433            return Ok(T::zero());
434        }
435
436        let pair_t: T = scirs2_core::numeric::NumCast::from(pair_count).unwrap_or_else(|| T::one());
437        Ok(total_dist / pair_t)
438    }
439}
440
441// ---------------------------------------------------------------------------
442// Utility functions
443// ---------------------------------------------------------------------------
444
445/// Cosine similarity between two vectors.
446fn cosine_similarity<T: Float>(a: &Array1<T>, b: &Array1<T>) -> T {
447    let len = a.len().min(b.len());
448    let mut dot = T::zero();
449    let mut na = T::zero();
450    let mut nb = T::zero();
451    for i in 0..len {
452        dot = dot + a[i] * b[i];
453        na = na + a[i] * a[i];
454        nb = nb + b[i] * b[i];
455    }
456    let denom = na.sqrt() * nb.sqrt();
457    if denom == T::zero() {
458        T::zero()
459    } else {
460        dot / denom
461    }
462}
463
464/// Squared Euclidean distance.
465fn squared_euclidean<T: Float>(a: &Array1<T>, b: &Array1<T>) -> T {
466    let len = a.len().min(b.len());
467    let mut sum = T::zero();
468    for i in 0..len {
469        let d = a[i] - b[i];
470        sum = sum + d * d;
471    }
472    sum
473}
474
475/// Squared L2 norm.
476fn vec_norm_sq<T: Float>(v: &Array1<T>) -> T {
477    let mut sum = T::zero();
478    for &x in v.iter() {
479        sum = sum + x * x;
480    }
481    sum
482}
483
484// ---------------------------------------------------------------------------
485// Tests
486// ---------------------------------------------------------------------------
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491    use scirs2_core::ndarray::Array1;
492
493    #[test]
494    fn test_episodic_memory_store_retrieve() {
495        let mut bank = EpisodicMemoryBank::<f64>::from_capacity(10)
496            .expect("failed to create EpisodicMemoryBank");
497        assert_eq!(bank.size(), 0);
498
499        bank.store_lightweight_episode(
500            "task_a".to_string(),
501            Array1::from_vec(vec![1.0, 0.0, 0.0]),
502            0.9,
503        )
504        .expect("store failed");
505        bank.store_lightweight_episode(
506            "task_b".to_string(),
507            Array1::from_vec(vec![0.0, 1.0, 0.0]),
508            0.7,
509        )
510        .expect("store failed");
511        assert_eq!(bank.size(), 2);
512
513        // Retrieve similar to [1, 0, 0] should return task_a first
514        let results = bank
515            .retrieve_by_repr(&Array1::from_vec(vec![1.0, 0.1, 0.0]), 2)
516            .expect("retrieve failed");
517        assert_eq!(results.len(), 2);
518        // First result should be ep_0 (task_a), which is closer to the query
519        assert_eq!(results[0].0, "ep_0");
520        assert!(results[0].1 > results[1].1);
521    }
522
523    #[test]
524    fn test_episodic_memory_eviction() {
525        let mut bank = EpisodicMemoryBank::<f64>::from_capacity(3)
526            .expect("failed to create EpisodicMemoryBank");
527
528        bank.store_lightweight_episode("t1".into(), Array1::from_vec(vec![1.0]), 0.5)
529            .expect("store failed");
530        bank.store_lightweight_episode("t2".into(), Array1::from_vec(vec![2.0]), 0.9)
531            .expect("store failed");
532        bank.store_lightweight_episode("t3".into(), Array1::from_vec(vec![3.0]), 0.3)
533            .expect("store failed");
534        assert_eq!(bank.size(), 3);
535
536        // Storing a 4th should trigger eviction (worst performance = t3 with 0.3)
537        bank.store_lightweight_episode("t4".into(), Array1::from_vec(vec![4.0]), 0.8)
538            .expect("store failed");
539        assert_eq!(bank.size(), 3);
540
541        // The episode with performance 0.3 should have been evicted
542        let has_low_perf = bank.episodes().iter().any(|ep| {
543            let perf = ep.adaptation_result.performance.query_performance;
544            (perf - 0.3).abs() < 1e-12
545        });
546        assert!(
547            !has_low_perf,
548            "lowest-performance episode should be evicted"
549        );
550    }
551
552    #[test]
553    fn test_memory_bank_stats() {
554        let mut bank = EpisodicMemoryBank::<f64>::from_capacity(10)
555            .expect("failed to create EpisodicMemoryBank");
556
557        let stats = bank.get_stats();
558        assert_eq!(stats.count, 0);
559        assert!((stats.avg_performance - 0.0).abs() < 1e-12);
560        assert!((stats.capacity_used - 0.0).abs() < 1e-12);
561        assert_eq!(stats.total_capacity, 10);
562
563        bank.store_lightweight_episode("a".into(), Array1::from_vec(vec![1.0]), 0.8)
564            .expect("store failed");
565        bank.store_lightweight_episode("b".into(), Array1::from_vec(vec![2.0]), 0.6)
566            .expect("store failed");
567
568        let stats2 = bank.get_stats();
569        assert_eq!(stats2.count, 2);
570        assert!((stats2.avg_performance - 0.7).abs() < 1e-12);
571        assert!((stats2.capacity_used - 0.2).abs() < 1e-12);
572
573        bank.clear();
574        assert_eq!(bank.size(), 0);
575    }
576
577    #[test]
578    fn test_support_set_selection() {
579        let mgr = SupportSetManager::<f64>::from_max_size(10)
580            .expect("failed to create SupportSetManager");
581        let candidates = vec![
582            Array1::from_vec(vec![0.0, 0.0]),
583            Array1::from_vec(vec![10.0, 0.0]),
584            Array1::from_vec(vec![0.0, 10.0]),
585            Array1::from_vec(vec![5.0, 5.0]),
586            Array1::from_vec(vec![10.0, 10.0]),
587        ];
588        let labels = vec![0.0, 1.0, 2.0, 3.0, 4.0];
589
590        let selected = mgr
591            .select_support_set(&candidates, &labels, 3)
592            .expect("select failed");
593        assert_eq!(selected.len(), 3);
594
595        // Farthest-point should pick well-separated points
596        // The seed is the one with largest norm = [10, 10] at index 4
597        assert!(selected.contains(&4));
598        // All selected indices must be unique
599        let mut unique = selected.clone();
600        unique.sort();
601        unique.dedup();
602        assert_eq!(unique.len(), selected.len());
603    }
604
605    #[test]
606    fn test_support_set_augmentation() {
607        let mgr = SupportSetManager::<f64>::from_max_size(10)
608            .expect("failed to create SupportSetManager");
609        let support = vec![
610            Array1::from_vec(vec![1.0, 2.0, 3.0]),
611            Array1::from_vec(vec![4.0, 5.0, 6.0]),
612        ];
613        let augmented = mgr
614            .augment_support_set(&support, 0.1)
615            .expect("augment failed");
616        // Should have original + noisy copies = 4
617        assert_eq!(augmented.len(), 4);
618        // First two should be identical to originals
619        for i in 0..3 {
620            assert!((augmented[0][i] - support[0][i]).abs() < 1e-12);
621            assert!((augmented[1][i] - support[1][i]).abs() < 1e-12);
622        }
623        // Noisy copies should be slightly different
624        let mut any_different = false;
625        for i in 0..3 {
626            if (augmented[2][i] - support[0][i]).abs() > 1e-15 {
627                any_different = true;
628            }
629        }
630        assert!(any_different, "augmented copy should differ from original");
631    }
632
633    #[test]
634    fn test_support_set_quality() {
635        let mgr = SupportSetManager::<f64>::from_max_size(10)
636            .expect("failed to create SupportSetManager");
637
638        // High diversity
639        let diverse = vec![
640            Array1::from_vec(vec![0.0, 0.0]),
641            Array1::from_vec(vec![100.0, 0.0]),
642            Array1::from_vec(vec![0.0, 100.0]),
643        ];
644        let quality_diverse = mgr.evaluate_quality(&diverse).expect("quality failed");
645
646        // Low diversity
647        let clustered = vec![
648            Array1::from_vec(vec![0.0, 0.0]),
649            Array1::from_vec(vec![0.1, 0.0]),
650            Array1::from_vec(vec![0.0, 0.1]),
651        ];
652        let quality_clustered = mgr.evaluate_quality(&clustered).expect("quality failed");
653
654        assert!(
655            quality_diverse > quality_clustered,
656            "diverse set should have higher quality than clustered set"
657        );
658
659        // Single element should return 0
660        let single = vec![Array1::from_vec(vec![1.0, 2.0])];
661        let quality_single = mgr.evaluate_quality(&single).expect("quality failed");
662        assert!((quality_single - 0.0).abs() < 1e-12);
663    }
664}