use std::path::PathBuf;
use libsvm_rs::io::{
load_model, load_model_from_reader_with_options, load_problem,
load_problem_from_reader_with_options, LoadOptions,
};
fn fixture(name: &str) -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("malicious")
.join(name)
}
fn assert_err_contains<E: std::fmt::Display, T: std::fmt::Debug>(
result: Result<T, E>,
needle: &str,
) {
match result {
Ok(v) => panic!(
"expected loader error containing {:?}, got successful parse: {:?}",
needle, v
),
Err(e) => {
let msg = format!("{}", e);
assert!(
msg.contains(needle),
"expected error to contain {:?}, got: {}",
needle,
msg
);
}
}
}
#[test]
fn rejects_huge_total_sv_preallocation_vector() {
assert_err_contains(
load_model(&fixture("huge_total_sv.model")),
"total_sv exceeds limit",
);
}
#[test]
fn rejects_huge_nr_class() {
assert_err_contains(
load_model(&fixture("huge_nr_class.model")),
"nr_class exceeds limit",
);
}
#[test]
fn rejects_mismatched_n_sv_sum() {
assert_err_contains(
load_model(&fixture("mismatched_n_sv_sum.model")),
"sum of nr_sv entries",
);
}
#[test]
fn rejects_rho_length_mismatch() {
assert_err_contains(
load_model(&fixture("rho_length_mismatch.model")),
"rho has 1 entries, expected 3",
);
}
#[test]
fn rejects_label_on_regression() {
assert_err_contains(
load_model(&fixture("label_on_regression.model")),
"label is only valid for classification",
);
}
#[test]
fn rejects_prob_density_marks_on_csvc() {
assert_err_contains(
load_model(&fixture("prob_density_marks_on_csvc.model")),
"prob_density_marks is only valid for one-class SVM",
);
}
#[test]
fn rejects_sv_feature_indices_not_ascending() {
assert_err_contains(
load_model(&fixture("sv_feature_indices_not_ascending.model")),
"feature indices must be ascending",
);
}
#[test]
fn rejects_feature_index_out_of_range() {
assert_err_contains(
load_problem(&fixture("feature_index_out_of_range.libsvm")),
"feature index 10000001 exceeds limit",
);
}
#[test]
fn rejects_line_over_max_line_len() {
let path = fixture("long_line.libsvm");
let bytes = std::fs::read(&path).unwrap();
let opts = LoadOptions {
max_line_len: 100,
..LoadOptions::default()
};
assert_err_contains(
load_problem_from_reader_with_options(bytes.as_slice(), &opts),
"max_line_len",
);
}
#[test]
fn rejects_file_over_max_bytes() {
let path = fixture("long_line.libsvm");
let bytes = std::fs::read(&path).unwrap();
let opts = LoadOptions {
max_bytes: 64,
..LoadOptions::default()
};
assert_err_contains(
load_problem_from_reader_with_options(bytes.as_slice(), &opts),
"max_bytes",
);
}
#[test]
fn rejects_nul_byte_in_problem_line() {
let payload: &[u8] = b"+1 1:0.5\0\n";
assert_err_contains(
load_problem_from_reader_with_options(payload, &LoadOptions::default()),
"NUL byte",
);
}
#[test]
fn rejects_nul_byte_in_model_sv_section() {
let mut payload: Vec<u8> =
b"svm_type c_svc\nkernel_type linear\nnr_class 2\ntotal_sv 1\nrho 0\nSV\n0.1 1:0.5"
.to_vec();
payload.push(0);
payload.push(b'\n');
assert_err_contains(
load_model_from_reader_with_options(payload.as_slice(), &LoadOptions::default()),
"NUL byte",
);
}