1use crate::kernel::k_function;
9use crate::types::{SvmModel, SvmNode, SvmType};
10
11pub fn predict_values(model: &SvmModel, x: &[SvmNode], dec_values: &mut [f64]) -> f64 {
20 match model.param.svm_type {
21 SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => {
22 let sv_coef = &model.sv_coef[0];
23 let mut sum = 0.0;
24 for (i, sv) in model.sv.iter().enumerate() {
25 sum += sv_coef[i] * k_function(x, sv, &model.param);
26 }
27 sum -= model.rho[0];
28 dec_values[0] = sum;
29
30 if model.param.svm_type == SvmType::OneClass {
31 if sum > 0.0 { 1.0 } else { -1.0 }
32 } else {
33 sum
34 }
35 }
36 SvmType::CSvc | SvmType::NuSvc => {
37 let nr_class = model.nr_class;
38 let l = model.sv.len();
39
40 let kvalue: Vec<f64> = model
42 .sv
43 .iter()
44 .map(|sv| k_function(x, sv, &model.param))
45 .collect();
46
47 let mut start = vec![0usize; nr_class];
49 for i in 1..nr_class {
50 start[i] = start[i - 1] + model.n_sv[i - 1];
51 }
52
53 let mut vote = vec![0usize; nr_class];
55 let mut p = 0;
56 for i in 0..nr_class {
57 for j in (i + 1)..nr_class {
58 let mut sum = 0.0;
59 let si = start[i];
60 let sj = start[j];
61 let ci = model.n_sv[i];
62 let cj = model.n_sv[j];
63
64 let coef1 = &model.sv_coef[j - 1];
65 let coef2 = &model.sv_coef[i];
66
67 for k in 0..ci {
68 sum += coef1[si + k] * kvalue[si + k];
69 }
70 for k in 0..cj {
71 sum += coef2[sj + k] * kvalue[sj + k];
72 }
73 sum -= model.rho[p];
74 dec_values[p] = sum;
75
76 if sum > 0.0 {
77 vote[i] += 1;
78 } else {
79 vote[j] += 1;
80 }
81 p += 1;
82 }
83 }
84
85 let vote_max_idx = vote
87 .iter()
88 .enumerate()
89 .max_by_key(|&(_, &v)| v)
90 .map(|(i, _)| i)
91 .unwrap_or(0);
92
93 let _ = l; model.label[vote_max_idx] as f64
95 }
96 }
97}
98
99pub fn predict(model: &SvmModel, x: &[SvmNode]) -> f64 {
104 let n = match model.param.svm_type {
105 SvmType::OneClass | SvmType::EpsilonSvr | SvmType::NuSvr => 1,
106 SvmType::CSvc | SvmType::NuSvc => {
107 model.nr_class * (model.nr_class - 1) / 2
108 }
109 };
110 let mut dec_values = vec![0.0; n];
111 predict_values(model, x, &mut dec_values)
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::io::load_model;
118 use crate::io::load_problem;
119 use std::path::PathBuf;
120
121 fn data_dir() -> PathBuf {
122 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
123 .join("..")
124 .join("..")
125 .join("data")
126 }
127
128 #[test]
129 fn predict_heart_scale() {
130 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
132 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
133
134 let mut correct = 0;
135 for (i, instance) in problem.instances.iter().enumerate() {
136 let pred = predict(&model, instance);
137 if pred == problem.labels[i] {
138 correct += 1;
139 }
140 }
141
142 let accuracy = correct as f64 / problem.labels.len() as f64;
143 assert!(
145 accuracy > 0.85,
146 "accuracy {:.2}% too low (expected >85%)",
147 accuracy * 100.0
148 );
149 }
150
151 #[test]
152 fn predict_values_binary() {
153 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
154 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
155
156 let mut dec_values = vec![0.0; 1];
158 let label = predict_values(&model, &problem.instances[0], &mut dec_values);
159
160 assert!(dec_values[0].abs() > 1e-10);
162 assert_eq!(label, predict(&model, &problem.instances[0]));
164 }
165
166 #[test]
167 fn predict_matches_c_svm_predict() {
168 let model = load_model(&data_dir().join("heart_scale.model")).unwrap();
171 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
172
173 let c_predict = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
175 .join("..")
176 .join("..")
177 .join("vendor")
178 .join("libsvm")
179 .join("svm-predict");
180
181 if !c_predict.exists() {
182 return;
184 }
185
186 let output_path = data_dir().join("heart_scale.predict_test");
188 let status = std::process::Command::new(&c_predict)
189 .args([
190 data_dir().join("heart_scale").to_str().unwrap(),
191 data_dir().join("heart_scale.model").to_str().unwrap(),
192 output_path.to_str().unwrap(),
193 ])
194 .output();
195
196 if let Ok(output) = status {
197 if output.status.success() {
198 let c_preds: Vec<f64> = std::fs::read_to_string(&output_path)
199 .unwrap()
200 .lines()
201 .filter(|l| !l.is_empty())
202 .map(|l| l.trim().parse().unwrap())
203 .collect();
204
205 assert_eq!(c_preds.len(), problem.labels.len());
206
207 let mut mismatches = 0;
208 for (i, instance) in problem.instances.iter().enumerate() {
209 let rust_pred = predict(&model, instance);
210 if rust_pred != c_preds[i] {
211 mismatches += 1;
212 }
213 }
214
215 assert_eq!(
216 mismatches, 0,
217 "{} predictions differ between Rust and C",
218 mismatches
219 );
220
221 let _ = std::fs::remove_file(&output_path);
223 }
224 }
225 }
226}