1use std::io::{BufRead, Write};
7use std::path::Path;
8
9use crate::error::SvmError;
10use crate::util::MAX_FEATURE_INDEX;
11use crate::util::parse_feature_index;
12use crate::types::*;
13
14use std::fmt;
24
25struct Gfmt {
27 value: f64,
28 precision: usize,
29}
30
31impl Gfmt {
32 fn new(value: f64, precision: usize) -> Self {
33 Self { value, precision }
34 }
35}
36
37impl fmt::Display for Gfmt {
38 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39 let v = self.value;
40 let p = self.precision;
41
42 if !v.is_finite() {
43 return write!(f, "{}", v); }
45
46 if v == 0.0 {
47 if v.is_sign_negative() {
49 return write!(f, "-0");
50 }
51 return write!(f, "0");
52 }
53
54 let abs_v = v.abs();
56 let exp = abs_v.log10().floor() as i32;
57
58 if exp < -4 || exp >= p as i32 {
59 let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
61 if let Some((mantissa, exponent)) = s.split_once('e') {
64 let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
65 let exp_val: i32 = exponent.parse().unwrap_or(0);
67 let exp_str = if exp_val < 0 {
68 format!("-{:02}", -exp_val)
69 } else {
70 format!("+{:02}", exp_val)
71 };
72 write!(f, "{}e{}", mantissa, exp_str)
73 } else {
74 write!(f, "{}", s)
75 }
76 } else {
77 let decimal_places = if exp >= 0 {
79 p.saturating_sub((exp + 1) as usize)
80 } else {
81 p + (-1 - exp) as usize
82 };
83 let s = format!("{:.prec$}", v, prec = decimal_places);
84 let s = s.trim_end_matches('0').trim_end_matches('.');
85 write!(f, "{}", s)
86 }
87 }
88}
89
90fn fmt_17g(v: f64) -> Gfmt {
92 Gfmt::new(v, 17)
93}
94
95fn fmt_8g(v: f64) -> Gfmt {
97 Gfmt::new(v, 8)
98}
99
100pub fn format_g(v: f64) -> String {
102 format!("{}", Gfmt::new(v, 6))
103}
104
105pub fn format_17g(v: f64) -> String {
107 format!("{}", Gfmt::new(v, 17))
108}
109
110const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
113const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
114
115fn svm_type_to_str(t: SvmType) -> &'static str {
116 SVM_TYPE_TABLE[t as usize]
117}
118
119fn kernel_type_to_str(t: KernelType) -> &'static str {
120 KERNEL_TYPE_TABLE[t as usize]
121}
122
123fn str_to_svm_type(s: &str) -> Option<SvmType> {
124 match s {
125 "c_svc" => Some(SvmType::CSvc),
126 "nu_svc" => Some(SvmType::NuSvc),
127 "one_class" => Some(SvmType::OneClass),
128 "epsilon_svr" => Some(SvmType::EpsilonSvr),
129 "nu_svr" => Some(SvmType::NuSvr),
130 _ => None,
131 }
132}
133
134fn str_to_kernel_type(s: &str) -> Option<KernelType> {
135 match s {
136 "linear" => Some(KernelType::Linear),
137 "polynomial" => Some(KernelType::Polynomial),
138 "rbf" => Some(KernelType::Rbf),
139 "sigmoid" => Some(KernelType::Sigmoid),
140 "precomputed" => Some(KernelType::Precomputed),
141 _ => None,
142 }
143}
144
145pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
151 let file = std::fs::File::open(path)?;
152 let reader = std::io::BufReader::new(file);
153 load_problem_from_reader(reader)
154}
155
156pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
158 let mut labels = Vec::new();
159 let mut instances = Vec::new();
160
161 for (line_idx, line_result) in reader.lines().enumerate() {
162 let line = line_result?;
163 let line = line.trim();
164 if line.is_empty() {
165 continue;
166 }
167
168 let line_num = line_idx + 1;
169 let mut parts = line.split_whitespace();
170
171 let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
173 line: line_num,
174 message: "missing label".into(),
175 })?;
176 let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
177 line: line_num,
178 message: format!("invalid label: {}", label_str),
179 })?;
180
181 let mut nodes = Vec::new();
183 let mut prev_index: i32 = 0;
184 for token in parts {
185 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
186 line: line_num,
187 message: format!("expected index:value, got: {}", token),
188 })?;
189 let index: i32 = parse_feature_index_problem_line(line_num, idx_str)?;
190
191 if !nodes.is_empty() && index <= prev_index {
192 return Err(SvmError::ParseError {
193 line: line_num,
194 message: format!(
195 "feature indices must be ascending: {} follows {}",
196 index, prev_index
197 ),
198 });
199 }
200 let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
201 line: line_num,
202 message: format!("invalid value: {}", val_str),
203 })?;
204 prev_index = index;
205 nodes.push(SvmNode { index, value });
206 }
207
208 labels.push(label);
209 instances.push(nodes);
210 }
211
212 Ok(SvmProblem { labels, instances })
213}
214
215const MAX_NR_CLASS: usize = 65535;
218const MAX_TOTAL_SV: usize = 10_000_000;
219
220pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
222 let file = std::fs::File::create(path)?;
223 let writer = std::io::BufWriter::new(file);
224 save_model_to_writer(writer, model)
225}
226
227pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
229 let param = &model.param;
230
231 writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
232 writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
233
234 if param.kernel_type == KernelType::Polynomial {
235 writeln!(w, "degree {}", param.degree)?;
236 }
237 if matches!(
238 param.kernel_type,
239 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
240 ) {
241 writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
242 }
243 if matches!(
244 param.kernel_type,
245 KernelType::Polynomial | KernelType::Sigmoid
246 ) {
247 writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
248 }
249
250 let nr_class = model.nr_class;
251 writeln!(w, "nr_class {}", nr_class)?;
252 writeln!(w, "total_sv {}", model.sv.len())?;
253
254 write!(w, "rho")?;
256 for r in &model.rho {
257 write!(w, " {}", fmt_17g(*r))?;
258 }
259 writeln!(w)?;
260
261 if !model.label.is_empty() {
263 write!(w, "label")?;
264 for l in &model.label {
265 write!(w, " {}", l)?;
266 }
267 writeln!(w)?;
268 }
269
270 if !model.prob_a.is_empty() {
272 write!(w, "probA")?;
273 for v in &model.prob_a {
274 write!(w, " {}", fmt_17g(*v))?;
275 }
276 writeln!(w)?;
277 }
278
279 if !model.prob_b.is_empty() {
281 write!(w, "probB")?;
282 for v in &model.prob_b {
283 write!(w, " {}", fmt_17g(*v))?;
284 }
285 writeln!(w)?;
286 }
287
288 if !model.prob_density_marks.is_empty() {
290 write!(w, "prob_density_marks")?;
291 for v in &model.prob_density_marks {
292 write!(w, " {}", fmt_17g(*v))?;
293 }
294 writeln!(w)?;
295 }
296
297 if !model.n_sv.is_empty() {
299 write!(w, "nr_sv")?;
300 for n in &model.n_sv {
301 write!(w, " {}", n)?;
302 }
303 writeln!(w)?;
304 }
305
306 writeln!(w, "SV")?;
308 let num_sv = model.sv.len();
309 let num_coef_rows = model.sv_coef.len(); for i in 0..num_sv {
312 for j in 0..num_coef_rows {
314 write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
315 }
316 if model.param.kernel_type == KernelType::Precomputed {
318 if let Some(node) = model.sv[i].first() {
319 write!(w, "0:{} ", node.value as i32)?;
320 }
321 } else {
322 for node in &model.sv[i] {
323 write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
324 }
325 }
326 writeln!(w)?;
327 }
328
329 Ok(())
330}
331
332pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
334 let file = std::fs::File::open(path)?;
335 let reader = std::io::BufReader::new(file);
336 load_model_from_reader(reader)
337}
338
339pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
341 let mut lines = reader.lines();
342
343 let mut param = SvmParameter::default();
345 let mut nr_class: usize = 0;
346 let mut total_sv: usize = 0;
347 let mut rho = Vec::new();
348 let mut label = Vec::new();
349 let mut prob_a = Vec::new();
350 let mut prob_b = Vec::new();
351 let mut prob_density_marks = Vec::new();
352 let mut n_sv = Vec::new();
353
354 let mut line_num: usize = 0;
356 loop {
357 let line = lines.next().ok_or_else(|| {
358 SvmError::ModelFormatError("unexpected end of file in header".into())
359 })??;
360 line_num += 1;
361 let line = line.trim().to_string();
362 if line.is_empty() {
363 continue;
364 }
365
366 let mut parts = line.split_whitespace();
367 let cmd = parts.next().unwrap();
368
369 match cmd {
370 "svm_type" => {
371 let val = parts.next().ok_or_else(|| {
372 SvmError::ModelFormatError(format!("line {}: missing svm_type value", line_num))
373 })?;
374 param.svm_type = str_to_svm_type(val).ok_or_else(|| {
375 SvmError::ModelFormatError(format!(
376 "line {}: unknown svm_type: {}",
377 line_num, val
378 ))
379 })?;
380 }
381 "kernel_type" => {
382 let val = parts.next().ok_or_else(|| {
383 SvmError::ModelFormatError(format!(
384 "line {}: missing kernel_type value",
385 line_num
386 ))
387 })?;
388 param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
389 SvmError::ModelFormatError(format!(
390 "line {}: unknown kernel_type: {}",
391 line_num, val
392 ))
393 })?;
394 }
395 "degree" => {
396 param.degree = parse_single(&mut parts, line_num, "degree")?;
397 }
398 "gamma" => {
399 param.gamma = parse_single(&mut parts, line_num, "gamma")?;
400 }
401 "coef0" => {
402 param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
403 }
404 "nr_class" => {
405 nr_class = parse_single(&mut parts, line_num, "nr_class")?;
406 if nr_class > MAX_NR_CLASS {
407 return Err(SvmError::ModelFormatError(format!(
408 "line {}: nr_class exceeds limit ({})",
409 line_num, MAX_NR_CLASS
410 )));
411 }
412 }
413 "total_sv" => {
414 total_sv = parse_single(&mut parts, line_num, "total_sv")?;
415 if total_sv > MAX_TOTAL_SV {
416 return Err(SvmError::ModelFormatError(format!(
417 "line {}: total_sv exceeds limit ({})",
418 line_num, MAX_TOTAL_SV
419 )));
420 }
421 }
422 "rho" => {
423 rho = parse_multiple(&mut parts, line_num, "rho")?;
424 }
425 "label" => {
426 label = parse_multiple(&mut parts, line_num, "label")?;
427 }
428 "probA" => {
429 prob_a = parse_multiple(&mut parts, line_num, "probA")?;
430 }
431 "probB" => {
432 prob_b = parse_multiple(&mut parts, line_num, "probB")?;
433 }
434 "prob_density_marks" => {
435 prob_density_marks =
436 parse_multiple(&mut parts, line_num, "prob_density_marks")?;
437 }
438 "nr_sv" => {
439 n_sv = parts
440 .map(|s| {
441 s.parse::<usize>().map_err(|_| {
442 SvmError::ModelFormatError(format!(
443 "line {}: invalid nr_sv value: {}",
444 line_num, s
445 ))
446 })
447 })
448 .collect::<Result<Vec<_>, _>>()?;
449 }
450 "SV" => break,
451 _ => {
452 return Err(SvmError::ModelFormatError(format!(
453 "line {}: unknown keyword: {}",
454 line_num, cmd
455 )));
456 }
457 }
458 }
459
460 let m = if nr_class > 1 { nr_class - 1 } else { 1 };
462 let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::with_capacity(total_sv)).collect();
463 let mut sv: Vec<Vec<SvmNode>> = Vec::with_capacity(total_sv);
464
465 for _ in 0..total_sv {
466 let line = lines.next().ok_or_else(|| {
467 SvmError::ModelFormatError("unexpected end of file in SV section".into())
468 })??;
469 line_num += 1;
470 let line = line.trim();
471 if line.is_empty() {
472 continue;
473 }
474
475 let mut parts = line.split_whitespace();
476
477 for (k, coef_row) in sv_coef.iter_mut().enumerate() {
479 let val_str = parts.next().ok_or_else(|| {
480 SvmError::ModelFormatError(format!("line {}: missing sv_coef[{}]", line_num, k))
481 })?;
482 let val: f64 = val_str.parse().map_err(|_| {
483 SvmError::ModelFormatError(format!(
484 "line {}: invalid sv_coef: {}",
485 line_num, val_str
486 ))
487 })?;
488 coef_row.push(val);
489 }
490
491 let mut nodes = Vec::new();
493 for token in parts {
494 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
495 SvmError::ModelFormatError(format!(
496 "line {}: expected index:value, got: {}",
497 line_num, token
498 ))
499 })?;
500 let index: i32 = parse_feature_index_model_line(line_num, idx_str)?;
501
502 let value: f64 = val_str.parse().map_err(|_| {
503 SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
504 })?;
505 nodes.push(SvmNode { index, value });
506 }
507 sv.push(nodes);
508 }
509
510 Ok(SvmModel {
511 param,
512 nr_class,
513 sv,
514 sv_coef,
515 rho,
516 prob_a,
517 prob_b,
518 prob_density_marks,
519 sv_indices: Vec::new(), label,
521 n_sv,
522 })
523}
524
525fn parse_feature_index_problem_line(line_num: usize, idx_str: &str) -> Result<i32, SvmError> {
528 parse_feature_index(idx_str, MAX_FEATURE_INDEX).map_err(|msg| SvmError::ParseError {
529 line: line_num,
530 message: msg,
531 })
532}
533
534fn parse_feature_index_model_line(line_num: usize, idx_str: &str) -> Result<i32, SvmError> {
535 parse_feature_index(idx_str, MAX_FEATURE_INDEX)
536 .map_err(|msg| SvmError::ModelFormatError(format!("line {}: {}", line_num, msg)))
537}
538
539fn parse_single<T: std::str::FromStr>(
540 parts: &mut std::str::SplitWhitespace<'_>,
541 line_num: usize,
542 field: &str,
543) -> Result<T, SvmError> {
544 let val_str = parts.next().ok_or_else(|| {
545 SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
546 })?;
547 val_str.parse().map_err(|_| {
548 SvmError::ModelFormatError(format!(
549 "line {}: invalid {} value: {}",
550 line_num, field, val_str
551 ))
552 })
553}
554
555fn parse_multiple<T: std::str::FromStr>(
556 parts: &mut std::str::SplitWhitespace<'_>,
557 line_num: usize,
558 field: &str,
559) -> Result<Vec<T>, SvmError> {
560 parts
561 .map(|s| {
562 s.parse::<T>().map_err(|_| {
563 SvmError::ModelFormatError(format!(
564 "line {}: invalid {} value: {}",
565 line_num, field, s
566 ))
567 })
568 })
569 .collect()
570}
571
572#[cfg(test)]
575mod tests {
576 use super::*;
577 use std::path::PathBuf;
578
579 fn data_dir() -> PathBuf {
580 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
581 .join("..")
582 .join("..")
583 .join("data")
584 }
585
586 #[test]
587 fn parse_heart_scale() {
588 let path = data_dir().join("heart_scale");
589 let problem = load_problem(&path).unwrap();
590 assert_eq!(problem.labels.len(), 270);
591 assert_eq!(problem.instances.len(), 270);
592 assert_eq!(problem.labels[0], 1.0);
594 assert_eq!(
595 problem.instances[0][0],
596 SvmNode {
597 index: 1,
598 value: 0.708333
599 }
600 );
601 assert_eq!(problem.instances[0].len(), 12);
602 }
603
604 #[test]
605 fn parse_iris() {
606 let path = data_dir().join("iris.scale");
607 let problem = load_problem(&path).unwrap();
608 assert_eq!(problem.labels.len(), 150);
609 let classes: std::collections::HashSet<i64> =
611 problem.labels.iter().map(|&l| l as i64).collect();
612 assert_eq!(classes.len(), 3);
613 }
614
615 #[test]
616 fn parse_housing() {
617 let path = data_dir().join("housing_scale");
618 let problem = load_problem(&path).unwrap();
619 assert_eq!(problem.labels.len(), 506);
620 assert!((problem.labels[0] - 24.0).abs() < 1e-10);
622 }
623
624 #[test]
625 fn parse_empty_lines() {
626 let input = b"+1 1:0.5\n\n-1 2:0.3\n";
627 let problem = load_problem_from_reader(&input[..]).unwrap();
628 assert_eq!(problem.labels.len(), 2);
629 }
630
631 #[test]
632 fn parse_error_unsorted_indices() {
633 let input = b"+1 3:0.5 1:0.3\n";
634 let result = load_problem_from_reader(&input[..]);
635 assert!(result.is_err());
636 let msg = format!("{}", result.unwrap_err());
637 assert!(msg.contains("ascending"), "error: {}", msg);
638 }
639
640 #[test]
641 fn parse_error_duplicate_indices() {
642 let input = b"+1 1:0.5 1:0.3\n";
643 let result = load_problem_from_reader(&input[..]);
644 assert!(result.is_err());
645 }
646
647 #[test]
648 fn parse_error_missing_colon() {
649 let input = b"+1 1:0.5 bad_token\n";
650 let result = load_problem_from_reader(&input[..]);
651 assert!(result.is_err());
652 }
653
654 #[test]
655 #[allow(clippy::excessive_precision)]
656 fn load_c_trained_model() {
657 let path = data_dir().join("heart_scale.model");
659 let model = load_model(&path).unwrap();
660 assert_eq!(model.nr_class, 2);
661 assert_eq!(model.param.svm_type, SvmType::CSvc);
662 assert_eq!(model.param.kernel_type, KernelType::Rbf);
663 assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
664 assert_eq!(model.sv.len(), 132);
665 assert_eq!(model.label, vec![1, -1]);
666 assert_eq!(model.n_sv, vec![64, 68]);
667 assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
668 assert_eq!(model.sv_coef.len(), 1);
670 assert_eq!(model.sv_coef[0].len(), 132);
671 }
672
673 #[test]
674 fn roundtrip_c_model() {
675 let path = data_dir().join("heart_scale.model");
677 let original_bytes = std::fs::read_to_string(&path).unwrap();
678 let model = load_model(&path).unwrap();
679
680 let mut buf = Vec::new();
681 save_model_to_writer(&mut buf, &model).unwrap();
682 let rust_output = String::from_utf8(buf).unwrap();
683
684 let orig_lines: Vec<&str> = original_bytes.lines().collect();
686 let rust_lines: Vec<&str> = rust_output.lines().collect();
687 assert_eq!(
688 orig_lines.len(),
689 rust_lines.len(),
690 "line count mismatch: C={} Rust={}",
691 orig_lines.len(),
692 rust_lines.len()
693 );
694 for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
695 assert_eq!(
696 o,
697 r,
698 "line {} differs:\n C: {:?}\n Rust: {:?}",
699 i + 1,
700 o,
701 r
702 );
703 }
704 }
705
706 #[test]
707 #[allow(clippy::excessive_precision)]
708 fn gfmt_matches_c_printf() {
709 let cases: &[(f64, &str, &str)] = &[
711 (0.5, "0.5", "0.5"),
712 (-1.0, "-1", "-1"),
713 (0.123456789012345, "0.123456789012345", "0.12345679"),
714 (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
715 (0.42446200000000001, "0.42446200000000001", "0.424462"),
716 (0.0, "0", "0"),
717 (1e-5, "1.0000000000000001e-05", "1e-05"),
718 (1e-4, "0.0001", "0.0001"),
719 (1e20, "1e+20", "1e+20"),
720 (-0.25, "-0.25", "-0.25"),
721 (0.75, "0.75", "0.75"),
722 (0.708333, "0.70833299999999999", "0.708333"),
723 (1.0, "1", "1"),
724 ];
725 for &(v, expected_17g, expected_8g) in cases {
726 let got_17 = format!("{}", fmt_17g(v));
727 let got_8 = format!("{}", fmt_8g(v));
728 assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
729 assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
730 }
731 }
732
733 #[test]
734 #[allow(clippy::excessive_precision)]
735 fn model_roundtrip() {
736 let model = SvmModel {
738 param: SvmParameter {
739 svm_type: SvmType::CSvc,
740 kernel_type: KernelType::Rbf,
741 gamma: 0.5,
742 ..Default::default()
743 },
744 nr_class: 2,
745 sv: vec![
746 vec![
747 SvmNode {
748 index: 1,
749 value: 0.5,
750 },
751 SvmNode {
752 index: 3,
753 value: -1.0,
754 },
755 ],
756 vec![
757 SvmNode {
758 index: 1,
759 value: -0.25,
760 },
761 SvmNode {
762 index: 2,
763 value: 0.75,
764 },
765 ],
766 ],
767 sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
768 rho: vec![0.42446200000000001],
769 prob_a: vec![],
770 prob_b: vec![],
771 prob_density_marks: vec![],
772 sv_indices: vec![],
773 label: vec![1, -1],
774 n_sv: vec![1, 1],
775 };
776
777 let mut buf = Vec::new();
778 save_model_to_writer(&mut buf, &model).unwrap();
779
780 let loaded = load_model_from_reader(&buf[..]).unwrap();
781
782 assert_eq!(loaded.nr_class, model.nr_class);
783 assert_eq!(loaded.param.svm_type, model.param.svm_type);
784 assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
785 assert_eq!(loaded.sv.len(), model.sv.len());
786 assert_eq!(loaded.label, model.label);
787 assert_eq!(loaded.n_sv, model.n_sv);
788 assert_eq!(loaded.rho.len(), model.rho.len());
789 for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
791 assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
792 }
793 for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
795 for (a, b) in row_a.iter().zip(row_b.iter()) {
796 assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
797 }
798 }
799 }
800
801 #[test]
802 fn parse_error_excessive_counts() {
803 let input =
804 b"svm_type c_svc\nkernel_type linear\nnr_class 1000000\ntotal_sv 100\nrho 0\nSV\n";
805 let result = load_model_from_reader(&input[..]);
806 assert!(result.is_err());
807 assert!(format!("{}", result.unwrap_err()).contains("nr_class exceeds limit"));
808
809 let input =
810 b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 100000000\nrho 0\nSV\n";
811 let result = load_model_from_reader(&input[..]);
812 assert!(result.is_err());
813 assert!(format!("{}", result.unwrap_err()).contains("total_sv exceeds limit"));
814 }
815
816 #[test]
817 fn parse_error_excessive_feature_index() {
818 let input = b"1 10000001:1\n";
820 let result = load_problem_from_reader(&input[..]);
821 assert!(result.is_err());
822 assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
823
824 let input = b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 1\nrho 0\nSV\n0.1 10000001:1\n";
826 let result = load_model_from_reader(&input[..]);
827 assert!(result.is_err());
828 assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
829 }
830
831 #[test]
832 fn parse_error_unknown_model_keyword() {
833 let input = b"bad_key value\n";
834 let result = load_model_from_reader(&input[..]);
835 assert!(result.is_err());
836 assert!(format!("{}", result.unwrap_err()).contains("unknown keyword"));
837 }
838
839 #[test]
840 fn parse_error_missing_or_unknown_model_values() {
841 let missing = b"svm_type\n";
842 let err = load_model_from_reader(&missing[..]).unwrap_err();
843 assert!(format!("{}", err).contains("missing svm_type value"));
844
845 let unknown = b"svm_type unknown_type\n";
846 let err = load_model_from_reader(&unknown[..]).unwrap_err();
847 assert!(format!("{}", err).contains("unknown svm_type"));
848 }
849
850 #[test]
851 fn parse_error_invalid_nr_sv_entry() {
852 let input = b"svm_type c_svc\n\
853kernel_type linear\n\
854nr_class 2\n\
855total_sv 1\n\
856rho 0\n\
857nr_sv a 1\n\
858SV\n\
8590.1 1:0.5\n";
860 let err = load_model_from_reader(&input[..]).unwrap_err();
861 assert!(format!("{}", err).contains("invalid nr_sv value"));
862 }
863
864 #[test]
865 fn parse_error_in_sv_section_tokens() {
866 let missing_coef = b"svm_type c_svc\n\
867kernel_type linear\n\
868nr_class 2\n\
869total_sv 1\n\
870rho 0\n\
871SV\n\
8721:0.5\n";
873 let err = load_model_from_reader(&missing_coef[..]).unwrap_err();
874 assert!(format!("{}", err).contains("invalid sv_coef"));
875
876 let bad_feature = b"svm_type c_svc\n\
877kernel_type linear\n\
878nr_class 2\n\
879total_sv 1\n\
880rho 0\n\
881SV\n\
8820.1 bad\n";
883 let err = load_model_from_reader(&bad_feature[..]).unwrap_err();
884 assert!(format!("{}", err).contains("expected index:value"));
885 }
886
887 #[test]
888 fn parse_error_unexpected_eof_in_header_and_sv_section() {
889 let eof_header = b"svm_type c_svc\n";
890 let err = load_model_from_reader(&eof_header[..]).unwrap_err();
891 assert!(format!("{}", err).contains("unexpected end of file in header"));
892
893 let eof_sv = b"svm_type c_svc\n\
894kernel_type linear\n\
895nr_class 2\n\
896total_sv 2\n\
897rho 0\n\
898SV\n\
8990.1 1:0.5\n";
900 let err = load_model_from_reader(&eof_sv[..]).unwrap_err();
901 assert!(format!("{}", err).contains("unexpected end of file in SV section"));
902 }
903
904 #[test]
905 fn save_precomputed_model_writes_zero_index() {
906 let model = SvmModel {
907 param: SvmParameter {
908 svm_type: SvmType::CSvc,
909 kernel_type: KernelType::Precomputed,
910 ..Default::default()
911 },
912 nr_class: 2,
913 sv: vec![vec![SvmNode {
914 index: 0,
915 value: 7.0,
916 }]],
917 sv_coef: vec![vec![0.25]],
918 rho: vec![0.0],
919 prob_a: vec![],
920 prob_b: vec![],
921 prob_density_marks: vec![],
922 sv_indices: vec![],
923 label: vec![1, -1],
924 n_sv: vec![1, 0],
925 };
926
927 let mut buf = Vec::new();
928 save_model_to_writer(&mut buf, &model).unwrap();
929 let out = String::from_utf8(buf).unwrap();
930 assert!(out.contains("kernel_type precomputed"));
931 assert!(out.contains("0:7"));
932 }
933}