1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15#[repr(i32)]
16pub enum SvmType {
17 CSvc = 0,
19 NuSvc = 1,
21 OneClass = 2,
23 EpsilonSvr = 3,
25 NuSvr = 4,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34#[repr(i32)]
35pub enum KernelType {
36 Linear = 0,
38 Polynomial = 1,
40 Rbf = 2,
42 Sigmoid = 3,
44 Precomputed = 4,
46}
47
48#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
54#[derive(Debug, Clone, Copy, PartialEq)]
55pub struct SvmNode {
56 pub index: i32,
59 pub value: f64,
61}
62
63#[derive(Debug, Clone, PartialEq)]
70pub struct SvmProblem {
71 pub labels: Vec<f64>,
73 pub instances: Vec<Vec<SvmNode>>,
75}
76
77#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
81#[derive(Debug, Clone, PartialEq)]
82pub struct SvmParameter {
83 pub svm_type: SvmType,
85 pub kernel_type: KernelType,
87 pub degree: i32,
89 pub gamma: f64,
92 pub coef0: f64,
94 pub cache_size: f64,
96 pub eps: f64,
98 pub c: f64,
100 pub weight: Vec<(i32, f64)>,
102 pub nu: f64,
104 pub p: f64,
106 pub shrinking: bool,
108 pub probability: bool,
110}
111
112impl Default for SvmParameter {
113 fn default() -> Self {
114 Self {
115 svm_type: SvmType::CSvc,
116 kernel_type: KernelType::Rbf,
117 degree: 3,
118 gamma: 0.0, coef0: 0.0,
120 cache_size: 100.0,
121 eps: 0.001,
122 c: 1.0,
123 weight: Vec::new(),
124 nu: 0.5,
125 p: 0.1,
126 shrinking: true,
127 probability: false,
128 }
129 }
130}
131
132impl SvmParameter {
133 pub fn validate(&self) -> Result<(), crate::error::SvmError> {
139 use crate::error::SvmError;
140
141 if matches!(
143 self.kernel_type,
144 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
145 ) && self.gamma < 0.0
146 {
147 return Err(SvmError::InvalidParameter("gamma < 0".into()));
148 }
149
150 if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
152 return Err(SvmError::InvalidParameter(
153 "degree of polynomial kernel < 0".into(),
154 ));
155 }
156
157 if self.cache_size <= 0.0 {
158 return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
159 }
160
161 if self.eps <= 0.0 {
162 return Err(SvmError::InvalidParameter("eps <= 0".into()));
163 }
164
165 if matches!(
167 self.svm_type,
168 SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
169 ) && self.c <= 0.0
170 {
171 return Err(SvmError::InvalidParameter("C <= 0".into()));
172 }
173
174 if matches!(
176 self.svm_type,
177 SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
178 ) && (self.nu <= 0.0 || self.nu > 1.0)
179 {
180 return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
181 }
182
183 if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
185 return Err(SvmError::InvalidParameter("p < 0".into()));
186 }
187
188 Ok(())
189 }
190}
191
192pub fn check_parameter(
196 problem: &SvmProblem,
197 param: &SvmParameter,
198) -> Result<(), crate::error::SvmError> {
199 use crate::error::SvmError;
200
201 param.validate()?;
203
204 if problem.labels.len() != problem.instances.len() {
205 return Err(SvmError::InvalidParameter(format!(
206 "labels length ({}) does not match instance length ({})",
207 problem.labels.len(),
208 problem.instances.len()
209 )));
210 }
211
212 if problem.labels.is_empty() {
213 return Err(SvmError::InvalidParameter(
214 "problem has no instances".into(),
215 ));
216 }
217
218 if param.kernel_type == KernelType::Precomputed {
219 let upper = problem.instances.len() as f64;
220 for (row, instance) in problem.instances.iter().enumerate() {
221 let first = instance.first().ok_or_else(|| {
222 SvmError::InvalidParameter(format!(
223 "precomputed kernel row {} is missing 0:sample_serial_number",
224 row + 1
225 ))
226 })?;
227 if first.index != 0
228 || !first.value.is_finite()
229 || first.value < 1.0
230 || first.value > upper
231 || first.value.fract() != 0.0
232 {
233 return Err(SvmError::InvalidParameter(format!(
234 "precomputed kernel row {} must start with 0:sample_serial_number in [1, {}]",
235 row + 1,
236 problem.instances.len()
237 )));
238 }
239 }
240 }
241
242 if param.svm_type == SvmType::NuSvc {
249 let mut class_counts: Vec<(i32, usize)> = Vec::new();
250 for &y in &problem.labels {
251 let label = y as i32;
252 if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
253 entry.1 += 1;
254 } else {
255 class_counts.push((label, 1));
256 }
257 }
258
259 for (i, &(_, n1)) in class_counts.iter().enumerate() {
260 for &(_, n2) in &class_counts[i + 1..] {
261 if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
262 return Err(SvmError::InvalidParameter(
263 "specified nu is infeasible".into(),
264 ));
265 }
266 }
267 }
268 }
269
270 Ok(())
271}
272
273#[cfg_attr(feature = "serde", derive(serde::Serialize))]
284#[derive(Debug, Clone, PartialEq)]
285pub struct SvmModel {
286 pub param: SvmParameter,
288 pub nr_class: usize,
290 pub sv: Vec<Vec<SvmNode>>,
292 pub sv_coef: Vec<Vec<f64>>,
295 pub rho: Vec<f64>,
297 pub prob_a: Vec<f64>,
300 pub prob_b: Vec<f64>,
303 pub prob_density_marks: Vec<f64>,
305 pub sv_indices: Vec<usize>,
307 pub label: Vec<i32>,
309 pub n_sv: Vec<usize>,
311}
312
313#[cfg(feature = "serde")]
314#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
315impl serde::Serialize for SvmType {
316 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
317 where
318 S: serde::Serializer,
319 {
320 serializer.serialize_i32(*self as i32)
321 }
322}
323
324#[cfg(feature = "serde")]
325#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
326impl<'de> serde::Deserialize<'de> for SvmType {
327 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
328 where
329 D: serde::Deserializer<'de>,
330 {
331 match <i32 as serde::Deserialize>::deserialize(deserializer)? {
332 0 => Ok(SvmType::CSvc),
333 1 => Ok(SvmType::NuSvc),
334 2 => Ok(SvmType::OneClass),
335 3 => Ok(SvmType::EpsilonSvr),
336 4 => Ok(SvmType::NuSvr),
337 code => Err(serde::de::Error::custom(format!(
338 "unknown SvmType code {code}"
339 ))),
340 }
341 }
342}
343
344#[cfg(feature = "serde")]
345#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
346impl serde::Serialize for KernelType {
347 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
348 where
349 S: serde::Serializer,
350 {
351 serializer.serialize_i32(*self as i32)
352 }
353}
354
355#[cfg(feature = "serde")]
356#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
357impl<'de> serde::Deserialize<'de> for KernelType {
358 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
359 where
360 D: serde::Deserializer<'de>,
361 {
362 match <i32 as serde::Deserialize>::deserialize(deserializer)? {
363 0 => Ok(KernelType::Linear),
364 1 => Ok(KernelType::Polynomial),
365 2 => Ok(KernelType::Rbf),
366 3 => Ok(KernelType::Sigmoid),
367 4 => Ok(KernelType::Precomputed),
368 code => Err(serde::de::Error::custom(format!(
369 "unknown KernelType code {code}"
370 ))),
371 }
372 }
373}
374
375#[cfg(feature = "serde")]
376#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
377impl<'de> serde::Deserialize<'de> for SvmModel {
378 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
379 where
380 D: serde::Deserializer<'de>,
381 {
382 #[derive(serde::Deserialize)]
383 struct RawSvmModel {
384 param: SvmParameter,
385 nr_class: usize,
386 sv: Vec<Vec<SvmNode>>,
387 sv_coef: Vec<Vec<f64>>,
388 rho: Vec<f64>,
389 prob_a: Vec<f64>,
390 prob_b: Vec<f64>,
391 prob_density_marks: Vec<f64>,
392 sv_indices: Vec<usize>,
393 label: Vec<i32>,
394 n_sv: Vec<usize>,
395 }
396
397 let raw = <RawSvmModel as serde::Deserialize>::deserialize(deserializer)?;
398 let model = SvmModel {
399 param: raw.param,
400 nr_class: raw.nr_class,
401 sv: raw.sv,
402 sv_coef: raw.sv_coef,
403 rho: raw.rho,
404 prob_a: raw.prob_a,
405 prob_b: raw.prob_b,
406 prob_density_marks: raw.prob_density_marks,
407 sv_indices: raw.sv_indices,
408 label: raw.label,
409 n_sv: raw.n_sv,
410 };
411 crate::io::validate_model(&model).map_err(serde::de::Error::custom)?;
412 Ok(model)
413 }
414}
415
416impl SvmModel {
417 pub fn svm_type(&self) -> SvmType {
419 self.param.svm_type
420 }
421
422 pub fn class_count(&self) -> usize {
424 self.nr_class
425 }
426
427 pub fn labels(&self) -> &[i32] {
429 &self.label
430 }
431
432 pub fn support_vector_indices(&self) -> &[usize] {
434 &self.sv_indices
435 }
436
437 pub fn support_vector_count(&self) -> usize {
439 self.sv.len()
440 }
441
442 pub fn svr_probability(&self) -> Option<f64> {
444 match self.param.svm_type {
445 SvmType::EpsilonSvr | SvmType::NuSvr => self.prob_a.first().copied(),
446 _ => None,
447 }
448 }
449
450 pub fn has_probability_model(&self) -> bool {
452 match self.param.svm_type {
453 SvmType::CSvc | SvmType::NuSvc => !self.prob_a.is_empty() && !self.prob_b.is_empty(),
454 SvmType::EpsilonSvr | SvmType::NuSvr => !self.prob_a.is_empty(),
455 SvmType::OneClass => !self.prob_density_marks.is_empty(),
456 }
457 }
458}
459
460pub fn svm_get_svm_type(model: &SvmModel) -> SvmType {
462 model.svm_type()
463}
464
465pub fn svm_get_nr_class(model: &SvmModel) -> usize {
467 model.class_count()
468}
469
470pub fn svm_get_labels(model: &SvmModel) -> &[i32] {
472 model.labels()
473}
474
475pub fn svm_get_sv_indices(model: &SvmModel) -> &[usize] {
477 model.support_vector_indices()
478}
479
480pub fn svm_get_nr_sv(model: &SvmModel) -> usize {
482 model.support_vector_count()
483}
484
485pub fn svm_get_svr_probability(model: &SvmModel) -> Option<f64> {
487 model.svr_probability()
488}
489
490pub fn svm_check_probability_model(model: &SvmModel) -> bool {
492 model.has_probability_model()
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use crate::train::svm_train;
499 use std::path::PathBuf;
500
501 fn data_dir() -> PathBuf {
502 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
503 .join("..")
504 .join("..")
505 .join("data")
506 }
507
508 #[test]
509 fn default_params_are_valid() {
510 SvmParameter::default().validate().unwrap();
511 }
512
513 #[test]
514 fn negative_gamma_rejected() {
515 let p = SvmParameter {
516 gamma: -1.0,
517 ..Default::default()
518 };
519 assert!(p.validate().is_err());
520 }
521
522 #[test]
523 fn zero_cache_rejected() {
524 let p = SvmParameter {
525 cache_size: 0.0,
526 ..Default::default()
527 };
528 assert!(p.validate().is_err());
529 }
530
531 #[test]
532 fn zero_c_rejected() {
533 let p = SvmParameter {
534 c: 0.0,
535 ..Default::default()
536 };
537 assert!(p.validate().is_err());
538 }
539
540 #[test]
541 fn nu_out_of_range_rejected() {
542 let p = SvmParameter {
543 svm_type: SvmType::NuSvc,
544 nu: 1.5,
545 ..Default::default()
546 };
547 assert!(p.validate().is_err());
548
549 let p2 = SvmParameter {
550 svm_type: SvmType::NuSvc,
551 nu: 0.0,
552 ..Default::default()
553 };
554 assert!(p2.validate().is_err());
555 }
556
557 #[test]
558 fn negative_p_rejected_for_svr() {
559 let p = SvmParameter {
560 svm_type: SvmType::EpsilonSvr,
561 p: -0.1,
562 ..Default::default()
563 };
564 assert!(p.validate().is_err());
565 }
566
567 #[test]
568 fn negative_poly_degree_rejected() {
569 let p = SvmParameter {
570 kernel_type: KernelType::Polynomial,
571 degree: -1,
572 ..Default::default()
573 };
574 assert!(p.validate().is_err());
575 }
576
577 #[test]
578 fn check_parameter_rejects_empty_problem() {
579 let problem = SvmProblem {
580 labels: Vec::new(),
581 instances: Vec::new(),
582 };
583 let err = check_parameter(&problem, &SvmParameter::default()).unwrap_err();
584 assert!(format!("{}", err).contains("problem has no instances"));
585 }
586
587 #[test]
588 fn check_parameter_rejects_label_instance_length_mismatch() {
589 let problem = SvmProblem {
590 labels: vec![1.0],
591 instances: Vec::new(),
592 };
593 let err = check_parameter(&problem, &SvmParameter::default()).unwrap_err();
594 assert!(format!("{}", err).contains("does not match instance length"));
595 }
596
597 #[test]
598 fn check_parameter_rejects_precomputed_rows_without_sample_serial_number() {
599 let problem = SvmProblem {
600 labels: vec![1.0, -1.0],
601 instances: vec![
602 vec![],
603 vec![SvmNode {
604 index: 0,
605 value: 2.0,
606 }],
607 ],
608 };
609 let param = SvmParameter {
610 kernel_type: KernelType::Precomputed,
611 ..Default::default()
612 };
613 let err = check_parameter(&problem, ¶m).unwrap_err();
614 assert!(format!("{}", err).contains("missing 0:sample_serial_number"));
615 }
616
617 #[test]
618 fn nu_svc_feasibility_check() {
619 let problem = SvmProblem {
621 labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
622 instances: vec![vec![]; 6],
623 };
624 let ok_param = SvmParameter {
625 svm_type: SvmType::NuSvc,
626 nu: 0.5,
627 ..Default::default()
628 };
629 check_parameter(&problem, &ok_param).unwrap();
630
631 let borderline = SvmParameter {
633 svm_type: SvmType::NuSvc,
634 nu: 0.9,
635 ..Default::default()
636 };
637 check_parameter(&problem, &borderline).unwrap();
638 }
639
640 #[test]
641 fn nu_svc_infeasible() {
642 let problem = SvmProblem {
644 labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
645 instances: vec![vec![]; 6],
646 };
647 let param = SvmParameter {
648 svm_type: SvmType::NuSvc,
649 nu: 0.5, ..Default::default()
651 };
652 let err = check_parameter(&problem, ¶m);
653 assert!(err.is_err());
654 assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
655 }
656
657 #[test]
658 fn c_api_style_model_helpers() {
659 let problem = crate::io::load_problem(&data_dir().join("heart_scale")).unwrap();
660 let param = SvmParameter {
661 gamma: 1.0 / 13.0,
662 ..Default::default()
663 };
664 let model = svm_train(&problem, ¶m);
665
666 assert_eq!(svm_get_svm_type(&model), SvmType::CSvc);
667 assert_eq!(svm_get_nr_class(&model), 2);
668 assert_eq!(svm_get_nr_sv(&model), model.sv.len());
669 assert_eq!(svm_get_labels(&model), model.label.as_slice());
670 assert_eq!(svm_get_sv_indices(&model), model.sv_indices.as_slice());
671 assert!(!svm_check_probability_model(&model));
672 assert_eq!(svm_get_svr_probability(&model), None);
673 }
674
675 #[test]
676 fn probability_helpers_by_svm_type() {
677 let svm = vec![SvmNode {
678 index: 1,
679 value: 1.0,
680 }];
681
682 let csvc_model = SvmModel {
683 param: SvmParameter {
684 svm_type: SvmType::CSvc,
685 ..Default::default()
686 },
687 nr_class: 2,
688 sv: vec![svm.clone()],
689 sv_coef: vec![vec![1.0]],
690 rho: vec![0.0],
691 prob_a: vec![1.0],
692 prob_b: vec![-0.5],
693 prob_density_marks: vec![],
694 sv_indices: vec![1],
695 label: vec![1, -1],
696 n_sv: vec![1, 0],
697 };
698 assert!(csvc_model.has_probability_model());
699 assert!(svm_check_probability_model(&csvc_model));
700 assert_eq!(svm_get_svr_probability(&csvc_model), None);
701
702 let eps_svr_model = SvmModel {
703 param: SvmParameter {
704 svm_type: SvmType::EpsilonSvr,
705 ..Default::default()
706 },
707 nr_class: 2,
708 sv: vec![svm.clone()],
709 sv_coef: vec![vec![0.8]],
710 rho: vec![0.0],
711 prob_a: vec![0.123],
712 prob_b: vec![],
713 prob_density_marks: vec![],
714 sv_indices: vec![1],
715 label: vec![],
716 n_sv: vec![],
717 };
718 assert!(eps_svr_model.has_probability_model());
719 assert_eq!(svm_get_svr_probability(&eps_svr_model), Some(0.123));
720
721 let one_class_model = SvmModel {
722 param: SvmParameter {
723 svm_type: SvmType::OneClass,
724 ..Default::default()
725 },
726 nr_class: 2,
727 sv: vec![svm],
728 sv_coef: vec![vec![1.0]],
729 rho: vec![0.0],
730 prob_a: vec![],
731 prob_b: vec![],
732 prob_density_marks: vec![0.1; 10],
733 sv_indices: vec![1],
734 label: vec![],
735 n_sv: vec![],
736 };
737 assert!(one_class_model.has_probability_model());
738 assert_eq!(svm_get_svr_probability(&one_class_model), None);
739 }
740}