Skip to main content

libsvm_rs/
kernel.rs

1//! Kernel functions matching the original LIBSVM.
2//!
3//! Provides both:
4//! - Standalone `k_function` for prediction (operates on sparse node slices)
5//! - `Kernel` struct for training (precomputes x_square for RBF, stores refs)
6
7use crate::types::{KernelType, SvmNode, SvmParameter};
8
9// ─── Integer power (matches LIBSVM's powi) ─────────────────────────
10
11/// Integer power by squaring. Matches LIBSVM's `powi(base, times)`.
12///
13/// For negative `times`, returns 1.0 (same as the C code, which only
14/// iterates while `t > 0`).
15#[inline]
16pub fn powi(base: f64, times: i32) -> f64 {
17    let mut tmp = base;
18    let mut ret = 1.0;
19    let mut t = times;
20    while t > 0 {
21        if t % 2 == 1 {
22            ret *= tmp;
23        }
24        tmp *= tmp;
25        t /= 2;
26    }
27    ret
28}
29
30// ─── Sparse dot product ─────────────────────────────────────────────
31
32/// Sparse dot product of two sorted-by-index node slices.
33///
34/// This is the merge-based O(n+m) algorithm from LIBSVM.
35#[inline]
36pub fn dot(x: &[SvmNode], y: &[SvmNode]) -> f64 {
37    let mut sum = 0.0;
38    let mut ix = 0;
39    let mut iy = 0;
40    let x_len = x.len();
41    let y_len = y.len();
42    while ix < x_len && iy < y_len {
43        let x_node = &x[ix];
44        let y_node = &y[iy];
45        if x_node.index == y_node.index {
46            sum += x_node.value * y_node.value;
47            ix += 1;
48            iy += 1;
49        } else if x_node.index > y_node.index {
50            iy += 1;
51        } else {
52            ix += 1;
53        }
54    }
55    sum
56}
57
58/// Squared Euclidean distance for sparse vectors (used by RBF k_function).
59///
60/// Computes ‖x - y‖² without computing the difference vector.
61#[inline]
62fn sparse_sq_dist(x: &[SvmNode], y: &[SvmNode]) -> f64 {
63    let mut sum = 0.0;
64    let mut ix = 0;
65    let mut iy = 0;
66    while ix < x.len() && iy < y.len() {
67        if x[ix].index == y[iy].index {
68            let d = x[ix].value - y[iy].value;
69            sum += d * d;
70            ix += 1;
71            iy += 1;
72        } else if x[ix].index > y[iy].index {
73            sum += y[iy].value * y[iy].value;
74            iy += 1;
75        } else {
76            sum += x[ix].value * x[ix].value;
77            ix += 1;
78        }
79    }
80    // Drain remaining elements
81    while ix < x.len() {
82        sum += x[ix].value * x[ix].value;
83        ix += 1;
84    }
85    while iy < y.len() {
86        sum += y[iy].value * y[iy].value;
87        iy += 1;
88    }
89    sum
90}
91
92// ─── Standalone kernel evaluation ───────────────────────────────────
93
94/// Evaluate the kernel function K(x, y) for the given parameters.
95///
96/// This is the standalone version used during prediction. Matches
97/// LIBSVM's `Kernel::k_function`.
98pub fn k_function(x: &[SvmNode], y: &[SvmNode], param: &SvmParameter) -> f64 {
99    match param.kernel_type {
100        KernelType::Linear => dot(x, y),
101        KernelType::Polynomial => powi(param.gamma * dot(x, y) + param.coef0, param.degree),
102        KernelType::Rbf => (-param.gamma * sparse_sq_dist(x, y)).exp(),
103        KernelType::Sigmoid => (param.gamma * dot(x, y) + param.coef0).tanh(),
104        KernelType::Precomputed => {
105            // For precomputed kernels, x[y[0].value as index] gives the value.
106            // y[0].value is the column index (1-based SV index).
107            //
108            // Safety note: `node.value as usize` is a float→usize cast that saturates
109            // on out-of-range values (negative or > usize::MAX) — this matches upstream
110            // C++ behaviour. `.get()` then maps any out-of-range index to None, and
111            // `.map_or(0.0, ...)` returns 0.0 in that case (silent mis-map, no panic).
112            // Intentional: check_parameter validates the precomputed kernel at train time;
113            // at predict time the model loader validates the 0:serial_number row header.
114            y.first()
115                .and_then(|node| x.get(node.value as usize))
116                .map_or(0.0, |n| n.value)
117        }
118    }
119}
120
121// ─── Kernel struct for training ─────────────────────────────────────
122
123/// Kernel evaluator for training. Holds references to the dataset and
124/// precomputes `x_square[i] = dot(x[i], x[i])` for RBF kernels.
125///
126/// Stores `Vec<&'a [SvmNode]>` so that the solver can swap entries
127/// during shrinking (mirroring the C++ pointer-array swap trick).
128///
129/// The `kernel_function` method pointer pattern from C++ is replaced
130/// by a match on `kernel_type` — the branch predictor handles this
131/// efficiently since the type doesn't change during training.
132pub struct Kernel<'a> {
133    x: Vec<&'a [SvmNode]>,
134    x_square: Option<Vec<f64>>,
135    kernel_type: KernelType,
136    degree: i32,
137    gamma: f64,
138    coef0: f64,
139}
140
141impl<'a> Kernel<'a> {
142    /// Create a new kernel evaluator for the given dataset and parameters.
143    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
144        let x_refs: Vec<&'a [SvmNode]> = x.iter().map(|xi| xi.as_slice()).collect();
145        let x_square = if param.kernel_type == KernelType::Rbf {
146            Some(x_refs.iter().map(|xi| dot(xi, xi)).collect())
147        } else {
148            None
149        };
150
151        Self {
152            x: x_refs,
153            x_square,
154            kernel_type: param.kernel_type,
155            degree: param.degree,
156            gamma: param.gamma,
157            coef0: param.coef0,
158        }
159    }
160
161    /// Evaluate K(x\[i\], x\[j\]) using precomputed data where possible.
162    #[inline]
163    pub fn evaluate(&self, i: usize, j: usize) -> f64 {
164        let xi = self.x[i];
165        let xj = self.x[j];
166        match self.kernel_type {
167            KernelType::Linear => dot(xi, xj),
168            KernelType::Polynomial => powi(self.gamma * dot(xi, xj) + self.coef0, self.degree),
169            KernelType::Rbf => {
170                // Use precomputed x_square: ‖x_i - x_j‖² = x_sq[i] + x_sq[j] - 2*dot(x_i, x_j)
171                let val = if let Some(sq) = &self.x_square {
172                    sq[i] + sq[j] - 2.0 * dot(xi, xj)
173                } else {
174                    sparse_sq_dist(xi, xj)
175                };
176                (-self.gamma * val).exp()
177            }
178            KernelType::Sigmoid => (self.gamma * dot(xi, xj) + self.coef0).tanh(),
179            // See the free-function kernel_function above for the safety note on the
180            // float→usize cast: saturates on out-of-range values, .get() maps to None,
181            // .map_or returns 0.0 (intentional, matches upstream; silent mis-map, no panic).
182            KernelType::Precomputed => xj
183                .first()
184                .and_then(|node| xi.get(node.value as usize))
185                .map_or(0.0, |n| n.value),
186        }
187    }
188
189    /// Swap data-point references and precomputed squares at positions i and j.
190    ///
191    /// Used by QMatrix implementations during solver shrinking.
192    pub fn swap_index(&mut self, i: usize, j: usize) {
193        self.x.swap(i, j);
194        if let Some(ref mut sq) = self.x_square {
195            sq.swap(i, j);
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::types::SvmParameter;
204
205    fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
206        pairs
207            .iter()
208            .map(|&(index, value)| SvmNode { index, value })
209            .collect()
210    }
211
212    #[test]
213    fn powi_basic() {
214        assert_eq!(powi(2.0, 10), 1024.0);
215        assert_eq!(powi(3.0, 0), 1.0);
216        assert_eq!(powi(5.0, 1), 5.0);
217        assert!((powi(2.0, 3) - 8.0).abs() < 1e-15);
218        // Negative exponent: LIBSVM returns 1.0 (loop doesn't execute)
219        assert_eq!(powi(2.0, -1), 1.0);
220    }
221
222    #[test]
223    fn dot_product() {
224        let x = make_nodes(&[(1, 1.0), (3, 2.0), (5, 3.0)]);
225        let y = make_nodes(&[(1, 4.0), (2, 5.0), (5, 6.0)]);
226        // dot = 1*4 + 3*6 = 4 + 18 = 22
227        assert!((dot(&x, &y) - 22.0).abs() < 1e-15);
228    }
229
230    #[test]
231    fn dot_disjoint() {
232        let x = make_nodes(&[(1, 1.0), (3, 2.0)]);
233        let y = make_nodes(&[(2, 5.0), (4, 6.0)]);
234        assert_eq!(dot(&x, &y), 0.0);
235    }
236
237    #[test]
238    fn dot_empty() {
239        let x = make_nodes(&[]);
240        let y = make_nodes(&[(1, 1.0)]);
241        assert_eq!(dot(&x, &y), 0.0);
242    }
243
244    #[test]
245    fn kernel_linear() {
246        let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
247        let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
248        let param = SvmParameter {
249            kernel_type: KernelType::Linear,
250            ..Default::default()
251        };
252        assert!((k_function(&x, &y, &param) - 11.0).abs() < 1e-15);
253    }
254
255    #[test]
256    fn kernel_rbf() {
257        let x = make_nodes(&[(1, 1.0), (2, 0.0)]);
258        let y = make_nodes(&[(1, 0.0), (2, 1.0)]);
259        let param = SvmParameter {
260            kernel_type: KernelType::Rbf,
261            gamma: 0.5,
262            ..Default::default()
263        };
264        // ‖x-y‖² = 1+1 = 2, K = exp(-0.5 * 2) = exp(-1)
265        let expected = (-1.0_f64).exp();
266        assert!((k_function(&x, &y, &param) - expected).abs() < 1e-15);
267    }
268
269    #[test]
270    fn kernel_poly() {
271        let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
272        let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
273        let param = SvmParameter {
274            kernel_type: KernelType::Polynomial,
275            gamma: 1.0,
276            coef0: 1.0,
277            degree: 2,
278            ..Default::default()
279        };
280        // (1*1*11 + 1)^2 = 12^2 = 144
281        assert!((k_function(&x, &y, &param) - 144.0).abs() < 1e-15);
282    }
283
284    #[test]
285    fn kernel_sigmoid() {
286        let x = make_nodes(&[(1, 1.0)]);
287        let y = make_nodes(&[(1, 1.0)]);
288        let param = SvmParameter {
289            kernel_type: KernelType::Sigmoid,
290            gamma: 1.0,
291            coef0: 0.0,
292            ..Default::default()
293        };
294        // tanh(1*1 + 0) = tanh(1)
295        let expected = 1.0_f64.tanh();
296        assert!((k_function(&x, &y, &param) - expected).abs() < 1e-15);
297    }
298
299    #[test]
300    fn kernel_precomputed() {
301        // Precomputed rows use 0:sample_id followed by 1..l kernel values.
302        let x = make_nodes(&[(0, 1.0), (1, 1.5), (2, 2.5)]);
303        let y = make_nodes(&[(0, 2.0), (1, 1.5), (2, 2.5)]);
304        let param = SvmParameter {
305            kernel_type: KernelType::Precomputed,
306            ..Default::default()
307        };
308
309        // y[0].value = 2 => take x[2] => 2.5
310        assert!((k_function(&x, &y, &param) - 2.5).abs() < 1e-15);
311
312        let data = vec![x.clone(), y.clone()];
313        let kern = Kernel::new(&data, &param);
314        // evaluate(0,1) uses column index from row 1 sample id (2)
315        assert!((kern.evaluate(0, 1) - 2.5).abs() < 1e-15);
316    }
317
318    #[test]
319    fn kernel_struct_matches_standalone() {
320        let data = vec![
321            make_nodes(&[(1, 0.5), (3, -1.0)]),
322            make_nodes(&[(1, -0.25), (2, 0.75)]),
323            make_nodes(&[(2, 1.0), (3, 0.5)]),
324        ];
325        let param = SvmParameter {
326            kernel_type: KernelType::Rbf,
327            gamma: 0.5,
328            ..Default::default()
329        };
330
331        let kern = Kernel::new(&data, &param);
332
333        // Verify Kernel::evaluate matches k_function for all pairs
334        for i in 0..data.len() {
335            for j in 0..data.len() {
336                let via_struct = kern.evaluate(i, j);
337                let via_func = k_function(&data[i], &data[j], &param);
338                assert!(
339                    (via_struct - via_func).abs() < 1e-15,
340                    "mismatch at ({},{}): {} vs {}",
341                    i,
342                    j,
343                    via_struct,
344                    via_func
345                );
346            }
347        }
348    }
349
350    #[test]
351    fn rbf_self_kernel_is_one() {
352        let x = make_nodes(&[(1, 3.0), (5, -2.0), (10, 0.7)]);
353        let param = SvmParameter {
354            kernel_type: KernelType::Rbf,
355            gamma: 1.0,
356            ..Default::default()
357        };
358        // K(x, x) = exp(-γ * 0) = 1
359        assert!((k_function(&x, &x, &param) - 1.0).abs() < 1e-15);
360    }
361
362    #[test]
363    fn precomputed_kernel_missing_sample_serial_number_returns_zero() {
364        let x = make_nodes(&[(0, 1.0), (1, 2.0)]);
365        let y = Vec::new();
366        let param = SvmParameter {
367            kernel_type: KernelType::Precomputed,
368            ..Default::default()
369        };
370        assert_eq!(k_function(&x, &y, &param), 0.0);
371    }
372}