1use scirs2_core::ndarray::{Array2, ArrayView1};
8use sklears_core::error::{Result as SklResult, SklearsError};
10
11pub use crate::kernel_trait::Kernel;
13
14#[derive(Debug, Clone)]
16pub struct RBF {
17 length_scale: f64,
18}
19
20impl RBF {
21 pub fn new(length_scale: f64) -> Self {
22 Self { length_scale }
23 }
24}
25
26impl Kernel for RBF {
27 fn compute_kernel_matrix(
28 &self,
29 X1: &Array2<f64>,
30 X2: Option<&Array2<f64>>,
31 ) -> SklResult<Array2<f64>> {
32 let X2 = X2.unwrap_or(X1);
33 let n1 = X1.nrows();
34 let n2 = X2.nrows();
35 let mut K = Array2::<f64>::zeros((n1, n2));
36
37 for i in 0..n1 {
38 for j in 0..n2 {
39 let x1 = X1.row(i);
40 let x2 = X2.row(j);
41 K[[i, j]] = self.kernel(&x1, &x2);
42 }
43 }
44 Ok(K)
45 }
46
47 fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
48 let mut sq_dist = 0.0;
49 for (a, b) in x1.iter().zip(x2.iter()) {
50 sq_dist += (a - b).powi(2);
51 }
52 (-sq_dist / (2.0 * self.length_scale.powi(2))).exp()
53 }
54
55 fn get_params(&self) -> Vec<f64> {
56 vec![self.length_scale]
57 }
58
59 fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
60 if params.len() != 1 {
61 return Err(SklearsError::InvalidInput(
62 "RBF kernel requires exactly 1 parameter".to_string(),
63 ));
64 }
65 self.length_scale = params[0];
66 Ok(())
67 }
68
69 fn clone_box(&self) -> Box<dyn Kernel> {
70 Box::new(self.clone())
71 }
72}
73
74#[derive(Debug, Clone)]
94pub struct ARDRBF {
95 length_scales: scirs2_core::ndarray::Array1<f64>,
97}
98
99impl ARDRBF {
100 pub fn new(length_scales: scirs2_core::ndarray::Array1<f64>) -> Self {
106 Self { length_scales }
107 }
108
109 pub fn new_uniform(n_dims: usize, length_scale: f64) -> Self {
116 Self {
117 length_scales: scirs2_core::ndarray::Array1::from_elem(n_dims, length_scale),
118 }
119 }
120
121 pub fn n_dimensions(&self) -> usize {
123 self.length_scales.len()
124 }
125}
126
127impl Kernel for ARDRBF {
128 fn compute_kernel_matrix(
129 &self,
130 X1: &Array2<f64>,
131 X2: Option<&Array2<f64>>,
132 ) -> SklResult<Array2<f64>> {
133 let X2 = X2.unwrap_or(X1);
134 let n1 = X1.nrows();
135 let n2 = X2.nrows();
136
137 if X1.ncols() != self.length_scales.len() {
139 return Err(SklearsError::InvalidInput(format!(
140 "X1 has {} dimensions but kernel expects {}",
141 X1.ncols(),
142 self.length_scales.len()
143 )));
144 }
145 if X2.ncols() != self.length_scales.len() {
146 return Err(SklearsError::InvalidInput(format!(
147 "X2 has {} dimensions but kernel expects {}",
148 X2.ncols(),
149 self.length_scales.len()
150 )));
151 }
152
153 let mut K = Array2::<f64>::zeros((n1, n2));
154
155 for i in 0..n1 {
156 for j in 0..n2 {
157 let x1 = X1.row(i);
158 let x2 = X2.row(j);
159 K[[i, j]] = self.kernel(&x1, &x2);
160 }
161 }
162 Ok(K)
163 }
164
165 fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
166 let mut weighted_sq_dist = 0.0;
168 for ((a, b), length_scale) in x1.iter().zip(x2.iter()).zip(self.length_scales.iter()) {
169 let diff = a - b;
170 weighted_sq_dist += (diff * diff) / (length_scale * length_scale);
171 }
172 (-0.5 * weighted_sq_dist).exp()
173 }
174
175 fn get_params(&self) -> Vec<f64> {
176 self.length_scales.to_vec()
177 }
178
179 fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
180 if params.len() != self.length_scales.len() {
181 return Err(SklearsError::InvalidInput(format!(
182 "ARD RBF kernel requires exactly {} parameters (one per dimension), got {}",
183 self.length_scales.len(),
184 params.len()
185 )));
186 }
187 for (i, ¶m) in params.iter().enumerate() {
188 self.length_scales[i] = param;
189 }
190 Ok(())
191 }
192
193 fn clone_box(&self) -> Box<dyn Kernel> {
194 Box::new(self.clone())
195 }
196}
197
198#[derive(Debug, Clone)]
200pub struct Matern {
201 length_scale: f64,
202 nu: f64,
203}
204
205impl Matern {
206 pub fn new(length_scale: f64, nu: f64) -> Self {
207 Self { length_scale, nu }
208 }
209}
210
211impl Kernel for Matern {
212 fn compute_kernel_matrix(
213 &self,
214 X1: &Array2<f64>,
215 X2: Option<&Array2<f64>>,
216 ) -> SklResult<Array2<f64>> {
217 let X2 = X2.unwrap_or(X1);
218 let n1 = X1.nrows();
219 let n2 = X2.nrows();
220 let mut K = Array2::<f64>::zeros((n1, n2));
221
222 for i in 0..n1 {
223 for j in 0..n2 {
224 let x1 = X1.row(i);
225 let x2 = X2.row(j);
226 K[[i, j]] = self.kernel(&x1, &x2);
227 }
228 }
229 Ok(K)
230 }
231
232 fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
233 let mut sq_dist = 0.0;
234 for (a, b) in x1.iter().zip(x2.iter()) {
235 sq_dist += (a - b).powi(2);
236 }
237 let dist = sq_dist.sqrt();
238
239 if dist == 0.0 {
240 return 1.0;
241 }
242
243 let sqrt_3_dist = (3.0_f64).sqrt() * dist / self.length_scale;
244 (1.0 + sqrt_3_dist) * (-sqrt_3_dist).exp()
245 }
246
247 fn get_params(&self) -> Vec<f64> {
248 vec![self.length_scale, self.nu]
249 }
250
251 fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
252 if params.len() != 2 {
253 return Err(SklearsError::InvalidInput(
254 "Matern kernel requires exactly 2 parameters".to_string(),
255 ));
256 }
257 self.length_scale = params[0];
258 self.nu = params[1];
259 Ok(())
260 }
261
262 fn clone_box(&self) -> Box<dyn Kernel> {
263 Box::new(self.clone())
264 }
265}
266
267#[derive(Debug, Clone)]
269pub struct Linear {
270 sigma_0_sq: f64,
271 sigma_1_sq: f64,
272}
273
274impl Linear {
275 pub fn new(sigma_0_sq: f64, sigma_1_sq: f64) -> Self {
276 Self {
277 sigma_0_sq,
278 sigma_1_sq,
279 }
280 }
281}
282
283impl Kernel for Linear {
284 fn compute_kernel_matrix(
285 &self,
286 X1: &Array2<f64>,
287 X2: Option<&Array2<f64>>,
288 ) -> SklResult<Array2<f64>> {
289 let X2 = X2.unwrap_or(X1);
290 let n1 = X1.nrows();
291 let n2 = X2.nrows();
292 let mut K = Array2::<f64>::zeros((n1, n2));
293
294 for i in 0..n1 {
295 for j in 0..n2 {
296 let x1 = X1.row(i);
297 let x2 = X2.row(j);
298 K[[i, j]] = self.kernel(&x1, &x2);
299 }
300 }
301 Ok(K)
302 }
303
304 fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
305 let dot_product: f64 = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum();
306 self.sigma_0_sq + self.sigma_1_sq * dot_product
307 }
308
309 fn get_params(&self) -> Vec<f64> {
310 vec![self.sigma_0_sq, self.sigma_1_sq]
311 }
312
313 fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
314 if params.len() != 2 {
315 return Err(SklearsError::InvalidInput(
316 "Linear kernel requires exactly 2 parameters".to_string(),
317 ));
318 }
319 self.sigma_0_sq = params[0];
320 self.sigma_1_sq = params[1];
321 Ok(())
322 }
323
324 fn clone_box(&self) -> Box<dyn Kernel> {
325 Box::new(self.clone())
326 }
327}
328
329#[derive(Debug, Clone)]
331pub struct Polynomial {
332 gamma: f64,
333 coef0: f64,
334 degree: f64,
335}
336
337impl Polynomial {
338 pub fn new(gamma: f64, coef0: f64, degree: f64) -> Self {
339 Self {
340 gamma,
341 coef0,
342 degree,
343 }
344 }
345}
346
347impl Kernel for Polynomial {
348 fn compute_kernel_matrix(
349 &self,
350 X1: &Array2<f64>,
351 X2: Option<&Array2<f64>>,
352 ) -> SklResult<Array2<f64>> {
353 let X2 = X2.unwrap_or(X1);
354 let n1 = X1.nrows();
355 let n2 = X2.nrows();
356 let mut K = Array2::<f64>::zeros((n1, n2));
357
358 for i in 0..n1 {
359 for j in 0..n2 {
360 let x1 = X1.row(i);
361 let x2 = X2.row(j);
362 K[[i, j]] = self.kernel(&x1, &x2);
363 }
364 }
365 Ok(K)
366 }
367
368 fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
369 let dot_product: f64 = x1.iter().zip(x2.iter()).map(|(a, b)| a * b).sum();
370 (self.gamma * dot_product + self.coef0).powf(self.degree)
371 }
372
373 fn get_params(&self) -> Vec<f64> {
374 vec![self.gamma, self.coef0, self.degree]
375 }
376
377 fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
378 if params.len() != 3 {
379 return Err(SklearsError::InvalidInput(
380 "Polynomial kernel requires exactly 3 parameters".to_string(),
381 ));
382 }
383 self.gamma = params[0];
384 self.coef0 = params[1];
385 self.degree = params[2];
386 Ok(())
387 }
388
389 fn clone_box(&self) -> Box<dyn Kernel> {
390 Box::new(self.clone())
391 }
392}
393
394#[derive(Debug, Clone)]
396pub struct RationalQuadratic {
397 length_scale: f64,
398 alpha: f64,
399}
400
401impl RationalQuadratic {
402 pub fn new(length_scale: f64, alpha: f64) -> Self {
403 Self {
404 length_scale,
405 alpha,
406 }
407 }
408}
409
410impl Kernel for RationalQuadratic {
411 fn compute_kernel_matrix(
412 &self,
413 X1: &Array2<f64>,
414 X2: Option<&Array2<f64>>,
415 ) -> SklResult<Array2<f64>> {
416 let X2 = X2.unwrap_or(X1);
417 let n1 = X1.nrows();
418 let n2 = X2.nrows();
419 let mut K = Array2::<f64>::zeros((n1, n2));
420
421 for i in 0..n1 {
422 for j in 0..n2 {
423 let x1 = X1.row(i);
424 let x2 = X2.row(j);
425 K[[i, j]] = self.kernel(&x1, &x2);
426 }
427 }
428 Ok(K)
429 }
430
431 fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
432 let mut sq_dist = 0.0;
433 for (a, b) in x1.iter().zip(x2.iter()) {
434 sq_dist += (a - b).powi(2);
435 }
436 (1.0 + sq_dist / (2.0 * self.alpha * self.length_scale.powi(2))).powf(-self.alpha)
437 }
438
439 fn get_params(&self) -> Vec<f64> {
440 vec![self.length_scale, self.alpha]
441 }
442
443 fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
444 if params.len() != 2 {
445 return Err(SklearsError::InvalidInput(
446 "RationalQuadratic kernel requires exactly 2 parameters".to_string(),
447 ));
448 }
449 self.length_scale = params[0];
450 self.alpha = params[1];
451 Ok(())
452 }
453
454 fn clone_box(&self) -> Box<dyn Kernel> {
455 Box::new(self.clone())
456 }
457}
458
459#[derive(Debug, Clone)]
461pub struct ExpSineSquared {
462 length_scale: f64,
463 periodicity: f64,
464}
465
466impl ExpSineSquared {
467 pub fn new(length_scale: f64, periodicity: f64) -> Self {
468 Self {
469 length_scale,
470 periodicity,
471 }
472 }
473}
474
475impl Kernel for ExpSineSquared {
476 fn compute_kernel_matrix(
477 &self,
478 X1: &Array2<f64>,
479 X2: Option<&Array2<f64>>,
480 ) -> SklResult<Array2<f64>> {
481 let X2 = X2.unwrap_or(X1);
482 let n1 = X1.nrows();
483 let n2 = X2.nrows();
484 let mut K = Array2::<f64>::zeros((n1, n2));
485
486 for i in 0..n1 {
487 for j in 0..n2 {
488 let x1 = X1.row(i);
489 let x2 = X2.row(j);
490 K[[i, j]] = self.kernel(&x1, &x2);
491 }
492 }
493 Ok(K)
494 }
495
496 fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
497 let dist = x1
498 .iter()
499 .zip(x2.iter())
500 .map(|(a, b)| (a - b).powi(2))
501 .sum::<f64>()
502 .sqrt();
503 let sin_term = (std::f64::consts::PI * dist / self.periodicity).sin();
504 (-2.0 * sin_term.powi(2) / self.length_scale.powi(2)).exp()
505 }
506
507 fn get_params(&self) -> Vec<f64> {
508 vec![self.length_scale, self.periodicity]
509 }
510
511 fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
512 if params.len() != 2 {
513 return Err(SklearsError::InvalidInput(
514 "ExpSineSquared kernel requires exactly 2 parameters".to_string(),
515 ));
516 }
517 self.length_scale = params[0];
518 self.periodicity = params[1];
519 Ok(())
520 }
521
522 fn clone_box(&self) -> Box<dyn Kernel> {
523 Box::new(self.clone())
524 }
525}
526
527#[derive(Debug, Clone)]
529pub struct WhiteKernel {
530 noise_level: f64,
531}
532
533impl WhiteKernel {
534 pub fn new(noise_level: f64) -> Self {
535 Self { noise_level }
536 }
537}
538
539impl Kernel for WhiteKernel {
540 fn compute_kernel_matrix(
541 &self,
542 X1: &Array2<f64>,
543 X2: Option<&Array2<f64>>,
544 ) -> SklResult<Array2<f64>> {
545 let n1 = X1.nrows();
546 let n2 = X2.map_or(n1, |x| x.nrows());
547 let mut K = Array2::<f64>::zeros((n1, n2));
548
549 if X2.is_none() {
551 for i in 0..n1 {
552 K[[i, i]] = self.noise_level;
553 }
554 }
555 Ok(K)
556 }
557
558 fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
559 let identical = x1.iter().zip(x2.iter()).all(|(a, b)| (a - b).abs() < 1e-10);
561 if identical {
562 self.noise_level
563 } else {
564 0.0
565 }
566 }
567
568 fn get_params(&self) -> Vec<f64> {
569 vec![self.noise_level]
570 }
571
572 fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
573 if params.len() != 1 {
574 return Err(SklearsError::InvalidInput(
575 "WhiteKernel requires exactly 1 parameter".to_string(),
576 ));
577 }
578 self.noise_level = params[0];
579 Ok(())
580 }
581
582 fn clone_box(&self) -> Box<dyn Kernel> {
583 Box::new(self.clone())
584 }
585}
586
587#[derive(Debug, Clone)]
589pub struct ConstantKernel {
590 constant_value: f64,
591}
592
593impl ConstantKernel {
594 pub fn new(constant_value: f64) -> Self {
595 Self { constant_value }
596 }
597}
598
599impl Kernel for ConstantKernel {
600 fn compute_kernel_matrix(
601 &self,
602 X1: &Array2<f64>,
603 X2: Option<&Array2<f64>>,
604 ) -> SklResult<Array2<f64>> {
605 let n1 = X1.nrows();
606 let n2 = X2.map_or(n1, |x| x.nrows());
607 Ok(Array2::<f64>::from_elem((n1, n2), self.constant_value))
608 }
609
610 fn kernel(&self, _x1: &ArrayView1<f64>, _x2: &ArrayView1<f64>) -> f64 {
611 self.constant_value
612 }
613
614 fn get_params(&self) -> Vec<f64> {
615 vec![self.constant_value]
616 }
617
618 fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
619 if params.len() != 1 {
620 return Err(SklearsError::InvalidInput(
621 "ConstantKernel requires exactly 1 parameter".to_string(),
622 ));
623 }
624 self.constant_value = params[0];
625 Ok(())
626 }
627
628 fn clone_box(&self) -> Box<dyn Kernel> {
629 Box::new(self.clone())
630 }
631}
632
633#[derive(Debug, Clone)]
635pub struct SumKernel {
636 kernels: Vec<Box<dyn Kernel>>,
637}
638
639impl SumKernel {
640 pub fn new(kernels: Vec<Box<dyn Kernel>>) -> Self {
641 Self { kernels }
642 }
643}
644
645impl Kernel for SumKernel {
646 fn compute_kernel_matrix(
647 &self,
648 X1: &Array2<f64>,
649 X2: Option<&Array2<f64>>,
650 ) -> SklResult<Array2<f64>> {
651 if self.kernels.is_empty() {
652 return Err(SklearsError::InvalidInput(
653 "SumKernel requires at least one kernel".to_string(),
654 ));
655 }
656
657 let mut result = self.kernels[0].compute_kernel_matrix(X1, X2)?;
658 for kernel in &self.kernels[1..] {
659 let k_matrix = kernel.compute_kernel_matrix(X1, X2)?;
660 result = result + k_matrix;
661 }
662 Ok(result)
663 }
664
665 fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
666 self.kernels.iter().map(|k| k.kernel(x1, x2)).sum()
667 }
668
669 fn get_params(&self) -> Vec<f64> {
670 self.kernels.iter().flat_map(|k| k.get_params()).collect()
671 }
672
673 fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
674 let mut offset = 0;
675 for kernel in &mut self.kernels {
676 let n_params = kernel.get_params().len();
677 if offset + n_params > params.len() {
678 return Err(SklearsError::InvalidInput(
679 "Not enough parameters for SumKernel".to_string(),
680 ));
681 }
682 kernel.set_params(¶ms[offset..offset + n_params])?;
683 offset += n_params;
684 }
685 Ok(())
686 }
687
688 fn clone_box(&self) -> Box<dyn Kernel> {
689 Box::new(Self {
690 kernels: self.kernels.iter().map(|k| k.clone_box()).collect(),
691 })
692 }
693}
694
695#[derive(Debug, Clone)]
697pub struct ProductKernel {
698 kernels: Vec<Box<dyn Kernel>>,
699}
700
701impl ProductKernel {
702 pub fn new(kernels: Vec<Box<dyn Kernel>>) -> Self {
703 Self { kernels }
704 }
705}
706
707impl Kernel for ProductKernel {
708 fn compute_kernel_matrix(
709 &self,
710 X1: &Array2<f64>,
711 X2: Option<&Array2<f64>>,
712 ) -> SklResult<Array2<f64>> {
713 if self.kernels.is_empty() {
714 return Err(SklearsError::InvalidInput(
715 "ProductKernel requires at least one kernel".to_string(),
716 ));
717 }
718
719 let mut result = self.kernels[0].compute_kernel_matrix(X1, X2)?;
720 for kernel in &self.kernels[1..] {
721 let k_matrix = kernel.compute_kernel_matrix(X1, X2)?;
722 result = result * k_matrix;
723 }
724 Ok(result)
725 }
726
727 fn kernel(&self, x1: &ArrayView1<f64>, x2: &ArrayView1<f64>) -> f64 {
728 self.kernels.iter().map(|k| k.kernel(x1, x2)).product()
729 }
730
731 fn get_params(&self) -> Vec<f64> {
732 self.kernels.iter().flat_map(|k| k.get_params()).collect()
733 }
734
735 fn set_params(&mut self, params: &[f64]) -> SklResult<()> {
736 let mut offset = 0;
737 for kernel in &mut self.kernels {
738 let n_params = kernel.get_params().len();
739 if offset + n_params > params.len() {
740 return Err(SklearsError::InvalidInput(
741 "Not enough parameters for ProductKernel".to_string(),
742 ));
743 }
744 kernel.set_params(¶ms[offset..offset + n_params])?;
745 offset += n_params;
746 }
747 Ok(())
748 }
749
750 fn clone_box(&self) -> Box<dyn Kernel> {
751 Box::new(Self {
752 kernels: self.kernels.iter().map(|k| k.clone_box()).collect(),
753 })
754 }
755}
756
757#[cfg(test)]
758mod tests {
759 use super::*;
760 use scirs2_core::ndarray::{array, Array1};
762
763 #[test]
764 fn test_ardrbf_creation() {
765 let length_scales = Array1::from_vec(vec![1.0, 2.0, 0.5]);
766 let kernel = ARDRBF::new(length_scales.clone());
767 assert_eq!(kernel.n_dimensions(), 3);
768 assert_eq!(kernel.get_params(), vec![1.0, 2.0, 0.5]);
769 }
770
771 #[test]
772 fn test_ardrbf_creation_uniform() {
773 let kernel = ARDRBF::new_uniform(4, 1.5);
774 assert_eq!(kernel.n_dimensions(), 4);
775 assert_eq!(kernel.get_params(), vec![1.5, 1.5, 1.5, 1.5]);
776 }
777
778 #[test]
779 fn test_ardrbf_kernel_identical_points() {
780 let length_scales = Array1::from_vec(vec![1.0, 1.0]);
781 let kernel = ARDRBF::new(length_scales);
782
783 let x1 = array![1.0, 2.0];
784 let x2 = array![1.0, 2.0];
785
786 let k = kernel.kernel(&x1.view(), &x2.view());
787 assert!(
788 (k - 1.0).abs() < 1e-10,
789 "Kernel of identical points should be 1.0"
790 );
791 }
792
793 #[test]
794 fn test_ardrbf_kernel_different_points() {
795 let length_scales = Array1::from_vec(vec![1.0, 1.0]);
796 let kernel = ARDRBF::new(length_scales);
797
798 let x1 = array![0.0, 0.0];
799 let x2 = array![1.0, 1.0];
800
801 let k = kernel.kernel(&x1.view(), &x2.view());
802 assert!(k > 0.0 && k < 1.0, "Kernel should be between 0 and 1");
804 assert!((k - (-1.0f64).exp()).abs() < 1e-10, "Kernel value mismatch");
805 }
806
807 #[test]
808 fn test_ardrbf_kernel_matrix() {
809 let length_scales = Array1::from_vec(vec![1.0, 1.0]);
810 let kernel = ARDRBF::new(length_scales);
811
812 let x = array![[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]];
813 let k_matrix = kernel.compute_kernel_matrix(&x, None).unwrap();
814
815 assert_eq!(k_matrix.dim(), (3, 3));
816 assert!((k_matrix[[0, 0]] - 1.0).abs() < 1e-10);
818 assert!((k_matrix[[1, 1]] - 1.0).abs() < 1e-10);
819 assert!((k_matrix[[2, 2]] - 1.0).abs() < 1e-10);
820 assert!((k_matrix[[0, 1]] - k_matrix[[1, 0]]).abs() < 1e-10);
822 assert!((k_matrix[[0, 2]] - k_matrix[[2, 0]]).abs() < 1e-10);
823 assert!((k_matrix[[1, 2]] - k_matrix[[2, 1]]).abs() < 1e-10);
824 }
825
826 #[test]
827 fn test_ardrbf_relevance_determination() {
828 let length_scales = Array1::from_vec(vec![0.1, 10.0]); let kernel = ARDRBF::new(length_scales);
831
832 let x1 = array![0.0, 0.0];
833 let x2_dim1 = array![1.0, 0.0]; let x2_dim2 = array![0.0, 1.0]; let k1 = kernel.kernel(&x1.view(), &x2_dim1.view());
837 let k2 = kernel.kernel(&x1.view(), &x2_dim2.view());
838
839 assert!(
841 k1 < k2,
842 "Dimension with smaller length scale should have more effect"
843 );
844 }
845
846 #[test]
847 fn test_ardrbf_set_params() {
848 let length_scales = Array1::from_vec(vec![1.0, 1.0]);
849 let mut kernel = ARDRBF::new(length_scales);
850
851 let new_params = vec![2.0, 3.0];
852 kernel.set_params(&new_params).unwrap();
853
854 assert_eq!(kernel.get_params(), vec![2.0, 3.0]);
855 }
856
857 #[test]
858 fn test_ardrbf_set_params_wrong_size() {
859 let length_scales = Array1::from_vec(vec![1.0, 1.0]);
860 let mut kernel = ARDRBF::new(length_scales);
861
862 let wrong_params = vec![1.0, 2.0, 3.0]; let result = kernel.set_params(&wrong_params);
864
865 assert!(
866 result.is_err(),
867 "Should error with wrong number of parameters"
868 );
869 }
870
871 #[test]
872 fn test_ardrbf_dimension_validation() {
873 let length_scales = Array1::from_vec(vec![1.0, 1.0]);
874 let kernel = ARDRBF::new(length_scales);
875
876 let x_wrong_dim = array![[0.0, 0.0, 0.0]]; let result = kernel.compute_kernel_matrix(&x_wrong_dim, None);
878
879 assert!(
880 result.is_err(),
881 "Should error with wrong number of dimensions"
882 );
883 }
884
885 #[test]
886 fn test_ardrbf_clone() {
887 let length_scales = Array1::from_vec(vec![1.0, 2.0]);
888 let kernel = ARDRBF::new(length_scales);
889 let cloned = kernel.clone();
890
891 assert_eq!(kernel.get_params(), cloned.get_params());
892 assert_eq!(kernel.n_dimensions(), cloned.n_dimensions());
893 }
894
895 #[test]
896 fn test_ardrbf_vs_rbf_isotropic() {
897 let length_scale = 1.5;
899 let ard_kernel = ARDRBF::new_uniform(2, length_scale);
900 let rbf_kernel = RBF::new(length_scale);
901
902 let x1 = array![1.0, 2.0];
903 let x2 = array![3.0, 4.0];
904
905 let k_ard = ard_kernel.kernel(&x1.view(), &x2.view());
906 let k_rbf = rbf_kernel.kernel(&x1.view(), &x2.view());
907
908 assert!(
909 (k_ard - k_rbf).abs() < 1e-10,
910 "ARD RBF with uniform length scales should match RBF"
911 );
912 }
913}