opensrdk_kernel_method/
linear.rs

1use std::ops::{Add, Mul};
2
3use crate::KernelError;
4
5use super::{KernelAdd, KernelMul, PositiveDefiniteKernel};
6use opensrdk_symbolic_computation::Expression;
7
8const PARAMS_LEN: usize = 0;
9#[derive(Clone, Debug)]
10pub struct Linear;
11
12impl PositiveDefiniteKernel for Linear {
13    fn expression(
14        &self,
15        x: Expression,
16        x_prime: Expression,
17        params: &[Expression],
18    ) -> Result<Expression, KernelError> {
19        if params.len() != PARAMS_LEN {
20            return Err(KernelError::ParametersLengthMismatch.into());
21        }
22        // if x.len() != x_prime.len() {
23        //     return Err(KernelError::InvalidArgument.into());
24        // }
25        Ok(x.clone().dot(x_prime, &[[0, 0]]))
26    }
27
28    fn params_len(&self) -> usize {
29        0
30    }
31}
32
33impl<R> Add<R> for Linear
34where
35    R: PositiveDefiniteKernel,
36{
37    type Output = KernelAdd<Self, R>;
38
39    fn add(self, rhs: R) -> Self::Output {
40        KernelAdd::new(self, rhs)
41    }
42}
43
44impl<R> Mul<R> for Linear
45where
46    R: PositiveDefiniteKernel,
47{
48    type Output = KernelMul<Self, R>;
49
50    fn mul(self, rhs: R) -> Self::Output {
51        KernelMul::new(self, rhs)
52    }
53}
54
55// use super::PositiveDefiniteKernel;
56// use crate::{
57//     KernelAdd, KernelError, KernelMul, ParamsDifferentiableKernel, ValueDifferentiableKernel,
58// };
59// use opensrdk_linear_algebra::*;
60// use rayon::prelude::*;
61// use std::{ops::Add, ops::Mul};
62
63// const PARAMS_LEN: usize = 0;
64
65// #[derive(Clone, Debug)]
66// pub struct Linear;
67
68// impl PositiveDefiniteKernel<Vec<f64>> for Linear {
69//     fn params_len(&self) -> usize {
70//         PARAMS_LEN
71//     }
72
73//     fn value(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> {
74//         if params.len() != PARAMS_LEN {
75//             return Err(KernelError::ParametersLengthMismatch.into());
76//         }
77//         if x.len() != xprime.len() {
78//             return Err(KernelError::InvalidArgument.into());
79//         }
80
81//         let fx = x
82//             .par_iter()
83//             .zip(xprime.par_iter())
84//             .map(|(x_i, xprime_i)| x_i * xprime_i)
85//             .sum();
86
87//         Ok(fx)
88//     }
89// }
90
91// impl<R> Add<R> for Linear
92// where
93//     R: PositiveDefiniteKernel<Vec<f64>>,
94// {
95//     type Output = KernelAdd<Self, R, Vec<f64>>;
96
97//     fn add(self, rhs: R) -> Self::Output {
98//         Self::Output::new(self, rhs)
99//     }
100// }
101
102// impl<R> Mul<R> for Linear
103// where
104//     R: PositiveDefiniteKernel<Vec<f64>>,
105// {
106//     type Output = KernelMul<Self, R, Vec<f64>>;
107
108//     fn mul(self, rhs: R) -> Self::Output {
109//         Self::Output::new(self, rhs)
110//     }
111// }
112
113// impl ValueDifferentiableKernel<Vec<f64>> for Linear {
114//     fn ln_diff_value(
115//         &self,
116//         params: &[f64],
117//         x: &Vec<f64>,
118//         xprime: &Vec<f64>,
119//     ) -> Result<Vec<f64>, KernelError> {
120//         let value = &self.value(params, x, xprime)?;
121//         let diff = (2.0 / value * x.clone().col_mat()).vec();
122//         Ok(diff)
123//     }
124// }
125
126// impl ParamsDifferentiableKernel<Vec<f64>> for Linear {
127//     fn ln_diff_params(
128//         &self,
129//         _params: &[f64],
130//         _x: &Vec<f64>,
131//         _xprime: &Vec<f64>,
132//     ) -> Result<Vec<f64>, KernelError> {
133//         let diff = vec![];
134//         Ok(diff)
135//     }
136// }
137
138// #[cfg(test)]
139// mod tests {
140//     use crate::*;
141//     #[test]
142//     fn it_works() {
143//         let kernel = Linear;
144
145//         let test_value = kernel
146//             .value(&[], &vec![1.0, 2.0, 3.0], &vec![3.0, 2.0, 1.0])
147//             .unwrap();
148
149//         assert_eq!(test_value, 10.0);
150//     }
151// }