opensrdk_kernel_method/linear.rs
1use std::ops::{Add, Mul};
2
3use crate::KernelError;
4
5use super::{KernelAdd, KernelMul, PositiveDefiniteKernel};
6use opensrdk_symbolic_computation::Expression;
7
8const PARAMS_LEN: usize = 0;
9#[derive(Clone, Debug)]
10pub struct Linear;
11
12impl PositiveDefiniteKernel for Linear {
13 fn expression(
14 &self,
15 x: Expression,
16 x_prime: Expression,
17 params: &[Expression],
18 ) -> Result<Expression, KernelError> {
19 if params.len() != PARAMS_LEN {
20 return Err(KernelError::ParametersLengthMismatch.into());
21 }
22 // if x.len() != x_prime.len() {
23 // return Err(KernelError::InvalidArgument.into());
24 // }
25 Ok(x.clone().dot(x_prime, &[[0, 0]]))
26 }
27
28 fn params_len(&self) -> usize {
29 0
30 }
31}
32
33impl<R> Add<R> for Linear
34where
35 R: PositiveDefiniteKernel,
36{
37 type Output = KernelAdd<Self, R>;
38
39 fn add(self, rhs: R) -> Self::Output {
40 KernelAdd::new(self, rhs)
41 }
42}
43
44impl<R> Mul<R> for Linear
45where
46 R: PositiveDefiniteKernel,
47{
48 type Output = KernelMul<Self, R>;
49
50 fn mul(self, rhs: R) -> Self::Output {
51 KernelMul::new(self, rhs)
52 }
53}
54
55// use super::PositiveDefiniteKernel;
56// use crate::{
57// KernelAdd, KernelError, KernelMul, ParamsDifferentiableKernel, ValueDifferentiableKernel,
58// };
59// use opensrdk_linear_algebra::*;
60// use rayon::prelude::*;
61// use std::{ops::Add, ops::Mul};
62
63// const PARAMS_LEN: usize = 0;
64
65// #[derive(Clone, Debug)]
66// pub struct Linear;
67
68// impl PositiveDefiniteKernel<Vec<f64>> for Linear {
69// fn params_len(&self) -> usize {
70// PARAMS_LEN
71// }
72
73// fn value(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> {
74// if params.len() != PARAMS_LEN {
75// return Err(KernelError::ParametersLengthMismatch.into());
76// }
77// if x.len() != xprime.len() {
78// return Err(KernelError::InvalidArgument.into());
79// }
80
81// let fx = x
82// .par_iter()
83// .zip(xprime.par_iter())
84// .map(|(x_i, xprime_i)| x_i * xprime_i)
85// .sum();
86
87// Ok(fx)
88// }
89// }
90
91// impl<R> Add<R> for Linear
92// where
93// R: PositiveDefiniteKernel<Vec<f64>>,
94// {
95// type Output = KernelAdd<Self, R, Vec<f64>>;
96
97// fn add(self, rhs: R) -> Self::Output {
98// Self::Output::new(self, rhs)
99// }
100// }
101
102// impl<R> Mul<R> for Linear
103// where
104// R: PositiveDefiniteKernel<Vec<f64>>,
105// {
106// type Output = KernelMul<Self, R, Vec<f64>>;
107
108// fn mul(self, rhs: R) -> Self::Output {
109// Self::Output::new(self, rhs)
110// }
111// }
112
113// impl ValueDifferentiableKernel<Vec<f64>> for Linear {
114// fn ln_diff_value(
115// &self,
116// params: &[f64],
117// x: &Vec<f64>,
118// xprime: &Vec<f64>,
119// ) -> Result<Vec<f64>, KernelError> {
120// let value = &self.value(params, x, xprime)?;
121// let diff = (2.0 / value * x.clone().col_mat()).vec();
122// Ok(diff)
123// }
124// }
125
126// impl ParamsDifferentiableKernel<Vec<f64>> for Linear {
127// fn ln_diff_params(
128// &self,
129// _params: &[f64],
130// _x: &Vec<f64>,
131// _xprime: &Vec<f64>,
132// ) -> Result<Vec<f64>, KernelError> {
133// let diff = vec![];
134// Ok(diff)
135// }
136// }
137
138// #[cfg(test)]
139// mod tests {
140// use crate::*;
141// #[test]
142// fn it_works() {
143// let kernel = Linear;
144
145// let test_value = kernel
146// .value(&[], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0])
147// .unwrap();
148
149// assert_eq!(test_value, 10.0);
150// }
151// }