1use std::io::{BufRead, Write};
20use std::path::Path;
21
22use crate::error::SvmError;
23use crate::types::*;
24use crate::util::parse_feature_index;
25use crate::util::MAX_FEATURE_INDEX;
26
27use std::fmt;
37
38struct Gfmt {
40 value: f64,
41 precision: usize,
42}
43
44impl Gfmt {
45 fn new(value: f64, precision: usize) -> Self {
46 Self { value, precision }
47 }
48}
49
50impl fmt::Display for Gfmt {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 let v = self.value;
53 let p = self.precision;
54
55 if !v.is_finite() {
56 return write!(f, "{}", v); }
58
59 if v == 0.0 {
60 if v.is_sign_negative() {
62 return write!(f, "-0");
63 }
64 return write!(f, "0");
65 }
66
67 let abs_v = v.abs();
69 let exp = abs_v.log10().floor() as i32;
70
71 if exp < -4 || exp >= p as i32 {
72 let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
74 if let Some((mantissa, exponent)) = s.split_once('e') {
77 let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
78 let exp_val: i32 = exponent.parse().unwrap_or(0);
80 let exp_str = if exp_val < 0 {
81 format!("-{:02}", -exp_val)
82 } else {
83 format!("+{:02}", exp_val)
84 };
85 write!(f, "{}e{}", mantissa, exp_str)
86 } else {
87 write!(f, "{}", s)
88 }
89 } else {
90 let decimal_places = if exp >= 0 {
92 p.saturating_sub((exp + 1) as usize)
93 } else {
94 p + (-1 - exp) as usize
95 };
96 let s = format!("{:.prec$}", v, prec = decimal_places);
97 let s = s.trim_end_matches('0').trim_end_matches('.');
98 write!(f, "{}", s)
99 }
100 }
101}
102
103fn fmt_17g(v: f64) -> Gfmt {
105 Gfmt::new(v, 17)
106}
107
108fn fmt_8g(v: f64) -> Gfmt {
110 Gfmt::new(v, 8)
111}
112
113pub fn format_g(v: f64) -> String {
115 format!("{}", Gfmt::new(v, 6))
116}
117
118pub fn format_17g(v: f64) -> String {
120 format!("{}", Gfmt::new(v, 17))
121}
122
123const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
126const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
127
128fn svm_type_to_str(t: SvmType) -> &'static str {
129 SVM_TYPE_TABLE[t as usize]
130}
131
132fn kernel_type_to_str(t: KernelType) -> &'static str {
133 KERNEL_TYPE_TABLE[t as usize]
134}
135
136fn str_to_svm_type(s: &str) -> Option<SvmType> {
137 match s {
138 "c_svc" => Some(SvmType::CSvc),
139 "nu_svc" => Some(SvmType::NuSvc),
140 "one_class" => Some(SvmType::OneClass),
141 "epsilon_svr" => Some(SvmType::EpsilonSvr),
142 "nu_svr" => Some(SvmType::NuSvr),
143 _ => None,
144 }
145}
146
147fn str_to_kernel_type(s: &str) -> Option<KernelType> {
148 match s {
149 "linear" => Some(KernelType::Linear),
150 "polynomial" => Some(KernelType::Polynomial),
151 "rbf" => Some(KernelType::Rbf),
152 "sigmoid" => Some(KernelType::Sigmoid),
153 "precomputed" => Some(KernelType::Precomputed),
154 _ => None,
155 }
156}
157
158#[derive(Debug, Clone, Copy)]
199pub struct LoadOptions {
200 pub max_bytes: u64,
202 pub max_line_len: usize,
204 pub max_sv: usize,
206 pub max_nr_class: usize,
208 pub max_feature_index: i32,
210}
211
212impl Default for LoadOptions {
213 fn default() -> Self {
214 Self {
215 max_bytes: 64 * 1024 * 1024,
216 max_line_len: 1024 * 1024,
217 max_sv: MAX_TOTAL_SV,
218 max_nr_class: MAX_NR_CLASS,
219 max_feature_index: MAX_FEATURE_INDEX,
220 }
221 }
222}
223
224impl LoadOptions {
225 pub fn trusted_input() -> Self {
232 Self {
233 max_bytes: u64::MAX,
234 max_line_len: usize::MAX,
235 max_sv: usize::MAX,
236 max_nr_class: usize::MAX,
237 max_feature_index: i32::MAX,
238 }
239 }
240}
241
242fn read_line_capped(
253 reader: &mut dyn BufRead,
254 bytes_read: &mut u64,
255 max_bytes: u64,
256 max_line_len: usize,
257) -> std::io::Result<Option<String>> {
258 let per_line_raw_cap: u64 = (max_line_len as u64).saturating_add(1);
264
265 let mut buf: Vec<u8> = Vec::new();
266 let mut found_newline = false;
267
268 loop {
269 let available = reader.fill_buf()?;
270 if available.is_empty() {
271 break;
272 }
273
274 let take_n = match available.iter().position(|&b| b == b'\n') {
275 Some(pos) => pos + 1, None => available.len(),
277 };
278 let ends_with_newline = take_n > 0 && available[take_n - 1] == b'\n';
279
280 let new_bytes_read = bytes_read.saturating_add(take_n as u64);
282 if new_bytes_read > max_bytes {
283 return Err(std::io::Error::new(
284 std::io::ErrorKind::InvalidData,
285 format!("input exceeds max_bytes limit ({})", max_bytes),
286 ));
287 }
288
289 let prospective_len = (buf.len() as u64).saturating_add(take_n as u64);
292 let content_bytes = if ends_with_newline {
293 prospective_len - 1
294 } else {
295 prospective_len
296 };
297 if content_bytes > per_line_raw_cap {
298 return Err(std::io::Error::new(
299 std::io::ErrorKind::InvalidData,
300 format!("line length exceeds max_line_len limit ({})", max_line_len),
301 ));
302 }
303
304 if available[..take_n].contains(&0) {
306 return Err(std::io::Error::new(
307 std::io::ErrorKind::InvalidData,
308 "unexpected NUL byte in text input".to_string(),
309 ));
310 }
311
312 buf.extend_from_slice(&available[..take_n]);
313 reader.consume(take_n);
314 *bytes_read = new_bytes_read;
315
316 if ends_with_newline {
317 found_newline = true;
318 break;
319 }
320 }
321
322 if buf.is_empty() && !found_newline {
323 return Ok(None);
324 }
325
326 let line = String::from_utf8(buf)
327 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
328 Ok(Some(line))
329}
330
331pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
361 let file = std::fs::File::open(path)?;
362 let reader = std::io::BufReader::new(file);
363 load_problem_from_reader(reader)
364}
365
366pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
371 load_problem_from_reader_with_options(reader, &LoadOptions::default())
372}
373
374pub fn load_problem_from_reader_with_options(
380 mut reader: impl BufRead,
381 options: &LoadOptions,
382) -> Result<SvmProblem, SvmError> {
383 let mut labels = Vec::new();
384 let mut instances = Vec::new();
385 let mut bytes_read: u64 = 0;
386 let mut line_idx: usize = 0;
387
388 while let Some(raw) = read_line_capped(
389 &mut reader,
390 &mut bytes_read,
391 options.max_bytes,
392 options.max_line_len,
393 )? {
394 let line_num = line_idx + 1;
395 line_idx += 1;
396 let line = raw.trim();
397 if line.is_empty() {
398 continue;
399 }
400
401 let mut parts = line.split_whitespace();
402
403 let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
405 line: line_num,
406 message: "missing label".into(),
407 })?;
408 let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
409 line: line_num,
410 message: format!("invalid label: {}", label_str),
411 })?;
412
413 let mut nodes = Vec::new();
415 let mut prev_index: i32 = 0;
416 for token in parts {
417 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
418 line: line_num,
419 message: format!("expected index:value, got: {}", token),
420 })?;
421 let index: i32 =
422 parse_feature_index_problem_line(line_num, idx_str, options.max_feature_index)?;
423
424 if !nodes.is_empty() && index <= prev_index {
425 return Err(SvmError::ParseError {
426 line: line_num,
427 message: format!(
428 "feature indices must be ascending: {} follows {}",
429 index, prev_index
430 ),
431 });
432 }
433 let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
434 line: line_num,
435 message: format!("invalid value: {}", val_str),
436 })?;
437 prev_index = index;
438 nodes.push(SvmNode { index, value });
439 }
440
441 labels.push(label);
442 instances.push(nodes);
443 }
444
445 Ok(SvmProblem { labels, instances })
446}
447
448const MAX_NR_CLASS: usize = 65535;
451const MAX_TOTAL_SV: usize = 10_000_000;
452
453pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
455 let file = std::fs::File::create(path)?;
456 let writer = std::io::BufWriter::new(file);
457 save_model_to_writer(writer, model)
458}
459
460pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
492 let param = &model.param;
493
494 writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
495 writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
496
497 if param.kernel_type == KernelType::Polynomial {
498 writeln!(w, "degree {}", param.degree)?;
499 }
500 if matches!(
501 param.kernel_type,
502 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
503 ) {
504 writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
505 }
506 if matches!(
507 param.kernel_type,
508 KernelType::Polynomial | KernelType::Sigmoid
509 ) {
510 writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
511 }
512
513 let nr_class = model.nr_class;
514 writeln!(w, "nr_class {}", nr_class)?;
515 writeln!(w, "total_sv {}", model.sv.len())?;
516
517 write!(w, "rho")?;
519 for r in &model.rho {
520 write!(w, " {}", fmt_17g(*r))?;
521 }
522 writeln!(w)?;
523
524 if !model.label.is_empty() {
526 write!(w, "label")?;
527 for l in &model.label {
528 write!(w, " {}", l)?;
529 }
530 writeln!(w)?;
531 }
532
533 if !model.prob_a.is_empty() {
535 write!(w, "probA")?;
536 for v in &model.prob_a {
537 write!(w, " {}", fmt_17g(*v))?;
538 }
539 writeln!(w)?;
540 }
541
542 if !model.prob_b.is_empty() {
544 write!(w, "probB")?;
545 for v in &model.prob_b {
546 write!(w, " {}", fmt_17g(*v))?;
547 }
548 writeln!(w)?;
549 }
550
551 if !model.prob_density_marks.is_empty() {
553 write!(w, "prob_density_marks")?;
554 for v in &model.prob_density_marks {
555 write!(w, " {}", fmt_17g(*v))?;
556 }
557 writeln!(w)?;
558 }
559
560 if !model.n_sv.is_empty() {
562 write!(w, "nr_sv")?;
563 for n in &model.n_sv {
564 write!(w, " {}", n)?;
565 }
566 writeln!(w)?;
567 }
568
569 writeln!(w, "SV")?;
571 let num_sv = model.sv.len();
572 let num_coef_rows = model.sv_coef.len(); for i in 0..num_sv {
575 for j in 0..num_coef_rows {
577 write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
578 }
579 if model.param.kernel_type == KernelType::Precomputed {
581 if let Some(node) = model.sv[i].first() {
582 write!(w, "0:{} ", node.value as i32)?;
583 }
584 } else {
585 for node in &model.sv[i] {
586 write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
587 }
588 }
589 writeln!(w)?;
590 }
591
592 Ok(())
593}
594
595pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
628 let file = std::fs::File::open(path)?;
629 let reader = std::io::BufReader::new(file);
630 load_model_from_reader(reader)
631}
632
633pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
666 load_model_from_reader_with_options(reader, &LoadOptions::default())
667}
668
669pub fn load_model_from_reader_with_options(
675 mut reader: impl BufRead,
676 options: &LoadOptions,
677) -> Result<SvmModel, SvmError> {
678 let mut bytes_read: u64 = 0;
679
680 let nr_class_cap = options.max_nr_class.min(MAX_NR_CLASS);
686 let total_sv_cap = options.max_sv.min(MAX_TOTAL_SV);
687
688 let mut param = SvmParameter::default();
690 let mut nr_class: usize = 0;
691 let mut total_sv: usize = 0;
692 let mut rho: Vec<f64> = Vec::new();
693 let mut label: Vec<i32> = Vec::new();
694 let mut prob_a: Vec<f64> = Vec::new();
695 let mut prob_b: Vec<f64> = Vec::new();
696 let mut prob_density_marks: Vec<f64> = Vec::new();
697 let mut n_sv: Vec<usize> = Vec::new();
698
699 let mut line_num: usize = 0;
701 loop {
702 let raw = read_line_capped(
703 &mut reader,
704 &mut bytes_read,
705 options.max_bytes,
706 options.max_line_len,
707 )
708 .map_err(|e| SvmError::ModelFormatError(e.to_string()))?
709 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in header".into()))?;
710 line_num += 1;
711 let line = raw.trim().to_string();
712 if line.is_empty() {
713 continue;
714 }
715
716 let mut parts = line.split_whitespace();
717 let cmd = parts.next().ok_or_else(|| {
718 SvmError::ModelFormatError(format!("line {}: empty model header line", line_num))
719 })?;
720
721 match cmd {
722 "svm_type" => {
723 let val = parts.next().ok_or_else(|| {
724 SvmError::ModelFormatError(format!("line {}: missing svm_type value", line_num))
725 })?;
726 param.svm_type = str_to_svm_type(val).ok_or_else(|| {
727 SvmError::ModelFormatError(format!(
728 "line {}: unknown svm_type: {}",
729 line_num, val
730 ))
731 })?;
732 }
733 "kernel_type" => {
734 let val = parts.next().ok_or_else(|| {
735 SvmError::ModelFormatError(format!(
736 "line {}: missing kernel_type value",
737 line_num
738 ))
739 })?;
740 param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
741 SvmError::ModelFormatError(format!(
742 "line {}: unknown kernel_type: {}",
743 line_num, val
744 ))
745 })?;
746 }
747 "degree" => {
748 let d: i32 = parse_single(&mut parts, line_num, "degree")?;
749 if d < 0 {
750 return Err(SvmError::ModelFormatError(format!(
751 "line {}: degree must be >= 0, got {}",
752 line_num, d
753 )));
754 }
755 param.degree = d;
756 }
757 "gamma" => {
758 let v: f64 = parse_single(&mut parts, line_num, "gamma")?;
759 if !v.is_finite() {
760 return Err(SvmError::ModelFormatError(format!(
761 "line {}: gamma must be finite, got {}",
762 line_num, v
763 )));
764 }
765 param.gamma = v;
766 }
767 "coef0" => {
768 let v: f64 = parse_single(&mut parts, line_num, "coef0")?;
769 if !v.is_finite() {
770 return Err(SvmError::ModelFormatError(format!(
771 "line {}: coef0 must be finite, got {}",
772 line_num, v
773 )));
774 }
775 param.coef0 = v;
776 }
777 "nr_class" => {
778 nr_class = parse_single(&mut parts, line_num, "nr_class")?;
779 if nr_class > nr_class_cap {
780 return Err(SvmError::ModelFormatError(format!(
781 "line {}: nr_class exceeds limit ({})",
782 line_num, nr_class_cap
783 )));
784 }
785 }
786 "total_sv" => {
787 total_sv = parse_single(&mut parts, line_num, "total_sv")?;
788 if total_sv > total_sv_cap {
789 return Err(SvmError::ModelFormatError(format!(
790 "line {}: total_sv exceeds limit ({})",
791 line_num, total_sv_cap
792 )));
793 }
794 }
795 "rho" => {
796 rho = parse_multiple(&mut parts, line_num, "rho")?;
797 for &r in &rho {
798 if !r.is_finite() {
799 return Err(SvmError::ModelFormatError(format!(
800 "line {}: rho must be finite, got {}",
801 line_num, r
802 )));
803 }
804 }
805 }
806 "label" => {
807 label = parse_multiple(&mut parts, line_num, "label")?;
808 }
809 "probA" => {
810 prob_a = parse_multiple(&mut parts, line_num, "probA")?;
811 }
812 "probB" => {
813 prob_b = parse_multiple(&mut parts, line_num, "probB")?;
814 }
815 "prob_density_marks" => {
816 prob_density_marks = parse_multiple(&mut parts, line_num, "prob_density_marks")?;
817 }
818 "nr_sv" => {
819 n_sv = parts
820 .map(|s| {
821 s.parse::<usize>().map_err(|_| {
822 SvmError::ModelFormatError(format!(
823 "line {}: invalid nr_sv value: {}",
824 line_num, s
825 ))
826 })
827 })
828 .collect::<Result<Vec<_>, _>>()?;
829 }
830 "SV" => break,
831 _ => {
832 return Err(SvmError::ModelFormatError(format!(
833 "line {}: unknown keyword: {}",
834 line_num, cmd
835 )));
836 }
837 }
838 }
839
840 validate_model_header(
846 param.svm_type,
847 nr_class,
848 total_sv,
849 &rho,
850 &label,
851 &prob_a,
852 &prob_b,
853 &prob_density_marks,
854 &n_sv,
855 )?;
856
857 let m = if nr_class > 1 { nr_class - 1 } else { 1 };
865 let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::new()).collect();
866 let mut sv: Vec<Vec<SvmNode>> = Vec::new();
867
868 while sv.len() < total_sv {
869 let raw = read_line_capped(
870 &mut reader,
871 &mut bytes_read,
872 options.max_bytes,
873 options.max_line_len,
874 )
875 .map_err(|e| SvmError::ModelFormatError(e.to_string()))?
876 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in SV section".into()))?;
877 line_num += 1;
878 let line = raw.trim();
879 if line.is_empty() {
880 continue;
883 }
884
885 let mut parts = line.split_whitespace();
886
887 for (k, coef_row) in sv_coef.iter_mut().enumerate() {
889 let val_str = parts.next().ok_or_else(|| {
890 SvmError::ModelFormatError(format!("line {}: missing sv_coef[{}]", line_num, k))
891 })?;
892 let val: f64 = val_str.parse().map_err(|_| {
893 SvmError::ModelFormatError(format!(
894 "line {}: invalid sv_coef: {}",
895 line_num, val_str
896 ))
897 })?;
898 if !val.is_finite() {
899 return Err(SvmError::ModelFormatError(format!(
900 "line {}: sv_coef must be finite, got {}",
901 line_num, val_str
902 )));
903 }
904 coef_row.push(val);
905 }
906
907 let mut nodes = Vec::new();
910 let mut prev_index: i32 = 0;
911 for token in parts {
912 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
913 SvmError::ModelFormatError(format!(
914 "line {}: expected index:value, got: {}",
915 line_num, token
916 ))
917 })?;
918 let index: i32 =
919 parse_feature_index_model_line(line_num, idx_str, options.max_feature_index)?;
920
921 if !nodes.is_empty() && index <= prev_index {
922 return Err(SvmError::ModelFormatError(format!(
923 "line {}: feature indices must be ascending: {} follows {}",
924 line_num, index, prev_index
925 )));
926 }
927
928 let value: f64 = val_str.parse().map_err(|_| {
929 SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
930 })?;
931 if !value.is_finite() {
932 return Err(SvmError::ModelFormatError(format!(
933 "line {}: feature value must be finite, got {}",
934 line_num, val_str
935 )));
936 }
937 prev_index = index;
938 nodes.push(SvmNode { index, value });
939 }
940
941 if param.kernel_type == KernelType::Precomputed {
942 validate_precomputed_row(&nodes, line_num, "support vector")?;
943 }
944 sv.push(nodes);
945 }
946
947 let model = SvmModel {
948 param,
949 nr_class,
950 sv,
951 sv_coef,
952 rho,
953 prob_a,
954 prob_b,
955 prob_density_marks,
956 sv_indices: Vec::new(), label,
958 n_sv,
959 };
960 validate_model(&model)?;
961 Ok(model)
962}
963
964pub(crate) fn validate_model(model: &SvmModel) -> Result<(), SvmError> {
966 if model.param.degree < 0 {
967 return Err(SvmError::ModelFormatError(format!(
968 "degree must be >= 0, got {}",
969 model.param.degree
970 )));
971 }
972 if !model.param.gamma.is_finite() {
973 return Err(SvmError::ModelFormatError(format!(
974 "gamma must be finite, got {}",
975 model.param.gamma
976 )));
977 }
978 if matches!(
979 model.param.kernel_type,
980 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
981 ) && model.param.gamma < 0.0
982 {
983 return Err(SvmError::ModelFormatError(format!(
984 "gamma must be >= 0, got {}",
985 model.param.gamma
986 )));
987 }
988 if !model.param.coef0.is_finite() {
989 return Err(SvmError::ModelFormatError(format!(
990 "coef0 must be finite, got {}",
991 model.param.coef0
992 )));
993 }
994
995 validate_model_header(
996 model.param.svm_type,
997 model.nr_class,
998 model.sv.len(),
999 &model.rho,
1000 &model.label,
1001 &model.prob_a,
1002 &model.prob_b,
1003 &model.prob_density_marks,
1004 &model.n_sv,
1005 )?;
1006
1007 let expected_rows = model.nr_class.saturating_sub(1).max(1);
1008 if model.sv_coef.len() != expected_rows {
1009 return Err(SvmError::ModelFormatError(format!(
1010 "sv_coef has {} rows, expected {}",
1011 model.sv_coef.len(),
1012 expected_rows
1013 )));
1014 }
1015 for (row_idx, row) in model.sv_coef.iter().enumerate() {
1016 if row.len() != model.sv.len() {
1017 return Err(SvmError::ModelFormatError(format!(
1018 "sv_coef row {} has {} entries, expected {}",
1019 row_idx,
1020 row.len(),
1021 model.sv.len()
1022 )));
1023 }
1024 for &coef in row {
1025 if !coef.is_finite() {
1026 return Err(SvmError::ModelFormatError(format!(
1027 "sv_coef must be finite, got {}",
1028 coef
1029 )));
1030 }
1031 }
1032 }
1033
1034 for &r in &model.rho {
1035 if !r.is_finite() {
1036 return Err(SvmError::ModelFormatError(format!(
1037 "rho must be finite, got {}",
1038 r
1039 )));
1040 }
1041 }
1042 for (name, values) in [("probA", &model.prob_a), ("probB", &model.prob_b)] {
1043 for &value in values {
1044 if !value.is_finite() {
1045 return Err(SvmError::ModelFormatError(format!(
1046 "{} must be finite, got {}",
1047 name, value
1048 )));
1049 }
1050 }
1051 }
1052 for &value in &model.prob_density_marks {
1053 if !value.is_finite() {
1054 return Err(SvmError::ModelFormatError(format!(
1055 "prob_density_marks must be finite, got {}",
1056 value
1057 )));
1058 }
1059 }
1060
1061 for (row_idx, nodes) in model.sv.iter().enumerate() {
1062 let mut prev_index = 0;
1063 for (node_idx, node) in nodes.iter().enumerate() {
1064 if node_idx > 0 && node.index <= prev_index {
1065 return Err(SvmError::ModelFormatError(format!(
1066 "support vector {} feature indices must be ascending: {} follows {}",
1067 row_idx + 1,
1068 node.index,
1069 prev_index
1070 )));
1071 }
1072 if !node.value.is_finite() {
1073 return Err(SvmError::ModelFormatError(format!(
1074 "feature value must be finite, got {}",
1075 node.value
1076 )));
1077 }
1078 prev_index = node.index;
1079 }
1080 if model.param.kernel_type == KernelType::Precomputed {
1081 validate_precomputed_row(nodes, row_idx + 1, "support vector")?;
1082 }
1083 }
1084
1085 Ok(())
1086}
1087
1088fn validate_precomputed_row(
1089 nodes: &[SvmNode],
1090 line_num: usize,
1091 context: &str,
1092) -> Result<(), SvmError> {
1093 let first = nodes.first().ok_or_else(|| {
1094 SvmError::ModelFormatError(format!(
1095 "line {}: precomputed kernel {} is missing 0:sample_serial_number",
1096 line_num, context
1097 ))
1098 })?;
1099
1100 if first.index != 0
1101 || !first.value.is_finite()
1102 || first.value < 1.0
1103 || first.value.fract() != 0.0
1104 {
1105 return Err(SvmError::ModelFormatError(format!(
1106 "line {}: precomputed kernel {} must start with 0:sample_serial_number",
1107 line_num, context
1108 )));
1109 }
1110
1111 Ok(())
1112}
1113
1114#[allow(clippy::too_many_arguments)]
1136fn validate_model_header(
1137 svm_type: SvmType,
1138 nr_class: usize,
1139 total_sv: usize,
1140 rho: &[f64],
1141 label: &[i32],
1142 prob_a: &[f64],
1143 prob_b: &[f64],
1144 prob_density_marks: &[f64],
1145 n_sv: &[usize],
1146) -> Result<(), SvmError> {
1147 let is_classification = matches!(svm_type, SvmType::CSvc | SvmType::NuSvc);
1148 let is_regression = matches!(svm_type, SvmType::EpsilonSvr | SvmType::NuSvr);
1149 let is_one_class = matches!(svm_type, SvmType::OneClass);
1150
1151 if nr_class < 2 {
1156 return Err(SvmError::ModelFormatError(format!(
1157 "nr_class must be >= 2, got {}",
1158 nr_class
1159 )));
1160 }
1161
1162 let expected_rho = if is_classification {
1164 nr_class * (nr_class - 1) / 2
1165 } else {
1166 1
1167 };
1168 if rho.len() != expected_rho {
1169 return Err(SvmError::ModelFormatError(format!(
1170 "rho has {} entries, expected {} for svm_type {}",
1171 rho.len(),
1172 expected_rho,
1173 svm_type_to_str(svm_type)
1174 )));
1175 }
1176
1177 if is_classification {
1180 if label.len() != nr_class {
1184 return Err(SvmError::ModelFormatError(format!(
1185 "label has {} entries, expected nr_class ({}) for svm_type {}",
1186 label.len(),
1187 nr_class,
1188 svm_type_to_str(svm_type)
1189 )));
1190 }
1191 } else if !label.is_empty() {
1192 return Err(SvmError::ModelFormatError(format!(
1193 "label is only valid for classification, got {} entries on svm_type {}",
1194 label.len(),
1195 svm_type_to_str(svm_type)
1196 )));
1197 }
1198
1199 if is_classification {
1203 if n_sv.len() != nr_class {
1204 return Err(SvmError::ModelFormatError(format!(
1205 "nr_sv has {} entries, expected nr_class ({}) for svm_type {}",
1206 n_sv.len(),
1207 nr_class,
1208 svm_type_to_str(svm_type)
1209 )));
1210 }
1211 let mut sum: usize = 0;
1215 for &n in n_sv {
1216 sum = sum.checked_add(n).ok_or_else(|| {
1217 SvmError::ModelFormatError("nr_sv entries overflow usize when summed".into())
1218 })?;
1219 }
1220 if sum != total_sv {
1221 return Err(SvmError::ModelFormatError(format!(
1222 "sum of nr_sv entries ({}) does not match total_sv ({})",
1223 sum, total_sv
1224 )));
1225 }
1226 } else if !n_sv.is_empty() {
1227 return Err(SvmError::ModelFormatError(format!(
1228 "nr_sv is only valid for classification, got {} entries on svm_type {}",
1229 n_sv.len(),
1230 svm_type_to_str(svm_type)
1231 )));
1232 }
1233
1234 if !prob_a.is_empty() && prob_a.len() != expected_rho {
1236 return Err(SvmError::ModelFormatError(format!(
1237 "probA has {} entries, expected {}",
1238 prob_a.len(),
1239 expected_rho
1240 )));
1241 }
1242 if !prob_b.is_empty() && prob_b.len() != expected_rho {
1243 return Err(SvmError::ModelFormatError(format!(
1244 "probB has {} entries, expected {}",
1245 prob_b.len(),
1246 expected_rho
1247 )));
1248 }
1249
1250 if !prob_density_marks.is_empty() && !is_one_class {
1252 return Err(SvmError::ModelFormatError(format!(
1253 "prob_density_marks is only valid for one-class SVM, got {} entries on svm_type {}",
1254 prob_density_marks.len(),
1255 svm_type_to_str(svm_type)
1256 )));
1257 }
1258
1259 let _ = is_regression;
1263
1264 Ok(())
1265}
1266
1267fn parse_feature_index_problem_line(
1270 line_num: usize,
1271 idx_str: &str,
1272 max_feature_index: i32,
1273) -> Result<i32, SvmError> {
1274 parse_feature_index(idx_str, max_feature_index).map_err(|msg| SvmError::ParseError {
1275 line: line_num,
1276 message: msg,
1277 })
1278}
1279
1280fn parse_feature_index_model_line(
1281 line_num: usize,
1282 idx_str: &str,
1283 max_feature_index: i32,
1284) -> Result<i32, SvmError> {
1285 parse_feature_index(idx_str, max_feature_index)
1286 .map_err(|msg| SvmError::ModelFormatError(format!("line {}: {}", line_num, msg)))
1287}
1288
1289fn parse_single<T: std::str::FromStr>(
1290 parts: &mut std::str::SplitWhitespace<'_>,
1291 line_num: usize,
1292 field: &str,
1293) -> Result<T, SvmError> {
1294 let val_str = parts.next().ok_or_else(|| {
1295 SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
1296 })?;
1297 val_str.parse().map_err(|_| {
1298 SvmError::ModelFormatError(format!(
1299 "line {}: invalid {} value: {}",
1300 line_num, field, val_str
1301 ))
1302 })
1303}
1304
1305fn parse_multiple<T: std::str::FromStr>(
1306 parts: &mut std::str::SplitWhitespace<'_>,
1307 line_num: usize,
1308 field: &str,
1309) -> Result<Vec<T>, SvmError> {
1310 parts
1311 .map(|s| {
1312 s.parse::<T>().map_err(|_| {
1313 SvmError::ModelFormatError(format!(
1314 "line {}: invalid {} value: {}",
1315 line_num, field, s
1316 ))
1317 })
1318 })
1319 .collect()
1320}
1321
1322#[cfg(test)]
1325mod tests {
1326 use super::*;
1327 use std::path::PathBuf;
1328
1329 fn data_dir() -> PathBuf {
1330 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1331 .join("..")
1332 .join("..")
1333 .join("data")
1334 }
1335
1336 #[test]
1337 fn parse_heart_scale() {
1338 let path = data_dir().join("heart_scale");
1339 let problem = load_problem(&path).unwrap();
1340 assert_eq!(problem.labels.len(), 270);
1341 assert_eq!(problem.instances.len(), 270);
1342 assert_eq!(problem.labels[0], 1.0);
1344 assert_eq!(
1345 problem.instances[0][0],
1346 SvmNode {
1347 index: 1,
1348 value: 0.708333
1349 }
1350 );
1351 assert_eq!(problem.instances[0].len(), 12);
1352 }
1353
1354 #[test]
1355 fn parse_iris() {
1356 let path = data_dir().join("iris.scale");
1357 let problem = load_problem(&path).unwrap();
1358 assert_eq!(problem.labels.len(), 150);
1359 let classes: std::collections::HashSet<i64> =
1361 problem.labels.iter().map(|&l| l as i64).collect();
1362 assert_eq!(classes.len(), 3);
1363 }
1364
1365 #[test]
1366 fn parse_housing() {
1367 let path = data_dir().join("housing_scale");
1368 let problem = load_problem(&path).unwrap();
1369 assert_eq!(problem.labels.len(), 506);
1370 assert!((problem.labels[0] - 24.0).abs() < 1e-10);
1372 }
1373
1374 #[test]
1375 fn parse_empty_lines() {
1376 let input = b"+1 1:0.5\n\n-1 2:0.3\n";
1377 let problem = load_problem_from_reader(&input[..]).unwrap();
1378 assert_eq!(problem.labels.len(), 2);
1379 }
1380
1381 #[test]
1382 fn parse_error_unsorted_indices() {
1383 let input = b"+1 3:0.5 1:0.3\n";
1384 let result = load_problem_from_reader(&input[..]);
1385 assert!(result.is_err());
1386 let msg = format!("{}", result.unwrap_err());
1387 assert!(msg.contains("ascending"), "error: {}", msg);
1388 }
1389
1390 #[test]
1391 fn parse_error_duplicate_indices() {
1392 let input = b"+1 1:0.5 1:0.3\n";
1393 let result = load_problem_from_reader(&input[..]);
1394 assert!(result.is_err());
1395 }
1396
1397 #[test]
1398 fn parse_error_missing_colon() {
1399 let input = b"+1 1:0.5 bad_token\n";
1400 let result = load_problem_from_reader(&input[..]);
1401 assert!(result.is_err());
1402 }
1403
1404 #[test]
1405 #[allow(clippy::excessive_precision)]
1406 fn load_c_trained_model() {
1407 let path = data_dir().join("heart_scale.model");
1409 let model = load_model(&path).unwrap();
1410 assert_eq!(model.nr_class, 2);
1411 assert_eq!(model.param.svm_type, SvmType::CSvc);
1412 assert_eq!(model.param.kernel_type, KernelType::Rbf);
1413 assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
1414 assert_eq!(model.sv.len(), 132);
1415 assert_eq!(model.label, vec![1, -1]);
1416 assert_eq!(model.n_sv, vec![64, 68]);
1417 assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
1418 assert_eq!(model.sv_coef.len(), 1);
1420 assert_eq!(model.sv_coef[0].len(), 132);
1421 }
1422
1423 #[test]
1424 fn roundtrip_c_model() {
1425 let path = data_dir().join("heart_scale.model");
1427 let original_bytes = std::fs::read_to_string(&path).unwrap();
1428 let model = load_model(&path).unwrap();
1429
1430 let mut buf = Vec::new();
1431 save_model_to_writer(&mut buf, &model).unwrap();
1432 let rust_output = String::from_utf8(buf).unwrap();
1433
1434 let orig_lines: Vec<&str> = original_bytes.lines().collect();
1436 let rust_lines: Vec<&str> = rust_output.lines().collect();
1437 assert_eq!(
1438 orig_lines.len(),
1439 rust_lines.len(),
1440 "line count mismatch: C={} Rust={}",
1441 orig_lines.len(),
1442 rust_lines.len()
1443 );
1444 for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
1445 assert_eq!(
1446 o,
1447 r,
1448 "line {} differs:\n C: {:?}\n Rust: {:?}",
1449 i + 1,
1450 o,
1451 r
1452 );
1453 }
1454 }
1455
1456 #[test]
1457 #[allow(clippy::excessive_precision)]
1458 fn gfmt_matches_c_printf() {
1459 let cases: &[(f64, &str, &str)] = &[
1461 (0.5, "0.5", "0.5"),
1462 (-1.0, "-1", "-1"),
1463 (0.123456789012345, "0.123456789012345", "0.12345679"),
1464 (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
1465 (0.42446200000000001, "0.42446200000000001", "0.424462"),
1466 (0.0, "0", "0"),
1467 (1e-5, "1.0000000000000001e-05", "1e-05"),
1468 (1e-4, "0.0001", "0.0001"),
1469 (1e20, "1e+20", "1e+20"),
1470 (-0.25, "-0.25", "-0.25"),
1471 (0.75, "0.75", "0.75"),
1472 (0.708333, "0.70833299999999999", "0.708333"),
1473 (1.0, "1", "1"),
1474 ];
1475 for &(v, expected_17g, expected_8g) in cases {
1476 let got_17 = format!("{}", fmt_17g(v));
1477 let got_8 = format!("{}", fmt_8g(v));
1478 assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
1479 assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
1480 }
1481 }
1482
1483 #[test]
1484 #[allow(clippy::excessive_precision)]
1485 fn model_roundtrip() {
1486 let model = SvmModel {
1488 param: SvmParameter {
1489 svm_type: SvmType::CSvc,
1490 kernel_type: KernelType::Rbf,
1491 gamma: 0.5,
1492 ..Default::default()
1493 },
1494 nr_class: 2,
1495 sv: vec![
1496 vec![
1497 SvmNode {
1498 index: 1,
1499 value: 0.5,
1500 },
1501 SvmNode {
1502 index: 3,
1503 value: -1.0,
1504 },
1505 ],
1506 vec![
1507 SvmNode {
1508 index: 1,
1509 value: -0.25,
1510 },
1511 SvmNode {
1512 index: 2,
1513 value: 0.75,
1514 },
1515 ],
1516 ],
1517 sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
1518 rho: vec![0.42446200000000001],
1519 prob_a: vec![],
1520 prob_b: vec![],
1521 prob_density_marks: vec![],
1522 sv_indices: vec![],
1523 label: vec![1, -1],
1524 n_sv: vec![1, 1],
1525 };
1526
1527 let mut buf = Vec::new();
1528 save_model_to_writer(&mut buf, &model).unwrap();
1529
1530 let loaded = load_model_from_reader(&buf[..]).unwrap();
1531
1532 assert_eq!(loaded.nr_class, model.nr_class);
1533 assert_eq!(loaded.param.svm_type, model.param.svm_type);
1534 assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
1535 assert_eq!(loaded.sv.len(), model.sv.len());
1536 assert_eq!(loaded.label, model.label);
1537 assert_eq!(loaded.n_sv, model.n_sv);
1538 assert_eq!(loaded.rho.len(), model.rho.len());
1539 for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
1541 assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
1542 }
1543 for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
1545 for (a, b) in row_a.iter().zip(row_b.iter()) {
1546 assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
1547 }
1548 }
1549 }
1550
1551 #[test]
1552 fn parse_error_excessive_counts() {
1553 let input =
1554 b"svm_type c_svc\nkernel_type linear\nnr_class 1000000\ntotal_sv 100\nrho 0\nSV\n";
1555 let result = load_model_from_reader(&input[..]);
1556 assert!(result.is_err());
1557 assert!(format!("{}", result.unwrap_err()).contains("nr_class exceeds limit"));
1558
1559 let input =
1560 b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 100000000\nrho 0\nSV\n";
1561 let result = load_model_from_reader(&input[..]);
1562 assert!(result.is_err());
1563 assert!(format!("{}", result.unwrap_err()).contains("total_sv exceeds limit"));
1564 }
1565
1566 #[test]
1567 fn parse_error_excessive_feature_index() {
1568 let input = b"1 10000001:1\n";
1570 let result = load_problem_from_reader(&input[..]);
1571 assert!(result.is_err());
1572 assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
1573
1574 let input = b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 1\nrho 0\nlabel 1 -1\nnr_sv 1 0\nSV\n0.1 10000001:1\n";
1576 let result = load_model_from_reader(&input[..]);
1577 assert!(result.is_err());
1578 assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
1579 }
1580
1581 #[test]
1582 fn parse_error_unknown_model_keyword() {
1583 let input = b"bad_key value\n";
1584 let result = load_model_from_reader(&input[..]);
1585 assert!(result.is_err());
1586 assert!(format!("{}", result.unwrap_err()).contains("unknown keyword"));
1587 }
1588
1589 #[test]
1590 fn parse_error_missing_or_unknown_model_values() {
1591 let missing = b"svm_type\n";
1592 let err = load_model_from_reader(&missing[..]).unwrap_err();
1593 assert!(format!("{}", err).contains("missing svm_type value"));
1594
1595 let unknown = b"svm_type unknown_type\n";
1596 let err = load_model_from_reader(&unknown[..]).unwrap_err();
1597 assert!(format!("{}", err).contains("unknown svm_type"));
1598 }
1599
1600 #[test]
1601 fn parse_error_invalid_nr_sv_entry() {
1602 let input = b"svm_type c_svc\n\
1603kernel_type linear\n\
1604nr_class 2\n\
1605total_sv 1\n\
1606rho 0\n\
1607label 1 -1\n\
1608nr_sv a 1\n\
1609SV\n\
16100.1 1:0.5\n";
1611 let err = load_model_from_reader(&input[..]).unwrap_err();
1612 assert!(format!("{}", err).contains("invalid nr_sv value"));
1613 }
1614
1615 #[test]
1616 fn parse_error_in_sv_section_tokens() {
1617 let missing_coef = b"svm_type c_svc\n\
1618kernel_type linear\n\
1619nr_class 2\n\
1620total_sv 1\n\
1621rho 0\n\
1622label 1 -1\n\
1623nr_sv 1 0\n\
1624SV\n\
16251:0.5\n";
1626 let err = load_model_from_reader(&missing_coef[..]).unwrap_err();
1627 assert!(format!("{}", err).contains("invalid sv_coef"));
1628
1629 let bad_feature = b"svm_type c_svc\n\
1630kernel_type linear\n\
1631nr_class 2\n\
1632total_sv 1\n\
1633rho 0\n\
1634label 1 -1\n\
1635nr_sv 1 0\n\
1636SV\n\
16370.1 bad\n";
1638 let err = load_model_from_reader(&bad_feature[..]).unwrap_err();
1639 assert!(format!("{}", err).contains("expected index:value"));
1640 }
1641
1642 #[test]
1643 fn parse_error_unexpected_eof_in_header_and_sv_section() {
1644 let eof_header = b"svm_type c_svc\n";
1645 let err = load_model_from_reader(&eof_header[..]).unwrap_err();
1646 assert!(format!("{}", err).contains("unexpected end of file in header"));
1647
1648 let eof_sv = b"svm_type c_svc\n\
1649kernel_type linear\n\
1650nr_class 2\n\
1651total_sv 2\n\
1652rho 0\n\
1653label 1 -1\n\
1654nr_sv 1 1\n\
1655SV\n\
16560.1 1:0.5\n";
1657 let err = load_model_from_reader(&eof_sv[..]).unwrap_err();
1658 assert!(format!("{}", err).contains("unexpected end of file in SV section"));
1659 }
1660
1661 #[test]
1662 fn reject_rho_length_mismatch_for_classification() {
1663 let input = b"svm_type c_svc\n\
1666kernel_type linear\n\
1667nr_class 3\n\
1668total_sv 3\n\
1669rho 0\n\
1670label 1 -1 0\n\
1671nr_sv 1 1 1\n\
1672SV\n";
1673 let err = load_model_from_reader(&input[..]).unwrap_err();
1674 assert!(
1675 format!("{}", err).contains("rho has 1 entries, expected 3"),
1676 "unexpected error: {}",
1677 err
1678 );
1679 }
1680
1681 #[test]
1682 fn reject_rho_length_mismatch_for_regression() {
1683 let input = b"svm_type epsilon_svr\n\
1685kernel_type linear\n\
1686nr_class 2\n\
1687total_sv 0\n\
1688rho 0 1\n\
1689SV\n";
1690 let err = load_model_from_reader(&input[..]).unwrap_err();
1691 assert!(
1692 format!("{}", err).contains("rho has 2 entries, expected 1"),
1693 "unexpected error: {}",
1694 err
1695 );
1696 }
1697
1698 #[test]
1699 fn reject_label_on_regression() {
1700 let input = b"svm_type epsilon_svr\n\
1702kernel_type linear\n\
1703nr_class 2\n\
1704total_sv 0\n\
1705rho 0\n\
1706label 1 -1\n\
1707SV\n";
1708 let err = load_model_from_reader(&input[..]).unwrap_err();
1709 assert!(
1710 format!("{}", err).contains("label is only valid for classification"),
1711 "unexpected error: {}",
1712 err
1713 );
1714 }
1715
1716 #[test]
1717 fn reject_label_length_mismatch() {
1718 let input = b"svm_type c_svc\n\
1719kernel_type linear\n\
1720nr_class 3\n\
1721total_sv 0\n\
1722rho 0 0 0\n\
1723label 1 -1\n\
1724SV\n";
1725 let err = load_model_from_reader(&input[..]).unwrap_err();
1726 assert!(
1727 format!("{}", err).contains("label has 2 entries, expected nr_class (3)"),
1728 "unexpected error: {}",
1729 err
1730 );
1731 }
1732
1733 #[test]
1734 fn reject_nr_sv_sum_mismatch() {
1735 let input = b"svm_type c_svc\n\
1738kernel_type linear\n\
1739nr_class 2\n\
1740total_sv 5\n\
1741rho 0\n\
1742label 1 -1\n\
1743nr_sv 1 2\n\
1744SV\n";
1745 let err = load_model_from_reader(&input[..]).unwrap_err();
1746 assert!(
1747 format!("{}", err).contains("sum of nr_sv entries (3) does not match total_sv (5)"),
1748 "unexpected error: {}",
1749 err
1750 );
1751 }
1752
1753 #[test]
1754 fn reject_nr_sv_length_mismatch() {
1755 let input = b"svm_type c_svc\n\
1756kernel_type linear\n\
1757nr_class 3\n\
1758total_sv 3\n\
1759rho 0 0 0\n\
1760label 1 -1 0\n\
1761nr_sv 1 2\n\
1762SV\n";
1763 let err = load_model_from_reader(&input[..]).unwrap_err();
1764 assert!(
1765 format!("{}", err).contains("nr_sv has 2 entries, expected nr_class (3)"),
1766 "unexpected error: {}",
1767 err
1768 );
1769 }
1770
1771 #[test]
1772 fn reject_proba_length_mismatch() {
1773 let input = b"svm_type c_svc\n\
1774kernel_type linear\n\
1775nr_class 3\n\
1776total_sv 0\n\
1777rho 0 0 0\n\
1778label 1 -1 0\n\
1779nr_sv 0 0 0\n\
1780probA 0.1 0.2\n\
1781SV\n";
1782 let err = load_model_from_reader(&input[..]).unwrap_err();
1783 assert!(
1784 format!("{}", err).contains("probA has 2 entries, expected 3"),
1785 "unexpected error: {}",
1786 err
1787 );
1788 }
1789
1790 #[test]
1791 fn reject_prob_density_marks_on_csvc() {
1792 let input = b"svm_type c_svc\n\
1793kernel_type linear\n\
1794nr_class 2\n\
1795total_sv 0\n\
1796rho 0\n\
1797label 1 -1\n\
1798nr_sv 0 0\n\
1799prob_density_marks 0.1 0.2\n\
1800SV\n";
1801 let err = load_model_from_reader(&input[..]).unwrap_err();
1802 assert!(
1803 format!("{}", err).contains("prob_density_marks is only valid for one-class SVM"),
1804 "unexpected error: {}",
1805 err
1806 );
1807 }
1808
1809 #[test]
1810 fn reject_nr_class_below_two() {
1811 let input = b"svm_type c_svc\n\
1812kernel_type linear\n\
1813nr_class 1\n\
1814total_sv 0\n\
1815rho\n\
1816SV\n";
1817 let err = load_model_from_reader(&input[..]).unwrap_err();
1818 assert!(
1819 format!("{}", err).contains("nr_class must be >= 2, got 1"),
1820 "unexpected error: {}",
1821 err
1822 );
1823 }
1824
1825 #[test]
1826 fn reject_sv_feature_indices_not_ascending() {
1827 let input = b"svm_type c_svc\n\
1828kernel_type linear\n\
1829nr_class 2\n\
1830total_sv 1\n\
1831rho 0\n\
1832label 1 -1\n\
1833nr_sv 1 0\n\
1834SV\n\
18350.1 3:0.5 1:0.3\n";
1836 let err = load_model_from_reader(&input[..]).unwrap_err();
1837 assert!(
1838 format!("{}", err).contains("feature indices must be ascending"),
1839 "unexpected error: {}",
1840 err
1841 );
1842 }
1843
1844 #[test]
1845 fn reject_precomputed_model_sv_without_sample_serial_number() {
1846 let input = b"svm_type c_svc\n\
1847kernel_type precomputed\n\
1848nr_class 2\n\
1849total_sv 1\n\
1850rho 0\n\
1851label 1 -1\n\
1852nr_sv 1 0\n\
1853SV\n\
18540.1\n";
1855 let err = load_model_from_reader(&input[..]).unwrap_err();
1856 assert!(
1857 format!("{}", err).contains("missing 0:sample_serial_number"),
1858 "unexpected error: {}",
1859 err
1860 );
1861 }
1862
1863 #[test]
1864 fn reject_sv_feature_index_duplicated() {
1865 let input = b"svm_type c_svc\n\
1866kernel_type linear\n\
1867nr_class 2\n\
1868total_sv 1\n\
1869rho 0\n\
1870label 1 -1\n\
1871nr_sv 1 0\n\
1872SV\n\
18730.1 1:0.5 1:0.3\n";
1874 let err = load_model_from_reader(&input[..]).unwrap_err();
1875 assert!(
1876 format!("{}", err).contains("feature indices must be ascending"),
1877 "unexpected error: {}",
1878 err
1879 );
1880 }
1881
1882 #[test]
1883 fn load_options_default_caps_match_documented_values() {
1884 let opts = LoadOptions::default();
1885 assert_eq!(opts.max_bytes, 64 * 1024 * 1024);
1886 assert_eq!(opts.max_line_len, 1024 * 1024);
1887 assert_eq!(opts.max_sv, MAX_TOTAL_SV);
1888 assert_eq!(opts.max_nr_class, MAX_NR_CLASS);
1889 assert_eq!(opts.max_feature_index, MAX_FEATURE_INDEX);
1890 }
1891
1892 #[test]
1893 fn load_options_trusted_input_sets_type_maxes() {
1894 let opts = LoadOptions::trusted_input();
1895 assert_eq!(opts.max_bytes, u64::MAX);
1896 assert_eq!(opts.max_line_len, usize::MAX);
1897 assert_eq!(opts.max_sv, usize::MAX);
1898 assert_eq!(opts.max_nr_class, usize::MAX);
1899 assert_eq!(opts.max_feature_index, i32::MAX);
1900 }
1901
1902 #[test]
1903 fn problem_reader_rejects_file_over_max_bytes() {
1904 let input = b"+1 1:0.5\n+1 2:0.5\n";
1906 let opts = LoadOptions {
1907 max_bytes: 10,
1908 ..LoadOptions::default()
1909 };
1910 let err = load_problem_from_reader_with_options(&input[..], &opts).unwrap_err();
1911 assert!(
1912 format!("{}", err).contains("max_bytes"),
1913 "unexpected error: {}",
1914 err
1915 );
1916 }
1917
1918 #[test]
1919 fn problem_reader_rejects_line_over_max_line_len() {
1920 let mut payload = String::from("+1 ");
1922 for i in 1..=50 {
1923 payload.push_str(&format!("{}:0.1 ", i));
1924 }
1925 payload.push('\n');
1926 let opts = LoadOptions {
1927 max_line_len: 50,
1928 ..LoadOptions::default()
1929 };
1930 let err = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap_err();
1931 assert!(
1932 format!("{}", err).contains("max_line_len"),
1933 "unexpected error: {}",
1934 err
1935 );
1936 }
1937
1938 #[test]
1939 fn problem_reader_accepts_line_at_max_line_len() {
1940 let line_content = "+1 1:0.5";
1943 let payload = format!("{}\n", line_content);
1944 let opts = LoadOptions {
1945 max_line_len: line_content.len(),
1946 ..LoadOptions::default()
1947 };
1948 let problem = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap();
1949 assert_eq!(problem.labels.len(), 1);
1950 }
1951
1952 #[test]
1953 fn problem_reader_tolerates_crlf_at_cap() {
1954 let line_content = "+1 1:0.5";
1957 let payload = format!("{}\r\n", line_content);
1958 let opts = LoadOptions {
1959 max_line_len: line_content.len(),
1960 ..LoadOptions::default()
1961 };
1962 let problem = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap();
1963 assert_eq!(problem.labels.len(), 1);
1964 }
1965
1966 #[test]
1967 fn problem_reader_rejects_nul_byte() {
1968 let mut payload: Vec<u8> = b"+1 1:0.5".to_vec();
1971 payload.push(0);
1972 payload.extend_from_slice(b"\n");
1973 let err = load_problem_from_reader(payload.as_slice()).unwrap_err();
1974 assert!(
1975 format!("{}", err).contains("NUL byte"),
1976 "unexpected error: {}",
1977 err
1978 );
1979 }
1980
1981 #[test]
1982 fn model_reader_honors_max_nr_class_cap() {
1983 let input = b"svm_type c_svc\n\
1986kernel_type linear\n\
1987nr_class 100\n\
1988total_sv 1\n\
1989rho 0\n\
1990SV\n";
1991 let opts = LoadOptions {
1992 max_nr_class: 50,
1993 ..LoadOptions::default()
1994 };
1995 let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
1996 assert!(
1997 format!("{}", err).contains("nr_class exceeds limit (50)"),
1998 "unexpected error: {}",
1999 err
2000 );
2001 }
2002
2003 #[test]
2004 fn model_reader_honors_max_sv_cap() {
2005 let input = b"svm_type c_svc\n\
2006kernel_type linear\n\
2007nr_class 2\n\
2008total_sv 1000\n\
2009rho 0\n\
2010SV\n";
2011 let opts = LoadOptions {
2012 max_sv: 100,
2013 ..LoadOptions::default()
2014 };
2015 let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
2016 assert!(
2017 format!("{}", err).contains("total_sv exceeds limit (100)"),
2018 "unexpected error: {}",
2019 err
2020 );
2021 }
2022
2023 #[test]
2024 fn trusted_input_cannot_exceed_hard_module_caps() {
2025 let huge_nr_class = format!(
2029 "svm_type c_svc\n\
2030kernel_type linear\n\
2031nr_class {}\n\
2032total_sv 1\n\
2033rho 0\n\
2034SV\n",
2035 MAX_NR_CLASS + 1
2036 );
2037 let opts = LoadOptions::trusted_input();
2038 let err = load_model_from_reader_with_options(huge_nr_class.as_bytes(), &opts).unwrap_err();
2039 assert!(
2040 format!("{}", err).contains("nr_class exceeds limit"),
2041 "unexpected error: {}",
2042 err
2043 );
2044 }
2045
2046 #[test]
2047 fn model_reader_honors_max_feature_index_cap() {
2048 let input = b"svm_type c_svc\n\
2049kernel_type linear\n\
2050nr_class 2\n\
2051total_sv 1\n\
2052rho 0\n\
2053label 1 -1\n\
2054nr_sv 1 0\n\
2055SV\n\
20560.1 50:0.5\n";
2057 let opts = LoadOptions {
2058 max_feature_index: 10,
2059 ..LoadOptions::default()
2060 };
2061 let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
2062 assert!(
2063 format!("{}", err).contains("feature index 50 exceeds limit (10)"),
2064 "unexpected error: {}",
2065 err
2066 );
2067 }
2068
2069 #[test]
2070 fn sv_count_loop_counts_nonblank_lines_only() {
2071 let input = b"svm_type c_svc\n\
2077kernel_type linear\n\
2078nr_class 2\n\
2079total_sv 2\n\
2080rho 0\n\
2081label 1 -1\n\
2082nr_sv 1 1\n\
2083SV\n\
2084\n\
20850.1 1:0.5\n\
2086\n\
2087-0.1 2:0.5\n";
2088 let model = load_model_from_reader(&input[..]).unwrap();
2089 assert_eq!(model.sv.len(), 2);
2090 assert_eq!(model.sv_coef[0].len(), 2);
2091 }
2092
2093 #[test]
2094 fn save_precomputed_model_writes_zero_index() {
2095 let model = SvmModel {
2096 param: SvmParameter {
2097 svm_type: SvmType::CSvc,
2098 kernel_type: KernelType::Precomputed,
2099 ..Default::default()
2100 },
2101 nr_class: 2,
2102 sv: vec![vec![SvmNode {
2103 index: 0,
2104 value: 7.0,
2105 }]],
2106 sv_coef: vec![vec![0.25]],
2107 rho: vec![0.0],
2108 prob_a: vec![],
2109 prob_b: vec![],
2110 prob_density_marks: vec![],
2111 sv_indices: vec![],
2112 label: vec![1, -1],
2113 n_sv: vec![1, 0],
2114 };
2115
2116 let mut buf = Vec::new();
2117 save_model_to_writer(&mut buf, &model).unwrap();
2118 let out = String::from_utf8(buf).unwrap();
2119 assert!(out.contains("kernel_type precomputed"));
2120 assert!(out.contains("0:7"));
2121 }
2122
2123 #[test]
2128 fn csvc_missing_label_and_nr_sv_returns_error_not_panic() {
2129 let input = b"svm_type c_svc\n\
2130kernel_type linear\n\
2131nr_class 2\n\
2132total_sv 1\n\
2133rho 0\n\
2134SV\n\
21351 1:1\n";
2136 let err = load_model_from_reader(&input[..]).unwrap_err();
2137 let msg = format!("{}", err);
2138 assert!(
2139 msg.contains("label has 0 entries, expected nr_class (2)"),
2140 "unexpected error: {}",
2141 msg
2142 );
2143 }
2144
2145 #[test]
2147 fn csvc_label_present_nr_sv_absent_returns_error() {
2148 let input = b"svm_type c_svc\n\
2149kernel_type linear\n\
2150nr_class 2\n\
2151total_sv 1\n\
2152rho 0\n\
2153label 1 -1\n\
2154SV\n\
21551 1:1\n";
2156 let err = load_model_from_reader(&input[..]).unwrap_err();
2157 let msg = format!("{}", err);
2158 assert!(
2159 msg.contains("nr_sv has 0 entries, expected nr_class (2)"),
2160 "unexpected error: {}",
2161 msg
2162 );
2163 }
2164
2165 #[test]
2169 fn one_class_without_label_and_nr_sv_loads_ok() {
2170 let input = b"svm_type one_class\n\
2172kernel_type rbf\n\
2173gamma 0.5\n\
2174nr_class 2\n\
2175total_sv 1\n\
2176rho -0.5\n\
2177SV\n\
21781 1:0.5\n";
2179 let model = load_model_from_reader(&input[..]).unwrap();
2180 assert_eq!(model.sv.len(), 1);
2181 assert_eq!(model.label, Vec::<i32>::new());
2182 assert_eq!(model.n_sv, Vec::<usize>::new());
2183 }
2184
2185 #[test]
2189 fn model_rho_inf_returns_error() {
2190 let input = b"svm_type c_svc\n\
2191kernel_type linear\n\
2192nr_class 2\n\
2193total_sv 1\n\
2194rho inf\n\
2195label 1 -1\n\
2196nr_sv 1 0\n\
2197SV\n\
21981 1:0.5\n";
2199 let err = load_model_from_reader(&input[..]).unwrap_err();
2200 let msg = format!("{}", err);
2201 assert!(
2202 msg.contains("rho must be finite"),
2203 "unexpected error: {}",
2204 msg
2205 );
2206 }
2207
2208 #[test]
2210 fn model_sv_feature_nan_returns_error() {
2211 let input = b"svm_type c_svc\n\
2212kernel_type linear\n\
2213nr_class 2\n\
2214total_sv 1\n\
2215rho 0\n\
2216label 1 -1\n\
2217nr_sv 1 0\n\
2218SV\n\
22191 1:nan\n";
2220 let err = load_model_from_reader(&input[..]).unwrap_err();
2221 let msg = format!("{}", err);
2222 assert!(
2223 msg.contains("feature value must be finite"),
2224 "unexpected error: {}",
2225 msg
2226 );
2227 }
2228}