Skip to main content

libsvm_rs/
predict.rs

1//! Prediction functions matching the original LIBSVM.
2//!
3//! Provides `predict` and `predict_values` for all SVM types:
4//! - Classification (C-SVC, ν-SVC): one-vs-one voting
5//! - One-class SVM: sign of decision value
6//! - Regression (ε-SVR, ν-SVR): continuous output
7
8use crate::kernel::k_function;
9use crate::types::{SvmModel, SvmNode, SvmType};
10
11/// Compute decision values and return the predicted label/value.
12///
13/// For classification, `dec_values` receives `nr_class * (nr_class - 1) / 2`
14/// pairwise decision values. For regression/one-class, a single value.
15///
16/// Returns the predicted label (classification) or function value (regression).
17///
18/// Matches LIBSVM's `svm_predict_values`.
19pub fn predict_values(model: &SvmModel, x: &[SvmNode], dec_values: &mut [f64]) -> f64 {
20    match model.param.svm_type {
21        SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
22            let sv_coef = &model.sv_coef[0];
23            let mut sum = 0.0;
24            for (i, sv) in model.sv.iter().enumerate() {
25                sum += sv_coef[i] * k_function(x, sv, &model.param);
26            }
27            sum -= model.rho[0];
28            dec_values[0] = sum;
29
30            if model.param.svm_type == SvmType::OneClass {
31                if sum > 0.0 { 1.0 } else { -1.0 }
32            } else {
33                sum
34            }
35        }
36        SvmType::CSvc | SvmType::NuSvc => {
37            let nr_class = model.nr_class;
38            let l = model.sv.len();
39
40            // Compute kernel values for all SVs
41            let kvalue: Vec<f64> = model
42                .sv
43                .iter()
44                .map(|sv| k_function(x, sv, &model.param))
45                .collect();
46
47            // Compute start indices for each class's SVs
48            let mut start = vec![0usize; nr_class];
49            for i in 1..nr_class {
50                start[i] = start[i - 1] + model.n_sv[i - 1];
51            }
52
53            // One-vs-one voting
54            let mut vote = vec![0usize; nr_class];
55            let mut p = 0;
56            for i in 0..nr_class {
57                for j in (i + 1)..nr_class {
58                    let mut sum = 0.0;
59                    let si = start[i];
60                    let sj = start[j];
61                    let ci = model.n_sv[i];
62                    let cj = model.n_sv[j];
63
64                    let coef1 = &model.sv_coef[j - 1];
65                    let coef2 = &model.sv_coef[i];
66
67                    for k in 0..ci {
68                        sum += coef1[si + k] * kvalue[si + k];
69                    }
70                    for k in 0..cj {
71                        sum += coef2[sj + k] * kvalue[sj + k];
72                    }
73                    sum -= model.rho[p];
74                    dec_values[p] = sum;
75
76                    if sum > 0.0 {
77                        vote[i] += 1;
78                    } else {
79                        vote[j] += 1;
80                    }
81                    p += 1;
82                }
83            }
84
85            // Find class with most votes
86            let vote_max_idx = vote
87                .iter()
88                .enumerate()
89                .max_by_key(|&(_, &v)| v)
90                .map(|(i, _)| i)
91                .unwrap_or(0);
92
93            let _ = l; // suppress unused warning
94            model.label[vote_max_idx] as f64
95        }
96    }
97}
98
99/// Predict the label/value for a single instance.
100///
101/// Convenience wrapper around `predict_values` that allocates the
102/// decision values buffer internally. Matches LIBSVM's `svm_predict`.
103pub fn predict(model: &SvmModel, x: &[SvmNode]) -> f64 {
104    let n = match model.param.svm_type {
105        SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => 1,
106        SvmType::CSvc | SvmType::NuSvc => {
107            model.nr_class * (model.nr_class - 1) / 2
108        }
109    };
110    let mut dec_values = vec![0.0; n];
111    predict_values(model, x, &mut dec_values)
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::io::load_model;
118    use crate::io::load_problem;
119    use std::path::PathBuf;
120
121    fn data_dir() -> PathBuf {
122        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
123            .join("..")
124            .join("..")
125            .join("data")
126    }
127
128    #[test]
129    fn predict_heart_scale() {
130        // Load model trained by C LIBSVM and predict on training data
131        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
132        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
133
134        let mut correct = 0;
135        for (i, instance) in problem.instances.iter().enumerate() {
136            let pred = predict(&model, instance);
137            if pred == problem.labels[i] {
138                correct += 1;
139            }
140        }
141
142        let accuracy = correct as f64 / problem.labels.len() as f64;
143        // C LIBSVM gets ~86.67% accuracy on training set with default params
144        assert!(
145            accuracy > 0.85,
146            "accuracy {:.2}% too low (expected >85%)",
147            accuracy * 100.0
148        );
149    }
150
151    #[test]
152    fn predict_values_binary() {
153        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
154        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
155
156        // For binary classification, there's exactly 1 decision value
157        let mut dec_values = vec![0.0; 1];
158        let label = predict_values(&model, &problem.instances[0], &mut dec_values);
159
160        // Decision value should be non-zero
161        assert!(dec_values[0].abs() > 1e-10);
162        // Label should match what predict returns
163        assert_eq!(label, predict(&model, &problem.instances[0]));
164    }
165
166    #[test]
167    fn predict_matches_c_svm_predict() {
168        // Run C svm-predict and compare outputs
169        // First, let's verify our predictions match by checking a few specific instances
170        let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
171        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
172
173        // Run C svm-predict to get reference predictions
174        let c_predict = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
175            .join("..")
176            .join("..")
177            .join("vendor")
178            .join("libsvm")
179            .join("svm-predict");
180
181        if !c_predict.exists() {
182            // Skip if C binary not compiled
183            return;
184        }
185
186        // Write predictions to a temp file
187        let output_path = data_dir().join("heart_scale.predict_test");
188        let status = std::process::Command::new(&c_predict)
189            .args([
190                data_dir().join("heart_scale").to_str().unwrap(),
191                data_dir().join("heart_scale.model").to_str().unwrap(),
192                output_path.to_str().unwrap(),
193            ])
194            .output();
195
196        if let Ok(output) = status {
197            if output.status.success() {
198                let c_preds: Vec<f64> = std::fs::read_to_string(&output_path)
199                    .unwrap()
200                    .lines()
201                    .filter(|l| !l.is_empty())
202                    .map(|l| l.trim().parse().unwrap())
203                    .collect();
204
205                assert_eq!(c_preds.len(), problem.labels.len());
206
207                let mut mismatches = 0;
208                for (i, instance) in problem.instances.iter().enumerate() {
209                    let rust_pred = predict(&model, instance);
210                    if rust_pred != c_preds[i] {
211                        mismatches += 1;
212                    }
213                }
214
215                assert_eq!(
216                    mismatches, 0,
217                    "{} predictions differ between Rust and C",
218                    mismatches
219                );
220
221                // Clean up
222                let _ = std::fs::remove_file(&output_path);
223            }
224        }
225    }
226}