1use crate::types::{KernelType, SvmNode, SvmParameter};
8
9#[inline]
16pub fn powi(base: f64, times: i32) -> f64 {
17 let mut tmp = base;
18 let mut ret = 1.0;
19 let mut t = times;
20 while t > 0 {
21 if t % 2 == 1 {
22 ret *= tmp;
23 }
24 tmp *= tmp;
25 t /= 2;
26 }
27 ret
28}
29
30#[inline]
36pub fn dot(x: &[SvmNode], y: &[SvmNode]) -> f64 {
37 let mut sum = 0.0;
38 let mut ix = 0;
39 let mut iy = 0;
40 while ix < x.len() && iy < y.len() {
41 if x[ix].index == y[iy].index {
42 sum += x[ix].value * y[iy].value;
43 ix += 1;
44 iy += 1;
45 } else if x[ix].index > y[iy].index {
46 iy += 1;
47 } else {
48 ix += 1;
49 }
50 }
51 sum
52}
53
54#[inline]
58fn sparse_sq_dist(x: &[SvmNode], y: &[SvmNode]) -> f64 {
59 let mut sum = 0.0;
60 let mut ix = 0;
61 let mut iy = 0;
62 while ix < x.len() && iy < y.len() {
63 if x[ix].index == y[iy].index {
64 let d = x[ix].value - y[iy].value;
65 sum += d * d;
66 ix += 1;
67 iy += 1;
68 } else if x[ix].index > y[iy].index {
69 sum += y[iy].value * y[iy].value;
70 iy += 1;
71 } else {
72 sum += x[ix].value * x[ix].value;
73 ix += 1;
74 }
75 }
76 while ix < x.len() {
78 sum += x[ix].value * x[ix].value;
79 ix += 1;
80 }
81 while iy < y.len() {
82 sum += y[iy].value * y[iy].value;
83 iy += 1;
84 }
85 sum
86}
87
88pub fn k_function(x: &[SvmNode], y: &[SvmNode], param: &SvmParameter) -> f64 {
95 match param.kernel_type {
96 KernelType::Linear => dot(x, y),
97 KernelType::Polynomial => powi(param.gamma * dot(x, y) + param.coef0, param.degree),
98 KernelType::Rbf => (-param.gamma * sparse_sq_dist(x, y)).exp(),
99 KernelType::Sigmoid => (param.gamma * dot(x, y) + param.coef0).tanh(),
100 KernelType::Precomputed => {
101 let col = y[0].value as usize;
104 x.get(col).map_or(0.0, |n| n.value)
105 }
106 }
107}
108
109pub struct Kernel<'a> {
121 x: Vec<&'a [SvmNode]>,
122 x_square: Option<Vec<f64>>,
123 kernel_type: KernelType,
124 degree: i32,
125 gamma: f64,
126 coef0: f64,
127}
128
129impl<'a> Kernel<'a> {
130 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
132 let x_refs: Vec<&'a [SvmNode]> = x.iter().map(|xi| xi.as_slice()).collect();
133 let x_square = if param.kernel_type == KernelType::Rbf {
134 Some(x_refs.iter().map(|xi| dot(xi, xi)).collect())
135 } else {
136 None
137 };
138
139 Self {
140 x: x_refs,
141 x_square,
142 kernel_type: param.kernel_type,
143 degree: param.degree,
144 gamma: param.gamma,
145 coef0: param.coef0,
146 }
147 }
148
149 #[inline]
151 pub fn evaluate(&self, i: usize, j: usize) -> f64 {
152 match self.kernel_type {
153 KernelType::Linear => dot(self.x[i], self.x[j]),
154 KernelType::Polynomial => powi(
155 self.gamma * dot(self.x[i], self.x[j]) + self.coef0,
156 self.degree,
157 ),
158 KernelType::Rbf => {
159 let sq = self.x_square.as_ref().unwrap();
161 let val = sq[i] + sq[j] - 2.0 * dot(self.x[i], self.x[j]);
162 (-self.gamma * val).exp()
163 }
164 KernelType::Sigmoid => (self.gamma * dot(self.x[i], self.x[j]) + self.coef0).tanh(),
165 KernelType::Precomputed => {
166 let col = self.x[j][0].value as usize;
167 self.x[i].get(col).map_or(0.0, |n| n.value)
168 }
169 }
170 }
171
172 pub fn swap_index(&mut self, i: usize, j: usize) {
176 self.x.swap(i, j);
177 if let Some(ref mut sq) = self.x_square {
178 sq.swap(i, j);
179 }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use crate::types::SvmParameter;
187
188 fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
189 pairs
190 .iter()
191 .map(|&(index, value)| SvmNode { index, value })
192 .collect()
193 }
194
195 #[test]
196 fn powi_basic() {
197 assert_eq!(powi(2.0, 10), 1024.0);
198 assert_eq!(powi(3.0, 0), 1.0);
199 assert_eq!(powi(5.0, 1), 5.0);
200 assert!((powi(2.0, 3) - 8.0).abs() < 1e-15);
201 assert_eq!(powi(2.0, -1), 1.0);
203 }
204
205 #[test]
206 fn dot_product() {
207 let x = make_nodes(&[(1, 1.0), (3, 2.0), (5, 3.0)]);
208 let y = make_nodes(&[(1, 4.0), (2, 5.0), (5, 6.0)]);
209 assert!((dot(&x, &y) - 22.0).abs() < 1e-15);
211 }
212
213 #[test]
214 fn dot_disjoint() {
215 let x = make_nodes(&[(1, 1.0), (3, 2.0)]);
216 let y = make_nodes(&[(2, 5.0), (4, 6.0)]);
217 assert_eq!(dot(&x, &y), 0.0);
218 }
219
220 #[test]
221 fn dot_empty() {
222 let x = make_nodes(&[]);
223 let y = make_nodes(&[(1, 1.0)]);
224 assert_eq!(dot(&x, &y), 0.0);
225 }
226
227 #[test]
228 fn kernel_linear() {
229 let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
230 let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
231 let param = SvmParameter {
232 kernel_type: KernelType::Linear,
233 ..Default::default()
234 };
235 assert!((k_function(&x, &y, ¶m) - 11.0).abs() < 1e-15);
236 }
237
238 #[test]
239 fn kernel_rbf() {
240 let x = make_nodes(&[(1, 1.0), (2, 0.0)]);
241 let y = make_nodes(&[(1, 0.0), (2, 1.0)]);
242 let param = SvmParameter {
243 kernel_type: KernelType::Rbf,
244 gamma: 0.5,
245 ..Default::default()
246 };
247 let expected = (-1.0_f64).exp();
249 assert!((k_function(&x, &y, ¶m) - expected).abs() < 1e-15);
250 }
251
252 #[test]
253 fn kernel_poly() {
254 let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
255 let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
256 let param = SvmParameter {
257 kernel_type: KernelType::Polynomial,
258 gamma: 1.0,
259 coef0: 1.0,
260 degree: 2,
261 ..Default::default()
262 };
263 assert!((k_function(&x, &y, ¶m) - 144.0).abs() < 1e-15);
265 }
266
267 #[test]
268 fn kernel_sigmoid() {
269 let x = make_nodes(&[(1, 1.0)]);
270 let y = make_nodes(&[(1, 1.0)]);
271 let param = SvmParameter {
272 kernel_type: KernelType::Sigmoid,
273 gamma: 1.0,
274 coef0: 0.0,
275 ..Default::default()
276 };
277 let expected = 1.0_f64.tanh();
279 assert!((k_function(&x, &y, ¶m) - expected).abs() < 1e-15);
280 }
281
282 #[test]
283 fn kernel_precomputed() {
284 let x = make_nodes(&[(0, 1.0), (1, 1.5), (2, 2.5)]);
286 let y = make_nodes(&[(0, 2.0), (1, 1.5), (2, 2.5)]);
287 let param = SvmParameter {
288 kernel_type: KernelType::Precomputed,
289 ..Default::default()
290 };
291
292 assert!((k_function(&x, &y, ¶m) - 2.5).abs() < 1e-15);
294
295 let data = vec![x.clone(), y.clone()];
296 let kern = Kernel::new(&data, ¶m);
297 assert!((kern.evaluate(0, 1) - 2.5).abs() < 1e-15);
299 }
300
301 #[test]
302 fn kernel_struct_matches_standalone() {
303 let data = vec![
304 make_nodes(&[(1, 0.5), (3, -1.0)]),
305 make_nodes(&[(1, -0.25), (2, 0.75)]),
306 make_nodes(&[(2, 1.0), (3, 0.5)]),
307 ];
308 let param = SvmParameter {
309 kernel_type: KernelType::Rbf,
310 gamma: 0.5,
311 ..Default::default()
312 };
313
314 let kern = Kernel::new(&data, ¶m);
315
316 for i in 0..data.len() {
318 for j in 0..data.len() {
319 let via_struct = kern.evaluate(i, j);
320 let via_func = k_function(&data[i], &data[j], ¶m);
321 assert!(
322 (via_struct - via_func).abs() < 1e-15,
323 "mismatch at ({},{}): {} vs {}",
324 i,
325 j,
326 via_struct,
327 via_func
328 );
329 }
330 }
331 }
332
333 #[test]
334 fn rbf_self_kernel_is_one() {
335 let x = make_nodes(&[(1, 3.0), (5, -2.0), (10, 0.7)]);
336 let param = SvmParameter {
337 kernel_type: KernelType::Rbf,
338 gamma: 1.0,
339 ..Default::default()
340 };
341 assert!((k_function(&x, &x, ¶m) - 1.0).abs() < 1e-15);
343 }
344}