1use std::io::{BufRead, Write};
7use std::path::Path;
8
9use crate::error::SvmError;
10use crate::types::*;
11
12use std::fmt;
22
23struct Gfmt {
25 value: f64,
26 precision: usize,
27}
28
29impl Gfmt {
30 fn new(value: f64, precision: usize) -> Self {
31 Self { value, precision }
32 }
33}
34
35impl fmt::Display for Gfmt {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 let v = self.value;
38 let p = self.precision;
39
40 if !v.is_finite() {
41 return write!(f, "{}", v); }
43
44 if v == 0.0 {
45 if v.is_sign_negative() {
47 return write!(f, "-0");
48 }
49 return write!(f, "0");
50 }
51
52 let abs_v = v.abs();
54 let exp = abs_v.log10().floor() as i32;
55
56 if exp < -4 || exp >= p as i32 {
57 let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
59 if let Some((mantissa, exponent)) = s.split_once('e') {
62 let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
63 let exp_val: i32 = exponent.parse().unwrap_or(0);
65 let exp_str = if exp_val < 0 {
66 format!("-{:02}", -exp_val)
67 } else {
68 format!("+{:02}", exp_val)
69 };
70 write!(f, "{}e{}", mantissa, exp_str)
71 } else {
72 write!(f, "{}", s)
73 }
74 } else {
75 let decimal_places = if exp >= 0 {
77 p.saturating_sub((exp + 1) as usize)
78 } else {
79 p + (-1 - exp) as usize
80 };
81 let s = format!("{:.prec$}", v, prec = decimal_places);
82 let s = s.trim_end_matches('0').trim_end_matches('.');
83 write!(f, "{}", s)
84 }
85 }
86}
87
88fn fmt_17g(v: f64) -> Gfmt {
90 Gfmt::new(v, 17)
91}
92
93fn fmt_8g(v: f64) -> Gfmt {
95 Gfmt::new(v, 8)
96}
97
98const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
101const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
102
103fn svm_type_to_str(t: SvmType) -> &'static str {
104 SVM_TYPE_TABLE[t as usize]
105}
106
107fn kernel_type_to_str(t: KernelType) -> &'static str {
108 KERNEL_TYPE_TABLE[t as usize]
109}
110
111fn str_to_svm_type(s: &str) -> Option<SvmType> {
112 match s {
113 "c_svc" => Some(SvmType::CSvc),
114 "nu_svc" => Some(SvmType::NuSvc),
115 "one_class" => Some(SvmType::OneClass),
116 "epsilon_svr" => Some(SvmType::EpsilonSvr),
117 "nu_svr" => Some(SvmType::NuSvr),
118 _ => None,
119 }
120}
121
122fn str_to_kernel_type(s: &str) -> Option<KernelType> {
123 match s {
124 "linear" => Some(KernelType::Linear),
125 "polynomial" => Some(KernelType::Polynomial),
126 "rbf" => Some(KernelType::Rbf),
127 "sigmoid" => Some(KernelType::Sigmoid),
128 "precomputed" => Some(KernelType::Precomputed),
129 _ => None,
130 }
131}
132
133pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
139 let file = std::fs::File::open(path)?;
140 let reader = std::io::BufReader::new(file);
141 load_problem_from_reader(reader)
142}
143
144pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
146 let mut labels = Vec::new();
147 let mut instances = Vec::new();
148
149 for (line_idx, line_result) in reader.lines().enumerate() {
150 let line = line_result?;
151 let line = line.trim();
152 if line.is_empty() {
153 continue;
154 }
155
156 let line_num = line_idx + 1;
157 let mut parts = line.split_whitespace();
158
159 let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
161 line: line_num,
162 message: "missing label".into(),
163 })?;
164 let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
165 line: line_num,
166 message: format!("invalid label: {}", label_str),
167 })?;
168
169 let mut nodes = Vec::new();
171 for token in parts {
172 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
173 line: line_num,
174 message: format!("expected index:value, got: {}", token),
175 })?;
176 let index: i32 = idx_str.parse().map_err(|_| SvmError::ParseError {
177 line: line_num,
178 message: format!("invalid index: {}", idx_str),
179 })?;
180 let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
181 line: line_num,
182 message: format!("invalid value: {}", val_str),
183 })?;
184 nodes.push(SvmNode { index, value });
185 }
186
187 labels.push(label);
188 instances.push(nodes);
189 }
190
191 Ok(SvmProblem { labels, instances })
192}
193
194pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
198 let file = std::fs::File::create(path)?;
199 let writer = std::io::BufWriter::new(file);
200 save_model_to_writer(writer, model)
201}
202
203pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
205 let param = &model.param;
206
207 writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
208 writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
209
210 if param.kernel_type == KernelType::Polynomial {
211 writeln!(w, "degree {}", param.degree)?;
212 }
213 if matches!(
214 param.kernel_type,
215 KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
216 ) {
217 writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
218 }
219 if matches!(
220 param.kernel_type,
221 KernelType::Polynomial | KernelType::Sigmoid
222 ) {
223 writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
224 }
225
226 let nr_class = model.nr_class;
227 writeln!(w, "nr_class {}", nr_class)?;
228 writeln!(w, "total_sv {}", model.sv.len())?;
229
230 write!(w, "rho")?;
232 for r in &model.rho {
233 write!(w, " {}", fmt_17g(*r))?;
234 }
235 writeln!(w)?;
236
237 if !model.label.is_empty() {
239 write!(w, "label")?;
240 for l in &model.label {
241 write!(w, " {}", l)?;
242 }
243 writeln!(w)?;
244 }
245
246 if !model.prob_a.is_empty() {
248 write!(w, "probA")?;
249 for v in &model.prob_a {
250 write!(w, " {}", fmt_17g(*v))?;
251 }
252 writeln!(w)?;
253 }
254
255 if !model.prob_b.is_empty() {
257 write!(w, "probB")?;
258 for v in &model.prob_b {
259 write!(w, " {}", fmt_17g(*v))?;
260 }
261 writeln!(w)?;
262 }
263
264 if !model.prob_density_marks.is_empty() {
266 write!(w, "prob_density_marks")?;
267 for v in &model.prob_density_marks {
268 write!(w, " {}", fmt_17g(*v))?;
269 }
270 writeln!(w)?;
271 }
272
273 if !model.n_sv.is_empty() {
275 write!(w, "nr_sv")?;
276 for n in &model.n_sv {
277 write!(w, " {}", n)?;
278 }
279 writeln!(w)?;
280 }
281
282 writeln!(w, "SV")?;
284 let num_sv = model.sv.len();
285 let num_coef_rows = model.sv_coef.len(); for i in 0..num_sv {
288 for j in 0..num_coef_rows {
290 write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
291 }
292 if model.param.kernel_type == KernelType::Precomputed {
294 if let Some(node) = model.sv[i].first() {
295 write!(w, "0:{} ", node.value as i32)?;
296 }
297 } else {
298 for node in &model.sv[i] {
299 write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
300 }
301 }
302 writeln!(w)?;
303 }
304
305 Ok(())
306}
307
308pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
310 let file = std::fs::File::open(path)?;
311 let reader = std::io::BufReader::new(file);
312 load_model_from_reader(reader)
313}
314
315pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
317 let mut lines = reader.lines();
318
319 let mut param = SvmParameter::default();
321 let mut nr_class: usize = 0;
322 let mut total_sv: usize = 0;
323 let mut rho = Vec::new();
324 let mut label = Vec::new();
325 let mut prob_a = Vec::new();
326 let mut prob_b = Vec::new();
327 let mut prob_density_marks = Vec::new();
328 let mut n_sv = Vec::new();
329
330 let mut line_num: usize = 0;
332 loop {
333 let line = lines
334 .next()
335 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in header".into()))??;
336 line_num += 1;
337 let line = line.trim().to_string();
338 if line.is_empty() {
339 continue;
340 }
341
342 let mut parts = line.split_whitespace();
343 let cmd = parts.next().unwrap();
344
345 match cmd {
346 "svm_type" => {
347 let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
348 format!("line {}: missing svm_type value", line_num),
349 ))?;
350 param.svm_type = str_to_svm_type(val).ok_or_else(|| {
351 SvmError::ModelFormatError(format!("line {}: unknown svm_type: {}", line_num, val))
352 })?;
353 }
354 "kernel_type" => {
355 let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
356 format!("line {}: missing kernel_type value", line_num),
357 ))?;
358 param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
359 SvmError::ModelFormatError(format!("line {}: unknown kernel_type: {}", line_num, val))
360 })?;
361 }
362 "degree" => {
363 param.degree = parse_single(&mut parts, line_num, "degree")?;
364 }
365 "gamma" => {
366 param.gamma = parse_single(&mut parts, line_num, "gamma")?;
367 }
368 "coef0" => {
369 param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
370 }
371 "nr_class" => {
372 nr_class = parse_single(&mut parts, line_num, "nr_class")?;
373 }
374 "total_sv" => {
375 total_sv = parse_single(&mut parts, line_num, "total_sv")?;
376 }
377 "rho" => {
378 rho = parse_multiple_f64(&mut parts, line_num, "rho")?;
379 }
380 "label" => {
381 label = parse_multiple_i32(&mut parts, line_num, "label")?;
382 }
383 "probA" => {
384 prob_a = parse_multiple_f64(&mut parts, line_num, "probA")?;
385 }
386 "probB" => {
387 prob_b = parse_multiple_f64(&mut parts, line_num, "probB")?;
388 }
389 "prob_density_marks" => {
390 prob_density_marks = parse_multiple_f64(&mut parts, line_num, "prob_density_marks")?;
391 }
392 "nr_sv" => {
393 n_sv = parts
394 .map(|s| {
395 s.parse::<usize>().map_err(|_| {
396 SvmError::ModelFormatError(format!(
397 "line {}: invalid nr_sv value: {}",
398 line_num, s
399 ))
400 })
401 })
402 .collect::<Result<Vec<_>, _>>()?;
403 }
404 "SV" => break,
405 _ => {
406 return Err(SvmError::ModelFormatError(format!(
407 "line {}: unknown keyword: {}",
408 line_num, cmd
409 )));
410 }
411 }
412 }
413
414 let m = if nr_class > 1 { nr_class - 1 } else { 1 };
416 let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::with_capacity(total_sv)).collect();
417 let mut sv: Vec<Vec<SvmNode>> = Vec::with_capacity(total_sv);
418
419 for _ in 0..total_sv {
420 let line = lines
421 .next()
422 .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in SV section".into()))??;
423 line_num += 1;
424 let line = line.trim();
425 if line.is_empty() {
426 continue;
427 }
428
429 let mut parts = line.split_whitespace();
430
431 for (k, coef_row) in sv_coef.iter_mut().enumerate() {
433 let val_str = parts.next().ok_or_else(|| SvmError::ModelFormatError(
434 format!("line {}: missing sv_coef[{}]", line_num, k),
435 ))?;
436 let val: f64 = val_str.parse().map_err(|_| SvmError::ModelFormatError(
437 format!("line {}: invalid sv_coef: {}", line_num, val_str),
438 ))?;
439 coef_row.push(val);
440 }
441
442 let mut nodes = Vec::new();
444 for token in parts {
445 let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
446 SvmError::ModelFormatError(format!(
447 "line {}: expected index:value, got: {}",
448 line_num, token
449 ))
450 })?;
451 let index: i32 = idx_str.parse().map_err(|_| {
452 SvmError::ModelFormatError(format!("line {}: invalid index: {}", line_num, idx_str))
453 })?;
454 let value: f64 = val_str.parse().map_err(|_| {
455 SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
456 })?;
457 nodes.push(SvmNode { index, value });
458 }
459 sv.push(nodes);
460 }
461
462 Ok(SvmModel {
463 param,
464 nr_class,
465 sv,
466 sv_coef,
467 rho,
468 prob_a,
469 prob_b,
470 prob_density_marks,
471 sv_indices: Vec::new(), label,
473 n_sv,
474 })
475}
476
477fn parse_single<T: std::str::FromStr>(
480 parts: &mut std::str::SplitWhitespace<'_>,
481 line_num: usize,
482 field: &str,
483) -> Result<T, SvmError> {
484 let val_str = parts.next().ok_or_else(|| {
485 SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
486 })?;
487 val_str.parse().map_err(|_| {
488 SvmError::ModelFormatError(format!("line {}: invalid {} value: {}", line_num, field, val_str))
489 })
490}
491
492fn parse_multiple_f64(
493 parts: &mut std::str::SplitWhitespace<'_>,
494 line_num: usize,
495 field: &str,
496) -> Result<Vec<f64>, SvmError> {
497 parts
498 .map(|s| {
499 s.parse::<f64>().map_err(|_| {
500 SvmError::ModelFormatError(format!(
501 "line {}: invalid {} value: {}",
502 line_num, field, s
503 ))
504 })
505 })
506 .collect()
507}
508
509fn parse_multiple_i32(
510 parts: &mut std::str::SplitWhitespace<'_>,
511 line_num: usize,
512 field: &str,
513) -> Result<Vec<i32>, SvmError> {
514 parts
515 .map(|s| {
516 s.parse::<i32>().map_err(|_| {
517 SvmError::ModelFormatError(format!(
518 "line {}: invalid {} value: {}",
519 line_num, field, s
520 ))
521 })
522 })
523 .collect()
524}
525
526#[cfg(test)]
529mod tests {
530 use super::*;
531 use std::path::PathBuf;
532
533 fn data_dir() -> PathBuf {
534 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
535 .join("..")
536 .join("..")
537 .join("data")
538 }
539
540 #[test]
541 fn parse_heart_scale() {
542 let path = data_dir().join("heart_scale");
543 let problem = load_problem(&path).unwrap();
544 assert_eq!(problem.labels.len(), 270);
545 assert_eq!(problem.instances.len(), 270);
546 assert_eq!(problem.labels[0], 1.0);
548 assert_eq!(problem.instances[0][0], SvmNode { index: 1, value: 0.708333 });
549 assert_eq!(problem.instances[0].len(), 12);
550 }
551
552 #[test]
553 fn parse_iris() {
554 let path = data_dir().join("iris.scale");
555 let problem = load_problem(&path).unwrap();
556 assert_eq!(problem.labels.len(), 150);
557 let classes: std::collections::HashSet<i64> =
559 problem.labels.iter().map(|&l| l as i64).collect();
560 assert_eq!(classes.len(), 3);
561 }
562
563 #[test]
564 fn parse_housing() {
565 let path = data_dir().join("housing_scale");
566 let problem = load_problem(&path).unwrap();
567 assert_eq!(problem.labels.len(), 506);
568 assert!((problem.labels[0] - 24.0).abs() < 1e-10);
570 }
571
572 #[test]
573 fn parse_empty_lines() {
574 let input = b"+1 1:0.5\n\n-1 2:0.3\n";
575 let problem = load_problem_from_reader(&input[..]).unwrap();
576 assert_eq!(problem.labels.len(), 2);
577 }
578
579 #[test]
580 fn parse_error_missing_colon() {
581 let input = b"+1 1:0.5 bad_token\n";
582 let result = load_problem_from_reader(&input[..]);
583 assert!(result.is_err());
584 }
585
586 #[test]
587 fn load_c_trained_model() {
588 let path = data_dir().join("heart_scale.model");
590 let model = load_model(&path).unwrap();
591 assert_eq!(model.nr_class, 2);
592 assert_eq!(model.param.svm_type, SvmType::CSvc);
593 assert_eq!(model.param.kernel_type, KernelType::Rbf);
594 assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
595 assert_eq!(model.sv.len(), 132);
596 assert_eq!(model.label, vec![1, -1]);
597 assert_eq!(model.n_sv, vec![64, 68]);
598 assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
599 assert_eq!(model.sv_coef.len(), 1);
601 assert_eq!(model.sv_coef[0].len(), 132);
602 }
603
604 #[test]
605 fn roundtrip_c_model() {
606 let path = data_dir().join("heart_scale.model");
608 let original_bytes = std::fs::read_to_string(&path).unwrap();
609 let model = load_model(&path).unwrap();
610
611 let mut buf = Vec::new();
612 save_model_to_writer(&mut buf, &model).unwrap();
613 let rust_output = String::from_utf8(buf).unwrap();
614
615 let orig_lines: Vec<&str> = original_bytes.lines().collect();
617 let rust_lines: Vec<&str> = rust_output.lines().collect();
618 assert_eq!(
619 orig_lines.len(),
620 rust_lines.len(),
621 "line count mismatch: C={} Rust={}",
622 orig_lines.len(),
623 rust_lines.len()
624 );
625 for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
626 assert_eq!(o, r, "line {} differs:\n C: {:?}\n Rust: {:?}", i + 1, o, r);
627 }
628 }
629
630 #[test]
631 fn gfmt_matches_c_printf() {
632 let cases: &[(f64, &str, &str)] = &[
634 (0.5, "0.5", "0.5"),
635 (-1.0, "-1", "-1"),
636 (0.123456789012345, "0.123456789012345", "0.12345679"),
637 (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
638 (0.42446200000000001, "0.42446200000000001", "0.424462"),
639 (0.0, "0", "0"),
640 (1e-5, "1.0000000000000001e-05", "1e-05"),
641 (1e-4, "0.0001", "0.0001"),
642 (1e20, "1e+20", "1e+20"),
643 (-0.25, "-0.25", "-0.25"),
644 (0.75, "0.75", "0.75"),
645 (0.708333, "0.70833299999999999", "0.708333"),
646 (1.0, "1", "1"),
647 ];
648 for &(v, expected_17g, expected_8g) in cases {
649 let got_17 = format!("{}", fmt_17g(v));
650 let got_8 = format!("{}", fmt_8g(v));
651 assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
652 assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
653 }
654 }
655
656 #[test]
657 fn model_roundtrip() {
658 let model = SvmModel {
660 param: SvmParameter {
661 svm_type: SvmType::CSvc,
662 kernel_type: KernelType::Rbf,
663 gamma: 0.5,
664 ..Default::default()
665 },
666 nr_class: 2,
667 sv: vec![
668 vec![SvmNode { index: 1, value: 0.5 }, SvmNode { index: 3, value: -1.0 }],
669 vec![SvmNode { index: 1, value: -0.25 }, SvmNode { index: 2, value: 0.75 }],
670 ],
671 sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
672 rho: vec![0.42446200000000001],
673 prob_a: vec![],
674 prob_b: vec![],
675 prob_density_marks: vec![],
676 sv_indices: vec![],
677 label: vec![1, -1],
678 n_sv: vec![1, 1],
679 };
680
681 let mut buf = Vec::new();
682 save_model_to_writer(&mut buf, &model).unwrap();
683
684 let loaded = load_model_from_reader(&buf[..]).unwrap();
685
686 assert_eq!(loaded.nr_class, model.nr_class);
687 assert_eq!(loaded.param.svm_type, model.param.svm_type);
688 assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
689 assert_eq!(loaded.sv.len(), model.sv.len());
690 assert_eq!(loaded.label, model.label);
691 assert_eq!(loaded.n_sv, model.n_sv);
692 assert_eq!(loaded.rho.len(), model.rho.len());
693 for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
695 assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
696 }
697 for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
699 for (a, b) in row_a.iter().zip(row_b.iter()) {
700 assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
701 }
702 }
703 }
704}