opensrdk_kernel_method/
add.rs

1use crate::{KernelError, KernelMul, PositiveDefiniteKernel};
2use opensrdk_symbolic_computation::Expression;
3use std::fmt::Debug;
4use std::{ops::Add, ops::Mul};
5
6#[derive(Clone, Debug)]
7pub struct KernelAdd<L, R>
8where
9    L: PositiveDefiniteKernel,
10    R: PositiveDefiniteKernel,
11{
12    lhs: L,
13    rhs: R,
14}
15
16impl<L, R> KernelAdd<L, R>
17where
18    L: PositiveDefiniteKernel,
19    R: PositiveDefiniteKernel,
20{
21    pub fn new(lhs: L, rhs: R) -> Self {
22        Self { lhs, rhs }
23    }
24}
25
26impl<L, R> PositiveDefiniteKernel for KernelAdd<L, R>
27where
28    L: PositiveDefiniteKernel,
29    R: PositiveDefiniteKernel,
30{
31    fn params_len(&self) -> usize {
32        self.lhs.params_len() + self.rhs.params_len()
33    }
34
35    fn expression(
36        &self,
37        x: Expression,
38        x_prime: Expression,
39        params: &[Expression],
40    ) -> Result<Expression, KernelError> {
41        let lhs_params_len = self.lhs.params_len();
42        let fx = self
43            .lhs
44            .expression(x.clone(), x_prime.clone(), &params[..lhs_params_len])?;
45        let gx = self.rhs.expression(x, x_prime, &params[lhs_params_len..])?;
46
47        let hx = fx + gx;
48
49        Ok(hx)
50    }
51}
52
53impl<Rhs, L, R> Add<Rhs> for KernelAdd<L, R>
54where
55    Rhs: PositiveDefiniteKernel,
56    L: PositiveDefiniteKernel,
57    R: PositiveDefiniteKernel,
58{
59    type Output = KernelAdd<Self, Rhs>;
60
61    fn add(self, rhs: Rhs) -> Self::Output {
62        Self::Output::new(self, rhs)
63    }
64}
65
66impl<Rhs, L, R> Mul<Rhs> for KernelAdd<L, R>
67where
68    Rhs: PositiveDefiniteKernel,
69    L: PositiveDefiniteKernel,
70    R: PositiveDefiniteKernel,
71{
72    type Output = KernelMul<Self, Rhs>;
73
74    fn mul(self, rhs: Rhs) -> Self::Output {
75        Self::Output::new(self, rhs)
76    }
77}
78
79// impl<L, R> ValueDifferentiableKernel for KernelAdd<L, R>
80// where
81//     L: ValueDifferentiableKernel<T>,
82//     R: ValueDifferentiableKernel<T>,
83//     T: Value,
84// {
85//     fn ln_diff_value(&self, params: &[f64], x: &T, xprime: &T) -> Result<Vec<f64>, KernelError> {
86//         let diff_rhs = &self
87//             .rhs
88//             .ln_diff_value(params, x, xprime)
89//             .unwrap()
90//             .clone()
91//             .col_mat();
92//         let diff_lhs = &self
93//             .lhs
94//             .ln_diff_value(params, x, xprime)
95//             .unwrap()
96//             .clone()
97//             .col_mat();
98//         let value_rhs = vec![self.rhs.value(params, x, xprime).unwrap()].col_mat();
99//         let value_lhs = vec![self.lhs.value(params, x, xprime).unwrap()].col_mat();
100//         let diff = ((&value_rhs * diff_rhs + &value_lhs * diff_lhs)
101//             * (&value_rhs + value_lhs)[(0, 0)].powi(-1))
102//         .vec();
103//         Ok(diff)
104//     }
105// }
106
107// impl<L, R, T> ParamsDifferentiableKernel<T> for KernelAdd<L, R, T>
108// where
109//     L: ParamsDifferentiableKernel<T>,
110//     R: ParamsDifferentiableKernel<T>,
111//     T: Value,
112// {
113//     fn ln_diff_params(&self, params: &[f64], x: &T, xprime: &T) -> Result<Vec<f64>, KernelError> {
114//         let diff_rhs = &self
115//             .rhs
116//             .ln_diff_params(params, x, xprime)
117//             .unwrap()
118//             .clone()
119//             .col_mat();
120//         let diff_lhs = &self
121//             .lhs
122//             .ln_diff_params(params, x, xprime)
123//             .unwrap()
124//             .clone()
125//             .col_mat();
126//         let value_rhs = vec![self.rhs.value(params, x, xprime).unwrap()].col_mat();
127//         let value_lhs = vec![self.lhs.value(params, x, xprime).unwrap()].col_mat();
128//         let diff = ((&value_rhs * diff_rhs + &value_lhs * diff_lhs)
129//             * (&value_rhs + value_lhs)[(0, 0)].powi(-1))
130//         .vec();
131//         Ok(diff)
132//     }
133// }