opensrdk_kernel_method/
spectral_mixture.rs

1use super::PositiveDefiniteKernel;
2use crate::{KernelAdd, KernelError, KernelMul};
3use opensrdk_symbolic_computation::Expression;
4use rayon::prelude::*;
5use std::{f64::consts::PI, ops::Add, ops::Mul};
6
7/// http://www.cs.cmu.edu/~andrewgw/andrewgwthesis.pdf
8#[derive(Clone, Debug)]
9pub struct SpectralMixture {
10    p: usize,
11    q: usize,
12}
13
14impl SpectralMixture {
15    pub fn new(p: usize, q: usize) -> Self {
16        Self { p, q }
17    }
18}
19
20impl PositiveDefiniteKernel for SpectralMixture {
21    fn params_len(&self) -> usize {
22        self.q + self.p * self.q + self.p * self.q
23    }
24
25    fn expression(
26        &self,
27        x: Expression,
28        x_prime: Expression,
29        params: &[Expression],
30    ) -> Result<Expression, KernelError> {
31        todo!()
32        // if params.len() != self.params_len() {
33        //     return Err(KernelError::ParametersLengthMismatch.into());
34        // }
35        // if self.p != x.len() {
36        //     return Err(KernelError::ParametersLengthMismatch.into());
37        // }
38        // if x.len() != x_prime.len() {
39        //     return Err(KernelError::InvalidArgument.into());
40        // }
41
42        // let w = &params[0..self.q];
43        // let v = &params[self.q..self.q + self.p * self.q];
44        // let mu = &params[self.q + self.p * self.q..self.q + self.p * self.q + self.p * self.q];
45
46        // let fx = (0..self.q)
47        //     .into_par_iter()
48        //     .map(|q| {
49        //         w[q] * (0..self.p)
50        //             .into_par_iter()
51        //             .map(|p| {
52        //                 (-2.0 * PI.powi(2) * (x[p] - x_prime[p]).powi(2) * v[self.p * p + q]).exp()
53        //                     * (2.0 * PI * (x[p] - x_prime[p]) * mu[self.p * p + q]).cos()
54        //             })
55        //             .product::<f64>()
56        //     })
57        // //     .sum();
58
59        // Ok(fx)
60    }
61}
62
63impl<R> Add<R> for SpectralMixture
64where
65    R: PositiveDefiniteKernel,
66{
67    type Output = KernelAdd<Self, R>;
68
69    fn add(self, rhs: R) -> Self::Output {
70        Self::Output::new(self, rhs)
71    }
72}
73
74impl<R> Mul<R> for SpectralMixture
75where
76    R: PositiveDefiniteKernel,
77{
78    type Output = KernelMul<Self, R>;
79
80    fn mul(self, rhs: R) -> Self::Output {
81        Self::Output::new(self, rhs)
82    }
83}
84
85// impl ValueDifferentiableKernel<Vec<f64>> for SpectralMixture {
86//     fn ln_diff_value(
87//         &self,
88//         params: &[f64],
89//         x: &Vec<f64>,
90//         xprime: &Vec<f64>,
91//     ) -> Result<Vec<f64>, KernelError> {
92//         let value = self.value(params, x, xprime).unwrap();
93//         let w = &params[0..self.q];
94//         let v = &params[self.q..self.q + self.p * self.q];
95//         let mu = &params[self.q + self.p * self.q..self.q + self.p * self.q + self.p * self.q];
96
97//         let diff = (0..self.p)
98//             .into_par_iter()
99//             .map(|p| {
100//                 (0..self.q)
101//                     .into_par_iter()
102//                     .map(|q| {
103//                         let each_wd = w[q]
104//                             * (0..self.p)
105//                                 .into_par_iter()
106//                                 .map(|i| {
107//                                     (-2.0
108//                                         * PI.powi(2)
109//                                         * (x[i] - xprime[i]).powi(2)
110//                                         * v[self.p * i + q])
111//                                         .exp()
112//                                         * (2.0 * PI * (x[i] - xprime[i]) * mu[self.p * i + q]).cos()
113//                                 })
114//                                 .product::<f64>();
115//                         let diff_d = (4.0 * PI.powi(2) * (x[p] - xprime[p]) * v[self.p * p + q])
116//                             + (2.0 * PI * (x[p] - xprime[p]) * mu[self.p * p + q]).tan()
117//                                 * ((-2.0) * PI * (x[p] - xprime[p]) * mu[self.p * p + q]);
118//                         diff_d * each_wd / value
119//                     })
120//                     .sum()
121//             })
122//             .collect::<Vec<f64>>();
123
124//         Ok(diff)
125//     }
126// }
127
128// impl ParamsDifferentiableKernel<Vec<f64>> for SpectralMixture {
129//     fn ln_diff_params(
130//         &self,
131//         params: &[f64],
132//         x: &Vec<f64>,
133//         xprime: &Vec<f64>,
134//     ) -> Result<Vec<f64>, KernelError> {
135//         let value = self.value(params, x, xprime).unwrap();
136//         let w = &params[0..self.q];
137//         let v = &params[self.q..self.q + self.p * self.q];
138//         let mu = &params[self.q + self.p * self.q..self.q + self.p * self.q + self.p * self.q];
139
140//         let diff_w = (0..self.q)
141//             .into_par_iter()
142//             .map(|q| {
143//                 (0..self.p)
144//                     .into_par_iter()
145//                     .map(|i| {
146//                         (-2.0 * PI.powi(2) * (x[i] - xprime[i]).powi(2) * v[self.p * i + q]).exp()
147//                             * (2.0 * PI * (x[i] - xprime[i]) * mu[self.p * i + q]).cos()
148//                     })
149//                     .product::<f64>()
150//             })
151//             .collect::<Vec<f64>>();
152
153//         let diff_mu = (0..self.q)
154//             .into_par_iter()
155//             .map(|q| {
156//                 let each_wd = w[q]
157//                     * (0..self.p)
158//                         .into_par_iter()
159//                         .map(|i| {
160//                             (-2.0 * PI.powi(2) * (x[i] - xprime[i]).powi(2) * v[self.p * i + q])
161//                                 .exp()
162//                                 * (2.0 * PI * (x[i] - xprime[i]) * mu[self.p * i + q]).cos()
163//                         })
164//                         .product::<f64>();
165//                 (0..self.p)
166//                     .into_par_iter()
167//                     .map(|p| {
168//                         let diff_d = (2.0 * PI * (x[p] - xprime[p]) * mu[self.p * p + q]).tan()
169//                             * ((-2.0) * PI * (x[p] - xprime[p]));
170//                         diff_d * each_wd / value
171//                     })
172//                     .collect::<Vec<f64>>()
173//             })
174//             .collect::<Vec<Vec<f64>>>()
175//             .concat();
176
177//         let diff_v = (0..self.q)
178//             .into_par_iter()
179//             .map(|q| {
180//                 let each_wd = w[q]
181//                     * (0..self.p)
182//                         .into_par_iter()
183//                         .map(|i| {
184//                             (-2.0 * PI.powi(2) * (x[i] - xprime[i]).powi(2) * v[self.p * i + q])
185//                                 .exp()
186//                                 * (2.0 * PI * (x[i] - xprime[i]) * mu[self.p * i + q]).cos()
187//                         })
188//                         .product::<f64>();
189//                 (0..self.p)
190//                     .into_par_iter()
191//                     .map(|p| {
192//                         let diff_d = 4.0 * PI.powi(2) * (x[p] - xprime[p]).powi(2);
193//                         diff_d * each_wd / value
194//                     })
195//                     .collect::<Vec<f64>>()
196//             })
197//             .collect::<Vec<Vec<f64>>>()
198//             .concat();
199
200//         let diff = [diff_w, diff_v, diff_mu].concat();
201
202//         Ok(diff)
203//     }
204// // }
205
206// #[cfg(test)]
207// mod tests {
208//     use crate::*;
209//     #[test]
210//     fn it_works() {
211//         let kernel = SpectralMixture::new(1, 2);
212
213//         let test_value = kernel.expression(
214//             Expression::from(vec![0.0, 0.0, 0.0]),
215//             Expression::from(vec![0.0, 0.0, 0.0]),
216//             &[Expression::from([1.0]), Expression::from([1.0])],
217//         );
218
219//         match test_value {
220//             Err(KernelError::ParametersLengthMismatch) => (),
221//             _ => panic!(),
222//         };
223//     }
224// }