Skip to main content

libsvm_rs/
predict.rs

1//! Prediction functions matching the original LIBSVM.
2//!
3//! Provides `predict`, `predict_values`, and `predict_probability` for
4//! all SVM types:
5//! - Classification (C-SVC, ν-SVC): one-vs-one voting + optional probability
6//! - One-class SVM: sign of decision value + optional density probability
7//! - Regression (ε-SVR, ν-SVR): continuous output
8
9use crate::kernel::k_function;
10use crate::probability::{multiclass_probability, predict_one_class_probability, sigmoid_predict};
11use crate::types::{SvmModel, SvmNode, SvmType};
12
13/// Compute decision values and return the predicted label/value.
14///
15/// For classification, `dec_values` receives `nr_class * (nr_class - 1) / 2`
16/// pairwise decision values. For regression/one-class, a single value.
17///
18/// Returns the predicted label (classification) or function value (regression).
19///
20/// Matches LIBSVM's `svm_predict_values`.
21pub fn predict_values(model: &SvmModel, x: &[SvmNode], dec_values: &mut [f64]) -> f64 {
22    match model.param.svm_type {
23        SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
24            let sv_coef = &model.sv_coef[0];
25            let mut sum = 0.0;
26            for (i, sv) in model.sv.iter().enumerate() {
27                sum += sv_coef[i] * k_function(x, sv, &model.param);
28            }
29            sum -= model.rho[0];
30            dec_values[0] = sum;
31
32            if model.param.svm_type == SvmType::OneClass {
33                if sum > 0.0 { 1.0 } else { -1.0 }
34            } else {
35                sum
36            }
37        }
38        SvmType::CSvc | SvmType::NuSvc => {
39            let nr_class = model.nr_class;
40            let l = model.sv.len();
41
42            // Compute kernel values for all SVs
43            let kvalue: Vec<f64> = model
44                .sv
45                .iter()
46                .map(|sv| k_function(x, sv, &model.param))
47                .collect();
48
49            // Compute start indices for each class's SVs
50            let mut start = vec![0usize; nr_class];
51            for i in 1..nr_class {
52                start[i] = start[i - 1] + model.n_sv[i - 1];
53            }
54
55            // One-vs-one voting
56            let mut vote = vec![0usize; nr_class];
57            let mut p = 0;
58            for i in 0..nr_class {
59                for j in (i + 1)..nr_class {
60                    let mut sum = 0.0;
61                    let si = start[i];
62                    let sj = start[j];
63                    let ci = model.n_sv[i];
64                    let cj = model.n_sv[j];
65
66                    let coef1 = &model.sv_coef[j - 1];
67                    let coef2 = &model.sv_coef[i];
68
69                    for k in 0..ci {
70                        sum += coef1[si + k] * kvalue[si + k];
71                    }
72                    for k in 0..cj {
73                        sum += coef2[sj + k] * kvalue[sj + k];
74                    }
75                    sum -= model.rho[p];
76                    dec_values[p] = sum;
77
78                    if sum > 0.0 {
79                        vote[i] += 1;
80                    } else {
81                        vote[j] += 1;
82                    }
83                    p += 1;
84                }
85            }
86
87            // Find class with most votes
88            let vote_max_idx = vote
89                .iter()
90                .enumerate()
91                .max_by_key(|&(_, &v)| v)
92                .map(|(i, _)| i)
93                .unwrap_or(0);
94
95            let _ = l; // suppress unused warning
96            model.label[vote_max_idx] as f64
97        }
98    }
99}
100
101/// Predict the label/value for a single instance.
102///
103/// Convenience wrapper around `predict_values` that allocates the
104/// decision values buffer internally. Matches LIBSVM's `svm_predict`.
105pub fn predict(model: &SvmModel, x: &[SvmNode]) -> f64 {
106    let n = match model.param.svm_type {
107        SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => 1,
108        SvmType::CSvc | SvmType::NuSvc => {
109            model.nr_class * (model.nr_class - 1) / 2
110        }
111    };
112    let mut dec_values = vec![0.0; n];
113    predict_values(model, x, &mut dec_values)
114}
115
116/// Predict with probability estimates.
117///
118/// Returns `Some((label, probs))` where `probs[i]` is the estimated
119/// probability of class `model.label[i]`. Returns `None` when the
120/// model was not trained with probability support.
121///
122/// - **C-SVC / ν-SVC**: requires `model.prob_a` and `model.prob_b`.
123///   Uses Platt scaling on pairwise decision values, then
124///   `multiclass_probability` for k > 2.
125/// - **One-class**: requires `model.prob_density_marks`.
126///   Returns `[p, 1-p]` via density-mark lookup.
127/// - **SVR**: probability prediction is not supported (returns `None`).
128///
129/// Matches LIBSVM's `svm_predict_probability`.
130pub fn predict_probability(model: &SvmModel, x: &[SvmNode]) -> Option<(f64, Vec<f64>)> {
131    match model.param.svm_type {
132        SvmType::CSvc | SvmType::NuSvc
133            if !model.prob_a.is_empty() && !model.prob_b.is_empty() =>
134        {
135            let nr_class = model.nr_class;
136            let n_pairs = nr_class * (nr_class - 1) / 2;
137            let mut dec_values = vec![0.0; n_pairs];
138            predict_values(model, x, &mut dec_values);
139
140            let min_prob = 1e-7;
141
142            // Build pairwise probability matrix
143            let mut pairwise = vec![vec![0.0; nr_class]; nr_class];
144            let mut k = 0;
145            for i in 0..nr_class {
146                for j in (i + 1)..nr_class {
147                    let p = sigmoid_predict(dec_values[k], model.prob_a[k], model.prob_b[k])
148                        .max(min_prob)
149                        .min(1.0 - min_prob);
150                    pairwise[i][j] = p;
151                    pairwise[j][i] = 1.0 - p;
152                    k += 1;
153                }
154            }
155
156            let mut prob_estimates = vec![0.0; nr_class];
157            if nr_class == 2 {
158                prob_estimates[0] = pairwise[0][1];
159                prob_estimates[1] = pairwise[1][0];
160            } else {
161                multiclass_probability(nr_class, &pairwise, &mut prob_estimates);
162            }
163
164            // Find class with highest probability
165            let mut best = 0;
166            for i in 1..nr_class {
167                if prob_estimates[i] > prob_estimates[best] {
168                    best = i;
169                }
170            }
171
172            Some((model.label[best] as f64, prob_estimates))
173        }
174
175        SvmType::OneClass if !model.prob_density_marks.is_empty() => {
176            let mut dec_value = [0.0];
177            let pred_result = predict_values(model, x, &mut dec_value);
178            let p = predict_one_class_probability(&model.prob_density_marks, dec_value[0]);
179            Some((pred_result, vec![p, 1.0 - p]))
180        }
181
182        _ => None,
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::io::load_model;
190    use crate::io::load_problem;
191    use std::path::PathBuf;
192
193    fn data_dir() -> PathBuf {
194        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
195            .join("..")
196            .join("..")
197            .join("data")
198    }
199
200    #[test]
201    fn predict_heart_scale() {
202        // Load model trained by C LIBSVM and predict on training data
203        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
204        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
205
206        let mut correct = 0;
207        for (i, instance) in problem.instances.iter().enumerate() {
208            let pred = predict(&model, instance);
209            if pred == problem.labels[i] {
210                correct += 1;
211            }
212        }
213
214        let accuracy = correct as f64 / problem.labels.len() as f64;
215        // C LIBSVM gets ~86.67% accuracy on training set with default params
216        assert!(
217            accuracy > 0.85,
218            "accuracy {:.2}% too low (expected >85%)",
219            accuracy * 100.0
220        );
221    }
222
223    #[test]
224    fn predict_values_binary() {
225        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
226        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
227
228        // For binary classification, there's exactly 1 decision value
229        let mut dec_values = vec![0.0; 1];
230        let label = predict_values(&model, &problem.instances[0], &mut dec_values);
231
232        // Decision value should be non-zero
233        assert!(dec_values[0].abs() > 1e-10);
234        // Label should match what predict returns
235        assert_eq!(label, predict(&model, &problem.instances[0]));
236    }
237
238    #[test]
239    fn predict_probability_binary() {
240        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
241        let param = crate::types::SvmParameter {
242            svm_type: SvmType::CSvc,
243            kernel_type: crate::types::KernelType::Rbf,
244            gamma: 1.0 / 13.0,
245            c: 1.0,
246            cache_size: 100.0,
247            eps: 0.001,
248            shrinking: true,
249            probability: true,
250            ..Default::default()
251        };
252
253        let model = crate::train::svm_train(&problem, &param);
254        assert!(!model.prob_a.is_empty());
255
256        for instance in &problem.instances {
257            let result = predict_probability(&model, instance);
258            assert!(result.is_some(), "should return probability");
259            let (label, probs) = result.unwrap();
260            assert!(label == 1.0 || label == -1.0);
261            assert_eq!(probs.len(), 2);
262            let sum: f64 = probs.iter().sum();
263            assert!(
264                (sum - 1.0).abs() < 1e-6,
265                "probs sum to {}, expected 1.0",
266                sum
267            );
268            for &p in &probs {
269                assert!(p >= 0.0 && p <= 1.0, "prob {} out of [0,1]", p);
270            }
271        }
272    }
273
274    #[test]
275    fn predict_probability_multiclass() {
276        let problem = load_problem(&data_dir().join("iris.scale")).unwrap();
277        let param = crate::types::SvmParameter {
278            svm_type: SvmType::CSvc,
279            kernel_type: crate::types::KernelType::Rbf,
280            gamma: 0.25,
281            c: 1.0,
282            cache_size: 100.0,
283            eps: 0.001,
284            shrinking: true,
285            probability: true,
286            ..Default::default()
287        };
288
289        let model = crate::train::svm_train(&problem, &param);
290        assert_eq!(model.nr_class, 3);
291        assert_eq!(model.prob_a.len(), 3); // 3 pairs
292
293        for instance in problem.instances.iter().take(10) {
294            let result = predict_probability(&model, instance);
295            assert!(result.is_some());
296            let (_label, probs) = result.unwrap();
297            assert_eq!(probs.len(), 3);
298            let sum: f64 = probs.iter().sum();
299            assert!(
300                (sum - 1.0).abs() < 1e-6,
301                "probs sum to {}, expected 1.0",
302                sum
303            );
304        }
305    }
306
307    #[test]
308    fn predict_matches_c_svm_predict() {
309        // Run C svm-predict and compare outputs
310        // First, let's verify our predictions match by checking a few specific instances
311        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
312        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
313
314        // Run C svm-predict to get reference predictions
315        let c_predict = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
316            .join("..")
317            .join("..")
318            .join("vendor")
319            .join("libsvm")
320            .join("svm-predict");
321
322        if !c_predict.exists() {
323            // Skip if C binary not compiled
324            return;
325        }
326
327        // Write predictions to a temp file
328        let output_path = data_dir().join("heart_scale.predict_test");
329        let status = std::process::Command::new(&c_predict)
330            .args([
331                data_dir().join("heart_scale").to_str().unwrap(),
332                data_dir().join("heart_scale.model").to_str().unwrap(),
333                output_path.to_str().unwrap(),
334            ])
335            .output();
336
337        if let Ok(output) = status {
338            if output.status.success() {
339                let c_preds: Vec<f64> = std::fs::read_to_string(&output_path)
340                    .unwrap()
341                    .lines()
342                    .filter(|l| !l.is_empty())
343                    .map(|l| l.trim().parse().unwrap())
344                    .collect();
345
346                assert_eq!(c_preds.len(), problem.labels.len());
347
348                let mut mismatches = 0;
349                for (i, instance) in problem.instances.iter().enumerate() {
350                    let rust_pred = predict(&model, instance);
351                    if rust_pred != c_preds[i] {
352                        mismatches += 1;
353                    }
354                }
355
356                assert_eq!(
357                    mismatches, 0,
358                    "{} predictions differ between Rust and C",
359                    mismatches
360                );
361
362                // Clean up
363                let _ = std::fs::remove_file(&output_path);
364            }
365        }
366    }
367}