Skip to main content

libsvm_rs/
io.rs

1//! I/O routines for LIBSVM problem and model files.
2//!
3//! File formats match the original LIBSVM exactly, ensuring cross-tool
4//! interoperability.
5
6use std::io::{BufRead, Write};
7use std::path::Path;
8
9use crate::error::SvmError;
10use crate::types::*;
11
12// ─── C-compatible %g formatting ─────────────────────────────────────
13//
14// C's printf `%.Pg` format strips trailing zeros and picks fixed vs.
15// scientific notation based on the exponent. Rust has no built-in
16// equivalent, so we replicate the POSIX specification:
17//   - Use scientific if exponent < -4 or exponent >= precision
18//   - Otherwise use fixed notation
19//   - Strip trailing zeros (and trailing decimal point)
20
21use std::fmt;
22
23/// Formats `f64` like C's `%.17g` (or any precision).
24struct Gfmt {
25    value: f64,
26    precision: usize,
27}
28
29impl Gfmt {
30    fn new(value: f64, precision: usize) -> Self {
31        Self { value, precision }
32    }
33}
34
35impl fmt::Display for Gfmt {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        let v = self.value;
38        let p = self.precision;
39
40        if !v.is_finite() {
41            return write!(f, "{}", v); // inf, -inf, NaN
42        }
43
44        if v == 0.0 {
45            // Preserve sign of -0.0
46            if v.is_sign_negative() {
47                return write!(f, "-0");
48            }
49            return write!(f, "0");
50        }
51
52        // Compute the exponent (floor of log10(|v|))
53        let abs_v = v.abs();
54        let exp = abs_v.log10().floor() as i32;
55
56        if exp < -4 || exp >= p as i32 {
57            // Use scientific notation
58            let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
59            // Rust uses 'e', C uses 'e'. Strip trailing zeros in mantissa.
60            // C zero-pads exponent to at least 2 digits (e-05 not e-5).
61            if let Some((mantissa, exponent)) = s.split_once('e') {
62                let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
63                // Parse exponent, reformat with at least 2 digits
64                let exp_val: i32 = exponent.parse().unwrap_or(0);
65                let exp_str = if exp_val < 0 {
66                    format!("-{:02}", -exp_val)
67                } else {
68                    format!("+{:02}", exp_val)
69                };
70                write!(f, "{}e{}", mantissa, exp_str)
71            } else {
72                write!(f, "{}", s)
73            }
74        } else {
75            // Use fixed notation. Number of decimal places = precision - (exp + 1)
76            let decimal_places = if exp >= 0 {
77                p.saturating_sub((exp + 1) as usize)
78            } else {
79                p + (-1 - exp) as usize
80            };
81            let s = format!("{:.prec$}", v, prec = decimal_places);
82            let s = s.trim_end_matches('0').trim_end_matches('.');
83            write!(f, "{}", s)
84        }
85    }
86}
87
88/// Format like C's `%.17g`
89fn fmt_17g(v: f64) -> Gfmt {
90    Gfmt::new(v, 17)
91}
92
93/// Format like C's `%.8g`
94fn fmt_8g(v: f64) -> Gfmt {
95    Gfmt::new(v, 8)
96}
97
98// ─── String tables matching original LIBSVM ──────────────────────────
99
100const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
101const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
102
103fn svm_type_to_str(t: SvmType) -> &'static str {
104    SVM_TYPE_TABLE[t as usize]
105}
106
107fn kernel_type_to_str(t: KernelType) -> &'static str {
108    KERNEL_TYPE_TABLE[t as usize]
109}
110
111fn str_to_svm_type(s: &str) -> Option<SvmType> {
112    match s {
113        "c_svc" => Some(SvmType::CSvc),
114        "nu_svc" => Some(SvmType::NuSvc),
115        "one_class" => Some(SvmType::OneClass),
116        "epsilon_svr" => Some(SvmType::EpsilonSvr),
117        "nu_svr" => Some(SvmType::NuSvr),
118        _ => None,
119    }
120}
121
122fn str_to_kernel_type(s: &str) -> Option<KernelType> {
123    match s {
124        "linear" => Some(KernelType::Linear),
125        "polynomial" => Some(KernelType::Polynomial),
126        "rbf" => Some(KernelType::Rbf),
127        "sigmoid" => Some(KernelType::Sigmoid),
128        "precomputed" => Some(KernelType::Precomputed),
129        _ => None,
130    }
131}
132
133// ─── Problem file I/O ────────────────────────────────────────────────
134
135/// Load an SVM problem from a file in LIBSVM sparse format.
136///
137/// Format: `<label> <index1>:<value1> <index2>:<value2> ...`
138pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
139    let file = std::fs::File::open(path)?;
140    let reader = std::io::BufReader::new(file);
141    load_problem_from_reader(reader)
142}
143
144/// Load an SVM problem from any buffered reader.
145pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
146    let mut labels = Vec::new();
147    let mut instances = Vec::new();
148
149    for (line_idx, line_result) in reader.lines().enumerate() {
150        let line = line_result?;
151        let line = line.trim();
152        if line.is_empty() {
153            continue;
154        }
155
156        let line_num = line_idx + 1;
157        let mut parts = line.split_whitespace();
158
159        // Parse label
160        let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
161            line: line_num,
162            message: "missing label".into(),
163        })?;
164        let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
165            line: line_num,
166            message: format!("invalid label: {}", label_str),
167        })?;
168
169        // Parse features (must be in ascending index order)
170        let mut nodes = Vec::new();
171        let mut prev_index: i32 = 0;
172        for token in parts {
173            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
174                line: line_num,
175                message: format!("expected index:value, got: {}", token),
176            })?;
177            let index: i32 = idx_str.parse().map_err(|_| SvmError::ParseError {
178                line: line_num,
179                message: format!("invalid index: {}", idx_str),
180            })?;
181            if !nodes.is_empty() && index <= prev_index {
182                return Err(SvmError::ParseError {
183                    line: line_num,
184                    message: format!(
185                        "feature indices must be ascending: {} follows {}",
186                        index, prev_index
187                    ),
188                });
189            }
190            let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
191                line: line_num,
192                message: format!("invalid value: {}", val_str),
193            })?;
194            prev_index = index;
195            nodes.push(SvmNode { index, value });
196        }
197
198        labels.push(label);
199        instances.push(nodes);
200    }
201
202    Ok(SvmProblem { labels, instances })
203}
204
205// ─── Model file I/O ──────────────────────────────────────────────────
206
207/// Save an SVM model to a file in the original LIBSVM format.
208pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
209    let file = std::fs::File::create(path)?;
210    let writer = std::io::BufWriter::new(file);
211    save_model_to_writer(writer, model)
212}
213
214/// Save an SVM model to any writer.
215pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
216    let param = &model.param;
217
218    writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
219    writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
220
221    if param.kernel_type == KernelType::Polynomial {
222        writeln!(w, "degree {}", param.degree)?;
223    }
224    if matches!(
225        param.kernel_type,
226        KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
227    ) {
228        writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
229    }
230    if matches!(
231        param.kernel_type,
232        KernelType::Polynomial | KernelType::Sigmoid
233    ) {
234        writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
235    }
236
237    let nr_class = model.nr_class;
238    writeln!(w, "nr_class {}", nr_class)?;
239    writeln!(w, "total_sv {}", model.sv.len())?;
240
241    // rho
242    write!(w, "rho")?;
243    for r in &model.rho {
244        write!(w, " {}", fmt_17g(*r))?;
245    }
246    writeln!(w)?;
247
248    // label (classification only)
249    if !model.label.is_empty() {
250        write!(w, "label")?;
251        for l in &model.label {
252            write!(w, " {}", l)?;
253        }
254        writeln!(w)?;
255    }
256
257    // probA
258    if !model.prob_a.is_empty() {
259        write!(w, "probA")?;
260        for v in &model.prob_a {
261            write!(w, " {}", fmt_17g(*v))?;
262        }
263        writeln!(w)?;
264    }
265
266    // probB
267    if !model.prob_b.is_empty() {
268        write!(w, "probB")?;
269        for v in &model.prob_b {
270            write!(w, " {}", fmt_17g(*v))?;
271        }
272        writeln!(w)?;
273    }
274
275    // prob_density_marks (one-class)
276    if !model.prob_density_marks.is_empty() {
277        write!(w, "prob_density_marks")?;
278        for v in &model.prob_density_marks {
279            write!(w, " {}", fmt_17g(*v))?;
280        }
281        writeln!(w)?;
282    }
283
284    // nr_sv
285    if !model.n_sv.is_empty() {
286        write!(w, "nr_sv")?;
287        for n in &model.n_sv {
288            write!(w, " {}", n)?;
289        }
290        writeln!(w)?;
291    }
292
293    // SV section
294    writeln!(w, "SV")?;
295    let num_sv = model.sv.len();
296    let num_coef_rows = model.sv_coef.len(); // nr_class - 1
297
298    for i in 0..num_sv {
299        // sv_coef columns for this SV: %.17g
300        for j in 0..num_coef_rows {
301            write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
302        }
303        // sparse features: %.8g
304        if model.param.kernel_type == KernelType::Precomputed {
305            if let Some(node) = model.sv[i].first() {
306                write!(w, "0:{} ", node.value as i32)?;
307            }
308        } else {
309            for node in &model.sv[i] {
310                write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
311            }
312        }
313        writeln!(w)?;
314    }
315
316    Ok(())
317}
318
319/// Load an SVM model from a file in the original LIBSVM format.
320pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
321    let file = std::fs::File::open(path)?;
322    let reader = std::io::BufReader::new(file);
323    load_model_from_reader(reader)
324}
325
326/// Load an SVM model from any buffered reader.
327pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
328    let mut lines = reader.lines();
329
330    // Defaults
331    let mut param = SvmParameter::default();
332    let mut nr_class: usize = 0;
333    let mut total_sv: usize = 0;
334    let mut rho = Vec::new();
335    let mut label = Vec::new();
336    let mut prob_a = Vec::new();
337    let mut prob_b = Vec::new();
338    let mut prob_density_marks = Vec::new();
339    let mut n_sv = Vec::new();
340
341    // Read header
342    let mut line_num: usize = 0;
343    loop {
344        let line = lines
345            .next()
346            .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in header".into()))??;
347        line_num += 1;
348        let line = line.trim().to_string();
349        if line.is_empty() {
350            continue;
351        }
352
353        let mut parts = line.split_whitespace();
354        let cmd = parts.next().unwrap();
355
356        match cmd {
357            "svm_type" => {
358                let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
359                    format!("line {}: missing svm_type value", line_num),
360                ))?;
361                param.svm_type = str_to_svm_type(val).ok_or_else(|| {
362                    SvmError::ModelFormatError(format!("line {}: unknown svm_type: {}", line_num, val))
363                })?;
364            }
365            "kernel_type" => {
366                let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
367                    format!("line {}: missing kernel_type value", line_num),
368                ))?;
369                param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
370                    SvmError::ModelFormatError(format!("line {}: unknown kernel_type: {}", line_num, val))
371                })?;
372            }
373            "degree" => {
374                param.degree = parse_single(&mut parts, line_num, "degree")?;
375            }
376            "gamma" => {
377                param.gamma = parse_single(&mut parts, line_num, "gamma")?;
378            }
379            "coef0" => {
380                param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
381            }
382            "nr_class" => {
383                nr_class = parse_single(&mut parts, line_num, "nr_class")?;
384            }
385            "total_sv" => {
386                total_sv = parse_single(&mut parts, line_num, "total_sv")?;
387            }
388            "rho" => {
389                rho = parse_multiple_f64(&mut parts, line_num, "rho")?;
390            }
391            "label" => {
392                label = parse_multiple_i32(&mut parts, line_num, "label")?;
393            }
394            "probA" => {
395                prob_a = parse_multiple_f64(&mut parts, line_num, "probA")?;
396            }
397            "probB" => {
398                prob_b = parse_multiple_f64(&mut parts, line_num, "probB")?;
399            }
400            "prob_density_marks" => {
401                prob_density_marks = parse_multiple_f64(&mut parts, line_num, "prob_density_marks")?;
402            }
403            "nr_sv" => {
404                n_sv = parts
405                    .map(|s| {
406                        s.parse::<usize>().map_err(|_| {
407                            SvmError::ModelFormatError(format!(
408                                "line {}: invalid nr_sv value: {}",
409                                line_num, s
410                            ))
411                        })
412                    })
413                    .collect::<Result<Vec<_>, _>>()?;
414            }
415            "SV" => break,
416            _ => {
417                return Err(SvmError::ModelFormatError(format!(
418                    "line {}: unknown keyword: {}",
419                    line_num, cmd
420                )));
421            }
422        }
423    }
424
425    // Read SV section
426    let m = if nr_class > 1 { nr_class - 1 } else { 1 };
427    let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::with_capacity(total_sv)).collect();
428    let mut sv: Vec<Vec<SvmNode>> = Vec::with_capacity(total_sv);
429
430    for _ in 0..total_sv {
431        let line = lines
432            .next()
433            .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in SV section".into()))??;
434        line_num += 1;
435        let line = line.trim();
436        if line.is_empty() {
437            continue;
438        }
439
440        let mut parts = line.split_whitespace();
441
442        // First m tokens are sv_coef values
443        for (k, coef_row) in sv_coef.iter_mut().enumerate() {
444            let val_str = parts.next().ok_or_else(|| SvmError::ModelFormatError(
445                format!("line {}: missing sv_coef[{}]", line_num, k),
446            ))?;
447            let val: f64 = val_str.parse().map_err(|_| SvmError::ModelFormatError(
448                format!("line {}: invalid sv_coef: {}", line_num, val_str),
449            ))?;
450            coef_row.push(val);
451        }
452
453        // Remaining tokens are index:value pairs
454        let mut nodes = Vec::new();
455        for token in parts {
456            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
457                SvmError::ModelFormatError(format!(
458                    "line {}: expected index:value, got: {}",
459                    line_num, token
460                ))
461            })?;
462            let index: i32 = idx_str.parse().map_err(|_| {
463                SvmError::ModelFormatError(format!("line {}: invalid index: {}", line_num, idx_str))
464            })?;
465            let value: f64 = val_str.parse().map_err(|_| {
466                SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
467            })?;
468            nodes.push(SvmNode { index, value });
469        }
470        sv.push(nodes);
471    }
472
473    Ok(SvmModel {
474        param,
475        nr_class,
476        sv,
477        sv_coef,
478        rho,
479        prob_a,
480        prob_b,
481        prob_density_marks,
482        sv_indices: Vec::new(), // not stored in model file
483        label,
484        n_sv,
485    })
486}
487
488// ─── Helper parsers ──────────────────────────────────────────────────
489
490fn parse_single<T: std::str::FromStr>(
491    parts: &mut std::str::SplitWhitespace<'_>,
492    line_num: usize,
493    field: &str,
494) -> Result<T, SvmError> {
495    let val_str = parts.next().ok_or_else(|| {
496        SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
497    })?;
498    val_str.parse().map_err(|_| {
499        SvmError::ModelFormatError(format!("line {}: invalid {} value: {}", line_num, field, val_str))
500    })
501}
502
503fn parse_multiple_f64(
504    parts: &mut std::str::SplitWhitespace<'_>,
505    line_num: usize,
506    field: &str,
507) -> Result<Vec<f64>, SvmError> {
508    parts
509        .map(|s| {
510            s.parse::<f64>().map_err(|_| {
511                SvmError::ModelFormatError(format!(
512                    "line {}: invalid {} value: {}",
513                    line_num, field, s
514                ))
515            })
516        })
517        .collect()
518}
519
520fn parse_multiple_i32(
521    parts: &mut std::str::SplitWhitespace<'_>,
522    line_num: usize,
523    field: &str,
524) -> Result<Vec<i32>, SvmError> {
525    parts
526        .map(|s| {
527            s.parse::<i32>().map_err(|_| {
528                SvmError::ModelFormatError(format!(
529                    "line {}: invalid {} value: {}",
530                    line_num, field, s
531                ))
532            })
533        })
534        .collect()
535}
536
537// ─── Tests ───────────────────────────────────────────────────────────
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542    use std::path::PathBuf;
543
544    fn data_dir() -> PathBuf {
545        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
546            .join("..")
547            .join("..")
548            .join("data")
549    }
550
551    #[test]
552    fn parse_heart_scale() {
553        let path = data_dir().join("heart_scale");
554        let problem = load_problem(&path).unwrap();
555        assert_eq!(problem.labels.len(), 270);
556        assert_eq!(problem.instances.len(), 270);
557        // First instance: +1 label, 12 features (index 11 is missing/sparse)
558        assert_eq!(problem.labels[0], 1.0);
559        assert_eq!(problem.instances[0][0], SvmNode { index: 1, value: 0.708333 });
560        assert_eq!(problem.instances[0].len(), 12);
561    }
562
563    #[test]
564    fn parse_iris() {
565        let path = data_dir().join("iris.scale");
566        let problem = load_problem(&path).unwrap();
567        assert_eq!(problem.labels.len(), 150);
568        // 3 classes: 1, 2, 3
569        let classes: std::collections::HashSet<i64> =
570            problem.labels.iter().map(|&l| l as i64).collect();
571        assert_eq!(classes.len(), 3);
572    }
573
574    #[test]
575    fn parse_housing() {
576        let path = data_dir().join("housing_scale");
577        let problem = load_problem(&path).unwrap();
578        assert_eq!(problem.labels.len(), 506);
579        // Regression: labels are continuous
580        assert!((problem.labels[0] - 24.0).abs() < 1e-10);
581    }
582
583    #[test]
584    fn parse_empty_lines() {
585        let input = b"+1 1:0.5\n\n-1 2:0.3\n";
586        let problem = load_problem_from_reader(&input[..]).unwrap();
587        assert_eq!(problem.labels.len(), 2);
588    }
589
590    #[test]
591    fn parse_error_unsorted_indices() {
592        let input = b"+1 3:0.5 1:0.3\n";
593        let result = load_problem_from_reader(&input[..]);
594        assert!(result.is_err());
595        let msg = format!("{}", result.unwrap_err());
596        assert!(msg.contains("ascending"), "error: {}", msg);
597    }
598
599    #[test]
600    fn parse_error_duplicate_indices() {
601        let input = b"+1 1:0.5 1:0.3\n";
602        let result = load_problem_from_reader(&input[..]);
603        assert!(result.is_err());
604    }
605
606    #[test]
607    fn parse_error_missing_colon() {
608        let input = b"+1 1:0.5 bad_token\n";
609        let result = load_problem_from_reader(&input[..]);
610        assert!(result.is_err());
611    }
612
613    #[test]
614    fn load_c_trained_model() {
615        // Load a model produced by the original C LIBSVM svm-train
616        let path = data_dir().join("heart_scale.model");
617        let model = load_model(&path).unwrap();
618        assert_eq!(model.nr_class, 2);
619        assert_eq!(model.param.svm_type, SvmType::CSvc);
620        assert_eq!(model.param.kernel_type, KernelType::Rbf);
621        assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
622        assert_eq!(model.sv.len(), 132);
623        assert_eq!(model.label, vec![1, -1]);
624        assert_eq!(model.n_sv, vec![64, 68]);
625        assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
626        // sv_coef should have 1 row (nr_class - 1) with 132 entries
627        assert_eq!(model.sv_coef.len(), 1);
628        assert_eq!(model.sv_coef[0].len(), 132);
629    }
630
631    #[test]
632    fn roundtrip_c_model() {
633        // Load C model, save it back, and verify byte-exact match
634        let path = data_dir().join("heart_scale.model");
635        let original_bytes = std::fs::read_to_string(&path).unwrap();
636        let model = load_model(&path).unwrap();
637
638        let mut buf = Vec::new();
639        save_model_to_writer(&mut buf, &model).unwrap();
640        let rust_output = String::from_utf8(buf).unwrap();
641
642        // Compare line by line for better diagnostics
643        let orig_lines: Vec<&str> = original_bytes.lines().collect();
644        let rust_lines: Vec<&str> = rust_output.lines().collect();
645        assert_eq!(
646            orig_lines.len(),
647            rust_lines.len(),
648            "line count mismatch: C={} Rust={}",
649            orig_lines.len(),
650            rust_lines.len()
651        );
652        for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
653            assert_eq!(o, r, "line {} differs:\n  C:    {:?}\n  Rust: {:?}", i + 1, o, r);
654        }
655    }
656
657    #[test]
658    fn gfmt_matches_c_printf() {
659        // Reference values from C's printf("%.17g|%.8g\n", v, v)
660        let cases: &[(f64, &str, &str)] = &[
661            (0.5,                    "0.5",                      "0.5"),
662            (-1.0,                   "-1",                       "-1"),
663            (0.123456789012345,      "0.123456789012345",        "0.12345679"),
664            (-0.987654321098765,     "-0.98765432109876505",     "-0.98765432"),
665            (0.42446200000000001,    "0.42446200000000001",      "0.424462"),
666            (0.0,                    "0",                        "0"),
667            (1e-5,                   "1.0000000000000001e-05",   "1e-05"),
668            (1e-4,                   "0.0001",                   "0.0001"),
669            (1e20,                   "1e+20",                    "1e+20"),
670            (-0.25,                  "-0.25",                    "-0.25"),
671            (0.75,                   "0.75",                     "0.75"),
672            (0.708333,               "0.70833299999999999",      "0.708333"),
673            (1.0,                    "1",                        "1"),
674        ];
675        for &(v, expected_17g, expected_8g) in cases {
676            let got_17 = format!("{}", fmt_17g(v));
677            let got_8 = format!("{}", fmt_8g(v));
678            assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
679            assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
680        }
681    }
682
683    #[test]
684    fn model_roundtrip() {
685        // Create a minimal model and verify save → load roundtrip
686        let model = SvmModel {
687            param: SvmParameter {
688                svm_type: SvmType::CSvc,
689                kernel_type: KernelType::Rbf,
690                gamma: 0.5,
691                ..Default::default()
692            },
693            nr_class: 2,
694            sv: vec![
695                vec![SvmNode { index: 1, value: 0.5 }, SvmNode { index: 3, value: -1.0 }],
696                vec![SvmNode { index: 1, value: -0.25 }, SvmNode { index: 2, value: 0.75 }],
697            ],
698            sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
699            rho: vec![0.42446200000000001],
700            prob_a: vec![],
701            prob_b: vec![],
702            prob_density_marks: vec![],
703            sv_indices: vec![],
704            label: vec![1, -1],
705            n_sv: vec![1, 1],
706        };
707
708        let mut buf = Vec::new();
709        save_model_to_writer(&mut buf, &model).unwrap();
710
711        let loaded = load_model_from_reader(&buf[..]).unwrap();
712
713        assert_eq!(loaded.nr_class, model.nr_class);
714        assert_eq!(loaded.param.svm_type, model.param.svm_type);
715        assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
716        assert_eq!(loaded.sv.len(), model.sv.len());
717        assert_eq!(loaded.label, model.label);
718        assert_eq!(loaded.n_sv, model.n_sv);
719        assert_eq!(loaded.rho.len(), model.rho.len());
720        // Check rho within tolerance (roundtrip through text)
721        for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
722            assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
723        }
724        // Check sv_coef within tolerance
725        for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
726            for (a, b) in row_a.iter().zip(row_b.iter()) {
727                assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
728            }
729        }
730    }
731}