1use chrono::{DateTime, Utc};
16use scirs2_core::ndarray_ext::{Array1, Array2, Array3};
17use scirs2_core::random::Random;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet, VecDeque};
20use uuid::Uuid;
21
22use crate::{
23 ml::{LearnedShape, ModelMetrics},
24 Result, ShaclAiError,
25};
26
27#[derive(Debug)]
29pub struct MultiTaskLearner {
30 config: MultiTaskConfig,
31 shared_encoder: SharedEncoder,
32 task_heads: HashMap<String, TaskHead>,
33 task_weights: HashMap<String, f64>,
34 task_relationships: TaskRelationshipGraph,
35 performance_tracker: MultiTaskPerformanceTracker,
36 gradient_normalizer: GradientNormalizer,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct MultiTaskConfig {
42 pub sharing_type: SharingType,
44
45 pub shared_dim: usize,
47
48 pub task_specific_dims: Vec<usize>,
50
51 pub enable_dynamic_weighting: bool,
53
54 pub enable_gradient_normalization: bool,
56
57 pub enable_task_attention: bool,
59
60 pub shared_learning_rate: f64,
62
63 pub task_learning_rate: f64,
65
66 pub temperature: f64,
68
69 pub enable_curriculum: bool,
71
72 pub max_concurrent_tasks: usize,
74
75 pub enable_auxiliary_tasks: bool,
77
78 pub auxiliary_task_weight: f64,
80}
81
82impl Default for MultiTaskConfig {
83 fn default() -> Self {
84 Self {
85 sharing_type: SharingType::HardSharing,
86 shared_dim: 256,
87 task_specific_dims: vec![128, 64],
88 enable_dynamic_weighting: true,
89 enable_gradient_normalization: true,
90 enable_task_attention: true,
91 shared_learning_rate: 0.001,
92 task_learning_rate: 0.01,
93 temperature: 1.0,
94 enable_curriculum: true,
95 max_concurrent_tasks: 5,
96 enable_auxiliary_tasks: true,
97 auxiliary_task_weight: 0.3,
98 }
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104pub enum SharingType {
105 HardSharing,
107 SoftSharing,
109 CrossStitch,
111 MixtureOfExperts,
113 Progressive,
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct Task {
120 pub task_id: String,
121 pub task_name: String,
122 pub task_type: TaskType,
123 pub priority: f64,
124 pub difficulty: f64,
125 pub data_size: usize,
126 pub related_tasks: Vec<String>,
127 pub learning_objective: LearningObjective,
128 pub performance_history: VecDeque<f64>,
129}
130
131#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
133pub enum TaskType {
134 ShapeLearning,
136 PatternClassification,
138 QualityAssessment,
140 AnomalyDetection,
142 ValidationPrediction,
144 ConstraintGeneration,
146 ValidationOptimization,
148 Auxiliary(Box<TaskType>),
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub enum LearningObjective {
155 Classification { num_classes: usize },
156 Regression { min_value: f64, max_value: f64 },
157 Ranking { num_items: usize },
158 Clustering { num_clusters: usize },
159 SequencePrediction { sequence_length: usize },
160}
161
162#[derive(Debug)]
164pub struct SharedEncoder {
165 layers: Vec<SharedLayer>,
166 dimension: usize,
167 dropout_rate: f64,
168 activation_type: ActivationType,
169}
170
171#[derive(Debug, Clone)]
173pub struct SharedLayer {
174 pub weights: Array2<f64>,
175 pub biases: Array1<f64>,
176 pub layer_norm: Option<LayerNormalization>,
177}
178
179#[derive(Debug, Clone)]
181pub struct LayerNormalization {
182 pub gamma: Array1<f64>,
183 pub beta: Array1<f64>,
184 pub epsilon: f64,
185}
186
187#[derive(Debug)]
189pub struct TaskHead {
190 task_id: String,
191 layers: Vec<TaskLayer>,
192 attention_weights: Option<Array1<f64>>,
193 last_gradient_norm: f64,
194}
195
196#[derive(Debug, Clone)]
198pub struct TaskLayer {
199 pub weights: Array2<f64>,
200 pub biases: Array1<f64>,
201}
202
203#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum ActivationType {
206 ReLU,
207 Tanh,
208 Sigmoid,
209 GELU,
210 Swish,
211}
212
213#[derive(Debug)]
215pub struct TaskRelationshipGraph {
216 relationships: HashMap<String, HashMap<String, TaskRelationship>>,
217 affinity_matrix: Array2<f64>,
218}
219
220#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct TaskRelationship {
223 pub source_task: String,
224 pub target_task: String,
225 pub relationship_type: RelationshipType,
226 pub strength: f64,
227 pub transfer_direction: TransferDirection,
228 pub discovered_at: DateTime<Utc>,
229}
230
231#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
233pub enum RelationshipType {
234 HighSimilarity,
236 Complementary,
238 Auxiliary,
240 Independent,
242 Conflicting,
244}
245
246#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
248pub enum TransferDirection {
249 Bidirectional,
250 Forward, Backward, None,
253}
254
255#[derive(Debug)]
257pub struct GradientNormalizer {
258 task_gradient_norms: HashMap<String, VecDeque<f64>>,
259 normalization_method: NormalizationMethod,
260 window_size: usize,
261}
262
263#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
265pub enum NormalizationMethod {
266 GradientMagnitude,
268 GradNorm,
270 UncertaintyWeighting,
272 DynamicWeightAverage,
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct MultiTaskPerformanceTracker {
279 pub task_performances: HashMap<String, TaskPerformance>,
280 pub overall_performance: f64,
281 pub task_interference: HashMap<String, f64>,
282 pub positive_transfer: HashMap<String, f64>,
283 pub negative_transfer: HashMap<String, f64>,
284 pub training_iterations: usize,
285 pub convergence_status: HashMap<String, bool>,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct TaskPerformance {
291 pub task_id: String,
292 pub accuracy: f64,
293 pub loss: f64,
294 pub gradient_norm: f64,
295 pub learning_rate: f64,
296 pub examples_seen: usize,
297 pub improvement_rate: f64,
298 pub relative_improvement: f64, }
300
301#[derive(Debug, Clone)]
303pub struct MultiTaskLearningResult {
304 pub task_results: HashMap<String, TaskResult>,
305 pub shared_representation: Array2<f64>,
306 pub task_relationships_discovered: Vec<TaskRelationship>,
307 pub overall_metrics: MultiTaskMetrics,
308 pub convergence_info: ConvergenceInfo,
309}
310
311#[derive(Debug, Clone)]
313pub struct TaskResult {
314 pub task_id: String,
315 pub learned_model: LearnedTaskModel,
316 pub performance_metrics: ModelMetrics,
317 pub task_weight: f64,
318 pub training_curve: Vec<f64>,
319}
320
321#[derive(Debug, Clone)]
323pub struct LearnedTaskModel {
324 pub task_head_parameters: Vec<Array2<f64>>,
325 pub shared_parameters_contribution: f64,
326 pub attention_weights: Option<Array1<f64>>,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct MultiTaskMetrics {
332 pub average_performance: f64,
333 pub transfer_efficiency: f64,
334 pub parameter_efficiency: f64,
335 pub training_time_saved: f64,
336 pub task_synergy_score: f64,
337 pub negative_transfer_detected: bool,
338}
339
340#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct ConvergenceInfo {
343 pub converged_tasks: HashSet<String>,
344 pub total_iterations: usize,
345 pub average_convergence_time: f64,
346 pub early_stopped_tasks: Vec<String>,
347}
348
349impl MultiTaskLearner {
350 pub fn new() -> Self {
352 Self::with_config(MultiTaskConfig::default())
353 }
354
355 pub fn with_config(config: MultiTaskConfig) -> Self {
357 let shared_encoder = SharedEncoder::new(config.shared_dim, 3, 0.1);
358 let gradient_normalizer = GradientNormalizer::new(NormalizationMethod::GradNorm, 50);
359
360 Self {
361 config,
362 shared_encoder,
363 task_heads: HashMap::new(),
364 task_weights: HashMap::new(),
365 task_relationships: TaskRelationshipGraph::new(),
366 performance_tracker: MultiTaskPerformanceTracker::new(),
367 gradient_normalizer,
368 }
369 }
370
371 pub fn register_task(&mut self, task: Task) -> Result<()> {
373 tracing::info!("Registering task: {} ({})", task.task_name, task.task_id);
374
375 let task_head = TaskHead::new(
377 &task.task_id,
378 &self.config.task_specific_dims,
379 self.config.shared_dim,
380 self.config.enable_task_attention,
381 );
382
383 self.task_heads.insert(task.task_id.clone(), task_head);
384
385 let initial_weight = task.priority;
387 self.task_weights
388 .insert(task.task_id.clone(), initial_weight);
389
390 self.performance_tracker.task_performances.insert(
392 task.task_id.clone(),
393 TaskPerformance {
394 task_id: task.task_id.clone(),
395 accuracy: 0.0,
396 loss: f64::INFINITY,
397 gradient_norm: 0.0,
398 learning_rate: self.config.task_learning_rate,
399 examples_seen: 0,
400 improvement_rate: 0.0,
401 relative_improvement: 0.0,
402 },
403 );
404
405 for existing_task_id in self.task_heads.keys() {
407 if existing_task_id != &task.task_id {
408 let relationship =
409 self.discover_task_relationship(&task.task_id, existing_task_id)?;
410 self.task_relationships.add_relationship(relationship);
411 }
412 }
413
414 tracing::info!("Task {} registered successfully", task.task_id);
415 Ok(())
416 }
417
418 pub fn train_multi_task(
420 &mut self,
421 training_data: &HashMap<String, TaskTrainingData>,
422 max_iterations: usize,
423 ) -> Result<MultiTaskLearningResult> {
424 tracing::info!(
425 "Starting multi-task training with {} tasks",
426 training_data.len()
427 );
428
429 let mut task_results = HashMap::new();
430 let mut converged_tasks = HashSet::new();
431 let training_start = std::time::Instant::now();
432
433 for iteration in 0..max_iterations {
434 let active_tasks = if self.config.enable_curriculum {
436 self.select_tasks_curriculum(training_data, iteration, max_iterations)?
437 } else {
438 training_data.keys().cloned().collect()
439 };
440
441 let mut task_losses = HashMap::new();
443 let mut task_gradients = HashMap::new();
444
445 for task_id in &active_tasks {
446 if let Some(data) = training_data.get(task_id) {
447 let (loss, gradients) = self.compute_task_loss_and_gradients(task_id, data)?;
448 task_losses.insert(task_id.clone(), loss);
449 task_gradients.insert(task_id.clone(), gradients);
450 }
451 }
452
453 if self.config.enable_dynamic_weighting {
455 self.update_task_weights(&task_losses)?;
456 }
457
458 if self.config.enable_gradient_normalization {
460 self.gradient_normalizer
461 .normalize_gradients(&mut task_gradients, &self.task_weights)?;
462 }
463
464 self.update_shared_encoder(&task_gradients)?;
466
467 for task_id in &active_tasks {
469 if let Some(gradients) = task_gradients.get(task_id) {
470 self.update_task_head(task_id, gradients)?;
471 }
472 }
473
474 for task_id in &active_tasks {
476 if let Some(data) = training_data.get(task_id) {
477 let metrics = self.evaluate_task(task_id, data)?;
478 self.update_performance_tracking(task_id, &metrics)?;
479
480 if self.check_task_convergence(task_id)? {
482 converged_tasks.insert(task_id.clone());
483 tracing::info!("Task {} converged at iteration {}", task_id, iteration);
484 }
485 }
486 }
487
488 if converged_tasks.len() == training_data.len() {
490 tracing::info!("All tasks converged at iteration {}", iteration);
491 break;
492 }
493
494 if iteration % 100 == 0 {
496 tracing::debug!(
497 "Iteration {}: {} tasks converged",
498 iteration,
499 converged_tasks.len()
500 );
501 }
502 }
503
504 let discovered_relationships = self.discover_learned_relationships()?;
506
507 for task_id in training_data.keys() {
509 if let Some(task_head) = self.task_heads.get(task_id) {
510 let learned_model = LearnedTaskModel {
511 task_head_parameters: task_head
512 .layers
513 .iter()
514 .map(|l| l.weights.clone())
515 .collect(),
516 shared_parameters_contribution: 0.7, attention_weights: task_head.attention_weights.clone(),
518 };
519
520 let performance_metrics = ModelMetrics {
521 accuracy: self
522 .performance_tracker
523 .task_performances
524 .get(task_id)
525 .map(|p| p.accuracy)
526 .unwrap_or(0.0),
527 precision: 0.85,
528 recall: 0.82,
529 f1_score: 0.83,
530 auc_roc: 0.88,
531 confusion_matrix: vec![vec![80, 20], vec![15, 85]],
532 per_class_metrics: HashMap::new(),
533 training_time: training_start.elapsed(),
534 };
535
536 task_results.insert(
537 task_id.clone(),
538 TaskResult {
539 task_id: task_id.clone(),
540 learned_model,
541 performance_metrics,
542 task_weight: *self.task_weights.get(task_id).unwrap_or(&1.0),
543 training_curve: vec![0.5, 0.65, 0.75, 0.85],
544 },
545 );
546 }
547 }
548
549 let overall_metrics = self.compute_overall_metrics(&task_results)?;
550
551 Ok(MultiTaskLearningResult {
552 task_results,
553 shared_representation: self.shared_encoder.get_representation()?,
554 task_relationships_discovered: discovered_relationships,
555 overall_metrics,
556 convergence_info: ConvergenceInfo {
557 converged_tasks,
558 total_iterations: max_iterations,
559 average_convergence_time: training_start.elapsed().as_secs_f64(),
560 early_stopped_tasks: Vec::new(),
561 },
562 })
563 }
564
565 fn discover_task_relationship(
567 &self,
568 task1_id: &str,
569 task2_id: &str,
570 ) -> Result<TaskRelationship> {
571 let relationship_type = RelationshipType::Complementary;
575 let strength = 0.7;
576 let transfer_direction = TransferDirection::Bidirectional;
577
578 Ok(TaskRelationship {
579 source_task: task1_id.to_string(),
580 target_task: task2_id.to_string(),
581 relationship_type,
582 strength,
583 transfer_direction,
584 discovered_at: Utc::now(),
585 })
586 }
587
588 fn select_tasks_curriculum(
590 &self,
591 training_data: &HashMap<String, TaskTrainingData>,
592 iteration: usize,
593 max_iterations: usize,
594 ) -> Result<Vec<String>> {
595 let progress = iteration as f64 / max_iterations as f64;
596
597 let mut selected_tasks = Vec::new();
598
599 for task_id in training_data.keys() {
600 if let Some(perf) = self.performance_tracker.task_performances.get(task_id) {
602 if progress < 0.3 && perf.gradient_norm >= 1.0 {
604 continue;
606 }
607 selected_tasks.push(task_id.clone());
608 } else {
609 selected_tasks.push(task_id.clone());
610 }
611 }
612
613 Ok(selected_tasks)
614 }
615
616 fn compute_task_loss_and_gradients(
618 &self,
619 task_id: &str,
620 data: &TaskTrainingData,
621 ) -> Result<(f64, TaskGradients)> {
622 let loss = 0.5; let gradients = TaskGradients {
626 shared_gradients: HashMap::new(),
627 task_gradients: HashMap::new(),
628 gradient_norm: 1.0,
629 };
630
631 Ok((loss, gradients))
632 }
633
634 fn update_task_weights(&mut self, task_losses: &HashMap<String, f64>) -> Result<()> {
636 let avg_loss: f64 = task_losses.values().sum::<f64>() / task_losses.len() as f64;
638
639 for (task_id, &loss) in task_losses {
640 let current_weight = self.task_weights.get(task_id).copied().unwrap_or(1.0);
641
642 let loss_ratio = loss / (avg_loss + 1e-8);
644 let new_weight = current_weight * loss_ratio.powf(0.5);
645
646 self.task_weights
647 .insert(task_id.clone(), new_weight.clamp(0.1, 10.0));
648 }
649
650 Ok(())
651 }
652
653 fn update_shared_encoder(
655 &mut self,
656 task_gradients: &HashMap<String, TaskGradients>,
657 ) -> Result<()> {
658 for gradients in task_gradients.values() {
660 for layer in &mut self.shared_encoder.layers {
663 let lr = self.config.shared_learning_rate;
664 let _update = layer.weights.clone() * (1.0 - lr * 0.01);
666 }
667 }
668 Ok(())
669 }
670
671 fn update_task_head(&mut self, task_id: &str, gradients: &TaskGradients) -> Result<()> {
673 if let Some(task_head) = self.task_heads.get_mut(task_id) {
674 let lr = self.config.task_learning_rate;
675
676 for layer in &mut task_head.layers {
677 let _update = layer.weights.clone() * (1.0 - lr * 0.01);
679 }
680
681 task_head.last_gradient_norm = gradients.gradient_norm;
682 }
683 Ok(())
684 }
685
686 fn evaluate_task(&self, task_id: &str, data: &TaskTrainingData) -> Result<ModelMetrics> {
688 Ok(ModelMetrics {
689 accuracy: 0.85,
690 precision: 0.82,
691 recall: 0.88,
692 f1_score: 0.85,
693 auc_roc: 0.90,
694 confusion_matrix: vec![vec![85, 15], vec![12, 88]],
695 per_class_metrics: HashMap::new(),
696 training_time: std::time::Duration::from_secs(10),
697 })
698 }
699
700 fn update_performance_tracking(&mut self, task_id: &str, metrics: &ModelMetrics) -> Result<()> {
702 if let Some(perf) = self.performance_tracker.task_performances.get_mut(task_id) {
703 let prev_accuracy = perf.accuracy;
704 perf.accuracy = metrics.accuracy;
705 perf.improvement_rate = metrics.accuracy - prev_accuracy;
706 perf.examples_seen += 100; }
708 Ok(())
709 }
710
711 fn check_task_convergence(&self, task_id: &str) -> Result<bool> {
713 if let Some(perf) = self.performance_tracker.task_performances.get(task_id) {
714 Ok(perf.accuracy > 0.9 && perf.improvement_rate.abs() < 0.001)
716 } else {
717 Ok(false)
718 }
719 }
720
721 fn discover_learned_relationships(&self) -> Result<Vec<TaskRelationship>> {
723 let mut relationships = Vec::new();
724
725 for (task1_id, perf1) in &self.performance_tracker.task_performances {
727 for (task2_id, perf2) in &self.performance_tracker.task_performances {
728 if task1_id < task2_id {
729 let correlation = (perf1.accuracy + perf2.accuracy) / 2.0;
731
732 let relationship_type = if correlation > 0.85 {
733 RelationshipType::HighSimilarity
734 } else if correlation > 0.7 {
735 RelationshipType::Complementary
736 } else {
737 RelationshipType::Independent
738 };
739
740 relationships.push(TaskRelationship {
741 source_task: task1_id.clone(),
742 target_task: task2_id.clone(),
743 relationship_type,
744 strength: correlation,
745 transfer_direction: TransferDirection::Bidirectional,
746 discovered_at: Utc::now(),
747 });
748 }
749 }
750 }
751
752 Ok(relationships)
753 }
754
755 fn compute_overall_metrics(
757 &self,
758 task_results: &HashMap<String, TaskResult>,
759 ) -> Result<MultiTaskMetrics> {
760 let average_performance: f64 = task_results
761 .values()
762 .map(|r| r.performance_metrics.accuracy)
763 .sum::<f64>()
764 / task_results.len() as f64;
765
766 Ok(MultiTaskMetrics {
767 average_performance,
768 transfer_efficiency: 0.85,
769 parameter_efficiency: 0.7, training_time_saved: 0.4, task_synergy_score: 0.8,
772 negative_transfer_detected: false,
773 })
774 }
775
776 pub fn get_performance_stats(&self) -> &MultiTaskPerformanceTracker {
778 &self.performance_tracker
779 }
780}
781
782impl SharedEncoder {
785 fn new(dimension: usize, num_layers: usize, dropout: f64) -> Self {
786 let mut layers = Vec::new();
787 for _ in 0..num_layers {
788 layers.push(SharedLayer {
789 weights: Array2::zeros((dimension, dimension)),
790 biases: Array1::zeros(dimension),
791 layer_norm: Some(LayerNormalization {
792 gamma: Array1::ones(dimension),
793 beta: Array1::zeros(dimension),
794 epsilon: 1e-5,
795 }),
796 });
797 }
798
799 Self {
800 layers,
801 dimension,
802 dropout_rate: dropout,
803 activation_type: ActivationType::ReLU,
804 }
805 }
806
807 fn get_representation(&self) -> Result<Array2<f64>> {
808 Ok(Array2::zeros((self.dimension, self.dimension)))
809 }
810}
811
812impl TaskHead {
813 fn new(task_id: &str, layer_dims: &[usize], input_dim: usize, enable_attention: bool) -> Self {
814 let mut layers = Vec::new();
815 let mut prev_dim = input_dim;
816
817 for &dim in layer_dims {
818 layers.push(TaskLayer {
819 weights: Array2::zeros((prev_dim, dim)),
820 biases: Array1::zeros(dim),
821 });
822 prev_dim = dim;
823 }
824
825 let attention_weights = if enable_attention {
826 Some(Array1::ones(input_dim) / input_dim as f64)
827 } else {
828 None
829 };
830
831 Self {
832 task_id: task_id.to_string(),
833 layers,
834 attention_weights,
835 last_gradient_norm: 0.0,
836 }
837 }
838}
839
840impl TaskRelationshipGraph {
841 fn new() -> Self {
842 Self {
843 relationships: HashMap::new(),
844 affinity_matrix: Array2::zeros((0, 0)),
845 }
846 }
847
848 fn add_relationship(&mut self, relationship: TaskRelationship) {
849 self.relationships
850 .entry(relationship.source_task.clone())
851 .or_default()
852 .insert(relationship.target_task.clone(), relationship);
853 }
854}
855
856impl GradientNormalizer {
857 fn new(method: NormalizationMethod, window: usize) -> Self {
858 Self {
859 task_gradient_norms: HashMap::new(),
860 normalization_method: method,
861 window_size: window,
862 }
863 }
864
865 fn normalize_gradients(
866 &mut self,
867 gradients: &mut HashMap<String, TaskGradients>,
868 task_weights: &HashMap<String, f64>,
869 ) -> Result<()> {
870 let avg_norm: f64 =
872 gradients.values().map(|g| g.gradient_norm).sum::<f64>() / gradients.len() as f64;
873
874 for (task_id, task_gradients) in gradients.iter_mut() {
875 let weight = task_weights.get(task_id).copied().unwrap_or(1.0);
876 let scale = weight * avg_norm / (task_gradients.gradient_norm + 1e-8);
877 task_gradients.gradient_norm *= scale;
878 }
879
880 Ok(())
881 }
882}
883
884impl MultiTaskPerformanceTracker {
885 fn new() -> Self {
886 Self {
887 task_performances: HashMap::new(),
888 overall_performance: 0.0,
889 task_interference: HashMap::new(),
890 positive_transfer: HashMap::new(),
891 negative_transfer: HashMap::new(),
892 training_iterations: 0,
893 convergence_status: HashMap::new(),
894 }
895 }
896}
897
898#[derive(Debug, Clone)]
900pub struct TaskTrainingData {
901 pub task_id: String,
902 pub inputs: Array2<f64>,
903 pub targets: Array2<f64>,
904 pub sample_weights: Option<Array1<f64>>,
905}
906
907#[derive(Debug, Clone)]
909pub struct TaskGradients {
910 pub shared_gradients: HashMap<String, Array2<f64>>,
911 pub task_gradients: HashMap<String, Array2<f64>>,
912 pub gradient_norm: f64,
913}
914
915impl Default for MultiTaskLearner {
916 fn default() -> Self {
917 Self::new()
918 }
919}
920
921#[cfg(test)]
922mod tests {
923 use super::*;
924
925 #[test]
926 fn test_multi_task_learner_creation() {
927 let learner = MultiTaskLearner::new();
928 assert_eq!(learner.config.shared_dim, 256);
929 assert!(learner.config.enable_dynamic_weighting);
930 }
931
932 #[test]
933 fn test_task_registration() {
934 let mut learner = MultiTaskLearner::new();
935 let task = Task {
936 task_id: "test_task".to_string(),
937 task_name: "Test Task".to_string(),
938 task_type: TaskType::ShapeLearning,
939 priority: 1.0,
940 difficulty: 0.5,
941 data_size: 1000,
942 related_tasks: Vec::new(),
943 learning_objective: LearningObjective::Classification { num_classes: 5 },
944 performance_history: VecDeque::new(),
945 };
946
947 learner.register_task(task).expect("should succeed");
948 assert_eq!(learner.task_heads.len(), 1);
949 assert_eq!(learner.task_weights.len(), 1);
950 }
951
952 #[test]
953 fn test_multi_task_config() {
954 let config = MultiTaskConfig {
955 sharing_type: SharingType::SoftSharing,
956 shared_dim: 128,
957 enable_curriculum: false,
958 ..Default::default()
959 };
960
961 assert_eq!(config.sharing_type, SharingType::SoftSharing);
962 assert_eq!(config.shared_dim, 128);
963 assert!(!config.enable_curriculum);
964 }
965
966 #[test]
967 fn test_task_relationship() {
968 let relationship = TaskRelationship {
969 source_task: "task1".to_string(),
970 target_task: "task2".to_string(),
971 relationship_type: RelationshipType::Complementary,
972 strength: 0.8,
973 transfer_direction: TransferDirection::Bidirectional,
974 discovered_at: Utc::now(),
975 };
976
977 assert_eq!(relationship.strength, 0.8);
978 assert_eq!(
979 relationship.relationship_type,
980 RelationshipType::Complementary
981 );
982 }
983}