opensrdk_kernel_method/
constant.rs

1use std::ops::{Add, Mul};
2
3use crate::KernelError;
4
5use super::{KernelAdd, KernelMul, PositiveDefiniteKernel};
6use opensrdk_symbolic_computation::Expression;
7
8const PARAMS_LEN: usize = 1;
9
10#[derive(Clone, Debug)]
11pub struct Constant;
12
13impl PositiveDefiniteKernel for Constant {
14    fn expression(
15        &self,
16        x: Expression,
17        x_prime: Expression,
18        params: &[Expression],
19    ) -> Result<Expression, KernelError> {
20        if params.len() != PARAMS_LEN {
21            return Err(KernelError::ParametersLengthMismatch.into());
22        }
23        // if x.len() != x_prime.len() {
24        //     return Err(KernelError::InvalidArgument.into());
25        // }
26        Ok(params[0].clone())
27    }
28
29    fn params_len(&self) -> usize {
30        1
31    }
32}
33
34impl<R> Add<R> for Constant
35where
36    R: PositiveDefiniteKernel,
37{
38    type Output = KernelAdd<Self, R>;
39
40    fn add(self, rhs: R) -> Self::Output {
41        KernelAdd::new(self, rhs)
42    }
43}
44
45impl<R> Mul<R> for Constant
46where
47    R: PositiveDefiniteKernel,
48{
49    type Output = KernelMul<Self, R>;
50
51    fn mul(self, rhs: R) -> Self::Output {
52        KernelMul::new(self, rhs)
53    }
54}
55
56// use super::PositiveDefiniteKernel;
57// use crate::{KernelAdd, KernelError, KernelMul};
58// use crate::{ParamsDifferentiableKernel, Value, ValueDifferentiableKernel};
59// use std::fmt::Debug;
60// use std::{ops::Add, ops::Mul};
61
62// const PARAMS_LEN: usize = 1;
63
64// #[derive(Clone, Debug)]
65// pub struct Constant;
66
67// impl<T> PositiveDefiniteKernel<T> for Constant
68// where
69//     T: Value,
70// {
71//     fn params_len(&self) -> usize {
72//         PARAMS_LEN
73//     }
74
75//     fn value(&self, params: &[f64], _: &T, _: &T) -> Result<f64, KernelError> {
76//         if params.len() != PARAMS_LEN {
77//             return Err(KernelError::ParametersLengthMismatch.into());
78//         }
79
80//         let fx = params[0];
81
82//         Ok(fx)
83//     }
84// }
85
86// impl<R> Add<R> for Constant
87// where
88//     R: PositiveDefiniteKernel<Vec<f64>>,
89// {
90//     type Output = KernelAdd<Self, R, Vec<f64>>;
91
92//     fn add(self, rhs: R) -> Self::Output {
93//         Self::Output::new(self, rhs)
94//     }
95// }
96
97// impl<R> Mul<R> for Constant
98// where
99//     R: PositiveDefiniteKernel<Vec<f64>>,
100// {
101//     type Output = KernelMul<Self, R, Vec<f64>>;
102
103//     fn mul(self, rhs: R) -> Self::Output {
104//         Self::Output::new(self, rhs)
105//     }
106// }
107
108// impl ValueDifferentiableKernel<Vec<f64>> for Constant {
109//     fn ln_diff_value(
110//         &self,
111//         _params: &[f64],
112//         x: &Vec<f64>,
113//         _xprime: &Vec<f64>,
114//     ) -> Result<Vec<f64>, KernelError> {
115//         let diff = vec![0.0; x.len()];
116//         Ok(diff)
117//     }
118// }
119
120// impl ParamsDifferentiableKernel<Vec<f64>> for Constant {
121//     fn ln_diff_params(
122//         &self,
123//         _params: &[f64],
124//         _x: &Vec<f64>,
125//         _xprime: &Vec<f64>,
126//     ) -> Result<Vec<f64>, KernelError> {
127//         let diff = vec![1.0];
128//         Ok(diff)
129//     }
130// }
131
132// #[cfg(test)]
133// mod tests {
134//     use crate::*;
135//     #[test]
136//     fn it_works() {
137//         let kernel = Constant;
138
139//         let test_value = kernel
140//             .value(&[1.0], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0])
141//             .unwrap();
142
143//         assert_eq!(test_value, 1.0);
144//     }
145// }