Skip to main content

libsvm_rs/
io.rs

1//! I/O routines for LIBSVM problem and model files.
2//!
3//! File formats match the original LIBSVM exactly, ensuring cross-tool
4//! interoperability.
5
6use std::io::{BufRead, Write};
7use std::path::Path;
8
9use crate::error::SvmError;
10use crate::types::*;
11
12// ─── C-compatible %g formatting ─────────────────────────────────────
13//
14// C's printf `%.Pg` format strips trailing zeros and picks fixed vs.
15// scientific notation based on the exponent. Rust has no built-in
16// equivalent, so we replicate the POSIX specification:
17//   - Use scientific if exponent < -4 or exponent >= precision
18//   - Otherwise use fixed notation
19//   - Strip trailing zeros (and trailing decimal point)
20
21use std::fmt;
22
23/// Formats `f64` like C's `%.17g` (or any precision).
24struct Gfmt {
25    value: f64,
26    precision: usize,
27}
28
29impl Gfmt {
30    fn new(value: f64, precision: usize) -> Self {
31        Self { value, precision }
32    }
33}
34
35impl fmt::Display for Gfmt {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        let v = self.value;
38        let p = self.precision;
39
40        if !v.is_finite() {
41            return write!(f, "{}", v); // inf, -inf, NaN
42        }
43
44        if v == 0.0 {
45            // Preserve sign of -0.0
46            if v.is_sign_negative() {
47                return write!(f, "-0");
48            }
49            return write!(f, "0");
50        }
51
52        // Compute the exponent (floor of log10(|v|))
53        let abs_v = v.abs();
54        let exp = abs_v.log10().floor() as i32;
55
56        if exp < -4 || exp >= p as i32 {
57            // Use scientific notation
58            let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
59            // Rust uses 'e', C uses 'e'. Strip trailing zeros in mantissa.
60            // C zero-pads exponent to at least 2 digits (e-05 not e-5).
61            if let Some((mantissa, exponent)) = s.split_once('e') {
62                let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
63                // Parse exponent, reformat with at least 2 digits
64                let exp_val: i32 = exponent.parse().unwrap_or(0);
65                let exp_str = if exp_val < 0 {
66                    format!("-{:02}", -exp_val)
67                } else {
68                    format!("+{:02}", exp_val)
69                };
70                write!(f, "{}e{}", mantissa, exp_str)
71            } else {
72                write!(f, "{}", s)
73            }
74        } else {
75            // Use fixed notation. Number of decimal places = precision - (exp + 1)
76            let decimal_places = if exp >= 0 {
77                p.saturating_sub((exp + 1) as usize)
78            } else {
79                p + (-1 - exp) as usize
80            };
81            let s = format!("{:.prec$}", v, prec = decimal_places);
82            let s = s.trim_end_matches('0').trim_end_matches('.');
83            write!(f, "{}", s)
84        }
85    }
86}
87
88/// Format like C's `%.17g`
89fn fmt_17g(v: f64) -> Gfmt {
90    Gfmt::new(v, 17)
91}
92
93/// Format like C's `%.8g`
94fn fmt_8g(v: f64) -> Gfmt {
95    Gfmt::new(v, 8)
96}
97
98/// Format a float like C's `%g` (6 significant digits).
99pub fn format_g(v: f64) -> String {
100    format!("{}", Gfmt::new(v, 6))
101}
102
103/// Format a float like C's `%.17g` (17 significant digits).
104pub fn format_17g(v: f64) -> String {
105    format!("{}", Gfmt::new(v, 17))
106}
107
108// ─── String tables matching original LIBSVM ──────────────────────────
109
110const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
111const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
112
113fn svm_type_to_str(t: SvmType) -> &'static str {
114    SVM_TYPE_TABLE[t as usize]
115}
116
117fn kernel_type_to_str(t: KernelType) -> &'static str {
118    KERNEL_TYPE_TABLE[t as usize]
119}
120
121fn str_to_svm_type(s: &str) -> Option<SvmType> {
122    match s {
123        "c_svc" => Some(SvmType::CSvc),
124        "nu_svc" => Some(SvmType::NuSvc),
125        "one_class" => Some(SvmType::OneClass),
126        "epsilon_svr" => Some(SvmType::EpsilonSvr),
127        "nu_svr" => Some(SvmType::NuSvr),
128        _ => None,
129    }
130}
131
132fn str_to_kernel_type(s: &str) -> Option<KernelType> {
133    match s {
134        "linear" => Some(KernelType::Linear),
135        "polynomial" => Some(KernelType::Polynomial),
136        "rbf" => Some(KernelType::Rbf),
137        "sigmoid" => Some(KernelType::Sigmoid),
138        "precomputed" => Some(KernelType::Precomputed),
139        _ => None,
140    }
141}
142
143// ─── Problem file I/O ────────────────────────────────────────────────
144
145const MAX_FEATURE_INDEX: i32 = 10_000_000;
146
147/// Load an SVM problem from a file in LIBSVM sparse format.
148///
149/// Format: `<label> <index1>:<value1> <index2>:<value2> ...`
150pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
151    let file = std::fs::File::open(path)?;
152    let reader = std::io::BufReader::new(file);
153    load_problem_from_reader(reader)
154}
155
156/// Load an SVM problem from any buffered reader.
157pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
158    let mut labels = Vec::new();
159    let mut instances = Vec::new();
160
161    for (line_idx, line_result) in reader.lines().enumerate() {
162        let line = line_result?;
163        let line = line.trim();
164        if line.is_empty() {
165            continue;
166        }
167
168        let line_num = line_idx + 1;
169        let mut parts = line.split_whitespace();
170
171        // Parse label
172        let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
173            line: line_num,
174            message: "missing label".into(),
175        })?;
176        let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
177            line: line_num,
178            message: format!("invalid label: {}", label_str),
179        })?;
180
181        // Parse features (must be in ascending index order)
182        let mut nodes = Vec::new();
183        let mut prev_index: i32 = 0;
184        for token in parts {
185            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
186                line: line_num,
187                message: format!("expected index:value, got: {}", token),
188            })?;
189            let index: i32 = idx_str.parse().map_err(|_| SvmError::ParseError {
190                line: line_num,
191                message: format!("invalid index: {}", idx_str),
192            })?;
193
194            if index > MAX_FEATURE_INDEX {
195                return Err(SvmError::ParseError {
196                    line: line_num,
197                    message: format!(
198                        "feature index {} exceeds limit ({})",
199                        index, MAX_FEATURE_INDEX
200                    ),
201                });
202            }
203
204            if !nodes.is_empty() && index <= prev_index {
205                return Err(SvmError::ParseError {
206                    line: line_num,
207                    message: format!(
208                        "feature indices must be ascending: {} follows {}",
209                        index, prev_index
210                    ),
211                });
212            }
213            let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
214                line: line_num,
215                message: format!("invalid value: {}", val_str),
216            })?;
217            prev_index = index;
218            nodes.push(SvmNode { index, value });
219        }
220
221        labels.push(label);
222        instances.push(nodes);
223    }
224
225    Ok(SvmProblem { labels, instances })
226}
227
228// ─── Model file I/O ──────────────────────────────────────────────────
229
230const MAX_NR_CLASS: usize = 65535;
231const MAX_TOTAL_SV: usize = 10_000_000;
232
233/// Save an SVM model to a file in the original LIBSVM format.
234pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
235    let file = std::fs::File::create(path)?;
236    let writer = std::io::BufWriter::new(file);
237    save_model_to_writer(writer, model)
238}
239
240/// Save an SVM model to any writer.
241pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
242    let param = &model.param;
243
244    writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
245    writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
246
247    if param.kernel_type == KernelType::Polynomial {
248        writeln!(w, "degree {}", param.degree)?;
249    }
250    if matches!(
251        param.kernel_type,
252        KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
253    ) {
254        writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
255    }
256    if matches!(
257        param.kernel_type,
258        KernelType::Polynomial | KernelType::Sigmoid
259    ) {
260        writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
261    }
262
263    let nr_class = model.nr_class;
264    writeln!(w, "nr_class {}", nr_class)?;
265    writeln!(w, "total_sv {}", model.sv.len())?;
266
267    // rho
268    write!(w, "rho")?;
269    for r in &model.rho {
270        write!(w, " {}", fmt_17g(*r))?;
271    }
272    writeln!(w)?;
273
274    // label (classification only)
275    if !model.label.is_empty() {
276        write!(w, "label")?;
277        for l in &model.label {
278            write!(w, " {}", l)?;
279        }
280        writeln!(w)?;
281    }
282
283    // probA
284    if !model.prob_a.is_empty() {
285        write!(w, "probA")?;
286        for v in &model.prob_a {
287            write!(w, " {}", fmt_17g(*v))?;
288        }
289        writeln!(w)?;
290    }
291
292    // probB
293    if !model.prob_b.is_empty() {
294        write!(w, "probB")?;
295        for v in &model.prob_b {
296            write!(w, " {}", fmt_17g(*v))?;
297        }
298        writeln!(w)?;
299    }
300
301    // prob_density_marks (one-class)
302    if !model.prob_density_marks.is_empty() {
303        write!(w, "prob_density_marks")?;
304        for v in &model.prob_density_marks {
305            write!(w, " {}", fmt_17g(*v))?;
306        }
307        writeln!(w)?;
308    }
309
310    // nr_sv
311    if !model.n_sv.is_empty() {
312        write!(w, "nr_sv")?;
313        for n in &model.n_sv {
314            write!(w, " {}", n)?;
315        }
316        writeln!(w)?;
317    }
318
319    // SV section
320    writeln!(w, "SV")?;
321    let num_sv = model.sv.len();
322    let num_coef_rows = model.sv_coef.len(); // nr_class - 1
323
324    for i in 0..num_sv {
325        // sv_coef columns for this SV: %.17g
326        for j in 0..num_coef_rows {
327            write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
328        }
329        // sparse features: %.8g
330        if model.param.kernel_type == KernelType::Precomputed {
331            if let Some(node) = model.sv[i].first() {
332                write!(w, "0:{} ", node.value as i32)?;
333            }
334        } else {
335            for node in &model.sv[i] {
336                write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
337            }
338        }
339        writeln!(w)?;
340    }
341
342    Ok(())
343}
344
345/// Load an SVM model from a file in the original LIBSVM format.
346pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
347    let file = std::fs::File::open(path)?;
348    let reader = std::io::BufReader::new(file);
349    load_model_from_reader(reader)
350}
351
352/// Load an SVM model from any buffered reader.
353pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
354    let mut lines = reader.lines();
355
356    // Defaults
357    let mut param = SvmParameter::default();
358    let mut nr_class: usize = 0;
359    let mut total_sv: usize = 0;
360    let mut rho = Vec::new();
361    let mut label = Vec::new();
362    let mut prob_a = Vec::new();
363    let mut prob_b = Vec::new();
364    let mut prob_density_marks = Vec::new();
365    let mut n_sv = Vec::new();
366
367    // Read header
368    let mut line_num: usize = 0;
369    loop {
370        let line = lines.next().ok_or_else(|| {
371            SvmError::ModelFormatError("unexpected end of file in header".into())
372        })??;
373        line_num += 1;
374        let line = line.trim().to_string();
375        if line.is_empty() {
376            continue;
377        }
378
379        let mut parts = line.split_whitespace();
380        let cmd = parts.next().unwrap();
381
382        match cmd {
383            "svm_type" => {
384                let val = parts.next().ok_or_else(|| {
385                    SvmError::ModelFormatError(format!("line {}: missing svm_type value", line_num))
386                })?;
387                param.svm_type = str_to_svm_type(val).ok_or_else(|| {
388                    SvmError::ModelFormatError(format!(
389                        "line {}: unknown svm_type: {}",
390                        line_num, val
391                    ))
392                })?;
393            }
394            "kernel_type" => {
395                let val = parts.next().ok_or_else(|| {
396                    SvmError::ModelFormatError(format!(
397                        "line {}: missing kernel_type value",
398                        line_num
399                    ))
400                })?;
401                param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
402                    SvmError::ModelFormatError(format!(
403                        "line {}: unknown kernel_type: {}",
404                        line_num, val
405                    ))
406                })?;
407            }
408            "degree" => {
409                param.degree = parse_single(&mut parts, line_num, "degree")?;
410            }
411            "gamma" => {
412                param.gamma = parse_single(&mut parts, line_num, "gamma")?;
413            }
414            "coef0" => {
415                param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
416            }
417            "nr_class" => {
418                nr_class = parse_single(&mut parts, line_num, "nr_class")?;
419                if nr_class > MAX_NR_CLASS {
420                    return Err(SvmError::ModelFormatError(format!(
421                        "line {}: nr_class exceeds limit ({})",
422                        line_num, MAX_NR_CLASS
423                    )));
424                }
425            }
426            "total_sv" => {
427                total_sv = parse_single(&mut parts, line_num, "total_sv")?;
428                if total_sv > MAX_TOTAL_SV {
429                    return Err(SvmError::ModelFormatError(format!(
430                        "line {}: total_sv exceeds limit ({})",
431                        line_num, MAX_TOTAL_SV
432                    )));
433                }
434            }
435            "rho" => {
436                rho = parse_multiple_f64(&mut parts, line_num, "rho")?;
437            }
438            "label" => {
439                label = parse_multiple_i32(&mut parts, line_num, "label")?;
440            }
441            "probA" => {
442                prob_a = parse_multiple_f64(&mut parts, line_num, "probA")?;
443            }
444            "probB" => {
445                prob_b = parse_multiple_f64(&mut parts, line_num, "probB")?;
446            }
447            "prob_density_marks" => {
448                prob_density_marks =
449                    parse_multiple_f64(&mut parts, line_num, "prob_density_marks")?;
450            }
451            "nr_sv" => {
452                n_sv = parts
453                    .map(|s| {
454                        s.parse::<usize>().map_err(|_| {
455                            SvmError::ModelFormatError(format!(
456                                "line {}: invalid nr_sv value: {}",
457                                line_num, s
458                            ))
459                        })
460                    })
461                    .collect::<Result<Vec<_>, _>>()?;
462            }
463            "SV" => break,
464            _ => {
465                return Err(SvmError::ModelFormatError(format!(
466                    "line {}: unknown keyword: {}",
467                    line_num, cmd
468                )));
469            }
470        }
471    }
472
473    // Read SV section
474    let m = if nr_class > 1 { nr_class - 1 } else { 1 };
475    let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::with_capacity(total_sv)).collect();
476    let mut sv: Vec<Vec<SvmNode>> = Vec::with_capacity(total_sv);
477
478    for _ in 0..total_sv {
479        let line = lines.next().ok_or_else(|| {
480            SvmError::ModelFormatError("unexpected end of file in SV section".into())
481        })??;
482        line_num += 1;
483        let line = line.trim();
484        if line.is_empty() {
485            continue;
486        }
487
488        let mut parts = line.split_whitespace();
489
490        // First m tokens are sv_coef values
491        for (k, coef_row) in sv_coef.iter_mut().enumerate() {
492            let val_str = parts.next().ok_or_else(|| {
493                SvmError::ModelFormatError(format!("line {}: missing sv_coef[{}]", line_num, k))
494            })?;
495            let val: f64 = val_str.parse().map_err(|_| {
496                SvmError::ModelFormatError(format!(
497                    "line {}: invalid sv_coef: {}",
498                    line_num, val_str
499                ))
500            })?;
501            coef_row.push(val);
502        }
503
504        // Remaining tokens are index:value pairs
505        let mut nodes = Vec::new();
506        for token in parts {
507            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
508                SvmError::ModelFormatError(format!(
509                    "line {}: expected index:value, got: {}",
510                    line_num, token
511                ))
512            })?;
513            let index: i32 = idx_str.parse().map_err(|_| {
514                SvmError::ModelFormatError(format!("line {}: invalid index: {}", line_num, idx_str))
515            })?;
516
517            if index > MAX_FEATURE_INDEX {
518                return Err(SvmError::ModelFormatError(format!(
519                    "line {}: feature index {} exceeds limit ({})",
520                    line_num, index, MAX_FEATURE_INDEX
521                )));
522            }
523
524            let value: f64 = val_str.parse().map_err(|_| {
525                SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
526            })?;
527            nodes.push(SvmNode { index, value });
528        }
529        sv.push(nodes);
530    }
531
532    Ok(SvmModel {
533        param,
534        nr_class,
535        sv,
536        sv_coef,
537        rho,
538        prob_a,
539        prob_b,
540        prob_density_marks,
541        sv_indices: Vec::new(), // not stored in model file
542        label,
543        n_sv,
544    })
545}
546
547// ─── Helper parsers ──────────────────────────────────────────────────
548
549fn parse_single<T: std::str::FromStr>(
550    parts: &mut std::str::SplitWhitespace<'_>,
551    line_num: usize,
552    field: &str,
553) -> Result<T, SvmError> {
554    let val_str = parts.next().ok_or_else(|| {
555        SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
556    })?;
557    val_str.parse().map_err(|_| {
558        SvmError::ModelFormatError(format!(
559            "line {}: invalid {} value: {}",
560            line_num, field, val_str
561        ))
562    })
563}
564
565fn parse_multiple_f64(
566    parts: &mut std::str::SplitWhitespace<'_>,
567    line_num: usize,
568    field: &str,
569) -> Result<Vec<f64>, SvmError> {
570    parts
571        .map(|s| {
572            s.parse::<f64>().map_err(|_| {
573                SvmError::ModelFormatError(format!(
574                    "line {}: invalid {} value: {}",
575                    line_num, field, s
576                ))
577            })
578        })
579        .collect()
580}
581
582fn parse_multiple_i32(
583    parts: &mut std::str::SplitWhitespace<'_>,
584    line_num: usize,
585    field: &str,
586) -> Result<Vec<i32>, SvmError> {
587    parts
588        .map(|s| {
589            s.parse::<i32>().map_err(|_| {
590                SvmError::ModelFormatError(format!(
591                    "line {}: invalid {} value: {}",
592                    line_num, field, s
593                ))
594            })
595        })
596        .collect()
597}
598
599// ─── Tests ───────────────────────────────────────────────────────────
600
601#[cfg(test)]
602mod tests {
603    use super::*;
604    use std::path::PathBuf;
605
606    fn data_dir() -> PathBuf {
607        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
608            .join("..")
609            .join("..")
610            .join("data")
611    }
612
613    #[test]
614    fn parse_heart_scale() {
615        let path = data_dir().join("heart_scale");
616        let problem = load_problem(&path).unwrap();
617        assert_eq!(problem.labels.len(), 270);
618        assert_eq!(problem.instances.len(), 270);
619        // First instance: +1 label, 12 features (index 11 is missing/sparse)
620        assert_eq!(problem.labels[0], 1.0);
621        assert_eq!(
622            problem.instances[0][0],
623            SvmNode {
624                index: 1,
625                value: 0.708333
626            }
627        );
628        assert_eq!(problem.instances[0].len(), 12);
629    }
630
631    #[test]
632    fn parse_iris() {
633        let path = data_dir().join("iris.scale");
634        let problem = load_problem(&path).unwrap();
635        assert_eq!(problem.labels.len(), 150);
636        // 3 classes: 1, 2, 3
637        let classes: std::collections::HashSet<i64> =
638            problem.labels.iter().map(|&l| l as i64).collect();
639        assert_eq!(classes.len(), 3);
640    }
641
642    #[test]
643    fn parse_housing() {
644        let path = data_dir().join("housing_scale");
645        let problem = load_problem(&path).unwrap();
646        assert_eq!(problem.labels.len(), 506);
647        // Regression: labels are continuous
648        assert!((problem.labels[0] - 24.0).abs() < 1e-10);
649    }
650
651    #[test]
652    fn parse_empty_lines() {
653        let input = b"+1 1:0.5\n\n-1 2:0.3\n";
654        let problem = load_problem_from_reader(&input[..]).unwrap();
655        assert_eq!(problem.labels.len(), 2);
656    }
657
658    #[test]
659    fn parse_error_unsorted_indices() {
660        let input = b"+1 3:0.5 1:0.3\n";
661        let result = load_problem_from_reader(&input[..]);
662        assert!(result.is_err());
663        let msg = format!("{}", result.unwrap_err());
664        assert!(msg.contains("ascending"), "error: {}", msg);
665    }
666
667    #[test]
668    fn parse_error_duplicate_indices() {
669        let input = b"+1 1:0.5 1:0.3\n";
670        let result = load_problem_from_reader(&input[..]);
671        assert!(result.is_err());
672    }
673
674    #[test]
675    fn parse_error_missing_colon() {
676        let input = b"+1 1:0.5 bad_token\n";
677        let result = load_problem_from_reader(&input[..]);
678        assert!(result.is_err());
679    }
680
681    #[test]
682    #[allow(clippy::excessive_precision)]
683    fn load_c_trained_model() {
684        // Load a model produced by the original C LIBSVM svm-train
685        let path = data_dir().join("heart_scale.model");
686        let model = load_model(&path).unwrap();
687        assert_eq!(model.nr_class, 2);
688        assert_eq!(model.param.svm_type, SvmType::CSvc);
689        assert_eq!(model.param.kernel_type, KernelType::Rbf);
690        assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
691        assert_eq!(model.sv.len(), 132);
692        assert_eq!(model.label, vec![1, -1]);
693        assert_eq!(model.n_sv, vec![64, 68]);
694        assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
695        // sv_coef should have 1 row (nr_class - 1) with 132 entries
696        assert_eq!(model.sv_coef.len(), 1);
697        assert_eq!(model.sv_coef[0].len(), 132);
698    }
699
700    #[test]
701    fn roundtrip_c_model() {
702        // Load C model, save it back, and verify byte-exact match
703        let path = data_dir().join("heart_scale.model");
704        let original_bytes = std::fs::read_to_string(&path).unwrap();
705        let model = load_model(&path).unwrap();
706
707        let mut buf = Vec::new();
708        save_model_to_writer(&mut buf, &model).unwrap();
709        let rust_output = String::from_utf8(buf).unwrap();
710
711        // Compare line by line for better diagnostics
712        let orig_lines: Vec<&str> = original_bytes.lines().collect();
713        let rust_lines: Vec<&str> = rust_output.lines().collect();
714        assert_eq!(
715            orig_lines.len(),
716            rust_lines.len(),
717            "line count mismatch: C={} Rust={}",
718            orig_lines.len(),
719            rust_lines.len()
720        );
721        for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
722            assert_eq!(
723                o,
724                r,
725                "line {} differs:\n  C:    {:?}\n  Rust: {:?}",
726                i + 1,
727                o,
728                r
729            );
730        }
731    }
732
733    #[test]
734    #[allow(clippy::excessive_precision)]
735    fn gfmt_matches_c_printf() {
736        // Reference values from C's printf("%.17g|%.8g\n", v, v)
737        let cases: &[(f64, &str, &str)] = &[
738            (0.5, "0.5", "0.5"),
739            (-1.0, "-1", "-1"),
740            (0.123456789012345, "0.123456789012345", "0.12345679"),
741            (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
742            (0.42446200000000001, "0.42446200000000001", "0.424462"),
743            (0.0, "0", "0"),
744            (1e-5, "1.0000000000000001e-05", "1e-05"),
745            (1e-4, "0.0001", "0.0001"),
746            (1e20, "1e+20", "1e+20"),
747            (-0.25, "-0.25", "-0.25"),
748            (0.75, "0.75", "0.75"),
749            (0.708333, "0.70833299999999999", "0.708333"),
750            (1.0, "1", "1"),
751        ];
752        for &(v, expected_17g, expected_8g) in cases {
753            let got_17 = format!("{}", fmt_17g(v));
754            let got_8 = format!("{}", fmt_8g(v));
755            assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
756            assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
757        }
758    }
759
760    #[test]
761    #[allow(clippy::excessive_precision)]
762    fn model_roundtrip() {
763        // Create a minimal model and verify save → load roundtrip
764        let model = SvmModel {
765            param: SvmParameter {
766                svm_type: SvmType::CSvc,
767                kernel_type: KernelType::Rbf,
768                gamma: 0.5,
769                ..Default::default()
770            },
771            nr_class: 2,
772            sv: vec![
773                vec![
774                    SvmNode {
775                        index: 1,
776                        value: 0.5,
777                    },
778                    SvmNode {
779                        index: 3,
780                        value: -1.0,
781                    },
782                ],
783                vec![
784                    SvmNode {
785                        index: 1,
786                        value: -0.25,
787                    },
788                    SvmNode {
789                        index: 2,
790                        value: 0.75,
791                    },
792                ],
793            ],
794            sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
795            rho: vec![0.42446200000000001],
796            prob_a: vec![],
797            prob_b: vec![],
798            prob_density_marks: vec![],
799            sv_indices: vec![],
800            label: vec![1, -1],
801            n_sv: vec![1, 1],
802        };
803
804        let mut buf = Vec::new();
805        save_model_to_writer(&mut buf, &model).unwrap();
806
807        let loaded = load_model_from_reader(&buf[..]).unwrap();
808
809        assert_eq!(loaded.nr_class, model.nr_class);
810        assert_eq!(loaded.param.svm_type, model.param.svm_type);
811        assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
812        assert_eq!(loaded.sv.len(), model.sv.len());
813        assert_eq!(loaded.label, model.label);
814        assert_eq!(loaded.n_sv, model.n_sv);
815        assert_eq!(loaded.rho.len(), model.rho.len());
816        // Check rho within tolerance (roundtrip through text)
817        for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
818            assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
819        }
820        // Check sv_coef within tolerance
821        for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
822            for (a, b) in row_a.iter().zip(row_b.iter()) {
823                assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
824            }
825        }
826    }
827
828    #[test]
829    fn parse_error_excessive_counts() {
830        let input =
831            b"svm_type c_svc\nkernel_type linear\nnr_class 1000000\ntotal_sv 100\nrho 0\nSV\n";
832        let result = load_model_from_reader(&input[..]);
833        assert!(result.is_err());
834        assert!(format!("{}", result.unwrap_err()).contains("nr_class exceeds limit"));
835
836        let input =
837            b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 100000000\nrho 0\nSV\n";
838        let result = load_model_from_reader(&input[..]);
839        assert!(result.is_err());
840        assert!(format!("{}", result.unwrap_err()).contains("total_sv exceeds limit"));
841    }
842
843    #[test]
844    fn parse_error_excessive_feature_index() {
845        // Problem file
846        let input = b"1 10000001:1\n";
847        let result = load_problem_from_reader(&input[..]);
848        assert!(result.is_err());
849        assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
850
851        // Model file
852        let input = b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 1\nrho 0\nSV\n0.1 10000001:1\n";
853        let result = load_model_from_reader(&input[..]);
854        assert!(result.is_err());
855        assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
856    }
857
858    #[test]
859    fn parse_error_unknown_model_keyword() {
860        let input = b"bad_key value\n";
861        let result = load_model_from_reader(&input[..]);
862        assert!(result.is_err());
863        assert!(format!("{}", result.unwrap_err()).contains("unknown keyword"));
864    }
865
866    #[test]
867    fn parse_error_missing_or_unknown_model_values() {
868        let missing = b"svm_type\n";
869        let err = load_model_from_reader(&missing[..]).unwrap_err();
870        assert!(format!("{}", err).contains("missing svm_type value"));
871
872        let unknown = b"svm_type unknown_type\n";
873        let err = load_model_from_reader(&unknown[..]).unwrap_err();
874        assert!(format!("{}", err).contains("unknown svm_type"));
875    }
876
877    #[test]
878    fn parse_error_invalid_nr_sv_entry() {
879        let input = b"svm_type c_svc\n\
880kernel_type linear\n\
881nr_class 2\n\
882total_sv 1\n\
883rho 0\n\
884nr_sv a 1\n\
885SV\n\
8860.1 1:0.5\n";
887        let err = load_model_from_reader(&input[..]).unwrap_err();
888        assert!(format!("{}", err).contains("invalid nr_sv value"));
889    }
890
891    #[test]
892    fn parse_error_in_sv_section_tokens() {
893        let missing_coef = b"svm_type c_svc\n\
894kernel_type linear\n\
895nr_class 2\n\
896total_sv 1\n\
897rho 0\n\
898SV\n\
8991:0.5\n";
900        let err = load_model_from_reader(&missing_coef[..]).unwrap_err();
901        assert!(format!("{}", err).contains("invalid sv_coef"));
902
903        let bad_feature = b"svm_type c_svc\n\
904kernel_type linear\n\
905nr_class 2\n\
906total_sv 1\n\
907rho 0\n\
908SV\n\
9090.1 bad\n";
910        let err = load_model_from_reader(&bad_feature[..]).unwrap_err();
911        assert!(format!("{}", err).contains("expected index:value"));
912    }
913
914    #[test]
915    fn parse_error_unexpected_eof_in_header_and_sv_section() {
916        let eof_header = b"svm_type c_svc\n";
917        let err = load_model_from_reader(&eof_header[..]).unwrap_err();
918        assert!(format!("{}", err).contains("unexpected end of file in header"));
919
920        let eof_sv = b"svm_type c_svc\n\
921kernel_type linear\n\
922nr_class 2\n\
923total_sv 2\n\
924rho 0\n\
925SV\n\
9260.1 1:0.5\n";
927        let err = load_model_from_reader(&eof_sv[..]).unwrap_err();
928        assert!(format!("{}", err).contains("unexpected end of file in SV section"));
929    }
930
931    #[test]
932    fn save_precomputed_model_writes_zero_index() {
933        let model = SvmModel {
934            param: SvmParameter {
935                svm_type: SvmType::CSvc,
936                kernel_type: KernelType::Precomputed,
937                ..Default::default()
938            },
939            nr_class: 2,
940            sv: vec![vec![SvmNode {
941                index: 0,
942                value: 7.0,
943            }]],
944            sv_coef: vec![vec![0.25]],
945            rho: vec![0.0],
946            prob_a: vec![],
947            prob_b: vec![],
948            prob_density_marks: vec![],
949            sv_indices: vec![],
950            label: vec![1, -1],
951            n_sv: vec![1, 0],
952        };
953
954        let mut buf = Vec::new();
955        save_model_to_writer(&mut buf, &model).unwrap();
956        let out = String::from_utf8(buf).unwrap();
957        assert!(out.contains("kernel_type precomputed"));
958        assert!(out.contains("0:7"));
959    }
960}