opensrdk_kernel_method/periodic.rs
1use std::ops::{Add, Mul};
2
3use crate::KernelError;
4
5use super::{KernelAdd, KernelMul, PositiveDefiniteKernel};
6use opensrdk_symbolic_computation::Expression;
7
8const PARAMS_LEN: usize = 2;
9
10#[derive(Clone, Debug)]
11pub struct Periodic;
12
13impl PositiveDefiniteKernel for Periodic {
14 fn expression(
15 &self,
16 x: Expression,
17 x_prime: Expression,
18 params: &[Expression],
19 ) -> Result<Expression, KernelError> {
20 if params.len() != PARAMS_LEN {
21 return Err(KernelError::ParametersLengthMismatch.into());
22 }
23 // if x.len() != x_prime.len() {
24 // return Err(KernelError::InvalidArgument.into());
25 // }
26 let diff = x - x_prime;
27
28 Ok((params[0].clone()
29 * (diff
30 .clone()
31 .dot(diff, &[[0, 0]])
32 .pow(Expression::from(1.0 / 2.0))
33 / params[1].clone())
34 .cos())
35 .exp())
36 }
37
38 fn params_len(&self) -> usize {
39 2
40 }
41}
42
43impl<R> Add<R> for Periodic
44where
45 R: PositiveDefiniteKernel,
46{
47 type Output = KernelAdd<Self, R>;
48
49 fn add(self, rhs: R) -> Self::Output {
50 KernelAdd::new(self, rhs)
51 }
52}
53
54impl<R> Mul<R> for Periodic
55where
56 R: PositiveDefiniteKernel,
57{
58 type Output = KernelMul<Self, R>;
59
60 fn mul(self, rhs: R) -> Self::Output {
61 KernelMul::new(self, rhs)
62 }
63}
64
65// use super::PositiveDefiniteKernel;
66// use crate::{
67// KernelAdd, KernelError, KernelMul, ParamsDifferentiableKernel, ValueDifferentiableKernel,
68// };
69// use opensrdk_linear_algebra::Vector;
70// use rayon::prelude::*;
71// use std::{ops::Add, ops::Mul};
72
73// const PARAMS_LEN: usize = 2;
74
75// #[derive(Clone, Debug)]
76// pub struct Periodic;
77
78// impl Periodic {
79// fn norm(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> {
80// if params.len() != PARAMS_LEN {
81// return Err(KernelError::ParametersLengthMismatch.into());
82// }
83// if x.len() != xprime.len() {
84// return Err(KernelError::InvalidArgument.into());
85// }
86
87// let v = x
88// .par_iter()
89// .zip(xprime.par_iter())
90// .map(|(x_i, xprime_i)| (x_i - xprime_i).powi(2))
91// .sum::<f64>()
92// .sqrt();
93
94// Ok(v)
95// }
96// }
97
98// impl PositiveDefiniteKernel<Vec<f64>> for Periodic {
99// fn params_len(&self) -> usize {
100// PARAMS_LEN
101// }
102
103// fn value(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> {
104// let norm = self.norm(params, x, xprime)?;
105
106// let fx = (params[0] * (norm / params[1]).cos()).exp();
107
108// Ok(fx)
109// }
110// }
111
112// impl<R> Add<R> for Periodic
113// where
114// R: PositiveDefiniteKernel<Vec<f64>>,
115// {
116// type Output = KernelAdd<Self, R, Vec<f64>>;
117
118// fn add(self, rhs: R) -> Self::Output {
119// Self::Output::new(self, rhs)
120// }
121// }
122
123// impl<R> Mul<R> for Periodic
124// where
125// R: PositiveDefiniteKernel<Vec<f64>>,
126// {
127// type Output = KernelMul<Self, R, Vec<f64>>;
128
129// fn mul(self, rhs: R) -> Self::Output {
130// Self::Output::new(self, rhs)
131// }
132// }
133
134// impl ValueDifferentiableKernel<Vec<f64>> for Periodic {
135// fn ln_diff_value(
136// &self,
137// params: &[f64],
138// x: &Vec<f64>,
139// xprime: &Vec<f64>,
140// ) -> Result<Vec<f64>, KernelError> {
141// let value = &self.value(params, x, xprime)?;
142// let diff = (-value.sin() * 2.0 / params[1]
143// * (x.clone().col_mat() - xprime.clone().col_mat()))
144// .vec();
145// Ok(diff)
146// }
147// }
148
149// impl ParamsDifferentiableKernel<Vec<f64>> for Periodic {
150// fn ln_diff_params(
151// &self,
152// params: &[f64],
153// x: &Vec<f64>,
154// xprime: &Vec<f64>,
155// ) -> Result<Vec<f64>, KernelError> {
156// let value = &self.value(params, x, xprime)?;
157// let diff0 = 1.0 / params[0];
158// let diff1 = value.sin() * 2.0 * params[1].powi(-2) * &self.norm(params, x, xprime)?;
159// let diff = vec![diff0, diff1];
160// Ok(diff)
161// }
162// }
163
164// #[cfg(test)]
165// mod tests {
166// use crate::*;
167// #[test]
168// fn it_works() {
169// let kernel = Periodic;
170
171// let test_value = kernel.value(&[1.0], &vec![0.0, 0.0, 0.0], &vec![0.0, 0.0, 0.0]);
172
173// match test_value {
174// Err(KernelError::ParametersLengthMismatch) => (),
175// _ => panic!(),
176// };
177// }
178
179// #[test]
180// fn it_works2() {
181// let kernel = Periodic;
182
183// let test_value = kernel
184// .value(&[1.0, 1.0], &vec![0.0, 0.0, 0.0], &vec![0.0, 0.0, 0.0])
185// .unwrap();
186
187// assert_eq!(test_value, 1f64.exp());
188// }
189// }