opensrdk_kernel_method/ard/mod.rs
1use super::PositiveDefiniteKernel;
2use crate::{KernelAdd, KernelError, KernelMul};
3use opensrdk_symbolic_computation::Expression;
4use rayon::prelude::*;
5use std::{ops::Add, ops::Mul};
6
7fn weighted_norm_pow(x: Expression, x_prime: Expression, params: &[Expression]) -> Expression {
8 todo!()
9 // params
10 // .par_iter()
11 // .zip(x.par_iter())
12 // .zip(x_prime.par_iter())
13 // .map(|((relevance, xi), x_primei)| relevance * (xi - x_primei).powi(2))
14 // .sum()
15}
16//must rewite this function!
17
18#[derive(Clone, Debug)]
19pub struct ARD(pub usize);
20
21impl PositiveDefiniteKernel for ARD {
22 fn params_len(&self) -> usize {
23 self.0
24 }
25
26 fn expression(
27 &self,
28 x: Expression,
29 x_prime: Expression,
30 params: &[Expression],
31 ) -> Result<Expression, KernelError> {
32 if params.len() != self.0 {
33 return Err(KernelError::ParametersLengthMismatch.into());
34 }
35 // if x.len() != self.0 || x_prime.len() != self.0 {
36 // return Err(KernelError::InvalidArgument.into());
37 // }
38 todo!()
39
40 // let fx = (-weighted_norm_pow(¶ms, x, x_prime)).exp();
41
42 // Ok(fx)
43 }
44}
45
46impl<R> Add<R> for ARD
47where
48 R: PositiveDefiniteKernel,
49{
50 type Output = KernelAdd<Self, R>;
51
52 fn add(self, rhs: R) -> Self::Output {
53 Self::Output::new(self, rhs)
54 }
55}
56
57impl<R> Mul<R> for ARD
58where
59 R: PositiveDefiniteKernel,
60{
61 type Output = KernelMul<Self, R>;
62
63 fn mul(self, rhs: R) -> Self::Output {
64 Self::Output::new(self, rhs)
65 }
66}
67
68// impl ValueDifferentiableKernel<Vec<f64>> for ARD {
69// fn ln_diff_value(
70// &self,
71// params: &[f64],
72// x: &Vec<f64>,
73// xprime: &Vec<f64>,
74// ) -> Result<Vec<f64>, KernelError> {
75// let diff = params
76// .par_iter()
77// .zip(x.par_iter())
78// .zip(xprime.par_iter())
79// .map(|((relevance, xi), xprimei)| -2.0 * relevance * (xi - xprimei))
80// .collect::<Vec<f64>>();
81// Ok(diff)
82// }
83// }
84
85// impl ParamsDifferentiableKernel<Vec<f64>> for ARD {
86// fn ln_diff_params(
87// &self,
88// params: &[f64],
89// x: &Vec<f64>,
90// xprime: &Vec<f64>,
91// ) -> Result<Vec<f64>, KernelError> {
92// let diff = params
93// .par_iter()
94// .zip(x.par_iter())
95// .zip(xprime.par_iter())
96// .map(|((_relevance, xi), xprimei)| -(xi - xprimei).powi(2))
97// .collect::<Vec<f64>>();
98// Ok(diff)
99// }
100// }
101
102// #[cfg(test)]
103// mod tests {
104// use crate::*;
105// #[test]
106// fn it_works() {
107// let kernel = ARD(3);
108
109// let test_value = kernel
110// .value(&[1.0, 0.0, 0.0], &vec![1.0, 2.0, 3.0], &vec![0.0, 2.0, 1.0])
111// .unwrap();
112
113// assert_eq!(test_value, (-1f64).exp());
114// }
115// }