Skip to main content

libsvm_rs/
kernel.rs

1//! Kernel functions matching the original LIBSVM.
2//!
3//! Provides both:
4//! - Standalone `k_function` for prediction (operates on sparse node slices)
5//! - `Kernel` struct for training (precomputes x_square for RBF, stores refs)
6
7use crate::types::{KernelType, SvmNode, SvmParameter};
8
9// ─── Integer power (matches LIBSVM's powi) ─────────────────────────
10
11/// Integer power by squaring. Matches LIBSVM's `powi(base, times)`.
12///
13/// For negative `times`, returns 1.0 (same as the C code, which only
14/// iterates while `t > 0`).
15#[inline]
16pub fn powi(base: f64, times: i32) -> f64 {
17    let mut tmp = base;
18    let mut ret = 1.0;
19    let mut t = times;
20    while t > 0 {
21        if t % 2 == 1 {
22            ret *= tmp;
23        }
24        tmp *= tmp;
25        t /= 2;
26    }
27    ret
28}
29
30// ─── Sparse dot product ─────────────────────────────────────────────
31
32/// Sparse dot product of two sorted-by-index node slices.
33///
34/// This is the merge-based O(n+m) algorithm from LIBSVM.
35#[inline]
36pub fn dot(x: &[SvmNode], y: &[SvmNode]) -> f64 {
37    let mut sum = 0.0;
38    let mut ix = 0;
39    let mut iy = 0;
40    while ix < x.len() && iy < y.len() {
41        if x[ix].index == y[iy].index {
42            sum += x[ix].value * y[iy].value;
43            ix += 1;
44            iy += 1;
45        } else if x[ix].index > y[iy].index {
46            iy += 1;
47        } else {
48            ix += 1;
49        }
50    }
51    sum
52}
53
54/// Squared Euclidean distance for sparse vectors (used by RBF k_function).
55///
56/// Computes ‖x - y‖² without computing the difference vector.
57#[inline]
58fn sparse_sq_dist(x: &[SvmNode], y: &[SvmNode]) -> f64 {
59    let mut sum = 0.0;
60    let mut ix = 0;
61    let mut iy = 0;
62    while ix < x.len() && iy < y.len() {
63        if x[ix].index == y[iy].index {
64            let d = x[ix].value - y[iy].value;
65            sum += d * d;
66            ix += 1;
67            iy += 1;
68        } else if x[ix].index > y[iy].index {
69            sum += y[iy].value * y[iy].value;
70            iy += 1;
71        } else {
72            sum += x[ix].value * x[ix].value;
73            ix += 1;
74        }
75    }
76    // Drain remaining elements
77    while ix < x.len() {
78        sum += x[ix].value * x[ix].value;
79        ix += 1;
80    }
81    while iy < y.len() {
82        sum += y[iy].value * y[iy].value;
83        iy += 1;
84    }
85    sum
86}
87
88// ─── Standalone kernel evaluation ───────────────────────────────────
89
90/// Evaluate the kernel function K(x, y) for the given parameters.
91///
92/// This is the standalone version used during prediction. Matches
93/// LIBSVM's `Kernel::k_function`.
94pub fn k_function(x: &[SvmNode], y: &[SvmNode], param: &SvmParameter) -> f64 {
95    match param.kernel_type {
96        KernelType::Linear => dot(x, y),
97        KernelType::Polynomial => powi(param.gamma * dot(x, y) + param.coef0, param.degree),
98        KernelType::Rbf => (-param.gamma * sparse_sq_dist(x, y)).exp(),
99        KernelType::Sigmoid => (param.gamma * dot(x, y) + param.coef0).tanh(),
100        KernelType::Precomputed => {
101            // For precomputed kernels, x[y[0].value as index] gives the value.
102            // y[0].value is the column index (1-based SV index).
103            //
104            // Safety note: `node.value as usize` is a float→usize cast that saturates
105            // on out-of-range values (negative or > usize::MAX) — this matches upstream
106            // C++ behaviour. `.get()` then maps any out-of-range index to None, and
107            // `.map_or(0.0, ...)` returns 0.0 in that case (silent mis-map, no panic).
108            // Intentional: check_parameter validates the precomputed kernel at train time;
109            // at predict time the model loader validates the 0:serial_number row header.
110            y.first()
111                .and_then(|node| x.get(node.value as usize))
112                .map_or(0.0, |n| n.value)
113        }
114    }
115}
116
117// ─── Kernel struct for training ─────────────────────────────────────
118
119/// Kernel evaluator for training. Holds references to the dataset and
120/// precomputes `x_square[i] = dot(x[i], x[i])` for RBF kernels.
121///
122/// Stores `Vec<&'a [SvmNode]>` so that the solver can swap entries
123/// during shrinking (mirroring the C++ pointer-array swap trick).
124///
125/// The `kernel_function` method pointer pattern from C++ is replaced
126/// by a match on `kernel_type` — the branch predictor handles this
127/// efficiently since the type doesn't change during training.
128pub struct Kernel<'a> {
129    x: Vec<&'a [SvmNode]>,
130    x_square: Option<Vec<f64>>,
131    kernel_type: KernelType,
132    degree: i32,
133    gamma: f64,
134    coef0: f64,
135}
136
137impl<'a> Kernel<'a> {
138    /// Create a new kernel evaluator for the given dataset and parameters.
139    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
140        let x_refs: Vec<&'a [SvmNode]> = x.iter().map(|xi| xi.as_slice()).collect();
141        let x_square = if param.kernel_type == KernelType::Rbf {
142            Some(x_refs.iter().map(|xi| dot(xi, xi)).collect())
143        } else {
144            None
145        };
146
147        Self {
148            x: x_refs,
149            x_square,
150            kernel_type: param.kernel_type,
151            degree: param.degree,
152            gamma: param.gamma,
153            coef0: param.coef0,
154        }
155    }
156
157    /// Evaluate K(x\[i\], x\[j\]) using precomputed data where possible.
158    #[inline]
159    pub fn evaluate(&self, i: usize, j: usize) -> f64 {
160        match self.kernel_type {
161            KernelType::Linear => dot(self.x[i], self.x[j]),
162            KernelType::Polynomial => powi(
163                self.gamma * dot(self.x[i], self.x[j]) + self.coef0,
164                self.degree,
165            ),
166            KernelType::Rbf => {
167                // Use precomputed x_square: ‖x_i - x_j‖² = x_sq[i] + x_sq[j] - 2*dot(x_i, x_j)
168                let val = if let Some(sq) = &self.x_square {
169                    sq[i] + sq[j] - 2.0 * dot(self.x[i], self.x[j])
170                } else {
171                    sparse_sq_dist(self.x[i], self.x[j])
172                };
173                (-self.gamma * val).exp()
174            }
175            KernelType::Sigmoid => (self.gamma * dot(self.x[i], self.x[j]) + self.coef0).tanh(),
176            // See the free-function kernel_function above for the safety note on the
177            // float→usize cast: saturates on out-of-range values, .get() maps to None,
178            // .map_or returns 0.0 (intentional, matches upstream; silent mis-map, no panic).
179            KernelType::Precomputed => self.x[j]
180                .first()
181                .and_then(|node| self.x[i].get(node.value as usize))
182                .map_or(0.0, |n| n.value),
183        }
184    }
185
186    /// Swap data-point references and precomputed squares at positions i and j.
187    ///
188    /// Used by QMatrix implementations during solver shrinking.
189    pub fn swap_index(&mut self, i: usize, j: usize) {
190        self.x.swap(i, j);
191        if let Some(ref mut sq) = self.x_square {
192            sq.swap(i, j);
193        }
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200    use crate::types::SvmParameter;
201
202    fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
203        pairs
204            .iter()
205            .map(|&(index, value)| SvmNode { index, value })
206            .collect()
207    }
208
209    #[test]
210    fn powi_basic() {
211        assert_eq!(powi(2.0, 10), 1024.0);
212        assert_eq!(powi(3.0, 0), 1.0);
213        assert_eq!(powi(5.0, 1), 5.0);
214        assert!((powi(2.0, 3) - 8.0).abs() < 1e-15);
215        // Negative exponent: LIBSVM returns 1.0 (loop doesn't execute)
216        assert_eq!(powi(2.0, -1), 1.0);
217    }
218
219    #[test]
220    fn dot_product() {
221        let x = make_nodes(&[(1, 1.0), (3, 2.0), (5, 3.0)]);
222        let y = make_nodes(&[(1, 4.0), (2, 5.0), (5, 6.0)]);
223        // dot = 1*4 + 3*6 = 4 + 18 = 22
224        assert!((dot(&x, &y) - 22.0).abs() < 1e-15);
225    }
226
227    #[test]
228    fn dot_disjoint() {
229        let x = make_nodes(&[(1, 1.0), (3, 2.0)]);
230        let y = make_nodes(&[(2, 5.0), (4, 6.0)]);
231        assert_eq!(dot(&x, &y), 0.0);
232    }
233
234    #[test]
235    fn dot_empty() {
236        let x = make_nodes(&[]);
237        let y = make_nodes(&[(1, 1.0)]);
238        assert_eq!(dot(&x, &y), 0.0);
239    }
240
241    #[test]
242    fn kernel_linear() {
243        let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
244        let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
245        let param = SvmParameter {
246            kernel_type: KernelType::Linear,
247            ..Default::default()
248        };
249        assert!((k_function(&x, &y, &param) - 11.0).abs() < 1e-15);
250    }
251
252    #[test]
253    fn kernel_rbf() {
254        let x = make_nodes(&[(1, 1.0), (2, 0.0)]);
255        let y = make_nodes(&[(1, 0.0), (2, 1.0)]);
256        let param = SvmParameter {
257            kernel_type: KernelType::Rbf,
258            gamma: 0.5,
259            ..Default::default()
260        };
261        // ‖x-y‖² = 1+1 = 2, K = exp(-0.5 * 2) = exp(-1)
262        let expected = (-1.0_f64).exp();
263        assert!((k_function(&x, &y, &param) - expected).abs() < 1e-15);
264    }
265
266    #[test]
267    fn kernel_poly() {
268        let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
269        let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
270        let param = SvmParameter {
271            kernel_type: KernelType::Polynomial,
272            gamma: 1.0,
273            coef0: 1.0,
274            degree: 2,
275            ..Default::default()
276        };
277        // (1*1*11 + 1)^2 = 12^2 = 144
278        assert!((k_function(&x, &y, &param) - 144.0).abs() < 1e-15);
279    }
280
281    #[test]
282    fn kernel_sigmoid() {
283        let x = make_nodes(&[(1, 1.0)]);
284        let y = make_nodes(&[(1, 1.0)]);
285        let param = SvmParameter {
286            kernel_type: KernelType::Sigmoid,
287            gamma: 1.0,
288            coef0: 0.0,
289            ..Default::default()
290        };
291        // tanh(1*1 + 0) = tanh(1)
292        let expected = 1.0_f64.tanh();
293        assert!((k_function(&x, &y, &param) - expected).abs() < 1e-15);
294    }
295
296    #[test]
297    fn kernel_precomputed() {
298        // Precomputed rows use 0:sample_id followed by 1..l kernel values.
299        let x = make_nodes(&[(0, 1.0), (1, 1.5), (2, 2.5)]);
300        let y = make_nodes(&[(0, 2.0), (1, 1.5), (2, 2.5)]);
301        let param = SvmParameter {
302            kernel_type: KernelType::Precomputed,
303            ..Default::default()
304        };
305
306        // y[0].value = 2 => take x[2] => 2.5
307        assert!((k_function(&x, &y, &param) - 2.5).abs() < 1e-15);
308
309        let data = vec![x.clone(), y.clone()];
310        let kern = Kernel::new(&data, &param);
311        // evaluate(0,1) uses column index from row 1 sample id (2)
312        assert!((kern.evaluate(0, 1) - 2.5).abs() < 1e-15);
313    }
314
315    #[test]
316    fn kernel_struct_matches_standalone() {
317        let data = vec![
318            make_nodes(&[(1, 0.5), (3, -1.0)]),
319            make_nodes(&[(1, -0.25), (2, 0.75)]),
320            make_nodes(&[(2, 1.0), (3, 0.5)]),
321        ];
322        let param = SvmParameter {
323            kernel_type: KernelType::Rbf,
324            gamma: 0.5,
325            ..Default::default()
326        };
327
328        let kern = Kernel::new(&data, &param);
329
330        // Verify Kernel::evaluate matches k_function for all pairs
331        for i in 0..data.len() {
332            for j in 0..data.len() {
333                let via_struct = kern.evaluate(i, j);
334                let via_func = k_function(&data[i], &data[j], &param);
335                assert!(
336                    (via_struct - via_func).abs() < 1e-15,
337                    "mismatch at ({},{}): {} vs {}",
338                    i,
339                    j,
340                    via_struct,
341                    via_func
342                );
343            }
344        }
345    }
346
347    #[test]
348    fn rbf_self_kernel_is_one() {
349        let x = make_nodes(&[(1, 3.0), (5, -2.0), (10, 0.7)]);
350        let param = SvmParameter {
351            kernel_type: KernelType::Rbf,
352            gamma: 1.0,
353            ..Default::default()
354        };
355        // K(x, x) = exp(-γ * 0) = 1
356        assert!((k_function(&x, &x, &param) - 1.0).abs() < 1e-15);
357    }
358
359    #[test]
360    fn precomputed_kernel_missing_sample_serial_number_returns_zero() {
361        let x = make_nodes(&[(0, 1.0), (1, 2.0)]);
362        let y = Vec::new();
363        let param = SvmParameter {
364            kernel_type: KernelType::Precomputed,
365            ..Default::default()
366        };
367        assert_eq!(k_function(&x, &y, &param), 0.0);
368    }
369}