Skip to main content

libsvm_rs/
predict.rs

1//! Prediction functions matching the original LIBSVM.
2//!
3//! Provides `predict`, `predict_values`, and `predict_probability` for
4//! all SVM types:
5//! - Classification (C-SVC, ν-SVC): one-vs-one voting + optional probability
6//! - One-class SVM: sign of decision value + optional density probability
7//! - Regression (ε-SVR, ν-SVR): continuous output
8
9use crate::kernel::k_function;
10use crate::probability::{multiclass_probability, predict_one_class_probability, sigmoid_predict};
11use crate::types::{SvmModel, SvmNode, SvmType};
12
13/// Compute decision values and return the predicted label/value.
14///
15/// For classification, `dec_values` receives `nr_class * (nr_class - 1) / 2`
16/// pairwise decision values. For regression/one-class, a single value.
17///
18/// Returns the predicted label (classification) or function value (regression).
19///
20/// Matches LIBSVM's `svm_predict_values`.
21pub fn predict_values(model: &SvmModel, x: &[SvmNode], dec_values: &mut [f64]) -> f64 {
22    match model.param.svm_type {
23        SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
24            let sv_coef = &model.sv_coef[0];
25            let mut sum = 0.0;
26            for (i, sv) in model.sv.iter().enumerate() {
27                sum += sv_coef[i] * k_function(x, sv, &model.param);
28            }
29            sum -= model.rho[0];
30            dec_values[0] = sum;
31
32            if model.param.svm_type == SvmType::OneClass {
33                if sum > 0.0 {
34                    1.0
35                } else {
36                    -1.0
37                }
38            } else {
39                sum
40            }
41        }
42        SvmType::CSvc | SvmType::NuSvc => {
43            let nr_class = model.nr_class;
44            let l = model.sv.len();
45
46            // Compute kernel values for all SVs
47            let kvalue: Vec<f64> = model
48                .sv
49                .iter()
50                .map(|sv| k_function(x, sv, &model.param))
51                .collect();
52
53            // Compute start indices for each class's SVs
54            let mut start = vec![0usize; nr_class];
55            for i in 1..nr_class {
56                start[i] = start[i - 1] + model.n_sv[i - 1];
57            }
58
59            // One-vs-one voting
60            let mut vote = vec![0usize; nr_class];
61            let mut p = 0;
62            for i in 0..nr_class {
63                for j in (i + 1)..nr_class {
64                    let mut sum = 0.0;
65                    let si = start[i];
66                    let sj = start[j];
67                    let ci = model.n_sv[i];
68                    let cj = model.n_sv[j];
69
70                    let coef1 = &model.sv_coef[j - 1];
71                    let coef2 = &model.sv_coef[i];
72
73                    for k in 0..ci {
74                        sum += coef1[si + k] * kvalue[si + k];
75                    }
76                    for k in 0..cj {
77                        sum += coef2[sj + k] * kvalue[sj + k];
78                    }
79                    sum -= model.rho[p];
80                    dec_values[p] = sum;
81
82                    if sum > 0.0 {
83                        vote[i] += 1;
84                    } else {
85                        vote[j] += 1;
86                    }
87                    p += 1;
88                }
89            }
90
91            // Find class with most votes
92            let vote_max_idx = vote
93                .iter()
94                .enumerate()
95                .max_by_key(|&(_, &v)| v)
96                .map(|(i, _)| i)
97                .unwrap_or(0);
98
99            let _ = l; // suppress unused warning
100            model.label[vote_max_idx] as f64
101        }
102    }
103}
104
105/// Predict the label/value for a single instance.
106///
107/// Convenience wrapper around `predict_values` that allocates the
108/// decision values buffer internally. Matches LIBSVM's `svm_predict`.
109pub fn predict(model: &SvmModel, x: &[SvmNode]) -> f64 {
110    let n = match model.param.svm_type {
111        SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => 1,
112        SvmType::CSvc | SvmType::NuSvc => model.nr_class * (model.nr_class - 1) / 2,
113    };
114    let mut dec_values = vec![0.0; n];
115    predict_values(model, x, &mut dec_values)
116}
117
118/// Predict with probability estimates.
119///
120/// Returns `Some((label, probs))` where `probs[i]` is the estimated
121/// probability of class `model.label[i]`. Returns `None` when the
122/// model was not trained with probability support.
123///
124/// - **C-SVC / ν-SVC**: requires `model.prob_a` and `model.prob_b`.
125///   Uses Platt scaling on pairwise decision values, then
126///   `multiclass_probability` for k > 2.
127/// - **One-class**: requires `model.prob_density_marks`.
128///   Returns `[p, 1-p]` via density-mark lookup.
129/// - **SVR**: probability prediction is not supported (returns `None`).
130///
131/// Matches LIBSVM's `svm_predict_probability`.
132#[allow(clippy::needless_range_loop)]
133pub fn predict_probability(model: &SvmModel, x: &[SvmNode]) -> Option<(f64, Vec<f64>)> {
134    match model.param.svm_type {
135        SvmType::CSvc | SvmType::NuSvc if !model.prob_a.is_empty() && !model.prob_b.is_empty() => {
136            let nr_class = model.nr_class;
137            let n_pairs = nr_class * (nr_class - 1) / 2;
138            let mut dec_values = vec![0.0; n_pairs];
139            predict_values(model, x, &mut dec_values);
140
141            let min_prob = 1e-7;
142
143            // Build pairwise probability matrix
144            let mut pairwise = vec![vec![0.0; nr_class]; nr_class];
145            let mut k = 0;
146            for i in 0..nr_class {
147                for j in (i + 1)..nr_class {
148                    let p = sigmoid_predict(dec_values[k], model.prob_a[k], model.prob_b[k])
149                        .max(min_prob)
150                        .min(1.0 - min_prob);
151                    pairwise[i][j] = p;
152                    pairwise[j][i] = 1.0 - p;
153                    k += 1;
154                }
155            }
156
157            let mut prob_estimates = vec![0.0; nr_class];
158            if nr_class == 2 {
159                prob_estimates[0] = pairwise[0][1];
160                prob_estimates[1] = pairwise[1][0];
161            } else {
162                multiclass_probability(nr_class, &pairwise, &mut prob_estimates);
163            }
164
165            // Find class with highest probability
166            let mut best = 0;
167            for i in 1..nr_class {
168                if prob_estimates[i] > prob_estimates[best] {
169                    best = i;
170                }
171            }
172
173            Some((model.label[best] as f64, prob_estimates))
174        }
175
176        SvmType::OneClass if !model.prob_density_marks.is_empty() => {
177            let mut dec_value = [0.0];
178            let pred_result = predict_values(model, x, &mut dec_value);
179            let p = predict_one_class_probability(&model.prob_density_marks, dec_value[0]);
180            Some((pred_result, vec![p, 1.0 - p]))
181        }
182
183        _ => None,
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::io::load_model;
191    use crate::io::load_problem;
192    use std::path::PathBuf;
193
194    fn data_dir() -> PathBuf {
195        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
196            .join("..")
197            .join("..")
198            .join("data")
199    }
200
201    #[test]
202    fn predict_heart_scale() {
203        // Load model trained by C LIBSVM and predict on training data
204        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
205        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
206
207        let mut correct = 0;
208        for (i, instance) in problem.instances.iter().enumerate() {
209            let pred = predict(&model, instance);
210            if pred == problem.labels[i] {
211                correct += 1;
212            }
213        }
214
215        let accuracy = correct as f64 / problem.labels.len() as f64;
216        // C LIBSVM gets ~86.67% accuracy on training set with default params
217        assert!(
218            accuracy > 0.85,
219            "accuracy {:.2}% too low (expected >85%)",
220            accuracy * 100.0
221        );
222    }
223
224    #[test]
225    fn predict_values_binary() {
226        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
227        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
228
229        // For binary classification, there's exactly 1 decision value
230        let mut dec_values = vec![0.0; 1];
231        let label = predict_values(&model, &problem.instances[0], &mut dec_values);
232
233        // Decision value should be non-zero
234        assert!(dec_values[0].abs() > 1e-10);
235        // Label should match what predict returns
236        assert_eq!(label, predict(&model, &problem.instances[0]));
237    }
238
239    #[test]
240    fn predict_probability_binary() {
241        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
242        let param = crate::types::SvmParameter {
243            svm_type: SvmType::CSvc,
244            kernel_type: crate::types::KernelType::Rbf,
245            gamma: 1.0 / 13.0,
246            c: 1.0,
247            cache_size: 100.0,
248            eps: 0.001,
249            shrinking: true,
250            probability: true,
251            ..Default::default()
252        };
253
254        let model = crate::train::svm_train(&problem, &param);
255        assert!(!model.prob_a.is_empty());
256
257        for instance in &problem.instances {
258            let result = predict_probability(&model, instance);
259            assert!(result.is_some(), "should return probability");
260            let (label, probs) = result.unwrap();
261            assert!(label == 1.0 || label == -1.0);
262            assert_eq!(probs.len(), 2);
263            let sum: f64 = probs.iter().sum();
264            assert!(
265                (sum - 1.0).abs() < 1e-6,
266                "probs sum to {}, expected 1.0",
267                sum
268            );
269            for &p in &probs {
270                assert!((0.0..=1.0).contains(&p), "prob {} out of [0,1]", p);
271            }
272        }
273    }
274
275    #[test]
276    fn predict_probability_multiclass() {
277        let problem = load_problem(&data_dir().join("iris.scale")).unwrap();
278        let param = crate::types::SvmParameter {
279            svm_type: SvmType::CSvc,
280            kernel_type: crate::types::KernelType::Rbf,
281            gamma: 0.25,
282            c: 1.0,
283            cache_size: 100.0,
284            eps: 0.001,
285            shrinking: true,
286            probability: true,
287            ..Default::default()
288        };
289
290        let model = crate::train::svm_train(&problem, &param);
291        assert_eq!(model.nr_class, 3);
292        assert_eq!(model.prob_a.len(), 3); // 3 pairs
293
294        for instance in problem.instances.iter().take(10) {
295            let result = predict_probability(&model, instance);
296            assert!(result.is_some());
297            let (_label, probs) = result.unwrap();
298            assert_eq!(probs.len(), 3);
299            let sum: f64 = probs.iter().sum();
300            assert!(
301                (sum - 1.0).abs() < 1e-6,
302                "probs sum to {}, expected 1.0",
303                sum
304            );
305        }
306    }
307
308    #[test]
309    fn predict_matches_c_svm_predict() {
310        // Run C svm-predict and compare outputs
311        // First, let's verify our predictions match by checking a few specific instances
312        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
313        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
314
315        // Run C svm-predict to get reference predictions
316        let c_predict = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
317            .join("..")
318            .join("..")
319            .join("vendor")
320            .join("libsvm")
321            .join("svm-predict");
322
323        if !c_predict.exists() {
324            // Skip if C binary not compiled
325            return;
326        }
327
328        // Write predictions to a temp file
329        let output_path = data_dir().join("heart_scale.predict_test");
330        let status = std::process::Command::new(&c_predict)
331            .args([
332                data_dir().join("heart_scale").to_str().unwrap(),
333                data_dir().join("heart_scale.model").to_str().unwrap(),
334                output_path.to_str().unwrap(),
335            ])
336            .output();
337
338        if let Ok(output) = status {
339            if output.status.success() {
340                let c_preds: Vec<f64> = std::fs::read_to_string(&output_path)
341                    .unwrap()
342                    .lines()
343                    .filter(|l| !l.is_empty())
344                    .map(|l| l.trim().parse().unwrap())
345                    .collect();
346
347                assert_eq!(c_preds.len(), problem.labels.len());
348
349                let mut mismatches = 0;
350                for (i, instance) in problem.instances.iter().enumerate() {
351                    let rust_pred = predict(&model, instance);
352                    if rust_pred != c_preds[i] {
353                        mismatches += 1;
354                    }
355                }
356
357                assert_eq!(
358                    mismatches, 0,
359                    "{} predictions differ between Rust and C",
360                    mismatches
361                );
362
363                // Clean up
364                let _ = std::fs::remove_file(&output_path);
365            }
366        }
367    }
368}