1use crate::predict::{predict, predict_probability};
8use crate::train::svm_train;
9use crate::types::{SvmModel, SvmNode, SvmParameter, SvmProblem, SvmType};
10
11fn rng_next(state: &mut u64) -> usize {
14 *state = state
15 .wrapping_mul(6364136223846793005)
16 .wrapping_add(1442695040888963407);
17 (*state >> 33) as usize
18}
19
20fn predict_cv_target(model: &SvmModel, param: &SvmParameter, x: &[SvmNode]) -> f64 {
21 if param.probability && matches!(param.svm_type, SvmType::CSvc | SvmType::NuSvc) {
22 predict_probability(model, x)
23 .map(|(label, _)| label)
24 .unwrap_or_else(|| predict(model, x))
25 } else {
26 predict(model, x)
27 }
28}
29
30pub fn svm_cross_validation(
43 prob: &SvmProblem,
44 param: &SvmParameter,
45 mut nr_fold: usize,
46) -> Vec<f64> {
47 let l = prob.labels.len();
48
49 if l == 0 {
50 return Vec::new();
51 }
52
53 if nr_fold == 0 {
54 crate::info(
55 "WARNING: # folds (0) <= 0. Will use # folds = # data instead \
56 (i.e., leave-one-out cross validation)\n",
57 );
58 nr_fold = l;
59 }
60
61 if nr_fold > l {
62 crate::info(&format!(
63 "WARNING: # folds ({}) > # data ({}). Will use # folds = # data instead \
64 (i.e., leave-one-out cross validation)\n",
65 nr_fold, l
66 ));
67 nr_fold = l;
68 }
69
70 let mut rng: u64 = 1;
71 let mut perm: Vec<usize> = (0..l).collect();
72 let mut fold_start = vec![0usize; nr_fold + 1];
73
74 if matches!(param.svm_type, SvmType::CSvc | SvmType::NuSvc) && nr_fold < l {
76 let (label_list, start, count, group_perm) = group_classes(&prob.labels);
78 let nr_class = label_list.len();
79
80 let mut index = group_perm;
81
82 for c in 0..nr_class {
84 let s = start[c];
85 let n = count[c];
86 for i in 0..n {
87 let j = i + rng_next(&mut rng) % (n - i);
88 index.swap(s + i, s + j);
89 }
90 }
91
92 let mut fold_count = vec![0usize; nr_fold];
94 for (i, fc) in fold_count.iter_mut().enumerate() {
95 for &cnt in &count {
96 *fc += ((i + 1) * cnt) / nr_fold - (i * cnt) / nr_fold;
97 }
98 }
99
100 fold_start[0] = 0;
101 for i in 0..nr_fold {
102 fold_start[i + 1] = fold_start[i] + fold_count[i];
103 }
104
105 let mut offset = vec![0usize; nr_fold];
109 for c in 0..nr_class {
110 for i in 0..nr_fold {
111 let begin = start[c] + (i * count[c]) / nr_fold;
112 let end = start[c] + ((i + 1) * count[c]) / nr_fold;
113 for &idx in &index[begin..end] {
114 perm[fold_start[i] + offset[i]] = idx;
115 offset[i] += 1;
116 }
117 }
118 }
119
120 fold_start[0] = 0;
122 for i in 0..nr_fold {
123 fold_start[i + 1] = fold_start[i] + fold_count[i];
124 }
125 } else {
126 for i in 0..l {
128 let j = i + rng_next(&mut rng) % (l - i);
129 perm.swap(i, j);
130 }
131 for (i, fs) in fold_start.iter_mut().enumerate() {
132 *fs = i * l / nr_fold;
133 }
134 }
135
136 let mut target = vec![0.0; l];
138
139 for i in 0..nr_fold {
140 let begin = fold_start[i];
141 let end = fold_start[i + 1];
142
143 let sub_l = l - (end - begin);
145 let mut sub_labels = Vec::with_capacity(sub_l);
146 let mut sub_instances = Vec::with_capacity(sub_l);
147
148 for &pi in &perm[..begin] {
149 sub_labels.push(prob.labels[pi]);
150 sub_instances.push(prob.instances[pi].clone());
151 }
152 for &pi in &perm[end..l] {
153 sub_labels.push(prob.labels[pi]);
154 sub_instances.push(prob.instances[pi].clone());
155 }
156
157 let subprob = SvmProblem {
158 labels: sub_labels,
159 instances: sub_instances,
160 };
161 let submodel = svm_train(&subprob, param);
162
163 for j in begin..end {
165 target[perm[j]] = predict_cv_target(&submodel, param, &prob.instances[perm[j]]);
166 }
167 }
168
169 target
170}
171
172#[allow(clippy::needless_range_loop)]
176fn group_classes(labels: &[f64]) -> (Vec<i32>, Vec<usize>, Vec<usize>, Vec<usize>) {
177 let l = labels.len();
178 let mut label_list: Vec<i32> = Vec::new();
179 let mut count: Vec<usize> = Vec::new();
180 let mut data_label = vec![0usize; l];
181
182 for i in 0..l {
183 let this_label = labels[i] as i32;
184 if let Some(pos) = label_list.iter().position(|&lab| lab == this_label) {
185 count[pos] += 1;
186 data_label[i] = pos;
187 } else {
188 data_label[i] = label_list.len();
189 label_list.push(this_label);
190 count.push(1);
191 }
192 }
193
194 let nr_class = label_list.len();
195
196 if nr_class == 2 && label_list[0] == -1 && label_list[1] == 1 {
198 label_list.swap(0, 1);
199 count.swap(0, 1);
200 for dl in data_label.iter_mut() {
201 *dl = if *dl == 0 { 1 } else { 0 };
202 }
203 }
204
205 let mut start = vec![0usize; nr_class];
206 for i in 1..nr_class {
207 start[i] = start[i - 1] + count[i - 1];
208 }
209
210 let mut perm = vec![0usize; l];
211 let mut start_copy = start.clone();
212 for i in 0..l {
213 let cls = data_label[i];
214 perm[start_copy[cls]] = i;
215 start_copy[cls] += 1;
216 }
217
218 (label_list, start, count, perm)
219}
220
221#[cfg(test)]
224mod tests {
225 use super::*;
226 use crate::io::load_problem;
227 use crate::types::{KernelType, SvmModel, SvmNode};
228 use std::path::PathBuf;
229
230 fn data_dir() -> PathBuf {
231 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
232 .join("..")
233 .join("..")
234 .join("data")
235 }
236
237 #[test]
238 fn cross_validation_basic() {
239 let labels = vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
240 let instances: Vec<Vec<SvmNode>> = (0..10)
241 .map(|i| {
242 vec![SvmNode {
243 index: 1,
244 value: i as f64 * 0.1,
245 }]
246 })
247 .collect();
248
249 let prob = SvmProblem { labels, instances };
250 let param = SvmParameter {
251 kernel_type: KernelType::Linear,
252 ..Default::default()
253 };
254
255 let target = svm_cross_validation(&prob, ¶m, 5);
256 assert_eq!(target.len(), 10);
257 for &pred in &target {
258 assert!(pred == 1.0 || pred == -1.0);
259 }
260 }
261
262 #[test]
263 fn cross_validation_classification() {
264 let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
265 let param = SvmParameter {
266 svm_type: SvmType::CSvc,
267 kernel_type: KernelType::Rbf,
268 gamma: 1.0 / 13.0,
269 c: 1.0,
270 cache_size: 100.0,
271 eps: 0.001,
272 shrinking: true,
273 ..Default::default()
274 };
275
276 let target = svm_cross_validation(&problem, ¶m, 5);
277 assert_eq!(target.len(), problem.labels.len());
278
279 let correct = target
280 .iter()
281 .zip(problem.labels.iter())
282 .filter(|(&pred, &label)| pred == label)
283 .count();
284 let accuracy = correct as f64 / problem.labels.len() as f64;
285 assert!(
286 accuracy > 0.70,
287 "5-fold CV accuracy {:.1}% too low (expected >70%)",
288 accuracy * 100.0
289 );
290 }
291
292 #[test]
293 fn cross_validation_regression() {
294 let problem = load_problem(&data_dir().join("housing_scale")).unwrap();
295 let param = SvmParameter {
296 svm_type: SvmType::EpsilonSvr,
297 kernel_type: KernelType::Rbf,
298 gamma: 1.0 / 13.0,
299 c: 1.0,
300 p: 0.1,
301 cache_size: 100.0,
302 eps: 0.001,
303 shrinking: true,
304 ..Default::default()
305 };
306
307 let target = svm_cross_validation(&problem, ¶m, 5);
308 assert_eq!(target.len(), problem.labels.len());
309
310 let mse: f64 = target
311 .iter()
312 .zip(problem.labels.iter())
313 .map(|(&pred, &label)| (pred - label).powi(2))
314 .sum::<f64>()
315 / problem.labels.len() as f64;
316 assert!(mse.is_finite(), "MSE is not finite");
317 assert!(mse < 500.0, "MSE {} too high", mse);
318 }
319
320 #[test]
321 fn cross_validation_zero_folds_clamps_to_leave_one_out() {
322 let labels = vec![1.0, -1.0, 1.0, -1.0];
323 let instances: Vec<Vec<SvmNode>> = vec![
324 vec![SvmNode {
325 index: 1,
326 value: 1.0,
327 }],
328 vec![SvmNode {
329 index: 1,
330 value: -1.0,
331 }],
332 vec![SvmNode {
333 index: 1,
334 value: 0.8,
335 }],
336 vec![SvmNode {
337 index: 1,
338 value: -0.9,
339 }],
340 ];
341 let prob = SvmProblem { labels, instances };
342 let param = SvmParameter {
343 svm_type: SvmType::CSvc,
344 kernel_type: KernelType::Linear,
345 c: 1.0,
346 eps: 0.001,
347 ..Default::default()
348 };
349
350 let target = svm_cross_validation(&prob, ¶m, 0);
351 assert_eq!(target.len(), prob.labels.len());
352 for &pred in &target {
353 assert!(pred == 1.0 || pred == -1.0);
354 }
355 }
356
357 #[test]
358 fn cross_validation_empty_problem_returns_empty() {
359 let prob = SvmProblem {
360 labels: Vec::new(),
361 instances: Vec::new(),
362 };
363 let target = svm_cross_validation(&prob, &SvmParameter::default(), 5);
364 assert!(target.is_empty());
365 }
366
367 #[test]
368 fn predict_cv_target_uses_probability_label_for_classification() {
369 let param = SvmParameter {
370 svm_type: SvmType::CSvc,
371 kernel_type: KernelType::Linear,
372 probability: true,
373 ..Default::default()
374 };
375 let model = SvmModel {
376 param: param.clone(),
377 nr_class: 2,
378 sv: vec![
379 vec![SvmNode {
380 index: 1,
381 value: 1.0,
382 }],
383 vec![SvmNode {
384 index: 1,
385 value: 1.0,
386 }],
387 ],
388 sv_coef: vec![vec![1.0, -1.0]],
389 rho: vec![-1.0],
390 prob_a: vec![1.0],
391 prob_b: vec![0.0],
392 prob_density_marks: Vec::new(),
393 sv_indices: vec![1, 2],
394 label: vec![1, -1],
395 n_sv: vec![1, 1],
396 };
397 let x = vec![SvmNode {
398 index: 1,
399 value: 1.0,
400 }];
401
402 let vote_label = predict(&model, &x);
403 let (prob_label, _) = predict_probability(&model, &x).unwrap();
404
405 assert_eq!(vote_label, 1.0);
406 assert_eq!(prob_label, -1.0);
407 assert_eq!(predict_cv_target(&model, ¶m, &x), prob_label);
408 }
409}