opensrdk_kernel_method/add.rs
1use crate::{KernelError, KernelMul, PositiveDefiniteKernel};
2use opensrdk_symbolic_computation::Expression;
3use std::fmt::Debug;
4use std::{ops::Add, ops::Mul};
5
6#[derive(Clone, Debug)]
7pub struct KernelAdd<L, R>
8where
9 L: PositiveDefiniteKernel,
10 R: PositiveDefiniteKernel,
11{
12 lhs: L,
13 rhs: R,
14}
15
16impl<L, R> KernelAdd<L, R>
17where
18 L: PositiveDefiniteKernel,
19 R: PositiveDefiniteKernel,
20{
21 pub fn new(lhs: L, rhs: R) -> Self {
22 Self { lhs, rhs }
23 }
24}
25
26impl<L, R> PositiveDefiniteKernel for KernelAdd<L, R>
27where
28 L: PositiveDefiniteKernel,
29 R: PositiveDefiniteKernel,
30{
31 fn params_len(&self) -> usize {
32 self.lhs.params_len() + self.rhs.params_len()
33 }
34
35 fn expression(
36 &self,
37 x: Expression,
38 x_prime: Expression,
39 params: &[Expression],
40 ) -> Result<Expression, KernelError> {
41 let lhs_params_len = self.lhs.params_len();
42 let fx = self
43 .lhs
44 .expression(x.clone(), x_prime.clone(), ¶ms[..lhs_params_len])?;
45 let gx = self.rhs.expression(x, x_prime, ¶ms[lhs_params_len..])?;
46
47 let hx = fx + gx;
48
49 Ok(hx)
50 }
51}
52
53impl<Rhs, L, R> Add<Rhs> for KernelAdd<L, R>
54where
55 Rhs: PositiveDefiniteKernel,
56 L: PositiveDefiniteKernel,
57 R: PositiveDefiniteKernel,
58{
59 type Output = KernelAdd<Self, Rhs>;
60
61 fn add(self, rhs: Rhs) -> Self::Output {
62 Self::Output::new(self, rhs)
63 }
64}
65
66impl<Rhs, L, R> Mul<Rhs> for KernelAdd<L, R>
67where
68 Rhs: PositiveDefiniteKernel,
69 L: PositiveDefiniteKernel,
70 R: PositiveDefiniteKernel,
71{
72 type Output = KernelMul<Self, Rhs>;
73
74 fn mul(self, rhs: Rhs) -> Self::Output {
75 Self::Output::new(self, rhs)
76 }
77}
78
79// impl<L, R> ValueDifferentiableKernel for KernelAdd<L, R>
80// where
81// L: ValueDifferentiableKernel<T>,
82// R: ValueDifferentiableKernel<T>,
83// T: Value,
84// {
85// fn ln_diff_value(&self, params: &[f64], x: &T, xprime: &T) -> Result<Vec<f64>, KernelError> {
86// let diff_rhs = &self
87// .rhs
88// .ln_diff_value(params, x, xprime)
89// .unwrap()
90// .clone()
91// .col_mat();
92// let diff_lhs = &self
93// .lhs
94// .ln_diff_value(params, x, xprime)
95// .unwrap()
96// .clone()
97// .col_mat();
98// let value_rhs = vec![self.rhs.value(params, x, xprime).unwrap()].col_mat();
99// let value_lhs = vec![self.lhs.value(params, x, xprime).unwrap()].col_mat();
100// let diff = ((&value_rhs * diff_rhs + &value_lhs * diff_lhs)
101// * (&value_rhs + value_lhs)[(0, 0)].powi(-1))
102// .vec();
103// Ok(diff)
104// }
105// }
106
107// impl<L, R, T> ParamsDifferentiableKernel<T> for KernelAdd<L, R, T>
108// where
109// L: ParamsDifferentiableKernel<T>,
110// R: ParamsDifferentiableKernel<T>,
111// T: Value,
112// {
113// fn ln_diff_params(&self, params: &[f64], x: &T, xprime: &T) -> Result<Vec<f64>, KernelError> {
114// let diff_rhs = &self
115// .rhs
116// .ln_diff_params(params, x, xprime)
117// .unwrap()
118// .clone()
119// .col_mat();
120// let diff_lhs = &self
121// .lhs
122// .ln_diff_params(params, x, xprime)
123// .unwrap()
124// .clone()
125// .col_mat();
126// let value_rhs = vec![self.rhs.value(params, x, xprime).unwrap()].col_mat();
127// let value_lhs = vec![self.lhs.value(params, x, xprime).unwrap()].col_mat();
128// let diff = ((&value_rhs * diff_rhs + &value_lhs * diff_lhs)
129// * (&value_rhs + value_lhs)[(0, 0)].powi(-1))
130// .vec();
131// Ok(diff)
132// }
133// }