1use std::io::{BufRead, Write};
20use std::path::Path;
21
22use crate::error::SvmError;
23use crate::types::*;
24use crate::util::parse_feature_index;
25use crate::util::MAX_FEATURE_INDEX;
26
27use std::fmt;
37
38struct Gfmt {
40 value: f64,
41 precision: usize,
42}
43
44impl Gfmt {
45 fn new(value: f64, precision: usize) -> Self {
46 Self { value, precision }
47 }
48}
49
50impl fmt::Display for Gfmt {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 let v = self.value;
53 let p = self.precision;
54
55 if !v.is_finite() {
56 return write!(f, "{}", v); }
58
59 if v == 0.0 {
60 if v.is_sign_negative() {
62 return write!(f, "-0");
63 }
64 return write!(f, "0");
65 }
66
67 let abs_v = v.abs();
69 let exp = abs_v.log10().floor() as i32;
70
71 if exp < -4 || exp >= p as i32 {
72 let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
74 if let Some((mantissa, exponent)) = s.split_once('e') {
77 let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
78 let exp_val: i32 = exponent.parse().unwrap_or(0);
80 let exp_str = if exp_val < 0 {
81 format!("-{:02}", -exp_val)
82 } else {
83 format!("+{:02}", exp_val)
84 };
85 write!(f, "{}e{}", mantissa, exp_str)
86 } else {
87 write!(f, "{}", s)
88 }
89 } else {
90 let decimal_places = if exp >= 0 {
92 p.saturating_sub((exp + 1) as usize)
93 } else {
94 p + (-1 - exp) as usize
95 };
96 let s = format!("{:.prec$}", v, prec = decimal_places);
97 let s = s.trim_end_matches('0').trim_end_matches('.');
98 write!(f, "{}", s)
99 }
100 }
101}
102
103fn fmt_17g(v: f64) -> Gfmt {
105 Gfmt::new(v, 17)
106}
107
108fn fmt_8g(v: f64) -> Gfmt {
110 Gfmt::new(v, 8)
111}
112
113pub fn format_g(v: f64) -> String {
115 format!("{}", Gfmt::new(v, 6))
116}
117
118pub fn format_17g(v: f64) -> String {
120 format!("{}", Gfmt::new(v, 17))
121}
122
123const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
126const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
127
128fn svm_type_to_str(t: SvmType) -> &'static str {
129 SVM_TYPE_TABLE[t as usize]
130}
131
132fn kernel_type_to_str(t: KernelType) -> &'static str {
133 KERNEL_TYPE_TABLE[t as usize]
134}
135
136fn str_to_svm_type(s: &str) -> Option<SvmType> {
137 match s {
138 "c_svc" => Some(SvmType::CSvc),
139 "nu_svc" => Some(SvmType::NuSvc),
140 "one_class" => Some(SvmType::OneClass),
141 "epsilon_svr" => Some(SvmType::EpsilonSvr),
142 "nu_svr" => Some(SvmType::NuSvr),
143 _ => None,
144 }
145}
146
147fn str_to_kernel_type(s: &str) -> Option<KernelType> {
148 match s {
149 "linear" => Some(KernelType::Linear),
150 "polynomial" => Some(KernelType::Polynomial),
151 "rbf" => Some(KernelType::Rbf),
152 "sigmoid" => Some(KernelType::Sigmoid),
153 "precomputed" => Some(KernelType::Precomputed),
154 _ => None,
155 }
156}
157
158#[derive(Debug, Clone, Copy)]
199pub struct LoadOptions {
200 pub max_bytes: u64,
202 pub max_line_len: usize,
204 pub max_sv: usize,
206 pub max_nr_class: usize,
208 pub max_feature_index: i32,
210}
211
212impl Default for LoadOptions {
213 fn default() -> Self {
214 Self {
215 max_bytes: 64 * 1024 * 1024,
216 max_line_len: 1024 * 1024,
217 max_sv: MAX_TOTAL_SV,
218 max_nr_class: MAX_NR_CLASS,
219 max_feature_index: MAX_FEATURE_INDEX,
220 }
221 }
222}
223
224impl LoadOptions {
225 pub fn trusted_input() -> Self {
232 Self {
233 max_bytes: u64::MAX,
234 max_line_len: usize::MAX,
235 max_sv: usize::MAX,
236 max_nr_class: usize::MAX,
237 max_feature_index: i32::MAX,
238 }
239 }
240}
241
242fn read_line_capped(
253 reader: &mut dyn BufRead,
254 bytes_read: &mut u64,
255 max_bytes: u64,
256 max_line_len: usize,
257) -> std::io::Result<Option<String>> {
258 let per_line_raw_cap: u64 = (max_line_len as u64).saturating_add(1);
264
265 let mut buf: Vec<u8> = Vec::new();
266 let mut found_newline = false;
267
268 loop {
269 let available = reader.fill_buf()?;
270 if available.is_empty() {
271 break;
272 }
273
274 let take_n = match available.iter().position(|&b| b == b'\n') {
275 Some(pos) => pos + 1, None => available.len(),
277 };
278 let ends_with_newline = take_n > 0 && available[take_n - 1] == b'\n';
279
280 let new_bytes_read = bytes_read.saturating_add(take_n as u64);
282 if new_bytes_read > max_bytes {
283 return Err(std::io::Error::new(
284 std::io::ErrorKind::InvalidData,
285 format!("input exceeds max_bytes limit ({})", max_bytes),
286 ));
287 }
288
289 let prospective_len = (buf.len() as u64).saturating_add(take_n as u64);
292 let content_bytes = if ends_with_newline {
293 prospective_len - 1
294 } else {
295 prospective_len
296 };
297 if content_bytes > per_line_raw_cap {
298 return Err(std::io::Error::new(
299 std::io::ErrorKind::InvalidData,
300 format!("line length exceeds max_line_len limit ({})", max_line_len),
301 ));
302 }
303
304 if available[..take_n].contains(&0) {
306 return Err(std::io::Error::new(
307 std::io::ErrorKind::InvalidData,
308 "unexpected NUL byte in text input".to_string(),
309 ));
310 }
311
312 buf.extend_from_slice(&available[..take_n]);
313 reader.consume(take_n);
314 *bytes_read = new_bytes_read;
315
316 if ends_with_newline {
317 found_newline = true;
318 break;
319 }
320 }
321
322 if buf.is_empty() && !found_newline {
323 return Ok(None);
324 }
325
326 let line = String::from_utf8(buf)
327 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
328 Ok(Some(line))
329}
330
331pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
361 let file = std::fs::File::open(path)?;
362 let reader = std::io::BufReader::new(file);
363 load_problem_from_reader(reader)
364}
365
366pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
371 load_problem_from_reader_with_options(reader, &LoadOptions::default())
372}
373
374pub fn load_problem_from_reader_with_options(
380 mut reader: impl BufRead,
381 options: &LoadOptions,
382) -> Result<SvmProblem, SvmError> {
383 let mut labels = Vec::new();
384 let mut instances = Vec::new();
385 let mut bytes_read: u64 = 0;
386 let mut line_idx: usize = 0;
387
388 while let Some(raw) = read_line_capped(
389 &mut reader,
390 &mut bytes_read,
391 options.max_bytes,
392 options.max_line_len,
393 )? {
394 let line_num = line_idx + 1;
395 line_idx += 1;
396 let line = raw.trim();
397 if line.is_empty() {
398 continue;
399 }
400
401 let mut parts = line.split_whitespace();
402
403 let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
405 line: line_num,
406 message: "missing label".into(),
407 })?;
408 let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
409 line: line_num,
410 message: format!("invalid label: {}", label_str),
411 })?;
412
413 let mut nodes = Vec::new();
415 let mut prev_index: i32 = 0;
416 for token in parts {
417 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
418 line: line_num,
419 message: format!("expected index:value, got: {}", token),
420 })?;
421 let index: i32 =
422 parse_feature_index_problem_line(line_num, idx_str, options.max_feature_index)?;
423
424 if !nodes.is_empty() && index <= prev_index {
425 return Err(SvmError::ParseError {
426 line: line_num,
427 message: format!(
428 "feature indices must be ascending: {} follows {}",
429 index, prev_index
430 ),
431 });
432 }
433 let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
434 line: line_num,
435 message: format!("invalid value: {}", val_str),
436 })?;
437 prev_index = index;
438 nodes.push(SvmNode { index, value });
439 }
440
441 labels.push(label);
442 instances.push(nodes);
443 }
444
445 Ok(SvmProblem { labels, instances })
446}
447
448const MAX_NR_CLASS: usize = 65535;
451const MAX_TOTAL_SV: usize = 10_000_000;
452
453pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
455 let file = std::fs::File::create(path)?;
456 let writer = std::io::BufWriter::new(file);
457 save_model_to_writer(writer, model)
458}
459
460pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
462 let param = &model.param;
463
464 writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
465 writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
466
467 if param.kernel_type == KernelType::Polynomial {
468 writeln!(w, "degree {}", param.degree)?;
469 }
470 if matches!(
471 param.kernel_type,
472 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
473 ) {
474 writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
475 }
476 if matches!(
477 param.kernel_type,
478 KernelType::Polynomial | KernelType::Sigmoid
479 ) {
480 writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
481 }
482
483 let nr_class = model.nr_class;
484 writeln!(w, "nr_class {}", nr_class)?;
485 writeln!(w, "total_sv {}", model.sv.len())?;
486
487 write!(w, "rho")?;
489 for r in &model.rho {
490 write!(w, " {}", fmt_17g(*r))?;
491 }
492 writeln!(w)?;
493
494 if !model.label.is_empty() {
496 write!(w, "label")?;
497 for l in &model.label {
498 write!(w, " {}", l)?;
499 }
500 writeln!(w)?;
501 }
502
503 if !model.prob_a.is_empty() {
505 write!(w, "probA")?;
506 for v in &model.prob_a {
507 write!(w, " {}", fmt_17g(*v))?;
508 }
509 writeln!(w)?;
510 }
511
512 if !model.prob_b.is_empty() {
514 write!(w, "probB")?;
515 for v in &model.prob_b {
516 write!(w, " {}", fmt_17g(*v))?;
517 }
518 writeln!(w)?;
519 }
520
521 if !model.prob_density_marks.is_empty() {
523 write!(w, "prob_density_marks")?;
524 for v in &model.prob_density_marks {
525 write!(w, " {}", fmt_17g(*v))?;
526 }
527 writeln!(w)?;
528 }
529
530 if !model.n_sv.is_empty() {
532 write!(w, "nr_sv")?;
533 for n in &model.n_sv {
534 write!(w, " {}", n)?;
535 }
536 writeln!(w)?;
537 }
538
539 writeln!(w, "SV")?;
541 let num_sv = model.sv.len();
542 let num_coef_rows = model.sv_coef.len(); for i in 0..num_sv {
545 for j in 0..num_coef_rows {
547 write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
548 }
549 if model.param.kernel_type == KernelType::Precomputed {
551 if let Some(node) = model.sv[i].first() {
552 write!(w, "0:{} ", node.value as i32)?;
553 }
554 } else {
555 for node in &model.sv[i] {
556 write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
557 }
558 }
559 writeln!(w)?;
560 }
561
562 Ok(())
563}
564
565pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
598 let file = std::fs::File::open(path)?;
599 let reader = std::io::BufReader::new(file);
600 load_model_from_reader(reader)
601}
602
603pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
608 load_model_from_reader_with_options(reader, &LoadOptions::default())
609}
610
611pub fn load_model_from_reader_with_options(
617 mut reader: impl BufRead,
618 options: &LoadOptions,
619) -> Result<SvmModel, SvmError> {
620 let mut bytes_read: u64 = 0;
621
622 let nr_class_cap = options.max_nr_class.min(MAX_NR_CLASS);
628 let total_sv_cap = options.max_sv.min(MAX_TOTAL_SV);
629
630 let mut param = SvmParameter::default();
632 let mut nr_class: usize = 0;
633 let mut total_sv: usize = 0;
634 let mut rho = Vec::new();
635 let mut label = Vec::new();
636 let mut prob_a = Vec::new();
637 let mut prob_b = Vec::new();
638 let mut prob_density_marks = Vec::new();
639 let mut n_sv = Vec::new();
640
641 let mut line_num: usize = 0;
643 loop {
644 let raw = read_line_capped(
645 &mut reader,
646 &mut bytes_read,
647 options.max_bytes,
648 options.max_line_len,
649 )
650 .map_err(|e| SvmError::ModelFormatError(e.to_string()))?
651 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in header".into()))?;
652 line_num += 1;
653 let line = raw.trim().to_string();
654 if line.is_empty() {
655 continue;
656 }
657
658 let mut parts = line.split_whitespace();
659 let cmd = parts.next().ok_or_else(|| {
660 SvmError::ModelFormatError(format!("line {}: empty model header line", line_num))
661 })?;
662
663 match cmd {
664 "svm_type" => {
665 let val = parts.next().ok_or_else(|| {
666 SvmError::ModelFormatError(format!("line {}: missing svm_type value", line_num))
667 })?;
668 param.svm_type = str_to_svm_type(val).ok_or_else(|| {
669 SvmError::ModelFormatError(format!(
670 "line {}: unknown svm_type: {}",
671 line_num, val
672 ))
673 })?;
674 }
675 "kernel_type" => {
676 let val = parts.next().ok_or_else(|| {
677 SvmError::ModelFormatError(format!(
678 "line {}: missing kernel_type value",
679 line_num
680 ))
681 })?;
682 param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
683 SvmError::ModelFormatError(format!(
684 "line {}: unknown kernel_type: {}",
685 line_num, val
686 ))
687 })?;
688 }
689 "degree" => {
690 param.degree = parse_single(&mut parts, line_num, "degree")?;
691 }
692 "gamma" => {
693 param.gamma = parse_single(&mut parts, line_num, "gamma")?;
694 }
695 "coef0" => {
696 param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
697 }
698 "nr_class" => {
699 nr_class = parse_single(&mut parts, line_num, "nr_class")?;
700 if nr_class > nr_class_cap {
701 return Err(SvmError::ModelFormatError(format!(
702 "line {}: nr_class exceeds limit ({})",
703 line_num, nr_class_cap
704 )));
705 }
706 }
707 "total_sv" => {
708 total_sv = parse_single(&mut parts, line_num, "total_sv")?;
709 if total_sv > total_sv_cap {
710 return Err(SvmError::ModelFormatError(format!(
711 "line {}: total_sv exceeds limit ({})",
712 line_num, total_sv_cap
713 )));
714 }
715 }
716 "rho" => {
717 rho = parse_multiple(&mut parts, line_num, "rho")?;
718 }
719 "label" => {
720 label = parse_multiple(&mut parts, line_num, "label")?;
721 }
722 "probA" => {
723 prob_a = parse_multiple(&mut parts, line_num, "probA")?;
724 }
725 "probB" => {
726 prob_b = parse_multiple(&mut parts, line_num, "probB")?;
727 }
728 "prob_density_marks" => {
729 prob_density_marks = parse_multiple(&mut parts, line_num, "prob_density_marks")?;
730 }
731 "nr_sv" => {
732 n_sv = parts
733 .map(|s| {
734 s.parse::<usize>().map_err(|_| {
735 SvmError::ModelFormatError(format!(
736 "line {}: invalid nr_sv value: {}",
737 line_num, s
738 ))
739 })
740 })
741 .collect::<Result<Vec<_>, _>>()?;
742 }
743 "SV" => break,
744 _ => {
745 return Err(SvmError::ModelFormatError(format!(
746 "line {}: unknown keyword: {}",
747 line_num, cmd
748 )));
749 }
750 }
751 }
752
753 validate_model_header(
759 param.svm_type,
760 nr_class,
761 total_sv,
762 &rho,
763 &label,
764 &prob_a,
765 &prob_b,
766 &prob_density_marks,
767 &n_sv,
768 )?;
769
770 let m = if nr_class > 1 { nr_class - 1 } else { 1 };
778 let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::new()).collect();
779 let mut sv: Vec<Vec<SvmNode>> = Vec::new();
780
781 while sv.len() < total_sv {
782 let raw = read_line_capped(
783 &mut reader,
784 &mut bytes_read,
785 options.max_bytes,
786 options.max_line_len,
787 )
788 .map_err(|e| SvmError::ModelFormatError(e.to_string()))?
789 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in SV section".into()))?;
790 line_num += 1;
791 let line = raw.trim();
792 if line.is_empty() {
793 continue;
796 }
797
798 let mut parts = line.split_whitespace();
799
800 for (k, coef_row) in sv_coef.iter_mut().enumerate() {
802 let val_str = parts.next().ok_or_else(|| {
803 SvmError::ModelFormatError(format!("line {}: missing sv_coef[{}]", line_num, k))
804 })?;
805 let val: f64 = val_str.parse().map_err(|_| {
806 SvmError::ModelFormatError(format!(
807 "line {}: invalid sv_coef: {}",
808 line_num, val_str
809 ))
810 })?;
811 coef_row.push(val);
812 }
813
814 let mut nodes = Vec::new();
817 let mut prev_index: i32 = 0;
818 for token in parts {
819 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
820 SvmError::ModelFormatError(format!(
821 "line {}: expected index:value, got: {}",
822 line_num, token
823 ))
824 })?;
825 let index: i32 =
826 parse_feature_index_model_line(line_num, idx_str, options.max_feature_index)?;
827
828 if !nodes.is_empty() && index <= prev_index {
829 return Err(SvmError::ModelFormatError(format!(
830 "line {}: feature indices must be ascending: {} follows {}",
831 line_num, index, prev_index
832 )));
833 }
834
835 let value: f64 = val_str.parse().map_err(|_| {
836 SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
837 })?;
838 prev_index = index;
839 nodes.push(SvmNode { index, value });
840 }
841
842 if param.kernel_type == KernelType::Precomputed {
843 validate_precomputed_row(&nodes, line_num, "support vector")?;
844 }
845 sv.push(nodes);
846 }
847
848 Ok(SvmModel {
849 param,
850 nr_class,
851 sv,
852 sv_coef,
853 rho,
854 prob_a,
855 prob_b,
856 prob_density_marks,
857 sv_indices: Vec::new(), label,
859 n_sv,
860 })
861}
862
863fn validate_precomputed_row(
864 nodes: &[SvmNode],
865 line_num: usize,
866 context: &str,
867) -> Result<(), SvmError> {
868 let first = nodes.first().ok_or_else(|| {
869 SvmError::ModelFormatError(format!(
870 "line {}: precomputed kernel {} is missing 0:sample_serial_number",
871 line_num, context
872 ))
873 })?;
874
875 if first.index != 0
876 || !first.value.is_finite()
877 || first.value < 1.0
878 || first.value.fract() != 0.0
879 {
880 return Err(SvmError::ModelFormatError(format!(
881 "line {}: precomputed kernel {} must start with 0:sample_serial_number",
882 line_num, context
883 )));
884 }
885
886 Ok(())
887}
888
889#[allow(clippy::too_many_arguments)]
911fn validate_model_header(
912 svm_type: SvmType,
913 nr_class: usize,
914 total_sv: usize,
915 rho: &[f64],
916 label: &[i32],
917 prob_a: &[f64],
918 prob_b: &[f64],
919 prob_density_marks: &[f64],
920 n_sv: &[usize],
921) -> Result<(), SvmError> {
922 let is_classification = matches!(svm_type, SvmType::CSvc | SvmType::NuSvc);
923 let is_regression = matches!(svm_type, SvmType::EpsilonSvr | SvmType::NuSvr);
924 let is_one_class = matches!(svm_type, SvmType::OneClass);
925
926 if nr_class < 2 {
931 return Err(SvmError::ModelFormatError(format!(
932 "nr_class must be >= 2, got {}",
933 nr_class
934 )));
935 }
936
937 let expected_rho = if is_classification {
939 nr_class * (nr_class - 1) / 2
940 } else {
941 1
942 };
943 if rho.len() != expected_rho {
944 return Err(SvmError::ModelFormatError(format!(
945 "rho has {} entries, expected {} for svm_type {}",
946 rho.len(),
947 expected_rho,
948 svm_type_to_str(svm_type)
949 )));
950 }
951
952 if !label.is_empty() {
954 if !is_classification {
955 return Err(SvmError::ModelFormatError(format!(
956 "label is only valid for classification, got {} entries on svm_type {}",
957 label.len(),
958 svm_type_to_str(svm_type)
959 )));
960 }
961 if label.len() != nr_class {
962 return Err(SvmError::ModelFormatError(format!(
963 "label has {} entries, expected nr_class ({})",
964 label.len(),
965 nr_class
966 )));
967 }
968 }
969
970 if !n_sv.is_empty() {
972 if !is_classification {
973 return Err(SvmError::ModelFormatError(format!(
974 "nr_sv is only valid for classification, got {} entries on svm_type {}",
975 n_sv.len(),
976 svm_type_to_str(svm_type)
977 )));
978 }
979 if n_sv.len() != nr_class {
980 return Err(SvmError::ModelFormatError(format!(
981 "nr_sv has {} entries, expected nr_class ({})",
982 n_sv.len(),
983 nr_class
984 )));
985 }
986 let mut sum: usize = 0;
990 for &n in n_sv {
991 sum = sum.checked_add(n).ok_or_else(|| {
992 SvmError::ModelFormatError("nr_sv entries overflow usize when summed".into())
993 })?;
994 }
995 if sum != total_sv {
996 return Err(SvmError::ModelFormatError(format!(
997 "sum of nr_sv entries ({}) does not match total_sv ({})",
998 sum, total_sv
999 )));
1000 }
1001 }
1002
1003 if !prob_a.is_empty() && prob_a.len() != expected_rho {
1005 return Err(SvmError::ModelFormatError(format!(
1006 "probA has {} entries, expected {}",
1007 prob_a.len(),
1008 expected_rho
1009 )));
1010 }
1011 if !prob_b.is_empty() && prob_b.len() != expected_rho {
1012 return Err(SvmError::ModelFormatError(format!(
1013 "probB has {} entries, expected {}",
1014 prob_b.len(),
1015 expected_rho
1016 )));
1017 }
1018
1019 if !prob_density_marks.is_empty() && !is_one_class {
1021 return Err(SvmError::ModelFormatError(format!(
1022 "prob_density_marks is only valid for one-class SVM, got {} entries on svm_type {}",
1023 prob_density_marks.len(),
1024 svm_type_to_str(svm_type)
1025 )));
1026 }
1027
1028 let _ = is_regression;
1032
1033 Ok(())
1034}
1035
1036fn parse_feature_index_problem_line(
1039 line_num: usize,
1040 idx_str: &str,
1041 max_feature_index: i32,
1042) -> Result<i32, SvmError> {
1043 parse_feature_index(idx_str, max_feature_index).map_err(|msg| SvmError::ParseError {
1044 line: line_num,
1045 message: msg,
1046 })
1047}
1048
1049fn parse_feature_index_model_line(
1050 line_num: usize,
1051 idx_str: &str,
1052 max_feature_index: i32,
1053) -> Result<i32, SvmError> {
1054 parse_feature_index(idx_str, max_feature_index)
1055 .map_err(|msg| SvmError::ModelFormatError(format!("line {}: {}", line_num, msg)))
1056}
1057
1058fn parse_single<T: std::str::FromStr>(
1059 parts: &mut std::str::SplitWhitespace<'_>,
1060 line_num: usize,
1061 field: &str,
1062) -> Result<T, SvmError> {
1063 let val_str = parts.next().ok_or_else(|| {
1064 SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
1065 })?;
1066 val_str.parse().map_err(|_| {
1067 SvmError::ModelFormatError(format!(
1068 "line {}: invalid {} value: {}",
1069 line_num, field, val_str
1070 ))
1071 })
1072}
1073
1074fn parse_multiple<T: std::str::FromStr>(
1075 parts: &mut std::str::SplitWhitespace<'_>,
1076 line_num: usize,
1077 field: &str,
1078) -> Result<Vec<T>, SvmError> {
1079 parts
1080 .map(|s| {
1081 s.parse::<T>().map_err(|_| {
1082 SvmError::ModelFormatError(format!(
1083 "line {}: invalid {} value: {}",
1084 line_num, field, s
1085 ))
1086 })
1087 })
1088 .collect()
1089}
1090
1091#[cfg(test)]
1094mod tests {
1095 use super::*;
1096 use std::path::PathBuf;
1097
1098 fn data_dir() -> PathBuf {
1099 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1100 .join("..")
1101 .join("..")
1102 .join("data")
1103 }
1104
1105 #[test]
1106 fn parse_heart_scale() {
1107 let path = data_dir().join("heart_scale");
1108 let problem = load_problem(&path).unwrap();
1109 assert_eq!(problem.labels.len(), 270);
1110 assert_eq!(problem.instances.len(), 270);
1111 assert_eq!(problem.labels[0], 1.0);
1113 assert_eq!(
1114 problem.instances[0][0],
1115 SvmNode {
1116 index: 1,
1117 value: 0.708333
1118 }
1119 );
1120 assert_eq!(problem.instances[0].len(), 12);
1121 }
1122
1123 #[test]
1124 fn parse_iris() {
1125 let path = data_dir().join("iris.scale");
1126 let problem = load_problem(&path).unwrap();
1127 assert_eq!(problem.labels.len(), 150);
1128 let classes: std::collections::HashSet<i64> =
1130 problem.labels.iter().map(|&l| l as i64).collect();
1131 assert_eq!(classes.len(), 3);
1132 }
1133
1134 #[test]
1135 fn parse_housing() {
1136 let path = data_dir().join("housing_scale");
1137 let problem = load_problem(&path).unwrap();
1138 assert_eq!(problem.labels.len(), 506);
1139 assert!((problem.labels[0] - 24.0).abs() < 1e-10);
1141 }
1142
1143 #[test]
1144 fn parse_empty_lines() {
1145 let input = b"+1 1:0.5\n\n-1 2:0.3\n";
1146 let problem = load_problem_from_reader(&input[..]).unwrap();
1147 assert_eq!(problem.labels.len(), 2);
1148 }
1149
1150 #[test]
1151 fn parse_error_unsorted_indices() {
1152 let input = b"+1 3:0.5 1:0.3\n";
1153 let result = load_problem_from_reader(&input[..]);
1154 assert!(result.is_err());
1155 let msg = format!("{}", result.unwrap_err());
1156 assert!(msg.contains("ascending"), "error: {}", msg);
1157 }
1158
1159 #[test]
1160 fn parse_error_duplicate_indices() {
1161 let input = b"+1 1:0.5 1:0.3\n";
1162 let result = load_problem_from_reader(&input[..]);
1163 assert!(result.is_err());
1164 }
1165
1166 #[test]
1167 fn parse_error_missing_colon() {
1168 let input = b"+1 1:0.5 bad_token\n";
1169 let result = load_problem_from_reader(&input[..]);
1170 assert!(result.is_err());
1171 }
1172
1173 #[test]
1174 #[allow(clippy::excessive_precision)]
1175 fn load_c_trained_model() {
1176 let path = data_dir().join("heart_scale.model");
1178 let model = load_model(&path).unwrap();
1179 assert_eq!(model.nr_class, 2);
1180 assert_eq!(model.param.svm_type, SvmType::CSvc);
1181 assert_eq!(model.param.kernel_type, KernelType::Rbf);
1182 assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
1183 assert_eq!(model.sv.len(), 132);
1184 assert_eq!(model.label, vec![1, -1]);
1185 assert_eq!(model.n_sv, vec![64, 68]);
1186 assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
1187 assert_eq!(model.sv_coef.len(), 1);
1189 assert_eq!(model.sv_coef[0].len(), 132);
1190 }
1191
1192 #[test]
1193 fn roundtrip_c_model() {
1194 let path = data_dir().join("heart_scale.model");
1196 let original_bytes = std::fs::read_to_string(&path).unwrap();
1197 let model = load_model(&path).unwrap();
1198
1199 let mut buf = Vec::new();
1200 save_model_to_writer(&mut buf, &model).unwrap();
1201 let rust_output = String::from_utf8(buf).unwrap();
1202
1203 let orig_lines: Vec<&str> = original_bytes.lines().collect();
1205 let rust_lines: Vec<&str> = rust_output.lines().collect();
1206 assert_eq!(
1207 orig_lines.len(),
1208 rust_lines.len(),
1209 "line count mismatch: C={} Rust={}",
1210 orig_lines.len(),
1211 rust_lines.len()
1212 );
1213 for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
1214 assert_eq!(
1215 o,
1216 r,
1217 "line {} differs:\n C: {:?}\n Rust: {:?}",
1218 i + 1,
1219 o,
1220 r
1221 );
1222 }
1223 }
1224
1225 #[test]
1226 #[allow(clippy::excessive_precision)]
1227 fn gfmt_matches_c_printf() {
1228 let cases: &[(f64, &str, &str)] = &[
1230 (0.5, "0.5", "0.5"),
1231 (-1.0, "-1", "-1"),
1232 (0.123456789012345, "0.123456789012345", "0.12345679"),
1233 (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
1234 (0.42446200000000001, "0.42446200000000001", "0.424462"),
1235 (0.0, "0", "0"),
1236 (1e-5, "1.0000000000000001e-05", "1e-05"),
1237 (1e-4, "0.0001", "0.0001"),
1238 (1e20, "1e+20", "1e+20"),
1239 (-0.25, "-0.25", "-0.25"),
1240 (0.75, "0.75", "0.75"),
1241 (0.708333, "0.70833299999999999", "0.708333"),
1242 (1.0, "1", "1"),
1243 ];
1244 for &(v, expected_17g, expected_8g) in cases {
1245 let got_17 = format!("{}", fmt_17g(v));
1246 let got_8 = format!("{}", fmt_8g(v));
1247 assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
1248 assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
1249 }
1250 }
1251
1252 #[test]
1253 #[allow(clippy::excessive_precision)]
1254 fn model_roundtrip() {
1255 let model = SvmModel {
1257 param: SvmParameter {
1258 svm_type: SvmType::CSvc,
1259 kernel_type: KernelType::Rbf,
1260 gamma: 0.5,
1261 ..Default::default()
1262 },
1263 nr_class: 2,
1264 sv: vec![
1265 vec![
1266 SvmNode {
1267 index: 1,
1268 value: 0.5,
1269 },
1270 SvmNode {
1271 index: 3,
1272 value: -1.0,
1273 },
1274 ],
1275 vec![
1276 SvmNode {
1277 index: 1,
1278 value: -0.25,
1279 },
1280 SvmNode {
1281 index: 2,
1282 value: 0.75,
1283 },
1284 ],
1285 ],
1286 sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
1287 rho: vec![0.42446200000000001],
1288 prob_a: vec![],
1289 prob_b: vec![],
1290 prob_density_marks: vec![],
1291 sv_indices: vec![],
1292 label: vec![1, -1],
1293 n_sv: vec![1, 1],
1294 };
1295
1296 let mut buf = Vec::new();
1297 save_model_to_writer(&mut buf, &model).unwrap();
1298
1299 let loaded = load_model_from_reader(&buf[..]).unwrap();
1300
1301 assert_eq!(loaded.nr_class, model.nr_class);
1302 assert_eq!(loaded.param.svm_type, model.param.svm_type);
1303 assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
1304 assert_eq!(loaded.sv.len(), model.sv.len());
1305 assert_eq!(loaded.label, model.label);
1306 assert_eq!(loaded.n_sv, model.n_sv);
1307 assert_eq!(loaded.rho.len(), model.rho.len());
1308 for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
1310 assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
1311 }
1312 for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
1314 for (a, b) in row_a.iter().zip(row_b.iter()) {
1315 assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
1316 }
1317 }
1318 }
1319
1320 #[test]
1321 fn parse_error_excessive_counts() {
1322 let input =
1323 b"svm_type c_svc\nkernel_type linear\nnr_class 1000000\ntotal_sv 100\nrho 0\nSV\n";
1324 let result = load_model_from_reader(&input[..]);
1325 assert!(result.is_err());
1326 assert!(format!("{}", result.unwrap_err()).contains("nr_class exceeds limit"));
1327
1328 let input =
1329 b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 100000000\nrho 0\nSV\n";
1330 let result = load_model_from_reader(&input[..]);
1331 assert!(result.is_err());
1332 assert!(format!("{}", result.unwrap_err()).contains("total_sv exceeds limit"));
1333 }
1334
1335 #[test]
1336 fn parse_error_excessive_feature_index() {
1337 let input = b"1 10000001:1\n";
1339 let result = load_problem_from_reader(&input[..]);
1340 assert!(result.is_err());
1341 assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
1342
1343 let input = b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 1\nrho 0\nSV\n0.1 10000001:1\n";
1345 let result = load_model_from_reader(&input[..]);
1346 assert!(result.is_err());
1347 assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
1348 }
1349
1350 #[test]
1351 fn parse_error_unknown_model_keyword() {
1352 let input = b"bad_key value\n";
1353 let result = load_model_from_reader(&input[..]);
1354 assert!(result.is_err());
1355 assert!(format!("{}", result.unwrap_err()).contains("unknown keyword"));
1356 }
1357
1358 #[test]
1359 fn parse_error_missing_or_unknown_model_values() {
1360 let missing = b"svm_type\n";
1361 let err = load_model_from_reader(&missing[..]).unwrap_err();
1362 assert!(format!("{}", err).contains("missing svm_type value"));
1363
1364 let unknown = b"svm_type unknown_type\n";
1365 let err = load_model_from_reader(&unknown[..]).unwrap_err();
1366 assert!(format!("{}", err).contains("unknown svm_type"));
1367 }
1368
1369 #[test]
1370 fn parse_error_invalid_nr_sv_entry() {
1371 let input = b"svm_type c_svc\n\
1372kernel_type linear\n\
1373nr_class 2\n\
1374total_sv 1\n\
1375rho 0\n\
1376nr_sv a 1\n\
1377SV\n\
13780.1 1:0.5\n";
1379 let err = load_model_from_reader(&input[..]).unwrap_err();
1380 assert!(format!("{}", err).contains("invalid nr_sv value"));
1381 }
1382
1383 #[test]
1384 fn parse_error_in_sv_section_tokens() {
1385 let missing_coef = b"svm_type c_svc\n\
1386kernel_type linear\n\
1387nr_class 2\n\
1388total_sv 1\n\
1389rho 0\n\
1390SV\n\
13911:0.5\n";
1392 let err = load_model_from_reader(&missing_coef[..]).unwrap_err();
1393 assert!(format!("{}", err).contains("invalid sv_coef"));
1394
1395 let bad_feature = b"svm_type c_svc\n\
1396kernel_type linear\n\
1397nr_class 2\n\
1398total_sv 1\n\
1399rho 0\n\
1400SV\n\
14010.1 bad\n";
1402 let err = load_model_from_reader(&bad_feature[..]).unwrap_err();
1403 assert!(format!("{}", err).contains("expected index:value"));
1404 }
1405
1406 #[test]
1407 fn parse_error_unexpected_eof_in_header_and_sv_section() {
1408 let eof_header = b"svm_type c_svc\n";
1409 let err = load_model_from_reader(&eof_header[..]).unwrap_err();
1410 assert!(format!("{}", err).contains("unexpected end of file in header"));
1411
1412 let eof_sv = b"svm_type c_svc\n\
1413kernel_type linear\n\
1414nr_class 2\n\
1415total_sv 2\n\
1416rho 0\n\
1417SV\n\
14180.1 1:0.5\n";
1419 let err = load_model_from_reader(&eof_sv[..]).unwrap_err();
1420 assert!(format!("{}", err).contains("unexpected end of file in SV section"));
1421 }
1422
1423 #[test]
1424 fn reject_rho_length_mismatch_for_classification() {
1425 let input = b"svm_type c_svc\n\
1428kernel_type linear\n\
1429nr_class 3\n\
1430total_sv 3\n\
1431rho 0\n\
1432SV\n";
1433 let err = load_model_from_reader(&input[..]).unwrap_err();
1434 assert!(
1435 format!("{}", err).contains("rho has 1 entries, expected 3"),
1436 "unexpected error: {}",
1437 err
1438 );
1439 }
1440
1441 #[test]
1442 fn reject_rho_length_mismatch_for_regression() {
1443 let input = b"svm_type epsilon_svr\n\
1445kernel_type linear\n\
1446nr_class 2\n\
1447total_sv 0\n\
1448rho 0 1\n\
1449SV\n";
1450 let err = load_model_from_reader(&input[..]).unwrap_err();
1451 assert!(
1452 format!("{}", err).contains("rho has 2 entries, expected 1"),
1453 "unexpected error: {}",
1454 err
1455 );
1456 }
1457
1458 #[test]
1459 fn reject_label_on_regression() {
1460 let input = b"svm_type epsilon_svr\n\
1462kernel_type linear\n\
1463nr_class 2\n\
1464total_sv 0\n\
1465rho 0\n\
1466label 1 -1\n\
1467SV\n";
1468 let err = load_model_from_reader(&input[..]).unwrap_err();
1469 assert!(
1470 format!("{}", err).contains("label is only valid for classification"),
1471 "unexpected error: {}",
1472 err
1473 );
1474 }
1475
1476 #[test]
1477 fn reject_label_length_mismatch() {
1478 let input = b"svm_type c_svc\n\
1479kernel_type linear\n\
1480nr_class 3\n\
1481total_sv 0\n\
1482rho 0 0 0\n\
1483label 1 -1\n\
1484SV\n";
1485 let err = load_model_from_reader(&input[..]).unwrap_err();
1486 assert!(
1487 format!("{}", err).contains("label has 2 entries, expected nr_class (3)"),
1488 "unexpected error: {}",
1489 err
1490 );
1491 }
1492
1493 #[test]
1494 fn reject_nr_sv_sum_mismatch() {
1495 let input = b"svm_type c_svc\n\
1498kernel_type linear\n\
1499nr_class 2\n\
1500total_sv 5\n\
1501rho 0\n\
1502label 1 -1\n\
1503nr_sv 1 2\n\
1504SV\n";
1505 let err = load_model_from_reader(&input[..]).unwrap_err();
1506 assert!(
1507 format!("{}", err).contains("sum of nr_sv entries (3) does not match total_sv (5)"),
1508 "unexpected error: {}",
1509 err
1510 );
1511 }
1512
1513 #[test]
1514 fn reject_nr_sv_length_mismatch() {
1515 let input = b"svm_type c_svc\n\
1516kernel_type linear\n\
1517nr_class 3\n\
1518total_sv 3\n\
1519rho 0 0 0\n\
1520label 1 -1 0\n\
1521nr_sv 1 2\n\
1522SV\n";
1523 let err = load_model_from_reader(&input[..]).unwrap_err();
1524 assert!(
1525 format!("{}", err).contains("nr_sv has 2 entries, expected nr_class (3)"),
1526 "unexpected error: {}",
1527 err
1528 );
1529 }
1530
1531 #[test]
1532 fn reject_proba_length_mismatch() {
1533 let input = b"svm_type c_svc\n\
1534kernel_type linear\n\
1535nr_class 3\n\
1536total_sv 0\n\
1537rho 0 0 0\n\
1538probA 0.1 0.2\n\
1539SV\n";
1540 let err = load_model_from_reader(&input[..]).unwrap_err();
1541 assert!(
1542 format!("{}", err).contains("probA has 2 entries, expected 3"),
1543 "unexpected error: {}",
1544 err
1545 );
1546 }
1547
1548 #[test]
1549 fn reject_prob_density_marks_on_csvc() {
1550 let input = b"svm_type c_svc\n\
1551kernel_type linear\n\
1552nr_class 2\n\
1553total_sv 0\n\
1554rho 0\n\
1555prob_density_marks 0.1 0.2\n\
1556SV\n";
1557 let err = load_model_from_reader(&input[..]).unwrap_err();
1558 assert!(
1559 format!("{}", err).contains("prob_density_marks is only valid for one-class SVM"),
1560 "unexpected error: {}",
1561 err
1562 );
1563 }
1564
1565 #[test]
1566 fn reject_nr_class_below_two() {
1567 let input = b"svm_type c_svc\n\
1568kernel_type linear\n\
1569nr_class 1\n\
1570total_sv 0\n\
1571rho\n\
1572SV\n";
1573 let err = load_model_from_reader(&input[..]).unwrap_err();
1574 assert!(
1575 format!("{}", err).contains("nr_class must be >= 2, got 1"),
1576 "unexpected error: {}",
1577 err
1578 );
1579 }
1580
1581 #[test]
1582 fn reject_sv_feature_indices_not_ascending() {
1583 let input = b"svm_type c_svc\n\
1584kernel_type linear\n\
1585nr_class 2\n\
1586total_sv 1\n\
1587rho 0\n\
1588SV\n\
15890.1 3:0.5 1:0.3\n";
1590 let err = load_model_from_reader(&input[..]).unwrap_err();
1591 assert!(
1592 format!("{}", err).contains("feature indices must be ascending"),
1593 "unexpected error: {}",
1594 err
1595 );
1596 }
1597
1598 #[test]
1599 fn reject_precomputed_model_sv_without_sample_serial_number() {
1600 let input = b"svm_type c_svc\n\
1601kernel_type precomputed\n\
1602nr_class 2\n\
1603total_sv 1\n\
1604rho 0\n\
1605SV\n\
16060.1\n";
1607 let err = load_model_from_reader(&input[..]).unwrap_err();
1608 assert!(
1609 format!("{}", err).contains("missing 0:sample_serial_number"),
1610 "unexpected error: {}",
1611 err
1612 );
1613 }
1614
1615 #[test]
1616 fn reject_sv_feature_index_duplicated() {
1617 let input = b"svm_type c_svc\n\
1618kernel_type linear\n\
1619nr_class 2\n\
1620total_sv 1\n\
1621rho 0\n\
1622SV\n\
16230.1 1:0.5 1:0.3\n";
1624 let err = load_model_from_reader(&input[..]).unwrap_err();
1625 assert!(
1626 format!("{}", err).contains("feature indices must be ascending"),
1627 "unexpected error: {}",
1628 err
1629 );
1630 }
1631
1632 #[test]
1633 fn load_options_default_caps_match_documented_values() {
1634 let opts = LoadOptions::default();
1635 assert_eq!(opts.max_bytes, 64 * 1024 * 1024);
1636 assert_eq!(opts.max_line_len, 1024 * 1024);
1637 assert_eq!(opts.max_sv, MAX_TOTAL_SV);
1638 assert_eq!(opts.max_nr_class, MAX_NR_CLASS);
1639 assert_eq!(opts.max_feature_index, MAX_FEATURE_INDEX);
1640 }
1641
1642 #[test]
1643 fn load_options_trusted_input_sets_type_maxes() {
1644 let opts = LoadOptions::trusted_input();
1645 assert_eq!(opts.max_bytes, u64::MAX);
1646 assert_eq!(opts.max_line_len, usize::MAX);
1647 assert_eq!(opts.max_sv, usize::MAX);
1648 assert_eq!(opts.max_nr_class, usize::MAX);
1649 assert_eq!(opts.max_feature_index, i32::MAX);
1650 }
1651
1652 #[test]
1653 fn problem_reader_rejects_file_over_max_bytes() {
1654 let input = b"+1 1:0.5\n+1 2:0.5\n";
1656 let opts = LoadOptions {
1657 max_bytes: 10,
1658 ..LoadOptions::default()
1659 };
1660 let err = load_problem_from_reader_with_options(&input[..], &opts).unwrap_err();
1661 assert!(
1662 format!("{}", err).contains("max_bytes"),
1663 "unexpected error: {}",
1664 err
1665 );
1666 }
1667
1668 #[test]
1669 fn problem_reader_rejects_line_over_max_line_len() {
1670 let mut payload = String::from("+1 ");
1672 for i in 1..=50 {
1673 payload.push_str(&format!("{}:0.1 ", i));
1674 }
1675 payload.push('\n');
1676 let opts = LoadOptions {
1677 max_line_len: 50,
1678 ..LoadOptions::default()
1679 };
1680 let err = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap_err();
1681 assert!(
1682 format!("{}", err).contains("max_line_len"),
1683 "unexpected error: {}",
1684 err
1685 );
1686 }
1687
1688 #[test]
1689 fn problem_reader_accepts_line_at_max_line_len() {
1690 let line_content = "+1 1:0.5";
1693 let payload = format!("{}\n", line_content);
1694 let opts = LoadOptions {
1695 max_line_len: line_content.len(),
1696 ..LoadOptions::default()
1697 };
1698 let problem = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap();
1699 assert_eq!(problem.labels.len(), 1);
1700 }
1701
1702 #[test]
1703 fn problem_reader_tolerates_crlf_at_cap() {
1704 let line_content = "+1 1:0.5";
1707 let payload = format!("{}\r\n", line_content);
1708 let opts = LoadOptions {
1709 max_line_len: line_content.len(),
1710 ..LoadOptions::default()
1711 };
1712 let problem = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap();
1713 assert_eq!(problem.labels.len(), 1);
1714 }
1715
1716 #[test]
1717 fn problem_reader_rejects_nul_byte() {
1718 let mut payload: Vec<u8> = b"+1 1:0.5".to_vec();
1721 payload.push(0);
1722 payload.extend_from_slice(b"\n");
1723 let err = load_problem_from_reader(payload.as_slice()).unwrap_err();
1724 assert!(
1725 format!("{}", err).contains("NUL byte"),
1726 "unexpected error: {}",
1727 err
1728 );
1729 }
1730
1731 #[test]
1732 fn model_reader_honors_max_nr_class_cap() {
1733 let input = b"svm_type c_svc\n\
1736kernel_type linear\n\
1737nr_class 100\n\
1738total_sv 1\n\
1739rho 0\n\
1740SV\n";
1741 let opts = LoadOptions {
1742 max_nr_class: 50,
1743 ..LoadOptions::default()
1744 };
1745 let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
1746 assert!(
1747 format!("{}", err).contains("nr_class exceeds limit (50)"),
1748 "unexpected error: {}",
1749 err
1750 );
1751 }
1752
1753 #[test]
1754 fn model_reader_honors_max_sv_cap() {
1755 let input = b"svm_type c_svc\n\
1756kernel_type linear\n\
1757nr_class 2\n\
1758total_sv 1000\n\
1759rho 0\n\
1760SV\n";
1761 let opts = LoadOptions {
1762 max_sv: 100,
1763 ..LoadOptions::default()
1764 };
1765 let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
1766 assert!(
1767 format!("{}", err).contains("total_sv exceeds limit (100)"),
1768 "unexpected error: {}",
1769 err
1770 );
1771 }
1772
1773 #[test]
1774 fn trusted_input_cannot_exceed_hard_module_caps() {
1775 let huge_nr_class = format!(
1779 "svm_type c_svc\n\
1780kernel_type linear\n\
1781nr_class {}\n\
1782total_sv 1\n\
1783rho 0\n\
1784SV\n",
1785 MAX_NR_CLASS + 1
1786 );
1787 let opts = LoadOptions::trusted_input();
1788 let err = load_model_from_reader_with_options(huge_nr_class.as_bytes(), &opts).unwrap_err();
1789 assert!(
1790 format!("{}", err).contains("nr_class exceeds limit"),
1791 "unexpected error: {}",
1792 err
1793 );
1794 }
1795
1796 #[test]
1797 fn model_reader_honors_max_feature_index_cap() {
1798 let input = b"svm_type c_svc\n\
1799kernel_type linear\n\
1800nr_class 2\n\
1801total_sv 1\n\
1802rho 0\n\
1803SV\n\
18040.1 50:0.5\n";
1805 let opts = LoadOptions {
1806 max_feature_index: 10,
1807 ..LoadOptions::default()
1808 };
1809 let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
1810 assert!(
1811 format!("{}", err).contains("feature index 50 exceeds limit (10)"),
1812 "unexpected error: {}",
1813 err
1814 );
1815 }
1816
1817 #[test]
1818 fn sv_count_loop_counts_nonblank_lines_only() {
1819 let input = b"svm_type c_svc\n\
1825kernel_type linear\n\
1826nr_class 2\n\
1827total_sv 2\n\
1828rho 0\n\
1829label 1 -1\n\
1830nr_sv 1 1\n\
1831SV\n\
1832\n\
18330.1 1:0.5\n\
1834\n\
1835-0.1 2:0.5\n";
1836 let model = load_model_from_reader(&input[..]).unwrap();
1837 assert_eq!(model.sv.len(), 2);
1838 assert_eq!(model.sv_coef[0].len(), 2);
1839 }
1840
1841 #[test]
1842 fn save_precomputed_model_writes_zero_index() {
1843 let model = SvmModel {
1844 param: SvmParameter {
1845 svm_type: SvmType::CSvc,
1846 kernel_type: KernelType::Precomputed,
1847 ..Default::default()
1848 },
1849 nr_class: 2,
1850 sv: vec![vec![SvmNode {
1851 index: 0,
1852 value: 7.0,
1853 }]],
1854 sv_coef: vec![vec![0.25]],
1855 rho: vec![0.0],
1856 prob_a: vec![],
1857 prob_b: vec![],
1858 prob_density_marks: vec![],
1859 sv_indices: vec![],
1860 label: vec![1, -1],
1861 n_sv: vec![1, 0],
1862 };
1863
1864 let mut buf = Vec::new();
1865 save_model_to_writer(&mut buf, &model).unwrap();
1866 let out = String::from_utf8(buf).unwrap();
1867 assert!(out.contains("kernel_type precomputed"));
1868 assert!(out.contains("0:7"));
1869 }
1870}