Skip to main content

libsvm_rs/
train.rs

1//! SVM training pipeline.
2//!
3//! Provides `svm_train` which produces an `SvmModel` from an `SvmProblem`
4//! and `SvmParameter`. Matches the original LIBSVM's `svm_train` function.
5
6use crate::qmatrix::{OneClassQ, SvcQ, SvrQ};
7use crate::solver::{SolutionInfo, Solver, SolverVariant};
8use crate::types::*;
9use crate::util::group_classes;
10
11/// Internal decision function result from one binary sub-problem.
12struct DecisionFunction {
13    alpha: Vec<f64>,
14    rho: f64,
15}
16
17fn sign_labels(labels: &[f64]) -> Vec<i8> {
18    labels
19        .iter()
20        .map(|&v| if v > 0.0 { 1 } else { -1 })
21        .collect()
22}
23
24// ─── Solve dispatchers ──────────────────────────────────────────────
25
26fn solve_c_svc(
27    x: &[Vec<SvmNode>],
28    labels: &[f64],
29    param: &SvmParameter,
30    cp: f64,
31    cn: f64,
32) -> (Vec<f64>, SolutionInfo) {
33    let l = x.len();
34    let mut alpha = vec![0.0; l];
35    let p: Vec<f64> = vec![-1.0; l];
36    let y = sign_labels(labels);
37
38    let q = Box::new(SvcQ::new(x, param, &y));
39    let si = Solver::solve(
40        SolverVariant::Standard,
41        l,
42        q,
43        &p,
44        &y,
45        &mut alpha,
46        cp,
47        cn,
48        param.eps,
49        param.shrinking,
50    );
51
52    // Multiply alpha by y to get signed coefficients
53    for i in 0..l {
54        alpha[i] *= y[i] as f64;
55    }
56
57    (alpha, si)
58}
59
60fn solve_nu_svc(
61    x: &[Vec<SvmNode>],
62    labels: &[f64],
63    param: &SvmParameter,
64) -> (Vec<f64>, SolutionInfo) {
65    let l = x.len();
66    let nu = param.nu;
67    let y = sign_labels(labels);
68
69    // Initialize alpha: spread nu*l/2 among positive and negative samples
70    let mut alpha = vec![0.0; l];
71    let mut sum_pos = nu * l as f64 / 2.0;
72    let mut sum_neg = nu * l as f64 / 2.0;
73    for i in 0..l {
74        if y[i] == 1 {
75            alpha[i] = f64::min(1.0, sum_pos);
76            sum_pos -= alpha[i];
77        } else {
78            alpha[i] = f64::min(1.0, sum_neg);
79            sum_neg -= alpha[i];
80        }
81    }
82
83    let p = vec![0.0; l];
84    let q = Box::new(SvcQ::new(x, param, &y));
85    let mut si = Solver::solve(
86        SolverVariant::Nu,
87        l,
88        q,
89        &p,
90        &y,
91        &mut alpha,
92        1.0,
93        1.0,
94        param.eps,
95        param.shrinking,
96    );
97
98    let r = si.r;
99    for i in 0..l {
100        alpha[i] *= y[i] as f64 / r;
101    }
102    si.rho /= r;
103    si.obj /= r * r;
104    si.upper_bound_p = 1.0 / r;
105    si.upper_bound_n = 1.0 / r;
106
107    (alpha, si)
108}
109
110fn solve_one_class(x: &[Vec<SvmNode>], param: &SvmParameter) -> (Vec<f64>, SolutionInfo) {
111    let l = x.len();
112
113    // Initialize alpha: first n=floor(nu*l) at 1, fractional remainder, rest 0
114    let n = (param.nu * l as f64) as usize;
115    let mut alpha = vec![0.0; l];
116    for a in alpha.iter_mut().take(n.min(l)) {
117        *a = 1.0;
118    }
119    if n < l {
120        alpha[n] = param.nu * l as f64 - n as f64;
121    }
122
123    let p = vec![0.0; l];
124    let y = vec![1i8; l];
125    let q = Box::new(OneClassQ::new(x, param));
126    let si = Solver::solve(
127        SolverVariant::Standard,
128        l,
129        q,
130        &p,
131        &y,
132        &mut alpha,
133        1.0,
134        1.0,
135        param.eps,
136        param.shrinking,
137    );
138
139    (alpha, si)
140}
141
142fn solve_epsilon_svr(
143    x: &[Vec<SvmNode>],
144    labels: &[f64],
145    param: &SvmParameter,
146) -> (Vec<f64>, SolutionInfo) {
147    let l = x.len();
148    let mut alpha2 = vec![0.0; 2 * l];
149    let mut linear_term = vec![0.0; 2 * l];
150    let mut y = vec![0i8; 2 * l];
151
152    for i in 0..l {
153        linear_term[i] = param.p - labels[i];
154        y[i] = 1;
155        linear_term[i + l] = param.p + labels[i];
156        y[i + l] = -1;
157    }
158
159    let q = Box::new(SvrQ::new(x, param));
160    let si = Solver::solve(
161        SolverVariant::Standard,
162        2 * l,
163        q,
164        &linear_term,
165        &y,
166        &mut alpha2,
167        param.c,
168        param.c,
169        param.eps,
170        param.shrinking,
171    );
172
173    let mut alpha = vec![0.0; l];
174    for i in 0..l {
175        alpha[i] = alpha2[i] - alpha2[i + l];
176    }
177
178    (alpha, si)
179}
180
181fn solve_nu_svr(
182    x: &[Vec<SvmNode>],
183    labels: &[f64],
184    param: &SvmParameter,
185) -> (Vec<f64>, SolutionInfo) {
186    let l = x.len();
187    let c = param.c;
188    let mut alpha2 = vec![0.0; 2 * l];
189    let mut linear_term = vec![0.0; 2 * l];
190    let mut y = vec![0i8; 2 * l];
191
192    let mut sum = c * param.nu * l as f64 / 2.0;
193    for i in 0..l {
194        let a = f64::min(sum, c);
195        alpha2[i] = a;
196        alpha2[i + l] = a;
197        sum -= a;
198
199        linear_term[i] = -labels[i];
200        y[i] = 1;
201        linear_term[i + l] = labels[i];
202        y[i + l] = -1;
203    }
204
205    let q = Box::new(SvrQ::new(x, param));
206    let si = Solver::solve(
207        SolverVariant::Nu,
208        2 * l,
209        q,
210        &linear_term,
211        &y,
212        &mut alpha2,
213        c,
214        c,
215        param.eps,
216        param.shrinking,
217    );
218
219    let mut alpha = vec![0.0; l];
220    for i in 0..l {
221        alpha[i] = alpha2[i] - alpha2[i + l];
222    }
223
224    (alpha, si)
225}
226
227// ─── svm_train_one ──────────────────────────────────────────────────
228
229fn svm_train_one(
230    x: &[Vec<SvmNode>],
231    labels: &[f64],
232    param: &SvmParameter,
233    cp: f64,
234    cn: f64,
235) -> DecisionFunction {
236    let (alpha, si) = match param.svm_type {
237        SvmType::CSvc => solve_c_svc(x, labels, param, cp, cn),
238        SvmType::NuSvc => solve_nu_svc(x, labels, param),
239        SvmType::OneClass => solve_one_class(x, param),
240        SvmType::EpsilonSvr => solve_epsilon_svr(x, labels, param),
241        SvmType::NuSvr => solve_nu_svr(x, labels, param),
242    };
243
244    crate::info(&format!("obj = {:.6}, rho = {:.6}\n", si.obj, si.rho));
245
246    // Count SVs
247    let n_sv = alpha.iter().filter(|a| a.abs() > 0.0).count();
248    let n_bsv = alpha
249        .iter()
250        .enumerate()
251        .filter(|&(i, a)| {
252            if a.abs() > 0.0 {
253                if labels[i] > 0.0 {
254                    a.abs() >= si.upper_bound_p
255                } else {
256                    a.abs() >= si.upper_bound_n
257                }
258            } else {
259                false
260            }
261        })
262        .count();
263    crate::info(&format!("nSV = {}, nBSV = {}\n", n_sv, n_bsv));
264
265    DecisionFunction { alpha, rho: si.rho }
266}
267
268fn mark_nonzero_indices(nonzero: &mut [bool], start: usize, alphas: &[f64]) {
269    for (offset, &alpha) in alphas.iter().enumerate() {
270        let idx = start + offset;
271        if !nonzero[idx] && alpha.abs() > 0.0 {
272            nonzero[idx] = true;
273        }
274    }
275}
276
277fn count_nonzero(nonzero: &[bool], start: usize, len: usize) -> usize {
278    nonzero[start..start + len]
279        .iter()
280        .filter(|&&is_nonzero| is_nonzero)
281        .count()
282}
283
284// ─── svm_train ──────────────────────────────────────────────────────
285
286/// Train an SVM model from a problem and parameters.
287///
288/// Matches LIBSVM's `svm_train` function. Produces an `SvmModel` that
289/// can be used for prediction or saved to a file.
290pub fn svm_train(problem: &SvmProblem, param: &SvmParameter) -> SvmModel {
291    // Compute effective gamma if zero
292    let mut param = param.clone();
293    if param.gamma == 0.0 && !problem.instances.is_empty() {
294        let max_index = problem
295            .instances
296            .iter()
297            .flat_map(|inst| inst.iter())
298            .map(|n| n.index)
299            .max()
300            .unwrap_or(0);
301        if max_index > 0 {
302            param.gamma = 1.0 / max_index as f64;
303        }
304    }
305
306    match param.svm_type {
307        SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
308            train_regression_or_one_class(problem, &param)
309        }
310        SvmType::CSvc | SvmType::NuSvc => train_classification(problem, &param),
311    }
312}
313
314fn train_regression_or_one_class(problem: &SvmProblem, param: &SvmParameter) -> SvmModel {
315    let f = svm_train_one(&problem.instances, &problem.labels, param, 0.0, 0.0);
316
317    // Extract support vectors
318    let mut sv = Vec::new();
319    let mut sv_coef = Vec::new();
320    let mut sv_indices = Vec::new();
321
322    for i in 0..problem.instances.len() {
323        if f.alpha[i].abs() > 0.0 {
324            sv.push(problem.instances[i].clone());
325            sv_coef.push(f.alpha[i]);
326            sv_indices.push(i + 1); // 1-based
327        }
328    }
329
330    let mut model = SvmModel {
331        param: param.clone(),
332        nr_class: 2,
333        sv,
334        sv_coef: vec![sv_coef],
335        rho: vec![f.rho],
336        prob_a: Vec::new(),
337        prob_b: Vec::new(),
338        prob_density_marks: Vec::new(),
339        sv_indices,
340        label: Vec::new(),
341        n_sv: Vec::new(),
342    };
343
344    // Probability estimates
345    if param.probability {
346        match param.svm_type {
347            SvmType::EpsilonSvr | SvmType::NuSvr => {
348                model.prob_a = vec![crate::probability::svm_svr_probability(problem, param)];
349            }
350            SvmType::OneClass => {
351                if let Some(marks) = crate::probability::svm_one_class_probability(problem, &model)
352                {
353                    model.prob_density_marks = marks;
354                }
355            }
356            _ => {}
357        }
358    }
359
360    model
361}
362
363fn train_classification(problem: &SvmProblem, param: &SvmParameter) -> SvmModel {
364    let l = problem.instances.len();
365    let group = group_classes(&problem.labels);
366    let nr_class = group.label.len();
367
368    if nr_class == 1 {
369        crate::info("WARNING: training data in only one class. See README for details.\n");
370    }
371
372    // Reorder instances by class
373    let x: Vec<&Vec<SvmNode>> = (0..l).map(|i| &problem.instances[group.perm[i]]).collect();
374
375    // Calculate weighted C
376    let mut weighted_c = vec![param.c; nr_class];
377    for &(wlabel, wval) in &param.weight {
378        if let Some(j) = group.label.iter().position(|&lab| lab == wlabel) {
379            weighted_c[j] *= wval;
380        } else {
381            crate::info(&format!(
382                "WARNING: class label {} specified in weight is not found\n",
383                wlabel
384            ));
385        }
386    }
387
388    // Train k*(k-1)/2 binary classifiers
389    let mut nonzero = vec![false; l];
390    let n_pairs = nr_class * (nr_class - 1) / 2;
391    let mut decisions = Vec::with_capacity(n_pairs);
392
393    // Probability arrays (filled only when param.probability is set)
394    let mut prob_a = Vec::new();
395    let mut prob_b = Vec::new();
396    if param.probability {
397        prob_a.reserve(n_pairs);
398        prob_b.reserve(n_pairs);
399    }
400
401    for i in 0..nr_class {
402        for j in (i + 1)..nr_class {
403            let si = group.start[i];
404            let sj = group.start[j];
405            let ci = group.count[i];
406            let cj = group.count[j];
407
408            // Build sub-problem
409            let mut sub_x = Vec::with_capacity(ci + cj);
410            let mut sub_labels = Vec::with_capacity(ci + cj);
411            for k in 0..ci {
412                sub_x.push(x[si + k].clone());
413                sub_labels.push(1.0);
414            }
415            for k in 0..cj {
416                sub_x.push(x[sj + k].clone());
417                sub_labels.push(-1.0);
418            }
419
420            // Probability estimates via internal 5-fold CV (before final training)
421            if param.probability {
422                let sub_prob = SvmProblem {
423                    labels: sub_labels.clone(),
424                    instances: sub_x.clone(),
425                };
426                let (pa, pb) = crate::probability::svm_binary_svc_probability(
427                    &sub_prob,
428                    param,
429                    weighted_c[i],
430                    weighted_c[j],
431                );
432                prob_a.push(pa);
433                prob_b.push(pb);
434            }
435
436            let f = svm_train_one(&sub_x, &sub_labels, param, weighted_c[i], weighted_c[j]);
437
438            // Mark nonzero alphas
439            mark_nonzero_indices(&mut nonzero, si, &f.alpha[..ci]);
440            mark_nonzero_indices(&mut nonzero, sj, &f.alpha[ci..(ci + cj)]);
441
442            decisions.push(f);
443        }
444    }
445
446    // Build model output
447    let labels: Vec<i32> = group.label.clone();
448    let rho: Vec<f64> = decisions.iter().map(|d| d.rho).collect();
449
450    // Count SVs per class
451    let mut total_sv = 0;
452    let mut n_sv_per_class = vec![0usize; nr_class];
453    for (i, n_sv) in n_sv_per_class.iter_mut().enumerate().take(nr_class) {
454        let n = count_nonzero(&nonzero, group.start[i], group.count[i]);
455        total_sv += n;
456        *n_sv = n;
457    }
458
459    crate::info(&format!("Total nSV = {}\n", total_sv));
460
461    // Collect SVs and indices
462    let mut model_sv = Vec::with_capacity(total_sv);
463    let mut model_sv_indices = Vec::with_capacity(total_sv);
464    for i in 0..l {
465        if nonzero[i] {
466            model_sv.push(x[i].clone());
467            model_sv_indices.push(group.perm[i] + 1); // 1-based original index
468        }
469    }
470
471    // Build nz_start (cumulative start of nonzero SVs per class)
472    let mut nz_start = vec![0usize; nr_class];
473    for i in 1..nr_class {
474        nz_start[i] = nz_start[i - 1] + n_sv_per_class[i - 1];
475    }
476
477    // Build sv_coef matrix: (nr_class - 1) rows × total_sv columns
478    let mut sv_coef = vec![vec![0.0; total_sv]; nr_class - 1];
479
480    {
481        let mut p = 0;
482        for i in 0..nr_class {
483            for j in (i + 1)..nr_class {
484                let si = group.start[i];
485                let sj = group.start[j];
486                let ci = group.count[i];
487                let cj = group.count[j];
488
489                // Coefficients for class i's SVs go in sv_coef[j-1]
490                let mut q = nz_start[i];
491                for k in 0..ci {
492                    if nonzero[si + k] {
493                        sv_coef[j - 1][q] = decisions[p].alpha[k];
494                        q += 1;
495                    }
496                }
497
498                // Coefficients for class j's SVs go in sv_coef[i]
499                q = nz_start[j];
500                for k in 0..cj {
501                    if nonzero[sj + k] {
502                        sv_coef[i][q] = decisions[p].alpha[ci + k];
503                        q += 1;
504                    }
505                }
506
507                p += 1;
508            }
509        }
510    }
511
512    SvmModel {
513        param: param.clone(),
514        nr_class,
515        sv: model_sv,
516        sv_coef,
517        rho,
518        prob_a,
519        prob_b,
520        prob_density_marks: Vec::new(),
521        sv_indices: model_sv_indices,
522        label: labels,
523        n_sv: n_sv_per_class,
524    }
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530    use crate::io::{load_model, load_problem};
531    use crate::predict::predict;
532    use std::path::PathBuf;
533
534    fn data_dir() -> PathBuf {
535        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
536            .join("..")
537            .join("..")
538            .join("data")
539    }
540
541    #[test]
542    fn train_c_svc_heart_scale() {
543        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
544        let param = SvmParameter {
545            svm_type: SvmType::CSvc,
546            kernel_type: KernelType::Rbf,
547            gamma: 1.0 / 13.0,
548            c: 1.0,
549            cache_size: 100.0,
550            eps: 0.001,
551            shrinking: true,
552            ..Default::default()
553        };
554
555        let model = svm_train(&problem, &param);
556
557        // Check basic model structure
558        assert_eq!(model.nr_class, 2);
559        assert_eq!(model.label, vec![1, -1]);
560        assert!(!model.sv.is_empty(), "model has no support vectors");
561
562        // Compare with C reference model
563        let ref_model = load_model(&data_dir().join("heart_scale_ref.model")).unwrap();
564
565        // Same number of SVs (within tolerance — solver iterations may vary slightly)
566        let sv_diff = (model.sv.len() as i64 - ref_model.sv.len() as i64).unsigned_abs();
567        assert!(
568            sv_diff <= 2,
569            "SV count mismatch: Rust={}, C={}",
570            model.sv.len(),
571            ref_model.sv.len()
572        );
573
574        // Same rho (within tolerance)
575        assert!(
576            (model.rho[0] - ref_model.rho[0]).abs() < 1e-4,
577            "rho mismatch: Rust={}, C={}",
578            model.rho[0],
579            ref_model.rho[0]
580        );
581
582        // Predictions should match on training data
583        let mut correct = 0;
584        for (i, instance) in problem.instances.iter().enumerate() {
585            let pred = predict(&model, instance);
586            if pred == problem.labels[i] {
587                correct += 1;
588            }
589        }
590        let accuracy = correct as f64 / problem.labels.len() as f64;
591        assert!(
592            accuracy > 0.85,
593            "training accuracy {:.2}% too low",
594            accuracy * 100.0
595        );
596
597        // Predictions from Rust-trained model should match C-trained model
598        let mut mismatches = 0;
599        for instance in &problem.instances {
600            let rust_pred = predict(&model, instance);
601            let c_pred = predict(&ref_model, instance);
602            if rust_pred != c_pred {
603                mismatches += 1;
604            }
605        }
606        assert!(
607            mismatches <= 3,
608            "{} prediction mismatches between Rust-trained and C-trained models",
609            mismatches
610        );
611    }
612
613    #[test]
614    fn train_c_svc_iris_multiclass() {
615        let problem = load_problem(&data_dir().join("iris.scale")).unwrap();
616        let param = SvmParameter {
617            svm_type: SvmType::CSvc,
618            kernel_type: KernelType::Rbf,
619            gamma: 0.25, // 1/num_features = 1/4
620            c: 1.0,
621            cache_size: 100.0,
622            eps: 0.001,
623            shrinking: true,
624            ..Default::default()
625        };
626
627        let model = svm_train(&problem, &param);
628
629        // Iris has 3 classes
630        assert_eq!(model.nr_class, 3);
631        assert_eq!(model.label.len(), 3);
632        // 3 class pairs = 3 rho values
633        assert_eq!(model.rho.len(), 3);
634        // sv_coef has nr_class-1 = 2 rows
635        assert_eq!(model.sv_coef.len(), 2);
636        // n_sv has 3 entries
637        assert_eq!(model.n_sv.len(), 3);
638
639        // Predict on training set — should be very accurate for iris
640        let mut correct = 0;
641        for (i, instance) in problem.instances.iter().enumerate() {
642            let pred = predict(&model, instance);
643            if pred == problem.labels[i] {
644                correct += 1;
645            }
646        }
647        let accuracy = correct as f64 / problem.labels.len() as f64;
648        assert!(
649            accuracy > 0.95,
650            "iris accuracy {:.2}% too low (expected >95%)",
651            accuracy * 100.0
652        );
653    }
654
655    #[test]
656    fn train_c_svc_precomputed_kernel() {
657        let problem = load_problem(&data_dir().join("heart_scale.precomputed")).unwrap();
658        let param = SvmParameter {
659            svm_type: SvmType::CSvc,
660            kernel_type: KernelType::Precomputed,
661            c: 1.0,
662            cache_size: 100.0,
663            eps: 0.001,
664            shrinking: true,
665            ..Default::default()
666        };
667
668        let model = svm_train(&problem, &param);
669
670        assert_eq!(model.nr_class, 2);
671        assert!(!model.sv.is_empty(), "model has no support vectors");
672
673        // Sanity-check predictions on training data.
674        let mut correct = 0;
675        for (i, instance) in problem.instances.iter().enumerate() {
676            let pred = predict(&model, instance);
677            if pred == problem.labels[i] {
678                correct += 1;
679            }
680        }
681        let accuracy = correct as f64 / problem.labels.len() as f64;
682        assert!(
683            accuracy > 0.70,
684            "precomputed-kernel accuracy {:.2}% too low",
685            accuracy * 100.0
686        );
687    }
688
689    #[test]
690    fn train_one_class() {
691        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
692        let param = SvmParameter {
693            svm_type: SvmType::OneClass,
694            kernel_type: KernelType::Rbf,
695            gamma: 1.0 / 13.0,
696            nu: 0.5,
697            cache_size: 100.0,
698            eps: 0.001,
699            shrinking: true,
700            ..Default::default()
701        };
702
703        let model = svm_train(&problem, &param);
704
705        assert_eq!(model.nr_class, 2);
706        assert!(!model.sv.is_empty());
707        assert_eq!(model.rho.len(), 1);
708
709        // Predict — most training points should be classified as +1 (inlier)
710        let mut inliers = 0;
711        for instance in &problem.instances {
712            let pred = predict(&model, instance);
713            if pred > 0.0 {
714                inliers += 1;
715            }
716        }
717        let inlier_rate = inliers as f64 / problem.instances.len() as f64;
718        // With nu=0.5, roughly half should be inliers (nu is upper bound on fraction of outliers)
719        assert!(
720            inlier_rate > 0.3 && inlier_rate < 0.9,
721            "unexpected inlier rate: {:.2}%",
722            inlier_rate * 100.0
723        );
724    }
725
726    #[test]
727    fn train_epsilon_svr() {
728        let problem = load_problem(&data_dir().join("housing_scale")).unwrap();
729        let param = SvmParameter {
730            svm_type: SvmType::EpsilonSvr,
731            kernel_type: KernelType::Rbf,
732            gamma: 1.0 / 13.0,
733            c: 1.0,
734            p: 0.1,
735            cache_size: 100.0,
736            eps: 0.001,
737            shrinking: true,
738            ..Default::default()
739        };
740
741        let model = svm_train(&problem, &param);
742
743        assert_eq!(model.nr_class, 2); // SVR always has nr_class=2
744        assert!(!model.sv.is_empty());
745
746        // Compute MSE on training set — should be reasonable
747        let mut mse = 0.0;
748        for (i, instance) in problem.instances.iter().enumerate() {
749            let pred = predict(&model, instance);
750            let err = pred - problem.labels[i];
751            mse += err * err;
752        }
753        mse /= problem.instances.len() as f64;
754
755        // MSE should be finite and reasonable
756        assert!(mse.is_finite(), "MSE is not finite");
757        assert!(mse < 100.0, "MSE too high: {}", mse);
758    }
759
760    #[test]
761    fn train_nu_svc() {
762        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
763        let param = SvmParameter {
764            svm_type: SvmType::NuSvc,
765            kernel_type: KernelType::Rbf,
766            gamma: 1.0 / 13.0,
767            nu: 0.5,
768            cache_size: 100.0,
769            eps: 0.001,
770            shrinking: true,
771            ..Default::default()
772        };
773
774        let model = svm_train(&problem, &param);
775
776        assert_eq!(model.nr_class, 2);
777        assert!(!model.sv.is_empty());
778
779        let mut correct = 0;
780        for (i, instance) in problem.instances.iter().enumerate() {
781            let pred = predict(&model, instance);
782            if pred == problem.labels[i] {
783                correct += 1;
784            }
785        }
786        let accuracy = correct as f64 / problem.labels.len() as f64;
787        assert!(
788            accuracy > 0.70,
789            "nu-SVC accuracy {:.2}% too low",
790            accuracy * 100.0
791        );
792    }
793
794    #[test]
795    fn train_csvc_with_probability() {
796        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
797        let param = SvmParameter {
798            svm_type: SvmType::CSvc,
799            kernel_type: KernelType::Rbf,
800            gamma: 1.0 / 13.0,
801            c: 1.0,
802            cache_size: 100.0,
803            eps: 0.001,
804            shrinking: true,
805            probability: true,
806            ..Default::default()
807        };
808
809        let model = svm_train(&problem, &param);
810
811        assert_eq!(model.nr_class, 2);
812        assert_eq!(model.prob_a.len(), 1, "binary should have 1 probA");
813        assert_eq!(model.prob_b.len(), 1, "binary should have 1 probB");
814        assert!(model.prob_a[0].is_finite());
815        assert!(model.prob_b[0].is_finite());
816    }
817
818    #[test]
819    fn train_nu_svr() {
820        let problem = load_problem(&data_dir().join("housing_scale")).unwrap();
821        let param = SvmParameter {
822            svm_type: SvmType::NuSvr,
823            kernel_type: KernelType::Rbf,
824            gamma: 1.0 / 13.0,
825            c: 1.0,
826            nu: 0.5,
827            cache_size: 100.0,
828            eps: 0.001,
829            shrinking: true,
830            ..Default::default()
831        };
832
833        let model = svm_train(&problem, &param);
834
835        assert_eq!(model.nr_class, 2);
836        assert!(!model.sv.is_empty());
837
838        let mut mse = 0.0;
839        for (i, instance) in problem.instances.iter().enumerate() {
840            let pred = predict(&model, instance);
841            let err = pred - problem.labels[i];
842            mse += err * err;
843        }
844        mse /= problem.instances.len() as f64;
845
846        assert!(mse.is_finite(), "MSE is not finite");
847        assert!(mse < 200.0, "MSE too high: {}", mse);
848    }
849}