opensrdk_kernel_method/neural_network/
deep_neural_network.rs

1use opensrdk_symbolic_computation::Expression;
2
3use super::ActivationFunction;
4use crate::{KernelAdd, KernelError, KernelMul, PositiveDefiniteKernel};
5use std::{
6    fmt::Debug,
7    ops::{Add, Mul},
8};
9
10/// https://arxiv.org/abs/1711.00165
11#[derive(Clone, Debug)]
12pub struct DeepNeuralNetwork<'a> {
13    layers: Vec<&'a dyn ActivationFunction>,
14}
15
16impl<'a> DeepNeuralNetwork<'a> {
17    pub fn new(layers: Vec<&'a dyn ActivationFunction>) -> Self {
18        Self { layers }
19    }
20}
21impl<'a> PositiveDefiniteKernel for DeepNeuralNetwork<'a> {
22    fn params_len(&self) -> usize {
23        2 * (1 + self.layers.len())
24    }
25
26    fn expression(
27        &self,
28        x: Expression,
29        x_prime: Expression,
30        params: &[Expression],
31    ) -> Result<Expression, KernelError> {
32        if params.len() != self.params_len() {
33            return Err(KernelError::ParametersLengthMismatch.into());
34        }
35        // if x.len() != x_prime.len() {
36        //     return Err(KernelError::InvalidArgument.into());
37        // }
38        todo!()
39    }
40
41    // fn value(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> {
42    //     if params.len() != self.params_len() {
43    //         return Err(KernelError::ParametersLengthMismatch.into());
44    //     }
45    //     if x.len() != xprime.len() {
46    //         return Err(KernelError::InvalidArgument.into());
47    //     }
48
49    //     let layer0 = Constant + Constant * Linear;
50    //     let mut previous_layer_kernel = (
51    //         layer0.value(&params[0..2], x, xprime)?,
52    //         layer0.value(&params[0..2], x, x)?,
53    //         layer0.value(&params[0..2], xprime, xprime)?,
54    //     );
55    //     let params = &params[2..];
56
57    //     for (i, &layer) in self.layers.iter().enumerate() {
58    //         let sigma_b = params[(i + 1) * 2];
59    //         let sigma_w = params[(i + 1) * 2 + 1];
60    //         let f = layer.f(previous_layer_kernel);
61    //         let fxx = layer.f((
62    //             previous_layer_kernel.1,
63    //             previous_layer_kernel.1,
64    //             previous_layer_kernel.1,
65    //         ));
66    //         let fxpxp = layer.f((
67    //             previous_layer_kernel.2,
68    //             previous_layer_kernel.2,
69    //             previous_layer_kernel.2,
70    //         ));
71
72    //         previous_layer_kernel = (
73    //             sigma_b + sigma_w * f,
74    //             sigma_b + sigma_w * fxx,
75    //             sigma_b + sigma_w * fxpxp,
76    //         );
77    //     }
78
79    //     Ok(previous_layer_kernel.0)
80    // }
81}
82
83impl<'a, R> Add<R> for DeepNeuralNetwork<'a>
84where
85    R: PositiveDefiniteKernel,
86{
87    type Output = KernelAdd<Self, R>;
88    fn add(self, rhs: R) -> Self::Output {
89        Self::Output::new(self, rhs)
90    }
91}
92
93impl<'a, R> Mul<R> for DeepNeuralNetwork<'a>
94where
95    R: PositiveDefiniteKernel,
96{
97    type Output = KernelMul<Self, R>;
98
99    fn mul(self, rhs: R) -> Self::Output {
100        Self::Output::new(self, rhs)
101    }
102}
103
104// #[cfg(test)]
105// mod tests {
106//     use crate::*;
107
108//     #[test]
109//     fn it_works() {
110//         let activfunc = ReLU;
111//         let kernel = DeepNeuralNetwork::new(vec![&activfunc]);
112
113//         let test_value = kernel.value(
114//             &[1.0, 1.0, 3.0, 4.0, 6.0],
115//             &vec![0.0, 0.0, 0.0],
116//             &vec![0.0, 0.0, 0.0],
117//         );
118
119//         match test_value {
120//             Err(KernelError::ParametersLengthMismatch) => (),
121//             _ => panic!(),
122//         };
123//     }
124// }