1use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
7use std::collections::HashMap;
8use std::sync::Arc;
9
10pub trait Kernel: Send + Sync + std::fmt::Debug {
12 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64;
14
15 fn compute_matrix(&self, x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
17 let (n_x, _) = x.dim();
18 let (n_y, _) = y.dim();
19 let mut kernel_matrix = Array2::zeros((n_x, n_y));
20
21 for i in 0..n_x {
22 for j in 0..n_y {
23 kernel_matrix[[i, j]] = self.compute(x.row(i), y.row(j));
24 }
25 }
26
27 kernel_matrix
28 }
29
30 fn parameters(&self) -> HashMap<String, f64>;
32}
33
34#[derive(Debug, Clone, PartialEq)]
36pub enum KernelType {
37 Linear,
39 Rbf { gamma: f64 },
41 Polynomial { gamma: f64, coef0: f64, degree: f64 },
43 Sigmoid { gamma: f64, coef0: f64 },
45 Precomputed,
47 Custom(String),
49 Cosine,
51 ChiSquared { gamma: f64 },
53 Intersection,
55 Hellinger,
57 JensenShannon,
59 Periodic { length_scale: f64, period: f64 },
61}
62
63pub fn create_kernel(kernel_type: KernelType) -> Box<dyn Kernel> {
65 match kernel_type {
66 KernelType::Linear => Box::new(LinearKernel),
67 KernelType::Rbf { gamma } => Box::new(RbfKernel { gamma }),
68 KernelType::Polynomial {
69 gamma,
70 coef0,
71 degree,
72 } => Box::new(PolynomialKernel {
73 gamma,
74 coef0,
75 degree,
76 }),
77 KernelType::Sigmoid { gamma, coef0 } => Box::new(SigmoidKernel { gamma, coef0 }),
78 KernelType::Cosine => Box::new(CosineKernel),
79 KernelType::ChiSquared { gamma } => Box::new(ChiSquaredKernel { gamma }),
80 KernelType::Intersection => Box::new(IntersectionKernel),
81 KernelType::Hellinger => Box::new(HellingerKernel),
82 KernelType::JensenShannon => Box::new(JensenShannonKernel),
83 KernelType::Periodic {
84 length_scale,
85 period,
86 } => Box::new(PeriodicKernel {
87 length_scale,
88 period,
89 }),
90 KernelType::Precomputed => panic!("Precomputed kernels must be created with data"),
91 KernelType::Custom(name) => panic!("Custom kernel '{}' not implemented", name),
92 }
93}
94
95impl<K: Kernel> Kernel for Arc<K> {
96 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
97 (**self).compute(x, y)
98 }
99
100 fn compute_matrix(&self, x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
101 (**self).compute_matrix(x, y)
102 }
103
104 fn parameters(&self) -> HashMap<String, f64> {
105 (**self).parameters()
106 }
107}
108
109#[derive(Debug, Clone)]
111pub struct LinearKernel;
112
113impl Default for LinearKernel {
114 fn default() -> Self {
115 Self::new()
116 }
117}
118
119impl LinearKernel {
120 pub fn new() -> Self {
121 Self
122 }
123}
124
125impl Kernel for LinearKernel {
126 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
127 x.dot(&y)
128 }
129
130 fn parameters(&self) -> HashMap<String, f64> {
131 HashMap::new()
132 }
133}
134
135#[derive(Debug, Clone)]
137pub struct RbfKernel {
138 pub gamma: f64,
139}
140
141impl RbfKernel {
142 pub fn new(gamma: f64) -> Self {
143 Self { gamma }
144 }
145}
146
147impl Kernel for RbfKernel {
148 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
149 let diff = &x.to_owned() - &y.to_owned();
150 let squared_distance = diff.dot(&diff);
151 (-self.gamma * squared_distance).exp()
152 }
153
154 fn parameters(&self) -> HashMap<String, f64> {
155 let mut params = HashMap::new();
156 params.insert("gamma".to_string(), self.gamma);
157 params
158 }
159}
160
161#[derive(Debug, Clone)]
163pub struct PolynomialKernel {
164 pub gamma: f64,
165 pub coef0: f64,
166 pub degree: f64,
167}
168
169impl PolynomialKernel {
170 pub fn new(gamma: f64, coef0: f64, degree: f64) -> Self {
171 Self {
172 gamma,
173 coef0,
174 degree,
175 }
176 }
177}
178
179impl Kernel for PolynomialKernel {
180 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
181 let dot_product = x.dot(&y);
182 (self.gamma * dot_product + self.coef0).powf(self.degree)
183 }
184
185 fn parameters(&self) -> HashMap<String, f64> {
186 let mut params = HashMap::new();
187 params.insert("gamma".to_string(), self.gamma);
188 params.insert("coef0".to_string(), self.coef0);
189 params.insert("degree".to_string(), self.degree);
190 params
191 }
192}
193
194#[derive(Debug, Clone)]
196pub struct SigmoidKernel {
197 pub gamma: f64,
198 pub coef0: f64,
199}
200
201impl SigmoidKernel {
202 pub fn new(gamma: f64, coef0: f64) -> Self {
203 Self { gamma, coef0 }
204 }
205}
206
207impl Kernel for SigmoidKernel {
208 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
209 let dot_product = x.dot(&y);
210 (self.gamma * dot_product + self.coef0).tanh()
211 }
212
213 fn parameters(&self) -> HashMap<String, f64> {
214 let mut params = HashMap::new();
215 params.insert("gamma".to_string(), self.gamma);
216 params.insert("coef0".to_string(), self.coef0);
217 params
218 }
219}
220
221#[derive(Debug, Clone)]
223pub struct CosineKernel;
224
225impl Kernel for CosineKernel {
226 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
227 let dot_product = x.dot(&y);
228 let x_norm = x.dot(&x).sqrt();
229 let y_norm = y.dot(&y).sqrt();
230
231 if x_norm == 0.0 || y_norm == 0.0 {
232 0.0
233 } else {
234 dot_product / (x_norm * y_norm)
235 }
236 }
237
238 fn parameters(&self) -> HashMap<String, f64> {
239 HashMap::new()
240 }
241}
242
243#[derive(Debug, Clone)]
245pub struct ChiSquaredKernel {
246 pub gamma: f64,
247}
248
249impl ChiSquaredKernel {
250 pub fn new(gamma: f64) -> Self {
251 Self { gamma }
252 }
253}
254
255impl Kernel for ChiSquaredKernel {
256 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
257 let chi_squared_distance = x
258 .iter()
259 .zip(y.iter())
260 .map(|(a, b)| {
261 if a + b > 0.0 {
262 (a - b).powi(2) / (a + b)
263 } else {
264 0.0
265 }
266 })
267 .sum::<f64>();
268
269 (-self.gamma * chi_squared_distance).exp()
270 }
271
272 fn parameters(&self) -> HashMap<String, f64> {
273 let mut params = HashMap::new();
274 params.insert("gamma".to_string(), self.gamma);
275 params
276 }
277}
278
279#[derive(Debug, Clone)]
281pub struct IntersectionKernel;
282
283impl Kernel for IntersectionKernel {
284 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
285 x.iter().zip(y.iter()).map(|(a, b)| a.min(*b)).sum()
286 }
287
288 fn parameters(&self) -> HashMap<String, f64> {
289 HashMap::new()
290 }
291}
292
293#[derive(Debug, Clone)]
295pub struct PeriodicKernel {
296 pub length_scale: f64,
297 pub period: f64,
298}
299
300impl PeriodicKernel {
301 pub fn new(length_scale: f64, period: f64) -> Self {
302 Self {
303 length_scale,
304 period,
305 }
306 }
307}
308
309impl Kernel for PeriodicKernel {
310 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
311 let diff = &x.to_owned() - &y.to_owned();
312 let sin_term = diff.mapv(|d| (std::f64::consts::PI * d / self.period).sin());
313 let sin_squared = sin_term.dot(&sin_term);
314 (-2.0 * sin_squared / (self.length_scale * self.length_scale)).exp()
315 }
316
317 fn parameters(&self) -> HashMap<String, f64> {
318 let mut params = HashMap::new();
319 params.insert("length_scale".to_string(), self.length_scale);
320 params.insert("period".to_string(), self.period);
321 params
322 }
323}
324
325#[derive(Debug, Clone)]
327pub struct CustomKernel {
328 pub name: String,
329 pub function: fn(ArrayView1<f64>, ArrayView1<f64>) -> f64,
330}
331
332impl CustomKernel {
333 pub fn new(name: String, function: fn(ArrayView1<f64>, ArrayView1<f64>) -> f64) -> Self {
334 Self { name, function }
335 }
336}
337
338impl Kernel for CustomKernel {
339 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
340 (self.function)(x, y)
341 }
342
343 fn parameters(&self) -> HashMap<String, f64> {
344 HashMap::new()
345 }
346}
347
348#[derive(Debug, Clone)]
350pub struct KernelFunction {
351 kernel_type: KernelType,
352}
353
354impl KernelFunction {
355 pub fn new(kernel_type: KernelType) -> Self {
356 Self { kernel_type }
357 }
358
359 pub fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
360 match &self.kernel_type {
361 KernelType::Linear => LinearKernel.compute(x, y),
362 KernelType::Rbf { gamma } => RbfKernel::new(*gamma).compute(x, y),
363 KernelType::Polynomial {
364 gamma,
365 coef0,
366 degree,
367 } => PolynomialKernel::new(*gamma, *coef0, *degree).compute(x, y),
368 KernelType::Sigmoid { gamma, coef0 } => {
369 SigmoidKernel::new(*gamma, *coef0).compute(x, y)
370 }
371 KernelType::Cosine => CosineKernel.compute(x, y),
372 KernelType::ChiSquared { gamma } => ChiSquaredKernel::new(*gamma).compute(x, y),
373 KernelType::Intersection => IntersectionKernel.compute(x, y),
374 KernelType::Periodic {
375 length_scale,
376 period,
377 } => PeriodicKernel::new(*length_scale, *period).compute(x, y),
378 KernelType::Precomputed => {
379 0.0 }
382 KernelType::Custom(_name) => {
383 x.dot(&y)
385 }
386 KernelType::Hellinger => {
387 let x_normalized = normalize_vector(&x.to_owned());
389 let y_normalized = normalize_vector(&y.to_owned());
390 x_normalized
391 .iter()
392 .zip(y_normalized.iter())
393 .map(|(a, b)| (a * b).sqrt())
394 .sum::<f64>()
395 .sqrt()
396 }
397 KernelType::JensenShannon => {
398 let x_normalized = normalize_vector(&x.to_owned());
400 let y_normalized = normalize_vector(&y.to_owned());
401
402 let mut js_divergence = 0.0;
403 for i in 0..x_normalized.len() {
404 let p = x_normalized[i];
405 let q = y_normalized[i];
406 let m = (p + q) / 2.0;
407
408 if p > 0.0 && m > 0.0 {
409 js_divergence += p * (p / m).ln();
410 }
411 if q > 0.0 && m > 0.0 {
412 js_divergence += q * (q / m).ln();
413 }
414 }
415 js_divergence /= 2.0;
416
417 (-js_divergence).exp()
418 }
419 }
420 }
421
422 pub fn compute_matrix(&self, x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
423 let (n_x, _) = x.dim();
424 let (n_y, _) = y.dim();
425 let mut kernel_matrix = Array2::zeros((n_x, n_y));
426
427 for i in 0..n_x {
428 for j in 0..n_y {
429 kernel_matrix[[i, j]] = self.compute(x.row(i), y.row(j));
430 }
431 }
432
433 kernel_matrix
434 }
435
436 pub fn kernel_type(&self) -> &KernelType {
437 &self.kernel_type
438 }
439}
440
441fn normalize_vector(vec: &Array1<f64>) -> Array1<f64> {
443 let sum: f64 = vec.iter().sum();
444 if sum == 0.0 {
445 vec.clone()
446 } else {
447 vec / sum
448 }
449}
450
451#[derive(Debug, Clone)]
453pub struct Graph {
454 pub adjacency_matrix: Array2<f64>,
455 pub node_labels: Option<Array1<usize>>,
456 pub edge_labels: Option<Array2<usize>>,
457}
458
459impl Graph {
460 pub fn new(adjacency_matrix: Array2<f64>) -> Self {
461 Self {
462 adjacency_matrix,
463 node_labels: None,
464 edge_labels: None,
465 }
466 }
467
468 pub fn with_node_labels(mut self, labels: Array1<usize>) -> Self {
469 self.node_labels = Some(labels);
470 self
471 }
472
473 pub fn with_edge_labels(mut self, labels: Array2<usize>) -> Self {
474 self.edge_labels = Some(labels);
475 self
476 }
477}
478
479#[derive(Debug, Clone)]
481pub struct RandomWalkKernel {
482 pub lambda: f64, pub max_steps: usize,
484}
485
486impl RandomWalkKernel {
487 pub fn new(lambda: f64, max_steps: usize) -> Self {
488 Self { lambda, max_steps }
489 }
490
491 pub fn compute_graph_kernel(&self, g1: &Graph, g2: &Graph) -> f64 {
492 let n1 = g1.adjacency_matrix.nrows();
497 let n2 = g2.adjacency_matrix.nrows();
498
499 let density1 = g1.adjacency_matrix.sum() / (n1 * n1) as f64;
501 let density2 = g2.adjacency_matrix.sum() / (n2 * n2) as f64;
502
503 (-(density1 - density2).abs()).exp()
504 }
505}
506
507#[derive(Debug)]
509pub struct HellingerKernel;
510
511impl Kernel for HellingerKernel {
512 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
513 x.iter()
515 .zip(y.iter())
516 .map(|(xi, yi)| (xi * yi).sqrt())
517 .sum()
518 }
519
520 fn parameters(&self) -> HashMap<String, f64> {
521 HashMap::new()
522 }
523}
524
525#[derive(Debug)]
527pub struct JensenShannonKernel;
528
529impl JensenShannonKernel {
530 fn jensen_shannon_divergence(&self, p: ArrayView1<f64>, q: ArrayView1<f64>) -> f64 {
531 let m: Vec<f64> = p
533 .iter()
534 .zip(q.iter())
535 .map(|(pi, qi)| 0.5 * (pi + qi))
536 .collect();
537 let m = Array1::from_vec(m);
538
539 let kl_pm = self.kl_divergence(p, m.view());
540 let kl_qm = self.kl_divergence(q, m.view());
541
542 0.5 * kl_pm + 0.5 * kl_qm
543 }
544
545 fn kl_divergence(&self, p: ArrayView1<f64>, q: ArrayView1<f64>) -> f64 {
546 p.iter()
548 .zip(q.iter())
549 .map(|(pi, qi)| {
550 if *pi > 0.0 && *qi > 0.0 {
551 pi * (pi / qi).ln()
552 } else {
553 0.0
554 }
555 })
556 .sum()
557 }
558}
559
560impl Kernel for JensenShannonKernel {
561 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
562 let js_div = self.jensen_shannon_divergence(x, y);
564 (-js_div).exp()
565 }
566
567 fn parameters(&self) -> HashMap<String, f64> {
568 HashMap::new()
569 }
570}
571
572impl Kernel for KernelType {
574 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
575 match self {
576 KernelType::Linear => LinearKernel.compute(x, y),
577 KernelType::Rbf { gamma } => RbfKernel::new(*gamma).compute(x, y),
578 KernelType::Polynomial {
579 gamma,
580 coef0,
581 degree,
582 } => PolynomialKernel::new(*gamma, *coef0, *degree).compute(x, y),
583 KernelType::Sigmoid { gamma, coef0 } => {
584 SigmoidKernel::new(*gamma, *coef0).compute(x, y)
585 }
586 KernelType::Cosine => CosineKernel.compute(x, y),
587 KernelType::ChiSquared { gamma } => ChiSquaredKernel::new(*gamma).compute(x, y),
588 KernelType::Intersection => IntersectionKernel.compute(x, y),
589 KernelType::Hellinger => HellingerKernel.compute(x, y),
590 KernelType::JensenShannon => JensenShannonKernel.compute(x, y),
591 KernelType::Periodic {
592 length_scale,
593 period,
594 } => PeriodicKernel::new(*length_scale, *period).compute(x, y),
595 KernelType::Precomputed => {
596 0.0 }
599 KernelType::Custom(_name) => {
600 x.dot(&y)
602 }
603 }
604 }
605
606 fn parameters(&self) -> HashMap<String, f64> {
607 match self {
608 KernelType::Linear => HashMap::new(),
609 KernelType::Rbf { gamma } => {
610 let mut params = HashMap::new();
611 params.insert("gamma".to_string(), *gamma);
612 params
613 }
614 KernelType::Polynomial {
615 gamma,
616 coef0,
617 degree,
618 } => {
619 let mut params = HashMap::new();
620 params.insert("gamma".to_string(), *gamma);
621 params.insert("coef0".to_string(), *coef0);
622 params.insert("degree".to_string(), *degree);
623 params
624 }
625 KernelType::Sigmoid { gamma, coef0 } => {
626 let mut params = HashMap::new();
627 params.insert("gamma".to_string(), *gamma);
628 params.insert("coef0".to_string(), *coef0);
629 params
630 }
631 KernelType::Cosine => HashMap::new(),
632 KernelType::ChiSquared { gamma } => {
633 let mut params = HashMap::new();
634 params.insert("gamma".to_string(), *gamma);
635 params
636 }
637 KernelType::Intersection => HashMap::new(),
638 KernelType::Hellinger => HashMap::new(),
639 KernelType::JensenShannon => HashMap::new(),
640 KernelType::Periodic {
641 length_scale,
642 period,
643 } => {
644 let mut params = HashMap::new();
645 params.insert("length_scale".to_string(), *length_scale);
646 params.insert("period".to_string(), *period);
647 params
648 }
649 KernelType::Precomputed => HashMap::new(),
650 KernelType::Custom(_name) => HashMap::new(),
651 }
652 }
653}
654
655impl Kernel for Box<dyn Kernel> {
657 fn compute(&self, x: ArrayView1<f64>, y: ArrayView1<f64>) -> f64 {
658 self.as_ref().compute(x, y)
659 }
660
661 fn compute_matrix(&self, x: &Array2<f64>, y: &Array2<f64>) -> Array2<f64> {
662 self.as_ref().compute_matrix(x, y)
663 }
664
665 fn parameters(&self) -> HashMap<String, f64> {
666 self.as_ref().parameters()
667 }
668}
669
670#[allow(non_snake_case)]
671#[cfg(test)]
672mod tests {
673 use super::*;
674 use approx::assert_abs_diff_eq;
675
676 #[test]
677 fn test_linear_kernel() {
678 let kernel = LinearKernel;
679 let x = Array1::from_vec(vec![1.0, 2.0, 3.0]);
680 let y = Array1::from_vec(vec![4.0, 5.0, 6.0]);
681
682 let result = kernel.compute(x.view(), y.view());
683 assert_abs_diff_eq!(result, 32.0, epsilon = 1e-10);
684 }
685
686 #[test]
687 fn test_rbf_kernel() {
688 let kernel = RbfKernel::new(1.0);
689 let x = Array1::from_vec(vec![1.0, 2.0]);
690 let y = Array1::from_vec(vec![1.0, 2.0]);
691
692 let result = kernel.compute(x.view(), y.view());
693 assert_abs_diff_eq!(result, 1.0, epsilon = 1e-10);
694 }
695
696 #[test]
697 fn test_polynomial_kernel() {
698 let kernel = PolynomialKernel::new(1.0, 1.0, 2.0);
699 let x = Array1::from_vec(vec![1.0, 2.0]);
700 let y = Array1::from_vec(vec![3.0, 4.0]);
701
702 let result = kernel.compute(x.view(), y.view());
703 let expected = (1.0_f64 * (1.0 * 3.0 + 2.0 * 4.0) + 1.0).powf(2.0);
704 assert_abs_diff_eq!(result, expected, epsilon = 1e-10);
705 }
706
707 #[test]
708 fn test_cosine_kernel() {
709 let kernel = CosineKernel;
710 let x = Array1::from_vec(vec![1.0, 0.0]);
711 let y = Array1::from_vec(vec![0.0, 1.0]);
712
713 let result = kernel.compute(x.view(), y.view());
714 assert_abs_diff_eq!(result, 0.0, epsilon = 1e-10);
715 }
716
717 #[test]
718 fn test_kernel_function() {
719 let kernel_fn = KernelFunction::new(KernelType::Rbf { gamma: 0.5 });
720 let x = Array1::from_vec(vec![1.0, 2.0]);
721 let y = Array1::from_vec(vec![1.0, 2.0]);
722
723 let result = kernel_fn.compute(x.view(), y.view());
724 assert_abs_diff_eq!(result, 1.0, epsilon = 1e-10);
725 }
726
727 #[test]
728 fn test_kernel_matrix() {
729 let kernel_fn = KernelFunction::new(KernelType::Linear);
730 let x = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
731 let y = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
732
733 let kernel_matrix = kernel_fn.compute_matrix(&x, &y);
734
735 assert_eq!(kernel_matrix.dim(), (2, 2));
736 assert_abs_diff_eq!(kernel_matrix[[0, 0]], 5.0, epsilon = 1e-10); assert_abs_diff_eq!(kernel_matrix[[1, 1]], 25.0, epsilon = 1e-10); }
739}