opensrdk_kernel_method/constant.rs
1use std::ops::{Add, Mul};
2
3use crate::KernelError;
4
5use super::{KernelAdd, KernelMul, PositiveDefiniteKernel};
6use opensrdk_symbolic_computation::Expression;
7
8const PARAMS_LEN: usize = 1;
9
10#[derive(Clone, Debug)]
11pub struct Constant;
12
13impl PositiveDefiniteKernel for Constant {
14 fn expression(
15 &self,
16 x: Expression,
17 x_prime: Expression,
18 params: &[Expression],
19 ) -> Result<Expression, KernelError> {
20 if params.len() != PARAMS_LEN {
21 return Err(KernelError::ParametersLengthMismatch.into());
22 }
23 // if x.len() != x_prime.len() {
24 // return Err(KernelError::InvalidArgument.into());
25 // }
26 Ok(params[0].clone())
27 }
28
29 fn params_len(&self) -> usize {
30 1
31 }
32}
33
34impl<R> Add<R> for Constant
35where
36 R: PositiveDefiniteKernel,
37{
38 type Output = KernelAdd<Self, R>;
39
40 fn add(self, rhs: R) -> Self::Output {
41 KernelAdd::new(self, rhs)
42 }
43}
44
45impl<R> Mul<R> for Constant
46where
47 R: PositiveDefiniteKernel,
48{
49 type Output = KernelMul<Self, R>;
50
51 fn mul(self, rhs: R) -> Self::Output {
52 KernelMul::new(self, rhs)
53 }
54}
55
56// use super::PositiveDefiniteKernel;
57// use crate::{KernelAdd, KernelError, KernelMul};
58// use crate::{ParamsDifferentiableKernel, Value, ValueDifferentiableKernel};
59// use std::fmt::Debug;
60// use std::{ops::Add, ops::Mul};
61
62// const PARAMS_LEN: usize = 1;
63
64// #[derive(Clone, Debug)]
65// pub struct Constant;
66
67// impl<T> PositiveDefiniteKernel<T> for Constant
68// where
69// T: Value,
70// {
71// fn params_len(&self) -> usize {
72// PARAMS_LEN
73// }
74
75// fn value(&self, params: &[f64], _: &T, _: &T) -> Result<f64, KernelError> {
76// if params.len() != PARAMS_LEN {
77// return Err(KernelError::ParametersLengthMismatch.into());
78// }
79
80// let fx = params[0];
81
82// Ok(fx)
83// }
84// }
85
86// impl<R> Add<R> for Constant
87// where
88// R: PositiveDefiniteKernel<Vec<f64>>,
89// {
90// type Output = KernelAdd<Self, R, Vec<f64>>;
91
92// fn add(self, rhs: R) -> Self::Output {
93// Self::Output::new(self, rhs)
94// }
95// }
96
97// impl<R> Mul<R> for Constant
98// where
99// R: PositiveDefiniteKernel<Vec<f64>>,
100// {
101// type Output = KernelMul<Self, R, Vec<f64>>;
102
103// fn mul(self, rhs: R) -> Self::Output {
104// Self::Output::new(self, rhs)
105// }
106// }
107
108// impl ValueDifferentiableKernel<Vec<f64>> for Constant {
109// fn ln_diff_value(
110// &self,
111// _params: &[f64],
112// x: &Vec<f64>,
113// _xprime: &Vec<f64>,
114// ) -> Result<Vec<f64>, KernelError> {
115// let diff = vec![0.0; x.len()];
116// Ok(diff)
117// }
118// }
119
120// impl ParamsDifferentiableKernel<Vec<f64>> for Constant {
121// fn ln_diff_params(
122// &self,
123// _params: &[f64],
124// _x: &Vec<f64>,
125// _xprime: &Vec<f64>,
126// ) -> Result<Vec<f64>, KernelError> {
127// let diff = vec![1.0];
128// Ok(diff)
129// }
130// }
131
132// #[cfg(test)]
133// mod tests {
134// use crate::*;
135// #[test]
136// fn it_works() {
137// let kernel = Constant;
138
139// let test_value = kernel
140// .value(&[1.0], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0])
141// .unwrap();
142
143// assert_eq!(test_value, 1.0);
144// }
145// }