Skip to main content

optirs_core/online_learning/
mod.rs

1// Online learning and lifelong optimization
2//
3// This module provides optimization strategies for continuous learning scenarios,
4// including online learning, continual learning, and lifelong optimization systems.
5
6use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
8use scirs2_core::numeric::Float;
9use scirs2_core::random::{thread_rng, Random};
10use std::collections::{HashMap, VecDeque};
11use std::fmt::Debug;
12
13/// Online learning strategy
14#[derive(Debug, Clone)]
15pub enum OnlineLearningStrategy {
16    /// Stochastic Gradient Descent with adaptive learning rate
17    AdaptiveSGD {
18        /// Initial learning rate
19        initial_lr: f64,
20        /// Learning rate adaptation method
21        adaptation_method: LearningRateAdaptation,
22    },
23    /// Online Newton's method (second-order)
24    OnlineNewton {
25        /// Damping parameter for stability
26        damping: f64,
27        /// Window size for Hessian estimation
28        hessian_window: usize,
29    },
30    /// Follow The Regularized Leader (FTRL)
31    FTRL {
32        /// L1 regularization strength
33        l1_regularization: f64,
34        /// L2 regularization strength
35        l2_regularization: f64,
36        /// Learning rate power
37        learning_rate_power: f64,
38    },
39    /// Online Mirror Descent
40    MirrorDescent {
41        /// Mirror function type
42        mirror_function: MirrorFunction,
43        /// Regularization strength
44        regularization: f64,
45    },
46    /// Adaptive Multi-Task Learning
47    AdaptiveMultiTask {
48        /// Task similarity threshold
49        similarity_threshold: f64,
50        /// Task-specific learning rates
51        task_lr_adaptation: bool,
52    },
53}
54
55/// Learning rate adaptation methods for online learning
56#[derive(Debug, Clone)]
57pub enum LearningRateAdaptation {
58    /// AdaGrad-style adaptation
59    AdaGrad {
60        /// Small constant for numerical stability
61        epsilon: f64,
62    },
63    /// RMSprop-style adaptation
64    RMSprop {
65        /// Decay rate
66        decay: f64,
67        /// Small constant for numerical stability
68        epsilon: f64,
69    },
70    /// Adam-style adaptation
71    Adam {
72        /// Exponential decay rate for first moment
73        beta1: f64,
74        /// Exponential decay rate for second moment
75        beta2: f64,
76        /// Small constant for numerical stability
77        epsilon: f64,
78    },
79    /// Exponential decay
80    ExponentialDecay {
81        /// Decay rate
82        decay_rate: f64,
83    },
84    /// Inverse scaling
85    InverseScaling {
86        /// Scaling power
87        power: f64,
88    },
89}
90
91/// Mirror functions for mirror descent
92#[derive(Debug, Clone)]
93pub enum MirrorFunction {
94    /// Euclidean (L2) regularization
95    Euclidean,
96    /// Entropy regularization (for probability simplex)
97    Entropy,
98    /// L1 regularization
99    L1,
100    /// Nuclear norm (for matrices)
101    Nuclear,
102}
103
104/// Lifelong learning strategy
105#[derive(Debug, Clone)]
106pub enum LifelongStrategy {
107    /// Elastic Weight Consolidation (EWC)
108    ElasticWeightConsolidation {
109        /// Importance weight for previous tasks
110        importance_weight: f64,
111        /// Fisher information estimation samples
112        fisher_samples: usize,
113    },
114    /// Progressive Neural Networks
115    ProgressiveNetworks {
116        /// Lateral connection strength
117        lateral_strength: f64,
118        /// Column growth strategy
119        growth_strategy: ColumnGrowthStrategy,
120    },
121    /// Memory-Augmented Networks
122    MemoryAugmented {
123        /// Memory size
124        memory_size: usize,
125        /// Memory update strategy
126        update_strategy: MemoryUpdateStrategy,
127    },
128    /// Meta-Learning Based Continual Learning
129    MetaLearning {
130        /// Meta-learning rate
131        meta_lr: f64,
132        /// Inner loop steps
133        inner_steps: usize,
134        /// Task embedding size
135        task_embedding_size: usize,
136    },
137    /// Gradient Episodic Memory (GEM)
138    GradientEpisodicMemory {
139        /// Memory buffer size per task
140        memory_per_task: usize,
141        /// Constraint violation tolerance
142        violation_tolerance: f64,
143    },
144}
145
146/// Column growth strategies for progressive networks
147#[derive(Debug, Clone)]
148pub enum ColumnGrowthStrategy {
149    /// Add new column for each task
150    PerTask,
151    /// Add new column when performance drops
152    PerformanceBased {
153        /// Performance threshold
154        threshold: f64,
155    },
156    /// Add new column after fixed intervals
157    FixedInterval {
158        /// Fixed interval
159        interval: usize,
160    },
161}
162
163/// Memory update strategies
164#[derive(Debug, Clone)]
165pub enum MemoryUpdateStrategy {
166    /// First In First Out
167    FIFO,
168    /// Random replacement
169    Random,
170    /// Importance-based replacement
171    ImportanceBased,
172    /// Gradient diversity
173    GradientDiversity,
174}
175
176/// Online optimizer that adapts to streaming data
177#[derive(Debug)]
178pub struct OnlineOptimizer<A: Float, D: Dimension> {
179    /// Online learning strategy
180    strategy: OnlineLearningStrategy,
181    /// Current parameters
182    parameters: Array<A, D>,
183    /// Accumulated gradients for adaptation
184    gradient_accumulator: Array<A, D>,
185    /// Second moment accumulator (for Adam-like methods)
186    second_moment_accumulator: Option<Array<A, D>>,
187    /// Current learning rate
188    current_lr: A,
189    /// Step counter
190    step_count: usize,
191    /// Performance history
192    performance_history: VecDeque<A>,
193    /// Regret bounds tracking
194    regret_bound: A,
195}
196
197/// Lifelong optimizer that learns continuously across tasks
198#[derive(Debug)]
199pub struct LifelongOptimizer<A: Float, D: Dimension> {
200    /// Lifelong learning strategy
201    strategy: LifelongStrategy,
202    /// Task-specific optimizers
203    task_optimizers: HashMap<String, OnlineOptimizer<A, D>>,
204    /// Shared knowledge across tasks
205    #[allow(dead_code)]
206    shared_knowledge: SharedKnowledge<A, D>,
207    /// Task sequence and relationships
208    task_graph: TaskGraph,
209    /// Memory buffer for important examples
210    memory_buffer: MemoryBuffer<A, D>,
211    /// Current active task
212    current_task: Option<String>,
213    /// Performance tracking across tasks
214    task_performance: HashMap<String, Vec<A>>,
215}
216
217/// Shared knowledge representation for lifelong learning
218#[derive(Debug)]
219pub struct SharedKnowledge<A: Float, D: Dimension> {
220    /// Fisher Information Matrix (for EWC)
221    #[allow(dead_code)]
222    fisher_information: Option<Array<A, D>>,
223    /// Important parameters (for EWC)
224    #[allow(dead_code)]
225    important_parameters: Option<Array<A, D>>,
226    /// Task embeddings
227    #[allow(dead_code)]
228    task_embeddings: HashMap<String, Array1<A>>,
229    /// Cross-task transfer weights
230    #[allow(dead_code)]
231    transfer_weights: HashMap<(String, String), A>,
232    /// Meta-parameters learned across tasks
233    #[allow(dead_code)]
234    meta_parameters: Option<Array1<A>>,
235}
236
237/// Task relationship graph
238#[derive(Debug)]
239pub struct TaskGraph {
240    /// Task relationships (similarity scores)
241    task_similarities: HashMap<(String, String), f64>,
242    /// Task dependencies
243    #[allow(dead_code)]
244    task_dependencies: HashMap<String, Vec<String>>,
245    /// Task categories/clusters
246    #[allow(dead_code)]
247    task_clusters: HashMap<String, String>,
248}
249
250/// Memory buffer for important examples
251#[derive(Debug)]
252pub struct MemoryBuffer<A: Float, D: Dimension> {
253    /// Stored examples
254    examples: VecDeque<MemoryExample<A, D>>,
255    /// Maximum buffer size
256    max_size: usize,
257    /// Update strategy
258    update_strategy: MemoryUpdateStrategy,
259    /// Importance scores
260    importance_scores: VecDeque<A>,
261}
262
263/// Single memory example
264#[derive(Debug, Clone)]
265pub struct MemoryExample<A: Float, D: Dimension> {
266    /// Input data
267    pub input: Array<A, D>,
268    /// Target output
269    pub target: Array<A, D>,
270    /// Task identifier
271    pub task_id: String,
272    /// Importance score
273    pub importance: A,
274    /// Gradient information
275    pub gradient: Option<Array<A, D>>,
276}
277
278/// Online learning performance metrics
279#[derive(Debug, Clone)]
280pub struct OnlinePerformanceMetrics<A: Float> {
281    /// Cumulative regret
282    pub cumulative_regret: A,
283    /// Average loss over window
284    pub average_loss: A,
285    /// Learning rate stability
286    pub lr_stability: A,
287    /// Adaptation speed
288    pub adaptation_speed: A,
289    /// Memory efficiency
290    pub memory_efficiency: A,
291}
292
293impl<A: Float + ScalarOperand + Debug + std::iter::Sum, D: Dimension + Send + Sync>
294    OnlineOptimizer<A, D>
295{
296    /// Create a new online optimizer
297    pub fn new(strategy: OnlineLearningStrategy, initial_parameters: Array<A, D>) -> Self {
298        let paramshape = initial_parameters.raw_dim();
299        let gradient_accumulator = Array::zeros(paramshape.clone());
300        let second_moment_accumulator = match &strategy {
301            OnlineLearningStrategy::AdaptiveSGD {
302                adaptation_method: LearningRateAdaptation::Adam { .. },
303                ..
304            } => Some(Array::zeros(paramshape)),
305            _ => None,
306        };
307
308        let current_lr = match &strategy {
309            OnlineLearningStrategy::AdaptiveSGD { initial_lr, .. } => {
310                A::from(*initial_lr).expect("unwrap failed")
311            }
312            OnlineLearningStrategy::OnlineNewton { .. } => A::from(0.01).expect("unwrap failed"),
313            OnlineLearningStrategy::FTRL { .. } => A::from(0.1).expect("unwrap failed"),
314            OnlineLearningStrategy::MirrorDescent { .. } => A::from(0.01).expect("unwrap failed"),
315            OnlineLearningStrategy::AdaptiveMultiTask { .. } => {
316                A::from(0.001).expect("unwrap failed")
317            }
318        };
319
320        Self {
321            strategy,
322            parameters: initial_parameters,
323            gradient_accumulator,
324            second_moment_accumulator,
325            current_lr,
326            step_count: 0,
327            performance_history: VecDeque::new(),
328            regret_bound: A::zero(),
329        }
330    }
331
332    /// Perform online update with new gradient
333    pub fn online_update(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
334        self.step_count += 1;
335        self.performance_history.push_back(loss);
336
337        // Keep performance history bounded
338        if self.performance_history.len() > 1000 {
339            self.performance_history.pop_front();
340        }
341
342        match self.strategy.clone() {
343            OnlineLearningStrategy::AdaptiveSGD {
344                adaptation_method, ..
345            } => {
346                self.adaptive_sgd_update(gradient, &adaptation_method)?;
347            }
348            OnlineLearningStrategy::OnlineNewton { damping, .. } => {
349                self.online_newton_update(gradient, damping)?;
350            }
351            OnlineLearningStrategy::FTRL {
352                l1_regularization,
353                l2_regularization,
354                learning_rate_power,
355            } => {
356                self.ftrl_update(
357                    gradient,
358                    l1_regularization,
359                    l2_regularization,
360                    learning_rate_power,
361                )?;
362            }
363            OnlineLearningStrategy::MirrorDescent {
364                mirror_function,
365                regularization,
366            } => {
367                self.mirror_descent_update(gradient, &mirror_function, regularization)?;
368            }
369            OnlineLearningStrategy::AdaptiveMultiTask { .. } => {
370                self.adaptive_multitask_update(gradient)?;
371            }
372        }
373
374        // Update regret bound
375        self.update_regret_bound(loss);
376
377        Ok(())
378    }
379
380    /// Adaptive SGD update
381    fn adaptive_sgd_update(
382        &mut self,
383        gradient: &Array<A, D>,
384        adaptation: &LearningRateAdaptation,
385    ) -> Result<()> {
386        match adaptation {
387            LearningRateAdaptation::AdaGrad { epsilon } => {
388                // Accumulate squared gradients
389                self.gradient_accumulator = &self.gradient_accumulator + &gradient.mapv(|g| g * g);
390
391                // Compute adaptive learning rate
392                let adaptive_lr = self
393                    .gradient_accumulator
394                    .mapv(|acc| A::from(*epsilon).expect("unwrap failed") + A::sqrt(acc));
395
396                // Update parameters
397                self.parameters = &self.parameters - &(gradient / &adaptive_lr * self.current_lr);
398            }
399            LearningRateAdaptation::RMSprop { decay, epsilon } => {
400                let decay_factor = A::from(*decay).expect("unwrap failed");
401                let one_minus_decay = A::one() - decay_factor;
402
403                // Update moving average of squared gradients
404                self.gradient_accumulator = &self.gradient_accumulator * decay_factor
405                    + &gradient.mapv(|g| g * g * one_minus_decay);
406
407                // Compute adaptive learning rate
408                let adaptive_lr = self
409                    .gradient_accumulator
410                    .mapv(|acc| A::sqrt(acc + A::from(*epsilon).expect("unwrap failed")));
411
412                // Update parameters
413                self.parameters = &self.parameters - &(gradient / &adaptive_lr * self.current_lr);
414            }
415            LearningRateAdaptation::Adam {
416                beta1,
417                beta2,
418                epsilon,
419            } => {
420                let beta1_val = A::from(*beta1).expect("unwrap failed");
421                let beta2_val = A::from(*beta2).expect("unwrap failed");
422                let one_minus_beta1 = A::one() - beta1_val;
423                let one_minus_beta2 = A::one() - beta2_val;
424
425                // Update first moment (gradient accumulator)
426                self.gradient_accumulator =
427                    &self.gradient_accumulator * beta1_val + gradient * one_minus_beta1;
428
429                // Update second moment
430                if let Some(ref mut second_moment) = self.second_moment_accumulator {
431                    *second_moment =
432                        &*second_moment * beta2_val + &gradient.mapv(|g| g * g * one_minus_beta2);
433
434                    // Bias correction
435                    let step_count_float = A::from(self.step_count).expect("unwrap failed");
436                    let bias_correction1 = A::one() - A::powf(beta1_val, step_count_float);
437                    let bias_correction2 = A::one() - A::powf(beta2_val, step_count_float);
438
439                    let corrected_first = &self.gradient_accumulator / bias_correction1;
440                    let corrected_second = &*second_moment / bias_correction2;
441
442                    // Update parameters
443                    let adaptive_lr = corrected_second
444                        .mapv(|v| A::sqrt(v) + A::from(*epsilon).expect("unwrap failed"));
445                    self.parameters =
446                        &self.parameters - &(corrected_first / adaptive_lr * self.current_lr);
447                }
448            }
449            LearningRateAdaptation::ExponentialDecay { decay_rate } => {
450                // Simple exponential decay
451                self.current_lr = self.current_lr * A::from(*decay_rate).expect("unwrap failed");
452                self.parameters = &self.parameters - gradient * self.current_lr;
453            }
454            LearningRateAdaptation::InverseScaling { power } => {
455                // Inverse scaling: lr = initial_lr / (step^power)
456                let step_power = A::powf(
457                    A::from(self.step_count).expect("unwrap failed"),
458                    A::from(*power).expect("unwrap failed"),
459                );
460                let decayed_lr = self.current_lr / step_power;
461                self.parameters = &self.parameters - gradient * decayed_lr;
462            }
463        }
464
465        Ok(())
466    }
467
468    /// Online Newton's method update
469    fn online_newton_update(&mut self, gradient: &Array<A, D>, damping: f64) -> Result<()> {
470        // Simplified online Newton update with damping
471        let damping_val = A::from(damping).expect("unwrap failed");
472
473        // Approximate Hessian diagonal with gradient squares (simplified)
474        let hessian_approx = gradient.mapv(|g| g * g + damping_val);
475
476        // Newton step
477        let newton_step = gradient / hessian_approx;
478        self.parameters = &self.parameters - &newton_step * self.current_lr;
479
480        Ok(())
481    }
482
483    /// FTRL update
484    fn ftrl_update(
485        &mut self,
486        gradient: &Array<A, D>,
487        l1_reg: f64,
488        l2_reg: f64,
489        lr_power: f64,
490    ) -> Result<()> {
491        // Accumulate gradients
492        self.gradient_accumulator = &self.gradient_accumulator + gradient;
493
494        // FTRL update rule (simplified)
495        let step_factor = A::powf(
496            A::from(self.step_count).expect("unwrap failed"),
497            A::from(lr_power).expect("unwrap failed"),
498        );
499        let learning_rate = self.current_lr / step_factor;
500
501        // Apply L1 and L2 regularization
502        let l1_weight = A::from(l1_reg).expect("unwrap failed");
503        let l2_weight = A::from(l2_reg).expect("unwrap failed");
504
505        self.parameters = self.gradient_accumulator.mapv(|g| {
506            let abs_g = A::abs(g);
507            if abs_g <= l1_weight {
508                A::zero()
509            } else {
510                let sign = if g > A::zero() { A::one() } else { -A::one() };
511                -sign * (abs_g - l1_weight) / (l2_weight + A::sqrt(abs_g))
512            }
513        }) * learning_rate;
514
515        Ok(())
516    }
517
518    /// Mirror descent update
519    fn mirror_descent_update(
520        &mut self,
521        gradient: &Array<A, D>,
522        mirror_fn: &MirrorFunction,
523        regularization: f64,
524    ) -> Result<()> {
525        match mirror_fn {
526            MirrorFunction::Euclidean => {
527                // Standard gradient descent
528                self.parameters = &self.parameters - gradient * self.current_lr;
529            }
530            MirrorFunction::Entropy => {
531                // Entropy regularized update (for probability simplex)
532                let reg_val = A::from(regularization).expect("unwrap failed");
533                let updated = self
534                    .parameters
535                    .mapv(|p| A::exp(A::ln(p) - self.current_lr * reg_val));
536                let sum = updated.sum();
537                self.parameters = updated / sum; // Normalize to probability simplex
538            }
539            MirrorFunction::L1 => {
540                // L1 regularized update with soft thresholding
541                let threshold = self.current_lr * A::from(regularization).expect("unwrap failed");
542                self.parameters = (&self.parameters - gradient * self.current_lr).mapv(|p| {
543                    if A::abs(p) <= threshold {
544                        A::zero()
545                    } else {
546                        p - A::signum(p) * threshold
547                    }
548                });
549            }
550            MirrorFunction::Nuclear => {
551                // Simplified nuclear norm update (requires matrix structure)
552                self.parameters = &self.parameters - gradient * self.current_lr;
553            }
554        }
555
556        Ok(())
557    }
558
559    /// Adaptive multi-task update
560    fn adaptive_multitask_update(&mut self, gradient: &Array<A, D>) -> Result<()> {
561        // Simplified multi-task update
562        self.parameters = &self.parameters - gradient * self.current_lr;
563        Ok(())
564    }
565
566    /// Update regret bound estimation
567    fn update_regret_bound(&mut self, loss: A) {
568        if let Some(&best_loss) = self
569            .performance_history
570            .iter()
571            .min_by(|a, b| a.partial_cmp(b).expect("unwrap failed"))
572        {
573            let regret = loss - best_loss;
574            self.regret_bound = self.regret_bound + regret.max(A::zero());
575        }
576    }
577
578    /// Get current parameters
579    pub fn parameters(&self) -> &Array<A, D> {
580        &self.parameters
581    }
582
583    /// Get performance metrics
584    pub fn get_performance_metrics(&self) -> OnlinePerformanceMetrics<A> {
585        let average_loss = if self.performance_history.is_empty() {
586            A::zero()
587        } else {
588            self.performance_history.iter().copied().sum::<A>()
589                / A::from(self.performance_history.len()).expect("unwrap failed")
590        };
591
592        let lr_stability = A::from(1.0).expect("unwrap failed"); // Simplified
593        let adaptation_speed = A::from(self.step_count as f64).expect("unwrap failed"); // Simplified
594        let memory_efficiency = A::from(0.8).expect("unwrap failed"); // Simplified
595
596        OnlinePerformanceMetrics {
597            cumulative_regret: self.regret_bound,
598            average_loss,
599            lr_stability,
600            adaptation_speed,
601            memory_efficiency,
602        }
603    }
604}
605
606impl<A: Float + ScalarOperand + Debug + std::iter::Sum, D: Dimension + Send + Sync>
607    LifelongOptimizer<A, D>
608{
609    /// Create a new lifelong optimizer
610    pub fn new(strategy: LifelongStrategy) -> Self {
611        Self {
612            strategy,
613            task_optimizers: HashMap::new(),
614            shared_knowledge: SharedKnowledge {
615                fisher_information: None,
616                important_parameters: None,
617                task_embeddings: HashMap::new(),
618                transfer_weights: HashMap::new(),
619                meta_parameters: None,
620            },
621            task_graph: TaskGraph {
622                task_similarities: HashMap::new(),
623                task_dependencies: HashMap::new(),
624                task_clusters: HashMap::new(),
625            },
626            memory_buffer: MemoryBuffer {
627                examples: VecDeque::new(),
628                max_size: 1000,
629                update_strategy: MemoryUpdateStrategy::FIFO,
630                importance_scores: VecDeque::new(),
631            },
632            current_task: None,
633            task_performance: HashMap::new(),
634        }
635    }
636
637    /// Start learning a new task
638    pub fn start_task(&mut self, task_id: String, initial_parameters: Array<A, D>) -> Result<()> {
639        self.current_task = Some(task_id.clone());
640
641        // Create task-specific optimizer
642        let online_strategy = OnlineLearningStrategy::AdaptiveSGD {
643            initial_lr: 0.001,
644            adaptation_method: LearningRateAdaptation::Adam {
645                beta1: 0.9,
646                beta2: 0.999,
647                epsilon: 1e-8,
648            },
649        };
650
651        let task_optimizer = OnlineOptimizer::new(online_strategy, initial_parameters);
652        self.task_optimizers.insert(task_id.clone(), task_optimizer);
653
654        // Initialize task performance tracking
655        self.task_performance.insert(task_id, Vec::new());
656
657        Ok(())
658    }
659
660    /// Update current task with new data
661    pub fn update_current_task(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
662        let task_id = self
663            .current_task
664            .as_ref()
665            .ok_or_else(|| OptimError::InvalidConfig("No current task set".to_string()))?
666            .clone();
667
668        // Update task-specific optimizer
669        if let Some(optimizer) = self.task_optimizers.get_mut(&task_id) {
670            optimizer.online_update(gradient, loss)?;
671        }
672
673        // Track performance
674        if let Some(performance) = self.task_performance.get_mut(&task_id) {
675            performance.push(loss);
676        }
677
678        // Apply lifelong learning strategy
679        match &self.strategy {
680            LifelongStrategy::ElasticWeightConsolidation {
681                importance_weight, ..
682            } => {
683                self.apply_ewc_regularization(gradient, *importance_weight)?;
684            }
685            LifelongStrategy::ProgressiveNetworks { .. } => {
686                self.apply_progressive_networks(gradient)?;
687            }
688            LifelongStrategy::MemoryAugmented { .. } => {
689                self.update_memory_buffer(gradient, loss)?;
690            }
691            LifelongStrategy::MetaLearning { .. } => {
692                self.apply_meta_learning(gradient)?;
693            }
694            LifelongStrategy::GradientEpisodicMemory { .. } => {
695                self.apply_gem_constraints(gradient)?;
696            }
697        }
698
699        Ok(())
700    }
701
702    /// Apply Elastic Weight Consolidation regularization
703    fn apply_ewc_regularization(
704        &mut self,
705        gradient: &Array<A, D>,
706        _importance_weight: f64,
707    ) -> Result<()> {
708        // Simplified EWC implementation
709        // In practice, this would compute Fisher Information Matrix and apply regularization
710        Ok(())
711    }
712
713    /// Apply Progressive Networks strategy
714    fn apply_progressive_networks(&mut self, gradient: &Array<A, D>) -> Result<()> {
715        // Simplified Progressive Networks implementation
716        // In practice, this would manage lateral connections between task columns
717        Ok(())
718    }
719
720    /// Update memory buffer with important examples
721    fn update_memory_buffer(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
722        if let Some(task_id) = &self.current_task {
723            let example = MemoryExample {
724                input: Array::zeros(gradient.raw_dim()),  // Placeholder
725                target: Array::zeros(gradient.raw_dim()), // Placeholder
726                task_id: task_id.clone(),
727                importance: loss,
728                gradient: Some(gradient.clone()),
729            };
730
731            // Add to buffer
732            if self.memory_buffer.examples.len() >= self.memory_buffer.max_size {
733                match self.memory_buffer.update_strategy {
734                    MemoryUpdateStrategy::FIFO => {
735                        self.memory_buffer.examples.pop_front();
736                        self.memory_buffer.importance_scores.pop_front();
737                    }
738                    MemoryUpdateStrategy::Random => {
739                        let idx = thread_rng().gen_range(0..self.memory_buffer.examples.len());
740                        self.memory_buffer.examples.remove(idx);
741                        self.memory_buffer.importance_scores.remove(idx);
742                    }
743                    MemoryUpdateStrategy::ImportanceBased => {
744                        // Remove least important example
745                        if let Some(min_idx) = self
746                            .memory_buffer
747                            .importance_scores
748                            .iter()
749                            .enumerate()
750                            .min_by(|a, b| a.1.partial_cmp(b.1).expect("unwrap failed"))
751                            .map(|(idx, _)| idx)
752                        {
753                            self.memory_buffer.examples.remove(min_idx);
754                            self.memory_buffer.importance_scores.remove(min_idx);
755                        }
756                    }
757                    MemoryUpdateStrategy::GradientDiversity => {
758                        // Remove most similar gradient (simplified)
759                        self.memory_buffer.examples.pop_front();
760                        self.memory_buffer.importance_scores.pop_front();
761                    }
762                }
763            }
764
765            self.memory_buffer.examples.push_back(example);
766            self.memory_buffer.importance_scores.push_back(loss);
767        }
768
769        Ok(())
770    }
771
772    /// Apply meta-learning strategy
773    fn apply_meta_learning(&mut self, gradient: &Array<A, D>) -> Result<()> {
774        // Simplified meta-learning implementation
775        // In practice, this would update meta-parameters based on task performance
776        Ok(())
777    }
778
779    /// Apply Gradient Episodic Memory constraints
780    fn apply_gem_constraints(&mut self, gradient: &Array<A, D>) -> Result<()> {
781        // Simplified GEM implementation
782        // In practice, this would project gradients to satisfy memory constraints
783        Ok(())
784    }
785
786    /// Compute task similarity
787    pub fn compute_task_similarity(&self, task1: &str, task2: &str) -> f64 {
788        self.task_graph
789            .task_similarities
790            .get(&(task1.to_string(), task2.to_string()))
791            .or_else(|| {
792                self.task_graph
793                    .task_similarities
794                    .get(&(task2.to_string(), task1.to_string()))
795            })
796            .copied()
797            .unwrap_or(0.0)
798    }
799
800    /// Get lifelong learning statistics
801    pub fn get_lifelong_stats(&self) -> LifelongStats<A> {
802        let num_tasks = self.task_optimizers.len();
803        let avg_performance = if self.task_performance.is_empty() {
804            A::zero()
805        } else {
806            let total_performance: A = self.task_performance.values().flatten().copied().sum();
807            let total_samples = self
808                .task_performance
809                .values()
810                .map(|v| v.len())
811                .sum::<usize>();
812            if total_samples > 0 {
813                total_performance / A::from(total_samples).expect("unwrap failed")
814            } else {
815                A::zero()
816            }
817        };
818
819        LifelongStats {
820            num_tasks,
821            average_performance: avg_performance,
822            memory_usage: self.memory_buffer.examples.len(),
823            transfer_efficiency: A::from(0.8).expect("unwrap failed"), // Placeholder
824            catastrophic_forgetting: A::from(0.1).expect("unwrap failed"), // Placeholder
825        }
826    }
827}
828
829/// Lifelong learning statistics
830#[derive(Debug, Clone)]
831pub struct LifelongStats<A: Float> {
832    /// Number of tasks learned
833    pub num_tasks: usize,
834    /// Average performance across all tasks
835    pub average_performance: A,
836    /// Current memory usage
837    pub memory_usage: usize,
838    /// Transfer learning efficiency
839    pub transfer_efficiency: A,
840    /// Catastrophic forgetting measure
841    pub catastrophic_forgetting: A,
842}
843
844#[cfg(test)]
845mod tests {
846    use super::*;
847    use approx::assert_relative_eq;
848
849    #[test]
850    fn test_online_optimizer_creation() {
851        let strategy = OnlineLearningStrategy::AdaptiveSGD {
852            initial_lr: 0.01,
853            adaptation_method: LearningRateAdaptation::AdaGrad { epsilon: 1e-8 },
854        };
855
856        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
857        let optimizer = OnlineOptimizer::new(strategy, initial_params);
858
859        assert_eq!(optimizer.step_count, 0);
860        assert_relative_eq!(optimizer.current_lr, 0.01, epsilon = 1e-6);
861    }
862
863    #[test]
864    fn test_online_update() {
865        let strategy = OnlineLearningStrategy::AdaptiveSGD {
866            initial_lr: 0.1,
867            adaptation_method: LearningRateAdaptation::ExponentialDecay { decay_rate: 0.99 },
868        };
869
870        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
871        let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
872
873        let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
874        let loss = 0.5;
875
876        optimizer
877            .online_update(&gradient, loss)
878            .expect("unwrap failed");
879
880        assert_eq!(optimizer.step_count, 1);
881        assert_eq!(optimizer.performance_history.len(), 1);
882        assert_relative_eq!(optimizer.performance_history[0], 0.5, epsilon = 1e-6);
883    }
884
885    #[test]
886    fn test_lifelong_optimizer_creation() {
887        let strategy = LifelongStrategy::ElasticWeightConsolidation {
888            importance_weight: 1000.0,
889            fisher_samples: 100,
890        };
891
892        let optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
893
894        assert_eq!(optimizer.task_optimizers.len(), 0);
895        assert!(optimizer.current_task.is_none());
896    }
897
898    #[test]
899    fn test_task_management() {
900        let strategy = LifelongStrategy::MemoryAugmented {
901            memory_size: 100,
902            update_strategy: MemoryUpdateStrategy::FIFO,
903        };
904
905        let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
906        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
907
908        optimizer
909            .start_task("task1".to_string(), initial_params)
910            .expect("unwrap failed");
911
912        assert_eq!(optimizer.current_task, Some("task1".to_string()));
913        assert!(optimizer.task_optimizers.contains_key("task1"));
914        assert!(optimizer.task_performance.contains_key("task1"));
915    }
916
917    #[test]
918    fn test_memory_buffer_update() {
919        let strategy = LifelongStrategy::MemoryAugmented {
920            memory_size: 2,
921            update_strategy: MemoryUpdateStrategy::FIFO,
922        };
923
924        let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
925        optimizer.memory_buffer.max_size = 2;
926
927        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
928        optimizer
929            .start_task("task1".to_string(), initial_params)
930            .expect("unwrap failed");
931
932        let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
933
934        // Add first example
935        optimizer
936            .update_current_task(&gradient, 0.5)
937            .expect("unwrap failed");
938        assert_eq!(optimizer.memory_buffer.examples.len(), 1);
939
940        // Add second example
941        optimizer
942            .update_current_task(&gradient, 0.6)
943            .expect("unwrap failed");
944        assert_eq!(optimizer.memory_buffer.examples.len(), 2);
945
946        // Add third example (should remove first due to FIFO)
947        optimizer
948            .update_current_task(&gradient, 0.7)
949            .expect("unwrap failed");
950        assert_eq!(optimizer.memory_buffer.examples.len(), 2);
951    }
952
953    #[test]
954    fn test_performance_metrics() {
955        let strategy = OnlineLearningStrategy::AdaptiveSGD {
956            initial_lr: 0.01,
957            adaptation_method: LearningRateAdaptation::Adam {
958                beta1: 0.9,
959                beta2: 0.999,
960                epsilon: 1e-8,
961            },
962        };
963
964        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
965        let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
966
967        // Add some performance data
968        optimizer.performance_history.push_back(0.8);
969        optimizer.performance_history.push_back(0.6);
970        optimizer.performance_history.push_back(0.4);
971        optimizer.regret_bound = 0.5;
972
973        let metrics = optimizer.get_performance_metrics();
974
975        assert_relative_eq!(metrics.cumulative_regret, 0.5, epsilon = 1e-6);
976        assert_relative_eq!(metrics.average_loss, 0.6, epsilon = 1e-6);
977    }
978
979    #[test]
980    fn test_lifelong_stats() {
981        let strategy = LifelongStrategy::MetaLearning {
982            meta_lr: 0.001,
983            inner_steps: 5,
984            task_embedding_size: 64,
985        };
986
987        let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
988
989        // Add some tasks
990        let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
991        optimizer
992            .start_task("task1".to_string(), initial_params.clone())
993            .expect("unwrap failed");
994        optimizer
995            .start_task("task2".to_string(), initial_params)
996            .expect("unwrap failed");
997
998        // Add some performance data
999        optimizer
1000            .task_performance
1001            .get_mut("task1")
1002            .expect("unwrap failed")
1003            .extend(vec![0.8, 0.7]);
1004        optimizer
1005            .task_performance
1006            .get_mut("task2")
1007            .expect("unwrap failed")
1008            .extend(vec![0.9, 0.8]);
1009
1010        let stats = optimizer.get_lifelong_stats();
1011
1012        assert_eq!(stats.num_tasks, 2);
1013        assert_relative_eq!(stats.average_performance, 0.8, epsilon = 1e-6);
1014    }
1015
1016    #[test]
1017    fn test_learning_rate_adaptations() {
1018        let strategies = vec![
1019            LearningRateAdaptation::AdaGrad { epsilon: 1e-8 },
1020            LearningRateAdaptation::RMSprop {
1021                decay: 0.9,
1022                epsilon: 1e-8,
1023            },
1024            LearningRateAdaptation::Adam {
1025                beta1: 0.9,
1026                beta2: 0.999,
1027                epsilon: 1e-8,
1028            },
1029            LearningRateAdaptation::ExponentialDecay { decay_rate: 0.99 },
1030            LearningRateAdaptation::InverseScaling { power: 0.5 },
1031        ];
1032
1033        for adaptation in strategies {
1034            let strategy = OnlineLearningStrategy::AdaptiveSGD {
1035                initial_lr: 0.01,
1036                adaptation_method: adaptation,
1037            };
1038
1039            let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1040            let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
1041
1042            let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1043            let result = optimizer.online_update(&gradient, 0.5);
1044
1045            assert!(result.is_ok());
1046            assert_eq!(optimizer.step_count, 1);
1047        }
1048    }
1049}