Skip to main content

libsvm_rs/
io.rs

1//! I/O routines for LIBSVM problem and model files.
2//!
3//! The file formats match the original LIBSVM text formats for cross-tool
4//! interoperability, but the loaders are written for untrusted input by
5//! default. They enforce [`LoadOptions`] caps, reject embedded NUL bytes,
6//! reject non-ascending feature indices, and return structured [`SvmError`]
7//! values for malformed input within those caps.
8//!
9//! Model loading performs additional structural checks before allocating support
10//! vector storage. Header fields such as `nr_class`, `total_sv`, `rho`, `label`,
11//! `nr_sv`, `probA`, `probB`, and one-class probability-density marks must be
12//! shape-consistent when present. Precomputed-kernel support-vector rows must
13//! begin with `0:sample_serial_number`.
14//!
15//! These checks validate text structure and resource bounds. They do not prove
16//! that a model is statistically meaningful, was trained from a particular
17//! dataset, or is cryptographically authentic.
18
19use std::io::{BufRead, Write};
20use std::path::Path;
21
22use crate::error::SvmError;
23use crate::types::*;
24use crate::util::parse_feature_index;
25use crate::util::MAX_FEATURE_INDEX;
26
27// ─── C-compatible %g formatting ─────────────────────────────────────
28//
29// C's printf `%.Pg` format strips trailing zeros and picks fixed vs.
30// scientific notation based on the exponent. Rust has no built-in
31// equivalent, so we replicate the POSIX specification:
32//   - Use scientific if exponent < -4 or exponent >= precision
33//   - Otherwise use fixed notation
34//   - Strip trailing zeros (and trailing decimal point)
35
36use std::fmt;
37
38/// Formats `f64` like C's `%.17g` (or any precision).
39struct Gfmt {
40    value: f64,
41    precision: usize,
42}
43
44impl Gfmt {
45    fn new(value: f64, precision: usize) -> Self {
46        Self { value, precision }
47    }
48}
49
50impl fmt::Display for Gfmt {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        let v = self.value;
53        let p = self.precision;
54
55        if !v.is_finite() {
56            return write!(f, "{}", v); // inf, -inf, NaN
57        }
58
59        if v == 0.0 {
60            // Preserve sign of -0.0
61            if v.is_sign_negative() {
62                return write!(f, "-0");
63            }
64            return write!(f, "0");
65        }
66
67        // Compute the exponent (floor of log10(|v|))
68        let abs_v = v.abs();
69        let exp = abs_v.log10().floor() as i32;
70
71        if exp < -4 || exp >= p as i32 {
72            // Use scientific notation
73            let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
74            // Rust uses 'e', C uses 'e'. Strip trailing zeros in mantissa.
75            // C zero-pads exponent to at least 2 digits (e-05 not e-5).
76            if let Some((mantissa, exponent)) = s.split_once('e') {
77                let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
78                // Parse exponent, reformat with at least 2 digits
79                let exp_val: i32 = exponent.parse().unwrap_or(0);
80                let exp_str = if exp_val < 0 {
81                    format!("-{:02}", -exp_val)
82                } else {
83                    format!("+{:02}", exp_val)
84                };
85                write!(f, "{}e{}", mantissa, exp_str)
86            } else {
87                write!(f, "{}", s)
88            }
89        } else {
90            // Use fixed notation. Number of decimal places = precision - (exp + 1)
91            let decimal_places = if exp >= 0 {
92                p.saturating_sub((exp + 1) as usize)
93            } else {
94                p + (-1 - exp) as usize
95            };
96            let s = format!("{:.prec$}", v, prec = decimal_places);
97            let s = s.trim_end_matches('0').trim_end_matches('.');
98            write!(f, "{}", s)
99        }
100    }
101}
102
103/// Format like C's `%.17g`
104fn fmt_17g(v: f64) -> Gfmt {
105    Gfmt::new(v, 17)
106}
107
108/// Format like C's `%.8g`
109fn fmt_8g(v: f64) -> Gfmt {
110    Gfmt::new(v, 8)
111}
112
113/// Format a float like C's `%g` (6 significant digits).
114pub fn format_g(v: f64) -> String {
115    format!("{}", Gfmt::new(v, 6))
116}
117
118/// Format a float like C's `%.17g` (17 significant digits).
119pub fn format_17g(v: f64) -> String {
120    format!("{}", Gfmt::new(v, 17))
121}
122
123// ─── String tables matching original LIBSVM ──────────────────────────
124
125const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
126const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
127
128fn svm_type_to_str(t: SvmType) -> &'static str {
129    SVM_TYPE_TABLE[t as usize]
130}
131
132fn kernel_type_to_str(t: KernelType) -> &'static str {
133    KERNEL_TYPE_TABLE[t as usize]
134}
135
136fn str_to_svm_type(s: &str) -> Option<SvmType> {
137    match s {
138        "c_svc" => Some(SvmType::CSvc),
139        "nu_svc" => Some(SvmType::NuSvc),
140        "one_class" => Some(SvmType::OneClass),
141        "epsilon_svr" => Some(SvmType::EpsilonSvr),
142        "nu_svr" => Some(SvmType::NuSvr),
143        _ => None,
144    }
145}
146
147fn str_to_kernel_type(s: &str) -> Option<KernelType> {
148    match s {
149        "linear" => Some(KernelType::Linear),
150        "polynomial" => Some(KernelType::Polynomial),
151        "rbf" => Some(KernelType::Rbf),
152        "sigmoid" => Some(KernelType::Sigmoid),
153        "precomputed" => Some(KernelType::Precomputed),
154        _ => None,
155    }
156}
157
158// ─── Load options ────────────────────────────────────────────────────
159
160/// Resource caps applied while reading LIBSVM problem and model files.
161///
162/// LIBSVM text formats are linear in file size, but individual fields (e.g.
163/// `total_sv`, feature indices, per-line token counts) can be used by a
164/// malicious file to trigger disproportionate allocation or CPU work. The
165/// [`Default`] impl returns defaults tuned for **untrusted input**:
166///
167/// | Field               | Default        |
168/// |---------------------|----------------|
169/// | `max_bytes`         | 64 MiB         |
170/// | `max_line_len`      | 1 MiB          |
171/// | `max_sv`            | 10 000 000     |
172/// | `max_nr_class`      | 65 535         |
173/// | `max_feature_index` | 10 000 000     |
174///
175/// If you know the input is trusted (produced locally, read from an
176/// attestation-protected location, etc.) and you need to handle inputs
177/// larger than the defaults, call [`LoadOptions::trusted_input`] for an
178/// "unlimited" profile, or override specific fields:
179///
180/// ```ignore
181/// use libsvm_rs::io::{load_problem_from_reader_with_options, LoadOptions};
182///
183/// let opts = LoadOptions {
184///     max_bytes: 1 << 30, // 1 GiB for a trusted bulk file
185///     ..LoadOptions::default()
186/// };
187/// let problem = load_problem_from_reader_with_options(reader, &opts)?;
188/// ```
189///
190/// Exceeding any cap returns an [`SvmError::ParseError`] (problem files) or
191/// [`SvmError::ModelFormatError`] (model files) with a message identifying
192/// the field that tripped the limit.
193///
194/// The default profile is intended to be panic-free for malformed text input
195/// within these caps. It is still a parser contract, not a semantic model
196/// audit: it does not verify that a loaded model came from a specific training
197/// set or that its predictions are appropriate for a deployment.
198#[derive(Debug, Clone, Copy)]
199pub struct LoadOptions {
200    /// Maximum total number of bytes read from the source.
201    pub max_bytes: u64,
202    /// Maximum number of bytes in a single line (excluding the terminator).
203    pub max_line_len: usize,
204    /// Maximum value accepted for the `total_sv` header field in model files.
205    pub max_sv: usize,
206    /// Maximum value accepted for the `nr_class` header field in model files.
207    pub max_nr_class: usize,
208    /// Maximum feature index accepted in problem or model feature lists.
209    pub max_feature_index: i32,
210}
211
212impl Default for LoadOptions {
213    fn default() -> Self {
214        Self {
215            max_bytes: 64 * 1024 * 1024,
216            max_line_len: 1024 * 1024,
217            max_sv: MAX_TOTAL_SV,
218            max_nr_class: MAX_NR_CLASS,
219            max_feature_index: MAX_FEATURE_INDEX,
220        }
221    }
222}
223
224impl LoadOptions {
225    /// Options with all caps set to the type maximum.
226    ///
227    /// Use **only** for input that is fully trusted. Operating with
228    /// unlimited caps removes the first line of defense against malformed
229    /// or adversarial model / problem files. Module-level hard caps still
230    /// apply to `nr_class` and `total_sv`.
231    pub fn trusted_input() -> Self {
232        Self {
233            max_bytes: u64::MAX,
234            max_line_len: usize::MAX,
235            max_sv: usize::MAX,
236            max_nr_class: usize::MAX,
237            max_feature_index: i32::MAX,
238        }
239    }
240}
241
242/// Read one line from `reader`, enforcing byte and line-length caps.
243///
244/// Returns `Ok(None)` on clean EOF. Updates `bytes_read` by the number of
245/// bytes actually consumed (including the line terminator).
246///
247/// Uses [`BufRead::fill_buf`] / [`BufRead::consume`] to pull at most a
248/// buffer's worth at a time and check both caps before committing bytes
249/// to the output buffer. A pathological unbounded line cannot grow the
250/// output buffer past `max_line_len` — the error fires before the next
251/// slice copy would push past the cap.
252fn read_line_capped(
253    reader: &mut dyn BufRead,
254    bytes_read: &mut u64,
255    max_bytes: u64,
256    max_line_len: usize,
257) -> std::io::Result<Option<String>> {
258    // One byte of slack on the per-line cap to accommodate `\r\n` line
259    // endings: `max_line_len` describes *content* length, so a line with
260    // the full allowed content plus CRLF consumes `max_line_len + 2` bytes
261    // total, of which one (the \r) would otherwise push us past the cap
262    // before the final `\n` arrives.
263    let per_line_raw_cap: u64 = (max_line_len as u64).saturating_add(1);
264
265    let mut buf: Vec<u8> = Vec::new();
266    let mut found_newline = false;
267
268    loop {
269        let available = reader.fill_buf()?;
270        if available.is_empty() {
271            break;
272        }
273
274        let take_n = match available.iter().position(|&b| b == b'\n') {
275            Some(pos) => pos + 1, // include the newline
276            None => available.len(),
277        };
278        let ends_with_newline = take_n > 0 && available[take_n - 1] == b'\n';
279
280        // Byte cap.
281        let new_bytes_read = bytes_read.saturating_add(take_n as u64);
282        if new_bytes_read > max_bytes {
283            return Err(std::io::Error::new(
284                std::io::ErrorKind::InvalidData,
285                format!("input exceeds max_bytes limit ({})", max_bytes),
286            ));
287        }
288
289        // Per-line cap. Content bytes excludes the trailing `\n` (if any);
290        // we allow one byte of slack for a possible `\r` preceding it.
291        let prospective_len = (buf.len() as u64).saturating_add(take_n as u64);
292        let content_bytes = if ends_with_newline {
293            prospective_len - 1
294        } else {
295            prospective_len
296        };
297        if content_bytes > per_line_raw_cap {
298            return Err(std::io::Error::new(
299                std::io::ErrorKind::InvalidData,
300                format!("line length exceeds max_line_len limit ({})", max_line_len),
301            ));
302        }
303
304        // NUL bytes have no legal use in LIBSVM text files.
305        if available[..take_n].contains(&0) {
306            return Err(std::io::Error::new(
307                std::io::ErrorKind::InvalidData,
308                "unexpected NUL byte in text input".to_string(),
309            ));
310        }
311
312        buf.extend_from_slice(&available[..take_n]);
313        reader.consume(take_n);
314        *bytes_read = new_bytes_read;
315
316        if ends_with_newline {
317            found_newline = true;
318            break;
319        }
320    }
321
322    if buf.is_empty() && !found_newline {
323        return Ok(None);
324    }
325
326    let line = String::from_utf8(buf)
327        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
328    Ok(Some(line))
329}
330
331// ─── Problem file I/O ────────────────────────────────────────────────
332
333/// Load an SVM problem from a file in LIBSVM sparse format.
334///
335/// Format: `<label> <index1>:<value1> <index2>:<value2> ...`
336///
337/// Uses [`LoadOptions::default`] — appropriate for untrusted input. For
338/// trusted bulk files exceeding the defaults, use
339/// [`load_problem_from_reader_with_options`] with a custom [`LoadOptions`].
340///
341/// Validates:
342///
343/// - total file size and per-line length,
344/// - embedded NUL byte absence,
345/// - parseable labels and feature values,
346/// - `index:value` token shape,
347/// - ascending feature indices,
348/// - [`LoadOptions::max_feature_index`].
349///
350/// This loader does not validate statistical quality, feature normalization, or
351/// whether labels are appropriate for a particular SVM formulation. Malformed
352/// text input within the configured caps is returned as [`SvmError`] rather
353/// than panicking.
354///
355/// ### Complexity
356///
357/// Linear in file size for parsing and `O(n · d)` memory where `n` is the
358/// number of instances and `d` is the average number of non-zero features
359/// per instance. No per-row allocation is driven by an untrusted header.
360pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
361    let file = std::fs::File::open(path)?;
362    let reader = std::io::BufReader::new(file);
363    load_problem_from_reader(reader)
364}
365
366/// Load an SVM problem from any buffered reader.
367///
368/// Uses [`LoadOptions::default`]. See [`load_problem`] for the validation
369/// contract and non-goals.
370pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
371    load_problem_from_reader_with_options(reader, &LoadOptions::default())
372}
373
374/// Load an SVM problem from any buffered reader, with explicit resource caps.
375///
376/// See [`LoadOptions`] for the meaning of each cap and for defaults tuned for
377/// untrusted input. This function has the same validation contract as
378/// [`load_problem`], with caller-supplied caps.
379pub fn load_problem_from_reader_with_options(
380    mut reader: impl BufRead,
381    options: &LoadOptions,
382) -> Result<SvmProblem, SvmError> {
383    let mut labels = Vec::new();
384    let mut instances = Vec::new();
385    let mut bytes_read: u64 = 0;
386    let mut line_idx: usize = 0;
387
388    while let Some(raw) = read_line_capped(
389        &mut reader,
390        &mut bytes_read,
391        options.max_bytes,
392        options.max_line_len,
393    )? {
394        let line_num = line_idx + 1;
395        line_idx += 1;
396        let line = raw.trim();
397        if line.is_empty() {
398            continue;
399        }
400
401        let mut parts = line.split_whitespace();
402
403        // Parse label
404        let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
405            line: line_num,
406            message: "missing label".into(),
407        })?;
408        let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
409            line: line_num,
410            message: format!("invalid label: {}", label_str),
411        })?;
412
413        // Parse features (must be in ascending index order).
414        let mut nodes = Vec::new();
415        let mut prev_index: i32 = 0;
416        for token in parts {
417            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
418                line: line_num,
419                message: format!("expected index:value, got: {}", token),
420            })?;
421            let index: i32 =
422                parse_feature_index_problem_line(line_num, idx_str, options.max_feature_index)?;
423
424            if !nodes.is_empty() && index <= prev_index {
425                return Err(SvmError::ParseError {
426                    line: line_num,
427                    message: format!(
428                        "feature indices must be ascending: {} follows {}",
429                        index, prev_index
430                    ),
431                });
432            }
433            let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
434                line: line_num,
435                message: format!("invalid value: {}", val_str),
436            })?;
437            prev_index = index;
438            nodes.push(SvmNode { index, value });
439        }
440
441        labels.push(label);
442        instances.push(nodes);
443    }
444
445    Ok(SvmProblem { labels, instances })
446}
447
448// ─── Model file I/O ──────────────────────────────────────────────────
449
450const MAX_NR_CLASS: usize = 65535;
451const MAX_TOTAL_SV: usize = 10_000_000;
452
453/// Save an SVM model to a file in the original LIBSVM format.
454pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
455    let file = std::fs::File::create(path)?;
456    let writer = std::io::BufWriter::new(file);
457    save_model_to_writer(writer, model)
458}
459
460/// Save an SVM model to any writer.
461pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
462    let param = &model.param;
463
464    writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
465    writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
466
467    if param.kernel_type == KernelType::Polynomial {
468        writeln!(w, "degree {}", param.degree)?;
469    }
470    if matches!(
471        param.kernel_type,
472        KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
473    ) {
474        writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
475    }
476    if matches!(
477        param.kernel_type,
478        KernelType::Polynomial | KernelType::Sigmoid
479    ) {
480        writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
481    }
482
483    let nr_class = model.nr_class;
484    writeln!(w, "nr_class {}", nr_class)?;
485    writeln!(w, "total_sv {}", model.sv.len())?;
486
487    // rho
488    write!(w, "rho")?;
489    for r in &model.rho {
490        write!(w, " {}", fmt_17g(*r))?;
491    }
492    writeln!(w)?;
493
494    // label (classification only)
495    if !model.label.is_empty() {
496        write!(w, "label")?;
497        for l in &model.label {
498            write!(w, " {}", l)?;
499        }
500        writeln!(w)?;
501    }
502
503    // probA
504    if !model.prob_a.is_empty() {
505        write!(w, "probA")?;
506        for v in &model.prob_a {
507            write!(w, " {}", fmt_17g(*v))?;
508        }
509        writeln!(w)?;
510    }
511
512    // probB
513    if !model.prob_b.is_empty() {
514        write!(w, "probB")?;
515        for v in &model.prob_b {
516            write!(w, " {}", fmt_17g(*v))?;
517        }
518        writeln!(w)?;
519    }
520
521    // prob_density_marks (one-class)
522    if !model.prob_density_marks.is_empty() {
523        write!(w, "prob_density_marks")?;
524        for v in &model.prob_density_marks {
525            write!(w, " {}", fmt_17g(*v))?;
526        }
527        writeln!(w)?;
528    }
529
530    // nr_sv
531    if !model.n_sv.is_empty() {
532        write!(w, "nr_sv")?;
533        for n in &model.n_sv {
534            write!(w, " {}", n)?;
535        }
536        writeln!(w)?;
537    }
538
539    // SV section
540    writeln!(w, "SV")?;
541    let num_sv = model.sv.len();
542    let num_coef_rows = model.sv_coef.len(); // nr_class - 1
543
544    for i in 0..num_sv {
545        // sv_coef columns for this SV: %.17g
546        for j in 0..num_coef_rows {
547            write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
548        }
549        // sparse features: %.8g
550        if model.param.kernel_type == KernelType::Precomputed {
551            if let Some(node) = model.sv[i].first() {
552                write!(w, "0:{} ", node.value as i32)?;
553            }
554        } else {
555            for node in &model.sv[i] {
556                write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
557            }
558        }
559        writeln!(w)?;
560    }
561
562    Ok(())
563}
564
565/// Load an SVM model from a file in the original LIBSVM format.
566///
567/// Uses [`LoadOptions::default`] — appropriate for untrusted input.
568///
569/// Validates:
570///
571/// - total file size and per-line length,
572/// - embedded NUL byte absence,
573/// - known `svm_type` and `kernel_type` values,
574/// - `nr_class` and `total_sv` caps,
575/// - `nr_class >= 2`,
576/// - `rho` length for classification, one-class, and regression models,
577/// - `label` and `nr_sv` length when present,
578/// - `sum(nr_sv) == total_sv` when `nr_sv` is present,
579/// - `probA` / `probB` decision-function counts when present,
580/// - one-class-only `prob_density_marks`,
581/// - support-vector feature token shape, ascending feature indices, and
582///   [`LoadOptions::max_feature_index`],
583/// - precomputed-kernel support-vector rows starting with
584///   `0:sample_serial_number`.
585///
586/// This loader does not prove model provenance, semantic correctness relative
587/// to a training set, or suitability for a deployment. Malformed text input
588/// within the configured caps is returned as [`SvmError`] rather than panicking.
589///
590/// ### Complexity
591///
592/// The header parse is `O(nr_class)` in the worst case (due to `rho` /
593/// `label` / `nr_sv` array reads). The SV section is linear in the file
594/// size. Downstream consumers of the returned [`SvmModel`] — notably
595/// `group_classes` and probability estimation — are `O(k²)` on `k =
596/// nr_class`, bounded by [`LoadOptions::max_nr_class`].
597pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
598    let file = std::fs::File::open(path)?;
599    let reader = std::io::BufReader::new(file);
600    load_model_from_reader(reader)
601}
602
603/// Load an SVM model from any buffered reader.
604///
605/// Uses [`LoadOptions::default`]. See [`load_model`] for the validation
606/// contract and non-goals.
607pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
608    load_model_from_reader_with_options(reader, &LoadOptions::default())
609}
610
611/// Load an SVM model from any buffered reader, with explicit resource caps.
612///
613/// See [`LoadOptions`] for the meaning of each cap and for defaults tuned for
614/// untrusted input. This function has the same validation contract as
615/// [`load_model`], with caller-supplied caps.
616pub fn load_model_from_reader_with_options(
617    mut reader: impl BufRead,
618    options: &LoadOptions,
619) -> Result<SvmModel, SvmError> {
620    let mut bytes_read: u64 = 0;
621
622    // The `nr_class` / `total_sv` caps are the intersection of the
623    // module-level hard limits (`MAX_NR_CLASS`, `MAX_TOTAL_SV`) and the
624    // per-call `LoadOptions` overrides. A caller using
625    // `LoadOptions::trusted_input()` relaxes only down to the hard caps;
626    // it cannot exceed them. This is defense in depth.
627    let nr_class_cap = options.max_nr_class.min(MAX_NR_CLASS);
628    let total_sv_cap = options.max_sv.min(MAX_TOTAL_SV);
629
630    // Defaults
631    let mut param = SvmParameter::default();
632    let mut nr_class: usize = 0;
633    let mut total_sv: usize = 0;
634    let mut rho = Vec::new();
635    let mut label = Vec::new();
636    let mut prob_a = Vec::new();
637    let mut prob_b = Vec::new();
638    let mut prob_density_marks = Vec::new();
639    let mut n_sv = Vec::new();
640
641    // Read header.
642    let mut line_num: usize = 0;
643    loop {
644        let raw = read_line_capped(
645            &mut reader,
646            &mut bytes_read,
647            options.max_bytes,
648            options.max_line_len,
649        )
650        .map_err(|e| SvmError::ModelFormatError(e.to_string()))?
651        .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in header".into()))?;
652        line_num += 1;
653        let line = raw.trim().to_string();
654        if line.is_empty() {
655            continue;
656        }
657
658        let mut parts = line.split_whitespace();
659        let cmd = parts.next().ok_or_else(|| {
660            SvmError::ModelFormatError(format!("line {}: empty model header line", line_num))
661        })?;
662
663        match cmd {
664            "svm_type" => {
665                let val = parts.next().ok_or_else(|| {
666                    SvmError::ModelFormatError(format!("line {}: missing svm_type value", line_num))
667                })?;
668                param.svm_type = str_to_svm_type(val).ok_or_else(|| {
669                    SvmError::ModelFormatError(format!(
670                        "line {}: unknown svm_type: {}",
671                        line_num, val
672                    ))
673                })?;
674            }
675            "kernel_type" => {
676                let val = parts.next().ok_or_else(|| {
677                    SvmError::ModelFormatError(format!(
678                        "line {}: missing kernel_type value",
679                        line_num
680                    ))
681                })?;
682                param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
683                    SvmError::ModelFormatError(format!(
684                        "line {}: unknown kernel_type: {}",
685                        line_num, val
686                    ))
687                })?;
688            }
689            "degree" => {
690                param.degree = parse_single(&mut parts, line_num, "degree")?;
691            }
692            "gamma" => {
693                param.gamma = parse_single(&mut parts, line_num, "gamma")?;
694            }
695            "coef0" => {
696                param.coef0 = parse_single(&mut parts, line_num, "coef0")?;
697            }
698            "nr_class" => {
699                nr_class = parse_single(&mut parts, line_num, "nr_class")?;
700                if nr_class > nr_class_cap {
701                    return Err(SvmError::ModelFormatError(format!(
702                        "line {}: nr_class exceeds limit ({})",
703                        line_num, nr_class_cap
704                    )));
705                }
706            }
707            "total_sv" => {
708                total_sv = parse_single(&mut parts, line_num, "total_sv")?;
709                if total_sv > total_sv_cap {
710                    return Err(SvmError::ModelFormatError(format!(
711                        "line {}: total_sv exceeds limit ({})",
712                        line_num, total_sv_cap
713                    )));
714                }
715            }
716            "rho" => {
717                rho = parse_multiple(&mut parts, line_num, "rho")?;
718            }
719            "label" => {
720                label = parse_multiple(&mut parts, line_num, "label")?;
721            }
722            "probA" => {
723                prob_a = parse_multiple(&mut parts, line_num, "probA")?;
724            }
725            "probB" => {
726                prob_b = parse_multiple(&mut parts, line_num, "probB")?;
727            }
728            "prob_density_marks" => {
729                prob_density_marks = parse_multiple(&mut parts, line_num, "prob_density_marks")?;
730            }
731            "nr_sv" => {
732                n_sv = parts
733                    .map(|s| {
734                        s.parse::<usize>().map_err(|_| {
735                            SvmError::ModelFormatError(format!(
736                                "line {}: invalid nr_sv value: {}",
737                                line_num, s
738                            ))
739                        })
740                    })
741                    .collect::<Result<Vec<_>, _>>()?;
742            }
743            "SV" => break,
744            _ => {
745                return Err(SvmError::ModelFormatError(format!(
746                    "line {}: unknown keyword: {}",
747                    line_num, cmd
748                )));
749            }
750        }
751    }
752
753    // Cross-consistency checks on the header.
754    //
755    // These run before any per-SV allocation so malformed files are rejected
756    // early, and so downstream code can rely on structural invariants (e.g.
757    // `rho.len() == nr_class * (nr_class - 1) / 2` for multiclass).
758    validate_model_header(
759        param.svm_type,
760        nr_class,
761        total_sv,
762        &rho,
763        &label,
764        &prob_a,
765        &prob_b,
766        &prob_density_marks,
767        &n_sv,
768    )?;
769
770    // Read SV section.
771    //
772    // SECURITY: we do NOT preallocate with `total_sv` capacity. A malicious
773    // file could claim up to `MAX_TOTAL_SV` support vectors in its header,
774    // which would trigger terabyte-scale reservations before any real data is
775    // read. Amortized `Vec` growth caps peak memory at the actually-parsed
776    // payload.
777    let m = if nr_class > 1 { nr_class - 1 } else { 1 };
778    let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::new()).collect();
779    let mut sv: Vec<Vec<SvmNode>> = Vec::new();
780
781    while sv.len() < total_sv {
782        let raw = read_line_capped(
783            &mut reader,
784            &mut bytes_read,
785            options.max_bytes,
786            options.max_line_len,
787        )
788        .map_err(|e| SvmError::ModelFormatError(e.to_string()))?
789        .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in SV section".into()))?;
790        line_num += 1;
791        let line = raw.trim();
792        if line.is_empty() {
793            // Skip blank lines without consuming an SV slot — the header's
794            // `total_sv` must match the number of SV rows we actually collect.
795            continue;
796        }
797
798        let mut parts = line.split_whitespace();
799
800        // First m tokens are sv_coef values
801        for (k, coef_row) in sv_coef.iter_mut().enumerate() {
802            let val_str = parts.next().ok_or_else(|| {
803                SvmError::ModelFormatError(format!("line {}: missing sv_coef[{}]", line_num, k))
804            })?;
805            let val: f64 = val_str.parse().map_err(|_| {
806                SvmError::ModelFormatError(format!(
807                    "line {}: invalid sv_coef: {}",
808                    line_num, val_str
809                ))
810            })?;
811            coef_row.push(val);
812        }
813
814        // Remaining tokens are index:value pairs (ascending index order, same
815        // invariant as the problem-file parser).
816        let mut nodes = Vec::new();
817        let mut prev_index: i32 = 0;
818        for token in parts {
819            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
820                SvmError::ModelFormatError(format!(
821                    "line {}: expected index:value, got: {}",
822                    line_num, token
823                ))
824            })?;
825            let index: i32 =
826                parse_feature_index_model_line(line_num, idx_str, options.max_feature_index)?;
827
828            if !nodes.is_empty() && index <= prev_index {
829                return Err(SvmError::ModelFormatError(format!(
830                    "line {}: feature indices must be ascending: {} follows {}",
831                    line_num, index, prev_index
832                )));
833            }
834
835            let value: f64 = val_str.parse().map_err(|_| {
836                SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
837            })?;
838            prev_index = index;
839            nodes.push(SvmNode { index, value });
840        }
841
842        if param.kernel_type == KernelType::Precomputed {
843            validate_precomputed_row(&nodes, line_num, "support vector")?;
844        }
845        sv.push(nodes);
846    }
847
848    Ok(SvmModel {
849        param,
850        nr_class,
851        sv,
852        sv_coef,
853        rho,
854        prob_a,
855        prob_b,
856        prob_density_marks,
857        sv_indices: Vec::new(), // not stored in model file
858        label,
859        n_sv,
860    })
861}
862
863fn validate_precomputed_row(
864    nodes: &[SvmNode],
865    line_num: usize,
866    context: &str,
867) -> Result<(), SvmError> {
868    let first = nodes.first().ok_or_else(|| {
869        SvmError::ModelFormatError(format!(
870            "line {}: precomputed kernel {} is missing 0:sample_serial_number",
871            line_num, context
872        ))
873    })?;
874
875    if first.index != 0
876        || !first.value.is_finite()
877        || first.value < 1.0
878        || first.value.fract() != 0.0
879    {
880        return Err(SvmError::ModelFormatError(format!(
881            "line {}: precomputed kernel {} must start with 0:sample_serial_number",
882            line_num, context
883        )));
884    }
885
886    Ok(())
887}
888
889// ─── Cross-consistency validation ────────────────────────────────────
890
891/// Validate model-header invariants before reading the SV section.
892///
893/// A malformed or adversarial model file can pass individual field parses
894/// and still describe a structurally impossible SVM. This gate rejects the
895/// mismatch early, before any allocation keyed on `total_sv`, so downstream
896/// code (prediction, probability estimation) can rely on the usual LIBSVM
897/// shape contracts:
898///
899/// * `rho.len() == k * (k - 1) / 2` for `k = nr_class` on classification,
900///   `rho.len() == 1` for one-class / regression (where `nr_class == 2`).
901/// * `label.len() == nr_class` and `n_sv.len() == nr_class` if supplied.
902/// * `sum(n_sv) == total_sv` if `n_sv` is supplied.
903/// * `prob_a` / `prob_b` (if supplied) match the expected decision-function
904///   count, and `prob_density_marks` only appears on one-class models.
905///
906/// Optional fields (e.g. `label`, `n_sv`, `probA`) are only validated when
907/// present, because minimal hand-written fixtures and some legacy writers
908/// omit them; the invariant "if present, must be consistent" is what matters
909/// for safety.
910#[allow(clippy::too_many_arguments)]
911fn validate_model_header(
912    svm_type: SvmType,
913    nr_class: usize,
914    total_sv: usize,
915    rho: &[f64],
916    label: &[i32],
917    prob_a: &[f64],
918    prob_b: &[f64],
919    prob_density_marks: &[f64],
920    n_sv: &[usize],
921) -> Result<(), SvmError> {
922    let is_classification = matches!(svm_type, SvmType::CSvc | SvmType::NuSvc);
923    let is_regression = matches!(svm_type, SvmType::EpsilonSvr | SvmType::NuSvr);
924    let is_one_class = matches!(svm_type, SvmType::OneClass);
925
926    // nr_class must be at least 2 under the LIBSVM convention (regression and
927    // one-class store nr_class=2 as well, because the one-vs-one scaffolding
928    // is reused). nr_class==0 or 1 would yield `m = nr_class - 1 = 0` or
929    // underflow-prone arithmetic elsewhere.
930    if nr_class < 2 {
931        return Err(SvmError::ModelFormatError(format!(
932            "nr_class must be >= 2, got {}",
933            nr_class
934        )));
935    }
936
937    // Expected rho length depends on svm_type.
938    let expected_rho = if is_classification {
939        nr_class * (nr_class - 1) / 2
940    } else {
941        1
942    };
943    if rho.len() != expected_rho {
944        return Err(SvmError::ModelFormatError(format!(
945            "rho has {} entries, expected {} for svm_type {}",
946            rho.len(),
947            expected_rho,
948            svm_type_to_str(svm_type)
949        )));
950    }
951
952    // label is mandatory shape on classification, absent on regression/one-class.
953    if !label.is_empty() {
954        if !is_classification {
955            return Err(SvmError::ModelFormatError(format!(
956                "label is only valid for classification, got {} entries on svm_type {}",
957                label.len(),
958                svm_type_to_str(svm_type)
959            )));
960        }
961        if label.len() != nr_class {
962            return Err(SvmError::ModelFormatError(format!(
963                "label has {} entries, expected nr_class ({})",
964                label.len(),
965                nr_class
966            )));
967        }
968    }
969
970    // n_sv: same shape rule; if present on classification, sum must equal total_sv.
971    if !n_sv.is_empty() {
972        if !is_classification {
973            return Err(SvmError::ModelFormatError(format!(
974                "nr_sv is only valid for classification, got {} entries on svm_type {}",
975                n_sv.len(),
976                svm_type_to_str(svm_type)
977            )));
978        }
979        if n_sv.len() != nr_class {
980            return Err(SvmError::ModelFormatError(format!(
981                "nr_sv has {} entries, expected nr_class ({})",
982                n_sv.len(),
983                nr_class
984            )));
985        }
986        // Use checked_add to prevent silent overflow on malicious huge values.
987        // MAX_TOTAL_SV bounds total_sv already; n_sv values are parsed as
988        // `usize` and otherwise unbounded until this sum-check.
989        let mut sum: usize = 0;
990        for &n in n_sv {
991            sum = sum.checked_add(n).ok_or_else(|| {
992                SvmError::ModelFormatError("nr_sv entries overflow usize when summed".into())
993            })?;
994        }
995        if sum != total_sv {
996            return Err(SvmError::ModelFormatError(format!(
997                "sum of nr_sv entries ({}) does not match total_sv ({})",
998                sum, total_sv
999            )));
1000        }
1001    }
1002
1003    // Probability arrays: must either be absent or length-match rho.
1004    if !prob_a.is_empty() && prob_a.len() != expected_rho {
1005        return Err(SvmError::ModelFormatError(format!(
1006            "probA has {} entries, expected {}",
1007            prob_a.len(),
1008            expected_rho
1009        )));
1010    }
1011    if !prob_b.is_empty() && prob_b.len() != expected_rho {
1012        return Err(SvmError::ModelFormatError(format!(
1013            "probB has {} entries, expected {}",
1014            prob_b.len(),
1015            expected_rho
1016        )));
1017    }
1018
1019    // prob_density_marks is only meaningful for one-class.
1020    if !prob_density_marks.is_empty() && !is_one_class {
1021        return Err(SvmError::ModelFormatError(format!(
1022            "prob_density_marks is only valid for one-class SVM, got {} entries on svm_type {}",
1023            prob_density_marks.len(),
1024            svm_type_to_str(svm_type)
1025        )));
1026    }
1027
1028    // Regression/one-class should not carry classification-only artifacts.
1029    // (Already caught above via `label` / `n_sv` branches; this assertion
1030    // keeps the intent self-documenting for future maintainers.)
1031    let _ = is_regression;
1032
1033    Ok(())
1034}
1035
1036// ─── Helper parsers ──────────────────────────────────────────────────
1037
1038fn parse_feature_index_problem_line(
1039    line_num: usize,
1040    idx_str: &str,
1041    max_feature_index: i32,
1042) -> Result<i32, SvmError> {
1043    parse_feature_index(idx_str, max_feature_index).map_err(|msg| SvmError::ParseError {
1044        line: line_num,
1045        message: msg,
1046    })
1047}
1048
1049fn parse_feature_index_model_line(
1050    line_num: usize,
1051    idx_str: &str,
1052    max_feature_index: i32,
1053) -> Result<i32, SvmError> {
1054    parse_feature_index(idx_str, max_feature_index)
1055        .map_err(|msg| SvmError::ModelFormatError(format!("line {}: {}", line_num, msg)))
1056}
1057
1058fn parse_single<T: std::str::FromStr>(
1059    parts: &mut std::str::SplitWhitespace<'_>,
1060    line_num: usize,
1061    field: &str,
1062) -> Result<T, SvmError> {
1063    let val_str = parts.next().ok_or_else(|| {
1064        SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
1065    })?;
1066    val_str.parse().map_err(|_| {
1067        SvmError::ModelFormatError(format!(
1068            "line {}: invalid {} value: {}",
1069            line_num, field, val_str
1070        ))
1071    })
1072}
1073
1074fn parse_multiple<T: std::str::FromStr>(
1075    parts: &mut std::str::SplitWhitespace<'_>,
1076    line_num: usize,
1077    field: &str,
1078) -> Result<Vec<T>, SvmError> {
1079    parts
1080        .map(|s| {
1081            s.parse::<T>().map_err(|_| {
1082                SvmError::ModelFormatError(format!(
1083                    "line {}: invalid {} value: {}",
1084                    line_num, field, s
1085                ))
1086            })
1087        })
1088        .collect()
1089}
1090
1091// ─── Tests ───────────────────────────────────────────────────────────
1092
1093#[cfg(test)]
1094mod tests {
1095    use super::*;
1096    use std::path::PathBuf;
1097
1098    fn data_dir() -> PathBuf {
1099        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1100            .join("..")
1101            .join("..")
1102            .join("data")
1103    }
1104
1105    #[test]
1106    fn parse_heart_scale() {
1107        let path = data_dir().join("heart_scale");
1108        let problem = load_problem(&path).unwrap();
1109        assert_eq!(problem.labels.len(), 270);
1110        assert_eq!(problem.instances.len(), 270);
1111        // First instance: +1 label, 12 features (index 11 is missing/sparse)
1112        assert_eq!(problem.labels[0], 1.0);
1113        assert_eq!(
1114            problem.instances[0][0],
1115            SvmNode {
1116                index: 1,
1117                value: 0.708333
1118            }
1119        );
1120        assert_eq!(problem.instances[0].len(), 12);
1121    }
1122
1123    #[test]
1124    fn parse_iris() {
1125        let path = data_dir().join("iris.scale");
1126        let problem = load_problem(&path).unwrap();
1127        assert_eq!(problem.labels.len(), 150);
1128        // 3 classes: 1, 2, 3
1129        let classes: std::collections::HashSet<i64> =
1130            problem.labels.iter().map(|&l| l as i64).collect();
1131        assert_eq!(classes.len(), 3);
1132    }
1133
1134    #[test]
1135    fn parse_housing() {
1136        let path = data_dir().join("housing_scale");
1137        let problem = load_problem(&path).unwrap();
1138        assert_eq!(problem.labels.len(), 506);
1139        // Regression: labels are continuous
1140        assert!((problem.labels[0] - 24.0).abs() < 1e-10);
1141    }
1142
1143    #[test]
1144    fn parse_empty_lines() {
1145        let input = b"+1 1:0.5\n\n-1 2:0.3\n";
1146        let problem = load_problem_from_reader(&input[..]).unwrap();
1147        assert_eq!(problem.labels.len(), 2);
1148    }
1149
1150    #[test]
1151    fn parse_error_unsorted_indices() {
1152        let input = b"+1 3:0.5 1:0.3\n";
1153        let result = load_problem_from_reader(&input[..]);
1154        assert!(result.is_err());
1155        let msg = format!("{}", result.unwrap_err());
1156        assert!(msg.contains("ascending"), "error: {}", msg);
1157    }
1158
1159    #[test]
1160    fn parse_error_duplicate_indices() {
1161        let input = b"+1 1:0.5 1:0.3\n";
1162        let result = load_problem_from_reader(&input[..]);
1163        assert!(result.is_err());
1164    }
1165
1166    #[test]
1167    fn parse_error_missing_colon() {
1168        let input = b"+1 1:0.5 bad_token\n";
1169        let result = load_problem_from_reader(&input[..]);
1170        assert!(result.is_err());
1171    }
1172
1173    #[test]
1174    #[allow(clippy::excessive_precision)]
1175    fn load_c_trained_model() {
1176        // Load a model produced by the original C LIBSVM svm-train
1177        let path = data_dir().join("heart_scale.model");
1178        let model = load_model(&path).unwrap();
1179        assert_eq!(model.nr_class, 2);
1180        assert_eq!(model.param.svm_type, SvmType::CSvc);
1181        assert_eq!(model.param.kernel_type, KernelType::Rbf);
1182        assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
1183        assert_eq!(model.sv.len(), 132);
1184        assert_eq!(model.label, vec![1, -1]);
1185        assert_eq!(model.n_sv, vec![64, 68]);
1186        assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
1187        // sv_coef should have 1 row (nr_class - 1) with 132 entries
1188        assert_eq!(model.sv_coef.len(), 1);
1189        assert_eq!(model.sv_coef[0].len(), 132);
1190    }
1191
1192    #[test]
1193    fn roundtrip_c_model() {
1194        // Load C model, save it back, and verify byte-exact match
1195        let path = data_dir().join("heart_scale.model");
1196        let original_bytes = std::fs::read_to_string(&path).unwrap();
1197        let model = load_model(&path).unwrap();
1198
1199        let mut buf = Vec::new();
1200        save_model_to_writer(&mut buf, &model).unwrap();
1201        let rust_output = String::from_utf8(buf).unwrap();
1202
1203        // Compare line by line for better diagnostics
1204        let orig_lines: Vec<&str> = original_bytes.lines().collect();
1205        let rust_lines: Vec<&str> = rust_output.lines().collect();
1206        assert_eq!(
1207            orig_lines.len(),
1208            rust_lines.len(),
1209            "line count mismatch: C={} Rust={}",
1210            orig_lines.len(),
1211            rust_lines.len()
1212        );
1213        for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
1214            assert_eq!(
1215                o,
1216                r,
1217                "line {} differs:\n  C:    {:?}\n  Rust: {:?}",
1218                i + 1,
1219                o,
1220                r
1221            );
1222        }
1223    }
1224
1225    #[test]
1226    #[allow(clippy::excessive_precision)]
1227    fn gfmt_matches_c_printf() {
1228        // Reference values from C's printf("%.17g|%.8g\n", v, v)
1229        let cases: &[(f64, &str, &str)] = &[
1230            (0.5, "0.5", "0.5"),
1231            (-1.0, "-1", "-1"),
1232            (0.123456789012345, "0.123456789012345", "0.12345679"),
1233            (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
1234            (0.42446200000000001, "0.42446200000000001", "0.424462"),
1235            (0.0, "0", "0"),
1236            (1e-5, "1.0000000000000001e-05", "1e-05"),
1237            (1e-4, "0.0001", "0.0001"),
1238            (1e20, "1e+20", "1e+20"),
1239            (-0.25, "-0.25", "-0.25"),
1240            (0.75, "0.75", "0.75"),
1241            (0.708333, "0.70833299999999999", "0.708333"),
1242            (1.0, "1", "1"),
1243        ];
1244        for &(v, expected_17g, expected_8g) in cases {
1245            let got_17 = format!("{}", fmt_17g(v));
1246            let got_8 = format!("{}", fmt_8g(v));
1247            assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
1248            assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
1249        }
1250    }
1251
1252    #[test]
1253    #[allow(clippy::excessive_precision)]
1254    fn model_roundtrip() {
1255        // Create a minimal model and verify save → load roundtrip
1256        let model = SvmModel {
1257            param: SvmParameter {
1258                svm_type: SvmType::CSvc,
1259                kernel_type: KernelType::Rbf,
1260                gamma: 0.5,
1261                ..Default::default()
1262            },
1263            nr_class: 2,
1264            sv: vec![
1265                vec![
1266                    SvmNode {
1267                        index: 1,
1268                        value: 0.5,
1269                    },
1270                    SvmNode {
1271                        index: 3,
1272                        value: -1.0,
1273                    },
1274                ],
1275                vec![
1276                    SvmNode {
1277                        index: 1,
1278                        value: -0.25,
1279                    },
1280                    SvmNode {
1281                        index: 2,
1282                        value: 0.75,
1283                    },
1284                ],
1285            ],
1286            sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
1287            rho: vec![0.42446200000000001],
1288            prob_a: vec![],
1289            prob_b: vec![],
1290            prob_density_marks: vec![],
1291            sv_indices: vec![],
1292            label: vec![1, -1],
1293            n_sv: vec![1, 1],
1294        };
1295
1296        let mut buf = Vec::new();
1297        save_model_to_writer(&mut buf, &model).unwrap();
1298
1299        let loaded = load_model_from_reader(&buf[..]).unwrap();
1300
1301        assert_eq!(loaded.nr_class, model.nr_class);
1302        assert_eq!(loaded.param.svm_type, model.param.svm_type);
1303        assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
1304        assert_eq!(loaded.sv.len(), model.sv.len());
1305        assert_eq!(loaded.label, model.label);
1306        assert_eq!(loaded.n_sv, model.n_sv);
1307        assert_eq!(loaded.rho.len(), model.rho.len());
1308        // Check rho within tolerance (roundtrip through text)
1309        for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
1310            assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
1311        }
1312        // Check sv_coef within tolerance
1313        for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
1314            for (a, b) in row_a.iter().zip(row_b.iter()) {
1315                assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
1316            }
1317        }
1318    }
1319
1320    #[test]
1321    fn parse_error_excessive_counts() {
1322        let input =
1323            b"svm_type c_svc\nkernel_type linear\nnr_class 1000000\ntotal_sv 100\nrho 0\nSV\n";
1324        let result = load_model_from_reader(&input[..]);
1325        assert!(result.is_err());
1326        assert!(format!("{}", result.unwrap_err()).contains("nr_class exceeds limit"));
1327
1328        let input =
1329            b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 100000000\nrho 0\nSV\n";
1330        let result = load_model_from_reader(&input[..]);
1331        assert!(result.is_err());
1332        assert!(format!("{}", result.unwrap_err()).contains("total_sv exceeds limit"));
1333    }
1334
1335    #[test]
1336    fn parse_error_excessive_feature_index() {
1337        // Problem file
1338        let input = b"1 10000001:1\n";
1339        let result = load_problem_from_reader(&input[..]);
1340        assert!(result.is_err());
1341        assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
1342
1343        // Model file
1344        let input = b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 1\nrho 0\nSV\n0.1 10000001:1\n";
1345        let result = load_model_from_reader(&input[..]);
1346        assert!(result.is_err());
1347        assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
1348    }
1349
1350    #[test]
1351    fn parse_error_unknown_model_keyword() {
1352        let input = b"bad_key value\n";
1353        let result = load_model_from_reader(&input[..]);
1354        assert!(result.is_err());
1355        assert!(format!("{}", result.unwrap_err()).contains("unknown keyword"));
1356    }
1357
1358    #[test]
1359    fn parse_error_missing_or_unknown_model_values() {
1360        let missing = b"svm_type\n";
1361        let err = load_model_from_reader(&missing[..]).unwrap_err();
1362        assert!(format!("{}", err).contains("missing svm_type value"));
1363
1364        let unknown = b"svm_type unknown_type\n";
1365        let err = load_model_from_reader(&unknown[..]).unwrap_err();
1366        assert!(format!("{}", err).contains("unknown svm_type"));
1367    }
1368
1369    #[test]
1370    fn parse_error_invalid_nr_sv_entry() {
1371        let input = b"svm_type c_svc\n\
1372kernel_type linear\n\
1373nr_class 2\n\
1374total_sv 1\n\
1375rho 0\n\
1376nr_sv a 1\n\
1377SV\n\
13780.1 1:0.5\n";
1379        let err = load_model_from_reader(&input[..]).unwrap_err();
1380        assert!(format!("{}", err).contains("invalid nr_sv value"));
1381    }
1382
1383    #[test]
1384    fn parse_error_in_sv_section_tokens() {
1385        let missing_coef = b"svm_type c_svc\n\
1386kernel_type linear\n\
1387nr_class 2\n\
1388total_sv 1\n\
1389rho 0\n\
1390SV\n\
13911:0.5\n";
1392        let err = load_model_from_reader(&missing_coef[..]).unwrap_err();
1393        assert!(format!("{}", err).contains("invalid sv_coef"));
1394
1395        let bad_feature = b"svm_type c_svc\n\
1396kernel_type linear\n\
1397nr_class 2\n\
1398total_sv 1\n\
1399rho 0\n\
1400SV\n\
14010.1 bad\n";
1402        let err = load_model_from_reader(&bad_feature[..]).unwrap_err();
1403        assert!(format!("{}", err).contains("expected index:value"));
1404    }
1405
1406    #[test]
1407    fn parse_error_unexpected_eof_in_header_and_sv_section() {
1408        let eof_header = b"svm_type c_svc\n";
1409        let err = load_model_from_reader(&eof_header[..]).unwrap_err();
1410        assert!(format!("{}", err).contains("unexpected end of file in header"));
1411
1412        let eof_sv = b"svm_type c_svc\n\
1413kernel_type linear\n\
1414nr_class 2\n\
1415total_sv 2\n\
1416rho 0\n\
1417SV\n\
14180.1 1:0.5\n";
1419        let err = load_model_from_reader(&eof_sv[..]).unwrap_err();
1420        assert!(format!("{}", err).contains("unexpected end of file in SV section"));
1421    }
1422
1423    #[test]
1424    fn reject_rho_length_mismatch_for_classification() {
1425        // CSvc with nr_class=3 expects rho.len() == 3. Supplying 1 entry
1426        // reveals either a malformed file or an intentional substitution.
1427        let input = b"svm_type c_svc\n\
1428kernel_type linear\n\
1429nr_class 3\n\
1430total_sv 3\n\
1431rho 0\n\
1432SV\n";
1433        let err = load_model_from_reader(&input[..]).unwrap_err();
1434        assert!(
1435            format!("{}", err).contains("rho has 1 entries, expected 3"),
1436            "unexpected error: {}",
1437            err
1438        );
1439    }
1440
1441    #[test]
1442    fn reject_rho_length_mismatch_for_regression() {
1443        // SVR expects exactly 1 rho entry; 2 entries is inconsistent.
1444        let input = b"svm_type epsilon_svr\n\
1445kernel_type linear\n\
1446nr_class 2\n\
1447total_sv 0\n\
1448rho 0 1\n\
1449SV\n";
1450        let err = load_model_from_reader(&input[..]).unwrap_err();
1451        assert!(
1452            format!("{}", err).contains("rho has 2 entries, expected 1"),
1453            "unexpected error: {}",
1454            err
1455        );
1456    }
1457
1458    #[test]
1459    fn reject_label_on_regression() {
1460        // label is classification-only; carrying it on SVR is a red flag.
1461        let input = b"svm_type epsilon_svr\n\
1462kernel_type linear\n\
1463nr_class 2\n\
1464total_sv 0\n\
1465rho 0\n\
1466label 1 -1\n\
1467SV\n";
1468        let err = load_model_from_reader(&input[..]).unwrap_err();
1469        assert!(
1470            format!("{}", err).contains("label is only valid for classification"),
1471            "unexpected error: {}",
1472            err
1473        );
1474    }
1475
1476    #[test]
1477    fn reject_label_length_mismatch() {
1478        let input = b"svm_type c_svc\n\
1479kernel_type linear\n\
1480nr_class 3\n\
1481total_sv 0\n\
1482rho 0 0 0\n\
1483label 1 -1\n\
1484SV\n";
1485        let err = load_model_from_reader(&input[..]).unwrap_err();
1486        assert!(
1487            format!("{}", err).contains("label has 2 entries, expected nr_class (3)"),
1488            "unexpected error: {}",
1489            err
1490        );
1491    }
1492
1493    #[test]
1494    fn reject_nr_sv_sum_mismatch() {
1495        // total_sv=5 but nr_sv sums to 3. An inconsistent header like this
1496        // could previously pass and leave downstream code with stale assumptions.
1497        let input = b"svm_type c_svc\n\
1498kernel_type linear\n\
1499nr_class 2\n\
1500total_sv 5\n\
1501rho 0\n\
1502label 1 -1\n\
1503nr_sv 1 2\n\
1504SV\n";
1505        let err = load_model_from_reader(&input[..]).unwrap_err();
1506        assert!(
1507            format!("{}", err).contains("sum of nr_sv entries (3) does not match total_sv (5)"),
1508            "unexpected error: {}",
1509            err
1510        );
1511    }
1512
1513    #[test]
1514    fn reject_nr_sv_length_mismatch() {
1515        let input = b"svm_type c_svc\n\
1516kernel_type linear\n\
1517nr_class 3\n\
1518total_sv 3\n\
1519rho 0 0 0\n\
1520label 1 -1 0\n\
1521nr_sv 1 2\n\
1522SV\n";
1523        let err = load_model_from_reader(&input[..]).unwrap_err();
1524        assert!(
1525            format!("{}", err).contains("nr_sv has 2 entries, expected nr_class (3)"),
1526            "unexpected error: {}",
1527            err
1528        );
1529    }
1530
1531    #[test]
1532    fn reject_proba_length_mismatch() {
1533        let input = b"svm_type c_svc\n\
1534kernel_type linear\n\
1535nr_class 3\n\
1536total_sv 0\n\
1537rho 0 0 0\n\
1538probA 0.1 0.2\n\
1539SV\n";
1540        let err = load_model_from_reader(&input[..]).unwrap_err();
1541        assert!(
1542            format!("{}", err).contains("probA has 2 entries, expected 3"),
1543            "unexpected error: {}",
1544            err
1545        );
1546    }
1547
1548    #[test]
1549    fn reject_prob_density_marks_on_csvc() {
1550        let input = b"svm_type c_svc\n\
1551kernel_type linear\n\
1552nr_class 2\n\
1553total_sv 0\n\
1554rho 0\n\
1555prob_density_marks 0.1 0.2\n\
1556SV\n";
1557        let err = load_model_from_reader(&input[..]).unwrap_err();
1558        assert!(
1559            format!("{}", err).contains("prob_density_marks is only valid for one-class SVM"),
1560            "unexpected error: {}",
1561            err
1562        );
1563    }
1564
1565    #[test]
1566    fn reject_nr_class_below_two() {
1567        let input = b"svm_type c_svc\n\
1568kernel_type linear\n\
1569nr_class 1\n\
1570total_sv 0\n\
1571rho\n\
1572SV\n";
1573        let err = load_model_from_reader(&input[..]).unwrap_err();
1574        assert!(
1575            format!("{}", err).contains("nr_class must be >= 2, got 1"),
1576            "unexpected error: {}",
1577            err
1578        );
1579    }
1580
1581    #[test]
1582    fn reject_sv_feature_indices_not_ascending() {
1583        let input = b"svm_type c_svc\n\
1584kernel_type linear\n\
1585nr_class 2\n\
1586total_sv 1\n\
1587rho 0\n\
1588SV\n\
15890.1 3:0.5 1:0.3\n";
1590        let err = load_model_from_reader(&input[..]).unwrap_err();
1591        assert!(
1592            format!("{}", err).contains("feature indices must be ascending"),
1593            "unexpected error: {}",
1594            err
1595        );
1596    }
1597
1598    #[test]
1599    fn reject_precomputed_model_sv_without_sample_serial_number() {
1600        let input = b"svm_type c_svc\n\
1601kernel_type precomputed\n\
1602nr_class 2\n\
1603total_sv 1\n\
1604rho 0\n\
1605SV\n\
16060.1\n";
1607        let err = load_model_from_reader(&input[..]).unwrap_err();
1608        assert!(
1609            format!("{}", err).contains("missing 0:sample_serial_number"),
1610            "unexpected error: {}",
1611            err
1612        );
1613    }
1614
1615    #[test]
1616    fn reject_sv_feature_index_duplicated() {
1617        let input = b"svm_type c_svc\n\
1618kernel_type linear\n\
1619nr_class 2\n\
1620total_sv 1\n\
1621rho 0\n\
1622SV\n\
16230.1 1:0.5 1:0.3\n";
1624        let err = load_model_from_reader(&input[..]).unwrap_err();
1625        assert!(
1626            format!("{}", err).contains("feature indices must be ascending"),
1627            "unexpected error: {}",
1628            err
1629        );
1630    }
1631
1632    #[test]
1633    fn load_options_default_caps_match_documented_values() {
1634        let opts = LoadOptions::default();
1635        assert_eq!(opts.max_bytes, 64 * 1024 * 1024);
1636        assert_eq!(opts.max_line_len, 1024 * 1024);
1637        assert_eq!(opts.max_sv, MAX_TOTAL_SV);
1638        assert_eq!(opts.max_nr_class, MAX_NR_CLASS);
1639        assert_eq!(opts.max_feature_index, MAX_FEATURE_INDEX);
1640    }
1641
1642    #[test]
1643    fn load_options_trusted_input_sets_type_maxes() {
1644        let opts = LoadOptions::trusted_input();
1645        assert_eq!(opts.max_bytes, u64::MAX);
1646        assert_eq!(opts.max_line_len, usize::MAX);
1647        assert_eq!(opts.max_sv, usize::MAX);
1648        assert_eq!(opts.max_nr_class, usize::MAX);
1649        assert_eq!(opts.max_feature_index, i32::MAX);
1650    }
1651
1652    #[test]
1653    fn problem_reader_rejects_file_over_max_bytes() {
1654        // 20 bytes of content, cap at 10.
1655        let input = b"+1 1:0.5\n+1 2:0.5\n";
1656        let opts = LoadOptions {
1657            max_bytes: 10,
1658            ..LoadOptions::default()
1659        };
1660        let err = load_problem_from_reader_with_options(&input[..], &opts).unwrap_err();
1661        assert!(
1662            format!("{}", err).contains("max_bytes"),
1663            "unexpected error: {}",
1664            err
1665        );
1666    }
1667
1668    #[test]
1669    fn problem_reader_rejects_line_over_max_line_len() {
1670        // Build a line of 200 chars; cap max_line_len at 50.
1671        let mut payload = String::from("+1 ");
1672        for i in 1..=50 {
1673            payload.push_str(&format!("{}:0.1 ", i));
1674        }
1675        payload.push('\n');
1676        let opts = LoadOptions {
1677            max_line_len: 50,
1678            ..LoadOptions::default()
1679        };
1680        let err = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap_err();
1681        assert!(
1682            format!("{}", err).contains("max_line_len"),
1683            "unexpected error: {}",
1684            err
1685        );
1686    }
1687
1688    #[test]
1689    fn problem_reader_accepts_line_at_max_line_len() {
1690        // A line whose content (excluding trailing newline) is exactly
1691        // max_line_len bytes must be accepted — the cap is inclusive.
1692        let line_content = "+1 1:0.5";
1693        let payload = format!("{}\n", line_content);
1694        let opts = LoadOptions {
1695            max_line_len: line_content.len(),
1696            ..LoadOptions::default()
1697        };
1698        let problem = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap();
1699        assert_eq!(problem.labels.len(), 1);
1700    }
1701
1702    #[test]
1703    fn problem_reader_tolerates_crlf_at_cap() {
1704        // Same content length, but with \r\n. Total bytes on disk is
1705        // content_len + 2; the helper allows one byte of slack for the \r.
1706        let line_content = "+1 1:0.5";
1707        let payload = format!("{}\r\n", line_content);
1708        let opts = LoadOptions {
1709            max_line_len: line_content.len(),
1710            ..LoadOptions::default()
1711        };
1712        let problem = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap();
1713        assert_eq!(problem.labels.len(), 1);
1714    }
1715
1716    #[test]
1717    fn problem_reader_rejects_nul_byte() {
1718        // A stray NUL byte inside a line is rejected early — LIBSVM text
1719        // files are ASCII, and a NUL is most likely truncated binary data.
1720        let mut payload: Vec<u8> = b"+1 1:0.5".to_vec();
1721        payload.push(0);
1722        payload.extend_from_slice(b"\n");
1723        let err = load_problem_from_reader(payload.as_slice()).unwrap_err();
1724        assert!(
1725            format!("{}", err).contains("NUL byte"),
1726            "unexpected error: {}",
1727            err
1728        );
1729    }
1730
1731    #[test]
1732    fn model_reader_honors_max_nr_class_cap() {
1733        // 100 is below both MAX_NR_CLASS and the 50 cap; the 50 cap should
1734        // fire first.
1735        let input = b"svm_type c_svc\n\
1736kernel_type linear\n\
1737nr_class 100\n\
1738total_sv 1\n\
1739rho 0\n\
1740SV\n";
1741        let opts = LoadOptions {
1742            max_nr_class: 50,
1743            ..LoadOptions::default()
1744        };
1745        let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
1746        assert!(
1747            format!("{}", err).contains("nr_class exceeds limit (50)"),
1748            "unexpected error: {}",
1749            err
1750        );
1751    }
1752
1753    #[test]
1754    fn model_reader_honors_max_sv_cap() {
1755        let input = b"svm_type c_svc\n\
1756kernel_type linear\n\
1757nr_class 2\n\
1758total_sv 1000\n\
1759rho 0\n\
1760SV\n";
1761        let opts = LoadOptions {
1762            max_sv: 100,
1763            ..LoadOptions::default()
1764        };
1765        let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
1766        assert!(
1767            format!("{}", err).contains("total_sv exceeds limit (100)"),
1768            "unexpected error: {}",
1769            err
1770        );
1771    }
1772
1773    #[test]
1774    fn trusted_input_cannot_exceed_hard_module_caps() {
1775        // `LoadOptions::trusted_input()` sets caps to usize::MAX / u64::MAX,
1776        // but the module-level hard caps (MAX_NR_CLASS, MAX_TOTAL_SV) still
1777        // apply as an upper bound. This is defense-in-depth.
1778        let huge_nr_class = format!(
1779            "svm_type c_svc\n\
1780kernel_type linear\n\
1781nr_class {}\n\
1782total_sv 1\n\
1783rho 0\n\
1784SV\n",
1785            MAX_NR_CLASS + 1
1786        );
1787        let opts = LoadOptions::trusted_input();
1788        let err = load_model_from_reader_with_options(huge_nr_class.as_bytes(), &opts).unwrap_err();
1789        assert!(
1790            format!("{}", err).contains("nr_class exceeds limit"),
1791            "unexpected error: {}",
1792            err
1793        );
1794    }
1795
1796    #[test]
1797    fn model_reader_honors_max_feature_index_cap() {
1798        let input = b"svm_type c_svc\n\
1799kernel_type linear\n\
1800nr_class 2\n\
1801total_sv 1\n\
1802rho 0\n\
1803SV\n\
18040.1 50:0.5\n";
1805        let opts = LoadOptions {
1806            max_feature_index: 10,
1807            ..LoadOptions::default()
1808        };
1809        let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
1810        assert!(
1811            format!("{}", err).contains("feature index 50 exceeds limit (10)"),
1812            "unexpected error: {}",
1813            err
1814        );
1815    }
1816
1817    #[test]
1818    fn sv_count_loop_counts_nonblank_lines_only() {
1819        // A blank line inside the SV section must NOT be billed against
1820        // total_sv — parsing must terminate only after `total_sv` real SV
1821        // rows have been collected. This guards the loop rewrite from
1822        // `for _ in 0..total_sv { if empty { continue } ... }`, which would
1823        // silently accept a model with fewer SVs than the header claims.
1824        let input = b"svm_type c_svc\n\
1825kernel_type linear\n\
1826nr_class 2\n\
1827total_sv 2\n\
1828rho 0\n\
1829label 1 -1\n\
1830nr_sv 1 1\n\
1831SV\n\
1832\n\
18330.1 1:0.5\n\
1834\n\
1835-0.1 2:0.5\n";
1836        let model = load_model_from_reader(&input[..]).unwrap();
1837        assert_eq!(model.sv.len(), 2);
1838        assert_eq!(model.sv_coef[0].len(), 2);
1839    }
1840
1841    #[test]
1842    fn save_precomputed_model_writes_zero_index() {
1843        let model = SvmModel {
1844            param: SvmParameter {
1845                svm_type: SvmType::CSvc,
1846                kernel_type: KernelType::Precomputed,
1847                ..Default::default()
1848            },
1849            nr_class: 2,
1850            sv: vec![vec![SvmNode {
1851                index: 0,
1852                value: 7.0,
1853            }]],
1854            sv_coef: vec![vec![0.25]],
1855            rho: vec![0.0],
1856            prob_a: vec![],
1857            prob_b: vec![],
1858            prob_density_marks: vec![],
1859            sv_indices: vec![],
1860            label: vec![1, -1],
1861            n_sv: vec![1, 0],
1862        };
1863
1864        let mut buf = Vec::new();
1865        save_model_to_writer(&mut buf, &model).unwrap();
1866        let out = String::from_utf8(buf).unwrap();
1867        assert!(out.contains("kernel_type precomputed"));
1868        assert!(out.contains("0:7"));
1869    }
1870}