Skip to main content

libsvm_rs/
types.rs

1/// Type of SVM formulation.
2///
3/// Matches the integer constants in the original LIBSVM (`svm.h`):
4/// `C_SVC=0, NU_SVC=1, ONE_CLASS=2, EPSILON_SVR=3, NU_SVR=4`.
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6#[repr(i32)]
7pub enum SvmType {
8    /// C-Support Vector Classification.
9    CSvc = 0,
10    /// ν-Support Vector Classification.
11    NuSvc = 1,
12    /// One-class SVM (distribution estimation / novelty detection).
13    OneClass = 2,
14    /// ε-Support Vector Regression.
15    EpsilonSvr = 3,
16    /// ν-Support Vector Regression.
17    NuSvr = 4,
18}
19
20/// Type of kernel function.
21///
22/// Matches the integer constants in the original LIBSVM (`svm.h`):
23/// `LINEAR=0, POLY=1, RBF=2, SIGMOID=3, PRECOMPUTED=4`.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25#[repr(i32)]
26pub enum KernelType {
27    /// `K(x,y) = x·y`
28    Linear = 0,
29    /// `K(x,y) = (γ·x·y + coef0)^degree`
30    Polynomial = 1,
31    /// `K(x,y) = exp(-γ·‖x-y‖²)`
32    Rbf = 2,
33    /// `K(x,y) = tanh(γ·x·y + coef0)`
34    Sigmoid = 3,
35    /// Kernel values supplied as a precomputed matrix.
36    Precomputed = 4,
37}
38
39/// A single sparse feature: `index:value`.
40///
41/// In the original LIBSVM, a sentinel node with `index = -1` marks the end
42/// of each instance. In this Rust port, instance length is tracked by
43/// `Vec::len()` instead, so no sentinel is needed.
44#[derive(Debug, Clone, Copy, PartialEq)]
45pub struct SvmNode {
46    /// 1-based feature index. Uses `i32` to match the original C `int` and
47    /// preserve file-format compatibility.
48    pub index: i32,
49    /// Feature value.
50    pub value: f64,
51}
52
53/// A training/test problem: a collection of labelled sparse instances.
54#[derive(Debug, Clone, PartialEq)]
55pub struct SvmProblem {
56    /// Label (class for classification, target for regression) per instance.
57    pub labels: Vec<f64>,
58    /// Sparse feature vectors, one per instance.
59    pub instances: Vec<Vec<SvmNode>>,
60}
61
62/// SVM parameters controlling the formulation, kernel, and solver.
63///
64/// Default values match the original LIBSVM defaults.
65#[derive(Debug, Clone, PartialEq)]
66pub struct SvmParameter {
67    /// SVM formulation type.
68    pub svm_type: SvmType,
69    /// Kernel function type.
70    pub kernel_type: KernelType,
71    /// Degree for polynomial kernel.
72    pub degree: i32,
73    /// γ parameter for RBF, polynomial, and sigmoid kernels.
74    /// Set to `1/num_features` when 0.
75    pub gamma: f64,
76    /// Independent term in polynomial and sigmoid kernels.
77    pub coef0: f64,
78    /// Cache memory size in MB.
79    pub cache_size: f64,
80    /// Stopping tolerance for the solver.
81    pub eps: f64,
82    /// Cost parameter C (for C-SVC, ε-SVR, ν-SVR).
83    pub c: f64,
84    /// Per-class weight overrides: `(class_label, weight)` pairs.
85    pub weight: Vec<(i32, f64)>,
86    /// ν parameter (for ν-SVC, one-class SVM, ν-SVR).
87    pub nu: f64,
88    /// ε in the ε-insensitive loss function (ε-SVR).
89    pub p: f64,
90    /// Whether to use the shrinking heuristic.
91    pub shrinking: bool,
92    /// Whether to train for probability estimates.
93    pub probability: bool,
94}
95
96impl Default for SvmParameter {
97    fn default() -> Self {
98        Self {
99            svm_type: SvmType::CSvc,
100            kernel_type: KernelType::Rbf,
101            degree: 3,
102            gamma: 0.0, // means 1/num_features
103            coef0: 0.0,
104            cache_size: 100.0,
105            eps: 0.001,
106            c: 1.0,
107            weight: Vec::new(),
108            nu: 0.5,
109            p: 0.1,
110            shrinking: true,
111            probability: false,
112        }
113    }
114}
115
116impl SvmParameter {
117    /// Validate parameter values (independent of training data).
118    ///
119    /// This checks the same constraints as the original LIBSVM's
120    /// `svm_check_parameter`, except for the ν-SVC feasibility check
121    /// which requires the problem. Use [`check_parameter`] for the full check.
122    pub fn validate(&self) -> Result<(), crate::error::SvmError> {
123        use crate::error::SvmError;
124
125        // gamma must be non-negative for kernels that use it
126        if matches!(
127            self.kernel_type,
128            KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
129        ) && self.gamma < 0.0
130        {
131            return Err(SvmError::InvalidParameter("gamma < 0".into()));
132        }
133
134        // polynomial degree must be non-negative
135        if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
136            return Err(SvmError::InvalidParameter(
137                "degree of polynomial kernel < 0".into(),
138            ));
139        }
140
141        if self.cache_size <= 0.0 {
142            return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
143        }
144
145        if self.eps <= 0.0 {
146            return Err(SvmError::InvalidParameter("eps <= 0".into()));
147        }
148
149        // C > 0 for formulations that use it
150        if matches!(
151            self.svm_type,
152            SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
153        ) && self.c <= 0.0
154        {
155            return Err(SvmError::InvalidParameter("C <= 0".into()));
156        }
157
158        // nu ∈ (0, 1] for formulations that use it
159        if matches!(
160            self.svm_type,
161            SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
162        ) && (self.nu <= 0.0 || self.nu > 1.0)
163        {
164            return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
165        }
166
167        // p >= 0 for epsilon-SVR
168        if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
169            return Err(SvmError::InvalidParameter("p < 0".into()));
170        }
171
172        Ok(())
173    }
174}
175
176/// Full parameter check including ν-SVC feasibility against training data.
177///
178/// Matches the original LIBSVM `svm_check_parameter()`.
179pub fn check_parameter(
180    problem: &SvmProblem,
181    param: &SvmParameter,
182) -> Result<(), crate::error::SvmError> {
183    use crate::error::SvmError;
184
185    // Run the data-independent checks first
186    param.validate()?;
187
188    // ν-SVC feasibility: for every pair of classes (i, j),
189    // nu * (count_i + count_j) / 2 must be <= min(count_i, count_j)
190    if param.svm_type == SvmType::NuSvc {
191        let mut class_counts: Vec<(i32, usize)> = Vec::new();
192        for &y in &problem.labels {
193            let label = y as i32;
194            if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
195                entry.1 += 1;
196            } else {
197                class_counts.push((label, 1));
198            }
199        }
200
201        for (i, &(_, n1)) in class_counts.iter().enumerate() {
202            for &(_, n2) in &class_counts[i + 1..] {
203                if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
204                    return Err(SvmError::InvalidParameter(
205                        "specified nu is infeasible".into(),
206                    ));
207                }
208            }
209        }
210    }
211
212    Ok(())
213}
214
215/// A trained SVM model.
216///
217/// Produced by training, or loaded from a LIBSVM model file.
218#[derive(Debug, Clone, PartialEq)]
219pub struct SvmModel {
220    /// Parameters used during training.
221    pub param: SvmParameter,
222    /// Number of classes (2 for binary, >2 for multiclass, 2 for regression).
223    pub nr_class: usize,
224    /// Support vectors (sparse feature vectors).
225    pub sv: Vec<Vec<SvmNode>>,
226    /// Support vector coefficients. For k classes, this is a
227    /// `(k-1) × num_sv` matrix stored as `Vec<Vec<f64>>`.
228    pub sv_coef: Vec<Vec<f64>>,
229    /// Bias terms (rho). One per class pair: `k*(k-1)/2` values.
230    pub rho: Vec<f64>,
231    /// Pairwise probability parameter A (Platt scaling). Empty if not trained
232    /// with probability estimates.
233    pub prob_a: Vec<f64>,
234    /// Pairwise probability parameter B (Platt scaling). Empty if not trained
235    /// with probability estimates.
236    pub prob_b: Vec<f64>,
237    /// Probability density marks (for one-class SVM).
238    pub prob_density_marks: Vec<f64>,
239    /// Original indices of support vectors in the training set (1-based).
240    pub sv_indices: Vec<usize>,
241    /// Class labels (in the order used internally).
242    pub label: Vec<i32>,
243    /// Number of support vectors per class.
244    pub n_sv: Vec<usize>,
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn default_params_are_valid() {
253        SvmParameter::default().validate().unwrap();
254    }
255
256    #[test]
257    fn negative_gamma_rejected() {
258        let p = SvmParameter {
259            gamma: -1.0,
260            ..Default::default()
261        };
262        assert!(p.validate().is_err());
263    }
264
265    #[test]
266    fn zero_cache_rejected() {
267        let p = SvmParameter {
268            cache_size: 0.0,
269            ..Default::default()
270        };
271        assert!(p.validate().is_err());
272    }
273
274    #[test]
275    fn zero_c_rejected() {
276        let p = SvmParameter {
277            c: 0.0,
278            ..Default::default()
279        };
280        assert!(p.validate().is_err());
281    }
282
283    #[test]
284    fn nu_out_of_range_rejected() {
285        let p = SvmParameter {
286            svm_type: SvmType::NuSvc,
287            nu: 1.5,
288            ..Default::default()
289        };
290        assert!(p.validate().is_err());
291
292        let p2 = SvmParameter {
293            svm_type: SvmType::NuSvc,
294            nu: 0.0,
295            ..Default::default()
296        };
297        assert!(p2.validate().is_err());
298    }
299
300    #[test]
301    fn negative_p_rejected_for_svr() {
302        let p = SvmParameter {
303            svm_type: SvmType::EpsilonSvr,
304            p: -0.1,
305            ..Default::default()
306        };
307        assert!(p.validate().is_err());
308    }
309
310    #[test]
311    fn negative_poly_degree_rejected() {
312        let p = SvmParameter {
313            kernel_type: KernelType::Polynomial,
314            degree: -1,
315            ..Default::default()
316        };
317        assert!(p.validate().is_err());
318    }
319
320    #[test]
321    fn nu_svc_feasibility_check() {
322        // 2 classes with 3 samples each: nu * (3+3)/2 <= 3  →  nu <= 1
323        let problem = SvmProblem {
324            labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
325            instances: vec![vec![]; 6],
326        };
327        let ok_param = SvmParameter {
328            svm_type: SvmType::NuSvc,
329            nu: 0.5,
330            ..Default::default()
331        };
332        check_parameter(&problem, &ok_param).unwrap();
333
334        // nu = 0.9: 0.9 * 6/2 = 2.7 <= 3 → feasible
335        let borderline = SvmParameter {
336            svm_type: SvmType::NuSvc,
337            nu: 0.9,
338            ..Default::default()
339        };
340        check_parameter(&problem, &borderline).unwrap();
341    }
342
343    #[test]
344    fn nu_svc_infeasible() {
345        // 5 class-A, 1 class-B: nu*(5+1)/2 > min(5,1)=1  →  nu > 1/3
346        let problem = SvmProblem {
347            labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
348            instances: vec![vec![]; 6],
349        };
350        let param = SvmParameter {
351            svm_type: SvmType::NuSvc,
352            nu: 0.5, // 0.5 * 6/2 = 1.5 > 1
353            ..Default::default()
354        };
355        let err = check_parameter(&problem, &param);
356        assert!(err.is_err());
357        assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
358    }
359}