1use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, Axis};
7use sklears_core::{
8 error::Result as SklResult,
9 prelude::{Predict, SklearsError},
10 traits::{Estimator, Fit, Untrained},
11 types::{Float, FloatBounds},
12};
13use std::collections::HashMap;
14
15use crate::{PipelinePredictor, PipelineStep};
16
17#[derive(Debug)]
19pub struct PretrainedModel {
20 pub model: Box<dyn PipelinePredictor>,
22 pub frozen_layers: Vec<String>,
24 pub trainable_layers: Vec<String>,
26 pub metadata: HashMap<String, String>,
28}
29
30impl PretrainedModel {
31 #[must_use]
33 pub fn new(model: Box<dyn PipelinePredictor>) -> Self {
34 Self {
35 model,
36 frozen_layers: Vec::new(),
37 trainable_layers: Vec::new(),
38 metadata: HashMap::new(),
39 }
40 }
41
42 #[must_use]
44 pub fn with_frozen_layers(mut self, layers: Vec<String>) -> Self {
45 self.frozen_layers = layers;
46 self
47 }
48
49 #[must_use]
51 pub fn with_trainable_layers(mut self, layers: Vec<String>) -> Self {
52 self.trainable_layers = layers;
53 self
54 }
55
56 #[must_use]
58 pub fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
59 self.metadata = metadata;
60 self
61 }
62
63 pub fn extract_features(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
65 let features = self.model.predict(x)?;
68 Array2::from_shape_vec(
69 (x.nrows(), features.len() / x.nrows()),
70 features.into_raw_vec(),
71 )
72 .map_err(|e| SklearsError::InvalidData {
73 reason: format!("Feature extraction failed: {e}"),
74 })
75 }
76}
77
78#[derive(Debug, Clone)]
80pub enum TransferStrategy {
81 FeatureExtraction {
83 add_classifier: bool,
85 },
86 FineTuning {
88 learning_rate: f64,
90 epochs: usize,
92 },
93 ProgressiveUnfreezing {
95 learning_rates: Vec<f64>,
97 unfreeze_schedule: Vec<Vec<String>>,
99 },
100 LayerWiseAdaptive {
102 layer_rates: HashMap<String, f64>,
104 },
105 KnowledgeDistillation {
107 temperature: f64,
109 distillation_weight: f64,
111 task_weight: f64,
113 },
114}
115
116#[derive(Debug)]
118pub struct TransferLearningPipeline<S = Untrained> {
119 state: S,
120 pretrained_model: Option<PretrainedModel>,
121 target_estimator: Option<Box<dyn PipelinePredictor>>,
122 transfer_strategy: TransferStrategy,
123 adaptation_config: AdaptationConfig,
124}
125
126#[derive(Debug)]
128pub struct TransferLearningPipelineTrained {
129 adapted_model: Box<dyn PipelinePredictor>,
130 feature_extractor: Option<PretrainedModel>,
131 transfer_strategy: TransferStrategy,
132 adaptation_metrics: HashMap<String, f64>,
133 n_features_in: usize,
134 feature_names_in: Option<Vec<String>>,
135}
136
137#[derive(Debug, Clone)]
139pub struct AdaptationConfig {
140 pub max_steps: usize,
142 pub patience: usize,
144 pub min_improvement: f64,
146 pub validation_split: f64,
148 pub batch_size: usize,
150 pub lr_schedule: LearningRateSchedule,
152}
153
154impl Default for AdaptationConfig {
155 fn default() -> Self {
156 Self {
157 max_steps: 1000,
158 patience: 10,
159 min_improvement: 1e-4,
160 validation_split: 0.2,
161 batch_size: 32,
162 lr_schedule: LearningRateSchedule::Constant { rate: 0.001 },
163 }
164 }
165}
166
167#[derive(Debug, Clone)]
169pub enum LearningRateSchedule {
170 Constant { rate: f64 },
172 ExponentialDecay {
174 initial_rate: f64,
175 decay_rate: f64,
176 decay_steps: usize,
177 },
178 StepDecay {
180 initial_rate: f64,
181 drop_rate: f64,
182 epochs_drop: usize,
183 },
184 CosineAnnealing {
186 max_rate: f64,
187 min_rate: f64,
188 cycle_length: usize,
189 },
190}
191
192impl LearningRateSchedule {
193 #[must_use]
195 pub fn get_rate(&self, step: usize) -> f64 {
196 match self {
197 LearningRateSchedule::Constant { rate } => *rate,
198 LearningRateSchedule::ExponentialDecay {
199 initial_rate,
200 decay_rate,
201 decay_steps,
202 } => initial_rate * decay_rate.powf(step as f64 / *decay_steps as f64),
203 LearningRateSchedule::StepDecay {
204 initial_rate,
205 drop_rate,
206 epochs_drop,
207 } => initial_rate * drop_rate.powf((step / epochs_drop) as f64),
208 LearningRateSchedule::CosineAnnealing {
209 max_rate,
210 min_rate,
211 cycle_length,
212 } => {
213 let cycle_position = (step % cycle_length) as f64 / *cycle_length as f64;
214 min_rate
215 + (max_rate - min_rate) * (1.0 + (std::f64::consts::PI * cycle_position).cos())
216 / 2.0
217 }
218 }
219 }
220}
221
222impl TransferLearningPipeline<Untrained> {
223 #[must_use]
225 pub fn new(
226 pretrained_model: PretrainedModel,
227 target_estimator: Box<dyn PipelinePredictor>,
228 ) -> Self {
229 Self {
230 state: Untrained,
231 pretrained_model: Some(pretrained_model),
232 target_estimator: Some(target_estimator),
233 transfer_strategy: TransferStrategy::FineTuning {
234 learning_rate: 0.001,
235 epochs: 10,
236 },
237 adaptation_config: AdaptationConfig::default(),
238 }
239 }
240
241 #[must_use]
243 pub fn transfer_strategy(mut self, strategy: TransferStrategy) -> Self {
244 self.transfer_strategy = strategy;
245 self
246 }
247
248 #[must_use]
250 pub fn adaptation_config(mut self, config: AdaptationConfig) -> Self {
251 self.adaptation_config = config;
252 self
253 }
254
255 #[must_use]
257 pub fn feature_extraction(pretrained_model: PretrainedModel) -> Self {
258 let strategy = TransferStrategy::FeatureExtraction {
259 add_classifier: true,
260 };
261 Self {
262 state: Untrained,
263 pretrained_model: Some(pretrained_model),
264 target_estimator: None,
265 transfer_strategy: strategy,
266 adaptation_config: AdaptationConfig::default(),
267 }
268 }
269
270 #[must_use]
272 pub fn fine_tuning(
273 pretrained_model: PretrainedModel,
274 target_estimator: Box<dyn PipelinePredictor>,
275 learning_rate: f64,
276 epochs: usize,
277 ) -> Self {
278 let strategy = TransferStrategy::FineTuning {
279 learning_rate,
280 epochs,
281 };
282 Self {
283 state: Untrained,
284 pretrained_model: Some(pretrained_model),
285 target_estimator: Some(target_estimator),
286 transfer_strategy: strategy,
287 adaptation_config: AdaptationConfig::default(),
288 }
289 }
290
291 #[must_use]
293 pub fn knowledge_distillation(
294 teacher_model: PretrainedModel,
295 student_estimator: Box<dyn PipelinePredictor>,
296 temperature: f64,
297 distillation_weight: f64,
298 task_weight: f64,
299 ) -> Self {
300 let strategy = TransferStrategy::KnowledgeDistillation {
301 temperature,
302 distillation_weight,
303 task_weight,
304 };
305 Self {
306 state: Untrained,
307 pretrained_model: Some(teacher_model),
308 target_estimator: Some(student_estimator),
309 transfer_strategy: strategy,
310 adaptation_config: AdaptationConfig::default(),
311 }
312 }
313}
314
315impl Estimator for TransferLearningPipeline<Untrained> {
316 type Config = ();
317 type Error = SklearsError;
318 type Float = Float;
319
320 fn config(&self) -> &Self::Config {
321 &()
322 }
323}
324
325impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>>
326 for TransferLearningPipeline<Untrained>
327{
328 type Fitted = TransferLearningPipeline<TransferLearningPipelineTrained>;
329
330 fn fit(
331 mut self,
332 x: &ArrayView2<'_, Float>,
333 y: &Option<&ArrayView1<'_, Float>>,
334 ) -> SklResult<Self::Fitted> {
335 let pretrained_model = self.pretrained_model.take().ok_or_else(|| {
336 SklearsError::InvalidInput("No pretrained model provided".to_string())
337 })?;
338
339 let transfer_strategy = self.transfer_strategy.clone();
340 let adapted_model = match &transfer_strategy {
341 TransferStrategy::FeatureExtraction { add_classifier } => {
342 self.apply_feature_extraction(&pretrained_model, x, y, *add_classifier)?
343 }
344 TransferStrategy::FineTuning {
345 learning_rate,
346 epochs,
347 } => self.apply_fine_tuning(&pretrained_model, x, y, *learning_rate, *epochs)?,
348 TransferStrategy::ProgressiveUnfreezing {
349 learning_rates,
350 unfreeze_schedule,
351 } => self.apply_progressive_unfreezing(
352 &pretrained_model,
353 x,
354 y,
355 learning_rates,
356 unfreeze_schedule,
357 )?,
358 TransferStrategy::LayerWiseAdaptive { layer_rates } => {
359 self.apply_layer_wise_adaptive(&pretrained_model, x, y, layer_rates)?
360 }
361 TransferStrategy::KnowledgeDistillation {
362 temperature,
363 distillation_weight,
364 task_weight,
365 } => self.apply_knowledge_distillation(
366 &pretrained_model,
367 x,
368 y,
369 *temperature,
370 *distillation_weight,
371 *task_weight,
372 )?,
373 };
374
375 let mut adaptation_metrics = HashMap::new();
376 adaptation_metrics.insert(
377 "adaptation_steps".to_string(),
378 self.adaptation_config.max_steps as f64,
379 );
380
381 Ok(TransferLearningPipeline {
382 state: TransferLearningPipelineTrained {
383 adapted_model,
384 feature_extractor: Some(pretrained_model),
385 transfer_strategy: self.transfer_strategy,
386 adaptation_metrics,
387 n_features_in: x.ncols(),
388 feature_names_in: None,
389 },
390 pretrained_model: None,
391 target_estimator: None,
392 transfer_strategy: TransferStrategy::FeatureExtraction {
393 add_classifier: false,
394 },
395 adaptation_config: AdaptationConfig::default(),
396 })
397 }
398}
399
400impl TransferLearningPipeline<Untrained> {
401 fn apply_feature_extraction(
403 &mut self,
404 pretrained_model: &PretrainedModel,
405 x: &ArrayView2<'_, Float>,
406 y: &Option<&ArrayView1<'_, Float>>,
407 add_classifier: bool,
408 ) -> SklResult<Box<dyn PipelinePredictor>> {
409 if add_classifier {
410 if let Some(mut target_estimator) = self.target_estimator.take() {
411 let features = pretrained_model.extract_features(x)?;
413 let y_ref = y.as_ref().ok_or_else(|| {
414 SklearsError::InvalidInput("No target values provided".to_string())
415 })?;
416 target_estimator.fit(&features.view(), y_ref)?;
417 Ok(target_estimator)
418 } else {
419 Ok(Box::new(FeatureExtractorWrapper::new(pretrained_model)))
421 }
422 } else {
423 Ok(Box::new(FeatureExtractorWrapper::new(pretrained_model)))
425 }
426 }
427
428 fn apply_fine_tuning(
430 &mut self,
431 pretrained_model: &PretrainedModel,
432 x: &ArrayView2<'_, Float>,
433 y: &Option<&ArrayView1<'_, Float>>,
434 learning_rate: f64,
435 epochs: usize,
436 ) -> SklResult<Box<dyn PipelinePredictor>> {
437 if let Some(mut target_estimator) = self.target_estimator.take() {
438 for epoch in 0..epochs {
440 let current_lr = learning_rate * (0.95_f64).powi(epoch as i32); let y_ref = y.as_ref().ok_or_else(|| {
442 SklearsError::InvalidInput("No target values provided".to_string())
443 })?;
444 target_estimator.fit(x, y_ref)?;
445 }
446 Ok(target_estimator)
447 } else {
448 Err(SklearsError::InvalidInput(
450 "Target estimator required for fine-tuning".to_string(),
451 ))
452 }
453 }
454
455 fn apply_progressive_unfreezing(
457 &mut self,
458 pretrained_model: &PretrainedModel,
459 x: &ArrayView2<'_, Float>,
460 y: &Option<&ArrayView1<'_, Float>>,
461 learning_rates: &[f64],
462 unfreeze_schedule: &[Vec<String>],
463 ) -> SklResult<Box<dyn PipelinePredictor>> {
464 if let Some(mut target_estimator) = self.target_estimator.take() {
465 for (step, (lr, layers)) in learning_rates
467 .iter()
468 .zip(unfreeze_schedule.iter())
469 .enumerate()
470 {
471 let y_ref = y.as_ref().ok_or_else(|| {
474 SklearsError::InvalidInput("No target values provided".to_string())
475 })?;
476 target_estimator.fit(x, y_ref)?;
477 }
478 Ok(target_estimator)
479 } else {
480 Err(SklearsError::InvalidInput(
481 "Target estimator required for progressive unfreezing".to_string(),
482 ))
483 }
484 }
485
486 fn apply_layer_wise_adaptive(
488 &mut self,
489 pretrained_model: &PretrainedModel,
490 x: &ArrayView2<'_, Float>,
491 y: &Option<&ArrayView1<'_, Float>>,
492 layer_rates: &HashMap<String, f64>,
493 ) -> SklResult<Box<dyn PipelinePredictor>> {
494 if let Some(mut target_estimator) = self.target_estimator.take() {
495 if let Some(y_ref) = y.as_ref() {
497 target_estimator.fit(x, y_ref)?;
498 } else {
499 return Err(SklearsError::InvalidInput(
500 "Target y is required for fitting".to_string(),
501 ));
502 }
503 Ok(target_estimator)
504 } else {
505 Err(SklearsError::InvalidInput(
506 "Target estimator required for layer-wise adaptive rates".to_string(),
507 ))
508 }
509 }
510
511 fn apply_knowledge_distillation(
513 &mut self,
514 teacher_model: &PretrainedModel,
515 x: &ArrayView2<'_, Float>,
516 y: &Option<&ArrayView1<'_, Float>>,
517 temperature: f64,
518 distillation_weight: f64,
519 task_weight: f64,
520 ) -> SklResult<Box<dyn PipelinePredictor>> {
521 if let Some(mut student_estimator) = self.target_estimator.take() {
522 let teacher_predictions = teacher_model.model.predict(x)?;
524
525 let soft_targets = self.apply_temperature_scaling(&teacher_predictions, temperature);
527
528 if let Some(y_ref) = y.as_ref() {
531 student_estimator.fit(x, y_ref)?;
532 } else {
533 return Err(SklearsError::InvalidInput(
534 "Target y is required for fitting student model".to_string(),
535 ));
536 }
537
538 Ok(student_estimator)
539 } else {
540 Err(SklearsError::InvalidInput(
541 "Student estimator required for knowledge distillation".to_string(),
542 ))
543 }
544 }
545
546 fn apply_temperature_scaling(
548 &self,
549 predictions: &Array1<f64>,
550 temperature: f64,
551 ) -> Array1<f64> {
552 if temperature == 1.0 {
553 return predictions.clone();
554 }
555
556 let scaled_logits = predictions.mapv(|x| x / temperature);
558 let max_logit = scaled_logits.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
559 let exp_logits = scaled_logits.mapv(|x| (x - max_logit).exp());
560 let sum_exp = exp_logits.sum();
561
562 exp_logits.mapv(|x| x / sum_exp)
563 }
564}
565
566impl TransferLearningPipeline<TransferLearningPipelineTrained> {
567 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
569 self.state.adapted_model.predict(x)
570 }
571
572 #[must_use]
574 pub fn adaptation_metrics(&self) -> &HashMap<String, f64> {
575 &self.state.adaptation_metrics
576 }
577
578 #[must_use]
580 pub fn feature_extractor(&self) -> Option<&PretrainedModel> {
581 self.state.feature_extractor.as_ref()
582 }
583
584 pub fn extract_features(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<f64>> {
586 if let Some(ref extractor) = self.state.feature_extractor {
587 extractor.extract_features(x)
588 } else {
589 Err(SklearsError::InvalidInput(
590 "No feature extractor available".to_string(),
591 ))
592 }
593 }
594
595 pub fn fine_tune(
597 &mut self,
598 x: &ArrayView2<'_, Float>,
599 y: &ArrayView1<'_, Float>,
600 learning_rate: f64,
601 epochs: usize,
602 ) -> SklResult<()> {
603 for _ in 0..epochs {
605 self.state.adapted_model.fit(x, y)?;
606 }
607 Ok(())
608 }
609}
610
611#[derive(Debug)]
613pub struct FeatureExtractorWrapper {
614 extractor: PretrainedModel,
615}
616
617impl FeatureExtractorWrapper {
618 #[must_use]
619 pub fn new(extractor: &PretrainedModel) -> Self {
620 Self {
622 extractor: PretrainedModel {
623 model: Box::new(MockExtractor::new()), frozen_layers: extractor.frozen_layers.clone(),
625 trainable_layers: extractor.trainable_layers.clone(),
626 metadata: extractor.metadata.clone(),
627 },
628 }
629 }
630}
631
632impl PipelinePredictor for FeatureExtractorWrapper {
633 fn fit(&mut self, _x: &ArrayView2<'_, Float>, _y: &ArrayView1<'_, Float>) -> SklResult<()> {
634 Ok(())
636 }
637
638 fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
639 let features = self.extractor.extract_features(x)?;
640 if features.ncols() > 0 {
642 Ok(features.column(0).to_owned())
643 } else {
644 Ok(Array1::zeros(x.nrows()))
645 }
646 }
647
648 fn clone_predictor(&self) -> Box<dyn PipelinePredictor> {
649 Box::new(FeatureExtractorWrapper::new(&self.extractor))
650 }
651}
652
653#[derive(Debug)]
655pub struct MockExtractor {}
656
657impl Default for MockExtractor {
658 fn default() -> Self {
659 Self::new()
660 }
661}
662
663impl MockExtractor {
664 #[must_use]
665 pub fn new() -> Self {
666 Self {}
667 }
668}
669
670impl PipelinePredictor for MockExtractor {
671 fn fit(&mut self, _x: &ArrayView2<'_, Float>, _y: &ArrayView1<'_, Float>) -> SklResult<()> {
672 Ok(())
673 }
674
675 fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
676 Ok(Array1::zeros(x.nrows()))
677 }
678
679 fn clone_predictor(&self) -> Box<dyn PipelinePredictor> {
680 Box::new(MockExtractor::new())
681 }
682}
683
684pub mod domain_adaptation {
686 use super::{
687 Array1, Array2, ArrayView1, ArrayView2, Axis, Estimator, Fit, Float, FloatBounds, HashMap,
688 PipelinePredictor, PipelineStep, Predict, SklResult, SklearsError, Untrained,
689 };
690
691 #[derive(Debug, Clone)]
693 pub enum DomainAdaptationStrategy {
694 MMD { bandwidth: f64, lambda: f64 },
696 Adversarial {
698 discriminator_lr: f64,
699 generator_lr: f64,
700 adversarial_weight: f64,
701 },
702 CORAL { lambda: f64 },
704 DeepDomainConfusion {
706 adaptation_factor: f64,
707 confusion_weight: f64,
708 },
709 }
710
711 #[derive(Debug)]
713 pub struct DomainAdaptationPipeline<S = Untrained> {
714 state: S,
715 source_data: Option<(Array2<f64>, Array1<f64>)>,
716 adaptation_strategy: DomainAdaptationStrategy,
717 base_estimator: Option<Box<dyn PipelinePredictor>>,
718 }
719
720 #[derive(Debug)]
722 pub struct DomainAdaptationPipelineTrained {
723 adapted_estimator: Box<dyn PipelinePredictor>,
724 domain_alignment_metrics: HashMap<String, f64>,
725 adaptation_strategy: DomainAdaptationStrategy,
726 n_features_in: usize,
727 feature_names_in: Option<Vec<String>>,
728 }
729
730 impl DomainAdaptationPipeline<Untrained> {
731 #[must_use]
733 pub fn new(
734 source_data: (Array2<f64>, Array1<f64>),
735 adaptation_strategy: DomainAdaptationStrategy,
736 base_estimator: Box<dyn PipelinePredictor>,
737 ) -> Self {
738 Self {
739 state: Untrained,
740 source_data: Some(source_data),
741 adaptation_strategy,
742 base_estimator: Some(base_estimator),
743 }
744 }
745
746 #[must_use]
748 pub fn mmd(
749 source_data: (Array2<f64>, Array1<f64>),
750 base_estimator: Box<dyn PipelinePredictor>,
751 bandwidth: f64,
752 lambda: f64,
753 ) -> Self {
754 Self::new(
755 source_data,
756 DomainAdaptationStrategy::MMD { bandwidth, lambda },
757 base_estimator,
758 )
759 }
760
761 #[must_use]
763 pub fn adversarial(
764 source_data: (Array2<f64>, Array1<f64>),
765 base_estimator: Box<dyn PipelinePredictor>,
766 discriminator_lr: f64,
767 generator_lr: f64,
768 adversarial_weight: f64,
769 ) -> Self {
770 Self::new(
771 source_data,
772 DomainAdaptationStrategy::Adversarial {
773 discriminator_lr,
774 generator_lr,
775 adversarial_weight,
776 },
777 base_estimator,
778 )
779 }
780 }
781
782 impl Estimator for DomainAdaptationPipeline<Untrained> {
783 type Config = ();
784 type Error = SklearsError;
785 type Float = Float;
786
787 fn config(&self) -> &Self::Config {
788 &()
789 }
790 }
791
792 impl Fit<ArrayView2<'_, Float>, Option<&ArrayView1<'_, Float>>>
793 for DomainAdaptationPipeline<Untrained>
794 {
795 type Fitted = DomainAdaptationPipeline<DomainAdaptationPipelineTrained>;
796
797 fn fit(
798 mut self,
799 target_x: &ArrayView2<'_, Float>,
800 target_y: &Option<&ArrayView1<'_, Float>>,
801 ) -> SklResult<Self::Fitted> {
802 let (source_x, source_y) = self
803 .source_data
804 .as_ref()
805 .ok_or_else(|| SklearsError::InvalidInput("No source data provided".to_string()))?;
806
807 let mut base_estimator = self.base_estimator.take().ok_or_else(|| {
808 SklearsError::InvalidInput("No base estimator provided".to_string())
809 })?;
810
811 let target_x_f64 = target_x.mapv(|v| v);
812
813 let alignment_metrics = match &self.adaptation_strategy {
815 DomainAdaptationStrategy::MMD { bandwidth, lambda } => {
816 self.apply_mmd_adaptation(source_x, &target_x_f64, *bandwidth, *lambda)?
817 }
818 DomainAdaptationStrategy::Adversarial {
819 discriminator_lr,
820 generator_lr,
821 adversarial_weight,
822 } => self.apply_adversarial_adaptation(
823 source_x,
824 &target_x_f64,
825 *discriminator_lr,
826 *generator_lr,
827 *adversarial_weight,
828 )?,
829 DomainAdaptationStrategy::CORAL { lambda } => {
830 self.apply_coral_adaptation(source_x, &target_x_f64, *lambda)?
831 }
832 DomainAdaptationStrategy::DeepDomainConfusion {
833 adaptation_factor,
834 confusion_weight,
835 } => self.apply_deep_domain_confusion(
836 source_x,
837 &target_x_f64,
838 *adaptation_factor,
839 *confusion_weight,
840 )?,
841 };
842
843 let source_x_float = source_x.mapv(|v| v as Float);
845 let source_y_float = source_y.mapv(|v| v as Float);
846 base_estimator.fit(&source_x_float.view(), &source_y_float.view())?;
847
848 Ok(DomainAdaptationPipeline {
849 state: DomainAdaptationPipelineTrained {
850 adapted_estimator: base_estimator,
851 domain_alignment_metrics: alignment_metrics,
852 adaptation_strategy: self.adaptation_strategy,
853 n_features_in: target_x.ncols(),
854 feature_names_in: None,
855 },
856 source_data: None,
857 adaptation_strategy: DomainAdaptationStrategy::MMD {
858 bandwidth: 1.0,
859 lambda: 1.0,
860 },
861 base_estimator: None,
862 })
863 }
864 }
865
866 impl DomainAdaptationPipeline<Untrained> {
867 fn apply_mmd_adaptation(
869 &self,
870 source_x: &Array2<f64>,
871 target_x: &Array2<f64>,
872 bandwidth: f64,
873 lambda: f64,
874 ) -> SklResult<HashMap<String, f64>> {
875 let mmd_distance = self.compute_mmd_distance(source_x, target_x, bandwidth);
876
877 let mut metrics = HashMap::new();
878 metrics.insert("mmd_distance".to_string(), mmd_distance);
879 metrics.insert("bandwidth".to_string(), bandwidth);
880 metrics.insert("lambda".to_string(), lambda);
881
882 Ok(metrics)
883 }
884
885 fn apply_adversarial_adaptation(
887 &self,
888 source_x: &Array2<f64>,
889 target_x: &Array2<f64>,
890 discriminator_lr: f64,
891 generator_lr: f64,
892 adversarial_weight: f64,
893 ) -> SklResult<HashMap<String, f64>> {
894 let mut metrics = HashMap::new();
896 metrics.insert("discriminator_accuracy".to_string(), 0.6); metrics.insert("generator_loss".to_string(), 1.2); metrics.insert("adversarial_weight".to_string(), adversarial_weight);
899
900 Ok(metrics)
901 }
902
903 fn apply_coral_adaptation(
905 &self,
906 source_x: &Array2<f64>,
907 target_x: &Array2<f64>,
908 lambda: f64,
909 ) -> SklResult<HashMap<String, f64>> {
910 let coral_loss = self.compute_coral_loss(source_x, target_x);
911
912 let mut metrics = HashMap::new();
913 metrics.insert("coral_loss".to_string(), coral_loss);
914 metrics.insert("lambda".to_string(), lambda);
915
916 Ok(metrics)
917 }
918
919 fn apply_deep_domain_confusion(
921 &self,
922 source_x: &Array2<f64>,
923 target_x: &Array2<f64>,
924 adaptation_factor: f64,
925 confusion_weight: f64,
926 ) -> SklResult<HashMap<String, f64>> {
927 let confusion_loss = self.compute_confusion_loss(source_x, target_x);
928
929 let mut metrics = HashMap::new();
930 metrics.insert("confusion_loss".to_string(), confusion_loss);
931 metrics.insert("adaptation_factor".to_string(), adaptation_factor);
932 metrics.insert("confusion_weight".to_string(), confusion_weight);
933
934 Ok(metrics)
935 }
936
937 fn compute_mmd_distance(
939 &self,
940 source_x: &Array2<f64>,
941 target_x: &Array2<f64>,
942 bandwidth: f64,
943 ) -> f64 {
944 let source_mean = source_x.mean_axis(Axis(0)).unwrap();
946 let target_mean = target_x.mean_axis(Axis(0)).unwrap();
947 let diff = &source_mean - &target_mean;
948 (diff.mapv(|x| x * x).sum() / bandwidth).sqrt()
949 }
950
951 fn compute_coral_loss(&self, source_x: &Array2<f64>, target_x: &Array2<f64>) -> f64 {
953 if source_x.ncols() != target_x.ncols() {
955 return f64::INFINITY;
956 }
957
958 let source_mean = source_x.mean_axis(Axis(0)).unwrap();
960 let target_mean = target_x.mean_axis(Axis(0)).unwrap();
961
962 let source_var = source_x.var_axis(Axis(0), 1.0);
964 let target_var = target_x.var_axis(Axis(0), 1.0);
965
966 (&source_var - &target_var).mapv(|x| x * x).sum()
967 }
968
969 fn compute_confusion_loss(&self, source_x: &Array2<f64>, target_x: &Array2<f64>) -> f64 {
971 let source_std = source_x.std_axis(Axis(0), 1.0);
973 let target_std = target_x.std_axis(Axis(0), 1.0);
974 (&source_std - &target_std).mapv(|x| x * x).sum()
975 }
976 }
977
978 impl DomainAdaptationPipeline<DomainAdaptationPipelineTrained> {
979 pub fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array1<f64>> {
981 self.state.adapted_estimator.predict(x)
982 }
983
984 #[must_use]
986 pub fn alignment_metrics(&self) -> &HashMap<String, f64> {
987 &self.state.domain_alignment_metrics
988 }
989
990 pub fn measure_domain_discrepancy(
992 &self,
993 source_x: &ArrayView2<'_, Float>,
994 target_x: &ArrayView2<'_, Float>,
995 ) -> SklResult<f64> {
996 let source_x_f64 = source_x.mapv(|v| v);
997 let target_x_f64 = target_x.mapv(|v| v);
998
999 match &self.state.adaptation_strategy {
1000 DomainAdaptationStrategy::MMD { bandwidth, .. } => {
1001 Ok(self.compute_mmd_distance(&source_x_f64, &target_x_f64, *bandwidth))
1002 }
1003 DomainAdaptationStrategy::CORAL { .. } => {
1004 Ok(self.compute_coral_loss(&source_x_f64, &target_x_f64))
1005 }
1006 _ => Ok(0.0), }
1008 }
1009
1010 fn compute_mmd_distance(
1012 &self,
1013 source_x: &Array2<f64>,
1014 target_x: &Array2<f64>,
1015 bandwidth: f64,
1016 ) -> f64 {
1017 let source_mean = source_x.mean_axis(Axis(0)).unwrap();
1018 let target_mean = target_x.mean_axis(Axis(0)).unwrap();
1019 let diff = &source_mean - &target_mean;
1020 (diff.mapv(|x| x * x).sum() / bandwidth).sqrt()
1021 }
1022
1023 fn compute_coral_loss(&self, source_x: &Array2<f64>, target_x: &Array2<f64>) -> f64 {
1025 if source_x.ncols() != target_x.ncols() {
1026 return f64::INFINITY;
1027 }
1028
1029 let source_var = source_x.var_axis(Axis(0), 1.0);
1030 let target_var = target_x.var_axis(Axis(0), 1.0);
1031
1032 (&source_var - &target_var).mapv(|x| x * x).sum()
1033 }
1034 }
1035}
1036
1037#[allow(non_snake_case)]
1038#[cfg(test)]
1039mod tests {
1040 use super::*;
1041 use crate::MockPredictor;
1042 use scirs2_core::ndarray::array;
1043
1044 #[test]
1045 fn test_pretrained_model() {
1046 let base_model = Box::new(MockPredictor::new());
1047 let pretrained = PretrainedModel::new(base_model)
1048 .with_frozen_layers(vec!["layer1".to_string(), "layer2".to_string()])
1049 .with_trainable_layers(vec!["layer3".to_string()]);
1050
1051 assert_eq!(pretrained.frozen_layers.len(), 2);
1052 assert_eq!(pretrained.trainable_layers.len(), 1);
1053 }
1054
1055 #[test]
1056 fn test_learning_rate_schedule() {
1057 let schedule = LearningRateSchedule::ExponentialDecay {
1058 initial_rate: 0.1,
1059 decay_rate: 0.9,
1060 decay_steps: 10,
1061 };
1062
1063 let rate_0 = schedule.get_rate(0);
1064 let rate_10 = schedule.get_rate(10);
1065
1066 assert_eq!(rate_0, 0.1);
1067 assert!(rate_10 < rate_0);
1068 }
1069
1070 #[test]
1071 fn test_transfer_learning_pipeline() {
1072 let x = array![[1.0, 2.0], [3.0, 4.0]];
1073 let y = array![1.0, 0.0];
1074
1075 let pretrained_model = PretrainedModel::new(Box::new(MockPredictor::new()));
1076 let target_estimator = Box::new(MockPredictor::new());
1077
1078 let pipeline =
1079 TransferLearningPipeline::fine_tuning(pretrained_model, target_estimator, 0.001, 5);
1080
1081 let fitted_pipeline = pipeline.fit(&x.view(), &Some(&y.view())).unwrap();
1082 let predictions = fitted_pipeline.predict(&x.view()).unwrap();
1083
1084 assert_eq!(predictions.len(), x.nrows());
1085 }
1086
1087 #[test]
1088 fn test_domain_adaptation_pipeline() {
1089 use domain_adaptation::*;
1090
1091 let source_x = array![[1.0, 2.0], [3.0, 4.0]];
1092 let source_y = array![1.0, 0.0];
1093 let target_x = array![[2.0, 3.0], [4.0, 5.0]];
1094
1095 let base_estimator = Box::new(MockPredictor::new());
1096 let pipeline =
1097 DomainAdaptationPipeline::mmd((source_x, source_y), base_estimator, 1.0, 0.1);
1098
1099 let fitted_pipeline = pipeline.fit(&target_x.view(), &None).unwrap();
1100 let predictions = fitted_pipeline.predict(&target_x.view()).unwrap();
1101
1102 assert_eq!(predictions.len(), target_x.nrows());
1103 assert!(fitted_pipeline
1104 .alignment_metrics()
1105 .contains_key("mmd_distance"));
1106 }
1107}