Skip to main content

libsvm_rs/
predict.rs

1//! Prediction functions matching the original LIBSVM.
2//!
3//! Provides `predict`, `predict_values`, and `predict_probability` for
4//! all SVM types:
5//! - Classification (C-SVC, ν-SVC): one-vs-one voting + optional probability
6//! - One-class SVM: sign of decision value + optional density probability
7//! - Regression (ε-SVR, ν-SVR): continuous output
8
9use crate::kernel::k_function;
10use crate::probability::{multiclass_probability, predict_one_class_probability, sigmoid_predict};
11use crate::types::{SvmModel, SvmNode, SvmType};
12
13/// Compute decision values and return the predicted label/value.
14///
15/// For classification, `dec_values` receives `nr_class * (nr_class - 1) / 2`
16/// pairwise decision values. For regression/one-class, a single value.
17///
18/// Returns the predicted label (classification) or function value (regression).
19///
20/// Matches LIBSVM's `svm_predict_values`.
21pub fn predict_values(model: &SvmModel, x: &[SvmNode], dec_values: &mut [f64]) -> f64 {
22    match model.param.svm_type {
23        SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
24            let sv_coef = &model.sv_coef[0];
25            let mut sum = 0.0;
26            for (i, sv) in model.sv.iter().enumerate() {
27                sum += sv_coef[i] * k_function(x, sv, &model.param);
28            }
29            sum -= model.rho[0];
30            dec_values[0] = sum;
31
32            if model.param.svm_type == SvmType::OneClass {
33                if sum > 0.0 { 1.0 } else { -1.0 }
34            } else {
35                sum
36            }
37        }
38        SvmType::CSvc | SvmType::NuSvc => {
39            let nr_class = model.nr_class;
40            let l = model.sv.len();
41
42            // Compute kernel values for all SVs
43            let kvalue: Vec<f64> = model
44                .sv
45                .iter()
46                .map(|sv| k_function(x, sv, &model.param))
47                .collect();
48
49            // Compute start indices for each class's SVs
50            let mut start = vec![0usize; nr_class];
51            for i in 1..nr_class {
52                start[i] = start[i - 1] + model.n_sv[i - 1];
53            }
54
55            // One-vs-one voting
56            let mut vote = vec![0usize; nr_class];
57            let mut p = 0;
58            for i in 0..nr_class {
59                for j in (i + 1)..nr_class {
60                    let mut sum = 0.0;
61                    let si = start[i];
62                    let sj = start[j];
63                    let ci = model.n_sv[i];
64                    let cj = model.n_sv[j];
65
66                    let coef1 = &model.sv_coef[j - 1];
67                    let coef2 = &model.sv_coef[i];
68
69                    for k in 0..ci {
70                        sum += coef1[si + k] * kvalue[si + k];
71                    }
72                    for k in 0..cj {
73                        sum += coef2[sj + k] * kvalue[sj + k];
74                    }
75                    sum -= model.rho[p];
76                    dec_values[p] = sum;
77
78                    if sum > 0.0 {
79                        vote[i] += 1;
80                    } else {
81                        vote[j] += 1;
82                    }
83                    p += 1;
84                }
85            }
86
87            // Find class with most votes
88            let vote_max_idx = vote
89                .iter()
90                .enumerate()
91                .max_by_key(|&(_, &v)| v)
92                .map(|(i, _)| i)
93                .unwrap_or(0);
94
95            let _ = l; // suppress unused warning
96            model.label[vote_max_idx] as f64
97        }
98    }
99}
100
101/// Predict the label/value for a single instance.
102///
103/// Convenience wrapper around `predict_values` that allocates the
104/// decision values buffer internally. Matches LIBSVM's `svm_predict`.
105pub fn predict(model: &SvmModel, x: &[SvmNode]) -> f64 {
106    let n = match model.param.svm_type {
107        SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => 1,
108        SvmType::CSvc | SvmType::NuSvc => {
109            model.nr_class * (model.nr_class - 1) / 2
110        }
111    };
112    let mut dec_values = vec![0.0; n];
113    predict_values(model, x, &mut dec_values)
114}
115
116/// Predict with probability estimates.
117///
118/// Returns `Some((label, probs))` where `probs[i]` is the estimated
119/// probability of class `model.label[i]`. Returns `None` when the
120/// model was not trained with probability support.
121///
122/// - **C-SVC / ν-SVC**: requires `model.prob_a` and `model.prob_b`.
123///   Uses Platt scaling on pairwise decision values, then
124///   `multiclass_probability` for k > 2.
125/// - **One-class**: requires `model.prob_density_marks`.
126///   Returns `[p, 1-p]` via density-mark lookup.
127/// - **SVR**: probability prediction is not supported (returns `None`).
128///
129/// Matches LIBSVM's `svm_predict_probability`.
130#[allow(clippy::needless_range_loop)]
131pub fn predict_probability(model: &SvmModel, x: &[SvmNode]) -> Option<(f64, Vec<f64>)> {
132    match model.param.svm_type {
133        SvmType::CSvc | SvmType::NuSvc
134            if !model.prob_a.is_empty() && !model.prob_b.is_empty() =>
135        {
136            let nr_class = model.nr_class;
137            let n_pairs = nr_class * (nr_class - 1) / 2;
138            let mut dec_values = vec![0.0; n_pairs];
139            predict_values(model, x, &mut dec_values);
140
141            let min_prob = 1e-7;
142
143            // Build pairwise probability matrix
144            let mut pairwise = vec![vec![0.0; nr_class]; nr_class];
145            let mut k = 0;
146            for i in 0..nr_class {
147                for j in (i + 1)..nr_class {
148                    let p = sigmoid_predict(dec_values[k], model.prob_a[k], model.prob_b[k])
149                        .max(min_prob)
150                        .min(1.0 - min_prob);
151                    pairwise[i][j] = p;
152                    pairwise[j][i] = 1.0 - p;
153                    k += 1;
154                }
155            }
156
157            let mut prob_estimates = vec![0.0; nr_class];
158            if nr_class == 2 {
159                prob_estimates[0] = pairwise[0][1];
160                prob_estimates[1] = pairwise[1][0];
161            } else {
162                multiclass_probability(nr_class, &pairwise, &mut prob_estimates);
163            }
164
165            // Find class with highest probability
166            let mut best = 0;
167            for i in 1..nr_class {
168                if prob_estimates[i] > prob_estimates[best] {
169                    best = i;
170                }
171            }
172
173            Some((model.label[best] as f64, prob_estimates))
174        }
175
176        SvmType::OneClass if !model.prob_density_marks.is_empty() => {
177            let mut dec_value = [0.0];
178            let pred_result = predict_values(model, x, &mut dec_value);
179            let p = predict_one_class_probability(&model.prob_density_marks, dec_value[0]);
180            Some((pred_result, vec![p, 1.0 - p]))
181        }
182
183        _ => None,
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use crate::io::load_model;
191    use crate::io::load_problem;
192    use std::path::PathBuf;
193
194    fn data_dir() -> PathBuf {
195        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
196            .join("..")
197            .join("..")
198            .join("data")
199    }
200
201    #[test]
202    fn predict_heart_scale() {
203        // Load model trained by C LIBSVM and predict on training data
204        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
205        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
206
207        let mut correct = 0;
208        for (i, instance) in problem.instances.iter().enumerate() {
209            let pred = predict(&model, instance);
210            if pred == problem.labels[i] {
211                correct += 1;
212            }
213        }
214
215        let accuracy = correct as f64 / problem.labels.len() as f64;
216        // C LIBSVM gets ~86.67% accuracy on training set with default params
217        assert!(
218            accuracy > 0.85,
219            "accuracy {:.2}% too low (expected >85%)",
220            accuracy * 100.0
221        );
222    }
223
224    #[test]
225    fn predict_values_binary() {
226        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
227        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
228
229        // For binary classification, there's exactly 1 decision value
230        let mut dec_values = vec![0.0; 1];
231        let label = predict_values(&model, &problem.instances[0], &mut dec_values);
232
233        // Decision value should be non-zero
234        assert!(dec_values[0].abs() > 1e-10);
235        // Label should match what predict returns
236        assert_eq!(label, predict(&model, &problem.instances[0]));
237    }
238
239    #[test]
240    fn predict_probability_binary() {
241        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
242        let param = crate::types::SvmParameter {
243            svm_type: SvmType::CSvc,
244            kernel_type: crate::types::KernelType::Rbf,
245            gamma: 1.0 / 13.0,
246            c: 1.0,
247            cache_size: 100.0,
248            eps: 0.001,
249            shrinking: true,
250            probability: true,
251            ..Default::default()
252        };
253
254        let model = crate::train::svm_train(&problem, &param);
255        assert!(!model.prob_a.is_empty());
256
257        for instance in &problem.instances {
258            let result = predict_probability(&model, instance);
259            assert!(result.is_some(), "should return probability");
260            let (label, probs) = result.unwrap();
261            assert!(label == 1.0 || label == -1.0);
262            assert_eq!(probs.len(), 2);
263            let sum: f64 = probs.iter().sum();
264            assert!(
265                (sum - 1.0).abs() < 1e-6,
266                "probs sum to {}, expected 1.0",
267                sum
268            );
269            for &p in &probs {
270                assert!((0.0..=1.0).contains(&p), "prob {} out of [0,1]", p);
271            }
272        }
273    }
274
275    #[test]
276    fn predict_probability_multiclass() {
277        let problem = load_problem(&data_dir().join("iris.scale")).unwrap();
278        let param = crate::types::SvmParameter {
279            svm_type: SvmType::CSvc,
280            kernel_type: crate::types::KernelType::Rbf,
281            gamma: 0.25,
282            c: 1.0,
283            cache_size: 100.0,
284            eps: 0.001,
285            shrinking: true,
286            probability: true,
287            ..Default::default()
288        };
289
290        let model = crate::train::svm_train(&problem, &param);
291        assert_eq!(model.nr_class, 3);
292        assert_eq!(model.prob_a.len(), 3); // 3 pairs
293
294        for instance in problem.instances.iter().take(10) {
295            let result = predict_probability(&model, instance);
296            assert!(result.is_some());
297            let (_label, probs) = result.unwrap();
298            assert_eq!(probs.len(), 3);
299            let sum: f64 = probs.iter().sum();
300            assert!(
301                (sum - 1.0).abs() < 1e-6,
302                "probs sum to {}, expected 1.0",
303                sum
304            );
305        }
306    }
307
308    #[test]
309    fn predict_matches_c_svm_predict() {
310        // Run C svm-predict and compare outputs
311        // First, let's verify our predictions match by checking a few specific instances
312        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
313        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
314
315        // Run C svm-predict to get reference predictions
316        let c_predict = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
317            .join("..")
318            .join("..")
319            .join("vendor")
320            .join("libsvm")
321            .join("svm-predict");
322
323        if !c_predict.exists() {
324            // Skip if C binary not compiled
325            return;
326        }
327
328        // Write predictions to a temp file
329        let output_path = data_dir().join("heart_scale.predict_test");
330        let status = std::process::Command::new(&c_predict)
331            .args([
332                data_dir().join("heart_scale").to_str().unwrap(),
333                data_dir().join("heart_scale.model").to_str().unwrap(),
334                output_path.to_str().unwrap(),
335            ])
336            .output();
337
338        if let Ok(output) = status {
339            if output.status.success() {
340                let c_preds: Vec<f64> = std::fs::read_to_string(&output_path)
341                    .unwrap()
342                    .lines()
343                    .filter(|l| !l.is_empty())
344                    .map(|l| l.trim().parse().unwrap())
345                    .collect();
346
347                assert_eq!(c_preds.len(), problem.labels.len());
348
349                let mut mismatches = 0;
350                for (i, instance) in problem.instances.iter().enumerate() {
351                    let rust_pred = predict(&model, instance);
352                    if rust_pred != c_preds[i] {
353                        mismatches += 1;
354                    }
355                }
356
357                assert_eq!(
358                    mismatches, 0,
359                    "{} predictions differ between Rust and C",
360                    mismatches
361                );
362
363                // Clean up
364                let _ = std::fs::remove_file(&output_path);
365            }
366        }
367    }
368}