Skip to main content

libsvm_rs/
io.rs

1//! I/O routines for LIBSVM problem and model files.
2//!
3//! File formats match the original LIBSVM exactly, ensuring cross-tool
4//! interoperability.
5
6use std::io::{BufRead, Write};
7use std::path::Path;
8
9use crate::error::SvmError;
10use crate::util::MAX_FEATURE_INDEX;
11use crate::util::parse_feature_index;
12use crate::types::*;
13
14// ─── C-compatible %g formatting ─────────────────────────────────────
15//
16// C's printf `%.Pg` format strips trailing zeros and picks fixed vs.
17// scientific notation based on the exponent. Rust has no built-in
18// equivalent, so we replicate the POSIX specification:
19//   - Use scientific if exponent < -4 or exponent >= precision
20//   - Otherwise use fixed notation
21//   - Strip trailing zeros (and trailing decimal point)
22
23use std::fmt;
24
25/// Formats `f64` like C's `%.17g` (or any precision).
26struct Gfmt {
27    value: f64,
28    precision: usize,
29}
30
31impl Gfmt {
32    fn new(value: f64, precision: usize) -> Self {
33        Self { value, precision }
34    }
35}
36
37impl fmt::Display for Gfmt {
38    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
39        let v = self.value;
40        let p = self.precision;
41
42        if !v.is_finite() {
43            return write!(f, "{}", v); // inf, -inf, NaN
44        }
45
46        if v == 0.0 {
47            // Preserve sign of -0.0
48            if v.is_sign_negative() {
49                return write!(f, "-0");
50            }
51            return write!(f, "0");
52        }
53
54        // Compute the exponent (floor of log10(|v|))
55        let abs_v = v.abs();
56        let exp = abs_v.log10().floor() as i32;
57
58        if exp < -4 || exp >= p as i32 {
59            // Use scientific notation
60            let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
61            // Rust uses 'e', C uses 'e'. Strip trailing zeros in mantissa.
62            // C zero-pads exponent to at least 2 digits (e-05 not e-5).
63            if let Some((mantissa, exponent)) = s.split_once('e') {
64                let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
65                // Parse exponent, reformat with at least 2 digits
66                let exp_val: i32 = exponent.parse().unwrap_or(0);
67                let exp_str = if exp_val < 0 {
68                    format!("-{:02}", -exp_val)
69                } else {
70                    format!("+{:02}", exp_val)
71                };
72                write!(f, "{}e{}", mantissa, exp_str)
73            } else {
74                write!(f, "{}", s)
75            }
76        } else {
77            // Use fixed notation. Number of decimal places = precision - (exp + 1)
78            let decimal_places = if exp >= 0 {
79                p.saturating_sub((exp + 1) as usize)
80            } else {
81                p + (-1 - exp) as usize
82            };
83            let s = format!("{:.prec$}", v, prec = decimal_places);
84            let s = s.trim_end_matches('0').trim_end_matches('.');
85            write!(f, "{}", s)
86        }
87    }
88}
89
90/// Format like C's `%.17g`
91fn fmt_17g(v: f64) -> Gfmt {
92    Gfmt::new(v, 17)
93}
94
95/// Format like C's `%.8g`
96fn fmt_8g(v: f64) -> Gfmt {
97    Gfmt::new(v, 8)
98}
99
100/// Format a float like C's `%g` (6 significant digits).
101pub fn format_g(v: f64) -> String {
102    format!("{}", Gfmt::new(v, 6))
103}
104
105/// Format a float like C's `%.17g` (17 significant digits).
106pub fn format_17g(v: f64) -> String {
107    format!("{}", Gfmt::new(v, 17))
108}
109
110// ─── String tables matching original LIBSVM ──────────────────────────
111
112const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
113const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
114
115fn svm_type_to_str(t: SvmType) -> &'static str {
116    SVM_TYPE_TABLE[t as usize]
117}
118
119fn kernel_type_to_str(t: KernelType) -> &'static str {
120    KERNEL_TYPE_TABLE[t as usize]
121}
122
123fn str_to_svm_type(s: &str) -> Option<SvmType> {
124    match s {
125        "c_svc" => Some(SvmType::CSvc),
126        "nu_svc" => Some(SvmType::NuSvc),
127        "one_class" => Some(SvmType::OneClass),
128        "epsilon_svr" => Some(SvmType::EpsilonSvr),
129        "nu_svr" => Some(SvmType::NuSvr),
130        _ => None,
131    }
132}
133
134fn str_to_kernel_type(s: &str) -> Option<KernelType> {
135    match s {
136        "linear" => Some(KernelType::Linear),
137        "polynomial" => Some(KernelType::Polynomial),
138        "rbf" => Some(KernelType::Rbf),
139        "sigmoid" => Some(KernelType::Sigmoid),
140        "precomputed" => Some(KernelType::Precomputed),
141        _ => None,
142    }
143}
144
145// ─── Problem file I/O ────────────────────────────────────────────────
146
147/// Load an SVM problem from a file in LIBSVM sparse format.
148///
149/// Format: `<label> <index1>:<value1> <index2>:<value2> ...`
150pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
151    let file = std::fs::File::open(path)?;
152    let reader = std::io::BufReader::new(file);
153    load_problem_from_reader(reader)
154}
155
156/// Load an SVM problem from any buffered reader.
157pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
158    let mut labels = Vec::new();
159    let mut instances = Vec::new();
160
161    for (line_idx, line_result) in reader.lines().enumerate() {
162        let line = line_result?;
163        let line = line.trim();
164        if line.is_empty() {
165            continue;
166        }
167
168        let line_num = line_idx + 1;
169        let mut parts = line.split_whitespace();
170
171        // Parse label
172        let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
173            line: line_num,
174            message: "missing label".into(),
175        })?;
176        let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
177            line: line_num,
178            message: format!("invalid label: {}", label_str),
179        })?;
180
181        // Parse features (must be in ascending index order)
182        let mut nodes = Vec::new();
183        let mut prev_index: i32 = 0;
184        for token in parts {
185            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
186                line: line_num,
187                message: format!("expected index:value, got: {}", token),
188            })?;
189            let index: i32 = parse_feature_index_problem_line(line_num, idx_str)?;
190
191            if !nodes.is_empty() && index <= prev_index {
192                return Err(SvmError::ParseError {
193                    line: line_num,
194                    message: format!(
195                        "feature indices must be ascending: {} follows {}",
196                        index, prev_index
197                    ),
198                });
199            }
200            let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
201                line: line_num,
202                message: format!("invalid value: {}", val_str),
203            })?;
204            prev_index = index;
205            nodes.push(SvmNode { index, value });
206        }
207
208        labels.push(label);
209        instances.push(nodes);
210    }
211
212    Ok(SvmProblem { labels, instances })
213}
214
215// ─── Model file I/O ──────────────────────────────────────────────────
216
217const MAX_NR_CLASS: usize = 65535;
218const MAX_TOTAL_SV: usize = 10_000_000;
219
220/// Save an SVM model to a file in the original LIBSVM format.
221pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
222    let file = std::fs::File::create(path)?;
223    let writer = std::io::BufWriter::new(file);
224    save_model_to_writer(writer, model)
225}
226
227/// Save an SVM model to any writer.
228pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
229    let param = &model.param;
230
231    writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
232    writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
233
234    if param.kernel_type == KernelType::Polynomial {
235        writeln!(w, "degree {}", param.degree)?;
236    }
237    if matches!(
238        param.kernel_type,
239        KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
240    ) {
241        writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
242    }
243    if matches!(
244        param.kernel_type,
245        KernelType::Polynomial | KernelType::Sigmoid
246    ) {
247        writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
248    }
249
250    let nr_class = model.nr_class;
251    writeln!(w, "nr_class {}", nr_class)?;
252    writeln!(w, "total_sv {}", model.sv.len())?;
253
254    // rho
255    write!(w, "rho")?;
256    for r in &model.rho {
257        write!(w, " {}", fmt_17g(*r))?;
258    }
259    writeln!(w)?;
260
261    // label (classification only)
262    if !model.label.is_empty() {
263        write!(w, "label")?;
264        for l in &model.label {
265            write!(w, " {}", l)?;
266        }
267        writeln!(w)?;
268    }
269
270    // probA
271    if !model.prob_a.is_empty() {
272        write!(w, "probA")?;
273        for v in &model.prob_a {
274            write!(w, " {}", fmt_17g(*v))?;
275        }
276        writeln!(w)?;
277    }
278
279    // probB
280    if !model.prob_b.is_empty() {
281        write!(w, "probB")?;
282        for v in &model.prob_b {
283            write!(w, " {}", fmt_17g(*v))?;
284        }
285        writeln!(w)?;
286    }
287
288    // prob_density_marks (one-class)
289    if !model.prob_density_marks.is_empty() {
290        write!(w, "prob_density_marks")?;
291        for v in &model.prob_density_marks {
292            write!(w, " {}", fmt_17g(*v))?;
293        }
294        writeln!(w)?;
295    }
296
297    // nr_sv
298    if !model.n_sv.is_empty() {
299        write!(w, "nr_sv")?;
300        for n in &model.n_sv {
301            write!(w, " {}", n)?;
302        }
303        writeln!(w)?;
304    }
305
306    // SV section
307    writeln!(w, "SV")?;
308    let num_sv = model.sv.len();
309    let num_coef_rows = model.sv_coef.len(); // nr_class - 1
310
311    for i in 0..num_sv {
312        // sv_coef columns for this SV: %.17g
313        for j in 0..num_coef_rows {
314            write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
315        }
316        // sparse features: %.8g
317        if model.param.kernel_type == KernelType::Precomputed {
318            if let Some(node) = model.sv[i].first() {
319                write!(w, "0:{} ", node.value as i32)?;
320            }
321        } else {
322            for node in &model.sv[i] {
323                write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
324            }
325        }
326        writeln!(w)?;
327    }
328
329    Ok(())
330}
331
332/// Load an SVM model from a file in the original LIBSVM format.
333pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
334    let file = std::fs::File::open(path)?;
335    let reader = std::io::BufReader::new(file);
336    load_model_from_reader(reader)
337}
338
339/// Load an SVM model from any buffered reader.
340pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
341    let mut lines = reader.lines();
342
343    // Defaults
344    let mut param = SvmParameter::default();
345    let mut nr_class: usize = 0;
346    let mut total_sv: usize = 0;
347    let mut rho = Vec::new();
348    let mut label = Vec::new();
349    let mut prob_a = Vec::new();
350    let mut prob_b = Vec::new();
351    let mut prob_density_marks = Vec::new();
352    let mut n_sv = Vec::new();
353
354    // Read header
355    let mut line_num: usize = 0;
356    loop {
357        let line = lines.next().ok_or_else(|| {
358            SvmError::ModelFormatError("unexpected end of file in header".into())
359        })??;
360        line_num += 1;
361        let line = line.trim().to_string();
362        if line.is_empty() {
363            continue;
364        }
365
366        let mut parts = line.split_whitespace();
367        let cmd = parts.next().unwrap();
368
369        match cmd {
370            "svm_type" => {
371                let val = parts.next().ok_or_else(|| {
372                    SvmError::ModelFormatError(format!("line {}: missing svm_type value", line_num))
373                })?;
374                param.svm_type = str_to_svm_type(val).ok_or_else(|| {
375                    SvmError::ModelFormatError(format!(
376                        "line {}: unknown svm_type: {}",
377                        line_num, val
378                    ))
379                })?;
380            }
381            "kernel_type" => {
382                let val = parts.next().ok_or_else(|| {
383                    SvmError::ModelFormatError(format!(
384                        "line {}: missing kernel_type value",
385                        line_num
386                    ))
387                })?;
388                param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
389                    SvmError::ModelFormatError(format!(
390                        "line {}: unknown kernel_type: {}",
391                        line_num, val
392                    ))
393                })?;
394            }
395            "degree" => {
396                param.degree = parse_single(&mut parts, line_num, "degree")?;
397            }
398            "gamma" => {
399                param.gamma = parse_single(&mut parts, line_num, "gamma")?;
400            }
401            "coef0" => {
402                param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
403            }
404            "nr_class" => {
405                nr_class = parse_single(&mut parts, line_num, "nr_class")?;
406                if nr_class > MAX_NR_CLASS {
407                    return Err(SvmError::ModelFormatError(format!(
408                        "line {}: nr_class exceeds limit ({})",
409                        line_num, MAX_NR_CLASS
410                    )));
411                }
412            }
413            "total_sv" => {
414                total_sv = parse_single(&mut parts, line_num, "total_sv")?;
415                if total_sv > MAX_TOTAL_SV {
416                    return Err(SvmError::ModelFormatError(format!(
417                        "line {}: total_sv exceeds limit ({})",
418                        line_num, MAX_TOTAL_SV
419                    )));
420                }
421            }
422            "rho" => {
423                rho = parse_multiple(&mut parts, line_num, "rho")?;
424            }
425            "label" => {
426                label = parse_multiple(&mut parts, line_num, "label")?;
427            }
428            "probA" => {
429                prob_a = parse_multiple(&mut parts, line_num, "probA")?;
430            }
431            "probB" => {
432                prob_b = parse_multiple(&mut parts, line_num, "probB")?;
433            }
434            "prob_density_marks" => {
435                prob_density_marks =
436                    parse_multiple(&mut parts, line_num, "prob_density_marks")?;
437            }
438            "nr_sv" => {
439                n_sv = parts
440                    .map(|s| {
441                        s.parse::<usize>().map_err(|_| {
442                            SvmError::ModelFormatError(format!(
443                                "line {}: invalid nr_sv value: {}",
444                                line_num, s
445                            ))
446                        })
447                    })
448                    .collect::<Result<Vec<_>, _>>()?;
449            }
450            "SV" => break,
451            _ => {
452                return Err(SvmError::ModelFormatError(format!(
453                    "line {}: unknown keyword: {}",
454                    line_num, cmd
455                )));
456            }
457        }
458    }
459
460    // Read SV section
461    let m = if nr_class > 1 { nr_class - 1 } else { 1 };
462    let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::with_capacity(total_sv)).collect();
463    let mut sv: Vec<Vec<SvmNode>> = Vec::with_capacity(total_sv);
464
465    for _ in 0..total_sv {
466        let line = lines.next().ok_or_else(|| {
467            SvmError::ModelFormatError("unexpected end of file in SV section".into())
468        })??;
469        line_num += 1;
470        let line = line.trim();
471        if line.is_empty() {
472            continue;
473        }
474
475        let mut parts = line.split_whitespace();
476
477        // First m tokens are sv_coef values
478        for (k, coef_row) in sv_coef.iter_mut().enumerate() {
479            let val_str = parts.next().ok_or_else(|| {
480                SvmError::ModelFormatError(format!("line {}: missing sv_coef[{}]", line_num, k))
481            })?;
482            let val: f64 = val_str.parse().map_err(|_| {
483                SvmError::ModelFormatError(format!(
484                    "line {}: invalid sv_coef: {}",
485                    line_num, val_str
486                ))
487            })?;
488            coef_row.push(val);
489        }
490
491        // Remaining tokens are index:value pairs
492        let mut nodes = Vec::new();
493        for token in parts {
494            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
495                SvmError::ModelFormatError(format!(
496                    "line {}: expected index:value, got: {}",
497                    line_num, token
498                ))
499            })?;
500            let index: i32 = parse_feature_index_model_line(line_num, idx_str)?;
501
502            let value: f64 = val_str.parse().map_err(|_| {
503                SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
504            })?;
505            nodes.push(SvmNode { index, value });
506        }
507        sv.push(nodes);
508    }
509
510    Ok(SvmModel {
511        param,
512        nr_class,
513        sv,
514        sv_coef,
515        rho,
516        prob_a,
517        prob_b,
518        prob_density_marks,
519        sv_indices: Vec::new(), // not stored in model file
520        label,
521        n_sv,
522    })
523}
524
525// ─── Helper parsers ──────────────────────────────────────────────────
526
527fn parse_feature_index_problem_line(line_num: usize, idx_str: &str) -> Result<i32, SvmError> {
528    parse_feature_index(idx_str, MAX_FEATURE_INDEX).map_err(|msg| SvmError::ParseError {
529        line: line_num,
530        message: msg,
531    })
532}
533
534fn parse_feature_index_model_line(line_num: usize, idx_str: &str) -> Result<i32, SvmError> {
535    parse_feature_index(idx_str, MAX_FEATURE_INDEX)
536        .map_err(|msg| SvmError::ModelFormatError(format!("line {}: {}", line_num, msg)))
537}
538
539fn parse_single<T: std::str::FromStr>(
540    parts: &mut std::str::SplitWhitespace<'_>,
541    line_num: usize,
542    field: &str,
543) -> Result<T, SvmError> {
544    let val_str = parts.next().ok_or_else(|| {
545        SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
546    })?;
547    val_str.parse().map_err(|_| {
548        SvmError::ModelFormatError(format!(
549            "line {}: invalid {} value: {}",
550            line_num, field, val_str
551        ))
552    })
553}
554
555fn parse_multiple<T: std::str::FromStr>(
556    parts: &mut std::str::SplitWhitespace<'_>,
557    line_num: usize,
558    field: &str,
559) -> Result<Vec<T>, SvmError> {
560    parts
561        .map(|s| {
562            s.parse::<T>().map_err(|_| {
563                SvmError::ModelFormatError(format!(
564                    "line {}: invalid {} value: {}",
565                    line_num, field, s
566                ))
567            })
568        })
569        .collect()
570}
571
572// ─── Tests ───────────────────────────────────────────────────────────
573
574#[cfg(test)]
575mod tests {
576    use super::*;
577    use std::path::PathBuf;
578
579    fn data_dir() -> PathBuf {
580        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
581            .join("..")
582            .join("..")
583            .join("data")
584    }
585
586    #[test]
587    fn parse_heart_scale() {
588        let path = data_dir().join("heart_scale");
589        let problem = load_problem(&path).unwrap();
590        assert_eq!(problem.labels.len(), 270);
591        assert_eq!(problem.instances.len(), 270);
592        // First instance: +1 label, 12 features (index 11 is missing/sparse)
593        assert_eq!(problem.labels[0], 1.0);
594        assert_eq!(
595            problem.instances[0][0],
596            SvmNode {
597                index: 1,
598                value: 0.708333
599            }
600        );
601        assert_eq!(problem.instances[0].len(), 12);
602    }
603
604    #[test]
605    fn parse_iris() {
606        let path = data_dir().join("iris.scale");
607        let problem = load_problem(&path).unwrap();
608        assert_eq!(problem.labels.len(), 150);
609        // 3 classes: 1, 2, 3
610        let classes: std::collections::HashSet<i64> =
611            problem.labels.iter().map(|&l| l as i64).collect();
612        assert_eq!(classes.len(), 3);
613    }
614
615    #[test]
616    fn parse_housing() {
617        let path = data_dir().join("housing_scale");
618        let problem = load_problem(&path).unwrap();
619        assert_eq!(problem.labels.len(), 506);
620        // Regression: labels are continuous
621        assert!((problem.labels[0] - 24.0).abs() < 1e-10);
622    }
623
624    #[test]
625    fn parse_empty_lines() {
626        let input = b"+1 1:0.5\n\n-1 2:0.3\n";
627        let problem = load_problem_from_reader(&input[..]).unwrap();
628        assert_eq!(problem.labels.len(), 2);
629    }
630
631    #[test]
632    fn parse_error_unsorted_indices() {
633        let input = b"+1 3:0.5 1:0.3\n";
634        let result = load_problem_from_reader(&input[..]);
635        assert!(result.is_err());
636        let msg = format!("{}", result.unwrap_err());
637        assert!(msg.contains("ascending"), "error: {}", msg);
638    }
639
640    #[test]
641    fn parse_error_duplicate_indices() {
642        let input = b"+1 1:0.5 1:0.3\n";
643        let result = load_problem_from_reader(&input[..]);
644        assert!(result.is_err());
645    }
646
647    #[test]
648    fn parse_error_missing_colon() {
649        let input = b"+1 1:0.5 bad_token\n";
650        let result = load_problem_from_reader(&input[..]);
651        assert!(result.is_err());
652    }
653
654    #[test]
655    #[allow(clippy::excessive_precision)]
656    fn load_c_trained_model() {
657        // Load a model produced by the original C LIBSVM svm-train
658        let path = data_dir().join("heart_scale.model");
659        let model = load_model(&path).unwrap();
660        assert_eq!(model.nr_class, 2);
661        assert_eq!(model.param.svm_type, SvmType::CSvc);
662        assert_eq!(model.param.kernel_type, KernelType::Rbf);
663        assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
664        assert_eq!(model.sv.len(), 132);
665        assert_eq!(model.label, vec![1, -1]);
666        assert_eq!(model.n_sv, vec![64, 68]);
667        assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
668        // sv_coef should have 1 row (nr_class - 1) with 132 entries
669        assert_eq!(model.sv_coef.len(), 1);
670        assert_eq!(model.sv_coef[0].len(), 132);
671    }
672
673    #[test]
674    fn roundtrip_c_model() {
675        // Load C model, save it back, and verify byte-exact match
676        let path = data_dir().join("heart_scale.model");
677        let original_bytes = std::fs::read_to_string(&path).unwrap();
678        let model = load_model(&path).unwrap();
679
680        let mut buf = Vec::new();
681        save_model_to_writer(&mut buf, &model).unwrap();
682        let rust_output = String::from_utf8(buf).unwrap();
683
684        // Compare line by line for better diagnostics
685        let orig_lines: Vec<&str> = original_bytes.lines().collect();
686        let rust_lines: Vec<&str> = rust_output.lines().collect();
687        assert_eq!(
688            orig_lines.len(),
689            rust_lines.len(),
690            "line count mismatch: C={} Rust={}",
691            orig_lines.len(),
692            rust_lines.len()
693        );
694        for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
695            assert_eq!(
696                o,
697                r,
698                "line {} differs:\n  C:    {:?}\n  Rust: {:?}",
699                i + 1,
700                o,
701                r
702            );
703        }
704    }
705
706    #[test]
707    #[allow(clippy::excessive_precision)]
708    fn gfmt_matches_c_printf() {
709        // Reference values from C's printf("%.17g|%.8g\n", v, v)
710        let cases: &[(f64, &str, &str)] = &[
711            (0.5, "0.5", "0.5"),
712            (-1.0, "-1", "-1"),
713            (0.123456789012345, "0.123456789012345", "0.12345679"),
714            (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
715            (0.42446200000000001, "0.42446200000000001", "0.424462"),
716            (0.0, "0", "0"),
717            (1e-5, "1.0000000000000001e-05", "1e-05"),
718            (1e-4, "0.0001", "0.0001"),
719            (1e20, "1e+20", "1e+20"),
720            (-0.25, "-0.25", "-0.25"),
721            (0.75, "0.75", "0.75"),
722            (0.708333, "0.70833299999999999", "0.708333"),
723            (1.0, "1", "1"),
724        ];
725        for &(v, expected_17g, expected_8g) in cases {
726            let got_17 = format!("{}", fmt_17g(v));
727            let got_8 = format!("{}", fmt_8g(v));
728            assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
729            assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
730        }
731    }
732
733    #[test]
734    #[allow(clippy::excessive_precision)]
735    fn model_roundtrip() {
736        // Create a minimal model and verify save → load roundtrip
737        let model = SvmModel {
738            param: SvmParameter {
739                svm_type: SvmType::CSvc,
740                kernel_type: KernelType::Rbf,
741                gamma: 0.5,
742                ..Default::default()
743            },
744            nr_class: 2,
745            sv: vec![
746                vec![
747                    SvmNode {
748                        index: 1,
749                        value: 0.5,
750                    },
751                    SvmNode {
752                        index: 3,
753                        value: -1.0,
754                    },
755                ],
756                vec![
757                    SvmNode {
758                        index: 1,
759                        value: -0.25,
760                    },
761                    SvmNode {
762                        index: 2,
763                        value: 0.75,
764                    },
765                ],
766            ],
767            sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
768            rho: vec![0.42446200000000001],
769            prob_a: vec![],
770            prob_b: vec![],
771            prob_density_marks: vec![],
772            sv_indices: vec![],
773            label: vec![1, -1],
774            n_sv: vec![1, 1],
775        };
776
777        let mut buf = Vec::new();
778        save_model_to_writer(&mut buf, &model).unwrap();
779
780        let loaded = load_model_from_reader(&buf[..]).unwrap();
781
782        assert_eq!(loaded.nr_class, model.nr_class);
783        assert_eq!(loaded.param.svm_type, model.param.svm_type);
784        assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
785        assert_eq!(loaded.sv.len(), model.sv.len());
786        assert_eq!(loaded.label, model.label);
787        assert_eq!(loaded.n_sv, model.n_sv);
788        assert_eq!(loaded.rho.len(), model.rho.len());
789        // Check rho within tolerance (roundtrip through text)
790        for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
791            assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
792        }
793        // Check sv_coef within tolerance
794        for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
795            for (a, b) in row_a.iter().zip(row_b.iter()) {
796                assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
797            }
798        }
799    }
800
801    #[test]
802    fn parse_error_excessive_counts() {
803        let input =
804            b"svm_type c_svc\nkernel_type linear\nnr_class 1000000\ntotal_sv 100\nrho 0\nSV\n";
805        let result = load_model_from_reader(&input[..]);
806        assert!(result.is_err());
807        assert!(format!("{}", result.unwrap_err()).contains("nr_class exceeds limit"));
808
809        let input =
810            b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 100000000\nrho 0\nSV\n";
811        let result = load_model_from_reader(&input[..]);
812        assert!(result.is_err());
813        assert!(format!("{}", result.unwrap_err()).contains("total_sv exceeds limit"));
814    }
815
816    #[test]
817    fn parse_error_excessive_feature_index() {
818        // Problem file
819        let input = b"1 10000001:1\n";
820        let result = load_problem_from_reader(&input[..]);
821        assert!(result.is_err());
822        assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
823
824        // Model file
825        let input = b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 1\nrho 0\nSV\n0.1 10000001:1\n";
826        let result = load_model_from_reader(&input[..]);
827        assert!(result.is_err());
828        assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
829    }
830
831    #[test]
832    fn parse_error_unknown_model_keyword() {
833        let input = b"bad_key value\n";
834        let result = load_model_from_reader(&input[..]);
835        assert!(result.is_err());
836        assert!(format!("{}", result.unwrap_err()).contains("unknown keyword"));
837    }
838
839    #[test]
840    fn parse_error_missing_or_unknown_model_values() {
841        let missing = b"svm_type\n";
842        let err = load_model_from_reader(&missing[..]).unwrap_err();
843        assert!(format!("{}", err).contains("missing svm_type value"));
844
845        let unknown = b"svm_type unknown_type\n";
846        let err = load_model_from_reader(&unknown[..]).unwrap_err();
847        assert!(format!("{}", err).contains("unknown svm_type"));
848    }
849
850    #[test]
851    fn parse_error_invalid_nr_sv_entry() {
852        let input = b"svm_type c_svc\n\
853kernel_type linear\n\
854nr_class 2\n\
855total_sv 1\n\
856rho 0\n\
857nr_sv a 1\n\
858SV\n\
8590.1 1:0.5\n";
860        let err = load_model_from_reader(&input[..]).unwrap_err();
861        assert!(format!("{}", err).contains("invalid nr_sv value"));
862    }
863
864    #[test]
865    fn parse_error_in_sv_section_tokens() {
866        let missing_coef = b"svm_type c_svc\n\
867kernel_type linear\n\
868nr_class 2\n\
869total_sv 1\n\
870rho 0\n\
871SV\n\
8721:0.5\n";
873        let err = load_model_from_reader(&missing_coef[..]).unwrap_err();
874        assert!(format!("{}", err).contains("invalid sv_coef"));
875
876        let bad_feature = b"svm_type c_svc\n\
877kernel_type linear\n\
878nr_class 2\n\
879total_sv 1\n\
880rho 0\n\
881SV\n\
8820.1 bad\n";
883        let err = load_model_from_reader(&bad_feature[..]).unwrap_err();
884        assert!(format!("{}", err).contains("expected index:value"));
885    }
886
887    #[test]
888    fn parse_error_unexpected_eof_in_header_and_sv_section() {
889        let eof_header = b"svm_type c_svc\n";
890        let err = load_model_from_reader(&eof_header[..]).unwrap_err();
891        assert!(format!("{}", err).contains("unexpected end of file in header"));
892
893        let eof_sv = b"svm_type c_svc\n\
894kernel_type linear\n\
895nr_class 2\n\
896total_sv 2\n\
897rho 0\n\
898SV\n\
8990.1 1:0.5\n";
900        let err = load_model_from_reader(&eof_sv[..]).unwrap_err();
901        assert!(format!("{}", err).contains("unexpected end of file in SV section"));
902    }
903
904    #[test]
905    fn save_precomputed_model_writes_zero_index() {
906        let model = SvmModel {
907            param: SvmParameter {
908                svm_type: SvmType::CSvc,
909                kernel_type: KernelType::Precomputed,
910                ..Default::default()
911            },
912            nr_class: 2,
913            sv: vec![vec![SvmNode {
914                index: 0,
915                value: 7.0,
916            }]],
917            sv_coef: vec![vec![0.25]],
918            rho: vec![0.0],
919            prob_a: vec![],
920            prob_b: vec![],
921            prob_density_marks: vec![],
922            sv_indices: vec![],
923            label: vec![1, -1],
924            n_sv: vec![1, 0],
925        };
926
927        let mut buf = Vec::new();
928        save_model_to_writer(&mut buf, &model).unwrap();
929        let out = String::from_utf8(buf).unwrap();
930        assert!(out.contains("kernel_type precomputed"));
931        assert!(out.contains("0:7"));
932    }
933}