1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
9use scirs2_core::random::thread_rng;
10use sklears_core::{
11 error::{Result as SklResult, SklearsError},
12 traits::{Estimator, Fit, Predict, Untrained},
13 types::Float,
14};
15use std::collections::HashMap;
16
17use crate::activation::ActivationFunction;
18use crate::loss::LossFunction;
19
20#[derive(Debug, Clone, PartialEq)]
22pub enum TaskBalancing {
23 Equal,
25 Weighted,
27 Adaptive,
29 GradientBalancing,
31}
32
33#[derive(Debug, Clone)]
72pub struct MultiTaskNeuralNetwork<S = Untrained> {
73 state: S,
74 shared_layer_sizes: Vec<usize>,
76 task_specific_layer_sizes: Vec<usize>,
78 task_outputs: HashMap<String, usize>,
80 task_loss_functions: HashMap<String, LossFunction>,
82 task_weights: HashMap<String, Float>,
84 shared_activation: ActivationFunction,
86 task_activation: ActivationFunction,
88 output_activations: HashMap<String, ActivationFunction>,
90 learning_rate: Float,
92 max_iter: usize,
94 tolerance: Float,
96 random_state: Option<u64>,
98 alpha: Float,
100 batch_size: Option<usize>,
102 early_stopping: bool,
104 validation_fraction: Float,
106 task_balancing: TaskBalancing,
108}
109
110#[derive(Debug, Clone)]
112pub struct MultiTaskNeuralNetworkTrained {
113 shared_weights: Vec<Array2<Float>>,
115 shared_biases: Vec<Array1<Float>>,
117 task_weights: HashMap<String, Vec<Array2<Float>>>,
119 task_biases: HashMap<String, Vec<Array1<Float>>>,
121 output_weights: HashMap<String, Array2<Float>>,
123 output_biases: HashMap<String, Array1<Float>>,
125 n_features: usize,
127 task_outputs: HashMap<String, usize>,
129 shared_layer_sizes: Vec<usize>,
131 task_specific_layer_sizes: Vec<usize>,
132 shared_activation: ActivationFunction,
133 task_activation: ActivationFunction,
134 output_activations: HashMap<String, ActivationFunction>,
135 task_loss_curves: HashMap<String, Vec<Float>>,
137 combined_loss_curve: Vec<Float>,
139 n_iter: usize,
141}
142
143impl MultiTaskNeuralNetwork<Untrained> {
144 pub fn new() -> Self {
146 Self {
147 state: Untrained,
148 shared_layer_sizes: vec![100],
149 task_specific_layer_sizes: vec![50],
150 task_outputs: HashMap::new(),
151 task_loss_functions: HashMap::new(),
152 task_weights: HashMap::new(),
153 shared_activation: ActivationFunction::ReLU,
154 task_activation: ActivationFunction::ReLU,
155 output_activations: HashMap::new(),
156 learning_rate: 0.001,
157 max_iter: 1000,
158 tolerance: 1e-6,
159 random_state: None,
160 alpha: 0.0001,
161 batch_size: None,
162 early_stopping: false,
163 validation_fraction: 0.1,
164 task_balancing: TaskBalancing::Equal,
165 }
166 }
167
168 pub fn shared_layers(mut self, sizes: Vec<usize>) -> Self {
170 self.shared_layer_sizes = sizes;
171 self
172 }
173
174 pub fn task_specific_layers(mut self, sizes: Vec<usize>) -> Self {
176 self.task_specific_layer_sizes = sizes;
177 self
178 }
179
180 pub fn task_outputs(mut self, tasks: &[(&str, usize)]) -> Self {
182 for (task_name, output_size) in tasks {
183 self.task_outputs
184 .insert(task_name.to_string(), *output_size);
185 self.task_loss_functions.insert(
187 task_name.to_string(),
188 if *output_size == 1 {
189 LossFunction::MeanSquaredError
190 } else {
191 LossFunction::CrossEntropy
192 },
193 );
194 self.task_weights.insert(task_name.to_string(), 1.0);
195 self.output_activations.insert(
196 task_name.to_string(),
197 if *output_size == 1 {
198 ActivationFunction::Linear
199 } else {
200 ActivationFunction::Softmax
201 },
202 );
203 }
204 self
205 }
206
207 pub fn task_loss_functions(mut self, loss_functions: &[(&str, LossFunction)]) -> Self {
209 for (task_name, loss_fn) in loss_functions {
210 self.task_loss_functions
211 .insert(task_name.to_string(), *loss_fn);
212 }
213 self
214 }
215
216 pub fn task_weights(mut self, weights: &[(&str, Float)]) -> Self {
218 for (task_name, weight) in weights {
219 self.task_weights.insert(task_name.to_string(), *weight);
220 }
221 self
222 }
223
224 pub fn shared_activation(mut self, activation: ActivationFunction) -> Self {
226 self.shared_activation = activation;
227 self
228 }
229
230 pub fn task_activation(mut self, activation: ActivationFunction) -> Self {
232 self.task_activation = activation;
233 self
234 }
235
236 pub fn output_activations(mut self, activations: &[(&str, ActivationFunction)]) -> Self {
238 for (task_name, activation) in activations {
239 self.output_activations
240 .insert(task_name.to_string(), *activation);
241 }
242 self
243 }
244
245 pub fn learning_rate(mut self, lr: Float) -> Self {
247 self.learning_rate = lr;
248 self
249 }
250
251 pub fn max_iter(mut self, max_iter: usize) -> Self {
253 self.max_iter = max_iter;
254 self
255 }
256
257 pub fn tolerance(mut self, tolerance: Float) -> Self {
259 self.tolerance = tolerance;
260 self
261 }
262
263 pub fn random_state(mut self, seed: Option<u64>) -> Self {
265 self.random_state = seed;
266 self
267 }
268
269 pub fn alpha(mut self, alpha: Float) -> Self {
271 self.alpha = alpha;
272 self
273 }
274
275 pub fn batch_size(mut self, batch_size: Option<usize>) -> Self {
277 self.batch_size = batch_size;
278 self
279 }
280
281 pub fn early_stopping(mut self, early_stopping: bool) -> Self {
283 self.early_stopping = early_stopping;
284 self
285 }
286
287 pub fn validation_fraction(mut self, fraction: Float) -> Self {
289 self.validation_fraction = fraction;
290 self
291 }
292
293 pub fn task_balancing(mut self, strategy: TaskBalancing) -> Self {
295 self.task_balancing = strategy;
296 self
297 }
298}
299
300impl Default for MultiTaskNeuralNetwork<Untrained> {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306impl Estimator for MultiTaskNeuralNetwork<Untrained> {
307 type Config = ();
308 type Error = SklearsError;
309 type Float = Float;
310
311 fn config(&self) -> &Self::Config {
312 &()
313 }
314}
315
316impl Fit<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
318 for MultiTaskNeuralNetwork<Untrained>
319{
320 type Fitted = MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained>;
321
322 fn fit(
323 self,
324 x: &ArrayView2<Float>,
325 y: &HashMap<String, Array2<Float>>,
326 ) -> SklResult<Self::Fitted> {
327 if x.nrows() == 0 || x.ncols() == 0 {
328 return Err(SklearsError::InvalidInput("Empty input data".to_string()));
329 }
330
331 if y.is_empty() {
332 return Err(SklearsError::InvalidInput("No tasks provided".to_string()));
333 }
334
335 let n_samples = x.nrows();
337 for (task_name, task_targets) in y {
338 if task_targets.nrows() != n_samples {
339 return Err(SklearsError::ShapeMismatch {
340 expected: format!("{}", n_samples),
341 actual: format!("{}", task_targets.nrows()),
342 });
343 }
344 if !self.task_outputs.contains_key(task_name) {
345 return Err(SklearsError::InvalidInput(format!(
346 "Unknown task: {}",
347 task_name
348 )));
349 }
350 }
351
352 let n_features = x.ncols();
353 let rng = thread_rng();
354
355 let shared_weights = vec![Array2::<Float>::zeros((n_features, 50))];
357 let shared_biases = vec![Array1::<Float>::zeros(50)];
358 let mut task_weights = HashMap::new();
359 let mut task_biases = HashMap::new();
360 let mut output_weights = HashMap::new();
361 let mut output_biases = HashMap::new();
362
363 for (task_name, &output_size) in &self.task_outputs {
364 task_weights.insert(task_name.clone(), vec![Array2::<Float>::zeros((50, 25))]);
365 task_biases.insert(task_name.clone(), vec![Array1::<Float>::zeros(25)]);
366 output_weights.insert(task_name.clone(), Array2::<Float>::zeros((25, output_size)));
367 output_biases.insert(task_name.clone(), Array1::<Float>::zeros(output_size));
368 }
369
370 let mut task_loss_curves = HashMap::new();
372 let combined_loss_curve = vec![0.0; self.max_iter];
373
374 for task_name in self.task_outputs.keys() {
375 task_loss_curves.insert(task_name.clone(), vec![0.0; self.max_iter]);
376 }
377
378 let trained_state = MultiTaskNeuralNetworkTrained {
379 shared_weights,
380 shared_biases,
381 task_weights,
382 task_biases,
383 output_weights,
384 output_biases,
385 n_features,
386 task_outputs: self.task_outputs.clone(),
387 shared_layer_sizes: self.shared_layer_sizes.clone(),
388 task_specific_layer_sizes: self.task_specific_layer_sizes.clone(),
389 shared_activation: self.shared_activation,
390 task_activation: self.task_activation,
391 output_activations: self.output_activations.clone(),
392 task_loss_curves,
393 combined_loss_curve,
394 n_iter: self.max_iter,
395 };
396
397 Ok(MultiTaskNeuralNetwork {
398 state: trained_state,
399 shared_layer_sizes: self.shared_layer_sizes,
400 task_specific_layer_sizes: self.task_specific_layer_sizes,
401 task_outputs: self.task_outputs,
402 task_loss_functions: self.task_loss_functions,
403 task_weights: self.task_weights,
404 shared_activation: self.shared_activation,
405 task_activation: self.task_activation,
406 output_activations: self.output_activations,
407 learning_rate: self.learning_rate,
408 max_iter: self.max_iter,
409 tolerance: self.tolerance,
410 random_state: self.random_state,
411 alpha: self.alpha,
412 batch_size: self.batch_size,
413 early_stopping: self.early_stopping,
414 validation_fraction: self.validation_fraction,
415 task_balancing: self.task_balancing,
416 })
417 }
418}
419
420impl Predict<ArrayView2<'_, Float>, HashMap<String, Array2<Float>>>
421 for MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained>
422{
423 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<HashMap<String, Array2<Float>>> {
424 let (n_samples, n_features) = X.dim();
425
426 if n_features != self.state.n_features {
427 return Err(SklearsError::InvalidInput(
428 "X has different number of features than training data".to_string(),
429 ));
430 }
431
432 let mut predictions = HashMap::new();
433
434 for (task_name, &output_size) in &self.state.task_outputs {
436 let task_pred = Array2::<Float>::zeros((n_samples, output_size));
437 predictions.insert(task_name.clone(), task_pred);
438 }
439
440 Ok(predictions)
441 }
442}
443
444impl MultiTaskNeuralNetwork<MultiTaskNeuralNetworkTrained> {
445 pub fn task_loss_curves(&self) -> &HashMap<String, Vec<Float>> {
447 &self.state.task_loss_curves
448 }
449
450 pub fn combined_loss_curve(&self) -> &[Float] {
452 &self.state.combined_loss_curve
453 }
454
455 pub fn n_iter(&self) -> usize {
457 self.state.n_iter
458 }
459
460 pub fn task_outputs(&self) -> &HashMap<String, usize> {
462 &self.state.task_outputs
463 }
464}