1use scirs2_core::ndarray::{Array1, Array2};
18use scirs2_core::numeric::Float;
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use crate::error::{OptimError, Result};
23
24pub struct ElasticWeightConsolidation<T: Float + Debug + Send + Sync + 'static> {
38 lambda: T,
40 fisher_diagonal: HashMap<String, Array1<T>>,
42 anchor_parameters: HashMap<String, Array1<T>>,
44 task_fisher_diagonals: Vec<HashMap<String, Array1<T>>>,
46 task_anchor_parameters: Vec<HashMap<String, Array1<T>>>,
48 num_samples_fisher: usize,
50 online: bool,
52 gamma: T,
54}
55
56impl<T: Float + Debug + Send + Sync + 'static> ElasticWeightConsolidation<T> {
57 pub fn new(lambda: T) -> Self {
61 Self {
62 lambda,
63 fisher_diagonal: HashMap::new(),
64 anchor_parameters: HashMap::new(),
65 task_fisher_diagonals: Vec::new(),
66 task_anchor_parameters: Vec::new(),
67 num_samples_fisher: 100,
68 online: false,
69 gamma: T::from(0.95).unwrap_or_else(|| T::one()),
70 }
71 }
72
73 pub fn with_num_samples(mut self, n: usize) -> Self {
75 self.num_samples_fisher = n;
76 self
77 }
78
79 pub fn with_online(mut self, online: bool) -> Self {
81 self.online = online;
82 self
83 }
84
85 pub fn with_gamma(mut self, gamma: T) -> Self {
87 self.gamma = gamma;
88 self
89 }
90
91 pub fn compute_fisher_diagonal(
99 &mut self,
100 parameters: &HashMap<String, Array1<T>>,
101 gradients_fn: impl Fn(&HashMap<String, Array1<T>>) -> Result<HashMap<String, Array1<T>>>,
102 ) -> Result<()> {
103 if parameters.is_empty() {
104 return Err(OptimError::InsufficientData(
105 "parameters map is empty".to_string(),
106 ));
107 }
108
109 if self.num_samples_fisher == 0 {
110 return Err(OptimError::InvalidConfig(
111 "num_samples_fisher must be > 0".to_string(),
112 ));
113 }
114
115 let mut fisher_accum: HashMap<String, Array1<T>> = HashMap::new();
117 for (name, param) in parameters {
118 fisher_accum.insert(name.clone(), Array1::from_elem(param.len(), T::zero()));
119 }
120
121 let n_samples = T::from(self.num_samples_fisher).unwrap_or_else(|| T::one());
122
123 for _ in 0..self.num_samples_fisher {
124 let grads = gradients_fn(parameters)?;
125 for (name, grad) in &grads {
126 if let Some(accum) = fisher_accum.get_mut(name) {
127 if accum.len() != grad.len() {
128 return Err(OptimError::ComputationError(format!(
129 "gradient dimension mismatch for '{}': expected {}, got {}",
130 name,
131 accum.len(),
132 grad.len()
133 )));
134 }
135 for (a, g) in accum.iter_mut().zip(grad.iter()) {
137 *a = *a + (*g) * (*g);
138 }
139 }
140 }
141 }
142
143 let mut fisher = HashMap::new();
145 for (name, mut accum) in fisher_accum {
146 for a in accum.iter_mut() {
147 *a = *a / n_samples;
148 }
149 fisher.insert(name, accum);
150 }
151
152 self.fisher_diagonal = fisher;
153 Ok(())
154 }
155
156 pub fn consolidate(&mut self, parameters: &HashMap<String, Array1<T>>) -> Result<()> {
162 if self.fisher_diagonal.is_empty() {
163 return Err(OptimError::InvalidState(
164 "Fisher diagonal not computed; call compute_fisher_diagonal first".to_string(),
165 ));
166 }
167
168 self.anchor_parameters = parameters.clone();
170 self.task_anchor_parameters.push(parameters.clone());
171
172 if self.online {
173 let mut merged = HashMap::new();
175 for (name, new_fisher) in &self.fisher_diagonal {
176 let updated = if let Some(old_fisher) = self
177 .task_fisher_diagonals
178 .last()
179 .and_then(|map| map.get(name))
180 {
181 let mut result = Array1::from_elem(new_fisher.len(), T::zero());
183 for i in 0..result.len() {
184 result[i] = self.gamma * old_fisher[i] + new_fisher[i];
185 }
186 result
187 } else {
188 new_fisher.clone()
189 };
190 merged.insert(name.clone(), updated);
191 }
192 self.task_fisher_diagonals.push(merged);
193 } else {
194 self.task_fisher_diagonals
196 .push(self.fisher_diagonal.clone());
197 }
198
199 Ok(())
200 }
201
202 pub fn ewc_penalty(&self, parameters: &HashMap<String, Array1<T>>) -> Result<T> {
209 if self.anchor_parameters.is_empty() {
210 return Err(OptimError::InvalidState(
211 "no anchor parameters stored; call consolidate first".to_string(),
212 ));
213 }
214
215 let half = T::from(0.5).unwrap_or_else(|| T::one());
216 let mut total_penalty = T::zero();
217
218 if self.online {
219 let fisher_map = self.task_fisher_diagonals.last().ok_or_else(|| {
221 OptimError::InvalidState("no consolidated Fisher available".to_string())
222 })?;
223
224 for (name, anchor) in &self.anchor_parameters {
225 let current = parameters.get(name).ok_or_else(|| {
226 OptimError::InvalidState(format!("parameter '{}' not found in input", name))
227 })?;
228 let fisher = fisher_map.get(name).ok_or_else(|| {
229 OptimError::InvalidState(format!("Fisher information for '{}' not found", name))
230 })?;
231
232 for i in 0..anchor.len() {
233 let diff = current[i] - anchor[i];
234 total_penalty = total_penalty + fisher[i] * diff * diff;
235 }
236 }
237 } else {
238 for (task_idx, task_fisher) in self.task_fisher_diagonals.iter().enumerate() {
240 let task_anchor = &self.task_anchor_parameters[task_idx];
241 for (name, anchor) in task_anchor {
242 let current = parameters.get(name).ok_or_else(|| {
243 OptimError::InvalidState(format!("parameter '{}' not found in input", name))
244 })?;
245 if let Some(fisher) = task_fisher.get(name) {
246 for i in 0..anchor.len() {
247 let diff = current[i] - anchor[i];
248 total_penalty = total_penalty + fisher[i] * diff * diff;
249 }
250 }
251 }
252 }
253 }
254
255 Ok(self.lambda * half * total_penalty)
256 }
257
258 pub fn ewc_gradient(
262 &self,
263 parameters: &HashMap<String, Array1<T>>,
264 ) -> Result<HashMap<String, Array1<T>>> {
265 if self.anchor_parameters.is_empty() {
266 return Err(OptimError::InvalidState(
267 "no anchor parameters stored; call consolidate first".to_string(),
268 ));
269 }
270
271 let mut gradients: HashMap<String, Array1<T>> = HashMap::new();
272
273 for (name, param) in parameters {
275 gradients.insert(name.clone(), Array1::from_elem(param.len(), T::zero()));
276 }
277
278 if self.online {
279 let fisher_map = self.task_fisher_diagonals.last().ok_or_else(|| {
280 OptimError::InvalidState("no consolidated Fisher available".to_string())
281 })?;
282
283 for (name, anchor) in &self.anchor_parameters {
284 let current = parameters.get(name).ok_or_else(|| {
285 OptimError::InvalidState(format!("parameter '{}' not found in input", name))
286 })?;
287 let fisher = fisher_map.get(name).ok_or_else(|| {
288 OptimError::InvalidState(format!("Fisher information for '{}' not found", name))
289 })?;
290 let grad = gradients.get_mut(name).ok_or_else(|| {
291 OptimError::InvalidState(format!("gradient entry for '{}' not found", name))
292 })?;
293
294 for i in 0..anchor.len() {
295 grad[i] = grad[i] + self.lambda * fisher[i] * (current[i] - anchor[i]);
296 }
297 }
298 } else {
299 for (task_idx, task_fisher) in self.task_fisher_diagonals.iter().enumerate() {
300 let task_anchor = &self.task_anchor_parameters[task_idx];
301 for (name, anchor) in task_anchor {
302 let current = parameters.get(name).ok_or_else(|| {
303 OptimError::InvalidState(format!("parameter '{}' not found in input", name))
304 })?;
305 if let Some(fisher) = task_fisher.get(name) {
306 let grad = gradients.get_mut(name).ok_or_else(|| {
307 OptimError::InvalidState(format!(
308 "gradient entry for '{}' not found",
309 name
310 ))
311 })?;
312
313 for i in 0..anchor.len() {
314 grad[i] = grad[i] + self.lambda * fisher[i] * (current[i] - anchor[i]);
315 }
316 }
317 }
318 }
319 }
320
321 Ok(gradients)
322 }
323
324 pub fn num_tasks(&self) -> usize {
326 self.task_fisher_diagonals.len()
327 }
328}
329
330#[derive(Debug, Clone)]
336pub struct NetworkColumn<T: Float + Debug + Send + Sync + 'static> {
337 pub weights: Vec<Array2<T>>,
339 pub biases: Vec<Array1<T>>,
341 pub frozen: bool,
343}
344
345impl<T: Float + Debug + Send + Sync + 'static> NetworkColumn<T> {
346 pub fn num_layers(&self) -> usize {
348 self.weights.len()
349 }
350
351 fn forward_layer(&self, input: &Array1<T>, layer_idx: usize) -> Result<(Array1<T>, Array1<T>)> {
354 if layer_idx >= self.weights.len() {
355 return Err(OptimError::NetworkError(format!(
356 "layer index {} out of range (column has {} layers)",
357 layer_idx,
358 self.weights.len()
359 )));
360 }
361 let w = &self.weights[layer_idx];
362 let b = &self.biases[layer_idx];
363
364 let out_dim = w.nrows();
366 let in_dim = w.ncols();
367 if input.len() != in_dim {
368 return Err(OptimError::NetworkError(format!(
369 "input dimension mismatch at layer {}: weight expects {}, got {}",
370 layer_idx,
371 in_dim,
372 input.len()
373 )));
374 }
375
376 let mut z = b.clone();
377 for i in 0..out_dim {
378 let mut sum = T::zero();
379 for j in 0..in_dim {
380 sum = sum + w[[i, j]] * input[j];
381 }
382 z[i] = z[i] + sum;
383 }
384
385 let is_last_layer = layer_idx == self.weights.len() - 1;
387 let a = if is_last_layer {
388 z.clone()
389 } else {
390 z.mapv(|v| if v > T::zero() { v } else { T::zero() })
391 };
392
393 Ok((z, a))
394 }
395}
396
397pub struct ProgressiveNetworks<T: Float + Debug + Send + Sync + 'static> {
403 columns: Vec<NetworkColumn<T>>,
405 lateral_connections: Vec<Vec<Array2<T>>>,
409 active_column: usize,
411 hidden_sizes: Vec<usize>,
413}
414
415impl<T: Float + Debug + Send + Sync + 'static> ProgressiveNetworks<T> {
416 pub fn new(hidden_sizes: Vec<usize>) -> Self {
420 Self {
421 columns: Vec::new(),
422 lateral_connections: Vec::new(),
423 active_column: 0,
424 hidden_sizes,
425 }
426 }
427
428 pub fn add_task_column(&mut self, input_size: usize, output_size: usize) -> Result<usize> {
436 if input_size == 0 || output_size == 0 {
437 return Err(OptimError::InvalidConfig(
438 "input_size and output_size must be > 0".to_string(),
439 ));
440 }
441
442 let col_id = self.columns.len();
443
444 for col in &mut self.columns {
446 col.frozen = true;
447 }
448
449 let mut layer_sizes = Vec::with_capacity(self.hidden_sizes.len() + 2);
451 layer_sizes.push(input_size);
452 layer_sizes.extend_from_slice(&self.hidden_sizes);
453 layer_sizes.push(output_size);
454
455 let mut weights = Vec::new();
457 let mut biases = Vec::new();
458
459 for l in 0..layer_sizes.len() - 1 {
460 let fan_in = layer_sizes[l];
461 let fan_out = layer_sizes[l + 1];
462
463 let scale_val = T::from(2.0 / ((fan_in + fan_out) as f64)).unwrap_or_else(|| T::one());
465 let scale = Float::sqrt(scale_val);
466
467 let mut w = Array2::from_elem((fan_out, fan_in), T::zero());
469 for i in 0..fan_out {
470 for j in 0..fan_in {
471 let idx = (i * fan_in + j) as f64;
473 let sign = if (i + j) % 2 == 0 {
474 T::one()
475 } else {
476 -T::one()
477 };
478 let magnitude =
479 scale * T::from((idx + 1.0).recip()).unwrap_or_else(|| T::one());
480 w[[i, j]] = sign * magnitude;
481 }
482 }
483
484 let b = Array1::from_elem(fan_out, T::zero());
485
486 weights.push(w);
487 biases.push(b);
488 }
489
490 let column = NetworkColumn {
491 weights,
492 biases,
493 frozen: false,
494 };
495
496 self.columns.push(column);
497
498 let mut laterals_for_col = Vec::new();
500 if col_id > 0 {
501 for l in 0..layer_sizes.len() - 1 {
502 let lateral_in = if l == 0 {
508 0
510 } else {
511 col_id * layer_sizes[l]
513 };
514 let lateral_out = layer_sizes[l + 1];
515
516 if lateral_in == 0 {
517 laterals_for_col.push(Array2::from_elem((lateral_out, 1), T::zero()));
519 } else {
520 let scale_val = T::from(0.1 / (lateral_in as f64)).unwrap_or_else(|| T::one());
522 let scale = Float::sqrt(scale_val);
523 let mut lat_w = Array2::from_elem((lateral_out, lateral_in), T::zero());
524 for i in 0..lateral_out {
525 for j in 0..lateral_in {
526 let sign = if (i + j) % 3 == 0 {
527 T::one()
528 } else if (i + j) % 3 == 1 {
529 -T::one()
530 } else {
531 T::zero()
532 };
533 lat_w[[i, j]] = sign * scale;
534 }
535 }
536 laterals_for_col.push(lat_w);
537 }
538 }
539 }
540
541 self.lateral_connections.push(laterals_for_col);
542 self.active_column = col_id;
543
544 Ok(col_id)
545 }
546
547 pub fn forward(&self, input: &Array1<T>, task_id: usize) -> Result<Array1<T>> {
552 let (output, _) = self.forward_with_laterals(input, task_id)?;
553 Ok(output)
554 }
555
556 pub fn forward_with_laterals(
562 &self,
563 input: &Array1<T>,
564 task_id: usize,
565 ) -> Result<(Array1<T>, Vec<Array1<T>>)> {
566 if task_id >= self.columns.len() {
567 return Err(OptimError::InvalidState(format!(
568 "task_id {} out of range; only {} columns exist",
569 task_id,
570 self.columns.len()
571 )));
572 }
573
574 let mut all_activations: Vec<Vec<Array1<T>>> = Vec::new();
577
578 for col in 0..=task_id {
579 let column = &self.columns[col];
580 let num_layers = column.num_layers();
581 let mut col_activations: Vec<Array1<T>> = Vec::new();
582 let mut h = input.clone();
583
584 for l in 0..num_layers {
585 if col > 0 && l > 0 {
587 let laterals = &self.lateral_connections[col];
588 if !laterals.is_empty() && l < laterals.len() {
589 let lat_w = &laterals[l];
590 let mut lateral_input_parts: Vec<T> = Vec::new();
593 for prev_col_acts in all_activations.iter().take(col) {
594 if l - 1 < prev_col_acts.len() {
595 let prev_act = &prev_col_acts[l - 1];
596 lateral_input_parts.extend(prev_act.iter().copied());
597 }
598 }
599
600 if !lateral_input_parts.is_empty() {
601 let lateral_input = Array1::from_vec(lateral_input_parts);
602
603 if lat_w.ncols() == lateral_input.len() {
605 let lat_out_dim = lat_w.nrows();
606 let lat_in_dim = lat_w.ncols();
607 let mut lateral_contribution =
608 Array1::from_elem(lat_out_dim, T::zero());
609 for i in 0..lat_out_dim {
610 let mut sum = T::zero();
611 for j in 0..lat_in_dim {
612 sum = sum + lat_w[[i, j]] * lateral_input[j];
613 }
614 lateral_contribution[i] = sum;
615 }
616
617 if h.len() == lateral_contribution.len() {
620 for i in 0..h.len() {
621 h[i] = h[i] + lateral_contribution[i];
622 }
623 }
624 }
625 }
626 }
627 }
628
629 let (_pre, post) = column.forward_layer(&h, l)?;
630 col_activations.push(post.clone());
631 h = post;
632 }
633
634 all_activations.push(col_activations);
635 }
636
637 let target_activations = &all_activations[task_id];
639 let output = target_activations
640 .last()
641 .ok_or_else(|| OptimError::NetworkError("column has no layers".to_string()))?
642 .clone();
643
644 let mut intermediates: Vec<Array1<T>> = Vec::new();
646 for col_acts in &all_activations {
647 for act in col_acts {
648 intermediates.push(act.clone());
649 }
650 }
651
652 Ok((output, intermediates))
653 }
654
655 pub fn freeze_column(&mut self, column_id: usize) -> Result<()> {
657 if column_id >= self.columns.len() {
658 return Err(OptimError::InvalidState(format!(
659 "column_id {} out of range; only {} columns exist",
660 column_id,
661 self.columns.len()
662 )));
663 }
664 self.columns[column_id].frozen = true;
665 Ok(())
666 }
667
668 pub fn num_columns(&self) -> usize {
670 self.columns.len()
671 }
672
673 pub fn get_column_parameters(&self, column_id: usize) -> Result<&NetworkColumn<T>> {
675 self.columns.get(column_id).ok_or_else(|| {
676 OptimError::InvalidState(format!(
677 "column_id {} out of range; only {} columns exist",
678 column_id,
679 self.columns.len()
680 ))
681 })
682 }
683}
684
685#[cfg(test)]
690mod tests {
691 use super::*;
692 use scirs2_core::ndarray::Array1;
693
694 type F = f64;
695
696 fn make_params(names: &[&str], size: usize, value: f64) -> HashMap<String, Array1<F>> {
698 let mut map = HashMap::new();
699 for name in names {
700 map.insert(name.to_string(), Array1::from_elem(size, value));
701 }
702 map
703 }
704
705 #[test]
710 fn test_ewc_creation_and_configuration() {
711 let ewc: ElasticWeightConsolidation<F> = ElasticWeightConsolidation::new(1000.0)
712 .with_num_samples(50)
713 .with_online(true)
714 .with_gamma(0.9);
715
716 assert_eq!(ewc.num_tasks(), 0);
717 assert!(ewc.online);
718 assert!((ewc.gamma - 0.9).abs() < 1e-12);
719 assert!((ewc.lambda - 1000.0).abs() < 1e-12);
720 assert_eq!(ewc.num_samples_fisher, 50);
721 }
722
723 #[test]
724 fn test_fisher_diagonal_computation() {
725 let mut ewc: ElasticWeightConsolidation<F> =
726 ElasticWeightConsolidation::new(1.0).with_num_samples(10);
727
728 let params = make_params(&["w1", "w2"], 3, 1.0);
729
730 let grad_fn = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
733 let mut g = HashMap::new();
734 g.insert("w1".to_string(), Array1::from_vec(vec![1.0, 2.0, 3.0]));
735 g.insert("w2".to_string(), Array1::from_vec(vec![0.5, 0.5, 0.5]));
736 Ok(g)
737 };
738
739 ewc.compute_fisher_diagonal(¶ms, grad_fn)
740 .expect("compute_fisher_diagonal should succeed");
741
742 let f1 = ewc
744 .fisher_diagonal
745 .get("w1")
746 .expect("w1 Fisher should exist");
747 assert!((f1[0] - 1.0).abs() < 1e-12);
748 assert!((f1[1] - 4.0).abs() < 1e-12);
749 assert!((f1[2] - 9.0).abs() < 1e-12);
750
751 let f2 = ewc
752 .fisher_diagonal
753 .get("w2")
754 .expect("w2 Fisher should exist");
755 assert!((f2[0] - 0.25).abs() < 1e-12);
756 }
757
758 #[test]
759 fn test_ewc_penalty_computation() {
760 let mut ewc: ElasticWeightConsolidation<F> =
761 ElasticWeightConsolidation::new(2.0).with_num_samples(5);
762
763 let anchor = make_params(&["w"], 2, 1.0);
764
765 let grad_fn = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
767 let mut g = HashMap::new();
768 g.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0]));
769 Ok(g)
770 };
771
772 ewc.compute_fisher_diagonal(&anchor, grad_fn)
773 .expect("compute Fisher should succeed");
774 ewc.consolidate(&anchor)
775 .expect("consolidate should succeed");
776
777 let mut current = HashMap::new();
779 current.insert("w".to_string(), Array1::from_vec(vec![2.0, 3.0]));
780
781 let penalty = ewc
784 .ewc_penalty(¤t)
785 .expect("ewc_penalty should succeed");
786 assert!(
787 (penalty - 5.0).abs() < 1e-12,
788 "expected penalty 5.0, got {}",
789 penalty
790 );
791 }
792
793 #[test]
794 fn test_ewc_gradient_computation() {
795 let mut ewc: ElasticWeightConsolidation<F> =
796 ElasticWeightConsolidation::new(2.0).with_num_samples(5);
797
798 let anchor = make_params(&["w"], 2, 1.0);
799
800 let grad_fn = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
801 let mut g = HashMap::new();
802 g.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0]));
803 Ok(g)
804 };
805
806 ewc.compute_fisher_diagonal(&anchor, grad_fn)
807 .expect("compute Fisher should succeed");
808 ewc.consolidate(&anchor)
809 .expect("consolidate should succeed");
810
811 let mut current = HashMap::new();
812 current.insert("w".to_string(), Array1::from_vec(vec![2.0, 3.0]));
813
814 let grads = ewc
817 .ewc_gradient(¤t)
818 .expect("ewc_gradient should succeed");
819 let gw = grads.get("w").expect("gradient for w should exist");
820 assert!((gw[0] - 2.0).abs() < 1e-12, "expected 2.0, got {}", gw[0]);
821 assert!((gw[1] - 4.0).abs() < 1e-12, "expected 4.0, got {}", gw[1]);
822 }
823
824 #[test]
825 fn test_consolidation_workflow() {
826 let mut ewc: ElasticWeightConsolidation<F> =
827 ElasticWeightConsolidation::new(1.0).with_num_samples(5);
828
829 let params_task1 = make_params(&["w"], 2, 1.0);
830
831 let grad_fn = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
832 let mut g = HashMap::new();
833 g.insert("w".to_string(), Array1::from_vec(vec![1.0, 2.0]));
834 Ok(g)
835 };
836
837 ewc.compute_fisher_diagonal(¶ms_task1, grad_fn)
838 .expect("Fisher computation should succeed");
839 ewc.consolidate(¶ms_task1)
840 .expect("consolidation should succeed");
841 assert_eq!(ewc.num_tasks(), 1);
842
843 let params_task2 = make_params(&["w"], 2, 2.0);
845 ewc.compute_fisher_diagonal(¶ms_task2, grad_fn)
846 .expect("Fisher computation for task 2 should succeed");
847 ewc.consolidate(¶ms_task2)
848 .expect("consolidation for task 2 should succeed");
849 assert_eq!(ewc.num_tasks(), 2);
850
851 let penalty = ewc
853 .ewc_penalty(¶ms_task2)
854 .expect("penalty should succeed");
855 assert!(
859 penalty > 0.0,
860 "penalty should be positive due to task1 Fisher, got {}",
861 penalty
862 );
863 }
864
865 #[test]
866 fn test_online_ewc_multiple_tasks() {
867 let mut ewc: ElasticWeightConsolidation<F> = ElasticWeightConsolidation::new(1.0)
868 .with_num_samples(5)
869 .with_online(true)
870 .with_gamma(0.5);
871
872 let params1 = make_params(&["w"], 2, 0.0);
874 let grad_fn1 = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
875 let mut g = HashMap::new();
876 g.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0]));
877 Ok(g)
878 };
879 ewc.compute_fisher_diagonal(¶ms1, grad_fn1)
880 .expect("Fisher 1 should succeed");
881 ewc.consolidate(¶ms1)
882 .expect("consolidate 1 should succeed");
883 assert_eq!(ewc.num_tasks(), 1);
884
885 let fisher1 = ewc.task_fisher_diagonals[0]
887 .get("w")
888 .expect("Fisher should exist");
889 assert!((fisher1[0] - 1.0).abs() < 1e-12);
890
891 let params2 = make_params(&["w"], 2, 1.0);
893 let grad_fn2 = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
894 let mut g = HashMap::new();
895 g.insert("w".to_string(), Array1::from_vec(vec![2.0, 2.0]));
896 Ok(g)
897 };
898 ewc.compute_fisher_diagonal(¶ms2, grad_fn2)
899 .expect("Fisher 2 should succeed");
900 ewc.consolidate(¶ms2)
901 .expect("consolidate 2 should succeed");
902 assert_eq!(ewc.num_tasks(), 2);
903
904 let fisher2 = ewc.task_fisher_diagonals[1]
906 .get("w")
907 .expect("Fisher should exist");
908 assert!(
909 (fisher2[0] - 4.5).abs() < 1e-12,
910 "expected 4.5, got {}",
911 fisher2[0]
912 );
913
914 let mut test_params = HashMap::new();
916 test_params.insert("w".to_string(), Array1::from_vec(vec![2.0, 2.0]));
917
918 let penalty = ewc
919 .ewc_penalty(&test_params)
920 .expect("penalty should succeed");
921 assert!(
923 (penalty - 4.5).abs() < 1e-12,
924 "expected 4.5, got {}",
925 penalty
926 );
927 }
928
929 #[test]
934 fn test_progressive_network_creation_and_columns() {
935 let mut pn: ProgressiveNetworks<F> = ProgressiveNetworks::new(vec![16, 8]);
936
937 assert_eq!(pn.num_columns(), 0);
938
939 let col0 = pn
940 .add_task_column(4, 2)
941 .expect("add column 0 should succeed");
942 assert_eq!(col0, 0);
943 assert_eq!(pn.num_columns(), 1);
944
945 let params0 = pn
947 .get_column_parameters(0)
948 .expect("get column 0 should succeed");
949 assert_eq!(params0.num_layers(), 3);
950 assert!(!params0.frozen);
951
952 let col1 = pn
954 .add_task_column(4, 2)
955 .expect("add column 1 should succeed");
956 assert_eq!(col1, 1);
957 assert_eq!(pn.num_columns(), 2);
958
959 let params0_after = pn
960 .get_column_parameters(0)
961 .expect("get column 0 should succeed");
962 assert!(params0_after.frozen, "column 0 should be frozen");
963
964 let params1 = pn
965 .get_column_parameters(1)
966 .expect("get column 1 should succeed");
967 assert!(!params1.frozen, "column 1 should not be frozen");
968 }
969
970 #[test]
971 fn test_progressive_network_forward_single_column() {
972 let mut pn: ProgressiveNetworks<F> = ProgressiveNetworks::new(vec![8]);
973
974 pn.add_task_column(4, 2).expect("add column should succeed");
975
976 let input = Array1::from_vec(vec![1.0, 0.5, -0.3, 0.8]);
977 let output = pn.forward(&input, 0).expect("forward should succeed");
978
979 assert_eq!(output.len(), 2, "output should have 2 elements");
981
982 for val in output.iter() {
984 assert!(val.is_finite(), "output should be finite, got {}", val);
985 }
986 }
987
988 #[test]
989 fn test_progressive_network_forward_with_laterals() {
990 let mut pn: ProgressiveNetworks<F> = ProgressiveNetworks::new(vec![8]);
991
992 pn.add_task_column(4, 2)
993 .expect("add column 0 should succeed");
994 pn.add_task_column(4, 2)
995 .expect("add column 1 should succeed");
996
997 let input = Array1::from_vec(vec![1.0, 0.5, -0.3, 0.8]);
998
999 let output0 = pn
1001 .forward(&input, 0)
1002 .expect("forward task 0 should succeed");
1003 assert_eq!(output0.len(), 2);
1004
1005 let (output1, intermediates) = pn
1007 .forward_with_laterals(&input, 1)
1008 .expect("forward_with_laterals task 1 should succeed");
1009 assert_eq!(output1.len(), 2);
1010
1011 assert!(
1013 !intermediates.is_empty(),
1014 "should have intermediate activations"
1015 );
1016
1017 let diff: f64 = output0
1020 .iter()
1021 .zip(output1.iter())
1022 .map(|(a, b)| (a - b).abs())
1023 .sum();
1024 for val in output1.iter() {
1027 assert!(val.is_finite(), "output should be finite, got {}", val);
1028 }
1029 let _ = diff; }
1031
1032 #[test]
1033 fn test_progressive_network_freeze_column() {
1034 let mut pn: ProgressiveNetworks<F> = ProgressiveNetworks::new(vec![8]);
1035
1036 pn.add_task_column(4, 2).expect("add column should succeed");
1037
1038 assert!(
1039 !pn.get_column_parameters(0).expect("get params").frozen,
1040 "column should not be frozen initially"
1041 );
1042
1043 pn.freeze_column(0).expect("freeze should succeed");
1044
1045 assert!(
1046 pn.get_column_parameters(0).expect("get params").frozen,
1047 "column should be frozen after freeze_column"
1048 );
1049
1050 let err = pn.freeze_column(99);
1052 assert!(err.is_err(), "freezing out-of-range column should fail");
1053 }
1054
1055 #[test]
1056 fn test_progressive_network_multiple_tasks_forward() {
1057 let mut pn: ProgressiveNetworks<F> = ProgressiveNetworks::new(vec![8, 4]);
1058
1059 for _ in 0..3 {
1061 pn.add_task_column(4, 2).expect("add column should succeed");
1062 }
1063 assert_eq!(pn.num_columns(), 3);
1064
1065 let input = Array1::from_vec(vec![0.5, -0.5, 1.0, -1.0]);
1066
1067 for task_id in 0..3 {
1069 let output = pn
1070 .forward(&input, task_id)
1071 .unwrap_or_else(|_| panic!("forward task {} should succeed", task_id));
1072 assert_eq!(
1073 output.len(),
1074 2,
1075 "output for task {} should have 2 elements",
1076 task_id
1077 );
1078 for val in output.iter() {
1079 assert!(
1080 val.is_finite(),
1081 "output for task {} should be finite",
1082 task_id
1083 );
1084 }
1085 }
1086
1087 let err = pn.forward(&input, 10);
1089 assert!(err.is_err(), "forward with invalid task_id should fail");
1090 }
1091}