Skip to main content

libsvm_rs/
cross_validation.rs

1//! Cross-validation for SVM models.
2//!
3//! Provides stratified cross-validation for classification and simple
4//! random-split cross-validation for regression/one-class problems.
5//! Matches LIBSVM's `svm_cross_validation` (svm.cpp:2437–2556).
6
7use crate::predict::{predict, predict_probability};
8use crate::train::svm_train;
9use crate::types::{SvmModel, SvmNode, SvmParameter, SvmProblem, SvmType};
10
11// ─── RNG helper ──────────────────────────────────────────────────────
12
13fn rng_next(state: &mut u64) -> usize {
14    *state = state
15        .wrapping_mul(6364136223846793005)
16        .wrapping_add(1442695040888963407);
17    (*state >> 33) as usize
18}
19
20fn predict_cv_target(model: &SvmModel, param: &SvmParameter, x: &[SvmNode]) -> f64 {
21    if param.probability && matches!(param.svm_type, SvmType::CSvc | SvmType::NuSvc) {
22        predict_probability(model, x)
23            .map(|(label, _)| label)
24            .unwrap_or_else(|| predict(model, x))
25    } else {
26        predict(model, x)
27    }
28}
29
30// ─── Public API ──────────────────────────────────────────────────────
31
32/// Perform k-fold cross-validation on an SVM problem.
33///
34/// Returns a `Vec<f64>` of length `prob.labels.len()` where `target[i]`
35/// is the prediction for instance `i` when it was held out.
36///
37/// - **Classification** (C-SVC, ν-SVC) with `nr_fold < l`: stratified
38///   splitting that preserves class ratios across folds.
39/// - **Regression / one-class** or `nr_fold == l`: simple random split.
40///
41/// If `nr_fold > l`, clamps to `l` (leave-one-out).
42pub fn svm_cross_validation(
43    prob: &SvmProblem,
44    param: &SvmParameter,
45    mut nr_fold: usize,
46) -> Vec<f64> {
47    let l = prob.labels.len();
48
49    if l == 0 {
50        return Vec::new();
51    }
52
53    if nr_fold == 0 {
54        crate::info(
55            "WARNING: # folds (0) <= 0. Will use # folds = # data instead \
56             (i.e., leave-one-out cross validation)\n",
57        );
58        nr_fold = l;
59    }
60
61    if nr_fold > l {
62        crate::info(&format!(
63            "WARNING: # folds ({}) > # data ({}). Will use # folds = # data instead \
64             (i.e., leave-one-out cross validation)\n",
65            nr_fold, l
66        ));
67        nr_fold = l;
68    }
69
70    let mut rng: u64 = 1;
71    let mut perm: Vec<usize> = (0..l).collect();
72    let mut fold_start = vec![0usize; nr_fold + 1];
73
74    // ── Fold assignment ──────────────────────────────────────────
75    if matches!(param.svm_type, SvmType::CSvc | SvmType::NuSvc) && nr_fold < l {
76        // Stratified: group by class, shuffle within class, distribute
77        let (label_list, start, count, group_perm) = group_classes(&prob.labels);
78        let nr_class = label_list.len();
79
80        let mut index = group_perm;
81
82        // Shuffle within each class
83        for c in 0..nr_class {
84            let s = start[c];
85            let n = count[c];
86            for i in 0..n {
87                let j = i + rng_next(&mut rng) % (n - i);
88                index.swap(s + i, s + j);
89            }
90        }
91
92        // Compute fold sizes
93        let mut fold_count = vec![0usize; nr_fold];
94        for (i, fc) in fold_count.iter_mut().enumerate() {
95            for &cnt in &count {
96                *fc += ((i + 1) * cnt) / nr_fold - (i * cnt) / nr_fold;
97            }
98        }
99
100        fold_start[0] = 0;
101        for i in 0..nr_fold {
102            fold_start[i + 1] = fold_start[i] + fold_count[i];
103        }
104
105        // Distribute samples to folds, preserving class balance
106        // (C++ increments fold_start[i] as a running pointer; we
107        // use a separate offset array to keep fold_start immutable.)
108        let mut offset = vec![0usize; nr_fold];
109        for c in 0..nr_class {
110            for i in 0..nr_fold {
111                let begin = start[c] + (i * count[c]) / nr_fold;
112                let end = start[c] + ((i + 1) * count[c]) / nr_fold;
113                for &idx in &index[begin..end] {
114                    perm[fold_start[i] + offset[i]] = idx;
115                    offset[i] += 1;
116                }
117            }
118        }
119
120        // Rebuild fold_start from fold_count
121        fold_start[0] = 0;
122        for i in 0..nr_fold {
123            fold_start[i + 1] = fold_start[i] + fold_count[i];
124        }
125    } else {
126        // Simple random shuffle
127        for i in 0..l {
128            let j = i + rng_next(&mut rng) % (l - i);
129            perm.swap(i, j);
130        }
131        for (i, fs) in fold_start.iter_mut().enumerate() {
132            *fs = i * l / nr_fold;
133        }
134    }
135
136    // ── Evaluate each fold ───────────────────────────────────────
137    let mut target = vec![0.0; l];
138
139    for i in 0..nr_fold {
140        let begin = fold_start[i];
141        let end = fold_start[i + 1];
142
143        // Build sub-problem excluding held-out [begin..end)
144        let sub_l = l - (end - begin);
145        let mut sub_labels = Vec::with_capacity(sub_l);
146        let mut sub_instances = Vec::with_capacity(sub_l);
147
148        for &pi in &perm[..begin] {
149            sub_labels.push(prob.labels[pi]);
150            sub_instances.push(prob.instances[pi].clone());
151        }
152        for &pi in &perm[end..l] {
153            sub_labels.push(prob.labels[pi]);
154            sub_instances.push(prob.instances[pi].clone());
155        }
156
157        let subprob = SvmProblem {
158            labels: sub_labels,
159            instances: sub_instances,
160        };
161        let submodel = svm_train(&subprob, param);
162
163        // Predict held-out
164        for j in begin..end {
165            target[perm[j]] = predict_cv_target(&submodel, param, &prob.instances[perm[j]]);
166        }
167    }
168
169    target
170}
171
172// ─── Internal helpers ────────────────────────────────────────────────
173
174/// Group samples by class label (same logic as `train::svm_group_classes`).
175#[allow(clippy::needless_range_loop)]
176fn group_classes(labels: &[f64]) -> (Vec<i32>, Vec<usize>, Vec<usize>, Vec<usize>) {
177    let l = labels.len();
178    let mut label_list: Vec<i32> = Vec::new();
179    let mut count: Vec<usize> = Vec::new();
180    let mut data_label = vec![0usize; l];
181
182    for i in 0..l {
183        let this_label = labels[i] as i32;
184        if let Some(pos) = label_list.iter().position(|&lab| lab == this_label) {
185            count[pos] += 1;
186            data_label[i] = pos;
187        } else {
188            data_label[i] = label_list.len();
189            label_list.push(this_label);
190            count.push(1);
191        }
192    }
193
194    let nr_class = label_list.len();
195
196    // For binary with -1/+1 where -1 appears first, swap to put +1 first
197    if nr_class == 2 && label_list[0] == -1 && label_list[1] == 1 {
198        label_list.swap(0, 1);
199        count.swap(0, 1);
200        for dl in data_label.iter_mut() {
201            *dl = if *dl == 0 { 1 } else { 0 };
202        }
203    }
204
205    let mut start = vec![0usize; nr_class];
206    for i in 1..nr_class {
207        start[i] = start[i - 1] + count[i - 1];
208    }
209
210    let mut perm = vec![0usize; l];
211    let mut start_copy = start.clone();
212    for i in 0..l {
213        let cls = data_label[i];
214        perm[start_copy[cls]] = i;
215        start_copy[cls] += 1;
216    }
217
218    (label_list, start, count, perm)
219}
220
221// ─── Tests ───────────────────────────────────────────────────────────
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226    use crate::io::load_problem;
227    use crate::types::{KernelType, SvmModel, SvmNode};
228    use std::path::PathBuf;
229
230    fn data_dir() -> PathBuf {
231        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
232            .join("..")
233            .join("..")
234            .join("data")
235    }
236
237    #[test]
238    fn cross_validation_basic() {
239        let labels = vec![1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0];
240        let instances: Vec<Vec<SvmNode>> = (0..10)
241            .map(|i| {
242                vec![SvmNode {
243                    index: 1,
244                    value: i as f64 * 0.1,
245                }]
246            })
247            .collect();
248
249        let prob = SvmProblem { labels, instances };
250        let param = SvmParameter {
251            kernel_type: KernelType::Linear,
252            ..Default::default()
253        };
254
255        let target = svm_cross_validation(&prob, &param, 5);
256        assert_eq!(target.len(), 10);
257        for &pred in &target {
258            assert!(pred == 1.0 || pred == -1.0);
259        }
260    }
261
262    #[test]
263    fn cross_validation_classification() {
264        let problem = load_problem(&data_dir().join("heart_scale")).unwrap();
265        let param = SvmParameter {
266            svm_type: SvmType::CSvc,
267            kernel_type: KernelType::Rbf,
268            gamma: 1.0 / 13.0,
269            c: 1.0,
270            cache_size: 100.0,
271            eps: 0.001,
272            shrinking: true,
273            ..Default::default()
274        };
275
276        let target = svm_cross_validation(&problem, &param, 5);
277        assert_eq!(target.len(), problem.labels.len());
278
279        let correct = target
280            .iter()
281            .zip(problem.labels.iter())
282            .filter(|(&pred, &label)| pred == label)
283            .count();
284        let accuracy = correct as f64 / problem.labels.len() as f64;
285        assert!(
286            accuracy > 0.70,
287            "5-fold CV accuracy {:.1}% too low (expected >70%)",
288            accuracy * 100.0
289        );
290    }
291
292    #[test]
293    fn cross_validation_regression() {
294        let problem = load_problem(&data_dir().join("housing_scale")).unwrap();
295        let param = SvmParameter {
296            svm_type: SvmType::EpsilonSvr,
297            kernel_type: KernelType::Rbf,
298            gamma: 1.0 / 13.0,
299            c: 1.0,
300            p: 0.1,
301            cache_size: 100.0,
302            eps: 0.001,
303            shrinking: true,
304            ..Default::default()
305        };
306
307        let target = svm_cross_validation(&problem, &param, 5);
308        assert_eq!(target.len(), problem.labels.len());
309
310        let mse: f64 = target
311            .iter()
312            .zip(problem.labels.iter())
313            .map(|(&pred, &label)| (pred - label).powi(2))
314            .sum::<f64>()
315            / problem.labels.len() as f64;
316        assert!(mse.is_finite(), "MSE is not finite");
317        assert!(mse < 500.0, "MSE {} too high", mse);
318    }
319
320    #[test]
321    fn cross_validation_zero_folds_clamps_to_leave_one_out() {
322        let labels = vec![1.0, -1.0, 1.0, -1.0];
323        let instances: Vec<Vec<SvmNode>> = vec![
324            vec![SvmNode {
325                index: 1,
326                value: 1.0,
327            }],
328            vec![SvmNode {
329                index: 1,
330                value: -1.0,
331            }],
332            vec![SvmNode {
333                index: 1,
334                value: 0.8,
335            }],
336            vec![SvmNode {
337                index: 1,
338                value: -0.9,
339            }],
340        ];
341        let prob = SvmProblem { labels, instances };
342        let param = SvmParameter {
343            svm_type: SvmType::CSvc,
344            kernel_type: KernelType::Linear,
345            c: 1.0,
346            eps: 0.001,
347            ..Default::default()
348        };
349
350        let target = svm_cross_validation(&prob, &param, 0);
351        assert_eq!(target.len(), prob.labels.len());
352        for &pred in &target {
353            assert!(pred == 1.0 || pred == -1.0);
354        }
355    }
356
357    #[test]
358    fn cross_validation_empty_problem_returns_empty() {
359        let prob = SvmProblem {
360            labels: Vec::new(),
361            instances: Vec::new(),
362        };
363        let target = svm_cross_validation(&prob, &SvmParameter::default(), 5);
364        assert!(target.is_empty());
365    }
366
367    #[test]
368    fn predict_cv_target_uses_probability_label_for_classification() {
369        let param = SvmParameter {
370            svm_type: SvmType::CSvc,
371            kernel_type: KernelType::Linear,
372            probability: true,
373            ..Default::default()
374        };
375        let model = SvmModel {
376            param: param.clone(),
377            nr_class: 2,
378            sv: vec![
379                vec![SvmNode {
380                    index: 1,
381                    value: 1.0,
382                }],
383                vec![SvmNode {
384                    index: 1,
385                    value: 1.0,
386                }],
387            ],
388            sv_coef: vec![vec![1.0, -1.0]],
389            rho: vec![-1.0],
390            prob_a: vec![1.0],
391            prob_b: vec![0.0],
392            prob_density_marks: Vec::new(),
393            sv_indices: vec![1, 2],
394            label: vec![1, -1],
395            n_sv: vec![1, 1],
396        };
397        let x = vec![SvmNode {
398            index: 1,
399            value: 1.0,
400        }];
401
402        let vote_label = predict(&model, &x);
403        let (prob_label, _) = predict_probability(&model, &x).unwrap();
404
405        assert_eq!(vote_label, 1.0);
406        assert_eq!(prob_label, -1.0);
407        assert_eq!(predict_cv_target(&model, &param, &x), prob_label);
408    }
409}