sklears_multioutput/
regularization.rs

1//! Multi-Task Regularization Methods
2//!
3//! This module provides various regularization techniques specifically designed for multi-task
4//! and multi-output learning scenarios. These methods help in learning shared structure
5//! across tasks while preventing overfitting.
6//!
7//! The module has been refactored into smaller submodules to comply with the 2000-line limit:
8//!
9//! - [`simd_ops`] - SIMD-accelerated operations for high-performance regularization computations
10//! - [`group_lasso`] - Group Lasso regularization for feature group selection
11//! - [`nuclear_norm`] - Nuclear norm regularization for low-rank structure learning
12//! - [`task_clustering`] - Task clustering regularization for similar task grouping
13//! - [`task_relationship`] - Task relationship learning for explicit task relationships
14//! - [`meta_learning`] - Meta-learning approach for quick adaptation to new tasks
15
16// Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
17use scirs2_core::ndarray::{Array1, Array2};
18use sklears_core::{traits::Untrained, types::Float};
19use std::collections::HashMap;
20
21// Submodules
22#[path = "regularization/simd_ops.rs"]
23pub mod simd_ops;
24
25#[path = "regularization/group_lasso.rs"]
26pub mod group_lasso;
27
28#[path = "regularization/nuclear_norm.rs"]
29pub mod nuclear_norm;
30
31#[path = "regularization/task_clustering.rs"]
32pub mod task_clustering;
33
34#[path = "regularization/task_relationship.rs"]
35pub mod task_relationship;
36
37#[path = "regularization/meta_learning.rs"]
38pub mod meta_learning;
39
40// Re-export the main types from submodules
41pub use group_lasso::{GroupLasso, GroupLassoTrained};
42pub use meta_learning::{MetaLearningMultiTask, MetaLearningMultiTaskTrained};
43pub use nuclear_norm::{NuclearNormRegression, NuclearNormRegressionTrained};
44pub use task_clustering::{TaskClusteringRegressionTrained, TaskClusteringRegularization};
45pub use task_relationship::{
46    TaskRelationshipLearning, TaskRelationshipLearningTrained, TaskSimilarityMethod,
47};
48
49/// Multi-Task Elastic Net with Group Structure
50///
51/// Combines L1 and L2 regularization with group structure awareness.
52/// Useful for scenarios where we want both feature selection and group selection.
53///
54/// Note: This is a placeholder struct - implementation is not yet complete.
55#[derive(Debug, Clone)]
56pub struct MultiTaskElasticNet<S = Untrained> {
57    state: S,
58    /// L1 regularization strength
59    alpha: Float,
60    /// L1 vs L2 balance (0 = Ridge, 1 = Lasso)
61    l1_ratio: Float,
62    /// Feature groups for group penalties
63    feature_groups: Vec<Vec<usize>>,
64    /// Group penalty strength
65    group_alpha: Float,
66    /// Maximum number of iterations
67    max_iter: usize,
68    /// Convergence tolerance
69    tolerance: Float,
70    /// Learning rate
71    learning_rate: Float,
72    /// Task configurations
73    task_outputs: HashMap<String, usize>,
74    /// Include intercept term
75    fit_intercept: bool,
76}
77
78/// Trained state for MultiTaskElasticNet
79///
80/// Note: This is a placeholder struct - implementation is not yet complete.
81#[derive(Debug, Clone)]
82pub struct MultiTaskElasticNetTrained {
83    /// Coefficients for each task
84    coefficients: HashMap<String, Array2<Float>>,
85    /// Intercepts for each task
86    intercepts: HashMap<String, Array1<Float>>,
87    /// Number of input features
88    n_features: usize,
89    /// Task configurations
90    task_outputs: HashMap<String, usize>,
91    /// Training parameters
92    alpha: Float,
93    l1_ratio: Float,
94    group_alpha: Float,
95    /// Training iterations performed
96    n_iter: usize,
97}
98
99/// Regularization strategies for multi-task learning
100#[derive(Debug, Clone, PartialEq, Default)]
101pub enum RegularizationStrategy {
102    /// No regularization
103    #[default]
104    None,
105    /// L1 regularization (Lasso)
106    L1(Float),
107    /// L2 regularization (Ridge)
108    L2(Float),
109    /// Elastic Net (L1 + L2)
110    ElasticNet { alpha: Float, l1_ratio: Float },
111    /// Group Lasso
112    GroupLasso { alpha: Float },
113    /// Nuclear norm regularization
114    NuclearNorm { alpha: Float },
115    /// Task clustering regularization
116    TaskClustering {
117        n_clusters: usize,
118        intra_cluster_alpha: Float,
119        inter_cluster_alpha: Float,
120    },
121    /// Task relationship learning
122    TaskRelationship {
123        relationship_strength: Float,
124        similarity_threshold: Float,
125    },
126    /// Meta-learning for multi-task
127    MetaLearning {
128        meta_learning_rate: Float,
129        inner_learning_rate: Float,
130        n_inner_steps: usize,
131    },
132}
133
134// Keep the tests in the main module for backwards compatibility
135#[allow(non_snake_case)]
136#[cfg(test)]
137mod regularization_tests {
138    use super::*;
139    use approx::assert_abs_diff_eq;
140    // Use SciRS2-Core for arrays and random number generation (SciRS2 Policy)
141    use scirs2_core::ndarray::array;
142    use sklears_core::traits::{Fit, Predict};
143    use std::collections::HashMap;
144
145    #[test]
146    fn test_group_lasso_creation() {
147        let group_lasso = GroupLasso::new()
148            .alpha(0.1)
149            .feature_groups(vec![vec![0, 1], vec![2, 3]])
150            .max_iter(100)
151            .tolerance(1e-6)
152            .learning_rate(0.01);
153
154        assert_eq!(group_lasso.alpha, 0.1);
155        assert_eq!(group_lasso.feature_groups, vec![vec![0, 1], vec![2, 3]]);
156        assert_eq!(group_lasso.max_iter, 100);
157        assert_abs_diff_eq!(group_lasso.tolerance, 1e-6);
158        assert_abs_diff_eq!(group_lasso.learning_rate, 0.01);
159    }
160
161    #[test]
162    fn test_group_lasso_fit_predict() {
163        let X = array![
164            [1.0, 2.0, 3.0, 4.0],
165            [2.0, 3.0, 4.0, 5.0],
166            [3.0, 1.0, 2.0, 3.0],
167            [4.0, 2.0, 1.0, 2.0]
168        ];
169
170        let mut y_tasks = HashMap::new();
171        y_tasks.insert("task1".to_string(), array![[1.0], [2.0], [1.5], [2.5]]);
172        y_tasks.insert("task2".to_string(), array![[0.5], [1.0], [0.8], [1.2]]);
173
174        let feature_groups = vec![vec![0, 1], vec![2, 3]];
175
176        let group_lasso = GroupLasso::new()
177            .alpha(0.01)
178            .feature_groups(feature_groups)
179            .task_outputs(&[("task1", 1), ("task2", 1)])
180            .max_iter(50)
181            .tolerance(1e-4)
182            .learning_rate(0.01);
183
184        let trained = group_lasso.fit(&X.view(), &y_tasks).unwrap();
185
186        // Test predictions
187        let predictions = trained.predict(&X.view()).unwrap();
188        assert!(predictions.contains_key("task1"));
189        assert!(predictions.contains_key("task2"));
190
191        let task1_pred = &predictions["task1"];
192        let task2_pred = &predictions["task2"];
193
194        assert_eq!(task1_pred.shape(), &[4, 1]);
195        assert_eq!(task2_pred.shape(), &[4, 1]);
196
197        // Test group sparsity
198        let sparsity = trained.group_sparsity();
199        assert!(sparsity >= 0.0 && sparsity <= 1.0); // Should be a percentage
200
201        // Test accessors
202        assert!(trained.task_coefficients("task1").is_some());
203        assert!(trained.task_intercepts("task1").is_some());
204        assert!(trained.n_iter() <= 50);
205    }
206
207    #[test]
208    fn test_nuclear_norm_regression_creation() {
209        let nuclear_norm = NuclearNormRegression::new()
210            .alpha(0.1)
211            .max_iter(100)
212            .tolerance(1e-6)
213            .learning_rate(0.01)
214            .target_rank(Some(5));
215
216        assert_eq!(nuclear_norm.alpha, 0.1);
217        assert_eq!(nuclear_norm.max_iter, 100);
218        assert_abs_diff_eq!(nuclear_norm.tolerance, 1e-6);
219        assert_abs_diff_eq!(nuclear_norm.learning_rate, 0.01);
220        assert_eq!(nuclear_norm.target_rank, Some(5));
221    }
222
223    #[test]
224    fn test_nuclear_norm_regression_fit_predict() {
225        let X = array![[1.0, 2.0], [2.0, 3.0], [3.0, 1.0], [4.0, 4.0]];
226
227        let mut y_tasks = HashMap::new();
228        y_tasks.insert("task1".to_string(), array![[1.0], [2.0], [1.5], [2.5]]);
229        y_tasks.insert("task2".to_string(), array![[0.5], [1.0], [0.8], [1.2]]);
230
231        let nuclear_norm = NuclearNormRegression::new()
232            .alpha(0.01)
233            .task_outputs(&[("task1", 1), ("task2", 1)])
234            .max_iter(50)
235            .tolerance(1e-4)
236            .learning_rate(0.01);
237
238        let trained = nuclear_norm.fit(&X.view(), &y_tasks).unwrap();
239
240        // Test predictions
241        let predictions = trained.predict(&X.view()).unwrap();
242        assert!(predictions.contains_key("task1"));
243        assert!(predictions.contains_key("task2"));
244
245        let task1_pred = &predictions["task1"];
246        let task2_pred = &predictions["task2"];
247
248        assert_eq!(task1_pred.shape(), &[4, 1]);
249        assert_eq!(task2_pred.shape(), &[4, 1]);
250
251        // Test accessors
252        assert!(trained.task_coefficient_matrix("task1").is_some());
253        assert!(trained.effective_rank() >= 0);
254        assert!(!trained.singular_values().is_empty());
255        assert!(trained.n_iter() <= 50);
256    }
257
258    #[test]
259    fn test_regularization_strategies() {
260        let strategies = vec![
261            RegularizationStrategy::None,
262            RegularizationStrategy::L1(0.1),
263            RegularizationStrategy::L2(0.1),
264            RegularizationStrategy::ElasticNet {
265                alpha: 0.1,
266                l1_ratio: 0.5,
267            },
268            RegularizationStrategy::GroupLasso { alpha: 0.1 },
269            RegularizationStrategy::NuclearNorm { alpha: 0.1 },
270            RegularizationStrategy::TaskClustering {
271                n_clusters: 2,
272                intra_cluster_alpha: 0.1,
273                inter_cluster_alpha: 0.01,
274            },
275            RegularizationStrategy::TaskRelationship {
276                relationship_strength: 0.1,
277                similarity_threshold: 0.5,
278            },
279            RegularizationStrategy::MetaLearning {
280                meta_learning_rate: 0.01,
281                inner_learning_rate: 0.1,
282                n_inner_steps: 5,
283            },
284        ];
285
286        assert_eq!(strategies.len(), 9);
287        assert_eq!(strategies[0], RegularizationStrategy::None);
288        assert_eq!(strategies[1], RegularizationStrategy::L1(0.1));
289    }
290
291    #[test]
292    fn test_task_similarity_methods() {
293        let methods = vec![
294            TaskSimilarityMethod::Correlation,
295            TaskSimilarityMethod::Cosine,
296            TaskSimilarityMethod::Euclidean,
297            TaskSimilarityMethod::MutualInformation,
298        ];
299
300        assert_eq!(methods.len(), 4);
301        assert_eq!(methods[0], TaskSimilarityMethod::Correlation);
302        assert_eq!(methods[1], TaskSimilarityMethod::Cosine);
303    }
304}