ffsvm/parser.rs
1use crate::errors::Error;
2use std::{convert::TryFrom, str};
3
4/// Parsing result of a model file used to instantiate a [`DenseSVM`](`crate::DenseSVM`) or [`SparseSVM`](`crate::SparseSVM`).
5///
6/// # Obtaining Models
7/// A model file is produced by [libSVM](https://github.com/cjlin1/libsvm). For details
8/// how to produce a model see the top-level [FFSVM](index.html#creating-a-libsvm-model)
9/// documentation.
10///
11/// # Loading Models
12///
13/// Models are generally produced by parsing a [`&str`] using the [`ModelFile::try_from`] function:
14///
15/// ```rust
16/// use ffsvm::ModelFile;
17/// # use ffsvm::SAMPLE_MODEL;
18///
19/// let model_result = ModelFile::try_from(SAMPLE_MODEL);
20/// ```
21///
22/// Should anything be wrong with the model format, an [`Error`] will be returned. Once you have
23/// your model, you can use it to create an SVM, for example by invoking `DenseSVM::try_from(model)`.
24///
25/// # Model Format
26///
27/// For FFSVM to load a model, it needs to look approximately like below. Note that you cannot
28/// reasonably create this model by hand, it needs to come from [libSVM](https://github.com/cjlin1/libsvm).
29///
30/// ```text
31/// svm_type c_svc
32/// kernel_type rbf
33/// gamma 1
34/// nr_class 2
35/// total_sv 3012
36/// rho -2.90877
37/// label 0 1
38/// probA -1.55583
39/// probB 0.0976659
40/// nr_sv 1513 1499
41/// SV
42/// 256 0:0.5106233 1:0.1584117 2:0.1689098 3:0.1664358 4:0.2327561 5:0 6:0 7:0 8:1 9:0.1989241
43/// 256 0:0.5018305 1:0.0945542 2:0.09242307 3:0.09439687 4:0.1398575 5:0 6:0 7:0 8:1 9:1
44/// 256 0:0.5020829 1:0 2:0 3:0 4:0.1393665 5:1 6:0 7:0 8:1 9:0
45/// 256 0:0.4933203 1:0.1098869 2:0.1048947 3:0.1069601 4:0.2152338 5:0 6:0 7:0 8:1 9:1
46/// ```
47///
48/// Apart from "one-class SVM" (`-s 2` in libSVM) and "precomputed kernel" (`-t 4`) all
49/// generated libSVM models should be supported.
50///
51/// However, note that for the [`DenseSVM`](`crate::DenseSVM`) to work, all support vectors
52/// (past the `SV` line) must have **strictly** increasing attribute identifiers starting at `0`,
53/// without skipping an attribute. In other words, your attributes have to be named `0:`, `1:`,
54/// `2:`, ... `n:` and not, say, `0:`, `1:`, `4:`, ... `n:`.
55#[derive(Clone, Debug, Default)]
56pub struct ModelFile<'a> {
57 header: Header<'a>,
58 vectors: Vec<SupportVector>,
59}
60
61impl<'a> ModelFile<'a> {
62 #[doc(hidden)]
63 #[must_use]
64 pub const fn new(header: Header<'a>, vectors: Vec<SupportVector>) -> Self {
65 Self { header, vectors }
66 }
67
68 #[doc(hidden)]
69 #[must_use]
70 pub const fn header(&self) -> &Header {
71 &self.header
72 }
73
74 #[doc(hidden)]
75 #[must_use]
76 pub fn vectors(&self) -> &[SupportVector] {
77 self.vectors.as_slice()
78 }
79}
80
81#[doc(hidden)]
82#[derive(Clone, Debug, Default)]
83pub struct Header<'a> {
84 pub svm_type: &'a str,
85 pub kernel_type: &'a str,
86 pub gamma: Option<f32>,
87 pub coef0: Option<f32>,
88 pub degree: Option<u32>,
89 pub nr_class: u32,
90 pub total_sv: u32,
91 pub rho: Vec<f64>,
92 pub label: Vec<i32>,
93 pub prob_a: Option<Vec<f64>>,
94 pub prob_b: Option<Vec<f64>>,
95 pub nr_sv: Vec<u32>,
96}
97
98#[doc(hidden)]
99#[derive(Copy, Clone, Debug, Default)]
100pub struct Attribute {
101 pub value: f32,
102 pub index: u32,
103}
104
105#[doc(hidden)]
106#[derive(Clone, Debug, Default)]
107pub struct SupportVector {
108 pub coefs: Vec<f32>,
109 pub features: Vec<Attribute>,
110}
111
112impl<'a> TryFrom<&'a str> for ModelFile<'a> {
113 type Error = Error;
114
115 /// Parses a string into an SVM model
116 #[allow(clippy::similar_names)]
117 fn try_from(input: &str) -> Result<ModelFile<'_>, Error> {
118 let mut svm_type = Option::None;
119 let mut kernel_type = Option::None;
120 let mut gamma = Option::None;
121 let mut coef0 = Option::None;
122 let mut degree = Option::None;
123 let mut nr_class = Option::None;
124 let mut total_sv = Option::None;
125 let mut rho = Vec::new();
126 let mut label = Vec::new();
127 let mut prob_a = Option::None;
128 let mut prob_b = Option::None;
129 let mut nr_sv = Vec::new();
130
131 let mut vectors = Vec::new();
132
133 for line in input.lines() {
134 let tokens = line.split_whitespace().collect::<Vec<_>>();
135
136 match tokens.first() {
137 // Single value headers
138 //
139 // svm_type c_svc
140 // kernel_type rbf
141 // gamma 0.5
142 // nr_class 6
143 // total_sv 153
144 // rho 2.37333 -0.579888 0.535784 0.0701838 0.609329 -0.932983 -0.427481 -1.15801 -0.108324 0.486988 -0.0642337 0.52711 -0.292071 0.214309 0.880031
145 // label 1 2 3 5 6 7
146 // probA -1.26241 -2.09056 -3.04781 -2.49489 -2.79378 -2.55612 -1.80921 -1.90492 -2.6911 -2.67778 -2.15836 -2.53895 -2.21813 -2.03491 -1.91923
147 // probB 0.135634 0.570051 -0.114691 -0.397667 0.0687938 0.839527 -0.310816 -0.787629 0.0335196 0.15079 -0.389211 0.288416 0.186429 0.46585 0.547398
148 // nr_sv 50 56 17 11 7 12
149 // SV
150 Some(x) if *x == "svm_type" => {
151 svm_type = Some(tokens[1]);
152 }
153 Some(x) if *x == "kernel_type" => {
154 kernel_type = Some(tokens[1]);
155 }
156 Some(x) if *x == "gamma" => {
157 gamma = tokens[1].parse::<f32>().ok();
158 }
159 Some(x) if *x == "coef0" => {
160 coef0 = tokens[1].parse::<f32>().ok();
161 }
162 Some(x) if *x == "degree" => {
163 degree = tokens[1].parse::<u32>().ok();
164 }
165 Some(x) if *x == "nr_class" => {
166 nr_class = tokens[1].parse::<u32>().ok();
167 }
168 Some(x) if *x == "total_sv" => {
169 total_sv = tokens[1].parse::<u32>().ok();
170 }
171 // Multi value headers
172 Some(x) if *x == "rho" => rho = tokens.iter().skip(1).filter_map(|x| x.parse::<f64>().ok()).collect(),
173 Some(x) if *x == "label" => label = tokens.iter().skip(1).filter_map(|x| x.parse::<i32>().ok()).collect(),
174 Some(x) if *x == "nr_sv" => nr_sv = tokens.iter().skip(1).filter_map(|x| x.parse::<u32>().ok()).collect(),
175 Some(x) if *x == "probA" => prob_a = Some(tokens.iter().skip(1).filter_map(|x| x.parse::<f64>().ok()).collect()),
176 Some(x) if *x == "probB" => prob_b = Some(tokens.iter().skip(1).filter_map(|x| x.parse::<f64>().ok()).collect()),
177 // Header separator
178 Some(x) if *x == "SV" => {}
179 // These are all regular lines without a clear header (after SV) ...
180 //
181 // 0.0625 0:0.6619648 1:0.8464851 2:0.4801146 3:0 4:0 5:0.02131653 6:0 7:0 8:0 9:0 10:0 11:0 12:0 13:0 14:0 15:0.5579834 16:0.1106567 17:0 18:0 19:0 20:0
182 // 0.0625 0:0.5861949 1:0.5556895 2:0.619291 3:0 4:0 5:0 6:0 7:0 8:0 9:0 10:0 11:0.5977631 12:0 13:0 14:0 15:0.6203156 16:0 17:0 18:0 19:0.1964417 20:0
183 // 0.0625 0:0.44675 1:0.4914977 2:0.4227562 3:0.2904663 4:0.2904663 5:0.268158 6:0 7:0 8:0 9:0 10:0 11:0.6202393 12:0.0224762 13:0 14:0 15:0.6427917 16:0.0224762 17:0 18:0 19:0.1739655 20:0
184 Some(_) => {
185 let mut sv = SupportVector {
186 coefs: Vec::new(),
187 features: Vec::new(),
188 };
189
190 let (features, coefs): (Vec<&str>, Vec<&str>) = tokens.iter().partition(|x| x.contains(':'));
191
192 sv.coefs = coefs.iter().filter_map(|x| x.parse::<f32>().ok()).collect();
193 sv.features = features
194 .iter()
195 .filter_map(|x| {
196 let split = x.split(':').collect::<Vec<&str>>();
197
198 Some(Attribute {
199 index: split.first()?.parse::<u32>().ok()?,
200 value: split.get(1)?.parse::<f32>().ok()?,
201 })
202 })
203 .collect();
204
205 vectors.push(sv);
206 }
207
208 // Empty end of file
209 None => break,
210 }
211 }
212
213 Ok(ModelFile {
214 header: Header {
215 svm_type: svm_type.ok_or(Error::MissingRequiredAttribute)?,
216 kernel_type: kernel_type.ok_or(Error::MissingRequiredAttribute)?,
217 gamma,
218 coef0,
219 degree,
220 nr_class: nr_class.ok_or(Error::MissingRequiredAttribute)?,
221 total_sv: total_sv.ok_or(Error::MissingRequiredAttribute)?,
222 rho,
223 label,
224 prob_a,
225 prob_b,
226 nr_sv,
227 },
228 vectors,
229 })
230 }
231}