1use std::io::{BufRead, Write};
7use std::path::Path;
8
9use crate::error::SvmError;
10use crate::types::*;
11
12use std::fmt;
22
23struct Gfmt {
25 value: f64,
26 precision: usize,
27}
28
29impl Gfmt {
30 fn new(value: f64, precision: usize) -> Self {
31 Self { value, precision }
32 }
33}
34
35impl fmt::Display for Gfmt {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 let v = self.value;
38 let p = self.precision;
39
40 if !v.is_finite() {
41 return write!(f, "{}", v); }
43
44 if v == 0.0 {
45 if v.is_sign_negative() {
47 return write!(f, "-0");
48 }
49 return write!(f, "0");
50 }
51
52 let abs_v = v.abs();
54 let exp = abs_v.log10().floor() as i32;
55
56 if exp < -4 || exp >= p as i32 {
57 let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
59 if let Some((mantissa, exponent)) = s.split_once('e') {
62 let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
63 let exp_val: i32 = exponent.parse().unwrap_or(0);
65 let exp_str = if exp_val < 0 {
66 format!("-{:02}", -exp_val)
67 } else {
68 format!("+{:02}", exp_val)
69 };
70 write!(f, "{}e{}", mantissa, exp_str)
71 } else {
72 write!(f, "{}", s)
73 }
74 } else {
75 let decimal_places = if exp >= 0 {
77 p.saturating_sub((exp + 1) as usize)
78 } else {
79 p + (-1 - exp) as usize
80 };
81 let s = format!("{:.prec$}", v, prec = decimal_places);
82 let s = s.trim_end_matches('0').trim_end_matches('.');
83 write!(f, "{}", s)
84 }
85 }
86}
87
88fn fmt_17g(v: f64) -> Gfmt {
90 Gfmt::new(v, 17)
91}
92
93fn fmt_8g(v: f64) -> Gfmt {
95 Gfmt::new(v, 8)
96}
97
98pub fn format_g(v: f64) -> String {
100 format!("{}", Gfmt::new(v, 6))
101}
102
103pub fn format_17g(v: f64) -> String {
105 format!("{}", Gfmt::new(v, 17))
106}
107
108const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
111const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
112
113fn svm_type_to_str(t: SvmType) -> &'static str {
114 SVM_TYPE_TABLE[t as usize]
115}
116
117fn kernel_type_to_str(t: KernelType) -> &'static str {
118 KERNEL_TYPE_TABLE[t as usize]
119}
120
121fn str_to_svm_type(s: &str) -> Option<SvmType> {
122 match s {
123 "c_svc" => Some(SvmType::CSvc),
124 "nu_svc" => Some(SvmType::NuSvc),
125 "one_class" => Some(SvmType::OneClass),
126 "epsilon_svr" => Some(SvmType::EpsilonSvr),
127 "nu_svr" => Some(SvmType::NuSvr),
128 _ => None,
129 }
130}
131
132fn str_to_kernel_type(s: &str) -> Option<KernelType> {
133 match s {
134 "linear" => Some(KernelType::Linear),
135 "polynomial" => Some(KernelType::Polynomial),
136 "rbf" => Some(KernelType::Rbf),
137 "sigmoid" => Some(KernelType::Sigmoid),
138 "precomputed" => Some(KernelType::Precomputed),
139 _ => None,
140 }
141}
142
143pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
149 let file = std::fs::File::open(path)?;
150 let reader = std::io::BufReader::new(file);
151 load_problem_from_reader(reader)
152}
153
154pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
156 let mut labels = Vec::new();
157 let mut instances = Vec::new();
158
159 for (line_idx, line_result) in reader.lines().enumerate() {
160 let line = line_result?;
161 let line = line.trim();
162 if line.is_empty() {
163 continue;
164 }
165
166 let line_num = line_idx + 1;
167 let mut parts = line.split_whitespace();
168
169 let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
171 line: line_num,
172 message: "missing label".into(),
173 })?;
174 let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
175 line: line_num,
176 message: format!("invalid label: {}", label_str),
177 })?;
178
179 let mut nodes = Vec::new();
181 let mut prev_index: i32 = 0;
182 for token in parts {
183 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
184 line: line_num,
185 message: format!("expected index:value, got: {}", token),
186 })?;
187 let index: i32 = idx_str.parse().map_err(|_| SvmError::ParseError {
188 line: line_num,
189 message: format!("invalid index: {}", idx_str),
190 })?;
191 if !nodes.is_empty() && index <= prev_index {
192 return Err(SvmError::ParseError {
193 line: line_num,
194 message: format!(
195 "feature indices must be ascending: {} follows {}",
196 index, prev_index
197 ),
198 });
199 }
200 let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
201 line: line_num,
202 message: format!("invalid value: {}", val_str),
203 })?;
204 prev_index = index;
205 nodes.push(SvmNode { index, value });
206 }
207
208 labels.push(label);
209 instances.push(nodes);
210 }
211
212 Ok(SvmProblem { labels, instances })
213}
214
215pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
219 let file = std::fs::File::create(path)?;
220 let writer = std::io::BufWriter::new(file);
221 save_model_to_writer(writer, model)
222}
223
224pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
226 let param = &model.param;
227
228 writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
229 writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
230
231 if param.kernel_type == KernelType::Polynomial {
232 writeln!(w, "degree {}", param.degree)?;
233 }
234 if matches!(
235 param.kernel_type,
236 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
237 ) {
238 writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
239 }
240 if matches!(
241 param.kernel_type,
242 KernelType::Polynomial | KernelType::Sigmoid
243 ) {
244 writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
245 }
246
247 let nr_class = model.nr_class;
248 writeln!(w, "nr_class {}", nr_class)?;
249 writeln!(w, "total_sv {}", model.sv.len())?;
250
251 write!(w, "rho")?;
253 for r in &model.rho {
254 write!(w, " {}", fmt_17g(*r))?;
255 }
256 writeln!(w)?;
257
258 if !model.label.is_empty() {
260 write!(w, "label")?;
261 for l in &model.label {
262 write!(w, " {}", l)?;
263 }
264 writeln!(w)?;
265 }
266
267 if !model.prob_a.is_empty() {
269 write!(w, "probA")?;
270 for v in &model.prob_a {
271 write!(w, " {}", fmt_17g(*v))?;
272 }
273 writeln!(w)?;
274 }
275
276 if !model.prob_b.is_empty() {
278 write!(w, "probB")?;
279 for v in &model.prob_b {
280 write!(w, " {}", fmt_17g(*v))?;
281 }
282 writeln!(w)?;
283 }
284
285 if !model.prob_density_marks.is_empty() {
287 write!(w, "prob_density_marks")?;
288 for v in &model.prob_density_marks {
289 write!(w, " {}", fmt_17g(*v))?;
290 }
291 writeln!(w)?;
292 }
293
294 if !model.n_sv.is_empty() {
296 write!(w, "nr_sv")?;
297 for n in &model.n_sv {
298 write!(w, " {}", n)?;
299 }
300 writeln!(w)?;
301 }
302
303 writeln!(w, "SV")?;
305 let num_sv = model.sv.len();
306 let num_coef_rows = model.sv_coef.len(); for i in 0..num_sv {
309 for j in 0..num_coef_rows {
311 write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
312 }
313 if model.param.kernel_type == KernelType::Precomputed {
315 if let Some(node) = model.sv[i].first() {
316 write!(w, "0:{} ", node.value as i32)?;
317 }
318 } else {
319 for node in &model.sv[i] {
320 write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
321 }
322 }
323 writeln!(w)?;
324 }
325
326 Ok(())
327}
328
329pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
331 let file = std::fs::File::open(path)?;
332 let reader = std::io::BufReader::new(file);
333 load_model_from_reader(reader)
334}
335
336pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
338 let mut lines = reader.lines();
339
340 let mut param = SvmParameter::default();
342 let mut nr_class: usize = 0;
343 let mut total_sv: usize = 0;
344 let mut rho = Vec::new();
345 let mut label = Vec::new();
346 let mut prob_a = Vec::new();
347 let mut prob_b = Vec::new();
348 let mut prob_density_marks = Vec::new();
349 let mut n_sv = Vec::new();
350
351 let mut line_num: usize = 0;
353 loop {
354 let line = lines
355 .next()
356 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in header".into()))??;
357 line_num += 1;
358 let line = line.trim().to_string();
359 if line.is_empty() {
360 continue;
361 }
362
363 let mut parts = line.split_whitespace();
364 let cmd = parts.next().unwrap();
365
366 match cmd {
367 "svm_type" => {
368 let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
369 format!("line {}: missing svm_type value", line_num),
370 ))?;
371 param.svm_type = str_to_svm_type(val).ok_or_else(|| {
372 SvmError::ModelFormatError(format!("line {}: unknown svm_type: {}", line_num, val))
373 })?;
374 }
375 "kernel_type" => {
376 let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
377 format!("line {}: missing kernel_type value", line_num),
378 ))?;
379 param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
380 SvmError::ModelFormatError(format!("line {}: unknown kernel_type: {}", line_num, val))
381 })?;
382 }
383 "degree" => {
384 param.degree = parse_single(&mut parts, line_num, "degree")?;
385 }
386 "gamma" => {
387 param.gamma = parse_single(&mut parts, line_num, "gamma")?;
388 }
389 "coef0" => {
390 param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
391 }
392 "nr_class" => {
393 nr_class = parse_single(&mut parts, line_num, "nr_class")?;
394 }
395 "total_sv" => {
396 total_sv = parse_single(&mut parts, line_num, "total_sv")?;
397 }
398 "rho" => {
399 rho = parse_multiple_f64(&mut parts, line_num, "rho")?;
400 }
401 "label" => {
402 label = parse_multiple_i32(&mut parts, line_num, "label")?;
403 }
404 "probA" => {
405 prob_a = parse_multiple_f64(&mut parts, line_num, "probA")?;
406 }
407 "probB" => {
408 prob_b = parse_multiple_f64(&mut parts, line_num, "probB")?;
409 }
410 "prob_density_marks" => {
411 prob_density_marks = parse_multiple_f64(&mut parts, line_num, "prob_density_marks")?;
412 }
413 "nr_sv" => {
414 n_sv = parts
415 .map(|s| {
416 s.parse::<usize>().map_err(|_| {
417 SvmError::ModelFormatError(format!(
418 "line {}: invalid nr_sv value: {}",
419 line_num, s
420 ))
421 })
422 })
423 .collect::<Result<Vec<_>, _>>()?;
424 }
425 "SV" => break,
426 _ => {
427 return Err(SvmError::ModelFormatError(format!(
428 "line {}: unknown keyword: {}",
429 line_num, cmd
430 )));
431 }
432 }
433 }
434
435 let m = if nr_class > 1 { nr_class - 1 } else { 1 };
437 let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::with_capacity(total_sv)).collect();
438 let mut sv: Vec<Vec<SvmNode>> = Vec::with_capacity(total_sv);
439
440 for _ in 0..total_sv {
441 let line = lines
442 .next()
443 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in SV section".into()))??;
444 line_num += 1;
445 let line = line.trim();
446 if line.is_empty() {
447 continue;
448 }
449
450 let mut parts = line.split_whitespace();
451
452 for (k, coef_row) in sv_coef.iter_mut().enumerate() {
454 let val_str = parts.next().ok_or_else(|| SvmError::ModelFormatError(
455 format!("line {}: missing sv_coef[{}]", line_num, k),
456 ))?;
457 let val: f64 = val_str.parse().map_err(|_| SvmError::ModelFormatError(
458 format!("line {}: invalid sv_coef: {}", line_num, val_str),
459 ))?;
460 coef_row.push(val);
461 }
462
463 let mut nodes = Vec::new();
465 for token in parts {
466 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
467 SvmError::ModelFormatError(format!(
468 "line {}: expected index:value, got: {}",
469 line_num, token
470 ))
471 })?;
472 let index: i32 = idx_str.parse().map_err(|_| {
473 SvmError::ModelFormatError(format!("line {}: invalid index: {}", line_num, idx_str))
474 })?;
475 let value: f64 = val_str.parse().map_err(|_| {
476 SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
477 })?;
478 nodes.push(SvmNode { index, value });
479 }
480 sv.push(nodes);
481 }
482
483 Ok(SvmModel {
484 param,
485 nr_class,
486 sv,
487 sv_coef,
488 rho,
489 prob_a,
490 prob_b,
491 prob_density_marks,
492 sv_indices: Vec::new(), label,
494 n_sv,
495 })
496}
497
498fn parse_single<T: std::str::FromStr>(
501 parts: &mut std::str::SplitWhitespace<'_>,
502 line_num: usize,
503 field: &str,
504) -> Result<T, SvmError> {
505 let val_str = parts.next().ok_or_else(|| {
506 SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
507 })?;
508 val_str.parse().map_err(|_| {
509 SvmError::ModelFormatError(format!("line {}: invalid {} value: {}", line_num, field, val_str))
510 })
511}
512
513fn parse_multiple_f64(
514 parts: &mut std::str::SplitWhitespace<'_>,
515 line_num: usize,
516 field: &str,
517) -> Result<Vec<f64>, SvmError> {
518 parts
519 .map(|s| {
520 s.parse::<f64>().map_err(|_| {
521 SvmError::ModelFormatError(format!(
522 "line {}: invalid {} value: {}",
523 line_num, field, s
524 ))
525 })
526 })
527 .collect()
528}
529
530fn parse_multiple_i32(
531 parts: &mut std::str::SplitWhitespace<'_>,
532 line_num: usize,
533 field: &str,
534) -> Result<Vec<i32>, SvmError> {
535 parts
536 .map(|s| {
537 s.parse::<i32>().map_err(|_| {
538 SvmError::ModelFormatError(format!(
539 "line {}: invalid {} value: {}",
540 line_num, field, s
541 ))
542 })
543 })
544 .collect()
545}
546
547#[cfg(test)]
550mod tests {
551 use super::*;
552 use std::path::PathBuf;
553
554 fn data_dir() -> PathBuf {
555 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
556 .join("..")
557 .join("..")
558 .join("data")
559 }
560
561 #[test]
562 fn parse_heart_scale() {
563 let path = data_dir().join("heart_scale");
564 let problem = load_problem(&path).unwrap();
565 assert_eq!(problem.labels.len(), 270);
566 assert_eq!(problem.instances.len(), 270);
567 assert_eq!(problem.labels[0], 1.0);
569 assert_eq!(problem.instances[0][0], SvmNode { index: 1, value: 0.708333 });
570 assert_eq!(problem.instances[0].len(), 12);
571 }
572
573 #[test]
574 fn parse_iris() {
575 let path = data_dir().join("iris.scale");
576 let problem = load_problem(&path).unwrap();
577 assert_eq!(problem.labels.len(), 150);
578 let classes: std::collections::HashSet<i64> =
580 problem.labels.iter().map(|&l| l as i64).collect();
581 assert_eq!(classes.len(), 3);
582 }
583
584 #[test]
585 fn parse_housing() {
586 let path = data_dir().join("housing_scale");
587 let problem = load_problem(&path).unwrap();
588 assert_eq!(problem.labels.len(), 506);
589 assert!((problem.labels[0] - 24.0).abs() < 1e-10);
591 }
592
593 #[test]
594 fn parse_empty_lines() {
595 let input = b"+1 1:0.5\n\n-1 2:0.3\n";
596 let problem = load_problem_from_reader(&input[..]).unwrap();
597 assert_eq!(problem.labels.len(), 2);
598 }
599
600 #[test]
601 fn parse_error_unsorted_indices() {
602 let input = b"+1 3:0.5 1:0.3\n";
603 let result = load_problem_from_reader(&input[..]);
604 assert!(result.is_err());
605 let msg = format!("{}", result.unwrap_err());
606 assert!(msg.contains("ascending"), "error: {}", msg);
607 }
608
609 #[test]
610 fn parse_error_duplicate_indices() {
611 let input = b"+1 1:0.5 1:0.3\n";
612 let result = load_problem_from_reader(&input[..]);
613 assert!(result.is_err());
614 }
615
616 #[test]
617 fn parse_error_missing_colon() {
618 let input = b"+1 1:0.5 bad_token\n";
619 let result = load_problem_from_reader(&input[..]);
620 assert!(result.is_err());
621 }
622
623 #[test]
624 fn load_c_trained_model() {
625 let path = data_dir().join("heart_scale.model");
627 let model = load_model(&path).unwrap();
628 assert_eq!(model.nr_class, 2);
629 assert_eq!(model.param.svm_type, SvmType::CSvc);
630 assert_eq!(model.param.kernel_type, KernelType::Rbf);
631 assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
632 assert_eq!(model.sv.len(), 132);
633 assert_eq!(model.label, vec![1, -1]);
634 assert_eq!(model.n_sv, vec![64, 68]);
635 assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
636 assert_eq!(model.sv_coef.len(), 1);
638 assert_eq!(model.sv_coef[0].len(), 132);
639 }
640
641 #[test]
642 fn roundtrip_c_model() {
643 let path = data_dir().join("heart_scale.model");
645 let original_bytes = std::fs::read_to_string(&path).unwrap();
646 let model = load_model(&path).unwrap();
647
648 let mut buf = Vec::new();
649 save_model_to_writer(&mut buf, &model).unwrap();
650 let rust_output = String::from_utf8(buf).unwrap();
651
652 let orig_lines: Vec<&str> = original_bytes.lines().collect();
654 let rust_lines: Vec<&str> = rust_output.lines().collect();
655 assert_eq!(
656 orig_lines.len(),
657 rust_lines.len(),
658 "line count mismatch: C={} Rust={}",
659 orig_lines.len(),
660 rust_lines.len()
661 );
662 for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
663 assert_eq!(o, r, "line {} differs:\n C: {:?}\n Rust: {:?}", i + 1, o, r);
664 }
665 }
666
667 #[test]
668 fn gfmt_matches_c_printf() {
669 let cases: &[(f64, &str, &str)] = &[
671 (0.5, "0.5", "0.5"),
672 (-1.0, "-1", "-1"),
673 (0.123456789012345, "0.123456789012345", "0.12345679"),
674 (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
675 (0.42446200000000001, "0.42446200000000001", "0.424462"),
676 (0.0, "0", "0"),
677 (1e-5, "1.0000000000000001e-05", "1e-05"),
678 (1e-4, "0.0001", "0.0001"),
679 (1e20, "1e+20", "1e+20"),
680 (-0.25, "-0.25", "-0.25"),
681 (0.75, "0.75", "0.75"),
682 (0.708333, "0.70833299999999999", "0.708333"),
683 (1.0, "1", "1"),
684 ];
685 for &(v, expected_17g, expected_8g) in cases {
686 let got_17 = format!("{}", fmt_17g(v));
687 let got_8 = format!("{}", fmt_8g(v));
688 assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
689 assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
690 }
691 }
692
693 #[test]
694 fn model_roundtrip() {
695 let model = SvmModel {
697 param: SvmParameter {
698 svm_type: SvmType::CSvc,
699 kernel_type: KernelType::Rbf,
700 gamma: 0.5,
701 ..Default::default()
702 },
703 nr_class: 2,
704 sv: vec![
705 vec![SvmNode { index: 1, value: 0.5 }, SvmNode { index: 3, value: -1.0 }],
706 vec![SvmNode { index: 1, value: -0.25 }, SvmNode { index: 2, value: 0.75 }],
707 ],
708 sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
709 rho: vec![0.42446200000000001],
710 prob_a: vec![],
711 prob_b: vec![],
712 prob_density_marks: vec![],
713 sv_indices: vec![],
714 label: vec![1, -1],
715 n_sv: vec![1, 1],
716 };
717
718 let mut buf = Vec::new();
719 save_model_to_writer(&mut buf, &model).unwrap();
720
721 let loaded = load_model_from_reader(&buf[..]).unwrap();
722
723 assert_eq!(loaded.nr_class, model.nr_class);
724 assert_eq!(loaded.param.svm_type, model.param.svm_type);
725 assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
726 assert_eq!(loaded.sv.len(), model.sv.len());
727 assert_eq!(loaded.label, model.label);
728 assert_eq!(loaded.n_sv, model.n_sv);
729 assert_eq!(loaded.rho.len(), model.rho.len());
730 for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
732 assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
733 }
734 for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
736 for (a, b) in row_a.iter().zip(row_b.iter()) {
737 assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
738 }
739 }
740 }
741}