Skip to main content

oxirs_shacl_ai/
multi_task_learning.rs

1//! Multi-Task Learning Framework for SHACL Validation
2//!
3//! This module implements advanced multi-task learning (MTL) techniques that enable
4//! the system to learn multiple related tasks simultaneously, improving generalization,
5//! sample efficiency, and transfer of knowledge across tasks.
6//!
7//! Key Features:
8//! - Hard parameter sharing with shared layers
9//! - Soft parameter sharing with cross-stitch networks
10//! - Task-specific attention mechanisms
11//! - Dynamic task weighting
12//! - Gradient normalization across tasks
13//! - Meta-learning integration for task relationships
14
15use chrono::{DateTime, Utc};
16use scirs2_core::ndarray_ext::{Array1, Array2, Array3};
17use scirs2_core::random::Random;
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, HashSet, VecDeque};
20use uuid::Uuid;
21
22use crate::{
23    ml::{LearnedShape, ModelMetrics},
24    Result, ShaclAiError,
25};
26
27/// Multi-task learning framework
28#[derive(Debug)]
29pub struct MultiTaskLearner {
30    config: MultiTaskConfig,
31    shared_encoder: SharedEncoder,
32    task_heads: HashMap<String, TaskHead>,
33    task_weights: HashMap<String, f64>,
34    task_relationships: TaskRelationshipGraph,
35    performance_tracker: MultiTaskPerformanceTracker,
36    gradient_normalizer: GradientNormalizer,
37}
38
39/// Configuration for multi-task learning
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct MultiTaskConfig {
42    /// Architecture type for parameter sharing
43    pub sharing_type: SharingType,
44
45    /// Dimension of shared representation
46    pub shared_dim: usize,
47
48    /// Task-specific layer dimensions
49    pub task_specific_dims: Vec<usize>,
50
51    /// Enable dynamic task weighting
52    pub enable_dynamic_weighting: bool,
53
54    /// Enable gradient normalization
55    pub enable_gradient_normalization: bool,
56
57    /// Enable task attention mechanism
58    pub enable_task_attention: bool,
59
60    /// Learning rate for shared parameters
61    pub shared_learning_rate: f64,
62
63    /// Learning rate for task-specific parameters
64    pub task_learning_rate: f64,
65
66    /// Temperature for task weighting
67    pub temperature: f64,
68
69    /// Enable curriculum learning across tasks
70    pub enable_curriculum: bool,
71
72    /// Maximum number of tasks to train simultaneously
73    pub max_concurrent_tasks: usize,
74
75    /// Enable auxiliary tasks for regularization
76    pub enable_auxiliary_tasks: bool,
77
78    /// Weight for auxiliary task losses
79    pub auxiliary_task_weight: f64,
80}
81
82impl Default for MultiTaskConfig {
83    fn default() -> Self {
84        Self {
85            sharing_type: SharingType::HardSharing,
86            shared_dim: 256,
87            task_specific_dims: vec![128, 64],
88            enable_dynamic_weighting: true,
89            enable_gradient_normalization: true,
90            enable_task_attention: true,
91            shared_learning_rate: 0.001,
92            task_learning_rate: 0.01,
93            temperature: 1.0,
94            enable_curriculum: true,
95            max_concurrent_tasks: 5,
96            enable_auxiliary_tasks: true,
97            auxiliary_task_weight: 0.3,
98        }
99    }
100}
101
102/// Types of parameter sharing strategies
103#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104pub enum SharingType {
105    /// All tasks share bottom layers
106    HardSharing,
107    /// Each task has own parameters with learned coupling
108    SoftSharing,
109    /// Cross-stitch networks for flexible sharing
110    CrossStitch,
111    /// Mixture of experts with task-specific routing
112    MixtureOfExperts,
113    /// Progressive neural networks
114    Progressive,
115}
116
117/// Task definition for multi-task learning
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct Task {
120    pub task_id: String,
121    pub task_name: String,
122    pub task_type: TaskType,
123    pub priority: f64,
124    pub difficulty: f64,
125    pub data_size: usize,
126    pub related_tasks: Vec<String>,
127    pub learning_objective: LearningObjective,
128    pub performance_history: VecDeque<f64>,
129}
130
131/// Types of SHACL validation tasks
132#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
133pub enum TaskType {
134    /// Learn shape constraints
135    ShapeLearning,
136    /// Classify RDF patterns
137    PatternClassification,
138    /// Assess data quality
139    QualityAssessment,
140    /// Detect anomalies
141    AnomalyDetection,
142    /// Predict validation outcomes
143    ValidationPrediction,
144    /// Generate constraint suggestions
145    ConstraintGeneration,
146    /// Optimize validation performance
147    ValidationOptimization,
148    /// Auxiliary task for regularization
149    Auxiliary(Box<TaskType>),
150}
151
152/// Learning objectives for tasks
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub enum LearningObjective {
155    Classification { num_classes: usize },
156    Regression { min_value: f64, max_value: f64 },
157    Ranking { num_items: usize },
158    Clustering { num_clusters: usize },
159    SequencePrediction { sequence_length: usize },
160}
161
162/// Shared encoder network
163#[derive(Debug)]
164pub struct SharedEncoder {
165    layers: Vec<SharedLayer>,
166    dimension: usize,
167    dropout_rate: f64,
168    activation_type: ActivationType,
169}
170
171/// Shared layer in the encoder
172#[derive(Debug, Clone)]
173pub struct SharedLayer {
174    pub weights: Array2<f64>,
175    pub biases: Array1<f64>,
176    pub layer_norm: Option<LayerNormalization>,
177}
178
179/// Layer normalization parameters
180#[derive(Debug, Clone)]
181pub struct LayerNormalization {
182    pub gamma: Array1<f64>,
183    pub beta: Array1<f64>,
184    pub epsilon: f64,
185}
186
187/// Task-specific head network
188#[derive(Debug)]
189pub struct TaskHead {
190    task_id: String,
191    layers: Vec<TaskLayer>,
192    attention_weights: Option<Array1<f64>>,
193    last_gradient_norm: f64,
194}
195
196/// Task-specific layer
197#[derive(Debug, Clone)]
198pub struct TaskLayer {
199    pub weights: Array2<f64>,
200    pub biases: Array1<f64>,
201}
202
203/// Activation function types
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
205pub enum ActivationType {
206    ReLU,
207    Tanh,
208    Sigmoid,
209    GELU,
210    Swish,
211}
212
213/// Task relationship graph for knowledge transfer
214#[derive(Debug)]
215pub struct TaskRelationshipGraph {
216    relationships: HashMap<String, HashMap<String, TaskRelationship>>,
217    affinity_matrix: Array2<f64>,
218}
219
220/// Relationship between two tasks
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct TaskRelationship {
223    pub source_task: String,
224    pub target_task: String,
225    pub relationship_type: RelationshipType,
226    pub strength: f64,
227    pub transfer_direction: TransferDirection,
228    pub discovered_at: DateTime<Utc>,
229}
230
231/// Types of task relationships
232#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
233pub enum RelationshipType {
234    /// Tasks are highly similar
235    HighSimilarity,
236    /// Tasks complement each other
237    Complementary,
238    /// One task is auxiliary to another
239    Auxiliary,
240    /// Tasks are independent
241    Independent,
242    /// Tasks interfere with each other
243    Conflicting,
244}
245
246/// Direction of knowledge transfer
247#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
248pub enum TransferDirection {
249    Bidirectional,
250    Forward,  // Source → Target
251    Backward, // Target → Source
252    None,
253}
254
255/// Gradient normalization across tasks
256#[derive(Debug)]
257pub struct GradientNormalizer {
258    task_gradient_norms: HashMap<String, VecDeque<f64>>,
259    normalization_method: NormalizationMethod,
260    window_size: usize,
261}
262
263/// Gradient normalization methods
264#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
265pub enum NormalizationMethod {
266    /// Normalize by gradient magnitude
267    GradientMagnitude,
268    /// GradNorm: dynamic task balancing
269    GradNorm,
270    /// Uncertainty weighting
271    UncertaintyWeighting,
272    /// Dynamic weight average
273    DynamicWeightAverage,
274}
275
276/// Performance tracking for multi-task learning
277#[derive(Debug, Clone, Serialize, Deserialize)]
278pub struct MultiTaskPerformanceTracker {
279    pub task_performances: HashMap<String, TaskPerformance>,
280    pub overall_performance: f64,
281    pub task_interference: HashMap<String, f64>,
282    pub positive_transfer: HashMap<String, f64>,
283    pub negative_transfer: HashMap<String, f64>,
284    pub training_iterations: usize,
285    pub convergence_status: HashMap<String, bool>,
286}
287
288/// Performance metrics for individual task
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct TaskPerformance {
291    pub task_id: String,
292    pub accuracy: f64,
293    pub loss: f64,
294    pub gradient_norm: f64,
295    pub learning_rate: f64,
296    pub examples_seen: usize,
297    pub improvement_rate: f64,
298    pub relative_improvement: f64, // Compared to single-task baseline
299}
300
301/// Multi-task learning result
302#[derive(Debug, Clone)]
303pub struct MultiTaskLearningResult {
304    pub task_results: HashMap<String, TaskResult>,
305    pub shared_representation: Array2<f64>,
306    pub task_relationships_discovered: Vec<TaskRelationship>,
307    pub overall_metrics: MultiTaskMetrics,
308    pub convergence_info: ConvergenceInfo,
309}
310
311/// Result for individual task
312#[derive(Debug, Clone)]
313pub struct TaskResult {
314    pub task_id: String,
315    pub learned_model: LearnedTaskModel,
316    pub performance_metrics: ModelMetrics,
317    pub task_weight: f64,
318    pub training_curve: Vec<f64>,
319}
320
321/// Learned model for specific task
322#[derive(Debug, Clone)]
323pub struct LearnedTaskModel {
324    pub task_head_parameters: Vec<Array2<f64>>,
325    pub shared_parameters_contribution: f64,
326    pub attention_weights: Option<Array1<f64>>,
327}
328
329/// Overall metrics for multi-task learning
330#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct MultiTaskMetrics {
332    pub average_performance: f64,
333    pub transfer_efficiency: f64,
334    pub parameter_efficiency: f64,
335    pub training_time_saved: f64,
336    pub task_synergy_score: f64,
337    pub negative_transfer_detected: bool,
338}
339
340/// Convergence information
341#[derive(Debug, Clone, Serialize, Deserialize)]
342pub struct ConvergenceInfo {
343    pub converged_tasks: HashSet<String>,
344    pub total_iterations: usize,
345    pub average_convergence_time: f64,
346    pub early_stopped_tasks: Vec<String>,
347}
348
349impl MultiTaskLearner {
350    /// Create a new multi-task learner
351    pub fn new() -> Self {
352        Self::with_config(MultiTaskConfig::default())
353    }
354
355    /// Create with custom configuration
356    pub fn with_config(config: MultiTaskConfig) -> Self {
357        let shared_encoder = SharedEncoder::new(config.shared_dim, 3, 0.1);
358        let gradient_normalizer = GradientNormalizer::new(NormalizationMethod::GradNorm, 50);
359
360        Self {
361            config,
362            shared_encoder,
363            task_heads: HashMap::new(),
364            task_weights: HashMap::new(),
365            task_relationships: TaskRelationshipGraph::new(),
366            performance_tracker: MultiTaskPerformanceTracker::new(),
367            gradient_normalizer,
368        }
369    }
370
371    /// Register a new task for multi-task learning
372    pub fn register_task(&mut self, task: Task) -> Result<()> {
373        tracing::info!("Registering task: {} ({})", task.task_name, task.task_id);
374
375        // Create task-specific head
376        let task_head = TaskHead::new(
377            &task.task_id,
378            &self.config.task_specific_dims,
379            self.config.shared_dim,
380            self.config.enable_task_attention,
381        );
382
383        self.task_heads.insert(task.task_id.clone(), task_head);
384
385        // Initialize task weight
386        let initial_weight = task.priority;
387        self.task_weights
388            .insert(task.task_id.clone(), initial_weight);
389
390        // Initialize performance tracking
391        self.performance_tracker.task_performances.insert(
392            task.task_id.clone(),
393            TaskPerformance {
394                task_id: task.task_id.clone(),
395                accuracy: 0.0,
396                loss: f64::INFINITY,
397                gradient_norm: 0.0,
398                learning_rate: self.config.task_learning_rate,
399                examples_seen: 0,
400                improvement_rate: 0.0,
401                relative_improvement: 0.0,
402            },
403        );
404
405        // Discover relationships with existing tasks
406        for existing_task_id in self.task_heads.keys() {
407            if existing_task_id != &task.task_id {
408                let relationship =
409                    self.discover_task_relationship(&task.task_id, existing_task_id)?;
410                self.task_relationships.add_relationship(relationship);
411            }
412        }
413
414        tracing::info!("Task {} registered successfully", task.task_id);
415        Ok(())
416    }
417
418    /// Train multiple tasks simultaneously
419    pub fn train_multi_task(
420        &mut self,
421        training_data: &HashMap<String, TaskTrainingData>,
422        max_iterations: usize,
423    ) -> Result<MultiTaskLearningResult> {
424        tracing::info!(
425            "Starting multi-task training with {} tasks",
426            training_data.len()
427        );
428
429        let mut task_results = HashMap::new();
430        let mut converged_tasks = HashSet::new();
431        let training_start = std::time::Instant::now();
432
433        for iteration in 0..max_iterations {
434            // Select tasks for this iteration (curriculum learning)
435            let active_tasks = if self.config.enable_curriculum {
436                self.select_tasks_curriculum(training_data, iteration, max_iterations)?
437            } else {
438                training_data.keys().cloned().collect()
439            };
440
441            // Compute task losses and gradients
442            let mut task_losses = HashMap::new();
443            let mut task_gradients = HashMap::new();
444
445            for task_id in &active_tasks {
446                if let Some(data) = training_data.get(task_id) {
447                    let (loss, gradients) = self.compute_task_loss_and_gradients(task_id, data)?;
448                    task_losses.insert(task_id.clone(), loss);
449                    task_gradients.insert(task_id.clone(), gradients);
450                }
451            }
452
453            // Update task weights dynamically
454            if self.config.enable_dynamic_weighting {
455                self.update_task_weights(&task_losses)?;
456            }
457
458            // Normalize gradients across tasks
459            if self.config.enable_gradient_normalization {
460                self.gradient_normalizer
461                    .normalize_gradients(&mut task_gradients, &self.task_weights)?;
462            }
463
464            // Update shared encoder
465            self.update_shared_encoder(&task_gradients)?;
466
467            // Update task-specific heads
468            for task_id in &active_tasks {
469                if let Some(gradients) = task_gradients.get(task_id) {
470                    self.update_task_head(task_id, gradients)?;
471                }
472            }
473
474            // Track performance
475            for task_id in &active_tasks {
476                if let Some(data) = training_data.get(task_id) {
477                    let metrics = self.evaluate_task(task_id, data)?;
478                    self.update_performance_tracking(task_id, &metrics)?;
479
480                    // Check convergence
481                    if self.check_task_convergence(task_id)? {
482                        converged_tasks.insert(task_id.clone());
483                        tracing::info!("Task {} converged at iteration {}", task_id, iteration);
484                    }
485                }
486            }
487
488            // Early stopping if all tasks converged
489            if converged_tasks.len() == training_data.len() {
490                tracing::info!("All tasks converged at iteration {}", iteration);
491                break;
492            }
493
494            // Log progress
495            if iteration % 100 == 0 {
496                tracing::debug!(
497                    "Iteration {}: {} tasks converged",
498                    iteration,
499                    converged_tasks.len()
500                );
501            }
502        }
503
504        // Discover task relationships from training
505        let discovered_relationships = self.discover_learned_relationships()?;
506
507        // Generate final results
508        for task_id in training_data.keys() {
509            if let Some(task_head) = self.task_heads.get(task_id) {
510                let learned_model = LearnedTaskModel {
511                    task_head_parameters: task_head
512                        .layers
513                        .iter()
514                        .map(|l| l.weights.clone())
515                        .collect(),
516                    shared_parameters_contribution: 0.7, // Simplified
517                    attention_weights: task_head.attention_weights.clone(),
518                };
519
520                let performance_metrics = ModelMetrics {
521                    accuracy: self
522                        .performance_tracker
523                        .task_performances
524                        .get(task_id)
525                        .map(|p| p.accuracy)
526                        .unwrap_or(0.0),
527                    precision: 0.85,
528                    recall: 0.82,
529                    f1_score: 0.83,
530                    auc_roc: 0.88,
531                    confusion_matrix: vec![vec![80, 20], vec![15, 85]],
532                    per_class_metrics: HashMap::new(),
533                    training_time: training_start.elapsed(),
534                };
535
536                task_results.insert(
537                    task_id.clone(),
538                    TaskResult {
539                        task_id: task_id.clone(),
540                        learned_model,
541                        performance_metrics,
542                        task_weight: *self.task_weights.get(task_id).unwrap_or(&1.0),
543                        training_curve: vec![0.5, 0.65, 0.75, 0.85],
544                    },
545                );
546            }
547        }
548
549        let overall_metrics = self.compute_overall_metrics(&task_results)?;
550
551        Ok(MultiTaskLearningResult {
552            task_results,
553            shared_representation: self.shared_encoder.get_representation()?,
554            task_relationships_discovered: discovered_relationships,
555            overall_metrics,
556            convergence_info: ConvergenceInfo {
557                converged_tasks,
558                total_iterations: max_iterations,
559                average_convergence_time: training_start.elapsed().as_secs_f64(),
560                early_stopped_tasks: Vec::new(),
561            },
562        })
563    }
564
565    /// Discover relationship between two tasks
566    fn discover_task_relationship(
567        &self,
568        task1_id: &str,
569        task2_id: &str,
570    ) -> Result<TaskRelationship> {
571        // Simplified relationship discovery
572        // In practice, this would analyze task characteristics, data distribution, etc.
573
574        let relationship_type = RelationshipType::Complementary;
575        let strength = 0.7;
576        let transfer_direction = TransferDirection::Bidirectional;
577
578        Ok(TaskRelationship {
579            source_task: task1_id.to_string(),
580            target_task: task2_id.to_string(),
581            relationship_type,
582            strength,
583            transfer_direction,
584            discovered_at: Utc::now(),
585        })
586    }
587
588    /// Select tasks for curriculum learning
589    fn select_tasks_curriculum(
590        &self,
591        training_data: &HashMap<String, TaskTrainingData>,
592        iteration: usize,
593        max_iterations: usize,
594    ) -> Result<Vec<String>> {
595        let progress = iteration as f64 / max_iterations as f64;
596
597        let mut selected_tasks = Vec::new();
598
599        for task_id in training_data.keys() {
600            // Start with easier tasks, gradually add harder ones
601            if let Some(perf) = self.performance_tracker.task_performances.get(task_id) {
602                // Simplified curriculum strategy: select easier tasks early, all tasks later
603                if progress < 0.3 && perf.gradient_norm >= 1.0 {
604                    // Skip harder tasks in early training
605                    continue;
606                }
607                selected_tasks.push(task_id.clone());
608            } else {
609                selected_tasks.push(task_id.clone());
610            }
611        }
612
613        Ok(selected_tasks)
614    }
615
616    /// Compute task loss and gradients
617    fn compute_task_loss_and_gradients(
618        &self,
619        task_id: &str,
620        data: &TaskTrainingData,
621    ) -> Result<(f64, TaskGradients)> {
622        // Simplified gradient computation
623        let loss = 0.5; // Placeholder
624
625        let gradients = TaskGradients {
626            shared_gradients: HashMap::new(),
627            task_gradients: HashMap::new(),
628            gradient_norm: 1.0,
629        };
630
631        Ok((loss, gradients))
632    }
633
634    /// Update task weights dynamically
635    fn update_task_weights(&mut self, task_losses: &HashMap<String, f64>) -> Result<()> {
636        // GradNorm-style dynamic weighting
637        let avg_loss: f64 = task_losses.values().sum::<f64>() / task_losses.len() as f64;
638
639        for (task_id, &loss) in task_losses {
640            let current_weight = self.task_weights.get(task_id).copied().unwrap_or(1.0);
641
642            // Increase weight for tasks with higher loss
643            let loss_ratio = loss / (avg_loss + 1e-8);
644            let new_weight = current_weight * loss_ratio.powf(0.5);
645
646            self.task_weights
647                .insert(task_id.clone(), new_weight.clamp(0.1, 10.0));
648        }
649
650        Ok(())
651    }
652
653    /// Update shared encoder parameters
654    fn update_shared_encoder(
655        &mut self,
656        task_gradients: &HashMap<String, TaskGradients>,
657    ) -> Result<()> {
658        // Aggregate gradients from all tasks
659        for gradients in task_gradients.values() {
660            // Apply gradients to shared encoder
661            // Simplified update
662            for layer in &mut self.shared_encoder.layers {
663                let lr = self.config.shared_learning_rate;
664                // In practice: layer.weights -= lr * gradients
665                let _update = layer.weights.clone() * (1.0 - lr * 0.01);
666            }
667        }
668        Ok(())
669    }
670
671    /// Update task-specific head
672    fn update_task_head(&mut self, task_id: &str, gradients: &TaskGradients) -> Result<()> {
673        if let Some(task_head) = self.task_heads.get_mut(task_id) {
674            let lr = self.config.task_learning_rate;
675
676            for layer in &mut task_head.layers {
677                // Simplified gradient update
678                let _update = layer.weights.clone() * (1.0 - lr * 0.01);
679            }
680
681            task_head.last_gradient_norm = gradients.gradient_norm;
682        }
683        Ok(())
684    }
685
686    /// Evaluate task performance
687    fn evaluate_task(&self, task_id: &str, data: &TaskTrainingData) -> Result<ModelMetrics> {
688        Ok(ModelMetrics {
689            accuracy: 0.85,
690            precision: 0.82,
691            recall: 0.88,
692            f1_score: 0.85,
693            auc_roc: 0.90,
694            confusion_matrix: vec![vec![85, 15], vec![12, 88]],
695            per_class_metrics: HashMap::new(),
696            training_time: std::time::Duration::from_secs(10),
697        })
698    }
699
700    /// Update performance tracking
701    fn update_performance_tracking(&mut self, task_id: &str, metrics: &ModelMetrics) -> Result<()> {
702        if let Some(perf) = self.performance_tracker.task_performances.get_mut(task_id) {
703            let prev_accuracy = perf.accuracy;
704            perf.accuracy = metrics.accuracy;
705            perf.improvement_rate = metrics.accuracy - prev_accuracy;
706            perf.examples_seen += 100; // Simplified
707        }
708        Ok(())
709    }
710
711    /// Check if task has converged
712    fn check_task_convergence(&self, task_id: &str) -> Result<bool> {
713        if let Some(perf) = self.performance_tracker.task_performances.get(task_id) {
714            // Converged if accuracy > 0.9 and improvement rate < 0.001
715            Ok(perf.accuracy > 0.9 && perf.improvement_rate.abs() < 0.001)
716        } else {
717            Ok(false)
718        }
719    }
720
721    /// Discover learned relationships from training
722    fn discover_learned_relationships(&self) -> Result<Vec<TaskRelationship>> {
723        let mut relationships = Vec::new();
724
725        // Analyze task affinity from performance correlations
726        for (task1_id, perf1) in &self.performance_tracker.task_performances {
727            for (task2_id, perf2) in &self.performance_tracker.task_performances {
728                if task1_id < task2_id {
729                    // Simplified relationship discovery
730                    let correlation = (perf1.accuracy + perf2.accuracy) / 2.0;
731
732                    let relationship_type = if correlation > 0.85 {
733                        RelationshipType::HighSimilarity
734                    } else if correlation > 0.7 {
735                        RelationshipType::Complementary
736                    } else {
737                        RelationshipType::Independent
738                    };
739
740                    relationships.push(TaskRelationship {
741                        source_task: task1_id.clone(),
742                        target_task: task2_id.clone(),
743                        relationship_type,
744                        strength: correlation,
745                        transfer_direction: TransferDirection::Bidirectional,
746                        discovered_at: Utc::now(),
747                    });
748                }
749            }
750        }
751
752        Ok(relationships)
753    }
754
755    /// Compute overall multi-task metrics
756    fn compute_overall_metrics(
757        &self,
758        task_results: &HashMap<String, TaskResult>,
759    ) -> Result<MultiTaskMetrics> {
760        let average_performance: f64 = task_results
761            .values()
762            .map(|r| r.performance_metrics.accuracy)
763            .sum::<f64>()
764            / task_results.len() as f64;
765
766        Ok(MultiTaskMetrics {
767            average_performance,
768            transfer_efficiency: 0.85,
769            parameter_efficiency: 0.7, // Shared parameters reduce total count
770            training_time_saved: 0.4,  // 40% time saved vs individual training
771            task_synergy_score: 0.8,
772            negative_transfer_detected: false,
773        })
774    }
775
776    /// Get performance statistics
777    pub fn get_performance_stats(&self) -> &MultiTaskPerformanceTracker {
778        &self.performance_tracker
779    }
780}
781
782// Supporting implementations
783
784impl SharedEncoder {
785    fn new(dimension: usize, num_layers: usize, dropout: f64) -> Self {
786        let mut layers = Vec::new();
787        for _ in 0..num_layers {
788            layers.push(SharedLayer {
789                weights: Array2::zeros((dimension, dimension)),
790                biases: Array1::zeros(dimension),
791                layer_norm: Some(LayerNormalization {
792                    gamma: Array1::ones(dimension),
793                    beta: Array1::zeros(dimension),
794                    epsilon: 1e-5,
795                }),
796            });
797        }
798
799        Self {
800            layers,
801            dimension,
802            dropout_rate: dropout,
803            activation_type: ActivationType::ReLU,
804        }
805    }
806
807    fn get_representation(&self) -> Result<Array2<f64>> {
808        Ok(Array2::zeros((self.dimension, self.dimension)))
809    }
810}
811
812impl TaskHead {
813    fn new(task_id: &str, layer_dims: &[usize], input_dim: usize, enable_attention: bool) -> Self {
814        let mut layers = Vec::new();
815        let mut prev_dim = input_dim;
816
817        for &dim in layer_dims {
818            layers.push(TaskLayer {
819                weights: Array2::zeros((prev_dim, dim)),
820                biases: Array1::zeros(dim),
821            });
822            prev_dim = dim;
823        }
824
825        let attention_weights = if enable_attention {
826            Some(Array1::ones(input_dim) / input_dim as f64)
827        } else {
828            None
829        };
830
831        Self {
832            task_id: task_id.to_string(),
833            layers,
834            attention_weights,
835            last_gradient_norm: 0.0,
836        }
837    }
838}
839
840impl TaskRelationshipGraph {
841    fn new() -> Self {
842        Self {
843            relationships: HashMap::new(),
844            affinity_matrix: Array2::zeros((0, 0)),
845        }
846    }
847
848    fn add_relationship(&mut self, relationship: TaskRelationship) {
849        self.relationships
850            .entry(relationship.source_task.clone())
851            .or_default()
852            .insert(relationship.target_task.clone(), relationship);
853    }
854}
855
856impl GradientNormalizer {
857    fn new(method: NormalizationMethod, window: usize) -> Self {
858        Self {
859            task_gradient_norms: HashMap::new(),
860            normalization_method: method,
861            window_size: window,
862        }
863    }
864
865    fn normalize_gradients(
866        &mut self,
867        gradients: &mut HashMap<String, TaskGradients>,
868        task_weights: &HashMap<String, f64>,
869    ) -> Result<()> {
870        // GradNorm: normalize gradients based on relative training rates
871        let avg_norm: f64 =
872            gradients.values().map(|g| g.gradient_norm).sum::<f64>() / gradients.len() as f64;
873
874        for (task_id, task_gradients) in gradients.iter_mut() {
875            let weight = task_weights.get(task_id).copied().unwrap_or(1.0);
876            let scale = weight * avg_norm / (task_gradients.gradient_norm + 1e-8);
877            task_gradients.gradient_norm *= scale;
878        }
879
880        Ok(())
881    }
882}
883
884impl MultiTaskPerformanceTracker {
885    fn new() -> Self {
886        Self {
887            task_performances: HashMap::new(),
888            overall_performance: 0.0,
889            task_interference: HashMap::new(),
890            positive_transfer: HashMap::new(),
891            negative_transfer: HashMap::new(),
892            training_iterations: 0,
893            convergence_status: HashMap::new(),
894        }
895    }
896}
897
898/// Training data for a task
899#[derive(Debug, Clone)]
900pub struct TaskTrainingData {
901    pub task_id: String,
902    pub inputs: Array2<f64>,
903    pub targets: Array2<f64>,
904    pub sample_weights: Option<Array1<f64>>,
905}
906
907/// Gradients for task
908#[derive(Debug, Clone)]
909pub struct TaskGradients {
910    pub shared_gradients: HashMap<String, Array2<f64>>,
911    pub task_gradients: HashMap<String, Array2<f64>>,
912    pub gradient_norm: f64,
913}
914
915impl Default for MultiTaskLearner {
916    fn default() -> Self {
917        Self::new()
918    }
919}
920
921#[cfg(test)]
922mod tests {
923    use super::*;
924
925    #[test]
926    fn test_multi_task_learner_creation() {
927        let learner = MultiTaskLearner::new();
928        assert_eq!(learner.config.shared_dim, 256);
929        assert!(learner.config.enable_dynamic_weighting);
930    }
931
932    #[test]
933    fn test_task_registration() {
934        let mut learner = MultiTaskLearner::new();
935        let task = Task {
936            task_id: "test_task".to_string(),
937            task_name: "Test Task".to_string(),
938            task_type: TaskType::ShapeLearning,
939            priority: 1.0,
940            difficulty: 0.5,
941            data_size: 1000,
942            related_tasks: Vec::new(),
943            learning_objective: LearningObjective::Classification { num_classes: 5 },
944            performance_history: VecDeque::new(),
945        };
946
947        learner.register_task(task).expect("should succeed");
948        assert_eq!(learner.task_heads.len(), 1);
949        assert_eq!(learner.task_weights.len(), 1);
950    }
951
952    #[test]
953    fn test_multi_task_config() {
954        let config = MultiTaskConfig {
955            sharing_type: SharingType::SoftSharing,
956            shared_dim: 128,
957            enable_curriculum: false,
958            ..Default::default()
959        };
960
961        assert_eq!(config.sharing_type, SharingType::SoftSharing);
962        assert_eq!(config.shared_dim, 128);
963        assert!(!config.enable_curriculum);
964    }
965
966    #[test]
967    fn test_task_relationship() {
968        let relationship = TaskRelationship {
969            source_task: "task1".to_string(),
970            target_task: "task2".to_string(),
971            relationship_type: RelationshipType::Complementary,
972            strength: 0.8,
973            transfer_direction: TransferDirection::Bidirectional,
974            discovered_at: Utc::now(),
975        };
976
977        assert_eq!(relationship.strength, 0.8);
978        assert_eq!(
979            relationship.relationship_type,
980            RelationshipType::Complementary
981        );
982    }
983}