1use crate::qmatrix::{OneClassQ, SvcQ, SvrQ};
7use crate::solver::{SolutionInfo, Solver, SolverVariant};
8use crate::types::*;
9use crate::util::group_classes;
10
11struct DecisionFunction {
13 alpha: Vec<f64>,
14 rho: f64,
15}
16
17fn sign_labels(labels: &[f64]) -> Vec<i8> {
18 labels
19 .iter()
20 .map(|&v| if v > 0.0 { 1 } else { -1 })
21 .collect()
22}
23
24fn solve_c_svc(
27 x: &[Vec<SvmNode>],
28 labels: &[f64],
29 param: &SvmParameter,
30 cp: f64,
31 cn: f64,
32) -> (Vec<f64>, SolutionInfo) {
33 let l = x.len();
34 let mut alpha = vec![0.0; l];
35 let p: Vec<f64> = vec![-1.0; l];
36 let y = sign_labels(labels);
37
38 let q = Box::new(SvcQ::new(x, param, &y));
39 let si = Solver::solve(
40 SolverVariant::Standard,
41 l,
42 q,
43 &p,
44 &y,
45 &mut alpha,
46 cp,
47 cn,
48 param.eps,
49 param.shrinking,
50 );
51
52 for i in 0..l {
54 alpha[i] *= y[i] as f64;
55 }
56
57 (alpha, si)
58}
59
60fn solve_nu_svc(
61 x: &[Vec<SvmNode>],
62 labels: &[f64],
63 param: &SvmParameter,
64) -> (Vec<f64>, SolutionInfo) {
65 let l = x.len();
66 let nu = param.nu;
67 let y = sign_labels(labels);
68
69 let mut alpha = vec![0.0; l];
71 let mut sum_pos = nu * l as f64 / 2.0;
72 let mut sum_neg = nu * l as f64 / 2.0;
73 for i in 0..l {
74 if y[i] == 1 {
75 alpha[i] = f64::min(1.0, sum_pos);
76 sum_pos -= alpha[i];
77 } else {
78 alpha[i] = f64::min(1.0, sum_neg);
79 sum_neg -= alpha[i];
80 }
81 }
82
83 let p = vec![0.0; l];
84 let q = Box::new(SvcQ::new(x, param, &y));
85 let mut si = Solver::solve(
86 SolverVariant::Nu,
87 l,
88 q,
89 &p,
90 &y,
91 &mut alpha,
92 1.0,
93 1.0,
94 param.eps,
95 param.shrinking,
96 );
97
98 let r = si.r;
99 for i in 0..l {
100 alpha[i] *= y[i] as f64 / r;
101 }
102 si.rho /= r;
103 si.obj /= r * r;
104 si.upper_bound_p = 1.0 / r;
105 si.upper_bound_n = 1.0 / r;
106
107 (alpha, si)
108}
109
110fn solve_one_class(x: &[Vec<SvmNode>], param: &SvmParameter) -> (Vec<f64>, SolutionInfo) {
111 let l = x.len();
112
113 let n = (param.nu * l as f64) as usize;
115 let mut alpha = vec![0.0; l];
116 for a in alpha.iter_mut().take(n.min(l)) {
117 *a = 1.0;
118 }
119 if n < l {
120 alpha[n] = param.nu * l as f64 - n as f64;
121 }
122
123 let p = vec![0.0; l];
124 let y = vec![1i8; l];
125 let q = Box::new(OneClassQ::new(x, param));
126 let si = Solver::solve(
127 SolverVariant::Standard,
128 l,
129 q,
130 &p,
131 &y,
132 &mut alpha,
133 1.0,
134 1.0,
135 param.eps,
136 param.shrinking,
137 );
138
139 (alpha, si)
140}
141
142fn solve_epsilon_svr(
143 x: &[Vec<SvmNode>],
144 labels: &[f64],
145 param: &SvmParameter,
146) -> (Vec<f64>, SolutionInfo) {
147 let l = x.len();
148 let mut alpha2 = vec![0.0; 2 * l];
149 let mut linear_term = vec![0.0; 2 * l];
150 let mut y = vec![0i8; 2 * l];
151
152 for i in 0..l {
153 linear_term[i] = param.p - labels[i];
154 y[i] = 1;
155 linear_term[i + l] = param.p + labels[i];
156 y[i + l] = -1;
157 }
158
159 let q = Box::new(SvrQ::new(x, param));
160 let si = Solver::solve(
161 SolverVariant::Standard,
162 2 * l,
163 q,
164 &linear_term,
165 &y,
166 &mut alpha2,
167 param.c,
168 param.c,
169 param.eps,
170 param.shrinking,
171 );
172
173 let mut alpha = vec![0.0; l];
174 for i in 0..l {
175 alpha[i] = alpha2[i] - alpha2[i + l];
176 }
177
178 (alpha, si)
179}
180
181fn solve_nu_svr(
182 x: &[Vec<SvmNode>],
183 labels: &[f64],
184 param: &SvmParameter,
185) -> (Vec<f64>, SolutionInfo) {
186 let l = x.len();
187 let c = param.c;
188 let mut alpha2 = vec![0.0; 2 * l];
189 let mut linear_term = vec![0.0; 2 * l];
190 let mut y = vec![0i8; 2 * l];
191
192 let mut sum = c * param.nu * l as f64 / 2.0;
193 for i in 0..l {
194 let a = f64::min(sum, c);
195 alpha2[i] = a;
196 alpha2[i + l] = a;
197 sum -= a;
198
199 linear_term[i] = -labels[i];
200 y[i] = 1;
201 linear_term[i + l] = labels[i];
202 y[i + l] = -1;
203 }
204
205 let q = Box::new(SvrQ::new(x, param));
206 let si = Solver::solve(
207 SolverVariant::Nu,
208 2 * l,
209 q,
210 &linear_term,
211 &y,
212 &mut alpha2,
213 c,
214 c,
215 param.eps,
216 param.shrinking,
217 );
218
219 let mut alpha = vec![0.0; l];
220 for i in 0..l {
221 alpha[i] = alpha2[i] - alpha2[i + l];
222 }
223
224 (alpha, si)
225}
226
227fn svm_train_one(
230 x: &[Vec<SvmNode>],
231 labels: &[f64],
232 param: &SvmParameter,
233 cp: f64,
234 cn: f64,
235) -> DecisionFunction {
236 let (alpha, si) = match param.svm_type {
237 SvmType::CSvc => solve_c_svc(x, labels, param, cp, cn),
238 SvmType::NuSvc => solve_nu_svc(x, labels, param),
239 SvmType::OneClass => solve_one_class(x, param),
240 SvmType::EpsilonSvr => solve_epsilon_svr(x, labels, param),
241 SvmType::NuSvr => solve_nu_svr(x, labels, param),
242 };
243
244 crate::info(&format!("obj = {:.6}, rho = {:.6}\n", si.obj, si.rho));
245
246 let n_sv = alpha.iter().filter(|a| a.abs() > 0.0).count();
248 let n_bsv = alpha
249 .iter()
250 .enumerate()
251 .filter(|&(i, a)| {
252 if a.abs() > 0.0 {
253 if labels[i] > 0.0 {
254 a.abs() >= si.upper_bound_p
255 } else {
256 a.abs() >= si.upper_bound_n
257 }
258 } else {
259 false
260 }
261 })
262 .count();
263 crate::info(&format!("nSV = {}, nBSV = {}\n", n_sv, n_bsv));
264
265 DecisionFunction { alpha, rho: si.rho }
266}
267
268fn mark_nonzero_indices(nonzero: &mut [bool], start: usize, alphas: &[f64]) {
269 for (offset, &alpha) in alphas.iter().enumerate() {
270 let idx = start + offset;
271 if !nonzero[idx] && alpha.abs() > 0.0 {
272 nonzero[idx] = true;
273 }
274 }
275}
276
277fn count_nonzero(nonzero: &[bool], start: usize, len: usize) -> usize {
278 nonzero[start..start + len]
279 .iter()
280 .filter(|&&is_nonzero| is_nonzero)
281 .count()
282}
283
284pub fn svm_train(problem: &SvmProblem, param: &SvmParameter) -> SvmModel {
291 let mut param = param.clone();
293 if param.gamma == 0.0 && !problem.instances.is_empty() {
294 let max_index = problem
295 .instances
296 .iter()
297 .flat_map(|inst| inst.iter())
298 .map(|n| n.index)
299 .max()
300 .unwrap_or(0);
301 if max_index > 0 {
302 param.gamma = 1.0 / max_index as f64;
303 }
304 }
305
306 match param.svm_type {
307 SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
308 train_regression_or_one_class(problem, ¶m)
309 }
310 SvmType::CSvc | SvmType::NuSvc => train_classification(problem, ¶m),
311 }
312}
313
314fn train_regression_or_one_class(problem: &SvmProblem, param: &SvmParameter) -> SvmModel {
315 let f = svm_train_one(&problem.instances, &problem.labels, param, 0.0, 0.0);
316
317 let mut sv = Vec::new();
319 let mut sv_coef = Vec::new();
320 let mut sv_indices = Vec::new();
321
322 for i in 0..problem.instances.len() {
323 if f.alpha[i].abs() > 0.0 {
324 sv.push(problem.instances[i].clone());
325 sv_coef.push(f.alpha[i]);
326 sv_indices.push(i + 1); }
328 }
329
330 let mut model = SvmModel {
331 param: param.clone(),
332 nr_class: 2,
333 sv,
334 sv_coef: vec![sv_coef],
335 rho: vec![f.rho],
336 prob_a: Vec::new(),
337 prob_b: Vec::new(),
338 prob_density_marks: Vec::new(),
339 sv_indices,
340 label: Vec::new(),
341 n_sv: Vec::new(),
342 };
343
344 if param.probability {
346 match param.svm_type {
347 SvmType::EpsilonSvr | SvmType::NuSvr => {
348 model.prob_a = vec![crate::probability::svm_svr_probability(problem, param)];
349 }
350 SvmType::OneClass => {
351 if let Some(marks) = crate::probability::svm_one_class_probability(problem, &model)
352 {
353 model.prob_density_marks = marks;
354 }
355 }
356 _ => {}
357 }
358 }
359
360 model
361}
362
363fn train_classification(problem: &SvmProblem, param: &SvmParameter) -> SvmModel {
364 let l = problem.instances.len();
365 let group = group_classes(&problem.labels);
366 let nr_class = group.label.len();
367
368 if nr_class == 1 {
369 crate::info("WARNING: training data in only one class. See README for details.\n");
370 }
371
372 let x: Vec<&Vec<SvmNode>> = (0..l).map(|i| &problem.instances[group.perm[i]]).collect();
374
375 let mut weighted_c = vec![param.c; nr_class];
377 for &(wlabel, wval) in ¶m.weight {
378 if let Some(j) = group.label.iter().position(|&lab| lab == wlabel) {
379 weighted_c[j] *= wval;
380 } else {
381 crate::info(&format!(
382 "WARNING: class label {} specified in weight is not found\n",
383 wlabel
384 ));
385 }
386 }
387
388 let mut nonzero = vec![false; l];
390 let n_pairs = nr_class * (nr_class - 1) / 2;
391 let mut decisions = Vec::with_capacity(n_pairs);
392
393 let mut prob_a = Vec::new();
395 let mut prob_b = Vec::new();
396 if param.probability {
397 prob_a.reserve(n_pairs);
398 prob_b.reserve(n_pairs);
399 }
400
401 for i in 0..nr_class {
402 for j in (i + 1)..nr_class {
403 let si = group.start[i];
404 let sj = group.start[j];
405 let ci = group.count[i];
406 let cj = group.count[j];
407
408 let mut sub_x = Vec::with_capacity(ci + cj);
410 let mut sub_labels = Vec::with_capacity(ci + cj);
411 for k in 0..ci {
412 sub_x.push(x[si + k].clone());
413 sub_labels.push(1.0);
414 }
415 for k in 0..cj {
416 sub_x.push(x[sj + k].clone());
417 sub_labels.push(-1.0);
418 }
419
420 if param.probability {
422 let sub_prob = SvmProblem {
423 labels: sub_labels.clone(),
424 instances: sub_x.clone(),
425 };
426 let (pa, pb) = crate::probability::svm_binary_svc_probability(
427 &sub_prob,
428 param,
429 weighted_c[i],
430 weighted_c[j],
431 );
432 prob_a.push(pa);
433 prob_b.push(pb);
434 }
435
436 let f = svm_train_one(&sub_x, &sub_labels, param, weighted_c[i], weighted_c[j]);
437
438 mark_nonzero_indices(&mut nonzero, si, &f.alpha[..ci]);
440 mark_nonzero_indices(&mut nonzero, sj, &f.alpha[ci..(ci + cj)]);
441
442 decisions.push(f);
443 }
444 }
445
446 let labels: Vec<i32> = group.label.clone();
448 let rho: Vec<f64> = decisions.iter().map(|d| d.rho).collect();
449
450 let mut total_sv = 0;
452 let mut n_sv_per_class = vec![0usize; nr_class];
453 for (i, n_sv) in n_sv_per_class.iter_mut().enumerate().take(nr_class) {
454 let n = count_nonzero(&nonzero, group.start[i], group.count[i]);
455 total_sv += n;
456 *n_sv = n;
457 }
458
459 crate::info(&format!("Total nSV = {}\n", total_sv));
460
461 let mut model_sv = Vec::with_capacity(total_sv);
463 let mut model_sv_indices = Vec::with_capacity(total_sv);
464 for i in 0..l {
465 if nonzero[i] {
466 model_sv.push(x[i].clone());
467 model_sv_indices.push(group.perm[i] + 1); }
469 }
470
471 let mut nz_start = vec![0usize; nr_class];
473 for i in 1..nr_class {
474 nz_start[i] = nz_start[i - 1] + n_sv_per_class[i - 1];
475 }
476
477 let mut sv_coef = vec![vec![0.0; total_sv]; nr_class - 1];
479
480 {
481 let mut p = 0;
482 for i in 0..nr_class {
483 for j in (i + 1)..nr_class {
484 let si = group.start[i];
485 let sj = group.start[j];
486 let ci = group.count[i];
487 let cj = group.count[j];
488
489 let mut q = nz_start[i];
491 for k in 0..ci {
492 if nonzero[si + k] {
493 sv_coef[j - 1][q] = decisions[p].alpha[k];
494 q += 1;
495 }
496 }
497
498 q = nz_start[j];
500 for k in 0..cj {
501 if nonzero[sj + k] {
502 sv_coef[i][q] = decisions[p].alpha[ci + k];
503 q += 1;
504 }
505 }
506
507 p += 1;
508 }
509 }
510 }
511
512 SvmModel {
513 param: param.clone(),
514 nr_class,
515 sv: model_sv,
516 sv_coef,
517 rho,
518 prob_a,
519 prob_b,
520 prob_density_marks: Vec::new(),
521 sv_indices: model_sv_indices,
522 label: labels,
523 n_sv: n_sv_per_class,
524 }
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530 use crate::io::{load_model, load_problem};
531 use crate::predict::predict;
532 use std::path::PathBuf;
533
534 fn data_dir() -> PathBuf {
535 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
536 .join("..")
537 .join("..")
538 .join("data")
539 }
540
541 #[test]
542 fn train_c_svc_heart_scale() {
543 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
544 let param = SvmParameter {
545 svm_type: SvmType::CSvc,
546 kernel_type: KernelType::Rbf,
547 gamma: 1.0 / 13.0,
548 c: 1.0,
549 cache_size: 100.0,
550 eps: 0.001,
551 shrinking: true,
552 ..Default::default()
553 };
554
555 let model = svm_train(&problem, ¶m);
556
557 assert_eq!(model.nr_class, 2);
559 assert_eq!(model.label, vec![1, -1]);
560 assert!(!model.sv.is_empty(), "model has no support vectors");
561
562 let ref_model = load_model(&data_dir().join("heart_scale_ref.model")).unwrap();
564
565 let sv_diff = (model.sv.len() as i64 - ref_model.sv.len() as i64).unsigned_abs();
567 assert!(
568 sv_diff <= 2,
569 "SV count mismatch: Rust={}, C={}",
570 model.sv.len(),
571 ref_model.sv.len()
572 );
573
574 assert!(
576 (model.rho[0] - ref_model.rho[0]).abs() < 1e-4,
577 "rho mismatch: Rust={}, C={}",
578 model.rho[0],
579 ref_model.rho[0]
580 );
581
582 let mut correct = 0;
584 for (i, instance) in problem.instances.iter().enumerate() {
585 let pred = predict(&model, instance);
586 if pred == problem.labels[i] {
587 correct += 1;
588 }
589 }
590 let accuracy = correct as f64 / problem.labels.len() as f64;
591 assert!(
592 accuracy > 0.85,
593 "training accuracy {:.2}% too low",
594 accuracy * 100.0
595 );
596
597 let mut mismatches = 0;
599 for instance in &problem.instances {
600 let rust_pred = predict(&model, instance);
601 let c_pred = predict(&ref_model, instance);
602 if rust_pred != c_pred {
603 mismatches += 1;
604 }
605 }
606 assert!(
607 mismatches <= 3,
608 "{} prediction mismatches between Rust-trained and C-trained models",
609 mismatches
610 );
611 }
612
613 #[test]
614 fn train_c_svc_iris_multiclass() {
615 let problem = load_problem(&data_dir().join("iris.scale")).unwrap();
616 let param = SvmParameter {
617 svm_type: SvmType::CSvc,
618 kernel_type: KernelType::Rbf,
619 gamma: 0.25, c: 1.0,
621 cache_size: 100.0,
622 eps: 0.001,
623 shrinking: true,
624 ..Default::default()
625 };
626
627 let model = svm_train(&problem, ¶m);
628
629 assert_eq!(model.nr_class, 3);
631 assert_eq!(model.label.len(), 3);
632 assert_eq!(model.rho.len(), 3);
634 assert_eq!(model.sv_coef.len(), 2);
636 assert_eq!(model.n_sv.len(), 3);
638
639 let mut correct = 0;
641 for (i, instance) in problem.instances.iter().enumerate() {
642 let pred = predict(&model, instance);
643 if pred == problem.labels[i] {
644 correct += 1;
645 }
646 }
647 let accuracy = correct as f64 / problem.labels.len() as f64;
648 assert!(
649 accuracy > 0.95,
650 "iris accuracy {:.2}% too low (expected >95%)",
651 accuracy * 100.0
652 );
653 }
654
655 #[test]
656 fn train_c_svc_precomputed_kernel() {
657 let problem = load_problem(&data_dir().join("heart_scale.precomputed")).unwrap();
658 let param = SvmParameter {
659 svm_type: SvmType::CSvc,
660 kernel_type: KernelType::Precomputed,
661 c: 1.0,
662 cache_size: 100.0,
663 eps: 0.001,
664 shrinking: true,
665 ..Default::default()
666 };
667
668 let model = svm_train(&problem, ¶m);
669
670 assert_eq!(model.nr_class, 2);
671 assert!(!model.sv.is_empty(), "model has no support vectors");
672
673 let mut correct = 0;
675 for (i, instance) in problem.instances.iter().enumerate() {
676 let pred = predict(&model, instance);
677 if pred == problem.labels[i] {
678 correct += 1;
679 }
680 }
681 let accuracy = correct as f64 / problem.labels.len() as f64;
682 assert!(
683 accuracy > 0.70,
684 "precomputed-kernel accuracy {:.2}% too low",
685 accuracy * 100.0
686 );
687 }
688
689 #[test]
690 fn train_one_class() {
691 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
692 let param = SvmParameter {
693 svm_type: SvmType::OneClass,
694 kernel_type: KernelType::Rbf,
695 gamma: 1.0 / 13.0,
696 nu: 0.5,
697 cache_size: 100.0,
698 eps: 0.001,
699 shrinking: true,
700 ..Default::default()
701 };
702
703 let model = svm_train(&problem, ¶m);
704
705 assert_eq!(model.nr_class, 2);
706 assert!(!model.sv.is_empty());
707 assert_eq!(model.rho.len(), 1);
708
709 let mut inliers = 0;
711 for instance in &problem.instances {
712 let pred = predict(&model, instance);
713 if pred > 0.0 {
714 inliers += 1;
715 }
716 }
717 let inlier_rate = inliers as f64 / problem.instances.len() as f64;
718 assert!(
720 inlier_rate > 0.3 && inlier_rate < 0.9,
721 "unexpected inlier rate: {:.2}%",
722 inlier_rate * 100.0
723 );
724 }
725
726 #[test]
727 fn train_epsilon_svr() {
728 let problem = load_problem(&data_dir().join("housing_scale")).unwrap();
729 let param = SvmParameter {
730 svm_type: SvmType::EpsilonSvr,
731 kernel_type: KernelType::Rbf,
732 gamma: 1.0 / 13.0,
733 c: 1.0,
734 p: 0.1,
735 cache_size: 100.0,
736 eps: 0.001,
737 shrinking: true,
738 ..Default::default()
739 };
740
741 let model = svm_train(&problem, ¶m);
742
743 assert_eq!(model.nr_class, 2); assert!(!model.sv.is_empty());
745
746 let mut mse = 0.0;
748 for (i, instance) in problem.instances.iter().enumerate() {
749 let pred = predict(&model, instance);
750 let err = pred - problem.labels[i];
751 mse += err * err;
752 }
753 mse /= problem.instances.len() as f64;
754
755 assert!(mse.is_finite(), "MSE is not finite");
757 assert!(mse < 100.0, "MSE too high: {}", mse);
758 }
759
760 #[test]
761 fn train_nu_svc() {
762 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
763 let param = SvmParameter {
764 svm_type: SvmType::NuSvc,
765 kernel_type: KernelType::Rbf,
766 gamma: 1.0 / 13.0,
767 nu: 0.5,
768 cache_size: 100.0,
769 eps: 0.001,
770 shrinking: true,
771 ..Default::default()
772 };
773
774 let model = svm_train(&problem, ¶m);
775
776 assert_eq!(model.nr_class, 2);
777 assert!(!model.sv.is_empty());
778
779 let mut correct = 0;
780 for (i, instance) in problem.instances.iter().enumerate() {
781 let pred = predict(&model, instance);
782 if pred == problem.labels[i] {
783 correct += 1;
784 }
785 }
786 let accuracy = correct as f64 / problem.labels.len() as f64;
787 assert!(
788 accuracy > 0.70,
789 "nu-SVC accuracy {:.2}% too low",
790 accuracy * 100.0
791 );
792 }
793
794 #[test]
795 fn train_csvc_with_probability() {
796 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
797 let param = SvmParameter {
798 svm_type: SvmType::CSvc,
799 kernel_type: KernelType::Rbf,
800 gamma: 1.0 / 13.0,
801 c: 1.0,
802 cache_size: 100.0,
803 eps: 0.001,
804 shrinking: true,
805 probability: true,
806 ..Default::default()
807 };
808
809 let model = svm_train(&problem, ¶m);
810
811 assert_eq!(model.nr_class, 2);
812 assert_eq!(model.prob_a.len(), 1, "binary should have 1 probA");
813 assert_eq!(model.prob_b.len(), 1, "binary should have 1 probB");
814 assert!(model.prob_a[0].is_finite());
815 assert!(model.prob_b[0].is_finite());
816 }
817
818 #[test]
819 fn train_nu_svr() {
820 let problem = load_problem(&data_dir().join("housing_scale")).unwrap();
821 let param = SvmParameter {
822 svm_type: SvmType::NuSvr,
823 kernel_type: KernelType::Rbf,
824 gamma: 1.0 / 13.0,
825 c: 1.0,
826 nu: 0.5,
827 cache_size: 100.0,
828 eps: 0.001,
829 shrinking: true,
830 ..Default::default()
831 };
832
833 let model = svm_train(&problem, ¶m);
834
835 assert_eq!(model.nr_class, 2);
836 assert!(!model.sv.is_empty());
837
838 let mut mse = 0.0;
839 for (i, instance) in problem.instances.iter().enumerate() {
840 let pred = predict(&model, instance);
841 let err = pred - problem.labels[i];
842 mse += err * err;
843 }
844 mse /= problem.instances.len() as f64;
845
846 assert!(mse.is_finite(), "MSE is not finite");
847 assert!(mse < 200.0, "MSE too high: {}", mse);
848 }
849}