1use crate::error::{OptimError, Result};
7use scirs2_core::ndarray::{Array, Array1, Dimension, ScalarOperand};
8use scirs2_core::numeric::Float;
9use scirs2_core::random::{thread_rng, Random};
10use std::collections::{HashMap, VecDeque};
11use std::fmt::Debug;
12
13#[derive(Debug, Clone)]
15pub enum OnlineLearningStrategy {
16 AdaptiveSGD {
18 initial_lr: f64,
20 adaptation_method: LearningRateAdaptation,
22 },
23 OnlineNewton {
25 damping: f64,
27 hessian_window: usize,
29 },
30 FTRL {
32 l1_regularization: f64,
34 l2_regularization: f64,
36 learning_rate_power: f64,
38 },
39 MirrorDescent {
41 mirror_function: MirrorFunction,
43 regularization: f64,
45 },
46 AdaptiveMultiTask {
48 similarity_threshold: f64,
50 task_lr_adaptation: bool,
52 },
53}
54
55#[derive(Debug, Clone)]
57pub enum LearningRateAdaptation {
58 AdaGrad {
60 epsilon: f64,
62 },
63 RMSprop {
65 decay: f64,
67 epsilon: f64,
69 },
70 Adam {
72 beta1: f64,
74 beta2: f64,
76 epsilon: f64,
78 },
79 ExponentialDecay {
81 decay_rate: f64,
83 },
84 InverseScaling {
86 power: f64,
88 },
89}
90
91#[derive(Debug, Clone)]
93pub enum MirrorFunction {
94 Euclidean,
96 Entropy,
98 L1,
100 Nuclear,
102}
103
104#[derive(Debug, Clone)]
106pub enum LifelongStrategy {
107 ElasticWeightConsolidation {
109 importance_weight: f64,
111 fisher_samples: usize,
113 },
114 ProgressiveNetworks {
116 lateral_strength: f64,
118 growth_strategy: ColumnGrowthStrategy,
120 },
121 MemoryAugmented {
123 memory_size: usize,
125 update_strategy: MemoryUpdateStrategy,
127 },
128 MetaLearning {
130 meta_lr: f64,
132 inner_steps: usize,
134 task_embedding_size: usize,
136 },
137 GradientEpisodicMemory {
139 memory_per_task: usize,
141 violation_tolerance: f64,
143 },
144}
145
146#[derive(Debug, Clone)]
148pub enum ColumnGrowthStrategy {
149 PerTask,
151 PerformanceBased {
153 threshold: f64,
155 },
156 FixedInterval {
158 interval: usize,
160 },
161}
162
163#[derive(Debug, Clone)]
165pub enum MemoryUpdateStrategy {
166 FIFO,
168 Random,
170 ImportanceBased,
172 GradientDiversity,
174}
175
176#[derive(Debug)]
178pub struct OnlineOptimizer<A: Float, D: Dimension> {
179 strategy: OnlineLearningStrategy,
181 parameters: Array<A, D>,
183 gradient_accumulator: Array<A, D>,
185 second_moment_accumulator: Option<Array<A, D>>,
187 current_lr: A,
189 step_count: usize,
191 performance_history: VecDeque<A>,
193 regret_bound: A,
195}
196
197#[derive(Debug)]
199pub struct LifelongOptimizer<A: Float, D: Dimension> {
200 strategy: LifelongStrategy,
202 task_optimizers: HashMap<String, OnlineOptimizer<A, D>>,
204 #[allow(dead_code)]
206 shared_knowledge: SharedKnowledge<A, D>,
207 task_graph: TaskGraph,
209 memory_buffer: MemoryBuffer<A, D>,
211 current_task: Option<String>,
213 task_performance: HashMap<String, Vec<A>>,
215}
216
217#[derive(Debug)]
219pub struct SharedKnowledge<A: Float, D: Dimension> {
220 #[allow(dead_code)]
222 fisher_information: Option<Array<A, D>>,
223 #[allow(dead_code)]
225 important_parameters: Option<Array<A, D>>,
226 #[allow(dead_code)]
228 task_embeddings: HashMap<String, Array1<A>>,
229 #[allow(dead_code)]
231 transfer_weights: HashMap<(String, String), A>,
232 #[allow(dead_code)]
234 meta_parameters: Option<Array1<A>>,
235}
236
237#[derive(Debug)]
239pub struct TaskGraph {
240 task_similarities: HashMap<(String, String), f64>,
242 #[allow(dead_code)]
244 task_dependencies: HashMap<String, Vec<String>>,
245 #[allow(dead_code)]
247 task_clusters: HashMap<String, String>,
248}
249
250#[derive(Debug)]
252pub struct MemoryBuffer<A: Float, D: Dimension> {
253 examples: VecDeque<MemoryExample<A, D>>,
255 max_size: usize,
257 update_strategy: MemoryUpdateStrategy,
259 importance_scores: VecDeque<A>,
261}
262
263#[derive(Debug, Clone)]
265pub struct MemoryExample<A: Float, D: Dimension> {
266 pub input: Array<A, D>,
268 pub target: Array<A, D>,
270 pub task_id: String,
272 pub importance: A,
274 pub gradient: Option<Array<A, D>>,
276}
277
278#[derive(Debug, Clone)]
280pub struct OnlinePerformanceMetrics<A: Float> {
281 pub cumulative_regret: A,
283 pub average_loss: A,
285 pub lr_stability: A,
287 pub adaptation_speed: A,
289 pub memory_efficiency: A,
291}
292
293impl<A: Float + ScalarOperand + Debug + std::iter::Sum, D: Dimension + Send + Sync>
294 OnlineOptimizer<A, D>
295{
296 pub fn new(strategy: OnlineLearningStrategy, initial_parameters: Array<A, D>) -> Self {
298 let paramshape = initial_parameters.raw_dim();
299 let gradient_accumulator = Array::zeros(paramshape.clone());
300 let second_moment_accumulator = match &strategy {
301 OnlineLearningStrategy::AdaptiveSGD {
302 adaptation_method: LearningRateAdaptation::Adam { .. },
303 ..
304 } => Some(Array::zeros(paramshape)),
305 _ => None,
306 };
307
308 let current_lr = match &strategy {
309 OnlineLearningStrategy::AdaptiveSGD { initial_lr, .. } => {
310 A::from(*initial_lr).expect("unwrap failed")
311 }
312 OnlineLearningStrategy::OnlineNewton { .. } => A::from(0.01).expect("unwrap failed"),
313 OnlineLearningStrategy::FTRL { .. } => A::from(0.1).expect("unwrap failed"),
314 OnlineLearningStrategy::MirrorDescent { .. } => A::from(0.01).expect("unwrap failed"),
315 OnlineLearningStrategy::AdaptiveMultiTask { .. } => {
316 A::from(0.001).expect("unwrap failed")
317 }
318 };
319
320 Self {
321 strategy,
322 parameters: initial_parameters,
323 gradient_accumulator,
324 second_moment_accumulator,
325 current_lr,
326 step_count: 0,
327 performance_history: VecDeque::new(),
328 regret_bound: A::zero(),
329 }
330 }
331
332 pub fn online_update(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
334 self.step_count += 1;
335 self.performance_history.push_back(loss);
336
337 if self.performance_history.len() > 1000 {
339 self.performance_history.pop_front();
340 }
341
342 match self.strategy.clone() {
343 OnlineLearningStrategy::AdaptiveSGD {
344 adaptation_method, ..
345 } => {
346 self.adaptive_sgd_update(gradient, &adaptation_method)?;
347 }
348 OnlineLearningStrategy::OnlineNewton { damping, .. } => {
349 self.online_newton_update(gradient, damping)?;
350 }
351 OnlineLearningStrategy::FTRL {
352 l1_regularization,
353 l2_regularization,
354 learning_rate_power,
355 } => {
356 self.ftrl_update(
357 gradient,
358 l1_regularization,
359 l2_regularization,
360 learning_rate_power,
361 )?;
362 }
363 OnlineLearningStrategy::MirrorDescent {
364 mirror_function,
365 regularization,
366 } => {
367 self.mirror_descent_update(gradient, &mirror_function, regularization)?;
368 }
369 OnlineLearningStrategy::AdaptiveMultiTask { .. } => {
370 self.adaptive_multitask_update(gradient)?;
371 }
372 }
373
374 self.update_regret_bound(loss);
376
377 Ok(())
378 }
379
380 fn adaptive_sgd_update(
382 &mut self,
383 gradient: &Array<A, D>,
384 adaptation: &LearningRateAdaptation,
385 ) -> Result<()> {
386 match adaptation {
387 LearningRateAdaptation::AdaGrad { epsilon } => {
388 self.gradient_accumulator = &self.gradient_accumulator + &gradient.mapv(|g| g * g);
390
391 let adaptive_lr = self
393 .gradient_accumulator
394 .mapv(|acc| A::from(*epsilon).expect("unwrap failed") + A::sqrt(acc));
395
396 self.parameters = &self.parameters - &(gradient / &adaptive_lr * self.current_lr);
398 }
399 LearningRateAdaptation::RMSprop { decay, epsilon } => {
400 let decay_factor = A::from(*decay).expect("unwrap failed");
401 let one_minus_decay = A::one() - decay_factor;
402
403 self.gradient_accumulator = &self.gradient_accumulator * decay_factor
405 + &gradient.mapv(|g| g * g * one_minus_decay);
406
407 let adaptive_lr = self
409 .gradient_accumulator
410 .mapv(|acc| A::sqrt(acc + A::from(*epsilon).expect("unwrap failed")));
411
412 self.parameters = &self.parameters - &(gradient / &adaptive_lr * self.current_lr);
414 }
415 LearningRateAdaptation::Adam {
416 beta1,
417 beta2,
418 epsilon,
419 } => {
420 let beta1_val = A::from(*beta1).expect("unwrap failed");
421 let beta2_val = A::from(*beta2).expect("unwrap failed");
422 let one_minus_beta1 = A::one() - beta1_val;
423 let one_minus_beta2 = A::one() - beta2_val;
424
425 self.gradient_accumulator =
427 &self.gradient_accumulator * beta1_val + gradient * one_minus_beta1;
428
429 if let Some(ref mut second_moment) = self.second_moment_accumulator {
431 *second_moment =
432 &*second_moment * beta2_val + &gradient.mapv(|g| g * g * one_minus_beta2);
433
434 let step_count_float = A::from(self.step_count).expect("unwrap failed");
436 let bias_correction1 = A::one() - A::powf(beta1_val, step_count_float);
437 let bias_correction2 = A::one() - A::powf(beta2_val, step_count_float);
438
439 let corrected_first = &self.gradient_accumulator / bias_correction1;
440 let corrected_second = &*second_moment / bias_correction2;
441
442 let adaptive_lr = corrected_second
444 .mapv(|v| A::sqrt(v) + A::from(*epsilon).expect("unwrap failed"));
445 self.parameters =
446 &self.parameters - &(corrected_first / adaptive_lr * self.current_lr);
447 }
448 }
449 LearningRateAdaptation::ExponentialDecay { decay_rate } => {
450 self.current_lr = self.current_lr * A::from(*decay_rate).expect("unwrap failed");
452 self.parameters = &self.parameters - gradient * self.current_lr;
453 }
454 LearningRateAdaptation::InverseScaling { power } => {
455 let step_power = A::powf(
457 A::from(self.step_count).expect("unwrap failed"),
458 A::from(*power).expect("unwrap failed"),
459 );
460 let decayed_lr = self.current_lr / step_power;
461 self.parameters = &self.parameters - gradient * decayed_lr;
462 }
463 }
464
465 Ok(())
466 }
467
468 fn online_newton_update(&mut self, gradient: &Array<A, D>, damping: f64) -> Result<()> {
470 let damping_val = A::from(damping).expect("unwrap failed");
472
473 let hessian_approx = gradient.mapv(|g| g * g + damping_val);
475
476 let newton_step = gradient / hessian_approx;
478 self.parameters = &self.parameters - &newton_step * self.current_lr;
479
480 Ok(())
481 }
482
483 fn ftrl_update(
485 &mut self,
486 gradient: &Array<A, D>,
487 l1_reg: f64,
488 l2_reg: f64,
489 lr_power: f64,
490 ) -> Result<()> {
491 self.gradient_accumulator = &self.gradient_accumulator + gradient;
493
494 let step_factor = A::powf(
496 A::from(self.step_count).expect("unwrap failed"),
497 A::from(lr_power).expect("unwrap failed"),
498 );
499 let learning_rate = self.current_lr / step_factor;
500
501 let l1_weight = A::from(l1_reg).expect("unwrap failed");
503 let l2_weight = A::from(l2_reg).expect("unwrap failed");
504
505 self.parameters = self.gradient_accumulator.mapv(|g| {
506 let abs_g = A::abs(g);
507 if abs_g <= l1_weight {
508 A::zero()
509 } else {
510 let sign = if g > A::zero() { A::one() } else { -A::one() };
511 -sign * (abs_g - l1_weight) / (l2_weight + A::sqrt(abs_g))
512 }
513 }) * learning_rate;
514
515 Ok(())
516 }
517
518 fn mirror_descent_update(
520 &mut self,
521 gradient: &Array<A, D>,
522 mirror_fn: &MirrorFunction,
523 regularization: f64,
524 ) -> Result<()> {
525 match mirror_fn {
526 MirrorFunction::Euclidean => {
527 self.parameters = &self.parameters - gradient * self.current_lr;
529 }
530 MirrorFunction::Entropy => {
531 let reg_val = A::from(regularization).expect("unwrap failed");
533 let updated = self
534 .parameters
535 .mapv(|p| A::exp(A::ln(p) - self.current_lr * reg_val));
536 let sum = updated.sum();
537 self.parameters = updated / sum; }
539 MirrorFunction::L1 => {
540 let threshold = self.current_lr * A::from(regularization).expect("unwrap failed");
542 self.parameters = (&self.parameters - gradient * self.current_lr).mapv(|p| {
543 if A::abs(p) <= threshold {
544 A::zero()
545 } else {
546 p - A::signum(p) * threshold
547 }
548 });
549 }
550 MirrorFunction::Nuclear => {
551 self.parameters = &self.parameters - gradient * self.current_lr;
553 }
554 }
555
556 Ok(())
557 }
558
559 fn adaptive_multitask_update(&mut self, gradient: &Array<A, D>) -> Result<()> {
561 self.parameters = &self.parameters - gradient * self.current_lr;
563 Ok(())
564 }
565
566 fn update_regret_bound(&mut self, loss: A) {
568 if let Some(&best_loss) = self
569 .performance_history
570 .iter()
571 .min_by(|a, b| a.partial_cmp(b).expect("unwrap failed"))
572 {
573 let regret = loss - best_loss;
574 self.regret_bound = self.regret_bound + regret.max(A::zero());
575 }
576 }
577
578 pub fn parameters(&self) -> &Array<A, D> {
580 &self.parameters
581 }
582
583 pub fn get_performance_metrics(&self) -> OnlinePerformanceMetrics<A> {
585 let average_loss = if self.performance_history.is_empty() {
586 A::zero()
587 } else {
588 self.performance_history.iter().copied().sum::<A>()
589 / A::from(self.performance_history.len()).expect("unwrap failed")
590 };
591
592 let lr_stability = A::from(1.0).expect("unwrap failed"); let adaptation_speed = A::from(self.step_count as f64).expect("unwrap failed"); let memory_efficiency = A::from(0.8).expect("unwrap failed"); OnlinePerformanceMetrics {
597 cumulative_regret: self.regret_bound,
598 average_loss,
599 lr_stability,
600 adaptation_speed,
601 memory_efficiency,
602 }
603 }
604}
605
606impl<A: Float + ScalarOperand + Debug + std::iter::Sum, D: Dimension + Send + Sync>
607 LifelongOptimizer<A, D>
608{
609 pub fn new(strategy: LifelongStrategy) -> Self {
611 Self {
612 strategy,
613 task_optimizers: HashMap::new(),
614 shared_knowledge: SharedKnowledge {
615 fisher_information: None,
616 important_parameters: None,
617 task_embeddings: HashMap::new(),
618 transfer_weights: HashMap::new(),
619 meta_parameters: None,
620 },
621 task_graph: TaskGraph {
622 task_similarities: HashMap::new(),
623 task_dependencies: HashMap::new(),
624 task_clusters: HashMap::new(),
625 },
626 memory_buffer: MemoryBuffer {
627 examples: VecDeque::new(),
628 max_size: 1000,
629 update_strategy: MemoryUpdateStrategy::FIFO,
630 importance_scores: VecDeque::new(),
631 },
632 current_task: None,
633 task_performance: HashMap::new(),
634 }
635 }
636
637 pub fn start_task(&mut self, task_id: String, initial_parameters: Array<A, D>) -> Result<()> {
639 self.current_task = Some(task_id.clone());
640
641 let online_strategy = OnlineLearningStrategy::AdaptiveSGD {
643 initial_lr: 0.001,
644 adaptation_method: LearningRateAdaptation::Adam {
645 beta1: 0.9,
646 beta2: 0.999,
647 epsilon: 1e-8,
648 },
649 };
650
651 let task_optimizer = OnlineOptimizer::new(online_strategy, initial_parameters);
652 self.task_optimizers.insert(task_id.clone(), task_optimizer);
653
654 self.task_performance.insert(task_id, Vec::new());
656
657 Ok(())
658 }
659
660 pub fn update_current_task(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
662 let task_id = self
663 .current_task
664 .as_ref()
665 .ok_or_else(|| OptimError::InvalidConfig("No current task set".to_string()))?
666 .clone();
667
668 if let Some(optimizer) = self.task_optimizers.get_mut(&task_id) {
670 optimizer.online_update(gradient, loss)?;
671 }
672
673 if let Some(performance) = self.task_performance.get_mut(&task_id) {
675 performance.push(loss);
676 }
677
678 match &self.strategy {
680 LifelongStrategy::ElasticWeightConsolidation {
681 importance_weight, ..
682 } => {
683 self.apply_ewc_regularization(gradient, *importance_weight)?;
684 }
685 LifelongStrategy::ProgressiveNetworks { .. } => {
686 self.apply_progressive_networks(gradient)?;
687 }
688 LifelongStrategy::MemoryAugmented { .. } => {
689 self.update_memory_buffer(gradient, loss)?;
690 }
691 LifelongStrategy::MetaLearning { .. } => {
692 self.apply_meta_learning(gradient)?;
693 }
694 LifelongStrategy::GradientEpisodicMemory { .. } => {
695 self.apply_gem_constraints(gradient)?;
696 }
697 }
698
699 Ok(())
700 }
701
702 fn apply_ewc_regularization(
704 &mut self,
705 gradient: &Array<A, D>,
706 _importance_weight: f64,
707 ) -> Result<()> {
708 Ok(())
711 }
712
713 fn apply_progressive_networks(&mut self, gradient: &Array<A, D>) -> Result<()> {
715 Ok(())
718 }
719
720 fn update_memory_buffer(&mut self, gradient: &Array<A, D>, loss: A) -> Result<()> {
722 if let Some(task_id) = &self.current_task {
723 let example = MemoryExample {
724 input: Array::zeros(gradient.raw_dim()), target: Array::zeros(gradient.raw_dim()), task_id: task_id.clone(),
727 importance: loss,
728 gradient: Some(gradient.clone()),
729 };
730
731 if self.memory_buffer.examples.len() >= self.memory_buffer.max_size {
733 match self.memory_buffer.update_strategy {
734 MemoryUpdateStrategy::FIFO => {
735 self.memory_buffer.examples.pop_front();
736 self.memory_buffer.importance_scores.pop_front();
737 }
738 MemoryUpdateStrategy::Random => {
739 let idx = thread_rng().gen_range(0..self.memory_buffer.examples.len());
740 self.memory_buffer.examples.remove(idx);
741 self.memory_buffer.importance_scores.remove(idx);
742 }
743 MemoryUpdateStrategy::ImportanceBased => {
744 if let Some(min_idx) = self
746 .memory_buffer
747 .importance_scores
748 .iter()
749 .enumerate()
750 .min_by(|a, b| a.1.partial_cmp(b.1).expect("unwrap failed"))
751 .map(|(idx, _)| idx)
752 {
753 self.memory_buffer.examples.remove(min_idx);
754 self.memory_buffer.importance_scores.remove(min_idx);
755 }
756 }
757 MemoryUpdateStrategy::GradientDiversity => {
758 self.memory_buffer.examples.pop_front();
760 self.memory_buffer.importance_scores.pop_front();
761 }
762 }
763 }
764
765 self.memory_buffer.examples.push_back(example);
766 self.memory_buffer.importance_scores.push_back(loss);
767 }
768
769 Ok(())
770 }
771
772 fn apply_meta_learning(&mut self, gradient: &Array<A, D>) -> Result<()> {
774 Ok(())
777 }
778
779 fn apply_gem_constraints(&mut self, gradient: &Array<A, D>) -> Result<()> {
781 Ok(())
784 }
785
786 pub fn compute_task_similarity(&self, task1: &str, task2: &str) -> f64 {
788 self.task_graph
789 .task_similarities
790 .get(&(task1.to_string(), task2.to_string()))
791 .or_else(|| {
792 self.task_graph
793 .task_similarities
794 .get(&(task2.to_string(), task1.to_string()))
795 })
796 .copied()
797 .unwrap_or(0.0)
798 }
799
800 pub fn get_lifelong_stats(&self) -> LifelongStats<A> {
802 let num_tasks = self.task_optimizers.len();
803 let avg_performance = if self.task_performance.is_empty() {
804 A::zero()
805 } else {
806 let total_performance: A = self.task_performance.values().flatten().copied().sum();
807 let total_samples = self
808 .task_performance
809 .values()
810 .map(|v| v.len())
811 .sum::<usize>();
812 if total_samples > 0 {
813 total_performance / A::from(total_samples).expect("unwrap failed")
814 } else {
815 A::zero()
816 }
817 };
818
819 LifelongStats {
820 num_tasks,
821 average_performance: avg_performance,
822 memory_usage: self.memory_buffer.examples.len(),
823 transfer_efficiency: A::from(0.8).expect("unwrap failed"), catastrophic_forgetting: A::from(0.1).expect("unwrap failed"), }
826 }
827}
828
829#[derive(Debug, Clone)]
831pub struct LifelongStats<A: Float> {
832 pub num_tasks: usize,
834 pub average_performance: A,
836 pub memory_usage: usize,
838 pub transfer_efficiency: A,
840 pub catastrophic_forgetting: A,
842}
843
844#[cfg(test)]
845mod tests {
846 use super::*;
847 use approx::assert_relative_eq;
848
849 #[test]
850 fn test_online_optimizer_creation() {
851 let strategy = OnlineLearningStrategy::AdaptiveSGD {
852 initial_lr: 0.01,
853 adaptation_method: LearningRateAdaptation::AdaGrad { epsilon: 1e-8 },
854 };
855
856 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
857 let optimizer = OnlineOptimizer::new(strategy, initial_params);
858
859 assert_eq!(optimizer.step_count, 0);
860 assert_relative_eq!(optimizer.current_lr, 0.01, epsilon = 1e-6);
861 }
862
863 #[test]
864 fn test_online_update() {
865 let strategy = OnlineLearningStrategy::AdaptiveSGD {
866 initial_lr: 0.1,
867 adaptation_method: LearningRateAdaptation::ExponentialDecay { decay_rate: 0.99 },
868 };
869
870 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
871 let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
872
873 let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
874 let loss = 0.5;
875
876 optimizer
877 .online_update(&gradient, loss)
878 .expect("unwrap failed");
879
880 assert_eq!(optimizer.step_count, 1);
881 assert_eq!(optimizer.performance_history.len(), 1);
882 assert_relative_eq!(optimizer.performance_history[0], 0.5, epsilon = 1e-6);
883 }
884
885 #[test]
886 fn test_lifelong_optimizer_creation() {
887 let strategy = LifelongStrategy::ElasticWeightConsolidation {
888 importance_weight: 1000.0,
889 fisher_samples: 100,
890 };
891
892 let optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
893
894 assert_eq!(optimizer.task_optimizers.len(), 0);
895 assert!(optimizer.current_task.is_none());
896 }
897
898 #[test]
899 fn test_task_management() {
900 let strategy = LifelongStrategy::MemoryAugmented {
901 memory_size: 100,
902 update_strategy: MemoryUpdateStrategy::FIFO,
903 };
904
905 let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
906 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
907
908 optimizer
909 .start_task("task1".to_string(), initial_params)
910 .expect("unwrap failed");
911
912 assert_eq!(optimizer.current_task, Some("task1".to_string()));
913 assert!(optimizer.task_optimizers.contains_key("task1"));
914 assert!(optimizer.task_performance.contains_key("task1"));
915 }
916
917 #[test]
918 fn test_memory_buffer_update() {
919 let strategy = LifelongStrategy::MemoryAugmented {
920 memory_size: 2,
921 update_strategy: MemoryUpdateStrategy::FIFO,
922 };
923
924 let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
925 optimizer.memory_buffer.max_size = 2;
926
927 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
928 optimizer
929 .start_task("task1".to_string(), initial_params)
930 .expect("unwrap failed");
931
932 let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
933
934 optimizer
936 .update_current_task(&gradient, 0.5)
937 .expect("unwrap failed");
938 assert_eq!(optimizer.memory_buffer.examples.len(), 1);
939
940 optimizer
942 .update_current_task(&gradient, 0.6)
943 .expect("unwrap failed");
944 assert_eq!(optimizer.memory_buffer.examples.len(), 2);
945
946 optimizer
948 .update_current_task(&gradient, 0.7)
949 .expect("unwrap failed");
950 assert_eq!(optimizer.memory_buffer.examples.len(), 2);
951 }
952
953 #[test]
954 fn test_performance_metrics() {
955 let strategy = OnlineLearningStrategy::AdaptiveSGD {
956 initial_lr: 0.01,
957 adaptation_method: LearningRateAdaptation::Adam {
958 beta1: 0.9,
959 beta2: 0.999,
960 epsilon: 1e-8,
961 },
962 };
963
964 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
965 let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
966
967 optimizer.performance_history.push_back(0.8);
969 optimizer.performance_history.push_back(0.6);
970 optimizer.performance_history.push_back(0.4);
971 optimizer.regret_bound = 0.5;
972
973 let metrics = optimizer.get_performance_metrics();
974
975 assert_relative_eq!(metrics.cumulative_regret, 0.5, epsilon = 1e-6);
976 assert_relative_eq!(metrics.average_loss, 0.6, epsilon = 1e-6);
977 }
978
979 #[test]
980 fn test_lifelong_stats() {
981 let strategy = LifelongStrategy::MetaLearning {
982 meta_lr: 0.001,
983 inner_steps: 5,
984 task_embedding_size: 64,
985 };
986
987 let mut optimizer = LifelongOptimizer::<f64, scirs2_core::ndarray::Ix1>::new(strategy);
988
989 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
991 optimizer
992 .start_task("task1".to_string(), initial_params.clone())
993 .expect("unwrap failed");
994 optimizer
995 .start_task("task2".to_string(), initial_params)
996 .expect("unwrap failed");
997
998 optimizer
1000 .task_performance
1001 .get_mut("task1")
1002 .expect("unwrap failed")
1003 .extend(vec![0.8, 0.7]);
1004 optimizer
1005 .task_performance
1006 .get_mut("task2")
1007 .expect("unwrap failed")
1008 .extend(vec![0.9, 0.8]);
1009
1010 let stats = optimizer.get_lifelong_stats();
1011
1012 assert_eq!(stats.num_tasks, 2);
1013 assert_relative_eq!(stats.average_performance, 0.8, epsilon = 1e-6);
1014 }
1015
1016 #[test]
1017 fn test_learning_rate_adaptations() {
1018 let strategies = vec![
1019 LearningRateAdaptation::AdaGrad { epsilon: 1e-8 },
1020 LearningRateAdaptation::RMSprop {
1021 decay: 0.9,
1022 epsilon: 1e-8,
1023 },
1024 LearningRateAdaptation::Adam {
1025 beta1: 0.9,
1026 beta2: 0.999,
1027 epsilon: 1e-8,
1028 },
1029 LearningRateAdaptation::ExponentialDecay { decay_rate: 0.99 },
1030 LearningRateAdaptation::InverseScaling { power: 0.5 },
1031 ];
1032
1033 for adaptation in strategies {
1034 let strategy = OnlineLearningStrategy::AdaptiveSGD {
1035 initial_lr: 0.01,
1036 adaptation_method: adaptation,
1037 };
1038
1039 let initial_params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
1040 let mut optimizer = OnlineOptimizer::new(strategy, initial_params);
1041
1042 let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
1043 let result = optimizer.online_update(&gradient, 0.5);
1044
1045 assert!(result.is_ok());
1046 assert_eq!(optimizer.step_count, 1);
1047 }
1048 }
1049}