1#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15#[repr(i32)]
16pub enum SvmType {
17 CSvc = 0,
19 NuSvc = 1,
21 OneClass = 2,
23 EpsilonSvr = 3,
25 NuSvr = 4,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34#[repr(i32)]
35pub enum KernelType {
36 Linear = 0,
38 Polynomial = 1,
40 Rbf = 2,
42 Sigmoid = 3,
44 Precomputed = 4,
46}
47
48#[derive(Debug, Clone, Copy, PartialEq)]
54pub struct SvmNode {
55 pub index: i32,
58 pub value: f64,
60}
61
62#[derive(Debug, Clone, PartialEq)]
69pub struct SvmProblem {
70 pub labels: Vec<f64>,
72 pub instances: Vec<Vec<SvmNode>>,
74}
75
76#[derive(Debug, Clone, PartialEq)]
80pub struct SvmParameter {
81 pub svm_type: SvmType,
83 pub kernel_type: KernelType,
85 pub degree: i32,
87 pub gamma: f64,
90 pub coef0: f64,
92 pub cache_size: f64,
94 pub eps: f64,
96 pub c: f64,
98 pub weight: Vec<(i32, f64)>,
100 pub nu: f64,
102 pub p: f64,
104 pub shrinking: bool,
106 pub probability: bool,
108}
109
110impl Default for SvmParameter {
111 fn default() -> Self {
112 Self {
113 svm_type: SvmType::CSvc,
114 kernel_type: KernelType::Rbf,
115 degree: 3,
116 gamma: 0.0, coef0: 0.0,
118 cache_size: 100.0,
119 eps: 0.001,
120 c: 1.0,
121 weight: Vec::new(),
122 nu: 0.5,
123 p: 0.1,
124 shrinking: true,
125 probability: false,
126 }
127 }
128}
129
130impl SvmParameter {
131 pub fn validate(&self) -> Result<(), crate::error::SvmError> {
137 use crate::error::SvmError;
138
139 if matches!(
141 self.kernel_type,
142 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
143 ) && self.gamma < 0.0
144 {
145 return Err(SvmError::InvalidParameter("gamma < 0".into()));
146 }
147
148 if self.kernel_type == KernelType::Polynomial && self.degree < 0 {
150 return Err(SvmError::InvalidParameter(
151 "degree of polynomial kernel < 0".into(),
152 ));
153 }
154
155 if self.cache_size <= 0.0 {
156 return Err(SvmError::InvalidParameter("cache_size <= 0".into()));
157 }
158
159 if self.eps <= 0.0 {
160 return Err(SvmError::InvalidParameter("eps <= 0".into()));
161 }
162
163 if matches!(
165 self.svm_type,
166 SvmType::CSvc | SvmType::EpsilonSvr | SvmType::NuSvr
167 ) && self.c <= 0.0
168 {
169 return Err(SvmError::InvalidParameter("C <= 0".into()));
170 }
171
172 if matches!(
174 self.svm_type,
175 SvmType::NuSvc | SvmType::OneClass | SvmType::NuSvr
176 ) && (self.nu <= 0.0 || self.nu > 1.0)
177 {
178 return Err(SvmError::InvalidParameter("nu <= 0 or nu > 1".into()));
179 }
180
181 if self.svm_type == SvmType::EpsilonSvr && self.p < 0.0 {
183 return Err(SvmError::InvalidParameter("p < 0".into()));
184 }
185
186 Ok(())
187 }
188}
189
190pub fn check_parameter(
194 problem: &SvmProblem,
195 param: &SvmParameter,
196) -> Result<(), crate::error::SvmError> {
197 use crate::error::SvmError;
198
199 param.validate()?;
201
202 if problem.labels.len() != problem.instances.len() {
203 return Err(SvmError::InvalidParameter(format!(
204 "labels length ({}) does not match instance length ({})",
205 problem.labels.len(),
206 problem.instances.len()
207 )));
208 }
209
210 if problem.labels.is_empty() {
211 return Err(SvmError::InvalidParameter(
212 "problem has no instances".into(),
213 ));
214 }
215
216 if param.kernel_type == KernelType::Precomputed {
217 let upper = problem.instances.len() as f64;
218 for (row, instance) in problem.instances.iter().enumerate() {
219 let first = instance.first().ok_or_else(|| {
220 SvmError::InvalidParameter(format!(
221 "precomputed kernel row {} is missing 0:sample_serial_number",
222 row + 1
223 ))
224 })?;
225 if first.index != 0
226 || !first.value.is_finite()
227 || first.value < 1.0
228 || first.value > upper
229 || first.value.fract() != 0.0
230 {
231 return Err(SvmError::InvalidParameter(format!(
232 "precomputed kernel row {} must start with 0:sample_serial_number in [1, {}]",
233 row + 1,
234 problem.instances.len()
235 )));
236 }
237 }
238 }
239
240 if param.svm_type == SvmType::NuSvc {
247 let mut class_counts: Vec<(i32, usize)> = Vec::new();
248 for &y in &problem.labels {
249 let label = y as i32;
250 if let Some(entry) = class_counts.iter_mut().find(|(l, _)| *l == label) {
251 entry.1 += 1;
252 } else {
253 class_counts.push((label, 1));
254 }
255 }
256
257 for (i, &(_, n1)) in class_counts.iter().enumerate() {
258 for &(_, n2) in &class_counts[i + 1..] {
259 if param.nu * (n1 + n2) as f64 / 2.0 > n1.min(n2) as f64 {
260 return Err(SvmError::InvalidParameter(
261 "specified nu is infeasible".into(),
262 ));
263 }
264 }
265 }
266 }
267
268 Ok(())
269}
270
271#[derive(Debug, Clone, PartialEq)]
282pub struct SvmModel {
283 pub param: SvmParameter,
285 pub nr_class: usize,
287 pub sv: Vec<Vec<SvmNode>>,
289 pub sv_coef: Vec<Vec<f64>>,
292 pub rho: Vec<f64>,
294 pub prob_a: Vec<f64>,
297 pub prob_b: Vec<f64>,
300 pub prob_density_marks: Vec<f64>,
302 pub sv_indices: Vec<usize>,
304 pub label: Vec<i32>,
306 pub n_sv: Vec<usize>,
308}
309
310impl SvmModel {
311 pub fn svm_type(&self) -> SvmType {
313 self.param.svm_type
314 }
315
316 pub fn class_count(&self) -> usize {
318 self.nr_class
319 }
320
321 pub fn labels(&self) -> &[i32] {
323 &self.label
324 }
325
326 pub fn support_vector_indices(&self) -> &[usize] {
328 &self.sv_indices
329 }
330
331 pub fn support_vector_count(&self) -> usize {
333 self.sv.len()
334 }
335
336 pub fn svr_probability(&self) -> Option<f64> {
338 match self.param.svm_type {
339 SvmType::EpsilonSvr | SvmType::NuSvr => self.prob_a.first().copied(),
340 _ => None,
341 }
342 }
343
344 pub fn has_probability_model(&self) -> bool {
346 match self.param.svm_type {
347 SvmType::CSvc | SvmType::NuSvc => !self.prob_a.is_empty() && !self.prob_b.is_empty(),
348 SvmType::EpsilonSvr | SvmType::NuSvr => !self.prob_a.is_empty(),
349 SvmType::OneClass => !self.prob_density_marks.is_empty(),
350 }
351 }
352}
353
354pub fn svm_get_svm_type(model: &SvmModel) -> SvmType {
356 model.svm_type()
357}
358
359pub fn svm_get_nr_class(model: &SvmModel) -> usize {
361 model.class_count()
362}
363
364pub fn svm_get_labels(model: &SvmModel) -> &[i32] {
366 model.labels()
367}
368
369pub fn svm_get_sv_indices(model: &SvmModel) -> &[usize] {
371 model.support_vector_indices()
372}
373
374pub fn svm_get_nr_sv(model: &SvmModel) -> usize {
376 model.support_vector_count()
377}
378
379pub fn svm_get_svr_probability(model: &SvmModel) -> Option<f64> {
381 model.svr_probability()
382}
383
384pub fn svm_check_probability_model(model: &SvmModel) -> bool {
386 model.has_probability_model()
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::train::svm_train;
393 use std::path::PathBuf;
394
395 fn data_dir() -> PathBuf {
396 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
397 .join("..")
398 .join("..")
399 .join("data")
400 }
401
402 #[test]
403 fn default_params_are_valid() {
404 SvmParameter::default().validate().unwrap();
405 }
406
407 #[test]
408 fn negative_gamma_rejected() {
409 let p = SvmParameter {
410 gamma: -1.0,
411 ..Default::default()
412 };
413 assert!(p.validate().is_err());
414 }
415
416 #[test]
417 fn zero_cache_rejected() {
418 let p = SvmParameter {
419 cache_size: 0.0,
420 ..Default::default()
421 };
422 assert!(p.validate().is_err());
423 }
424
425 #[test]
426 fn zero_c_rejected() {
427 let p = SvmParameter {
428 c: 0.0,
429 ..Default::default()
430 };
431 assert!(p.validate().is_err());
432 }
433
434 #[test]
435 fn nu_out_of_range_rejected() {
436 let p = SvmParameter {
437 svm_type: SvmType::NuSvc,
438 nu: 1.5,
439 ..Default::default()
440 };
441 assert!(p.validate().is_err());
442
443 let p2 = SvmParameter {
444 svm_type: SvmType::NuSvc,
445 nu: 0.0,
446 ..Default::default()
447 };
448 assert!(p2.validate().is_err());
449 }
450
451 #[test]
452 fn negative_p_rejected_for_svr() {
453 let p = SvmParameter {
454 svm_type: SvmType::EpsilonSvr,
455 p: -0.1,
456 ..Default::default()
457 };
458 assert!(p.validate().is_err());
459 }
460
461 #[test]
462 fn negative_poly_degree_rejected() {
463 let p = SvmParameter {
464 kernel_type: KernelType::Polynomial,
465 degree: -1,
466 ..Default::default()
467 };
468 assert!(p.validate().is_err());
469 }
470
471 #[test]
472 fn check_parameter_rejects_empty_problem() {
473 let problem = SvmProblem {
474 labels: Vec::new(),
475 instances: Vec::new(),
476 };
477 let err = check_parameter(&problem, &SvmParameter::default()).unwrap_err();
478 assert!(format!("{}", err).contains("problem has no instances"));
479 }
480
481 #[test]
482 fn check_parameter_rejects_label_instance_length_mismatch() {
483 let problem = SvmProblem {
484 labels: vec![1.0],
485 instances: Vec::new(),
486 };
487 let err = check_parameter(&problem, &SvmParameter::default()).unwrap_err();
488 assert!(format!("{}", err).contains("does not match instance length"));
489 }
490
491 #[test]
492 fn check_parameter_rejects_precomputed_rows_without_sample_serial_number() {
493 let problem = SvmProblem {
494 labels: vec![1.0, -1.0],
495 instances: vec![
496 vec![],
497 vec![SvmNode {
498 index: 0,
499 value: 2.0,
500 }],
501 ],
502 };
503 let param = SvmParameter {
504 kernel_type: KernelType::Precomputed,
505 ..Default::default()
506 };
507 let err = check_parameter(&problem, ¶m).unwrap_err();
508 assert!(format!("{}", err).contains("missing 0:sample_serial_number"));
509 }
510
511 #[test]
512 fn nu_svc_feasibility_check() {
513 let problem = SvmProblem {
515 labels: vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
516 instances: vec![vec![]; 6],
517 };
518 let ok_param = SvmParameter {
519 svm_type: SvmType::NuSvc,
520 nu: 0.5,
521 ..Default::default()
522 };
523 check_parameter(&problem, &ok_param).unwrap();
524
525 let borderline = SvmParameter {
527 svm_type: SvmType::NuSvc,
528 nu: 0.9,
529 ..Default::default()
530 };
531 check_parameter(&problem, &borderline).unwrap();
532 }
533
534 #[test]
535 fn nu_svc_infeasible() {
536 let problem = SvmProblem {
538 labels: vec![1.0, 1.0, 1.0, 1.0, 1.0, 2.0],
539 instances: vec![vec![]; 6],
540 };
541 let param = SvmParameter {
542 svm_type: SvmType::NuSvc,
543 nu: 0.5, ..Default::default()
545 };
546 let err = check_parameter(&problem, ¶m);
547 assert!(err.is_err());
548 assert!(format!("{}", err.unwrap_err()).contains("infeasible"));
549 }
550
551 #[test]
552 fn c_api_style_model_helpers() {
553 let problem = crate::io::load_problem(&data_dir().join("heart_scale")).unwrap();
554 let param = SvmParameter {
555 gamma: 1.0 / 13.0,
556 ..Default::default()
557 };
558 let model = svm_train(&problem, ¶m);
559
560 assert_eq!(svm_get_svm_type(&model), SvmType::CSvc);
561 assert_eq!(svm_get_nr_class(&model), 2);
562 assert_eq!(svm_get_nr_sv(&model), model.sv.len());
563 assert_eq!(svm_get_labels(&model), model.label.as_slice());
564 assert_eq!(svm_get_sv_indices(&model), model.sv_indices.as_slice());
565 assert!(!svm_check_probability_model(&model));
566 assert_eq!(svm_get_svr_probability(&model), None);
567 }
568
569 #[test]
570 fn probability_helpers_by_svm_type() {
571 let svm = vec![SvmNode {
572 index: 1,
573 value: 1.0,
574 }];
575
576 let csvc_model = SvmModel {
577 param: SvmParameter {
578 svm_type: SvmType::CSvc,
579 ..Default::default()
580 },
581 nr_class: 2,
582 sv: vec![svm.clone()],
583 sv_coef: vec![vec![1.0]],
584 rho: vec![0.0],
585 prob_a: vec![1.0],
586 prob_b: vec![-0.5],
587 prob_density_marks: vec![],
588 sv_indices: vec![1],
589 label: vec![1, -1],
590 n_sv: vec![1, 0],
591 };
592 assert!(csvc_model.has_probability_model());
593 assert!(svm_check_probability_model(&csvc_model));
594 assert_eq!(svm_get_svr_probability(&csvc_model), None);
595
596 let eps_svr_model = SvmModel {
597 param: SvmParameter {
598 svm_type: SvmType::EpsilonSvr,
599 ..Default::default()
600 },
601 nr_class: 2,
602 sv: vec![svm.clone()],
603 sv_coef: vec![vec![0.8]],
604 rho: vec![0.0],
605 prob_a: vec![0.123],
606 prob_b: vec![],
607 prob_density_marks: vec![],
608 sv_indices: vec![1],
609 label: vec![],
610 n_sv: vec![],
611 };
612 assert!(eps_svr_model.has_probability_model());
613 assert_eq!(svm_get_svr_probability(&eps_svr_model), Some(0.123));
614
615 let one_class_model = SvmModel {
616 param: SvmParameter {
617 svm_type: SvmType::OneClass,
618 ..Default::default()
619 },
620 nr_class: 2,
621 sv: vec![svm],
622 sv_coef: vec![vec![1.0]],
623 rho: vec![0.0],
624 prob_a: vec![],
625 prob_b: vec![],
626 prob_density_marks: vec![0.1; 10],
627 sv_indices: vec![1],
628 label: vec![],
629 n_sv: vec![],
630 };
631 assert!(one_class_model.has_probability_model());
632 assert_eq!(svm_get_svr_probability(&one_class_model), None);
633 }
634}