1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
use super::PositiveDefiniteKernel;
use crate::{KernelAdd, KernelError, KernelMul};
use opensrdk_symbolic_computation::Expression;
use rayon::prelude::*;
use std::{ops::Add, ops::Mul};
fn weighted_norm_pow(x: Expression, x_prime: Expression, params: &[Expression]) -> Expression {
todo!()
// params
// .par_iter()
// .zip(x.par_iter())
// .zip(x_prime.par_iter())
// .map(|((relevance, xi), x_primei)| relevance * (xi - x_primei).powi(2))
// .sum()
}
//must rewite this function!
#[derive(Clone, Debug)]
pub struct ARD(pub usize);
impl PositiveDefiniteKernel for ARD {
fn params_len(&self) -> usize {
self.0
}
fn expression(
&self,
x: Expression,
x_prime: Expression,
params: &[Expression],
) -> Result<Expression, KernelError> {
if params.len() != self.0 {
return Err(KernelError::ParametersLengthMismatch.into());
}
// if x.len() != self.0 || x_prime.len() != self.0 {
// return Err(KernelError::InvalidArgument.into());
// }
todo!()
// let fx = (-weighted_norm_pow(¶ms, x, x_prime)).exp();
// Ok(fx)
}
}
impl<R> Add<R> for ARD
where
R: PositiveDefiniteKernel,
{
type Output = KernelAdd<Self, R>;
fn add(self, rhs: R) -> Self::Output {
Self::Output::new(self, rhs)
}
}
impl<R> Mul<R> for ARD
where
R: PositiveDefiniteKernel,
{
type Output = KernelMul<Self, R>;
fn mul(self, rhs: R) -> Self::Output {
Self::Output::new(self, rhs)
}
}
// impl ValueDifferentiableKernel<Vec<f64>> for ARD {
// fn ln_diff_value(
// &self,
// params: &[f64],
// x: &Vec<f64>,
// xprime: &Vec<f64>,
// ) -> Result<Vec<f64>, KernelError> {
// let diff = params
// .par_iter()
// .zip(x.par_iter())
// .zip(xprime.par_iter())
// .map(|((relevance, xi), xprimei)| -2.0 * relevance * (xi - xprimei))
// .collect::<Vec<f64>>();
// Ok(diff)
// }
// }
// impl ParamsDifferentiableKernel<Vec<f64>> for ARD {
// fn ln_diff_params(
// &self,
// params: &[f64],
// x: &Vec<f64>,
// xprime: &Vec<f64>,
// ) -> Result<Vec<f64>, KernelError> {
// let diff = params
// .par_iter()
// .zip(x.par_iter())
// .zip(xprime.par_iter())
// .map(|((_relevance, xi), xprimei)| -(xi - xprimei).powi(2))
// .collect::<Vec<f64>>();
// Ok(diff)
// }
// }
// #[cfg(test)]
// mod tests {
// use crate::*;
// #[test]
// fn it_works() {
// let kernel = ARD(3);
// let test_value = kernel
// .value(&[1.0, 0.0, 0.0], &vec![1.0, 2.0, 3.0], &vec![0.0, 2.0, 1.0])
// .unwrap();
// assert_eq!(test_value, (-1f64).exp());
// }
// }