1use crate::types::{KernelType, SvmNode, SvmParameter};
8
9#[inline]
16pub fn powi(base: f64, times: i32) -> f64 {
17 let mut tmp = base;
18 let mut ret = 1.0;
19 let mut t = times;
20 while t > 0 {
21 if t % 2 == 1 {
22 ret *= tmp;
23 }
24 tmp *= tmp;
25 t /= 2;
26 }
27 ret
28}
29
30#[inline]
36pub fn dot(x: &[SvmNode], y: &[SvmNode]) -> f64 {
37 let mut sum = 0.0;
38 let mut ix = 0;
39 let mut iy = 0;
40 while ix < x.len() && iy < y.len() {
41 if x[ix].index == y[iy].index {
42 sum += x[ix].value * y[iy].value;
43 ix += 1;
44 iy += 1;
45 } else if x[ix].index > y[iy].index {
46 iy += 1;
47 } else {
48 ix += 1;
49 }
50 }
51 sum
52}
53
54#[inline]
58fn sparse_sq_dist(x: &[SvmNode], y: &[SvmNode]) -> f64 {
59 let mut sum = 0.0;
60 let mut ix = 0;
61 let mut iy = 0;
62 while ix < x.len() && iy < y.len() {
63 if x[ix].index == y[iy].index {
64 let d = x[ix].value - y[iy].value;
65 sum += d * d;
66 ix += 1;
67 iy += 1;
68 } else if x[ix].index > y[iy].index {
69 sum += y[iy].value * y[iy].value;
70 iy += 1;
71 } else {
72 sum += x[ix].value * x[ix].value;
73 ix += 1;
74 }
75 }
76 while ix < x.len() {
78 sum += x[ix].value * x[ix].value;
79 ix += 1;
80 }
81 while iy < y.len() {
82 sum += y[iy].value * y[iy].value;
83 iy += 1;
84 }
85 sum
86}
87
88pub fn k_function(x: &[SvmNode], y: &[SvmNode], param: &SvmParameter) -> f64 {
95 match param.kernel_type {
96 KernelType::Linear => dot(x, y),
97 KernelType::Polynomial => {
98 powi(param.gamma * dot(x, y) + param.coef0, param.degree)
99 }
100 KernelType::Rbf => {
101 (-param.gamma * sparse_sq_dist(x, y)).exp()
102 }
103 KernelType::Sigmoid => {
104 (param.gamma * dot(x, y) + param.coef0).tanh()
105 }
106 KernelType::Precomputed => {
107 let col = y[0].value as usize;
110 x.get(col).map_or(0.0, |n| n.value)
111 }
112 }
113}
114
115pub struct Kernel<'a> {
124 x: &'a [Vec<SvmNode>],
125 x_square: Option<Vec<f64>>,
126 kernel_type: KernelType,
127 degree: i32,
128 gamma: f64,
129 coef0: f64,
130}
131
132impl<'a> Kernel<'a> {
133 pub fn new(x: &'a [Vec<SvmNode>], param: &SvmParameter) -> Self {
135 let x_square = if param.kernel_type == KernelType::Rbf {
136 Some(x.iter().map(|xi| dot(xi, xi)).collect())
137 } else {
138 None
139 };
140
141 Self {
142 x,
143 x_square,
144 kernel_type: param.kernel_type,
145 degree: param.degree,
146 gamma: param.gamma,
147 coef0: param.coef0,
148 }
149 }
150
151 #[inline]
153 pub fn evaluate(&self, i: usize, j: usize) -> f64 {
154 match self.kernel_type {
155 KernelType::Linear => dot(&self.x[i], &self.x[j]),
156 KernelType::Polynomial => {
157 powi(self.gamma * dot(&self.x[i], &self.x[j]) + self.coef0, self.degree)
158 }
159 KernelType::Rbf => {
160 let sq = self.x_square.as_ref().unwrap();
162 let val = sq[i] + sq[j] - 2.0 * dot(&self.x[i], &self.x[j]);
163 (-self.gamma * val).exp()
164 }
165 KernelType::Sigmoid => {
166 (self.gamma * dot(&self.x[i], &self.x[j]) + self.coef0).tanh()
167 }
168 KernelType::Precomputed => {
169 let col = self.x[j][0].value as usize;
170 self.x[i].get(col).map_or(0.0, |n| n.value)
171 }
172 }
173 }
174
175 pub fn swap_x_square(&mut self, i: usize, j: usize) {
182 if let Some(ref mut sq) = self.x_square {
183 sq.swap(i, j);
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191 use crate::types::SvmParameter;
192
193 fn make_nodes(pairs: &[(i32, f64)]) -> Vec<SvmNode> {
194 pairs
195 .iter()
196 .map(|&(index, value)| SvmNode { index, value })
197 .collect()
198 }
199
200 #[test]
201 fn powi_basic() {
202 assert_eq!(powi(2.0, 10), 1024.0);
203 assert_eq!(powi(3.0, 0), 1.0);
204 assert_eq!(powi(5.0, 1), 5.0);
205 assert!((powi(2.0, 3) - 8.0).abs() < 1e-15);
206 assert_eq!(powi(2.0, -1), 1.0);
208 }
209
210 #[test]
211 fn dot_product() {
212 let x = make_nodes(&[(1, 1.0), (3, 2.0), (5, 3.0)]);
213 let y = make_nodes(&[(1, 4.0), (2, 5.0), (5, 6.0)]);
214 assert!((dot(&x, &y) - 22.0).abs() < 1e-15);
216 }
217
218 #[test]
219 fn dot_disjoint() {
220 let x = make_nodes(&[(1, 1.0), (3, 2.0)]);
221 let y = make_nodes(&[(2, 5.0), (4, 6.0)]);
222 assert_eq!(dot(&x, &y), 0.0);
223 }
224
225 #[test]
226 fn dot_empty() {
227 let x = make_nodes(&[]);
228 let y = make_nodes(&[(1, 1.0)]);
229 assert_eq!(dot(&x, &y), 0.0);
230 }
231
232 #[test]
233 fn kernel_linear() {
234 let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
235 let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
236 let param = SvmParameter {
237 kernel_type: KernelType::Linear,
238 ..Default::default()
239 };
240 assert!((k_function(&x, &y, ¶m) - 11.0).abs() < 1e-15);
241 }
242
243 #[test]
244 fn kernel_rbf() {
245 let x = make_nodes(&[(1, 1.0), (2, 0.0)]);
246 let y = make_nodes(&[(1, 0.0), (2, 1.0)]);
247 let param = SvmParameter {
248 kernel_type: KernelType::Rbf,
249 gamma: 0.5,
250 ..Default::default()
251 };
252 let expected = (-1.0_f64).exp();
254 assert!((k_function(&x, &y, ¶m) - expected).abs() < 1e-15);
255 }
256
257 #[test]
258 fn kernel_poly() {
259 let x = make_nodes(&[(1, 1.0), (2, 2.0)]);
260 let y = make_nodes(&[(1, 3.0), (2, 4.0)]);
261 let param = SvmParameter {
262 kernel_type: KernelType::Polynomial,
263 gamma: 1.0,
264 coef0: 1.0,
265 degree: 2,
266 ..Default::default()
267 };
268 assert!((k_function(&x, &y, ¶m) - 144.0).abs() < 1e-15);
270 }
271
272 #[test]
273 fn kernel_sigmoid() {
274 let x = make_nodes(&[(1, 1.0)]);
275 let y = make_nodes(&[(1, 1.0)]);
276 let param = SvmParameter {
277 kernel_type: KernelType::Sigmoid,
278 gamma: 1.0,
279 coef0: 0.0,
280 ..Default::default()
281 };
282 let expected = 1.0_f64.tanh();
284 assert!((k_function(&x, &y, ¶m) - expected).abs() < 1e-15);
285 }
286
287 #[test]
288 fn kernel_struct_matches_standalone() {
289 let data = vec![
290 make_nodes(&[(1, 0.5), (3, -1.0)]),
291 make_nodes(&[(1, -0.25), (2, 0.75)]),
292 make_nodes(&[(2, 1.0), (3, 0.5)]),
293 ];
294 let param = SvmParameter {
295 kernel_type: KernelType::Rbf,
296 gamma: 0.5,
297 ..Default::default()
298 };
299
300 let kern = Kernel::new(&data, ¶m);
301
302 for i in 0..data.len() {
304 for j in 0..data.len() {
305 let via_struct = kern.evaluate(i, j);
306 let via_func = k_function(&data[i], &data[j], ¶m);
307 assert!(
308 (via_struct - via_func).abs() < 1e-15,
309 "mismatch at ({},{}): {} vs {}",
310 i, j, via_struct, via_func
311 );
312 }
313 }
314 }
315
316 #[test]
317 fn rbf_self_kernel_is_one() {
318 let x = make_nodes(&[(1, 3.0), (5, -2.0), (10, 0.7)]);
319 let param = SvmParameter {
320 kernel_type: KernelType::Rbf,
321 gamma: 1.0,
322 ..Default::default()
323 };
324 assert!((k_function(&x, &x, ¶m) - 1.0).abs() < 1e-15);
326 }
327}