1use scirs2_core::ndarray::Array1;
7use scirs2_core::numeric::Float;
8use std::fmt::Debug;
9
10use crate::error::{OptimError, Result};
11use crate::few_shot::{EpisodicMemoryBank, MemoryBankStats, SupportSetManager};
12
13impl<T: Float + Debug + Send + Sync + 'static> EpisodicMemoryBank<T> {
18 pub fn store_lightweight_episode(
23 &mut self,
24 task_id: String,
25 representation: Array1<T>,
26 performance: T,
27 ) -> Result<()> {
28 use crate::few_shot::{
29 AdaptationPerformance, AdaptationResult, AdaptationStep, DifficultyLevel,
30 DomainCharacteristics, DomainInfo, DomainType, EpisodeMetadata, ExampleMetadata,
31 MemoryEpisode, QueryExample, QuerySet, QuerySetStatistics, ResourceUsage,
32 SupportExample, SupportSet, SupportSetStatistics, TaskData, TaskMetadata,
33 };
34 use std::collections::HashMap;
35 use std::time::Duration;
36
37 if self.episodes().len() >= self.capacity() {
39 self.evict()?;
40 }
41
42 let dim = representation.len();
43
44 let support_example = SupportExample {
46 features: representation.clone(),
47 target: performance,
48 weight: T::one(),
49 context: HashMap::new(),
50 metadata: ExampleMetadata {
51 source: task_id.clone(),
52 quality_score: scirs2_core::numeric::NumCast::from(performance).unwrap_or(0.0),
53 created_at: std::time::SystemTime::now(),
54 },
55 };
56
57 let support_set = SupportSet {
58 examples: vec![support_example],
59 task_metadata: TaskMetadata {
60 task_name: task_id.clone(),
61 domain: DomainType::Optimization,
62 difficulty: DifficultyLevel::Medium,
63 created_at: std::time::SystemTime::now(),
64 },
65 statistics: SupportSetStatistics {
66 mean: representation.clone(),
67 variance: Array1::zeros(dim),
68 size: 1,
69 diversity_score: T::zero(),
70 },
71 temporal_order: None,
72 };
73
74 let query_set = QuerySet {
75 examples: Vec::<QueryExample<T>>::new(),
76 statistics: QuerySetStatistics {
77 mean: Array1::zeros(dim),
78 variance: Array1::zeros(dim),
79 size: 0,
80 },
81 eval_metrics: Vec::new(),
82 };
83
84 let task_data = TaskData {
85 task_id: task_id.clone(),
86 support_set,
87 query_set,
88 task_params: HashMap::new(),
89 domain_info: DomainInfo {
90 domain_type: DomainType::Optimization,
91 characteristics: DomainCharacteristics {
92 input_dim: dim,
93 output_dim: 1,
94 temporal: false,
95 stochasticity: 0.0,
96 noise_level: 0.0,
97 sparsity: 0.0,
98 },
99 difficulty_level: DifficultyLevel::Medium,
100 constraints: Vec::new(),
101 },
102 };
103
104 let adaptation_result = AdaptationResult {
105 adapted_state: crate::OptimizerState {
106 parameters: Array1::zeros(1),
107 gradients: Array1::zeros(1),
108 momentum: None,
109 hidden_states: HashMap::new(),
110 memory_buffers: HashMap::new(),
111 step: 0,
112 step_count: 0,
113 loss: None,
114 learning_rate: scirs2_core::numeric::NumCast::from(0.001)
115 .unwrap_or_else(|| T::one()),
116 metadata: crate::StateMetadata {
117 task_id: Some(task_id.clone()),
118 optimizer_type: None,
119 version: "1.0".to_string(),
120 timestamp: std::time::SystemTime::now(),
121 checksum: 0,
122 compression_level: 0,
123 custom_data: HashMap::new(),
124 },
125 },
126 performance: AdaptationPerformance {
127 query_performance: performance,
128 support_performance: performance,
129 adaptation_speed: 1,
130 final_loss: T::one() - performance,
131 improvement: performance,
132 stability: T::one(),
133 },
134 task_representation: representation,
135 adaptation_trajectory: Vec::<AdaptationStep<T>>::new(),
136 resource_usage: ResourceUsage {
137 total_time: Duration::from_secs(0),
138 peak_memory_mb: T::zero(),
139 compute_cost: T::zero(),
140 energy_consumption: T::zero(),
141 },
142 };
143
144 let episode = MemoryEpisode {
145 episode_id: format!("ep_{}", self.usage_stats().total_episodes),
146 task_data,
147 adaptation_result,
148 timestamp: std::time::SystemTime::now(),
149 metadata: EpisodeMetadata {
150 difficulty: DifficultyLevel::Medium,
151 domain: DomainType::Optimization,
152 success_rate: scirs2_core::numeric::NumCast::from(performance).unwrap_or(0.0),
153 tags: Vec::new(),
154 },
155 access_count: 0,
156 };
157
158 self.episodes_mut().push_back(episode);
159 self.usage_stats_mut().total_episodes += 1;
160 let len = self.episodes().len();
161 let cap = self.capacity();
162 self.usage_stats_mut().memory_utilization = len as f64 / cap as f64;
163 Ok(())
164 }
165
166 pub fn retrieve_by_repr(&self, query: &Array1<T>, k: usize) -> Result<Vec<(String, T)>> {
171 if self.is_empty() {
172 return Ok(Vec::new());
173 }
174
175 let mut scored: Vec<(usize, T)> = Vec::with_capacity(self.len());
176 for (idx, ep) in self.episodes().iter().enumerate() {
177 let repr = &ep.adaptation_result.task_representation;
178 let sim = cosine_similarity(query, repr);
179 scored.push((idx, sim));
180 }
181
182 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
184
185 let take = k.min(scored.len());
186 let result: Vec<(String, T)> = scored[..take]
187 .iter()
188 .map(|&(idx, sim)| {
189 let ep = &self.episodes()[idx];
190 (ep.episode_id.clone(), sim)
191 })
192 .collect();
193
194 Ok(result)
195 }
196
197 pub fn evict(&mut self) -> Result<()> {
202 if self.is_empty() {
203 return Ok(());
204 }
205
206 match self.eviction_policy() {
207 crate::few_shot::EvictionPolicy::Performance => {
208 let mut worst_idx = 0;
210 let mut worst_perf = T::infinity();
211 for (i, ep) in self.episodes().iter().enumerate() {
212 let perf = ep.adaptation_result.performance.query_performance;
213 if perf < worst_perf {
214 worst_perf = perf;
215 worst_idx = i;
216 }
217 }
218 self.episodes_mut().remove(worst_idx);
219 }
220 crate::few_shot::EvictionPolicy::LRU => {
221 let mut lru_idx = 0;
223 let mut min_access = usize::MAX;
224 for (i, ep) in self.episodes().iter().enumerate() {
225 if ep.access_count < min_access {
226 min_access = ep.access_count;
227 lru_idx = i;
228 }
229 }
230 self.episodes_mut().remove(lru_idx);
231 }
232 crate::few_shot::EvictionPolicy::LFU => {
233 let mut lfu_idx = 0;
235 let mut min_access = usize::MAX;
236 for (i, ep) in self.episodes().iter().enumerate() {
237 if ep.access_count < min_access {
238 min_access = ep.access_count;
239 lfu_idx = i;
240 }
241 }
242 self.episodes_mut().remove(lfu_idx);
243 }
244 _ => {
245 self.episodes_mut().pop_front();
247 }
248 }
249
250 let len = self.episodes().len();
251 let cap = self.capacity();
252 self.usage_stats_mut().memory_utilization = len as f64 / cap as f64;
253 Ok(())
254 }
255
256 pub fn get_stats(&self) -> MemoryBankStats<T> {
258 let count = self.len();
259 let cap = self.capacity();
260
261 let avg_performance = if count == 0 {
262 T::zero()
263 } else {
264 let mut sum = T::zero();
265 for ep in self.episodes() {
266 sum = sum + ep.adaptation_result.performance.query_performance;
267 }
268 let count_t: T = scirs2_core::numeric::NumCast::from(count).unwrap_or_else(|| T::one());
269 sum / count_t
270 };
271
272 MemoryBankStats {
273 count,
274 avg_performance,
275 capacity_used: if cap > 0 {
276 count as f64 / cap as f64
277 } else {
278 0.0
279 },
280 total_capacity: cap,
281 }
282 }
283
284 pub fn clear(&mut self) {
286 self.episodes_mut().clear();
287 self.usage_stats_mut().memory_utilization = 0.0;
288 }
289
290 pub fn size(&self) -> usize {
292 self.len()
293 }
294}
295
296impl<T: Float + Debug + Send + Sync + 'static> SupportSetManager<T> {
301 pub fn select_support_set(
308 &self,
309 candidates: &[Array1<T>],
310 _labels: &[T],
311 budget: usize,
312 ) -> Result<Vec<usize>> {
313 if candidates.is_empty() {
314 return Err(OptimError::InsufficientData(
315 "No candidates to select from".to_string(),
316 ));
317 }
318 let n = candidates.len();
319 let take = budget.min(n).min(self.max_support_size());
320
321 if take >= n {
322 return Ok((0..n).collect());
323 }
324
325 let mut selected: Vec<usize> = Vec::with_capacity(take);
327
328 let mut best_seed = 0;
330 let mut best_norm = T::neg_infinity();
331 for (i, c) in candidates.iter().enumerate() {
332 let norm = vec_norm_sq(c);
333 if norm > best_norm {
334 best_norm = norm;
335 best_seed = i;
336 }
337 }
338 selected.push(best_seed);
339
340 let mut min_dist: Vec<T> = vec![T::infinity(); n];
342
343 while selected.len() < take {
344 let last = selected[selected.len() - 1];
346 for i in 0..n {
347 let d = squared_euclidean(&candidates[i], &candidates[last]);
348 if d < min_dist[i] {
349 min_dist[i] = d;
350 }
351 }
352 for &s in &selected {
354 min_dist[s] = T::neg_infinity();
355 }
356
357 let mut farthest_idx = 0;
359 let mut farthest_dist = T::neg_infinity();
360 for (i, &dist) in min_dist.iter().enumerate().take(n) {
361 if dist > farthest_dist {
362 farthest_dist = dist;
363 farthest_idx = i;
364 }
365 }
366 selected.push(farthest_idx);
367 }
368
369 Ok(selected)
370 }
371
372 pub fn augment_support_set(
378 &self,
379 support: &[Array1<T>],
380 noise_scale: T,
381 ) -> Result<Vec<Array1<T>>> {
382 if support.is_empty() {
383 return Err(OptimError::InsufficientData(
384 "Cannot augment empty support set".to_string(),
385 ));
386 }
387
388 let mut augmented = Vec::with_capacity(support.len() * 2);
389
390 for s in support {
392 augmented.push(s.clone());
393 }
394
395 for (ex_idx, s) in support.iter().enumerate() {
397 let mut noisy = s.clone();
398 for (i, val) in noisy.iter_mut().enumerate() {
399 let seed = (ex_idx * 7919 + i * 104729 + 31) as f64;
401 let noise_val = ((seed * 0.6180339887).fract() - 0.5) * 2.0; let noise_t: T =
403 scirs2_core::numeric::NumCast::from(noise_val).unwrap_or_else(|| T::zero());
404 *val = *val + noise_scale * noise_t;
405 }
406 augmented.push(noisy);
407 }
408
409 Ok(augmented)
410 }
411
412 pub fn evaluate_quality(&self, support: &[Array1<T>]) -> Result<T> {
417 if support.len() < 2 {
418 return Ok(T::zero());
419 }
420
421 let n = support.len();
422 let mut total_dist = T::zero();
423 let mut pair_count = 0usize;
424
425 for i in 0..n {
426 for j in (i + 1)..n {
427 total_dist = total_dist + squared_euclidean(&support[i], &support[j]);
428 pair_count += 1;
429 }
430 }
431
432 if pair_count == 0 {
433 return Ok(T::zero());
434 }
435
436 let pair_t: T = scirs2_core::numeric::NumCast::from(pair_count).unwrap_or_else(|| T::one());
437 Ok(total_dist / pair_t)
438 }
439}
440
441fn cosine_similarity<T: Float>(a: &Array1<T>, b: &Array1<T>) -> T {
447 let len = a.len().min(b.len());
448 let mut dot = T::zero();
449 let mut na = T::zero();
450 let mut nb = T::zero();
451 for i in 0..len {
452 dot = dot + a[i] * b[i];
453 na = na + a[i] * a[i];
454 nb = nb + b[i] * b[i];
455 }
456 let denom = na.sqrt() * nb.sqrt();
457 if denom == T::zero() {
458 T::zero()
459 } else {
460 dot / denom
461 }
462}
463
464fn squared_euclidean<T: Float>(a: &Array1<T>, b: &Array1<T>) -> T {
466 let len = a.len().min(b.len());
467 let mut sum = T::zero();
468 for i in 0..len {
469 let d = a[i] - b[i];
470 sum = sum + d * d;
471 }
472 sum
473}
474
475fn vec_norm_sq<T: Float>(v: &Array1<T>) -> T {
477 let mut sum = T::zero();
478 for &x in v.iter() {
479 sum = sum + x * x;
480 }
481 sum
482}
483
484#[cfg(test)]
489mod tests {
490 use super::*;
491 use scirs2_core::ndarray::Array1;
492
493 #[test]
494 fn test_episodic_memory_store_retrieve() {
495 let mut bank = EpisodicMemoryBank::<f64>::from_capacity(10)
496 .expect("failed to create EpisodicMemoryBank");
497 assert_eq!(bank.size(), 0);
498
499 bank.store_lightweight_episode(
500 "task_a".to_string(),
501 Array1::from_vec(vec![1.0, 0.0, 0.0]),
502 0.9,
503 )
504 .expect("store failed");
505 bank.store_lightweight_episode(
506 "task_b".to_string(),
507 Array1::from_vec(vec![0.0, 1.0, 0.0]),
508 0.7,
509 )
510 .expect("store failed");
511 assert_eq!(bank.size(), 2);
512
513 let results = bank
515 .retrieve_by_repr(&Array1::from_vec(vec![1.0, 0.1, 0.0]), 2)
516 .expect("retrieve failed");
517 assert_eq!(results.len(), 2);
518 assert_eq!(results[0].0, "ep_0");
520 assert!(results[0].1 > results[1].1);
521 }
522
523 #[test]
524 fn test_episodic_memory_eviction() {
525 let mut bank = EpisodicMemoryBank::<f64>::from_capacity(3)
526 .expect("failed to create EpisodicMemoryBank");
527
528 bank.store_lightweight_episode("t1".into(), Array1::from_vec(vec![1.0]), 0.5)
529 .expect("store failed");
530 bank.store_lightweight_episode("t2".into(), Array1::from_vec(vec![2.0]), 0.9)
531 .expect("store failed");
532 bank.store_lightweight_episode("t3".into(), Array1::from_vec(vec![3.0]), 0.3)
533 .expect("store failed");
534 assert_eq!(bank.size(), 3);
535
536 bank.store_lightweight_episode("t4".into(), Array1::from_vec(vec![4.0]), 0.8)
538 .expect("store failed");
539 assert_eq!(bank.size(), 3);
540
541 let has_low_perf = bank.episodes().iter().any(|ep| {
543 let perf = ep.adaptation_result.performance.query_performance;
544 (perf - 0.3).abs() < 1e-12
545 });
546 assert!(
547 !has_low_perf,
548 "lowest-performance episode should be evicted"
549 );
550 }
551
552 #[test]
553 fn test_memory_bank_stats() {
554 let mut bank = EpisodicMemoryBank::<f64>::from_capacity(10)
555 .expect("failed to create EpisodicMemoryBank");
556
557 let stats = bank.get_stats();
558 assert_eq!(stats.count, 0);
559 assert!((stats.avg_performance - 0.0).abs() < 1e-12);
560 assert!((stats.capacity_used - 0.0).abs() < 1e-12);
561 assert_eq!(stats.total_capacity, 10);
562
563 bank.store_lightweight_episode("a".into(), Array1::from_vec(vec![1.0]), 0.8)
564 .expect("store failed");
565 bank.store_lightweight_episode("b".into(), Array1::from_vec(vec![2.0]), 0.6)
566 .expect("store failed");
567
568 let stats2 = bank.get_stats();
569 assert_eq!(stats2.count, 2);
570 assert!((stats2.avg_performance - 0.7).abs() < 1e-12);
571 assert!((stats2.capacity_used - 0.2).abs() < 1e-12);
572
573 bank.clear();
574 assert_eq!(bank.size(), 0);
575 }
576
577 #[test]
578 fn test_support_set_selection() {
579 let mgr = SupportSetManager::<f64>::from_max_size(10)
580 .expect("failed to create SupportSetManager");
581 let candidates = vec![
582 Array1::from_vec(vec![0.0, 0.0]),
583 Array1::from_vec(vec![10.0, 0.0]),
584 Array1::from_vec(vec![0.0, 10.0]),
585 Array1::from_vec(vec![5.0, 5.0]),
586 Array1::from_vec(vec![10.0, 10.0]),
587 ];
588 let labels = vec![0.0, 1.0, 2.0, 3.0, 4.0];
589
590 let selected = mgr
591 .select_support_set(&candidates, &labels, 3)
592 .expect("select failed");
593 assert_eq!(selected.len(), 3);
594
595 assert!(selected.contains(&4));
598 let mut unique = selected.clone();
600 unique.sort();
601 unique.dedup();
602 assert_eq!(unique.len(), selected.len());
603 }
604
605 #[test]
606 fn test_support_set_augmentation() {
607 let mgr = SupportSetManager::<f64>::from_max_size(10)
608 .expect("failed to create SupportSetManager");
609 let support = vec![
610 Array1::from_vec(vec![1.0, 2.0, 3.0]),
611 Array1::from_vec(vec![4.0, 5.0, 6.0]),
612 ];
613 let augmented = mgr
614 .augment_support_set(&support, 0.1)
615 .expect("augment failed");
616 assert_eq!(augmented.len(), 4);
618 for i in 0..3 {
620 assert!((augmented[0][i] - support[0][i]).abs() < 1e-12);
621 assert!((augmented[1][i] - support[1][i]).abs() < 1e-12);
622 }
623 let mut any_different = false;
625 for i in 0..3 {
626 if (augmented[2][i] - support[0][i]).abs() > 1e-15 {
627 any_different = true;
628 }
629 }
630 assert!(any_different, "augmented copy should differ from original");
631 }
632
633 #[test]
634 fn test_support_set_quality() {
635 let mgr = SupportSetManager::<f64>::from_max_size(10)
636 .expect("failed to create SupportSetManager");
637
638 let diverse = vec![
640 Array1::from_vec(vec![0.0, 0.0]),
641 Array1::from_vec(vec![100.0, 0.0]),
642 Array1::from_vec(vec![0.0, 100.0]),
643 ];
644 let quality_diverse = mgr.evaluate_quality(&diverse).expect("quality failed");
645
646 let clustered = vec![
648 Array1::from_vec(vec![0.0, 0.0]),
649 Array1::from_vec(vec![0.1, 0.0]),
650 Array1::from_vec(vec![0.0, 0.1]),
651 ];
652 let quality_clustered = mgr.evaluate_quality(&clustered).expect("quality failed");
653
654 assert!(
655 quality_diverse > quality_clustered,
656 "diverse set should have higher quality than clustered set"
657 );
658
659 let single = vec![Array1::from_vec(vec![1.0, 2.0])];
661 let quality_single = mgr.evaluate_quality(&single).expect("quality failed");
662 assert!((quality_single - 0.0).abs() < 1e-12);
663 }
664}