ffsvm/
parser.rs

1use crate::errors::Error;
2use std::{convert::TryFrom, str};
3
4/// Parsing result of a model file used to instantiate a [`DenseSVM`](`crate::DenseSVM`) or [`SparseSVM`](`crate::SparseSVM`).
5///
6/// # Obtaining Models
7/// A model file is produced by [libSVM](https://github.com/cjlin1/libsvm). For details
8/// how to produce a model see the top-level [FFSVM](index.html#creating-a-libsvm-model)
9/// documentation.
10///
11/// # Loading Models
12///
13/// Models are generally produced by parsing a [`&str`] using the [`ModelFile::try_from`] function:
14///
15/// ```rust
16/// use ffsvm::ModelFile;
17/// # use ffsvm::SAMPLE_MODEL;
18///
19/// let model_result = ModelFile::try_from(SAMPLE_MODEL);
20/// ```
21///
22/// Should anything be wrong with the model format, an [`Error`] will be returned. Once you have
23/// your model, you can use it to create an SVM, for example by invoking `DenseSVM::try_from(model)`.
24///
25/// # Model Format
26///
27/// For FFSVM to load a model, it needs to look approximately like below. Note that you cannot
28/// reasonably create this model by hand, it needs to come from [libSVM](https://github.com/cjlin1/libsvm).
29///
30/// ```text
31/// svm_type c_svc
32/// kernel_type rbf
33/// gamma 1
34/// nr_class 2
35/// total_sv 3012
36/// rho -2.90877
37/// label 0 1
38/// probA -1.55583
39/// probB 0.0976659
40/// nr_sv 1513 1499
41/// SV
42/// 256 0:0.5106233 1:0.1584117 2:0.1689098 3:0.1664358 4:0.2327561 5:0 6:0 7:0 8:1 9:0.1989241
43/// 256 0:0.5018305 1:0.0945542 2:0.09242307 3:0.09439687 4:0.1398575 5:0 6:0 7:0 8:1 9:1
44/// 256 0:0.5020829 1:0 2:0 3:0 4:0.1393665 5:1 6:0 7:0 8:1 9:0
45/// 256 0:0.4933203 1:0.1098869 2:0.1048947 3:0.1069601 4:0.2152338 5:0 6:0 7:0 8:1 9:1
46/// ```
47///
48/// Apart from "one-class SVM" (`-s 2` in libSVM) and "precomputed kernel" (`-t 4`) all
49/// generated libSVM models should be supported.
50///
51/// However, note that for the [`DenseSVM`](`crate::DenseSVM`) to work, all support vectors
52/// (past the `SV` line) must have **strictly** increasing attribute identifiers starting at `0`,
53/// without skipping an attribute. In other words, your attributes have to be named `0:`, `1:`,
54/// `2:`, ... `n:` and not, say, `0:`, `1:`, `4:`, ... `n:`.
55#[derive(Clone, Debug, Default)]
56pub struct ModelFile<'a> {
57    header: Header<'a>,
58    vectors: Vec<SupportVector>,
59}
60
61impl<'a> ModelFile<'a> {
62    #[doc(hidden)]
63    #[must_use]
64    pub const fn new(header: Header<'a>, vectors: Vec<SupportVector>) -> Self {
65        Self { header, vectors }
66    }
67
68    #[doc(hidden)]
69    #[must_use]
70    pub const fn header(&self) -> &Header {
71        &self.header
72    }
73
74    #[doc(hidden)]
75    #[must_use]
76    pub fn vectors(&self) -> &[SupportVector] {
77        self.vectors.as_slice()
78    }
79}
80
81#[doc(hidden)]
82#[derive(Clone, Debug, Default)]
83pub struct Header<'a> {
84    pub svm_type: &'a str,
85    pub kernel_type: &'a str,
86    pub gamma: Option<f32>,
87    pub coef0: Option<f32>,
88    pub degree: Option<u32>,
89    pub nr_class: u32,
90    pub total_sv: u32,
91    pub rho: Vec<f64>,
92    pub label: Vec<i32>,
93    pub prob_a: Option<Vec<f64>>,
94    pub prob_b: Option<Vec<f64>>,
95    pub nr_sv: Vec<u32>,
96}
97
98#[doc(hidden)]
99#[derive(Copy, Clone, Debug, Default)]
100pub struct Attribute {
101    pub value: f32,
102    pub index: u32,
103}
104
105#[doc(hidden)]
106#[derive(Clone, Debug, Default)]
107pub struct SupportVector {
108    pub coefs: Vec<f32>,
109    pub features: Vec<Attribute>,
110}
111
112impl<'a> TryFrom<&'a str> for ModelFile<'a> {
113    type Error = Error;
114
115    /// Parses a string into an SVM model
116    #[allow(clippy::similar_names)]
117    fn try_from(input: &str) -> Result<ModelFile<'_>, Error> {
118        let mut svm_type = Option::None;
119        let mut kernel_type = Option::None;
120        let mut gamma = Option::None;
121        let mut coef0 = Option::None;
122        let mut degree = Option::None;
123        let mut nr_class = Option::None;
124        let mut total_sv = Option::None;
125        let mut rho = Vec::new();
126        let mut label = Vec::new();
127        let mut prob_a = Option::None;
128        let mut prob_b = Option::None;
129        let mut nr_sv = Vec::new();
130
131        let mut vectors = Vec::new();
132
133        for line in input.lines() {
134            let tokens = line.split_whitespace().collect::<Vec<_>>();
135
136            match tokens.first() {
137                // Single value headers
138                //
139                // svm_type c_svc
140                // kernel_type rbf
141                // gamma 0.5
142                // nr_class 6
143                // total_sv 153
144                // rho 2.37333 -0.579888 0.535784 0.0701838 0.609329 -0.932983 -0.427481 -1.15801 -0.108324 0.486988 -0.0642337 0.52711 -0.292071 0.214309 0.880031
145                // label 1 2 3 5 6 7
146                // probA -1.26241 -2.09056 -3.04781 -2.49489 -2.79378 -2.55612 -1.80921 -1.90492 -2.6911 -2.67778 -2.15836 -2.53895 -2.21813 -2.03491 -1.91923
147                // probB 0.135634 0.570051 -0.114691 -0.397667 0.0687938 0.839527 -0.310816 -0.787629 0.0335196 0.15079 -0.389211 0.288416 0.186429 0.46585 0.547398
148                // nr_sv 50 56 17 11 7 12
149                // SV
150                Some(x) if *x == "svm_type" => {
151                    svm_type = Some(tokens[1]);
152                }
153                Some(x) if *x == "kernel_type" => {
154                    kernel_type = Some(tokens[1]);
155                }
156                Some(x) if *x == "gamma" => {
157                    gamma = tokens[1].parse::<f32>().ok();
158                }
159                Some(x) if *x == "coef0" => {
160                    coef0 = tokens[1].parse::<f32>().ok();
161                }
162                Some(x) if *x == "degree" => {
163                    degree = tokens[1].parse::<u32>().ok();
164                }
165                Some(x) if *x == "nr_class" => {
166                    nr_class = tokens[1].parse::<u32>().ok();
167                }
168                Some(x) if *x == "total_sv" => {
169                    total_sv = tokens[1].parse::<u32>().ok();
170                }
171                // Multi value headers
172                Some(x) if *x == "rho" => rho = tokens.iter().skip(1).filter_map(|x| x.parse::<f64>().ok()).collect(),
173                Some(x) if *x == "label" => label = tokens.iter().skip(1).filter_map(|x| x.parse::<i32>().ok()).collect(),
174                Some(x) if *x == "nr_sv" => nr_sv = tokens.iter().skip(1).filter_map(|x| x.parse::<u32>().ok()).collect(),
175                Some(x) if *x == "probA" => prob_a = Some(tokens.iter().skip(1).filter_map(|x| x.parse::<f64>().ok()).collect()),
176                Some(x) if *x == "probB" => prob_b = Some(tokens.iter().skip(1).filter_map(|x| x.parse::<f64>().ok()).collect()),
177                // Header separator
178                Some(x) if *x == "SV" => {}
179                // These are all regular lines without a clear header (after SV) ...
180                //
181                // 0.0625 0:0.6619648 1:0.8464851 2:0.4801146 3:0 4:0 5:0.02131653 6:0 7:0 8:0 9:0 10:0 11:0 12:0 13:0 14:0 15:0.5579834 16:0.1106567 17:0 18:0 19:0 20:0
182                // 0.0625 0:0.5861949 1:0.5556895 2:0.619291 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 11:0.5977631 12:0 13:0 14:0 15:0.6203156 16:0 17:0 18:0 19:0.1964417 20:0
183                // 0.0625 0:0.44675 1:0.4914977 2:0.4227562 3:0.2904663 4:0.2904663 5:0.268158 6:0 7:0 8:0 9:0 10:0 11:0.6202393 12:0.0224762 13:0 14:0 15:0.6427917 16:0.0224762 17:0 18:0 19:0.1739655 20:0
184                Some(_) => {
185                    let mut sv = SupportVector {
186                        coefs: Vec::new(),
187                        features: Vec::new(),
188                    };
189
190                    let (features, coefs): (Vec<&str>, Vec<&str>) = tokens.iter().partition(|x| x.contains(':'));
191
192                    sv.coefs = coefs.iter().filter_map(|x| x.parse::<f32>().ok()).collect();
193                    sv.features = features
194                        .iter()
195                        .filter_map(|x| {
196                            let split = x.split(':').collect::<Vec<&str>>();
197
198                            Some(Attribute {
199                                index: split.first()?.parse::<u32>().ok()?,
200                                value: split.get(1)?.parse::<f32>().ok()?,
201                            })
202                        })
203                        .collect();
204
205                    vectors.push(sv);
206                }
207
208                // Empty end of file
209                None => break,
210            }
211        }
212
213        Ok(ModelFile {
214            header: Header {
215                svm_type: svm_type.ok_or(Error::MissingRequiredAttribute)?,
216                kernel_type: kernel_type.ok_or(Error::MissingRequiredAttribute)?,
217                gamma,
218                coef0,
219                degree,
220                nr_class: nr_class.ok_or(Error::MissingRequiredAttribute)?,
221                total_sv: total_sv.ok_or(Error::MissingRequiredAttribute)?,
222                rho,
223                label,
224                prob_a,
225                prob_b,
226                nr_sv,
227            },
228            vectors,
229        })
230    }
231}