1use crate::kernel::k_function;
10use crate::probability::{multiclass_probability, predict_one_class_probability, sigmoid_predict};
11use crate::types::{SvmModel, SvmNode, SvmType};
12
13pub fn predict_values(model: &SvmModel, x: &[SvmNode], dec_values: &mut [f64]) -> f64 {
22 match model.param.svm_type {
23 SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
24 let sv_coef = &model.sv_coef[0];
25 let mut sum = 0.0;
26 for (i, sv) in model.sv.iter().enumerate() {
27 sum += sv_coef[i] * k_function(x, sv, &model.param);
28 }
29 sum -= model.rho[0];
30 dec_values[0] = sum;
31
32 if model.param.svm_type == SvmType::OneClass {
33 if sum > 0.0 {
34 1.0
35 } else {
36 -1.0
37 }
38 } else {
39 sum
40 }
41 }
42 SvmType::CSvc | SvmType::NuSvc => {
43 let nr_class = model.nr_class;
44 let l = model.sv.len();
45
46 let kvalue: Vec<f64> = model
48 .sv
49 .iter()
50 .map(|sv| k_function(x, sv, &model.param))
51 .collect();
52
53 let mut start = vec![0usize; nr_class];
55 for i in 1..nr_class {
56 start[i] = start[i - 1] + model.n_sv[i - 1];
57 }
58
59 let mut vote = vec![0usize; nr_class];
61 let mut p = 0;
62 for i in 0..nr_class {
63 for j in (i + 1)..nr_class {
64 let mut sum = 0.0;
65 let si = start[i];
66 let sj = start[j];
67 let ci = model.n_sv[i];
68 let cj = model.n_sv[j];
69
70 let coef1 = &model.sv_coef[j - 1];
71 let coef2 = &model.sv_coef[i];
72
73 for k in 0..ci {
74 sum += coef1[si + k] * kvalue[si + k];
75 }
76 for k in 0..cj {
77 sum += coef2[sj + k] * kvalue[sj + k];
78 }
79 sum -= model.rho[p];
80 dec_values[p] = sum;
81
82 if sum > 0.0 {
83 vote[i] += 1;
84 } else {
85 vote[j] += 1;
86 }
87 p += 1;
88 }
89 }
90
91 let vote_max_idx = vote
93 .iter()
94 .enumerate()
95 .max_by_key(|&(_, &v)| v)
96 .map(|(i, _)| i)
97 .unwrap_or(0);
98
99 let _ = l; model.label[vote_max_idx] as f64
101 }
102 }
103}
104
105pub fn predict(model: &SvmModel, x: &[SvmNode]) -> f64 {
110 let n = match model.param.svm_type {
111 SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => 1,
112 SvmType::CSvc | SvmType::NuSvc => model.nr_class * (model.nr_class - 1) / 2,
113 };
114 let mut dec_values = vec![0.0; n];
115 predict_values(model, x, &mut dec_values)
116}
117
118#[allow(clippy::needless_range_loop)]
133pub fn predict_probability(model: &SvmModel, x: &[SvmNode]) -> Option<(f64, Vec<f64>)> {
134 match model.param.svm_type {
135 SvmType::CSvc | SvmType::NuSvc if !model.prob_a.is_empty() && !model.prob_b.is_empty() => {
136 let nr_class = model.nr_class;
137 let n_pairs = nr_class * (nr_class - 1) / 2;
138 let mut dec_values = vec![0.0; n_pairs];
139 predict_values(model, x, &mut dec_values);
140
141 let min_prob = 1e-7;
142
143 let mut pairwise = vec![vec![0.0; nr_class]; nr_class];
145 let mut k = 0;
146 for i in 0..nr_class {
147 for j in (i + 1)..nr_class {
148 let p = sigmoid_predict(dec_values[k], model.prob_a[k], model.prob_b[k])
149 .max(min_prob)
150 .min(1.0 - min_prob);
151 pairwise[i][j] = p;
152 pairwise[j][i] = 1.0 - p;
153 k += 1;
154 }
155 }
156
157 let mut prob_estimates = vec![0.0; nr_class];
158 if nr_class == 2 {
159 prob_estimates[0] = pairwise[0][1];
160 prob_estimates[1] = pairwise[1][0];
161 } else {
162 multiclass_probability(nr_class, &pairwise, &mut prob_estimates);
163 }
164
165 let mut best = 0;
167 for i in 1..nr_class {
168 if prob_estimates[i] > prob_estimates[best] {
169 best = i;
170 }
171 }
172
173 Some((model.label[best] as f64, prob_estimates))
174 }
175
176 SvmType::OneClass if !model.prob_density_marks.is_empty() => {
177 let mut dec_value = [0.0];
178 let pred_result = predict_values(model, x, &mut dec_value);
179 let p = predict_one_class_probability(&model.prob_density_marks, dec_value[0]);
180 Some((pred_result, vec![p, 1.0 - p]))
181 }
182
183 _ => None,
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::io::load_model;
191 use crate::io::load_problem;
192 use std::path::PathBuf;
193
194 fn data_dir() -> PathBuf {
195 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
196 .join("..")
197 .join("..")
198 .join("data")
199 }
200
201 #[test]
202 fn predict_heart_scale() {
203 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
205 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
206
207 let mut correct = 0;
208 for (i, instance) in problem.instances.iter().enumerate() {
209 let pred = predict(&model, instance);
210 if pred == problem.labels[i] {
211 correct += 1;
212 }
213 }
214
215 let accuracy = correct as f64 / problem.labels.len() as f64;
216 assert!(
218 accuracy > 0.85,
219 "accuracy {:.2}% too low (expected >85%)",
220 accuracy * 100.0
221 );
222 }
223
224 #[test]
225 fn predict_values_binary() {
226 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
227 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
228
229 let mut dec_values = vec![0.0; 1];
231 let label = predict_values(&model, &problem.instances[0], &mut dec_values);
232
233 assert!(dec_values[0].abs() > 1e-10);
235 assert_eq!(label, predict(&model, &problem.instances[0]));
237 }
238
239 #[test]
240 fn predict_probability_binary() {
241 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
242 let param = crate::types::SvmParameter {
243 svm_type: SvmType::CSvc,
244 kernel_type: crate::types::KernelType::Rbf,
245 gamma: 1.0 / 13.0,
246 c: 1.0,
247 cache_size: 100.0,
248 eps: 0.001,
249 shrinking: true,
250 probability: true,
251 ..Default::default()
252 };
253
254 let model = crate::train::svm_train(&problem, ¶m);
255 assert!(!model.prob_a.is_empty());
256
257 for instance in &problem.instances {
258 let result = predict_probability(&model, instance);
259 assert!(result.is_some(), "should return probability");
260 let (label, probs) = result.unwrap();
261 assert!(label == 1.0 || label == -1.0);
262 assert_eq!(probs.len(), 2);
263 let sum: f64 = probs.iter().sum();
264 assert!(
265 (sum - 1.0).abs() < 1e-6,
266 "probs sum to {}, expected 1.0",
267 sum
268 );
269 for &p in &probs {
270 assert!((0.0..=1.0).contains(&p), "prob {} out of [0,1]", p);
271 }
272 }
273 }
274
275 #[test]
276 fn predict_probability_multiclass() {
277 let problem = load_problem(&data_dir().join("iris.scale")).unwrap();
278 let param = crate::types::SvmParameter {
279 svm_type: SvmType::CSvc,
280 kernel_type: crate::types::KernelType::Rbf,
281 gamma: 0.25,
282 c: 1.0,
283 cache_size: 100.0,
284 eps: 0.001,
285 shrinking: true,
286 probability: true,
287 ..Default::default()
288 };
289
290 let model = crate::train::svm_train(&problem, ¶m);
291 assert_eq!(model.nr_class, 3);
292 assert_eq!(model.prob_a.len(), 3); for instance in problem.instances.iter().take(10) {
295 let result = predict_probability(&model, instance);
296 assert!(result.is_some());
297 let (_label, probs) = result.unwrap();
298 assert_eq!(probs.len(), 3);
299 let sum: f64 = probs.iter().sum();
300 assert!(
301 (sum - 1.0).abs() < 1e-6,
302 "probs sum to {}, expected 1.0",
303 sum
304 );
305 }
306 }
307
308 #[test]
309 fn predict_matches_c_svm_predict() {
310 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
313 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
314
315 let c_predict = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
317 .join("..")
318 .join("..")
319 .join("vendor")
320 .join("libsvm")
321 .join("svm-predict");
322
323 if !c_predict.exists() {
324 return;
326 }
327
328 let output_path = data_dir().join("heart_scale.predict_test");
330 let status = std::process::Command::new(&c_predict)
331 .args([
332 data_dir().join("heart_scale").to_str().unwrap(),
333 data_dir().join("heart_scale.model").to_str().unwrap(),
334 output_path.to_str().unwrap(),
335 ])
336 .output();
337
338 if let Ok(output) = status {
339 if output.status.success() {
340 let c_preds: Vec<f64> = std::fs::read_to_string(&output_path)
341 .unwrap()
342 .lines()
343 .filter(|l| !l.is_empty())
344 .map(|l| l.trim().parse().unwrap())
345 .collect();
346
347 assert_eq!(c_preds.len(), problem.labels.len());
348
349 let mut mismatches = 0;
350 for (i, instance) in problem.instances.iter().enumerate() {
351 let rust_pred = predict(&model, instance);
352 if rust_pred != c_preds[i] {
353 mismatches += 1;
354 }
355 }
356
357 assert_eq!(
358 mismatches, 0,
359 "{} predictions differ between Rust and C",
360 mismatches
361 );
362
363 let _ = std::fs::remove_file(&output_path);
365 }
366 }
367 }
368}