Skip to main content

libsvm_rs/
types.rs

1//! Core LIBSVM-compatible data structures.
2//!
3//! These types intentionally mirror LIBSVM concepts: sparse nodes, problem
4//! rows, solver parameters, and trained model fields. Values produced by
5//! [`crate::io::load_problem`] and [`crate::io::load_model`] have passed the
6//! loader's text-format and resource-bound checks. Values constructed manually
7//! by a caller are not automatically checked; call [`check_parameter`] before
8//! training, and prefer the loader APIs for external model/problem files.
9
10/// Type of SVM formulation.
11///
12/// Matches the integer constants in the original LIBSVM (`svm.h`):
13/// `C_SVC=0, NU_SVC=1, ONE_CLASS=2, EPSILON_SVR=3, NU_SVR=4`.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15#[repr(i32)]
16pub enum SvmType {
17    /// C-Support Vector Classification.
18    CSvc = 0,
19    /// ν-Support Vector Classification.
20    NuSvc = 1,
21    /// One-class SVM (distribution estimation / novelty detection).
22    OneClass = 2,
23    /// ε-Support Vector Regression.
24    EpsilonSvr = 3,
25    /// ν-Support Vector Regression.
26    NuSvr = 4,
27}
28
29/// Type of kernel function.
30///
31/// Matches the integer constants in the original LIBSVM (`svm.h`):
32/// `LINEAR=0, POLY=1, RBF=2, SIGMOID=3, PRECOMPUTED=4`.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34#[repr(i32)]
35pub enum KernelType {
36    /// `K(x,y) = x·y`
37    Linear = 0,
38    /// `K(x,y) = (γ·x·y + coef0)^degree`
39    Polynomial = 1,
40    /// `K(x,y) = exp(-γ·‖x-y‖²)`
41    Rbf = 2,
42    /// `K(x,y) = tanh(γ·x·y + coef0)`
43    Sigmoid = 3,
44    /// Kernel values supplied as a precomputed matrix.
45    Precomputed = 4,
46}
47
48/// A single sparse feature: `index:value`.
49///
50/// In the original LIBSVM, a sentinel node with `index = -1` marks the end
51/// of each instance. In this Rust port, instance length is tracked by
52/// `Vec::len()` instead, so no sentinel is needed.
53#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub struct SvmNode {
56    /// 1-based feature index. Uses `i32` to match the original C `int` and
57    /// preserve file-format compatibility.
58    pub index: i32,
59    /// Feature value.
60    pub value: f64,
61}
62
63/// A training/test problem: a collection of labelled sparse instances.
64///
65/// `load_problem` validates that sparse feature indices are ascending and
66/// within the configured [`crate::io::LoadOptions`] bounds. When constructing a
67/// problem manually, keep `labels.len() == instances.len()` and use ascending
68/// feature indices to match LIBSVM input assumptions.
69#[derive(Debug, Clone, PartialEq)]
70pub struct SvmProblem {
71    /// Label (class for classification, target for regression) per instance.
72    pub labels: Vec<f64>,
73    /// Sparse feature vectors, one per instance.
74    pub instances: Vec<Vec<SvmNode>>,
75}
76
77/// SVM parameters controlling the formulation, kernel, and solver.
78///
79/// Default values match the original LIBSVM defaults.
80#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
81#[derive(Debug, Clone, PartialEq)]
82pub struct SvmParameter {
83    /// SVM formulation type.
84    pub svm_type: SvmType,
85    /// Kernel function type.
86    pub kernel_type: KernelType,
87    /// Degree for polynomial kernel.
88    pub degree: i32,
89    /// γ parameter for RBF, polynomial, and sigmoid kernels.
90    /// Set to `1/num_features` when 0.
91    pub gamma: f64,
92    /// Independent term in polynomial and sigmoid kernels.
93    pub coef0: f64,
94    /// Cache memory size in MB.
95    pub cache_size: f64,
96    /// Stopping tolerance for the solver.
97    pub eps: f64,
98    /// Cost parameter C (for C-SVC, ε-SVR, ν-SVR).
99    pub c: f64,
100    /// Per-class weight overrides: `(class_label, weight)` pairs.
101    pub weight: Vec<(i32, f64)>,
102    /// ν parameter (for ν-SVC, one-class SVM, ν-SVR).
103    pub nu: f64,
104    /// ε in the ε-insensitive loss function (ε-SVR).
105    pub p: f64,
106    /// Whether to use the shrinking heuristic.
107    pub shrinking: bool,
108    /// Whether to train for probability estimates.
109    pub probability: bool,
110}
111
112impl Default for SvmParameter {
113    fn default() -> Self {
114        Self {
115            svm_type: SvmType::CSvc,
116            kernel_type: KernelType::Rbf,
117            degree: 3,
118            gamma: 0.0, // means 1/num_features
119            coef0: 0.0,
120            cache_size: 100.0,
121            eps: 0.001,
122            c: 1.0,
123            weight: Vec::new(),
124            nu: 0.5,
125            p: 0.1,
126            shrinking: true,
127            probability: false,
128        }
129    }
130}
131
132impl SvmParameter {
133    /// Validate parameter values (independent of training data).
134    ///
135    /// This checks the same constraints as the original LIBSVM's
136    /// `svm_check_parameter`, except for the ν-SVC feasibility check
137    /// which requires the problem. Use [`check_parameter`] for the full check.
138    pub fn validate(&self) -> Result<(), crate::error::SvmError> {
139        use crate::error::SvmError;
140
141        // gamma must be non-negative for kernels that use it
142        if matches!(
143            self.kernel_type,
144            KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
145        ) && self.gamma < 0.0
146        {
147            return Err(SvmError::InvalidParameter("gamma < 0".into()));
148        }
149
150        // polynomial degree must be non-negative
151        if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
152            return Err(SvmError::InvalidParameter(
153                "degree of polynomial kernel < 0".into(),
154            ));
155        }
156
157        if self.cache_size <= 0.0 {
158            return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
159        }
160
161        if self.eps <= 0.0 {
162            return Err(SvmError::InvalidParameter("eps <= 0".into()));
163        }
164
165        // C > 0 for formulations that use it
166        if matches!(
167            self.svm_type,
168            SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
169        ) && self.c <= 0.0
170        {
171            return Err(SvmError::InvalidParameter("C <= 0".into()));
172        }
173
174        // nu ∈ (0, 1] for formulations that use it
175        if matches!(
176            self.svm_type,
177            SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
178        ) && (self.nu <= 0.0 || self.nu > 1.0)
179        {
180            return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
181        }
182
183        // p >= 0 for epsilon-SVR
184        if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
185            return Err(SvmError::InvalidParameter("p < 0".into()));
186        }
187
188        Ok(())
189    }
190}
191
192/// Full parameter check including ν-SVC feasibility against training data.
193///
194/// Matches the original LIBSVM `svm_check_parameter()`.
195pub fn check_parameter(
196    problem: &SvmProblem,
197    param: &SvmParameter,
198) -> Result<(), crate::error::SvmError> {
199    use crate::error::SvmError;
200
201    // Run the data-independent checks first
202    param.validate()?;
203
204    if problem.labels.len() != problem.instances.len() {
205        return Err(SvmError::InvalidParameter(format!(
206            "labels length ({}) does not match instance length ({})",
207            problem.labels.len(),
208            problem.instances.len()
209        )));
210    }
211
212    if problem.labels.is_empty() {
213        return Err(SvmError::InvalidParameter(
214            "problem has no instances".into(),
215        ));
216    }
217
218    if param.kernel_type == KernelType::Precomputed {
219        let upper = problem.instances.len() as f64;
220        for (row, instance) in problem.instances.iter().enumerate() {
221            let first = instance.first().ok_or_else(|| {
222                SvmError::InvalidParameter(format!(
223                    "precomputed kernel row {} is missing 0:sample_serial_number",
224                    row + 1
225                ))
226            })?;
227            if first.index != 0
228                || !first.value.is_finite()
229                || first.value < 1.0
230                || first.value > upper
231                || first.value.fract() != 0.0
232            {
233                return Err(SvmError::InvalidParameter(format!(
234                    "precomputed kernel row {} must start with 0:sample_serial_number in [1, {}]",
235                    row + 1,
236                    problem.instances.len()
237                )));
238            }
239        }
240    }
241
242    // ν-SVC feasibility: for every pair of classes (i, j),
243    // nu * (count_i + count_j) / 2 must be <= min(count_i, count_j)
244    //
245    // Note: LIBSVM casts labels to int for class grouping. We match this
246    // behavior. Classification labels must be integers (non-integer labels
247    // will be truncated, matching `(int)prob->y[i]` in the C code).
248    if param.svm_type == SvmType::NuSvc {
249        let mut class_counts: Vec<(i32, usize)> = Vec::new();
250        for &y in &problem.labels {
251            let label = y as i32;
252            if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
253                entry.1 += 1;
254            } else {
255                class_counts.push((label, 1));
256            }
257        }
258
259        for (i, &(_, n1)) in class_counts.iter().enumerate() {
260            for &(_, n2) in &class_counts[i + 1..] {
261                if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
262                    return Err(SvmError::InvalidParameter(
263                        "specified nu is infeasible".into(),
264                    ));
265                }
266            }
267        }
268    }
269
270    Ok(())
271}
272
273/// A trained SVM model.
274///
275/// Produced by training, or loaded from a LIBSVM model file.
276///
277/// `load_model` validates model-file shape contracts before returning this
278/// type: class counts, support-vector counts, decision-function arrays, optional
279/// probability metadata, and sparse support-vector rows must be internally
280/// consistent. Manually constructed values bypass those checks, so callers that
281/// accept external model text should prefer [`crate::io::load_model`] or
282/// [`crate::io::load_model_from_reader`].
283#[cfg_attr(feature = "serde", derive(serde::Serialize))]
284#[derive(Debug, Clone, PartialEq)]
285pub struct SvmModel {
286    /// Parameters used during training.
287    pub param: SvmParameter,
288    /// Number of classes (2 for binary, >2 for multiclass, 2 for regression).
289    pub nr_class: usize,
290    /// Support vectors (sparse feature vectors).
291    pub sv: Vec<Vec<SvmNode>>,
292    /// Support vector coefficients. For k classes, this is a
293    /// `(k-1) × num_sv` matrix stored as `Vec<Vec<f64>>`.
294    pub sv_coef: Vec<Vec<f64>>,
295    /// Bias terms (rho). One per class pair: `k*(k-1)/2` values.
296    pub rho: Vec<f64>,
297    /// Pairwise probability parameter A (Platt scaling). Empty if not trained
298    /// with probability estimates.
299    pub prob_a: Vec<f64>,
300    /// Pairwise probability parameter B (Platt scaling). Empty if not trained
301    /// with probability estimates.
302    pub prob_b: Vec<f64>,
303    /// Probability density marks (for one-class SVM).
304    pub prob_density_marks: Vec<f64>,
305    /// Original indices of support vectors in the training set (1-based).
306    pub sv_indices: Vec<usize>,
307    /// Class labels (in the order used internally).
308    pub label: Vec<i32>,
309    /// Number of support vectors per class.
310    pub n_sv: Vec<usize>,
311}
312
313#[cfg(feature = "serde")]
314#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
315impl serde::Serialize for SvmType {
316    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
317    where
318        S: serde::Serializer,
319    {
320        serializer.serialize_i32(*self as i32)
321    }
322}
323
324#[cfg(feature = "serde")]
325#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
326impl<'de> serde::Deserialize<'de> for SvmType {
327    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
328    where
329        D: serde::Deserializer<'de>,
330    {
331        match <i32 as serde::Deserialize>::deserialize(deserializer)? {
332            0 => Ok(SvmType::CSvc),
333            1 => Ok(SvmType::NuSvc),
334            2 => Ok(SvmType::OneClass),
335            3 => Ok(SvmType::EpsilonSvr),
336            4 => Ok(SvmType::NuSvr),
337            code => Err(serde::de::Error::custom(format!(
338                "unknown SvmType code {code}"
339            ))),
340        }
341    }
342}
343
344#[cfg(feature = "serde")]
345#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
346impl serde::Serialize for KernelType {
347    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
348    where
349        S: serde::Serializer,
350    {
351        serializer.serialize_i32(*self as i32)
352    }
353}
354
355#[cfg(feature = "serde")]
356#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
357impl<'de> serde::Deserialize<'de> for KernelType {
358    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
359    where
360        D: serde::Deserializer<'de>,
361    {
362        match <i32 as serde::Deserialize>::deserialize(deserializer)? {
363            0 => Ok(KernelType::Linear),
364            1 => Ok(KernelType::Polynomial),
365            2 => Ok(KernelType::Rbf),
366            3 => Ok(KernelType::Sigmoid),
367            4 => Ok(KernelType::Precomputed),
368            code => Err(serde::de::Error::custom(format!(
369                "unknown KernelType code {code}"
370            ))),
371        }
372    }
373}
374
375#[cfg(feature = "serde")]
376#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
377impl<'de> serde::Deserialize<'de> for SvmModel {
378    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
379    where
380        D: serde::Deserializer<'de>,
381    {
382        #[derive(serde::Deserialize)]
383        struct RawSvmModel {
384            param: SvmParameter,
385            nr_class: usize,
386            sv: Vec<Vec<SvmNode>>,
387            sv_coef: Vec<Vec<f64>>,
388            rho: Vec<f64>,
389            prob_a: Vec<f64>,
390            prob_b: Vec<f64>,
391            prob_density_marks: Vec<f64>,
392            sv_indices: Vec<usize>,
393            label: Vec<i32>,
394            n_sv: Vec<usize>,
395        }
396
397        let raw = <RawSvmModel as serde::Deserialize>::deserialize(deserializer)?;
398        let model = SvmModel {
399            param: raw.param,
400            nr_class: raw.nr_class,
401            sv: raw.sv,
402            sv_coef: raw.sv_coef,
403            rho: raw.rho,
404            prob_a: raw.prob_a,
405            prob_b: raw.prob_b,
406            prob_density_marks: raw.prob_density_marks,
407            sv_indices: raw.sv_indices,
408            label: raw.label,
409            n_sv: raw.n_sv,
410        };
411        crate::io::validate_model(&model).map_err(serde::de::Error::custom)?;
412        Ok(model)
413    }
414}
415
416impl SvmModel {
417    /// Return the SVM type used by the model.
418    pub fn svm_type(&self) -> SvmType {
419        self.param.svm_type
420    }
421
422    /// Return number of classes.
423    pub fn class_count(&self) -> usize {
424        self.nr_class
425    }
426
427    /// Return class labels in internal one-vs-one order.
428    pub fn labels(&self) -> &[i32] {
429        &self.label
430    }
431
432    /// Return original 1-based support-vector indices.
433    pub fn support_vector_indices(&self) -> &[usize] {
434        &self.sv_indices
435    }
436
437    /// Return total number of support vectors.
438    pub fn support_vector_count(&self) -> usize {
439        self.sv.len()
440    }
441
442    /// Return SVR sigma when a probability-capable SVR model is available.
443    pub fn svr_probability(&self) -> Option<f64> {
444        match self.param.svm_type {
445            SvmType::EpsilonSvr | SvmType::NuSvr => self.prob_a.first().copied(),
446            _ => None,
447        }
448    }
449
450    /// Check whether the model contains probability metadata.
451    pub fn has_probability_model(&self) -> bool {
452        match self.param.svm_type {
453            SvmType::CSvc | SvmType::NuSvc => !self.prob_a.is_empty() && !self.prob_b.is_empty(),
454            SvmType::EpsilonSvr | SvmType::NuSvr => !self.prob_a.is_empty(),
455            SvmType::OneClass => !self.prob_density_marks.is_empty(),
456        }
457    }
458}
459
460/// C-API style helper matching LIBSVM's `svm_get_svm_type`.
461pub fn svm_get_svm_type(model: &SvmModel) -> SvmType {
462    model.svm_type()
463}
464
465/// C-API style helper matching LIBSVM's `svm_get_nr_class`.
466pub fn svm_get_nr_class(model: &SvmModel) -> usize {
467    model.class_count()
468}
469
470/// C-API style helper matching LIBSVM's `svm_get_labels`.
471pub fn svm_get_labels(model: &SvmModel) -> &[i32] {
472    model.labels()
473}
474
475/// C-API style helper matching LIBSVM's `svm_get_sv_indices`.
476pub fn svm_get_sv_indices(model: &SvmModel) -> &[usize] {
477    model.support_vector_indices()
478}
479
480/// C-API style helper matching LIBSVM's `svm_get_nr_sv`.
481pub fn svm_get_nr_sv(model: &SvmModel) -> usize {
482    model.support_vector_count()
483}
484
485/// C-API style helper matching LIBSVM's `svm_get_svr_probability`.
486pub fn svm_get_svr_probability(model: &SvmModel) -> Option<f64> {
487    model.svr_probability()
488}
489
490/// C-API style helper matching LIBSVM's `svm_check_probability_model`.
491pub fn svm_check_probability_model(model: &SvmModel) -> bool {
492    model.has_probability_model()
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::train::svm_train;
499    use std::path::PathBuf;
500
501    fn data_dir() -> PathBuf {
502        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
503            .join("..")
504            .join("..")
505            .join("data")
506    }
507
508    #[test]
509    fn default_params_are_valid() {
510        SvmParameter::default().validate().unwrap();
511    }
512
513    #[test]
514    fn negative_gamma_rejected() {
515        let p = SvmParameter {
516            gamma: -1.0,
517            ..Default::default()
518        };
519        assert!(p.validate().is_err());
520    }
521
522    #[test]
523    fn zero_cache_rejected() {
524        let p = SvmParameter {
525            cache_size: 0.0,
526            ..Default::default()
527        };
528        assert!(p.validate().is_err());
529    }
530
531    #[test]
532    fn zero_c_rejected() {
533        let p = SvmParameter {
534            c: 0.0,
535            ..Default::default()
536        };
537        assert!(p.validate().is_err());
538    }
539
540    #[test]
541    fn nu_out_of_range_rejected() {
542        let p = SvmParameter {
543            svm_type: SvmType::NuSvc,
544            nu: 1.5,
545            ..Default::default()
546        };
547        assert!(p.validate().is_err());
548
549        let p2 = SvmParameter {
550            svm_type: SvmType::NuSvc,
551            nu: 0.0,
552            ..Default::default()
553        };
554        assert!(p2.validate().is_err());
555    }
556
557    #[test]
558    fn negative_p_rejected_for_svr() {
559        let p = SvmParameter {
560            svm_type: SvmType::EpsilonSvr,
561            p: -0.1,
562            ..Default::default()
563        };
564        assert!(p.validate().is_err());
565    }
566
567    #[test]
568    fn negative_poly_degree_rejected() {
569        let p = SvmParameter {
570            kernel_type: KernelType::Polynomial,
571            degree: -1,
572            ..Default::default()
573        };
574        assert!(p.validate().is_err());
575    }
576
577    #[test]
578    fn check_parameter_rejects_empty_problem() {
579        let problem = SvmProblem {
580            labels: Vec::new(),
581            instances: Vec::new(),
582        };
583        let err = check_parameter(&problem, &SvmParameter::default()).unwrap_err();
584        assert!(format!("{}", err).contains("problem has no instances"));
585    }
586
587    #[test]
588    fn check_parameter_rejects_label_instance_length_mismatch() {
589        let problem = SvmProblem {
590            labels: vec![1.0],
591            instances: Vec::new(),
592        };
593        let err = check_parameter(&problem, &SvmParameter::default()).unwrap_err();
594        assert!(format!("{}", err).contains("does not match instance length"));
595    }
596
597    #[test]
598    fn check_parameter_rejects_precomputed_rows_without_sample_serial_number() {
599        let problem = SvmProblem {
600            labels: vec![1.0, -1.0],
601            instances: vec![
602                vec![],
603                vec![SvmNode {
604                    index: 0,
605                    value: 2.0,
606                }],
607            ],
608        };
609        let param = SvmParameter {
610            kernel_type: KernelType::Precomputed,
611            ..Default::default()
612        };
613        let err = check_parameter(&problem, &param).unwrap_err();
614        assert!(format!("{}", err).contains("missing 0:sample_serial_number"));
615    }
616
617    #[test]
618    fn nu_svc_feasibility_check() {
619        // 2 classes with 3 samples each: nu * (3+3)/2 <= 3  →  nu <= 1
620        let problem = SvmProblem {
621            labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
622            instances: vec![vec![]; 6],
623        };
624        let ok_param = SvmParameter {
625            svm_type: SvmType::NuSvc,
626            nu: 0.5,
627            ..Default::default()
628        };
629        check_parameter(&problem, &ok_param).unwrap();
630
631        // nu = 0.9: 0.9 * 6/2 = 2.7 <= 3 → feasible
632        let borderline = SvmParameter {
633            svm_type: SvmType::NuSvc,
634            nu: 0.9,
635            ..Default::default()
636        };
637        check_parameter(&problem, &borderline).unwrap();
638    }
639
640    #[test]
641    fn nu_svc_infeasible() {
642        // 5 class-A, 1 class-B: nu*(5+1)/2 > min(5,1)=1  →  nu > 1/3
643        let problem = SvmProblem {
644            labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
645            instances: vec![vec![]; 6],
646        };
647        let param = SvmParameter {
648            svm_type: SvmType::NuSvc,
649            nu: 0.5, // 0.5 * 6/2 = 1.5 > 1
650            ..Default::default()
651        };
652        let err = check_parameter(&problem, &param);
653        assert!(err.is_err());
654        assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
655    }
656
657    #[test]
658    fn c_api_style_model_helpers() {
659        let problem = crate::io::load_problem(&data_dir().join("heart_scale")).unwrap();
660        let param = SvmParameter {
661            gamma: 1.0 / 13.0,
662            ..Default::default()
663        };
664        let model = svm_train(&problem, &param);
665
666        assert_eq!(svm_get_svm_type(&model), SvmType::CSvc);
667        assert_eq!(svm_get_nr_class(&model), 2);
668        assert_eq!(svm_get_nr_sv(&model), model.sv.len());
669        assert_eq!(svm_get_labels(&model), model.label.as_slice());
670        assert_eq!(svm_get_sv_indices(&model), model.sv_indices.as_slice());
671        assert!(!svm_check_probability_model(&model));
672        assert_eq!(svm_get_svr_probability(&model), None);
673    }
674
675    #[test]
676    fn probability_helpers_by_svm_type() {
677        let svm = vec![SvmNode {
678            index: 1,
679            value: 1.0,
680        }];
681
682        let csvc_model = SvmModel {
683            param: SvmParameter {
684                svm_type: SvmType::CSvc,
685                ..Default::default()
686            },
687            nr_class: 2,
688            sv: vec![svm.clone()],
689            sv_coef: vec![vec![1.0]],
690            rho: vec![0.0],
691            prob_a: vec![1.0],
692            prob_b: vec![-0.5],
693            prob_density_marks: vec![],
694            sv_indices: vec![1],
695            label: vec![1, -1],
696            n_sv: vec![1, 0],
697        };
698        assert!(csvc_model.has_probability_model());
699        assert!(svm_check_probability_model(&csvc_model));
700        assert_eq!(svm_get_svr_probability(&csvc_model), None);
701
702        let eps_svr_model = SvmModel {
703            param: SvmParameter {
704                svm_type: SvmType::EpsilonSvr,
705                ..Default::default()
706            },
707            nr_class: 2,
708            sv: vec![svm.clone()],
709            sv_coef: vec![vec![0.8]],
710            rho: vec![0.0],
711            prob_a: vec![0.123],
712            prob_b: vec![],
713            prob_density_marks: vec![],
714            sv_indices: vec![1],
715            label: vec![],
716            n_sv: vec![],
717        };
718        assert!(eps_svr_model.has_probability_model());
719        assert_eq!(svm_get_svr_probability(&eps_svr_model), Some(0.123));
720
721        let one_class_model = SvmModel {
722            param: SvmParameter {
723                svm_type: SvmType::OneClass,
724                ..Default::default()
725            },
726            nr_class: 2,
727            sv: vec![svm],
728            sv_coef: vec![vec![1.0]],
729            rho: vec![0.0],
730            prob_a: vec![],
731            prob_b: vec![],
732            prob_density_marks: vec![0.1; 10],
733            sv_indices: vec![1],
734            label: vec![],
735            n_sv: vec![],
736        };
737        assert!(one_class_model.has_probability_model());
738        assert_eq!(svm_get_svr_probability(&one_class_model), None);
739    }
740}