libsvm_rs/types.rs
1/// Type of SVM formulation.
2///
3/// Matches the integer constants in the original LIBSVM (`svm.h`):
4/// `C_SVC=0, NU_SVC=1, ONE_CLASS=2, EPSILON_SVR=3, NU_SVR=4`.
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6#[repr(i32)]
7pub enum SvmType {
8 /// C-Support Vector Classification.
9 CSvc = 0,
10 /// ν-Support Vector Classification.
11 NuSvc = 1,
12 /// One-class SVM (distribution estimation / novelty detection).
13 OneClass = 2,
14 /// ε-Support Vector Regression.
15 EpsilonSvr = 3,
16 /// ν-Support Vector Regression.
17 NuSvr = 4,
18}
19
20/// Type of kernel function.
21///
22/// Matches the integer constants in the original LIBSVM (`svm.h`):
23/// `LINEAR=0, POLY=1, RBF=2, SIGMOID=3, PRECOMPUTED=4`.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25#[repr(i32)]
26pub enum KernelType {
27 /// `K(x,y) = x·y`
28 Linear = 0,
29 /// `K(x,y) = (γ·x·y + coef0)^degree`
30 Polynomial = 1,
31 /// `K(x,y) = exp(-γ·‖x-y‖²)`
32 Rbf = 2,
33 /// `K(x,y) = tanh(γ·x·y + coef0)`
34 Sigmoid = 3,
35 /// Kernel values supplied as a precomputed matrix.
36 Precomputed = 4,
37}
38
39/// A single sparse feature: `index:value`.
40///
41/// In the original LIBSVM, a sentinel node with `index = -1` marks the end
42/// of each instance. In this Rust port, instance length is tracked by
43/// `Vec::len()` instead, so no sentinel is needed.
44#[derive(Debug, Clone, Copy, PartialEq)]
45pub struct SvmNode {
46 /// 1-based feature index. Uses `i32` to match the original C `int` and
47 /// preserve file-format compatibility.
48 pub index: i32,
49 /// Feature value.
50 pub value: f64,
51}
52
53/// A training/test problem: a collection of labelled sparse instances.
54#[derive(Debug, Clone, PartialEq)]
55pub struct SvmProblem {
56 /// Label (class for classification, target for regression) per instance.
57 pub labels: Vec<f64>,
58 /// Sparse feature vectors, one per instance.
59 pub instances: Vec<Vec<SvmNode>>,
60}
61
62/// SVM parameters controlling the formulation, kernel, and solver.
63///
64/// Default values match the original LIBSVM defaults.
65#[derive(Debug, Clone, PartialEq)]
66pub struct SvmParameter {
67 /// SVM formulation type.
68 pub svm_type: SvmType,
69 /// Kernel function type.
70 pub kernel_type: KernelType,
71 /// Degree for polynomial kernel.
72 pub degree: i32,
73 /// γ parameter for RBF, polynomial, and sigmoid kernels.
74 /// Set to `1/num_features` when 0.
75 pub gamma: f64,
76 /// Independent term in polynomial and sigmoid kernels.
77 pub coef0: f64,
78 /// Cache memory size in MB.
79 pub cache_size: f64,
80 /// Stopping tolerance for the solver.
81 pub eps: f64,
82 /// Cost parameter C (for C-SVC, ε-SVR, ν-SVR).
83 pub c: f64,
84 /// Per-class weight overrides: `(class_label, weight)` pairs.
85 pub weight: Vec<(i32, f64)>,
86 /// ν parameter (for ν-SVC, one-class SVM, ν-SVR).
87 pub nu: f64,
88 /// ε in the ε-insensitive loss function (ε-SVR).
89 pub p: f64,
90 /// Whether to use the shrinking heuristic.
91 pub shrinking: bool,
92 /// Whether to train for probability estimates.
93 pub probability: bool,
94}
95
96impl Default for SvmParameter {
97 fn default() -> Self {
98 Self {
99 svm_type: SvmType::CSvc,
100 kernel_type: KernelType::Rbf,
101 degree: 3,
102 gamma: 0.0, // means 1/num_features
103 coef0: 0.0,
104 cache_size: 100.0,
105 eps: 0.001,
106 c: 1.0,
107 weight: Vec::new(),
108 nu: 0.5,
109 p: 0.1,
110 shrinking: true,
111 probability: false,
112 }
113 }
114}
115
116impl SvmParameter {
117 /// Validate parameter values (independent of training data).
118 ///
119 /// This checks the same constraints as the original LIBSVM's
120 /// `svm_check_parameter`, except for the ν-SVC feasibility check
121 /// which requires the problem. Use [`check_parameter`] for the full check.
122 pub fn validate(&self) -> Result<(), crate::error::SvmError> {
123 use crate::error::SvmError;
124
125 // gamma must be non-negative for kernels that use it
126 if matches!(
127 self.kernel_type,
128 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
129 ) && self.gamma < 0.0
130 {
131 return Err(SvmError::InvalidParameter("gamma < 0".into()));
132 }
133
134 // polynomial degree must be non-negative
135 if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
136 return Err(SvmError::InvalidParameter(
137 "degree of polynomial kernel < 0".into(),
138 ));
139 }
140
141 if self.cache_size <= 0.0 {
142 return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
143 }
144
145 if self.eps <= 0.0 {
146 return Err(SvmError::InvalidParameter("eps <= 0".into()));
147 }
148
149 // C > 0 for formulations that use it
150 if matches!(
151 self.svm_type,
152 SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
153 ) && self.c <= 0.0
154 {
155 return Err(SvmError::InvalidParameter("C <= 0".into()));
156 }
157
158 // nu ∈ (0, 1] for formulations that use it
159 if matches!(
160 self.svm_type,
161 SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
162 ) && (self.nu <= 0.0 || self.nu > 1.0)
163 {
164 return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
165 }
166
167 // p >= 0 for epsilon-SVR
168 if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
169 return Err(SvmError::InvalidParameter("p < 0".into()));
170 }
171
172 Ok(())
173 }
174}
175
176/// Full parameter check including ν-SVC feasibility against training data.
177///
178/// Matches the original LIBSVM `svm_check_parameter()`.
179pub fn check_parameter(
180 problem: &SvmProblem,
181 param: &SvmParameter,
182) -> Result<(), crate::error::SvmError> {
183 use crate::error::SvmError;
184
185 // Run the data-independent checks first
186 param.validate()?;
187
188 // ν-SVC feasibility: for every pair of classes (i, j),
189 // nu * (count_i + count_j) / 2 must be <= min(count_i, count_j)
190 //
191 // Note: LIBSVM casts labels to int for class grouping. We match this
192 // behavior. Classification labels must be integers (non-integer labels
193 // will be truncated, matching `(int)prob->y[i]` in the C code).
194 if param.svm_type == SvmType::NuSvc {
195 let mut class_counts: Vec<(i32, usize)> = Vec::new();
196 for &y in &problem.labels {
197 let label = y as i32;
198 if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
199 entry.1 += 1;
200 } else {
201 class_counts.push((label, 1));
202 }
203 }
204
205 for (i, &(_, n1)) in class_counts.iter().enumerate() {
206 for &(_, n2) in &class_counts[i + 1..] {
207 if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
208 return Err(SvmError::InvalidParameter(
209 "specified nu is infeasible".into(),
210 ));
211 }
212 }
213 }
214 }
215
216 Ok(())
217}
218
219/// A trained SVM model.
220///
221/// Produced by training, or loaded from a LIBSVM model file.
222#[derive(Debug, Clone, PartialEq)]
223pub struct SvmModel {
224 /// Parameters used during training.
225 pub param: SvmParameter,
226 /// Number of classes (2 for binary, >2 for multiclass, 2 for regression).
227 pub nr_class: usize,
228 /// Support vectors (sparse feature vectors).
229 pub sv: Vec<Vec<SvmNode>>,
230 /// Support vector coefficients. For k classes, this is a
231 /// `(k-1) × num_sv` matrix stored as `Vec<Vec<f64>>`.
232 pub sv_coef: Vec<Vec<f64>>,
233 /// Bias terms (rho). One per class pair: `k*(k-1)/2` values.
234 pub rho: Vec<f64>,
235 /// Pairwise probability parameter A (Platt scaling). Empty if not trained
236 /// with probability estimates.
237 pub prob_a: Vec<f64>,
238 /// Pairwise probability parameter B (Platt scaling). Empty if not trained
239 /// with probability estimates.
240 pub prob_b: Vec<f64>,
241 /// Probability density marks (for one-class SVM).
242 pub prob_density_marks: Vec<f64>,
243 /// Original indices of support vectors in the training set (1-based).
244 pub sv_indices: Vec<usize>,
245 /// Class labels (in the order used internally).
246 pub label: Vec<i32>,
247 /// Number of support vectors per class.
248 pub n_sv: Vec<usize>,
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 #[test]
256 fn default_params_are_valid() {
257 SvmParameter::default().validate().unwrap();
258 }
259
260 #[test]
261 fn negative_gamma_rejected() {
262 let p = SvmParameter {
263 gamma: -1.0,
264 ..Default::default()
265 };
266 assert!(p.validate().is_err());
267 }
268
269 #[test]
270 fn zero_cache_rejected() {
271 let p = SvmParameter {
272 cache_size: 0.0,
273 ..Default::default()
274 };
275 assert!(p.validate().is_err());
276 }
277
278 #[test]
279 fn zero_c_rejected() {
280 let p = SvmParameter {
281 c: 0.0,
282 ..Default::default()
283 };
284 assert!(p.validate().is_err());
285 }
286
287 #[test]
288 fn nu_out_of_range_rejected() {
289 let p = SvmParameter {
290 svm_type: SvmType::NuSvc,
291 nu: 1.5,
292 ..Default::default()
293 };
294 assert!(p.validate().is_err());
295
296 let p2 = SvmParameter {
297 svm_type: SvmType::NuSvc,
298 nu: 0.0,
299 ..Default::default()
300 };
301 assert!(p2.validate().is_err());
302 }
303
304 #[test]
305 fn negative_p_rejected_for_svr() {
306 let p = SvmParameter {
307 svm_type: SvmType::EpsilonSvr,
308 p: -0.1,
309 ..Default::default()
310 };
311 assert!(p.validate().is_err());
312 }
313
314 #[test]
315 fn negative_poly_degree_rejected() {
316 let p = SvmParameter {
317 kernel_type: KernelType::Polynomial,
318 degree: -1,
319 ..Default::default()
320 };
321 assert!(p.validate().is_err());
322 }
323
324 #[test]
325 fn nu_svc_feasibility_check() {
326 // 2 classes with 3 samples each: nu * (3+3)/2 <= 3 → nu <= 1
327 let problem = SvmProblem {
328 labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
329 instances: vec![vec![]; 6],
330 };
331 let ok_param = SvmParameter {
332 svm_type: SvmType::NuSvc,
333 nu: 0.5,
334 ..Default::default()
335 };
336 check_parameter(&problem, &ok_param).unwrap();
337
338 // nu = 0.9: 0.9 * 6/2 = 2.7 <= 3 → feasible
339 let borderline = SvmParameter {
340 svm_type: SvmType::NuSvc,
341 nu: 0.9,
342 ..Default::default()
343 };
344 check_parameter(&problem, &borderline).unwrap();
345 }
346
347 #[test]
348 fn nu_svc_infeasible() {
349 // 5 class-A, 1 class-B: nu*(5+1)/2 > min(5,1)=1 → nu > 1/3
350 let problem = SvmProblem {
351 labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
352 instances: vec![vec![]; 6],
353 };
354 let param = SvmParameter {
355 svm_type: SvmType::NuSvc,
356 nu: 0.5, // 0.5 * 6/2 = 1.5 > 1
357 ..Default::default()
358 };
359 let err = check_parameter(&problem, ¶m);
360 assert!(err.is_err());
361 assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
362 }
363}