1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
use crate::{KernelAdd, KernelError, PositiveDefiniteKernel};
use opensrdk_symbolic_computation::Expression;
use std::ops::Add;
use std::{fmt::Debug, ops::Mul};
#[derive(Clone, Debug)]
pub struct KernelMul<L, R>
where
L: PositiveDefiniteKernel,
R: PositiveDefiniteKernel,
{
lhs: L,
rhs: R,
}
impl<L, R> KernelMul<L, R>
where
L: PositiveDefiniteKernel,
R: PositiveDefiniteKernel,
{
pub fn new(lhs: L, rhs: R) -> Self {
Self { lhs, rhs }
}
}
impl<L, R> PositiveDefiniteKernel for KernelMul<L, R>
where
L: PositiveDefiniteKernel,
R: PositiveDefiniteKernel,
{
fn params_len(&self) -> usize {
self.lhs.params_len() + self.rhs.params_len()
}
fn expression(
&self,
x: Expression,
x_prime: Expression,
params: &[Expression],
) -> Result<Expression, KernelError> {
let lhs_params_len = self.lhs.params_len();
let fx = self
.lhs
.expression(x.clone(), x_prime.clone(), ¶ms[..lhs_params_len])?;
let gx = self.rhs.expression(x, x_prime, ¶ms[lhs_params_len..])?;
let hx = fx * gx;
Ok(hx)
}
}
impl<Rhs, L, R> Add<Rhs> for KernelMul<L, R>
where
Rhs: PositiveDefiniteKernel,
L: PositiveDefiniteKernel,
R: PositiveDefiniteKernel,
{
type Output = KernelAdd<Self, Rhs>;
fn add(self, rhs: Rhs) -> Self::Output {
Self::Output::new(self, rhs)
}
}
impl<Rhs, L, R> Mul<Rhs> for KernelMul<L, R>
where
Rhs: PositiveDefiniteKernel,
L: PositiveDefiniteKernel,
R: PositiveDefiniteKernel,
{
type Output = KernelMul<Self, Rhs>;
fn mul(self, rhs: Rhs) -> Self::Output {
Self::Output::new(self, rhs)
}
}
// impl<L, R> ValueDifferentiableKernel for KernelMul<L, R>
// where
// L: ValueDifferentiableKernel,
// R: ValueDifferentiableKernel<T>,
// T: Value,
// {
// fn ln_diff_value(&self, params: &[f64], x: &T, xprime: &T) -> Result<Vec<f64>, KernelError> {
// let diff_rhs = &self
// .rhs
// .ln_diff_value(params, x, xprime)
// .unwrap()
// .clone()
// .col_mat();
// let diff_lhs = &self
// .lhs
// .ln_diff_value(params, x, xprime)
// .unwrap()
// .clone()
// .col_mat();
// let diff = (diff_rhs + diff_lhs.clone()).vec();
// Ok(diff)
// }
// }
// impl<L, R, T> ParamsDifferentiableKernel<T> for KernelMul<L, R, T>
// where
// L: ParamsDifferentiableKernel<T>,
// R: ParamsDifferentiableKernel<T>,
// T: Value,
// {
// fn ln_diff_params(&self, params: &[f64], x: &T, xprime: &T) -> Result<Vec<f64>, KernelError> {
// let diff_rhs = &self
// .rhs
// .ln_diff_params(params, x, xprime)
// .unwrap()
// .clone()
// .col_mat();
// let diff_lhs = &self
// .lhs
// .ln_diff_params(params, x, xprime)
// .unwrap()
// .clone()
// .col_mat();
// let diff = (diff_rhs + diff_lhs.clone()).vec();
// Ok(diff)
// }
// }