opensrdk_kernel_method/
convolutional.rs

1use crate::{KernelError, PositiveDefiniteKernel};
2use opensrdk_symbolic_computation::Expression;
3use rayon::prelude::*;
4use std::fmt::Debug;
5
6pub trait Convolutable {
7    fn parts_len(&self) -> usize;
8    fn part(&self, index: usize) -> &Expression;
9}
10
11impl Convolutable for Expression {
12    fn parts_len(&self) -> usize {
13        1
14    }
15
16    fn part(&self, _: usize) -> &Expression {
17        self
18    }
19}
20
21#[derive(Clone, Debug)]
22pub struct Convolutional<K>
23where
24    K: PositiveDefiniteKernel,
25{
26    kernel: K,
27}
28
29impl<K> Convolutional<K>
30where
31    K: PositiveDefiniteKernel,
32{
33    pub fn new(kernel: K) -> Self {
34        Self { kernel }
35    }
36
37    pub fn kernel_ref(&self) -> &K {
38        &self.kernel
39    }
40}
41
42impl<K> PositiveDefiniteKernel for Convolutional<K>
43where
44    K: PositiveDefiniteKernel,
45{
46    fn params_len(&self) -> usize {
47        self.kernel.params_len()
48    }
49
50    fn expression(
51        &self,
52        x: Expression,
53        x_prime: Expression,
54        params: &[Expression],
55    ) -> Result<Expression, KernelError> {
56        if params.len() != self.kernel.params_len() {
57            return Err(KernelError::ParametersLengthMismatch.into());
58        }
59        let p = x.parts_len();
60        if p != x_prime.parts_len() {
61            return Err(KernelError::InvalidArgument.into());
62        }
63
64        todo!()
65
66        // let fx = (0..p)
67        //     .into_par_iter()
68        //     .map(|pi| {
69        //         let expression = self
70        //             .kernel
71        //             .expression(x.part(pi).clone(), x_prime.part(pi).clone(), params)
72        //             .unwrap();
73        //         expression
74        //     })
75        //     .sum::<Result<Expression, KernelError>>()?;
76
77        // Ok(fx)
78    }
79}
80
81// #[cfg(test)]
82// mod tests {
83//     use crate::*;
84//     #[test]
85//     fn it_works() {
86//         let kernel = Convolutional::new(RBF);
87
88//         let test_value = kernel.value(&[1.0], &vec![0.0, 0.0, 0.0], &vec![0.0, 0.0, 0.0]);
89
90//         match test_value {
91//             Err(KernelError::ParametersLengthMismatch) => (),
92//             _ => panic!(),
93//         };
94//     }
95// }