1use crate::types::{KernelType, SvmNode, SvmParameter};
8
9#[inline]
16pub fn powi(base: f64, times: i32) -> f64 {
17 let mut tmp = base;
18 let mut ret = 1.0;
19 let mut t = times;
20 while t > 0 {
21 if t % 2 == 1 {
22 ret *= tmp;
23 }
24 tmp *= tmp;
25 t /= 2;
26 }
27 ret
28}
29
30#[inline]
36pub fn dot(x: &[SvmNode], y: &[SvmNode]) -> f64 {
37 let mut sum = 0.0;
38 let mut ix = 0;
39 let mut iy = 0;
40 while ix < x.len() && iy < y.len() {
41 if x[ix].index == y[iy].index {
42 sum += x[ix].value * y[iy].value;
43 ix += 1;
44 iy += 1;
45 } else if x[ix].index > y[iy].index {
46 iy += 1;
47 } else {
48 ix += 1;
49 }
50 }
51 sum
52}
53
54#[inline]
58fn sparse_sq_dist(x: &[SvmNode], y: &[SvmNode]) -> f64 {
59 let mut sum = 0.0;
60 let mut ix = 0;
61 let mut iy = 0;
62 while ix < x.len() && iy < y.len() {
63 if x[ix].index == y[iy].index {
64 let d = x[ix].value - y[iy].value;
65 sum += d * d;
66 ix += 1;
67 iy += 1;
68 } else if x[ix].index > y[iy].index {
69 sum += y[iy].value * y[iy].value;
70 iy += 1;
71 } else {
72 sum += x[ix].value * x[ix].value;
73 ix += 1;
74 }
75 }
76 while ix < x.len() {
78 sum += x[ix].value * x[ix].value;
79 ix += 1;
80 }
81 while iy < y.len() {
82 sum += y[iy].value * y[iy].value;
83 iy += 1;
84 }
85 sum
86}
87
88pub fn k_function(x: &[SvmNode], y: &[SvmNode], param: &SvmParameter) -> f64 {
95 match param.kernel_type {
96 KernelType::Linear => dot(x, y),
97 KernelType::Polynomial => powi(param.gamma * dot(x, y) + param.coef0, param.degree),
98 KernelType::Rbf => (-param.gamma * sparse_sq_dist(x, y)).exp(),
99 KernelType::Sigmoid => (param.gamma * dot(x, y) + param.coef0).tanh(),
100 KernelType::Precomputed => {
101 y.first()
104 .and_then(|node| x.get(node.value as usize))
105 .map_or(0.0, |n| n.value)
106 }
107 }
108}
109
110pub struct Kernel<'a> {
122 x: Vec<&'a [SvmNode]>,
123 x_square: Option<Vec<f64>>,
124 kernel_type: KernelType,
125 degree: i32,
126 gamma: f64,
127 coef0: f64,
128}
129
130impl<'a> Kernel<'a> {
131 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
133 let x_refs: Vec<&'a [SvmNode]> = x.iter().map(|xi| xi.as_slice()).collect();
134 let x_square = if param.kernel_type == KernelType::Rbf {
135 Some(x_refs.iter().map(|xi| dot(xi, xi)).collect())
136 } else {
137 None
138 };
139
140 Self {
141 x: x_refs,
142 x_square,
143 kernel_type: param.kernel_type,
144 degree: param.degree,
145 gamma: param.gamma,
146 coef0: param.coef0,
147 }
148 }
149
150 #[inline]
152 pub fn evaluate(&self, i: usize, j: usize) -> f64 {
153 match self.kernel_type {
154 KernelType::Linear => dot(self.x[i], self.x[j]),
155 KernelType::Polynomial => powi(
156 self.gamma * dot(self.x[i], self.x[j]) + self.coef0,
157 self.degree,
158 ),
159 KernelType::Rbf => {
160 let val = if let Some(sq) = &self.x_square {
162 sq[i] + sq[j] - 2.0 * dot(self.x[i], self.x[j])
163 } else {
164 sparse_sq_dist(self.x[i], self.x[j])
165 };
166 (-self.gamma * val).exp()
167 }
168 KernelType::Sigmoid => (self.gamma * dot(self.x[i], self.x[j]) + self.coef0).tanh(),
169 KernelType::Precomputed => self.x[j]
170 .first()
171 .and_then(|node| self.x[i].get(node.value as usize))
172 .map_or(0.0, |n| n.value),
173 }
174 }
175
176 pub fn swap_index(&mut self, i: usize, j: usize) {
180 self.x.swap(i, j);
181 if let Some(ref mut sq) = self.x_square {
182 sq.swap(i, j);
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use crate::types::SvmParameter;
191
192 fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
193 pairs
194 .iter()
195 .map(|&(index, value)| SvmNode { index, value })
196 .collect()
197 }
198
199 #[test]
200 fn powi_basic() {
201 assert_eq!(powi(2.0, 10), 1024.0);
202 assert_eq!(powi(3.0, 0), 1.0);
203 assert_eq!(powi(5.0, 1), 5.0);
204 assert!((powi(2.0, 3) - 8.0).abs() < 1e-15);
205 assert_eq!(powi(2.0, -1), 1.0);
207 }
208
209 #[test]
210 fn dot_product() {
211 let x = make_nodes(&[(1, 1.0), (3, 2.0), (5, 3.0)]);
212 let y = make_nodes(&[(1, 4.0), (2, 5.0), (5, 6.0)]);
213 assert!((dot(&x, &y) - 22.0).abs() < 1e-15);
215 }
216
217 #[test]
218 fn dot_disjoint() {
219 let x = make_nodes(&[(1, 1.0), (3, 2.0)]);
220 let y = make_nodes(&[(2, 5.0), (4, 6.0)]);
221 assert_eq!(dot(&x, &y), 0.0);
222 }
223
224 #[test]
225 fn dot_empty() {
226 let x = make_nodes(&[]);
227 let y = make_nodes(&[(1, 1.0)]);
228 assert_eq!(dot(&x, &y), 0.0);
229 }
230
231 #[test]
232 fn kernel_linear() {
233 let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
234 let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
235 let param = SvmParameter {
236 kernel_type: KernelType::Linear,
237 ..Default::default()
238 };
239 assert!((k_function(&x, &y, ¶m) - 11.0).abs() < 1e-15);
240 }
241
242 #[test]
243 fn kernel_rbf() {
244 let x = make_nodes(&[(1, 1.0), (2, 0.0)]);
245 let y = make_nodes(&[(1, 0.0), (2, 1.0)]);
246 let param = SvmParameter {
247 kernel_type: KernelType::Rbf,
248 gamma: 0.5,
249 ..Default::default()
250 };
251 let expected = (-1.0_f64).exp();
253 assert!((k_function(&x, &y, ¶m) - expected).abs() < 1e-15);
254 }
255
256 #[test]
257 fn kernel_poly() {
258 let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
259 let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
260 let param = SvmParameter {
261 kernel_type: KernelType::Polynomial,
262 gamma: 1.0,
263 coef0: 1.0,
264 degree: 2,
265 ..Default::default()
266 };
267 assert!((k_function(&x, &y, ¶m) - 144.0).abs() < 1e-15);
269 }
270
271 #[test]
272 fn kernel_sigmoid() {
273 let x = make_nodes(&[(1, 1.0)]);
274 let y = make_nodes(&[(1, 1.0)]);
275 let param = SvmParameter {
276 kernel_type: KernelType::Sigmoid,
277 gamma: 1.0,
278 coef0: 0.0,
279 ..Default::default()
280 };
281 let expected = 1.0_f64.tanh();
283 assert!((k_function(&x, &y, ¶m) - expected).abs() < 1e-15);
284 }
285
286 #[test]
287 fn kernel_precomputed() {
288 let x = make_nodes(&[(0, 1.0), (1, 1.5), (2, 2.5)]);
290 let y = make_nodes(&[(0, 2.0), (1, 1.5), (2, 2.5)]);
291 let param = SvmParameter {
292 kernel_type: KernelType::Precomputed,
293 ..Default::default()
294 };
295
296 assert!((k_function(&x, &y, ¶m) - 2.5).abs() < 1e-15);
298
299 let data = vec![x.clone(), y.clone()];
300 let kern = Kernel::new(&data, ¶m);
301 assert!((kern.evaluate(0, 1) - 2.5).abs() < 1e-15);
303 }
304
305 #[test]
306 fn kernel_struct_matches_standalone() {
307 let data = vec![
308 make_nodes(&[(1, 0.5), (3, -1.0)]),
309 make_nodes(&[(1, -0.25), (2, 0.75)]),
310 make_nodes(&[(2, 1.0), (3, 0.5)]),
311 ];
312 let param = SvmParameter {
313 kernel_type: KernelType::Rbf,
314 gamma: 0.5,
315 ..Default::default()
316 };
317
318 let kern = Kernel::new(&data, ¶m);
319
320 for i in 0..data.len() {
322 for j in 0..data.len() {
323 let via_struct = kern.evaluate(i, j);
324 let via_func = k_function(&data[i], &data[j], ¶m);
325 assert!(
326 (via_struct - via_func).abs() < 1e-15,
327 "mismatch at ({},{}): {} vs {}",
328 i,
329 j,
330 via_struct,
331 via_func
332 );
333 }
334 }
335 }
336
337 #[test]
338 fn rbf_self_kernel_is_one() {
339 let x = make_nodes(&[(1, 3.0), (5, -2.0), (10, 0.7)]);
340 let param = SvmParameter {
341 kernel_type: KernelType::Rbf,
342 gamma: 1.0,
343 ..Default::default()
344 };
345 assert!((k_function(&x, &x, ¶m) - 1.0).abs() < 1e-15);
347 }
348
349 #[test]
350 fn precomputed_kernel_missing_sample_serial_number_returns_zero() {
351 let x = make_nodes(&[(0, 1.0), (1, 2.0)]);
352 let y = Vec::new();
353 let param = SvmParameter {
354 kernel_type: KernelType::Precomputed,
355 ..Default::default()
356 };
357 assert_eq!(k_function(&x, &y, ¶m), 0.0);
358 }
359}