1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
use super::PositiveDefiniteKernel;
use crate::{KernelAdd, KernelError, KernelMul};
use opensrdk_symbolic_computation::Expression;
use rayon::prelude::*;
use std::{ops::Add, ops::Mul};

fn weighted_norm_pow(x: Expression, x_prime: Expression, params: &[Expression]) -> Expression {
    todo!()
    // params
    //     .par_iter()
    //     .zip(x.par_iter())
    //     .zip(x_prime.par_iter())
    //     .map(|((relevance, xi), x_primei)| relevance * (xi - x_primei).powi(2))
    //     .sum()
}
//must rewite this function!

#[derive(Clone, Debug)]
pub struct ARD(pub usize);

impl PositiveDefiniteKernel for ARD {
    fn params_len(&self) -> usize {
        self.0
    }

    fn expression(
        &self,
        x: Expression,
        x_prime: Expression,
        params: &[Expression],
    ) -> Result<Expression, KernelError> {
        if params.len() != self.0 {
            return Err(KernelError::ParametersLengthMismatch.into());
        }
        // if x.len() != self.0 || x_prime.len() != self.0 {
        //     return Err(KernelError::InvalidArgument.into());
        // }
        todo!()

        // let fx = (-weighted_norm_pow(&params, x, x_prime)).exp();

        // Ok(fx)
    }
}

impl<R> Add<R> for ARD
where
    R: PositiveDefiniteKernel,
{
    type Output = KernelAdd<Self, R>;

    fn add(self, rhs: R) -> Self::Output {
        Self::Output::new(self, rhs)
    }
}

impl<R> Mul<R> for ARD
where
    R: PositiveDefiniteKernel,
{
    type Output = KernelMul<Self, R>;

    fn mul(self, rhs: R) -> Self::Output {
        Self::Output::new(self, rhs)
    }
}

// impl ValueDifferentiableKernel<Vec<f64>> for ARD {
//     fn ln_diff_value(
//         &self,
//         params: &[f64],
//         x: &Vec<f64>,
//         xprime: &Vec<f64>,
//     ) -> Result<Vec<f64>, KernelError> {
//         let diff = params
//             .par_iter()
//             .zip(x.par_iter())
//             .zip(xprime.par_iter())
//             .map(|((relevance, xi), xprimei)| -2.0 * relevance * (xi - xprimei))
//             .collect::<Vec<f64>>();
//         Ok(diff)
//     }
// }

// impl ParamsDifferentiableKernel<Vec<f64>> for ARD {
//     fn ln_diff_params(
//         &self,
//         params: &[f64],
//         x: &Vec<f64>,
//         xprime: &Vec<f64>,
//     ) -> Result<Vec<f64>, KernelError> {
//         let diff = params
//             .par_iter()
//             .zip(x.par_iter())
//             .zip(xprime.par_iter())
//             .map(|((_relevance, xi), xprimei)| -(xi - xprimei).powi(2))
//             .collect::<Vec<f64>>();
//         Ok(diff)
//     }
// }

// #[cfg(test)]
// mod tests {
//     use crate::*;
//     #[test]
//     fn it_works() {
//         let kernel = ARD(3);

//         let test_value = kernel
//             .value(&[1.0, 0.0, 0.0], &vec![1.0, 2.0, 3.0], &vec![0.0, 2.0, 1.0])
//             .unwrap();

//         assert_eq!(test_value, (-1f64).exp());
//     }
// }