use crate::errors::Error;
use std::{convert::TryFrom, str};
#[derive(Clone, Debug, Default)]
pub struct ModelFile<'a> {
pub header: Header<'a>,
pub vectors: Vec<SupportVector>,
}
#[doc(hidden)]
#[derive(Clone, Debug, Default)]
pub struct Header<'a> {
pub svm_type: &'a str,
pub kernel_type: &'a str,
pub gamma: Option<f32>,
pub coef0: Option<f32>,
pub degree: Option<u32>,
pub nr_class: u32,
pub total_sv: u32,
pub rho: Vec<f64>,
pub label: Vec<i32>,
pub prob_a: Option<Vec<f64>>,
pub prob_b: Option<Vec<f64>>,
pub nr_sv: Vec<u32>,
}
#[doc(hidden)]
#[derive(Copy, Clone, Debug, Default)]
pub struct Attribute {
pub value: f32,
pub index: u32,
}
#[doc(hidden)]
#[derive(Clone, Debug, Default)]
pub struct SupportVector {
pub coefs: Vec<f32>,
pub features: Vec<Attribute>,
}
impl<'a> TryFrom<&'a str> for ModelFile<'a> {
type Error = Error;
#[allow(clippy::similar_names)]
fn try_from(input: &str) -> Result<ModelFile<'_>, Error> {
let mut svm_type = Option::None;
let mut kernel_type = Option::None;
let mut gamma = Option::None;
let mut coef0 = Option::None;
let mut degree = Option::None;
let mut nr_class = Option::None;
let mut total_sv = Option::None;
let mut rho = Vec::new();
let mut label = Vec::new();
let mut prob_a = Option::None;
let mut prob_b = Option::None;
let mut nr_sv = Vec::new();
let mut vectors = Vec::new();
for line in input.lines() {
let tokens = line.split_whitespace().collect::<Vec<_>>();
match tokens.get(0) {
Some(x) if *x == "svm_type" => {
svm_type = Some(tokens[1]);
}
Some(x) if *x == "kernel_type" => {
kernel_type = Some(tokens[1]);
}
Some(x) if *x == "gamma" => {
gamma = tokens[1].parse::<f32>().ok();
}
Some(x) if *x == "coef0" => {
coef0 = tokens[1].parse::<f32>().ok();
}
Some(x) if *x == "degree" => {
degree = tokens[1].parse::<u32>().ok();
}
Some(x) if *x == "nr_class" => {
nr_class = tokens[1].parse::<u32>().ok();
}
Some(x) if *x == "total_sv" => {
total_sv = tokens[1].parse::<u32>().ok();
}
Some(x) if *x == "rho" => rho = tokens.iter().skip(1).filter_map(|x| x.parse::<f64>().ok()).collect(),
Some(x) if *x == "label" => label = tokens.iter().skip(1).filter_map(|x| x.parse::<i32>().ok()).collect(),
Some(x) if *x == "nr_sv" => nr_sv = tokens.iter().skip(1).filter_map(|x| x.parse::<u32>().ok()).collect(),
Some(x) if *x == "probA" => prob_a = Some(tokens.iter().skip(1).filter_map(|x| x.parse::<f64>().ok()).collect()),
Some(x) if *x == "probB" => prob_b = Some(tokens.iter().skip(1).filter_map(|x| x.parse::<f64>().ok()).collect()),
Some(x) if *x == "SV" => {}
Some(_) => {
let mut sv = SupportVector {
coefs: Vec::new(),
features: Vec::new(),
};
let (features, coefs): (Vec<&str>, Vec<&str>) = tokens.iter().partition(|x| x.contains(':'));
sv.coefs = coefs.iter().filter_map(|x| x.parse::<f32>().ok()).collect();
sv.features = features
.iter()
.filter_map(|x| {
let split = x.split(':').collect::<Vec<&str>>();
Some(Attribute {
index: split.get(0)?.parse::<u32>().ok()?,
value: split.get(1)?.parse::<f32>().ok()?,
})
})
.collect();
vectors.push(sv);
}
None => break,
}
}
Ok(ModelFile {
header: Header {
svm_type: svm_type?,
kernel_type: kernel_type?,
gamma,
coef0,
degree,
nr_class: nr_class?,
total_sv: total_sv?,
rho,
label,
prob_a,
prob_b,
nr_sv,
},
vectors,
})
}
}