opensrdk_kernel_method/spectral_mixture.rs
1use super::PositiveDefiniteKernel;
2use crate::{KernelAdd, KernelError, KernelMul};
3use opensrdk_symbolic_computation::Expression;
4use rayon::prelude::*;
5use std::{f64::consts::PI, ops::Add, ops::Mul};
6
7/// http://www.cs.cmu.edu/~andrewgw/andrewgwthesis.pdf
8#[derive(Clone, Debug)]
9pub struct SpectralMixture {
10 p: usize,
11 q: usize,
12}
13
14impl SpectralMixture {
15 pub fn new(p: usize, q: usize) -> Self {
16 Self { p, q }
17 }
18}
19
20impl PositiveDefiniteKernel for SpectralMixture {
21 fn params_len(&self) -> usize {
22 self.q + self.p * self.q + self.p * self.q
23 }
24
25 fn expression(
26 &self,
27 x: Expression,
28 x_prime: Expression,
29 params: &[Expression],
30 ) -> Result<Expression, KernelError> {
31 todo!()
32 // if params.len() != self.params_len() {
33 // return Err(KernelError::ParametersLengthMismatch.into());
34 // }
35 // if self.p != x.len() {
36 // return Err(KernelError::ParametersLengthMismatch.into());
37 // }
38 // if x.len() != x_prime.len() {
39 // return Err(KernelError::InvalidArgument.into());
40 // }
41
42 // let w = ¶ms[0..self.q];
43 // let v = ¶ms[self.q..self.q + self.p * self.q];
44 // let mu = ¶ms[self.q + self.p * self.q..self.q + self.p * self.q + self.p * self.q];
45
46 // let fx = (0..self.q)
47 // .into_par_iter()
48 // .map(|q| {
49 // w[q] * (0..self.p)
50 // .into_par_iter()
51 // .map(|p| {
52 // (-2.0 * PI.powi(2) * (x[p] - x_prime[p]).powi(2) * v[self.p * p + q]).exp()
53 // * (2.0 * PI * (x[p] - x_prime[p]) * mu[self.p * p + q]).cos()
54 // })
55 // .product::<f64>()
56 // })
57 // // .sum();
58
59 // Ok(fx)
60 }
61}
62
63impl<R> Add<R> for SpectralMixture
64where
65 R: PositiveDefiniteKernel,
66{
67 type Output = KernelAdd<Self, R>;
68
69 fn add(self, rhs: R) -> Self::Output {
70 Self::Output::new(self, rhs)
71 }
72}
73
74impl<R> Mul<R> for SpectralMixture
75where
76 R: PositiveDefiniteKernel,
77{
78 type Output = KernelMul<Self, R>;
79
80 fn mul(self, rhs: R) -> Self::Output {
81 Self::Output::new(self, rhs)
82 }
83}
84
85// impl ValueDifferentiableKernel<Vec<f64>> for SpectralMixture {
86// fn ln_diff_value(
87// &self,
88// params: &[f64],
89// x: &Vec<f64>,
90// xprime: &Vec<f64>,
91// ) -> Result<Vec<f64>, KernelError> {
92// let value = self.value(params, x, xprime).unwrap();
93// let w = ¶ms[0..self.q];
94// let v = ¶ms[self.q..self.q + self.p * self.q];
95// let mu = ¶ms[self.q + self.p * self.q..self.q + self.p * self.q + self.p * self.q];
96
97// let diff = (0..self.p)
98// .into_par_iter()
99// .map(|p| {
100// (0..self.q)
101// .into_par_iter()
102// .map(|q| {
103// let each_wd = w[q]
104// * (0..self.p)
105// .into_par_iter()
106// .map(|i| {
107// (-2.0
108// * PI.powi(2)
109// * (x[i] - xprime[i]).powi(2)
110// * v[self.p * i + q])
111// .exp()
112// * (2.0 * PI * (x[i] - xprime[i]) * mu[self.p * i + q]).cos()
113// })
114// .product::<f64>();
115// let diff_d = (4.0 * PI.powi(2) * (x[p] - xprime[p]) * v[self.p * p + q])
116// + (2.0 * PI * (x[p] - xprime[p]) * mu[self.p * p + q]).tan()
117// * ((-2.0) * PI * (x[p] - xprime[p]) * mu[self.p * p + q]);
118// diff_d * each_wd / value
119// })
120// .sum()
121// })
122// .collect::<Vec<f64>>();
123
124// Ok(diff)
125// }
126// }
127
128// impl ParamsDifferentiableKernel<Vec<f64>> for SpectralMixture {
129// fn ln_diff_params(
130// &self,
131// params: &[f64],
132// x: &Vec<f64>,
133// xprime: &Vec<f64>,
134// ) -> Result<Vec<f64>, KernelError> {
135// let value = self.value(params, x, xprime).unwrap();
136// let w = ¶ms[0..self.q];
137// let v = ¶ms[self.q..self.q + self.p * self.q];
138// let mu = ¶ms[self.q + self.p * self.q..self.q + self.p * self.q + self.p * self.q];
139
140// let diff_w = (0..self.q)
141// .into_par_iter()
142// .map(|q| {
143// (0..self.p)
144// .into_par_iter()
145// .map(|i| {
146// (-2.0 * PI.powi(2) * (x[i] - xprime[i]).powi(2) * v[self.p * i + q]).exp()
147// * (2.0 * PI * (x[i] - xprime[i]) * mu[self.p * i + q]).cos()
148// })
149// .product::<f64>()
150// })
151// .collect::<Vec<f64>>();
152
153// let diff_mu = (0..self.q)
154// .into_par_iter()
155// .map(|q| {
156// let each_wd = w[q]
157// * (0..self.p)
158// .into_par_iter()
159// .map(|i| {
160// (-2.0 * PI.powi(2) * (x[i] - xprime[i]).powi(2) * v[self.p * i + q])
161// .exp()
162// * (2.0 * PI * (x[i] - xprime[i]) * mu[self.p * i + q]).cos()
163// })
164// .product::<f64>();
165// (0..self.p)
166// .into_par_iter()
167// .map(|p| {
168// let diff_d = (2.0 * PI * (x[p] - xprime[p]) * mu[self.p * p + q]).tan()
169// * ((-2.0) * PI * (x[p] - xprime[p]));
170// diff_d * each_wd / value
171// })
172// .collect::<Vec<f64>>()
173// })
174// .collect::<Vec<Vec<f64>>>()
175// .concat();
176
177// let diff_v = (0..self.q)
178// .into_par_iter()
179// .map(|q| {
180// let each_wd = w[q]
181// * (0..self.p)
182// .into_par_iter()
183// .map(|i| {
184// (-2.0 * PI.powi(2) * (x[i] - xprime[i]).powi(2) * v[self.p * i + q])
185// .exp()
186// * (2.0 * PI * (x[i] - xprime[i]) * mu[self.p * i + q]).cos()
187// })
188// .product::<f64>();
189// (0..self.p)
190// .into_par_iter()
191// .map(|p| {
192// let diff_d = 4.0 * PI.powi(2) * (x[p] - xprime[p]).powi(2);
193// diff_d * each_wd / value
194// })
195// .collect::<Vec<f64>>()
196// })
197// .collect::<Vec<Vec<f64>>>()
198// .concat();
199
200// let diff = [diff_w, diff_v, diff_mu].concat();
201
202// Ok(diff)
203// }
204// // }
205
206// #[cfg(test)]
207// mod tests {
208// use crate::*;
209// #[test]
210// fn it_works() {
211// let kernel = SpectralMixture::new(1, 2);
212
213// let test_value = kernel.expression(
214// Expression::from(vec![0.0, 0.0, 0.0]),
215// Expression::from(vec![0.0, 0.0, 0.0]),
216// &[Expression::from([1.0]), Expression::from([1.0])],
217// );
218
219// match test_value {
220// Err(KernelError::ParametersLengthMismatch) => (),
221// _ => panic!(),
222// };
223// }
224// }