1use scirs2_core::ndarray::{s, Array1, Array2, ArrayView1, ArrayView2, Axis};
5use sklears_core::{
6 error::{Result as SklResult, SklearsError},
7 traits::{Estimator, Fit, Predict, Untrained},
8 types::Float,
9};
10
11#[derive(Debug, Clone)]
34pub struct MLTSVM<S = Untrained> {
35 state: S,
36 c1: Float, c2: Float, epsilon: Float, max_iter: usize, }
41
42#[derive(Debug, Clone)]
44pub struct MLTSVMTrained {
45 models: Vec<TwinSVMModel>, n_labels: usize,
47 feature_means: Array1<Float>,
48 feature_stds: Array1<Float>,
49}
50
51#[derive(Debug, Clone)]
53pub struct TwinSVMModel {
54 w1: Array1<Float>, b1: Float, w2: Array1<Float>, b2: Float, }
59
60impl MLTSVM<Untrained> {
61 pub fn new() -> Self {
63 Self {
64 state: Untrained,
65 c1: 1.0,
66 c2: 1.0,
67 epsilon: 1e-3,
68 max_iter: 1000,
69 }
70 }
71
72 pub fn c1(mut self, c1: Float) -> Self {
74 self.c1 = c1;
75 self
76 }
77
78 pub fn c2(mut self, c2: Float) -> Self {
80 self.c2 = c2;
81 self
82 }
83
84 pub fn epsilon(mut self, epsilon: Float) -> Self {
86 self.epsilon = epsilon;
87 self
88 }
89
90 pub fn max_iter(mut self, max_iter: usize) -> Self {
92 self.max_iter = max_iter;
93 self
94 }
95}
96
97impl Default for MLTSVM<Untrained> {
98 fn default() -> Self {
99 Self::new()
100 }
101}
102
103impl Estimator for MLTSVM<Untrained> {
104 type Config = ();
105 type Error = SklearsError;
106 type Float = Float;
107
108 fn config(&self) -> &Self::Config {
109 &()
110 }
111}
112
113impl Fit<ArrayView2<'_, Float>, Array2<i32>> for MLTSVM<Untrained> {
114 type Fitted = MLTSVM<MLTSVMTrained>;
115
116 fn fit(self, x: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
117 let (n_samples, n_features) = x.dim();
118 let (y_samples, n_labels) = y.dim();
119
120 if n_samples != y_samples {
121 return Err(SklearsError::InvalidInput(
122 "Number of samples in X and y must match".to_string(),
123 ));
124 }
125
126 if n_samples < 2 {
127 return Err(SklearsError::InvalidInput(
128 "Need at least 2 samples for SVM training".to_string(),
129 ));
130 }
131
132 for sample_idx in 0..y_samples {
134 for label_idx in 0..n_labels {
135 let value = y[[sample_idx, label_idx]];
136 if value != 0 && value != 1 {
137 return Err(SklearsError::InvalidInput(format!(
138 "All label values must be 0 or 1, found: {}",
139 value
140 )));
141 }
142 }
143 }
144
145 let feature_means = x.mean_axis(Axis(0)).unwrap();
147 let feature_stds = x.mapv(|val| val * val).mean_axis(Axis(0)).unwrap()
148 - &feature_means.mapv(|mean| mean * mean);
149 let feature_stds = feature_stds.mapv(|var| (var.max(1e-10)).sqrt());
150
151 let mut models = Vec::new();
153 for label_idx in 0..n_labels {
154 let y_label = y.column(label_idx);
155 let model = self.train_twin_svm(x, &y_label, &feature_means, &feature_stds)?;
156 models.push(model);
157 }
158
159 Ok(MLTSVM {
160 state: MLTSVMTrained {
161 models,
162 n_labels,
163 feature_means,
164 feature_stds,
165 },
166 c1: self.c1,
167 c2: self.c2,
168 epsilon: self.epsilon,
169 max_iter: self.max_iter,
170 })
171 }
172}
173
174impl MLTSVM<Untrained> {
175 fn train_twin_svm(
176 &self,
177 x: &ArrayView2<'_, Float>,
178 y: &ArrayView1<'_, i32>,
179 feature_means: &Array1<Float>,
180 feature_stds: &Array1<Float>,
181 ) -> SklResult<TwinSVMModel> {
182 let (n_samples, n_features) = x.dim();
183
184 let mut x_normalized = x.to_owned();
186 for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
187 row -= feature_means;
188 row /= feature_stds;
189 }
190
191 let mut pos_samples = Vec::new();
193 let mut neg_samples = Vec::new();
194
195 for i in 0..n_samples {
196 if y[i] == 1 {
197 pos_samples.push(x_normalized.row(i).to_owned());
198 } else {
199 neg_samples.push(x_normalized.row(i).to_owned());
200 }
201 }
202
203 if pos_samples.is_empty() || neg_samples.is_empty() {
204 return Err(SklearsError::InvalidInput(
205 "Need both positive and negative samples for Twin SVM".to_string(),
206 ));
207 }
208
209 let pos_matrix = Array2::from_shape_vec(
211 (pos_samples.len(), n_features),
212 pos_samples.into_iter().flatten().collect(),
213 )
214 .map_err(|_| SklearsError::InvalidInput("Failed to create positive matrix".to_string()))?;
215
216 let neg_matrix = Array2::from_shape_vec(
217 (neg_samples.len(), n_features),
218 neg_samples.into_iter().flatten().collect(),
219 )
220 .map_err(|_| SklearsError::InvalidInput("Failed to create negative matrix".to_string()))?;
221
222 let (w1, b1) = self.solve_twin_svm_problem(&pos_matrix, &neg_matrix, self.c1)?;
224 let (w2, b2) = self.solve_twin_svm_problem(&neg_matrix, &pos_matrix, self.c2)?;
225
226 Ok(TwinSVMModel { w1, b1, w2, b2 })
227 }
228
229 fn solve_twin_svm_problem(
230 &self,
231 target_matrix: &Array2<Float>,
232 other_matrix: &Array2<Float>,
233 c: Float,
234 ) -> SklResult<(Array1<Float>, Float)> {
235 let n_target = target_matrix.nrows();
236 let n_other = other_matrix.nrows();
237 let n_features = target_matrix.ncols();
238
239 let mut w = Array1::<Float>::zeros(n_features + 1); let learning_rate = 0.01;
244
245 for _iter in 0..self.max_iter {
246 let mut gradient = Array1::<Float>::zeros(n_features + 1);
247
248 for i in 0..n_target {
250 let x_aug = {
251 let mut x = Array1::ones(n_features + 1);
252 x.slice_mut(s![..n_features]).assign(&target_matrix.row(i));
253 x
254 };
255 let loss = x_aug.dot(&w);
256 gradient += &(x_aug * loss);
257 }
258
259 for i in 0..n_other {
260 let x_aug = {
261 let mut x = Array1::ones(n_features + 1);
262 x.slice_mut(s![..n_features]).assign(&other_matrix.row(i));
263 x
264 };
265 let margin = 1.0 - x_aug.dot(&w);
266 if margin > 0.0 {
267 gradient -= &(x_aug * c);
268 }
269 }
270
271 let gradient_norm = gradient.mapv(|x| x.abs()).sum();
273
274 w -= &(gradient * learning_rate);
276
277 if gradient_norm < self.epsilon {
278 break;
279 }
280 }
281
282 let weights = w.slice(s![..n_features]).to_owned();
283 let bias = w[n_features];
284
285 Ok((weights, bias))
286 }
287}
288
289impl Predict<ArrayView2<'_, Float>, Array2<i32>> for MLTSVM<MLTSVMTrained> {
290 fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
291 let (n_samples, n_features) = x.dim();
292 let expected_features = self.state.feature_means.len();
293
294 if n_features != expected_features {
295 return Err(SklearsError::InvalidInput(format!(
296 "Number of features in X ({}) does not match training data ({})",
297 n_features, expected_features
298 )));
299 }
300
301 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
302
303 let mut x_normalized = x.to_owned();
305 for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
306 row -= &self.state.feature_means;
307 row /= &self.state.feature_stds;
308 }
309
310 for label_idx in 0..self.state.n_labels {
311 let model = &self.state.models[label_idx];
312
313 for sample_idx in 0..n_samples {
314 let x_sample = x_normalized.row(sample_idx);
315
316 let dist1 = (x_sample.dot(&model.w1) + model.b1).abs();
318 let dist2 = (x_sample.dot(&model.w2) + model.b2).abs();
319
320 predictions[[sample_idx, label_idx]] = if dist1 < dist2 { 1 } else { 0 };
322 }
323 }
324
325 Ok(predictions)
326 }
327}
328
329impl MLTSVM<MLTSVMTrained> {
330 pub fn n_labels(&self) -> usize {
332 self.state.n_labels
333 }
334
335 pub fn decision_function(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
337 let (n_samples, _n_features) = x.dim();
338 let mut decision_values = Array2::<Float>::zeros((n_samples, self.state.n_labels));
339
340 let mut x_normalized = x.to_owned();
342 for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
343 row -= &self.state.feature_means;
344 row /= &self.state.feature_stds;
345 }
346
347 for label_idx in 0..self.state.n_labels {
348 let model = &self.state.models[label_idx];
349
350 for sample_idx in 0..n_samples {
351 let x_sample = x_normalized.row(sample_idx);
352
353 let dist1 = x_sample.dot(&model.w1) + model.b1;
355 let dist2 = x_sample.dot(&model.w2) + model.b2;
356
357 decision_values[[sample_idx, label_idx]] = dist1 - dist2;
359 }
360 }
361
362 Ok(decision_values)
363 }
364}
365
366#[derive(Debug, Clone)]
372pub struct RankSVM<S = Untrained> {
373 state: S,
374 c: Float, epsilon: Float, max_iter: usize, threshold_strategy: ThresholdStrategy, }
379
380#[derive(Debug, Clone)]
382pub enum ThresholdStrategy {
383 Fixed(Float),
385 OptimizeF1,
387 TopK(usize),
389}
390
391#[derive(Debug, Clone)]
393pub struct RankSVMTrained {
394 models: Vec<RankingSVMModel>, thresholds: Vec<Float>, n_labels: usize,
397 feature_means: Array1<Float>,
398 feature_stds: Array1<Float>,
399}
400
401#[derive(Debug, Clone)]
403pub struct RankingSVMModel {
404 weights: Array1<Float>,
405 bias: Float,
406}
407
408impl RankSVM<Untrained> {
409 pub fn new() -> Self {
411 Self {
412 state: Untrained,
413 c: 1.0,
414 epsilon: 1e-3,
415 max_iter: 1000,
416 threshold_strategy: ThresholdStrategy::Fixed(0.0),
417 }
418 }
419
420 pub fn c(mut self, c: Float) -> Self {
422 self.c = c;
423 self
424 }
425
426 pub fn epsilon(mut self, epsilon: Float) -> Self {
428 self.epsilon = epsilon;
429 self
430 }
431
432 pub fn max_iter(mut self, max_iter: usize) -> Self {
434 self.max_iter = max_iter;
435 self
436 }
437
438 pub fn threshold_strategy(mut self, strategy: ThresholdStrategy) -> Self {
440 self.threshold_strategy = strategy;
441 self
442 }
443}
444
445impl Default for RankSVM<Untrained> {
446 fn default() -> Self {
447 Self::new()
448 }
449}
450
451impl Estimator for RankSVM<Untrained> {
452 type Config = ();
453 type Error = SklearsError;
454 type Float = Float;
455
456 fn config(&self) -> &Self::Config {
457 &()
458 }
459}
460
461impl Fit<ArrayView2<'_, Float>, Array2<i32>> for RankSVM<Untrained> {
462 type Fitted = RankSVM<RankSVMTrained>;
463
464 fn fit(self, x: &ArrayView2<'_, Float>, y: &Array2<i32>) -> SklResult<Self::Fitted> {
465 let (n_samples, n_features) = x.dim();
466 let (y_samples, n_labels) = y.dim();
467
468 if n_samples != y_samples {
469 return Err(SklearsError::InvalidInput(
470 "Number of samples in X and y must match".to_string(),
471 ));
472 }
473
474 for sample_idx in 0..y_samples {
476 for label_idx in 0..n_labels {
477 let value = y[[sample_idx, label_idx]];
478 if value != 0 && value != 1 {
479 return Err(SklearsError::InvalidInput(format!(
480 "All label values must be 0 or 1, found: {}",
481 value
482 )));
483 }
484 }
485 }
486
487 let feature_means = x.mean_axis(Axis(0)).ok_or_else(|| {
489 SklearsError::InvalidInput("Cannot compute feature means from input data".to_string())
490 })?;
491
492 let squared_means = x.mapv(|val| val * val).mean_axis(Axis(0)).ok_or_else(|| {
493 SklearsError::InvalidInput("Cannot compute squared means from input data".to_string())
494 })?;
495
496 let feature_stds = squared_means - &feature_means.mapv(|mean| mean * mean);
497 let feature_stds = feature_stds.mapv(|var| (var.max(1e-10)).sqrt());
498
499 let mut models = Vec::new();
501 for label_idx in 0..n_labels {
502 let y_label = y.column(label_idx);
503 let model = self.train_ranking_svm(x, &y_label, &feature_means, &feature_stds)?;
504 models.push(model);
505 }
506
507 let thresholds = match &self.threshold_strategy {
509 ThresholdStrategy::Fixed(threshold) => vec![*threshold; n_labels],
510 ThresholdStrategy::OptimizeF1 => {
511 self.optimize_f1_thresholds(x, y, &models, &feature_means, &feature_stds)?
512 }
513 ThresholdStrategy::TopK(_) => vec![0.0; n_labels], };
515
516 Ok(RankSVM {
517 state: RankSVMTrained {
518 models,
519 thresholds,
520 n_labels,
521 feature_means,
522 feature_stds,
523 },
524 c: self.c,
525 epsilon: self.epsilon,
526 max_iter: self.max_iter,
527 threshold_strategy: self.threshold_strategy,
528 })
529 }
530}
531
532impl RankSVM<Untrained> {
533 fn train_ranking_svm(
534 &self,
535 x: &ArrayView2<'_, Float>,
536 y: &ArrayView1<'_, i32>,
537 feature_means: &Array1<Float>,
538 feature_stds: &Array1<Float>,
539 ) -> SklResult<RankingSVMModel> {
540 let (n_samples, n_features) = x.dim();
541
542 let mut x_normalized = x.to_owned();
544 for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
545 row -= feature_means;
546 row /= feature_stds;
547 }
548
549 let mut weights = Array1::<Float>::zeros(n_features);
551 let mut bias = 0.0;
552
553 let learning_rate = 0.01;
554
555 for _iter in 0..self.max_iter {
557 let mut weight_gradient = Array1::<Float>::zeros(n_features);
558 let mut bias_gradient = 0.0;
559
560 for i in 0..n_samples {
562 for j in 0..n_samples {
563 if y[i] > y[j] {
564 let x_i = x_normalized.row(i);
566 let x_j = x_normalized.row(j);
567 let x_diff = &x_i.to_owned() - &x_j.to_owned();
568
569 let score_diff = x_diff.dot(&weights) + bias;
570 let margin = 1.0 - score_diff;
571
572 if margin > 0.0 {
573 weight_gradient -= &(x_diff * self.c);
575 bias_gradient -= self.c;
576 }
577 }
578 }
579 }
580
581 weight_gradient += &(&weights * 2.0);
583
584 let gradient_norm = weight_gradient.mapv(|x| x.abs()).sum();
586
587 weights -= &(weight_gradient * learning_rate);
589 bias -= bias_gradient * learning_rate;
590
591 if gradient_norm < self.epsilon {
592 break;
593 }
594 }
595
596 Ok(RankingSVMModel { weights, bias })
597 }
598
599 fn optimize_f1_thresholds(
600 &self,
601 x: &ArrayView2<'_, Float>,
602 y: &Array2<i32>,
603 models: &[RankingSVMModel],
604 feature_means: &Array1<Float>,
605 feature_stds: &Array1<Float>,
606 ) -> SklResult<Vec<Float>> {
607 let mut thresholds = Vec::new();
608
609 for label_idx in 0..y.ncols() {
610 let y_true = y.column(label_idx);
611 let scores = self.predict_scores_single_label(
612 x,
613 &models[label_idx],
614 feature_means,
615 feature_stds,
616 )?;
617
618 let threshold = self.find_optimal_f1_threshold(&y_true, &scores)?;
619 thresholds.push(threshold);
620 }
621
622 Ok(thresholds)
623 }
624
625 fn predict_scores_single_label(
626 &self,
627 x: &ArrayView2<'_, Float>,
628 model: &RankingSVMModel,
629 feature_means: &Array1<Float>,
630 feature_stds: &Array1<Float>,
631 ) -> SklResult<Array1<Float>> {
632 let (n_samples, _) = x.dim();
633 let mut scores = Array1::<Float>::zeros(n_samples);
634
635 for i in 0..n_samples {
636 let x_sample = x.row(i);
637 let x_normalized = (&x_sample.to_owned() - feature_means) / feature_stds;
638 scores[i] = x_normalized.dot(&model.weights) + model.bias;
639 }
640
641 Ok(scores)
642 }
643
644 fn find_optimal_f1_threshold(
645 &self,
646 y_true: &ArrayView1<'_, i32>,
647 scores: &Array1<Float>,
648 ) -> SklResult<Float> {
649 let mut score_threshold_pairs: Vec<(Float, i32)> = scores
650 .iter()
651 .zip(y_true.iter())
652 .map(|(&score, &label)| (score, label))
653 .collect();
654
655 score_threshold_pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
656
657 let mut best_f1 = 0.0;
658 let mut best_threshold = 0.0;
659
660 for &(threshold, _) in &score_threshold_pairs {
662 let mut tp = 0;
663 let mut fp = 0;
664 let mut fn_count = 0;
665
666 for (&score, &true_label) in scores.iter().zip(y_true.iter()) {
667 let predicted = if score >= threshold { 1 } else { 0 };
668
669 match (true_label, predicted) {
670 (1, 1) => tp += 1,
671 (0, 1) => fp += 1,
672 (1, 0) => fn_count += 1,
673 _ => {}
674 }
675 }
676
677 let precision = if tp + fp > 0 {
678 tp as Float / (tp + fp) as Float
679 } else {
680 0.0
681 };
682 let recall = if tp + fn_count > 0 {
683 tp as Float / (tp + fn_count) as Float
684 } else {
685 0.0
686 };
687 let f1 = if precision + recall > 0.0 {
688 2.0 * precision * recall / (precision + recall)
689 } else {
690 0.0
691 };
692
693 if f1 > best_f1 {
694 best_f1 = f1;
695 best_threshold = threshold;
696 }
697 }
698
699 Ok(best_threshold)
700 }
701}
702
703impl Predict<ArrayView2<'_, Float>, Array2<i32>> for RankSVM<RankSVMTrained> {
704 fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<i32>> {
705 let (n_samples, n_features) = x.dim();
706 let expected_features = self.state.feature_means.len();
707
708 if n_features != expected_features {
709 return Err(SklearsError::InvalidInput(format!(
710 "Number of features in X ({}) does not match training data ({})",
711 n_features, expected_features
712 )));
713 }
714
715 let mut predictions = Array2::<i32>::zeros((n_samples, self.state.n_labels));
716
717 match &self.threshold_strategy {
718 ThresholdStrategy::TopK(k) => {
719 for sample_idx in 0..n_samples {
721 let mut scores = Vec::new();
722 for label_idx in 0..self.state.n_labels {
723 let x_sample = x.row(sample_idx);
724 let x_normalized = (&x_sample.to_owned() - &self.state.feature_means)
725 / &self.state.feature_stds;
726 let score = x_normalized.dot(&self.state.models[label_idx].weights)
727 + self.state.models[label_idx].bias;
728 scores.push((score, label_idx));
729 }
730
731 scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
732
733 for (i, &(_, label_idx)) in scores.iter().take(*k).enumerate() {
734 predictions[[sample_idx, label_idx]] = 1;
735 }
736 }
737 }
738 _ => {
739 for label_idx in 0..self.state.n_labels {
741 let threshold = self.state.thresholds[label_idx];
742
743 for sample_idx in 0..n_samples {
744 let x_sample = x.row(sample_idx);
745 let x_normalized = (&x_sample.to_owned() - &self.state.feature_means)
746 / &self.state.feature_stds;
747 let score = x_normalized.dot(&self.state.models[label_idx].weights)
748 + self.state.models[label_idx].bias;
749
750 predictions[[sample_idx, label_idx]] =
751 if score >= threshold { 1 } else { 0 };
752 }
753 }
754 }
755 }
756
757 Ok(predictions)
758 }
759}
760
761impl RankSVM<RankSVMTrained> {
762 pub fn n_labels(&self) -> usize {
764 self.state.n_labels
765 }
766
767 pub fn decision_function(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
769 let (n_samples, n_features) = x.dim();
770 let expected_features = self.state.feature_means.len();
771
772 if n_features != expected_features {
773 return Err(SklearsError::InvalidInput(format!(
774 "Number of features in X ({}) does not match training data ({})",
775 n_features, expected_features
776 )));
777 }
778
779 let mut decision_values = Array2::<Float>::zeros((n_samples, self.state.n_labels));
780
781 for sample_idx in 0..n_samples {
782 for label_idx in 0..self.state.n_labels {
783 let x_sample = x.row(sample_idx);
784 let x_normalized =
785 (&x_sample.to_owned() - &self.state.feature_means) / &self.state.feature_stds;
786 let score = x_normalized.dot(&self.state.models[label_idx].weights)
787 + self.state.models[label_idx].bias;
788 decision_values[[sample_idx, label_idx]] = score;
789 }
790 }
791
792 Ok(decision_values)
793 }
794
795 pub fn predict_ranking(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<usize>> {
797 let (n_samples, n_features) = x.dim();
798 let expected_features = self.state.feature_means.len();
799
800 if n_features != expected_features {
801 return Err(SklearsError::InvalidInput(format!(
802 "Number of features in X ({}) does not match training data ({})",
803 n_features, expected_features
804 )));
805 }
806
807 let mut rankings = Array2::<usize>::zeros((n_samples, self.state.n_labels));
808
809 for sample_idx in 0..n_samples {
810 let mut scores = Vec::new();
811 for label_idx in 0..self.state.n_labels {
812 let x_sample = x.row(sample_idx);
813 let x_normalized =
814 (&x_sample.to_owned() - &self.state.feature_means) / &self.state.feature_stds;
815 let score = x_normalized.dot(&self.state.models[label_idx].weights)
816 + self.state.models[label_idx].bias;
817 scores.push((score, label_idx));
818 }
819
820 scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
822
823 for (rank, &(_score, label_idx)) in scores.iter().enumerate() {
825 rankings[[sample_idx, rank]] = label_idx;
826 }
827 }
828
829 Ok(rankings)
830 }
831
832 pub fn thresholds(&self) -> &Vec<Float> {
834 &self.state.thresholds
835 }
836}
837
838#[derive(Debug, Clone)]
843pub struct MultiOutputSVM<S = Untrained> {
844 state: S,
845 kernel: SVMKernel,
846 c: Float,
847 epsilon: Float,
848 gamma: Option<Float>,
849}
850
851#[derive(Debug, Clone, Copy, PartialEq)]
853pub enum SVMKernel {
854 Linear,
856 Polynomial {
858 degree: i32,
859 gamma: Float,
860 coef0: Float,
861 },
862 Rbf { gamma: Float },
864 Sigmoid { gamma: Float, coef0: Float },
866}
867
868#[derive(Debug, Clone)]
870pub struct MultiOutputSVMTrained {
871 models: Vec<SVMModel>,
872 n_outputs: usize,
873 feature_means: Array1<Float>,
874 feature_stds: Array1<Float>,
875}
876
877#[derive(Debug, Clone)]
879pub struct SVMModel {
880 support_vectors: Array2<Float>,
881 support_coefficients: Array1<Float>,
882 bias: Float,
883 kernel: SVMKernel,
884}
885
886impl MultiOutputSVM<Untrained> {
887 pub fn new() -> Self {
889 Self {
890 state: Untrained,
891 kernel: SVMKernel::Rbf { gamma: 1.0 },
892 c: 1.0,
893 epsilon: 1e-3,
894 gamma: None,
895 }
896 }
897
898 pub fn kernel(mut self, kernel: SVMKernel) -> Self {
900 self.kernel = kernel;
901 self
902 }
903
904 pub fn c(mut self, c: Float) -> Self {
906 self.c = c;
907 self
908 }
909
910 pub fn epsilon(mut self, epsilon: Float) -> Self {
912 self.epsilon = epsilon;
913 self
914 }
915
916 pub fn gamma(mut self, gamma: Float) -> Self {
918 self.gamma = Some(gamma);
919 self
920 }
921}
922
923impl Default for MultiOutputSVM<Untrained> {
924 fn default() -> Self {
925 Self::new()
926 }
927}
928
929impl Estimator for MultiOutputSVM<Untrained> {
930 type Config = ();
931 type Error = SklearsError;
932 type Float = Float;
933
934 fn config(&self) -> &Self::Config {
935 &()
936 }
937}
938
939impl Fit<ArrayView2<'_, Float>, ArrayView2<'_, Float>> for MultiOutputSVM<Untrained> {
940 type Fitted = MultiOutputSVM<MultiOutputSVMTrained>;
941
942 fn fit(self, x: &ArrayView2<'_, Float>, y: &ArrayView2<'_, Float>) -> SklResult<Self::Fitted> {
943 let (n_samples, n_features) = x.dim();
944 let (y_samples, n_outputs) = y.dim();
945
946 if n_samples != y_samples {
947 return Err(SklearsError::InvalidInput(
948 "Number of samples in X and y must match".to_string(),
949 ));
950 }
951
952 let feature_means = x.mean_axis(Axis(0)).unwrap();
954 let feature_stds = x.mapv(|val| val * val).mean_axis(Axis(0)).unwrap()
955 - &feature_means.mapv(|mean| mean * mean);
956 let feature_stds = feature_stds.mapv(|var| (var.max(1e-10)).sqrt());
957
958 let kernel = if let Some(gamma) = self.gamma {
960 match self.kernel {
961 SVMKernel::Rbf { .. } => SVMKernel::Rbf { gamma },
962 SVMKernel::Polynomial { degree, coef0, .. } => SVMKernel::Polynomial {
963 degree,
964 gamma,
965 coef0,
966 },
967 SVMKernel::Sigmoid { coef0, .. } => SVMKernel::Sigmoid { gamma, coef0 },
968 other => other,
969 }
970 } else {
971 self.kernel
972 };
973
974 let mut models = Vec::new();
976 for output_idx in 0..n_outputs {
977 let y_output = y.column(output_idx);
978 let model =
979 self.train_single_svm(x, &y_output, &feature_means, &feature_stds, kernel)?;
980 models.push(model);
981 }
982
983 Ok(MultiOutputSVM {
984 state: MultiOutputSVMTrained {
985 models,
986 n_outputs,
987 feature_means,
988 feature_stds,
989 },
990 kernel,
991 c: self.c,
992 epsilon: self.epsilon,
993 gamma: self.gamma,
994 })
995 }
996}
997
998impl MultiOutputSVM<Untrained> {
999 fn train_single_svm(
1000 &self,
1001 x: &ArrayView2<'_, Float>,
1002 y: &ArrayView1<'_, Float>,
1003 feature_means: &Array1<Float>,
1004 feature_stds: &Array1<Float>,
1005 kernel: SVMKernel,
1006 ) -> SklResult<SVMModel> {
1007 let (n_samples, n_features) = x.dim();
1008
1009 let mut x_normalized = x.to_owned();
1011 for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
1012 row -= feature_means;
1013 row /= feature_stds;
1014 }
1015
1016 let support_vectors = x_normalized.clone();
1019 let mut support_coefficients = Array1::<Float>::zeros(n_samples);
1020
1021 let y_mean = y.mean().unwrap();
1023 for i in 0..n_samples {
1024 support_coefficients[i] = (y[i] - y_mean) / self.c;
1025 }
1026
1027 let bias = y_mean;
1028
1029 Ok(SVMModel {
1030 support_vectors,
1031 support_coefficients,
1032 bias,
1033 kernel,
1034 })
1035 }
1036}
1037
1038impl Predict<ArrayView2<'_, Float>, Array2<Float>> for MultiOutputSVM<MultiOutputSVMTrained> {
1039 fn predict(&self, x: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1040 let (n_samples, _) = x.dim();
1041 let mut predictions = Array2::<Float>::zeros((n_samples, self.state.n_outputs));
1042
1043 let mut x_normalized = x.to_owned();
1045 for (i, mut row) in x_normalized.rows_mut().into_iter().enumerate() {
1046 row -= &self.state.feature_means;
1047 row /= &self.state.feature_stds;
1048 }
1049
1050 for output_idx in 0..self.state.n_outputs {
1051 let model = &self.state.models[output_idx];
1052
1053 for sample_idx in 0..n_samples {
1054 let x_sample = x_normalized.row(sample_idx);
1055 let mut prediction = model.bias;
1056
1057 for (sv_idx, support_vector) in model.support_vectors.rows().into_iter().enumerate()
1059 {
1060 let kernel_value =
1061 compute_kernel_value(&x_sample, &support_vector, model.kernel);
1062 prediction += model.support_coefficients[sv_idx] * kernel_value;
1063 }
1064
1065 predictions[[sample_idx, output_idx]] = prediction;
1066 }
1067 }
1068
1069 Ok(predictions)
1070 }
1071}
1072
1073fn compute_kernel_value(
1075 x1: &ArrayView1<Float>,
1076 x2: &ArrayView1<Float>,
1077 kernel: SVMKernel,
1078) -> Float {
1079 match kernel {
1080 SVMKernel::Linear => x1.dot(x2),
1081 SVMKernel::Polynomial {
1082 degree,
1083 gamma,
1084 coef0,
1085 } => (gamma * x1.dot(x2) + coef0).powi(degree),
1086 SVMKernel::Rbf { gamma } => {
1087 let dist_sq = x1
1088 .iter()
1089 .zip(x2.iter())
1090 .map(|(a, b)| (a - b).powi(2))
1091 .sum::<Float>();
1092 (-gamma * dist_sq).exp()
1093 }
1094 SVMKernel::Sigmoid { gamma, coef0 } => (gamma * x1.dot(x2) + coef0).tanh(),
1095 }
1096}