opensrdk_kernel_method/neural_network/deep_neural_network.rs
1use opensrdk_symbolic_computation::Expression;
2
3use super::ActivationFunction;
4use crate::{KernelAdd, KernelError, KernelMul, PositiveDefiniteKernel};
5use std::{
6 fmt::Debug,
7 ops::{Add, Mul},
8};
9
10/// https://arxiv.org/abs/1711.00165
11#[derive(Clone, Debug)]
12pub struct DeepNeuralNetwork<'a> {
13 layers: Vec<&'a dyn ActivationFunction>,
14}
15
16impl<'a> DeepNeuralNetwork<'a> {
17 pub fn new(layers: Vec<&'a dyn ActivationFunction>) -> Self {
18 Self { layers }
19 }
20}
21impl<'a> PositiveDefiniteKernel for DeepNeuralNetwork<'a> {
22 fn params_len(&self) -> usize {
23 2 * (1 + self.layers.len())
24 }
25
26 fn expression(
27 &self,
28 x: Expression,
29 x_prime: Expression,
30 params: &[Expression],
31 ) -> Result<Expression, KernelError> {
32 if params.len() != self.params_len() {
33 return Err(KernelError::ParametersLengthMismatch.into());
34 }
35 // if x.len() != x_prime.len() {
36 // return Err(KernelError::InvalidArgument.into());
37 // }
38 todo!()
39 }
40
41 // fn value(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> {
42 // if params.len() != self.params_len() {
43 // return Err(KernelError::ParametersLengthMismatch.into());
44 // }
45 // if x.len() != xprime.len() {
46 // return Err(KernelError::InvalidArgument.into());
47 // }
48
49 // let layer0 = Constant + Constant * Linear;
50 // let mut previous_layer_kernel = (
51 // layer0.value(¶ms[0..2], x, xprime)?,
52 // layer0.value(¶ms[0..2], x, x)?,
53 // layer0.value(¶ms[0..2], xprime, xprime)?,
54 // );
55 // let params = ¶ms[2..];
56
57 // for (i, &layer) in self.layers.iter().enumerate() {
58 // let sigma_b = params[(i + 1) * 2];
59 // let sigma_w = params[(i + 1) * 2 + 1];
60 // let f = layer.f(previous_layer_kernel);
61 // let fxx = layer.f((
62 // previous_layer_kernel.1,
63 // previous_layer_kernel.1,
64 // previous_layer_kernel.1,
65 // ));
66 // let fxpxp = layer.f((
67 // previous_layer_kernel.2,
68 // previous_layer_kernel.2,
69 // previous_layer_kernel.2,
70 // ));
71
72 // previous_layer_kernel = (
73 // sigma_b + sigma_w * f,
74 // sigma_b + sigma_w * fxx,
75 // sigma_b + sigma_w * fxpxp,
76 // );
77 // }
78
79 // Ok(previous_layer_kernel.0)
80 // }
81}
82
83impl<'a, R> Add<R> for DeepNeuralNetwork<'a>
84where
85 R: PositiveDefiniteKernel,
86{
87 type Output = KernelAdd<Self, R>;
88 fn add(self, rhs: R) -> Self::Output {
89 Self::Output::new(self, rhs)
90 }
91}
92
93impl<'a, R> Mul<R> for DeepNeuralNetwork<'a>
94where
95 R: PositiveDefiniteKernel,
96{
97 type Output = KernelMul<Self, R>;
98
99 fn mul(self, rhs: R) -> Self::Output {
100 Self::Output::new(self, rhs)
101 }
102}
103
104// #[cfg(test)]
105// mod tests {
106// use crate::*;
107
108// #[test]
109// fn it_works() {
110// let activfunc = ReLU;
111// let kernel = DeepNeuralNetwork::new(vec![&activfunc]);
112
113// let test_value = kernel.value(
114// &[1.0, 1.0, 3.0, 4.0, 6.0],
115// &vec![0.0, 0.0, 0.0],
116// &vec![0.0, 0.0, 0.0],
117// );
118
119// match test_value {
120// Err(KernelError::ParametersLengthMismatch) => (),
121// _ => panic!(),
122// };
123// }
124// }