opensrdk_kernel_method/mul.rs
1use crate::{KernelAdd, KernelError, PositiveDefiniteKernel};
2use opensrdk_symbolic_computation::Expression;
3use std::ops::Add;
4use std::{fmt::Debug, ops::Mul};
5
6#[derive(Clone, Debug)]
7pub struct KernelMul<L, R>
8where
9 L: PositiveDefiniteKernel,
10 R: PositiveDefiniteKernel,
11{
12 lhs: L,
13 rhs: R,
14}
15
16impl<L, R> KernelMul<L, R>
17where
18 L: PositiveDefiniteKernel,
19 R: PositiveDefiniteKernel,
20{
21 pub fn new(lhs: L, rhs: R) -> Self {
22 Self { lhs, rhs }
23 }
24}
25
26impl<L, R> PositiveDefiniteKernel for KernelMul<L, R>
27where
28 L: PositiveDefiniteKernel,
29 R: PositiveDefiniteKernel,
30{
31 fn params_len(&self) -> usize {
32 self.lhs.params_len() + self.rhs.params_len()
33 }
34 fn expression(
35 &self,
36 x: Expression,
37 x_prime: Expression,
38 params: &[Expression],
39 ) -> Result<Expression, KernelError> {
40 let lhs_params_len = self.lhs.params_len();
41 let fx = self
42 .lhs
43 .expression(x.clone(), x_prime.clone(), ¶ms[..lhs_params_len])?;
44 let gx = self.rhs.expression(x, x_prime, ¶ms[lhs_params_len..])?;
45
46 let hx = fx * gx;
47
48 Ok(hx)
49 }
50}
51
52impl<Rhs, L, R> Add<Rhs> for KernelMul<L, R>
53where
54 Rhs: PositiveDefiniteKernel,
55 L: PositiveDefiniteKernel,
56 R: PositiveDefiniteKernel,
57{
58 type Output = KernelAdd<Self, Rhs>;
59
60 fn add(self, rhs: Rhs) -> Self::Output {
61 Self::Output::new(self, rhs)
62 }
63}
64
65impl<Rhs, L, R> Mul<Rhs> for KernelMul<L, R>
66where
67 Rhs: PositiveDefiniteKernel,
68 L: PositiveDefiniteKernel,
69 R: PositiveDefiniteKernel,
70{
71 type Output = KernelMul<Self, Rhs>;
72
73 fn mul(self, rhs: Rhs) -> Self::Output {
74 Self::Output::new(self, rhs)
75 }
76}
77
78// impl<L, R> ValueDifferentiableKernel for KernelMul<L, R>
79// where
80// L: ValueDifferentiableKernel,
81// R: ValueDifferentiableKernel<T>,
82// T: Value,
83// {
84// fn ln_diff_value(&self, params: &[f64], x: &T, xprime: &T) -> Result<Vec<f64>, KernelError> {
85// let diff_rhs = &self
86// .rhs
87// .ln_diff_value(params, x, xprime)
88// .unwrap()
89// .clone()
90// .col_mat();
91// let diff_lhs = &self
92// .lhs
93// .ln_diff_value(params, x, xprime)
94// .unwrap()
95// .clone()
96// .col_mat();
97// let diff = (diff_rhs + diff_lhs.clone()).vec();
98// Ok(diff)
99// }
100// }
101
102// impl<L, R, T> ParamsDifferentiableKernel<T> for KernelMul<L, R, T>
103// where
104// L: ParamsDifferentiableKernel<T>,
105// R: ParamsDifferentiableKernel<T>,
106// T: Value,
107// {
108// fn ln_diff_params(&self, params: &[f64], x: &T, xprime: &T) -> Result<Vec<f64>, KernelError> {
109// let diff_rhs = &self
110// .rhs
111// .ln_diff_params(params, x, xprime)
112// .unwrap()
113// .clone()
114// .col_mat();
115// let diff_lhs = &self
116// .lhs
117// .ln_diff_params(params, x, xprime)
118// .unwrap()
119// .clone()
120// .col_mat();
121// let diff = (diff_rhs + diff_lhs.clone()).vec();
122// Ok(diff)
123// }
124// }