1use crate::sparse::{SparseMatrix, SparseVector};
2
3use simd_aligned::traits::Simd;
4use std::convert::TryFrom;
5
6use crate::{
7 errors::Error,
8 parser::ModelFile,
9 svm::{
10 class::Class,
11 features::{FeatureVector, Label},
12 kernel::{KernelSparse, Linear, Poly, Rbf, Sigmoid},
13 predict::Predict,
14 Probabilities, SVMType,
15 },
16 util::{find_max_index, set_all, sigmoid_predict},
17 vectors::Triangular,
18};
19
20pub struct SparseSVM {
33 pub(crate) num_total_sv: usize,
35
36 pub(crate) num_attributes: usize,
38
39 pub(crate) rho: Triangular<f64>,
40
41 pub(crate) probabilities: Option<Probabilities>,
42
43 pub(crate) svm_type: SVMType,
44
45 pub(crate) kernel: Box<dyn KernelSparse>,
47
48 pub(crate) classes: Vec<Class<SparseMatrix<f32>>>,
50}
51
52impl SparseSVM {
53 #[must_use]
67 pub fn class_index_for_label(&self, label: i32) -> Option<usize> {
68 for (i, class) in self.classes.iter().enumerate() {
69 if class.label != label {
70 continue;
71 }
72
73 return Some(i);
74 }
75
76 None
77 }
78
79 #[must_use]
92 pub fn class_label_for_index(&self, index: usize) -> Option<i32> {
93 if index >= self.classes.len() {
94 None
95 } else {
96 Some(self.classes[index].label)
97 }
98 }
99
100 pub(crate) fn compute_kernel_values(&self, problem: &mut FeatureVector<SparseVector<f32>>) {
102 let features = &problem.features;
104 let kernel_values = &mut problem.kernel_values;
105
106 for (i, class) in self.classes.iter().enumerate() {
108 let kvalues = kernel_values.row_as_flat_mut(i);
109
110 self.kernel.compute(&class.support_vectors, features, kvalues);
111 }
112 }
113
114 pub(crate) fn compute_multiclass_probabilities(&self, problem: &mut FeatureVector<SparseVector<f32>>) -> Result<(), Error> {
120 compute_multiclass_probabilities_impl!(self, problem)
121 }
122
123 pub(crate) fn compute_classification_values(&self, problem: &mut FeatureVector<SparseVector<f32>>) {
125 compute_classification_values_impl!(self, problem);
126 }
127
128 pub(crate) fn compute_regression_values(&self, problem: &mut FeatureVector<SparseVector<f32>>) {
130 let class = &self.classes[0];
131 let coef = class.coefficients.row(0);
132 let kvalues = problem.kernel_values.row(0);
133
134 let mut sum = coef.iter().zip(kvalues).map(|(a, b)| (*a * *b).sum()).sum::<f64>();
135
136 sum -= self.rho[0];
137
138 problem.result = Label::Value(sum as f32);
139 }
140
141 #[must_use]
143 pub const fn attributes(&self) -> usize {
144 self.num_attributes
145 }
146
147 #[must_use]
149 pub fn classes(&self) -> usize {
150 self.classes.len()
151 }
152}
153
154impl Predict<SparseVector<f32>> for SparseSVM {
155 fn predict_value(&self, problem: &mut FeatureVector<SparseVector<f32>>) -> Result<(), Error> {
157 match self.svm_type {
158 SVMType::CSvc | SVMType::NuSvc => {
159 self.compute_kernel_values(problem);
161 self.compute_classification_values(problem);
162
163 let highest_vote = find_max_index(&problem.vote);
165 problem.result = Label::Class(self.classes[highest_vote].label);
166
167 Ok(())
168 }
169 SVMType::ESvr | SVMType::NuSvr => {
170 self.compute_kernel_values(problem);
171 self.compute_regression_values(problem);
172 Ok(())
173 }
174 }
175 }
176
177 fn predict_probability(&self, problem: &mut FeatureVector<SparseVector<f32>>) -> Result<(), Error> {
178 predict_probability_impl!(self, problem)
179 }
180}
181
182impl<'a> TryFrom<&'a str> for SparseSVM {
183 type Error = Error;
184
185 fn try_from(input: &'a str) -> Result<Self, Error> {
186 let raw_model = ModelFile::try_from(input)?;
187 Self::try_from(&raw_model)
188 }
189}
190
191impl<'a> TryFrom<&'a ModelFile<'_>> for SparseSVM {
192 type Error = Error;
193
194 fn try_from(raw_model: &'a ModelFile<'_>) -> Result<Self, Error> {
195 let (mut svm, nr_sv) = prepare_svm!(raw_model, dyn KernelSparse, SparseMatrix<f32>, Self);
196
197 let vectors = &raw_model.vectors();
198
199 let mut start_offset = 0;
202
203 for (i, num_sv_per_class) in nr_sv.iter().enumerate() {
205 let stop_offset = start_offset + *num_sv_per_class as usize;
206
207 for (i_vector, vector) in vectors[start_offset..stop_offset].iter().enumerate() {
209 for attribute in &vector.features {
211 let support_vectors = &mut svm.classes[i].support_vectors;
212 support_vectors[(i_vector, attribute.index as usize)] = attribute.value;
213 }
214
215 for (i_coefficient, coefficient) in vector.coefs.iter().enumerate() {
217 let mut coefficients = svm.classes[i].coefficients.flat_mut();
218 coefficients[(i_coefficient, i_vector)] = f64::from(*coefficient);
219 }
220 }
221
222 start_offset = stop_offset;
224 }
225
226 Ok(svm)
228 }
229}