sklears_multioutput/
transfer_learning.rs

1//! Transfer Learning for Multi-Task Learning
2//!
3//! This module provides transfer learning algorithms for multi-task scenarios,
4//! including domain adaptation, progressive transfer, and continual learning methods.
5
6// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
7use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2, Axis};
8use scirs2_core::random::thread_rng;
9use scirs2_core::random::RandNormal;
10use sklears_core::{
11    error::{Result as SklResult, SklearsError},
12    traits::{Estimator, Untrained},
13    types::Float,
14};
15
16/// Cross-Task Transfer Learning
17///
18/// Implements cross-task transfer learning for multi-task scenarios where
19/// knowledge from source tasks is transferred to target tasks.
20///
21/// # Examples
22///
23/// ```
24/// use sklears_multioutput::transfer_learning::CrossTaskTransferLearning;
25/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
26/// use scirs2_core::ndarray::array;
27///
28/// let source_data = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
29/// let source_labels = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
30/// let target_data = array![[1.1, 2.1], [2.1, 3.1]];
31/// let target_labels = array![[1.0, 0.0], [0.0, 1.0]];
32///
33/// let transfer = CrossTaskTransferLearning::new()
34///     .transfer_strength(0.5)
35///     .learning_rate(0.01);
36/// ```
37#[derive(Debug, Clone)]
38pub struct CrossTaskTransferLearning<S = Untrained> {
39    state: S,
40    transfer_strength: Float,
41    learning_rate: Float,
42    max_iter: usize,
43    random_state: Option<u64>,
44}
45
46#[derive(Debug, Clone)]
47pub struct CrossTaskTransferLearningTrained {
48    source_weights: Array2<Float>,
49    target_weights: Array2<Float>,
50    transfer_matrix: Array2<Float>,
51    n_features: usize,
52    n_source_tasks: usize,
53    n_target_tasks: usize,
54}
55
56impl CrossTaskTransferLearning<Untrained> {
57    /// Create a new CrossTaskTransferLearning instance
58    pub fn new() -> Self {
59        Self {
60            state: Untrained,
61            transfer_strength: 0.5,
62            learning_rate: 0.01,
63            max_iter: 1000,
64            random_state: None,
65        }
66    }
67
68    /// Set the transfer strength (higher values = more transfer)
69    pub fn transfer_strength(mut self, strength: Float) -> Self {
70        self.transfer_strength = strength;
71        self
72    }
73
74    /// Set the learning rate
75    pub fn learning_rate(mut self, lr: Float) -> Self {
76        self.learning_rate = lr;
77        self
78    }
79
80    /// Set the maximum number of iterations
81    pub fn max_iter(mut self, max_iter: usize) -> Self {
82        self.max_iter = max_iter;
83        self
84    }
85
86    /// Set the random state for reproducibility
87    pub fn random_state(mut self, seed: Option<u64>) -> Self {
88        self.random_state = seed;
89        self
90    }
91
92    /// Fit the transfer learning model
93    pub fn fit(
94        &self,
95        source_X: &ArrayView2<Float>,
96        source_y: &ArrayView2<Float>,
97        target_X: &ArrayView2<Float>,
98        target_y: &ArrayView2<Float>,
99    ) -> SklResult<CrossTaskTransferLearning<CrossTaskTransferLearningTrained>> {
100        let n_source_samples = source_X.nrows();
101        let n_target_samples = target_X.nrows();
102        let n_features = source_X.ncols();
103        let n_source_tasks = source_y.ncols();
104        let n_target_tasks = target_y.ncols();
105
106        if source_X.ncols() != target_X.ncols() {
107            return Err(SklearsError::InvalidInput(
108                "Source and target data must have the same number of features".to_string(),
109            ));
110        }
111
112        if n_source_samples != source_y.nrows() {
113            return Err(SklearsError::InvalidInput(
114                "Number of source samples must match source labels".to_string(),
115            ));
116        }
117
118        if n_target_samples != target_y.nrows() {
119            return Err(SklearsError::InvalidInput(
120                "Number of target samples must match target labels".to_string(),
121            ));
122        }
123
124        let mut rng = thread_rng();
125
126        // Initialize weights
127        let normal_dist = RandNormal::new(0.0, 0.1).unwrap();
128
129        let mut source_weights = Array2::<Float>::zeros((n_features, n_source_tasks));
130        for i in 0..n_features {
131            for j in 0..n_source_tasks {
132                source_weights[[i, j]] = rng.sample(normal_dist);
133            }
134        }
135
136        let mut target_weights = Array2::<Float>::zeros((n_features, n_target_tasks));
137        for i in 0..n_features {
138            for j in 0..n_target_tasks {
139                target_weights[[i, j]] = rng.sample(normal_dist);
140            }
141        }
142
143        let mut transfer_matrix = Array2::<Float>::zeros((n_source_tasks, n_target_tasks));
144        for i in 0..n_source_tasks {
145            for j in 0..n_target_tasks {
146                transfer_matrix[[i, j]] = rng.sample(normal_dist);
147            }
148        }
149
150        // Training loop
151        for _ in 0..self.max_iter {
152            // Update source weights
153            let source_pred = source_X.dot(&source_weights);
154            let source_error = &source_pred - source_y;
155            let source_grad = source_X.t().dot(&source_error) / n_source_samples as Float;
156            source_weights -= &(source_grad * self.learning_rate);
157
158            // Update target weights with transfer
159            let target_pred = target_X.dot(&target_weights);
160            let transferred_pred = target_X.dot(&source_weights).dot(&transfer_matrix);
161            let target_error = &target_pred - target_y;
162            let transfer_error = &transferred_pred - target_y;
163
164            let target_grad = target_X.t().dot(&target_error) / n_target_samples as Float;
165            let transfer_grad = target_X.t().dot(&transfer_error) / n_target_samples as Float;
166
167            target_weights -= &(target_grad * self.learning_rate);
168            target_weights -= &(transfer_grad * self.learning_rate * self.transfer_strength);
169
170            // Update transfer matrix
171            let transfer_matrix_grad =
172                target_X.dot(&source_weights).t().dot(&transfer_error) / n_target_samples as Float;
173            transfer_matrix -=
174                &(transfer_matrix_grad * self.learning_rate * self.transfer_strength);
175        }
176
177        Ok(CrossTaskTransferLearning {
178            state: CrossTaskTransferLearningTrained {
179                source_weights,
180                target_weights,
181                transfer_matrix,
182                n_features,
183                n_source_tasks,
184                n_target_tasks,
185            },
186            transfer_strength: self.transfer_strength,
187            learning_rate: self.learning_rate,
188            max_iter: self.max_iter,
189            random_state: self.random_state,
190        })
191    }
192}
193
194impl CrossTaskTransferLearning<CrossTaskTransferLearningTrained> {
195    /// Predict using the trained transfer learning model
196    pub fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
197        if X.ncols() != self.state.n_features {
198            return Err(SklearsError::InvalidInput(
199                "Number of features must match training data".to_string(),
200            ));
201        }
202
203        let target_pred = X.dot(&self.state.target_weights);
204        Ok(target_pred)
205    }
206
207    /// Predict using source task knowledge
208    pub fn predict_from_source(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
209        if X.ncols() != self.state.n_features {
210            return Err(SklearsError::InvalidInput(
211                "Number of features must match training data".to_string(),
212            ));
213        }
214
215        let source_pred = X.dot(&self.state.source_weights);
216        let transferred_pred = source_pred.dot(&self.state.transfer_matrix);
217        Ok(transferred_pred)
218    }
219
220    /// Get the transfer matrix
221    pub fn transfer_matrix(&self) -> &Array2<Float> {
222        &self.state.transfer_matrix
223    }
224
225    /// Get the source weights
226    pub fn source_weights(&self) -> &Array2<Float> {
227        &self.state.source_weights
228    }
229
230    /// Get the target weights
231    pub fn target_weights(&self) -> &Array2<Float> {
232        &self.state.target_weights
233    }
234}
235
236impl Default for CrossTaskTransferLearning<Untrained> {
237    fn default() -> Self {
238        Self::new()
239    }
240}
241
242impl Estimator for CrossTaskTransferLearning<Untrained> {
243    type Config = ();
244    type Error = SklearsError;
245    type Float = Float;
246
247    fn config(&self) -> &Self::Config {
248        &()
249    }
250}
251
252impl Estimator for CrossTaskTransferLearning<CrossTaskTransferLearningTrained> {
253    type Config = ();
254    type Error = SklearsError;
255    type Float = Float;
256
257    fn config(&self) -> &Self::Config {
258        &()
259    }
260}
261
262/// Domain Adaptation for Multi-Task Learning
263///
264/// Implements domain adaptation techniques to transfer knowledge
265/// between different domains in multi-task settings.
266///
267/// # Examples
268///
269/// ```
270/// use sklears_multioutput::transfer_learning::DomainAdaptation;
271/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
272/// use scirs2_core::ndarray::array;
273///
274/// let source_data = array![[1.0, 2.0], [2.0, 3.0]];
275/// let source_labels = array![[1.0], [0.0]];
276/// let target_data = array![[1.1, 2.1], [2.1, 3.1]];
277/// let target_labels = array![[1.0], [0.0]];
278///
279/// let adaptation = DomainAdaptation::new()
280///     .adaptation_strength(0.3)
281///     .learning_rate(0.01);
282/// ```
283#[derive(Debug, Clone)]
284pub struct DomainAdaptation<S = Untrained> {
285    state: S,
286    adaptation_strength: Float,
287    learning_rate: Float,
288    max_iter: usize,
289    random_state: Option<u64>,
290}
291
292#[derive(Debug, Clone)]
293pub struct DomainAdaptationTrained {
294    feature_extractor: Array2<Float>,
295    classifier: Array2<Float>,
296    domain_discriminator: Array2<Float>,
297    n_features: usize,
298    n_tasks: usize,
299}
300
301impl DomainAdaptation<Untrained> {
302    /// Create a new DomainAdaptation instance
303    pub fn new() -> Self {
304        Self {
305            state: Untrained,
306            adaptation_strength: 0.3,
307            learning_rate: 0.01,
308            max_iter: 1000,
309            random_state: None,
310        }
311    }
312
313    /// Set the adaptation strength
314    pub fn adaptation_strength(mut self, strength: Float) -> Self {
315        self.adaptation_strength = strength;
316        self
317    }
318
319    /// Set the learning rate
320    pub fn learning_rate(mut self, lr: Float) -> Self {
321        self.learning_rate = lr;
322        self
323    }
324
325    /// Set the maximum number of iterations
326    pub fn max_iter(mut self, max_iter: usize) -> Self {
327        self.max_iter = max_iter;
328        self
329    }
330
331    /// Set the random state for reproducibility
332    pub fn random_state(mut self, seed: Option<u64>) -> Self {
333        self.random_state = seed;
334        self
335    }
336
337    /// Fit the domain adaptation model
338    pub fn fit(
339        &self,
340        source_X: &ArrayView2<Float>,
341        source_y: &ArrayView2<Float>,
342        target_X: &ArrayView2<Float>,
343        target_y: &ArrayView2<Float>,
344    ) -> SklResult<DomainAdaptation<DomainAdaptationTrained>> {
345        let n_source_samples = source_X.nrows();
346        let n_target_samples = target_X.nrows();
347        let n_features = source_X.ncols();
348        let n_tasks = source_y.ncols();
349
350        if source_X.ncols() != target_X.ncols() {
351            return Err(SklearsError::InvalidInput(
352                "Source and target data must have the same number of features".to_string(),
353            ));
354        }
355
356        if n_source_samples != source_y.nrows() {
357            return Err(SklearsError::InvalidInput(
358                "Number of source samples must match source labels".to_string(),
359            ));
360        }
361
362        if n_target_samples != target_y.nrows() {
363            return Err(SklearsError::InvalidInput(
364                "Number of target samples must match target labels".to_string(),
365            ));
366        }
367
368        let mut rng = thread_rng();
369
370        // Initialize networks
371        let hidden_dim = (n_features + n_tasks) / 2;
372        let mut feature_extractor = Array2::<Float>::zeros((n_features, hidden_dim));
373        let normal_dist = RandNormal::new(0.0, 0.1).unwrap();
374        for i in 0..n_features {
375            for j in 0..hidden_dim {
376                feature_extractor[[i, j]] = rng.sample(normal_dist);
377            }
378        }
379        let mut classifier = Array2::<Float>::zeros((hidden_dim, n_tasks));
380        let classifier_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
381        for i in 0..hidden_dim {
382            for j in 0..n_tasks {
383                classifier[[i, j]] = rng.sample(classifier_normal_dist);
384            }
385        }
386        let mut domain_discriminator = Array2::<Float>::zeros((hidden_dim, 1));
387        let discriminator_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
388        for i in 0..hidden_dim {
389            domain_discriminator[[i, 0]] = rng.sample(discriminator_normal_dist);
390        }
391
392        // Create domain labels (0 for source, 1 for target)
393        let mut domain_labels = Array2::<Float>::zeros((n_source_samples + n_target_samples, 1));
394        for i in n_source_samples..(n_source_samples + n_target_samples) {
395            domain_labels[(i, 0)] = 1.0;
396        }
397
398        // Combine data
399        let mut combined_X =
400            Array2::<Float>::zeros((n_source_samples + n_target_samples, n_features));
401        combined_X
402            .slice_mut(s![..n_source_samples, ..])
403            .assign(source_X);
404        combined_X
405            .slice_mut(s![n_source_samples.., ..])
406            .assign(target_X);
407
408        // Training loop
409        for _ in 0..self.max_iter {
410            // Extract features
411            let features = combined_X.dot(&feature_extractor);
412            let source_features = features.slice(s![..n_source_samples, ..]);
413            let target_features = features.slice(s![n_source_samples.., ..]);
414
415            // Train classifier on source domain
416            let source_pred = source_features.dot(&classifier);
417            let classification_error = &source_pred - source_y;
418            let classifier_grad =
419                source_features.t().dot(&classification_error) / n_source_samples as Float;
420            classifier -= &(&classifier_grad * self.learning_rate);
421
422            // Train domain discriminator (distinguish source from target)
423            let domain_pred = features.dot(&domain_discriminator);
424            let domain_error = &domain_pred - &domain_labels;
425            let discriminator_grad =
426                features.t().dot(&domain_error) / (n_source_samples + n_target_samples) as Float;
427            domain_discriminator -= &(&discriminator_grad * self.learning_rate);
428
429            // Update feature extractor (adversarial training)
430            let feat_class_grad =
431                combined_X.t().dot(&features.dot(&classifier_grad.t())) / n_source_samples as Float;
432            let feat_domain_grad = combined_X.t().dot(&features.dot(&discriminator_grad))
433                / (n_source_samples + n_target_samples) as Float;
434
435            feature_extractor -= &(feat_class_grad * self.learning_rate);
436            feature_extractor +=
437                &(feat_domain_grad * self.learning_rate * self.adaptation_strength);
438            // Adversarial
439        }
440
441        Ok(DomainAdaptation {
442            state: DomainAdaptationTrained {
443                feature_extractor,
444                classifier,
445                domain_discriminator,
446                n_features,
447                n_tasks,
448            },
449            adaptation_strength: self.adaptation_strength,
450            learning_rate: self.learning_rate,
451            max_iter: self.max_iter,
452            random_state: self.random_state,
453        })
454    }
455}
456
457impl DomainAdaptation<DomainAdaptationTrained> {
458    /// Predict using the trained domain adaptation model
459    pub fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
460        if X.ncols() != self.state.n_features {
461            return Err(SklearsError::InvalidInput(
462                "Number of features must match training data".to_string(),
463            ));
464        }
465
466        let features = X.dot(&self.state.feature_extractor);
467        let predictions = features.dot(&self.state.classifier);
468        Ok(predictions)
469    }
470
471    /// Extract domain-invariant features
472    pub fn extract_features(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
473        if X.ncols() != self.state.n_features {
474            return Err(SklearsError::InvalidInput(
475                "Number of features must match training data".to_string(),
476            ));
477        }
478
479        let features = X.dot(&self.state.feature_extractor);
480        Ok(features)
481    }
482
483    /// Predict domain labels (0 for source-like, 1 for target-like)
484    pub fn predict_domain(&self, X: &ArrayView2<Float>) -> SklResult<Array1<Float>> {
485        if X.ncols() != self.state.n_features {
486            return Err(SklearsError::InvalidInput(
487                "Number of features must match training data".to_string(),
488            ));
489        }
490
491        let features = X.dot(&self.state.feature_extractor);
492        let domain_pred = features.dot(&self.state.domain_discriminator);
493        Ok(domain_pred.column(0).to_owned())
494    }
495}
496
497impl Default for DomainAdaptation<Untrained> {
498    fn default() -> Self {
499        Self::new()
500    }
501}
502
503impl Estimator for DomainAdaptation<Untrained> {
504    type Config = ();
505    type Error = SklearsError;
506    type Float = Float;
507
508    fn config(&self) -> &Self::Config {
509        &()
510    }
511}
512
513impl Estimator for DomainAdaptation<DomainAdaptationTrained> {
514    type Config = ();
515    type Error = SklearsError;
516    type Float = Float;
517
518    fn config(&self) -> &Self::Config {
519        &()
520    }
521}
522
523/// Progressive Transfer Learning
524///
525/// Implements progressive transfer learning where tasks are learned
526/// sequentially, with knowledge from earlier tasks helping later ones.
527///
528/// # Examples
529///
530/// ```
531/// use sklears_multioutput::transfer_learning::ProgressiveTransferLearning;
532/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
533///
534/// let transfer = ProgressiveTransferLearning::new()
535///     .transfer_strength(0.4)
536///     .learning_rate(0.01)
537///     .max_iter(500);
538/// ```
539#[derive(Debug, Clone)]
540pub struct ProgressiveTransferLearning<S = Untrained> {
541    state: S,
542    transfer_strength: Float,
543    learning_rate: Float,
544    max_iter: usize,
545    random_state: Option<u64>,
546}
547
548#[derive(Debug, Clone)]
549pub struct ProgressiveTransferLearningTrained {
550    task_weights: Vec<Array2<Float>>,
551    shared_weights: Array2<Float>,
552    task_order: Vec<usize>,
553    n_features: usize,
554    n_tasks: usize,
555}
556
557impl ProgressiveTransferLearning<Untrained> {
558    /// Create a new ProgressiveTransferLearning instance
559    pub fn new() -> Self {
560        Self {
561            state: Untrained,
562            transfer_strength: 0.4,
563            learning_rate: 0.01,
564            max_iter: 500,
565            random_state: None,
566        }
567    }
568
569    /// Set the transfer strength
570    pub fn transfer_strength(mut self, strength: Float) -> Self {
571        self.transfer_strength = strength;
572        self
573    }
574
575    /// Set the learning rate
576    pub fn learning_rate(mut self, lr: Float) -> Self {
577        self.learning_rate = lr;
578        self
579    }
580
581    /// Set the maximum number of iterations
582    pub fn max_iter(mut self, max_iter: usize) -> Self {
583        self.max_iter = max_iter;
584        self
585    }
586
587    /// Set the random state for reproducibility
588    pub fn random_state(mut self, seed: Option<u64>) -> Self {
589        self.random_state = seed;
590        self
591    }
592
593    /// Fit the progressive transfer learning model
594    pub fn fit(
595        &self,
596        X: &ArrayView2<Float>,
597        y: &ArrayView2<Float>,
598        task_order: Option<Vec<usize>>,
599    ) -> SklResult<ProgressiveTransferLearning<ProgressiveTransferLearningTrained>> {
600        let n_samples = X.nrows();
601        let n_features = X.ncols();
602        let n_tasks = y.ncols();
603
604        if n_samples != y.nrows() {
605            return Err(SklearsError::InvalidInput(
606                "Number of samples must match number of labels".to_string(),
607            ));
608        }
609
610        let mut rng = thread_rng();
611
612        // Determine task order
613        let task_order = task_order.unwrap_or_else(|| (0..n_tasks).collect());
614
615        // Initialize shared weights
616        let mut shared_weights = Array2::<Float>::zeros((n_features, n_features));
617        let shared_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
618        for i in 0..n_features {
619            for j in 0..n_features {
620                shared_weights[[i, j]] = rng.sample(shared_normal_dist);
621            }
622        }
623
624        let mut task_weights = Vec::with_capacity(n_tasks);
625
626        // Train tasks progressively
627        for &task_idx in &task_order {
628            let task_y = y.column(task_idx);
629
630            // Initialize task-specific weights
631            let mut task_weight = Array2::<Float>::zeros((n_features, 1));
632            let task_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
633            for i in 0..n_features {
634                task_weight[[i, 0]] = rng.sample(task_normal_dist);
635            }
636
637            // Train this task
638            for _ in 0..self.max_iter {
639                // Compute shared features
640                let shared_features = X.dot(&shared_weights);
641
642                // Compute task prediction
643                let task_pred = shared_features.dot(&task_weight);
644                let task_error = &task_pred.column(0) - &task_y;
645
646                // Update task weights
647                let task_error_2d = task_error.insert_axis(Axis(1));
648                let task_grad = shared_features.t().dot(&task_error_2d) / n_samples as Float;
649                task_weight -= &(&task_grad * self.learning_rate);
650
651                // Update shared weights (transfer from previous tasks)
652                if !task_weights.is_empty() {
653                    let shared_grad =
654                        X.t().dot(&task_error_2d.dot(&task_weight.t())) / n_samples as Float;
655                    shared_weights -= &(shared_grad * self.learning_rate * self.transfer_strength);
656                }
657            }
658
659            task_weights.push(task_weight);
660        }
661
662        Ok(ProgressiveTransferLearning {
663            state: ProgressiveTransferLearningTrained {
664                task_weights,
665                shared_weights,
666                task_order,
667                n_features,
668                n_tasks,
669            },
670            transfer_strength: self.transfer_strength,
671            learning_rate: self.learning_rate,
672            max_iter: self.max_iter,
673            random_state: self.random_state,
674        })
675    }
676}
677
678impl ProgressiveTransferLearning<ProgressiveTransferLearningTrained> {
679    /// Predict using the trained progressive transfer learning model
680    pub fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
681        if X.ncols() != self.state.n_features {
682            return Err(SklearsError::InvalidInput(
683                "Number of features must match training data".to_string(),
684            ));
685        }
686
687        let n_samples = X.nrows();
688        let shared_features = X.dot(&self.state.shared_weights);
689        let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_tasks));
690
691        for (i, &task_idx) in self.state.task_order.iter().enumerate() {
692            let task_pred = shared_features.dot(&self.state.task_weights[i]);
693            predictions
694                .column_mut(task_idx)
695                .assign(&task_pred.column(0));
696        }
697
698        Ok(predictions)
699    }
700
701    /// Get the shared weights
702    pub fn shared_weights(&self) -> &Array2<Float> {
703        &self.state.shared_weights
704    }
705
706    /// Get the task-specific weights
707    pub fn task_weights(&self) -> &Vec<Array2<Float>> {
708        &self.state.task_weights
709    }
710
711    /// Get the task order
712    pub fn task_order(&self) -> &Vec<usize> {
713        &self.state.task_order
714    }
715}
716
717impl Default for ProgressiveTransferLearning<Untrained> {
718    fn default() -> Self {
719        Self::new()
720    }
721}
722
723impl Estimator for ProgressiveTransferLearning<Untrained> {
724    type Config = ();
725    type Error = SklearsError;
726    type Float = Float;
727
728    fn config(&self) -> &Self::Config {
729        &()
730    }
731}
732
733impl Estimator for ProgressiveTransferLearning<ProgressiveTransferLearningTrained> {
734    type Config = ();
735    type Error = SklearsError;
736    type Float = Float;
737
738    fn config(&self) -> &Self::Config {
739        &()
740    }
741}
742
743/// Continual Learning for Multi-Task Learning
744///
745/// Implements continual learning where new tasks are learned sequentially
746/// without forgetting previously learned tasks using elastic weight consolidation.
747///
748/// # Examples
749///
750/// ```
751/// use sklears_multioutput::transfer_learning::ContinualLearning;
752/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
753/// use scirs2_core::ndarray::array;
754///
755/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
756/// let y = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
757///
758/// let continual = ContinualLearning::new()
759///     .importance_weight(1000.0)
760///     .learning_rate(0.01);
761/// ```
762#[derive(Debug, Clone)]
763pub struct ContinualLearning<S = Untrained> {
764    state: S,
765    importance_weight: Float,
766    learning_rate: Float,
767    max_iter: usize,
768    random_state: Option<u64>,
769}
770
771#[derive(Debug, Clone)]
772pub struct ContinualLearningTrained {
773    task_weights: Vec<Array2<Float>>,
774    fisher_information: Array2<Float>,
775    optimal_weights: Array2<Float>,
776    n_features: usize,
777    n_tasks: usize,
778}
779
780impl Default for ContinualLearning<Untrained> {
781    fn default() -> Self {
782        Self::new()
783    }
784}
785
786impl ContinualLearning<Untrained> {
787    /// Create a new ContinualLearning instance
788    pub fn new() -> Self {
789        Self {
790            state: Untrained,
791            importance_weight: 1000.0,
792            learning_rate: 0.01,
793            max_iter: 1000,
794            random_state: None,
795        }
796    }
797
798    /// Set the importance weight for preventing forgetting
799    pub fn importance_weight(mut self, weight: Float) -> Self {
800        self.importance_weight = weight;
801        self
802    }
803
804    /// Set the learning rate
805    pub fn learning_rate(mut self, lr: Float) -> Self {
806        self.learning_rate = lr;
807        self
808    }
809
810    /// Set the maximum number of iterations
811    pub fn max_iter(mut self, max_iter: usize) -> Self {
812        self.max_iter = max_iter;
813        self
814    }
815
816    /// Set the random state for reproducibility
817    pub fn random_state(mut self, seed: Option<u64>) -> Self {
818        self.random_state = seed;
819        self
820    }
821
822    /// Fit the continual learning model
823    pub fn fit(
824        &self,
825        tasks_X: &[ArrayView2<Float>],
826        tasks_y: &[ArrayView2<Float>],
827    ) -> SklResult<ContinualLearning<ContinualLearningTrained>> {
828        if tasks_X.len() != tasks_y.len() {
829            return Err(SklearsError::InvalidInput(
830                "Number of X and y task arrays must match".to_string(),
831            ));
832        }
833
834        if tasks_X.is_empty() {
835            return Err(SklearsError::InvalidInput("No tasks provided".to_string()));
836        }
837
838        let n_features = tasks_X[0].ncols();
839        let n_tasks = tasks_y[0].ncols();
840
841        // Initialize with random weights
842        let mut rng = thread_rng();
843
844        let mut weights = Array2::<Float>::zeros((n_features, n_tasks));
845        let weights_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
846        for i in 0..n_features {
847            for j in 0..n_tasks {
848                weights[[i, j]] = rng.sample(weights_normal_dist);
849            }
850        }
851        let mut fisher_information = Array2::<Float>::zeros((n_features, n_tasks));
852        let mut task_weights = Vec::new();
853
854        // Learn tasks sequentially
855        for (task_idx, (X, y)) in tasks_X.iter().zip(tasks_y.iter()).enumerate() {
856            if X.nrows() != y.nrows() {
857                return Err(SklearsError::InvalidInput(
858                    "Number of samples in X and y must match".to_string(),
859                ));
860            }
861
862            // Store weights before learning new task
863            let old_weights = weights.clone();
864
865            // Learn current task
866            for _ in 0..self.max_iter {
867                let predictions = X.dot(&weights);
868                let errors = &predictions - y;
869                let gradient = X.t().dot(&errors) / X.nrows() as Float;
870
871                // Add elastic weight consolidation penalty for previous tasks
872                if task_idx > 0 {
873                    let penalty =
874                        &fisher_information * (&weights - &old_weights) * self.importance_weight;
875                    weights = &weights - self.learning_rate * (&gradient + penalty);
876                } else {
877                    weights = &weights - self.learning_rate * &gradient;
878                }
879            }
880
881            // Update Fisher information matrix
882            let predictions = X.dot(&weights);
883            let errors = &predictions - y;
884            let grad_squared = X.t().dot(&errors.mapv(|x| x * x)) / X.nrows() as Float;
885            fisher_information = &fisher_information + grad_squared;
886
887            task_weights.push(weights.clone());
888        }
889
890        Ok(ContinualLearning {
891            state: ContinualLearningTrained {
892                task_weights,
893                fisher_information,
894                optimal_weights: weights,
895                n_features,
896                n_tasks,
897            },
898            importance_weight: self.importance_weight,
899            learning_rate: self.learning_rate,
900            max_iter: self.max_iter,
901            random_state: self.random_state,
902        })
903    }
904}
905
906impl ContinualLearning<ContinualLearningTrained> {
907    /// Predict using the continual learning model
908    pub fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
909        if X.ncols() != self.state.n_features {
910            return Err(SklearsError::InvalidInput(
911                "Number of features must match training data".to_string(),
912            ));
913        }
914
915        Ok(X.dot(&self.state.optimal_weights))
916    }
917
918    /// Get the task weights
919    pub fn task_weights(&self) -> &[Array2<Float>] {
920        &self.state.task_weights
921    }
922
923    /// Get the Fisher information matrix
924    pub fn fisher_information(&self) -> &Array2<Float> {
925        &self.state.fisher_information
926    }
927}
928
929impl Estimator for ContinualLearning<Untrained> {
930    type Config = ();
931    type Error = SklearsError;
932    type Float = Float;
933
934    fn config(&self) -> &Self::Config {
935        &()
936    }
937}
938
939impl Estimator for ContinualLearning<ContinualLearningTrained> {
940    type Config = ();
941    type Error = SklearsError;
942    type Float = Float;
943
944    fn config(&self) -> &Self::Config {
945        &()
946    }
947}
948
949/// Knowledge Distillation for Multi-Task Learning
950///
951/// Implements knowledge distillation where a smaller student network learns
952/// from a larger teacher network for improved efficiency and performance.
953///
954/// # Examples
955///
956/// ```
957/// use sklears_multioutput::transfer_learning::KnowledgeDistillation;
958/// // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
959/// use scirs2_core::ndarray::array;
960///
961/// let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
962/// let y = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
963///
964/// let distillation = KnowledgeDistillation::new()
965///     .temperature(3.0)
966///     .alpha(0.7)
967///     .learning_rate(0.01);
968/// ```
969#[derive(Debug, Clone)]
970pub struct KnowledgeDistillation<S = Untrained> {
971    state: S,
972    temperature: Float,
973    alpha: Float,
974    learning_rate: Float,
975    max_iter: usize,
976    random_state: Option<u64>,
977}
978
979#[derive(Debug, Clone)]
980pub struct KnowledgeDistillationTrained {
981    student_weights: Array2<Float>,
982    teacher_weights: Array2<Float>,
983    n_features: usize,
984    n_tasks: usize,
985}
986
987impl Default for KnowledgeDistillation<Untrained> {
988    fn default() -> Self {
989        Self::new()
990    }
991}
992
993impl KnowledgeDistillation<Untrained> {
994    /// Create a new KnowledgeDistillation instance
995    pub fn new() -> Self {
996        Self {
997            state: Untrained,
998            temperature: 3.0,
999            alpha: 0.7,
1000            learning_rate: 0.01,
1001            max_iter: 1000,
1002            random_state: None,
1003        }
1004    }
1005
1006    /// Set the temperature for softening teacher predictions
1007    pub fn temperature(mut self, temp: Float) -> Self {
1008        self.temperature = temp;
1009        self
1010    }
1011
1012    /// Set the alpha parameter for balancing hard and soft targets
1013    pub fn alpha(mut self, alpha: Float) -> Self {
1014        self.alpha = alpha;
1015        self
1016    }
1017
1018    /// Set the learning rate
1019    pub fn learning_rate(mut self, lr: Float) -> Self {
1020        self.learning_rate = lr;
1021        self
1022    }
1023
1024    /// Set the maximum number of iterations
1025    pub fn max_iter(mut self, max_iter: usize) -> Self {
1026        self.max_iter = max_iter;
1027        self
1028    }
1029
1030    /// Set the random state for reproducibility
1031    pub fn random_state(mut self, seed: Option<u64>) -> Self {
1032        self.random_state = seed;
1033        self
1034    }
1035
1036    /// Fit the knowledge distillation model
1037    pub fn fit(
1038        &self,
1039        X: &ArrayView2<Float>,
1040        y: &ArrayView2<Float>,
1041        teacher_predictions: &ArrayView2<Float>,
1042    ) -> SklResult<KnowledgeDistillation<KnowledgeDistillationTrained>> {
1043        if X.nrows() != y.nrows() {
1044            return Err(SklearsError::InvalidInput(
1045                "Number of samples in X and y must match".to_string(),
1046            ));
1047        }
1048
1049        if X.nrows() != teacher_predictions.nrows() {
1050            return Err(SklearsError::InvalidInput(
1051                "Number of samples in X and teacher predictions must match".to_string(),
1052            ));
1053        }
1054
1055        let n_features = X.ncols();
1056        let n_tasks = y.ncols();
1057
1058        // Initialize with random weights
1059        let mut rng = thread_rng();
1060
1061        let mut student_weights = Array2::<Float>::zeros((n_features, n_tasks));
1062        let student_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
1063        for i in 0..n_features {
1064            for j in 0..n_tasks {
1065                student_weights[[i, j]] = rng.sample(student_normal_dist);
1066            }
1067        }
1068        let mut teacher_weights = Array2::<Float>::zeros((n_features, n_tasks));
1069        let teacher_normal_dist = RandNormal::new(0.0, 0.1).unwrap();
1070        for i in 0..n_features {
1071            for j in 0..n_tasks {
1072                teacher_weights[[i, j]] = rng.sample(teacher_normal_dist);
1073            }
1074        }
1075
1076        // Train student network
1077        for _ in 0..self.max_iter {
1078            let student_predictions = X.dot(&student_weights);
1079
1080            // Soft targets from teacher (temperature-scaled)
1081            let soft_targets = teacher_predictions / self.temperature;
1082            let student_soft = &student_predictions / self.temperature;
1083
1084            // Combined loss: weighted sum of hard and soft targets
1085            let hard_loss = &student_predictions - y;
1086            let soft_loss = &student_soft - &soft_targets;
1087
1088            let combined_loss = (1.0 - self.alpha) * hard_loss + self.alpha * soft_loss;
1089            let gradient = X.t().dot(&combined_loss) / X.nrows() as Float;
1090
1091            student_weights = &student_weights - self.learning_rate * &gradient;
1092        }
1093
1094        Ok(KnowledgeDistillation {
1095            state: KnowledgeDistillationTrained {
1096                student_weights,
1097                teacher_weights,
1098                n_features,
1099                n_tasks,
1100            },
1101            temperature: self.temperature,
1102            alpha: self.alpha,
1103            learning_rate: self.learning_rate,
1104            max_iter: self.max_iter,
1105            random_state: self.random_state,
1106        })
1107    }
1108}
1109
1110impl KnowledgeDistillation<KnowledgeDistillationTrained> {
1111    /// Predict using the student network
1112    pub fn predict(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1113        if X.ncols() != self.state.n_features {
1114            return Err(SklearsError::InvalidInput(
1115                "Number of features must match training data".to_string(),
1116            ));
1117        }
1118
1119        Ok(X.dot(&self.state.student_weights))
1120    }
1121
1122    /// Get the student weights
1123    pub fn student_weights(&self) -> &Array2<Float> {
1124        &self.state.student_weights
1125    }
1126
1127    /// Get the teacher weights
1128    pub fn teacher_weights(&self) -> &Array2<Float> {
1129        &self.state.teacher_weights
1130    }
1131}
1132
1133impl Estimator for KnowledgeDistillation<Untrained> {
1134    type Config = ();
1135    type Error = SklearsError;
1136    type Float = Float;
1137
1138    fn config(&self) -> &Self::Config {
1139        &()
1140    }
1141}
1142
1143impl Estimator for KnowledgeDistillation<KnowledgeDistillationTrained> {
1144    type Config = ();
1145    type Error = SklearsError;
1146    type Float = Float;
1147
1148    fn config(&self) -> &Self::Config {
1149        &()
1150    }
1151}
1152
1153#[allow(non_snake_case)]
1154#[cfg(test)]
1155mod tests {
1156    use super::*;
1157    // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
1158    use scirs2_core::ndarray::array;
1159
1160    #[test]
1161    fn test_cross_task_transfer_learning_basic() {
1162        let source_X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 3.0]];
1163        let source_y = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.0, 0.0]];
1164        let target_X = array![[1.1, 2.1], [2.1, 3.1]];
1165        let target_y = array![[1.0, 0.0], [0.0, 1.0]];
1166
1167        let transfer = CrossTaskTransferLearning::new()
1168            .transfer_strength(0.5)
1169            .learning_rate(0.01)
1170            .max_iter(100)
1171            .random_state(Some(42));
1172
1173        let trained = transfer
1174            .fit(
1175                &source_X.view(),
1176                &source_y.view(),
1177                &target_X.view(),
1178                &target_y.view(),
1179            )
1180            .unwrap();
1181
1182        let predictions = trained.predict(&target_X.view()).unwrap();
1183        assert_eq!(predictions.dim(), (2, 2));
1184
1185        let source_predictions = trained.predict_from_source(&target_X.view()).unwrap();
1186        assert_eq!(source_predictions.dim(), (2, 2));
1187    }
1188
1189    #[test]
1190    fn test_cross_task_transfer_learning_validation() {
1191        let source_X = array![[1.0, 2.0], [2.0, 3.0]];
1192        let source_y = array![[1.0, 0.0], [0.0, 1.0]];
1193        let target_X = array![[1.1, 2.1, 3.1]]; // Different number of features
1194        let target_y = array![[1.0, 0.0]];
1195
1196        let transfer = CrossTaskTransferLearning::new();
1197
1198        // Should fail due to feature mismatch
1199        assert!(transfer
1200            .fit(
1201                &source_X.view(),
1202                &source_y.view(),
1203                &target_X.view(),
1204                &target_y.view()
1205            )
1206            .is_err());
1207    }
1208
1209    #[test]
1210    fn test_domain_adaptation_basic() {
1211        let source_X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 3.0]];
1212        let source_y = array![[1.0], [0.0], [1.0], [0.0]];
1213        let target_X = array![[1.1, 2.1], [2.1, 3.1]];
1214        let target_y = array![[1.0], [0.0]];
1215
1216        let adaptation = DomainAdaptation::new()
1217            .adaptation_strength(0.3)
1218            .learning_rate(0.01)
1219            .max_iter(100)
1220            .random_state(Some(42));
1221
1222        let trained = adaptation
1223            .fit(
1224                &source_X.view(),
1225                &source_y.view(),
1226                &target_X.view(),
1227                &target_y.view(),
1228            )
1229            .unwrap();
1230
1231        let predictions = trained.predict(&target_X.view()).unwrap();
1232        assert_eq!(predictions.dim(), (2, 1));
1233
1234        let features = trained.extract_features(&target_X.view()).unwrap();
1235        assert_eq!(features.ncols(), 1); // Hidden dimension
1236
1237        let domain_pred = trained.predict_domain(&target_X.view()).unwrap();
1238        assert_eq!(domain_pred.len(), 2);
1239    }
1240
1241    #[test]
1242    #[allow(non_snake_case)]
1243    fn test_progressive_transfer_learning_basic() {
1244        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [1.0, 3.0]];
1245        let y = array![
1246            [1.0, 0.0, 1.0],
1247            [0.0, 1.0, 0.0],
1248            [1.0, 1.0, 1.0],
1249            [0.0, 0.0, 0.0]
1250        ];
1251
1252        let transfer = ProgressiveTransferLearning::new()
1253            .transfer_strength(0.4)
1254            .learning_rate(0.01)
1255            .max_iter(100)
1256            .random_state(Some(42));
1257
1258        let trained = transfer.fit(&X.view(), &y.view(), None).unwrap();
1259
1260        let predictions = trained.predict(&X.view()).unwrap();
1261        assert_eq!(predictions.dim(), (4, 3));
1262
1263        // Check that we have weights for all tasks
1264        assert_eq!(trained.task_weights().len(), 3);
1265        assert_eq!(trained.task_order().len(), 3);
1266    }
1267
1268    #[test]
1269    #[allow(non_snake_case)]
1270    fn test_progressive_transfer_learning_custom_order() {
1271        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1272        let y = array![[1.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 1.0, 1.0]];
1273
1274        let transfer = ProgressiveTransferLearning::new().random_state(Some(42));
1275
1276        let custom_order = vec![2, 0, 1]; // Start with task 2, then 0, then 1
1277        let trained = transfer
1278            .fit(&X.view(), &y.view(), Some(custom_order.clone()))
1279            .unwrap();
1280
1281        assert_eq!(trained.task_order(), &custom_order);
1282    }
1283
1284    #[test]
1285    #[allow(non_snake_case)]
1286    fn test_transfer_learning_error_handling() {
1287        let X = array![[1.0, 2.0], [2.0, 3.0]];
1288        let y = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]; // Mismatched samples
1289
1290        let transfer = ProgressiveTransferLearning::new();
1291        assert!(transfer.fit(&X.view(), &y.view(), None).is_err());
1292    }
1293
1294    #[test]
1295    fn test_continual_learning_basic() {
1296        let X1 = array![[1.0, 2.0], [2.0, 3.0]];
1297        let y1 = array![[1.0, 0.0], [0.0, 1.0]];
1298        let X2 = array![[3.0, 1.0], [1.0, 3.0]];
1299        let y2 = array![[1.0, 1.0], [0.0, 0.0]];
1300
1301        let tasks_X = vec![X1.view(), X2.view()];
1302        let tasks_y = vec![y1.view(), y2.view()];
1303
1304        let continual = ContinualLearning::new()
1305            .importance_weight(1000.0)
1306            .learning_rate(0.01)
1307            .max_iter(100)
1308            .random_state(Some(42));
1309
1310        let trained = continual.fit(&tasks_X, &tasks_y).unwrap();
1311
1312        let predictions = trained.predict(&X1.view()).unwrap();
1313        assert_eq!(predictions.dim(), (2, 2));
1314
1315        // Check that we have weights for both tasks
1316        assert_eq!(trained.task_weights().len(), 2);
1317        assert_eq!(trained.fisher_information().dim(), (2, 2));
1318    }
1319
1320    #[test]
1321    fn test_continual_learning_error_handling() {
1322        let X1 = array![[1.0, 2.0], [2.0, 3.0]];
1323        let y1 = array![[1.0, 0.0], [0.0, 1.0]];
1324        let X2 = array![[3.0, 1.0]]; // Wrong number of samples
1325        let y2 = array![[1.0, 1.0], [0.0, 0.0]];
1326
1327        let tasks_X = vec![X1.view(), X2.view()];
1328        let tasks_y = vec![y1.view(), y2.view()];
1329
1330        let continual = ContinualLearning::new();
1331        assert!(continual.fit(&tasks_X, &tasks_y).is_err());
1332    }
1333
1334    #[test]
1335    #[allow(non_snake_case)]
1336    fn test_knowledge_distillation_basic() {
1337        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0]];
1338        let y = array![[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
1339        let teacher_predictions = array![[0.9, 0.1], [0.1, 0.9], [0.8, 0.8]];
1340
1341        let distillation = KnowledgeDistillation::new()
1342            .temperature(3.0)
1343            .alpha(0.7)
1344            .learning_rate(0.01)
1345            .max_iter(100)
1346            .random_state(Some(42));
1347
1348        let trained = distillation
1349            .fit(&X.view(), &y.view(), &teacher_predictions.view())
1350            .unwrap();
1351
1352        let predictions = trained.predict(&X.view()).unwrap();
1353        assert_eq!(predictions.dim(), (3, 2));
1354
1355        // Check that we have student and teacher weights
1356        assert_eq!(trained.student_weights().dim(), (2, 2));
1357        assert_eq!(trained.teacher_weights().dim(), (2, 2));
1358    }
1359
1360    #[test]
1361    #[allow(non_snake_case)]
1362    fn test_knowledge_distillation_error_handling() {
1363        let X = array![[1.0, 2.0], [2.0, 3.0]];
1364        let y = array![[1.0, 0.0], [0.0, 1.0]];
1365        let teacher_predictions = array![[0.9, 0.1], [0.1, 0.9], [0.8, 0.8]]; // Wrong number of samples
1366
1367        let distillation = KnowledgeDistillation::new();
1368        assert!(distillation
1369            .fit(&X.view(), &y.view(), &teacher_predictions.view())
1370            .is_err());
1371    }
1372
1373    #[test]
1374    fn test_knowledge_distillation_configuration() {
1375        let distillation = KnowledgeDistillation::new()
1376            .temperature(5.0)
1377            .alpha(0.5)
1378            .learning_rate(0.001)
1379            .max_iter(2000)
1380            .random_state(Some(123));
1381
1382        // Test configuration parameters
1383        assert_eq!(distillation.temperature, 5.0);
1384        assert_eq!(distillation.alpha, 0.5);
1385        assert_eq!(distillation.learning_rate, 0.001);
1386        assert_eq!(distillation.max_iter, 2000);
1387        assert_eq!(distillation.random_state, Some(123));
1388    }
1389}