Skip to main content

libsvm_rs/
types.rs

1/// Type of SVM formulation.
2///
3/// Matches the integer constants in the original LIBSVM (`svm.h`):
4/// `C_SVC=0, NU_SVC=1, ONE_CLASS=2, EPSILON_SVR=3, NU_SVR=4`.
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6#[repr(i32)]
7pub enum SvmType {
8    /// C-Support Vector Classification.
9    CSvc = 0,
10    /// ν-Support Vector Classification.
11    NuSvc = 1,
12    /// One-class SVM (distribution estimation / novelty detection).
13    OneClass = 2,
14    /// ε-Support Vector Regression.
15    EpsilonSvr = 3,
16    /// ν-Support Vector Regression.
17    NuSvr = 4,
18}
19
20/// Type of kernel function.
21///
22/// Matches the integer constants in the original LIBSVM (`svm.h`):
23/// `LINEAR=0, POLY=1, RBF=2, SIGMOID=3, PRECOMPUTED=4`.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25#[repr(i32)]
26pub enum KernelType {
27    /// `K(x,y) = x·y`
28    Linear = 0,
29    /// `K(x,y) = (γ·x·y + coef0)^degree`
30    Polynomial = 1,
31    /// `K(x,y) = exp(-γ·‖x-y‖²)`
32    Rbf = 2,
33    /// `K(x,y) = tanh(γ·x·y + coef0)`
34    Sigmoid = 3,
35    /// Kernel values supplied as a precomputed matrix.
36    Precomputed = 4,
37}
38
39/// A single sparse feature: `index:value`.
40///
41/// In the original LIBSVM, a sentinel node with `index = -1` marks the end
42/// of each instance. In this Rust port, instance length is tracked by
43/// `Vec::len()` instead, so no sentinel is needed.
44#[derive(Debug, Clone, Copy, PartialEq)]
45pub struct SvmNode {
46    /// 1-based feature index. Uses `i32` to match the original C `int` and
47    /// preserve file-format compatibility.
48    pub index: i32,
49    /// Feature value.
50    pub value: f64,
51}
52
53/// A training/test problem: a collection of labelled sparse instances.
54#[derive(Debug, Clone, PartialEq)]
55pub struct SvmProblem {
56    /// Label (class for classification, target for regression) per instance.
57    pub labels: Vec<f64>,
58    /// Sparse feature vectors, one per instance.
59    pub instances: Vec<Vec<SvmNode>>,
60}
61
62/// SVM parameters controlling the formulation, kernel, and solver.
63///
64/// Default values match the original LIBSVM defaults.
65#[derive(Debug, Clone, PartialEq)]
66pub struct SvmParameter {
67    /// SVM formulation type.
68    pub svm_type: SvmType,
69    /// Kernel function type.
70    pub kernel_type: KernelType,
71    /// Degree for polynomial kernel.
72    pub degree: i32,
73    /// γ parameter for RBF, polynomial, and sigmoid kernels.
74    /// Set to `1/num_features` when 0.
75    pub gamma: f64,
76    /// Independent term in polynomial and sigmoid kernels.
77    pub coef0: f64,
78    /// Cache memory size in MB.
79    pub cache_size: f64,
80    /// Stopping tolerance for the solver.
81    pub eps: f64,
82    /// Cost parameter C (for C-SVC, ε-SVR, ν-SVR).
83    pub c: f64,
84    /// Per-class weight overrides: `(class_label, weight)` pairs.
85    pub weight: Vec<(i32, f64)>,
86    /// ν parameter (for ν-SVC, one-class SVM, ν-SVR).
87    pub nu: f64,
88    /// ε in the ε-insensitive loss function (ε-SVR).
89    pub p: f64,
90    /// Whether to use the shrinking heuristic.
91    pub shrinking: bool,
92    /// Whether to train for probability estimates.
93    pub probability: bool,
94}
95
96impl Default for SvmParameter {
97    fn default() -> Self {
98        Self {
99            svm_type: SvmType::CSvc,
100            kernel_type: KernelType::Rbf,
101            degree: 3,
102            gamma: 0.0, // means 1/num_features
103            coef0: 0.0,
104            cache_size: 100.0,
105            eps: 0.001,
106            c: 1.0,
107            weight: Vec::new(),
108            nu: 0.5,
109            p: 0.1,
110            shrinking: true,
111            probability: false,
112        }
113    }
114}
115
116impl SvmParameter {
117    /// Validate parameter values (independent of training data).
118    ///
119    /// This checks the same constraints as the original LIBSVM's
120    /// `svm_check_parameter`, except for the ν-SVC feasibility check
121    /// which requires the problem. Use [`check_parameter`] for the full check.
122    pub fn validate(&self) -> Result<(), crate::error::SvmError> {
123        use crate::error::SvmError;
124
125        // gamma must be non-negative for kernels that use it
126        if matches!(
127            self.kernel_type,
128            KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
129        ) && self.gamma < 0.0
130        {
131            return Err(SvmError::InvalidParameter("gamma < 0".into()));
132        }
133
134        // polynomial degree must be non-negative
135        if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
136            return Err(SvmError::InvalidParameter(
137                "degree of polynomial kernel < 0".into(),
138            ));
139        }
140
141        if self.cache_size <= 0.0 {
142            return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
143        }
144
145        if self.eps <= 0.0 {
146            return Err(SvmError::InvalidParameter("eps <= 0".into()));
147        }
148
149        // C > 0 for formulations that use it
150        if matches!(
151            self.svm_type,
152            SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
153        ) && self.c <= 0.0
154        {
155            return Err(SvmError::InvalidParameter("C <= 0".into()));
156        }
157
158        // nu ∈ (0, 1] for formulations that use it
159        if matches!(
160            self.svm_type,
161            SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
162        ) && (self.nu <= 0.0 || self.nu > 1.0)
163        {
164            return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
165        }
166
167        // p >= 0 for epsilon-SVR
168        if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
169            return Err(SvmError::InvalidParameter("p < 0".into()));
170        }
171
172        Ok(())
173    }
174}
175
176/// Full parameter check including ν-SVC feasibility against training data.
177///
178/// Matches the original LIBSVM `svm_check_parameter()`.
179pub fn check_parameter(
180    problem: &SvmProblem,
181    param: &SvmParameter,
182) -> Result<(), crate::error::SvmError> {
183    use crate::error::SvmError;
184
185    // Run the data-independent checks first
186    param.validate()?;
187
188    // ν-SVC feasibility: for every pair of classes (i, j),
189    // nu * (count_i + count_j) / 2 must be <= min(count_i, count_j)
190    //
191    // Note: LIBSVM casts labels to int for class grouping. We match this
192    // behavior. Classification labels must be integers (non-integer labels
193    // will be truncated, matching `(int)prob->y[i]` in the C code).
194    if param.svm_type == SvmType::NuSvc {
195        let mut class_counts: Vec<(i32, usize)> = Vec::new();
196        for &y in &problem.labels {
197            let label = y as i32;
198            if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
199                entry.1 += 1;
200            } else {
201                class_counts.push((label, 1));
202            }
203        }
204
205        for (i, &(_, n1)) in class_counts.iter().enumerate() {
206            for &(_, n2) in &class_counts[i + 1..] {
207                if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
208                    return Err(SvmError::InvalidParameter(
209                        "specified nu is infeasible".into(),
210                    ));
211                }
212            }
213        }
214    }
215
216    Ok(())
217}
218
219/// A trained SVM model.
220///
221/// Produced by training, or loaded from a LIBSVM model file.
222#[derive(Debug, Clone, PartialEq)]
223pub struct SvmModel {
224    /// Parameters used during training.
225    pub param: SvmParameter,
226    /// Number of classes (2 for binary, >2 for multiclass, 2 for regression).
227    pub nr_class: usize,
228    /// Support vectors (sparse feature vectors).
229    pub sv: Vec<Vec<SvmNode>>,
230    /// Support vector coefficients. For k classes, this is a
231    /// `(k-1) × num_sv` matrix stored as `Vec<Vec<f64>>`.
232    pub sv_coef: Vec<Vec<f64>>,
233    /// Bias terms (rho). One per class pair: `k*(k-1)/2` values.
234    pub rho: Vec<f64>,
235    /// Pairwise probability parameter A (Platt scaling). Empty if not trained
236    /// with probability estimates.
237    pub prob_a: Vec<f64>,
238    /// Pairwise probability parameter B (Platt scaling). Empty if not trained
239    /// with probability estimates.
240    pub prob_b: Vec<f64>,
241    /// Probability density marks (for one-class SVM).
242    pub prob_density_marks: Vec<f64>,
243    /// Original indices of support vectors in the training set (1-based).
244    pub sv_indices: Vec<usize>,
245    /// Class labels (in the order used internally).
246    pub label: Vec<i32>,
247    /// Number of support vectors per class.
248    pub n_sv: Vec<usize>,
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn default_params_are_valid() {
257        SvmParameter::default().validate().unwrap();
258    }
259
260    #[test]
261    fn negative_gamma_rejected() {
262        let p = SvmParameter {
263            gamma: -1.0,
264            ..Default::default()
265        };
266        assert!(p.validate().is_err());
267    }
268
269    #[test]
270    fn zero_cache_rejected() {
271        let p = SvmParameter {
272            cache_size: 0.0,
273            ..Default::default()
274        };
275        assert!(p.validate().is_err());
276    }
277
278    #[test]
279    fn zero_c_rejected() {
280        let p = SvmParameter {
281            c: 0.0,
282            ..Default::default()
283        };
284        assert!(p.validate().is_err());
285    }
286
287    #[test]
288    fn nu_out_of_range_rejected() {
289        let p = SvmParameter {
290            svm_type: SvmType::NuSvc,
291            nu: 1.5,
292            ..Default::default()
293        };
294        assert!(p.validate().is_err());
295
296        let p2 = SvmParameter {
297            svm_type: SvmType::NuSvc,
298            nu: 0.0,
299            ..Default::default()
300        };
301        assert!(p2.validate().is_err());
302    }
303
304    #[test]
305    fn negative_p_rejected_for_svr() {
306        let p = SvmParameter {
307            svm_type: SvmType::EpsilonSvr,
308            p: -0.1,
309            ..Default::default()
310        };
311        assert!(p.validate().is_err());
312    }
313
314    #[test]
315    fn negative_poly_degree_rejected() {
316        let p = SvmParameter {
317            kernel_type: KernelType::Polynomial,
318            degree: -1,
319            ..Default::default()
320        };
321        assert!(p.validate().is_err());
322    }
323
324    #[test]
325    fn nu_svc_feasibility_check() {
326        // 2 classes with 3 samples each: nu * (3+3)/2 <= 3  →  nu <= 1
327        let problem = SvmProblem {
328            labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
329            instances: vec![vec![]; 6],
330        };
331        let ok_param = SvmParameter {
332            svm_type: SvmType::NuSvc,
333            nu: 0.5,
334            ..Default::default()
335        };
336        check_parameter(&problem, &ok_param).unwrap();
337
338        // nu = 0.9: 0.9 * 6/2 = 2.7 <= 3 → feasible
339        let borderline = SvmParameter {
340            svm_type: SvmType::NuSvc,
341            nu: 0.9,
342            ..Default::default()
343        };
344        check_parameter(&problem, &borderline).unwrap();
345    }
346
347    #[test]
348    fn nu_svc_infeasible() {
349        // 5 class-A, 1 class-B: nu*(5+1)/2 > min(5,1)=1  →  nu > 1/3
350        let problem = SvmProblem {
351            labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
352            instances: vec![vec![]; 6],
353        };
354        let param = SvmParameter {
355            svm_type: SvmType::NuSvc,
356            nu: 0.5, // 0.5 * 6/2 = 1.5 > 1
357            ..Default::default()
358        };
359        let err = check_parameter(&problem, &param);
360        assert!(err.is_err());
361        assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
362    }
363}