Skip to main content

libsvm_rs/
types.rs

1//! Core LIBSVM-compatible data structures.
2//!
3//! These types intentionally mirror LIBSVM concepts: sparse nodes, problem
4//! rows, solver parameters, and trained model fields. Values produced by
5//! [`crate::io::load_problem`] and [`crate::io::load_model`] have passed the
6//! loader's text-format and resource-bound checks. Values constructed manually
7//! by a caller are not automatically checked; call [`check_parameter`] before
8//! training, and prefer the loader APIs for external model/problem files.
9
10/// Type of SVM formulation.
11///
12/// Matches the integer constants in the original LIBSVM (`svm.h`):
13/// `C_SVC=0, NU_SVC=1, ONE_CLASS=2, EPSILON_SVR=3, NU_SVR=4`.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15#[repr(i32)]
16pub enum SvmType {
17    /// C-Support Vector Classification.
18    CSvc = 0,
19    /// ν-Support Vector Classification.
20    NuSvc = 1,
21    /// One-class SVM (distribution estimation / novelty detection).
22    OneClass = 2,
23    /// ε-Support Vector Regression.
24    EpsilonSvr = 3,
25    /// ν-Support Vector Regression.
26    NuSvr = 4,
27}
28
29/// Type of kernel function.
30///
31/// Matches the integer constants in the original LIBSVM (`svm.h`):
32/// `LINEAR=0, POLY=1, RBF=2, SIGMOID=3, PRECOMPUTED=4`.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34#[repr(i32)]
35pub enum KernelType {
36    /// `K(x,y) = x·y`
37    Linear = 0,
38    /// `K(x,y) = (γ·x·y + coef0)^degree`
39    Polynomial = 1,
40    /// `K(x,y) = exp(-γ·‖x-y‖²)`
41    Rbf = 2,
42    /// `K(x,y) = tanh(γ·x·y + coef0)`
43    Sigmoid = 3,
44    /// Kernel values supplied as a precomputed matrix.
45    Precomputed = 4,
46}
47
48/// A single sparse feature: `index:value`.
49///
50/// In the original LIBSVM, a sentinel node with `index = -1` marks the end
51/// of each instance. In this Rust port, instance length is tracked by
52/// `Vec::len()` instead, so no sentinel is needed.
53#[derive(Debug, Clone, Copy, PartialEq)]
54pub struct SvmNode {
55    /// 1-based feature index. Uses `i32` to match the original C `int` and
56    /// preserve file-format compatibility.
57    pub index: i32,
58    /// Feature value.
59    pub value: f64,
60}
61
62/// A training/test problem: a collection of labelled sparse instances.
63///
64/// `load_problem` validates that sparse feature indices are ascending and
65/// within the configured [`crate::io::LoadOptions`] bounds. When constructing a
66/// problem manually, keep `labels.len() == instances.len()` and use ascending
67/// feature indices to match LIBSVM input assumptions.
68#[derive(Debug, Clone, PartialEq)]
69pub struct SvmProblem {
70    /// Label (class for classification, target for regression) per instance.
71    pub labels: Vec<f64>,
72    /// Sparse feature vectors, one per instance.
73    pub instances: Vec<Vec<SvmNode>>,
74}
75
76/// SVM parameters controlling the formulation, kernel, and solver.
77///
78/// Default values match the original LIBSVM defaults.
79#[derive(Debug, Clone, PartialEq)]
80pub struct SvmParameter {
81    /// SVM formulation type.
82    pub svm_type: SvmType,
83    /// Kernel function type.
84    pub kernel_type: KernelType,
85    /// Degree for polynomial kernel.
86    pub degree: i32,
87    /// γ parameter for RBF, polynomial, and sigmoid kernels.
88    /// Set to `1/num_features` when 0.
89    pub gamma: f64,
90    /// Independent term in polynomial and sigmoid kernels.
91    pub coef0: f64,
92    /// Cache memory size in MB.
93    pub cache_size: f64,
94    /// Stopping tolerance for the solver.
95    pub eps: f64,
96    /// Cost parameter C (for C-SVC, ε-SVR, ν-SVR).
97    pub c: f64,
98    /// Per-class weight overrides: `(class_label, weight)` pairs.
99    pub weight: Vec<(i32, f64)>,
100    /// ν parameter (for ν-SVC, one-class SVM, ν-SVR).
101    pub nu: f64,
102    /// ε in the ε-insensitive loss function (ε-SVR).
103    pub p: f64,
104    /// Whether to use the shrinking heuristic.
105    pub shrinking: bool,
106    /// Whether to train for probability estimates.
107    pub probability: bool,
108}
109
110impl Default for SvmParameter {
111    fn default() -> Self {
112        Self {
113            svm_type: SvmType::CSvc,
114            kernel_type: KernelType::Rbf,
115            degree: 3,
116            gamma: 0.0, // means 1/num_features
117            coef0: 0.0,
118            cache_size: 100.0,
119            eps: 0.001,
120            c: 1.0,
121            weight: Vec::new(),
122            nu: 0.5,
123            p: 0.1,
124            shrinking: true,
125            probability: false,
126        }
127    }
128}
129
130impl SvmParameter {
131    /// Validate parameter values (independent of training data).
132    ///
133    /// This checks the same constraints as the original LIBSVM's
134    /// `svm_check_parameter`, except for the ν-SVC feasibility check
135    /// which requires the problem. Use [`check_parameter`] for the full check.
136    pub fn validate(&self) -> Result<(), crate::error::SvmError> {
137        use crate::error::SvmError;
138
139        // gamma must be non-negative for kernels that use it
140        if matches!(
141            self.kernel_type,
142            KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
143        ) && self.gamma < 0.0
144        {
145            return Err(SvmError::InvalidParameter("gamma < 0".into()));
146        }
147
148        // polynomial degree must be non-negative
149        if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
150            return Err(SvmError::InvalidParameter(
151                "degree of polynomial kernel < 0".into(),
152            ));
153        }
154
155        if self.cache_size <= 0.0 {
156            return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
157        }
158
159        if self.eps <= 0.0 {
160            return Err(SvmError::InvalidParameter("eps <= 0".into()));
161        }
162
163        // C > 0 for formulations that use it
164        if matches!(
165            self.svm_type,
166            SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
167        ) && self.c <= 0.0
168        {
169            return Err(SvmError::InvalidParameter("C <= 0".into()));
170        }
171
172        // nu ∈ (0, 1] for formulations that use it
173        if matches!(
174            self.svm_type,
175            SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
176        ) && (self.nu <= 0.0 || self.nu > 1.0)
177        {
178            return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
179        }
180
181        // p >= 0 for epsilon-SVR
182        if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
183            return Err(SvmError::InvalidParameter("p < 0".into()));
184        }
185
186        Ok(())
187    }
188}
189
190/// Full parameter check including ν-SVC feasibility against training data.
191///
192/// Matches the original LIBSVM `svm_check_parameter()`.
193pub fn check_parameter(
194    problem: &SvmProblem,
195    param: &SvmParameter,
196) -> Result<(), crate::error::SvmError> {
197    use crate::error::SvmError;
198
199    // Run the data-independent checks first
200    param.validate()?;
201
202    if problem.labels.len() != problem.instances.len() {
203        return Err(SvmError::InvalidParameter(format!(
204            "labels length ({}) does not match instance length ({})",
205            problem.labels.len(),
206            problem.instances.len()
207        )));
208    }
209
210    if problem.labels.is_empty() {
211        return Err(SvmError::InvalidParameter(
212            "problem has no instances".into(),
213        ));
214    }
215
216    if param.kernel_type == KernelType::Precomputed {
217        let upper = problem.instances.len() as f64;
218        for (row, instance) in problem.instances.iter().enumerate() {
219            let first = instance.first().ok_or_else(|| {
220                SvmError::InvalidParameter(format!(
221                    "precomputed kernel row {} is missing 0:sample_serial_number",
222                    row + 1
223                ))
224            })?;
225            if first.index != 0
226                || !first.value.is_finite()
227                || first.value < 1.0
228                || first.value > upper
229                || first.value.fract() != 0.0
230            {
231                return Err(SvmError::InvalidParameter(format!(
232                    "precomputed kernel row {} must start with 0:sample_serial_number in [1, {}]",
233                    row + 1,
234                    problem.instances.len()
235                )));
236            }
237        }
238    }
239
240    // ν-SVC feasibility: for every pair of classes (i, j),
241    // nu * (count_i + count_j) / 2 must be <= min(count_i, count_j)
242    //
243    // Note: LIBSVM casts labels to int for class grouping. We match this
244    // behavior. Classification labels must be integers (non-integer labels
245    // will be truncated, matching `(int)prob->y[i]` in the C code).
246    if param.svm_type == SvmType::NuSvc {
247        let mut class_counts: Vec<(i32, usize)> = Vec::new();
248        for &y in &problem.labels {
249            let label = y as i32;
250            if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
251                entry.1 += 1;
252            } else {
253                class_counts.push((label, 1));
254            }
255        }
256
257        for (i, &(_, n1)) in class_counts.iter().enumerate() {
258            for &(_, n2) in &class_counts[i + 1..] {
259                if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
260                    return Err(SvmError::InvalidParameter(
261                        "specified nu is infeasible".into(),
262                    ));
263                }
264            }
265        }
266    }
267
268    Ok(())
269}
270
271/// A trained SVM model.
272///
273/// Produced by training, or loaded from a LIBSVM model file.
274///
275/// `load_model` validates model-file shape contracts before returning this
276/// type: class counts, support-vector counts, decision-function arrays, optional
277/// probability metadata, and sparse support-vector rows must be internally
278/// consistent. Manually constructed values bypass those checks, so callers that
279/// accept external model text should prefer [`crate::io::load_model`] or
280/// [`crate::io::load_model_from_reader`].
281#[derive(Debug, Clone, PartialEq)]
282pub struct SvmModel {
283    /// Parameters used during training.
284    pub param: SvmParameter,
285    /// Number of classes (2 for binary, >2 for multiclass, 2 for regression).
286    pub nr_class: usize,
287    /// Support vectors (sparse feature vectors).
288    pub sv: Vec<Vec<SvmNode>>,
289    /// Support vector coefficients. For k classes, this is a
290    /// `(k-1) × num_sv` matrix stored as `Vec<Vec<f64>>`.
291    pub sv_coef: Vec<Vec<f64>>,
292    /// Bias terms (rho). One per class pair: `k*(k-1)/2` values.
293    pub rho: Vec<f64>,
294    /// Pairwise probability parameter A (Platt scaling). Empty if not trained
295    /// with probability estimates.
296    pub prob_a: Vec<f64>,
297    /// Pairwise probability parameter B (Platt scaling). Empty if not trained
298    /// with probability estimates.
299    pub prob_b: Vec<f64>,
300    /// Probability density marks (for one-class SVM).
301    pub prob_density_marks: Vec<f64>,
302    /// Original indices of support vectors in the training set (1-based).
303    pub sv_indices: Vec<usize>,
304    /// Class labels (in the order used internally).
305    pub label: Vec<i32>,
306    /// Number of support vectors per class.
307    pub n_sv: Vec<usize>,
308}
309
310impl SvmModel {
311    /// Return the SVM type used by the model.
312    pub fn svm_type(&self) -> SvmType {
313        self.param.svm_type
314    }
315
316    /// Return number of classes.
317    pub fn class_count(&self) -> usize {
318        self.nr_class
319    }
320
321    /// Return class labels in internal one-vs-one order.
322    pub fn labels(&self) -> &[i32] {
323        &self.label
324    }
325
326    /// Return original 1-based support-vector indices.
327    pub fn support_vector_indices(&self) -> &[usize] {
328        &self.sv_indices
329    }
330
331    /// Return total number of support vectors.
332    pub fn support_vector_count(&self) -> usize {
333        self.sv.len()
334    }
335
336    /// Return SVR sigma when a probability-capable SVR model is available.
337    pub fn svr_probability(&self) -> Option<f64> {
338        match self.param.svm_type {
339            SvmType::EpsilonSvr | SvmType::NuSvr => self.prob_a.first().copied(),
340            _ => None,
341        }
342    }
343
344    /// Check whether the model contains probability metadata.
345    pub fn has_probability_model(&self) -> bool {
346        match self.param.svm_type {
347            SvmType::CSvc | SvmType::NuSvc => !self.prob_a.is_empty() && !self.prob_b.is_empty(),
348            SvmType::EpsilonSvr | SvmType::NuSvr => !self.prob_a.is_empty(),
349            SvmType::OneClass => !self.prob_density_marks.is_empty(),
350        }
351    }
352}
353
354/// C-API style helper matching LIBSVM's `svm_get_svm_type`.
355pub fn svm_get_svm_type(model: &SvmModel) -> SvmType {
356    model.svm_type()
357}
358
359/// C-API style helper matching LIBSVM's `svm_get_nr_class`.
360pub fn svm_get_nr_class(model: &SvmModel) -> usize {
361    model.class_count()
362}
363
364/// C-API style helper matching LIBSVM's `svm_get_labels`.
365pub fn svm_get_labels(model: &SvmModel) -> &[i32] {
366    model.labels()
367}
368
369/// C-API style helper matching LIBSVM's `svm_get_sv_indices`.
370pub fn svm_get_sv_indices(model: &SvmModel) -> &[usize] {
371    model.support_vector_indices()
372}
373
374/// C-API style helper matching LIBSVM's `svm_get_nr_sv`.
375pub fn svm_get_nr_sv(model: &SvmModel) -> usize {
376    model.support_vector_count()
377}
378
379/// C-API style helper matching LIBSVM's `svm_get_svr_probability`.
380pub fn svm_get_svr_probability(model: &SvmModel) -> Option<f64> {
381    model.svr_probability()
382}
383
384/// C-API style helper matching LIBSVM's `svm_check_probability_model`.
385pub fn svm_check_probability_model(model: &SvmModel) -> bool {
386    model.has_probability_model()
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use crate::train::svm_train;
393    use std::path::PathBuf;
394
395    fn data_dir() -> PathBuf {
396        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
397            .join("..")
398            .join("..")
399            .join("data")
400    }
401
402    #[test]
403    fn default_params_are_valid() {
404        SvmParameter::default().validate().unwrap();
405    }
406
407    #[test]
408    fn negative_gamma_rejected() {
409        let p = SvmParameter {
410            gamma: -1.0,
411            ..Default::default()
412        };
413        assert!(p.validate().is_err());
414    }
415
416    #[test]
417    fn zero_cache_rejected() {
418        let p = SvmParameter {
419            cache_size: 0.0,
420            ..Default::default()
421        };
422        assert!(p.validate().is_err());
423    }
424
425    #[test]
426    fn zero_c_rejected() {
427        let p = SvmParameter {
428            c: 0.0,
429            ..Default::default()
430        };
431        assert!(p.validate().is_err());
432    }
433
434    #[test]
435    fn nu_out_of_range_rejected() {
436        let p = SvmParameter {
437            svm_type: SvmType::NuSvc,
438            nu: 1.5,
439            ..Default::default()
440        };
441        assert!(p.validate().is_err());
442
443        let p2 = SvmParameter {
444            svm_type: SvmType::NuSvc,
445            nu: 0.0,
446            ..Default::default()
447        };
448        assert!(p2.validate().is_err());
449    }
450
451    #[test]
452    fn negative_p_rejected_for_svr() {
453        let p = SvmParameter {
454            svm_type: SvmType::EpsilonSvr,
455            p: -0.1,
456            ..Default::default()
457        };
458        assert!(p.validate().is_err());
459    }
460
461    #[test]
462    fn negative_poly_degree_rejected() {
463        let p = SvmParameter {
464            kernel_type: KernelType::Polynomial,
465            degree: -1,
466            ..Default::default()
467        };
468        assert!(p.validate().is_err());
469    }
470
471    #[test]
472    fn check_parameter_rejects_empty_problem() {
473        let problem = SvmProblem {
474            labels: Vec::new(),
475            instances: Vec::new(),
476        };
477        let err = check_parameter(&problem, &SvmParameter::default()).unwrap_err();
478        assert!(format!("{}", err).contains("problem has no instances"));
479    }
480
481    #[test]
482    fn check_parameter_rejects_label_instance_length_mismatch() {
483        let problem = SvmProblem {
484            labels: vec![1.0],
485            instances: Vec::new(),
486        };
487        let err = check_parameter(&problem, &SvmParameter::default()).unwrap_err();
488        assert!(format!("{}", err).contains("does not match instance length"));
489    }
490
491    #[test]
492    fn check_parameter_rejects_precomputed_rows_without_sample_serial_number() {
493        let problem = SvmProblem {
494            labels: vec![1.0, -1.0],
495            instances: vec![
496                vec![],
497                vec![SvmNode {
498                    index: 0,
499                    value: 2.0,
500                }],
501            ],
502        };
503        let param = SvmParameter {
504            kernel_type: KernelType::Precomputed,
505            ..Default::default()
506        };
507        let err = check_parameter(&problem, &param).unwrap_err();
508        assert!(format!("{}", err).contains("missing 0:sample_serial_number"));
509    }
510
511    #[test]
512    fn nu_svc_feasibility_check() {
513        // 2 classes with 3 samples each: nu * (3+3)/2 <= 3  →  nu <= 1
514        let problem = SvmProblem {
515            labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
516            instances: vec![vec![]; 6],
517        };
518        let ok_param = SvmParameter {
519            svm_type: SvmType::NuSvc,
520            nu: 0.5,
521            ..Default::default()
522        };
523        check_parameter(&problem, &ok_param).unwrap();
524
525        // nu = 0.9: 0.9 * 6/2 = 2.7 <= 3 → feasible
526        let borderline = SvmParameter {
527            svm_type: SvmType::NuSvc,
528            nu: 0.9,
529            ..Default::default()
530        };
531        check_parameter(&problem, &borderline).unwrap();
532    }
533
534    #[test]
535    fn nu_svc_infeasible() {
536        // 5 class-A, 1 class-B: nu*(5+1)/2 > min(5,1)=1  →  nu > 1/3
537        let problem = SvmProblem {
538            labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
539            instances: vec![vec![]; 6],
540        };
541        let param = SvmParameter {
542            svm_type: SvmType::NuSvc,
543            nu: 0.5, // 0.5 * 6/2 = 1.5 > 1
544            ..Default::default()
545        };
546        let err = check_parameter(&problem, &param);
547        assert!(err.is_err());
548        assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
549    }
550
551    #[test]
552    fn c_api_style_model_helpers() {
553        let problem = crate::io::load_problem(&data_dir().join("heart_scale")).unwrap();
554        let param = SvmParameter {
555            gamma: 1.0 / 13.0,
556            ..Default::default()
557        };
558        let model = svm_train(&problem, &param);
559
560        assert_eq!(svm_get_svm_type(&model), SvmType::CSvc);
561        assert_eq!(svm_get_nr_class(&model), 2);
562        assert_eq!(svm_get_nr_sv(&model), model.sv.len());
563        assert_eq!(svm_get_labels(&model), model.label.as_slice());
564        assert_eq!(svm_get_sv_indices(&model), model.sv_indices.as_slice());
565        assert!(!svm_check_probability_model(&model));
566        assert_eq!(svm_get_svr_probability(&model), None);
567    }
568
569    #[test]
570    fn probability_helpers_by_svm_type() {
571        let svm = vec![SvmNode {
572            index: 1,
573            value: 1.0,
574        }];
575
576        let csvc_model = SvmModel {
577            param: SvmParameter {
578                svm_type: SvmType::CSvc,
579                ..Default::default()
580            },
581            nr_class: 2,
582            sv: vec![svm.clone()],
583            sv_coef: vec![vec![1.0]],
584            rho: vec![0.0],
585            prob_a: vec![1.0],
586            prob_b: vec![-0.5],
587            prob_density_marks: vec![],
588            sv_indices: vec![1],
589            label: vec![1, -1],
590            n_sv: vec![1, 0],
591        };
592        assert!(csvc_model.has_probability_model());
593        assert!(svm_check_probability_model(&csvc_model));
594        assert_eq!(svm_get_svr_probability(&csvc_model), None);
595
596        let eps_svr_model = SvmModel {
597            param: SvmParameter {
598                svm_type: SvmType::EpsilonSvr,
599                ..Default::default()
600            },
601            nr_class: 2,
602            sv: vec![svm.clone()],
603            sv_coef: vec![vec![0.8]],
604            rho: vec![0.0],
605            prob_a: vec![0.123],
606            prob_b: vec![],
607            prob_density_marks: vec![],
608            sv_indices: vec![1],
609            label: vec![],
610            n_sv: vec![],
611        };
612        assert!(eps_svr_model.has_probability_model());
613        assert_eq!(svm_get_svr_probability(&eps_svr_model), Some(0.123));
614
615        let one_class_model = SvmModel {
616            param: SvmParameter {
617                svm_type: SvmType::OneClass,
618                ..Default::default()
619            },
620            nr_class: 2,
621            sv: vec![svm],
622            sv_coef: vec![vec![1.0]],
623            rho: vec![0.0],
624            prob_a: vec![],
625            prob_b: vec![],
626            prob_density_marks: vec![0.1; 10],
627            sv_indices: vec![1],
628            label: vec![],
629            n_sv: vec![],
630        };
631        assert!(one_class_model.has_probability_model());
632        assert_eq!(svm_get_svr_probability(&one_class_model), None);
633    }
634}