1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
use opensrdk_symbolic_computation::Expression;
use super::ActivationFunction;
use crate::{KernelAdd, KernelError, KernelMul, PositiveDefiniteKernel};
use std::{
fmt::Debug,
ops::{Add, Mul},
};
/// https://arxiv.org/abs/1711.00165
#[derive(Clone, Debug)]
pub struct DeepNeuralNetwork<'a> {
layers: Vec<&'a dyn ActivationFunction>,
}
impl<'a> DeepNeuralNetwork<'a> {
pub fn new(layers: Vec<&'a dyn ActivationFunction>) -> Self {
Self { layers }
}
}
impl<'a> PositiveDefiniteKernel for DeepNeuralNetwork<'a> {
fn params_len(&self) -> usize {
2 * (1 + self.layers.len())
}
fn expression(
&self,
x: Expression,
x_prime: Expression,
params: &[Expression],
) -> Result<Expression, KernelError> {
if params.len() != self.params_len() {
return Err(KernelError::ParametersLengthMismatch.into());
}
// if x.len() != x_prime.len() {
// return Err(KernelError::InvalidArgument.into());
// }
todo!()
}
// fn value(&self, params: &[f64], x: &Vec<f64>, xprime: &Vec<f64>) -> Result<f64, KernelError> {
// if params.len() != self.params_len() {
// return Err(KernelError::ParametersLengthMismatch.into());
// }
// if x.len() != xprime.len() {
// return Err(KernelError::InvalidArgument.into());
// }
// let layer0 = Constant + Constant * Linear;
// let mut previous_layer_kernel = (
// layer0.value(¶ms[0..2], x, xprime)?,
// layer0.value(¶ms[0..2], x, x)?,
// layer0.value(¶ms[0..2], xprime, xprime)?,
// );
// let params = ¶ms[2..];
// for (i, &layer) in self.layers.iter().enumerate() {
// let sigma_b = params[(i + 1) * 2];
// let sigma_w = params[(i + 1) * 2 + 1];
// let f = layer.f(previous_layer_kernel);
// let fxx = layer.f((
// previous_layer_kernel.1,
// previous_layer_kernel.1,
// previous_layer_kernel.1,
// ));
// let fxpxp = layer.f((
// previous_layer_kernel.2,
// previous_layer_kernel.2,
// previous_layer_kernel.2,
// ));
// previous_layer_kernel = (
// sigma_b + sigma_w * f,
// sigma_b + sigma_w * fxx,
// sigma_b + sigma_w * fxpxp,
// );
// }
// Ok(previous_layer_kernel.0)
// }
}
impl<'a, R> Add<R> for DeepNeuralNetwork<'a>
where
R: PositiveDefiniteKernel,
{
type Output = KernelAdd<Self, R>;
fn add(self, rhs: R) -> Self::Output {
Self::Output::new(self, rhs)
}
}
impl<'a, R> Mul<R> for DeepNeuralNetwork<'a>
where
R: PositiveDefiniteKernel,
{
type Output = KernelMul<Self, R>;
fn mul(self, rhs: R) -> Self::Output {
Self::Output::new(self, rhs)
}
}
// #[cfg(test)]
// mod tests {
// use crate::*;
// #[test]
// fn it_works() {
// let activfunc = ReLU;
// let kernel = DeepNeuralNetwork::new(vec![&activfunc]);
// let test_value = kernel.value(
// &[1.0, 1.0, 3.0, 4.0, 6.0],
// &vec![0.0, 0.0, 0.0],
// &vec![0.0, 0.0, 0.0],
// );
// match test_value {
// Err(KernelError::ParametersLengthMismatch) => (),
// _ => panic!(),
// };
// }
// }