opensrdk_kernel_method/
rbf.rs

1use std::ops::{Add, Mul};
2
3use crate::KernelError;
4
5use super::{KernelAdd, KernelMul, PositiveDefiniteKernel};
6use opensrdk_symbolic_computation::Expression;
7
8const PARAMS_LEN: usize = 1;
9
10#[derive(Clone, Debug)]
11pub struct RBF;
12
13impl PositiveDefiniteKernel for RBF {
14    fn expression(
15        &self,
16        x: Expression,
17        x_prime: Expression,
18        params: &[Expression],
19    ) -> Result<Expression, KernelError> {
20        if params.len() != PARAMS_LEN {
21            return Err(KernelError::ParametersLengthMismatch.into());
22        }
23        // if x.len() != x_prime.len() {
24        //     return Err(KernelError::InvalidArgument.into());
25        // }
26
27        let diff = x - x_prime;
28
29        Ok((-diff.clone().dot(diff, &[[0, 0]]) / params[0].clone()).exp())
30    }
31
32    fn params_len(&self) -> usize {
33        1
34    }
35}
36
37impl<R> Add<R> for RBF
38where
39    R: PositiveDefiniteKernel,
40{
41    type Output = KernelAdd<Self, R>;
42
43    fn add(self, rhs: R) -> Self::Output {
44        KernelAdd::new(self, rhs)
45    }
46}
47
48impl<R> Mul<R> for RBF
49where
50    R: PositiveDefiniteKernel,
51{
52    type Output = KernelMul<Self, R>;
53
54    fn mul(self, rhs: R) -> Self::Output {
55        KernelMul::new(self, rhs)
56    }
57}
58
59// use super::PositiveDefiniteKernel;
60// use crate::{
61//     KernelAdd, KernelError, KernelMul, ParamsDifferentiableKernel, ValueDifferentiableKernel,
62// };
63// use opensrdk_linear_algebra::Vector;
64// use rayon::prelude::*;
65// use std::{ops::Add, ops::Mul};
66
67// const PARAMS_LEN: usize = 2;
68
69// #[derive(Clone, Debug)]
70// pub struct RBF;
71
72// impl RBF {
73//     fn norm_pow(
74//         &self,
75//         params: &[f64],
76//         x: &Vec<f64>,
77//         xprime: &Vec<f64>,
78//     ) -> Result<f64, KernelError> {
79//         if params.len() != PARAMS_LEN {
80//             return Err(KernelError::ParametersLengthMismatch.into());
81//         }
82//         if x.len() != xprime.len() {
83//             return Err(KernelError::InvalidArgument.into());
84//         }
85
86//         let norm_pow = x
87//             .par_iter()
88//             .zip(xprime.par_iter())
89//             .map(|(x_i, xprime_i)| (x_i - xprime_i).powi(2))
90//             .sum();
91
92//         Ok(norm_pow)
93//     }
94// }
95
96// impl PositiveDefiniteKernel<Vec<f64>> for RBF {
97//     fn params_len(&self) -> usize {
98//         PARAMS_LEN
99//     }
100
101//     fn value(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> {
102//         let norm_pow = self.norm_pow(params, x, xprime)?;
103
104//         let fx = params[0] * (-norm_pow / params[1]).exp();
105
106//         Ok(fx)
107//     }
108// }
109
110// impl<R> Add<R> for RBF
111// where
112//     R: PositiveDefiniteKernel<Vec<f64>>,
113// {
114//     type Output = KernelAdd<Self, R, Vec<f64>>;
115
116//     fn add(self, rhs: R) -> Self::Output {
117//         Self::Output::new(self, rhs)
118//     }
119// }
120
121// impl<R> Mul<R> for RBF
122// where
123//     R: PositiveDefiniteKernel<Vec<f64>>,
124// {
125//     type Output = KernelMul<Self, R, Vec<f64>>;
126
127//     fn mul(self, rhs: R) -> Self::Output {
128//         Self::Output::new(self, rhs)
129//     }
130// }
131
132// impl ValueDifferentiableKernel<Vec<f64>> for RBF {
133//     fn ln_diff_value(
134//         &self,
135//         params: &[f64],
136//         x: &Vec<f64>,
137//         xprime: &Vec<f64>,
138//     ) -> Result<Vec<f64>, KernelError> {
139//         let diff = (-2.0 / params[1] * (x.clone().col_mat() - xprime.clone().col_mat())).vec();
140//         Ok(diff)
141//     }
142// }
143
144// impl ParamsDifferentiableKernel<Vec<f64>> for RBF {
145//     fn ln_diff_params(
146//         &self,
147//         params: &[f64],
148//         x: &Vec<f64>,
149//         xprime: &Vec<f64>,
150//     ) -> Result<Vec<f64>, KernelError> {
151//         let diff0 = 1.0 / params[0];
152//         let diff1 = 2.0 * params[1].powi(-2) * &self.norm_pow(params, x, xprime).unwrap();
153//         let diff = vec![diff0, diff1];
154//         Ok(diff)
155//     }
156// }
157
158// #[cfg(test)]
159// mod tests {
160//     use crate::*;
161//     #[test]
162//     fn it_works() {
163//         let kernel = RBF;
164//         let kernel_diff = kernel
165//             .expression(theta_array, samples_array, &kernel_params_expression)
166//             .unwrap()
167//             .ln();
168
169//         assert_eq!(test_value, (-1f64).exp());
170//     }
171//     #[test]
172//     fn it_works2() {
173//         let kernel = RBF;
174
175//         //let (func, grad) = kernel
176//         //    .value_with_grad(&[1.0, 1.0], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0])
177//         //    .unwrap();
178
179//         //println!("{}", func);
180//         //println!("{:#?}", grad);
181
182//         let test_value = kernel
183//             .ln_diff_value(&[1.0, 1.0], &vec![1.0, 0.0, 0.0], &vec![0.0, 0.0, 0.0])
184//             .unwrap();
185
186//         println!("{:?}", test_value);
187//     }
188// }