opensrdk_kernel_method/
periodic.rs

1use std::ops::{Add, Mul};
2
3use crate::KernelError;
4
5use super::{KernelAdd, KernelMul, PositiveDefiniteKernel};
6use opensrdk_symbolic_computation::Expression;
7
8const PARAMS_LEN: usize = 2;
9
10#[derive(Clone, Debug)]
11pub struct Periodic;
12
13impl PositiveDefiniteKernel for Periodic {
14    fn expression(
15        &self,
16        x: Expression,
17        x_prime: Expression,
18        params: &[Expression],
19    ) -> Result<Expression, KernelError> {
20        if params.len() != PARAMS_LEN {
21            return Err(KernelError::ParametersLengthMismatch.into());
22        }
23        // if x.len() != x_prime.len() {
24        //     return Err(KernelError::InvalidArgument.into());
25        // }
26        let diff = x - x_prime;
27
28        Ok((params[0].clone()
29            * (diff
30                .clone()
31                .dot(diff, &[[0, 0]])
32                .pow(Expression::from(1.0 / 2.0))
33                / params[1].clone())
34            .cos())
35        .exp())
36    }
37
38    fn params_len(&self) -> usize {
39        2
40    }
41}
42
43impl<R> Add<R> for Periodic
44where
45    R: PositiveDefiniteKernel,
46{
47    type Output = KernelAdd<Self, R>;
48
49    fn add(self, rhs: R) -> Self::Output {
50        KernelAdd::new(self, rhs)
51    }
52}
53
54impl<R> Mul<R> for Periodic
55where
56    R: PositiveDefiniteKernel,
57{
58    type Output = KernelMul<Self, R>;
59
60    fn mul(self, rhs: R) -> Self::Output {
61        KernelMul::new(self, rhs)
62    }
63}
64
65// use super::PositiveDefiniteKernel;
66// use crate::{
67//     KernelAdd, KernelError, KernelMul, ParamsDifferentiableKernel, ValueDifferentiableKernel,
68// };
69// use opensrdk_linear_algebra::Vector;
70// use rayon::prelude::*;
71// use std::{ops::Add, ops::Mul};
72
73// const PARAMS_LEN: usize = 2;
74
75// #[derive(Clone, Debug)]
76// pub struct Periodic;
77
78// impl Periodic {
79//     fn norm(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> {
80//         if params.len() != PARAMS_LEN {
81//             return Err(KernelError::ParametersLengthMismatch.into());
82//         }
83//         if x.len() != xprime.len() {
84//             return Err(KernelError::InvalidArgument.into());
85//         }
86
87//         let v = x
88//             .par_iter()
89//             .zip(xprime.par_iter())
90//             .map(|(x_i, xprime_i)| (x_i - xprime_i).powi(2))
91//             .sum::<f64>()
92//             .sqrt();
93
94//         Ok(v)
95//     }
96// }
97
98// impl PositiveDefiniteKernel<Vec<f64>> for Periodic {
99//     fn params_len(&self) -> usize {
100//         PARAMS_LEN
101//     }
102
103//     fn value(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> {
104//         let norm = self.norm(params, x, xprime)?;
105
106//         let fx = (params[0] * (norm / params[1]).cos()).exp();
107
108//         Ok(fx)
109//     }
110// }
111
112// impl<R> Add<R> for Periodic
113// where
114//     R: PositiveDefiniteKernel<Vec<f64>>,
115// {
116//     type Output = KernelAdd<Self, R, Vec<f64>>;
117
118//     fn add(self, rhs: R) -> Self::Output {
119//         Self::Output::new(self, rhs)
120//     }
121// }
122
123// impl<R> Mul<R> for Periodic
124// where
125//     R: PositiveDefiniteKernel<Vec<f64>>,
126// {
127//     type Output = KernelMul<Self, R, Vec<f64>>;
128
129//     fn mul(self, rhs: R) -> Self::Output {
130//         Self::Output::new(self, rhs)
131//     }
132// }
133
134// impl ValueDifferentiableKernel<Vec<f64>> for Periodic {
135//     fn ln_diff_value(
136//         &self,
137//         params: &[f64],
138//         x: &Vec<f64>,
139//         xprime: &Vec<f64>,
140//     ) -> Result<Vec<f64>, KernelError> {
141//         let value = &self.value(params, x, xprime)?;
142//         let diff = (-value.sin() * 2.0 / params[1]
143//             * (x.clone().col_mat() - xprime.clone().col_mat()))
144//         .vec();
145//         Ok(diff)
146//     }
147// }
148
149// impl ParamsDifferentiableKernel<Vec<f64>> for Periodic {
150//     fn ln_diff_params(
151//         &self,
152//         params: &[f64],
153//         x: &Vec<f64>,
154//         xprime: &Vec<f64>,
155//     ) -> Result<Vec<f64>, KernelError> {
156//         let value = &self.value(params, x, xprime)?;
157//         let diff0 = 1.0 / params[0];
158//         let diff1 = value.sin() * 2.0 * params[1].powi(-2) * &self.norm(params, x, xprime)?;
159//         let diff = vec![diff0, diff1];
160//         Ok(diff)
161//     }
162// }
163
164// #[cfg(test)]
165// mod tests {
166//     use crate::*;
167//     #[test]
168//     fn it_works() {
169//         let kernel = Periodic;
170
171//         let test_value = kernel.value(&[1.0], &vec![0.0, 0.0, 0.0], &vec![0.0, 0.0, 0.0]);
172
173//         match test_value {
174//             Err(KernelError::ParametersLengthMismatch) => (),
175//             _ => panic!(),
176//         };
177//     }
178
179//     #[test]
180//     fn it_works2() {
181//         let kernel = Periodic;
182
183//         let test_value = kernel
184//             .value(&[1.0, 1.0], &vec![0.0, 0.0, 0.0], &vec![0.0, 0.0, 0.0])
185//             .unwrap();
186
187//         assert_eq!(test_value, 1f64.exp());
188//     }
189// }