Skip to main content

optirs_learned/
few_shot.rs

1// Few-Shot Learning Enhancement for Optimizer Meta-Learning
2//
3// This module implements advanced few-shot learning techniques specifically designed
4// for quickly adapting optimizers to new tasks with minimal data. It includes
5// prototypical networks, meta-learning approaches, and rapid adaptation mechanisms.
6
7#[allow(dead_code)]
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::numeric::Float;
10use std::collections::{HashMap, VecDeque};
11use std::fmt::Debug;
12use std::time::{Duration, Instant};
13
14use super::OptimizerState;
15use crate::error::{OptimError, Result};
16
17/// Few-shot learning coordinator for optimizer adaptation
18pub struct FewShotLearningSystem<T: Float + Debug + Send + Sync + 'static> {
19    /// Base meta-learned optimizer
20    base_optimizer: Box<dyn FewShotOptimizer<T>>,
21
22    /// Prototypical network for task representation
23    prototype_network: PrototypicalNetwork<T>,
24
25    /// Support set manager
26    support_set_manager: SupportSetManager<T>,
27
28    /// Adaptation strategies
29    adaptation_strategies: Vec<Box<dyn AdaptationStrategy<T>>>,
30
31    /// Task similarity calculator
32    similarity_calculator: TaskSimilarityCalculator<T>,
33
34    /// Memory bank for storing task experiences
35    memory_bank: EpisodicMemoryBank<T>,
36
37    /// Fast adaptation engine
38    fast_adaptation: FastAdaptationEngine<T>,
39
40    /// Performance tracker
41    performance_tracker: FewShotPerformanceTracker<T>,
42}
43
44/// Base trait for few-shot optimizers
45pub trait FewShotOptimizer<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
46    /// Adapt to new task with few examples
47    fn adapt_few_shot(
48        &mut self,
49        support_set: &SupportSet<T>,
50        query_set: &QuerySet<T>,
51        adaptation_config: &AdaptationConfig,
52    ) -> Result<AdaptationResult<T>>;
53
54    /// Get task representation
55    fn get_task_representation(&self, taskdata: &TaskData<T>) -> Result<Array1<T>>;
56
57    /// Compute adaptation loss
58    fn compute_adaptation_loss(
59        &self,
60        support_set: &SupportSet<T>,
61        query_set: &QuerySet<T>,
62    ) -> Result<T>;
63
64    /// Update meta-parameters
65    fn update_meta_parameters(&mut self, metagradients: &MetaGradients<T>) -> Result<()>;
66
67    /// Get current state for transfer
68    fn get_transfer_state(&self) -> TransferState<T>;
69
70    /// Load transfer state
71    fn load_transfer_state(&mut self, state: TransferState<T>) -> Result<()>;
72}
73
74/// Support set for few-shot learning
75#[derive(Debug, Clone)]
76pub struct SupportSet<T: Float + Debug + Send + Sync + 'static> {
77    /// Support examples
78    pub examples: Vec<SupportExample<T>>,
79
80    /// Task metadata
81    pub task_metadata: TaskMetadata,
82
83    /// Support _set statistics
84    pub statistics: SupportSetStatistics<T>,
85
86    /// Temporal ordering (if applicable)
87    pub temporal_order: Option<Vec<usize>>,
88}
89
90/// Individual support example
91#[derive(Debug, Clone)]
92pub struct SupportExample<T: Float + Debug + Send + Sync + 'static> {
93    /// Input features
94    pub features: Array1<T>,
95
96    /// Target output
97    pub target: T,
98
99    /// Example weight/importance
100    pub weight: T,
101
102    /// Context information
103    pub context: HashMap<String, T>,
104
105    /// Example metadata
106    pub metadata: ExampleMetadata,
107}
108
109/// Query set for evaluation
110#[derive(Debug, Clone)]
111pub struct QuerySet<T: Float + Debug + Send + Sync + 'static> {
112    /// Query examples
113    pub examples: Vec<QueryExample<T>>,
114
115    /// Query statistics
116    pub statistics: QuerySetStatistics<T>,
117
118    /// Evaluation metrics
119    pub eval_metrics: Vec<EvaluationMetric>,
120}
121
122/// Individual query example
123#[derive(Debug, Clone)]
124pub struct QueryExample<T: Float + Debug + Send + Sync + 'static> {
125    /// Input features
126    pub features: Array1<T>,
127
128    /// True target (for evaluation)
129    pub true_target: Option<T>,
130
131    /// Query weight
132    pub weight: T,
133
134    /// Query context
135    pub context: HashMap<String, T>,
136}
137
138/// Task data container
139#[derive(Debug, Clone)]
140pub struct TaskData<T: Float + Debug + Send + Sync + 'static> {
141    /// Task identifier
142    pub task_id: String,
143
144    /// Support set
145    pub support_set: SupportSet<T>,
146
147    /// Query set
148    pub query_set: QuerySet<T>,
149
150    /// Task-specific parameters
151    pub task_params: HashMap<String, T>,
152
153    /// Task domain information
154    pub domain_info: DomainInfo,
155}
156
157/// Domain information
158#[derive(Debug, Clone)]
159pub struct DomainInfo {
160    /// Domain type
161    pub domain_type: DomainType,
162
163    /// Domain characteristics
164    pub characteristics: DomainCharacteristics,
165
166    /// Expected difficulty
167    pub difficulty_level: DifficultyLevel,
168
169    /// Domain-specific constraints
170    pub constraints: Vec<DomainConstraint>,
171}
172
173/// Domain types
174#[derive(Debug, Clone, Copy)]
175pub enum DomainType {
176    ComputerVision,
177    NaturalLanguageProcessing,
178    ReinforcementLearning,
179    TimeSeriesForecasting,
180    ScientificComputing,
181    Optimization,
182    ControlSystems,
183    GamePlaying,
184    Robotics,
185    Healthcare,
186}
187
188/// Domain characteristics
189#[derive(Debug, Clone)]
190pub struct DomainCharacteristics {
191    /// Input dimensionality
192    pub input_dim: usize,
193
194    /// Output dimensionality
195    pub output_dim: usize,
196
197    /// Temporal dependencies
198    pub temporal: bool,
199
200    /// Stochasticity level
201    pub stochasticity: f64,
202
203    /// Noise level
204    pub noise_level: f64,
205
206    /// Data sparsity
207    pub sparsity: f64,
208}
209
210/// Difficulty levels
211#[derive(Debug, Clone, Copy)]
212pub enum DifficultyLevel {
213    Trivial,
214    Easy,
215    Medium,
216    Hard,
217    Expert,
218    Extreme,
219}
220
221/// Domain constraints
222#[derive(Debug, Clone)]
223pub struct DomainConstraint {
224    /// Constraint type
225    pub constraint_type: ConstraintType,
226
227    /// Constraint description
228    pub description: String,
229
230    /// Enforcement level
231    pub enforcement: ConstraintEnforcement,
232}
233
234/// Constraint types
235#[derive(Debug, Clone, Copy)]
236pub enum ConstraintType {
237    ResourceLimit,
238    TemporalConstraint,
239    AccuracyRequirement,
240    LatencyRequirement,
241    MemoryConstraint,
242    EnergyConstraint,
243    SafetyConstraint,
244}
245
246/// Constraint enforcement levels
247#[derive(Debug, Clone, Copy)]
248pub enum ConstraintEnforcement {
249    Hard,
250    Soft,
251    Advisory,
252}
253
254/// Adaptation configuration
255#[derive(Debug, Clone)]
256pub struct AdaptationConfig {
257    /// Number of adaptation steps
258    pub adaptation_steps: usize,
259
260    /// Learning rate for adaptation
261    pub adaptation_lr: f64,
262
263    /// Adaptation strategy
264    pub strategy: AdaptationStrategyType,
265
266    /// Early stopping configuration
267    pub early_stopping: Option<EarlyStoppingConfig>,
268
269    /// Regularization parameters
270    pub regularization: RegularizationConfig,
271
272    /// Resource constraints
273    pub resource_constraints: ResourceConstraints,
274}
275
276/// Adaptation strategy types
277#[derive(Debug, Clone, Copy)]
278pub enum AdaptationStrategyType {
279    /// Model-Agnostic Meta-Learning (MAML)
280    MAML,
281
282    /// First-Order MAML (FOMAML)
283    FOMAML,
284
285    /// Prototypical Networks
286    Prototypical,
287
288    /// Matching Networks
289    Matching,
290
291    /// Relation Networks
292    Relation,
293
294    /// Meta-SGD
295    MetaSGD,
296
297    /// Learned optimizer approach
298    LearnedOptimizer,
299
300    /// Gradient-based meta-learning
301    GradientBased,
302
303    /// Memory-augmented networks
304    MemoryAugmented,
305}
306
307/// Early stopping configuration
308#[derive(Debug, Clone)]
309pub struct EarlyStoppingConfig {
310    /// Patience (steps without improvement)
311    pub patience: usize,
312
313    /// Minimum improvement threshold
314    pub min_improvement: f64,
315
316    /// Validation frequency
317    pub validation_frequency: usize,
318}
319
320/// Regularization configuration
321#[derive(Debug, Clone)]
322pub struct RegularizationConfig {
323    /// L2 regularization strength
324    pub l2_strength: f64,
325
326    /// Dropout rate
327    pub dropout_rate: f64,
328
329    /// Gradient clipping threshold
330    pub gradient_clip: Option<f64>,
331
332    /// Task-specific regularization
333    pub task_regularization: HashMap<String, f64>,
334}
335
336/// Resource constraints for adaptation
337#[derive(Debug, Clone)]
338pub struct ResourceConstraints {
339    /// Maximum adaptation time
340    pub max_time: Duration,
341
342    /// Maximum memory usage (MB)
343    pub max_memory_mb: usize,
344
345    /// Maximum computational budget
346    pub max_compute_budget: f64,
347}
348
349/// Adaptation result
350#[derive(Debug, Clone)]
351pub struct AdaptationResult<T: Float + Debug + Send + Sync + 'static> {
352    /// Adapted optimizer state
353    pub adapted_state: OptimizerState<T>,
354
355    /// Adaptation performance
356    pub performance: AdaptationPerformance<T>,
357
358    /// Task representation learned
359    pub task_representation: Array1<T>,
360
361    /// Adaptation trajectory
362    pub adaptation_trajectory: Vec<AdaptationStep<T>>,
363
364    /// Resource usage
365    pub resource_usage: ResourceUsage<T>,
366}
367
368/// Adaptation performance metrics
369#[derive(Debug, Clone)]
370pub struct AdaptationPerformance<T: Float + Debug + Send + Sync + 'static> {
371    /// Query set performance
372    pub query_performance: T,
373
374    /// Support set performance
375    pub support_performance: T,
376
377    /// Adaptation speed (steps to convergence)
378    pub adaptation_speed: usize,
379
380    /// Final loss
381    pub final_loss: T,
382
383    /// Performance improvement
384    pub improvement: T,
385
386    /// Stability measure
387    pub stability: T,
388}
389
390/// Individual adaptation step
391#[derive(Debug, Clone)]
392pub struct AdaptationStep<T: Float + Debug + Send + Sync + 'static> {
393    /// Step number
394    pub step: usize,
395
396    /// Loss at this step
397    pub loss: T,
398
399    /// Performance at this step
400    pub performance: T,
401
402    /// Gradient norm
403    pub gradient_norm: T,
404
405    /// Step time
406    pub step_time: Duration,
407}
408
409/// Resource usage tracking
410#[derive(Debug, Clone)]
411pub struct ResourceUsage<T: Float + Debug + Send + Sync + 'static> {
412    /// Total time taken
413    pub total_time: Duration,
414
415    /// Peak memory usage (MB)
416    pub peak_memory_mb: T,
417
418    /// Computational cost
419    pub compute_cost: T,
420
421    /// Energy consumption
422    pub energy_consumption: T,
423}
424
425/// Meta-gradients for updating meta-parameters
426#[derive(Debug, Clone)]
427pub struct MetaGradients<T: Float + Debug + Send + Sync + 'static> {
428    /// Parameter gradients
429    pub param_gradients: HashMap<String, Array1<T>>,
430
431    /// Learning rate gradients
432    pub lr_gradients: HashMap<String, T>,
433
434    /// Architecture gradients
435    pub arch_gradients: HashMap<String, Array1<T>>,
436
437    /// Meta-gradient norm
438    pub gradient_norm: T,
439}
440
441/// Transfer state for cross-task transfer
442#[derive(Debug, Clone)]
443pub struct TransferState<T: Float + Debug + Send + Sync + 'static> {
444    /// Learned representations
445    pub representations: HashMap<String, Array1<T>>,
446
447    /// Meta-parameters
448    pub meta_parameters: HashMap<String, Array1<T>>,
449
450    /// Task embeddings
451    pub task_embeddings: Array2<T>,
452
453    /// Transfer statistics
454    pub transfer_stats: TransferStatistics<T>,
455}
456
457/// Transfer statistics
458#[derive(Debug, Clone)]
459pub struct TransferStatistics<T: Float + Debug + Send + Sync + 'static> {
460    /// Source task performance
461    pub source_performance: T,
462
463    /// Target task performance
464    pub target_performance: T,
465
466    /// Transfer efficiency
467    pub transfer_efficiency: T,
468
469    /// Adaptation steps saved
470    pub steps_saved: usize,
471}
472
473/// Prototypical network for task representation
474pub struct PrototypicalNetwork<T: Float + Debug + Send + Sync + 'static> {
475    /// Encoder network
476    encoder: EncoderNetwork<T>,
477
478    /// Prototype storage
479    prototypes: HashMap<String, Prototype<T>>,
480
481    /// Distance metric
482    distance_metric: DistanceMetric,
483
484    /// Network parameters
485    parameters: PrototypicalNetworkParams<T>,
486}
487
488/// Encoder network for feature extraction
489#[derive(Debug)]
490pub struct EncoderNetwork<T: Float + Debug + Send + Sync + 'static> {
491    /// Network layers
492    layers: Vec<EncoderLayer<T>>,
493
494    /// Activation function
495    activation: ActivationFunction,
496}
497
498/// Individual encoder layer
499#[derive(Debug)]
500pub struct EncoderLayer<T: Float + Debug + Send + Sync + 'static> {
501    /// Weight matrix
502    pub weights: Array2<T>,
503
504    /// Bias vector
505    pub bias: Array1<T>,
506
507    /// Layer type
508    pub layer_type: LayerType,
509}
510
511/// Layer types
512#[derive(Debug, Clone, Copy)]
513pub enum LayerType {
514    Linear,
515    Convolutional,
516    Recurrent,
517    Attention,
518    Residual,
519}
520
521/// Activation functions
522#[derive(Debug, Clone, Copy)]
523pub enum ActivationFunction {
524    ReLU,
525    Tanh,
526    Sigmoid,
527    Swish,
528    GELU,
529    Mish,
530}
531
532/// Task prototypes
533#[derive(Debug, Clone)]
534pub struct Prototype<T: Float + Debug + Send + Sync + 'static> {
535    /// Prototype vector
536    pub vector: Array1<T>,
537
538    /// Prototype confidence
539    pub confidence: T,
540
541    /// Number of examples used
542    pub example_count: usize,
543
544    /// Last update time
545    pub last_updated: std::time::SystemTime,
546
547    /// Prototype metadata
548    pub metadata: PrototypeMetadata,
549}
550
551/// Prototype metadata
552#[derive(Debug, Clone)]
553pub struct PrototypeMetadata {
554    /// Task category
555    pub task_category: String,
556
557    /// Domain type
558    pub domain: DomainType,
559
560    /// Creation timestamp
561    pub created_at: std::time::SystemTime,
562
563    /// Update count
564    pub update_count: usize,
565}
566
567/// Distance metrics for prototype comparison
568#[derive(Debug, Clone, Copy)]
569pub enum DistanceMetric {
570    Euclidean,
571    Cosine,
572    Manhattan,
573    Mahalanobis,
574    Learned,
575}
576
577/// Prototypical network parameters
578#[derive(Debug, Clone)]
579pub struct PrototypicalNetworkParams<T: Float + Debug + Send + Sync + 'static> {
580    /// Embedding dimension
581    pub embedding_dim: usize,
582
583    /// Learning rate
584    pub learning_rate: T,
585
586    /// Temperature parameter
587    pub temperature: T,
588
589    /// Prototype update rate
590    pub prototype_update_rate: T,
591}
592
593/// Support set manager
594pub struct SupportSetManager<T: Float + Debug + Send + Sync + 'static> {
595    /// Current support sets
596    support_sets: HashMap<String, SupportSet<T>>,
597
598    /// Support set selection strategy
599    selection_strategy: SupportSetSelectionStrategy,
600
601    /// Manager configuration
602    config: SupportSetManagerConfig,
603}
604
605/// Support set selection strategies
606#[derive(Debug, Clone, Copy)]
607pub enum SupportSetSelectionStrategy {
608    Random,
609    DiversityBased,
610    DifficultyBased,
611    UncertaintyBased,
612    PrototypeBased,
613    Adaptive,
614}
615
616/// Support set manager configuration
617#[derive(Debug, Clone)]
618pub struct SupportSetManagerConfig {
619    /// Minimum support set size
620    pub min_support_size: usize,
621
622    /// Maximum support set size
623    pub max_support_size: usize,
624
625    /// Quality threshold
626    pub quality_threshold: f64,
627
628    /// Enable augmentation
629    pub enable_augmentation: bool,
630
631    /// Cache support sets
632    pub cache_support_sets: bool,
633}
634
635/// Adaptation strategy trait
636pub trait AdaptationStrategy<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
637    /// Perform adaptation
638    fn adapt(
639        &mut self,
640        optimizer: &mut dyn FewShotOptimizer<T>,
641        task_data: &TaskData<T>,
642        config: &AdaptationConfig,
643    ) -> Result<AdaptationResult<T>>;
644
645    /// Get strategy name
646    fn name(&self) -> &str;
647
648    /// Get strategy parameters
649    fn parameters(&self) -> HashMap<String, f64>;
650
651    /// Update strategy parameters
652    fn update_parameters(&mut self, params: HashMap<String, f64>) -> Result<()>;
653}
654
655/// Task similarity calculator
656pub struct TaskSimilarityCalculator<T: Float + Debug + Send + Sync + 'static> {
657    /// Similarity metrics
658    similarity_metrics: Vec<Box<dyn SimilarityMetric<T>>>,
659
660    /// Metric weights
661    metric_weights: HashMap<String, T>,
662
663    /// Similarity cache
664    similarity_cache: HashMap<(String, String), T>,
665
666    /// Calculator configuration
667    config: SimilarityCalculatorConfig<T>,
668}
669
670/// Similarity metric trait
671pub trait SimilarityMetric<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
672    /// Calculate similarity between tasks
673    fn calculate_similarity(&self, task1: &TaskData<T>, task2: &TaskData<T>) -> Result<T>;
674
675    /// Get metric name
676    fn name(&self) -> &str;
677
678    /// Get metric weight
679    fn weight(&self) -> T;
680}
681
682/// Similarity calculator configuration
683#[derive(Debug, Clone)]
684pub struct SimilarityCalculatorConfig<T: Float + Debug + Send + Sync + 'static> {
685    /// Enable caching
686    pub enable_caching: bool,
687
688    /// Cache size limit
689    pub cache_size_limit: usize,
690
691    /// Similarity threshold
692    pub similarity_threshold: T,
693
694    /// Use task metadata
695    pub use_metadata: bool,
696}
697
698/// Episodic memory bank for storing task experiences
699pub struct EpisodicMemoryBank<T: Float + Debug + Send + Sync + 'static> {
700    /// Memory episodes
701    episodes: VecDeque<MemoryEpisode<T>>,
702
703    /// Memory configuration
704    config: MemoryBankConfig<T>,
705
706    /// Usage statistics
707    usage_stats: MemoryUsageStats,
708}
709
710/// Memory episode
711#[derive(Debug, Clone)]
712pub struct MemoryEpisode<T: Float + Debug + Send + Sync + 'static> {
713    /// Episode ID
714    pub episode_id: String,
715
716    /// Task data
717    pub task_data: TaskData<T>,
718
719    /// Adaptation result
720    pub adaptation_result: AdaptationResult<T>,
721
722    /// Episode timestamp
723    pub timestamp: std::time::SystemTime,
724
725    /// Episode metadata
726    pub metadata: EpisodeMetadata,
727
728    /// Access count
729    pub access_count: usize,
730}
731
732/// Episode metadata
733#[derive(Debug, Clone)]
734pub struct EpisodeMetadata {
735    /// Task difficulty
736    pub difficulty: DifficultyLevel,
737
738    /// Domain type
739    pub domain: DomainType,
740
741    /// Success rate
742    pub success_rate: f64,
743
744    /// Tags
745    pub tags: Vec<String>,
746}
747
748/// Memory bank configuration
749#[derive(Debug, Clone)]
750pub struct MemoryBankConfig<T: Float + Debug + Send + Sync + 'static> {
751    /// Maximum memory size
752    pub max_memory_size: usize,
753
754    /// Eviction policy
755    pub eviction_policy: EvictionPolicy,
756
757    /// Similarity threshold for retrieval
758    pub similarity_threshold: T,
759
760    /// Enable compression
761    pub enable_compression: bool,
762}
763
764/// Memory eviction policies
765#[derive(Debug, Clone, Copy)]
766pub enum EvictionPolicy {
767    LRU,         // Least Recently Used
768    LFU,         // Least Frequently Used
769    Performance, // Worst Performing
770    Age,         // Oldest First
771    Random,
772}
773
774/// Memory bank statistics (summary view)
775#[derive(Debug, Clone)]
776pub struct MemoryBankStats<T: Float + Debug + Send + Sync + 'static> {
777    /// Number of stored episodes
778    pub count: usize,
779
780    /// Average performance across episodes
781    pub avg_performance: T,
782
783    /// Capacity used as a fraction
784    pub capacity_used: f64,
785
786    /// Total capacity
787    pub total_capacity: usize,
788}
789
790/// Memory usage statistics
791#[derive(Debug, Clone)]
792pub struct MemoryUsageStats {
793    /// Total episodes stored
794    pub total_episodes: usize,
795
796    /// Memory utilization
797    pub memory_utilization: f64,
798
799    /// Hit rate
800    pub hit_rate: f64,
801
802    /// Average retrieval time
803    pub avg_retrieval_time: Duration,
804}
805
806/// Fast adaptation engine
807pub struct FastAdaptationEngine<T: Float + Debug + Send + Sync + 'static> {
808    /// Adaptation algorithms
809    algorithms: Vec<Box<dyn FastAdaptationAlgorithm<T>>>,
810
811    /// Engine configuration
812    config: FastAdaptationConfig,
813}
814
815/// Fast adaptation algorithm trait
816pub trait FastAdaptationAlgorithm<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
817    /// Perform fast adaptation
818    fn adapt_fast(
819        &mut self,
820        optimizer: &mut dyn FewShotOptimizer<T>,
821        task_data: &TaskData<T>,
822        target_performance: Option<T>,
823    ) -> Result<AdaptationResult<T>>;
824
825    /// Estimate adaptation time
826    fn estimate_adaptation_time(&self, taskdata: &TaskData<T>) -> Duration;
827
828    /// Get algorithm name
829    fn name(&self) -> &str;
830}
831
832/// Fast adaptation configuration
833#[derive(Debug, Clone)]
834pub struct FastAdaptationConfig {
835    /// Enable caching
836    pub enable_caching: bool,
837
838    /// Enable prediction
839    pub enable_prediction: bool,
840
841    /// Maximum adaptation time
842    pub max_adaptation_time: Duration,
843
844    /// Performance threshold
845    pub _performance_threshold: f64,
846}
847
848/// Performance tracker for few-shot learning
849pub struct FewShotPerformanceTracker<T: Float + Debug + Send + Sync + 'static> {
850    /// Performance history
851    performance_history: VecDeque<PerformanceRecord<T>>,
852
853    /// Performance metrics
854    metrics: Vec<Box<dyn PerformanceMetric<T>>>,
855
856    /// Tracking configuration
857    config: TrackingConfig,
858
859    /// Performance statistics
860    stats: PerformanceStats<T>,
861}
862
863/// Performance record
864#[derive(Debug, Clone)]
865pub struct PerformanceRecord<T: Float + Debug + Send + Sync + 'static> {
866    /// Task ID
867    pub task_id: String,
868
869    /// Performance value
870    pub performance: T,
871
872    /// Adaptation time
873    pub adaptation_time: Duration,
874
875    /// Strategy used
876    pub strategy_used: String,
877
878    /// Timestamp
879    pub timestamp: std::time::SystemTime,
880
881    /// Additional metrics
882    pub additional_metrics: HashMap<String, T>,
883}
884
885/// Performance metric trait
886pub trait PerformanceMetric<T: Float + Debug + Send + Sync + 'static>: Send + Sync {
887    /// Calculate performance metric
888    fn calculate(&self, records: &[PerformanceRecord<T>]) -> Result<T>;
889
890    /// Get metric name
891    fn name(&self) -> &str;
892
893    /// Is higher better
894    fn higher_is_better(&self) -> bool;
895}
896
897/// Tracking configuration
898#[derive(Debug, Clone)]
899pub struct TrackingConfig {
900    /// Maximum history size
901    pub max_history_size: usize,
902
903    /// Update frequency
904    pub update_frequency: Duration,
905
906    /// Enable detailed tracking
907    pub detailed_tracking: bool,
908
909    /// Export results
910    pub export_results: bool,
911}
912
913/// Performance statistics
914#[derive(Debug)]
915pub struct PerformanceStats<T: Float + Debug + Send + Sync + 'static> {
916    /// Best performance
917    pub best_performance: T,
918
919    /// Average performance
920    pub average_performance: T,
921
922    /// Performance variance
923    pub performance_variance: T,
924
925    /// Improvement rate
926    pub improvement_rate: T,
927
928    /// Success rate
929    pub success_rate: T,
930}
931
932// Implementation stubs for key structures
933impl<T: Float + Debug + Send + Sync + 'static> FewShotLearningSystem<T> {
934    /// Create new few-shot learning system
935    pub fn new(
936        base_optimizer: Box<dyn FewShotOptimizer<T>>,
937        config: FewShotConfig<T>,
938    ) -> Result<Self> {
939        Ok(Self {
940            base_optimizer,
941            prototype_network: PrototypicalNetwork::new(config.prototype_config)?,
942            support_set_manager: SupportSetManager::new(config.support_set_config)?,
943            adaptation_strategies: Vec::new(),
944            similarity_calculator: TaskSimilarityCalculator::new(config.similarity_config)?,
945            memory_bank: EpisodicMemoryBank::new(config.memory_config)?,
946            fast_adaptation: FastAdaptationEngine::new(config.adaptation_config)?,
947            performance_tracker: FewShotPerformanceTracker::new(config.tracking_config)?,
948        })
949    }
950
951    /// Learn from few examples
952    pub fn learn_few_shot(
953        &mut self,
954        task_data: TaskData<T>,
955        adaptation_config: AdaptationConfig,
956    ) -> Result<AdaptationResult<T>> {
957        // Comprehensive few-shot learning implementation
958        let _start_time = Instant::now();
959
960        // 1. Extract task representation using prototypical network
961        let task_representation = self.prototype_network.encode_task(&task_data)?;
962
963        // 2. Retrieve similar tasks from memory
964        let similar_tasks = self.memory_bank.retrieve_similar(&task_data, 5)?;
965
966        // 3. Select best adaptation strategy based on task characteristics
967        let strategy = self.select_adaptation_strategy(&task_data, &similar_tasks)?;
968
969        // 4. Perform fast adaptation
970        let mut adaptation_result = self.fast_adaptation.adapt_fast(
971            &mut *self.base_optimizer,
972            &task_data,
973            strategy,
974            &adaptation_config,
975        )?;
976
977        // 5. Update task representation
978        adaptation_result.task_representation = task_representation;
979
980        // 6. Store experience in memory bank
981        self.memory_bank
982            .store_episode(task_data.clone(), adaptation_result.clone())?;
983
984        // 7. Update performance tracker
985        self.performance_tracker
986            .record_performance(&adaptation_result)?;
987
988        // 8. Update prototypical network with new experience
989        self.prototype_network
990            .update_prototypes(&task_data, &adaptation_result)?;
991
992        Ok(adaptation_result)
993    }
994
995    fn select_adaptation_strategy(
996        &self,
997        task_data: &TaskData<T>,
998        _similar_tasks: &[MemoryEpisode<T>],
999    ) -> Result<AdaptationStrategyType> {
1000        // Strategy selection based on task characteristics and historical performance
1001        match task_data.domain_info.difficulty_level {
1002            DifficultyLevel::Trivial | DifficultyLevel::Easy => Ok(AdaptationStrategyType::FOMAML),
1003            DifficultyLevel::Medium => Ok(AdaptationStrategyType::MAML),
1004            DifficultyLevel::Hard | DifficultyLevel::Expert => {
1005                Ok(AdaptationStrategyType::Prototypical)
1006            }
1007            DifficultyLevel::Extreme => Ok(AdaptationStrategyType::MemoryAugmented),
1008        }
1009    }
1010}
1011
1012/// Few-shot learning system configuration
1013#[derive(Debug, Clone)]
1014pub struct FewShotConfig<T: Float + Debug + Send + Sync + 'static> {
1015    /// Prototypical network configuration
1016    pub prototype_config: PrototypicalNetworkConfig<T>,
1017
1018    /// Support set management configuration
1019    pub support_set_config: SupportSetManagerConfig,
1020
1021    /// Similarity calculation configuration
1022    pub similarity_config: SimilarityCalculatorConfig<T>,
1023
1024    /// Memory bank configuration
1025    pub memory_config: MemoryBankConfig<T>,
1026
1027    /// Fast adaptation configuration
1028    pub adaptation_config: FastAdaptationConfig,
1029
1030    /// Performance tracking configuration
1031    pub tracking_config: TrackingConfig,
1032}
1033
1034/// Prototypical network configuration
1035#[derive(Debug, Clone)]
1036pub struct PrototypicalNetworkConfig<T: Float + Debug + Send + Sync + 'static> {
1037    /// Embedding dimension
1038    pub embedding_dim: usize,
1039
1040    /// Learning rate
1041    pub learning_rate: T,
1042
1043    /// Number of encoder layers
1044    pub num_layers: usize,
1045
1046    /// Hidden dimension
1047    pub hidden_dim: usize,
1048}
1049
1050// Constructors and core methods for major components
1051impl<T: Float + Debug + Send + Sync + 'static> PrototypicalNetwork<T> {
1052    /// Create a new prototypical network from configuration
1053    pub fn new(config: PrototypicalNetworkConfig<T>) -> Result<Self> {
1054        if config.embedding_dim == 0 {
1055            return Err(OptimError::InvalidConfig(
1056                "embedding_dim must be > 0".to_string(),
1057            ));
1058        }
1059        // Build a single-layer encoder: input_dim -> embedding_dim
1060        let input_dim = config.hidden_dim.max(config.embedding_dim);
1061        let layer = EncoderLayer {
1062            weights: Array2::zeros((input_dim, config.embedding_dim)),
1063            bias: Array1::zeros(config.embedding_dim),
1064            layer_type: LayerType::Linear,
1065        };
1066        Ok(Self {
1067            encoder: EncoderNetwork {
1068                layers: vec![layer],
1069                activation: ActivationFunction::ReLU,
1070            },
1071            prototypes: HashMap::new(),
1072            distance_metric: DistanceMetric::Euclidean,
1073            parameters: PrototypicalNetworkParams {
1074                embedding_dim: config.embedding_dim,
1075                learning_rate: config.learning_rate,
1076                temperature: T::one(),
1077                prototype_update_rate: scirs2_core::numeric::NumCast::from(0.1)
1078                    .unwrap_or_else(|| T::one()),
1079            },
1080        })
1081    }
1082
1083    /// Create from embedding dimension and number of classes (convenience)
1084    pub fn from_dims(embedding_dim: usize, _num_classes: usize) -> Result<Self> {
1085        if embedding_dim == 0 {
1086            return Err(OptimError::InvalidConfig(
1087                "embedding_dim must be > 0".to_string(),
1088            ));
1089        }
1090        let layer = EncoderLayer {
1091            weights: Array2::zeros((embedding_dim, embedding_dim)),
1092            bias: Array1::zeros(embedding_dim),
1093            layer_type: LayerType::Linear,
1094        };
1095        Ok(Self {
1096            encoder: EncoderNetwork {
1097                layers: vec![layer],
1098                activation: ActivationFunction::ReLU,
1099            },
1100            prototypes: HashMap::new(),
1101            distance_metric: DistanceMetric::Euclidean,
1102            parameters: PrototypicalNetworkParams {
1103                embedding_dim,
1104                learning_rate: scirs2_core::numeric::NumCast::from(0.01)
1105                    .unwrap_or_else(|| T::one()),
1106                temperature: T::one(),
1107                prototype_update_rate: scirs2_core::numeric::NumCast::from(0.1)
1108                    .unwrap_or_else(|| T::one()),
1109            },
1110        })
1111    }
1112
1113    /// Get the embedding dimension
1114    pub fn embedding_dim(&self) -> usize {
1115        self.parameters.embedding_dim
1116    }
1117
1118    /// Get encoder layers (for projection)
1119    pub fn encoder_layers(&self) -> &[EncoderLayer<T>] {
1120        &self.encoder.layers
1121    }
1122
1123    /// Get mutable access to prototypes
1124    pub fn prototypes_mut(&mut self) -> &mut HashMap<String, Prototype<T>> {
1125        &mut self.prototypes
1126    }
1127
1128    /// Get read access to prototypes
1129    pub fn prototypes(&self) -> &HashMap<String, Prototype<T>> {
1130        &self.prototypes
1131    }
1132
1133    /// Get the distance metric
1134    pub fn distance_metric(&self) -> &DistanceMetric {
1135        &self.distance_metric
1136    }
1137
1138    /// Encode a task into an embedding vector
1139    pub fn encode_task(&self, task_data: &TaskData<T>) -> Result<Array1<T>> {
1140        // Compute mean of support set features as task representation
1141        if task_data.support_set.examples.is_empty() {
1142            return Err(OptimError::InsufficientData(
1143                "No support examples for encoding".to_string(),
1144            ));
1145        }
1146        let dim = self.parameters.embedding_dim;
1147        let mut sum = Array1::<T>::zeros(dim);
1148        let count = task_data.support_set.examples.len();
1149        for ex in &task_data.support_set.examples {
1150            let feat = &ex.features;
1151            let len = feat.len().min(dim);
1152            for i in 0..len {
1153                sum[i] = sum[i] + feat[i];
1154            }
1155        }
1156        let count_t = scirs2_core::numeric::NumCast::from(count).unwrap_or_else(|| T::one());
1157        for i in 0..dim {
1158            sum[i] = sum[i] / count_t;
1159        }
1160        Ok(sum)
1161    }
1162
1163    /// Update prototypes with new experience
1164    pub fn update_prototypes(
1165        &mut self,
1166        task_data: &TaskData<T>,
1167        _result: &AdaptationResult<T>,
1168    ) -> Result<()> {
1169        let task_repr = self.encode_task(task_data)?;
1170        let task_id = task_data.task_id.clone();
1171        let update_rate = self.parameters.prototype_update_rate;
1172        let one_minus = T::one() - update_rate;
1173
1174        if let Some(proto) = self.prototypes.get_mut(&task_id) {
1175            // Exponential moving average update
1176            let dim = proto.vector.len().min(task_repr.len());
1177            for i in 0..dim {
1178                proto.vector[i] = one_minus * proto.vector[i] + update_rate * task_repr[i];
1179            }
1180            proto.example_count += task_data.support_set.examples.len();
1181            proto.last_updated = std::time::SystemTime::now();
1182            proto.metadata.update_count += 1;
1183        } else {
1184            let proto = Prototype {
1185                vector: task_repr,
1186                confidence: T::one(),
1187                example_count: task_data.support_set.examples.len(),
1188                last_updated: std::time::SystemTime::now(),
1189                metadata: PrototypeMetadata {
1190                    task_category: task_id.clone(),
1191                    domain: task_data.domain_info.domain_type,
1192                    created_at: std::time::SystemTime::now(),
1193                    update_count: 1,
1194                },
1195            };
1196            self.prototypes.insert(task_id, proto);
1197        }
1198        Ok(())
1199    }
1200}
1201
1202impl<T: Float + Debug + Send + Sync + 'static> SupportSetManager<T> {
1203    /// Create a new support set manager
1204    pub fn new(config: SupportSetManagerConfig) -> Result<Self> {
1205        if config.max_support_size == 0 {
1206            return Err(OptimError::InvalidConfig(
1207                "max_support_size must be > 0".to_string(),
1208            ));
1209        }
1210        Ok(Self {
1211            support_sets: HashMap::new(),
1212            selection_strategy: SupportSetSelectionStrategy::DiversityBased,
1213            config,
1214        })
1215    }
1216
1217    /// Create from max support size (convenience)
1218    pub fn from_max_size(max_support_size: usize) -> Result<Self> {
1219        Self::new(SupportSetManagerConfig {
1220            min_support_size: 1,
1221            max_support_size,
1222            quality_threshold: 0.5,
1223            enable_augmentation: true,
1224            cache_support_sets: true,
1225        })
1226    }
1227
1228    /// Get the max support size
1229    pub fn max_support_size(&self) -> usize {
1230        self.config.max_support_size
1231    }
1232
1233    /// Get the config
1234    pub fn config(&self) -> &SupportSetManagerConfig {
1235        &self.config
1236    }
1237}
1238
1239impl<T: Float + Debug + Send + Sync + 'static> TaskSimilarityCalculator<T> {
1240    /// Create a new task similarity calculator
1241    pub fn new(config: SimilarityCalculatorConfig<T>) -> Result<Self> {
1242        Ok(Self {
1243            similarity_metrics: Vec::new(),
1244            metric_weights: HashMap::new(),
1245            similarity_cache: HashMap::new(),
1246            config,
1247        })
1248    }
1249
1250    /// Create with default configuration
1251    pub fn default_new() -> Result<Self> {
1252        Self::new(SimilarityCalculatorConfig {
1253            enable_caching: true,
1254            cache_size_limit: 1000,
1255            similarity_threshold: scirs2_core::numeric::NumCast::from(0.5)
1256                .unwrap_or_else(|| T::zero()),
1257            use_metadata: true,
1258        })
1259    }
1260
1261    /// Get mutable access to the similarity cache
1262    pub fn similarity_cache_mut(&mut self) -> &mut HashMap<(String, String), T> {
1263        &mut self.similarity_cache
1264    }
1265
1266    /// Get whether caching is enabled
1267    pub fn caching_enabled(&self) -> bool {
1268        self.config.enable_caching
1269    }
1270}
1271
1272impl<T: Float + Debug + Send + Sync + 'static> EpisodicMemoryBank<T> {
1273    /// Create a new episodic memory bank
1274    pub fn new(config: MemoryBankConfig<T>) -> Result<Self> {
1275        if config.max_memory_size == 0 {
1276            return Err(OptimError::InvalidConfig(
1277                "max_memory_size must be > 0".to_string(),
1278            ));
1279        }
1280        Ok(Self {
1281            episodes: VecDeque::new(),
1282            config,
1283            usage_stats: MemoryUsageStats {
1284                total_episodes: 0,
1285                memory_utilization: 0.0,
1286                hit_rate: 0.0,
1287                avg_retrieval_time: Duration::from_secs(0),
1288            },
1289        })
1290    }
1291
1292    /// Create from capacity (convenience)
1293    pub fn from_capacity(capacity: usize) -> Result<Self> {
1294        Self::new(MemoryBankConfig {
1295            max_memory_size: capacity,
1296            eviction_policy: EvictionPolicy::Performance,
1297            similarity_threshold: scirs2_core::numeric::NumCast::from(0.3)
1298                .unwrap_or_else(|| T::zero()),
1299            enable_compression: false,
1300        })
1301    }
1302
1303    /// Get the capacity
1304    pub fn capacity(&self) -> usize {
1305        self.config.max_memory_size
1306    }
1307
1308    /// Get episode count
1309    pub fn len(&self) -> usize {
1310        self.episodes.len()
1311    }
1312
1313    /// Check if empty
1314    pub fn is_empty(&self) -> bool {
1315        self.episodes.is_empty()
1316    }
1317
1318    /// Get mutable access to episodes
1319    pub fn episodes_mut(&mut self) -> &mut VecDeque<MemoryEpisode<T>> {
1320        &mut self.episodes
1321    }
1322
1323    /// Get read access to episodes
1324    pub fn episodes(&self) -> &VecDeque<MemoryEpisode<T>> {
1325        &self.episodes
1326    }
1327
1328    /// Get mutable access to usage stats
1329    pub fn usage_stats_mut(&mut self) -> &mut MemoryUsageStats {
1330        &mut self.usage_stats
1331    }
1332
1333    /// Get usage stats
1334    pub fn usage_stats(&self) -> &MemoryUsageStats {
1335        &self.usage_stats
1336    }
1337
1338    /// Get eviction policy
1339    pub fn eviction_policy(&self) -> EvictionPolicy {
1340        self.config.eviction_policy
1341    }
1342
1343    /// Retrieve similar episodes to a task
1344    pub fn retrieve_similar(
1345        &self,
1346        task_data: &TaskData<T>,
1347        k: usize,
1348    ) -> Result<Vec<MemoryEpisode<T>>> {
1349        if self.episodes.is_empty() {
1350            return Ok(Vec::new());
1351        }
1352        // Simple retrieval: return up to k most recent episodes
1353        // (full similarity-based retrieval is in episodic_memory_impl)
1354        let count = k.min(self.episodes.len());
1355        let result: Vec<MemoryEpisode<T>> =
1356            self.episodes.iter().rev().take(count).cloned().collect();
1357        let _ = task_data; // suppress unused warning
1358        Ok(result)
1359    }
1360
1361    /// Store a new episode
1362    pub fn store_episode(
1363        &mut self,
1364        task_data: TaskData<T>,
1365        result: AdaptationResult<T>,
1366    ) -> Result<()> {
1367        let episode = MemoryEpisode {
1368            episode_id: format!("ep_{}", self.usage_stats.total_episodes),
1369            task_data,
1370            adaptation_result: result,
1371            timestamp: std::time::SystemTime::now(),
1372            metadata: EpisodeMetadata {
1373                difficulty: DifficultyLevel::Medium,
1374                domain: DomainType::Optimization,
1375                success_rate: 0.0,
1376                tags: Vec::new(),
1377            },
1378            access_count: 0,
1379        };
1380
1381        if self.episodes.len() >= self.config.max_memory_size {
1382            self.episodes.pop_front();
1383        }
1384        self.episodes.push_back(episode);
1385        self.usage_stats.total_episodes += 1;
1386        self.usage_stats.memory_utilization =
1387            self.episodes.len() as f64 / self.config.max_memory_size as f64;
1388        Ok(())
1389    }
1390}
1391
1392impl<T: Float + Debug + Send + Sync + 'static> FastAdaptationEngine<T> {
1393    /// Create a new fast adaptation engine
1394    pub fn new(config: FastAdaptationConfig) -> Result<Self> {
1395        Ok(Self {
1396            algorithms: Vec::new(),
1397            config,
1398        })
1399    }
1400
1401    /// Create from inner learning rate and adaptation steps (convenience)
1402    pub fn from_params(inner_lr: T, adaptation_steps: usize) -> Result<Self> {
1403        let _ = (inner_lr, adaptation_steps);
1404        Self::new(FastAdaptationConfig {
1405            enable_caching: true,
1406            enable_prediction: true,
1407            max_adaptation_time: Duration::from_secs(60),
1408            _performance_threshold: 0.8,
1409        })
1410    }
1411
1412    /// Get the configuration
1413    pub fn config(&self) -> &FastAdaptationConfig {
1414        &self.config
1415    }
1416
1417    /// Perform fast adaptation
1418    pub fn adapt_fast(
1419        &mut self,
1420        _optimizer: &mut dyn FewShotOptimizer<T>,
1421        _task_data: &TaskData<T>,
1422        _strategy: AdaptationStrategyType,
1423        _config: &AdaptationConfig,
1424    ) -> Result<AdaptationResult<T>> {
1425        Ok(AdaptationResult {
1426            adapted_state: OptimizerState {
1427                parameters: Array1::zeros(1),
1428                gradients: Array1::zeros(1),
1429                momentum: None,
1430                hidden_states: HashMap::new(),
1431                memory_buffers: HashMap::new(),
1432                step: 0,
1433                step_count: 0,
1434                loss: None,
1435                learning_rate: scirs2_core::numeric::NumCast::from(0.001)
1436                    .unwrap_or_else(|| T::one()),
1437                metadata: super::StateMetadata {
1438                    task_id: None,
1439                    optimizer_type: None,
1440                    version: "1.0".to_string(),
1441                    timestamp: std::time::SystemTime::now(),
1442                    checksum: 0,
1443                    compression_level: 0,
1444                    custom_data: HashMap::new(),
1445                },
1446            },
1447            performance: AdaptationPerformance {
1448                query_performance: scirs2_core::numeric::NumCast::from(0.85)
1449                    .unwrap_or_else(|| T::zero()),
1450                support_performance: scirs2_core::numeric::NumCast::from(0.90)
1451                    .unwrap_or_else(|| T::zero()),
1452                adaptation_speed: 5,
1453                final_loss: scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero()),
1454                improvement: scirs2_core::numeric::NumCast::from(0.25).unwrap_or_else(|| T::zero()),
1455                stability: scirs2_core::numeric::NumCast::from(0.95).unwrap_or_else(|| T::zero()),
1456            },
1457            task_representation: Array1::zeros(128),
1458            adaptation_trajectory: Vec::new(),
1459            resource_usage: ResourceUsage {
1460                total_time: Duration::from_secs(15),
1461                peak_memory_mb: scirs2_core::numeric::NumCast::from(256.0)
1462                    .unwrap_or_else(|| T::zero()),
1463                compute_cost: scirs2_core::numeric::NumCast::from(5.0).unwrap_or_else(|| T::zero()),
1464                energy_consumption: scirs2_core::numeric::NumCast::from(0.05)
1465                    .unwrap_or_else(|| T::zero()),
1466            },
1467        })
1468    }
1469}
1470
1471impl<T: Float + Debug + Send + Sync + 'static> FewShotPerformanceTracker<T> {
1472    /// Create a new performance tracker
1473    pub fn new(config: TrackingConfig) -> Result<Self> {
1474        Ok(Self {
1475            performance_history: VecDeque::new(),
1476            metrics: Vec::new(),
1477            config,
1478            stats: PerformanceStats {
1479                best_performance: T::zero(),
1480                average_performance: T::zero(),
1481                performance_variance: T::zero(),
1482                improvement_rate: T::zero(),
1483                success_rate: T::zero(),
1484            },
1485        })
1486    }
1487
1488    /// Record a performance result
1489    pub fn record_performance(&mut self, result: &AdaptationResult<T>) -> Result<()> {
1490        let record = PerformanceRecord {
1491            task_id: String::new(),
1492            performance: result.performance.query_performance,
1493            adaptation_time: result.resource_usage.total_time,
1494            strategy_used: String::new(),
1495            timestamp: std::time::SystemTime::now(),
1496            additional_metrics: HashMap::new(),
1497        };
1498
1499        if self.performance_history.len() >= self.config.max_history_size {
1500            self.performance_history.pop_front();
1501        }
1502        self.performance_history.push_back(record);
1503
1504        // Update stats
1505        if !self.performance_history.is_empty() {
1506            let mut sum = T::zero();
1507            let mut best = T::neg_infinity();
1508            for r in &self.performance_history {
1509                sum = sum + r.performance;
1510                if r.performance > best {
1511                    best = r.performance;
1512                }
1513            }
1514            let count_t = scirs2_core::numeric::NumCast::from(self.performance_history.len())
1515                .unwrap_or_else(|| T::one());
1516            self.stats.average_performance = sum / count_t;
1517            self.stats.best_performance = best;
1518        }
1519        Ok(())
1520    }
1521}
1522
1523// Task metadata and example metadata
1524#[derive(Debug, Clone)]
1525pub struct TaskMetadata {
1526    pub task_name: String,
1527    pub domain: DomainType,
1528    pub difficulty: DifficultyLevel,
1529    pub created_at: std::time::SystemTime,
1530}
1531
1532#[derive(Debug, Clone)]
1533pub struct ExampleMetadata {
1534    pub source: String,
1535    pub quality_score: f64,
1536    pub created_at: std::time::SystemTime,
1537}
1538
1539#[derive(Debug, Clone)]
1540pub struct SupportSetStatistics<T: Float + Debug + Send + Sync + 'static> {
1541    pub mean: Array1<T>,
1542    pub variance: Array1<T>,
1543    pub size: usize,
1544    pub diversity_score: T,
1545}
1546
1547#[derive(Debug, Clone)]
1548pub struct QuerySetStatistics<T: Float + Debug + Send + Sync + 'static> {
1549    pub mean: Array1<T>,
1550    pub variance: Array1<T>,
1551    pub size: usize,
1552}
1553
1554/// Evaluation metrics for few-shot learning
1555#[derive(Debug, Clone, Copy)]
1556pub enum EvaluationMetric {
1557    Accuracy,
1558    Loss,
1559    F1Score,
1560    Precision,
1561    Recall,
1562    AUC,
1563    MSE,
1564    MAE,
1565    FinalPerformance,
1566    Efficiency,
1567    TrainingTime,
1568}
1569
1570#[cfg(test)]
1571mod tests {
1572    use super::*;
1573
1574    #[test]
1575    fn test_support_set_creation() {
1576        let support_set = SupportSet::<f64> {
1577            examples: vec![SupportExample {
1578                features: Array1::from_vec(vec![1.0, 2.0, 3.0]),
1579                target: 0.5,
1580                weight: 1.0,
1581                context: HashMap::new(),
1582                metadata: ExampleMetadata {
1583                    source: "test".to_string(),
1584                    quality_score: 0.9,
1585                    created_at: std::time::SystemTime::now(),
1586                },
1587            }],
1588            task_metadata: TaskMetadata {
1589                task_name: "test_task".to_string(),
1590                domain: DomainType::Optimization,
1591                difficulty: DifficultyLevel::Easy,
1592                created_at: std::time::SystemTime::now(),
1593            },
1594            statistics: SupportSetStatistics {
1595                mean: Array1::from_vec(vec![1.0, 2.0, 3.0]),
1596                variance: Array1::from_vec(vec![0.1, 0.1, 0.1]),
1597                size: 1,
1598                diversity_score: 0.8,
1599            },
1600            temporal_order: None,
1601        };
1602
1603        assert_eq!(support_set.examples.len(), 1);
1604        assert_eq!(support_set.statistics.size, 1);
1605    }
1606
1607    #[test]
1608    fn test_adaptation_config() {
1609        let config = AdaptationConfig {
1610            adaptation_steps: 10,
1611            adaptation_lr: 0.01,
1612            strategy: AdaptationStrategyType::MAML,
1613            early_stopping: None,
1614            regularization: RegularizationConfig {
1615                l2_strength: 0.001,
1616                dropout_rate: 0.1,
1617                gradient_clip: Some(1.0),
1618                task_regularization: HashMap::new(),
1619            },
1620            resource_constraints: ResourceConstraints {
1621                max_time: Duration::from_secs(60),
1622                max_memory_mb: 1024,
1623                max_compute_budget: 100.0,
1624            },
1625        };
1626
1627        assert_eq!(config.adaptation_steps, 10);
1628        assert_eq!(config.adaptation_lr, 0.01);
1629        assert!(matches!(config.strategy, AdaptationStrategyType::MAML));
1630    }
1631
1632    #[test]
1633    fn test_domain_info() {
1634        let domain_info = DomainInfo {
1635            domain_type: DomainType::ComputerVision,
1636            characteristics: DomainCharacteristics {
1637                input_dim: 784,
1638                output_dim: 10,
1639                temporal: false,
1640                stochasticity: 0.1,
1641                noise_level: 0.05,
1642                sparsity: 0.0,
1643            },
1644            difficulty_level: DifficultyLevel::Medium,
1645            constraints: vec![DomainConstraint {
1646                constraint_type: ConstraintType::LatencyRequirement,
1647                description: "Max 100ms inference time".to_string(),
1648                enforcement: ConstraintEnforcement::Hard,
1649            }],
1650        };
1651
1652        assert!(matches!(
1653            domain_info.domain_type,
1654            DomainType::ComputerVision
1655        ));
1656        assert_eq!(domain_info.characteristics.input_dim, 784);
1657        assert_eq!(domain_info.constraints.len(), 1);
1658    }
1659}