opensrdk_kernel_method/ard/
mod.rs

1use super::PositiveDefiniteKernel;
2use crate::{KernelAdd, KernelError, KernelMul};
3use opensrdk_symbolic_computation::Expression;
4use rayon::prelude::*;
5use std::{ops::Add, ops::Mul};
6
7fn weighted_norm_pow(x: Expression, x_prime: Expression, params: &[Expression]) -> Expression {
8    todo!()
9    // params
10    //     .par_iter()
11    //     .zip(x.par_iter())
12    //     .zip(x_prime.par_iter())
13    //     .map(|((relevance, xi), x_primei)| relevance * (xi - x_primei).powi(2))
14    //     .sum()
15}
16//must rewite this function!
17
18#[derive(Clone, Debug)]
19pub struct ARD(pub usize);
20
21impl PositiveDefiniteKernel for ARD {
22    fn params_len(&self) -> usize {
23        self.0
24    }
25
26    fn expression(
27        &self,
28        x: Expression,
29        x_prime: Expression,
30        params: &[Expression],
31    ) -> Result<Expression, KernelError> {
32        if params.len() != self.0 {
33            return Err(KernelError::ParametersLengthMismatch.into());
34        }
35        // if x.len() != self.0 || x_prime.len() != self.0 {
36        //     return Err(KernelError::InvalidArgument.into());
37        // }
38        todo!()
39
40        // let fx = (-weighted_norm_pow(&params, x, x_prime)).exp();
41
42        // Ok(fx)
43    }
44}
45
46impl<R> Add<R> for ARD
47where
48    R: PositiveDefiniteKernel,
49{
50    type Output = KernelAdd<Self, R>;
51
52    fn add(self, rhs: R) -> Self::Output {
53        Self::Output::new(self, rhs)
54    }
55}
56
57impl<R> Mul<R> for ARD
58where
59    R: PositiveDefiniteKernel,
60{
61    type Output = KernelMul<Self, R>;
62
63    fn mul(self, rhs: R) -> Self::Output {
64        Self::Output::new(self, rhs)
65    }
66}
67
68// impl ValueDifferentiableKernel<Vec<f64>> for ARD {
69//     fn ln_diff_value(
70//         &self,
71//         params: &[f64],
72//         x: &Vec<f64>,
73//         xprime: &Vec<f64>,
74//     ) -> Result<Vec<f64>, KernelError> {
75//         let diff = params
76//             .par_iter()
77//             .zip(x.par_iter())
78//             .zip(xprime.par_iter())
79//             .map(|((relevance, xi), xprimei)| -2.0 * relevance * (xi - xprimei))
80//             .collect::<Vec<f64>>();
81//         Ok(diff)
82//     }
83// }
84
85// impl ParamsDifferentiableKernel<Vec<f64>> for ARD {
86//     fn ln_diff_params(
87//         &self,
88//         params: &[f64],
89//         x: &Vec<f64>,
90//         xprime: &Vec<f64>,
91//     ) -> Result<Vec<f64>, KernelError> {
92//         let diff = params
93//             .par_iter()
94//             .zip(x.par_iter())
95//             .zip(xprime.par_iter())
96//             .map(|((_relevance, xi), xprimei)| -(xi - xprimei).powi(2))
97//             .collect::<Vec<f64>>();
98//         Ok(diff)
99//     }
100// }
101
102// #[cfg(test)]
103// mod tests {
104//     use crate::*;
105//     #[test]
106//     fn it_works() {
107//         let kernel = ARD(3);
108
109//         let test_value = kernel
110//             .value(&[1.0, 0.0, 0.0], &vec![1.0, 2.0, 3.0], &vec![0.0, 2.0, 1.0])
111//             .unwrap();
112
113//         assert_eq!(test_value, (-1f64).exp());
114//     }
115// }