Skip to main content

optirs_learned/
continual_learning.rs

1// Continual Learning for Optimizers
2//
3// This module implements continual learning techniques that allow neural networks
4// and optimizers to learn new tasks sequentially without catastrophically forgetting
5// previous ones. It includes:
6//
7// - Elastic Weight Consolidation (EWC): Regularization-based approach using Fisher
8//   information to protect important parameters for previous tasks.
9// - Progressive Networks: Architecture-based approach that adds new columns for new
10//   tasks while preserving previously learned knowledge via frozen columns and lateral
11//   connections.
12//
13// References:
14// - Kirkpatrick et al., "Overcoming catastrophic forgetting in neural networks" (2017)
15// - Rusu et al., "Progressive Neural Networks" (2016)
16
17use scirs2_core::ndarray::{Array1, Array2};
18use scirs2_core::numeric::Float;
19use std::collections::HashMap;
20use std::fmt::Debug;
21
22use crate::error::{OptimError, Result};
23
24// ---------------------------------------------------------------------------
25// Elastic Weight Consolidation (EWC)
26// ---------------------------------------------------------------------------
27
28/// Elastic Weight Consolidation (EWC) regularizer for continual learning.
29///
30/// EWC prevents catastrophic forgetting by adding a quadratic penalty that
31/// discourages changes to parameters that were important for previous tasks.
32/// The importance of each parameter is estimated via the diagonal of the
33/// Fisher information matrix.
34///
35/// Supports both standard (multi-task) EWC and online EWC with exponential
36/// decay of older Fisher information.
37pub struct ElasticWeightConsolidation<T: Float + Debug + Send + Sync + 'static> {
38    /// Regularization strength (lambda)
39    lambda: T,
40    /// Diagonal Fisher information matrix per parameter group
41    fisher_diagonal: HashMap<String, Array1<T>>,
42    /// Anchor parameters (theta*) stored after training on a task
43    anchor_parameters: HashMap<String, Array1<T>>,
44    /// Per-task Fisher diagonals for standard (non-online) EWC
45    task_fisher_diagonals: Vec<HashMap<String, Array1<T>>>,
46    /// Per-task anchor parameters (for standard EWC penalty over all tasks)
47    task_anchor_parameters: Vec<HashMap<String, Array1<T>>>,
48    /// Number of gradient samples used to estimate the Fisher
49    num_samples_fisher: usize,
50    /// Whether to use online EWC (running average of Fisher)
51    online: bool,
52    /// Decay factor for online EWC
53    gamma: T,
54}
55
56impl<T: Float + Debug + Send + Sync + 'static> ElasticWeightConsolidation<T> {
57    /// Create a new EWC regularizer with the given lambda (regularization strength).
58    ///
59    /// Defaults: num_samples_fisher = 100, online = false, gamma = 0.95
60    pub fn new(lambda: T) -> Self {
61        Self {
62            lambda,
63            fisher_diagonal: HashMap::new(),
64            anchor_parameters: HashMap::new(),
65            task_fisher_diagonals: Vec::new(),
66            task_anchor_parameters: Vec::new(),
67            num_samples_fisher: 100,
68            online: false,
69            gamma: T::from(0.95).unwrap_or_else(|| T::one()),
70        }
71    }
72
73    /// Set the number of gradient samples used to estimate the Fisher information.
74    pub fn with_num_samples(mut self, n: usize) -> Self {
75        self.num_samples_fisher = n;
76        self
77    }
78
79    /// Enable or disable online EWC mode.
80    pub fn with_online(mut self, online: bool) -> Self {
81        self.online = online;
82        self
83    }
84
85    /// Set the decay factor gamma for online EWC.
86    pub fn with_gamma(mut self, gamma: T) -> Self {
87        self.gamma = gamma;
88        self
89    }
90
91    /// Compute the diagonal Fisher information matrix by sampling gradients.
92    ///
93    /// `parameters` - current model parameters keyed by name.
94    /// `gradients_fn` - callable that returns stochastic gradients for a single sample.
95    ///
96    /// The Fisher diagonal is approximated as E[g_i^2] where g_i is the gradient
97    /// of the log-likelihood with respect to parameter i.
98    pub fn compute_fisher_diagonal(
99        &mut self,
100        parameters: &HashMap<String, Array1<T>>,
101        gradients_fn: impl Fn(&HashMap<String, Array1<T>>) -> Result<HashMap<String, Array1<T>>>,
102    ) -> Result<()> {
103        if parameters.is_empty() {
104            return Err(OptimError::InsufficientData(
105                "parameters map is empty".to_string(),
106            ));
107        }
108
109        if self.num_samples_fisher == 0 {
110            return Err(OptimError::InvalidConfig(
111                "num_samples_fisher must be > 0".to_string(),
112            ));
113        }
114
115        // Accumulator for squared gradients
116        let mut fisher_accum: HashMap<String, Array1<T>> = HashMap::new();
117        for (name, param) in parameters {
118            fisher_accum.insert(name.clone(), Array1::from_elem(param.len(), T::zero()));
119        }
120
121        let n_samples = T::from(self.num_samples_fisher).unwrap_or_else(|| T::one());
122
123        for _ in 0..self.num_samples_fisher {
124            let grads = gradients_fn(parameters)?;
125            for (name, grad) in &grads {
126                if let Some(accum) = fisher_accum.get_mut(name) {
127                    if accum.len() != grad.len() {
128                        return Err(OptimError::ComputationError(format!(
129                            "gradient dimension mismatch for '{}': expected {}, got {}",
130                            name,
131                            accum.len(),
132                            grad.len()
133                        )));
134                    }
135                    // Accumulate g_i^2
136                    for (a, g) in accum.iter_mut().zip(grad.iter()) {
137                        *a = *a + (*g) * (*g);
138                    }
139                }
140            }
141        }
142
143        // Average: F_i = (1/N) * sum(g_i^2)
144        let mut fisher = HashMap::new();
145        for (name, mut accum) in fisher_accum {
146            for a in accum.iter_mut() {
147                *a = *a / n_samples;
148            }
149            fisher.insert(name, accum);
150        }
151
152        self.fisher_diagonal = fisher;
153        Ok(())
154    }
155
156    /// Consolidate the current task by storing anchor parameters and Fisher information.
157    ///
158    /// This should be called after training on a task completes.
159    /// For online EWC, the existing Fisher is decayed by gamma and the new Fisher is added.
160    /// For standard EWC, the new Fisher is appended to the list of per-task Fishers.
161    pub fn consolidate(&mut self, parameters: &HashMap<String, Array1<T>>) -> Result<()> {
162        if self.fisher_diagonal.is_empty() {
163            return Err(OptimError::InvalidState(
164                "Fisher diagonal not computed; call compute_fisher_diagonal first".to_string(),
165            ));
166        }
167
168        // Store anchor parameters
169        self.anchor_parameters = parameters.clone();
170        self.task_anchor_parameters.push(parameters.clone());
171
172        if self.online {
173            // Online EWC: running Fisher = gamma * old_Fisher + new_Fisher
174            let mut merged = HashMap::new();
175            for (name, new_fisher) in &self.fisher_diagonal {
176                let updated = if let Some(old_fisher) = self
177                    .task_fisher_diagonals
178                    .last()
179                    .and_then(|map| map.get(name))
180                {
181                    // gamma * old + new
182                    let mut result = Array1::from_elem(new_fisher.len(), T::zero());
183                    for i in 0..result.len() {
184                        result[i] = self.gamma * old_fisher[i] + new_fisher[i];
185                    }
186                    result
187                } else {
188                    new_fisher.clone()
189                };
190                merged.insert(name.clone(), updated);
191            }
192            self.task_fisher_diagonals.push(merged);
193        } else {
194            // Standard EWC: store per-task Fisher
195            self.task_fisher_diagonals
196                .push(self.fisher_diagonal.clone());
197        }
198
199        Ok(())
200    }
201
202    /// Compute the EWC penalty for the current parameters.
203    ///
204    /// penalty = (lambda / 2) * sum_i F_i * (theta_i - theta*_i)^2
205    ///
206    /// For standard EWC, the penalty sums over all stored tasks.
207    /// For online EWC, only the most recent consolidated Fisher is used.
208    pub fn ewc_penalty(&self, parameters: &HashMap<String, Array1<T>>) -> Result<T> {
209        if self.anchor_parameters.is_empty() {
210            return Err(OptimError::InvalidState(
211                "no anchor parameters stored; call consolidate first".to_string(),
212            ));
213        }
214
215        let half = T::from(0.5).unwrap_or_else(|| T::one());
216        let mut total_penalty = T::zero();
217
218        if self.online {
219            // Online EWC: use the latest consolidated Fisher
220            let fisher_map = self.task_fisher_diagonals.last().ok_or_else(|| {
221                OptimError::InvalidState("no consolidated Fisher available".to_string())
222            })?;
223
224            for (name, anchor) in &self.anchor_parameters {
225                let current = parameters.get(name).ok_or_else(|| {
226                    OptimError::InvalidState(format!("parameter '{}' not found in input", name))
227                })?;
228                let fisher = fisher_map.get(name).ok_or_else(|| {
229                    OptimError::InvalidState(format!("Fisher information for '{}' not found", name))
230                })?;
231
232                for i in 0..anchor.len() {
233                    let diff = current[i] - anchor[i];
234                    total_penalty = total_penalty + fisher[i] * diff * diff;
235                }
236            }
237        } else {
238            // Standard EWC: sum penalty over all tasks, each with its own anchor
239            for (task_idx, task_fisher) in self.task_fisher_diagonals.iter().enumerate() {
240                let task_anchor = &self.task_anchor_parameters[task_idx];
241                for (name, anchor) in task_anchor {
242                    let current = parameters.get(name).ok_or_else(|| {
243                        OptimError::InvalidState(format!("parameter '{}' not found in input", name))
244                    })?;
245                    if let Some(fisher) = task_fisher.get(name) {
246                        for i in 0..anchor.len() {
247                            let diff = current[i] - anchor[i];
248                            total_penalty = total_penalty + fisher[i] * diff * diff;
249                        }
250                    }
251                }
252            }
253        }
254
255        Ok(self.lambda * half * total_penalty)
256    }
257
258    /// Compute the gradient of the EWC penalty with respect to current parameters.
259    ///
260    /// grad_i = lambda * F_i * (theta_i - theta*_i)
261    pub fn ewc_gradient(
262        &self,
263        parameters: &HashMap<String, Array1<T>>,
264    ) -> Result<HashMap<String, Array1<T>>> {
265        if self.anchor_parameters.is_empty() {
266            return Err(OptimError::InvalidState(
267                "no anchor parameters stored; call consolidate first".to_string(),
268            ));
269        }
270
271        let mut gradients: HashMap<String, Array1<T>> = HashMap::new();
272
273        // Initialize gradient arrays to zero
274        for (name, param) in parameters {
275            gradients.insert(name.clone(), Array1::from_elem(param.len(), T::zero()));
276        }
277
278        if self.online {
279            let fisher_map = self.task_fisher_diagonals.last().ok_or_else(|| {
280                OptimError::InvalidState("no consolidated Fisher available".to_string())
281            })?;
282
283            for (name, anchor) in &self.anchor_parameters {
284                let current = parameters.get(name).ok_or_else(|| {
285                    OptimError::InvalidState(format!("parameter '{}' not found in input", name))
286                })?;
287                let fisher = fisher_map.get(name).ok_or_else(|| {
288                    OptimError::InvalidState(format!("Fisher information for '{}' not found", name))
289                })?;
290                let grad = gradients.get_mut(name).ok_or_else(|| {
291                    OptimError::InvalidState(format!("gradient entry for '{}' not found", name))
292                })?;
293
294                for i in 0..anchor.len() {
295                    grad[i] = grad[i] + self.lambda * fisher[i] * (current[i] - anchor[i]);
296                }
297            }
298        } else {
299            for (task_idx, task_fisher) in self.task_fisher_diagonals.iter().enumerate() {
300                let task_anchor = &self.task_anchor_parameters[task_idx];
301                for (name, anchor) in task_anchor {
302                    let current = parameters.get(name).ok_or_else(|| {
303                        OptimError::InvalidState(format!("parameter '{}' not found in input", name))
304                    })?;
305                    if let Some(fisher) = task_fisher.get(name) {
306                        let grad = gradients.get_mut(name).ok_or_else(|| {
307                            OptimError::InvalidState(format!(
308                                "gradient entry for '{}' not found",
309                                name
310                            ))
311                        })?;
312
313                        for i in 0..anchor.len() {
314                            grad[i] = grad[i] + self.lambda * fisher[i] * (current[i] - anchor[i]);
315                        }
316                    }
317                }
318            }
319        }
320
321        Ok(gradients)
322    }
323
324    /// Return the number of tasks that have been consolidated so far.
325    pub fn num_tasks(&self) -> usize {
326        self.task_fisher_diagonals.len()
327    }
328}
329
330// ---------------------------------------------------------------------------
331// Progressive Networks
332// ---------------------------------------------------------------------------
333
334/// A single column in a progressive network, representing one task's parameters.
335#[derive(Debug, Clone)]
336pub struct NetworkColumn<T: Float + Debug + Send + Sync + 'static> {
337    /// Weight matrices per layer (layer l has shape [output_l, input_l])
338    pub weights: Vec<Array2<T>>,
339    /// Bias vectors per layer
340    pub biases: Vec<Array1<T>>,
341    /// Whether the column is frozen (previous tasks are frozen)
342    pub frozen: bool,
343}
344
345impl<T: Float + Debug + Send + Sync + 'static> NetworkColumn<T> {
346    /// Number of layers in this column.
347    pub fn num_layers(&self) -> usize {
348        self.weights.len()
349    }
350
351    /// Forward pass through a single layer with ReLU activation.
352    /// Returns pre-activation and post-activation values.
353    fn forward_layer(&self, input: &Array1<T>, layer_idx: usize) -> Result<(Array1<T>, Array1<T>)> {
354        if layer_idx >= self.weights.len() {
355            return Err(OptimError::NetworkError(format!(
356                "layer index {} out of range (column has {} layers)",
357                layer_idx,
358                self.weights.len()
359            )));
360        }
361        let w = &self.weights[layer_idx];
362        let b = &self.biases[layer_idx];
363
364        // z = W * x + b  (matrix-vector product)
365        let out_dim = w.nrows();
366        let in_dim = w.ncols();
367        if input.len() != in_dim {
368            return Err(OptimError::NetworkError(format!(
369                "input dimension mismatch at layer {}: weight expects {}, got {}",
370                layer_idx,
371                in_dim,
372                input.len()
373            )));
374        }
375
376        let mut z = b.clone();
377        for i in 0..out_dim {
378            let mut sum = T::zero();
379            for j in 0..in_dim {
380                sum = sum + w[[i, j]] * input[j];
381            }
382            z[i] = z[i] + sum;
383        }
384
385        // ReLU activation (except for the last layer which is linear)
386        let is_last_layer = layer_idx == self.weights.len() - 1;
387        let a = if is_last_layer {
388            z.clone()
389        } else {
390            z.mapv(|v| if v > T::zero() { v } else { T::zero() })
391        };
392
393        Ok((z, a))
394    }
395}
396
397/// Progressive Networks for continual learning.
398///
399/// Each new task gets its own column (set of layers). Previous columns are frozen
400/// and lateral connections from frozen columns feed into the new column, enabling
401/// knowledge transfer without forgetting.
402pub struct ProgressiveNetworks<T: Float + Debug + Send + Sync + 'static> {
403    /// One column per task
404    columns: Vec<NetworkColumn<T>>,
405    /// Lateral connection weights: lateral_connections[col][layer] is a matrix
406    /// of shape [hidden_size, prev_columns_total_hidden] that maps previous
407    /// columns' activations to a contribution for the current column.
408    lateral_connections: Vec<Vec<Array2<T>>>,
409    /// Index of the currently active (trainable) column
410    active_column: usize,
411    /// Hidden layer sizes (shared architecture template)
412    hidden_sizes: Vec<usize>,
413}
414
415impl<T: Float + Debug + Send + Sync + 'static> ProgressiveNetworks<T> {
416    /// Create a new progressive network with the given hidden layer sizes.
417    ///
418    /// No columns are created initially; call `add_task_column` for each task.
419    pub fn new(hidden_sizes: Vec<usize>) -> Self {
420        Self {
421            columns: Vec::new(),
422            lateral_connections: Vec::new(),
423            active_column: 0,
424            hidden_sizes,
425        }
426    }
427
428    /// Add a new column for a new task.
429    ///
430    /// `input_size`  - dimensionality of the input features.
431    /// `output_size` - dimensionality of the output.
432    ///
433    /// Returns the column index (task id).
434    /// All previously existing columns are frozen.
435    pub fn add_task_column(&mut self, input_size: usize, output_size: usize) -> Result<usize> {
436        if input_size == 0 || output_size == 0 {
437            return Err(OptimError::InvalidConfig(
438                "input_size and output_size must be > 0".to_string(),
439            ));
440        }
441
442        let col_id = self.columns.len();
443
444        // Freeze all existing columns
445        for col in &mut self.columns {
446            col.frozen = true;
447        }
448
449        // Build layer sizes: input -> hidden_1 -> ... -> hidden_n -> output
450        let mut layer_sizes = Vec::with_capacity(self.hidden_sizes.len() + 2);
451        layer_sizes.push(input_size);
452        layer_sizes.extend_from_slice(&self.hidden_sizes);
453        layer_sizes.push(output_size);
454
455        // Initialize weights with small values (Xavier-like initialization)
456        let mut weights = Vec::new();
457        let mut biases = Vec::new();
458
459        for l in 0..layer_sizes.len() - 1 {
460            let fan_in = layer_sizes[l];
461            let fan_out = layer_sizes[l + 1];
462
463            // Xavier init scale: sqrt(2 / (fan_in + fan_out))
464            let scale_val = T::from(2.0 / ((fan_in + fan_out) as f64)).unwrap_or_else(|| T::one());
465            let scale = Float::sqrt(scale_val);
466
467            // Deterministic initialization: alternate small positive/negative values
468            let mut w = Array2::from_elem((fan_out, fan_in), T::zero());
469            for i in 0..fan_out {
470                for j in 0..fan_in {
471                    // Simple deterministic pattern
472                    let idx = (i * fan_in + j) as f64;
473                    let sign = if (i + j) % 2 == 0 {
474                        T::one()
475                    } else {
476                        -T::one()
477                    };
478                    let magnitude =
479                        scale * T::from((idx + 1.0).recip()).unwrap_or_else(|| T::one());
480                    w[[i, j]] = sign * magnitude;
481                }
482            }
483
484            let b = Array1::from_elem(fan_out, T::zero());
485
486            weights.push(w);
487            biases.push(b);
488        }
489
490        let column = NetworkColumn {
491            weights,
492            biases,
493            frozen: false,
494        };
495
496        self.columns.push(column);
497
498        // Create lateral connections from all previous columns to this new column
499        let mut laterals_for_col = Vec::new();
500        if col_id > 0 {
501            for l in 0..layer_sizes.len() - 1 {
502                // At layer l, each previous column produces an activation of size layer_sizes[l+1]
503                // (except for the input layer which doesn't have lateral connections).
504                // Lateral input dimension = sum of previous columns' hidden sizes at layer l
505                // For simplicity, each previous column contributes layer_sizes[l+1] activations
506                // at layer l.
507                let lateral_in = if l == 0 {
508                    // No lateral connections at the first layer (input layer)
509                    0
510                } else {
511                    // Each previous column contributes its hidden size at layer l
512                    col_id * layer_sizes[l]
513                };
514                let lateral_out = layer_sizes[l + 1];
515
516                if lateral_in == 0 {
517                    // Placeholder (unused) for consistency
518                    laterals_for_col.push(Array2::from_elem((lateral_out, 1), T::zero()));
519                } else {
520                    // Small initialization for lateral weights
521                    let scale_val = T::from(0.1 / (lateral_in as f64)).unwrap_or_else(|| T::one());
522                    let scale = Float::sqrt(scale_val);
523                    let mut lat_w = Array2::from_elem((lateral_out, lateral_in), T::zero());
524                    for i in 0..lateral_out {
525                        for j in 0..lateral_in {
526                            let sign = if (i + j) % 3 == 0 {
527                                T::one()
528                            } else if (i + j) % 3 == 1 {
529                                -T::one()
530                            } else {
531                                T::zero()
532                            };
533                            lat_w[[i, j]] = sign * scale;
534                        }
535                    }
536                    laterals_for_col.push(lat_w);
537                }
538            }
539        }
540
541        self.lateral_connections.push(laterals_for_col);
542        self.active_column = col_id;
543
544        Ok(col_id)
545    }
546
547    /// Forward pass through the network for a specific task.
548    ///
549    /// For columns beyond the first, lateral connections from all previous columns
550    /// are added at each hidden layer.
551    pub fn forward(&self, input: &Array1<T>, task_id: usize) -> Result<Array1<T>> {
552        let (output, _) = self.forward_with_laterals(input, task_id)?;
553        Ok(output)
554    }
555
556    /// Forward pass that also returns intermediate activations from all columns.
557    ///
558    /// Returns (final_output, all_intermediate_activations) where
559    /// `all_intermediate_activations[col][layer]` is the activation at that layer
560    /// of that column.
561    pub fn forward_with_laterals(
562        &self,
563        input: &Array1<T>,
564        task_id: usize,
565    ) -> Result<(Array1<T>, Vec<Array1<T>>)> {
566        if task_id >= self.columns.len() {
567            return Err(OptimError::InvalidState(format!(
568                "task_id {} out of range; only {} columns exist",
569                task_id,
570                self.columns.len()
571            )));
572        }
573
574        // Compute activations for all columns up to and including task_id
575        // activations[col][layer] holds post-activation output of layer l in column col
576        let mut all_activations: Vec<Vec<Array1<T>>> = Vec::new();
577
578        for col in 0..=task_id {
579            let column = &self.columns[col];
580            let num_layers = column.num_layers();
581            let mut col_activations: Vec<Array1<T>> = Vec::new();
582            let mut h = input.clone();
583
584            for l in 0..num_layers {
585                // Lateral contribution from previous columns (only for col > 0 and l > 0)
586                if col > 0 && l > 0 {
587                    let laterals = &self.lateral_connections[col];
588                    if !laterals.is_empty() && l < laterals.len() {
589                        let lat_w = &laterals[l];
590                        // Concatenate activations from all previous columns at layer l-1
591                        // (previous layer's output for each previous column)
592                        let mut lateral_input_parts: Vec<T> = Vec::new();
593                        for prev_col_acts in all_activations.iter().take(col) {
594                            if l - 1 < prev_col_acts.len() {
595                                let prev_act = &prev_col_acts[l - 1];
596                                lateral_input_parts.extend(prev_act.iter().copied());
597                            }
598                        }
599
600                        if !lateral_input_parts.is_empty() {
601                            let lateral_input = Array1::from_vec(lateral_input_parts);
602
603                            // Check dimension compatibility
604                            if lat_w.ncols() == lateral_input.len() {
605                                let lat_out_dim = lat_w.nrows();
606                                let lat_in_dim = lat_w.ncols();
607                                let mut lateral_contribution =
608                                    Array1::from_elem(lat_out_dim, T::zero());
609                                for i in 0..lat_out_dim {
610                                    let mut sum = T::zero();
611                                    for j in 0..lat_in_dim {
612                                        sum = sum + lat_w[[i, j]] * lateral_input[j];
613                                    }
614                                    lateral_contribution[i] = sum;
615                                }
616
617                                // Add lateral contribution to input before this layer
618                                // Dimensions must match: h and lateral_contribution
619                                if h.len() == lateral_contribution.len() {
620                                    for i in 0..h.len() {
621                                        h[i] = h[i] + lateral_contribution[i];
622                                    }
623                                }
624                            }
625                        }
626                    }
627                }
628
629                let (_pre, post) = column.forward_layer(&h, l)?;
630                col_activations.push(post.clone());
631                h = post;
632            }
633
634            all_activations.push(col_activations);
635        }
636
637        // Final output is the last activation of the target column
638        let target_activations = &all_activations[task_id];
639        let output = target_activations
640            .last()
641            .ok_or_else(|| OptimError::NetworkError("column has no layers".to_string()))?
642            .clone();
643
644        // Collect intermediate activations (flatten from all columns)
645        let mut intermediates: Vec<Array1<T>> = Vec::new();
646        for col_acts in &all_activations {
647            for act in col_acts {
648                intermediates.push(act.clone());
649            }
650        }
651
652        Ok((output, intermediates))
653    }
654
655    /// Freeze a specific column (mark as non-trainable).
656    pub fn freeze_column(&mut self, column_id: usize) -> Result<()> {
657        if column_id >= self.columns.len() {
658            return Err(OptimError::InvalidState(format!(
659                "column_id {} out of range; only {} columns exist",
660                column_id,
661                self.columns.len()
662            )));
663        }
664        self.columns[column_id].frozen = true;
665        Ok(())
666    }
667
668    /// Return the number of columns (one per task).
669    pub fn num_columns(&self) -> usize {
670        self.columns.len()
671    }
672
673    /// Get a reference to a specific column's parameters.
674    pub fn get_column_parameters(&self, column_id: usize) -> Result<&NetworkColumn<T>> {
675        self.columns.get(column_id).ok_or_else(|| {
676            OptimError::InvalidState(format!(
677                "column_id {} out of range; only {} columns exist",
678                column_id,
679                self.columns.len()
680            ))
681        })
682    }
683}
684
685// ---------------------------------------------------------------------------
686// Tests
687// ---------------------------------------------------------------------------
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692    use scirs2_core::ndarray::Array1;
693
694    type F = f64;
695
696    // Helper: create simple parameters map
697    fn make_params(names: &[&str], size: usize, value: f64) -> HashMap<String, Array1<F>> {
698        let mut map = HashMap::new();
699        for name in names {
700            map.insert(name.to_string(), Array1::from_elem(size, value));
701        }
702        map
703    }
704
705    // -----------------------------------------------------------------------
706    // EWC Tests
707    // -----------------------------------------------------------------------
708
709    #[test]
710    fn test_ewc_creation_and_configuration() {
711        let ewc: ElasticWeightConsolidation<F> = ElasticWeightConsolidation::new(1000.0)
712            .with_num_samples(50)
713            .with_online(true)
714            .with_gamma(0.9);
715
716        assert_eq!(ewc.num_tasks(), 0);
717        assert!(ewc.online);
718        assert!((ewc.gamma - 0.9).abs() < 1e-12);
719        assert!((ewc.lambda - 1000.0).abs() < 1e-12);
720        assert_eq!(ewc.num_samples_fisher, 50);
721    }
722
723    #[test]
724    fn test_fisher_diagonal_computation() {
725        let mut ewc: ElasticWeightConsolidation<F> =
726            ElasticWeightConsolidation::new(1.0).with_num_samples(10);
727
728        let params = make_params(&["w1", "w2"], 3, 1.0);
729
730        // Gradients function: returns constant gradients (simulating quadratic loss)
731        // For a quadratic loss f(x) = 0.5 * x^2, grad = x = [1, 1, 1]
732        let grad_fn = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
733            let mut g = HashMap::new();
734            g.insert("w1".to_string(), Array1::from_vec(vec![1.0, 2.0, 3.0]));
735            g.insert("w2".to_string(), Array1::from_vec(vec![0.5, 0.5, 0.5]));
736            Ok(g)
737        };
738
739        ewc.compute_fisher_diagonal(&params, grad_fn)
740            .expect("compute_fisher_diagonal should succeed");
741
742        // Fisher = E[g^2] = g^2 since gradients are constant
743        let f1 = ewc
744            .fisher_diagonal
745            .get("w1")
746            .expect("w1 Fisher should exist");
747        assert!((f1[0] - 1.0).abs() < 1e-12);
748        assert!((f1[1] - 4.0).abs() < 1e-12);
749        assert!((f1[2] - 9.0).abs() < 1e-12);
750
751        let f2 = ewc
752            .fisher_diagonal
753            .get("w2")
754            .expect("w2 Fisher should exist");
755        assert!((f2[0] - 0.25).abs() < 1e-12);
756    }
757
758    #[test]
759    fn test_ewc_penalty_computation() {
760        let mut ewc: ElasticWeightConsolidation<F> =
761            ElasticWeightConsolidation::new(2.0).with_num_samples(5);
762
763        let anchor = make_params(&["w"], 2, 1.0);
764
765        // Constant Fisher = [1.0, 1.0]
766        let grad_fn = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
767            let mut g = HashMap::new();
768            g.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0]));
769            Ok(g)
770        };
771
772        ewc.compute_fisher_diagonal(&anchor, grad_fn)
773            .expect("compute Fisher should succeed");
774        ewc.consolidate(&anchor)
775            .expect("consolidate should succeed");
776
777        // Current parameters differ from anchor
778        let mut current = HashMap::new();
779        current.insert("w".to_string(), Array1::from_vec(vec![2.0, 3.0]));
780
781        // Penalty = (lambda/2) * sum F_i * (theta_i - theta*_i)^2
782        // = (2/2) * (1*(2-1)^2 + 1*(3-1)^2) = 1 * (1 + 4) = 5.0
783        let penalty = ewc
784            .ewc_penalty(&current)
785            .expect("ewc_penalty should succeed");
786        assert!(
787            (penalty - 5.0).abs() < 1e-12,
788            "expected penalty 5.0, got {}",
789            penalty
790        );
791    }
792
793    #[test]
794    fn test_ewc_gradient_computation() {
795        let mut ewc: ElasticWeightConsolidation<F> =
796            ElasticWeightConsolidation::new(2.0).with_num_samples(5);
797
798        let anchor = make_params(&["w"], 2, 1.0);
799
800        let grad_fn = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
801            let mut g = HashMap::new();
802            g.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0]));
803            Ok(g)
804        };
805
806        ewc.compute_fisher_diagonal(&anchor, grad_fn)
807            .expect("compute Fisher should succeed");
808        ewc.consolidate(&anchor)
809            .expect("consolidate should succeed");
810
811        let mut current = HashMap::new();
812        current.insert("w".to_string(), Array1::from_vec(vec![2.0, 3.0]));
813
814        // Gradient = lambda * F_i * (theta_i - theta*_i)
815        // = 2 * 1 * (2-1, 3-1) = (2, 4)
816        let grads = ewc
817            .ewc_gradient(&current)
818            .expect("ewc_gradient should succeed");
819        let gw = grads.get("w").expect("gradient for w should exist");
820        assert!((gw[0] - 2.0).abs() < 1e-12, "expected 2.0, got {}", gw[0]);
821        assert!((gw[1] - 4.0).abs() < 1e-12, "expected 4.0, got {}", gw[1]);
822    }
823
824    #[test]
825    fn test_consolidation_workflow() {
826        let mut ewc: ElasticWeightConsolidation<F> =
827            ElasticWeightConsolidation::new(1.0).with_num_samples(5);
828
829        let params_task1 = make_params(&["w"], 2, 1.0);
830
831        let grad_fn = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
832            let mut g = HashMap::new();
833            g.insert("w".to_string(), Array1::from_vec(vec![1.0, 2.0]));
834            Ok(g)
835        };
836
837        ewc.compute_fisher_diagonal(&params_task1, grad_fn)
838            .expect("Fisher computation should succeed");
839        ewc.consolidate(&params_task1)
840            .expect("consolidation should succeed");
841        assert_eq!(ewc.num_tasks(), 1);
842
843        // Second task
844        let params_task2 = make_params(&["w"], 2, 2.0);
845        ewc.compute_fisher_diagonal(&params_task2, grad_fn)
846            .expect("Fisher computation for task 2 should succeed");
847        ewc.consolidate(&params_task2)
848            .expect("consolidation for task 2 should succeed");
849        assert_eq!(ewc.num_tasks(), 2);
850
851        // Penalty at anchor should be zero (for the latest anchor)
852        let penalty = ewc
853            .ewc_penalty(&params_task2)
854            .expect("penalty should succeed");
855        // The latest anchor is params_task2, so the second task's Fisher penalty is 0.
856        // But the first task's Fisher penalty is non-zero because params_task2 != params_task1.
857        // penalty = 0.5 * sum_task1 F_i*(2-1)^2 = 0.5 * (1*1 + 4*1) = 2.5
858        assert!(
859            penalty > 0.0,
860            "penalty should be positive due to task1 Fisher, got {}",
861            penalty
862        );
863    }
864
865    #[test]
866    fn test_online_ewc_multiple_tasks() {
867        let mut ewc: ElasticWeightConsolidation<F> = ElasticWeightConsolidation::new(1.0)
868            .with_num_samples(5)
869            .with_online(true)
870            .with_gamma(0.5);
871
872        // Task 1
873        let params1 = make_params(&["w"], 2, 0.0);
874        let grad_fn1 = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
875            let mut g = HashMap::new();
876            g.insert("w".to_string(), Array1::from_vec(vec![1.0, 1.0]));
877            Ok(g)
878        };
879        ewc.compute_fisher_diagonal(&params1, grad_fn1)
880            .expect("Fisher 1 should succeed");
881        ewc.consolidate(&params1)
882            .expect("consolidate 1 should succeed");
883        assert_eq!(ewc.num_tasks(), 1);
884
885        // Task 1 Fisher = [1.0, 1.0] (since gradients are constant = 1.0)
886        let fisher1 = ewc.task_fisher_diagonals[0]
887            .get("w")
888            .expect("Fisher should exist");
889        assert!((fisher1[0] - 1.0).abs() < 1e-12);
890
891        // Task 2
892        let params2 = make_params(&["w"], 2, 1.0);
893        let grad_fn2 = |_p: &HashMap<String, Array1<F>>| -> Result<HashMap<String, Array1<F>>> {
894            let mut g = HashMap::new();
895            g.insert("w".to_string(), Array1::from_vec(vec![2.0, 2.0]));
896            Ok(g)
897        };
898        ewc.compute_fisher_diagonal(&params2, grad_fn2)
899            .expect("Fisher 2 should succeed");
900        ewc.consolidate(&params2)
901            .expect("consolidate 2 should succeed");
902        assert_eq!(ewc.num_tasks(), 2);
903
904        // Online Fisher for task 2 = gamma * old + new = 0.5 * [1,1] + [4,4] = [4.5, 4.5]
905        let fisher2 = ewc.task_fisher_diagonals[1]
906            .get("w")
907            .expect("Fisher should exist");
908        assert!(
909            (fisher2[0] - 4.5).abs() < 1e-12,
910            "expected 4.5, got {}",
911            fisher2[0]
912        );
913
914        // Penalty at params different from anchor (params2)
915        let mut test_params = HashMap::new();
916        test_params.insert("w".to_string(), Array1::from_vec(vec![2.0, 2.0]));
917
918        let penalty = ewc
919            .ewc_penalty(&test_params)
920            .expect("penalty should succeed");
921        // penalty = 0.5 * sum 4.5 * (2-1)^2 = 0.5 * (4.5 + 4.5) = 4.5
922        assert!(
923            (penalty - 4.5).abs() < 1e-12,
924            "expected 4.5, got {}",
925            penalty
926        );
927    }
928
929    // -----------------------------------------------------------------------
930    // Progressive Networks Tests
931    // -----------------------------------------------------------------------
932
933    #[test]
934    fn test_progressive_network_creation_and_columns() {
935        let mut pn: ProgressiveNetworks<F> = ProgressiveNetworks::new(vec![16, 8]);
936
937        assert_eq!(pn.num_columns(), 0);
938
939        let col0 = pn
940            .add_task_column(4, 2)
941            .expect("add column 0 should succeed");
942        assert_eq!(col0, 0);
943        assert_eq!(pn.num_columns(), 1);
944
945        // Check column structure: 4 -> 16 -> 8 -> 2 = 3 layers
946        let params0 = pn
947            .get_column_parameters(0)
948            .expect("get column 0 should succeed");
949        assert_eq!(params0.num_layers(), 3);
950        assert!(!params0.frozen);
951
952        // Add second column - first should now be frozen
953        let col1 = pn
954            .add_task_column(4, 2)
955            .expect("add column 1 should succeed");
956        assert_eq!(col1, 1);
957        assert_eq!(pn.num_columns(), 2);
958
959        let params0_after = pn
960            .get_column_parameters(0)
961            .expect("get column 0 should succeed");
962        assert!(params0_after.frozen, "column 0 should be frozen");
963
964        let params1 = pn
965            .get_column_parameters(1)
966            .expect("get column 1 should succeed");
967        assert!(!params1.frozen, "column 1 should not be frozen");
968    }
969
970    #[test]
971    fn test_progressive_network_forward_single_column() {
972        let mut pn: ProgressiveNetworks<F> = ProgressiveNetworks::new(vec![8]);
973
974        pn.add_task_column(4, 2).expect("add column should succeed");
975
976        let input = Array1::from_vec(vec![1.0, 0.5, -0.3, 0.8]);
977        let output = pn.forward(&input, 0).expect("forward should succeed");
978
979        // Output should have 2 elements (output_size = 2)
980        assert_eq!(output.len(), 2, "output should have 2 elements");
981
982        // Output should be finite
983        for val in output.iter() {
984            assert!(val.is_finite(), "output should be finite, got {}", val);
985        }
986    }
987
988    #[test]
989    fn test_progressive_network_forward_with_laterals() {
990        let mut pn: ProgressiveNetworks<F> = ProgressiveNetworks::new(vec![8]);
991
992        pn.add_task_column(4, 2)
993            .expect("add column 0 should succeed");
994        pn.add_task_column(4, 2)
995            .expect("add column 1 should succeed");
996
997        let input = Array1::from_vec(vec![1.0, 0.5, -0.3, 0.8]);
998
999        // Forward through task 0 (no laterals)
1000        let output0 = pn
1001            .forward(&input, 0)
1002            .expect("forward task 0 should succeed");
1003        assert_eq!(output0.len(), 2);
1004
1005        // Forward through task 1 (with lateral connections from task 0)
1006        let (output1, intermediates) = pn
1007            .forward_with_laterals(&input, 1)
1008            .expect("forward_with_laterals task 1 should succeed");
1009        assert_eq!(output1.len(), 2);
1010
1011        // Should have intermediate activations from both columns
1012        assert!(
1013            !intermediates.is_empty(),
1014            "should have intermediate activations"
1015        );
1016
1017        // The outputs for task 0 and task 1 should differ due to
1018        // different weights and lateral connections
1019        let diff: f64 = output0
1020            .iter()
1021            .zip(output1.iter())
1022            .map(|(a, b)| (a - b).abs())
1023            .sum();
1024        // They are almost certainly different given different random inits
1025        // but we don't require it -- just check both are finite
1026        for val in output1.iter() {
1027            assert!(val.is_finite(), "output should be finite, got {}", val);
1028        }
1029        let _ = diff; // suppress unused warning
1030    }
1031
1032    #[test]
1033    fn test_progressive_network_freeze_column() {
1034        let mut pn: ProgressiveNetworks<F> = ProgressiveNetworks::new(vec![8]);
1035
1036        pn.add_task_column(4, 2).expect("add column should succeed");
1037
1038        assert!(
1039            !pn.get_column_parameters(0).expect("get params").frozen,
1040            "column should not be frozen initially"
1041        );
1042
1043        pn.freeze_column(0).expect("freeze should succeed");
1044
1045        assert!(
1046            pn.get_column_parameters(0).expect("get params").frozen,
1047            "column should be frozen after freeze_column"
1048        );
1049
1050        // Freeze out-of-range column should error
1051        let err = pn.freeze_column(99);
1052        assert!(err.is_err(), "freezing out-of-range column should fail");
1053    }
1054
1055    #[test]
1056    fn test_progressive_network_multiple_tasks_forward() {
1057        let mut pn: ProgressiveNetworks<F> = ProgressiveNetworks::new(vec![8, 4]);
1058
1059        // Add 3 task columns
1060        for _ in 0..3 {
1061            pn.add_task_column(4, 2).expect("add column should succeed");
1062        }
1063        assert_eq!(pn.num_columns(), 3);
1064
1065        let input = Array1::from_vec(vec![0.5, -0.5, 1.0, -1.0]);
1066
1067        // Forward through each task should work
1068        for task_id in 0..3 {
1069            let output = pn
1070                .forward(&input, task_id)
1071                .unwrap_or_else(|_| panic!("forward task {} should succeed", task_id));
1072            assert_eq!(
1073                output.len(),
1074                2,
1075                "output for task {} should have 2 elements",
1076                task_id
1077            );
1078            for val in output.iter() {
1079                assert!(
1080                    val.is_finite(),
1081                    "output for task {} should be finite",
1082                    task_id
1083                );
1084            }
1085        }
1086
1087        // Forward with invalid task_id should fail
1088        let err = pn.forward(&input, 10);
1089        assert!(err.is_err(), "forward with invalid task_id should fail");
1090    }
1091}