1use crate::types::{KernelType, SvmNode, SvmParameter};
8
9#[inline]
16pub fn powi(base: f64, times: i32) -> f64 {
17 let mut tmp = base;
18 let mut ret = 1.0;
19 let mut t = times;
20 while t > 0 {
21 if t % 2 == 1 {
22 ret *= tmp;
23 }
24 tmp *= tmp;
25 t /= 2;
26 }
27 ret
28}
29
30#[inline]
36pub fn dot(x: &[SvmNode], y: &[SvmNode]) -> f64 {
37 let mut sum = 0.0;
38 let mut ix = 0;
39 let mut iy = 0;
40 let x_len = x.len();
41 let y_len = y.len();
42 while ix < x_len && iy < y_len {
43 let x_node = &x[ix];
44 let y_node = &y[iy];
45 if x_node.index == y_node.index {
46 sum += x_node.value * y_node.value;
47 ix += 1;
48 iy += 1;
49 } else if x_node.index > y_node.index {
50 iy += 1;
51 } else {
52 ix += 1;
53 }
54 }
55 sum
56}
57
58#[inline]
62fn sparse_sq_dist(x: &[SvmNode], y: &[SvmNode]) -> f64 {
63 let mut sum = 0.0;
64 let mut ix = 0;
65 let mut iy = 0;
66 while ix < x.len() && iy < y.len() {
67 if x[ix].index == y[iy].index {
68 let d = x[ix].value - y[iy].value;
69 sum += d * d;
70 ix += 1;
71 iy += 1;
72 } else if x[ix].index > y[iy].index {
73 sum += y[iy].value * y[iy].value;
74 iy += 1;
75 } else {
76 sum += x[ix].value * x[ix].value;
77 ix += 1;
78 }
79 }
80 while ix < x.len() {
82 sum += x[ix].value * x[ix].value;
83 ix += 1;
84 }
85 while iy < y.len() {
86 sum += y[iy].value * y[iy].value;
87 iy += 1;
88 }
89 sum
90}
91
92pub fn k_function(x: &[SvmNode], y: &[SvmNode], param: &SvmParameter) -> f64 {
99 match param.kernel_type {
100 KernelType::Linear => dot(x, y),
101 KernelType::Polynomial => powi(param.gamma * dot(x, y) + param.coef0, param.degree),
102 KernelType::Rbf => (-param.gamma * sparse_sq_dist(x, y)).exp(),
103 KernelType::Sigmoid => (param.gamma * dot(x, y) + param.coef0).tanh(),
104 KernelType::Precomputed => {
105 y.first()
115 .and_then(|node| x.get(node.value as usize))
116 .map_or(0.0, |n| n.value)
117 }
118 }
119}
120
121pub struct Kernel<'a> {
133 x: Vec<&'a [SvmNode]>,
134 x_square: Option<Vec<f64>>,
135 kernel_type: KernelType,
136 degree: i32,
137 gamma: f64,
138 coef0: f64,
139}
140
141impl<'a> Kernel<'a> {
142 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
144 let x_refs: Vec<&'a [SvmNode]> = x.iter().map(|xi| xi.as_slice()).collect();
145 let x_square = if param.kernel_type == KernelType::Rbf {
146 Some(x_refs.iter().map(|xi| dot(xi, xi)).collect())
147 } else {
148 None
149 };
150
151 Self {
152 x: x_refs,
153 x_square,
154 kernel_type: param.kernel_type,
155 degree: param.degree,
156 gamma: param.gamma,
157 coef0: param.coef0,
158 }
159 }
160
161 #[inline]
163 pub fn evaluate(&self, i: usize, j: usize) -> f64 {
164 let xi = self.x[i];
165 let xj = self.x[j];
166 match self.kernel_type {
167 KernelType::Linear => dot(xi, xj),
168 KernelType::Polynomial => powi(self.gamma * dot(xi, xj) + self.coef0, self.degree),
169 KernelType::Rbf => {
170 let val = if let Some(sq) = &self.x_square {
172 sq[i] + sq[j] - 2.0 * dot(xi, xj)
173 } else {
174 sparse_sq_dist(xi, xj)
175 };
176 (-self.gamma * val).exp()
177 }
178 KernelType::Sigmoid => (self.gamma * dot(xi, xj) + self.coef0).tanh(),
179 KernelType::Precomputed => xj
183 .first()
184 .and_then(|node| xi.get(node.value as usize))
185 .map_or(0.0, |n| n.value),
186 }
187 }
188
189 pub fn swap_index(&mut self, i: usize, j: usize) {
193 self.x.swap(i, j);
194 if let Some(ref mut sq) = self.x_square {
195 sq.swap(i, j);
196 }
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203 use crate::types::SvmParameter;
204
205 fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
206 pairs
207 .iter()
208 .map(|&(index, value)| SvmNode { index, value })
209 .collect()
210 }
211
212 #[test]
213 fn powi_basic() {
214 assert_eq!(powi(2.0, 10), 1024.0);
215 assert_eq!(powi(3.0, 0), 1.0);
216 assert_eq!(powi(5.0, 1), 5.0);
217 assert!((powi(2.0, 3) - 8.0).abs() < 1e-15);
218 assert_eq!(powi(2.0, -1), 1.0);
220 }
221
222 #[test]
223 fn dot_product() {
224 let x = make_nodes(&[(1, 1.0), (3, 2.0), (5, 3.0)]);
225 let y = make_nodes(&[(1, 4.0), (2, 5.0), (5, 6.0)]);
226 assert!((dot(&x, &y) - 22.0).abs() < 1e-15);
228 }
229
230 #[test]
231 fn dot_disjoint() {
232 let x = make_nodes(&[(1, 1.0), (3, 2.0)]);
233 let y = make_nodes(&[(2, 5.0), (4, 6.0)]);
234 assert_eq!(dot(&x, &y), 0.0);
235 }
236
237 #[test]
238 fn dot_empty() {
239 let x = make_nodes(&[]);
240 let y = make_nodes(&[(1, 1.0)]);
241 assert_eq!(dot(&x, &y), 0.0);
242 }
243
244 #[test]
245 fn kernel_linear() {
246 let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
247 let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
248 let param = SvmParameter {
249 kernel_type: KernelType::Linear,
250 ..Default::default()
251 };
252 assert!((k_function(&x, &y, ¶m) - 11.0).abs() < 1e-15);
253 }
254
255 #[test]
256 fn kernel_rbf() {
257 let x = make_nodes(&[(1, 1.0), (2, 0.0)]);
258 let y = make_nodes(&[(1, 0.0), (2, 1.0)]);
259 let param = SvmParameter {
260 kernel_type: KernelType::Rbf,
261 gamma: 0.5,
262 ..Default::default()
263 };
264 let expected = (-1.0_f64).exp();
266 assert!((k_function(&x, &y, ¶m) - expected).abs() < 1e-15);
267 }
268
269 #[test]
270 fn kernel_poly() {
271 let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
272 let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
273 let param = SvmParameter {
274 kernel_type: KernelType::Polynomial,
275 gamma: 1.0,
276 coef0: 1.0,
277 degree: 2,
278 ..Default::default()
279 };
280 assert!((k_function(&x, &y, ¶m) - 144.0).abs() < 1e-15);
282 }
283
284 #[test]
285 fn kernel_sigmoid() {
286 let x = make_nodes(&[(1, 1.0)]);
287 let y = make_nodes(&[(1, 1.0)]);
288 let param = SvmParameter {
289 kernel_type: KernelType::Sigmoid,
290 gamma: 1.0,
291 coef0: 0.0,
292 ..Default::default()
293 };
294 let expected = 1.0_f64.tanh();
296 assert!((k_function(&x, &y, ¶m) - expected).abs() < 1e-15);
297 }
298
299 #[test]
300 fn kernel_precomputed() {
301 let x = make_nodes(&[(0, 1.0), (1, 1.5), (2, 2.5)]);
303 let y = make_nodes(&[(0, 2.0), (1, 1.5), (2, 2.5)]);
304 let param = SvmParameter {
305 kernel_type: KernelType::Precomputed,
306 ..Default::default()
307 };
308
309 assert!((k_function(&x, &y, ¶m) - 2.5).abs() < 1e-15);
311
312 let data = vec![x.clone(), y.clone()];
313 let kern = Kernel::new(&data, ¶m);
314 assert!((kern.evaluate(0, 1) - 2.5).abs() < 1e-15);
316 }
317
318 #[test]
319 fn kernel_struct_matches_standalone() {
320 let data = vec![
321 make_nodes(&[(1, 0.5), (3, -1.0)]),
322 make_nodes(&[(1, -0.25), (2, 0.75)]),
323 make_nodes(&[(2, 1.0), (3, 0.5)]),
324 ];
325 let param = SvmParameter {
326 kernel_type: KernelType::Rbf,
327 gamma: 0.5,
328 ..Default::default()
329 };
330
331 let kern = Kernel::new(&data, ¶m);
332
333 for i in 0..data.len() {
335 for j in 0..data.len() {
336 let via_struct = kern.evaluate(i, j);
337 let via_func = k_function(&data[i], &data[j], ¶m);
338 assert!(
339 (via_struct - via_func).abs() < 1e-15,
340 "mismatch at ({},{}): {} vs {}",
341 i,
342 j,
343 via_struct,
344 via_func
345 );
346 }
347 }
348 }
349
350 #[test]
351 fn rbf_self_kernel_is_one() {
352 let x = make_nodes(&[(1, 3.0), (5, -2.0), (10, 0.7)]);
353 let param = SvmParameter {
354 kernel_type: KernelType::Rbf,
355 gamma: 1.0,
356 ..Default::default()
357 };
358 assert!((k_function(&x, &x, ¶m) - 1.0).abs() < 1e-15);
360 }
361
362 #[test]
363 fn precomputed_kernel_missing_sample_serial_number_returns_zero() {
364 let x = make_nodes(&[(0, 1.0), (1, 2.0)]);
365 let y = Vec::new();
366 let param = SvmParameter {
367 kernel_type: KernelType::Precomputed,
368 ..Default::default()
369 };
370 assert_eq!(k_function(&x, &y, ¶m), 0.0);
371 }
372}