1#[allow(dead_code)]
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::numeric::Float;
10use std::collections::{HashMap, VecDeque};
11use std::fmt::Debug;
12use std::time::{Duration, Instant};
13
14use super::OptimizerState;
15use crate::error::{OptimError, Result};
16
17pub struct FewShotLearningSystem<T: Float + Debug + Send + Sync + 'static> {
19 base_optimizer: Box<dyn FewShotOptimizer<T>>,
21
22 prototype_network: PrototypicalNetwork<T>,
24
25 support_set_manager: SupportSetManager<T>,
27
28 adaptation_strategies: Vec<Box<dyn AdaptationStrategy<T>>>,
30
31 similarity_calculator: TaskSimilarityCalculator<T>,
33
34 memory_bank: EpisodicMemoryBank<T>,
36
37 fast_adaptation: FastAdaptationEngine<T>,
39
40 performance_tracker: FewShotPerformanceTracker<T>,
42}
43
44pub trait FewShotOptimizer<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
46 fn adapt_few_shot(
48 &mut self,
49 support_set: &SupportSet<T>,
50 query_set: &QuerySet<T>,
51 adaptation_config: &AdaptationConfig,
52 ) -> Result<AdaptationResult<T>>;
53
54 fn get_task_representation(&self, taskdata: &TaskData<T>) -> Result<Array1<T>>;
56
57 fn compute_adaptation_loss(
59 &self,
60 support_set: &SupportSet<T>,
61 query_set: &QuerySet<T>,
62 ) -> Result<T>;
63
64 fn update_meta_parameters(&mut self, metagradients: &MetaGradients<T>) -> Result<()>;
66
67 fn get_transfer_state(&self) -> TransferState<T>;
69
70 fn load_transfer_state(&mut self, state: TransferState<T>) -> Result<()>;
72}
73
74#[derive(Debug, Clone)]
76pub struct SupportSet<T: Float + Debug + Send + Sync + 'static> {
77 pub examples: Vec<SupportExample<T>>,
79
80 pub task_metadata: TaskMetadata,
82
83 pub statistics: SupportSetStatistics<T>,
85
86 pub temporal_order: Option<Vec<usize>>,
88}
89
90#[derive(Debug, Clone)]
92pub struct SupportExample<T: Float + Debug + Send + Sync + 'static> {
93 pub features: Array1<T>,
95
96 pub target: T,
98
99 pub weight: T,
101
102 pub context: HashMap<String, T>,
104
105 pub metadata: ExampleMetadata,
107}
108
109#[derive(Debug, Clone)]
111pub struct QuerySet<T: Float + Debug + Send + Sync + 'static> {
112 pub examples: Vec<QueryExample<T>>,
114
115 pub statistics: QuerySetStatistics<T>,
117
118 pub eval_metrics: Vec<EvaluationMetric>,
120}
121
122#[derive(Debug, Clone)]
124pub struct QueryExample<T: Float + Debug + Send + Sync + 'static> {
125 pub features: Array1<T>,
127
128 pub true_target: Option<T>,
130
131 pub weight: T,
133
134 pub context: HashMap<String, T>,
136}
137
138#[derive(Debug, Clone)]
140pub struct TaskData<T: Float + Debug + Send + Sync + 'static> {
141 pub task_id: String,
143
144 pub support_set: SupportSet<T>,
146
147 pub query_set: QuerySet<T>,
149
150 pub task_params: HashMap<String, T>,
152
153 pub domain_info: DomainInfo,
155}
156
157#[derive(Debug, Clone)]
159pub struct DomainInfo {
160 pub domain_type: DomainType,
162
163 pub characteristics: DomainCharacteristics,
165
166 pub difficulty_level: DifficultyLevel,
168
169 pub constraints: Vec<DomainConstraint>,
171}
172
173#[derive(Debug, Clone, Copy)]
175pub enum DomainType {
176 ComputerVision,
177 NaturalLanguageProcessing,
178 ReinforcementLearning,
179 TimeSeriesForecasting,
180 ScientificComputing,
181 Optimization,
182 ControlSystems,
183 GamePlaying,
184 Robotics,
185 Healthcare,
186}
187
188#[derive(Debug, Clone)]
190pub struct DomainCharacteristics {
191 pub input_dim: usize,
193
194 pub output_dim: usize,
196
197 pub temporal: bool,
199
200 pub stochasticity: f64,
202
203 pub noise_level: f64,
205
206 pub sparsity: f64,
208}
209
210#[derive(Debug, Clone, Copy)]
212pub enum DifficultyLevel {
213 Trivial,
214 Easy,
215 Medium,
216 Hard,
217 Expert,
218 Extreme,
219}
220
221#[derive(Debug, Clone)]
223pub struct DomainConstraint {
224 pub constraint_type: ConstraintType,
226
227 pub description: String,
229
230 pub enforcement: ConstraintEnforcement,
232}
233
234#[derive(Debug, Clone, Copy)]
236pub enum ConstraintType {
237 ResourceLimit,
238 TemporalConstraint,
239 AccuracyRequirement,
240 LatencyRequirement,
241 MemoryConstraint,
242 EnergyConstraint,
243 SafetyConstraint,
244}
245
246#[derive(Debug, Clone, Copy)]
248pub enum ConstraintEnforcement {
249 Hard,
250 Soft,
251 Advisory,
252}
253
254#[derive(Debug, Clone)]
256pub struct AdaptationConfig {
257 pub adaptation_steps: usize,
259
260 pub adaptation_lr: f64,
262
263 pub strategy: AdaptationStrategyType,
265
266 pub early_stopping: Option<EarlyStoppingConfig>,
268
269 pub regularization: RegularizationConfig,
271
272 pub resource_constraints: ResourceConstraints,
274}
275
276#[derive(Debug, Clone, Copy)]
278pub enum AdaptationStrategyType {
279 MAML,
281
282 FOMAML,
284
285 Prototypical,
287
288 Matching,
290
291 Relation,
293
294 MetaSGD,
296
297 LearnedOptimizer,
299
300 GradientBased,
302
303 MemoryAugmented,
305}
306
307#[derive(Debug, Clone)]
309pub struct EarlyStoppingConfig {
310 pub patience: usize,
312
313 pub min_improvement: f64,
315
316 pub validation_frequency: usize,
318}
319
320#[derive(Debug, Clone)]
322pub struct RegularizationConfig {
323 pub l2_strength: f64,
325
326 pub dropout_rate: f64,
328
329 pub gradient_clip: Option<f64>,
331
332 pub task_regularization: HashMap<String, f64>,
334}
335
336#[derive(Debug, Clone)]
338pub struct ResourceConstraints {
339 pub max_time: Duration,
341
342 pub max_memory_mb: usize,
344
345 pub max_compute_budget: f64,
347}
348
349#[derive(Debug, Clone)]
351pub struct AdaptationResult<T: Float + Debug + Send + Sync + 'static> {
352 pub adapted_state: OptimizerState<T>,
354
355 pub performance: AdaptationPerformance<T>,
357
358 pub task_representation: Array1<T>,
360
361 pub adaptation_trajectory: Vec<AdaptationStep<T>>,
363
364 pub resource_usage: ResourceUsage<T>,
366}
367
368#[derive(Debug, Clone)]
370pub struct AdaptationPerformance<T: Float + Debug + Send + Sync + 'static> {
371 pub query_performance: T,
373
374 pub support_performance: T,
376
377 pub adaptation_speed: usize,
379
380 pub final_loss: T,
382
383 pub improvement: T,
385
386 pub stability: T,
388}
389
390#[derive(Debug, Clone)]
392pub struct AdaptationStep<T: Float + Debug + Send + Sync + 'static> {
393 pub step: usize,
395
396 pub loss: T,
398
399 pub performance: T,
401
402 pub gradient_norm: T,
404
405 pub step_time: Duration,
407}
408
409#[derive(Debug, Clone)]
411pub struct ResourceUsage<T: Float + Debug + Send + Sync + 'static> {
412 pub total_time: Duration,
414
415 pub peak_memory_mb: T,
417
418 pub compute_cost: T,
420
421 pub energy_consumption: T,
423}
424
425#[derive(Debug, Clone)]
427pub struct MetaGradients<T: Float + Debug + Send + Sync + 'static> {
428 pub param_gradients: HashMap<String, Array1<T>>,
430
431 pub lr_gradients: HashMap<String, T>,
433
434 pub arch_gradients: HashMap<String, Array1<T>>,
436
437 pub gradient_norm: T,
439}
440
441#[derive(Debug, Clone)]
443pub struct TransferState<T: Float + Debug + Send + Sync + 'static> {
444 pub representations: HashMap<String, Array1<T>>,
446
447 pub meta_parameters: HashMap<String, Array1<T>>,
449
450 pub task_embeddings: Array2<T>,
452
453 pub transfer_stats: TransferStatistics<T>,
455}
456
457#[derive(Debug, Clone)]
459pub struct TransferStatistics<T: Float + Debug + Send + Sync + 'static> {
460 pub source_performance: T,
462
463 pub target_performance: T,
465
466 pub transfer_efficiency: T,
468
469 pub steps_saved: usize,
471}
472
473pub struct PrototypicalNetwork<T: Float + Debug + Send + Sync + 'static> {
475 encoder: EncoderNetwork<T>,
477
478 prototypes: HashMap<String, Prototype<T>>,
480
481 distance_metric: DistanceMetric,
483
484 parameters: PrototypicalNetworkParams<T>,
486}
487
488#[derive(Debug)]
490pub struct EncoderNetwork<T: Float + Debug + Send + Sync + 'static> {
491 layers: Vec<EncoderLayer<T>>,
493
494 activation: ActivationFunction,
496}
497
498#[derive(Debug)]
500pub struct EncoderLayer<T: Float + Debug + Send + Sync + 'static> {
501 pub weights: Array2<T>,
503
504 pub bias: Array1<T>,
506
507 pub layer_type: LayerType,
509}
510
511#[derive(Debug, Clone, Copy)]
513pub enum LayerType {
514 Linear,
515 Convolutional,
516 Recurrent,
517 Attention,
518 Residual,
519}
520
521#[derive(Debug, Clone, Copy)]
523pub enum ActivationFunction {
524 ReLU,
525 Tanh,
526 Sigmoid,
527 Swish,
528 GELU,
529 Mish,
530}
531
532#[derive(Debug, Clone)]
534pub struct Prototype<T: Float + Debug + Send + Sync + 'static> {
535 pub vector: Array1<T>,
537
538 pub confidence: T,
540
541 pub example_count: usize,
543
544 pub last_updated: std::time::SystemTime,
546
547 pub metadata: PrototypeMetadata,
549}
550
551#[derive(Debug, Clone)]
553pub struct PrototypeMetadata {
554 pub task_category: String,
556
557 pub domain: DomainType,
559
560 pub created_at: std::time::SystemTime,
562
563 pub update_count: usize,
565}
566
567#[derive(Debug, Clone, Copy)]
569pub enum DistanceMetric {
570 Euclidean,
571 Cosine,
572 Manhattan,
573 Mahalanobis,
574 Learned,
575}
576
577#[derive(Debug, Clone)]
579pub struct PrototypicalNetworkParams<T: Float + Debug + Send + Sync + 'static> {
580 pub embedding_dim: usize,
582
583 pub learning_rate: T,
585
586 pub temperature: T,
588
589 pub prototype_update_rate: T,
591}
592
593pub struct SupportSetManager<T: Float + Debug + Send + Sync + 'static> {
595 support_sets: HashMap<String, SupportSet<T>>,
597
598 selection_strategy: SupportSetSelectionStrategy,
600
601 config: SupportSetManagerConfig,
603}
604
605#[derive(Debug, Clone, Copy)]
607pub enum SupportSetSelectionStrategy {
608 Random,
609 DiversityBased,
610 DifficultyBased,
611 UncertaintyBased,
612 PrototypeBased,
613 Adaptive,
614}
615
616#[derive(Debug, Clone)]
618pub struct SupportSetManagerConfig {
619 pub min_support_size: usize,
621
622 pub max_support_size: usize,
624
625 pub quality_threshold: f64,
627
628 pub enable_augmentation: bool,
630
631 pub cache_support_sets: bool,
633}
634
635pub trait AdaptationStrategy<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
637 fn adapt(
639 &mut self,
640 optimizer: &mut dyn FewShotOptimizer<T>,
641 task_data: &TaskData<T>,
642 config: &AdaptationConfig,
643 ) -> Result<AdaptationResult<T>>;
644
645 fn name(&self) -> &str;
647
648 fn parameters(&self) -> HashMap<String, f64>;
650
651 fn update_parameters(&mut self, params: HashMap<String, f64>) -> Result<()>;
653}
654
655pub struct TaskSimilarityCalculator<T: Float + Debug + Send + Sync + 'static> {
657 similarity_metrics: Vec<Box<dyn SimilarityMetric<T>>>,
659
660 metric_weights: HashMap<String, T>,
662
663 similarity_cache: HashMap<(String, String), T>,
665
666 config: SimilarityCalculatorConfig<T>,
668}
669
670pub trait SimilarityMetric<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
672 fn calculate_similarity(&self, task1: &TaskData<T>, task2: &TaskData<T>) -> Result<T>;
674
675 fn name(&self) -> &str;
677
678 fn weight(&self) -> T;
680}
681
682#[derive(Debug, Clone)]
684pub struct SimilarityCalculatorConfig<T: Float + Debug + Send + Sync + 'static> {
685 pub enable_caching: bool,
687
688 pub cache_size_limit: usize,
690
691 pub similarity_threshold: T,
693
694 pub use_metadata: bool,
696}
697
698pub struct EpisodicMemoryBank<T: Float + Debug + Send + Sync + 'static> {
700 episodes: VecDeque<MemoryEpisode<T>>,
702
703 config: MemoryBankConfig<T>,
705
706 usage_stats: MemoryUsageStats,
708}
709
710#[derive(Debug, Clone)]
712pub struct MemoryEpisode<T: Float + Debug + Send + Sync + 'static> {
713 pub episode_id: String,
715
716 pub task_data: TaskData<T>,
718
719 pub adaptation_result: AdaptationResult<T>,
721
722 pub timestamp: std::time::SystemTime,
724
725 pub metadata: EpisodeMetadata,
727
728 pub access_count: usize,
730}
731
732#[derive(Debug, Clone)]
734pub struct EpisodeMetadata {
735 pub difficulty: DifficultyLevel,
737
738 pub domain: DomainType,
740
741 pub success_rate: f64,
743
744 pub tags: Vec<String>,
746}
747
748#[derive(Debug, Clone)]
750pub struct MemoryBankConfig<T: Float + Debug + Send + Sync + 'static> {
751 pub max_memory_size: usize,
753
754 pub eviction_policy: EvictionPolicy,
756
757 pub similarity_threshold: T,
759
760 pub enable_compression: bool,
762}
763
764#[derive(Debug, Clone, Copy)]
766pub enum EvictionPolicy {
767 LRU, LFU, Performance, Age, Random,
772}
773
774#[derive(Debug, Clone)]
776pub struct MemoryBankStats<T: Float + Debug + Send + Sync + 'static> {
777 pub count: usize,
779
780 pub avg_performance: T,
782
783 pub capacity_used: f64,
785
786 pub total_capacity: usize,
788}
789
790#[derive(Debug, Clone)]
792pub struct MemoryUsageStats {
793 pub total_episodes: usize,
795
796 pub memory_utilization: f64,
798
799 pub hit_rate: f64,
801
802 pub avg_retrieval_time: Duration,
804}
805
806pub struct FastAdaptationEngine<T: Float + Debug + Send + Sync + 'static> {
808 algorithms: Vec<Box<dyn FastAdaptationAlgorithm<T>>>,
810
811 config: FastAdaptationConfig,
813}
814
815pub trait FastAdaptationAlgorithm<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
817 fn adapt_fast(
819 &mut self,
820 optimizer: &mut dyn FewShotOptimizer<T>,
821 task_data: &TaskData<T>,
822 target_performance: Option<T>,
823 ) -> Result<AdaptationResult<T>>;
824
825 fn estimate_adaptation_time(&self, taskdata: &TaskData<T>) -> Duration;
827
828 fn name(&self) -> &str;
830}
831
832#[derive(Debug, Clone)]
834pub struct FastAdaptationConfig {
835 pub enable_caching: bool,
837
838 pub enable_prediction: bool,
840
841 pub max_adaptation_time: Duration,
843
844 pub _performance_threshold: f64,
846}
847
848pub struct FewShotPerformanceTracker<T: Float + Debug + Send + Sync + 'static> {
850 performance_history: VecDeque<PerformanceRecord<T>>,
852
853 metrics: Vec<Box<dyn PerformanceMetric<T>>>,
855
856 config: TrackingConfig,
858
859 stats: PerformanceStats<T>,
861}
862
863#[derive(Debug, Clone)]
865pub struct PerformanceRecord<T: Float + Debug + Send + Sync + 'static> {
866 pub task_id: String,
868
869 pub performance: T,
871
872 pub adaptation_time: Duration,
874
875 pub strategy_used: String,
877
878 pub timestamp: std::time::SystemTime,
880
881 pub additional_metrics: HashMap<String, T>,
883}
884
885pub trait PerformanceMetric<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
887 fn calculate(&self, records: &[PerformanceRecord<T>]) -> Result<T>;
889
890 fn name(&self) -> &str;
892
893 fn higher_is_better(&self) -> bool;
895}
896
897#[derive(Debug, Clone)]
899pub struct TrackingConfig {
900 pub max_history_size: usize,
902
903 pub update_frequency: Duration,
905
906 pub detailed_tracking: bool,
908
909 pub export_results: bool,
911}
912
913#[derive(Debug)]
915pub struct PerformanceStats<T: Float + Debug + Send + Sync + 'static> {
916 pub best_performance: T,
918
919 pub average_performance: T,
921
922 pub performance_variance: T,
924
925 pub improvement_rate: T,
927
928 pub success_rate: T,
930}
931
932impl<T: Float + Debug + Send + Sync + 'static> FewShotLearningSystem<T> {
934 pub fn new(
936 base_optimizer: Box<dyn FewShotOptimizer<T>>,
937 config: FewShotConfig<T>,
938 ) -> Result<Self> {
939 Ok(Self {
940 base_optimizer,
941 prototype_network: PrototypicalNetwork::new(config.prototype_config)?,
942 support_set_manager: SupportSetManager::new(config.support_set_config)?,
943 adaptation_strategies: Vec::new(),
944 similarity_calculator: TaskSimilarityCalculator::new(config.similarity_config)?,
945 memory_bank: EpisodicMemoryBank::new(config.memory_config)?,
946 fast_adaptation: FastAdaptationEngine::new(config.adaptation_config)?,
947 performance_tracker: FewShotPerformanceTracker::new(config.tracking_config)?,
948 })
949 }
950
951 pub fn learn_few_shot(
953 &mut self,
954 task_data: TaskData<T>,
955 adaptation_config: AdaptationConfig,
956 ) -> Result<AdaptationResult<T>> {
957 let _start_time = Instant::now();
959
960 let task_representation = self.prototype_network.encode_task(&task_data)?;
962
963 let similar_tasks = self.memory_bank.retrieve_similar(&task_data, 5)?;
965
966 let strategy = self.select_adaptation_strategy(&task_data, &similar_tasks)?;
968
969 let mut adaptation_result = self.fast_adaptation.adapt_fast(
971 &mut *self.base_optimizer,
972 &task_data,
973 strategy,
974 &adaptation_config,
975 )?;
976
977 adaptation_result.task_representation = task_representation;
979
980 self.memory_bank
982 .store_episode(task_data.clone(), adaptation_result.clone())?;
983
984 self.performance_tracker
986 .record_performance(&adaptation_result)?;
987
988 self.prototype_network
990 .update_prototypes(&task_data, &adaptation_result)?;
991
992 Ok(adaptation_result)
993 }
994
995 fn select_adaptation_strategy(
996 &self,
997 task_data: &TaskData<T>,
998 _similar_tasks: &[MemoryEpisode<T>],
999 ) -> Result<AdaptationStrategyType> {
1000 match task_data.domain_info.difficulty_level {
1002 DifficultyLevel::Trivial | DifficultyLevel::Easy => Ok(AdaptationStrategyType::FOMAML),
1003 DifficultyLevel::Medium => Ok(AdaptationStrategyType::MAML),
1004 DifficultyLevel::Hard | DifficultyLevel::Expert => {
1005 Ok(AdaptationStrategyType::Prototypical)
1006 }
1007 DifficultyLevel::Extreme => Ok(AdaptationStrategyType::MemoryAugmented),
1008 }
1009 }
1010}
1011
1012#[derive(Debug, Clone)]
1014pub struct FewShotConfig<T: Float + Debug + Send + Sync + 'static> {
1015 pub prototype_config: PrototypicalNetworkConfig<T>,
1017
1018 pub support_set_config: SupportSetManagerConfig,
1020
1021 pub similarity_config: SimilarityCalculatorConfig<T>,
1023
1024 pub memory_config: MemoryBankConfig<T>,
1026
1027 pub adaptation_config: FastAdaptationConfig,
1029
1030 pub tracking_config: TrackingConfig,
1032}
1033
1034#[derive(Debug, Clone)]
1036pub struct PrototypicalNetworkConfig<T: Float + Debug + Send + Sync + 'static> {
1037 pub embedding_dim: usize,
1039
1040 pub learning_rate: T,
1042
1043 pub num_layers: usize,
1045
1046 pub hidden_dim: usize,
1048}
1049
1050impl<T: Float + Debug + Send + Sync + 'static> PrototypicalNetwork<T> {
1052 pub fn new(config: PrototypicalNetworkConfig<T>) -> Result<Self> {
1054 if config.embedding_dim == 0 {
1055 return Err(OptimError::InvalidConfig(
1056 "embedding_dim must be > 0".to_string(),
1057 ));
1058 }
1059 let input_dim = config.hidden_dim.max(config.embedding_dim);
1061 let layer = EncoderLayer {
1062 weights: Array2::zeros((input_dim, config.embedding_dim)),
1063 bias: Array1::zeros(config.embedding_dim),
1064 layer_type: LayerType::Linear,
1065 };
1066 Ok(Self {
1067 encoder: EncoderNetwork {
1068 layers: vec![layer],
1069 activation: ActivationFunction::ReLU,
1070 },
1071 prototypes: HashMap::new(),
1072 distance_metric: DistanceMetric::Euclidean,
1073 parameters: PrototypicalNetworkParams {
1074 embedding_dim: config.embedding_dim,
1075 learning_rate: config.learning_rate,
1076 temperature: T::one(),
1077 prototype_update_rate: scirs2_core::numeric::NumCast::from(0.1)
1078 .unwrap_or_else(|| T::one()),
1079 },
1080 })
1081 }
1082
1083 pub fn from_dims(embedding_dim: usize, _num_classes: usize) -> Result<Self> {
1085 if embedding_dim == 0 {
1086 return Err(OptimError::InvalidConfig(
1087 "embedding_dim must be > 0".to_string(),
1088 ));
1089 }
1090 let layer = EncoderLayer {
1091 weights: Array2::zeros((embedding_dim, embedding_dim)),
1092 bias: Array1::zeros(embedding_dim),
1093 layer_type: LayerType::Linear,
1094 };
1095 Ok(Self {
1096 encoder: EncoderNetwork {
1097 layers: vec![layer],
1098 activation: ActivationFunction::ReLU,
1099 },
1100 prototypes: HashMap::new(),
1101 distance_metric: DistanceMetric::Euclidean,
1102 parameters: PrototypicalNetworkParams {
1103 embedding_dim,
1104 learning_rate: scirs2_core::numeric::NumCast::from(0.01)
1105 .unwrap_or_else(|| T::one()),
1106 temperature: T::one(),
1107 prototype_update_rate: scirs2_core::numeric::NumCast::from(0.1)
1108 .unwrap_or_else(|| T::one()),
1109 },
1110 })
1111 }
1112
1113 pub fn embedding_dim(&self) -> usize {
1115 self.parameters.embedding_dim
1116 }
1117
1118 pub fn encoder_layers(&self) -> &[EncoderLayer<T>] {
1120 &self.encoder.layers
1121 }
1122
1123 pub fn prototypes_mut(&mut self) -> &mut HashMap<String, Prototype<T>> {
1125 &mut self.prototypes
1126 }
1127
1128 pub fn prototypes(&self) -> &HashMap<String, Prototype<T>> {
1130 &self.prototypes
1131 }
1132
1133 pub fn distance_metric(&self) -> &DistanceMetric {
1135 &self.distance_metric
1136 }
1137
1138 pub fn encode_task(&self, task_data: &TaskData<T>) -> Result<Array1<T>> {
1140 if task_data.support_set.examples.is_empty() {
1142 return Err(OptimError::InsufficientData(
1143 "No support examples for encoding".to_string(),
1144 ));
1145 }
1146 let dim = self.parameters.embedding_dim;
1147 let mut sum = Array1::<T>::zeros(dim);
1148 let count = task_data.support_set.examples.len();
1149 for ex in &task_data.support_set.examples {
1150 let feat = &ex.features;
1151 let len = feat.len().min(dim);
1152 for i in 0..len {
1153 sum[i] = sum[i] + feat[i];
1154 }
1155 }
1156 let count_t = scirs2_core::numeric::NumCast::from(count).unwrap_or_else(|| T::one());
1157 for i in 0..dim {
1158 sum[i] = sum[i] / count_t;
1159 }
1160 Ok(sum)
1161 }
1162
1163 pub fn update_prototypes(
1165 &mut self,
1166 task_data: &TaskData<T>,
1167 _result: &AdaptationResult<T>,
1168 ) -> Result<()> {
1169 let task_repr = self.encode_task(task_data)?;
1170 let task_id = task_data.task_id.clone();
1171 let update_rate = self.parameters.prototype_update_rate;
1172 let one_minus = T::one() - update_rate;
1173
1174 if let Some(proto) = self.prototypes.get_mut(&task_id) {
1175 let dim = proto.vector.len().min(task_repr.len());
1177 for i in 0..dim {
1178 proto.vector[i] = one_minus * proto.vector[i] + update_rate * task_repr[i];
1179 }
1180 proto.example_count += task_data.support_set.examples.len();
1181 proto.last_updated = std::time::SystemTime::now();
1182 proto.metadata.update_count += 1;
1183 } else {
1184 let proto = Prototype {
1185 vector: task_repr,
1186 confidence: T::one(),
1187 example_count: task_data.support_set.examples.len(),
1188 last_updated: std::time::SystemTime::now(),
1189 metadata: PrototypeMetadata {
1190 task_category: task_id.clone(),
1191 domain: task_data.domain_info.domain_type,
1192 created_at: std::time::SystemTime::now(),
1193 update_count: 1,
1194 },
1195 };
1196 self.prototypes.insert(task_id, proto);
1197 }
1198 Ok(())
1199 }
1200}
1201
1202impl<T: Float + Debug + Send + Sync + 'static> SupportSetManager<T> {
1203 pub fn new(config: SupportSetManagerConfig) -> Result<Self> {
1205 if config.max_support_size == 0 {
1206 return Err(OptimError::InvalidConfig(
1207 "max_support_size must be > 0".to_string(),
1208 ));
1209 }
1210 Ok(Self {
1211 support_sets: HashMap::new(),
1212 selection_strategy: SupportSetSelectionStrategy::DiversityBased,
1213 config,
1214 })
1215 }
1216
1217 pub fn from_max_size(max_support_size: usize) -> Result<Self> {
1219 Self::new(SupportSetManagerConfig {
1220 min_support_size: 1,
1221 max_support_size,
1222 quality_threshold: 0.5,
1223 enable_augmentation: true,
1224 cache_support_sets: true,
1225 })
1226 }
1227
1228 pub fn max_support_size(&self) -> usize {
1230 self.config.max_support_size
1231 }
1232
1233 pub fn config(&self) -> &SupportSetManagerConfig {
1235 &self.config
1236 }
1237}
1238
1239impl<T: Float + Debug + Send + Sync + 'static> TaskSimilarityCalculator<T> {
1240 pub fn new(config: SimilarityCalculatorConfig<T>) -> Result<Self> {
1242 Ok(Self {
1243 similarity_metrics: Vec::new(),
1244 metric_weights: HashMap::new(),
1245 similarity_cache: HashMap::new(),
1246 config,
1247 })
1248 }
1249
1250 pub fn default_new() -> Result<Self> {
1252 Self::new(SimilarityCalculatorConfig {
1253 enable_caching: true,
1254 cache_size_limit: 1000,
1255 similarity_threshold: scirs2_core::numeric::NumCast::from(0.5)
1256 .unwrap_or_else(|| T::zero()),
1257 use_metadata: true,
1258 })
1259 }
1260
1261 pub fn similarity_cache_mut(&mut self) -> &mut HashMap<(String, String), T> {
1263 &mut self.similarity_cache
1264 }
1265
1266 pub fn caching_enabled(&self) -> bool {
1268 self.config.enable_caching
1269 }
1270}
1271
1272impl<T: Float + Debug + Send + Sync + 'static> EpisodicMemoryBank<T> {
1273 pub fn new(config: MemoryBankConfig<T>) -> Result<Self> {
1275 if config.max_memory_size == 0 {
1276 return Err(OptimError::InvalidConfig(
1277 "max_memory_size must be > 0".to_string(),
1278 ));
1279 }
1280 Ok(Self {
1281 episodes: VecDeque::new(),
1282 config,
1283 usage_stats: MemoryUsageStats {
1284 total_episodes: 0,
1285 memory_utilization: 0.0,
1286 hit_rate: 0.0,
1287 avg_retrieval_time: Duration::from_secs(0),
1288 },
1289 })
1290 }
1291
1292 pub fn from_capacity(capacity: usize) -> Result<Self> {
1294 Self::new(MemoryBankConfig {
1295 max_memory_size: capacity,
1296 eviction_policy: EvictionPolicy::Performance,
1297 similarity_threshold: scirs2_core::numeric::NumCast::from(0.3)
1298 .unwrap_or_else(|| T::zero()),
1299 enable_compression: false,
1300 })
1301 }
1302
1303 pub fn capacity(&self) -> usize {
1305 self.config.max_memory_size
1306 }
1307
1308 pub fn len(&self) -> usize {
1310 self.episodes.len()
1311 }
1312
1313 pub fn is_empty(&self) -> bool {
1315 self.episodes.is_empty()
1316 }
1317
1318 pub fn episodes_mut(&mut self) -> &mut VecDeque<MemoryEpisode<T>> {
1320 &mut self.episodes
1321 }
1322
1323 pub fn episodes(&self) -> &VecDeque<MemoryEpisode<T>> {
1325 &self.episodes
1326 }
1327
1328 pub fn usage_stats_mut(&mut self) -> &mut MemoryUsageStats {
1330 &mut self.usage_stats
1331 }
1332
1333 pub fn usage_stats(&self) -> &MemoryUsageStats {
1335 &self.usage_stats
1336 }
1337
1338 pub fn eviction_policy(&self) -> EvictionPolicy {
1340 self.config.eviction_policy
1341 }
1342
1343 pub fn retrieve_similar(
1345 &self,
1346 task_data: &TaskData<T>,
1347 k: usize,
1348 ) -> Result<Vec<MemoryEpisode<T>>> {
1349 if self.episodes.is_empty() {
1350 return Ok(Vec::new());
1351 }
1352 let count = k.min(self.episodes.len());
1355 let result: Vec<MemoryEpisode<T>> =
1356 self.episodes.iter().rev().take(count).cloned().collect();
1357 let _ = task_data; Ok(result)
1359 }
1360
1361 pub fn store_episode(
1363 &mut self,
1364 task_data: TaskData<T>,
1365 result: AdaptationResult<T>,
1366 ) -> Result<()> {
1367 let episode = MemoryEpisode {
1368 episode_id: format!("ep_{}", self.usage_stats.total_episodes),
1369 task_data,
1370 adaptation_result: result,
1371 timestamp: std::time::SystemTime::now(),
1372 metadata: EpisodeMetadata {
1373 difficulty: DifficultyLevel::Medium,
1374 domain: DomainType::Optimization,
1375 success_rate: 0.0,
1376 tags: Vec::new(),
1377 },
1378 access_count: 0,
1379 };
1380
1381 if self.episodes.len() >= self.config.max_memory_size {
1382 self.episodes.pop_front();
1383 }
1384 self.episodes.push_back(episode);
1385 self.usage_stats.total_episodes += 1;
1386 self.usage_stats.memory_utilization =
1387 self.episodes.len() as f64 / self.config.max_memory_size as f64;
1388 Ok(())
1389 }
1390}
1391
1392impl<T: Float + Debug + Send + Sync + 'static> FastAdaptationEngine<T> {
1393 pub fn new(config: FastAdaptationConfig) -> Result<Self> {
1395 Ok(Self {
1396 algorithms: Vec::new(),
1397 config,
1398 })
1399 }
1400
1401 pub fn from_params(inner_lr: T, adaptation_steps: usize) -> Result<Self> {
1403 let _ = (inner_lr, adaptation_steps);
1404 Self::new(FastAdaptationConfig {
1405 enable_caching: true,
1406 enable_prediction: true,
1407 max_adaptation_time: Duration::from_secs(60),
1408 _performance_threshold: 0.8,
1409 })
1410 }
1411
1412 pub fn config(&self) -> &FastAdaptationConfig {
1414 &self.config
1415 }
1416
1417 pub fn adapt_fast(
1419 &mut self,
1420 _optimizer: &mut dyn FewShotOptimizer<T>,
1421 _task_data: &TaskData<T>,
1422 _strategy: AdaptationStrategyType,
1423 _config: &AdaptationConfig,
1424 ) -> Result<AdaptationResult<T>> {
1425 Ok(AdaptationResult {
1426 adapted_state: OptimizerState {
1427 parameters: Array1::zeros(1),
1428 gradients: Array1::zeros(1),
1429 momentum: None,
1430 hidden_states: HashMap::new(),
1431 memory_buffers: HashMap::new(),
1432 step: 0,
1433 step_count: 0,
1434 loss: None,
1435 learning_rate: scirs2_core::numeric::NumCast::from(0.001)
1436 .unwrap_or_else(|| T::one()),
1437 metadata: super::StateMetadata {
1438 task_id: None,
1439 optimizer_type: None,
1440 version: "1.0".to_string(),
1441 timestamp: std::time::SystemTime::now(),
1442 checksum: 0,
1443 compression_level: 0,
1444 custom_data: HashMap::new(),
1445 },
1446 },
1447 performance: AdaptationPerformance {
1448 query_performance: scirs2_core::numeric::NumCast::from(0.85)
1449 .unwrap_or_else(|| T::zero()),
1450 support_performance: scirs2_core::numeric::NumCast::from(0.90)
1451 .unwrap_or_else(|| T::zero()),
1452 adaptation_speed: 5,
1453 final_loss: scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero()),
1454 improvement: scirs2_core::numeric::NumCast::from(0.25).unwrap_or_else(|| T::zero()),
1455 stability: scirs2_core::numeric::NumCast::from(0.95).unwrap_or_else(|| T::zero()),
1456 },
1457 task_representation: Array1::zeros(128),
1458 adaptation_trajectory: Vec::new(),
1459 resource_usage: ResourceUsage {
1460 total_time: Duration::from_secs(15),
1461 peak_memory_mb: scirs2_core::numeric::NumCast::from(256.0)
1462 .unwrap_or_else(|| T::zero()),
1463 compute_cost: scirs2_core::numeric::NumCast::from(5.0).unwrap_or_else(|| T::zero()),
1464 energy_consumption: scirs2_core::numeric::NumCast::from(0.05)
1465 .unwrap_or_else(|| T::zero()),
1466 },
1467 })
1468 }
1469}
1470
1471impl<T: Float + Debug + Send + Sync + 'static> FewShotPerformanceTracker<T> {
1472 pub fn new(config: TrackingConfig) -> Result<Self> {
1474 Ok(Self {
1475 performance_history: VecDeque::new(),
1476 metrics: Vec::new(),
1477 config,
1478 stats: PerformanceStats {
1479 best_performance: T::zero(),
1480 average_performance: T::zero(),
1481 performance_variance: T::zero(),
1482 improvement_rate: T::zero(),
1483 success_rate: T::zero(),
1484 },
1485 })
1486 }
1487
1488 pub fn record_performance(&mut self, result: &AdaptationResult<T>) -> Result<()> {
1490 let record = PerformanceRecord {
1491 task_id: String::new(),
1492 performance: result.performance.query_performance,
1493 adaptation_time: result.resource_usage.total_time,
1494 strategy_used: String::new(),
1495 timestamp: std::time::SystemTime::now(),
1496 additional_metrics: HashMap::new(),
1497 };
1498
1499 if self.performance_history.len() >= self.config.max_history_size {
1500 self.performance_history.pop_front();
1501 }
1502 self.performance_history.push_back(record);
1503
1504 if !self.performance_history.is_empty() {
1506 let mut sum = T::zero();
1507 let mut best = T::neg_infinity();
1508 for r in &self.performance_history {
1509 sum = sum + r.performance;
1510 if r.performance > best {
1511 best = r.performance;
1512 }
1513 }
1514 let count_t = scirs2_core::numeric::NumCast::from(self.performance_history.len())
1515 .unwrap_or_else(|| T::one());
1516 self.stats.average_performance = sum / count_t;
1517 self.stats.best_performance = best;
1518 }
1519 Ok(())
1520 }
1521}
1522
1523#[derive(Debug, Clone)]
1525pub struct TaskMetadata {
1526 pub task_name: String,
1527 pub domain: DomainType,
1528 pub difficulty: DifficultyLevel,
1529 pub created_at: std::time::SystemTime,
1530}
1531
1532#[derive(Debug, Clone)]
1533pub struct ExampleMetadata {
1534 pub source: String,
1535 pub quality_score: f64,
1536 pub created_at: std::time::SystemTime,
1537}
1538
1539#[derive(Debug, Clone)]
1540pub struct SupportSetStatistics<T: Float + Debug + Send + Sync + 'static> {
1541 pub mean: Array1<T>,
1542 pub variance: Array1<T>,
1543 pub size: usize,
1544 pub diversity_score: T,
1545}
1546
1547#[derive(Debug, Clone)]
1548pub struct QuerySetStatistics<T: Float + Debug + Send + Sync + 'static> {
1549 pub mean: Array1<T>,
1550 pub variance: Array1<T>,
1551 pub size: usize,
1552}
1553
1554#[derive(Debug, Clone, Copy)]
1556pub enum EvaluationMetric {
1557 Accuracy,
1558 Loss,
1559 F1Score,
1560 Precision,
1561 Recall,
1562 AUC,
1563 MSE,
1564 MAE,
1565 FinalPerformance,
1566 Efficiency,
1567 TrainingTime,
1568}
1569
1570#[cfg(test)]
1571mod tests {
1572 use super::*;
1573
1574 #[test]
1575 fn test_support_set_creation() {
1576 let support_set = SupportSet::<f64> {
1577 examples: vec![SupportExample {
1578 features: Array1::from_vec(vec![1.0, 2.0, 3.0]),
1579 target: 0.5,
1580 weight: 1.0,
1581 context: HashMap::new(),
1582 metadata: ExampleMetadata {
1583 source: "test".to_string(),
1584 quality_score: 0.9,
1585 created_at: std::time::SystemTime::now(),
1586 },
1587 }],
1588 task_metadata: TaskMetadata {
1589 task_name: "test_task".to_string(),
1590 domain: DomainType::Optimization,
1591 difficulty: DifficultyLevel::Easy,
1592 created_at: std::time::SystemTime::now(),
1593 },
1594 statistics: SupportSetStatistics {
1595 mean: Array1::from_vec(vec![1.0, 2.0, 3.0]),
1596 variance: Array1::from_vec(vec![0.1, 0.1, 0.1]),
1597 size: 1,
1598 diversity_score: 0.8,
1599 },
1600 temporal_order: None,
1601 };
1602
1603 assert_eq!(support_set.examples.len(), 1);
1604 assert_eq!(support_set.statistics.size, 1);
1605 }
1606
1607 #[test]
1608 fn test_adaptation_config() {
1609 let config = AdaptationConfig {
1610 adaptation_steps: 10,
1611 adaptation_lr: 0.01,
1612 strategy: AdaptationStrategyType::MAML,
1613 early_stopping: None,
1614 regularization: RegularizationConfig {
1615 l2_strength: 0.001,
1616 dropout_rate: 0.1,
1617 gradient_clip: Some(1.0),
1618 task_regularization: HashMap::new(),
1619 },
1620 resource_constraints: ResourceConstraints {
1621 max_time: Duration::from_secs(60),
1622 max_memory_mb: 1024,
1623 max_compute_budget: 100.0,
1624 },
1625 };
1626
1627 assert_eq!(config.adaptation_steps, 10);
1628 assert_eq!(config.adaptation_lr, 0.01);
1629 assert!(matches!(config.strategy, AdaptationStrategyType::MAML));
1630 }
1631
1632 #[test]
1633 fn test_domain_info() {
1634 let domain_info = DomainInfo {
1635 domain_type: DomainType::ComputerVision,
1636 characteristics: DomainCharacteristics {
1637 input_dim: 784,
1638 output_dim: 10,
1639 temporal: false,
1640 stochasticity: 0.1,
1641 noise_level: 0.05,
1642 sparsity: 0.0,
1643 },
1644 difficulty_level: DifficultyLevel::Medium,
1645 constraints: vec![DomainConstraint {
1646 constraint_type: ConstraintType::LatencyRequirement,
1647 description: "Max 100ms inference time".to_string(),
1648 enforcement: ConstraintEnforcement::Hard,
1649 }],
1650 };
1651
1652 assert!(matches!(
1653 domain_info.domain_type,
1654 DomainType::ComputerVision
1655 ));
1656 assert_eq!(domain_info.characteristics.input_dim, 784);
1657 assert_eq!(domain_info.constraints.len(), 1);
1658 }
1659}