Skip to main content

sklears_simd/
kernels.rs

1//! SIMD-optimized kernel functions for machine learning
2
3use crate::vector::dot_product;
4
5#[cfg(feature = "no-std")]
6use alloc::{vec, vec::Vec};
7#[cfg(not(feature = "no-std"))]
8use std::{vec, vec::Vec};
9
10/// SIMD-optimized RBF (Gaussian) kernel function
11pub fn rbf_kernel(x: &[f32], y: &[f32], gamma: f32) -> f32 {
12    let distance_squared = euclidean_distance_squared(x, y);
13    (-gamma * distance_squared).exp()
14}
15
16/// SIMD-optimized polynomial kernel function
17pub fn polynomial_kernel(x: &[f32], y: &[f32], degree: f32, coef0: f32, gamma: f32) -> f32 {
18    let dot_prod = dot_product(x, y);
19    (gamma * dot_prod + coef0).powf(degree)
20}
21
22/// SIMD-optimized linear kernel function
23pub fn linear_kernel(x: &[f32], y: &[f32]) -> f32 {
24    dot_product(x, y)
25}
26
27/// SIMD-optimized sigmoid kernel function
28pub fn sigmoid_kernel(x: &[f32], y: &[f32], gamma: f32, coef0: f32) -> f32 {
29    let dot_prod = dot_product(x, y);
30    (gamma * dot_prod + coef0).tanh()
31}
32
33/// Helper function to compute squared Euclidean distance
34fn euclidean_distance_squared(a: &[f32], b: &[f32]) -> f32 {
35    assert_eq!(a.len(), b.len(), "Vectors must have the same length");
36
37    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
38    {
39        if crate::simd_feature_detected!("avx2") {
40            return unsafe { euclidean_distance_squared_avx2(a, b) };
41        } else if crate::simd_feature_detected!("sse2") {
42            return unsafe { euclidean_distance_squared_sse2(a, b) };
43        }
44    }
45
46    euclidean_distance_squared_scalar(a, b)
47}
48
49fn euclidean_distance_squared_scalar(a: &[f32], b: &[f32]) -> f32 {
50    a.iter()
51        .zip(b.iter())
52        .map(|(x, y)| {
53            let diff = x - y;
54            diff * diff
55        })
56        .sum()
57}
58
59#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
60#[target_feature(enable = "sse2")]
61unsafe fn euclidean_distance_squared_sse2(a: &[f32], b: &[f32]) -> f32 {
62    use core::arch::x86_64::*;
63
64    let mut sum = _mm_setzero_ps();
65    let mut i = 0;
66
67    while i + 4 <= a.len() {
68        let a_vec = _mm_loadu_ps(a.as_ptr().add(i));
69        let b_vec = _mm_loadu_ps(b.as_ptr().add(i));
70        let diff = _mm_sub_ps(a_vec, b_vec);
71        let squared = _mm_mul_ps(diff, diff);
72        sum = _mm_add_ps(sum, squared);
73        i += 4;
74    }
75
76    let mut result = [0.0f32; 4];
77    _mm_storeu_ps(result.as_mut_ptr(), sum);
78    let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
79
80    while i < a.len() {
81        let diff = a[i] - b[i];
82        scalar_sum += diff * diff;
83        i += 1;
84    }
85
86    scalar_sum
87}
88
89#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
90#[target_feature(enable = "avx2")]
91unsafe fn euclidean_distance_squared_avx2(a: &[f32], b: &[f32]) -> f32 {
92    use core::arch::x86_64::*;
93
94    let mut sum = _mm256_setzero_ps();
95    let mut i = 0;
96
97    while i + 8 <= a.len() {
98        let a_vec = _mm256_loadu_ps(a.as_ptr().add(i));
99        let b_vec = _mm256_loadu_ps(b.as_ptr().add(i));
100        let diff = _mm256_sub_ps(a_vec, b_vec);
101        let squared = _mm256_mul_ps(diff, diff);
102        sum = _mm256_add_ps(sum, squared);
103        i += 8;
104    }
105
106    let mut result = [0.0f32; 8];
107    _mm256_storeu_ps(result.as_mut_ptr(), sum);
108    let mut scalar_sum = result.iter().sum::<f32>();
109
110    while i < a.len() {
111        let diff = a[i] - b[i];
112        scalar_sum += diff * diff;
113        i += 1;
114    }
115
116    scalar_sum
117}
118
119/// Kernel matrix computation for batch processing
120pub fn kernel_matrix(
121    x_data: &[Vec<f32>],
122    y_data: &[Vec<f32>],
123    kernel_type: KernelType,
124) -> Vec<Vec<f32>> {
125    let mut matrix = vec![vec![0.0; y_data.len()]; x_data.len()];
126
127    for i in 0..x_data.len() {
128        for j in 0..y_data.len() {
129            matrix[i][j] = match kernel_type {
130                KernelType::Linear => linear_kernel(&x_data[i], &y_data[j]),
131                KernelType::Rbf { gamma } => rbf_kernel(&x_data[i], &y_data[j], gamma),
132                KernelType::Polynomial {
133                    degree,
134                    gamma,
135                    coef0,
136                } => polynomial_kernel(&x_data[i], &y_data[j], degree, coef0, gamma),
137                KernelType::Sigmoid { gamma, coef0 } => {
138                    sigmoid_kernel(&x_data[i], &y_data[j], gamma, coef0)
139                }
140            };
141        }
142    }
143
144    matrix
145}
146
147/// Kernel types for different kernel functions
148#[derive(Debug, Clone, Copy)]
149pub enum KernelType {
150    Linear,
151    Rbf { gamma: f32 },
152    Polynomial { degree: f32, gamma: f32, coef0: f32 },
153    Sigmoid { gamma: f32, coef0: f32 },
154}
155
156impl Default for KernelType {
157    fn default() -> Self {
158        KernelType::Rbf { gamma: 1.0 }
159    }
160}
161
162/// Compute kernel values for a single point against multiple points
163pub fn kernel_vector(x: &[f32], y_data: &[Vec<f32>], kernel_type: KernelType) -> Vec<f32> {
164    #[cfg(feature = "parallel")]
165    {
166        use rayon::prelude::*;
167        y_data
168            .par_iter()
169            .map(|y| match kernel_type {
170                KernelType::Linear => linear_kernel(x, y),
171                KernelType::Rbf { gamma } => rbf_kernel(x, y, gamma),
172                KernelType::Polynomial {
173                    degree,
174                    gamma,
175                    coef0,
176                } => polynomial_kernel(x, y, degree, coef0, gamma),
177                KernelType::Sigmoid { gamma, coef0 } => sigmoid_kernel(x, y, gamma, coef0),
178            })
179            .collect()
180    }
181
182    #[cfg(not(feature = "parallel"))]
183    {
184        y_data
185            .iter()
186            .map(|y| match kernel_type {
187                KernelType::Linear => linear_kernel(x, y),
188                KernelType::Rbf { gamma } => rbf_kernel(x, y, gamma),
189                KernelType::Polynomial {
190                    degree,
191                    gamma,
192                    coef0,
193                } => polynomial_kernel(x, y, degree, coef0, gamma),
194                KernelType::Sigmoid { gamma, coef0 } => sigmoid_kernel(x, y, gamma, coef0),
195            })
196            .collect()
197    }
198}
199
200#[allow(non_snake_case)]
201#[cfg(all(test, not(feature = "no-std")))]
202mod tests {
203    use super::*;
204    use approx::assert_relative_eq;
205
206    #[test]
207    fn test_linear_kernel() {
208        let x = vec![1.0, 2.0, 3.0];
209        let y = vec![4.0, 5.0, 6.0];
210
211        let result = linear_kernel(&x, &y);
212        let expected = 32.0; // 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
213
214        assert_relative_eq!(result, expected, epsilon = 1e-6);
215    }
216
217    #[test]
218    fn test_rbf_kernel() {
219        let x = vec![1.0, 2.0];
220        let y = vec![1.0, 2.0];
221        let gamma = 1.0;
222
223        let result = rbf_kernel(&x, &y, gamma);
224        let expected = 1.0; // Same points should give 1.0
225
226        assert_relative_eq!(result, expected, epsilon = 1e-6);
227    }
228
229    #[test]
230    fn test_polynomial_kernel() {
231        let x = vec![1.0, 2.0];
232        let y = vec![3.0, 4.0];
233        let degree = 2.0;
234        let gamma = 1.0;
235        let coef0 = 0.0;
236
237        let result = polynomial_kernel(&x, &y, degree, coef0, gamma);
238        let dot_prod: f32 = 1.0 * 3.0 + 2.0 * 4.0; // = 11
239        let expected = dot_prod.powf(degree); // = 121
240
241        assert_relative_eq!(result, expected, epsilon = 1e-6);
242    }
243
244    #[test]
245    fn test_sigmoid_kernel() {
246        let x = vec![1.0, 2.0];
247        let y = vec![3.0, 4.0];
248        let gamma = 1.0;
249        let coef0 = 0.0;
250
251        let result = sigmoid_kernel(&x, &y, gamma, coef0);
252        let dot_prod: f32 = 1.0 * 3.0 + 2.0 * 4.0; // = 11
253        let expected = (gamma * dot_prod + coef0).tanh();
254
255        assert_relative_eq!(result, expected, epsilon = 1e-6);
256    }
257
258    #[test]
259    fn test_kernel_matrix() {
260        let x_data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
261        let y_data = vec![vec![1.0, 2.0], vec![5.0, 6.0]];
262
263        let matrix = kernel_matrix(&x_data, &y_data, KernelType::Linear);
264
265        assert_eq!(matrix.len(), 2);
266        assert_eq!(matrix[0].len(), 2);
267
268        // Check diagonal element (same vectors)
269        assert_relative_eq!(matrix[0][0], 5.0, epsilon = 1e-6); // 1*1 + 2*2 = 5
270    }
271
272    #[test]
273    fn test_kernel_vector() {
274        let x = vec![1.0, 2.0];
275        let y_data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
276
277        let result = kernel_vector(&x, &y_data, KernelType::Linear);
278
279        assert_eq!(result.len(), 2);
280        assert_relative_eq!(result[0], 5.0, epsilon = 1e-6); // 1*1 + 2*2 = 5
281        assert_relative_eq!(result[1], 11.0, epsilon = 1e-6); // 1*3 + 2*4 = 11
282    }
283}