opensrdk_kernel_method/rbf.rs
1use std::ops::{Add, Mul};
2
3use crate::KernelError;
4
5use super::{KernelAdd, KernelMul, PositiveDefiniteKernel};
6use opensrdk_symbolic_computation::Expression;
7
8const PARAMS_LEN: usize = 1;
9
10#[derive(Clone, Debug)]
11pub struct RBF;
12
13impl PositiveDefiniteKernel for RBF {
14 fn expression(
15 &self,
16 x: Expression,
17 x_prime: Expression,
18 params: &[Expression],
19 ) -> Result<Expression, KernelError> {
20 if params.len() != PARAMS_LEN {
21 return Err(KernelError::ParametersLengthMismatch.into());
22 }
23 // if x.len() != x_prime.len() {
24 // return Err(KernelError::InvalidArgument.into());
25 // }
26
27 let diff = x - x_prime;
28
29 Ok((-diff.clone().dot(diff, &[[0, 0]]) / params[0].clone()).exp())
30 }
31
32 fn params_len(&self) -> usize {
33 1
34 }
35}
36
37impl<R> Add<R> for RBF
38where
39 R: PositiveDefiniteKernel,
40{
41 type Output = KernelAdd<Self, R>;
42
43 fn add(self, rhs: R) -> Self::Output {
44 KernelAdd::new(self, rhs)
45 }
46}
47
48impl<R> Mul<R> for RBF
49where
50 R: PositiveDefiniteKernel,
51{
52 type Output = KernelMul<Self, R>;
53
54 fn mul(self, rhs: R) -> Self::Output {
55 KernelMul::new(self, rhs)
56 }
57}
58
59// use super::PositiveDefiniteKernel;
60// use crate::{
61// KernelAdd, KernelError, KernelMul, ParamsDifferentiableKernel, ValueDifferentiableKernel,
62// };
63// use opensrdk_linear_algebra::Vector;
64// use rayon::prelude::*;
65// use std::{ops::Add, ops::Mul};
66
67// const PARAMS_LEN: usize = 2;
68
69// #[derive(Clone, Debug)]
70// pub struct RBF;
71
72// impl RBF {
73// fn norm_pow(
74// &self,
75// params: &[f64],
76// x: &Vec<f64>,
77// xprime: &Vec<f64>,
78// ) -> Result<f64, KernelError> {
79// if params.len() != PARAMS_LEN {
80// return Err(KernelError::ParametersLengthMismatch.into());
81// }
82// if x.len() != xprime.len() {
83// return Err(KernelError::InvalidArgument.into());
84// }
85
86// let norm_pow = x
87// .par_iter()
88// .zip(xprime.par_iter())
89// .map(|(x_i, xprime_i)| (x_i - xprime_i).powi(2))
90// .sum();
91
92// Ok(norm_pow)
93// }
94// }
95
96// impl PositiveDefiniteKernel<Vec<f64>> for RBF {
97// fn params_len(&self) -> usize {
98// PARAMS_LEN
99// }
100
101// fn value(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> {
102// let norm_pow = self.norm_pow(params, x, xprime)?;
103
104// let fx = params[0] * (-norm_pow / params[1]).exp();
105
106// Ok(fx)
107// }
108// }
109
110// impl<R> Add<R> for RBF
111// where
112// R: PositiveDefiniteKernel<Vec<f64>>,
113// {
114// type Output = KernelAdd<Self, R, Vec<f64>>;
115
116// fn add(self, rhs: R) -> Self::Output {
117// Self::Output::new(self, rhs)
118// }
119// }
120
121// impl<R> Mul<R> for RBF
122// where
123// R: PositiveDefiniteKernel<Vec<f64>>,
124// {
125// type Output = KernelMul<Self, R, Vec<f64>>;
126
127// fn mul(self, rhs: R) -> Self::Output {
128// Self::Output::new(self, rhs)
129// }
130// }
131
132// impl ValueDifferentiableKernel<Vec<f64>> for RBF {
133// fn ln_diff_value(
134// &self,
135// params: &[f64],
136// x: &Vec<f64>,
137// xprime: &Vec<f64>,
138// ) -> Result<Vec<f64>, KernelError> {
139// let diff = (-2.0 / params[1] * (x.clone().col_mat() - xprime.clone().col_mat())).vec();
140// Ok(diff)
141// }
142// }
143
144// impl ParamsDifferentiableKernel<Vec<f64>> for RBF {
145// fn ln_diff_params(
146// &self,
147// params: &[f64],
148// x: &Vec<f64>,
149// xprime: &Vec<f64>,
150// ) -> Result<Vec<f64>, KernelError> {
151// let diff0 = 1.0 / params[0];
152// let diff1 = 2.0 * params[1].powi(-2) * &self.norm_pow(params, x, xprime).unwrap();
153// let diff = vec![diff0, diff1];
154// Ok(diff)
155// }
156// }
157
158// #[cfg(test)]
159// mod tests {
160// use crate::*;
161// #[test]
162// fn it_works() {
163// let kernel = RBF;
164// let kernel_diff = kernel
165// .expression(theta_array, samples_array, &kernel_params_expression)
166// .unwrap()
167// .ln();
168
169// assert_eq!(test_value, (-1f64).exp());
170// }
171// #[test]
172// fn it_works2() {
173// let kernel = RBF;
174
175// //let (func, grad) = kernel
176// // .value_with_grad(&[1.0, 1.0], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0])
177// // .unwrap();
178
179// //println!("{}", func);
180// //println!("{:#?}", grad);
181
182// let test_value = kernel
183// .ln_diff_value(&[1.0, 1.0], &vec![1.0, 0.0, 0.0], &vec![0.0, 0.0, 0.0])
184// .unwrap();
185
186// println!("{:?}", test_value);
187// }
188// }