Skip to main content

libsvm_rs/
predict.rs

1//! Prediction functions matching the original LIBSVM.
2//!
3//! Provides `predict`, `predict_values`, and `predict_probability` for
4//! all SVM types:
5//! - Classification (C-SVC, ν-SVC): one-vs-one voting + optional probability
6//! - One-class SVM: sign of decision value + optional density probability
7//! - Regression (ε-SVR, ν-SVR): continuous output
8
9use crate::kernel::k_function;
10use crate::probability::{multiclass_probability, predict_one_class_probability, sigmoid_predict};
11use crate::types::{SvmModel, SvmNode, SvmType};
12
13/// Compute decision values and return the predicted label/value.
14///
15/// For classification, `dec_values` receives `nr_class * (nr_class - 1) / 2`
16/// pairwise decision values. For regression/one-class, a single value.
17///
18/// Returns the predicted label (classification) or function value (regression).
19///
20/// Matches LIBSVM's `svm_predict_values`.
21pub fn predict_values(model: &SvmModel, x: &[SvmNode], dec_values: &mut [f64]) -> f64 {
22    match model.param.svm_type {
23        SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
24            let sv_coef = &model.sv_coef[0];
25            let mut sum = 0.0;
26            for (i, sv) in model.sv.iter().enumerate() {
27                sum += sv_coef[i] * k_function(x, sv, &model.param);
28            }
29            sum -= model.rho[0];
30            dec_values[0] = sum;
31
32            if model.param.svm_type == SvmType::OneClass {
33                if sum > 0.0 {
34                    1.0
35                } else {
36                    -1.0
37                }
38            } else {
39                sum
40            }
41        }
42        SvmType::CSvc | SvmType::NuSvc => {
43            let nr_class = model.nr_class;
44            let l = model.sv.len();
45
46            // Compute kernel values for all SVs
47            let kvalue: Vec<f64> = model
48                .sv
49                .iter()
50                .map(|sv| k_function(x, sv, &model.param))
51                .collect();
52
53            // Compute start indices for each class's SVs
54            let mut start = vec![0usize; nr_class];
55            for i in 1..nr_class {
56                start[i] = start[i - 1] + model.n_sv[i - 1];
57            }
58
59            // One-vs-one voting
60            let mut vote = vec![0usize; nr_class];
61            let mut p = 0;
62            for i in 0..nr_class {
63                for j in (i + 1)..nr_class {
64                    let mut sum = 0.0;
65                    let si = start[i];
66                    let sj = start[j];
67                    let ci = model.n_sv[i];
68                    let cj = model.n_sv[j];
69
70                    let coef1 = &model.sv_coef[j - 1];
71                    let coef2 = &model.sv_coef[i];
72
73                    for k in 0..ci {
74                        sum += coef1[si + k] * kvalue[si + k];
75                    }
76                    for k in 0..cj {
77                        sum += coef2[sj + k] * kvalue[sj + k];
78                    }
79                    sum -= model.rho[p];
80                    dec_values[p] = sum;
81
82                    if sum > 0.0 {
83                        vote[i] += 1;
84                    } else {
85                        vote[j] += 1;
86                    }
87                    p += 1;
88                }
89            }
90
91            // Find class with most votes
92            let vote_max_idx = vote
93                .iter()
94                .enumerate()
95                .max_by_key(|&(_, &v)| v)
96                .map(|(i, _)| i)
97                .unwrap_or(0);
98
99            let _ = l; // suppress unused warning
100            model.label[vote_max_idx] as f64
101        }
102    }
103}
104
105/// Predict the label/value for a single instance.
106///
107/// Convenience wrapper around `predict_values` that allocates the
108/// decision values buffer internally. Matches LIBSVM's `svm_predict`.
109///
110/// ```
111/// use libsvm_rs::predict::predict;
112/// use libsvm_rs::train::svm_train;
113/// use libsvm_rs::{KernelType, SvmNode, SvmParameterBuilder, SvmProblem, SvmType};
114///
115/// let problem = SvmProblem {
116///     labels: vec![-1.0, -1.0, 1.0, 1.0],
117///     instances: vec![
118///         vec![SvmNode { index: 1, value: -2.0 }],
119///         vec![SvmNode { index: 1, value: -1.0 }],
120///         vec![SvmNode { index: 1, value: 1.0 }],
121///         vec![SvmNode { index: 1, value: 2.0 }],
122///     ],
123/// };
124/// let param = SvmParameterBuilder::new()
125///     .svm_type(SvmType::CSvc)
126///     .kernel_type(KernelType::Linear)
127///     .build()?;
128/// let model = svm_train(&problem, &param);
129///
130/// assert_eq!(predict(&model, &[SvmNode { index: 1, value: 1.5 }]), 1.0);
131/// # Ok::<(), libsvm_rs::SvmError>(())
132/// ```
133pub fn predict(model: &SvmModel, x: &[SvmNode]) -> f64 {
134    let n = match model.param.svm_type {
135        SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => 1,
136        SvmType::CSvc | SvmType::NuSvc => model.nr_class * (model.nr_class - 1) / 2,
137    };
138    let mut dec_values = vec![0.0; n];
139    predict_values(model, x, &mut dec_values)
140}
141
142/// Predict with probability estimates.
143///
144/// Returns `Some((label, probs))` where `probs[i]` is the estimated
145/// probability of class `model.label[i]`. Returns `None` when the
146/// model was not trained with probability support.
147///
148/// - **C-SVC / ν-SVC**: requires `model.prob_a` and `model.prob_b`.
149///   Uses Platt scaling on pairwise decision values, then
150///   `multiclass_probability` for k > 2.
151/// - **One-class**: requires `model.prob_density_marks`.
152///   Returns `[p, 1-p]` via density-mark lookup.
153/// - **SVR**: probability prediction is not supported (returns `None`).
154///
155/// Matches LIBSVM's `svm_predict_probability`.
156#[allow(clippy::needless_range_loop)]
157pub fn predict_probability(model: &SvmModel, x: &[SvmNode]) -> Option<(f64, Vec<f64>)> {
158    match model.param.svm_type {
159        SvmType::CSvc | SvmType::NuSvc if !model.prob_a.is_empty() && !model.prob_b.is_empty() => {
160            let nr_class = model.nr_class;
161            let n_pairs = nr_class * (nr_class - 1) / 2;
162            let mut dec_values = vec![0.0; n_pairs];
163            predict_values(model, x, &mut dec_values);
164
165            let min_prob = 1e-7;
166
167            // Build pairwise probability matrix
168            let mut pairwise = vec![vec![0.0; nr_class]; nr_class];
169            let mut k = 0;
170            for i in 0..nr_class {
171                for j in (i + 1)..nr_class {
172                    let p = sigmoid_predict(dec_values[k], model.prob_a[k], model.prob_b[k])
173                        .max(min_prob)
174                        .min(1.0 - min_prob);
175                    pairwise[i][j] = p;
176                    pairwise[j][i] = 1.0 - p;
177                    k += 1;
178                }
179            }
180
181            let mut prob_estimates = vec![0.0; nr_class];
182            if nr_class == 2 {
183                prob_estimates[0] = pairwise[0][1];
184                prob_estimates[1] = pairwise[1][0];
185            } else {
186                multiclass_probability(nr_class, &pairwise, &mut prob_estimates);
187            }
188
189            // Find class with highest probability
190            let mut best = 0;
191            for i in 1..nr_class {
192                if prob_estimates[i] > prob_estimates[best] {
193                    best = i;
194                }
195            }
196
197            Some((model.label[best] as f64, prob_estimates))
198        }
199
200        SvmType::OneClass if !model.prob_density_marks.is_empty() => {
201            let mut dec_value = [0.0];
202            let pred_result = predict_values(model, x, &mut dec_value);
203            let p = predict_one_class_probability(&model.prob_density_marks, dec_value[0]);
204            Some((pred_result, vec![p, 1.0 - p]))
205        }
206
207        _ => None,
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::io::load_model;
215    use crate::io::load_problem;
216    use std::path::PathBuf;
217
218    fn data_dir() -> PathBuf {
219        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
220            .join("..")
221            .join("..")
222            .join("data")
223    }
224
225    #[test]
226    fn predict_heart_scale() {
227        // Load model trained by C LIBSVM and predict on training data
228        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
229        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
230
231        let mut correct = 0;
232        for (i, instance) in problem.instances.iter().enumerate() {
233            let pred = predict(&model, instance);
234            if pred == problem.labels[i] {
235                correct += 1;
236            }
237        }
238
239        let accuracy = correct as f64 / problem.labels.len() as f64;
240        // C LIBSVM gets ~86.67% accuracy on training set with default params
241        assert!(
242            accuracy > 0.85,
243            "accuracy {:.2}% too low (expected >85%)",
244            accuracy * 100.0
245        );
246    }
247
248    #[test]
249    fn predict_values_binary() {
250        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
251        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
252
253        // For binary classification, there's exactly 1 decision value
254        let mut dec_values = vec![0.0; 1];
255        let label = predict_values(&model, &problem.instances[0], &mut dec_values);
256
257        // Decision value should be non-zero
258        assert!(dec_values[0].abs() > 1e-10);
259        // Label should match what predict returns
260        assert_eq!(label, predict(&model, &problem.instances[0]));
261    }
262
263    #[test]
264    fn predict_probability_binary() {
265        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
266        let param = crate::types::SvmParameter {
267            svm_type: SvmType::CSvc,
268            kernel_type: crate::types::KernelType::Rbf,
269            gamma: 1.0 / 13.0,
270            c: 1.0,
271            cache_size: 100.0,
272            eps: 0.001,
273            shrinking: true,
274            probability: true,
275            ..Default::default()
276        };
277
278        let model = crate::train::svm_train(&problem, &param);
279        assert!(!model.prob_a.is_empty());
280
281        for instance in &problem.instances {
282            let result = predict_probability(&model, instance);
283            assert!(result.is_some(), "should return probability");
284            let (label, probs) = result.unwrap();
285            assert!(label == 1.0 || label == -1.0);
286            assert_eq!(probs.len(), 2);
287            let sum: f64 = probs.iter().sum();
288            assert!(
289                (sum - 1.0).abs() < 1e-6,
290                "probs sum to {}, expected 1.0",
291                sum
292            );
293            for &p in &probs {
294                assert!((0.0..=1.0).contains(&p), "prob {} out of [0,1]", p);
295            }
296        }
297    }
298
299    #[test]
300    fn predict_probability_multiclass() {
301        let problem = load_problem(&data_dir().join("iris.scale")).unwrap();
302        let param = crate::types::SvmParameter {
303            svm_type: SvmType::CSvc,
304            kernel_type: crate::types::KernelType::Rbf,
305            gamma: 0.25,
306            c: 1.0,
307            cache_size: 100.0,
308            eps: 0.001,
309            shrinking: true,
310            probability: true,
311            ..Default::default()
312        };
313
314        let model = crate::train::svm_train(&problem, &param);
315        assert_eq!(model.nr_class, 3);
316        assert_eq!(model.prob_a.len(), 3); // 3 pairs
317
318        for instance in problem.instances.iter().take(10) {
319            let result = predict_probability(&model, instance);
320            assert!(result.is_some());
321            let (_label, probs) = result.unwrap();
322            assert_eq!(probs.len(), 3);
323            let sum: f64 = probs.iter().sum();
324            assert!(
325                (sum - 1.0).abs() < 1e-6,
326                "probs sum to {}, expected 1.0",
327                sum
328            );
329        }
330    }
331
332    #[test]
333    fn predict_matches_c_svm_predict() {
334        // Run C svm-predict and compare outputs
335        // First, let's verify our predictions match by checking a few specific instances
336        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
337        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
338
339        // Run C svm-predict to get reference predictions
340        let c_predict = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
341            .join("..")
342            .join("..")
343            .join("vendor")
344            .join("libsvm")
345            .join("svm-predict");
346
347        if !c_predict.exists() {
348            // Skip if C binary not compiled
349            return;
350        }
351
352        // Write predictions to a temp file
353        let output_path = data_dir().join("heart_scale.predict_test");
354        let status = std::process::Command::new(&c_predict)
355            .args([
356                data_dir().join("heart_scale").to_str().unwrap(),
357                data_dir().join("heart_scale.model").to_str().unwrap(),
358                output_path.to_str().unwrap(),
359            ])
360            .output();
361
362        if let Ok(output) = status {
363            if output.status.success() {
364                let c_preds: Vec<f64> = std::fs::read_to_string(&output_path)
365                    .unwrap()
366                    .lines()
367                    .filter(|l| !l.is_empty())
368                    .map(|l| l.trim().parse().unwrap())
369                    .collect();
370
371                assert_eq!(c_preds.len(), problem.labels.len());
372
373                let mut mismatches = 0;
374                for (i, instance) in problem.instances.iter().enumerate() {
375                    let rust_pred = predict(&model, instance);
376                    if rust_pred != c_preds[i] {
377                        mismatches += 1;
378                    }
379                }
380
381                assert_eq!(
382                    mismatches, 0,
383                    "{} predictions differ between Rust and C",
384                    mismatches
385                );
386
387                // Clean up
388                let _ = std::fs::remove_file(&output_path);
389            }
390        }
391    }
392}