Skip to main content

libsvm_rs/
io.rs

1//! I/O routines for LIBSVM problem and model files.
2//!
3//! File formats match the original LIBSVM exactly, ensuring cross-tool
4//! interoperability.
5
6use std::io::{BufRead, Write};
7use std::path::Path;
8
9use crate::error::SvmError;
10use crate::types::*;
11
12// ─── C-compatible %g formatting ─────────────────────────────────────
13//
14// C's printf `%.Pg` format strips trailing zeros and picks fixed vs.
15// scientific notation based on the exponent. Rust has no built-in
16// equivalent, so we replicate the POSIX specification:
17//   - Use scientific if exponent < -4 or exponent >= precision
18//   - Otherwise use fixed notation
19//   - Strip trailing zeros (and trailing decimal point)
20
21use std::fmt;
22
23/// Formats `f64` like C's `%.17g` (or any precision).
24struct Gfmt {
25    value: f64,
26    precision: usize,
27}
28
29impl Gfmt {
30    fn new(value: f64, precision: usize) -> Self {
31        Self { value, precision }
32    }
33}
34
35impl fmt::Display for Gfmt {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        let v = self.value;
38        let p = self.precision;
39
40        if !v.is_finite() {
41            return write!(f, "{}", v); // inf, -inf, NaN
42        }
43
44        if v == 0.0 {
45            // Preserve sign of -0.0
46            if v.is_sign_negative() {
47                return write!(f, "-0");
48            }
49            return write!(f, "0");
50        }
51
52        // Compute the exponent (floor of log10(|v|))
53        let abs_v = v.abs();
54        let exp = abs_v.log10().floor() as i32;
55
56        if exp < -4 || exp >= p as i32 {
57            // Use scientific notation
58            let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
59            // Rust uses 'e', C uses 'e'. Strip trailing zeros in mantissa.
60            // C zero-pads exponent to at least 2 digits (e-05 not e-5).
61            if let Some((mantissa, exponent)) = s.split_once('e') {
62                let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
63                // Parse exponent, reformat with at least 2 digits
64                let exp_val: i32 = exponent.parse().unwrap_or(0);
65                let exp_str = if exp_val < 0 {
66                    format!("-{:02}", -exp_val)
67                } else {
68                    format!("+{:02}", exp_val)
69                };
70                write!(f, "{}e{}", mantissa, exp_str)
71            } else {
72                write!(f, "{}", s)
73            }
74        } else {
75            // Use fixed notation. Number of decimal places = precision - (exp + 1)
76            let decimal_places = if exp >= 0 {
77                p.saturating_sub((exp + 1) as usize)
78            } else {
79                p + (-1 - exp) as usize
80            };
81            let s = format!("{:.prec$}", v, prec = decimal_places);
82            let s = s.trim_end_matches('0').trim_end_matches('.');
83            write!(f, "{}", s)
84        }
85    }
86}
87
88/// Format like C's `%.17g`
89fn fmt_17g(v: f64) -> Gfmt {
90    Gfmt::new(v, 17)
91}
92
93/// Format like C's `%.8g`
94fn fmt_8g(v: f64) -> Gfmt {
95    Gfmt::new(v, 8)
96}
97
98/// Format a float like C's `%g` (6 significant digits).
99pub fn format_g(v: f64) -> String {
100    format!("{}", Gfmt::new(v, 6))
101}
102
103/// Format a float like C's `%.17g` (17 significant digits).
104pub fn format_17g(v: f64) -> String {
105    format!("{}", Gfmt::new(v, 17))
106}
107
108// ─── String tables matching original LIBSVM ──────────────────────────
109
110const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
111const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
112
113fn svm_type_to_str(t: SvmType) -> &'static str {
114    SVM_TYPE_TABLE[t as usize]
115}
116
117fn kernel_type_to_str(t: KernelType) -> &'static str {
118    KERNEL_TYPE_TABLE[t as usize]
119}
120
121fn str_to_svm_type(s: &str) -> Option<SvmType> {
122    match s {
123        "c_svc" => Some(SvmType::CSvc),
124        "nu_svc" => Some(SvmType::NuSvc),
125        "one_class" => Some(SvmType::OneClass),
126        "epsilon_svr" => Some(SvmType::EpsilonSvr),
127        "nu_svr" => Some(SvmType::NuSvr),
128        _ => None,
129    }
130}
131
132fn str_to_kernel_type(s: &str) -> Option<KernelType> {
133    match s {
134        "linear" => Some(KernelType::Linear),
135        "polynomial" => Some(KernelType::Polynomial),
136        "rbf" => Some(KernelType::Rbf),
137        "sigmoid" => Some(KernelType::Sigmoid),
138        "precomputed" => Some(KernelType::Precomputed),
139        _ => None,
140    }
141}
142
143// ─── Problem file I/O ────────────────────────────────────────────────
144
145/// Load an SVM problem from a file in LIBSVM sparse format.
146///
147/// Format: `<label> <index1>:<value1> <index2>:<value2> ...`
148pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
149    let file = std::fs::File::open(path)?;
150    let reader = std::io::BufReader::new(file);
151    load_problem_from_reader(reader)
152}
153
154/// Load an SVM problem from any buffered reader.
155pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
156    let mut labels = Vec::new();
157    let mut instances = Vec::new();
158
159    for (line_idx, line_result) in reader.lines().enumerate() {
160        let line = line_result?;
161        let line = line.trim();
162        if line.is_empty() {
163            continue;
164        }
165
166        let line_num = line_idx + 1;
167        let mut parts = line.split_whitespace();
168
169        // Parse label
170        let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
171            line: line_num,
172            message: "missing label".into(),
173        })?;
174        let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
175            line: line_num,
176            message: format!("invalid label: {}", label_str),
177        })?;
178
179        // Parse features (must be in ascending index order)
180        let mut nodes = Vec::new();
181        let mut prev_index: i32 = 0;
182        for token in parts {
183            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
184                line: line_num,
185                message: format!("expected index:value, got: {}", token),
186            })?;
187            let index: i32 = idx_str.parse().map_err(|_| SvmError::ParseError {
188                line: line_num,
189                message: format!("invalid index: {}", idx_str),
190            })?;
191            if !nodes.is_empty() && index <= prev_index {
192                return Err(SvmError::ParseError {
193                    line: line_num,
194                    message: format!(
195                        "feature indices must be ascending: {} follows {}",
196                        index, prev_index
197                    ),
198                });
199            }
200            let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
201                line: line_num,
202                message: format!("invalid value: {}", val_str),
203            })?;
204            prev_index = index;
205            nodes.push(SvmNode { index, value });
206        }
207
208        labels.push(label);
209        instances.push(nodes);
210    }
211
212    Ok(SvmProblem { labels, instances })
213}
214
215// ─── Model file I/O ──────────────────────────────────────────────────
216
217const MAX_NR_CLASS: usize = 65535;
218const MAX_TOTAL_SV: usize = 10_000_000;
219
220/// Save an SVM model to a file in the original LIBSVM format.
221pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
222    let file = std::fs::File::create(path)?;
223    let writer = std::io::BufWriter::new(file);
224    save_model_to_writer(writer, model)
225}
226
227/// Save an SVM model to any writer.
228pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
229    let param = &model.param;
230
231    writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
232    writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
233
234    if param.kernel_type == KernelType::Polynomial {
235        writeln!(w, "degree {}", param.degree)?;
236    }
237    if matches!(
238        param.kernel_type,
239        KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
240    ) {
241        writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
242    }
243    if matches!(
244        param.kernel_type,
245        KernelType::Polynomial | KernelType::Sigmoid
246    ) {
247        writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
248    }
249
250    let nr_class = model.nr_class;
251    writeln!(w, "nr_class {}", nr_class)?;
252    writeln!(w, "total_sv {}", model.sv.len())?;
253
254    // rho
255    write!(w, "rho")?;
256    for r in &model.rho {
257        write!(w, " {}", fmt_17g(*r))?;
258    }
259    writeln!(w)?;
260
261    // label (classification only)
262    if !model.label.is_empty() {
263        write!(w, "label")?;
264        for l in &model.label {
265            write!(w, " {}", l)?;
266        }
267        writeln!(w)?;
268    }
269
270    // probA
271    if !model.prob_a.is_empty() {
272        write!(w, "probA")?;
273        for v in &model.prob_a {
274            write!(w, " {}", fmt_17g(*v))?;
275        }
276        writeln!(w)?;
277    }
278
279    // probB
280    if !model.prob_b.is_empty() {
281        write!(w, "probB")?;
282        for v in &model.prob_b {
283            write!(w, " {}", fmt_17g(*v))?;
284        }
285        writeln!(w)?;
286    }
287
288    // prob_density_marks (one-class)
289    if !model.prob_density_marks.is_empty() {
290        write!(w, "prob_density_marks")?;
291        for v in &model.prob_density_marks {
292            write!(w, " {}", fmt_17g(*v))?;
293        }
294        writeln!(w)?;
295    }
296
297    // nr_sv
298    if !model.n_sv.is_empty() {
299        write!(w, "nr_sv")?;
300        for n in &model.n_sv {
301            write!(w, " {}", n)?;
302        }
303        writeln!(w)?;
304    }
305
306    // SV section
307    writeln!(w, "SV")?;
308    let num_sv = model.sv.len();
309    let num_coef_rows = model.sv_coef.len(); // nr_class - 1
310
311    for i in 0..num_sv {
312        // sv_coef columns for this SV: %.17g
313        for j in 0..num_coef_rows {
314            write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
315        }
316        // sparse features: %.8g
317        if model.param.kernel_type == KernelType::Precomputed {
318            if let Some(node) = model.sv[i].first() {
319                write!(w, "0:{} ", node.value as i32)?;
320            }
321        } else {
322            for node in &model.sv[i] {
323                write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
324            }
325        }
326        writeln!(w)?;
327    }
328
329    Ok(())
330}
331
332/// Load an SVM model from a file in the original LIBSVM format.
333pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
334    let file = std::fs::File::open(path)?;
335    let reader = std::io::BufReader::new(file);
336    load_model_from_reader(reader)
337}
338
339/// Load an SVM model from any buffered reader.
340pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
341    let mut lines = reader.lines();
342
343    // Defaults
344    let mut param = SvmParameter::default();
345    let mut nr_class: usize = 0;
346    let mut total_sv: usize = 0;
347    let mut rho = Vec::new();
348    let mut label = Vec::new();
349    let mut prob_a = Vec::new();
350    let mut prob_b = Vec::new();
351    let mut prob_density_marks = Vec::new();
352    let mut n_sv = Vec::new();
353
354    // Read header
355    let mut line_num: usize = 0;
356    loop {
357        let line = lines
358            .next()
359            .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in header".into()))??;
360        line_num += 1;
361        let line = line.trim().to_string();
362        if line.is_empty() {
363            continue;
364        }
365
366        let mut parts = line.split_whitespace();
367        let cmd = parts.next().unwrap();
368
369        match cmd {
370            "svm_type" => {
371                let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
372                    format!("line {}: missing svm_type value", line_num),
373                ))?;
374                param.svm_type = str_to_svm_type(val).ok_or_else(|| {
375                    SvmError::ModelFormatError(format!("line {}: unknown svm_type: {}", line_num, val))
376                })?;
377            }
378            "kernel_type" => {
379                let val = parts.next().ok_or_else(|| SvmError::ModelFormatError(
380                    format!("line {}: missing kernel_type value", line_num),
381                ))?;
382                param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
383                    SvmError::ModelFormatError(format!("line {}: unknown kernel_type: {}", line_num, val))
384                })?;
385            }
386            "degree" => {
387                param.degree = parse_single(&mut parts, line_num, "degree")?;
388            }
389            "gamma" => {
390                param.gamma = parse_single(&mut parts, line_num, "gamma")?;
391            }
392            "coef0" => {
393                param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
394            }
395            "nr_class" => {
396                nr_class = parse_single(&mut parts, line_num, "nr_class")?;
397                if nr_class > MAX_NR_CLASS {
398                    return Err(SvmError::ModelFormatError(format!(
399                        "line {}: nr_class exceeds limit ({})",
400                        line_num, MAX_NR_CLASS
401                    )));
402                }
403            }
404            "total_sv" => {
405                total_sv = parse_single(&mut parts, line_num, "total_sv")?;
406                if total_sv > MAX_TOTAL_SV {
407                    return Err(SvmError::ModelFormatError(format!(
408                        "line {}: total_sv exceeds limit ({})",
409                        line_num, MAX_TOTAL_SV
410                    )));
411                }
412            }
413            "rho" => {
414                rho = parse_multiple_f64(&mut parts, line_num, "rho")?;
415            }
416            "label" => {
417                label = parse_multiple_i32(&mut parts, line_num, "label")?;
418            }
419            "probA" => {
420                prob_a = parse_multiple_f64(&mut parts, line_num, "probA")?;
421            }
422            "probB" => {
423                prob_b = parse_multiple_f64(&mut parts, line_num, "probB")?;
424            }
425            "prob_density_marks" => {
426                prob_density_marks = parse_multiple_f64(&mut parts, line_num, "prob_density_marks")?;
427            }
428            "nr_sv" => {
429                n_sv = parts
430                    .map(|s| {
431                        s.parse::<usize>().map_err(|_| {
432                            SvmError::ModelFormatError(format!(
433                                "line {}: invalid nr_sv value: {}",
434                                line_num, s
435                            ))
436                        })
437                    })
438                    .collect::<Result<Vec<_>, _>>()?;
439            }
440            "SV" => break,
441            _ => {
442                return Err(SvmError::ModelFormatError(format!(
443                    "line {}: unknown keyword: {}",
444                    line_num, cmd
445                )));
446            }
447        }
448    }
449
450    // Read SV section
451    let m = if nr_class > 1 { nr_class - 1 } else { 1 };
452    let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::with_capacity(total_sv)).collect();
453    let mut sv: Vec<Vec<SvmNode>> = Vec::with_capacity(total_sv);
454
455    for _ in 0..total_sv {
456        let line = lines
457            .next()
458            .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in SV section".into()))??;
459        line_num += 1;
460        let line = line.trim();
461        if line.is_empty() {
462            continue;
463        }
464
465        let mut parts = line.split_whitespace();
466
467        // First m tokens are sv_coef values
468        for (k, coef_row) in sv_coef.iter_mut().enumerate() {
469            let val_str = parts.next().ok_or_else(|| SvmError::ModelFormatError(
470                format!("line {}: missing sv_coef[{}]", line_num, k),
471            ))?;
472            let val: f64 = val_str.parse().map_err(|_| SvmError::ModelFormatError(
473                format!("line {}: invalid sv_coef: {}", line_num, val_str),
474            ))?;
475            coef_row.push(val);
476        }
477
478        // Remaining tokens are index:value pairs
479        let mut nodes = Vec::new();
480        for token in parts {
481            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
482                SvmError::ModelFormatError(format!(
483                    "line {}: expected index:value, got: {}",
484                    line_num, token
485                ))
486            })?;
487            let index: i32 = idx_str.parse().map_err(|_| {
488                SvmError::ModelFormatError(format!("line {}: invalid index: {}", line_num, idx_str))
489            })?;
490            let value: f64 = val_str.parse().map_err(|_| {
491                SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
492            })?;
493            nodes.push(SvmNode { index, value });
494        }
495        sv.push(nodes);
496    }
497
498    Ok(SvmModel {
499        param,
500        nr_class,
501        sv,
502        sv_coef,
503        rho,
504        prob_a,
505        prob_b,
506        prob_density_marks,
507        sv_indices: Vec::new(), // not stored in model file
508        label,
509        n_sv,
510    })
511}
512
513// ─── Helper parsers ──────────────────────────────────────────────────
514
515fn parse_single<T: std::str::FromStr>(
516    parts: &mut std::str::SplitWhitespace<'_>,
517    line_num: usize,
518    field: &str,
519) -> Result<T, SvmError> {
520    let val_str = parts.next().ok_or_else(|| {
521        SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
522    })?;
523    val_str.parse().map_err(|_| {
524        SvmError::ModelFormatError(format!("line {}: invalid {} value: {}", line_num, field, val_str))
525    })
526}
527
528fn parse_multiple_f64(
529    parts: &mut std::str::SplitWhitespace<'_>,
530    line_num: usize,
531    field: &str,
532) -> Result<Vec<f64>, SvmError> {
533    parts
534        .map(|s| {
535            s.parse::<f64>().map_err(|_| {
536                SvmError::ModelFormatError(format!(
537                    "line {}: invalid {} value: {}",
538                    line_num, field, s
539                ))
540            })
541        })
542        .collect()
543}
544
545fn parse_multiple_i32(
546    parts: &mut std::str::SplitWhitespace<'_>,
547    line_num: usize,
548    field: &str,
549) -> Result<Vec<i32>, SvmError> {
550    parts
551        .map(|s| {
552            s.parse::<i32>().map_err(|_| {
553                SvmError::ModelFormatError(format!(
554                    "line {}: invalid {} value: {}",
555                    line_num, field, s
556                ))
557            })
558        })
559        .collect()
560}
561
562// ─── Tests ───────────────────────────────────────────────────────────
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567    use std::path::PathBuf;
568
569    fn data_dir() -> PathBuf {
570        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
571            .join("..")
572            .join("..")
573            .join("data")
574    }
575
576    #[test]
577    fn parse_heart_scale() {
578        let path = data_dir().join("heart_scale");
579        let problem = load_problem(&path).unwrap();
580        assert_eq!(problem.labels.len(), 270);
581        assert_eq!(problem.instances.len(), 270);
582        // First instance: +1 label, 12 features (index 11 is missing/sparse)
583        assert_eq!(problem.labels[0], 1.0);
584        assert_eq!(problem.instances[0][0], SvmNode { index: 1, value: 0.708333 });
585        assert_eq!(problem.instances[0].len(), 12);
586    }
587
588    #[test]
589    fn parse_iris() {
590        let path = data_dir().join("iris.scale");
591        let problem = load_problem(&path).unwrap();
592        assert_eq!(problem.labels.len(), 150);
593        // 3 classes: 1, 2, 3
594        let classes: std::collections::HashSet<i64> =
595            problem.labels.iter().map(|&l| l as i64).collect();
596        assert_eq!(classes.len(), 3);
597    }
598
599    #[test]
600    fn parse_housing() {
601        let path = data_dir().join("housing_scale");
602        let problem = load_problem(&path).unwrap();
603        assert_eq!(problem.labels.len(), 506);
604        // Regression: labels are continuous
605        assert!((problem.labels[0] - 24.0).abs() < 1e-10);
606    }
607
608    #[test]
609    fn parse_empty_lines() {
610        let input = b"+1 1:0.5\n\n-1 2:0.3\n";
611        let problem = load_problem_from_reader(&input[..]).unwrap();
612        assert_eq!(problem.labels.len(), 2);
613    }
614
615    #[test]
616    fn parse_error_unsorted_indices() {
617        let input = b"+1 3:0.5 1:0.3\n";
618        let result = load_problem_from_reader(&input[..]);
619        assert!(result.is_err());
620        let msg = format!("{}", result.unwrap_err());
621        assert!(msg.contains("ascending"), "error: {}", msg);
622    }
623
624    #[test]
625    fn parse_error_duplicate_indices() {
626        let input = b"+1 1:0.5 1:0.3\n";
627        let result = load_problem_from_reader(&input[..]);
628        assert!(result.is_err());
629    }
630
631    #[test]
632    fn parse_error_missing_colon() {
633        let input = b"+1 1:0.5 bad_token\n";
634        let result = load_problem_from_reader(&input[..]);
635        assert!(result.is_err());
636    }
637
638    #[test]
639    #[allow(clippy::excessive_precision)]
640    fn load_c_trained_model() {
641        // Load a model produced by the original C LIBSVM svm-train
642        let path = data_dir().join("heart_scale.model");
643        let model = load_model(&path).unwrap();
644        assert_eq!(model.nr_class, 2);
645        assert_eq!(model.param.svm_type, SvmType::CSvc);
646        assert_eq!(model.param.kernel_type, KernelType::Rbf);
647        assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
648        assert_eq!(model.sv.len(), 132);
649        assert_eq!(model.label, vec![1, -1]);
650        assert_eq!(model.n_sv, vec![64, 68]);
651        assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
652        // sv_coef should have 1 row (nr_class - 1) with 132 entries
653        assert_eq!(model.sv_coef.len(), 1);
654        assert_eq!(model.sv_coef[0].len(), 132);
655    }
656
657    #[test]
658    fn roundtrip_c_model() {
659        // Load C model, save it back, and verify byte-exact match
660        let path = data_dir().join("heart_scale.model");
661        let original_bytes = std::fs::read_to_string(&path).unwrap();
662        let model = load_model(&path).unwrap();
663
664        let mut buf = Vec::new();
665        save_model_to_writer(&mut buf, &model).unwrap();
666        let rust_output = String::from_utf8(buf).unwrap();
667
668        // Compare line by line for better diagnostics
669        let orig_lines: Vec<&str> = original_bytes.lines().collect();
670        let rust_lines: Vec<&str> = rust_output.lines().collect();
671        assert_eq!(
672            orig_lines.len(),
673            rust_lines.len(),
674            "line count mismatch: C={} Rust={}",
675            orig_lines.len(),
676            rust_lines.len()
677        );
678        for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
679            assert_eq!(o, r, "line {} differs:\n  C:    {:?}\n  Rust: {:?}", i + 1, o, r);
680        }
681    }
682
683    #[test]
684    #[allow(clippy::excessive_precision)]
685    fn gfmt_matches_c_printf() {
686        // Reference values from C's printf("%.17g|%.8g\n", v, v)
687        let cases: &[(f64, &str, &str)] = &[
688            (0.5,                    "0.5",                      "0.5"),
689            (-1.0,                   "-1",                       "-1"),
690            (0.123456789012345,      "0.123456789012345",        "0.12345679"),
691            (-0.987654321098765,     "-0.98765432109876505",     "-0.98765432"),
692            (0.42446200000000001,    "0.42446200000000001",      "0.424462"),
693            (0.0,                    "0",                        "0"),
694            (1e-5,                   "1.0000000000000001e-05",   "1e-05"),
695            (1e-4,                   "0.0001",                   "0.0001"),
696            (1e20,                   "1e+20",                    "1e+20"),
697            (-0.25,                  "-0.25",                    "-0.25"),
698            (0.75,                   "0.75",                     "0.75"),
699            (0.708333,               "0.70833299999999999",      "0.708333"),
700            (1.0,                    "1",                        "1"),
701        ];
702        for &(v, expected_17g, expected_8g) in cases {
703            let got_17 = format!("{}", fmt_17g(v));
704            let got_8 = format!("{}", fmt_8g(v));
705            assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
706            assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
707        }
708    }
709
710    #[test]
711    #[allow(clippy::excessive_precision)]
712    fn model_roundtrip() {
713        // Create a minimal model and verify save → load roundtrip
714        let model = SvmModel {
715            param: SvmParameter {
716                svm_type: SvmType::CSvc,
717                kernel_type: KernelType::Rbf,
718                gamma: 0.5,
719                ..Default::default()
720            },
721            nr_class: 2,
722            sv: vec![
723                vec![SvmNode { index: 1, value: 0.5 }, SvmNode { index: 3, value: -1.0 }],
724                vec![SvmNode { index: 1, value: -0.25 }, SvmNode { index: 2, value: 0.75 }],
725            ],
726            sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
727            rho: vec![0.42446200000000001],
728            prob_a: vec![],
729            prob_b: vec![],
730            prob_density_marks: vec![],
731            sv_indices: vec![],
732            label: vec![1, -1],
733            n_sv: vec![1, 1],
734        };
735
736        let mut buf = Vec::new();
737        save_model_to_writer(&mut buf, &model).unwrap();
738
739        let loaded = load_model_from_reader(&buf[..]).unwrap();
740
741        assert_eq!(loaded.nr_class, model.nr_class);
742        assert_eq!(loaded.param.svm_type, model.param.svm_type);
743        assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
744        assert_eq!(loaded.sv.len(), model.sv.len());
745        assert_eq!(loaded.label, model.label);
746        assert_eq!(loaded.n_sv, model.n_sv);
747        assert_eq!(loaded.rho.len(), model.rho.len());
748        // Check rho within tolerance (roundtrip through text)
749        for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
750            assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
751        }
752        // Check sv_coef within tolerance
753        for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
754            for (a, b) in row_a.iter().zip(row_b.iter()) {
755                assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
756            }
757        }
758    }
759
760    #[test]
761    fn parse_error_excessive_counts() {
762        let input = b"svm_type c_svc\nkernel_type linear\nnr_class 1000000\ntotal_sv 100\nrho 0\nSV\n";
763        let result = load_model_from_reader(&input[..]);
764        assert!(result.is_err());
765        assert!(format!("{}", result.unwrap_err()).contains("nr_class exceeds limit"));
766
767        let input = b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 100000000\nrho 0\nSV\n";
768        let result = load_model_from_reader(&input[..]);
769        assert!(result.is_err());
770        assert!(format!("{}", result.unwrap_err()).contains("total_sv exceeds limit"));
771    }
772}