1use crate::kernel::k_function;
10use crate::probability::{multiclass_probability, predict_one_class_probability, sigmoid_predict};
11use crate::types::{SvmModel, SvmNode, SvmType};
12
13pub fn predict_values(model: &SvmModel, x: &[SvmNode], dec_values: &mut [f64]) -> f64 {
22 match model.param.svm_type {
23 SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
24 let sv_coef = &model.sv_coef[0];
25 let mut sum = 0.0;
26 for (i, sv) in model.sv.iter().enumerate() {
27 sum += sv_coef[i] * k_function(x, sv, &model.param);
28 }
29 sum -= model.rho[0];
30 dec_values[0] = sum;
31
32 if model.param.svm_type == SvmType::OneClass {
33 if sum > 0.0 { 1.0 } else { -1.0 }
34 } else {
35 sum
36 }
37 }
38 SvmType::CSvc | SvmType::NuSvc => {
39 let nr_class = model.nr_class;
40 let l = model.sv.len();
41
42 let kvalue: Vec<f64> = model
44 .sv
45 .iter()
46 .map(|sv| k_function(x, sv, &model.param))
47 .collect();
48
49 let mut start = vec![0usize; nr_class];
51 for i in 1..nr_class {
52 start[i] = start[i - 1] + model.n_sv[i - 1];
53 }
54
55 let mut vote = vec![0usize; nr_class];
57 let mut p = 0;
58 for i in 0..nr_class {
59 for j in (i + 1)..nr_class {
60 let mut sum = 0.0;
61 let si = start[i];
62 let sj = start[j];
63 let ci = model.n_sv[i];
64 let cj = model.n_sv[j];
65
66 let coef1 = &model.sv_coef[j - 1];
67 let coef2 = &model.sv_coef[i];
68
69 for k in 0..ci {
70 sum += coef1[si + k] * kvalue[si + k];
71 }
72 for k in 0..cj {
73 sum += coef2[sj + k] * kvalue[sj + k];
74 }
75 sum -= model.rho[p];
76 dec_values[p] = sum;
77
78 if sum > 0.0 {
79 vote[i] += 1;
80 } else {
81 vote[j] += 1;
82 }
83 p += 1;
84 }
85 }
86
87 let vote_max_idx = vote
89 .iter()
90 .enumerate()
91 .max_by_key(|&(_, &v)| v)
92 .map(|(i, _)| i)
93 .unwrap_or(0);
94
95 let _ = l; model.label[vote_max_idx] as f64
97 }
98 }
99}
100
101pub fn predict(model: &SvmModel, x: &[SvmNode]) -> f64 {
106 let n = match model.param.svm_type {
107 SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => 1,
108 SvmType::CSvc | SvmType::NuSvc => {
109 model.nr_class * (model.nr_class - 1) / 2
110 }
111 };
112 let mut dec_values = vec![0.0; n];
113 predict_values(model, x, &mut dec_values)
114}
115
116pub fn predict_probability(model: &SvmModel, x: &[SvmNode]) -> Option<(f64, Vec<f64>)> {
131 match model.param.svm_type {
132 SvmType::CSvc | SvmType::NuSvc
133 if !model.prob_a.is_empty() && !model.prob_b.is_empty() =>
134 {
135 let nr_class = model.nr_class;
136 let n_pairs = nr_class * (nr_class - 1) / 2;
137 let mut dec_values = vec![0.0; n_pairs];
138 predict_values(model, x, &mut dec_values);
139
140 let min_prob = 1e-7;
141
142 let mut pairwise = vec![vec![0.0; nr_class]; nr_class];
144 let mut k = 0;
145 for i in 0..nr_class {
146 for j in (i + 1)..nr_class {
147 let p = sigmoid_predict(dec_values[k], model.prob_a[k], model.prob_b[k])
148 .max(min_prob)
149 .min(1.0 - min_prob);
150 pairwise[i][j] = p;
151 pairwise[j][i] = 1.0 - p;
152 k += 1;
153 }
154 }
155
156 let mut prob_estimates = vec![0.0; nr_class];
157 if nr_class == 2 {
158 prob_estimates[0] = pairwise[0][1];
159 prob_estimates[1] = pairwise[1][0];
160 } else {
161 multiclass_probability(nr_class, &pairwise, &mut prob_estimates);
162 }
163
164 let mut best = 0;
166 for i in 1..nr_class {
167 if prob_estimates[i] > prob_estimates[best] {
168 best = i;
169 }
170 }
171
172 Some((model.label[best] as f64, prob_estimates))
173 }
174
175 SvmType::OneClass if !model.prob_density_marks.is_empty() => {
176 let mut dec_value = [0.0];
177 let pred_result = predict_values(model, x, &mut dec_value);
178 let p = predict_one_class_probability(&model.prob_density_marks, dec_value[0]);
179 Some((pred_result, vec![p, 1.0 - p]))
180 }
181
182 _ => None,
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::io::load_model;
190 use crate::io::load_problem;
191 use std::path::PathBuf;
192
193 fn data_dir() -> PathBuf {
194 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
195 .join("..")
196 .join("..")
197 .join("data")
198 }
199
200 #[test]
201 fn predict_heart_scale() {
202 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
204 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
205
206 let mut correct = 0;
207 for (i, instance) in problem.instances.iter().enumerate() {
208 let pred = predict(&model, instance);
209 if pred == problem.labels[i] {
210 correct += 1;
211 }
212 }
213
214 let accuracy = correct as f64 / problem.labels.len() as f64;
215 assert!(
217 accuracy > 0.85,
218 "accuracy {:.2}% too low (expected >85%)",
219 accuracy * 100.0
220 );
221 }
222
223 #[test]
224 fn predict_values_binary() {
225 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
226 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
227
228 let mut dec_values = vec![0.0; 1];
230 let label = predict_values(&model, &problem.instances[0], &mut dec_values);
231
232 assert!(dec_values[0].abs() > 1e-10);
234 assert_eq!(label, predict(&model, &problem.instances[0]));
236 }
237
238 #[test]
239 fn predict_probability_binary() {
240 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
241 let param = crate::types::SvmParameter {
242 svm_type: SvmType::CSvc,
243 kernel_type: crate::types::KernelType::Rbf,
244 gamma: 1.0 / 13.0,
245 c: 1.0,
246 cache_size: 100.0,
247 eps: 0.001,
248 shrinking: true,
249 probability: true,
250 ..Default::default()
251 };
252
253 let model = crate::train::svm_train(&problem, ¶m);
254 assert!(!model.prob_a.is_empty());
255
256 for instance in &problem.instances {
257 let result = predict_probability(&model, instance);
258 assert!(result.is_some(), "should return probability");
259 let (label, probs) = result.unwrap();
260 assert!(label == 1.0 || label == -1.0);
261 assert_eq!(probs.len(), 2);
262 let sum: f64 = probs.iter().sum();
263 assert!(
264 (sum - 1.0).abs() < 1e-6,
265 "probs sum to {}, expected 1.0",
266 sum
267 );
268 for &p in &probs {
269 assert!(p >= 0.0 && p <= 1.0, "prob {} out of [0,1]", p);
270 }
271 }
272 }
273
274 #[test]
275 fn predict_probability_multiclass() {
276 let problem = load_problem(&data_dir().join("iris.scale")).unwrap();
277 let param = crate::types::SvmParameter {
278 svm_type: SvmType::CSvc,
279 kernel_type: crate::types::KernelType::Rbf,
280 gamma: 0.25,
281 c: 1.0,
282 cache_size: 100.0,
283 eps: 0.001,
284 shrinking: true,
285 probability: true,
286 ..Default::default()
287 };
288
289 let model = crate::train::svm_train(&problem, ¶m);
290 assert_eq!(model.nr_class, 3);
291 assert_eq!(model.prob_a.len(), 3); for instance in problem.instances.iter().take(10) {
294 let result = predict_probability(&model, instance);
295 assert!(result.is_some());
296 let (_label, probs) = result.unwrap();
297 assert_eq!(probs.len(), 3);
298 let sum: f64 = probs.iter().sum();
299 assert!(
300 (sum - 1.0).abs() < 1e-6,
301 "probs sum to {}, expected 1.0",
302 sum
303 );
304 }
305 }
306
307 #[test]
308 fn predict_matches_c_svm_predict() {
309 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
312 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
313
314 let c_predict = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
316 .join("..")
317 .join("..")
318 .join("vendor")
319 .join("libsvm")
320 .join("svm-predict");
321
322 if !c_predict.exists() {
323 return;
325 }
326
327 let output_path = data_dir().join("heart_scale.predict_test");
329 let status = std::process::Command::new(&c_predict)
330 .args([
331 data_dir().join("heart_scale").to_str().unwrap(),
332 data_dir().join("heart_scale.model").to_str().unwrap(),
333 output_path.to_str().unwrap(),
334 ])
335 .output();
336
337 if let Ok(output) = status {
338 if output.status.success() {
339 let c_preds: Vec<f64> = std::fs::read_to_string(&output_path)
340 .unwrap()
341 .lines()
342 .filter(|l| !l.is_empty())
343 .map(|l| l.trim().parse().unwrap())
344 .collect();
345
346 assert_eq!(c_preds.len(), problem.labels.len());
347
348 let mut mismatches = 0;
349 for (i, instance) in problem.instances.iter().enumerate() {
350 let rust_pred = predict(&model, instance);
351 if rust_pred != c_preds[i] {
352 mismatches += 1;
353 }
354 }
355
356 assert_eq!(
357 mismatches, 0,
358 "{} predictions differ between Rust and C",
359 mismatches
360 );
361
362 let _ = std::fs::remove_file(&output_path);
364 }
365 }
366 }
367}