Skip to main content

libsvm_rs/
kernel.rs

1//! Kernel functions matching the original LIBSVM.
2//!
3//! Provides both:
4//! - Standalone `k_function` for prediction (operates on sparse node slices)
5//! - `Kernel` struct for training (precomputes x_square for RBF, stores refs)
6
7use crate::types::{KernelType, SvmNode, SvmParameter};
8
9// ─── Integer power (matches LIBSVM's powi) ─────────────────────────
10
11/// Integer power by squaring. Matches LIBSVM's `powi(base, times)`.
12///
13/// For negative `times`, returns 1.0 (same as the C code, which only
14/// iterates while `t > 0`).
15#[inline]
16pub fn powi(base: f64, times: i32) -> f64 {
17    let mut tmp = base;
18    let mut ret = 1.0;
19    let mut t = times;
20    while t > 0 {
21        if t % 2 == 1 {
22            ret *= tmp;
23        }
24        tmp *= tmp;
25        t /= 2;
26    }
27    ret
28}
29
30// ─── Sparse dot product ─────────────────────────────────────────────
31
32/// Sparse dot product of two sorted-by-index node slices.
33///
34/// This is the merge-based O(n+m) algorithm from LIBSVM.
35#[inline]
36pub fn dot(x: &[SvmNode], y: &[SvmNode]) -> f64 {
37    let mut sum = 0.0;
38    let mut ix = 0;
39    let mut iy = 0;
40    while ix < x.len() && iy < y.len() {
41        if x[ix].index == y[iy].index {
42            sum += x[ix].value * y[iy].value;
43            ix += 1;
44            iy += 1;
45        } else if x[ix].index > y[iy].index {
46            iy += 1;
47        } else {
48            ix += 1;
49        }
50    }
51    sum
52}
53
54/// Squared Euclidean distance for sparse vectors (used by RBF k_function).
55///
56/// Computes ‖x - y‖² without computing the difference vector.
57#[inline]
58fn sparse_sq_dist(x: &[SvmNode], y: &[SvmNode]) -> f64 {
59    let mut sum = 0.0;
60    let mut ix = 0;
61    let mut iy = 0;
62    while ix < x.len() && iy < y.len() {
63        if x[ix].index == y[iy].index {
64            let d = x[ix].value - y[iy].value;
65            sum += d * d;
66            ix += 1;
67            iy += 1;
68        } else if x[ix].index > y[iy].index {
69            sum += y[iy].value * y[iy].value;
70            iy += 1;
71        } else {
72            sum += x[ix].value * x[ix].value;
73            ix += 1;
74        }
75    }
76    // Drain remaining elements
77    while ix < x.len() {
78        sum += x[ix].value * x[ix].value;
79        ix += 1;
80    }
81    while iy < y.len() {
82        sum += y[iy].value * y[iy].value;
83        iy += 1;
84    }
85    sum
86}
87
88// ─── Standalone kernel evaluation ───────────────────────────────────
89
90/// Evaluate the kernel function K(x, y) for the given parameters.
91///
92/// This is the standalone version used during prediction. Matches
93/// LIBSVM's `Kernel::k_function`.
94pub fn k_function(x: &[SvmNode], y: &[SvmNode], param: &SvmParameter) -> f64 {
95    match param.kernel_type {
96        KernelType::Linear => dot(x, y),
97        KernelType::Polynomial => powi(param.gamma * dot(x, y) + param.coef0, param.degree),
98        KernelType::Rbf => (-param.gamma * sparse_sq_dist(x, y)).exp(),
99        KernelType::Sigmoid => (param.gamma * dot(x, y) + param.coef0).tanh(),
100        KernelType::Precomputed => {
101            // For precomputed kernels, x[y[0].value as index] gives the value.
102            // y[0].value is the column index (1-based SV index).
103            let col = y[0].value as usize;
104            x.get(col).map_or(0.0, |n| n.value)
105        }
106    }
107}
108
109// ─── Kernel struct for training ─────────────────────────────────────
110
111/// Kernel evaluator for training. Holds references to the dataset and
112/// precomputes `x_square[i] = dot(x[i], x[i])` for RBF kernels.
113///
114/// Stores `Vec<&'a [SvmNode]>` so that the solver can swap entries
115/// during shrinking (mirroring the C++ pointer-array swap trick).
116///
117/// The `kernel_function` method pointer pattern from C++ is replaced
118/// by a match on `kernel_type` — the branch predictor handles this
119/// efficiently since the type doesn't change during training.
120pub struct Kernel<'a> {
121    x: Vec<&'a [SvmNode]>,
122    x_square: Option<Vec<f64>>,
123    kernel_type: KernelType,
124    degree: i32,
125    gamma: f64,
126    coef0: f64,
127}
128
129impl<'a> Kernel<'a> {
130    /// Create a new kernel evaluator for the given dataset and parameters.
131    pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
132        let x_refs: Vec<&'a [SvmNode]> = x.iter().map(|xi| xi.as_slice()).collect();
133        let x_square = if param.kernel_type == KernelType::Rbf {
134            Some(x_refs.iter().map(|xi| dot(xi, xi)).collect())
135        } else {
136            None
137        };
138
139        Self {
140            x: x_refs,
141            x_square,
142            kernel_type: param.kernel_type,
143            degree: param.degree,
144            gamma: param.gamma,
145            coef0: param.coef0,
146        }
147    }
148
149    /// Evaluate K(x\[i\], x\[j\]) using precomputed data where possible.
150    #[inline]
151    pub fn evaluate(&self, i: usize, j: usize) -> f64 {
152        match self.kernel_type {
153            KernelType::Linear => dot(self.x[i], self.x[j]),
154            KernelType::Polynomial => powi(
155                self.gamma * dot(self.x[i], self.x[j]) + self.coef0,
156                self.degree,
157            ),
158            KernelType::Rbf => {
159                // Use precomputed x_square: ‖x_i - x_j‖² = x_sq[i] + x_sq[j] - 2*dot(x_i, x_j)
160                let sq = self.x_square.as_ref().unwrap();
161                let val = sq[i] + sq[j] - 2.0 * dot(self.x[i], self.x[j]);
162                (-self.gamma * val).exp()
163            }
164            KernelType::Sigmoid => (self.gamma * dot(self.x[i], self.x[j]) + self.coef0).tanh(),
165            KernelType::Precomputed => {
166                let col = self.x[j][0].value as usize;
167                self.x[i].get(col).map_or(0.0, |n| n.value)
168            }
169        }
170    }
171
172    /// Swap data-point references and precomputed squares at positions i and j.
173    ///
174    /// Used by QMatrix implementations during solver shrinking.
175    pub fn swap_index(&mut self, i: usize, j: usize) {
176        self.x.swap(i, j);
177        if let Some(ref mut sq) = self.x_square {
178            sq.swap(i, j);
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use crate::types::SvmParameter;
187
188    fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
189        pairs
190            .iter()
191            .map(|&(index, value)| SvmNode { index, value })
192            .collect()
193    }
194
195    #[test]
196    fn powi_basic() {
197        assert_eq!(powi(2.0, 10), 1024.0);
198        assert_eq!(powi(3.0, 0), 1.0);
199        assert_eq!(powi(5.0, 1), 5.0);
200        assert!((powi(2.0, 3) - 8.0).abs() < 1e-15);
201        // Negative exponent: LIBSVM returns 1.0 (loop doesn't execute)
202        assert_eq!(powi(2.0, -1), 1.0);
203    }
204
205    #[test]
206    fn dot_product() {
207        let x = make_nodes(&[(1, 1.0), (3, 2.0), (5, 3.0)]);
208        let y = make_nodes(&[(1, 4.0), (2, 5.0), (5, 6.0)]);
209        // dot = 1*4 + 3*6 = 4 + 18 = 22
210        assert!((dot(&x, &y) - 22.0).abs() < 1e-15);
211    }
212
213    #[test]
214    fn dot_disjoint() {
215        let x = make_nodes(&[(1, 1.0), (3, 2.0)]);
216        let y = make_nodes(&[(2, 5.0), (4, 6.0)]);
217        assert_eq!(dot(&x, &y), 0.0);
218    }
219
220    #[test]
221    fn dot_empty() {
222        let x = make_nodes(&[]);
223        let y = make_nodes(&[(1, 1.0)]);
224        assert_eq!(dot(&x, &y), 0.0);
225    }
226
227    #[test]
228    fn kernel_linear() {
229        let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
230        let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
231        let param = SvmParameter {
232            kernel_type: KernelType::Linear,
233            ..Default::default()
234        };
235        assert!((k_function(&x, &y, &param) - 11.0).abs() < 1e-15);
236    }
237
238    #[test]
239    fn kernel_rbf() {
240        let x = make_nodes(&[(1, 1.0), (2, 0.0)]);
241        let y = make_nodes(&[(1, 0.0), (2, 1.0)]);
242        let param = SvmParameter {
243            kernel_type: KernelType::Rbf,
244            gamma: 0.5,
245            ..Default::default()
246        };
247        // ‖x-y‖² = 1+1 = 2, K = exp(-0.5 * 2) = exp(-1)
248        let expected = (-1.0_f64).exp();
249        assert!((k_function(&x, &y, &param) - expected).abs() < 1e-15);
250    }
251
252    #[test]
253    fn kernel_poly() {
254        let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
255        let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
256        let param = SvmParameter {
257            kernel_type: KernelType::Polynomial,
258            gamma: 1.0,
259            coef0: 1.0,
260            degree: 2,
261            ..Default::default()
262        };
263        // (1*1*11 + 1)^2 = 12^2 = 144
264        assert!((k_function(&x, &y, &param) - 144.0).abs() < 1e-15);
265    }
266
267    #[test]
268    fn kernel_sigmoid() {
269        let x = make_nodes(&[(1, 1.0)]);
270        let y = make_nodes(&[(1, 1.0)]);
271        let param = SvmParameter {
272            kernel_type: KernelType::Sigmoid,
273            gamma: 1.0,
274            coef0: 0.0,
275            ..Default::default()
276        };
277        // tanh(1*1 + 0) = tanh(1)
278        let expected = 1.0_f64.tanh();
279        assert!((k_function(&x, &y, &param) - expected).abs() < 1e-15);
280    }
281
282    #[test]
283    fn kernel_precomputed() {
284        // Precomputed rows use 0:sample_id followed by 1..l kernel values.
285        let x = make_nodes(&[(0, 1.0), (1, 1.5), (2, 2.5)]);
286        let y = make_nodes(&[(0, 2.0), (1, 1.5), (2, 2.5)]);
287        let param = SvmParameter {
288            kernel_type: KernelType::Precomputed,
289            ..Default::default()
290        };
291
292        // y[0].value = 2 => take x[2] => 2.5
293        assert!((k_function(&x, &y, &param) - 2.5).abs() < 1e-15);
294
295        let data = vec![x.clone(), y.clone()];
296        let kern = Kernel::new(&data, &param);
297        // evaluate(0,1) uses column index from row 1 sample id (2)
298        assert!((kern.evaluate(0, 1) - 2.5).abs() < 1e-15);
299    }
300
301    #[test]
302    fn kernel_struct_matches_standalone() {
303        let data = vec![
304            make_nodes(&[(1, 0.5), (3, -1.0)]),
305            make_nodes(&[(1, -0.25), (2, 0.75)]),
306            make_nodes(&[(2, 1.0), (3, 0.5)]),
307        ];
308        let param = SvmParameter {
309            kernel_type: KernelType::Rbf,
310            gamma: 0.5,
311            ..Default::default()
312        };
313
314        let kern = Kernel::new(&data, &param);
315
316        // Verify Kernel::evaluate matches k_function for all pairs
317        for i in 0..data.len() {
318            for j in 0..data.len() {
319                let via_struct = kern.evaluate(i, j);
320                let via_func = k_function(&data[i], &data[j], &param);
321                assert!(
322                    (via_struct - via_func).abs() < 1e-15,
323                    "mismatch at ({},{}): {} vs {}",
324                    i,
325                    j,
326                    via_struct,
327                    via_func
328                );
329            }
330        }
331    }
332
333    #[test]
334    fn rbf_self_kernel_is_one() {
335        let x = make_nodes(&[(1, 3.0), (5, -2.0), (10, 0.7)]);
336        let param = SvmParameter {
337            kernel_type: KernelType::Rbf,
338            gamma: 1.0,
339            ..Default::default()
340        };
341        // K(x, x) = exp(-γ * 0) = 1
342        assert!((k_function(&x, &x, &param) - 1.0).abs() < 1e-15);
343    }
344}