1use std::io::{BufRead, Write};
7use std::path::Path;
8
9use crate::error::SvmError;
10use crate::types::*;
11
12use std::fmt;
22
23struct Gfmt {
25 value: f64,
26 precision: usize,
27}
28
29impl Gfmt {
30 fn new(value: f64, precision: usize) -> Self {
31 Self { value, precision }
32 }
33}
34
35impl fmt::Display for Gfmt {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 let v = self.value;
38 let p = self.precision;
39
40 if !v.is_finite() {
41 return write!(f, "{}", v); }
43
44 if v == 0.0 {
45 if v.is_sign_negative() {
47 return write!(f, "-0");
48 }
49 return write!(f, "0");
50 }
51
52 let abs_v = v.abs();
54 let exp = abs_v.log10().floor() as i32;
55
56 if exp < -4 || exp >= p as i32 {
57 let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
59 if let Some((mantissa, exponent)) = s.split_once('e') {
62 let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
63 let exp_val: i32 = exponent.parse().unwrap_or(0);
65 let exp_str = if exp_val < 0 {
66 format!("-{:02}", -exp_val)
67 } else {
68 format!("+{:02}", exp_val)
69 };
70 write!(f, "{}e{}", mantissa, exp_str)
71 } else {
72 write!(f, "{}", s)
73 }
74 } else {
75 let decimal_places = if exp >= 0 {
77 p.saturating_sub((exp + 1) as usize)
78 } else {
79 p + (-1 - exp) as usize
80 };
81 let s = format!("{:.prec$}", v, prec = decimal_places);
82 let s = s.trim_end_matches('0').trim_end_matches('.');
83 write!(f, "{}", s)
84 }
85 }
86}
87
88fn fmt_17g(v: f64) -> Gfmt {
90 Gfmt::new(v, 17)
91}
92
93fn fmt_8g(v: f64) -> Gfmt {
95 Gfmt::new(v, 8)
96}
97
98const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
101const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
102
103fn svm_type_to_str(t: SvmType) -> &'static str {
104 SVM_TYPE_TABLE[t as usize]
105}
106
107fn kernel_type_to_str(t: KernelType) -> &'static str {
108 KERNEL_TYPE_TABLE[t as usize]
109}
110
111fn str_to_svm_type(s: &str) -> Option<SvmType> {
112 match s {
113 "c_svc" => Some(SvmType::CSvc),
114 "nu_svc" => Some(SvmType::NuSvc),
115 "one_class" => Some(SvmType::OneClass),
116 "epsilon_svr" => Some(SvmType::EpsilonSvr),
117 "nu_svr" => Some(SvmType::NuSvr),
118 _ => None,
119 }
120}
121
122fn str_to_kernel_type(s: &str) -> Option<KernelType> {
123 match s {
124 "linear" => Some(KernelType::Linear),
125 "polynomial" => Some(KernelType::Polynomial),
126 "rbf" => Some(KernelType::Rbf),
127 "sigmoid" => Some(KernelType::Sigmoid),
128 "precomputed" => Some(KernelType::Precomputed),
129 _ => None,
130 }
131}
132
133pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
139 let file = std::fs::File::open(path)?;
140 let reader = std::io::BufReader::new(file);
141 load_problem_from_reader(reader)
142}
143
144pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
146 let mut labels = Vec::new();
147 let mut instances = Vec::new();
148
149 for (line_idx, line_result) in reader.lines().enumerate() {
150 let line = line_result?;
151 let line = line.trim();
152 if line.is_empty() {
153 continue;
154 }
155
156 let line_num = line_idx + 1;
157 let mut parts = line.split_whitespace();
158
159 let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
161 line: line_num,
162 message: "missing label".into(),
163 })?;
164 let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
165 line: line_num,
166 message: format!("invalid label: {}", label_str),
167 })?;
168
169 let mut nodes = Vec::new();
171 let mut prev_index: i32 = 0;
172 for token in parts {
173 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
174 line: line_num,
175 message: format!("expected index:value, got: {}", token),
176 })?;
177 let index: i32 = idx_str.parse().map_err(|_| SvmError::ParseError {
178 line: line_num,
179 message: format!("invalid index: {}", idx_str),
180 })?;
181 if !nodes.is_empty() && index <= prev_index {
182 return Err(SvmError::ParseError {
183 line: line_num,
184 message: format!(
185 "feature indices must be ascending: {} follows {}",
186 index, prev_index
187 ),
188 });
189 }
190 let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
191 line: line_num,
192 message: format!("invalid value: {}", val_str),
193 })?;
194 prev_index = index;
195 nodes.push(SvmNode { index, value });
196 }
197
198 labels.push(label);
199 instances.push(nodes);
200 }
201
202 Ok(SvmProblem { labels, instances })
203}
204
205pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
209 let file = std::fs::File::create(path)?;
210 let writer = std::io::BufWriter::new(file);
211 save_model_to_writer(writer, model)
212}
213
214pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
216 let param = &model.param;
217
218 writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
219 writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
220
221 if param.kernel_type == KernelType::Polynomial {
222 writeln!(w, "degree {}", param.degree)?;
223 }
224 if matches!(
225 param.kernel_type,
226 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
227 ) {
228 writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
229 }
230 if matches!(
231 param.kernel_type,
232 KernelType::Polynomial | KernelType::Sigmoid
233 ) {
234 writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
235 }
236
237 let nr_class = model.nr_class;
238 writeln!(w, "nr_class {}", nr_class)?;
239 writeln!(w, "total_sv {}", model.sv.len())?;
240
241 write!(w, "rho")?;
243 for r in &model.rho {
244 write!(w, " {}", fmt_17g(*r))?;
245 }
246 writeln!(w)?;
247
248 if !model.label.is_empty() {
250 write!(w, "label")?;
251 for l in &model.label {
252 write!(w, " {}", l)?;
253 }
254 writeln!(w)?;
255 }
256
257 if !model.prob_a.is_empty() {
259 write!(w, "probA")?;
260 for v in &model.prob_a {
261 write!(w, " {}", fmt_17g(*v))?;
262 }
263 writeln!(w)?;
264 }
265
266 if !model.prob_b.is_empty() {
268 write!(w, "probB")?;
269 for v in &model.prob_b {
270 write!(w, " {}", fmt_17g(*v))?;
271 }
272 writeln!(w)?;
273 }
274
275 if !model.prob_density_marks.is_empty() {
277 write!(w, "prob_density_marks")?;
278 for v in &model.prob_density_marks {
279 write!(w, " {}", fmt_17g(*v))?;
280 }
281 writeln!(w)?;
282 }
283
284 if !model.n_sv.is_empty() {
286 write!(w, "nr_sv")?;
287 for n in &model.n_sv {
288 write!(w, " {}", n)?;
289 }
290 writeln!(w)?;
291 }
292
293 writeln!(w, "SV")?;
295 let num_sv = model.sv.len();
296 let num_coef_rows = model.sv_coef.len(); for i in 0..num_sv {
299 for j in 0..num_coef_rows {
301 write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
302 }
303 if model.param.kernel_type == KernelType::Precomputed {
305 if let Some(node) = model.sv[i].first() {
306 write!(w, "0:{} ", node.value as i32)?;
307 }
308 } else {
309 for node in &model.sv[i] {
310 write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
311 }
312 }
313 writeln!(w)?;
314 }
315
316 Ok(())
317}
318
319pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
321 let file = std::fs::File::open(path)?;
322 let reader = std::io::BufReader::new(file);
323 load_model_from_reader(reader)
324}
325
326pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
328 let mut lines = reader.lines();
329
330 let mut param = SvmParameter::default();
332 let mut nr_class: usize = 0;
333 let mut total_sv: usize = 0;
334 let mut rho = Vec::new();
335 let mut label = Vec::new();
336 let mut prob_a = Vec::new();
337 let mut prob_b = Vec::new();
338 let mut prob_density_marks = Vec::new();
339 let mut n_sv = Vec::new();
340
341 let mut line_num: usize = 0;
343 loop {
344 let line = lines
345 .next()
346 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in header".into()))??;
347 line_num += 1;
348 let line = line.trim().to_string();
349 if line.is_empty() {
350 continue;
351 }
352
353 let mut parts = line.split_whitespace();
354 let cmd = parts.next().unwrap();
355
356 match cmd {
357 "svm_type" => {
358 let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
359 format!("line {}: missing svm_type value", line_num),
360 ))?;
361 param.svm_type = str_to_svm_type(val).ok_or_else(|| {
362 SvmError::ModelFormatError(format!("line {}: unknown svm_type: {}", line_num, val))
363 })?;
364 }
365 "kernel_type" => {
366 let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
367 format!("line {}: missing kernel_type value", line_num),
368 ))?;
369 param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
370 SvmError::ModelFormatError(format!("line {}: unknown kernel_type: {}", line_num, val))
371 })?;
372 }
373 "degree" => {
374 param.degree = parse_single(&mut parts, line_num, "degree")?;
375 }
376 "gamma" => {
377 param.gamma = parse_single(&mut parts, line_num, "gamma")?;
378 }
379 "coef0" => {
380 param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
381 }
382 "nr_class" => {
383 nr_class = parse_single(&mut parts, line_num, "nr_class")?;
384 }
385 "total_sv" => {
386 total_sv = parse_single(&mut parts, line_num, "total_sv")?;
387 }
388 "rho" => {
389 rho = parse_multiple_f64(&mut parts, line_num, "rho")?;
390 }
391 "label" => {
392 label = parse_multiple_i32(&mut parts, line_num, "label")?;
393 }
394 "probA" => {
395 prob_a = parse_multiple_f64(&mut parts, line_num, "probA")?;
396 }
397 "probB" => {
398 prob_b = parse_multiple_f64(&mut parts, line_num, "probB")?;
399 }
400 "prob_density_marks" => {
401 prob_density_marks = parse_multiple_f64(&mut parts, line_num, "prob_density_marks")?;
402 }
403 "nr_sv" => {
404 n_sv = parts
405 .map(|s| {
406 s.parse::<usize>().map_err(|_| {
407 SvmError::ModelFormatError(format!(
408 "line {}: invalid nr_sv value: {}",
409 line_num, s
410 ))
411 })
412 })
413 .collect::<Result<Vec<_>, _>>()?;
414 }
415 "SV" => break,
416 _ => {
417 return Err(SvmError::ModelFormatError(format!(
418 "line {}: unknown keyword: {}",
419 line_num, cmd
420 )));
421 }
422 }
423 }
424
425 let m = if nr_class > 1 { nr_class - 1 } else { 1 };
427 let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::with_capacity(total_sv)).collect();
428 let mut sv: Vec<Vec<SvmNode>> = Vec::with_capacity(total_sv);
429
430 for _ in 0..total_sv {
431 let line = lines
432 .next()
433 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in SV section".into()))??;
434 line_num += 1;
435 let line = line.trim();
436 if line.is_empty() {
437 continue;
438 }
439
440 let mut parts = line.split_whitespace();
441
442 for (k, coef_row) in sv_coef.iter_mut().enumerate() {
444 let val_str = parts.next().ok_or_else(|| SvmError::ModelFormatError(
445 format!("line {}: missing sv_coef[{}]", line_num, k),
446 ))?;
447 let val: f64 = val_str.parse().map_err(|_| SvmError::ModelFormatError(
448 format!("line {}: invalid sv_coef: {}", line_num, val_str),
449 ))?;
450 coef_row.push(val);
451 }
452
453 let mut nodes = Vec::new();
455 for token in parts {
456 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
457 SvmError::ModelFormatError(format!(
458 "line {}: expected index:value, got: {}",
459 line_num, token
460 ))
461 })?;
462 let index: i32 = idx_str.parse().map_err(|_| {
463 SvmError::ModelFormatError(format!("line {}: invalid index: {}", line_num, idx_str))
464 })?;
465 let value: f64 = val_str.parse().map_err(|_| {
466 SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
467 })?;
468 nodes.push(SvmNode { index, value });
469 }
470 sv.push(nodes);
471 }
472
473 Ok(SvmModel {
474 param,
475 nr_class,
476 sv,
477 sv_coef,
478 rho,
479 prob_a,
480 prob_b,
481 prob_density_marks,
482 sv_indices: Vec::new(), label,
484 n_sv,
485 })
486}
487
488fn parse_single<T: std::str::FromStr>(
491 parts: &mut std::str::SplitWhitespace<'_>,
492 line_num: usize,
493 field: &str,
494) -> Result<T, SvmError> {
495 let val_str = parts.next().ok_or_else(|| {
496 SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
497 })?;
498 val_str.parse().map_err(|_| {
499 SvmError::ModelFormatError(format!("line {}: invalid {} value: {}", line_num, field, val_str))
500 })
501}
502
503fn parse_multiple_f64(
504 parts: &mut std::str::SplitWhitespace<'_>,
505 line_num: usize,
506 field: &str,
507) -> Result<Vec<f64>, SvmError> {
508 parts
509 .map(|s| {
510 s.parse::<f64>().map_err(|_| {
511 SvmError::ModelFormatError(format!(
512 "line {}: invalid {} value: {}",
513 line_num, field, s
514 ))
515 })
516 })
517 .collect()
518}
519
520fn parse_multiple_i32(
521 parts: &mut std::str::SplitWhitespace<'_>,
522 line_num: usize,
523 field: &str,
524) -> Result<Vec<i32>, SvmError> {
525 parts
526 .map(|s| {
527 s.parse::<i32>().map_err(|_| {
528 SvmError::ModelFormatError(format!(
529 "line {}: invalid {} value: {}",
530 line_num, field, s
531 ))
532 })
533 })
534 .collect()
535}
536
537#[cfg(test)]
540mod tests {
541 use super::*;
542 use std::path::PathBuf;
543
544 fn data_dir() -> PathBuf {
545 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
546 .join("..")
547 .join("..")
548 .join("data")
549 }
550
551 #[test]
552 fn parse_heart_scale() {
553 let path = data_dir().join("heart_scale");
554 let problem = load_problem(&path).unwrap();
555 assert_eq!(problem.labels.len(), 270);
556 assert_eq!(problem.instances.len(), 270);
557 assert_eq!(problem.labels[0], 1.0);
559 assert_eq!(problem.instances[0][0], SvmNode { index: 1, value: 0.708333 });
560 assert_eq!(problem.instances[0].len(), 12);
561 }
562
563 #[test]
564 fn parse_iris() {
565 let path = data_dir().join("iris.scale");
566 let problem = load_problem(&path).unwrap();
567 assert_eq!(problem.labels.len(), 150);
568 let classes: std::collections::HashSet<i64> =
570 problem.labels.iter().map(|&l| l as i64).collect();
571 assert_eq!(classes.len(), 3);
572 }
573
574 #[test]
575 fn parse_housing() {
576 let path = data_dir().join("housing_scale");
577 let problem = load_problem(&path).unwrap();
578 assert_eq!(problem.labels.len(), 506);
579 assert!((problem.labels[0] - 24.0).abs() < 1e-10);
581 }
582
583 #[test]
584 fn parse_empty_lines() {
585 let input = b"+1 1:0.5\n\n-1 2:0.3\n";
586 let problem = load_problem_from_reader(&input[..]).unwrap();
587 assert_eq!(problem.labels.len(), 2);
588 }
589
590 #[test]
591 fn parse_error_unsorted_indices() {
592 let input = b"+1 3:0.5 1:0.3\n";
593 let result = load_problem_from_reader(&input[..]);
594 assert!(result.is_err());
595 let msg = format!("{}", result.unwrap_err());
596 assert!(msg.contains("ascending"), "error: {}", msg);
597 }
598
599 #[test]
600 fn parse_error_duplicate_indices() {
601 let input = b"+1 1:0.5 1:0.3\n";
602 let result = load_problem_from_reader(&input[..]);
603 assert!(result.is_err());
604 }
605
606 #[test]
607 fn parse_error_missing_colon() {
608 let input = b"+1 1:0.5 bad_token\n";
609 let result = load_problem_from_reader(&input[..]);
610 assert!(result.is_err());
611 }
612
613 #[test]
614 fn load_c_trained_model() {
615 let path = data_dir().join("heart_scale.model");
617 let model = load_model(&path).unwrap();
618 assert_eq!(model.nr_class, 2);
619 assert_eq!(model.param.svm_type, SvmType::CSvc);
620 assert_eq!(model.param.kernel_type, KernelType::Rbf);
621 assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
622 assert_eq!(model.sv.len(), 132);
623 assert_eq!(model.label, vec![1, -1]);
624 assert_eq!(model.n_sv, vec![64, 68]);
625 assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
626 assert_eq!(model.sv_coef.len(), 1);
628 assert_eq!(model.sv_coef[0].len(), 132);
629 }
630
631 #[test]
632 fn roundtrip_c_model() {
633 let path = data_dir().join("heart_scale.model");
635 let original_bytes = std::fs::read_to_string(&path).unwrap();
636 let model = load_model(&path).unwrap();
637
638 let mut buf = Vec::new();
639 save_model_to_writer(&mut buf, &model).unwrap();
640 let rust_output = String::from_utf8(buf).unwrap();
641
642 let orig_lines: Vec<&str> = original_bytes.lines().collect();
644 let rust_lines: Vec<&str> = rust_output.lines().collect();
645 assert_eq!(
646 orig_lines.len(),
647 rust_lines.len(),
648 "line count mismatch: C={} Rust={}",
649 orig_lines.len(),
650 rust_lines.len()
651 );
652 for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
653 assert_eq!(o, r, "line {} differs:\n C: {:?}\n Rust: {:?}", i + 1, o, r);
654 }
655 }
656
657 #[test]
658 fn gfmt_matches_c_printf() {
659 let cases: &[(f64, &str, &str)] = &[
661 (0.5, "0.5", "0.5"),
662 (-1.0, "-1", "-1"),
663 (0.123456789012345, "0.123456789012345", "0.12345679"),
664 (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
665 (0.42446200000000001, "0.42446200000000001", "0.424462"),
666 (0.0, "0", "0"),
667 (1e-5, "1.0000000000000001e-05", "1e-05"),
668 (1e-4, "0.0001", "0.0001"),
669 (1e20, "1e+20", "1e+20"),
670 (-0.25, "-0.25", "-0.25"),
671 (0.75, "0.75", "0.75"),
672 (0.708333, "0.70833299999999999", "0.708333"),
673 (1.0, "1", "1"),
674 ];
675 for &(v, expected_17g, expected_8g) in cases {
676 let got_17 = format!("{}", fmt_17g(v));
677 let got_8 = format!("{}", fmt_8g(v));
678 assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
679 assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
680 }
681 }
682
683 #[test]
684 fn model_roundtrip() {
685 let model = SvmModel {
687 param: SvmParameter {
688 svm_type: SvmType::CSvc,
689 kernel_type: KernelType::Rbf,
690 gamma: 0.5,
691 ..Default::default()
692 },
693 nr_class: 2,
694 sv: vec![
695 vec![SvmNode { index: 1, value: 0.5 }, SvmNode { index: 3, value: -1.0 }],
696 vec![SvmNode { index: 1, value: -0.25 }, SvmNode { index: 2, value: 0.75 }],
697 ],
698 sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
699 rho: vec![0.42446200000000001],
700 prob_a: vec![],
701 prob_b: vec![],
702 prob_density_marks: vec![],
703 sv_indices: vec![],
704 label: vec![1, -1],
705 n_sv: vec![1, 1],
706 };
707
708 let mut buf = Vec::new();
709 save_model_to_writer(&mut buf, &model).unwrap();
710
711 let loaded = load_model_from_reader(&buf[..]).unwrap();
712
713 assert_eq!(loaded.nr_class, model.nr_class);
714 assert_eq!(loaded.param.svm_type, model.param.svm_type);
715 assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
716 assert_eq!(loaded.sv.len(), model.sv.len());
717 assert_eq!(loaded.label, model.label);
718 assert_eq!(loaded.n_sv, model.n_sv);
719 assert_eq!(loaded.rho.len(), model.rho.len());
720 for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
722 assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
723 }
724 for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
726 for (a, b) in row_a.iter().zip(row_b.iter()) {
727 assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
728 }
729 }
730 }
731}