1use std::io::{BufRead, Write};
7use std::path::Path;
8
9use crate::error::SvmError;
10use crate::types::*;
11
12use std::fmt;
22
23struct Gfmt {
25 value: f64,
26 precision: usize,
27}
28
29impl Gfmt {
30 fn new(value: f64, precision: usize) -> Self {
31 Self { value, precision }
32 }
33}
34
35impl fmt::Display for Gfmt {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 let v = self.value;
38 let p = self.precision;
39
40 if !v.is_finite() {
41 return write!(f, "{}", v); }
43
44 if v == 0.0 {
45 if v.is_sign_negative() {
47 return write!(f, "-0");
48 }
49 return write!(f, "0");
50 }
51
52 let abs_v = v.abs();
54 let exp = abs_v.log10().floor() as i32;
55
56 if exp < -4 || exp >= p as i32 {
57 let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
59 if let Some((mantissa, exponent)) = s.split_once('e') {
62 let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
63 let exp_val: i32 = exponent.parse().unwrap_or(0);
65 let exp_str = if exp_val < 0 {
66 format!("-{:02}", -exp_val)
67 } else {
68 format!("+{:02}", exp_val)
69 };
70 write!(f, "{}e{}", mantissa, exp_str)
71 } else {
72 write!(f, "{}", s)
73 }
74 } else {
75 let decimal_places = if exp >= 0 {
77 p.saturating_sub((exp + 1) as usize)
78 } else {
79 p + (-1 - exp) as usize
80 };
81 let s = format!("{:.prec$}", v, prec = decimal_places);
82 let s = s.trim_end_matches('0').trim_end_matches('.');
83 write!(f, "{}", s)
84 }
85 }
86}
87
88fn fmt_17g(v: f64) -> Gfmt {
90 Gfmt::new(v, 17)
91}
92
93fn fmt_8g(v: f64) -> Gfmt {
95 Gfmt::new(v, 8)
96}
97
98pub fn format_g(v: f64) -> String {
100 format!("{}", Gfmt::new(v, 6))
101}
102
103pub fn format_17g(v: f64) -> String {
105 format!("{}", Gfmt::new(v, 17))
106}
107
108const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
111const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
112
113fn svm_type_to_str(t: SvmType) -> &'static str {
114 SVM_TYPE_TABLE[t as usize]
115}
116
117fn kernel_type_to_str(t: KernelType) -> &'static str {
118 KERNEL_TYPE_TABLE[t as usize]
119}
120
121fn str_to_svm_type(s: &str) -> Option<SvmType> {
122 match s {
123 "c_svc" => Some(SvmType::CSvc),
124 "nu_svc" => Some(SvmType::NuSvc),
125 "one_class" => Some(SvmType::OneClass),
126 "epsilon_svr" => Some(SvmType::EpsilonSvr),
127 "nu_svr" => Some(SvmType::NuSvr),
128 _ => None,
129 }
130}
131
132fn str_to_kernel_type(s: &str) -> Option<KernelType> {
133 match s {
134 "linear" => Some(KernelType::Linear),
135 "polynomial" => Some(KernelType::Polynomial),
136 "rbf" => Some(KernelType::Rbf),
137 "sigmoid" => Some(KernelType::Sigmoid),
138 "precomputed" => Some(KernelType::Precomputed),
139 _ => None,
140 }
141}
142
143const MAX_FEATURE_INDEX: i32 = 10_000_000;
146
147pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
151 let file = std::fs::File::open(path)?;
152 let reader = std::io::BufReader::new(file);
153 load_problem_from_reader(reader)
154}
155
156pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
158 let mut labels = Vec::new();
159 let mut instances = Vec::new();
160
161 for (line_idx, line_result) in reader.lines().enumerate() {
162 let line = line_result?;
163 let line = line.trim();
164 if line.is_empty() {
165 continue;
166 }
167
168 let line_num = line_idx + 1;
169 let mut parts = line.split_whitespace();
170
171 let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
173 line: line_num,
174 message: "missing label".into(),
175 })?;
176 let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
177 line: line_num,
178 message: format!("invalid label: {}", label_str),
179 })?;
180
181 let mut nodes = Vec::new();
183 let mut prev_index: i32 = 0;
184 for token in parts {
185 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
186 line: line_num,
187 message: format!("expected index:value, got: {}", token),
188 })?;
189 let index: i32 = idx_str.parse().map_err(|_| SvmError::ParseError {
190 line: line_num,
191 message: format!("invalid index: {}", idx_str),
192 })?;
193
194 if index > MAX_FEATURE_INDEX {
195 return Err(SvmError::ParseError {
196 line: line_num,
197 message: format!(
198 "feature index {} exceeds limit ({})",
199 index, MAX_FEATURE_INDEX
200 ),
201 });
202 }
203
204 if !nodes.is_empty() && index <= prev_index {
205 return Err(SvmError::ParseError {
206 line: line_num,
207 message: format!(
208 "feature indices must be ascending: {} follows {}",
209 index, prev_index
210 ),
211 });
212 }
213 let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
214 line: line_num,
215 message: format!("invalid value: {}", val_str),
216 })?;
217 prev_index = index;
218 nodes.push(SvmNode { index, value });
219 }
220
221 labels.push(label);
222 instances.push(nodes);
223 }
224
225 Ok(SvmProblem { labels, instances })
226}
227
228const MAX_NR_CLASS: usize = 65535;
231const MAX_TOTAL_SV: usize = 10_000_000;
232
233pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
235 let file = std::fs::File::create(path)?;
236 let writer = std::io::BufWriter::new(file);
237 save_model_to_writer(writer, model)
238}
239
240pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
242 let param = &model.param;
243
244 writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
245 writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
246
247 if param.kernel_type == KernelType::Polynomial {
248 writeln!(w, "degree {}", param.degree)?;
249 }
250 if matches!(
251 param.kernel_type,
252 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
253 ) {
254 writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
255 }
256 if matches!(
257 param.kernel_type,
258 KernelType::Polynomial | KernelType::Sigmoid
259 ) {
260 writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
261 }
262
263 let nr_class = model.nr_class;
264 writeln!(w, "nr_class {}", nr_class)?;
265 writeln!(w, "total_sv {}", model.sv.len())?;
266
267 write!(w, "rho")?;
269 for r in &model.rho {
270 write!(w, " {}", fmt_17g(*r))?;
271 }
272 writeln!(w)?;
273
274 if !model.label.is_empty() {
276 write!(w, "label")?;
277 for l in &model.label {
278 write!(w, " {}", l)?;
279 }
280 writeln!(w)?;
281 }
282
283 if !model.prob_a.is_empty() {
285 write!(w, "probA")?;
286 for v in &model.prob_a {
287 write!(w, " {}", fmt_17g(*v))?;
288 }
289 writeln!(w)?;
290 }
291
292 if !model.prob_b.is_empty() {
294 write!(w, "probB")?;
295 for v in &model.prob_b {
296 write!(w, " {}", fmt_17g(*v))?;
297 }
298 writeln!(w)?;
299 }
300
301 if !model.prob_density_marks.is_empty() {
303 write!(w, "prob_density_marks")?;
304 for v in &model.prob_density_marks {
305 write!(w, " {}", fmt_17g(*v))?;
306 }
307 writeln!(w)?;
308 }
309
310 if !model.n_sv.is_empty() {
312 write!(w, "nr_sv")?;
313 for n in &model.n_sv {
314 write!(w, " {}", n)?;
315 }
316 writeln!(w)?;
317 }
318
319 writeln!(w, "SV")?;
321 let num_sv = model.sv.len();
322 let num_coef_rows = model.sv_coef.len(); for i in 0..num_sv {
325 for j in 0..num_coef_rows {
327 write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
328 }
329 if model.param.kernel_type == KernelType::Precomputed {
331 if let Some(node) = model.sv[i].first() {
332 write!(w, "0:{} ", node.value as i32)?;
333 }
334 } else {
335 for node in &model.sv[i] {
336 write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
337 }
338 }
339 writeln!(w)?;
340 }
341
342 Ok(())
343}
344
345pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
347 let file = std::fs::File::open(path)?;
348 let reader = std::io::BufReader::new(file);
349 load_model_from_reader(reader)
350}
351
352pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
354 let mut lines = reader.lines();
355
356 let mut param = SvmParameter::default();
358 let mut nr_class: usize = 0;
359 let mut total_sv: usize = 0;
360 let mut rho = Vec::new();
361 let mut label = Vec::new();
362 let mut prob_a = Vec::new();
363 let mut prob_b = Vec::new();
364 let mut prob_density_marks = Vec::new();
365 let mut n_sv = Vec::new();
366
367 let mut line_num: usize = 0;
369 loop {
370 let line = lines.next().ok_or_else(|| {
371 SvmError::ModelFormatError("unexpected end of file in header".into())
372 })??;
373 line_num += 1;
374 let line = line.trim().to_string();
375 if line.is_empty() {
376 continue;
377 }
378
379 let mut parts = line.split_whitespace();
380 let cmd = parts.next().unwrap();
381
382 match cmd {
383 "svm_type" => {
384 let val = parts.next().ok_or_else(|| {
385 SvmError::ModelFormatError(format!("line {}: missing svm_type value", line_num))
386 })?;
387 param.svm_type = str_to_svm_type(val).ok_or_else(|| {
388 SvmError::ModelFormatError(format!(
389 "line {}: unknown svm_type: {}",
390 line_num, val
391 ))
392 })?;
393 }
394 "kernel_type" => {
395 let val = parts.next().ok_or_else(|| {
396 SvmError::ModelFormatError(format!(
397 "line {}: missing kernel_type value",
398 line_num
399 ))
400 })?;
401 param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
402 SvmError::ModelFormatError(format!(
403 "line {}: unknown kernel_type: {}",
404 line_num, val
405 ))
406 })?;
407 }
408 "degree" => {
409 param.degree = parse_single(&mut parts, line_num, "degree")?;
410 }
411 "gamma" => {
412 param.gamma = parse_single(&mut parts, line_num, "gamma")?;
413 }
414 "coef0" => {
415 param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
416 }
417 "nr_class" => {
418 nr_class = parse_single(&mut parts, line_num, "nr_class")?;
419 if nr_class > MAX_NR_CLASS {
420 return Err(SvmError::ModelFormatError(format!(
421 "line {}: nr_class exceeds limit ({})",
422 line_num, MAX_NR_CLASS
423 )));
424 }
425 }
426 "total_sv" => {
427 total_sv = parse_single(&mut parts, line_num, "total_sv")?;
428 if total_sv > MAX_TOTAL_SV {
429 return Err(SvmError::ModelFormatError(format!(
430 "line {}: total_sv exceeds limit ({})",
431 line_num, MAX_TOTAL_SV
432 )));
433 }
434 }
435 "rho" => {
436 rho = parse_multiple_f64(&mut parts, line_num, "rho")?;
437 }
438 "label" => {
439 label = parse_multiple_i32(&mut parts, line_num, "label")?;
440 }
441 "probA" => {
442 prob_a = parse_multiple_f64(&mut parts, line_num, "probA")?;
443 }
444 "probB" => {
445 prob_b = parse_multiple_f64(&mut parts, line_num, "probB")?;
446 }
447 "prob_density_marks" => {
448 prob_density_marks =
449 parse_multiple_f64(&mut parts, line_num, "prob_density_marks")?;
450 }
451 "nr_sv" => {
452 n_sv = parts
453 .map(|s| {
454 s.parse::<usize>().map_err(|_| {
455 SvmError::ModelFormatError(format!(
456 "line {}: invalid nr_sv value: {}",
457 line_num, s
458 ))
459 })
460 })
461 .collect::<Result<Vec<_>, _>>()?;
462 }
463 "SV" => break,
464 _ => {
465 return Err(SvmError::ModelFormatError(format!(
466 "line {}: unknown keyword: {}",
467 line_num, cmd
468 )));
469 }
470 }
471 }
472
473 let m = if nr_class > 1 { nr_class - 1 } else { 1 };
475 let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::with_capacity(total_sv)).collect();
476 let mut sv: Vec<Vec<SvmNode>> = Vec::with_capacity(total_sv);
477
478 for _ in 0..total_sv {
479 let line = lines.next().ok_or_else(|| {
480 SvmError::ModelFormatError("unexpected end of file in SV section".into())
481 })??;
482 line_num += 1;
483 let line = line.trim();
484 if line.is_empty() {
485 continue;
486 }
487
488 let mut parts = line.split_whitespace();
489
490 for (k, coef_row) in sv_coef.iter_mut().enumerate() {
492 let val_str = parts.next().ok_or_else(|| {
493 SvmError::ModelFormatError(format!("line {}: missing sv_coef[{}]", line_num, k))
494 })?;
495 let val: f64 = val_str.parse().map_err(|_| {
496 SvmError::ModelFormatError(format!(
497 "line {}: invalid sv_coef: {}",
498 line_num, val_str
499 ))
500 })?;
501 coef_row.push(val);
502 }
503
504 let mut nodes = Vec::new();
506 for token in parts {
507 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
508 SvmError::ModelFormatError(format!(
509 "line {}: expected index:value, got: {}",
510 line_num, token
511 ))
512 })?;
513 let index: i32 = idx_str.parse().map_err(|_| {
514 SvmError::ModelFormatError(format!("line {}: invalid index: {}", line_num, idx_str))
515 })?;
516
517 if index > MAX_FEATURE_INDEX {
518 return Err(SvmError::ModelFormatError(format!(
519 "line {}: feature index {} exceeds limit ({})",
520 line_num, index, MAX_FEATURE_INDEX
521 )));
522 }
523
524 let value: f64 = val_str.parse().map_err(|_| {
525 SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
526 })?;
527 nodes.push(SvmNode { index, value });
528 }
529 sv.push(nodes);
530 }
531
532 Ok(SvmModel {
533 param,
534 nr_class,
535 sv,
536 sv_coef,
537 rho,
538 prob_a,
539 prob_b,
540 prob_density_marks,
541 sv_indices: Vec::new(), label,
543 n_sv,
544 })
545}
546
547fn parse_single<T: std::str::FromStr>(
550 parts: &mut std::str::SplitWhitespace<'_>,
551 line_num: usize,
552 field: &str,
553) -> Result<T, SvmError> {
554 let val_str = parts.next().ok_or_else(|| {
555 SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
556 })?;
557 val_str.parse().map_err(|_| {
558 SvmError::ModelFormatError(format!(
559 "line {}: invalid {} value: {}",
560 line_num, field, val_str
561 ))
562 })
563}
564
565fn parse_multiple_f64(
566 parts: &mut std::str::SplitWhitespace<'_>,
567 line_num: usize,
568 field: &str,
569) -> Result<Vec<f64>, SvmError> {
570 parts
571 .map(|s| {
572 s.parse::<f64>().map_err(|_| {
573 SvmError::ModelFormatError(format!(
574 "line {}: invalid {} value: {}",
575 line_num, field, s
576 ))
577 })
578 })
579 .collect()
580}
581
582fn parse_multiple_i32(
583 parts: &mut std::str::SplitWhitespace<'_>,
584 line_num: usize,
585 field: &str,
586) -> Result<Vec<i32>, SvmError> {
587 parts
588 .map(|s| {
589 s.parse::<i32>().map_err(|_| {
590 SvmError::ModelFormatError(format!(
591 "line {}: invalid {} value: {}",
592 line_num, field, s
593 ))
594 })
595 })
596 .collect()
597}
598
599#[cfg(test)]
602mod tests {
603 use super::*;
604 use std::path::PathBuf;
605
606 fn data_dir() -> PathBuf {
607 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
608 .join("..")
609 .join("..")
610 .join("data")
611 }
612
613 #[test]
614 fn parse_heart_scale() {
615 let path = data_dir().join("heart_scale");
616 let problem = load_problem(&path).unwrap();
617 assert_eq!(problem.labels.len(), 270);
618 assert_eq!(problem.instances.len(), 270);
619 assert_eq!(problem.labels[0], 1.0);
621 assert_eq!(
622 problem.instances[0][0],
623 SvmNode {
624 index: 1,
625 value: 0.708333
626 }
627 );
628 assert_eq!(problem.instances[0].len(), 12);
629 }
630
631 #[test]
632 fn parse_iris() {
633 let path = data_dir().join("iris.scale");
634 let problem = load_problem(&path).unwrap();
635 assert_eq!(problem.labels.len(), 150);
636 let classes: std::collections::HashSet<i64> =
638 problem.labels.iter().map(|&l| l as i64).collect();
639 assert_eq!(classes.len(), 3);
640 }
641
642 #[test]
643 fn parse_housing() {
644 let path = data_dir().join("housing_scale");
645 let problem = load_problem(&path).unwrap();
646 assert_eq!(problem.labels.len(), 506);
647 assert!((problem.labels[0] - 24.0).abs() < 1e-10);
649 }
650
651 #[test]
652 fn parse_empty_lines() {
653 let input = b"+1 1:0.5\n\n-1 2:0.3\n";
654 let problem = load_problem_from_reader(&input[..]).unwrap();
655 assert_eq!(problem.labels.len(), 2);
656 }
657
658 #[test]
659 fn parse_error_unsorted_indices() {
660 let input = b"+1 3:0.5 1:0.3\n";
661 let result = load_problem_from_reader(&input[..]);
662 assert!(result.is_err());
663 let msg = format!("{}", result.unwrap_err());
664 assert!(msg.contains("ascending"), "error: {}", msg);
665 }
666
667 #[test]
668 fn parse_error_duplicate_indices() {
669 let input = b"+1 1:0.5 1:0.3\n";
670 let result = load_problem_from_reader(&input[..]);
671 assert!(result.is_err());
672 }
673
674 #[test]
675 fn parse_error_missing_colon() {
676 let input = b"+1 1:0.5 bad_token\n";
677 let result = load_problem_from_reader(&input[..]);
678 assert!(result.is_err());
679 }
680
681 #[test]
682 #[allow(clippy::excessive_precision)]
683 fn load_c_trained_model() {
684 let path = data_dir().join("heart_scale.model");
686 let model = load_model(&path).unwrap();
687 assert_eq!(model.nr_class, 2);
688 assert_eq!(model.param.svm_type, SvmType::CSvc);
689 assert_eq!(model.param.kernel_type, KernelType::Rbf);
690 assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
691 assert_eq!(model.sv.len(), 132);
692 assert_eq!(model.label, vec![1, -1]);
693 assert_eq!(model.n_sv, vec![64, 68]);
694 assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
695 assert_eq!(model.sv_coef.len(), 1);
697 assert_eq!(model.sv_coef[0].len(), 132);
698 }
699
700 #[test]
701 fn roundtrip_c_model() {
702 let path = data_dir().join("heart_scale.model");
704 let original_bytes = std::fs::read_to_string(&path).unwrap();
705 let model = load_model(&path).unwrap();
706
707 let mut buf = Vec::new();
708 save_model_to_writer(&mut buf, &model).unwrap();
709 let rust_output = String::from_utf8(buf).unwrap();
710
711 let orig_lines: Vec<&str> = original_bytes.lines().collect();
713 let rust_lines: Vec<&str> = rust_output.lines().collect();
714 assert_eq!(
715 orig_lines.len(),
716 rust_lines.len(),
717 "line count mismatch: C={} Rust={}",
718 orig_lines.len(),
719 rust_lines.len()
720 );
721 for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
722 assert_eq!(
723 o,
724 r,
725 "line {} differs:\n C: {:?}\n Rust: {:?}",
726 i + 1,
727 o,
728 r
729 );
730 }
731 }
732
733 #[test]
734 #[allow(clippy::excessive_precision)]
735 fn gfmt_matches_c_printf() {
736 let cases: &[(f64, &str, &str)] = &[
738 (0.5, "0.5", "0.5"),
739 (-1.0, "-1", "-1"),
740 (0.123456789012345, "0.123456789012345", "0.12345679"),
741 (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
742 (0.42446200000000001, "0.42446200000000001", "0.424462"),
743 (0.0, "0", "0"),
744 (1e-5, "1.0000000000000001e-05", "1e-05"),
745 (1e-4, "0.0001", "0.0001"),
746 (1e20, "1e+20", "1e+20"),
747 (-0.25, "-0.25", "-0.25"),
748 (0.75, "0.75", "0.75"),
749 (0.708333, "0.70833299999999999", "0.708333"),
750 (1.0, "1", "1"),
751 ];
752 for &(v, expected_17g, expected_8g) in cases {
753 let got_17 = format!("{}", fmt_17g(v));
754 let got_8 = format!("{}", fmt_8g(v));
755 assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
756 assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
757 }
758 }
759
760 #[test]
761 #[allow(clippy::excessive_precision)]
762 fn model_roundtrip() {
763 let model = SvmModel {
765 param: SvmParameter {
766 svm_type: SvmType::CSvc,
767 kernel_type: KernelType::Rbf,
768 gamma: 0.5,
769 ..Default::default()
770 },
771 nr_class: 2,
772 sv: vec![
773 vec![
774 SvmNode {
775 index: 1,
776 value: 0.5,
777 },
778 SvmNode {
779 index: 3,
780 value: -1.0,
781 },
782 ],
783 vec![
784 SvmNode {
785 index: 1,
786 value: -0.25,
787 },
788 SvmNode {
789 index: 2,
790 value: 0.75,
791 },
792 ],
793 ],
794 sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
795 rho: vec![0.42446200000000001],
796 prob_a: vec![],
797 prob_b: vec![],
798 prob_density_marks: vec![],
799 sv_indices: vec![],
800 label: vec![1, -1],
801 n_sv: vec![1, 1],
802 };
803
804 let mut buf = Vec::new();
805 save_model_to_writer(&mut buf, &model).unwrap();
806
807 let loaded = load_model_from_reader(&buf[..]).unwrap();
808
809 assert_eq!(loaded.nr_class, model.nr_class);
810 assert_eq!(loaded.param.svm_type, model.param.svm_type);
811 assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
812 assert_eq!(loaded.sv.len(), model.sv.len());
813 assert_eq!(loaded.label, model.label);
814 assert_eq!(loaded.n_sv, model.n_sv);
815 assert_eq!(loaded.rho.len(), model.rho.len());
816 for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
818 assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
819 }
820 for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
822 for (a, b) in row_a.iter().zip(row_b.iter()) {
823 assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
824 }
825 }
826 }
827
828 #[test]
829 fn parse_error_excessive_counts() {
830 let input =
831 b"svm_type c_svc\nkernel_type linear\nnr_class 1000000\ntotal_sv 100\nrho 0\nSV\n";
832 let result = load_model_from_reader(&input[..]);
833 assert!(result.is_err());
834 assert!(format!("{}", result.unwrap_err()).contains("nr_class exceeds limit"));
835
836 let input =
837 b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 100000000\nrho 0\nSV\n";
838 let result = load_model_from_reader(&input[..]);
839 assert!(result.is_err());
840 assert!(format!("{}", result.unwrap_err()).contains("total_sv exceeds limit"));
841 }
842
843 #[test]
844 fn parse_error_excessive_feature_index() {
845 let input = b"1 10000001:1\n";
847 let result = load_problem_from_reader(&input[..]);
848 assert!(result.is_err());
849 assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
850
851 let input = b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 1\nrho 0\nSV\n0.1 10000001:1\n";
853 let result = load_model_from_reader(&input[..]);
854 assert!(result.is_err());
855 assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
856 }
857
858 #[test]
859 fn parse_error_unknown_model_keyword() {
860 let input = b"bad_key value\n";
861 let result = load_model_from_reader(&input[..]);
862 assert!(result.is_err());
863 assert!(format!("{}", result.unwrap_err()).contains("unknown keyword"));
864 }
865
866 #[test]
867 fn parse_error_missing_or_unknown_model_values() {
868 let missing = b"svm_type\n";
869 let err = load_model_from_reader(&missing[..]).unwrap_err();
870 assert!(format!("{}", err).contains("missing svm_type value"));
871
872 let unknown = b"svm_type unknown_type\n";
873 let err = load_model_from_reader(&unknown[..]).unwrap_err();
874 assert!(format!("{}", err).contains("unknown svm_type"));
875 }
876
877 #[test]
878 fn parse_error_invalid_nr_sv_entry() {
879 let input = b"svm_type c_svc\n\
880kernel_type linear\n\
881nr_class 2\n\
882total_sv 1\n\
883rho 0\n\
884nr_sv a 1\n\
885SV\n\
8860.1 1:0.5\n";
887 let err = load_model_from_reader(&input[..]).unwrap_err();
888 assert!(format!("{}", err).contains("invalid nr_sv value"));
889 }
890
891 #[test]
892 fn parse_error_in_sv_section_tokens() {
893 let missing_coef = b"svm_type c_svc\n\
894kernel_type linear\n\
895nr_class 2\n\
896total_sv 1\n\
897rho 0\n\
898SV\n\
8991:0.5\n";
900 let err = load_model_from_reader(&missing_coef[..]).unwrap_err();
901 assert!(format!("{}", err).contains("invalid sv_coef"));
902
903 let bad_feature = b"svm_type c_svc\n\
904kernel_type linear\n\
905nr_class 2\n\
906total_sv 1\n\
907rho 0\n\
908SV\n\
9090.1 bad\n";
910 let err = load_model_from_reader(&bad_feature[..]).unwrap_err();
911 assert!(format!("{}", err).contains("expected index:value"));
912 }
913
914 #[test]
915 fn parse_error_unexpected_eof_in_header_and_sv_section() {
916 let eof_header = b"svm_type c_svc\n";
917 let err = load_model_from_reader(&eof_header[..]).unwrap_err();
918 assert!(format!("{}", err).contains("unexpected end of file in header"));
919
920 let eof_sv = b"svm_type c_svc\n\
921kernel_type linear\n\
922nr_class 2\n\
923total_sv 2\n\
924rho 0\n\
925SV\n\
9260.1 1:0.5\n";
927 let err = load_model_from_reader(&eof_sv[..]).unwrap_err();
928 assert!(format!("{}", err).contains("unexpected end of file in SV section"));
929 }
930
931 #[test]
932 fn save_precomputed_model_writes_zero_index() {
933 let model = SvmModel {
934 param: SvmParameter {
935 svm_type: SvmType::CSvc,
936 kernel_type: KernelType::Precomputed,
937 ..Default::default()
938 },
939 nr_class: 2,
940 sv: vec![vec![SvmNode {
941 index: 0,
942 value: 7.0,
943 }]],
944 sv_coef: vec![vec![0.25]],
945 rho: vec![0.0],
946 prob_a: vec![],
947 prob_b: vec![],
948 prob_density_marks: vec![],
949 sv_indices: vec![],
950 label: vec![1, -1],
951 n_sv: vec![1, 0],
952 };
953
954 let mut buf = Vec::new();
955 save_model_to_writer(&mut buf, &model).unwrap();
956 let out = String::from_utf8(buf).unwrap();
957 assert!(out.contains("kernel_type precomputed"));
958 assert!(out.contains("0:7"));
959 }
960}