optirs_core/online_learning/
mod.rs

1// Online learning and lifelong optimization
2//
3// This module provides optimization strategies for continuous learning scenarios,
4// including online learning, continual learning, and lifelong optimization systems.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
8use scirs2_core::numeric::Float;
9use scirs2_core::random::{thread_rng, Random};
10use std::collections::{HashMap, VecDeque};
11use std::fmt::Debug;
12
13/// Online learning strategy
14#[derive(Debug, Clone)]
15pub enum OnlineLearningStrategy {
16    /// Stochastic Gradient Descent with adaptive learning rate
17    AdaptiveSGD {
18        /// Initial learning rate
19        initial_lr: f64,
20        /// Learning rate adaptation method
21        adaptation_method: LearningRateAdaptation,
22    },
23    /// Online Newton's method (second-order)
24    OnlineNewton {
25        /// Damping parameter for stability
26        damping: f64,
27        /// Window size for Hessian estimation
28        hessian_window: usize,
29    },
30    /// Follow The Regularized Leader (FTRL)
31    FTRL {
32        /// L1 regularization strength
33        l1_regularization: f64,
34        /// L2 regularization strength
35        l2_regularization: f64,
36        /// Learning rate power
37        learning_rate_power: f64,
38    },
39    /// Online Mirror Descent
40    MirrorDescent {
41        /// Mirror function type
42        mirror_function: MirrorFunction,
43        /// Regularization strength
44        regularization: f64,
45    },
46    /// Adaptive Multi-Task Learning
47    AdaptiveMultiTask {
48        /// Task similarity threshold
49        similarity_threshold: f64,
50        /// Task-specific learning rates
51        task_lr_adaptation: bool,
52    },
53}
54
55/// Learning rate adaptation methods for online learning
56#[derive(Debug, Clone)]
57pub enum LearningRateAdaptation {
58    /// AdaGrad-style adaptation
59    AdaGrad {
60        /// Small constant for numerical stability
61        epsilon: f64,
62    },
63    /// RMSprop-style adaptation
64    RMSprop {
65        /// Decay rate
66        decay: f64,
67        /// Small constant for numerical stability
68        epsilon: f64,
69    },
70    /// Adam-style adaptation
71    Adam {
72        /// Exponential decay rate for first moment
73        beta1: f64,
74        /// Exponential decay rate for second moment
75        beta2: f64,
76        /// Small constant for numerical stability
77        epsilon: f64,
78    },
79    /// Exponential decay
80    ExponentialDecay {
81        /// Decay rate
82        decay_rate: f64,
83    },
84    /// Inverse scaling
85    InverseScaling {
86        /// Scaling power
87        power: f64,
88    },
89}
90
91/// Mirror functions for mirror descent
92#[derive(Debug, Clone)]
93pub enum MirrorFunction {
94    /// Euclidean (L2) regularization
95    Euclidean,
96    /// Entropy regularization (for probability simplex)
97    Entropy,
98    /// L1 regularization
99    L1,
100    /// Nuclear norm (for matrices)
101    Nuclear,
102}
103
104/// Lifelong learning strategy
105#[derive(Debug, Clone)]
106pub enum LifelongStrategy {
107    /// Elastic Weight Consolidation (EWC)
108    ElasticWeightConsolidation {
109        /// Importance weight for previous tasks
110        importance_weight: f64,
111        /// Fisher information estimation samples
112        fisher_samples: usize,
113    },
114    /// Progressive Neural Networks
115    ProgressiveNetworks {
116        /// Lateral connection strength
117        lateral_strength: f64,
118        /// Column growth strategy
119        growth_strategy: ColumnGrowthStrategy,
120    },
121    /// Memory-Augmented Networks
122    MemoryAugmented {
123        /// Memory size
124        memory_size: usize,
125        /// Memory update strategy
126        update_strategy: MemoryUpdateStrategy,
127    },
128    /// Meta-Learning Based Continual Learning
129    MetaLearning {
130        /// Meta-learning rate
131        meta_lr: f64,
132        /// Inner loop steps
133        inner_steps: usize,
134        /// Task embedding size
135        task_embedding_size: usize,
136    },
137    /// Gradient Episodic Memory (GEM)
138    GradientEpisodicMemory {
139        /// Memory buffer size per task
140        memory_per_task: usize,
141        /// Constraint violation tolerance
142        violation_tolerance: f64,
143    },
144}
145
146/// Column growth strategies for progressive networks
147#[derive(Debug, Clone)]
148pub enum ColumnGrowthStrategy {
149    /// Add new column for each task
150    PerTask,
151    /// Add new column when performance drops
152    PerformanceBased {
153        /// Performance threshold
154        threshold: f64,
155    },
156    /// Add new column after fixed intervals
157    FixedInterval {
158        /// Fixed interval
159        interval: usize,
160    },
161}
162
163/// Memory update strategies
164#[derive(Debug, Clone)]
165pub enum MemoryUpdateStrategy {
166    /// First In First Out
167    FIFO,
168    /// Random replacement
169    Random,
170    /// Importance-based replacement
171    ImportanceBased,
172    /// Gradient diversity
173    GradientDiversity,
174}
175
176/// Online optimizer that adapts to streaming data
177#[derive(Debug)]
178pub struct OnlineOptimizer<A: Float, D: Dimension> {
179    /// Online learning strategy
180    strategy: OnlineLearningStrategy,
181    /// Current parameters
182    parameters: Array<A, D>,
183    /// Accumulated gradients for adaptation
184    gradient_accumulator: Array<A, D>,
185    /// Second moment accumulator (for Adam-like methods)
186    second_moment_accumulator: Option<Array<A, D>>,
187    /// Current learning rate
188    current_lr: A,
189    /// Step counter
190    step_count: usize,
191    /// Performance history
192    performance_history: VecDeque<A>,
193    /// Regret bounds tracking
194    regret_bound: A,
195}
196
197/// Lifelong optimizer that learns continuously across tasks
198#[derive(Debug)]
199pub struct LifelongOptimizer<A: Float, D: Dimension> {
200    /// Lifelong learning strategy
201    strategy: LifelongStrategy,
202    /// Task-specific optimizers
203    task_optimizers: HashMap<String, OnlineOptimizer<A, D>>,
204    /// Shared knowledge across tasks
205    #[allow(dead_code)]
206    shared_knowledge: SharedKnowledge<A, D>,
207    /// Task sequence and relationships
208    task_graph: TaskGraph,
209    /// Memory buffer for important examples
210    memory_buffer: MemoryBuffer<A, D>,
211    /// Current active task
212    current_task: Option<String>,
213    /// Performance tracking across tasks
214    task_performance: HashMap<String, Vec<A>>,
215}
216
217/// Shared knowledge representation for lifelong learning
218#[derive(Debug)]
219pub struct SharedKnowledge<A: Float, D: Dimension> {
220    /// Fisher Information Matrix (for EWC)
221    #[allow(dead_code)]
222    fisher_information: Option<Array<A, D>>,
223    /// Important parameters (for EWC)
224    #[allow(dead_code)]
225    important_parameters: Option<Array<A, D>>,
226    /// Task embeddings
227    #[allow(dead_code)]
228    task_embeddings: HashMap<String, Array1<A>>,
229    /// Cross-task transfer weights
230    #[allow(dead_code)]
231    transfer_weights: HashMap<(String, String), A>,
232    /// Meta-parameters learned across tasks
233    #[allow(dead_code)]
234    meta_parameters: Option<Array1<A>>,
235}
236
237/// Task relationship graph
238#[derive(Debug)]
239pub struct TaskGraph {
240    /// Task relationships (similarity scores)
241    task_similarities: HashMap<(String, String), f64>,
242    /// Task dependencies
243    #[allow(dead_code)]
244    task_dependencies: HashMap<String, Vec<String>>,
245    /// Task categories/clusters
246    #[allow(dead_code)]
247    task_clusters: HashMap<String, String>,
248}
249
250/// Memory buffer for important examples
251#[derive(Debug)]
252pub struct MemoryBuffer<A: Float, D: Dimension> {
253    /// Stored examples
254    examples: VecDeque<MemoryExample<A, D>>,
255    /// Maximum buffer size
256    max_size: usize,
257    /// Update strategy
258    update_strategy: MemoryUpdateStrategy,
259    /// Importance scores
260    importance_scores: VecDeque<A>,
261}
262
263/// Single memory example
264#[derive(Debug, Clone)]
265pub struct MemoryExample<A: Float, D: Dimension> {
266    /// Input data
267    pub input: Array<A, D>,
268    /// Target output
269    pub target: Array<A, D>,
270    /// Task identifier
271    pub task_id: String,
272    /// Importance score
273    pub importance: A,
274    /// Gradient information
275    pub gradient: Option<Array<A, D>>,
276}
277
278/// Online learning performance metrics
279#[derive(Debug, Clone)]
280pub struct OnlinePerformanceMetrics<A: Float> {
281    /// Cumulative regret
282    pub cumulative_regret: A,
283    /// Average loss over window
284    pub average_loss: A,
285    /// Learning rate stability
286    pub lr_stability: A,
287    /// Adaptation speed
288    pub adaptation_speed: A,
289    /// Memory efficiency
290    pub memory_efficiency: A,
291}
292
293impl<A: Float + ScalarOperand + Debug + std::iter::Sum, D: Dimension + Send + Sync>
294    OnlineOptimizer<A, D>
295{
296    /// Create a new online optimizer
297    pub fn new(strategy: OnlineLearningStrategy, initial_parameters: Array<A, D>) -> Self {
298        let paramshape = initial_parameters.raw_dim();
299        let gradient_accumulator = Array::zeros(paramshape.clone());
300        let second_moment_accumulator = match &strategy {
301            OnlineLearningStrategy::AdaptiveSGD {
302                adaptation_method: LearningRateAdaptation::Adam { .. },
303                ..
304            } => Some(Array::zeros(paramshape)),
305            _ => None,
306        };
307
308        let current_lr = match &strategy {
309            OnlineLearningStrategy::AdaptiveSGD { initial_lr, .. } => A::from(*initial_lr).unwrap(),
310            OnlineLearningStrategy::OnlineNewton { .. } => A::from(0.01).unwrap(),
311            OnlineLearningStrategy::FTRL { .. } => A::from(0.1).unwrap(),
312            OnlineLearningStrategy::MirrorDescent { .. } => A::from(0.01).unwrap(),
313            OnlineLearningStrategy::AdaptiveMultiTask { .. } => A::from(0.001).unwrap(),
314        };
315
316        Self {
317            strategy,
318            parameters: initial_parameters,
319            gradient_accumulator,
320            second_moment_accumulator,
321            current_lr,
322            step_count: 0,
323            performance_history: VecDeque::new(),
324            regret_bound: A::zero(),
325        }
326    }
327
328    /// Perform online update with new gradient
329    pub fn online_update(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
330        self.step_count += 1;
331        self.performance_history.push_back(loss);
332
333        // Keep performance history bounded
334        if self.performance_history.len() > 1000 {
335            self.performance_history.pop_front();
336        }
337
338        match self.strategy.clone() {
339            OnlineLearningStrategy::AdaptiveSGD {
340                adaptation_method, ..
341            } => {
342                self.adaptive_sgd_update(gradient, &adaptation_method)?;
343            }
344            OnlineLearningStrategy::OnlineNewton { damping, .. } => {
345                self.online_newton_update(gradient, damping)?;
346            }
347            OnlineLearningStrategy::FTRL {
348                l1_regularization,
349                l2_regularization,
350                learning_rate_power,
351            } => {
352                self.ftrl_update(
353                    gradient,
354                    l1_regularization,
355                    l2_regularization,
356                    learning_rate_power,
357                )?;
358            }
359            OnlineLearningStrategy::MirrorDescent {
360                mirror_function,
361                regularization,
362            } => {
363                self.mirror_descent_update(gradient, &mirror_function, regularization)?;
364            }
365            OnlineLearningStrategy::AdaptiveMultiTask { .. } => {
366                self.adaptive_multitask_update(gradient)?;
367            }
368        }
369
370        // Update regret bound
371        self.update_regret_bound(loss);
372
373        Ok(())
374    }
375
376    /// Adaptive SGD update
377    fn adaptive_sgd_update(
378        &mut self,
379        gradient: &Array<A, D>,
380        adaptation: &LearningRateAdaptation,
381    ) -> Result<()> {
382        match adaptation {
383            LearningRateAdaptation::AdaGrad { epsilon } => {
384                // Accumulate squared gradients
385                self.gradient_accumulator = &self.gradient_accumulator + &gradient.mapv(|g| g * g);
386
387                // Compute adaptive learning rate
388                let adaptive_lr = self
389                    .gradient_accumulator
390                    .mapv(|acc| A::from(*epsilon).unwrap() + A::sqrt(acc));
391
392                // Update parameters
393                self.parameters = &self.parameters - &(gradient / &adaptive_lr * self.current_lr);
394            }
395            LearningRateAdaptation::RMSprop { decay, epsilon } => {
396                let decay_factor = A::from(*decay).unwrap();
397                let one_minus_decay = A::one() - decay_factor;
398
399                // Update moving average of squared gradients
400                self.gradient_accumulator = &self.gradient_accumulator * decay_factor
401                    + &gradient.mapv(|g| g * g * one_minus_decay);
402
403                // Compute adaptive learning rate
404                let adaptive_lr = self
405                    .gradient_accumulator
406                    .mapv(|acc| A::sqrt(acc + A::from(*epsilon).unwrap()));
407
408                // Update parameters
409                self.parameters = &self.parameters - &(gradient / &adaptive_lr * self.current_lr);
410            }
411            LearningRateAdaptation::Adam {
412                beta1,
413                beta2,
414                epsilon,
415            } => {
416                let beta1_val = A::from(*beta1).unwrap();
417                let beta2_val = A::from(*beta2).unwrap();
418                let one_minus_beta1 = A::one() - beta1_val;
419                let one_minus_beta2 = A::one() - beta2_val;
420
421                // Update first moment (gradient accumulator)
422                self.gradient_accumulator =
423                    &self.gradient_accumulator * beta1_val + gradient * one_minus_beta1;
424
425                // Update second moment
426                if let Some(ref mut second_moment) = self.second_moment_accumulator {
427                    *second_moment =
428                        &*second_moment * beta2_val + &gradient.mapv(|g| g * g * one_minus_beta2);
429
430                    // Bias correction
431                    let step_count_float = A::from(self.step_count).unwrap();
432                    let bias_correction1 = A::one() - A::powf(beta1_val, step_count_float);
433                    let bias_correction2 = A::one() - A::powf(beta2_val, step_count_float);
434
435                    let corrected_first = &self.gradient_accumulator / bias_correction1;
436                    let corrected_second = &*second_moment / bias_correction2;
437
438                    // Update parameters
439                    let adaptive_lr =
440                        corrected_second.mapv(|v| A::sqrt(v) + A::from(*epsilon).unwrap());
441                    self.parameters =
442                        &self.parameters - &(corrected_first / adaptive_lr * self.current_lr);
443                }
444            }
445            LearningRateAdaptation::ExponentialDecay { decay_rate } => {
446                // Simple exponential decay
447                self.current_lr = self.current_lr * A::from(*decay_rate).unwrap();
448                self.parameters = &self.parameters - gradient * self.current_lr;
449            }
450            LearningRateAdaptation::InverseScaling { power } => {
451                // Inverse scaling: lr = initial_lr / (step^power)
452                let step_power =
453                    A::powf(A::from(self.step_count).unwrap(), A::from(*power).unwrap());
454                let decayed_lr = self.current_lr / step_power;
455                self.parameters = &self.parameters - gradient * decayed_lr;
456            }
457        }
458
459        Ok(())
460    }
461
462    /// Online Newton's method update
463    fn online_newton_update(&mut self, gradient: &Array<A, D>, damping: f64) -> Result<()> {
464        // Simplified online Newton update with damping
465        let damping_val = A::from(damping).unwrap();
466
467        // Approximate Hessian diagonal with gradient squares (simplified)
468        let hessian_approx = gradient.mapv(|g| g * g + damping_val);
469
470        // Newton step
471        let newton_step = gradient / hessian_approx;
472        self.parameters = &self.parameters - &newton_step * self.current_lr;
473
474        Ok(())
475    }
476
477    /// FTRL update
478    fn ftrl_update(
479        &mut self,
480        gradient: &Array<A, D>,
481        l1_reg: f64,
482        l2_reg: f64,
483        lr_power: f64,
484    ) -> Result<()> {
485        // Accumulate gradients
486        self.gradient_accumulator = &self.gradient_accumulator + gradient;
487
488        // FTRL update rule (simplified)
489        let step_factor = A::powf(
490            A::from(self.step_count).unwrap(),
491            A::from(lr_power).unwrap(),
492        );
493        let learning_rate = self.current_lr / step_factor;
494
495        // Apply L1 and L2 regularization
496        let l1_weight = A::from(l1_reg).unwrap();
497        let l2_weight = A::from(l2_reg).unwrap();
498
499        self.parameters = self.gradient_accumulator.mapv(|g| {
500            let abs_g = A::abs(g);
501            if abs_g <= l1_weight {
502                A::zero()
503            } else {
504                let sign = if g > A::zero() { A::one() } else { -A::one() };
505                -sign * (abs_g - l1_weight) / (l2_weight + A::sqrt(abs_g))
506            }
507        }) * learning_rate;
508
509        Ok(())
510    }
511
512    /// Mirror descent update
513    fn mirror_descent_update(
514        &mut self,
515        gradient: &Array<A, D>,
516        mirror_fn: &MirrorFunction,
517        regularization: f64,
518    ) -> Result<()> {
519        match mirror_fn {
520            MirrorFunction::Euclidean => {
521                // Standard gradient descent
522                self.parameters = &self.parameters - gradient * self.current_lr;
523            }
524            MirrorFunction::Entropy => {
525                // Entropy regularized update (for probability simplex)
526                let reg_val = A::from(regularization).unwrap();
527                let updated = self
528                    .parameters
529                    .mapv(|p| A::exp(A::ln(p) - self.current_lr * reg_val));
530                let sum = updated.sum();
531                self.parameters = updated / sum; // Normalize to probability simplex
532            }
533            MirrorFunction::L1 => {
534                // L1 regularized update with soft thresholding
535                let threshold = self.current_lr * A::from(regularization).unwrap();
536                self.parameters = (&self.parameters - gradient * self.current_lr).mapv(|p| {
537                    if A::abs(p) <= threshold {
538                        A::zero()
539                    } else {
540                        p - A::signum(p) * threshold
541                    }
542                });
543            }
544            MirrorFunction::Nuclear => {
545                // Simplified nuclear norm update (requires matrix structure)
546                self.parameters = &self.parameters - gradient * self.current_lr;
547            }
548        }
549
550        Ok(())
551    }
552
553    /// Adaptive multi-task update
554    fn adaptive_multitask_update(&mut self, gradient: &Array<A, D>) -> Result<()> {
555        // Simplified multi-task update
556        self.parameters = &self.parameters - gradient * self.current_lr;
557        Ok(())
558    }
559
560    /// Update regret bound estimation
561    fn update_regret_bound(&mut self, loss: A) {
562        if let Some(&best_loss) = self
563            .performance_history
564            .iter()
565            .min_by(|a, b| a.partial_cmp(b).unwrap())
566        {
567            let regret = loss - best_loss;
568            self.regret_bound = self.regret_bound + regret.max(A::zero());
569        }
570    }
571
572    /// Get current parameters
573    pub fn parameters(&self) -> &Array<A, D> {
574        &self.parameters
575    }
576
577    /// Get performance metrics
578    pub fn get_performance_metrics(&self) -> OnlinePerformanceMetrics<A> {
579        let average_loss = if self.performance_history.is_empty() {
580            A::zero()
581        } else {
582            self.performance_history.iter().copied().sum::<A>()
583                / A::from(self.performance_history.len()).unwrap()
584        };
585
586        let lr_stability = A::from(1.0).unwrap(); // Simplified
587        let adaptation_speed = A::from(self.step_count as f64).unwrap(); // Simplified
588        let memory_efficiency = A::from(0.8).unwrap(); // Simplified
589
590        OnlinePerformanceMetrics {
591            cumulative_regret: self.regret_bound,
592            average_loss,
593            lr_stability,
594            adaptation_speed,
595            memory_efficiency,
596        }
597    }
598}
599
600impl<A: Float + ScalarOperand + Debug + std::iter::Sum, D: Dimension + Send + Sync>
601    LifelongOptimizer<A, D>
602{
603    /// Create a new lifelong optimizer
604    pub fn new(strategy: LifelongStrategy) -> Self {
605        Self {
606            strategy,
607            task_optimizers: HashMap::new(),
608            shared_knowledge: SharedKnowledge {
609                fisher_information: None,
610                important_parameters: None,
611                task_embeddings: HashMap::new(),
612                transfer_weights: HashMap::new(),
613                meta_parameters: None,
614            },
615            task_graph: TaskGraph {
616                task_similarities: HashMap::new(),
617                task_dependencies: HashMap::new(),
618                task_clusters: HashMap::new(),
619            },
620            memory_buffer: MemoryBuffer {
621                examples: VecDeque::new(),
622                max_size: 1000,
623                update_strategy: MemoryUpdateStrategy::FIFO,
624                importance_scores: VecDeque::new(),
625            },
626            current_task: None,
627            task_performance: HashMap::new(),
628        }
629    }
630
631    /// Start learning a new task
632    pub fn start_task(&mut self, task_id: String, initial_parameters: Array<A, D>) -> Result<()> {
633        self.current_task = Some(task_id.clone());
634
635        // Create task-specific optimizer
636        let online_strategy = OnlineLearningStrategy::AdaptiveSGD {
637            initial_lr: 0.001,
638            adaptation_method: LearningRateAdaptation::Adam {
639                beta1: 0.9,
640                beta2: 0.999,
641                epsilon: 1e-8,
642            },
643        };
644
645        let task_optimizer = OnlineOptimizer::new(online_strategy, initial_parameters);
646        self.task_optimizers.insert(task_id.clone(), task_optimizer);
647
648        // Initialize task performance tracking
649        self.task_performance.insert(task_id, Vec::new());
650
651        Ok(())
652    }
653
654    /// Update current task with new data
655    pub fn update_current_task(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
656        let task_id = self
657            .current_task
658            .as_ref()
659            .ok_or_else(|| OptimError::InvalidConfig("No current task set".to_string()))?
660            .clone();
661
662        // Update task-specific optimizer
663        if let Some(optimizer) = self.task_optimizers.get_mut(&task_id) {
664            optimizer.online_update(gradient, loss)?;
665        }
666
667        // Track performance
668        if let Some(performance) = self.task_performance.get_mut(&task_id) {
669            performance.push(loss);
670        }
671
672        // Apply lifelong learning strategy
673        match &self.strategy {
674            LifelongStrategy::ElasticWeightConsolidation {
675                importance_weight, ..
676            } => {
677                self.apply_ewc_regularization(gradient, *importance_weight)?;
678            }
679            LifelongStrategy::ProgressiveNetworks { .. } => {
680                self.apply_progressive_networks(gradient)?;
681            }
682            LifelongStrategy::MemoryAugmented { .. } => {
683                self.update_memory_buffer(gradient, loss)?;
684            }
685            LifelongStrategy::MetaLearning { .. } => {
686                self.apply_meta_learning(gradient)?;
687            }
688            LifelongStrategy::GradientEpisodicMemory { .. } => {
689                self.apply_gem_constraints(gradient)?;
690            }
691        }
692
693        Ok(())
694    }
695
696    /// Apply Elastic Weight Consolidation regularization
697    fn apply_ewc_regularization(
698        &mut self,
699        gradient: &Array<A, D>,
700        _importance_weight: f64,
701    ) -> Result<()> {
702        // Simplified EWC implementation
703        // In practice, this would compute Fisher Information Matrix and apply regularization
704        Ok(())
705    }
706
707    /// Apply Progressive Networks strategy
708    fn apply_progressive_networks(&mut self, gradient: &Array<A, D>) -> Result<()> {
709        // Simplified Progressive Networks implementation
710        // In practice, this would manage lateral connections between task columns
711        Ok(())
712    }
713
714    /// Update memory buffer with important examples
715    fn update_memory_buffer(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
716        if let Some(task_id) = &self.current_task {
717            let example = MemoryExample {
718                input: Array::zeros(gradient.raw_dim()),  // Placeholder
719                target: Array::zeros(gradient.raw_dim()), // Placeholder
720                task_id: task_id.clone(),
721                importance: loss,
722                gradient: Some(gradient.clone()),
723            };
724
725            // Add to buffer
726            if self.memory_buffer.examples.len() >= self.memory_buffer.max_size {
727                match self.memory_buffer.update_strategy {
728                    MemoryUpdateStrategy::FIFO => {
729                        self.memory_buffer.examples.pop_front();
730                        self.memory_buffer.importance_scores.pop_front();
731                    }
732                    MemoryUpdateStrategy::Random => {
733                        let idx = thread_rng().gen_range(0..self.memory_buffer.examples.len());
734                        self.memory_buffer.examples.remove(idx);
735                        self.memory_buffer.importance_scores.remove(idx);
736                    }
737                    MemoryUpdateStrategy::ImportanceBased => {
738                        // Remove least important example
739                        if let Some(min_idx) = self
740                            .memory_buffer
741                            .importance_scores
742                            .iter()
743                            .enumerate()
744                            .min_by(|a, b| a.1.partial_cmp(b.1).unwrap())
745                            .map(|(idx, _)| idx)
746                        {
747                            self.memory_buffer.examples.remove(min_idx);
748                            self.memory_buffer.importance_scores.remove(min_idx);
749                        }
750                    }
751                    MemoryUpdateStrategy::GradientDiversity => {
752                        // Remove most similar gradient (simplified)
753                        self.memory_buffer.examples.pop_front();
754                        self.memory_buffer.importance_scores.pop_front();
755                    }
756                }
757            }
758
759            self.memory_buffer.examples.push_back(example);
760            self.memory_buffer.importance_scores.push_back(loss);
761        }
762
763        Ok(())
764    }
765
766    /// Apply meta-learning strategy
767    fn apply_meta_learning(&mut self, gradient: &Array<A, D>) -> Result<()> {
768        // Simplified meta-learning implementation
769        // In practice, this would update meta-parameters based on task performance
770        Ok(())
771    }
772
773    /// Apply Gradient Episodic Memory constraints
774    fn apply_gem_constraints(&mut self, gradient: &Array<A, D>) -> Result<()> {
775        // Simplified GEM implementation
776        // In practice, this would project gradients to satisfy memory constraints
777        Ok(())
778    }
779
780    /// Compute task similarity
781    pub fn compute_task_similarity(&self, task1: &str, task2: &str) -> f64 {
782        self.task_graph
783            .task_similarities
784            .get(&(task1.to_string(), task2.to_string()))
785            .or_else(|| {
786                self.task_graph
787                    .task_similarities
788                    .get(&(task2.to_string(), task1.to_string()))
789            })
790            .copied()
791            .unwrap_or(0.0)
792    }
793
794    /// Get lifelong learning statistics
795    pub fn get_lifelong_stats(&self) -> LifelongStats<A> {
796        let num_tasks = self.task_optimizers.len();
797        let avg_performance = if self.task_performance.is_empty() {
798            A::zero()
799        } else {
800            let total_performance: A = self.task_performance.values().flatten().copied().sum();
801            let total_samples = self
802                .task_performance
803                .values()
804                .map(|v| v.len())
805                .sum::<usize>();
806            if total_samples > 0 {
807                total_performance / A::from(total_samples).unwrap()
808            } else {
809                A::zero()
810            }
811        };
812
813        LifelongStats {
814            num_tasks,
815            average_performance: avg_performance,
816            memory_usage: self.memory_buffer.examples.len(),
817            transfer_efficiency: A::from(0.8).unwrap(), // Placeholder
818            catastrophic_forgetting: A::from(0.1).unwrap(), // Placeholder
819        }
820    }
821}
822
823/// Lifelong learning statistics
824#[derive(Debug, Clone)]
825pub struct LifelongStats<A: Float> {
826    /// Number of tasks learned
827    pub num_tasks: usize,
828    /// Average performance across all tasks
829    pub average_performance: A,
830    /// Current memory usage
831    pub memory_usage: usize,
832    /// Transfer learning efficiency
833    pub transfer_efficiency: A,
834    /// Catastrophic forgetting measure
835    pub catastrophic_forgetting: A,
836}
837
838#[cfg(test)]
839mod tests {
840    use super::*;
841    use approx::assert_relative_eq;
842
843    #[test]
844    fn test_online_optimizer_creation() {
845        let strategy = OnlineLearningStrategy::AdaptiveSGD {
846            initial_lr: 0.01,
847            adaptation_method: LearningRateAdaptation::AdaGrad { epsilon: 1e-8 },
848        };
849
850        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
851        let optimizer = OnlineOptimizer::new(strategy, initial_params);
852
853        assert_eq!(optimizer.step_count, 0);
854        assert_relative_eq!(optimizer.current_lr, 0.01, epsilon = 1e-6);
855    }
856
857    #[test]
858    fn test_online_update() {
859        let strategy = OnlineLearningStrategy::AdaptiveSGD {
860            initial_lr: 0.1,
861            adaptation_method: LearningRateAdaptation::ExponentialDecay { decay_rate: 0.99 },
862        };
863
864        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
865        let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
866
867        let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
868        let loss = 0.5;
869
870        optimizer.online_update(&gradient, loss).unwrap();
871
872        assert_eq!(optimizer.step_count, 1);
873        assert_eq!(optimizer.performance_history.len(), 1);
874        assert_relative_eq!(optimizer.performance_history[0], 0.5, epsilon = 1e-6);
875    }
876
877    #[test]
878    fn test_lifelong_optimizer_creation() {
879        let strategy = LifelongStrategy::ElasticWeightConsolidation {
880            importance_weight: 1000.0,
881            fisher_samples: 100,
882        };
883
884        let optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
885
886        assert_eq!(optimizer.task_optimizers.len(), 0);
887        assert!(optimizer.current_task.is_none());
888    }
889
890    #[test]
891    fn test_task_management() {
892        let strategy = LifelongStrategy::MemoryAugmented {
893            memory_size: 100,
894            update_strategy: MemoryUpdateStrategy::FIFO,
895        };
896
897        let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
898        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
899
900        optimizer
901            .start_task("task1".to_string(), initial_params)
902            .unwrap();
903
904        assert_eq!(optimizer.current_task, Some("task1".to_string()));
905        assert!(optimizer.task_optimizers.contains_key("task1"));
906        assert!(optimizer.task_performance.contains_key("task1"));
907    }
908
909    #[test]
910    fn test_memory_buffer_update() {
911        let strategy = LifelongStrategy::MemoryAugmented {
912            memory_size: 2,
913            update_strategy: MemoryUpdateStrategy::FIFO,
914        };
915
916        let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
917        optimizer.memory_buffer.max_size = 2;
918
919        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
920        optimizer
921            .start_task("task1".to_string(), initial_params)
922            .unwrap();
923
924        let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
925
926        // Add first example
927        optimizer.update_current_task(&gradient, 0.5).unwrap();
928        assert_eq!(optimizer.memory_buffer.examples.len(), 1);
929
930        // Add second example
931        optimizer.update_current_task(&gradient, 0.6).unwrap();
932        assert_eq!(optimizer.memory_buffer.examples.len(), 2);
933
934        // Add third example (should remove first due to FIFO)
935        optimizer.update_current_task(&gradient, 0.7).unwrap();
936        assert_eq!(optimizer.memory_buffer.examples.len(), 2);
937    }
938
939    #[test]
940    fn test_performance_metrics() {
941        let strategy = OnlineLearningStrategy::AdaptiveSGD {
942            initial_lr: 0.01,
943            adaptation_method: LearningRateAdaptation::Adam {
944                beta1: 0.9,
945                beta2: 0.999,
946                epsilon: 1e-8,
947            },
948        };
949
950        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
951        let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
952
953        // Add some performance data
954        optimizer.performance_history.push_back(0.8);
955        optimizer.performance_history.push_back(0.6);
956        optimizer.performance_history.push_back(0.4);
957        optimizer.regret_bound = 0.5;
958
959        let metrics = optimizer.get_performance_metrics();
960
961        assert_relative_eq!(metrics.cumulative_regret, 0.5, epsilon = 1e-6);
962        assert_relative_eq!(metrics.average_loss, 0.6, epsilon = 1e-6);
963    }
964
965    #[test]
966    fn test_lifelong_stats() {
967        let strategy = LifelongStrategy::MetaLearning {
968            meta_lr: 0.001,
969            inner_steps: 5,
970            task_embedding_size: 64,
971        };
972
973        let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
974
975        // Add some tasks
976        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
977        optimizer
978            .start_task("task1".to_string(), initial_params.clone())
979            .unwrap();
980        optimizer
981            .start_task("task2".to_string(), initial_params)
982            .unwrap();
983
984        // Add some performance data
985        optimizer
986            .task_performance
987            .get_mut("task1")
988            .unwrap()
989            .extend(vec![0.8, 0.7]);
990        optimizer
991            .task_performance
992            .get_mut("task2")
993            .unwrap()
994            .extend(vec![0.9, 0.8]);
995
996        let stats = optimizer.get_lifelong_stats();
997
998        assert_eq!(stats.num_tasks, 2);
999        assert_relative_eq!(stats.average_performance, 0.8, epsilon = 1e-6);
1000    }
1001
1002    #[test]
1003    fn test_learning_rate_adaptations() {
1004        let strategies = vec![
1005            LearningRateAdaptation::AdaGrad { epsilon: 1e-8 },
1006            LearningRateAdaptation::RMSprop {
1007                decay: 0.9,
1008                epsilon: 1e-8,
1009            },
1010            LearningRateAdaptation::Adam {
1011                beta1: 0.9,
1012                beta2: 0.999,
1013                epsilon: 1e-8,
1014            },
1015            LearningRateAdaptation::ExponentialDecay { decay_rate: 0.99 },
1016            LearningRateAdaptation::InverseScaling { power: 0.5 },
1017        ];
1018
1019        for adaptation in strategies {
1020            let strategy = OnlineLearningStrategy::AdaptiveSGD {
1021                initial_lr: 0.01,
1022                adaptation_method: adaptation,
1023            };
1024
1025            let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1026            let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
1027
1028            let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1029            let result = optimizer.online_update(&gradient, 0.5);
1030
1031            assert!(result.is_ok());
1032            assert_eq!(optimizer.step_count, 1);
1033        }
1034    }
1035}