Skip to main content

libsvm_rs/
io.rs

1//! I/O routines for LIBSVM problem and model files.
2//!
3//! File formats match the original LIBSVM exactly, ensuring cross-tool
4//! interoperability.
5
6use std::io::{BufRead, Write};
7use std::path::Path;
8
9use crate::error::SvmError;
10use crate::types::*;
11
12// ─── C-compatible %g formatting ─────────────────────────────────────
13//
14// C's printf `%.Pg` format strips trailing zeros and picks fixed vs.
15// scientific notation based on the exponent. Rust has no built-in
16// equivalent, so we replicate the POSIX specification:
17//   - Use scientific if exponent < -4 or exponent >= precision
18//   - Otherwise use fixed notation
19//   - Strip trailing zeros (and trailing decimal point)
20
21use std::fmt;
22
23/// Formats `f64` like C's `%.17g` (or any precision).
24struct Gfmt {
25    value: f64,
26    precision: usize,
27}
28
29impl Gfmt {
30    fn new(value: f64, precision: usize) -> Self {
31        Self { value, precision }
32    }
33}
34
35impl fmt::Display for Gfmt {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        let v = self.value;
38        let p = self.precision;
39
40        if !v.is_finite() {
41            return write!(f, "{}", v); // inf, -inf, NaN
42        }
43
44        if v == 0.0 {
45            // Preserve sign of -0.0
46            if v.is_sign_negative() {
47                return write!(f, "-0");
48            }
49            return write!(f, "0");
50        }
51
52        // Compute the exponent (floor of log10(|v|))
53        let abs_v = v.abs();
54        let exp = abs_v.log10().floor() as i32;
55
56        if exp < -4 || exp >= p as i32 {
57            // Use scientific notation
58            let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
59            // Rust uses 'e', C uses 'e'. Strip trailing zeros in mantissa.
60            // C zero-pads exponent to at least 2 digits (e-05 not e-5).
61            if let Some((mantissa, exponent)) = s.split_once('e') {
62                let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
63                // Parse exponent, reformat with at least 2 digits
64                let exp_val: i32 = exponent.parse().unwrap_or(0);
65                let exp_str = if exp_val < 0 {
66                    format!("-{:02}", -exp_val)
67                } else {
68                    format!("+{:02}", exp_val)
69                };
70                write!(f, "{}e{}", mantissa, exp_str)
71            } else {
72                write!(f, "{}", s)
73            }
74        } else {
75            // Use fixed notation. Number of decimal places = precision - (exp + 1)
76            let decimal_places = if exp >= 0 {
77                p.saturating_sub((exp + 1) as usize)
78            } else {
79                p + (-1 - exp) as usize
80            };
81            let s = format!("{:.prec$}", v, prec = decimal_places);
82            let s = s.trim_end_matches('0').trim_end_matches('.');
83            write!(f, "{}", s)
84        }
85    }
86}
87
88/// Format like C's `%.17g`
89fn fmt_17g(v: f64) -> Gfmt {
90    Gfmt::new(v, 17)
91}
92
93/// Format like C's `%.8g`
94fn fmt_8g(v: f64) -> Gfmt {
95    Gfmt::new(v, 8)
96}
97
98/// Format a float like C's `%g` (6 significant digits).
99pub fn format_g(v: f64) -> String {
100    format!("{}", Gfmt::new(v, 6))
101}
102
103/// Format a float like C's `%.17g` (17 significant digits).
104pub fn format_17g(v: f64) -> String {
105    format!("{}", Gfmt::new(v, 17))
106}
107
108// ─── String tables matching original LIBSVM ──────────────────────────
109
110const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
111const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
112
113fn svm_type_to_str(t: SvmType) -> &'static str {
114    SVM_TYPE_TABLE[t as usize]
115}
116
117fn kernel_type_to_str(t: KernelType) -> &'static str {
118    KERNEL_TYPE_TABLE[t as usize]
119}
120
121fn str_to_svm_type(s: &str) -> Option<SvmType> {
122    match s {
123        "c_svc" => Some(SvmType::CSvc),
124        "nu_svc" => Some(SvmType::NuSvc),
125        "one_class" => Some(SvmType::OneClass),
126        "epsilon_svr" => Some(SvmType::EpsilonSvr),
127        "nu_svr" => Some(SvmType::NuSvr),
128        _ => None,
129    }
130}
131
132fn str_to_kernel_type(s: &str) -> Option<KernelType> {
133    match s {
134        "linear" => Some(KernelType::Linear),
135        "polynomial" => Some(KernelType::Polynomial),
136        "rbf" => Some(KernelType::Rbf),
137        "sigmoid" => Some(KernelType::Sigmoid),
138        "precomputed" => Some(KernelType::Precomputed),
139        _ => None,
140    }
141}
142
143// ─── Problem file I/O ────────────────────────────────────────────────
144
145/// Load an SVM problem from a file in LIBSVM sparse format.
146///
147/// Format: `<label> <index1>:<value1> <index2>:<value2> ...`
148pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
149    let file = std::fs::File::open(path)?;
150    let reader = std::io::BufReader::new(file);
151    load_problem_from_reader(reader)
152}
153
154/// Load an SVM problem from any buffered reader.
155pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
156    let mut labels = Vec::new();
157    let mut instances = Vec::new();
158
159    for (line_idx, line_result) in reader.lines().enumerate() {
160        let line = line_result?;
161        let line = line.trim();
162        if line.is_empty() {
163            continue;
164        }
165
166        let line_num = line_idx + 1;
167        let mut parts = line.split_whitespace();
168
169        // Parse label
170        let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
171            line: line_num,
172            message: "missing label".into(),
173        })?;
174        let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
175            line: line_num,
176            message: format!("invalid label: {}", label_str),
177        })?;
178
179        // Parse features (must be in ascending index order)
180        let mut nodes = Vec::new();
181        let mut prev_index: i32 = 0;
182        for token in parts {
183            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
184                line: line_num,
185                message: format!("expected index:value, got: {}", token),
186            })?;
187            let index: i32 = idx_str.parse().map_err(|_| SvmError::ParseError {
188                line: line_num,
189                message: format!("invalid index: {}", idx_str),
190            })?;
191            if !nodes.is_empty() && index <= prev_index {
192                return Err(SvmError::ParseError {
193                    line: line_num,
194                    message: format!(
195                        "feature indices must be ascending: {} follows {}",
196                        index, prev_index
197                    ),
198                });
199            }
200            let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
201                line: line_num,
202                message: format!("invalid value: {}", val_str),
203            })?;
204            prev_index = index;
205            nodes.push(SvmNode { index, value });
206        }
207
208        labels.push(label);
209        instances.push(nodes);
210    }
211
212    Ok(SvmProblem { labels, instances })
213}
214
215// ─── Model file I/O ──────────────────────────────────────────────────
216
217/// Save an SVM model to a file in the original LIBSVM format.
218pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
219    let file = std::fs::File::create(path)?;
220    let writer = std::io::BufWriter::new(file);
221    save_model_to_writer(writer, model)
222}
223
224/// Save an SVM model to any writer.
225pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
226    let param = &model.param;
227
228    writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
229    writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
230
231    if param.kernel_type == KernelType::Polynomial {
232        writeln!(w, "degree {}", param.degree)?;
233    }
234    if matches!(
235        param.kernel_type,
236        KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
237    ) {
238        writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
239    }
240    if matches!(
241        param.kernel_type,
242        KernelType::Polynomial | KernelType::Sigmoid
243    ) {
244        writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
245    }
246
247    let nr_class = model.nr_class;
248    writeln!(w, "nr_class {}", nr_class)?;
249    writeln!(w, "total_sv {}", model.sv.len())?;
250
251    // rho
252    write!(w, "rho")?;
253    for r in &model.rho {
254        write!(w, " {}", fmt_17g(*r))?;
255    }
256    writeln!(w)?;
257
258    // label (classification only)
259    if !model.label.is_empty() {
260        write!(w, "label")?;
261        for l in &model.label {
262            write!(w, " {}", l)?;
263        }
264        writeln!(w)?;
265    }
266
267    // probA
268    if !model.prob_a.is_empty() {
269        write!(w, "probA")?;
270        for v in &model.prob_a {
271            write!(w, " {}", fmt_17g(*v))?;
272        }
273        writeln!(w)?;
274    }
275
276    // probB
277    if !model.prob_b.is_empty() {
278        write!(w, "probB")?;
279        for v in &model.prob_b {
280            write!(w, " {}", fmt_17g(*v))?;
281        }
282        writeln!(w)?;
283    }
284
285    // prob_density_marks (one-class)
286    if !model.prob_density_marks.is_empty() {
287        write!(w, "prob_density_marks")?;
288        for v in &model.prob_density_marks {
289            write!(w, " {}", fmt_17g(*v))?;
290        }
291        writeln!(w)?;
292    }
293
294    // nr_sv
295    if !model.n_sv.is_empty() {
296        write!(w, "nr_sv")?;
297        for n in &model.n_sv {
298            write!(w, " {}", n)?;
299        }
300        writeln!(w)?;
301    }
302
303    // SV section
304    writeln!(w, "SV")?;
305    let num_sv = model.sv.len();
306    let num_coef_rows = model.sv_coef.len(); // nr_class - 1
307
308    for i in 0..num_sv {
309        // sv_coef columns for this SV: %.17g
310        for j in 0..num_coef_rows {
311            write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
312        }
313        // sparse features: %.8g
314        if model.param.kernel_type == KernelType::Precomputed {
315            if let Some(node) = model.sv[i].first() {
316                write!(w, "0:{} ", node.value as i32)?;
317            }
318        } else {
319            for node in &model.sv[i] {
320                write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
321            }
322        }
323        writeln!(w)?;
324    }
325
326    Ok(())
327}
328
329/// Load an SVM model from a file in the original LIBSVM format.
330pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
331    let file = std::fs::File::open(path)?;
332    let reader = std::io::BufReader::new(file);
333    load_model_from_reader(reader)
334}
335
336/// Load an SVM model from any buffered reader.
337pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
338    let mut lines = reader.lines();
339
340    // Defaults
341    let mut param = SvmParameter::default();
342    let mut nr_class: usize = 0;
343    let mut total_sv: usize = 0;
344    let mut rho = Vec::new();
345    let mut label = Vec::new();
346    let mut prob_a = Vec::new();
347    let mut prob_b = Vec::new();
348    let mut prob_density_marks = Vec::new();
349    let mut n_sv = Vec::new();
350
351    // Read header
352    let mut line_num: usize = 0;
353    loop {
354        let line = lines
355            .next()
356            .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in header".into()))??;
357        line_num += 1;
358        let line = line.trim().to_string();
359        if line.is_empty() {
360            continue;
361        }
362
363        let mut parts = line.split_whitespace();
364        let cmd = parts.next().unwrap();
365
366        match cmd {
367            "svm_type" => {
368                let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
369                    format!("line {}: missing svm_type value", line_num),
370                ))?;
371                param.svm_type = str_to_svm_type(val).ok_or_else(|| {
372                    SvmError::ModelFormatError(format!("line {}: unknown svm_type: {}", line_num, val))
373                })?;
374            }
375            "kernel_type" => {
376                let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
377                    format!("line {}: missing kernel_type value", line_num),
378                ))?;
379                param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
380                    SvmError::ModelFormatError(format!("line {}: unknown kernel_type: {}", line_num, val))
381                })?;
382            }
383            "degree" => {
384                param.degree = parse_single(&mut parts, line_num, "degree")?;
385            }
386            "gamma" => {
387                param.gamma = parse_single(&mut parts, line_num, "gamma")?;
388            }
389            "coef0" => {
390                param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
391            }
392            "nr_class" => {
393                nr_class = parse_single(&mut parts, line_num, "nr_class")?;
394            }
395            "total_sv" => {
396                total_sv = parse_single(&mut parts, line_num, "total_sv")?;
397            }
398            "rho" => {
399                rho = parse_multiple_f64(&mut parts, line_num, "rho")?;
400            }
401            "label" => {
402                label = parse_multiple_i32(&mut parts, line_num, "label")?;
403            }
404            "probA" => {
405                prob_a = parse_multiple_f64(&mut parts, line_num, "probA")?;
406            }
407            "probB" => {
408                prob_b = parse_multiple_f64(&mut parts, line_num, "probB")?;
409            }
410            "prob_density_marks" => {
411                prob_density_marks = parse_multiple_f64(&mut parts, line_num, "prob_density_marks")?;
412            }
413            "nr_sv" => {
414                n_sv = parts
415                    .map(|s| {
416                        s.parse::<usize>().map_err(|_| {
417                            SvmError::ModelFormatError(format!(
418                                "line {}: invalid nr_sv value: {}",
419                                line_num, s
420                            ))
421                        })
422                    })
423                    .collect::<Result<Vec<_>, _>>()?;
424            }
425            "SV" => break,
426            _ => {
427                return Err(SvmError::ModelFormatError(format!(
428                    "line {}: unknown keyword: {}",
429                    line_num, cmd
430                )));
431            }
432        }
433    }
434
435    // Read SV section
436    let m = if nr_class > 1 { nr_class - 1 } else { 1 };
437    let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::with_capacity(total_sv)).collect();
438    let mut sv: Vec<Vec<SvmNode>> = Vec::with_capacity(total_sv);
439
440    for _ in 0..total_sv {
441        let line = lines
442            .next()
443            .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in SV section".into()))??;
444        line_num += 1;
445        let line = line.trim();
446        if line.is_empty() {
447            continue;
448        }
449
450        let mut parts = line.split_whitespace();
451
452        // First m tokens are sv_coef values
453        for (k, coef_row) in sv_coef.iter_mut().enumerate() {
454            let val_str = parts.next().ok_or_else(|| SvmError::ModelFormatError(
455                format!("line {}: missing sv_coef[{}]", line_num, k),
456            ))?;
457            let val: f64 = val_str.parse().map_err(|_| SvmError::ModelFormatError(
458                format!("line {}: invalid sv_coef: {}", line_num, val_str),
459            ))?;
460            coef_row.push(val);
461        }
462
463        // Remaining tokens are index:value pairs
464        let mut nodes = Vec::new();
465        for token in parts {
466            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
467                SvmError::ModelFormatError(format!(
468                    "line {}: expected index:value, got: {}",
469                    line_num, token
470                ))
471            })?;
472            let index: i32 = idx_str.parse().map_err(|_| {
473                SvmError::ModelFormatError(format!("line {}: invalid index: {}", line_num, idx_str))
474            })?;
475            let value: f64 = val_str.parse().map_err(|_| {
476                SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
477            })?;
478            nodes.push(SvmNode { index, value });
479        }
480        sv.push(nodes);
481    }
482
483    Ok(SvmModel {
484        param,
485        nr_class,
486        sv,
487        sv_coef,
488        rho,
489        prob_a,
490        prob_b,
491        prob_density_marks,
492        sv_indices: Vec::new(), // not stored in model file
493        label,
494        n_sv,
495    })
496}
497
498// ─── Helper parsers ──────────────────────────────────────────────────
499
500fn parse_single<T: std::str::FromStr>(
501    parts: &mut std::str::SplitWhitespace<'_>,
502    line_num: usize,
503    field: &str,
504) -> Result<T, SvmError> {
505    let val_str = parts.next().ok_or_else(|| {
506        SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
507    })?;
508    val_str.parse().map_err(|_| {
509        SvmError::ModelFormatError(format!("line {}: invalid {} value: {}", line_num, field, val_str))
510    })
511}
512
513fn parse_multiple_f64(
514    parts: &mut std::str::SplitWhitespace<'_>,
515    line_num: usize,
516    field: &str,
517) -> Result<Vec<f64>, SvmError> {
518    parts
519        .map(|s| {
520            s.parse::<f64>().map_err(|_| {
521                SvmError::ModelFormatError(format!(
522                    "line {}: invalid {} value: {}",
523                    line_num, field, s
524                ))
525            })
526        })
527        .collect()
528}
529
530fn parse_multiple_i32(
531    parts: &mut std::str::SplitWhitespace<'_>,
532    line_num: usize,
533    field: &str,
534) -> Result<Vec<i32>, SvmError> {
535    parts
536        .map(|s| {
537            s.parse::<i32>().map_err(|_| {
538                SvmError::ModelFormatError(format!(
539                    "line {}: invalid {} value: {}",
540                    line_num, field, s
541                ))
542            })
543        })
544        .collect()
545}
546
547// ─── Tests ───────────────────────────────────────────────────────────
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use std::path::PathBuf;
553
554    fn data_dir() -> PathBuf {
555        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
556            .join("..")
557            .join("..")
558            .join("data")
559    }
560
561    #[test]
562    fn parse_heart_scale() {
563        let path = data_dir().join("heart_scale");
564        let problem = load_problem(&path).unwrap();
565        assert_eq!(problem.labels.len(), 270);
566        assert_eq!(problem.instances.len(), 270);
567        // First instance: +1 label, 12 features (index 11 is missing/sparse)
568        assert_eq!(problem.labels[0], 1.0);
569        assert_eq!(problem.instances[0][0], SvmNode { index: 1, value: 0.708333 });
570        assert_eq!(problem.instances[0].len(), 12);
571    }
572
573    #[test]
574    fn parse_iris() {
575        let path = data_dir().join("iris.scale");
576        let problem = load_problem(&path).unwrap();
577        assert_eq!(problem.labels.len(), 150);
578        // 3 classes: 1, 2, 3
579        let classes: std::collections::HashSet<i64> =
580            problem.labels.iter().map(|&l| l as i64).collect();
581        assert_eq!(classes.len(), 3);
582    }
583
584    #[test]
585    fn parse_housing() {
586        let path = data_dir().join("housing_scale");
587        let problem = load_problem(&path).unwrap();
588        assert_eq!(problem.labels.len(), 506);
589        // Regression: labels are continuous
590        assert!((problem.labels[0] - 24.0).abs() < 1e-10);
591    }
592
593    #[test]
594    fn parse_empty_lines() {
595        let input = b"+1 1:0.5\n\n-1 2:0.3\n";
596        let problem = load_problem_from_reader(&input[..]).unwrap();
597        assert_eq!(problem.labels.len(), 2);
598    }
599
600    #[test]
601    fn parse_error_unsorted_indices() {
602        let input = b"+1 3:0.5 1:0.3\n";
603        let result = load_problem_from_reader(&input[..]);
604        assert!(result.is_err());
605        let msg = format!("{}", result.unwrap_err());
606        assert!(msg.contains("ascending"), "error: {}", msg);
607    }
608
609    #[test]
610    fn parse_error_duplicate_indices() {
611        let input = b"+1 1:0.5 1:0.3\n";
612        let result = load_problem_from_reader(&input[..]);
613        assert!(result.is_err());
614    }
615
616    #[test]
617    fn parse_error_missing_colon() {
618        let input = b"+1 1:0.5 bad_token\n";
619        let result = load_problem_from_reader(&input[..]);
620        assert!(result.is_err());
621    }
622
623    #[test]
624    fn load_c_trained_model() {
625        // Load a model produced by the original C LIBSVM svm-train
626        let path = data_dir().join("heart_scale.model");
627        let model = load_model(&path).unwrap();
628        assert_eq!(model.nr_class, 2);
629        assert_eq!(model.param.svm_type, SvmType::CSvc);
630        assert_eq!(model.param.kernel_type, KernelType::Rbf);
631        assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
632        assert_eq!(model.sv.len(), 132);
633        assert_eq!(model.label, vec![1, -1]);
634        assert_eq!(model.n_sv, vec![64, 68]);
635        assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
636        // sv_coef should have 1 row (nr_class - 1) with 132 entries
637        assert_eq!(model.sv_coef.len(), 1);
638        assert_eq!(model.sv_coef[0].len(), 132);
639    }
640
641    #[test]
642    fn roundtrip_c_model() {
643        // Load C model, save it back, and verify byte-exact match
644        let path = data_dir().join("heart_scale.model");
645        let original_bytes = std::fs::read_to_string(&path).unwrap();
646        let model = load_model(&path).unwrap();
647
648        let mut buf = Vec::new();
649        save_model_to_writer(&mut buf, &model).unwrap();
650        let rust_output = String::from_utf8(buf).unwrap();
651
652        // Compare line by line for better diagnostics
653        let orig_lines: Vec<&str> = original_bytes.lines().collect();
654        let rust_lines: Vec<&str> = rust_output.lines().collect();
655        assert_eq!(
656            orig_lines.len(),
657            rust_lines.len(),
658            "line count mismatch: C={} Rust={}",
659            orig_lines.len(),
660            rust_lines.len()
661        );
662        for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
663            assert_eq!(o, r, "line {} differs:\n  C:    {:?}\n  Rust: {:?}", i + 1, o, r);
664        }
665    }
666
667    #[test]
668    fn gfmt_matches_c_printf() {
669        // Reference values from C's printf("%.17g|%.8g\n", v, v)
670        let cases: &[(f64, &str, &str)] = &[
671            (0.5,                    "0.5",                      "0.5"),
672            (-1.0,                   "-1",                       "-1"),
673            (0.123456789012345,      "0.123456789012345",        "0.12345679"),
674            (-0.987654321098765,     "-0.98765432109876505",     "-0.98765432"),
675            (0.42446200000000001,    "0.42446200000000001",      "0.424462"),
676            (0.0,                    "0",                        "0"),
677            (1e-5,                   "1.0000000000000001e-05",   "1e-05"),
678            (1e-4,                   "0.0001",                   "0.0001"),
679            (1e20,                   "1e+20",                    "1e+20"),
680            (-0.25,                  "-0.25",                    "-0.25"),
681            (0.75,                   "0.75",                     "0.75"),
682            (0.708333,               "0.70833299999999999",      "0.708333"),
683            (1.0,                    "1",                        "1"),
684        ];
685        for &(v, expected_17g, expected_8g) in cases {
686            let got_17 = format!("{}", fmt_17g(v));
687            let got_8 = format!("{}", fmt_8g(v));
688            assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
689            assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
690        }
691    }
692
693    #[test]
694    fn model_roundtrip() {
695        // Create a minimal model and verify save → load roundtrip
696        let model = SvmModel {
697            param: SvmParameter {
698                svm_type: SvmType::CSvc,
699                kernel_type: KernelType::Rbf,
700                gamma: 0.5,
701                ..Default::default()
702            },
703            nr_class: 2,
704            sv: vec![
705                vec![SvmNode { index: 1, value: 0.5 }, SvmNode { index: 3, value: -1.0 }],
706                vec![SvmNode { index: 1, value: -0.25 }, SvmNode { index: 2, value: 0.75 }],
707            ],
708            sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
709            rho: vec![0.42446200000000001],
710            prob_a: vec![],
711            prob_b: vec![],
712            prob_density_marks: vec![],
713            sv_indices: vec![],
714            label: vec![1, -1],
715            n_sv: vec![1, 1],
716        };
717
718        let mut buf = Vec::new();
719        save_model_to_writer(&mut buf, &model).unwrap();
720
721        let loaded = load_model_from_reader(&buf[..]).unwrap();
722
723        assert_eq!(loaded.nr_class, model.nr_class);
724        assert_eq!(loaded.param.svm_type, model.param.svm_type);
725        assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
726        assert_eq!(loaded.sv.len(), model.sv.len());
727        assert_eq!(loaded.label, model.label);
728        assert_eq!(loaded.n_sv, model.n_sv);
729        assert_eq!(loaded.rho.len(), model.rho.len());
730        // Check rho within tolerance (roundtrip through text)
731        for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
732            assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
733        }
734        // Check sv_coef within tolerance
735        for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
736            for (a, b) in row_a.iter().zip(row_b.iter()) {
737                assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
738            }
739        }
740    }
741}