1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6#[repr(i32)]
7pub enum SvmType {
8 CSvc = 0,
10 NuSvc = 1,
12 OneClass = 2,
14 EpsilonSvr = 3,
16 NuSvr = 4,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
25#[repr(i32)]
26pub enum KernelType {
27 Linear = 0,
29 Polynomial = 1,
31 Rbf = 2,
33 Sigmoid = 3,
35 Precomputed = 4,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq)]
45pub struct SvmNode {
46 pub index: i32,
49 pub value: f64,
51}
52
53#[derive(Debug, Clone, PartialEq)]
55pub struct SvmProblem {
56 pub labels: Vec<f64>,
58 pub instances: Vec<Vec<SvmNode>>,
60}
61
62#[derive(Debug, Clone, PartialEq)]
66pub struct SvmParameter {
67 pub svm_type: SvmType,
69 pub kernel_type: KernelType,
71 pub degree: i32,
73 pub gamma: f64,
76 pub coef0: f64,
78 pub cache_size: f64,
80 pub eps: f64,
82 pub c: f64,
84 pub weight: Vec<(i32, f64)>,
86 pub nu: f64,
88 pub p: f64,
90 pub shrinking: bool,
92 pub probability: bool,
94}
95
96impl Default for SvmParameter {
97 fn default() -> Self {
98 Self {
99 svm_type: SvmType::CSvc,
100 kernel_type: KernelType::Rbf,
101 degree: 3,
102 gamma: 0.0, coef0: 0.0,
104 cache_size: 100.0,
105 eps: 0.001,
106 c: 1.0,
107 weight: Vec::new(),
108 nu: 0.5,
109 p: 0.1,
110 shrinking: true,
111 probability: false,
112 }
113 }
114}
115
116impl SvmParameter {
117 pub fn validate(&self) -> Result<(), crate::error::SvmError> {
123 use crate::error::SvmError;
124
125 if matches!(
127 self.kernel_type,
128 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
129 ) && self.gamma < 0.0
130 {
131 return Err(SvmError::InvalidParameter("gamma < 0".into()));
132 }
133
134 if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
136 return Err(SvmError::InvalidParameter(
137 "degree of polynomial kernel < 0".into(),
138 ));
139 }
140
141 if self.cache_size <= 0.0 {
142 return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
143 }
144
145 if self.eps <= 0.0 {
146 return Err(SvmError::InvalidParameter("eps <= 0".into()));
147 }
148
149 if matches!(
151 self.svm_type,
152 SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
153 ) && self.c <= 0.0
154 {
155 return Err(SvmError::InvalidParameter("C <= 0".into()));
156 }
157
158 if matches!(
160 self.svm_type,
161 SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
162 ) && (self.nu <= 0.0 || self.nu > 1.0)
163 {
164 return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
165 }
166
167 if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
169 return Err(SvmError::InvalidParameter("p < 0".into()));
170 }
171
172 Ok(())
173 }
174}
175
176pub fn check_parameter(
180 problem: &SvmProblem,
181 param: &SvmParameter,
182) -> Result<(), crate::error::SvmError> {
183 use crate::error::SvmError;
184
185 param.validate()?;
187
188 if param.svm_type == SvmType::NuSvc {
195 let mut class_counts: Vec<(i32, usize)> = Vec::new();
196 for &y in &problem.labels {
197 let label = y as i32;
198 if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
199 entry.1 += 1;
200 } else {
201 class_counts.push((label, 1));
202 }
203 }
204
205 for (i, &(_, n1)) in class_counts.iter().enumerate() {
206 for &(_, n2) in &class_counts[i + 1..] {
207 if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
208 return Err(SvmError::InvalidParameter(
209 "specified nu is infeasible".into(),
210 ));
211 }
212 }
213 }
214 }
215
216 Ok(())
217}
218
219#[derive(Debug, Clone, PartialEq)]
223pub struct SvmModel {
224 pub param: SvmParameter,
226 pub nr_class: usize,
228 pub sv: Vec<Vec<SvmNode>>,
230 pub sv_coef: Vec<Vec<f64>>,
233 pub rho: Vec<f64>,
235 pub prob_a: Vec<f64>,
238 pub prob_b: Vec<f64>,
241 pub prob_density_marks: Vec<f64>,
243 pub sv_indices: Vec<usize>,
245 pub label: Vec<i32>,
247 pub n_sv: Vec<usize>,
249}
250
251impl SvmModel {
252 pub fn svm_type(&self) -> SvmType {
254 self.param.svm_type
255 }
256
257 pub fn class_count(&self) -> usize {
259 self.nr_class
260 }
261
262 pub fn labels(&self) -> &[i32] {
264 &self.label
265 }
266
267 pub fn support_vector_indices(&self) -> &[usize] {
269 &self.sv_indices
270 }
271
272 pub fn support_vector_count(&self) -> usize {
274 self.sv.len()
275 }
276
277 pub fn svr_probability(&self) -> Option<f64> {
279 match self.param.svm_type {
280 SvmType::EpsilonSvr | SvmType::NuSvr => self.prob_a.first().copied(),
281 _ => None,
282 }
283 }
284
285 pub fn has_probability_model(&self) -> bool {
287 match self.param.svm_type {
288 SvmType::CSvc | SvmType::NuSvc => !self.prob_a.is_empty() && !self.prob_b.is_empty(),
289 SvmType::EpsilonSvr | SvmType::NuSvr => !self.prob_a.is_empty(),
290 SvmType::OneClass => !self.prob_density_marks.is_empty(),
291 }
292 }
293}
294
295pub fn svm_get_svm_type(model: &SvmModel) -> SvmType {
297 model.svm_type()
298}
299
300pub fn svm_get_nr_class(model: &SvmModel) -> usize {
302 model.class_count()
303}
304
305pub fn svm_get_labels(model: &SvmModel) -> &[i32] {
307 model.labels()
308}
309
310pub fn svm_get_sv_indices(model: &SvmModel) -> &[usize] {
312 model.support_vector_indices()
313}
314
315pub fn svm_get_nr_sv(model: &SvmModel) -> usize {
317 model.support_vector_count()
318}
319
320pub fn svm_get_svr_probability(model: &SvmModel) -> Option<f64> {
322 model.svr_probability()
323}
324
325pub fn svm_check_probability_model(model: &SvmModel) -> bool {
327 model.has_probability_model()
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use crate::train::svm_train;
334 use std::path::PathBuf;
335
336 fn data_dir() -> PathBuf {
337 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
338 .join("..")
339 .join("..")
340 .join("data")
341 }
342
343 #[test]
344 fn default_params_are_valid() {
345 SvmParameter::default().validate().unwrap();
346 }
347
348 #[test]
349 fn negative_gamma_rejected() {
350 let p = SvmParameter {
351 gamma: -1.0,
352 ..Default::default()
353 };
354 assert!(p.validate().is_err());
355 }
356
357 #[test]
358 fn zero_cache_rejected() {
359 let p = SvmParameter {
360 cache_size: 0.0,
361 ..Default::default()
362 };
363 assert!(p.validate().is_err());
364 }
365
366 #[test]
367 fn zero_c_rejected() {
368 let p = SvmParameter {
369 c: 0.0,
370 ..Default::default()
371 };
372 assert!(p.validate().is_err());
373 }
374
375 #[test]
376 fn nu_out_of_range_rejected() {
377 let p = SvmParameter {
378 svm_type: SvmType::NuSvc,
379 nu: 1.5,
380 ..Default::default()
381 };
382 assert!(p.validate().is_err());
383
384 let p2 = SvmParameter {
385 svm_type: SvmType::NuSvc,
386 nu: 0.0,
387 ..Default::default()
388 };
389 assert!(p2.validate().is_err());
390 }
391
392 #[test]
393 fn negative_p_rejected_for_svr() {
394 let p = SvmParameter {
395 svm_type: SvmType::EpsilonSvr,
396 p: -0.1,
397 ..Default::default()
398 };
399 assert!(p.validate().is_err());
400 }
401
402 #[test]
403 fn negative_poly_degree_rejected() {
404 let p = SvmParameter {
405 kernel_type: KernelType::Polynomial,
406 degree: -1,
407 ..Default::default()
408 };
409 assert!(p.validate().is_err());
410 }
411
412 #[test]
413 fn nu_svc_feasibility_check() {
414 let problem = SvmProblem {
416 labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
417 instances: vec![vec![]; 6],
418 };
419 let ok_param = SvmParameter {
420 svm_type: SvmType::NuSvc,
421 nu: 0.5,
422 ..Default::default()
423 };
424 check_parameter(&problem, &ok_param).unwrap();
425
426 let borderline = SvmParameter {
428 svm_type: SvmType::NuSvc,
429 nu: 0.9,
430 ..Default::default()
431 };
432 check_parameter(&problem, &borderline).unwrap();
433 }
434
435 #[test]
436 fn nu_svc_infeasible() {
437 let problem = SvmProblem {
439 labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
440 instances: vec![vec![]; 6],
441 };
442 let param = SvmParameter {
443 svm_type: SvmType::NuSvc,
444 nu: 0.5, ..Default::default()
446 };
447 let err = check_parameter(&problem, ¶m);
448 assert!(err.is_err());
449 assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
450 }
451
452 #[test]
453 fn c_api_style_model_helpers() {
454 let problem = crate::io::load_problem(&data_dir().join("heart_scale")).unwrap();
455 let param = SvmParameter {
456 gamma: 1.0 / 13.0,
457 ..Default::default()
458 };
459 let model = svm_train(&problem, ¶m);
460
461 assert_eq!(svm_get_svm_type(&model), SvmType::CSvc);
462 assert_eq!(svm_get_nr_class(&model), 2);
463 assert_eq!(svm_get_nr_sv(&model), model.sv.len());
464 assert_eq!(svm_get_labels(&model), model.label.as_slice());
465 assert_eq!(svm_get_sv_indices(&model), model.sv_indices.as_slice());
466 assert!(!svm_check_probability_model(&model));
467 assert_eq!(svm_get_svr_probability(&model), None);
468 }
469
470 #[test]
471 fn probability_helpers_by_svm_type() {
472 let svm = vec![SvmNode {
473 index: 1,
474 value: 1.0,
475 }];
476
477 let csvc_model = SvmModel {
478 param: SvmParameter {
479 svm_type: SvmType::CSvc,
480 ..Default::default()
481 },
482 nr_class: 2,
483 sv: vec![svm.clone()],
484 sv_coef: vec![vec![1.0]],
485 rho: vec![0.0],
486 prob_a: vec![1.0],
487 prob_b: vec![-0.5],
488 prob_density_marks: vec![],
489 sv_indices: vec![1],
490 label: vec![1, -1],
491 n_sv: vec![1, 0],
492 };
493 assert!(csvc_model.has_probability_model());
494 assert!(svm_check_probability_model(&csvc_model));
495 assert_eq!(svm_get_svr_probability(&csvc_model), None);
496
497 let eps_svr_model = SvmModel {
498 param: SvmParameter {
499 svm_type: SvmType::EpsilonSvr,
500 ..Default::default()
501 },
502 nr_class: 2,
503 sv: vec![svm.clone()],
504 sv_coef: vec![vec![0.8]],
505 rho: vec![0.0],
506 prob_a: vec![0.123],
507 prob_b: vec![],
508 prob_density_marks: vec![],
509 sv_indices: vec![1],
510 label: vec![],
511 n_sv: vec![],
512 };
513 assert!(eps_svr_model.has_probability_model());
514 assert_eq!(svm_get_svr_probability(&eps_svr_model), Some(0.123));
515
516 let one_class_model = SvmModel {
517 param: SvmParameter {
518 svm_type: SvmType::OneClass,
519 ..Default::default()
520 },
521 nr_class: 2,
522 sv: vec![svm],
523 sv_coef: vec![vec![1.0]],
524 rho: vec![0.0],
525 prob_a: vec![],
526 prob_b: vec![],
527 prob_density_marks: vec![0.1; 10],
528 sv_indices: vec![1],
529 label: vec![],
530 n_sv: vec![],
531 };
532 assert!(one_class_model.has_probability_model());
533 assert_eq!(svm_get_svr_probability(&one_class_model), None);
534 }
535}