1use std::io::{BufRead, Write};
20use std::path::Path;
21
22use crate::error::SvmError;
23use crate::types::*;
24use crate::util::parse_feature_index;
25use crate::util::MAX_FEATURE_INDEX;
26
27use std::fmt;
37
38struct Gfmt {
40 value: f64,
41 precision: usize,
42}
43
44impl Gfmt {
45 fn new(value: f64, precision: usize) -> Self {
46 Self { value, precision }
47 }
48}
49
50impl fmt::Display for Gfmt {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 let v = self.value;
53 let p = self.precision;
54
55 if !v.is_finite() {
56 return write!(f, "{}", v); }
58
59 if v == 0.0 {
60 if v.is_sign_negative() {
62 return write!(f, "-0");
63 }
64 return write!(f, "0");
65 }
66
67 let abs_v = v.abs();
69 let exp = abs_v.log10().floor() as i32;
70
71 if exp < -4 || exp >= p as i32 {
72 let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
74 if let Some((mantissa, exponent)) = s.split_once('e') {
77 let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
78 let exp_val: i32 = exponent.parse().unwrap_or(0);
80 let exp_str = if exp_val < 0 {
81 format!("-{:02}", -exp_val)
82 } else {
83 format!("+{:02}", exp_val)
84 };
85 write!(f, "{}e{}", mantissa, exp_str)
86 } else {
87 write!(f, "{}", s)
88 }
89 } else {
90 let decimal_places = if exp >= 0 {
92 p.saturating_sub((exp + 1) as usize)
93 } else {
94 p + (-1 - exp) as usize
95 };
96 let s = format!("{:.prec$}", v, prec = decimal_places);
97 let s = s.trim_end_matches('0').trim_end_matches('.');
98 write!(f, "{}", s)
99 }
100 }
101}
102
103fn fmt_17g(v: f64) -> Gfmt {
105 Gfmt::new(v, 17)
106}
107
108fn fmt_8g(v: f64) -> Gfmt {
110 Gfmt::new(v, 8)
111}
112
113pub fn format_g(v: f64) -> String {
115 format!("{}", Gfmt::new(v, 6))
116}
117
118pub fn format_17g(v: f64) -> String {
120 format!("{}", Gfmt::new(v, 17))
121}
122
123const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
126const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
127
128fn svm_type_to_str(t: SvmType) -> &'static str {
129 SVM_TYPE_TABLE[t as usize]
130}
131
132fn kernel_type_to_str(t: KernelType) -> &'static str {
133 KERNEL_TYPE_TABLE[t as usize]
134}
135
136fn str_to_svm_type(s: &str) -> Option<SvmType> {
137 match s {
138 "c_svc" => Some(SvmType::CSvc),
139 "nu_svc" => Some(SvmType::NuSvc),
140 "one_class" => Some(SvmType::OneClass),
141 "epsilon_svr" => Some(SvmType::EpsilonSvr),
142 "nu_svr" => Some(SvmType::NuSvr),
143 _ => None,
144 }
145}
146
147fn str_to_kernel_type(s: &str) -> Option<KernelType> {
148 match s {
149 "linear" => Some(KernelType::Linear),
150 "polynomial" => Some(KernelType::Polynomial),
151 "rbf" => Some(KernelType::Rbf),
152 "sigmoid" => Some(KernelType::Sigmoid),
153 "precomputed" => Some(KernelType::Precomputed),
154 _ => None,
155 }
156}
157
158#[derive(Debug, Clone, Copy)]
199pub struct LoadOptions {
200 pub max_bytes: u64,
202 pub max_line_len: usize,
204 pub max_sv: usize,
206 pub max_nr_class: usize,
208 pub max_feature_index: i32,
210}
211
212impl Default for LoadOptions {
213 fn default() -> Self {
214 Self {
215 max_bytes: 64 * 1024 * 1024,
216 max_line_len: 1024 * 1024,
217 max_sv: MAX_TOTAL_SV,
218 max_nr_class: MAX_NR_CLASS,
219 max_feature_index: MAX_FEATURE_INDEX,
220 }
221 }
222}
223
224impl LoadOptions {
225 pub fn trusted_input() -> Self {
232 Self {
233 max_bytes: u64::MAX,
234 max_line_len: usize::MAX,
235 max_sv: usize::MAX,
236 max_nr_class: usize::MAX,
237 max_feature_index: i32::MAX,
238 }
239 }
240}
241
242fn read_line_capped(
253 reader: &mut dyn BufRead,
254 bytes_read: &mut u64,
255 max_bytes: u64,
256 max_line_len: usize,
257) -> std::io::Result<Option<String>> {
258 let per_line_raw_cap: u64 = (max_line_len as u64).saturating_add(1);
264
265 let mut buf: Vec<u8> = Vec::new();
266 let mut found_newline = false;
267
268 loop {
269 let available = reader.fill_buf()?;
270 if available.is_empty() {
271 break;
272 }
273
274 let take_n = match available.iter().position(|&b| b == b'\n') {
275 Some(pos) => pos + 1, None => available.len(),
277 };
278 let ends_with_newline = take_n > 0 && available[take_n - 1] == b'\n';
279
280 let new_bytes_read = bytes_read.saturating_add(take_n as u64);
282 if new_bytes_read > max_bytes {
283 return Err(std::io::Error::new(
284 std::io::ErrorKind::InvalidData,
285 format!("input exceeds max_bytes limit ({})", max_bytes),
286 ));
287 }
288
289 let prospective_len = (buf.len() as u64).saturating_add(take_n as u64);
292 let content_bytes = if ends_with_newline {
293 prospective_len - 1
294 } else {
295 prospective_len
296 };
297 if content_bytes > per_line_raw_cap {
298 return Err(std::io::Error::new(
299 std::io::ErrorKind::InvalidData,
300 format!("line length exceeds max_line_len limit ({})", max_line_len),
301 ));
302 }
303
304 if available[..take_n].contains(&0) {
306 return Err(std::io::Error::new(
307 std::io::ErrorKind::InvalidData,
308 "unexpected NUL byte in text input".to_string(),
309 ));
310 }
311
312 buf.extend_from_slice(&available[..take_n]);
313 reader.consume(take_n);
314 *bytes_read = new_bytes_read;
315
316 if ends_with_newline {
317 found_newline = true;
318 break;
319 }
320 }
321
322 if buf.is_empty() && !found_newline {
323 return Ok(None);
324 }
325
326 let line = String::from_utf8(buf)
327 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
328 Ok(Some(line))
329}
330
331pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
361 let file = std::fs::File::open(path)?;
362 let reader = std::io::BufReader::new(file);
363 load_problem_from_reader(reader)
364}
365
366pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
371 load_problem_from_reader_with_options(reader, &LoadOptions::default())
372}
373
374pub fn load_problem_from_reader_with_options(
380 mut reader: impl BufRead,
381 options: &LoadOptions,
382) -> Result<SvmProblem, SvmError> {
383 let mut labels = Vec::new();
384 let mut instances = Vec::new();
385 let mut bytes_read: u64 = 0;
386 let mut line_idx: usize = 0;
387
388 while let Some(raw) = read_line_capped(
389 &mut reader,
390 &mut bytes_read,
391 options.max_bytes,
392 options.max_line_len,
393 )? {
394 let line_num = line_idx + 1;
395 line_idx += 1;
396 let line = raw.trim();
397 if line.is_empty() {
398 continue;
399 }
400
401 let mut parts = line.split_whitespace();
402
403 let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
405 line: line_num,
406 message: "missing label".into(),
407 })?;
408 let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
409 line: line_num,
410 message: format!("invalid label: {}", label_str),
411 })?;
412
413 let mut nodes = Vec::new();
415 let mut prev_index: i32 = 0;
416 for token in parts {
417 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
418 line: line_num,
419 message: format!("expected index:value, got: {}", token),
420 })?;
421 let index: i32 =
422 parse_feature_index_problem_line(line_num, idx_str, options.max_feature_index)?;
423
424 if !nodes.is_empty() && index <= prev_index {
425 return Err(SvmError::ParseError {
426 line: line_num,
427 message: format!(
428 "feature indices must be ascending: {} follows {}",
429 index, prev_index
430 ),
431 });
432 }
433 let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
434 line: line_num,
435 message: format!("invalid value: {}", val_str),
436 })?;
437 prev_index = index;
438 nodes.push(SvmNode { index, value });
439 }
440
441 labels.push(label);
442 instances.push(nodes);
443 }
444
445 Ok(SvmProblem { labels, instances })
446}
447
448const MAX_NR_CLASS: usize = 65535;
451const MAX_TOTAL_SV: usize = 10_000_000;
452
453pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
455 let file = std::fs::File::create(path)?;
456 let writer = std::io::BufWriter::new(file);
457 save_model_to_writer(writer, model)
458}
459
460pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
462 let param = &model.param;
463
464 writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
465 writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
466
467 if param.kernel_type == KernelType::Polynomial {
468 writeln!(w, "degree {}", param.degree)?;
469 }
470 if matches!(
471 param.kernel_type,
472 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
473 ) {
474 writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
475 }
476 if matches!(
477 param.kernel_type,
478 KernelType::Polynomial | KernelType::Sigmoid
479 ) {
480 writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
481 }
482
483 let nr_class = model.nr_class;
484 writeln!(w, "nr_class {}", nr_class)?;
485 writeln!(w, "total_sv {}", model.sv.len())?;
486
487 write!(w, "rho")?;
489 for r in &model.rho {
490 write!(w, " {}", fmt_17g(*r))?;
491 }
492 writeln!(w)?;
493
494 if !model.label.is_empty() {
496 write!(w, "label")?;
497 for l in &model.label {
498 write!(w, " {}", l)?;
499 }
500 writeln!(w)?;
501 }
502
503 if !model.prob_a.is_empty() {
505 write!(w, "probA")?;
506 for v in &model.prob_a {
507 write!(w, " {}", fmt_17g(*v))?;
508 }
509 writeln!(w)?;
510 }
511
512 if !model.prob_b.is_empty() {
514 write!(w, "probB")?;
515 for v in &model.prob_b {
516 write!(w, " {}", fmt_17g(*v))?;
517 }
518 writeln!(w)?;
519 }
520
521 if !model.prob_density_marks.is_empty() {
523 write!(w, "prob_density_marks")?;
524 for v in &model.prob_density_marks {
525 write!(w, " {}", fmt_17g(*v))?;
526 }
527 writeln!(w)?;
528 }
529
530 if !model.n_sv.is_empty() {
532 write!(w, "nr_sv")?;
533 for n in &model.n_sv {
534 write!(w, " {}", n)?;
535 }
536 writeln!(w)?;
537 }
538
539 writeln!(w, "SV")?;
541 let num_sv = model.sv.len();
542 let num_coef_rows = model.sv_coef.len(); for i in 0..num_sv {
545 for j in 0..num_coef_rows {
547 write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
548 }
549 if model.param.kernel_type == KernelType::Precomputed {
551 if let Some(node) = model.sv[i].first() {
552 write!(w, "0:{} ", node.value as i32)?;
553 }
554 } else {
555 for node in &model.sv[i] {
556 write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
557 }
558 }
559 writeln!(w)?;
560 }
561
562 Ok(())
563}
564
565pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
598 let file = std::fs::File::open(path)?;
599 let reader = std::io::BufReader::new(file);
600 load_model_from_reader(reader)
601}
602
603pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
608 load_model_from_reader_with_options(reader, &LoadOptions::default())
609}
610
611pub fn load_model_from_reader_with_options(
617 mut reader: impl BufRead,
618 options: &LoadOptions,
619) -> Result<SvmModel, SvmError> {
620 let mut bytes_read: u64 = 0;
621
622 let nr_class_cap = options.max_nr_class.min(MAX_NR_CLASS);
628 let total_sv_cap = options.max_sv.min(MAX_TOTAL_SV);
629
630 let mut param = SvmParameter::default();
632 let mut nr_class: usize = 0;
633 let mut total_sv: usize = 0;
634 let mut rho: Vec<f64> = Vec::new();
635 let mut label: Vec<i32> = Vec::new();
636 let mut prob_a: Vec<f64> = Vec::new();
637 let mut prob_b: Vec<f64> = Vec::new();
638 let mut prob_density_marks: Vec<f64> = Vec::new();
639 let mut n_sv: Vec<usize> = Vec::new();
640
641 let mut line_num: usize = 0;
643 loop {
644 let raw = read_line_capped(
645 &mut reader,
646 &mut bytes_read,
647 options.max_bytes,
648 options.max_line_len,
649 )
650 .map_err(|e| SvmError::ModelFormatError(e.to_string()))?
651 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in header".into()))?;
652 line_num += 1;
653 let line = raw.trim().to_string();
654 if line.is_empty() {
655 continue;
656 }
657
658 let mut parts = line.split_whitespace();
659 let cmd = parts.next().ok_or_else(|| {
660 SvmError::ModelFormatError(format!("line {}: empty model header line", line_num))
661 })?;
662
663 match cmd {
664 "svm_type" => {
665 let val = parts.next().ok_or_else(|| {
666 SvmError::ModelFormatError(format!("line {}: missing svm_type value", line_num))
667 })?;
668 param.svm_type = str_to_svm_type(val).ok_or_else(|| {
669 SvmError::ModelFormatError(format!(
670 "line {}: unknown svm_type: {}",
671 line_num, val
672 ))
673 })?;
674 }
675 "kernel_type" => {
676 let val = parts.next().ok_or_else(|| {
677 SvmError::ModelFormatError(format!(
678 "line {}: missing kernel_type value",
679 line_num
680 ))
681 })?;
682 param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
683 SvmError::ModelFormatError(format!(
684 "line {}: unknown kernel_type: {}",
685 line_num, val
686 ))
687 })?;
688 }
689 "degree" => {
690 let d: i32 = parse_single(&mut parts, line_num, "degree")?;
691 if d < 0 {
692 return Err(SvmError::ModelFormatError(format!(
693 "line {}: degree must be >= 0, got {}",
694 line_num, d
695 )));
696 }
697 param.degree = d;
698 }
699 "gamma" => {
700 let v: f64 = parse_single(&mut parts, line_num, "gamma")?;
701 if !v.is_finite() {
702 return Err(SvmError::ModelFormatError(format!(
703 "line {}: gamma must be finite, got {}",
704 line_num, v
705 )));
706 }
707 param.gamma = v;
708 }
709 "coef0" => {
710 let v: f64 = parse_single(&mut parts, line_num, "coef0")?;
711 if !v.is_finite() {
712 return Err(SvmError::ModelFormatError(format!(
713 "line {}: coef0 must be finite, got {}",
714 line_num, v
715 )));
716 }
717 param.coef0 = v;
718 }
719 "nr_class" => {
720 nr_class = parse_single(&mut parts, line_num, "nr_class")?;
721 if nr_class > nr_class_cap {
722 return Err(SvmError::ModelFormatError(format!(
723 "line {}: nr_class exceeds limit ({})",
724 line_num, nr_class_cap
725 )));
726 }
727 }
728 "total_sv" => {
729 total_sv = parse_single(&mut parts, line_num, "total_sv")?;
730 if total_sv > total_sv_cap {
731 return Err(SvmError::ModelFormatError(format!(
732 "line {}: total_sv exceeds limit ({})",
733 line_num, total_sv_cap
734 )));
735 }
736 }
737 "rho" => {
738 rho = parse_multiple(&mut parts, line_num, "rho")?;
739 for &r in &rho {
740 if !r.is_finite() {
741 return Err(SvmError::ModelFormatError(format!(
742 "line {}: rho must be finite, got {}",
743 line_num, r
744 )));
745 }
746 }
747 }
748 "label" => {
749 label = parse_multiple(&mut parts, line_num, "label")?;
750 }
751 "probA" => {
752 prob_a = parse_multiple(&mut parts, line_num, "probA")?;
753 }
754 "probB" => {
755 prob_b = parse_multiple(&mut parts, line_num, "probB")?;
756 }
757 "prob_density_marks" => {
758 prob_density_marks = parse_multiple(&mut parts, line_num, "prob_density_marks")?;
759 }
760 "nr_sv" => {
761 n_sv = parts
762 .map(|s| {
763 s.parse::<usize>().map_err(|_| {
764 SvmError::ModelFormatError(format!(
765 "line {}: invalid nr_sv value: {}",
766 line_num, s
767 ))
768 })
769 })
770 .collect::<Result<Vec<_>, _>>()?;
771 }
772 "SV" => break,
773 _ => {
774 return Err(SvmError::ModelFormatError(format!(
775 "line {}: unknown keyword: {}",
776 line_num, cmd
777 )));
778 }
779 }
780 }
781
782 validate_model_header(
788 param.svm_type,
789 nr_class,
790 total_sv,
791 &rho,
792 &label,
793 &prob_a,
794 &prob_b,
795 &prob_density_marks,
796 &n_sv,
797 )?;
798
799 let m = if nr_class > 1 { nr_class - 1 } else { 1 };
807 let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::new()).collect();
808 let mut sv: Vec<Vec<SvmNode>> = Vec::new();
809
810 while sv.len() < total_sv {
811 let raw = read_line_capped(
812 &mut reader,
813 &mut bytes_read,
814 options.max_bytes,
815 options.max_line_len,
816 )
817 .map_err(|e| SvmError::ModelFormatError(e.to_string()))?
818 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in SV section".into()))?;
819 line_num += 1;
820 let line = raw.trim();
821 if line.is_empty() {
822 continue;
825 }
826
827 let mut parts = line.split_whitespace();
828
829 for (k, coef_row) in sv_coef.iter_mut().enumerate() {
831 let val_str = parts.next().ok_or_else(|| {
832 SvmError::ModelFormatError(format!("line {}: missing sv_coef[{}]", line_num, k))
833 })?;
834 let val: f64 = val_str.parse().map_err(|_| {
835 SvmError::ModelFormatError(format!(
836 "line {}: invalid sv_coef: {}",
837 line_num, val_str
838 ))
839 })?;
840 if !val.is_finite() {
841 return Err(SvmError::ModelFormatError(format!(
842 "line {}: sv_coef must be finite, got {}",
843 line_num, val_str
844 )));
845 }
846 coef_row.push(val);
847 }
848
849 let mut nodes = Vec::new();
852 let mut prev_index: i32 = 0;
853 for token in parts {
854 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
855 SvmError::ModelFormatError(format!(
856 "line {}: expected index:value, got: {}",
857 line_num, token
858 ))
859 })?;
860 let index: i32 =
861 parse_feature_index_model_line(line_num, idx_str, options.max_feature_index)?;
862
863 if !nodes.is_empty() && index <= prev_index {
864 return Err(SvmError::ModelFormatError(format!(
865 "line {}: feature indices must be ascending: {} follows {}",
866 line_num, index, prev_index
867 )));
868 }
869
870 let value: f64 = val_str.parse().map_err(|_| {
871 SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
872 })?;
873 if !value.is_finite() {
874 return Err(SvmError::ModelFormatError(format!(
875 "line {}: feature value must be finite, got {}",
876 line_num, val_str
877 )));
878 }
879 prev_index = index;
880 nodes.push(SvmNode { index, value });
881 }
882
883 if param.kernel_type == KernelType::Precomputed {
884 validate_precomputed_row(&nodes, line_num, "support vector")?;
885 }
886 sv.push(nodes);
887 }
888
889 Ok(SvmModel {
890 param,
891 nr_class,
892 sv,
893 sv_coef,
894 rho,
895 prob_a,
896 prob_b,
897 prob_density_marks,
898 sv_indices: Vec::new(), label,
900 n_sv,
901 })
902}
903
904fn validate_precomputed_row(
905 nodes: &[SvmNode],
906 line_num: usize,
907 context: &str,
908) -> Result<(), SvmError> {
909 let first = nodes.first().ok_or_else(|| {
910 SvmError::ModelFormatError(format!(
911 "line {}: precomputed kernel {} is missing 0:sample_serial_number",
912 line_num, context
913 ))
914 })?;
915
916 if first.index != 0
917 || !first.value.is_finite()
918 || first.value < 1.0
919 || first.value.fract() != 0.0
920 {
921 return Err(SvmError::ModelFormatError(format!(
922 "line {}: precomputed kernel {} must start with 0:sample_serial_number",
923 line_num, context
924 )));
925 }
926
927 Ok(())
928}
929
930#[allow(clippy::too_many_arguments)]
952fn validate_model_header(
953 svm_type: SvmType,
954 nr_class: usize,
955 total_sv: usize,
956 rho: &[f64],
957 label: &[i32],
958 prob_a: &[f64],
959 prob_b: &[f64],
960 prob_density_marks: &[f64],
961 n_sv: &[usize],
962) -> Result<(), SvmError> {
963 let is_classification = matches!(svm_type, SvmType::CSvc | SvmType::NuSvc);
964 let is_regression = matches!(svm_type, SvmType::EpsilonSvr | SvmType::NuSvr);
965 let is_one_class = matches!(svm_type, SvmType::OneClass);
966
967 if nr_class < 2 {
972 return Err(SvmError::ModelFormatError(format!(
973 "nr_class must be >= 2, got {}",
974 nr_class
975 )));
976 }
977
978 let expected_rho = if is_classification {
980 nr_class * (nr_class - 1) / 2
981 } else {
982 1
983 };
984 if rho.len() != expected_rho {
985 return Err(SvmError::ModelFormatError(format!(
986 "rho has {} entries, expected {} for svm_type {}",
987 rho.len(),
988 expected_rho,
989 svm_type_to_str(svm_type)
990 )));
991 }
992
993 if is_classification {
996 if label.len() != nr_class {
1000 return Err(SvmError::ModelFormatError(format!(
1001 "label has {} entries, expected nr_class ({}) for svm_type {}",
1002 label.len(),
1003 nr_class,
1004 svm_type_to_str(svm_type)
1005 )));
1006 }
1007 } else if !label.is_empty() {
1008 return Err(SvmError::ModelFormatError(format!(
1009 "label is only valid for classification, got {} entries on svm_type {}",
1010 label.len(),
1011 svm_type_to_str(svm_type)
1012 )));
1013 }
1014
1015 if is_classification {
1019 if n_sv.len() != nr_class {
1020 return Err(SvmError::ModelFormatError(format!(
1021 "nr_sv has {} entries, expected nr_class ({}) for svm_type {}",
1022 n_sv.len(),
1023 nr_class,
1024 svm_type_to_str(svm_type)
1025 )));
1026 }
1027 let mut sum: usize = 0;
1031 for &n in n_sv {
1032 sum = sum.checked_add(n).ok_or_else(|| {
1033 SvmError::ModelFormatError("nr_sv entries overflow usize when summed".into())
1034 })?;
1035 }
1036 if sum != total_sv {
1037 return Err(SvmError::ModelFormatError(format!(
1038 "sum of nr_sv entries ({}) does not match total_sv ({})",
1039 sum, total_sv
1040 )));
1041 }
1042 } else if !n_sv.is_empty() {
1043 return Err(SvmError::ModelFormatError(format!(
1044 "nr_sv is only valid for classification, got {} entries on svm_type {}",
1045 n_sv.len(),
1046 svm_type_to_str(svm_type)
1047 )));
1048 }
1049
1050 if !prob_a.is_empty() && prob_a.len() != expected_rho {
1052 return Err(SvmError::ModelFormatError(format!(
1053 "probA has {} entries, expected {}",
1054 prob_a.len(),
1055 expected_rho
1056 )));
1057 }
1058 if !prob_b.is_empty() && prob_b.len() != expected_rho {
1059 return Err(SvmError::ModelFormatError(format!(
1060 "probB has {} entries, expected {}",
1061 prob_b.len(),
1062 expected_rho
1063 )));
1064 }
1065
1066 if !prob_density_marks.is_empty() && !is_one_class {
1068 return Err(SvmError::ModelFormatError(format!(
1069 "prob_density_marks is only valid for one-class SVM, got {} entries on svm_type {}",
1070 prob_density_marks.len(),
1071 svm_type_to_str(svm_type)
1072 )));
1073 }
1074
1075 let _ = is_regression;
1079
1080 Ok(())
1081}
1082
1083fn parse_feature_index_problem_line(
1086 line_num: usize,
1087 idx_str: &str,
1088 max_feature_index: i32,
1089) -> Result<i32, SvmError> {
1090 parse_feature_index(idx_str, max_feature_index).map_err(|msg| SvmError::ParseError {
1091 line: line_num,
1092 message: msg,
1093 })
1094}
1095
1096fn parse_feature_index_model_line(
1097 line_num: usize,
1098 idx_str: &str,
1099 max_feature_index: i32,
1100) -> Result<i32, SvmError> {
1101 parse_feature_index(idx_str, max_feature_index)
1102 .map_err(|msg| SvmError::ModelFormatError(format!("line {}: {}", line_num, msg)))
1103}
1104
1105fn parse_single<T: std::str::FromStr>(
1106 parts: &mut std::str::SplitWhitespace<'_>,
1107 line_num: usize,
1108 field: &str,
1109) -> Result<T, SvmError> {
1110 let val_str = parts.next().ok_or_else(|| {
1111 SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
1112 })?;
1113 val_str.parse().map_err(|_| {
1114 SvmError::ModelFormatError(format!(
1115 "line {}: invalid {} value: {}",
1116 line_num, field, val_str
1117 ))
1118 })
1119}
1120
1121fn parse_multiple<T: std::str::FromStr>(
1122 parts: &mut std::str::SplitWhitespace<'_>,
1123 line_num: usize,
1124 field: &str,
1125) -> Result<Vec<T>, SvmError> {
1126 parts
1127 .map(|s| {
1128 s.parse::<T>().map_err(|_| {
1129 SvmError::ModelFormatError(format!(
1130 "line {}: invalid {} value: {}",
1131 line_num, field, s
1132 ))
1133 })
1134 })
1135 .collect()
1136}
1137
1138#[cfg(test)]
1141mod tests {
1142 use super::*;
1143 use std::path::PathBuf;
1144
1145 fn data_dir() -> PathBuf {
1146 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1147 .join("..")
1148 .join("..")
1149 .join("data")
1150 }
1151
1152 #[test]
1153 fn parse_heart_scale() {
1154 let path = data_dir().join("heart_scale");
1155 let problem = load_problem(&path).unwrap();
1156 assert_eq!(problem.labels.len(), 270);
1157 assert_eq!(problem.instances.len(), 270);
1158 assert_eq!(problem.labels[0], 1.0);
1160 assert_eq!(
1161 problem.instances[0][0],
1162 SvmNode {
1163 index: 1,
1164 value: 0.708333
1165 }
1166 );
1167 assert_eq!(problem.instances[0].len(), 12);
1168 }
1169
1170 #[test]
1171 fn parse_iris() {
1172 let path = data_dir().join("iris.scale");
1173 let problem = load_problem(&path).unwrap();
1174 assert_eq!(problem.labels.len(), 150);
1175 let classes: std::collections::HashSet<i64> =
1177 problem.labels.iter().map(|&l| l as i64).collect();
1178 assert_eq!(classes.len(), 3);
1179 }
1180
1181 #[test]
1182 fn parse_housing() {
1183 let path = data_dir().join("housing_scale");
1184 let problem = load_problem(&path).unwrap();
1185 assert_eq!(problem.labels.len(), 506);
1186 assert!((problem.labels[0] - 24.0).abs() < 1e-10);
1188 }
1189
1190 #[test]
1191 fn parse_empty_lines() {
1192 let input = b"+1 1:0.5\n\n-1 2:0.3\n";
1193 let problem = load_problem_from_reader(&input[..]).unwrap();
1194 assert_eq!(problem.labels.len(), 2);
1195 }
1196
1197 #[test]
1198 fn parse_error_unsorted_indices() {
1199 let input = b"+1 3:0.5 1:0.3\n";
1200 let result = load_problem_from_reader(&input[..]);
1201 assert!(result.is_err());
1202 let msg = format!("{}", result.unwrap_err());
1203 assert!(msg.contains("ascending"), "error: {}", msg);
1204 }
1205
1206 #[test]
1207 fn parse_error_duplicate_indices() {
1208 let input = b"+1 1:0.5 1:0.3\n";
1209 let result = load_problem_from_reader(&input[..]);
1210 assert!(result.is_err());
1211 }
1212
1213 #[test]
1214 fn parse_error_missing_colon() {
1215 let input = b"+1 1:0.5 bad_token\n";
1216 let result = load_problem_from_reader(&input[..]);
1217 assert!(result.is_err());
1218 }
1219
1220 #[test]
1221 #[allow(clippy::excessive_precision)]
1222 fn load_c_trained_model() {
1223 let path = data_dir().join("heart_scale.model");
1225 let model = load_model(&path).unwrap();
1226 assert_eq!(model.nr_class, 2);
1227 assert_eq!(model.param.svm_type, SvmType::CSvc);
1228 assert_eq!(model.param.kernel_type, KernelType::Rbf);
1229 assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
1230 assert_eq!(model.sv.len(), 132);
1231 assert_eq!(model.label, vec![1, -1]);
1232 assert_eq!(model.n_sv, vec![64, 68]);
1233 assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
1234 assert_eq!(model.sv_coef.len(), 1);
1236 assert_eq!(model.sv_coef[0].len(), 132);
1237 }
1238
1239 #[test]
1240 fn roundtrip_c_model() {
1241 let path = data_dir().join("heart_scale.model");
1243 let original_bytes = std::fs::read_to_string(&path).unwrap();
1244 let model = load_model(&path).unwrap();
1245
1246 let mut buf = Vec::new();
1247 save_model_to_writer(&mut buf, &model).unwrap();
1248 let rust_output = String::from_utf8(buf).unwrap();
1249
1250 let orig_lines: Vec<&str> = original_bytes.lines().collect();
1252 let rust_lines: Vec<&str> = rust_output.lines().collect();
1253 assert_eq!(
1254 orig_lines.len(),
1255 rust_lines.len(),
1256 "line count mismatch: C={} Rust={}",
1257 orig_lines.len(),
1258 rust_lines.len()
1259 );
1260 for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
1261 assert_eq!(
1262 o,
1263 r,
1264 "line {} differs:\n C: {:?}\n Rust: {:?}",
1265 i + 1,
1266 o,
1267 r
1268 );
1269 }
1270 }
1271
1272 #[test]
1273 #[allow(clippy::excessive_precision)]
1274 fn gfmt_matches_c_printf() {
1275 let cases: &[(f64, &str, &str)] = &[
1277 (0.5, "0.5", "0.5"),
1278 (-1.0, "-1", "-1"),
1279 (0.123456789012345, "0.123456789012345", "0.12345679"),
1280 (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
1281 (0.42446200000000001, "0.42446200000000001", "0.424462"),
1282 (0.0, "0", "0"),
1283 (1e-5, "1.0000000000000001e-05", "1e-05"),
1284 (1e-4, "0.0001", "0.0001"),
1285 (1e20, "1e+20", "1e+20"),
1286 (-0.25, "-0.25", "-0.25"),
1287 (0.75, "0.75", "0.75"),
1288 (0.708333, "0.70833299999999999", "0.708333"),
1289 (1.0, "1", "1"),
1290 ];
1291 for &(v, expected_17g, expected_8g) in cases {
1292 let got_17 = format!("{}", fmt_17g(v));
1293 let got_8 = format!("{}", fmt_8g(v));
1294 assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
1295 assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
1296 }
1297 }
1298
1299 #[test]
1300 #[allow(clippy::excessive_precision)]
1301 fn model_roundtrip() {
1302 let model = SvmModel {
1304 param: SvmParameter {
1305 svm_type: SvmType::CSvc,
1306 kernel_type: KernelType::Rbf,
1307 gamma: 0.5,
1308 ..Default::default()
1309 },
1310 nr_class: 2,
1311 sv: vec![
1312 vec![
1313 SvmNode {
1314 index: 1,
1315 value: 0.5,
1316 },
1317 SvmNode {
1318 index: 3,
1319 value: -1.0,
1320 },
1321 ],
1322 vec![
1323 SvmNode {
1324 index: 1,
1325 value: -0.25,
1326 },
1327 SvmNode {
1328 index: 2,
1329 value: 0.75,
1330 },
1331 ],
1332 ],
1333 sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
1334 rho: vec![0.42446200000000001],
1335 prob_a: vec![],
1336 prob_b: vec![],
1337 prob_density_marks: vec![],
1338 sv_indices: vec![],
1339 label: vec![1, -1],
1340 n_sv: vec![1, 1],
1341 };
1342
1343 let mut buf = Vec::new();
1344 save_model_to_writer(&mut buf, &model).unwrap();
1345
1346 let loaded = load_model_from_reader(&buf[..]).unwrap();
1347
1348 assert_eq!(loaded.nr_class, model.nr_class);
1349 assert_eq!(loaded.param.svm_type, model.param.svm_type);
1350 assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
1351 assert_eq!(loaded.sv.len(), model.sv.len());
1352 assert_eq!(loaded.label, model.label);
1353 assert_eq!(loaded.n_sv, model.n_sv);
1354 assert_eq!(loaded.rho.len(), model.rho.len());
1355 for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
1357 assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
1358 }
1359 for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
1361 for (a, b) in row_a.iter().zip(row_b.iter()) {
1362 assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
1363 }
1364 }
1365 }
1366
1367 #[test]
1368 fn parse_error_excessive_counts() {
1369 let input =
1370 b"svm_type c_svc\nkernel_type linear\nnr_class 1000000\ntotal_sv 100\nrho 0\nSV\n";
1371 let result = load_model_from_reader(&input[..]);
1372 assert!(result.is_err());
1373 assert!(format!("{}", result.unwrap_err()).contains("nr_class exceeds limit"));
1374
1375 let input =
1376 b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 100000000\nrho 0\nSV\n";
1377 let result = load_model_from_reader(&input[..]);
1378 assert!(result.is_err());
1379 assert!(format!("{}", result.unwrap_err()).contains("total_sv exceeds limit"));
1380 }
1381
1382 #[test]
1383 fn parse_error_excessive_feature_index() {
1384 let input = b"1 10000001:1\n";
1386 let result = load_problem_from_reader(&input[..]);
1387 assert!(result.is_err());
1388 assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
1389
1390 let input = b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 1\nrho 0\nlabel 1 -1\nnr_sv 1 0\nSV\n0.1 10000001:1\n";
1392 let result = load_model_from_reader(&input[..]);
1393 assert!(result.is_err());
1394 assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
1395 }
1396
1397 #[test]
1398 fn parse_error_unknown_model_keyword() {
1399 let input = b"bad_key value\n";
1400 let result = load_model_from_reader(&input[..]);
1401 assert!(result.is_err());
1402 assert!(format!("{}", result.unwrap_err()).contains("unknown keyword"));
1403 }
1404
1405 #[test]
1406 fn parse_error_missing_or_unknown_model_values() {
1407 let missing = b"svm_type\n";
1408 let err = load_model_from_reader(&missing[..]).unwrap_err();
1409 assert!(format!("{}", err).contains("missing svm_type value"));
1410
1411 let unknown = b"svm_type unknown_type\n";
1412 let err = load_model_from_reader(&unknown[..]).unwrap_err();
1413 assert!(format!("{}", err).contains("unknown svm_type"));
1414 }
1415
1416 #[test]
1417 fn parse_error_invalid_nr_sv_entry() {
1418 let input = b"svm_type c_svc\n\
1419kernel_type linear\n\
1420nr_class 2\n\
1421total_sv 1\n\
1422rho 0\n\
1423label 1 -1\n\
1424nr_sv a 1\n\
1425SV\n\
14260.1 1:0.5\n";
1427 let err = load_model_from_reader(&input[..]).unwrap_err();
1428 assert!(format!("{}", err).contains("invalid nr_sv value"));
1429 }
1430
1431 #[test]
1432 fn parse_error_in_sv_section_tokens() {
1433 let missing_coef = b"svm_type c_svc\n\
1434kernel_type linear\n\
1435nr_class 2\n\
1436total_sv 1\n\
1437rho 0\n\
1438label 1 -1\n\
1439nr_sv 1 0\n\
1440SV\n\
14411:0.5\n";
1442 let err = load_model_from_reader(&missing_coef[..]).unwrap_err();
1443 assert!(format!("{}", err).contains("invalid sv_coef"));
1444
1445 let bad_feature = b"svm_type c_svc\n\
1446kernel_type linear\n\
1447nr_class 2\n\
1448total_sv 1\n\
1449rho 0\n\
1450label 1 -1\n\
1451nr_sv 1 0\n\
1452SV\n\
14530.1 bad\n";
1454 let err = load_model_from_reader(&bad_feature[..]).unwrap_err();
1455 assert!(format!("{}", err).contains("expected index:value"));
1456 }
1457
1458 #[test]
1459 fn parse_error_unexpected_eof_in_header_and_sv_section() {
1460 let eof_header = b"svm_type c_svc\n";
1461 let err = load_model_from_reader(&eof_header[..]).unwrap_err();
1462 assert!(format!("{}", err).contains("unexpected end of file in header"));
1463
1464 let eof_sv = b"svm_type c_svc\n\
1465kernel_type linear\n\
1466nr_class 2\n\
1467total_sv 2\n\
1468rho 0\n\
1469label 1 -1\n\
1470nr_sv 1 1\n\
1471SV\n\
14720.1 1:0.5\n";
1473 let err = load_model_from_reader(&eof_sv[..]).unwrap_err();
1474 assert!(format!("{}", err).contains("unexpected end of file in SV section"));
1475 }
1476
1477 #[test]
1478 fn reject_rho_length_mismatch_for_classification() {
1479 let input = b"svm_type c_svc\n\
1482kernel_type linear\n\
1483nr_class 3\n\
1484total_sv 3\n\
1485rho 0\n\
1486label 1 -1 0\n\
1487nr_sv 1 1 1\n\
1488SV\n";
1489 let err = load_model_from_reader(&input[..]).unwrap_err();
1490 assert!(
1491 format!("{}", err).contains("rho has 1 entries, expected 3"),
1492 "unexpected error: {}",
1493 err
1494 );
1495 }
1496
1497 #[test]
1498 fn reject_rho_length_mismatch_for_regression() {
1499 let input = b"svm_type epsilon_svr\n\
1501kernel_type linear\n\
1502nr_class 2\n\
1503total_sv 0\n\
1504rho 0 1\n\
1505SV\n";
1506 let err = load_model_from_reader(&input[..]).unwrap_err();
1507 assert!(
1508 format!("{}", err).contains("rho has 2 entries, expected 1"),
1509 "unexpected error: {}",
1510 err
1511 );
1512 }
1513
1514 #[test]
1515 fn reject_label_on_regression() {
1516 let input = b"svm_type epsilon_svr\n\
1518kernel_type linear\n\
1519nr_class 2\n\
1520total_sv 0\n\
1521rho 0\n\
1522label 1 -1\n\
1523SV\n";
1524 let err = load_model_from_reader(&input[..]).unwrap_err();
1525 assert!(
1526 format!("{}", err).contains("label is only valid for classification"),
1527 "unexpected error: {}",
1528 err
1529 );
1530 }
1531
1532 #[test]
1533 fn reject_label_length_mismatch() {
1534 let input = b"svm_type c_svc\n\
1535kernel_type linear\n\
1536nr_class 3\n\
1537total_sv 0\n\
1538rho 0 0 0\n\
1539label 1 -1\n\
1540SV\n";
1541 let err = load_model_from_reader(&input[..]).unwrap_err();
1542 assert!(
1543 format!("{}", err).contains("label has 2 entries, expected nr_class (3)"),
1544 "unexpected error: {}",
1545 err
1546 );
1547 }
1548
1549 #[test]
1550 fn reject_nr_sv_sum_mismatch() {
1551 let input = b"svm_type c_svc\n\
1554kernel_type linear\n\
1555nr_class 2\n\
1556total_sv 5\n\
1557rho 0\n\
1558label 1 -1\n\
1559nr_sv 1 2\n\
1560SV\n";
1561 let err = load_model_from_reader(&input[..]).unwrap_err();
1562 assert!(
1563 format!("{}", err).contains("sum of nr_sv entries (3) does not match total_sv (5)"),
1564 "unexpected error: {}",
1565 err
1566 );
1567 }
1568
1569 #[test]
1570 fn reject_nr_sv_length_mismatch() {
1571 let input = b"svm_type c_svc\n\
1572kernel_type linear\n\
1573nr_class 3\n\
1574total_sv 3\n\
1575rho 0 0 0\n\
1576label 1 -1 0\n\
1577nr_sv 1 2\n\
1578SV\n";
1579 let err = load_model_from_reader(&input[..]).unwrap_err();
1580 assert!(
1581 format!("{}", err).contains("nr_sv has 2 entries, expected nr_class (3)"),
1582 "unexpected error: {}",
1583 err
1584 );
1585 }
1586
1587 #[test]
1588 fn reject_proba_length_mismatch() {
1589 let input = b"svm_type c_svc\n\
1590kernel_type linear\n\
1591nr_class 3\n\
1592total_sv 0\n\
1593rho 0 0 0\n\
1594label 1 -1 0\n\
1595nr_sv 0 0 0\n\
1596probA 0.1 0.2\n\
1597SV\n";
1598 let err = load_model_from_reader(&input[..]).unwrap_err();
1599 assert!(
1600 format!("{}", err).contains("probA has 2 entries, expected 3"),
1601 "unexpected error: {}",
1602 err
1603 );
1604 }
1605
1606 #[test]
1607 fn reject_prob_density_marks_on_csvc() {
1608 let input = b"svm_type c_svc\n\
1609kernel_type linear\n\
1610nr_class 2\n\
1611total_sv 0\n\
1612rho 0\n\
1613label 1 -1\n\
1614nr_sv 0 0\n\
1615prob_density_marks 0.1 0.2\n\
1616SV\n";
1617 let err = load_model_from_reader(&input[..]).unwrap_err();
1618 assert!(
1619 format!("{}", err).contains("prob_density_marks is only valid for one-class SVM"),
1620 "unexpected error: {}",
1621 err
1622 );
1623 }
1624
1625 #[test]
1626 fn reject_nr_class_below_two() {
1627 let input = b"svm_type c_svc\n\
1628kernel_type linear\n\
1629nr_class 1\n\
1630total_sv 0\n\
1631rho\n\
1632SV\n";
1633 let err = load_model_from_reader(&input[..]).unwrap_err();
1634 assert!(
1635 format!("{}", err).contains("nr_class must be >= 2, got 1"),
1636 "unexpected error: {}",
1637 err
1638 );
1639 }
1640
1641 #[test]
1642 fn reject_sv_feature_indices_not_ascending() {
1643 let input = b"svm_type c_svc\n\
1644kernel_type linear\n\
1645nr_class 2\n\
1646total_sv 1\n\
1647rho 0\n\
1648label 1 -1\n\
1649nr_sv 1 0\n\
1650SV\n\
16510.1 3:0.5 1:0.3\n";
1652 let err = load_model_from_reader(&input[..]).unwrap_err();
1653 assert!(
1654 format!("{}", err).contains("feature indices must be ascending"),
1655 "unexpected error: {}",
1656 err
1657 );
1658 }
1659
1660 #[test]
1661 fn reject_precomputed_model_sv_without_sample_serial_number() {
1662 let input = b"svm_type c_svc\n\
1663kernel_type precomputed\n\
1664nr_class 2\n\
1665total_sv 1\n\
1666rho 0\n\
1667label 1 -1\n\
1668nr_sv 1 0\n\
1669SV\n\
16700.1\n";
1671 let err = load_model_from_reader(&input[..]).unwrap_err();
1672 assert!(
1673 format!("{}", err).contains("missing 0:sample_serial_number"),
1674 "unexpected error: {}",
1675 err
1676 );
1677 }
1678
1679 #[test]
1680 fn reject_sv_feature_index_duplicated() {
1681 let input = b"svm_type c_svc\n\
1682kernel_type linear\n\
1683nr_class 2\n\
1684total_sv 1\n\
1685rho 0\n\
1686label 1 -1\n\
1687nr_sv 1 0\n\
1688SV\n\
16890.1 1:0.5 1:0.3\n";
1690 let err = load_model_from_reader(&input[..]).unwrap_err();
1691 assert!(
1692 format!("{}", err).contains("feature indices must be ascending"),
1693 "unexpected error: {}",
1694 err
1695 );
1696 }
1697
1698 #[test]
1699 fn load_options_default_caps_match_documented_values() {
1700 let opts = LoadOptions::default();
1701 assert_eq!(opts.max_bytes, 64 * 1024 * 1024);
1702 assert_eq!(opts.max_line_len, 1024 * 1024);
1703 assert_eq!(opts.max_sv, MAX_TOTAL_SV);
1704 assert_eq!(opts.max_nr_class, MAX_NR_CLASS);
1705 assert_eq!(opts.max_feature_index, MAX_FEATURE_INDEX);
1706 }
1707
1708 #[test]
1709 fn load_options_trusted_input_sets_type_maxes() {
1710 let opts = LoadOptions::trusted_input();
1711 assert_eq!(opts.max_bytes, u64::MAX);
1712 assert_eq!(opts.max_line_len, usize::MAX);
1713 assert_eq!(opts.max_sv, usize::MAX);
1714 assert_eq!(opts.max_nr_class, usize::MAX);
1715 assert_eq!(opts.max_feature_index, i32::MAX);
1716 }
1717
1718 #[test]
1719 fn problem_reader_rejects_file_over_max_bytes() {
1720 let input = b"+1 1:0.5\n+1 2:0.5\n";
1722 let opts = LoadOptions {
1723 max_bytes: 10,
1724 ..LoadOptions::default()
1725 };
1726 let err = load_problem_from_reader_with_options(&input[..], &opts).unwrap_err();
1727 assert!(
1728 format!("{}", err).contains("max_bytes"),
1729 "unexpected error: {}",
1730 err
1731 );
1732 }
1733
1734 #[test]
1735 fn problem_reader_rejects_line_over_max_line_len() {
1736 let mut payload = String::from("+1 ");
1738 for i in 1..=50 {
1739 payload.push_str(&format!("{}:0.1 ", i));
1740 }
1741 payload.push('\n');
1742 let opts = LoadOptions {
1743 max_line_len: 50,
1744 ..LoadOptions::default()
1745 };
1746 let err = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap_err();
1747 assert!(
1748 format!("{}", err).contains("max_line_len"),
1749 "unexpected error: {}",
1750 err
1751 );
1752 }
1753
1754 #[test]
1755 fn problem_reader_accepts_line_at_max_line_len() {
1756 let line_content = "+1 1:0.5";
1759 let payload = format!("{}\n", line_content);
1760 let opts = LoadOptions {
1761 max_line_len: line_content.len(),
1762 ..LoadOptions::default()
1763 };
1764 let problem = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap();
1765 assert_eq!(problem.labels.len(), 1);
1766 }
1767
1768 #[test]
1769 fn problem_reader_tolerates_crlf_at_cap() {
1770 let line_content = "+1 1:0.5";
1773 let payload = format!("{}\r\n", line_content);
1774 let opts = LoadOptions {
1775 max_line_len: line_content.len(),
1776 ..LoadOptions::default()
1777 };
1778 let problem = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap();
1779 assert_eq!(problem.labels.len(), 1);
1780 }
1781
1782 #[test]
1783 fn problem_reader_rejects_nul_byte() {
1784 let mut payload: Vec<u8> = b"+1 1:0.5".to_vec();
1787 payload.push(0);
1788 payload.extend_from_slice(b"\n");
1789 let err = load_problem_from_reader(payload.as_slice()).unwrap_err();
1790 assert!(
1791 format!("{}", err).contains("NUL byte"),
1792 "unexpected error: {}",
1793 err
1794 );
1795 }
1796
1797 #[test]
1798 fn model_reader_honors_max_nr_class_cap() {
1799 let input = b"svm_type c_svc\n\
1802kernel_type linear\n\
1803nr_class 100\n\
1804total_sv 1\n\
1805rho 0\n\
1806SV\n";
1807 let opts = LoadOptions {
1808 max_nr_class: 50,
1809 ..LoadOptions::default()
1810 };
1811 let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
1812 assert!(
1813 format!("{}", err).contains("nr_class exceeds limit (50)"),
1814 "unexpected error: {}",
1815 err
1816 );
1817 }
1818
1819 #[test]
1820 fn model_reader_honors_max_sv_cap() {
1821 let input = b"svm_type c_svc\n\
1822kernel_type linear\n\
1823nr_class 2\n\
1824total_sv 1000\n\
1825rho 0\n\
1826SV\n";
1827 let opts = LoadOptions {
1828 max_sv: 100,
1829 ..LoadOptions::default()
1830 };
1831 let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
1832 assert!(
1833 format!("{}", err).contains("total_sv exceeds limit (100)"),
1834 "unexpected error: {}",
1835 err
1836 );
1837 }
1838
1839 #[test]
1840 fn trusted_input_cannot_exceed_hard_module_caps() {
1841 let huge_nr_class = format!(
1845 "svm_type c_svc\n\
1846kernel_type linear\n\
1847nr_class {}\n\
1848total_sv 1\n\
1849rho 0\n\
1850SV\n",
1851 MAX_NR_CLASS + 1
1852 );
1853 let opts = LoadOptions::trusted_input();
1854 let err = load_model_from_reader_with_options(huge_nr_class.as_bytes(), &opts).unwrap_err();
1855 assert!(
1856 format!("{}", err).contains("nr_class exceeds limit"),
1857 "unexpected error: {}",
1858 err
1859 );
1860 }
1861
1862 #[test]
1863 fn model_reader_honors_max_feature_index_cap() {
1864 let input = b"svm_type c_svc\n\
1865kernel_type linear\n\
1866nr_class 2\n\
1867total_sv 1\n\
1868rho 0\n\
1869label 1 -1\n\
1870nr_sv 1 0\n\
1871SV\n\
18720.1 50:0.5\n";
1873 let opts = LoadOptions {
1874 max_feature_index: 10,
1875 ..LoadOptions::default()
1876 };
1877 let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
1878 assert!(
1879 format!("{}", err).contains("feature index 50 exceeds limit (10)"),
1880 "unexpected error: {}",
1881 err
1882 );
1883 }
1884
1885 #[test]
1886 fn sv_count_loop_counts_nonblank_lines_only() {
1887 let input = b"svm_type c_svc\n\
1893kernel_type linear\n\
1894nr_class 2\n\
1895total_sv 2\n\
1896rho 0\n\
1897label 1 -1\n\
1898nr_sv 1 1\n\
1899SV\n\
1900\n\
19010.1 1:0.5\n\
1902\n\
1903-0.1 2:0.5\n";
1904 let model = load_model_from_reader(&input[..]).unwrap();
1905 assert_eq!(model.sv.len(), 2);
1906 assert_eq!(model.sv_coef[0].len(), 2);
1907 }
1908
1909 #[test]
1910 fn save_precomputed_model_writes_zero_index() {
1911 let model = SvmModel {
1912 param: SvmParameter {
1913 svm_type: SvmType::CSvc,
1914 kernel_type: KernelType::Precomputed,
1915 ..Default::default()
1916 },
1917 nr_class: 2,
1918 sv: vec![vec![SvmNode {
1919 index: 0,
1920 value: 7.0,
1921 }]],
1922 sv_coef: vec![vec![0.25]],
1923 rho: vec![0.0],
1924 prob_a: vec![],
1925 prob_b: vec![],
1926 prob_density_marks: vec![],
1927 sv_indices: vec![],
1928 label: vec![1, -1],
1929 n_sv: vec![1, 0],
1930 };
1931
1932 let mut buf = Vec::new();
1933 save_model_to_writer(&mut buf, &model).unwrap();
1934 let out = String::from_utf8(buf).unwrap();
1935 assert!(out.contains("kernel_type precomputed"));
1936 assert!(out.contains("0:7"));
1937 }
1938
1939 #[test]
1944 fn csvc_missing_label_and_nr_sv_returns_error_not_panic() {
1945 let input = b"svm_type c_svc\n\
1946kernel_type linear\n\
1947nr_class 2\n\
1948total_sv 1\n\
1949rho 0\n\
1950SV\n\
19511 1:1\n";
1952 let err = load_model_from_reader(&input[..]).unwrap_err();
1953 let msg = format!("{}", err);
1954 assert!(
1955 msg.contains("label has 0 entries, expected nr_class (2)"),
1956 "unexpected error: {}",
1957 msg
1958 );
1959 }
1960
1961 #[test]
1963 fn csvc_label_present_nr_sv_absent_returns_error() {
1964 let input = b"svm_type c_svc\n\
1965kernel_type linear\n\
1966nr_class 2\n\
1967total_sv 1\n\
1968rho 0\n\
1969label 1 -1\n\
1970SV\n\
19711 1:1\n";
1972 let err = load_model_from_reader(&input[..]).unwrap_err();
1973 let msg = format!("{}", err);
1974 assert!(
1975 msg.contains("nr_sv has 0 entries, expected nr_class (2)"),
1976 "unexpected error: {}",
1977 msg
1978 );
1979 }
1980
1981 #[test]
1985 fn one_class_without_label_and_nr_sv_loads_ok() {
1986 let input = b"svm_type one_class\n\
1988kernel_type rbf\n\
1989gamma 0.5\n\
1990nr_class 2\n\
1991total_sv 1\n\
1992rho -0.5\n\
1993SV\n\
19941 1:0.5\n";
1995 let model = load_model_from_reader(&input[..]).unwrap();
1996 assert_eq!(model.sv.len(), 1);
1997 assert_eq!(model.label, Vec::<i32>::new());
1998 assert_eq!(model.n_sv, Vec::<usize>::new());
1999 }
2000
2001 #[test]
2005 fn model_rho_inf_returns_error() {
2006 let input = b"svm_type c_svc\n\
2007kernel_type linear\n\
2008nr_class 2\n\
2009total_sv 1\n\
2010rho inf\n\
2011label 1 -1\n\
2012nr_sv 1 0\n\
2013SV\n\
20141 1:0.5\n";
2015 let err = load_model_from_reader(&input[..]).unwrap_err();
2016 let msg = format!("{}", err);
2017 assert!(
2018 msg.contains("rho must be finite"),
2019 "unexpected error: {}",
2020 msg
2021 );
2022 }
2023
2024 #[test]
2026 fn model_sv_feature_nan_returns_error() {
2027 let input = b"svm_type c_svc\n\
2028kernel_type linear\n\
2029nr_class 2\n\
2030total_sv 1\n\
2031rho 0\n\
2032label 1 -1\n\
2033nr_sv 1 0\n\
2034SV\n\
20351 1:nan\n";
2036 let err = load_model_from_reader(&input[..]).unwrap_err();
2037 let msg = format!("{}", err);
2038 assert!(
2039 msg.contains("feature value must be finite"),
2040 "unexpected error: {}",
2041 msg
2042 );
2043 }
2044}