1use crate::kernel::k_function;
10use crate::probability::{multiclass_probability, predict_one_class_probability, sigmoid_predict};
11use crate::types::{SvmModel, SvmNode, SvmType};
12
13pub fn predict_values(model: &SvmModel, x: &[SvmNode], dec_values: &mut [f64]) -> f64 {
22 match model.param.svm_type {
23 SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
24 let sv_coef = &model.sv_coef[0];
25 let mut sum = 0.0;
26 for (i, sv) in model.sv.iter().enumerate() {
27 sum += sv_coef[i] * k_function(x, sv, &model.param);
28 }
29 sum -= model.rho[0];
30 dec_values[0] = sum;
31
32 if model.param.svm_type == SvmType::OneClass {
33 if sum > 0.0 {
34 1.0
35 } else {
36 -1.0
37 }
38 } else {
39 sum
40 }
41 }
42 SvmType::CSvc | SvmType::NuSvc => {
43 let nr_class = model.nr_class;
44 let l = model.sv.len();
45
46 let kvalue: Vec<f64> = model
48 .sv
49 .iter()
50 .map(|sv| k_function(x, sv, &model.param))
51 .collect();
52
53 let mut start = vec![0usize; nr_class];
55 for i in 1..nr_class {
56 start[i] = start[i - 1] + model.n_sv[i - 1];
57 }
58
59 let mut vote = vec![0usize; nr_class];
61 let mut p = 0;
62 for i in 0..nr_class {
63 for j in (i + 1)..nr_class {
64 let mut sum = 0.0;
65 let si = start[i];
66 let sj = start[j];
67 let ci = model.n_sv[i];
68 let cj = model.n_sv[j];
69
70 let coef1 = &model.sv_coef[j - 1];
71 let coef2 = &model.sv_coef[i];
72
73 for k in 0..ci {
74 sum += coef1[si + k] * kvalue[si + k];
75 }
76 for k in 0..cj {
77 sum += coef2[sj + k] * kvalue[sj + k];
78 }
79 sum -= model.rho[p];
80 dec_values[p] = sum;
81
82 if sum > 0.0 {
83 vote[i] += 1;
84 } else {
85 vote[j] += 1;
86 }
87 p += 1;
88 }
89 }
90
91 let vote_max_idx = vote
93 .iter()
94 .enumerate()
95 .max_by_key(|&(_, &v)| v)
96 .map(|(i, _)| i)
97 .unwrap_or(0);
98
99 let _ = l; model.label[vote_max_idx] as f64
101 }
102 }
103}
104
105pub fn predict(model: &SvmModel, x: &[SvmNode]) -> f64 {
134 let n = match model.param.svm_type {
135 SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => 1,
136 SvmType::CSvc | SvmType::NuSvc => model.nr_class * (model.nr_class - 1) / 2,
137 };
138 let mut dec_values = vec![0.0; n];
139 predict_values(model, x, &mut dec_values)
140}
141
142#[allow(clippy::needless_range_loop)]
157pub fn predict_probability(model: &SvmModel, x: &[SvmNode]) -> Option<(f64, Vec<f64>)> {
158 match model.param.svm_type {
159 SvmType::CSvc | SvmType::NuSvc if !model.prob_a.is_empty() && !model.prob_b.is_empty() => {
160 let nr_class = model.nr_class;
161 let n_pairs = nr_class * (nr_class - 1) / 2;
162 let mut dec_values = vec![0.0; n_pairs];
163 predict_values(model, x, &mut dec_values);
164
165 let min_prob = 1e-7;
166
167 let mut pairwise = vec![vec![0.0; nr_class]; nr_class];
169 let mut k = 0;
170 for i in 0..nr_class {
171 for j in (i + 1)..nr_class {
172 let p = sigmoid_predict(dec_values[k], model.prob_a[k], model.prob_b[k])
173 .max(min_prob)
174 .min(1.0 - min_prob);
175 pairwise[i][j] = p;
176 pairwise[j][i] = 1.0 - p;
177 k += 1;
178 }
179 }
180
181 let mut prob_estimates = vec![0.0; nr_class];
182 if nr_class == 2 {
183 prob_estimates[0] = pairwise[0][1];
184 prob_estimates[1] = pairwise[1][0];
185 } else {
186 multiclass_probability(nr_class, &pairwise, &mut prob_estimates);
187 }
188
189 let mut best = 0;
191 for i in 1..nr_class {
192 if prob_estimates[i] > prob_estimates[best] {
193 best = i;
194 }
195 }
196
197 Some((model.label[best] as f64, prob_estimates))
198 }
199
200 SvmType::OneClass if !model.prob_density_marks.is_empty() => {
201 let mut dec_value = [0.0];
202 let pred_result = predict_values(model, x, &mut dec_value);
203 let p = predict_one_class_probability(&model.prob_density_marks, dec_value[0]);
204 Some((pred_result, vec![p, 1.0 - p]))
205 }
206
207 _ => None,
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::io::load_model;
215 use crate::io::load_problem;
216 use std::path::PathBuf;
217
218 fn data_dir() -> PathBuf {
219 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
220 .join("..")
221 .join("..")
222 .join("data")
223 }
224
225 #[test]
226 fn predict_heart_scale() {
227 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
229 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
230
231 let mut correct = 0;
232 for (i, instance) in problem.instances.iter().enumerate() {
233 let pred = predict(&model, instance);
234 if pred == problem.labels[i] {
235 correct += 1;
236 }
237 }
238
239 let accuracy = correct as f64 / problem.labels.len() as f64;
240 assert!(
242 accuracy > 0.85,
243 "accuracy {:.2}% too low (expected >85%)",
244 accuracy * 100.0
245 );
246 }
247
248 #[test]
249 fn predict_values_binary() {
250 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
251 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
252
253 let mut dec_values = vec![0.0; 1];
255 let label = predict_values(&model, &problem.instances[0], &mut dec_values);
256
257 assert!(dec_values[0].abs() > 1e-10);
259 assert_eq!(label, predict(&model, &problem.instances[0]));
261 }
262
263 #[test]
264 fn predict_probability_binary() {
265 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
266 let param = crate::types::SvmParameter {
267 svm_type: SvmType::CSvc,
268 kernel_type: crate::types::KernelType::Rbf,
269 gamma: 1.0 / 13.0,
270 c: 1.0,
271 cache_size: 100.0,
272 eps: 0.001,
273 shrinking: true,
274 probability: true,
275 ..Default::default()
276 };
277
278 let model = crate::train::svm_train(&problem, ¶m);
279 assert!(!model.prob_a.is_empty());
280
281 for instance in &problem.instances {
282 let result = predict_probability(&model, instance);
283 assert!(result.is_some(), "should return probability");
284 let (label, probs) = result.unwrap();
285 assert!(label == 1.0 || label == -1.0);
286 assert_eq!(probs.len(), 2);
287 let sum: f64 = probs.iter().sum();
288 assert!(
289 (sum - 1.0).abs() < 1e-6,
290 "probs sum to {}, expected 1.0",
291 sum
292 );
293 for &p in &probs {
294 assert!((0.0..=1.0).contains(&p), "prob {} out of [0,1]", p);
295 }
296 }
297 }
298
299 #[test]
300 fn predict_probability_multiclass() {
301 let problem = load_problem(&data_dir().join("iris.scale")).unwrap();
302 let param = crate::types::SvmParameter {
303 svm_type: SvmType::CSvc,
304 kernel_type: crate::types::KernelType::Rbf,
305 gamma: 0.25,
306 c: 1.0,
307 cache_size: 100.0,
308 eps: 0.001,
309 shrinking: true,
310 probability: true,
311 ..Default::default()
312 };
313
314 let model = crate::train::svm_train(&problem, ¶m);
315 assert_eq!(model.nr_class, 3);
316 assert_eq!(model.prob_a.len(), 3); for instance in problem.instances.iter().take(10) {
319 let result = predict_probability(&model, instance);
320 assert!(result.is_some());
321 let (_label, probs) = result.unwrap();
322 assert_eq!(probs.len(), 3);
323 let sum: f64 = probs.iter().sum();
324 assert!(
325 (sum - 1.0).abs() < 1e-6,
326 "probs sum to {}, expected 1.0",
327 sum
328 );
329 }
330 }
331
332 #[test]
333 fn predict_matches_c_svm_predict() {
334 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
337 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
338
339 let c_predict = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
341 .join("..")
342 .join("..")
343 .join("vendor")
344 .join("libsvm")
345 .join("svm-predict");
346
347 if !c_predict.exists() {
348 return;
350 }
351
352 let output_path = data_dir().join("heart_scale.predict_test");
354 let status = std::process::Command::new(&c_predict)
355 .args([
356 data_dir().join("heart_scale").to_str().unwrap(),
357 data_dir().join("heart_scale.model").to_str().unwrap(),
358 output_path.to_str().unwrap(),
359 ])
360 .output();
361
362 if let Ok(output) = status {
363 if output.status.success() {
364 let c_preds: Vec<f64> = std::fs::read_to_string(&output_path)
365 .unwrap()
366 .lines()
367 .filter(|l| !l.is_empty())
368 .map(|l| l.trim().parse().unwrap())
369 .collect();
370
371 assert_eq!(c_preds.len(), problem.labels.len());
372
373 let mut mismatches = 0;
374 for (i, instance) in problem.instances.iter().enumerate() {
375 let rust_pred = predict(&model, instance);
376 if rust_pred != c_preds[i] {
377 mismatches += 1;
378 }
379 }
380
381 assert_eq!(
382 mismatches, 0,
383 "{} predictions differ between Rust and C",
384 mismatches
385 );
386
387 let _ = std::fs::remove_file(&output_path);
389 }
390 }
391 }
392}