1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
8use scirs2_core::numeric::Float;
9use scirs2_core::random::{thread_rng, Random};
10use std::collections::{HashMap, VecDeque};
11use std::fmt::Debug;
12
13#[derive(Debug, Clone)]
15pub enum OnlineLearningStrategy {
16 AdaptiveSGD {
18 initial_lr: f64,
20 adaptation_method: LearningRateAdaptation,
22 },
23 OnlineNewton {
25 damping: f64,
27 hessian_window: usize,
29 },
30 FTRL {
32 l1_regularization: f64,
34 l2_regularization: f64,
36 learning_rate_power: f64,
38 },
39 MirrorDescent {
41 mirror_function: MirrorFunction,
43 regularization: f64,
45 },
46 AdaptiveMultiTask {
48 similarity_threshold: f64,
50 task_lr_adaptation: bool,
52 },
53}
54
55#[derive(Debug, Clone)]
57pub enum LearningRateAdaptation {
58 AdaGrad {
60 epsilon: f64,
62 },
63 RMSprop {
65 decay: f64,
67 epsilon: f64,
69 },
70 Adam {
72 beta1: f64,
74 beta2: f64,
76 epsilon: f64,
78 },
79 ExponentialDecay {
81 decay_rate: f64,
83 },
84 InverseScaling {
86 power: f64,
88 },
89}
90
91#[derive(Debug, Clone)]
93pub enum MirrorFunction {
94 Euclidean,
96 Entropy,
98 L1,
100 Nuclear,
102}
103
104#[derive(Debug, Clone)]
106pub enum LifelongStrategy {
107 ElasticWeightConsolidation {
109 importance_weight: f64,
111 fisher_samples: usize,
113 },
114 ProgressiveNetworks {
116 lateral_strength: f64,
118 growth_strategy: ColumnGrowthStrategy,
120 },
121 MemoryAugmented {
123 memory_size: usize,
125 update_strategy: MemoryUpdateStrategy,
127 },
128 MetaLearning {
130 meta_lr: f64,
132 inner_steps: usize,
134 task_embedding_size: usize,
136 },
137 GradientEpisodicMemory {
139 memory_per_task: usize,
141 violation_tolerance: f64,
143 },
144}
145
146#[derive(Debug, Clone)]
148pub enum ColumnGrowthStrategy {
149 PerTask,
151 PerformanceBased {
153 threshold: f64,
155 },
156 FixedInterval {
158 interval: usize,
160 },
161}
162
163#[derive(Debug, Clone)]
165pub enum MemoryUpdateStrategy {
166 FIFO,
168 Random,
170 ImportanceBased,
172 GradientDiversity,
174}
175
176#[derive(Debug)]
178pub struct OnlineOptimizer<A: Float, D: Dimension> {
179 strategy: OnlineLearningStrategy,
181 parameters: Array<A, D>,
183 gradient_accumulator: Array<A, D>,
185 second_moment_accumulator: Option<Array<A, D>>,
187 current_lr: A,
189 step_count: usize,
191 performance_history: VecDeque<A>,
193 regret_bound: A,
195}
196
197#[derive(Debug)]
199pub struct LifelongOptimizer<A: Float, D: Dimension> {
200 strategy: LifelongStrategy,
202 task_optimizers: HashMap<String, OnlineOptimizer<A, D>>,
204 #[allow(dead_code)]
206 shared_knowledge: SharedKnowledge<A, D>,
207 task_graph: TaskGraph,
209 memory_buffer: MemoryBuffer<A, D>,
211 current_task: Option<String>,
213 task_performance: HashMap<String, Vec<A>>,
215}
216
217#[derive(Debug)]
219pub struct SharedKnowledge<A: Float, D: Dimension> {
220 #[allow(dead_code)]
222 fisher_information: Option<Array<A, D>>,
223 #[allow(dead_code)]
225 important_parameters: Option<Array<A, D>>,
226 #[allow(dead_code)]
228 task_embeddings: HashMap<String, Array1<A>>,
229 #[allow(dead_code)]
231 transfer_weights: HashMap<(String, String), A>,
232 #[allow(dead_code)]
234 meta_parameters: Option<Array1<A>>,
235}
236
237#[derive(Debug)]
239pub struct TaskGraph {
240 task_similarities: HashMap<(String, String), f64>,
242 #[allow(dead_code)]
244 task_dependencies: HashMap<String, Vec<String>>,
245 #[allow(dead_code)]
247 task_clusters: HashMap<String, String>,
248}
249
250#[derive(Debug)]
252pub struct MemoryBuffer<A: Float, D: Dimension> {
253 examples: VecDeque<MemoryExample<A, D>>,
255 max_size: usize,
257 update_strategy: MemoryUpdateStrategy,
259 importance_scores: VecDeque<A>,
261}
262
263#[derive(Debug, Clone)]
265pub struct MemoryExample<A: Float, D: Dimension> {
266 pub input: Array<A, D>,
268 pub target: Array<A, D>,
270 pub task_id: String,
272 pub importance: A,
274 pub gradient: Option<Array<A, D>>,
276}
277
278#[derive(Debug, Clone)]
280pub struct OnlinePerformanceMetrics<A: Float> {
281 pub cumulative_regret: A,
283 pub average_loss: A,
285 pub lr_stability: A,
287 pub adaptation_speed: A,
289 pub memory_efficiency: A,
291}
292
293impl<A: Float + ScalarOperand + Debug + std::iter::Sum, D: Dimension + Send + Sync>
294 OnlineOptimizer<A, D>
295{
296 pub fn new(strategy: OnlineLearningStrategy, initial_parameters: Array<A, D>) -> Self {
298 let paramshape = initial_parameters.raw_dim();
299 let gradient_accumulator = Array::zeros(paramshape.clone());
300 let second_moment_accumulator = match &strategy {
301 OnlineLearningStrategy::AdaptiveSGD {
302 adaptation_method: LearningRateAdaptation::Adam { .. },
303 ..
304 } => Some(Array::zeros(paramshape)),
305 _ => None,
306 };
307
308 let current_lr = match &strategy {
309 OnlineLearningStrategy::AdaptiveSGD { initial_lr, .. } => A::from(*initial_lr).unwrap(),
310 OnlineLearningStrategy::OnlineNewton { .. } => A::from(0.01).unwrap(),
311 OnlineLearningStrategy::FTRL { .. } => A::from(0.1).unwrap(),
312 OnlineLearningStrategy::MirrorDescent { .. } => A::from(0.01).unwrap(),
313 OnlineLearningStrategy::AdaptiveMultiTask { .. } => A::from(0.001).unwrap(),
314 };
315
316 Self {
317 strategy,
318 parameters: initial_parameters,
319 gradient_accumulator,
320 second_moment_accumulator,
321 current_lr,
322 step_count: 0,
323 performance_history: VecDeque::new(),
324 regret_bound: A::zero(),
325 }
326 }
327
328 pub fn online_update(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
330 self.step_count += 1;
331 self.performance_history.push_back(loss);
332
333 if self.performance_history.len() > 1000 {
335 self.performance_history.pop_front();
336 }
337
338 match self.strategy.clone() {
339 OnlineLearningStrategy::AdaptiveSGD {
340 adaptation_method, ..
341 } => {
342 self.adaptive_sgd_update(gradient, &adaptation_method)?;
343 }
344 OnlineLearningStrategy::OnlineNewton { damping, .. } => {
345 self.online_newton_update(gradient, damping)?;
346 }
347 OnlineLearningStrategy::FTRL {
348 l1_regularization,
349 l2_regularization,
350 learning_rate_power,
351 } => {
352 self.ftrl_update(
353 gradient,
354 l1_regularization,
355 l2_regularization,
356 learning_rate_power,
357 )?;
358 }
359 OnlineLearningStrategy::MirrorDescent {
360 mirror_function,
361 regularization,
362 } => {
363 self.mirror_descent_update(gradient, &mirror_function, regularization)?;
364 }
365 OnlineLearningStrategy::AdaptiveMultiTask { .. } => {
366 self.adaptive_multitask_update(gradient)?;
367 }
368 }
369
370 self.update_regret_bound(loss);
372
373 Ok(())
374 }
375
376 fn adaptive_sgd_update(
378 &mut self,
379 gradient: &Array<A, D>,
380 adaptation: &LearningRateAdaptation,
381 ) -> Result<()> {
382 match adaptation {
383 LearningRateAdaptation::AdaGrad { epsilon } => {
384 self.gradient_accumulator = &self.gradient_accumulator + &gradient.mapv(|g| g * g);
386
387 let adaptive_lr = self
389 .gradient_accumulator
390 .mapv(|acc| A::from(*epsilon).unwrap() + A::sqrt(acc));
391
392 self.parameters = &self.parameters - &(gradient / &adaptive_lr * self.current_lr);
394 }
395 LearningRateAdaptation::RMSprop { decay, epsilon } => {
396 let decay_factor = A::from(*decay).unwrap();
397 let one_minus_decay = A::one() - decay_factor;
398
399 self.gradient_accumulator = &self.gradient_accumulator * decay_factor
401 + &gradient.mapv(|g| g * g * one_minus_decay);
402
403 let adaptive_lr = self
405 .gradient_accumulator
406 .mapv(|acc| A::sqrt(acc + A::from(*epsilon).unwrap()));
407
408 self.parameters = &self.parameters - &(gradient / &adaptive_lr * self.current_lr);
410 }
411 LearningRateAdaptation::Adam {
412 beta1,
413 beta2,
414 epsilon,
415 } => {
416 let beta1_val = A::from(*beta1).unwrap();
417 let beta2_val = A::from(*beta2).unwrap();
418 let one_minus_beta1 = A::one() - beta1_val;
419 let one_minus_beta2 = A::one() - beta2_val;
420
421 self.gradient_accumulator =
423 &self.gradient_accumulator * beta1_val + gradient * one_minus_beta1;
424
425 if let Some(ref mut second_moment) = self.second_moment_accumulator {
427 *second_moment =
428 &*second_moment * beta2_val + &gradient.mapv(|g| g * g * one_minus_beta2);
429
430 let step_count_float = A::from(self.step_count).unwrap();
432 let bias_correction1 = A::one() - A::powf(beta1_val, step_count_float);
433 let bias_correction2 = A::one() - A::powf(beta2_val, step_count_float);
434
435 let corrected_first = &self.gradient_accumulator / bias_correction1;
436 let corrected_second = &*second_moment / bias_correction2;
437
438 let adaptive_lr =
440 corrected_second.mapv(|v| A::sqrt(v) + A::from(*epsilon).unwrap());
441 self.parameters =
442 &self.parameters - &(corrected_first / adaptive_lr * self.current_lr);
443 }
444 }
445 LearningRateAdaptation::ExponentialDecay { decay_rate } => {
446 self.current_lr = self.current_lr * A::from(*decay_rate).unwrap();
448 self.parameters = &self.parameters - gradient * self.current_lr;
449 }
450 LearningRateAdaptation::InverseScaling { power } => {
451 let step_power =
453 A::powf(A::from(self.step_count).unwrap(), A::from(*power).unwrap());
454 let decayed_lr = self.current_lr / step_power;
455 self.parameters = &self.parameters - gradient * decayed_lr;
456 }
457 }
458
459 Ok(())
460 }
461
462 fn online_newton_update(&mut self, gradient: &Array<A, D>, damping: f64) -> Result<()> {
464 let damping_val = A::from(damping).unwrap();
466
467 let hessian_approx = gradient.mapv(|g| g * g + damping_val);
469
470 let newton_step = gradient / hessian_approx;
472 self.parameters = &self.parameters - &newton_step * self.current_lr;
473
474 Ok(())
475 }
476
477 fn ftrl_update(
479 &mut self,
480 gradient: &Array<A, D>,
481 l1_reg: f64,
482 l2_reg: f64,
483 lr_power: f64,
484 ) -> Result<()> {
485 self.gradient_accumulator = &self.gradient_accumulator + gradient;
487
488 let step_factor = A::powf(
490 A::from(self.step_count).unwrap(),
491 A::from(lr_power).unwrap(),
492 );
493 let learning_rate = self.current_lr / step_factor;
494
495 let l1_weight = A::from(l1_reg).unwrap();
497 let l2_weight = A::from(l2_reg).unwrap();
498
499 self.parameters = self.gradient_accumulator.mapv(|g| {
500 let abs_g = A::abs(g);
501 if abs_g <= l1_weight {
502 A::zero()
503 } else {
504 let sign = if g > A::zero() { A::one() } else { -A::one() };
505 -sign * (abs_g - l1_weight) / (l2_weight + A::sqrt(abs_g))
506 }
507 }) * learning_rate;
508
509 Ok(())
510 }
511
512 fn mirror_descent_update(
514 &mut self,
515 gradient: &Array<A, D>,
516 mirror_fn: &MirrorFunction,
517 regularization: f64,
518 ) -> Result<()> {
519 match mirror_fn {
520 MirrorFunction::Euclidean => {
521 self.parameters = &self.parameters - gradient * self.current_lr;
523 }
524 MirrorFunction::Entropy => {
525 let reg_val = A::from(regularization).unwrap();
527 let updated = self
528 .parameters
529 .mapv(|p| A::exp(A::ln(p) - self.current_lr * reg_val));
530 let sum = updated.sum();
531 self.parameters = updated / sum; }
533 MirrorFunction::L1 => {
534 let threshold = self.current_lr * A::from(regularization).unwrap();
536 self.parameters = (&self.parameters - gradient * self.current_lr).mapv(|p| {
537 if A::abs(p) <= threshold {
538 A::zero()
539 } else {
540 p - A::signum(p) * threshold
541 }
542 });
543 }
544 MirrorFunction::Nuclear => {
545 self.parameters = &self.parameters - gradient * self.current_lr;
547 }
548 }
549
550 Ok(())
551 }
552
553 fn adaptive_multitask_update(&mut self, gradient: &Array<A, D>) -> Result<()> {
555 self.parameters = &self.parameters - gradient * self.current_lr;
557 Ok(())
558 }
559
560 fn update_regret_bound(&mut self, loss: A) {
562 if let Some(&best_loss) = self
563 .performance_history
564 .iter()
565 .min_by(|a, b| a.partial_cmp(b).unwrap())
566 {
567 let regret = loss - best_loss;
568 self.regret_bound = self.regret_bound + regret.max(A::zero());
569 }
570 }
571
572 pub fn parameters(&self) -> &Array<A, D> {
574 &self.parameters
575 }
576
577 pub fn get_performance_metrics(&self) -> OnlinePerformanceMetrics<A> {
579 let average_loss = if self.performance_history.is_empty() {
580 A::zero()
581 } else {
582 self.performance_history.iter().copied().sum::<A>()
583 / A::from(self.performance_history.len()).unwrap()
584 };
585
586 let lr_stability = A::from(1.0).unwrap(); let adaptation_speed = A::from(self.step_count as f64).unwrap(); let memory_efficiency = A::from(0.8).unwrap(); OnlinePerformanceMetrics {
591 cumulative_regret: self.regret_bound,
592 average_loss,
593 lr_stability,
594 adaptation_speed,
595 memory_efficiency,
596 }
597 }
598}
599
600impl<A: Float + ScalarOperand + Debug + std::iter::Sum, D: Dimension + Send + Sync>
601 LifelongOptimizer<A, D>
602{
603 pub fn new(strategy: LifelongStrategy) -> Self {
605 Self {
606 strategy,
607 task_optimizers: HashMap::new(),
608 shared_knowledge: SharedKnowledge {
609 fisher_information: None,
610 important_parameters: None,
611 task_embeddings: HashMap::new(),
612 transfer_weights: HashMap::new(),
613 meta_parameters: None,
614 },
615 task_graph: TaskGraph {
616 task_similarities: HashMap::new(),
617 task_dependencies: HashMap::new(),
618 task_clusters: HashMap::new(),
619 },
620 memory_buffer: MemoryBuffer {
621 examples: VecDeque::new(),
622 max_size: 1000,
623 update_strategy: MemoryUpdateStrategy::FIFO,
624 importance_scores: VecDeque::new(),
625 },
626 current_task: None,
627 task_performance: HashMap::new(),
628 }
629 }
630
631 pub fn start_task(&mut self, task_id: String, initial_parameters: Array<A, D>) -> Result<()> {
633 self.current_task = Some(task_id.clone());
634
635 let online_strategy = OnlineLearningStrategy::AdaptiveSGD {
637 initial_lr: 0.001,
638 adaptation_method: LearningRateAdaptation::Adam {
639 beta1: 0.9,
640 beta2: 0.999,
641 epsilon: 1e-8,
642 },
643 };
644
645 let task_optimizer = OnlineOptimizer::new(online_strategy, initial_parameters);
646 self.task_optimizers.insert(task_id.clone(), task_optimizer);
647
648 self.task_performance.insert(task_id, Vec::new());
650
651 Ok(())
652 }
653
654 pub fn update_current_task(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
656 let task_id = self
657 .current_task
658 .as_ref()
659 .ok_or_else(|| OptimError::InvalidConfig("No current task set".to_string()))?
660 .clone();
661
662 if let Some(optimizer) = self.task_optimizers.get_mut(&task_id) {
664 optimizer.online_update(gradient, loss)?;
665 }
666
667 if let Some(performance) = self.task_performance.get_mut(&task_id) {
669 performance.push(loss);
670 }
671
672 match &self.strategy {
674 LifelongStrategy::ElasticWeightConsolidation {
675 importance_weight, ..
676 } => {
677 self.apply_ewc_regularization(gradient, *importance_weight)?;
678 }
679 LifelongStrategy::ProgressiveNetworks { .. } => {
680 self.apply_progressive_networks(gradient)?;
681 }
682 LifelongStrategy::MemoryAugmented { .. } => {
683 self.update_memory_buffer(gradient, loss)?;
684 }
685 LifelongStrategy::MetaLearning { .. } => {
686 self.apply_meta_learning(gradient)?;
687 }
688 LifelongStrategy::GradientEpisodicMemory { .. } => {
689 self.apply_gem_constraints(gradient)?;
690 }
691 }
692
693 Ok(())
694 }
695
696 fn apply_ewc_regularization(
698 &mut self,
699 gradient: &Array<A, D>,
700 _importance_weight: f64,
701 ) -> Result<()> {
702 Ok(())
705 }
706
707 fn apply_progressive_networks(&mut self, gradient: &Array<A, D>) -> Result<()> {
709 Ok(())
712 }
713
714 fn update_memory_buffer(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
716 if let Some(task_id) = &self.current_task {
717 let example = MemoryExample {
718 input: Array::zeros(gradient.raw_dim()), target: Array::zeros(gradient.raw_dim()), task_id: task_id.clone(),
721 importance: loss,
722 gradient: Some(gradient.clone()),
723 };
724
725 if self.memory_buffer.examples.len() >= self.memory_buffer.max_size {
727 match self.memory_buffer.update_strategy {
728 MemoryUpdateStrategy::FIFO => {
729 self.memory_buffer.examples.pop_front();
730 self.memory_buffer.importance_scores.pop_front();
731 }
732 MemoryUpdateStrategy::Random => {
733 let idx = thread_rng().gen_range(0..self.memory_buffer.examples.len());
734 self.memory_buffer.examples.remove(idx);
735 self.memory_buffer.importance_scores.remove(idx);
736 }
737 MemoryUpdateStrategy::ImportanceBased => {
738 if let Some(min_idx) = self
740 .memory_buffer
741 .importance_scores
742 .iter()
743 .enumerate()
744 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap())
745 .map(|(idx, _)| idx)
746 {
747 self.memory_buffer.examples.remove(min_idx);
748 self.memory_buffer.importance_scores.remove(min_idx);
749 }
750 }
751 MemoryUpdateStrategy::GradientDiversity => {
752 self.memory_buffer.examples.pop_front();
754 self.memory_buffer.importance_scores.pop_front();
755 }
756 }
757 }
758
759 self.memory_buffer.examples.push_back(example);
760 self.memory_buffer.importance_scores.push_back(loss);
761 }
762
763 Ok(())
764 }
765
766 fn apply_meta_learning(&mut self, gradient: &Array<A, D>) -> Result<()> {
768 Ok(())
771 }
772
773 fn apply_gem_constraints(&mut self, gradient: &Array<A, D>) -> Result<()> {
775 Ok(())
778 }
779
780 pub fn compute_task_similarity(&self, task1: &str, task2: &str) -> f64 {
782 self.task_graph
783 .task_similarities
784 .get(&(task1.to_string(), task2.to_string()))
785 .or_else(|| {
786 self.task_graph
787 .task_similarities
788 .get(&(task2.to_string(), task1.to_string()))
789 })
790 .copied()
791 .unwrap_or(0.0)
792 }
793
794 pub fn get_lifelong_stats(&self) -> LifelongStats<A> {
796 let num_tasks = self.task_optimizers.len();
797 let avg_performance = if self.task_performance.is_empty() {
798 A::zero()
799 } else {
800 let total_performance: A = self.task_performance.values().flatten().copied().sum();
801 let total_samples = self
802 .task_performance
803 .values()
804 .map(|v| v.len())
805 .sum::<usize>();
806 if total_samples > 0 {
807 total_performance / A::from(total_samples).unwrap()
808 } else {
809 A::zero()
810 }
811 };
812
813 LifelongStats {
814 num_tasks,
815 average_performance: avg_performance,
816 memory_usage: self.memory_buffer.examples.len(),
817 transfer_efficiency: A::from(0.8).unwrap(), catastrophic_forgetting: A::from(0.1).unwrap(), }
820 }
821}
822
823#[derive(Debug, Clone)]
825pub struct LifelongStats<A: Float> {
826 pub num_tasks: usize,
828 pub average_performance: A,
830 pub memory_usage: usize,
832 pub transfer_efficiency: A,
834 pub catastrophic_forgetting: A,
836}
837
838#[cfg(test)]
839mod tests {
840 use super::*;
841 use approx::assert_relative_eq;
842
843 #[test]
844 fn test_online_optimizer_creation() {
845 let strategy = OnlineLearningStrategy::AdaptiveSGD {
846 initial_lr: 0.01,
847 adaptation_method: LearningRateAdaptation::AdaGrad { epsilon: 1e-8 },
848 };
849
850 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
851 let optimizer = OnlineOptimizer::new(strategy, initial_params);
852
853 assert_eq!(optimizer.step_count, 0);
854 assert_relative_eq!(optimizer.current_lr, 0.01, epsilon = 1e-6);
855 }
856
857 #[test]
858 fn test_online_update() {
859 let strategy = OnlineLearningStrategy::AdaptiveSGD {
860 initial_lr: 0.1,
861 adaptation_method: LearningRateAdaptation::ExponentialDecay { decay_rate: 0.99 },
862 };
863
864 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
865 let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
866
867 let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
868 let loss = 0.5;
869
870 optimizer.online_update(&gradient, loss).unwrap();
871
872 assert_eq!(optimizer.step_count, 1);
873 assert_eq!(optimizer.performance_history.len(), 1);
874 assert_relative_eq!(optimizer.performance_history[0], 0.5, epsilon = 1e-6);
875 }
876
877 #[test]
878 fn test_lifelong_optimizer_creation() {
879 let strategy = LifelongStrategy::ElasticWeightConsolidation {
880 importance_weight: 1000.0,
881 fisher_samples: 100,
882 };
883
884 let optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
885
886 assert_eq!(optimizer.task_optimizers.len(), 0);
887 assert!(optimizer.current_task.is_none());
888 }
889
890 #[test]
891 fn test_task_management() {
892 let strategy = LifelongStrategy::MemoryAugmented {
893 memory_size: 100,
894 update_strategy: MemoryUpdateStrategy::FIFO,
895 };
896
897 let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
898 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
899
900 optimizer
901 .start_task("task1".to_string(), initial_params)
902 .unwrap();
903
904 assert_eq!(optimizer.current_task, Some("task1".to_string()));
905 assert!(optimizer.task_optimizers.contains_key("task1"));
906 assert!(optimizer.task_performance.contains_key("task1"));
907 }
908
909 #[test]
910 fn test_memory_buffer_update() {
911 let strategy = LifelongStrategy::MemoryAugmented {
912 memory_size: 2,
913 update_strategy: MemoryUpdateStrategy::FIFO,
914 };
915
916 let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
917 optimizer.memory_buffer.max_size = 2;
918
919 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
920 optimizer
921 .start_task("task1".to_string(), initial_params)
922 .unwrap();
923
924 let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
925
926 optimizer.update_current_task(&gradient, 0.5).unwrap();
928 assert_eq!(optimizer.memory_buffer.examples.len(), 1);
929
930 optimizer.update_current_task(&gradient, 0.6).unwrap();
932 assert_eq!(optimizer.memory_buffer.examples.len(), 2);
933
934 optimizer.update_current_task(&gradient, 0.7).unwrap();
936 assert_eq!(optimizer.memory_buffer.examples.len(), 2);
937 }
938
939 #[test]
940 fn test_performance_metrics() {
941 let strategy = OnlineLearningStrategy::AdaptiveSGD {
942 initial_lr: 0.01,
943 adaptation_method: LearningRateAdaptation::Adam {
944 beta1: 0.9,
945 beta2: 0.999,
946 epsilon: 1e-8,
947 },
948 };
949
950 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
951 let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
952
953 optimizer.performance_history.push_back(0.8);
955 optimizer.performance_history.push_back(0.6);
956 optimizer.performance_history.push_back(0.4);
957 optimizer.regret_bound = 0.5;
958
959 let metrics = optimizer.get_performance_metrics();
960
961 assert_relative_eq!(metrics.cumulative_regret, 0.5, epsilon = 1e-6);
962 assert_relative_eq!(metrics.average_loss, 0.6, epsilon = 1e-6);
963 }
964
965 #[test]
966 fn test_lifelong_stats() {
967 let strategy = LifelongStrategy::MetaLearning {
968 meta_lr: 0.001,
969 inner_steps: 5,
970 task_embedding_size: 64,
971 };
972
973 let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
974
975 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
977 optimizer
978 .start_task("task1".to_string(), initial_params.clone())
979 .unwrap();
980 optimizer
981 .start_task("task2".to_string(), initial_params)
982 .unwrap();
983
984 optimizer
986 .task_performance
987 .get_mut("task1")
988 .unwrap()
989 .extend(vec![0.8, 0.7]);
990 optimizer
991 .task_performance
992 .get_mut("task2")
993 .unwrap()
994 .extend(vec![0.9, 0.8]);
995
996 let stats = optimizer.get_lifelong_stats();
997
998 assert_eq!(stats.num_tasks, 2);
999 assert_relative_eq!(stats.average_performance, 0.8, epsilon = 1e-6);
1000 }
1001
1002 #[test]
1003 fn test_learning_rate_adaptations() {
1004 let strategies = vec![
1005 LearningRateAdaptation::AdaGrad { epsilon: 1e-8 },
1006 LearningRateAdaptation::RMSprop {
1007 decay: 0.9,
1008 epsilon: 1e-8,
1009 },
1010 LearningRateAdaptation::Adam {
1011 beta1: 0.9,
1012 beta2: 0.999,
1013 epsilon: 1e-8,
1014 },
1015 LearningRateAdaptation::ExponentialDecay { decay_rate: 0.99 },
1016 LearningRateAdaptation::InverseScaling { power: 0.5 },
1017 ];
1018
1019 for adaptation in strategies {
1020 let strategy = OnlineLearningStrategy::AdaptiveSGD {
1021 initial_lr: 0.01,
1022 adaptation_method: adaptation,
1023 };
1024
1025 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1026 let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
1027
1028 let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1029 let result = optimizer.online_update(&gradient, 0.5);
1030
1031 assert!(result.is_ok());
1032 assert_eq!(optimizer.step_count, 1);
1033 }
1034 }
1035}