Skip to main content

libsvm_rs/
io.rs

1//! I/O routines for LIBSVM problem and model files.
2//!
3//! The file formats match the original LIBSVM text formats for cross-tool
4//! interoperability, but the loaders are written for untrusted input by
5//! default. They enforce [`LoadOptions`] caps, reject embedded NUL bytes,
6//! reject non-ascending feature indices, and return structured [`SvmError`]
7//! values for malformed input within those caps.
8//!
9//! Model loading performs additional structural checks before allocating support
10//! vector storage. Header fields such as `nr_class`, `total_sv`, `rho`, `label`,
11//! `nr_sv`, `probA`, `probB`, and one-class probability-density marks must be
12//! shape-consistent when present. Precomputed-kernel support-vector rows must
13//! begin with `0:sample_serial_number`.
14//!
15//! These checks validate text structure and resource bounds. They do not prove
16//! that a model is statistically meaningful, was trained from a particular
17//! dataset, or is cryptographically authentic.
18
19use std::io::{BufRead, Write};
20use std::path::Path;
21
22use crate::error::SvmError;
23use crate::types::*;
24use crate::util::parse_feature_index;
25use crate::util::MAX_FEATURE_INDEX;
26
27// ─── C-compatible %g formatting ─────────────────────────────────────
28//
29// C's printf `%.Pg` format strips trailing zeros and picks fixed vs.
30// scientific notation based on the exponent. Rust has no built-in
31// equivalent, so we replicate the POSIX specification:
32//   - Use scientific if exponent < -4 or exponent >= precision
33//   - Otherwise use fixed notation
34//   - Strip trailing zeros (and trailing decimal point)
35
36use std::fmt;
37
38/// Formats `f64` like C's `%.17g` (or any precision).
39struct Gfmt {
40    value: f64,
41    precision: usize,
42}
43
44impl Gfmt {
45    fn new(value: f64, precision: usize) -> Self {
46        Self { value, precision }
47    }
48}
49
50impl fmt::Display for Gfmt {
51    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52        let v = self.value;
53        let p = self.precision;
54
55        if !v.is_finite() {
56            return write!(f, "{}", v); // inf, -inf, NaN
57        }
58
59        if v == 0.0 {
60            // Preserve sign of -0.0
61            if v.is_sign_negative() {
62                return write!(f, "-0");
63            }
64            return write!(f, "0");
65        }
66
67        // Compute the exponent (floor of log10(|v|))
68        let abs_v = v.abs();
69        let exp = abs_v.log10().floor() as i32;
70
71        if exp < -4 || exp >= p as i32 {
72            // Use scientific notation
73            let s = format!("{:.prec$e}", v, prec = p.saturating_sub(1));
74            // Rust uses 'e', C uses 'e'. Strip trailing zeros in mantissa.
75            // C zero-pads exponent to at least 2 digits (e-05 not e-5).
76            if let Some((mantissa, exponent)) = s.split_once('e') {
77                let mantissa = mantissa.trim_end_matches('0').trim_end_matches('.');
78                // Parse exponent, reformat with at least 2 digits
79                let exp_val: i32 = exponent.parse().unwrap_or(0);
80                let exp_str = if exp_val < 0 {
81                    format!("-{:02}", -exp_val)
82                } else {
83                    format!("+{:02}", exp_val)
84                };
85                write!(f, "{}e{}", mantissa, exp_str)
86            } else {
87                write!(f, "{}", s)
88            }
89        } else {
90            // Use fixed notation. Number of decimal places = precision - (exp + 1)
91            let decimal_places = if exp >= 0 {
92                p.saturating_sub((exp + 1) as usize)
93            } else {
94                p + (-1 - exp) as usize
95            };
96            let s = format!("{:.prec$}", v, prec = decimal_places);
97            let s = s.trim_end_matches('0').trim_end_matches('.');
98            write!(f, "{}", s)
99        }
100    }
101}
102
103/// Format like C's `%.17g`
104fn fmt_17g(v: f64) -> Gfmt {
105    Gfmt::new(v, 17)
106}
107
108/// Format like C's `%.8g`
109fn fmt_8g(v: f64) -> Gfmt {
110    Gfmt::new(v, 8)
111}
112
113/// Format a float like C's `%g` (6 significant digits).
114pub fn format_g(v: f64) -> String {
115    format!("{}", Gfmt::new(v, 6))
116}
117
118/// Format a float like C's `%.17g` (17 significant digits).
119pub fn format_17g(v: f64) -> String {
120    format!("{}", Gfmt::new(v, 17))
121}
122
123// ─── String tables matching original LIBSVM ──────────────────────────
124
125const SVM_TYPE_TABLE: &[&str] = &["c_svc", "nu_svc", "one_class", "epsilon_svr", "nu_svr"];
126const KERNEL_TYPE_TABLE: &[&str] = &["linear", "polynomial", "rbf", "sigmoid", "precomputed"];
127
128fn svm_type_to_str(t: SvmType) -> &'static str {
129    SVM_TYPE_TABLE[t as usize]
130}
131
132fn kernel_type_to_str(t: KernelType) -> &'static str {
133    KERNEL_TYPE_TABLE[t as usize]
134}
135
136fn str_to_svm_type(s: &str) -> Option<SvmType> {
137    match s {
138        "c_svc" => Some(SvmType::CSvc),
139        "nu_svc" => Some(SvmType::NuSvc),
140        "one_class" => Some(SvmType::OneClass),
141        "epsilon_svr" => Some(SvmType::EpsilonSvr),
142        "nu_svr" => Some(SvmType::NuSvr),
143        _ => None,
144    }
145}
146
147fn str_to_kernel_type(s: &str) -> Option<KernelType> {
148    match s {
149        "linear" => Some(KernelType::Linear),
150        "polynomial" => Some(KernelType::Polynomial),
151        "rbf" => Some(KernelType::Rbf),
152        "sigmoid" => Some(KernelType::Sigmoid),
153        "precomputed" => Some(KernelType::Precomputed),
154        _ => None,
155    }
156}
157
158// ─── Load options ────────────────────────────────────────────────────
159
160/// Resource caps applied while reading LIBSVM problem and model files.
161///
162/// LIBSVM text formats are linear in file size, but individual fields (e.g.
163/// `total_sv`, feature indices, per-line token counts) can be used by a
164/// malicious file to trigger disproportionate allocation or CPU work. The
165/// [`Default`] impl returns defaults tuned for **untrusted input**:
166///
167/// | Field               | Default        |
168/// |---------------------|----------------|
169/// | `max_bytes`         | 64 MiB         |
170/// | `max_line_len`      | 1 MiB          |
171/// | `max_sv`            | 10 000 000     |
172/// | `max_nr_class`      | 65 535         |
173/// | `max_feature_index` | 10 000 000     |
174///
175/// If you know the input is trusted (produced locally, read from an
176/// attestation-protected location, etc.) and you need to handle inputs
177/// larger than the defaults, call [`LoadOptions::trusted_input`] for an
178/// "unlimited" profile, or override specific fields:
179///
180/// ```ignore
181/// use libsvm_rs::io::{load_problem_from_reader_with_options, LoadOptions};
182///
183/// let opts = LoadOptions {
184///     max_bytes: 1 << 30, // 1 GiB for a trusted bulk file
185///     ..LoadOptions::default()
186/// };
187/// let problem = load_problem_from_reader_with_options(reader, &opts)?;
188/// ```
189///
190/// Exceeding any cap returns an [`SvmError::ParseError`] (problem files) or
191/// [`SvmError::ModelFormatError`] (model files) with a message identifying
192/// the field that tripped the limit.
193///
194/// The default profile is intended to be panic-free for malformed text input
195/// within these caps. It is still a parser contract, not a semantic model
196/// audit: it does not verify that a loaded model came from a specific training
197/// set or that its predictions are appropriate for a deployment.
198#[derive(Debug, Clone, Copy)]
199pub struct LoadOptions {
200    /// Maximum total number of bytes read from the source.
201    pub max_bytes: u64,
202    /// Maximum number of bytes in a single line (excluding the terminator).
203    pub max_line_len: usize,
204    /// Maximum value accepted for the `total_sv` header field in model files.
205    pub max_sv: usize,
206    /// Maximum value accepted for the `nr_class` header field in model files.
207    pub max_nr_class: usize,
208    /// Maximum feature index accepted in problem or model feature lists.
209    pub max_feature_index: i32,
210}
211
212impl Default for LoadOptions {
213    fn default() -> Self {
214        Self {
215            max_bytes: 64 * 1024 * 1024,
216            max_line_len: 1024 * 1024,
217            max_sv: MAX_TOTAL_SV,
218            max_nr_class: MAX_NR_CLASS,
219            max_feature_index: MAX_FEATURE_INDEX,
220        }
221    }
222}
223
224impl LoadOptions {
225    /// Options with all caps set to the type maximum.
226    ///
227    /// Use **only** for input that is fully trusted. Operating with
228    /// unlimited caps removes the first line of defense against malformed
229    /// or adversarial model / problem files. Module-level hard caps still
230    /// apply to `nr_class` and `total_sv`.
231    pub fn trusted_input() -> Self {
232        Self {
233            max_bytes: u64::MAX,
234            max_line_len: usize::MAX,
235            max_sv: usize::MAX,
236            max_nr_class: usize::MAX,
237            max_feature_index: i32::MAX,
238        }
239    }
240}
241
242/// Read one line from `reader`, enforcing byte and line-length caps.
243///
244/// Returns `Ok(None)` on clean EOF. Updates `bytes_read` by the number of
245/// bytes actually consumed (including the line terminator).
246///
247/// Uses [`BufRead::fill_buf`] / [`BufRead::consume`] to pull at most a
248/// buffer's worth at a time and check both caps before committing bytes
249/// to the output buffer. A pathological unbounded line cannot grow the
250/// output buffer past `max_line_len` — the error fires before the next
251/// slice copy would push past the cap.
252fn read_line_capped(
253    reader: &mut dyn BufRead,
254    bytes_read: &mut u64,
255    max_bytes: u64,
256    max_line_len: usize,
257) -> std::io::Result<Option<String>> {
258    // One byte of slack on the per-line cap to accommodate `\r\n` line
259    // endings: `max_line_len` describes *content* length, so a line with
260    // the full allowed content plus CRLF consumes `max_line_len + 2` bytes
261    // total, of which one (the \r) would otherwise push us past the cap
262    // before the final `\n` arrives.
263    let per_line_raw_cap: u64 = (max_line_len as u64).saturating_add(1);
264
265    let mut buf: Vec<u8> = Vec::new();
266    let mut found_newline = false;
267
268    loop {
269        let available = reader.fill_buf()?;
270        if available.is_empty() {
271            break;
272        }
273
274        let take_n = match available.iter().position(|&b| b == b'\n') {
275            Some(pos) => pos + 1, // include the newline
276            None => available.len(),
277        };
278        let ends_with_newline = take_n > 0 && available[take_n - 1] == b'\n';
279
280        // Byte cap.
281        let new_bytes_read = bytes_read.saturating_add(take_n as u64);
282        if new_bytes_read > max_bytes {
283            return Err(std::io::Error::new(
284                std::io::ErrorKind::InvalidData,
285                format!("input exceeds max_bytes limit ({})", max_bytes),
286            ));
287        }
288
289        // Per-line cap. Content bytes excludes the trailing `\n` (if any);
290        // we allow one byte of slack for a possible `\r` preceding it.
291        let prospective_len = (buf.len() as u64).saturating_add(take_n as u64);
292        let content_bytes = if ends_with_newline {
293            prospective_len - 1
294        } else {
295            prospective_len
296        };
297        if content_bytes > per_line_raw_cap {
298            return Err(std::io::Error::new(
299                std::io::ErrorKind::InvalidData,
300                format!("line length exceeds max_line_len limit ({})", max_line_len),
301            ));
302        }
303
304        // NUL bytes have no legal use in LIBSVM text files.
305        if available[..take_n].contains(&0) {
306            return Err(std::io::Error::new(
307                std::io::ErrorKind::InvalidData,
308                "unexpected NUL byte in text input".to_string(),
309            ));
310        }
311
312        buf.extend_from_slice(&available[..take_n]);
313        reader.consume(take_n);
314        *bytes_read = new_bytes_read;
315
316        if ends_with_newline {
317            found_newline = true;
318            break;
319        }
320    }
321
322    if buf.is_empty() && !found_newline {
323        return Ok(None);
324    }
325
326    let line = String::from_utf8(buf)
327        .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
328    Ok(Some(line))
329}
330
331// ─── Problem file I/O ────────────────────────────────────────────────
332
333/// Load an SVM problem from a file in LIBSVM sparse format.
334///
335/// Format: `<label> <index1>:<value1> <index2>:<value2> ...`
336///
337/// Uses [`LoadOptions::default`] — appropriate for untrusted input. For
338/// trusted bulk files exceeding the defaults, use
339/// [`load_problem_from_reader_with_options`] with a custom [`LoadOptions`].
340///
341/// Validates:
342///
343/// - total file size and per-line length,
344/// - embedded NUL byte absence,
345/// - parseable labels and feature values,
346/// - `index:value` token shape,
347/// - ascending feature indices,
348/// - [`LoadOptions::max_feature_index`].
349///
350/// This loader does not validate statistical quality, feature normalization, or
351/// whether labels are appropriate for a particular SVM formulation. Malformed
352/// text input within the configured caps is returned as [`SvmError`] rather
353/// than panicking.
354///
355/// ### Complexity
356///
357/// Linear in file size for parsing and `O(n · d)` memory where `n` is the
358/// number of instances and `d` is the average number of non-zero features
359/// per instance. No per-row allocation is driven by an untrusted header.
360pub fn load_problem(path: &Path) -> Result<SvmProblem, SvmError> {
361    let file = std::fs::File::open(path)?;
362    let reader = std::io::BufReader::new(file);
363    load_problem_from_reader(reader)
364}
365
366/// Load an SVM problem from any buffered reader.
367///
368/// Uses [`LoadOptions::default`]. See [`load_problem`] for the validation
369/// contract and non-goals.
370pub fn load_problem_from_reader(reader: impl BufRead) -> Result<SvmProblem, SvmError> {
371    load_problem_from_reader_with_options(reader, &LoadOptions::default())
372}
373
374/// Load an SVM problem from any buffered reader, with explicit resource caps.
375///
376/// See [`LoadOptions`] for the meaning of each cap and for defaults tuned for
377/// untrusted input. This function has the same validation contract as
378/// [`load_problem`], with caller-supplied caps.
379pub fn load_problem_from_reader_with_options(
380    mut reader: impl BufRead,
381    options: &LoadOptions,
382) -> Result<SvmProblem, SvmError> {
383    let mut labels = Vec::new();
384    let mut instances = Vec::new();
385    let mut bytes_read: u64 = 0;
386    let mut line_idx: usize = 0;
387
388    while let Some(raw) = read_line_capped(
389        &mut reader,
390        &mut bytes_read,
391        options.max_bytes,
392        options.max_line_len,
393    )? {
394        let line_num = line_idx + 1;
395        line_idx += 1;
396        let line = raw.trim();
397        if line.is_empty() {
398            continue;
399        }
400
401        let mut parts = line.split_whitespace();
402
403        // Parse label
404        let label_str = parts.next().ok_or_else(|| SvmError::ParseError {
405            line: line_num,
406            message: "missing label".into(),
407        })?;
408        let label: f64 = label_str.parse().map_err(|_| SvmError::ParseError {
409            line: line_num,
410            message: format!("invalid label: {}", label_str),
411        })?;
412
413        // Parse features (must be in ascending index order).
414        let mut nodes = Vec::new();
415        let mut prev_index: i32 = 0;
416        for token in parts {
417            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| SvmError::ParseError {
418                line: line_num,
419                message: format!("expected index:value, got: {}", token),
420            })?;
421            let index: i32 =
422                parse_feature_index_problem_line(line_num, idx_str, options.max_feature_index)?;
423
424            if !nodes.is_empty() && index <= prev_index {
425                return Err(SvmError::ParseError {
426                    line: line_num,
427                    message: format!(
428                        "feature indices must be ascending: {} follows {}",
429                        index, prev_index
430                    ),
431                });
432            }
433            let value: f64 = val_str.parse().map_err(|_| SvmError::ParseError {
434                line: line_num,
435                message: format!("invalid value: {}", val_str),
436            })?;
437            prev_index = index;
438            nodes.push(SvmNode { index, value });
439        }
440
441        labels.push(label);
442        instances.push(nodes);
443    }
444
445    Ok(SvmProblem { labels, instances })
446}
447
448// ─── Model file I/O ──────────────────────────────────────────────────
449
450const MAX_NR_CLASS: usize = 65535;
451const MAX_TOTAL_SV: usize = 10_000_000;
452
453/// Save an SVM model to a file in the original LIBSVM format.
454pub fn save_model(path: &Path, model: &SvmModel) -> Result<(), SvmError> {
455    let file = std::fs::File::create(path)?;
456    let writer = std::io::BufWriter::new(file);
457    save_model_to_writer(writer, model)
458}
459
460/// Save an SVM model to any writer.
461///
462/// This is the writer-backed equivalent of LIBSVM's `svm_save_model`, useful
463/// for in-memory buffers and tests.
464///
465/// ```
466/// use libsvm_rs::io::{load_model_from_reader, save_model_to_writer};
467/// use libsvm_rs::train::svm_train;
468/// use libsvm_rs::{KernelType, SvmNode, SvmParameterBuilder, SvmProblem, SvmType};
469///
470/// let problem = SvmProblem {
471///     labels: vec![-1.0, -1.0, 1.0, 1.0],
472///     instances: vec![
473///         vec![SvmNode { index: 1, value: -2.0 }],
474///         vec![SvmNode { index: 1, value: -1.0 }],
475///         vec![SvmNode { index: 1, value: 1.0 }],
476///         vec![SvmNode { index: 1, value: 2.0 }],
477///     ],
478/// };
479/// let param = SvmParameterBuilder::new()
480///     .svm_type(SvmType::CSvc)
481///     .kernel_type(KernelType::Linear)
482///     .build()?;
483/// let model = svm_train(&problem, &param);
484///
485/// let mut bytes = Vec::new();
486/// save_model_to_writer(&mut bytes, &model)?;
487/// let loaded = load_model_from_reader(std::io::Cursor::new(bytes))?;
488/// assert_eq!(loaded.svm_type(), SvmType::CSvc);
489/// # Ok::<(), libsvm_rs::SvmError>(())
490/// ```
491pub fn save_model_to_writer(mut w: impl Write, model: &SvmModel) -> Result<(), SvmError> {
492    let param = &model.param;
493
494    writeln!(w, "svm_type {}", svm_type_to_str(param.svm_type))?;
495    writeln!(w, "kernel_type {}", kernel_type_to_str(param.kernel_type))?;
496
497    if param.kernel_type == KernelType::Polynomial {
498        writeln!(w, "degree {}", param.degree)?;
499    }
500    if matches!(
501        param.kernel_type,
502        KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
503    ) {
504        writeln!(w, "gamma {}", fmt_17g(param.gamma))?;
505    }
506    if matches!(
507        param.kernel_type,
508        KernelType::Polynomial | KernelType::Sigmoid
509    ) {
510        writeln!(w, "coef0 {}", fmt_17g(param.coef0))?;
511    }
512
513    let nr_class = model.nr_class;
514    writeln!(w, "nr_class {}", nr_class)?;
515    writeln!(w, "total_sv {}", model.sv.len())?;
516
517    // rho
518    write!(w, "rho")?;
519    for r in &model.rho {
520        write!(w, " {}", fmt_17g(*r))?;
521    }
522    writeln!(w)?;
523
524    // label (classification only)
525    if !model.label.is_empty() {
526        write!(w, "label")?;
527        for l in &model.label {
528            write!(w, " {}", l)?;
529        }
530        writeln!(w)?;
531    }
532
533    // probA
534    if !model.prob_a.is_empty() {
535        write!(w, "probA")?;
536        for v in &model.prob_a {
537            write!(w, " {}", fmt_17g(*v))?;
538        }
539        writeln!(w)?;
540    }
541
542    // probB
543    if !model.prob_b.is_empty() {
544        write!(w, "probB")?;
545        for v in &model.prob_b {
546            write!(w, " {}", fmt_17g(*v))?;
547        }
548        writeln!(w)?;
549    }
550
551    // prob_density_marks (one-class)
552    if !model.prob_density_marks.is_empty() {
553        write!(w, "prob_density_marks")?;
554        for v in &model.prob_density_marks {
555            write!(w, " {}", fmt_17g(*v))?;
556        }
557        writeln!(w)?;
558    }
559
560    // nr_sv
561    if !model.n_sv.is_empty() {
562        write!(w, "nr_sv")?;
563        for n in &model.n_sv {
564            write!(w, " {}", n)?;
565        }
566        writeln!(w)?;
567    }
568
569    // SV section
570    writeln!(w, "SV")?;
571    let num_sv = model.sv.len();
572    let num_coef_rows = model.sv_coef.len(); // nr_class - 1
573
574    for i in 0..num_sv {
575        // sv_coef columns for this SV: %.17g
576        for j in 0..num_coef_rows {
577            write!(w, "{} ", fmt_17g(model.sv_coef[j][i]))?;
578        }
579        // sparse features: %.8g
580        if model.param.kernel_type == KernelType::Precomputed {
581            if let Some(node) = model.sv[i].first() {
582                write!(w, "0:{} ", node.value as i32)?;
583            }
584        } else {
585            for node in &model.sv[i] {
586                write!(w, "{}:{} ", node.index, fmt_8g(node.value))?;
587            }
588        }
589        writeln!(w)?;
590    }
591
592    Ok(())
593}
594
595/// Load an SVM model from a file in the original LIBSVM format.
596///
597/// Uses [`LoadOptions::default`] — appropriate for untrusted input.
598///
599/// Validates:
600///
601/// - total file size and per-line length,
602/// - embedded NUL byte absence,
603/// - known `svm_type` and `kernel_type` values,
604/// - `nr_class` and `total_sv` caps,
605/// - `nr_class >= 2`,
606/// - `rho` length for classification, one-class, and regression models,
607/// - `label` and `nr_sv` length when present,
608/// - `sum(nr_sv) == total_sv` when `nr_sv` is present,
609/// - `probA` / `probB` decision-function counts when present,
610/// - one-class-only `prob_density_marks`,
611/// - support-vector feature token shape, ascending feature indices, and
612///   [`LoadOptions::max_feature_index`],
613/// - precomputed-kernel support-vector rows starting with
614///   `0:sample_serial_number`.
615///
616/// This loader does not prove model provenance, semantic correctness relative
617/// to a training set, or suitability for a deployment. Malformed text input
618/// within the configured caps is returned as [`SvmError`] rather than panicking.
619///
620/// ### Complexity
621///
622/// The header parse is `O(nr_class)` in the worst case (due to `rho` /
623/// `label` / `nr_sv` array reads). The SV section is linear in the file
624/// size. Downstream consumers of the returned [`SvmModel`] — notably
625/// `group_classes` and probability estimation — are `O(k²)` on `k =
626/// nr_class`, bounded by [`LoadOptions::max_nr_class`].
627pub fn load_model(path: &Path) -> Result<SvmModel, SvmError> {
628    let file = std::fs::File::open(path)?;
629    let reader = std::io::BufReader::new(file);
630    load_model_from_reader(reader)
631}
632
633/// Load an SVM model from any buffered reader.
634///
635/// Uses [`LoadOptions::default`]. See [`load_model`] for the validation
636/// contract and non-goals. This is the reader-backed equivalent of LIBSVM's
637/// `svm_load_model`.
638///
639/// ```
640/// use libsvm_rs::io::{load_model_from_reader, save_model_to_writer};
641/// use libsvm_rs::train::svm_train;
642/// use libsvm_rs::{KernelType, SvmNode, SvmParameterBuilder, SvmProblem, SvmType};
643///
644/// let problem = SvmProblem {
645///     labels: vec![-1.0, -1.0, 1.0, 1.0],
646///     instances: vec![
647///         vec![SvmNode { index: 1, value: -2.0 }],
648///         vec![SvmNode { index: 1, value: -1.0 }],
649///         vec![SvmNode { index: 1, value: 1.0 }],
650///         vec![SvmNode { index: 1, value: 2.0 }],
651///     ],
652/// };
653/// let param = SvmParameterBuilder::new()
654///     .svm_type(SvmType::CSvc)
655///     .kernel_type(KernelType::Linear)
656///     .build()?;
657/// let model = svm_train(&problem, &param);
658///
659/// let mut bytes = Vec::new();
660/// save_model_to_writer(&mut bytes, &model)?;
661/// let loaded = load_model_from_reader(std::io::Cursor::new(bytes))?;
662/// assert_eq!(loaded.sv.len(), model.sv.len());
663/// # Ok::<(), libsvm_rs::SvmError>(())
664/// ```
665pub fn load_model_from_reader(reader: impl BufRead) -> Result<SvmModel, SvmError> {
666    load_model_from_reader_with_options(reader, &LoadOptions::default())
667}
668
669/// Load an SVM model from any buffered reader, with explicit resource caps.
670///
671/// See [`LoadOptions`] for the meaning of each cap and for defaults tuned for
672/// untrusted input. This function has the same validation contract as
673/// [`load_model`], with caller-supplied caps.
674pub fn load_model_from_reader_with_options(
675    mut reader: impl BufRead,
676    options: &LoadOptions,
677) -> Result<SvmModel, SvmError> {
678    let mut bytes_read: u64 = 0;
679
680    // The `nr_class` / `total_sv` caps are the intersection of the
681    // module-level hard limits (`MAX_NR_CLASS`, `MAX_TOTAL_SV`) and the
682    // per-call `LoadOptions` overrides. A caller using
683    // `LoadOptions::trusted_input()` relaxes only down to the hard caps;
684    // it cannot exceed them. This is defense in depth.
685    let nr_class_cap = options.max_nr_class.min(MAX_NR_CLASS);
686    let total_sv_cap = options.max_sv.min(MAX_TOTAL_SV);
687
688    // Defaults
689    let mut param = SvmParameter::default();
690    let mut nr_class: usize = 0;
691    let mut total_sv: usize = 0;
692    let mut rho: Vec<f64> = Vec::new();
693    let mut label: Vec<i32> = Vec::new();
694    let mut prob_a: Vec<f64> = Vec::new();
695    let mut prob_b: Vec<f64> = Vec::new();
696    let mut prob_density_marks: Vec<f64> = Vec::new();
697    let mut n_sv: Vec<usize> = Vec::new();
698
699    // Read header.
700    let mut line_num: usize = 0;
701    loop {
702        let raw = read_line_capped(
703            &mut reader,
704            &mut bytes_read,
705            options.max_bytes,
706            options.max_line_len,
707        )
708        .map_err(|e| SvmError::ModelFormatError(e.to_string()))?
709        .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in header".into()))?;
710        line_num += 1;
711        let line = raw.trim().to_string();
712        if line.is_empty() {
713            continue;
714        }
715
716        let mut parts = line.split_whitespace();
717        let cmd = parts.next().ok_or_else(|| {
718            SvmError::ModelFormatError(format!("line {}: empty model header line", line_num))
719        })?;
720
721        match cmd {
722            "svm_type" => {
723                let val = parts.next().ok_or_else(|| {
724                    SvmError::ModelFormatError(format!("line {}: missing svm_type value", line_num))
725                })?;
726                param.svm_type = str_to_svm_type(val).ok_or_else(|| {
727                    SvmError::ModelFormatError(format!(
728                        "line {}: unknown svm_type: {}",
729                        line_num, val
730                    ))
731                })?;
732            }
733            "kernel_type" => {
734                let val = parts.next().ok_or_else(|| {
735                    SvmError::ModelFormatError(format!(
736                        "line {}: missing kernel_type value",
737                        line_num
738                    ))
739                })?;
740                param.kernel_type = str_to_kernel_type(val).ok_or_else(|| {
741                    SvmError::ModelFormatError(format!(
742                        "line {}: unknown kernel_type: {}",
743                        line_num, val
744                    ))
745                })?;
746            }
747            "degree" => {
748                let d: i32 = parse_single(&mut parts, line_num, "degree")?;
749                if d < 0 {
750                    return Err(SvmError::ModelFormatError(format!(
751                        "line {}: degree must be >= 0, got {}",
752                        line_num, d
753                    )));
754                }
755                param.degree = d;
756            }
757            "gamma" => {
758                let v: f64 = parse_single(&mut parts, line_num, "gamma")?;
759                if !v.is_finite() {
760                    return Err(SvmError::ModelFormatError(format!(
761                        "line {}: gamma must be finite, got {}",
762                        line_num, v
763                    )));
764                }
765                param.gamma = v;
766            }
767            "coef0" => {
768                let v: f64 = parse_single(&mut parts, line_num, "coef0")?;
769                if !v.is_finite() {
770                    return Err(SvmError::ModelFormatError(format!(
771                        "line {}: coef0 must be finite, got {}",
772                        line_num, v
773                    )));
774                }
775                param.coef0 = v;
776            }
777            "nr_class" => {
778                nr_class = parse_single(&mut parts, line_num, "nr_class")?;
779                if nr_class > nr_class_cap {
780                    return Err(SvmError::ModelFormatError(format!(
781                        "line {}: nr_class exceeds limit ({})",
782                        line_num, nr_class_cap
783                    )));
784                }
785            }
786            "total_sv" => {
787                total_sv = parse_single(&mut parts, line_num, "total_sv")?;
788                if total_sv > total_sv_cap {
789                    return Err(SvmError::ModelFormatError(format!(
790                        "line {}: total_sv exceeds limit ({})",
791                        line_num, total_sv_cap
792                    )));
793                }
794            }
795            "rho" => {
796                rho = parse_multiple(&mut parts, line_num, "rho")?;
797                for &r in &rho {
798                    if !r.is_finite() {
799                        return Err(SvmError::ModelFormatError(format!(
800                            "line {}: rho must be finite, got {}",
801                            line_num, r
802                        )));
803                    }
804                }
805            }
806            "label" => {
807                label = parse_multiple(&mut parts, line_num, "label")?;
808            }
809            "probA" => {
810                prob_a = parse_multiple(&mut parts, line_num, "probA")?;
811            }
812            "probB" => {
813                prob_b = parse_multiple(&mut parts, line_num, "probB")?;
814            }
815            "prob_density_marks" => {
816                prob_density_marks = parse_multiple(&mut parts, line_num, "prob_density_marks")?;
817            }
818            "nr_sv" => {
819                n_sv = parts
820                    .map(|s| {
821                        s.parse::<usize>().map_err(|_| {
822                            SvmError::ModelFormatError(format!(
823                                "line {}: invalid nr_sv value: {}",
824                                line_num, s
825                            ))
826                        })
827                    })
828                    .collect::<Result<Vec<_>, _>>()?;
829            }
830            "SV" => break,
831            _ => {
832                return Err(SvmError::ModelFormatError(format!(
833                    "line {}: unknown keyword: {}",
834                    line_num, cmd
835                )));
836            }
837        }
838    }
839
840    // Cross-consistency checks on the header.
841    //
842    // These run before any per-SV allocation so malformed files are rejected
843    // early, and so downstream code can rely on structural invariants (e.g.
844    // `rho.len() == nr_class * (nr_class - 1) / 2` for multiclass).
845    validate_model_header(
846        param.svm_type,
847        nr_class,
848        total_sv,
849        &rho,
850        &label,
851        &prob_a,
852        &prob_b,
853        &prob_density_marks,
854        &n_sv,
855    )?;
856
857    // Read SV section.
858    //
859    // SECURITY: we do NOT preallocate with `total_sv` capacity. A malicious
860    // file could claim up to `MAX_TOTAL_SV` support vectors in its header,
861    // which would trigger terabyte-scale reservations before any real data is
862    // read. Amortized `Vec` growth caps peak memory at the actually-parsed
863    // payload.
864    let m = if nr_class > 1 { nr_class - 1 } else { 1 };
865    let mut sv_coef: Vec<Vec<f64>> = (0..m).map(|_| Vec::new()).collect();
866    let mut sv: Vec<Vec<SvmNode>> = Vec::new();
867
868    while sv.len() < total_sv {
869        let raw = read_line_capped(
870            &mut reader,
871            &mut bytes_read,
872            options.max_bytes,
873            options.max_line_len,
874        )
875        .map_err(|e| SvmError::ModelFormatError(e.to_string()))?
876        .ok_or_else(|| SvmError::ModelFormatError("unexpected end of file in SV section".into()))?;
877        line_num += 1;
878        let line = raw.trim();
879        if line.is_empty() {
880            // Skip blank lines without consuming an SV slot — the header's
881            // `total_sv` must match the number of SV rows we actually collect.
882            continue;
883        }
884
885        let mut parts = line.split_whitespace();
886
887        // First m tokens are sv_coef values
888        for (k, coef_row) in sv_coef.iter_mut().enumerate() {
889            let val_str = parts.next().ok_or_else(|| {
890                SvmError::ModelFormatError(format!("line {}: missing sv_coef[{}]", line_num, k))
891            })?;
892            let val: f64 = val_str.parse().map_err(|_| {
893                SvmError::ModelFormatError(format!(
894                    "line {}: invalid sv_coef: {}",
895                    line_num, val_str
896                ))
897            })?;
898            if !val.is_finite() {
899                return Err(SvmError::ModelFormatError(format!(
900                    "line {}: sv_coef must be finite, got {}",
901                    line_num, val_str
902                )));
903            }
904            coef_row.push(val);
905        }
906
907        // Remaining tokens are index:value pairs (ascending index order, same
908        // invariant as the problem-file parser).
909        let mut nodes = Vec::new();
910        let mut prev_index: i32 = 0;
911        for token in parts {
912            let (idx_str, val_str) = token.split_once(':').ok_or_else(|| {
913                SvmError::ModelFormatError(format!(
914                    "line {}: expected index:value, got: {}",
915                    line_num, token
916                ))
917            })?;
918            let index: i32 =
919                parse_feature_index_model_line(line_num, idx_str, options.max_feature_index)?;
920
921            if !nodes.is_empty() && index <= prev_index {
922                return Err(SvmError::ModelFormatError(format!(
923                    "line {}: feature indices must be ascending: {} follows {}",
924                    line_num, index, prev_index
925                )));
926            }
927
928            let value: f64 = val_str.parse().map_err(|_| {
929                SvmError::ModelFormatError(format!("line {}: invalid value: {}", line_num, val_str))
930            })?;
931            if !value.is_finite() {
932                return Err(SvmError::ModelFormatError(format!(
933                    "line {}: feature value must be finite, got {}",
934                    line_num, val_str
935                )));
936            }
937            prev_index = index;
938            nodes.push(SvmNode { index, value });
939        }
940
941        if param.kernel_type == KernelType::Precomputed {
942            validate_precomputed_row(&nodes, line_num, "support vector")?;
943        }
944        sv.push(nodes);
945    }
946
947    let model = SvmModel {
948        param,
949        nr_class,
950        sv,
951        sv_coef,
952        rho,
953        prob_a,
954        prob_b,
955        prob_density_marks,
956        sv_indices: Vec::new(), // not stored in model file
957        label,
958        n_sv,
959    };
960    validate_model(&model)?;
961    Ok(model)
962}
963
964/// Validate the shared structural invariants required of a loaded/deserialized model.
965pub(crate) fn validate_model(model: &SvmModel) -> Result<(), SvmError> {
966    if model.param.degree < 0 {
967        return Err(SvmError::ModelFormatError(format!(
968            "degree must be >= 0, got {}",
969            model.param.degree
970        )));
971    }
972    if !model.param.gamma.is_finite() {
973        return Err(SvmError::ModelFormatError(format!(
974            "gamma must be finite, got {}",
975            model.param.gamma
976        )));
977    }
978    if matches!(
979        model.param.kernel_type,
980        KernelType::Polynomial | KernelType::Rbf | KernelType::Sigmoid
981    ) && model.param.gamma < 0.0
982    {
983        return Err(SvmError::ModelFormatError(format!(
984            "gamma must be >= 0, got {}",
985            model.param.gamma
986        )));
987    }
988    if !model.param.coef0.is_finite() {
989        return Err(SvmError::ModelFormatError(format!(
990            "coef0 must be finite, got {}",
991            model.param.coef0
992        )));
993    }
994
995    validate_model_header(
996        model.param.svm_type,
997        model.nr_class,
998        model.sv.len(),
999        &model.rho,
1000        &model.label,
1001        &model.prob_a,
1002        &model.prob_b,
1003        &model.prob_density_marks,
1004        &model.n_sv,
1005    )?;
1006
1007    let expected_rows = model.nr_class.saturating_sub(1).max(1);
1008    if model.sv_coef.len() != expected_rows {
1009        return Err(SvmError::ModelFormatError(format!(
1010            "sv_coef has {} rows, expected {}",
1011            model.sv_coef.len(),
1012            expected_rows
1013        )));
1014    }
1015    for (row_idx, row) in model.sv_coef.iter().enumerate() {
1016        if row.len() != model.sv.len() {
1017            return Err(SvmError::ModelFormatError(format!(
1018                "sv_coef row {} has {} entries, expected {}",
1019                row_idx,
1020                row.len(),
1021                model.sv.len()
1022            )));
1023        }
1024        for &coef in row {
1025            if !coef.is_finite() {
1026                return Err(SvmError::ModelFormatError(format!(
1027                    "sv_coef must be finite, got {}",
1028                    coef
1029                )));
1030            }
1031        }
1032    }
1033
1034    for &r in &model.rho {
1035        if !r.is_finite() {
1036            return Err(SvmError::ModelFormatError(format!(
1037                "rho must be finite, got {}",
1038                r
1039            )));
1040        }
1041    }
1042    for (name, values) in [("probA", &model.prob_a), ("probB", &model.prob_b)] {
1043        for &value in values {
1044            if !value.is_finite() {
1045                return Err(SvmError::ModelFormatError(format!(
1046                    "{} must be finite, got {}",
1047                    name, value
1048                )));
1049            }
1050        }
1051    }
1052    for &value in &model.prob_density_marks {
1053        if !value.is_finite() {
1054            return Err(SvmError::ModelFormatError(format!(
1055                "prob_density_marks must be finite, got {}",
1056                value
1057            )));
1058        }
1059    }
1060
1061    for (row_idx, nodes) in model.sv.iter().enumerate() {
1062        let mut prev_index = 0;
1063        for (node_idx, node) in nodes.iter().enumerate() {
1064            if node_idx > 0 && node.index <= prev_index {
1065                return Err(SvmError::ModelFormatError(format!(
1066                    "support vector {} feature indices must be ascending: {} follows {}",
1067                    row_idx + 1,
1068                    node.index,
1069                    prev_index
1070                )));
1071            }
1072            if !node.value.is_finite() {
1073                return Err(SvmError::ModelFormatError(format!(
1074                    "feature value must be finite, got {}",
1075                    node.value
1076                )));
1077            }
1078            prev_index = node.index;
1079        }
1080        if model.param.kernel_type == KernelType::Precomputed {
1081            validate_precomputed_row(nodes, row_idx + 1, "support vector")?;
1082        }
1083    }
1084
1085    Ok(())
1086}
1087
1088fn validate_precomputed_row(
1089    nodes: &[SvmNode],
1090    line_num: usize,
1091    context: &str,
1092) -> Result<(), SvmError> {
1093    let first = nodes.first().ok_or_else(|| {
1094        SvmError::ModelFormatError(format!(
1095            "line {}: precomputed kernel {} is missing 0:sample_serial_number",
1096            line_num, context
1097        ))
1098    })?;
1099
1100    if first.index != 0
1101        || !first.value.is_finite()
1102        || first.value < 1.0
1103        || first.value.fract() != 0.0
1104    {
1105        return Err(SvmError::ModelFormatError(format!(
1106            "line {}: precomputed kernel {} must start with 0:sample_serial_number",
1107            line_num, context
1108        )));
1109    }
1110
1111    Ok(())
1112}
1113
1114// ─── Cross-consistency validation ────────────────────────────────────
1115
1116/// Validate model-header invariants before reading the SV section.
1117///
1118/// A malformed or adversarial model file can pass individual field parses
1119/// and still describe a structurally impossible SVM. This gate rejects the
1120/// mismatch early, before any allocation keyed on `total_sv`, so downstream
1121/// code (prediction, probability estimation) can rely on the usual LIBSVM
1122/// shape contracts:
1123///
1124/// * `rho.len() == k * (k - 1) / 2` for `k = nr_class` on classification,
1125///   `rho.len() == 1` for one-class / regression (where `nr_class == 2`).
1126/// * `label.len() == nr_class` and `n_sv.len() == nr_class` if supplied.
1127/// * `sum(n_sv) == total_sv` if `n_sv` is supplied.
1128/// * `prob_a` / `prob_b` (if supplied) match the expected decision-function
1129///   count, and `prob_density_marks` only appears on one-class models.
1130///
1131/// Optional fields (e.g. `label`, `n_sv`, `probA`) are only validated when
1132/// present, because minimal hand-written fixtures and some legacy writers
1133/// omit them; the invariant "if present, must be consistent" is what matters
1134/// for safety.
1135#[allow(clippy::too_many_arguments)]
1136fn validate_model_header(
1137    svm_type: SvmType,
1138    nr_class: usize,
1139    total_sv: usize,
1140    rho: &[f64],
1141    label: &[i32],
1142    prob_a: &[f64],
1143    prob_b: &[f64],
1144    prob_density_marks: &[f64],
1145    n_sv: &[usize],
1146) -> Result<(), SvmError> {
1147    let is_classification = matches!(svm_type, SvmType::CSvc | SvmType::NuSvc);
1148    let is_regression = matches!(svm_type, SvmType::EpsilonSvr | SvmType::NuSvr);
1149    let is_one_class = matches!(svm_type, SvmType::OneClass);
1150
1151    // nr_class must be at least 2 under the LIBSVM convention (regression and
1152    // one-class store nr_class=2 as well, because the one-vs-one scaffolding
1153    // is reused). nr_class==0 or 1 would yield `m = nr_class - 1 = 0` or
1154    // underflow-prone arithmetic elsewhere.
1155    if nr_class < 2 {
1156        return Err(SvmError::ModelFormatError(format!(
1157            "nr_class must be >= 2, got {}",
1158            nr_class
1159        )));
1160    }
1161
1162    // Expected rho length depends on svm_type.
1163    let expected_rho = if is_classification {
1164        nr_class * (nr_class - 1) / 2
1165    } else {
1166        1
1167    };
1168    if rho.len() != expected_rho {
1169        return Err(SvmError::ModelFormatError(format!(
1170            "rho has {} entries, expected {} for svm_type {}",
1171            rho.len(),
1172            expected_rho,
1173            svm_type_to_str(svm_type)
1174        )));
1175    }
1176
1177    // label is mandatory for classification types; must equal nr_class when present.
1178    // For one_class / epsilon_svr / nu_svr, label is absent and empty is legitimate.
1179    if is_classification {
1180        // Require label to be present and correctly sized; an empty label vec on a
1181        // classification model means the file omitted the `label` line entirely, which
1182        // would leave predict.rs indexing `model.label[i]` out of bounds.
1183        if label.len() != nr_class {
1184            return Err(SvmError::ModelFormatError(format!(
1185                "label has {} entries, expected nr_class ({}) for svm_type {}",
1186                label.len(),
1187                nr_class,
1188                svm_type_to_str(svm_type)
1189            )));
1190        }
1191    } else if !label.is_empty() {
1192        return Err(SvmError::ModelFormatError(format!(
1193            "label is only valid for classification, got {} entries on svm_type {}",
1194            label.len(),
1195            svm_type_to_str(svm_type)
1196        )));
1197    }
1198
1199    // n_sv: mandatory for classification; absent on regression/one-class.
1200    // An empty n_sv on a classification model means the file omitted `nr_sv`,
1201    // which would leave predict.rs indexing `model.n_sv[i-1]` out of bounds.
1202    if is_classification {
1203        if n_sv.len() != nr_class {
1204            return Err(SvmError::ModelFormatError(format!(
1205                "nr_sv has {} entries, expected nr_class ({}) for svm_type {}",
1206                n_sv.len(),
1207                nr_class,
1208                svm_type_to_str(svm_type)
1209            )));
1210        }
1211        // Use checked_add to prevent silent overflow on malicious huge values.
1212        // MAX_TOTAL_SV bounds total_sv already; n_sv values are parsed as
1213        // `usize` and otherwise unbounded until this sum-check.
1214        let mut sum: usize = 0;
1215        for &n in n_sv {
1216            sum = sum.checked_add(n).ok_or_else(|| {
1217                SvmError::ModelFormatError("nr_sv entries overflow usize when summed".into())
1218            })?;
1219        }
1220        if sum != total_sv {
1221            return Err(SvmError::ModelFormatError(format!(
1222                "sum of nr_sv entries ({}) does not match total_sv ({})",
1223                sum, total_sv
1224            )));
1225        }
1226    } else if !n_sv.is_empty() {
1227        return Err(SvmError::ModelFormatError(format!(
1228            "nr_sv is only valid for classification, got {} entries on svm_type {}",
1229            n_sv.len(),
1230            svm_type_to_str(svm_type)
1231        )));
1232    }
1233
1234    // Probability arrays: must either be absent or length-match rho.
1235    if !prob_a.is_empty() && prob_a.len() != expected_rho {
1236        return Err(SvmError::ModelFormatError(format!(
1237            "probA has {} entries, expected {}",
1238            prob_a.len(),
1239            expected_rho
1240        )));
1241    }
1242    if !prob_b.is_empty() && prob_b.len() != expected_rho {
1243        return Err(SvmError::ModelFormatError(format!(
1244            "probB has {} entries, expected {}",
1245            prob_b.len(),
1246            expected_rho
1247        )));
1248    }
1249
1250    // prob_density_marks is only meaningful for one-class.
1251    if !prob_density_marks.is_empty() && !is_one_class {
1252        return Err(SvmError::ModelFormatError(format!(
1253            "prob_density_marks is only valid for one-class SVM, got {} entries on svm_type {}",
1254            prob_density_marks.len(),
1255            svm_type_to_str(svm_type)
1256        )));
1257    }
1258
1259    // Regression/one-class should not carry classification-only artifacts.
1260    // (Already caught above via `label` / `n_sv` branches; this assertion
1261    // keeps the intent self-documenting for future maintainers.)
1262    let _ = is_regression;
1263
1264    Ok(())
1265}
1266
1267// ─── Helper parsers ──────────────────────────────────────────────────
1268
1269fn parse_feature_index_problem_line(
1270    line_num: usize,
1271    idx_str: &str,
1272    max_feature_index: i32,
1273) -> Result<i32, SvmError> {
1274    parse_feature_index(idx_str, max_feature_index).map_err(|msg| SvmError::ParseError {
1275        line: line_num,
1276        message: msg,
1277    })
1278}
1279
1280fn parse_feature_index_model_line(
1281    line_num: usize,
1282    idx_str: &str,
1283    max_feature_index: i32,
1284) -> Result<i32, SvmError> {
1285    parse_feature_index(idx_str, max_feature_index)
1286        .map_err(|msg| SvmError::ModelFormatError(format!("line {}: {}", line_num, msg)))
1287}
1288
1289fn parse_single<T: std::str::FromStr>(
1290    parts: &mut std::str::SplitWhitespace<'_>,
1291    line_num: usize,
1292    field: &str,
1293) -> Result<T, SvmError> {
1294    let val_str = parts.next().ok_or_else(|| {
1295        SvmError::ModelFormatError(format!("line {}: missing {} value", line_num, field))
1296    })?;
1297    val_str.parse().map_err(|_| {
1298        SvmError::ModelFormatError(format!(
1299            "line {}: invalid {} value: {}",
1300            line_num, field, val_str
1301        ))
1302    })
1303}
1304
1305fn parse_multiple<T: std::str::FromStr>(
1306    parts: &mut std::str::SplitWhitespace<'_>,
1307    line_num: usize,
1308    field: &str,
1309) -> Result<Vec<T>, SvmError> {
1310    parts
1311        .map(|s| {
1312            s.parse::<T>().map_err(|_| {
1313                SvmError::ModelFormatError(format!(
1314                    "line {}: invalid {} value: {}",
1315                    line_num, field, s
1316                ))
1317            })
1318        })
1319        .collect()
1320}
1321
1322// ─── Tests ───────────────────────────────────────────────────────────
1323
1324#[cfg(test)]
1325mod tests {
1326    use super::*;
1327    use std::path::PathBuf;
1328
1329    fn data_dir() -> PathBuf {
1330        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1331            .join("..")
1332            .join("..")
1333            .join("data")
1334    }
1335
1336    #[test]
1337    fn parse_heart_scale() {
1338        let path = data_dir().join("heart_scale");
1339        let problem = load_problem(&path).unwrap();
1340        assert_eq!(problem.labels.len(), 270);
1341        assert_eq!(problem.instances.len(), 270);
1342        // First instance: +1 label, 12 features (index 11 is missing/sparse)
1343        assert_eq!(problem.labels[0], 1.0);
1344        assert_eq!(
1345            problem.instances[0][0],
1346            SvmNode {
1347                index: 1,
1348                value: 0.708333
1349            }
1350        );
1351        assert_eq!(problem.instances[0].len(), 12);
1352    }
1353
1354    #[test]
1355    fn parse_iris() {
1356        let path = data_dir().join("iris.scale");
1357        let problem = load_problem(&path).unwrap();
1358        assert_eq!(problem.labels.len(), 150);
1359        // 3 classes: 1, 2, 3
1360        let classes: std::collections::HashSet<i64> =
1361            problem.labels.iter().map(|&l| l as i64).collect();
1362        assert_eq!(classes.len(), 3);
1363    }
1364
1365    #[test]
1366    fn parse_housing() {
1367        let path = data_dir().join("housing_scale");
1368        let problem = load_problem(&path).unwrap();
1369        assert_eq!(problem.labels.len(), 506);
1370        // Regression: labels are continuous
1371        assert!((problem.labels[0] - 24.0).abs() < 1e-10);
1372    }
1373
1374    #[test]
1375    fn parse_empty_lines() {
1376        let input = b"+1 1:0.5\n\n-1 2:0.3\n";
1377        let problem = load_problem_from_reader(&input[..]).unwrap();
1378        assert_eq!(problem.labels.len(), 2);
1379    }
1380
1381    #[test]
1382    fn parse_error_unsorted_indices() {
1383        let input = b"+1 3:0.5 1:0.3\n";
1384        let result = load_problem_from_reader(&input[..]);
1385        assert!(result.is_err());
1386        let msg = format!("{}", result.unwrap_err());
1387        assert!(msg.contains("ascending"), "error: {}", msg);
1388    }
1389
1390    #[test]
1391    fn parse_error_duplicate_indices() {
1392        let input = b"+1 1:0.5 1:0.3\n";
1393        let result = load_problem_from_reader(&input[..]);
1394        assert!(result.is_err());
1395    }
1396
1397    #[test]
1398    fn parse_error_missing_colon() {
1399        let input = b"+1 1:0.5 bad_token\n";
1400        let result = load_problem_from_reader(&input[..]);
1401        assert!(result.is_err());
1402    }
1403
1404    #[test]
1405    #[allow(clippy::excessive_precision)]
1406    fn load_c_trained_model() {
1407        // Load a model produced by the original C LIBSVM svm-train
1408        let path = data_dir().join("heart_scale.model");
1409        let model = load_model(&path).unwrap();
1410        assert_eq!(model.nr_class, 2);
1411        assert_eq!(model.param.svm_type, SvmType::CSvc);
1412        assert_eq!(model.param.kernel_type, KernelType::Rbf);
1413        assert!((model.param.gamma - 0.076923076923076927).abs() < 1e-15);
1414        assert_eq!(model.sv.len(), 132);
1415        assert_eq!(model.label, vec![1, -1]);
1416        assert_eq!(model.n_sv, vec![64, 68]);
1417        assert!((model.rho[0] - 0.42446205176771573).abs() < 1e-15);
1418        // sv_coef should have 1 row (nr_class - 1) with 132 entries
1419        assert_eq!(model.sv_coef.len(), 1);
1420        assert_eq!(model.sv_coef[0].len(), 132);
1421    }
1422
1423    #[test]
1424    fn roundtrip_c_model() {
1425        // Load C model, save it back, and verify byte-exact match
1426        let path = data_dir().join("heart_scale.model");
1427        let original_bytes = std::fs::read_to_string(&path).unwrap();
1428        let model = load_model(&path).unwrap();
1429
1430        let mut buf = Vec::new();
1431        save_model_to_writer(&mut buf, &model).unwrap();
1432        let rust_output = String::from_utf8(buf).unwrap();
1433
1434        // Compare line by line for better diagnostics
1435        let orig_lines: Vec<&str> = original_bytes.lines().collect();
1436        let rust_lines: Vec<&str> = rust_output.lines().collect();
1437        assert_eq!(
1438            orig_lines.len(),
1439            rust_lines.len(),
1440            "line count mismatch: C={} Rust={}",
1441            orig_lines.len(),
1442            rust_lines.len()
1443        );
1444        for (i, (o, r)) in orig_lines.iter().zip(rust_lines.iter()).enumerate() {
1445            assert_eq!(
1446                o,
1447                r,
1448                "line {} differs:\n  C:    {:?}\n  Rust: {:?}",
1449                i + 1,
1450                o,
1451                r
1452            );
1453        }
1454    }
1455
1456    #[test]
1457    #[allow(clippy::excessive_precision)]
1458    fn gfmt_matches_c_printf() {
1459        // Reference values from C's printf("%.17g|%.8g\n", v, v)
1460        let cases: &[(f64, &str, &str)] = &[
1461            (0.5, "0.5", "0.5"),
1462            (-1.0, "-1", "-1"),
1463            (0.123456789012345, "0.123456789012345", "0.12345679"),
1464            (-0.987654321098765, "-0.98765432109876505", "-0.98765432"),
1465            (0.42446200000000001, "0.42446200000000001", "0.424462"),
1466            (0.0, "0", "0"),
1467            (1e-5, "1.0000000000000001e-05", "1e-05"),
1468            (1e-4, "0.0001", "0.0001"),
1469            (1e20, "1e+20", "1e+20"),
1470            (-0.25, "-0.25", "-0.25"),
1471            (0.75, "0.75", "0.75"),
1472            (0.708333, "0.70833299999999999", "0.708333"),
1473            (1.0, "1", "1"),
1474        ];
1475        for &(v, expected_17g, expected_8g) in cases {
1476            let got_17 = format!("{}", fmt_17g(v));
1477            let got_8 = format!("{}", fmt_8g(v));
1478            assert_eq!(got_17, expected_17g, "%.17g mismatch for {}", v);
1479            assert_eq!(got_8, expected_8g, "%.8g mismatch for {}", v);
1480        }
1481    }
1482
1483    #[test]
1484    #[allow(clippy::excessive_precision)]
1485    fn model_roundtrip() {
1486        // Create a minimal model and verify save → load roundtrip
1487        let model = SvmModel {
1488            param: SvmParameter {
1489                svm_type: SvmType::CSvc,
1490                kernel_type: KernelType::Rbf,
1491                gamma: 0.5,
1492                ..Default::default()
1493            },
1494            nr_class: 2,
1495            sv: vec![
1496                vec![
1497                    SvmNode {
1498                        index: 1,
1499                        value: 0.5,
1500                    },
1501                    SvmNode {
1502                        index: 3,
1503                        value: -1.0,
1504                    },
1505                ],
1506                vec![
1507                    SvmNode {
1508                        index: 1,
1509                        value: -0.25,
1510                    },
1511                    SvmNode {
1512                        index: 2,
1513                        value: 0.75,
1514                    },
1515                ],
1516            ],
1517            sv_coef: vec![vec![0.123456789012345, -0.987654321098765]],
1518            rho: vec![0.42446200000000001],
1519            prob_a: vec![],
1520            prob_b: vec![],
1521            prob_density_marks: vec![],
1522            sv_indices: vec![],
1523            label: vec![1, -1],
1524            n_sv: vec![1, 1],
1525        };
1526
1527        let mut buf = Vec::new();
1528        save_model_to_writer(&mut buf, &model).unwrap();
1529
1530        let loaded = load_model_from_reader(&buf[..]).unwrap();
1531
1532        assert_eq!(loaded.nr_class, model.nr_class);
1533        assert_eq!(loaded.param.svm_type, model.param.svm_type);
1534        assert_eq!(loaded.param.kernel_type, model.param.kernel_type);
1535        assert_eq!(loaded.sv.len(), model.sv.len());
1536        assert_eq!(loaded.label, model.label);
1537        assert_eq!(loaded.n_sv, model.n_sv);
1538        assert_eq!(loaded.rho.len(), model.rho.len());
1539        // Check rho within tolerance (roundtrip through text)
1540        for (a, b) in loaded.rho.iter().zip(model.rho.iter()) {
1541            assert!((a - b).abs() < 1e-10, "rho mismatch: {} vs {}", a, b);
1542        }
1543        // Check sv_coef within tolerance
1544        for (row_a, row_b) in loaded.sv_coef.iter().zip(model.sv_coef.iter()) {
1545            for (a, b) in row_a.iter().zip(row_b.iter()) {
1546                assert!((a - b).abs() < 1e-10, "sv_coef mismatch: {} vs {}", a, b);
1547            }
1548        }
1549    }
1550
1551    #[test]
1552    fn parse_error_excessive_counts() {
1553        let input =
1554            b"svm_type c_svc\nkernel_type linear\nnr_class 1000000\ntotal_sv 100\nrho 0\nSV\n";
1555        let result = load_model_from_reader(&input[..]);
1556        assert!(result.is_err());
1557        assert!(format!("{}", result.unwrap_err()).contains("nr_class exceeds limit"));
1558
1559        let input =
1560            b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 100000000\nrho 0\nSV\n";
1561        let result = load_model_from_reader(&input[..]);
1562        assert!(result.is_err());
1563        assert!(format!("{}", result.unwrap_err()).contains("total_sv exceeds limit"));
1564    }
1565
1566    #[test]
1567    fn parse_error_excessive_feature_index() {
1568        // Problem file
1569        let input = b"1 10000001:1\n";
1570        let result = load_problem_from_reader(&input[..]);
1571        assert!(result.is_err());
1572        assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
1573
1574        // Model file
1575        let input = b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 1\nrho 0\nlabel 1 -1\nnr_sv 1 0\nSV\n0.1 10000001:1\n";
1576        let result = load_model_from_reader(&input[..]);
1577        assert!(result.is_err());
1578        assert!(format!("{}", result.unwrap_err()).contains("feature index 10000001 exceeds limit"));
1579    }
1580
1581    #[test]
1582    fn parse_error_unknown_model_keyword() {
1583        let input = b"bad_key value\n";
1584        let result = load_model_from_reader(&input[..]);
1585        assert!(result.is_err());
1586        assert!(format!("{}", result.unwrap_err()).contains("unknown keyword"));
1587    }
1588
1589    #[test]
1590    fn parse_error_missing_or_unknown_model_values() {
1591        let missing = b"svm_type\n";
1592        let err = load_model_from_reader(&missing[..]).unwrap_err();
1593        assert!(format!("{}", err).contains("missing svm_type value"));
1594
1595        let unknown = b"svm_type unknown_type\n";
1596        let err = load_model_from_reader(&unknown[..]).unwrap_err();
1597        assert!(format!("{}", err).contains("unknown svm_type"));
1598    }
1599
1600    #[test]
1601    fn parse_error_invalid_nr_sv_entry() {
1602        let input = b"svm_type c_svc\n\
1603kernel_type linear\n\
1604nr_class 2\n\
1605total_sv 1\n\
1606rho 0\n\
1607label 1 -1\n\
1608nr_sv a 1\n\
1609SV\n\
16100.1 1:0.5\n";
1611        let err = load_model_from_reader(&input[..]).unwrap_err();
1612        assert!(format!("{}", err).contains("invalid nr_sv value"));
1613    }
1614
1615    #[test]
1616    fn parse_error_in_sv_section_tokens() {
1617        let missing_coef = b"svm_type c_svc\n\
1618kernel_type linear\n\
1619nr_class 2\n\
1620total_sv 1\n\
1621rho 0\n\
1622label 1 -1\n\
1623nr_sv 1 0\n\
1624SV\n\
16251:0.5\n";
1626        let err = load_model_from_reader(&missing_coef[..]).unwrap_err();
1627        assert!(format!("{}", err).contains("invalid sv_coef"));
1628
1629        let bad_feature = b"svm_type c_svc\n\
1630kernel_type linear\n\
1631nr_class 2\n\
1632total_sv 1\n\
1633rho 0\n\
1634label 1 -1\n\
1635nr_sv 1 0\n\
1636SV\n\
16370.1 bad\n";
1638        let err = load_model_from_reader(&bad_feature[..]).unwrap_err();
1639        assert!(format!("{}", err).contains("expected index:value"));
1640    }
1641
1642    #[test]
1643    fn parse_error_unexpected_eof_in_header_and_sv_section() {
1644        let eof_header = b"svm_type c_svc\n";
1645        let err = load_model_from_reader(&eof_header[..]).unwrap_err();
1646        assert!(format!("{}", err).contains("unexpected end of file in header"));
1647
1648        let eof_sv = b"svm_type c_svc\n\
1649kernel_type linear\n\
1650nr_class 2\n\
1651total_sv 2\n\
1652rho 0\n\
1653label 1 -1\n\
1654nr_sv 1 1\n\
1655SV\n\
16560.1 1:0.5\n";
1657        let err = load_model_from_reader(&eof_sv[..]).unwrap_err();
1658        assert!(format!("{}", err).contains("unexpected end of file in SV section"));
1659    }
1660
1661    #[test]
1662    fn reject_rho_length_mismatch_for_classification() {
1663        // CSvc with nr_class=3 expects rho.len() == 3. Supplying 1 entry
1664        // reveals either a malformed file or an intentional substitution.
1665        let input = b"svm_type c_svc\n\
1666kernel_type linear\n\
1667nr_class 3\n\
1668total_sv 3\n\
1669rho 0\n\
1670label 1 -1 0\n\
1671nr_sv 1 1 1\n\
1672SV\n";
1673        let err = load_model_from_reader(&input[..]).unwrap_err();
1674        assert!(
1675            format!("{}", err).contains("rho has 1 entries, expected 3"),
1676            "unexpected error: {}",
1677            err
1678        );
1679    }
1680
1681    #[test]
1682    fn reject_rho_length_mismatch_for_regression() {
1683        // SVR expects exactly 1 rho entry; 2 entries is inconsistent.
1684        let input = b"svm_type epsilon_svr\n\
1685kernel_type linear\n\
1686nr_class 2\n\
1687total_sv 0\n\
1688rho 0 1\n\
1689SV\n";
1690        let err = load_model_from_reader(&input[..]).unwrap_err();
1691        assert!(
1692            format!("{}", err).contains("rho has 2 entries, expected 1"),
1693            "unexpected error: {}",
1694            err
1695        );
1696    }
1697
1698    #[test]
1699    fn reject_label_on_regression() {
1700        // label is classification-only; carrying it on SVR is a red flag.
1701        let input = b"svm_type epsilon_svr\n\
1702kernel_type linear\n\
1703nr_class 2\n\
1704total_sv 0\n\
1705rho 0\n\
1706label 1 -1\n\
1707SV\n";
1708        let err = load_model_from_reader(&input[..]).unwrap_err();
1709        assert!(
1710            format!("{}", err).contains("label is only valid for classification"),
1711            "unexpected error: {}",
1712            err
1713        );
1714    }
1715
1716    #[test]
1717    fn reject_label_length_mismatch() {
1718        let input = b"svm_type c_svc\n\
1719kernel_type linear\n\
1720nr_class 3\n\
1721total_sv 0\n\
1722rho 0 0 0\n\
1723label 1 -1\n\
1724SV\n";
1725        let err = load_model_from_reader(&input[..]).unwrap_err();
1726        assert!(
1727            format!("{}", err).contains("label has 2 entries, expected nr_class (3)"),
1728            "unexpected error: {}",
1729            err
1730        );
1731    }
1732
1733    #[test]
1734    fn reject_nr_sv_sum_mismatch() {
1735        // total_sv=5 but nr_sv sums to 3. An inconsistent header like this
1736        // could previously pass and leave downstream code with stale assumptions.
1737        let input = b"svm_type c_svc\n\
1738kernel_type linear\n\
1739nr_class 2\n\
1740total_sv 5\n\
1741rho 0\n\
1742label 1 -1\n\
1743nr_sv 1 2\n\
1744SV\n";
1745        let err = load_model_from_reader(&input[..]).unwrap_err();
1746        assert!(
1747            format!("{}", err).contains("sum of nr_sv entries (3) does not match total_sv (5)"),
1748            "unexpected error: {}",
1749            err
1750        );
1751    }
1752
1753    #[test]
1754    fn reject_nr_sv_length_mismatch() {
1755        let input = b"svm_type c_svc\n\
1756kernel_type linear\n\
1757nr_class 3\n\
1758total_sv 3\n\
1759rho 0 0 0\n\
1760label 1 -1 0\n\
1761nr_sv 1 2\n\
1762SV\n";
1763        let err = load_model_from_reader(&input[..]).unwrap_err();
1764        assert!(
1765            format!("{}", err).contains("nr_sv has 2 entries, expected nr_class (3)"),
1766            "unexpected error: {}",
1767            err
1768        );
1769    }
1770
1771    #[test]
1772    fn reject_proba_length_mismatch() {
1773        let input = b"svm_type c_svc\n\
1774kernel_type linear\n\
1775nr_class 3\n\
1776total_sv 0\n\
1777rho 0 0 0\n\
1778label 1 -1 0\n\
1779nr_sv 0 0 0\n\
1780probA 0.1 0.2\n\
1781SV\n";
1782        let err = load_model_from_reader(&input[..]).unwrap_err();
1783        assert!(
1784            format!("{}", err).contains("probA has 2 entries, expected 3"),
1785            "unexpected error: {}",
1786            err
1787        );
1788    }
1789
1790    #[test]
1791    fn reject_prob_density_marks_on_csvc() {
1792        let input = b"svm_type c_svc\n\
1793kernel_type linear\n\
1794nr_class 2\n\
1795total_sv 0\n\
1796rho 0\n\
1797label 1 -1\n\
1798nr_sv 0 0\n\
1799prob_density_marks 0.1 0.2\n\
1800SV\n";
1801        let err = load_model_from_reader(&input[..]).unwrap_err();
1802        assert!(
1803            format!("{}", err).contains("prob_density_marks is only valid for one-class SVM"),
1804            "unexpected error: {}",
1805            err
1806        );
1807    }
1808
1809    #[test]
1810    fn reject_nr_class_below_two() {
1811        let input = b"svm_type c_svc\n\
1812kernel_type linear\n\
1813nr_class 1\n\
1814total_sv 0\n\
1815rho\n\
1816SV\n";
1817        let err = load_model_from_reader(&input[..]).unwrap_err();
1818        assert!(
1819            format!("{}", err).contains("nr_class must be >= 2, got 1"),
1820            "unexpected error: {}",
1821            err
1822        );
1823    }
1824
1825    #[test]
1826    fn reject_sv_feature_indices_not_ascending() {
1827        let input = b"svm_type c_svc\n\
1828kernel_type linear\n\
1829nr_class 2\n\
1830total_sv 1\n\
1831rho 0\n\
1832label 1 -1\n\
1833nr_sv 1 0\n\
1834SV\n\
18350.1 3:0.5 1:0.3\n";
1836        let err = load_model_from_reader(&input[..]).unwrap_err();
1837        assert!(
1838            format!("{}", err).contains("feature indices must be ascending"),
1839            "unexpected error: {}",
1840            err
1841        );
1842    }
1843
1844    #[test]
1845    fn reject_precomputed_model_sv_without_sample_serial_number() {
1846        let input = b"svm_type c_svc\n\
1847kernel_type precomputed\n\
1848nr_class 2\n\
1849total_sv 1\n\
1850rho 0\n\
1851label 1 -1\n\
1852nr_sv 1 0\n\
1853SV\n\
18540.1\n";
1855        let err = load_model_from_reader(&input[..]).unwrap_err();
1856        assert!(
1857            format!("{}", err).contains("missing 0:sample_serial_number"),
1858            "unexpected error: {}",
1859            err
1860        );
1861    }
1862
1863    #[test]
1864    fn reject_sv_feature_index_duplicated() {
1865        let input = b"svm_type c_svc\n\
1866kernel_type linear\n\
1867nr_class 2\n\
1868total_sv 1\n\
1869rho 0\n\
1870label 1 -1\n\
1871nr_sv 1 0\n\
1872SV\n\
18730.1 1:0.5 1:0.3\n";
1874        let err = load_model_from_reader(&input[..]).unwrap_err();
1875        assert!(
1876            format!("{}", err).contains("feature indices must be ascending"),
1877            "unexpected error: {}",
1878            err
1879        );
1880    }
1881
1882    #[test]
1883    fn load_options_default_caps_match_documented_values() {
1884        let opts = LoadOptions::default();
1885        assert_eq!(opts.max_bytes, 64 * 1024 * 1024);
1886        assert_eq!(opts.max_line_len, 1024 * 1024);
1887        assert_eq!(opts.max_sv, MAX_TOTAL_SV);
1888        assert_eq!(opts.max_nr_class, MAX_NR_CLASS);
1889        assert_eq!(opts.max_feature_index, MAX_FEATURE_INDEX);
1890    }
1891
1892    #[test]
1893    fn load_options_trusted_input_sets_type_maxes() {
1894        let opts = LoadOptions::trusted_input();
1895        assert_eq!(opts.max_bytes, u64::MAX);
1896        assert_eq!(opts.max_line_len, usize::MAX);
1897        assert_eq!(opts.max_sv, usize::MAX);
1898        assert_eq!(opts.max_nr_class, usize::MAX);
1899        assert_eq!(opts.max_feature_index, i32::MAX);
1900    }
1901
1902    #[test]
1903    fn problem_reader_rejects_file_over_max_bytes() {
1904        // 20 bytes of content, cap at 10.
1905        let input = b"+1 1:0.5\n+1 2:0.5\n";
1906        let opts = LoadOptions {
1907            max_bytes: 10,
1908            ..LoadOptions::default()
1909        };
1910        let err = load_problem_from_reader_with_options(&input[..], &opts).unwrap_err();
1911        assert!(
1912            format!("{}", err).contains("max_bytes"),
1913            "unexpected error: {}",
1914            err
1915        );
1916    }
1917
1918    #[test]
1919    fn problem_reader_rejects_line_over_max_line_len() {
1920        // Build a line of 200 chars; cap max_line_len at 50.
1921        let mut payload = String::from("+1 ");
1922        for i in 1..=50 {
1923            payload.push_str(&format!("{}:0.1 ", i));
1924        }
1925        payload.push('\n');
1926        let opts = LoadOptions {
1927            max_line_len: 50,
1928            ..LoadOptions::default()
1929        };
1930        let err = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap_err();
1931        assert!(
1932            format!("{}", err).contains("max_line_len"),
1933            "unexpected error: {}",
1934            err
1935        );
1936    }
1937
1938    #[test]
1939    fn problem_reader_accepts_line_at_max_line_len() {
1940        // A line whose content (excluding trailing newline) is exactly
1941        // max_line_len bytes must be accepted — the cap is inclusive.
1942        let line_content = "+1 1:0.5";
1943        let payload = format!("{}\n", line_content);
1944        let opts = LoadOptions {
1945            max_line_len: line_content.len(),
1946            ..LoadOptions::default()
1947        };
1948        let problem = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap();
1949        assert_eq!(problem.labels.len(), 1);
1950    }
1951
1952    #[test]
1953    fn problem_reader_tolerates_crlf_at_cap() {
1954        // Same content length, but with \r\n. Total bytes on disk is
1955        // content_len + 2; the helper allows one byte of slack for the \r.
1956        let line_content = "+1 1:0.5";
1957        let payload = format!("{}\r\n", line_content);
1958        let opts = LoadOptions {
1959            max_line_len: line_content.len(),
1960            ..LoadOptions::default()
1961        };
1962        let problem = load_problem_from_reader_with_options(payload.as_bytes(), &opts).unwrap();
1963        assert_eq!(problem.labels.len(), 1);
1964    }
1965
1966    #[test]
1967    fn problem_reader_rejects_nul_byte() {
1968        // A stray NUL byte inside a line is rejected early — LIBSVM text
1969        // files are ASCII, and a NUL is most likely truncated binary data.
1970        let mut payload: Vec<u8> = b"+1 1:0.5".to_vec();
1971        payload.push(0);
1972        payload.extend_from_slice(b"\n");
1973        let err = load_problem_from_reader(payload.as_slice()).unwrap_err();
1974        assert!(
1975            format!("{}", err).contains("NUL byte"),
1976            "unexpected error: {}",
1977            err
1978        );
1979    }
1980
1981    #[test]
1982    fn model_reader_honors_max_nr_class_cap() {
1983        // 100 is below both MAX_NR_CLASS and the 50 cap; the 50 cap should
1984        // fire first.
1985        let input = b"svm_type c_svc\n\
1986kernel_type linear\n\
1987nr_class 100\n\
1988total_sv 1\n\
1989rho 0\n\
1990SV\n";
1991        let opts = LoadOptions {
1992            max_nr_class: 50,
1993            ..LoadOptions::default()
1994        };
1995        let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
1996        assert!(
1997            format!("{}", err).contains("nr_class exceeds limit (50)"),
1998            "unexpected error: {}",
1999            err
2000        );
2001    }
2002
2003    #[test]
2004    fn model_reader_honors_max_sv_cap() {
2005        let input = b"svm_type c_svc\n\
2006kernel_type linear\n\
2007nr_class 2\n\
2008total_sv 1000\n\
2009rho 0\n\
2010SV\n";
2011        let opts = LoadOptions {
2012            max_sv: 100,
2013            ..LoadOptions::default()
2014        };
2015        let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
2016        assert!(
2017            format!("{}", err).contains("total_sv exceeds limit (100)"),
2018            "unexpected error: {}",
2019            err
2020        );
2021    }
2022
2023    #[test]
2024    fn trusted_input_cannot_exceed_hard_module_caps() {
2025        // `LoadOptions::trusted_input()` sets caps to usize::MAX / u64::MAX,
2026        // but the module-level hard caps (MAX_NR_CLASS, MAX_TOTAL_SV) still
2027        // apply as an upper bound. This is defense-in-depth.
2028        let huge_nr_class = format!(
2029            "svm_type c_svc\n\
2030kernel_type linear\n\
2031nr_class {}\n\
2032total_sv 1\n\
2033rho 0\n\
2034SV\n",
2035            MAX_NR_CLASS + 1
2036        );
2037        let opts = LoadOptions::trusted_input();
2038        let err = load_model_from_reader_with_options(huge_nr_class.as_bytes(), &opts).unwrap_err();
2039        assert!(
2040            format!("{}", err).contains("nr_class exceeds limit"),
2041            "unexpected error: {}",
2042            err
2043        );
2044    }
2045
2046    #[test]
2047    fn model_reader_honors_max_feature_index_cap() {
2048        let input = b"svm_type c_svc\n\
2049kernel_type linear\n\
2050nr_class 2\n\
2051total_sv 1\n\
2052rho 0\n\
2053label 1 -1\n\
2054nr_sv 1 0\n\
2055SV\n\
20560.1 50:0.5\n";
2057        let opts = LoadOptions {
2058            max_feature_index: 10,
2059            ..LoadOptions::default()
2060        };
2061        let err = load_model_from_reader_with_options(&input[..], &opts).unwrap_err();
2062        assert!(
2063            format!("{}", err).contains("feature index 50 exceeds limit (10)"),
2064            "unexpected error: {}",
2065            err
2066        );
2067    }
2068
2069    #[test]
2070    fn sv_count_loop_counts_nonblank_lines_only() {
2071        // A blank line inside the SV section must NOT be billed against
2072        // total_sv — parsing must terminate only after `total_sv` real SV
2073        // rows have been collected. This guards the loop rewrite from
2074        // `for _ in 0..total_sv { if empty { continue } ... }`, which would
2075        // silently accept a model with fewer SVs than the header claims.
2076        let input = b"svm_type c_svc\n\
2077kernel_type linear\n\
2078nr_class 2\n\
2079total_sv 2\n\
2080rho 0\n\
2081label 1 -1\n\
2082nr_sv 1 1\n\
2083SV\n\
2084\n\
20850.1 1:0.5\n\
2086\n\
2087-0.1 2:0.5\n";
2088        let model = load_model_from_reader(&input[..]).unwrap();
2089        assert_eq!(model.sv.len(), 2);
2090        assert_eq!(model.sv_coef[0].len(), 2);
2091    }
2092
2093    #[test]
2094    fn save_precomputed_model_writes_zero_index() {
2095        let model = SvmModel {
2096            param: SvmParameter {
2097                svm_type: SvmType::CSvc,
2098                kernel_type: KernelType::Precomputed,
2099                ..Default::default()
2100            },
2101            nr_class: 2,
2102            sv: vec![vec![SvmNode {
2103                index: 0,
2104                value: 7.0,
2105            }]],
2106            sv_coef: vec![vec![0.25]],
2107            rho: vec![0.0],
2108            prob_a: vec![],
2109            prob_b: vec![],
2110            prob_density_marks: vec![],
2111            sv_indices: vec![],
2112            label: vec![1, -1],
2113            n_sv: vec![1, 0],
2114        };
2115
2116        let mut buf = Vec::new();
2117        save_model_to_writer(&mut buf, &model).unwrap();
2118        let out = String::from_utf8(buf).unwrap();
2119        assert!(out.contains("kernel_type precomputed"));
2120        assert!(out.contains("0:7"));
2121    }
2122
2123    // ─── F1: classification models must carry label + nr_sv ─────────────
2124
2125    /// A c_svc model that omits both `label` and `nr_sv` must return a
2126    /// ParseError rather than panic in predict.rs.
2127    #[test]
2128    fn csvc_missing_label_and_nr_sv_returns_error_not_panic() {
2129        let input = b"svm_type c_svc\n\
2130kernel_type linear\n\
2131nr_class 2\n\
2132total_sv 1\n\
2133rho 0\n\
2134SV\n\
21351 1:1\n";
2136        let err = load_model_from_reader(&input[..]).unwrap_err();
2137        let msg = format!("{}", err);
2138        assert!(
2139            msg.contains("label has 0 entries, expected nr_class (2)"),
2140            "unexpected error: {}",
2141            msg
2142        );
2143    }
2144
2145    /// A c_svc model with `label` present but `nr_sv` absent must also fail.
2146    #[test]
2147    fn csvc_label_present_nr_sv_absent_returns_error() {
2148        let input = b"svm_type c_svc\n\
2149kernel_type linear\n\
2150nr_class 2\n\
2151total_sv 1\n\
2152rho 0\n\
2153label 1 -1\n\
2154SV\n\
21551 1:1\n";
2156        let err = load_model_from_reader(&input[..]).unwrap_err();
2157        let msg = format!("{}", err);
2158        assert!(
2159            msg.contains("nr_sv has 0 entries, expected nr_class (2)"),
2160            "unexpected error: {}",
2161            msg
2162        );
2163    }
2164
2165    /// A valid one_class model with no label/nr_sv lines must still load Ok
2166    /// (regression guard: the new classification-only requirement must not
2167    /// affect one-class models).
2168    #[test]
2169    fn one_class_without_label_and_nr_sv_loads_ok() {
2170        // one_class uses m = nr_class - 1 = 1 sv_coef column per SV row.
2171        let input = b"svm_type one_class\n\
2172kernel_type rbf\n\
2173gamma 0.5\n\
2174nr_class 2\n\
2175total_sv 1\n\
2176rho -0.5\n\
2177SV\n\
21781 1:0.5\n";
2179        let model = load_model_from_reader(&input[..]).unwrap();
2180        assert_eq!(model.sv.len(), 1);
2181        assert_eq!(model.label, Vec::<i32>::new());
2182        assert_eq!(model.n_sv, Vec::<usize>::new());
2183    }
2184
2185    // ─── F3: non-finite numeric fields must be rejected ──────────────────
2186
2187    /// A model with `rho inf` must fail with a ParseError.
2188    #[test]
2189    fn model_rho_inf_returns_error() {
2190        let input = b"svm_type c_svc\n\
2191kernel_type linear\n\
2192nr_class 2\n\
2193total_sv 1\n\
2194rho inf\n\
2195label 1 -1\n\
2196nr_sv 1 0\n\
2197SV\n\
21981 1:0.5\n";
2199        let err = load_model_from_reader(&input[..]).unwrap_err();
2200        let msg = format!("{}", err);
2201        assert!(
2202            msg.contains("rho must be finite"),
2203            "unexpected error: {}",
2204            msg
2205        );
2206    }
2207
2208    /// A model with `1:nan` in an SV feature must fail with a ParseError.
2209    #[test]
2210    fn model_sv_feature_nan_returns_error() {
2211        let input = b"svm_type c_svc\n\
2212kernel_type linear\n\
2213nr_class 2\n\
2214total_sv 1\n\
2215rho 0\n\
2216label 1 -1\n\
2217nr_sv 1 0\n\
2218SV\n\
22191 1:nan\n";
2220        let err = load_model_from_reader(&input[..]).unwrap_err();
2221        let msg = format!("{}", err);
2222        assert!(
2223            msg.contains("feature value must be finite"),
2224            "unexpected error: {}",
2225            msg
2226        );
2227    }
2228}