Skip to main content

libsvm_rs/
types.rs

1/// Type of SVM formulation.
2///
3/// Matches the integer constants in the original LIBSVM (`svm.h`):
4/// `C_SVC=0, NU_SVC=1, ONE_CLASS=2, EPSILON_SVR=3, NU_SVR=4`.
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6#[repr(i32)]
7pub enum SvmType {
8    /// C-Support Vector Classification.
9    CSvc = 0,
10    /// ν-Support Vector Classification.
11    NuSvc = 1,
12    /// One-class SVM (distribution estimation / novelty detection).
13    OneClass = 2,
14    /// ε-Support Vector Regression.
15    EpsilonSvr = 3,
16    /// ν-Support Vector Regression.
17    NuSvr = 4,
18}
19
20/// Type of kernel function.
21///
22/// Matches the integer constants in the original LIBSVM (`svm.h`):
23/// `LINEAR=0, POLY=1, RBF=2, SIGMOID=3, PRECOMPUTED=4`.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25#[repr(i32)]
26pub enum KernelType {
27    /// `K(x,y) = x·y`
28    Linear = 0,
29    /// `K(x,y) = (γ·x·y + coef0)^degree`
30    Polynomial = 1,
31    /// `K(x,y) = exp(-γ·‖x-y‖²)`
32    Rbf = 2,
33    /// `K(x,y) = tanh(γ·x·y + coef0)`
34    Sigmoid = 3,
35    /// Kernel values supplied as a precomputed matrix.
36    Precomputed = 4,
37}
38
39/// A single sparse feature: `index:value`.
40///
41/// In the original LIBSVM, a sentinel node with `index = -1` marks the end
42/// of each instance. In this Rust port, instance length is tracked by
43/// `Vec::len()` instead, so no sentinel is needed.
44#[derive(Debug, Clone, Copy, PartialEq)]
45pub struct SvmNode {
46    /// 1-based feature index. Uses `i32` to match the original C `int` and
47    /// preserve file-format compatibility.
48    pub index: i32,
49    /// Feature value.
50    pub value: f64,
51}
52
53/// A training/test problem: a collection of labelled sparse instances.
54#[derive(Debug, Clone, PartialEq)]
55pub struct SvmProblem {
56    /// Label (class for classification, target for regression) per instance.
57    pub labels: Vec<f64>,
58    /// Sparse feature vectors, one per instance.
59    pub instances: Vec<Vec<SvmNode>>,
60}
61
62/// SVM parameters controlling the formulation, kernel, and solver.
63///
64/// Default values match the original LIBSVM defaults.
65#[derive(Debug, Clone, PartialEq)]
66pub struct SvmParameter {
67    /// SVM formulation type.
68    pub svm_type: SvmType,
69    /// Kernel function type.
70    pub kernel_type: KernelType,
71    /// Degree for polynomial kernel.
72    pub degree: i32,
73    /// γ parameter for RBF, polynomial, and sigmoid kernels.
74    /// Set to `1/num_features` when 0.
75    pub gamma: f64,
76    /// Independent term in polynomial and sigmoid kernels.
77    pub coef0: f64,
78    /// Cache memory size in MB.
79    pub cache_size: f64,
80    /// Stopping tolerance for the solver.
81    pub eps: f64,
82    /// Cost parameter C (for C-SVC, ε-SVR, ν-SVR).
83    pub c: f64,
84    /// Per-class weight overrides: `(class_label, weight)` pairs.
85    pub weight: Vec<(i32, f64)>,
86    /// ν parameter (for ν-SVC, one-class SVM, ν-SVR).
87    pub nu: f64,
88    /// ε in the ε-insensitive loss function (ε-SVR).
89    pub p: f64,
90    /// Whether to use the shrinking heuristic.
91    pub shrinking: bool,
92    /// Whether to train for probability estimates.
93    pub probability: bool,
94}
95
96impl Default for SvmParameter {
97    fn default() -> Self {
98        Self {
99            svm_type: SvmType::CSvc,
100            kernel_type: KernelType::Rbf,
101            degree: 3,
102            gamma: 0.0, // means 1/num_features
103            coef0: 0.0,
104            cache_size: 100.0,
105            eps: 0.001,
106            c: 1.0,
107            weight: Vec::new(),
108            nu: 0.5,
109            p: 0.1,
110            shrinking: true,
111            probability: false,
112        }
113    }
114}
115
116impl SvmParameter {
117    /// Validate parameter values (independent of training data).
118    ///
119    /// This checks the same constraints as the original LIBSVM's
120    /// `svm_check_parameter`, except for the ν-SVC feasibility check
121    /// which requires the problem. Use [`check_parameter`] for the full check.
122    pub fn validate(&self) -> Result<(), crate::error::SvmError> {
123        use crate::error::SvmError;
124
125        // gamma must be non-negative for kernels that use it
126        if matches!(
127            self.kernel_type,
128            KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
129        ) && self.gamma < 0.0
130        {
131            return Err(SvmError::InvalidParameter("gamma < 0".into()));
132        }
133
134        // polynomial degree must be non-negative
135        if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
136            return Err(SvmError::InvalidParameter(
137                "degree of polynomial kernel < 0".into(),
138            ));
139        }
140
141        if self.cache_size <= 0.0 {
142            return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
143        }
144
145        if self.eps <= 0.0 {
146            return Err(SvmError::InvalidParameter("eps <= 0".into()));
147        }
148
149        // C > 0 for formulations that use it
150        if matches!(
151            self.svm_type,
152            SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
153        ) && self.c <= 0.0
154        {
155            return Err(SvmError::InvalidParameter("C <= 0".into()));
156        }
157
158        // nu ∈ (0, 1] for formulations that use it
159        if matches!(
160            self.svm_type,
161            SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
162        ) && (self.nu <= 0.0 || self.nu > 1.0)
163        {
164            return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
165        }
166
167        // p >= 0 for epsilon-SVR
168        if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
169            return Err(SvmError::InvalidParameter("p < 0".into()));
170        }
171
172        Ok(())
173    }
174}
175
176/// Full parameter check including ν-SVC feasibility against training data.
177///
178/// Matches the original LIBSVM `svm_check_parameter()`.
179pub fn check_parameter(
180    problem: &SvmProblem,
181    param: &SvmParameter,
182) -> Result<(), crate::error::SvmError> {
183    use crate::error::SvmError;
184
185    // Run the data-independent checks first
186    param.validate()?;
187
188    // ν-SVC feasibility: for every pair of classes (i, j),
189    // nu * (count_i + count_j) / 2 must be <= min(count_i, count_j)
190    //
191    // Note: LIBSVM casts labels to int for class grouping. We match this
192    // behavior. Classification labels must be integers (non-integer labels
193    // will be truncated, matching `(int)prob->y[i]` in the C code).
194    if param.svm_type == SvmType::NuSvc {
195        let mut class_counts: Vec<(i32, usize)> = Vec::new();
196        for &y in &problem.labels {
197            let label = y as i32;
198            if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
199                entry.1 += 1;
200            } else {
201                class_counts.push((label, 1));
202            }
203        }
204
205        for (i, &(_, n1)) in class_counts.iter().enumerate() {
206            for &(_, n2) in &class_counts[i + 1..] {
207                if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
208                    return Err(SvmError::InvalidParameter(
209                        "specified nu is infeasible".into(),
210                    ));
211                }
212            }
213        }
214    }
215
216    Ok(())
217}
218
219/// A trained SVM model.
220///
221/// Produced by training, or loaded from a LIBSVM model file.
222#[derive(Debug, Clone, PartialEq)]
223pub struct SvmModel {
224    /// Parameters used during training.
225    pub param: SvmParameter,
226    /// Number of classes (2 for binary, >2 for multiclass, 2 for regression).
227    pub nr_class: usize,
228    /// Support vectors (sparse feature vectors).
229    pub sv: Vec<Vec<SvmNode>>,
230    /// Support vector coefficients. For k classes, this is a
231    /// `(k-1) × num_sv` matrix stored as `Vec<Vec<f64>>`.
232    pub sv_coef: Vec<Vec<f64>>,
233    /// Bias terms (rho). One per class pair: `k*(k-1)/2` values.
234    pub rho: Vec<f64>,
235    /// Pairwise probability parameter A (Platt scaling). Empty if not trained
236    /// with probability estimates.
237    pub prob_a: Vec<f64>,
238    /// Pairwise probability parameter B (Platt scaling). Empty if not trained
239    /// with probability estimates.
240    pub prob_b: Vec<f64>,
241    /// Probability density marks (for one-class SVM).
242    pub prob_density_marks: Vec<f64>,
243    /// Original indices of support vectors in the training set (1-based).
244    pub sv_indices: Vec<usize>,
245    /// Class labels (in the order used internally).
246    pub label: Vec<i32>,
247    /// Number of support vectors per class.
248    pub n_sv: Vec<usize>,
249}
250
251impl SvmModel {
252    /// Return the SVM type used by the model.
253    pub fn svm_type(&self) -> SvmType {
254        self.param.svm_type
255    }
256
257    /// Return number of classes.
258    pub fn class_count(&self) -> usize {
259        self.nr_class
260    }
261
262    /// Return class labels in internal one-vs-one order.
263    pub fn labels(&self) -> &[i32] {
264        &self.label
265    }
266
267    /// Return original 1-based support-vector indices.
268    pub fn support_vector_indices(&self) -> &[usize] {
269        &self.sv_indices
270    }
271
272    /// Return total number of support vectors.
273    pub fn support_vector_count(&self) -> usize {
274        self.sv.len()
275    }
276
277    /// Return SVR sigma when a probability-capable SVR model is available.
278    pub fn svr_probability(&self) -> Option<f64> {
279        match self.param.svm_type {
280            SvmType::EpsilonSvr | SvmType::NuSvr => self.prob_a.first().copied(),
281            _ => None,
282        }
283    }
284
285    /// Check whether the model contains probability metadata.
286    pub fn has_probability_model(&self) -> bool {
287        match self.param.svm_type {
288            SvmType::CSvc | SvmType::NuSvc => !self.prob_a.is_empty() && !self.prob_b.is_empty(),
289            SvmType::EpsilonSvr | SvmType::NuSvr => !self.prob_a.is_empty(),
290            SvmType::OneClass => !self.prob_density_marks.is_empty(),
291        }
292    }
293}
294
295/// C-API style helper matching LIBSVM's `svm_get_svm_type`.
296pub fn svm_get_svm_type(model: &SvmModel) -> SvmType {
297    model.svm_type()
298}
299
300/// C-API style helper matching LIBSVM's `svm_get_nr_class`.
301pub fn svm_get_nr_class(model: &SvmModel) -> usize {
302    model.class_count()
303}
304
305/// C-API style helper matching LIBSVM's `svm_get_labels`.
306pub fn svm_get_labels(model: &SvmModel) -> &[i32] {
307    model.labels()
308}
309
310/// C-API style helper matching LIBSVM's `svm_get_sv_indices`.
311pub fn svm_get_sv_indices(model: &SvmModel) -> &[usize] {
312    model.support_vector_indices()
313}
314
315/// C-API style helper matching LIBSVM's `svm_get_nr_sv`.
316pub fn svm_get_nr_sv(model: &SvmModel) -> usize {
317    model.support_vector_count()
318}
319
320/// C-API style helper matching LIBSVM's `svm_get_svr_probability`.
321pub fn svm_get_svr_probability(model: &SvmModel) -> Option<f64> {
322    model.svr_probability()
323}
324
325/// C-API style helper matching LIBSVM's `svm_check_probability_model`.
326pub fn svm_check_probability_model(model: &SvmModel) -> bool {
327    model.has_probability_model()
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use crate::train::svm_train;
334    use std::path::PathBuf;
335
336    fn data_dir() -> PathBuf {
337        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
338            .join("..")
339            .join("..")
340            .join("data")
341    }
342
343    #[test]
344    fn default_params_are_valid() {
345        SvmParameter::default().validate().unwrap();
346    }
347
348    #[test]
349    fn negative_gamma_rejected() {
350        let p = SvmParameter {
351            gamma: -1.0,
352            ..Default::default()
353        };
354        assert!(p.validate().is_err());
355    }
356
357    #[test]
358    fn zero_cache_rejected() {
359        let p = SvmParameter {
360            cache_size: 0.0,
361            ..Default::default()
362        };
363        assert!(p.validate().is_err());
364    }
365
366    #[test]
367    fn zero_c_rejected() {
368        let p = SvmParameter {
369            c: 0.0,
370            ..Default::default()
371        };
372        assert!(p.validate().is_err());
373    }
374
375    #[test]
376    fn nu_out_of_range_rejected() {
377        let p = SvmParameter {
378            svm_type: SvmType::NuSvc,
379            nu: 1.5,
380            ..Default::default()
381        };
382        assert!(p.validate().is_err());
383
384        let p2 = SvmParameter {
385            svm_type: SvmType::NuSvc,
386            nu: 0.0,
387            ..Default::default()
388        };
389        assert!(p2.validate().is_err());
390    }
391
392    #[test]
393    fn negative_p_rejected_for_svr() {
394        let p = SvmParameter {
395            svm_type: SvmType::EpsilonSvr,
396            p: -0.1,
397            ..Default::default()
398        };
399        assert!(p.validate().is_err());
400    }
401
402    #[test]
403    fn negative_poly_degree_rejected() {
404        let p = SvmParameter {
405            kernel_type: KernelType::Polynomial,
406            degree: -1,
407            ..Default::default()
408        };
409        assert!(p.validate().is_err());
410    }
411
412    #[test]
413    fn nu_svc_feasibility_check() {
414        // 2 classes with 3 samples each: nu * (3+3)/2 <= 3  →  nu <= 1
415        let problem = SvmProblem {
416            labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
417            instances: vec![vec![]; 6],
418        };
419        let ok_param = SvmParameter {
420            svm_type: SvmType::NuSvc,
421            nu: 0.5,
422            ..Default::default()
423        };
424        check_parameter(&problem, &ok_param).unwrap();
425
426        // nu = 0.9: 0.9 * 6/2 = 2.7 <= 3 → feasible
427        let borderline = SvmParameter {
428            svm_type: SvmType::NuSvc,
429            nu: 0.9,
430            ..Default::default()
431        };
432        check_parameter(&problem, &borderline).unwrap();
433    }
434
435    #[test]
436    fn nu_svc_infeasible() {
437        // 5 class-A, 1 class-B: nu*(5+1)/2 > min(5,1)=1  →  nu > 1/3
438        let problem = SvmProblem {
439            labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
440            instances: vec![vec![]; 6],
441        };
442        let param = SvmParameter {
443            svm_type: SvmType::NuSvc,
444            nu: 0.5, // 0.5 * 6/2 = 1.5 > 1
445            ..Default::default()
446        };
447        let err = check_parameter(&problem, &param);
448        assert!(err.is_err());
449        assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
450    }
451
452    #[test]
453    fn c_api_style_model_helpers() {
454        let problem = crate::io::load_problem(&data_dir().join("heart_scale")).unwrap();
455        let param = SvmParameter {
456            gamma: 1.0 / 13.0,
457            ..Default::default()
458        };
459        let model = svm_train(&problem, &param);
460
461        assert_eq!(svm_get_svm_type(&model), SvmType::CSvc);
462        assert_eq!(svm_get_nr_class(&model), 2);
463        assert_eq!(svm_get_nr_sv(&model), model.sv.len());
464        assert_eq!(svm_get_labels(&model), model.label.as_slice());
465        assert_eq!(svm_get_sv_indices(&model), model.sv_indices.as_slice());
466        assert!(!svm_check_probability_model(&model));
467        assert_eq!(svm_get_svr_probability(&model), None);
468    }
469
470    #[test]
471    fn probability_helpers_by_svm_type() {
472        let svm = vec![SvmNode {
473            index: 1,
474            value: 1.0,
475        }];
476
477        let csvc_model = SvmModel {
478            param: SvmParameter {
479                svm_type: SvmType::CSvc,
480                ..Default::default()
481            },
482            nr_class: 2,
483            sv: vec![svm.clone()],
484            sv_coef: vec![vec![1.0]],
485            rho: vec![0.0],
486            prob_a: vec![1.0],
487            prob_b: vec![-0.5],
488            prob_density_marks: vec![],
489            sv_indices: vec![1],
490            label: vec![1, -1],
491            n_sv: vec![1, 0],
492        };
493        assert!(csvc_model.has_probability_model());
494        assert!(svm_check_probability_model(&csvc_model));
495        assert_eq!(svm_get_svr_probability(&csvc_model), None);
496
497        let eps_svr_model = SvmModel {
498            param: SvmParameter {
499                svm_type: SvmType::EpsilonSvr,
500                ..Default::default()
501            },
502            nr_class: 2,
503            sv: vec![svm.clone()],
504            sv_coef: vec![vec![0.8]],
505            rho: vec![0.0],
506            prob_a: vec![0.123],
507            prob_b: vec![],
508            prob_density_marks: vec![],
509            sv_indices: vec![1],
510            label: vec![],
511            n_sv: vec![],
512        };
513        assert!(eps_svr_model.has_probability_model());
514        assert_eq!(svm_get_svr_probability(&eps_svr_model), Some(0.123));
515
516        let one_class_model = SvmModel {
517            param: SvmParameter {
518                svm_type: SvmType::OneClass,
519                ..Default::default()
520            },
521            nr_class: 2,
522            sv: vec![svm],
523            sv_coef: vec![vec![1.0]],
524            rho: vec![0.0],
525            prob_a: vec![],
526            prob_b: vec![],
527            prob_density_marks: vec![0.1; 10],
528            sv_indices: vec![1],
529            label: vec![],
530            n_sv: vec![],
531        };
532        assert!(one_class_model.has_probability_model());
533        assert_eq!(svm_get_svr_probability(&one_class_model), None);
534    }
535}