Skip to main content

optirs_learned/meta_learning/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5#[allow(unused_imports)]
6use crate::error::Result;
7use optirs_core::optimizers::Optimizer;
8#[allow(dead_code)]
9use scirs2_core::ndarray::{Array1, Array2, Dimension};
10use scirs2_core::numeric::Float;
11use std::collections::{HashMap, VecDeque};
12use std::fmt::Debug;
13use std::time::Instant;
14
15use super::functions::MetaLearner;
16
17/// Adaptation step
18#[derive(Debug, Clone)]
19pub struct AdaptationStep<T: Float + Debug + Send + Sync + 'static> {
20    /// Step number
21    pub step: usize,
22    /// Loss at this step
23    pub loss: T,
24    /// Gradient norm
25    pub gradient_norm: T,
26    /// Parameter change norm
27    pub parameter_change_norm: T,
28    /// Learning rate used
29    pub learning_rate: T,
30}
31/// Meta-training metrics
32#[derive(Debug, Clone)]
33pub struct MetaTrainingMetrics<T: Float + Debug + Send + Sync + 'static> {
34    /// Average adaptation speed
35    pub avg_adaptation_speed: T,
36    /// Generalization performance
37    pub generalization_performance: T,
38    /// Task diversity handled
39    pub task_diversity: T,
40    /// Gradient alignment score
41    pub gradient_alignment: T,
42}
43/// Hessian-vector product engine
44#[derive(Debug)]
45pub struct HessianVectorProductEngine<T: Float + Debug + Send + Sync + 'static> {
46    /// HVP computation method
47    method: HVPComputationMethod,
48    /// Vector cache
49    vector_cache: Vec<Array1<T>>,
50    /// Product cache
51    product_cache: Vec<Array1<T>>,
52}
53impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> HessianVectorProductEngine<T> {
54    /// Create a new HVP engine
55    pub fn new() -> Result<Self> {
56        Ok(Self {
57            method: HVPComputationMethod::FiniteDifference,
58            vector_cache: Vec::new(),
59            product_cache: Vec::new(),
60        })
61    }
62}
63/// Task identification methods
64#[derive(Debug, Clone, Copy)]
65pub enum TaskIdentificationMethod {
66    Oracle,
67    Learned,
68    Clustering,
69    EntropyBased,
70    GradientBased,
71}
72/// Query evaluation metrics
73#[derive(Debug, Clone)]
74pub struct QueryEvaluationMetrics<T: Float + Debug + Send + Sync + 'static> {
75    /// Mean squared error (for regression)
76    pub mse: Option<T>,
77    /// Classification accuracy (for classification)
78    pub classification_accuracy: Option<T>,
79    /// AUC score
80    pub auc: Option<T>,
81    /// Uncertainty estimation quality
82    pub uncertainty_quality: T,
83}
84/// Meta-parameters for meta-learning
85#[derive(Debug, Clone)]
86pub struct MetaParameters<T: Float + Debug + Send + Sync + 'static> {
87    /// Parameter values
88    pub parameters: HashMap<String, Array1<T>>,
89    /// Parameter metadata
90    pub metadata: HashMap<String, String>,
91}
92/// MAML implementation
93pub struct MAMLLearner<T: Float + Debug + Send + Sync + 'static, D: Dimension> {
94    /// MAML configuration
95    pub(super) config: MAMLConfig<T>,
96    /// Inner loop optimizer
97    inner_optimizer: Box<dyn Optimizer<T, D> + Send + Sync>,
98    /// Outer loop optimizer
99    outer_optimizer: Box<dyn Optimizer<T, D> + Send + Sync>,
100    /// Gradient computation engine
101    gradient_engine: GradientComputationEngine<T>,
102    /// Second-order gradient computation
103    second_order_engine: Option<SecondOrderGradientEngine<T>>,
104    /// Task adaptation history
105    adaptation_history: VecDeque<TaskAdaptationResult<T>>,
106}
107impl<
108        T: Float
109            + Default
110            + Clone
111            + Send
112            + Sync
113            + scirs2_core::ndarray::ScalarOperand
114            + std::fmt::Debug,
115        D: Dimension,
116    > MAMLLearner<T, D>
117{
118    pub fn new(config: MAMLConfig<T>) -> Result<Self> {
119        let inner_optimizer: Box<dyn Optimizer<T, D> + Send + Sync> =
120            Box::new(optirs_core::optimizers::SGD::new(config.inner_lr));
121        let outer_optimizer: Box<dyn Optimizer<T, D> + Send + Sync> =
122            Box::new(optirs_core::optimizers::SGD::new(config.outer_lr));
123        let gradient_engine = GradientComputationEngine::new()?;
124        let second_order_engine = if config.second_order {
125            Some(SecondOrderGradientEngine::new()?)
126        } else {
127            None
128        };
129        let adaptation_history = VecDeque::with_capacity(1000);
130        Ok(Self {
131            config,
132            inner_optimizer,
133            outer_optimizer,
134            gradient_engine,
135            second_order_engine,
136            adaptation_history,
137        })
138    }
139}
140impl<T: Float + Debug + Send + Sync + 'static + Default + Clone + std::iter::Sum, D: Dimension>
141    MAMLLearner<T, D>
142{
143    pub(super) fn compute_support_loss(
144        &self,
145        task: &MetaTask<T>,
146        _parameters: &HashMap<String, Array1<T>>,
147    ) -> Result<T> {
148        let mut total_loss = T::zero();
149        for (features, target) in task
150            .support_set
151            .features
152            .iter()
153            .zip(&task.support_set.targets)
154        {
155            let prediction = features.iter().copied().sum::<T>()
156                / T::from(features.len()).expect("unwrap failed");
157            let loss = (prediction - *target) * (prediction - *target);
158            total_loss = total_loss + loss;
159        }
160        Ok(total_loss / T::from(task.support_set.features.len()).expect("unwrap failed"))
161    }
162    pub(super) fn compute_gradients(
163        &self,
164        parameters: &HashMap<String, Array1<T>>,
165        _loss: T,
166    ) -> Result<HashMap<String, Array1<T>>> {
167        let mut gradients = HashMap::new();
168        for (name, param) in parameters {
169            let grad = Array1::zeros(param.len());
170            gradients.insert(name.clone(), grad);
171        }
172        Ok(gradients)
173    }
174}
175/// Memory selection criteria
176#[derive(Debug, Clone, Copy)]
177pub enum MemorySelectionCriteria {
178    Random,
179    GradientMagnitude,
180    LossBased,
181    Uncertainty,
182    Diversity,
183    TemporalProximity,
184}
185/// Task distribution manager
186pub struct TaskDistributionManager<T: Float + Debug + Send + Sync + 'static> {
187    config: MetaLearningConfig,
188    _phantom: std::marker::PhantomData<T>,
189}
190impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> TaskDistributionManager<T> {
191    pub fn new(config: &MetaLearningConfig) -> Result<Self> {
192        Ok(Self {
193            config: config.clone(),
194            _phantom: std::marker::PhantomData,
195        })
196    }
197    pub fn sample_task_batch(
198        &self,
199        _tasks: &[MetaTask<T>],
200        batch_size: usize,
201    ) -> Result<Vec<MetaTask<T>>> {
202        Ok(vec![MetaTask::default(); batch_size.min(10)])
203    }
204}
205/// Task metadata
206#[derive(Debug, Clone)]
207pub struct TaskMetadata {
208    /// Task name
209    pub name: String,
210    /// Task description
211    pub description: String,
212    /// Task properties
213    pub properties: HashMap<String, String>,
214    /// Creation timestamp
215    pub created_at: Instant,
216    /// Task source
217    pub source: String,
218}
219/// Computation graph for gradient computation
220#[derive(Debug)]
221pub struct ComputationGraph<T: Float + Debug + Send + Sync + 'static> {
222    /// Graph nodes
223    nodes: Vec<ComputationNode<T>>,
224    /// Node dependencies
225    dependencies: HashMap<usize, Vec<usize>>,
226    /// Topological order
227    topological_order: Vec<usize>,
228    /// Input nodes
229    input_nodes: Vec<usize>,
230    /// Output nodes
231    output_nodes: Vec<usize>,
232}
233impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> ComputationGraph<T> {
234    /// Create a new computation graph
235    pub fn new() -> Result<Self> {
236        Ok(Self {
237            nodes: Vec::new(),
238            dependencies: HashMap::new(),
239            topological_order: Vec::new(),
240            input_nodes: Vec::new(),
241            output_nodes: Vec::new(),
242        })
243    }
244}
245/// Meta-training epoch
246#[derive(Debug, Clone)]
247pub struct MetaTrainingEpoch<T: Float + Debug + Send + Sync + 'static> {
248    pub epoch: usize,
249    pub training_result: MetaTrainingResult<T>,
250    pub validation_result: MetaValidationResult<T>,
251    pub meta_parameters: HashMap<String, Array1<T>>,
252}
253/// Task dataset
254#[derive(Debug, Clone)]
255pub struct TaskDataset<T: Float + Debug + Send + Sync + 'static> {
256    /// Input features
257    pub features: Vec<Array1<T>>,
258    /// Target values
259    pub targets: Vec<T>,
260    /// Sample weights
261    pub weights: Vec<T>,
262    /// Dataset metadata
263    pub metadata: DatasetMetadata,
264}
265/// MAML configuration
266#[derive(Debug, Clone)]
267pub struct MAMLConfig<T: Float + Debug + Send + Sync + 'static> {
268    /// Enable second-order gradients
269    pub second_order: bool,
270    /// Inner learning rate
271    pub inner_lr: T,
272    /// Outer learning rate
273    pub outer_lr: T,
274    /// Number of inner steps
275    pub inner_steps: usize,
276    /// Allow unused parameters
277    pub allow_unused: bool,
278    /// Gradient clipping
279    pub gradient_clip: Option<f64>,
280}
281/// Continual learning system
282pub struct ContinualLearningSystem<T: Float + Debug + Send + Sync + 'static> {
283    settings: ContinualLearningSettings,
284    _phantom: std::marker::PhantomData<T>,
285}
286impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> ContinualLearningSystem<T> {
287    pub fn new(settings: &ContinualLearningSettings) -> Result<Self> {
288        Ok(Self {
289            settings: settings.clone(),
290            _phantom: std::marker::PhantomData,
291        })
292    }
293    pub fn learn_sequence(
294        &mut self,
295        sequence: &[MetaTask<T>],
296        _meta_parameters: &mut HashMap<String, Array1<T>>,
297    ) -> Result<ContinualLearningResult<T>> {
298        let mut sequence_results = Vec::new();
299        for task in sequence {
300            let task_result = TaskResult {
301                task_id: task.id.clone(),
302                loss: scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero()),
303                metrics: HashMap::new(),
304            };
305            sequence_results.push(task_result);
306        }
307        Ok(ContinualLearningResult {
308            sequence_results,
309            forgetting_measure: scirs2_core::numeric::NumCast::from(0.05)
310                .unwrap_or_else(|| T::zero()),
311            adaptation_efficiency: scirs2_core::numeric::NumCast::from(0.95)
312                .unwrap_or_else(|| T::zero()),
313        })
314    }
315    pub fn forgetting_measure(&self) -> T {
316        T::from(0.05).unwrap_or_default()
317    }
318}
319/// Meta-training result
320#[derive(Debug, Clone)]
321pub struct MetaTrainingResult<T: Float + Debug + Send + Sync + 'static> {
322    /// Meta-loss
323    pub meta_loss: T,
324    /// Per-task losses
325    pub task_losses: Vec<T>,
326    /// Meta-gradients
327    pub meta_gradients: HashMap<String, Array1<T>>,
328    /// Training metrics
329    pub metrics: MetaTrainingMetrics<T>,
330    /// Adaptation statistics
331    pub adaptation_stats: AdaptationStatistics<T>,
332}
333/// Metric learning settings
334#[derive(Debug, Clone)]
335pub struct MetricLearningSettings {
336    /// Distance metric
337    pub distance_metric: DistanceMetric,
338    /// Embedding dimension
339    pub embedding_dim: usize,
340    /// Learned metric parameters
341    pub learned_metric: bool,
342}
343/// Multi-task learning result
344#[derive(Debug, Clone)]
345pub struct MultiTaskResult<T: Float + Debug + Send + Sync + 'static> {
346    pub task_results: Vec<TaskResult<T>>,
347    pub coordination_overhead: T,
348    pub convergence_status: String,
349}
350/// Memory replay settings
351#[derive(Debug, Clone)]
352pub struct MemoryReplaySettings {
353    /// Memory buffer size
354    pub buffer_size: usize,
355    /// Replay strategy
356    pub replay_strategy: ReplayStrategy,
357    /// Replay frequency
358    pub replay_frequency: usize,
359    /// Memory selection criteria
360    pub selection_criteria: MemorySelectionCriteria,
361}
362/// Task types
363#[derive(Debug, Clone, Copy)]
364pub enum TaskType {
365    Regression,
366    Classification,
367    Optimization,
368    ReinforcementLearning,
369    StructuredPrediction,
370    Generative,
371}
372/// Forward mode automatic differentiation
373#[derive(Debug)]
374pub struct ForwardModeAD<T: Float + Debug + Send + Sync + 'static> {
375    /// Dual numbers
376    dual_numbers: Vec<DualNumber<T>>,
377    /// Jacobian matrix
378    jacobian: Array2<T>,
379}
380impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> ForwardModeAD<T> {
381    /// Create a new forward mode AD engine
382    pub fn new() -> Result<Self> {
383        Ok(Self {
384            dual_numbers: Vec::new(),
385            jacobian: Array2::zeros((1, 1)),
386        })
387    }
388}
389/// Tape entry for reverse mode AD
390#[derive(Debug, Clone)]
391pub struct TapeEntry<T: Float + Debug + Send + Sync + 'static> {
392    /// Operation ID
393    pub op_id: usize,
394    /// Input IDs
395    pub inputs: Vec<usize>,
396    /// Output ID
397    pub output: usize,
398    /// Local gradients
399    pub local_gradients: Vec<T>,
400}
401/// Interference mitigation strategies
402#[derive(Debug, Clone, Copy)]
403pub enum InterferenceMitigationStrategy {
404    OrthogonalGradients,
405    TaskSpecificLayers,
406    AttentionMechanisms,
407    MetaGradients,
408}
409/// Hessian computation methods
410#[derive(Debug, Clone, Copy)]
411pub enum HessianComputationMethod {
412    Exact,
413    FiniteDifference,
414    GaussNewton,
415    BFGS,
416    LBfgs,
417}
418/// Curvature estimator
419#[derive(Debug)]
420pub struct CurvatureEstimator<T: Float + Debug + Send + Sync + 'static> {
421    /// Curvature estimation method
422    method: CurvatureEstimationMethod,
423    /// Curvature history
424    curvature_history: VecDeque<T>,
425    /// Local curvature estimates
426    local_curvature: HashMap<String, T>,
427}
428impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> CurvatureEstimator<T> {
429    /// Create a new curvature estimator
430    pub fn new() -> Result<Self> {
431        Ok(Self {
432            method: CurvatureEstimationMethod::DiagonalHessian,
433            curvature_history: VecDeque::new(),
434            local_curvature: HashMap::new(),
435        })
436    }
437}
438/// Automatic differentiation engine
439#[derive(Debug)]
440pub struct AutoDiffEngine<T: Float + Debug + Send + Sync + 'static> {
441    /// Forward mode AD
442    forward_mode: ForwardModeAD<T>,
443    /// Reverse mode AD
444    reverse_mode: ReverseModeAD<T>,
445    /// Mixed mode AD
446    mixed_mode: MixedModeAD<T>,
447}
448impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> AutoDiffEngine<T> {
449    /// Create a new autodiff engine
450    pub fn new() -> Result<Self> {
451        Ok(Self {
452            forward_mode: ForwardModeAD::new()?,
453            reverse_mode: ReverseModeAD::new()?,
454            mixed_mode: MixedModeAD::new()?,
455        })
456    }
457}
458/// Replay strategies
459#[derive(Debug, Clone, Copy)]
460pub enum ReplayStrategy {
461    Random,
462    GradientBased,
463    UncertaintyBased,
464    DiversityBased,
465    Temporal,
466}
467/// Transfer learning settings
468#[derive(Debug, Clone)]
469pub struct TransferLearningSettings {
470    /// Enable domain adaptation
471    pub domain_adaptation: bool,
472    /// Source domain weights
473    pub source_domain_weights: Vec<f64>,
474    /// Transfer learning strategies
475    pub strategies: Vec<TransferStrategy>,
476    /// Domain similarity measures
477    pub similarity_measures: Vec<SimilarityMeasure>,
478    /// Enable progressive transfer
479    pub progressive_transfer: bool,
480}
481/// Validation result for meta-learning
482#[derive(Debug, Clone)]
483pub struct ValidationResult {
484    /// Whether validation passed
485    pub is_valid: bool,
486    /// Validation loss
487    pub validation_loss: f64,
488    /// Additional validation metrics
489    pub metrics: HashMap<String, f64>,
490}
491/// Transfer learning result
492#[derive(Debug, Clone)]
493pub struct TransferLearningResult<T: Float + Debug + Send + Sync + 'static> {
494    pub transfer_efficiency: T,
495    pub domain_adaptation_score: T,
496    pub source_task_retention: T,
497    pub target_task_performance: T,
498}
499/// Adaptation strategies
500#[derive(Debug, Clone, Copy)]
501pub enum AdaptationStrategy {
502    /// Fine-tuning all parameters
503    FullFineTuning,
504    /// Fine-tuning only specific layers
505    LayerWiseFineTuning,
506    /// Parameter-efficient adaptation
507    ParameterEfficient,
508    /// Adaptation via learned learning rates
509    LearnedLearningRates,
510    /// Gradient-based adaptation
511    GradientBased,
512    /// Memory-based adaptation
513    MemoryBased,
514    /// Attention-based adaptation
515    AttentionBased,
516    /// Modular adaptation
517    ModularAdaptation,
518}
519/// Meta-task representation
520#[derive(Debug, Clone)]
521pub struct MetaTask<T: Float + Debug + Send + Sync + 'static> {
522    /// Task identifier
523    pub id: String,
524    /// Support set (training data for adaptation)
525    pub support_set: TaskDataset<T>,
526    /// Query set (test data for evaluation)
527    pub query_set: TaskDataset<T>,
528    /// Task metadata
529    pub metadata: TaskMetadata,
530    /// Task difficulty
531    pub difficulty: T,
532    /// Task domain
533    pub domain: String,
534    /// Task type
535    pub task_type: TaskType,
536}
537/// Few-shot learner
538pub struct FewShotLearner<T: Float + Debug + Send + Sync + 'static> {
539    settings: FewShotSettings,
540    _phantom: std::marker::PhantomData<T>,
541}
542impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> FewShotLearner<T> {
543    pub fn new(settings: &FewShotSettings) -> Result<Self> {
544        Ok(Self {
545            settings: settings.clone(),
546            _phantom: std::marker::PhantomData,
547        })
548    }
549    pub fn learn(
550        &mut self,
551        _support_set: &TaskDataset<T>,
552        _query_set: &TaskDataset<T>,
553        _meta_parameters: &HashMap<String, Array1<T>>,
554    ) -> Result<FewShotResult<T>> {
555        Ok(FewShotResult {
556            accuracy: T::from(0.8).unwrap_or_default(),
557            confidence: T::from(0.9).unwrap_or_default(),
558            adaptation_steps: 5,
559            uncertainty_estimates: vec![T::from(0.1).unwrap_or_default(); 10],
560        })
561    }
562    pub fn average_performance(&self) -> T {
563        T::from(0.8).unwrap_or_default()
564    }
565}
566/// Reverse mode automatic differentiation
567#[derive(Debug)]
568pub struct ReverseModeAD<T: Float + Debug + Send + Sync + 'static> {
569    /// Computational tape
570    tape: Vec<TapeEntry<T>>,
571    /// Adjoint values
572    adjoints: HashMap<usize, T>,
573    /// Gradient accumulator
574    gradient_accumulator: Array1<T>,
575}
576impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> ReverseModeAD<T> {
577    /// Create a new reverse mode AD engine
578    pub fn new() -> Result<Self> {
579        Ok(Self {
580            tape: Vec::new(),
581            adjoints: HashMap::new(),
582            gradient_accumulator: Array1::zeros(1),
583        })
584    }
585}
586/// Meta-optimization tracker
587pub struct MetaOptimizationTracker<T: Float + Debug + Send + Sync + 'static> {
588    step_count: usize,
589    _phantom: std::marker::PhantomData<T>,
590}
591impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> MetaOptimizationTracker<T> {
592    pub fn new() -> Self {
593        Self {
594            step_count: 0,
595            _phantom: std::marker::PhantomData,
596        }
597    }
598    pub fn record_epoch(
599        &mut self,
600        _epoch: usize,
601        _training_result: &TrainingResult,
602        _validation_result: &ValidationResult,
603    ) -> Result<()> {
604        self.step_count += 1;
605        Ok(())
606    }
607    pub fn update_best_parameters(&mut self, _metaparameters: &MetaParameters<T>) -> Result<()> {
608        Ok(())
609    }
610    pub fn total_tasks_seen(&self) -> usize {
611        self.step_count * 10
612    }
613    pub fn adaptation_efficiency(&self) -> T {
614        T::from(0.9).unwrap_or_default()
615    }
616}
617/// Continual learning result
618#[derive(Debug, Clone)]
619pub struct ContinualLearningResult<T: Float + Debug + Send + Sync + 'static> {
620    pub sequence_results: Vec<TaskResult<T>>,
621    pub forgetting_measure: T,
622    pub adaptation_efficiency: T,
623}
624/// Anti-forgetting strategies
625#[derive(Debug, Clone, Copy)]
626pub enum AntiForgettingStrategy {
627    ElasticWeightConsolidation,
628    SynapticIntelligence,
629    MemoryReplay,
630    ProgressiveNetworks,
631    PackNet,
632    Piggyback,
633    HAT,
634}
635/// Training result for meta-learning
636#[derive(Debug, Clone)]
637pub struct TrainingResult {
638    /// Training loss
639    pub training_loss: f64,
640    /// Training metrics
641    pub metrics: HashMap<String, f64>,
642    /// Number of training steps
643    pub steps: usize,
644}
645/// Meta-training results
646#[derive(Debug, Clone)]
647pub struct MetaTrainingResults<T: Float + Debug + Send + Sync + 'static> {
648    pub final_parameters: HashMap<String, Array1<T>>,
649    pub training_history: Vec<MetaTrainingEpoch<T>>,
650    pub best_performance: T,
651    pub total_epochs: usize,
652}
653/// Meta-Learning Framework for Learned Optimizers
654pub struct MetaLearningFramework<T: Float + Debug + Send + Sync + 'static> {
655    /// Meta-learning configuration
656    config: MetaLearningConfig,
657    /// Meta-learner implementation
658    meta_learner: Box<dyn MetaLearner<T> + Send + Sync>,
659    /// Task distribution manager
660    task_manager: TaskDistributionManager<T>,
661    /// Meta-validation system
662    meta_validator: MetaValidator<T>,
663    /// Adaptation engine
664    adaptation_engine: AdaptationEngine<T>,
665    /// Transfer learning manager
666    transfer_manager: TransferLearningManager<T>,
667    /// Continual learning system
668    continual_learner: ContinualLearningSystem<T>,
669    /// Multi-task coordinator
670    multitask_coordinator: MultiTaskCoordinator<T>,
671    /// Meta-optimization tracker
672    meta_tracker: MetaOptimizationTracker<T>,
673    /// Few-shot learning specialist
674    few_shot_learner: FewShotLearner<T>,
675}
676impl<
677        T: Float
678            + Default
679            + Clone
680            + Send
681            + Sync
682            + std::iter::Sum
683            + for<'a> std::iter::Sum<&'a T>
684            + scirs2_core::ndarray::ScalarOperand
685            + std::fmt::Debug,
686    > MetaLearningFramework<T>
687{
688    /// Create a new meta-learning framework
689    pub fn new(config: MetaLearningConfig) -> Result<Self> {
690        let meta_learner = Self::create_meta_learner(&config)?;
691        let task_manager = TaskDistributionManager::new(&config)?;
692        let meta_validator = MetaValidator::new(&config)?;
693        let adaptation_engine = AdaptationEngine::new(&config)?;
694        let transfer_manager = TransferLearningManager::new(&config.transfer_settings)?;
695        let continual_learner = ContinualLearningSystem::new(&config.continual_settings)?;
696        let multitask_coordinator = MultiTaskCoordinator::new(&config.multitask_settings)?;
697        let meta_tracker = MetaOptimizationTracker::new();
698        let few_shot_learner = FewShotLearner::new(&config.few_shot_settings)?;
699        Ok(Self {
700            config,
701            meta_learner,
702            task_manager,
703            meta_validator,
704            adaptation_engine,
705            transfer_manager,
706            continual_learner,
707            multitask_coordinator,
708            meta_tracker,
709            few_shot_learner,
710        })
711    }
712    fn create_meta_learner(
713        config: &MetaLearningConfig,
714    ) -> Result<Box<dyn MetaLearner<T> + Send + Sync>> {
715        match config.algorithm {
716            MetaLearningAlgorithm::MAML => {
717                let maml_config = MAMLConfig {
718                    second_order: config.second_order,
719                    inner_lr: scirs2_core::numeric::NumCast::from(config.inner_learning_rate)
720                        .unwrap_or_else(|| T::zero()),
721                    outer_lr: scirs2_core::numeric::NumCast::from(config.meta_learning_rate)
722                        .unwrap_or_else(|| T::zero()),
723                    inner_steps: config.inner_steps,
724                    allow_unused: true,
725                    gradient_clip: Some(config.gradient_clip),
726                };
727                Ok(Box::new(MAMLLearner::<T, scirs2_core::ndarray::Ix1>::new(
728                    maml_config,
729                )?))
730            }
731            _ => {
732                let maml_config = MAMLConfig {
733                    second_order: false,
734                    inner_lr: scirs2_core::numeric::NumCast::from(config.inner_learning_rate)
735                        .unwrap_or_else(|| T::zero()),
736                    outer_lr: scirs2_core::numeric::NumCast::from(config.meta_learning_rate)
737                        .unwrap_or_else(|| T::zero()),
738                    inner_steps: config.inner_steps,
739                    allow_unused: true,
740                    gradient_clip: Some(config.gradient_clip),
741                };
742                Ok(Box::new(MAMLLearner::<T, scirs2_core::ndarray::Ix1>::new(
743                    maml_config,
744                )?))
745            }
746        }
747    }
748    /// Perform meta-training
749    pub async fn meta_train(
750        &mut self,
751        tasks: Vec<MetaTask<T>>,
752        num_epochs: usize,
753    ) -> Result<MetaTrainingResults<T>> {
754        let meta_params_raw = self.initialize_meta_parameters()?;
755        let mut meta_parameters = MetaParameters {
756            parameters: meta_params_raw,
757            metadata: HashMap::new(),
758        };
759        let mut training_history = Vec::new();
760        let mut best_performance = T::neg_infinity();
761        for epoch in 0..num_epochs {
762            let task_batch = self
763                .task_manager
764                .sample_task_batch(&tasks, self.config.task_batch_size)?;
765            let training_result = self
766                .meta_learner
767                .meta_train_step(&task_batch, &mut meta_parameters.parameters)?;
768            self.update_meta_parameters(
769                &mut meta_parameters.parameters,
770                &training_result.meta_gradients,
771            )?;
772            let validation_result = self.meta_validator.validate(&meta_parameters, &tasks)?;
773            let training_result_simple = TrainingResult {
774                training_loss: training_result.meta_loss.to_f64().unwrap_or(0.0),
775                metrics: HashMap::new(),
776                steps: epoch,
777            };
778            self.meta_tracker
779                .record_epoch(epoch, &training_result_simple, &validation_result)?;
780            let current_performance =
781                T::from(-validation_result.validation_loss).unwrap_or_default();
782            if current_performance > best_performance {
783                best_performance = current_performance;
784                self.meta_tracker.update_best_parameters(&meta_parameters)?;
785            }
786            let meta_validation_result = MetaValidationResult {
787                performance: current_performance,
788                adaptation_speed: T::from(0.0).unwrap_or_default(),
789                generalization_gap: T::from(validation_result.validation_loss).unwrap_or_default(),
790                task_specific_metrics: HashMap::new(),
791            };
792            training_history.push(MetaTrainingEpoch {
793                epoch,
794                training_result,
795                validation_result: meta_validation_result,
796                meta_parameters: meta_parameters.parameters.clone(),
797            });
798            if self.should_early_stop(&training_history) {
799                break;
800            }
801        }
802        let total_epochs = training_history.len();
803        Ok(MetaTrainingResults {
804            final_parameters: meta_parameters.parameters,
805            training_history,
806            best_performance,
807            total_epochs,
808        })
809    }
810    /// Adapt to new task
811    pub fn adapt_to_task(
812        &mut self,
813        task: &MetaTask<T>,
814        meta_parameters: &HashMap<String, Array1<T>>,
815    ) -> Result<TaskAdaptationResult<T>> {
816        self.adaptation_engine.adapt(
817            task,
818            meta_parameters,
819            &mut *self.meta_learner,
820            self.config.inner_steps,
821        )
822    }
823    /// Perform few-shot learning
824    pub fn few_shot_learning(
825        &mut self,
826        support_set: &TaskDataset<T>,
827        query_set: &TaskDataset<T>,
828        meta_parameters: &HashMap<String, Array1<T>>,
829    ) -> Result<FewShotResult<T>> {
830        self.few_shot_learner
831            .learn(support_set, query_set, meta_parameters)
832    }
833    /// Transfer learning to new domain
834    pub fn transfer_to_domain(
835        &mut self,
836        source_tasks: &[MetaTask<T>],
837        target_tasks: &[MetaTask<T>],
838        meta_parameters: &HashMap<String, Array1<T>>,
839    ) -> Result<TransferLearningResult<T>> {
840        self.transfer_manager
841            .transfer(source_tasks, target_tasks, meta_parameters)
842    }
843    /// Continual learning across task sequence
844    pub fn continual_learning(
845        &mut self,
846        task_sequence: &[MetaTask<T>],
847        meta_parameters: &mut HashMap<String, Array1<T>>,
848    ) -> Result<ContinualLearningResult<T>> {
849        self.continual_learner
850            .learn_sequence(task_sequence, meta_parameters)
851    }
852    /// Multi-task learning
853    pub fn multi_task_learning(
854        &mut self,
855        tasks: &[MetaTask<T>],
856        meta_parameters: &mut HashMap<String, Array1<T>>,
857    ) -> Result<MultiTaskResult<T>> {
858        self.multitask_coordinator
859            .learn_simultaneously(tasks, meta_parameters)
860    }
861    fn initialize_meta_parameters(&self) -> Result<HashMap<String, Array1<T>>> {
862        let mut parameters = HashMap::new();
863        parameters.insert("lstm_weights".to_string(), Array1::zeros(256 * 4));
864        parameters.insert("output_weights".to_string(), Array1::zeros(256));
865        Ok(parameters)
866    }
867    fn update_meta_parameters(
868        &self,
869        meta_parameters: &mut HashMap<String, Array1<T>>,
870        meta_gradients: &HashMap<String, Array1<T>>,
871    ) -> Result<()> {
872        let meta_lr = scirs2_core::numeric::NumCast::from(self.config.meta_learning_rate)
873            .unwrap_or_else(|| T::zero());
874        for (name, gradient) in meta_gradients {
875            if let Some(parameter) = meta_parameters.get_mut(name) {
876                for i in 0..parameter.len() {
877                    parameter[i] = parameter[i] - meta_lr * gradient[i];
878                }
879            }
880        }
881        Ok(())
882    }
883    fn should_early_stop(&self, history: &[MetaTrainingEpoch<T>]) -> bool {
884        if history.len() < 10 {
885            return false;
886        }
887        let recent_performances: Vec<_> = history
888            .iter()
889            .rev()
890            .take(5)
891            .map(|epoch| epoch.validation_result.performance)
892            .collect();
893        let max_recent = recent_performances
894            .iter()
895            .fold(T::neg_infinity(), |a, &b| a.max(b));
896        let min_recent = recent_performances
897            .iter()
898            .fold(T::infinity(), |a, &b| a.min(b));
899        let performance_range = max_recent - min_recent;
900        let threshold = scirs2_core::numeric::NumCast::from(1e-4).unwrap_or_else(|| T::zero());
901        performance_range < threshold
902    }
903    /// Get meta-learning statistics
904    pub fn get_meta_learning_statistics(&self) -> MetaLearningStatistics<T> {
905        MetaLearningStatistics {
906            algorithm: self.config.algorithm,
907            total_tasks_seen: self.meta_tracker.total_tasks_seen(),
908            adaptation_efficiency: self.meta_tracker.adaptation_efficiency(),
909            transfer_success_rate: self.transfer_manager.success_rate(),
910            forgetting_measure: self.continual_learner.forgetting_measure(),
911            multitask_interference: self.multitask_coordinator.interference_measure(),
912            few_shot_performance: self.few_shot_learner.average_performance(),
913        }
914    }
915}
916/// Gradient computation methods
917#[derive(Debug, Clone, Copy)]
918pub enum GradientComputationMethod {
919    FiniteDifference,
920    AutomaticDifferentiation,
921    SymbolicDifferentiation,
922    Hybrid,
923}
924/// Dataset metadata
925#[derive(Debug, Clone)]
926pub struct DatasetMetadata {
927    /// Number of samples
928    pub num_samples: usize,
929    /// Feature dimension
930    pub feature_dim: usize,
931    /// Data distribution type
932    pub distribution_type: String,
933    /// Noise level
934    pub noise_level: f64,
935}
936/// Transfer learning manager
937pub struct TransferLearningManager<T: Float + Debug + Send + Sync + 'static> {
938    settings: TransferLearningSettings,
939    _phantom: std::marker::PhantomData<T>,
940}
941impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> TransferLearningManager<T> {
942    pub fn new(settings: &TransferLearningSettings) -> Result<Self> {
943        Ok(Self {
944            settings: settings.clone(),
945            _phantom: std::marker::PhantomData,
946        })
947    }
948    pub fn transfer(
949        &mut self,
950        _source_tasks: &[MetaTask<T>],
951        _target_tasks: &[MetaTask<T>],
952        _meta_parameters: &HashMap<String, Array1<T>>,
953    ) -> Result<TransferLearningResult<T>> {
954        Ok(TransferLearningResult {
955            transfer_efficiency: T::from(0.85).unwrap_or_default(),
956            domain_adaptation_score: T::from(0.8).unwrap_or_default(),
957            source_task_retention: T::from(0.9).unwrap_or_default(),
958            target_task_performance: T::from(0.8).unwrap_or_default(),
959        })
960    }
961    pub fn success_rate(&self) -> T {
962        T::from(0.85).unwrap_or_default()
963    }
964}
965/// Second-order gradient engine
966#[derive(Debug)]
967pub struct SecondOrderGradientEngine<T: Float + Debug + Send + Sync + 'static> {
968    /// Hessian computation method
969    hessian_method: HessianComputationMethod,
970    /// Hessian matrix
971    hessian: Array2<T>,
972    /// Hessian-vector product engine
973    hvp_engine: HessianVectorProductEngine<T>,
974    /// Curvature estimation
975    curvature_estimator: CurvatureEstimator<T>,
976}
977impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> SecondOrderGradientEngine<T> {
978    /// Create a new second-order gradient engine
979    pub fn new() -> Result<Self> {
980        Ok(Self {
981            hessian_method: HessianComputationMethod::BFGS,
982            hessian: Array2::zeros((1, 1)),
983            hvp_engine: HessianVectorProductEngine::new()?,
984            curvature_estimator: CurvatureEstimator::new()?,
985        })
986    }
987}
988/// Meta-learning statistics
989#[derive(Debug, Clone)]
990pub struct MetaLearningStatistics<T: Float + Debug + Send + Sync + 'static> {
991    pub algorithm: MetaLearningAlgorithm,
992    pub total_tasks_seen: usize,
993    pub adaptation_efficiency: T,
994    pub transfer_success_rate: T,
995    pub forgetting_measure: T,
996    pub multitask_interference: T,
997    pub few_shot_performance: T,
998}
999/// Task adaptation result
1000#[derive(Debug, Clone)]
1001pub struct TaskAdaptationResult<T: Float + Debug + Send + Sync + 'static> {
1002    /// Adapted parameters
1003    pub adapted_parameters: HashMap<String, Array1<T>>,
1004    /// Adaptation trajectory
1005    pub adaptation_trajectory: Vec<AdaptationStep<T>>,
1006    /// Final adaptation loss
1007    pub final_loss: T,
1008    /// Adaptation metrics
1009    pub metrics: TaskAdaptationMetrics<T>,
1010}
1011/// Transfer strategies
1012#[derive(Debug, Clone, Copy)]
1013pub enum TransferStrategy {
1014    FeatureExtraction,
1015    FineTuning,
1016    DomainAdaptation,
1017    MultiTask,
1018    MetaTransfer,
1019    Progressive,
1020}
1021/// Augmentation strategies
1022#[derive(Debug, Clone, Copy)]
1023pub enum AugmentationStrategy {
1024    Geometric,
1025    Color,
1026    Noise,
1027    Mixup,
1028    CutMix,
1029    Learned,
1030}
1031/// Stability metrics
1032#[derive(Debug, Clone)]
1033pub struct StabilityMetrics<T: Float + Debug + Send + Sync + 'static> {
1034    /// Parameter stability
1035    pub parameter_stability: T,
1036    /// Performance stability
1037    pub performance_stability: T,
1038    /// Gradient stability
1039    pub gradient_stability: T,
1040    /// Catastrophic forgetting measure
1041    pub forgetting_measure: T,
1042}
1043/// Gradient computation engine
1044#[derive(Debug)]
1045pub struct GradientComputationEngine<T: Float + Debug + Send + Sync + 'static> {
1046    /// Gradient computation method
1047    method: GradientComputationMethod,
1048    /// Computational graph
1049    computation_graph: ComputationGraph<T>,
1050    /// Gradient cache
1051    gradient_cache: HashMap<String, Array1<T>>,
1052    /// Automatic differentiation engine
1053    autodiff_engine: AutoDiffEngine<T>,
1054}
1055impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> GradientComputationEngine<T> {
1056    /// Create a new gradient computation engine
1057    pub fn new() -> Result<Self> {
1058        Ok(Self {
1059            method: GradientComputationMethod::AutomaticDifferentiation,
1060            computation_graph: ComputationGraph::new()?,
1061            gradient_cache: HashMap::new(),
1062            autodiff_engine: AutoDiffEngine::new()?,
1063        })
1064    }
1065}
1066/// Task result for meta-learning
1067#[derive(Debug, Clone)]
1068pub struct TaskResult<T: Float + Debug + Send + Sync + 'static> {
1069    pub task_id: String,
1070    pub loss: T,
1071    pub metrics: HashMap<String, T>,
1072}
1073/// Multi-task coordinator
1074pub struct MultiTaskCoordinator<T: Float + Debug + Send + Sync + 'static> {
1075    settings: MultiTaskSettings,
1076    _phantom: std::marker::PhantomData<T>,
1077}
1078impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> MultiTaskCoordinator<T> {
1079    pub fn new(settings: &MultiTaskSettings) -> Result<Self> {
1080        Ok(Self {
1081            settings: settings.clone(),
1082            _phantom: std::marker::PhantomData,
1083        })
1084    }
1085    pub fn learn_simultaneously(
1086        &mut self,
1087        tasks: &[MetaTask<T>],
1088        _meta_parameters: &mut HashMap<String, Array1<T>>,
1089    ) -> Result<MultiTaskResult<T>> {
1090        let mut task_results = Vec::new();
1091        for task in tasks {
1092            let task_result = TaskResult {
1093                task_id: task.id.clone(),
1094                loss: scirs2_core::numeric::NumCast::from(0.1).unwrap_or_else(|| T::zero()),
1095                metrics: HashMap::new(),
1096            };
1097            task_results.push(task_result);
1098        }
1099        Ok(MultiTaskResult {
1100            task_results,
1101            coordination_overhead: scirs2_core::numeric::NumCast::from(0.01)
1102                .unwrap_or_else(|| T::zero()),
1103            convergence_status: "converged".to_string(),
1104        })
1105    }
1106    pub fn interference_measure(&self) -> T {
1107        T::from(0.1).unwrap_or_default()
1108    }
1109}
1110/// HVP computation methods
1111#[derive(Debug, Clone, Copy)]
1112pub enum HVPComputationMethod {
1113    FiniteDifference,
1114    AutomaticDifferentiation,
1115    ConjugateGradient,
1116}
1117/// Computation graph node
1118#[derive(Debug, Clone)]
1119pub struct ComputationNode<T: Float + Debug + Send + Sync + 'static> {
1120    /// Node ID
1121    pub id: usize,
1122    /// Operation type
1123    pub operation: ComputationOperation<T>,
1124    /// Input connections
1125    pub inputs: Vec<usize>,
1126    /// Output value
1127    pub output: Option<Array1<T>>,
1128    /// Gradient w.r.t. this node
1129    pub gradient: Option<Array1<T>>,
1130}
1131/// Meta-learning algorithms
1132#[derive(Debug, Clone, Copy)]
1133pub enum MetaLearningAlgorithm {
1134    /// Model-Agnostic Meta-Learning
1135    MAML,
1136    /// First-Order MAML (FOMAML)
1137    FOMAML,
1138    /// Reptile algorithm
1139    Reptile,
1140    /// Meta-SGD
1141    MetaSGD,
1142    /// Learning to Learn by Gradient Descent
1143    L2L,
1144    /// Gradient-Based Meta-Learning
1145    GBML,
1146    /// Meta-Learning with Implicit Gradients
1147    IMaml,
1148    /// Prototypical Networks
1149    ProtoNet,
1150    /// Matching Networks
1151    MatchingNet,
1152    /// Relation Networks
1153    RelationNet,
1154    /// Memory-Augmented Neural Networks
1155    MANN,
1156    /// Meta-Learning with Warped Gradient Descent
1157    WarpGrad,
1158    /// Learned Gradient Descent
1159    LearnedGD,
1160}
1161/// Gradient balancing methods
1162#[derive(Debug, Clone, Copy)]
1163pub enum GradientBalancingMethod {
1164    Uniform,
1165    GradNorm,
1166    PCGrad,
1167    CAGrad,
1168    NashMTL,
1169}
1170/// Meta-validation system for meta-learning
1171pub struct MetaValidator<T: Float + Debug + Send + Sync + 'static> {
1172    config: MetaLearningConfig,
1173    _phantom: std::marker::PhantomData<T>,
1174}
1175impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> MetaValidator<T> {
1176    pub fn new(config: &MetaLearningConfig) -> Result<Self> {
1177        Ok(Self {
1178            config: config.clone(),
1179            _phantom: std::marker::PhantomData,
1180        })
1181    }
1182    pub fn validate(
1183        &self,
1184        _meta_parameters: &MetaParameters<T>,
1185        _tasks: &[MetaTask<T>],
1186    ) -> Result<ValidationResult> {
1187        Ok(ValidationResult {
1188            is_valid: true,
1189            validation_loss: 0.5,
1190            metrics: std::collections::HashMap::new(),
1191        })
1192    }
1193}
1194/// Query evaluation result
1195#[derive(Debug, Clone)]
1196pub struct QueryEvaluationResult<T: Float + Debug + Send + Sync + 'static> {
1197    /// Query set loss
1198    pub query_loss: T,
1199    /// Prediction accuracy
1200    pub accuracy: T,
1201    /// Per-sample predictions
1202    pub predictions: Vec<T>,
1203    /// Confidence scores
1204    pub confidence_scores: Vec<T>,
1205    /// Evaluation metrics
1206    pub metrics: QueryEvaluationMetrics<T>,
1207}
1208/// Few-shot learning algorithms
1209#[derive(Debug, Clone, Copy)]
1210pub enum FewShotAlgorithm {
1211    Prototypical,
1212    Matching,
1213    Relation,
1214    MAML,
1215    Reptile,
1216    MetaOptNet,
1217}
1218/// Activation functions
1219#[derive(Debug, Clone, Copy)]
1220pub enum ActivationFunction {
1221    ReLU,
1222    Sigmoid,
1223    Tanh,
1224    Softmax,
1225    GELU,
1226}
1227/// Curvature estimation methods
1228#[derive(Debug, Clone, Copy)]
1229pub enum CurvatureEstimationMethod {
1230    DiagonalHessian,
1231    BlockDiagonalHessian,
1232    KroneckerFactored,
1233    NaturalGradient,
1234}
1235/// Mixed mode automatic differentiation
1236#[derive(Debug)]
1237pub struct MixedModeAD<T: Float + Debug + Send + Sync + 'static> {
1238    /// Forward mode component
1239    forward_component: ForwardModeAD<T>,
1240    /// Reverse mode component
1241    reverse_component: ReverseModeAD<T>,
1242    /// Mode selection strategy
1243    mode_selection: ModeSelectionStrategy,
1244}
1245impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> MixedModeAD<T> {
1246    /// Create a new mixed mode AD engine
1247    pub fn new() -> Result<Self> {
1248        Ok(Self {
1249            forward_component: ForwardModeAD::new()?,
1250            reverse_component: ReverseModeAD::new()?,
1251            mode_selection: ModeSelectionStrategy::Adaptive,
1252        })
1253    }
1254}
1255/// Adaptation statistics
1256#[derive(Debug, Clone)]
1257pub struct AdaptationStatistics<T: Float + Debug + Send + Sync + 'static> {
1258    /// Convergence steps per task
1259    pub convergence_steps: Vec<usize>,
1260    /// Final losses per task
1261    pub final_losses: Vec<T>,
1262    /// Adaptation efficiency
1263    pub adaptation_efficiency: T,
1264    /// Stability metrics
1265    pub stability_metrics: StabilityMetrics<T>,
1266}
1267/// Meta-validation result
1268#[derive(Debug, Clone)]
1269pub struct MetaValidationResult<T: Float + Debug + Send + Sync + 'static> {
1270    pub performance: T,
1271    pub adaptation_speed: T,
1272    pub generalization_gap: T,
1273    pub task_specific_metrics: HashMap<String, T>,
1274}
1275/// Loss functions
1276#[derive(Debug, Clone, Copy)]
1277pub enum LossFunction {
1278    MeanSquaredError,
1279    CrossEntropy,
1280    Hinge,
1281    Huber,
1282}
1283/// Adaptation engine for meta-learning
1284pub struct AdaptationEngine<T: Float + Debug + Send + Sync + 'static> {
1285    config: MetaLearningConfig,
1286    _phantom: std::marker::PhantomData<T>,
1287}
1288impl<T: Float + Debug + Send + Sync + 'static + Default + Clone> AdaptationEngine<T> {
1289    pub fn new(config: &MetaLearningConfig) -> Result<Self> {
1290        Ok(Self {
1291            config: config.clone(),
1292            _phantom: std::marker::PhantomData,
1293        })
1294    }
1295    pub fn adapt(
1296        &mut self,
1297        task: &MetaTask<T>,
1298        _meta_parameters: &HashMap<String, Array1<T>>,
1299        _meta_learner: &mut dyn MetaLearner<T>,
1300        _inner_steps: usize,
1301    ) -> Result<TaskAdaptationResult<T>> {
1302        Ok(TaskAdaptationResult {
1303            adapted_parameters: _meta_parameters.clone(),
1304            adaptation_trajectory: Vec::new(),
1305            final_loss: T::from(0.1).unwrap_or_default(),
1306            metrics: TaskAdaptationMetrics {
1307                convergence_speed: T::from(1.0).unwrap_or_default(),
1308                final_performance: T::from(0.9).unwrap_or_default(),
1309                efficiency: T::from(0.8).unwrap_or_default(),
1310                robustness: T::from(0.85).unwrap_or_default(),
1311            },
1312        })
1313    }
1314}
1315/// Continual learning settings
1316#[derive(Debug, Clone)]
1317pub struct ContinualLearningSettings {
1318    /// Catastrophic forgetting mitigation
1319    pub anti_forgetting_strategies: Vec<AntiForgettingStrategy>,
1320    /// Memory replay settings
1321    pub memory_replay: MemoryReplaySettings,
1322    /// Task identification method
1323    pub task_identification: TaskIdentificationMethod,
1324    /// Plasticity-stability trade-off
1325    pub plasticity_stability_balance: f64,
1326}
1327/// Task weighting strategies
1328#[derive(Debug, Clone, Copy)]
1329pub enum TaskWeightingStrategy {
1330    Uniform,
1331    UncertaintyBased,
1332    GradientMagnitude,
1333    PerformanceBased,
1334    Adaptive,
1335    Learned,
1336}
1337/// Multi-task settings
1338#[derive(Debug, Clone)]
1339pub struct MultiTaskSettings {
1340    /// Task weighting strategy
1341    pub task_weighting: TaskWeightingStrategy,
1342    /// Gradient balancing method
1343    pub gradient_balancing: GradientBalancingMethod,
1344    /// Task interference mitigation
1345    pub interference_mitigation: InterferenceMitigationStrategy,
1346    /// Shared representation learning
1347    pub shared_representation: SharedRepresentationStrategy,
1348}
1349/// Shared representation strategies
1350#[derive(Debug, Clone, Copy)]
1351pub enum SharedRepresentationStrategy {
1352    HardSharing,
1353    SoftSharing,
1354    HierarchicalSharing,
1355    AttentionBased,
1356    Modular,
1357}
1358/// Distance metrics
1359#[derive(Debug, Clone, Copy)]
1360pub enum DistanceMetric {
1361    Euclidean,
1362    Cosine,
1363    Mahalanobis,
1364    Learned,
1365}
1366/// Dual number for forward mode AD
1367#[derive(Debug, Clone)]
1368pub struct DualNumber<T: Float + Debug + Send + Sync + 'static> {
1369    /// Real part
1370    pub real: T,
1371    /// Infinitesimal part
1372    pub dual: T,
1373}
1374/// Task adaptation metrics
1375#[derive(Debug, Clone)]
1376pub struct TaskAdaptationMetrics<T: Float + Debug + Send + Sync + 'static> {
1377    /// Convergence speed
1378    pub convergence_speed: T,
1379    /// Final performance
1380    pub final_performance: T,
1381    /// Adaptation efficiency
1382    pub efficiency: T,
1383    /// Robustness to noise
1384    pub robustness: T,
1385}
1386/// Mode selection strategies
1387#[derive(Debug, Clone, Copy)]
1388pub enum ModeSelectionStrategy {
1389    ForwardOnly,
1390    ReverseOnly,
1391    Adaptive,
1392    Hybrid,
1393}
1394/// Few-shot learning settings
1395#[derive(Debug, Clone)]
1396pub struct FewShotSettings {
1397    /// Number of shots (examples per class)
1398    pub num_shots: usize,
1399    /// Number of ways (classes)
1400    pub num_ways: usize,
1401    /// Few-shot algorithm
1402    pub algorithm: FewShotAlgorithm,
1403    /// Metric learning settings
1404    pub metric_learning: MetricLearningSettings,
1405    /// Augmentation strategies
1406    pub augmentation_strategies: Vec<AugmentationStrategy>,
1407}
1408/// Domain similarity measures
1409#[derive(Debug, Clone, Copy)]
1410pub enum SimilarityMeasure {
1411    CosineDistance,
1412    KLDivergence,
1413    WassersteinDistance,
1414    CentralMomentDiscrepancy,
1415    MaximumMeanDiscrepancy,
1416}
1417/// Computation operations
1418#[derive(Debug, Clone)]
1419pub enum ComputationOperation<T: Float + Debug + Send + Sync + 'static> {
1420    Add,
1421    Multiply,
1422    MatMul(Array2<T>),
1423    Activation(ActivationFunction),
1424    Loss(LossFunction),
1425    Parameter(Array1<T>),
1426    Input,
1427}
1428/// Meta-learning configuration
1429#[derive(Debug, Clone)]
1430pub struct MetaLearningConfig {
1431    /// Meta-learning algorithm
1432    pub algorithm: MetaLearningAlgorithm,
1433    /// Number of inner loop steps
1434    pub inner_steps: usize,
1435    /// Number of outer loop steps
1436    pub outer_steps: usize,
1437    /// Meta-learning rate
1438    pub meta_learning_rate: f64,
1439    /// Inner learning rate
1440    pub inner_learning_rate: f64,
1441    /// Task batch size
1442    pub task_batch_size: usize,
1443    /// Support set size per task
1444    pub support_set_size: usize,
1445    /// Query set size per task
1446    pub query_set_size: usize,
1447    /// Enable second-order gradients
1448    pub second_order: bool,
1449    /// Gradient clipping threshold
1450    pub gradient_clip: f64,
1451    /// Adaptation strategies
1452    pub adaptation_strategies: Vec<AdaptationStrategy>,
1453    /// Transfer learning settings
1454    pub transfer_settings: TransferLearningSettings,
1455    /// Continual learning settings
1456    pub continual_settings: ContinualLearningSettings,
1457    /// Multi-task settings
1458    pub multitask_settings: MultiTaskSettings,
1459    /// Few-shot learning settings
1460    pub few_shot_settings: FewShotSettings,
1461    /// Enable meta-regularization
1462    pub enable_meta_regularization: bool,
1463    /// Meta-regularization strength
1464    pub meta_regularization_strength: f64,
1465    /// Task sampling strategy
1466    pub task_sampling_strategy: TaskSamplingStrategy,
1467}
1468/// Task sampling strategies
1469#[derive(Debug, Clone, Copy)]
1470pub enum TaskSamplingStrategy {
1471    Uniform,
1472    Curriculum,
1473    DifficultyBased,
1474    DiversityBased,
1475    ActiveLearning,
1476    Adversarial,
1477}
1478/// Few-shot learning result
1479#[derive(Debug, Clone)]
1480pub struct FewShotResult<T: Float + Debug + Send + Sync + 'static> {
1481    pub accuracy: T,
1482    pub confidence: T,
1483    pub adaptation_steps: usize,
1484    pub uncertainty_estimates: Vec<T>,
1485}