use opensrdk_symbolic_computation::Expression;
use super::ActivationFunction;
use crate::{KernelAdd, KernelError, KernelMul, PositiveDefiniteKernel};
use std::{
fmt::Debug,
ops::{Add, Mul},
};
#[derive(Clone, Debug)]
pub struct DeepNeuralNetwork<'a> {
layers: Vec<&'a dyn ActivationFunction>,
}
impl<'a> DeepNeuralNetwork<'a> {
pub fn new(layers: Vec<&'a dyn ActivationFunction>) -> Self {
Self { layers }
}
}
impl<'a> PositiveDefiniteKernel for DeepNeuralNetwork<'a> {
fn params_len(&self) -> usize {
2 * (1 + self.layers.len())
}
fn expression(
&self,
x: Expression,
x_prime: Expression,
params: &[Expression],
) -> Result<Expression, KernelError> {
if params.len() != self.params_len() {
return Err(KernelError::ParametersLengthMismatch.into());
}
todo!()
}
}
impl<'a, R> Add<R> for DeepNeuralNetwork<'a>
where
R: PositiveDefiniteKernel,
{
type Output = KernelAdd<Self, R>;
fn add(self, rhs: R) -> Self::Output {
Self::Output::new(self, rhs)
}
}
impl<'a, R> Mul<R> for DeepNeuralNetwork<'a>
where
R: PositiveDefiniteKernel,
{
type Output = KernelMul<Self, R>;
fn mul(self, rhs: R) -> Self::Output {
Self::Output::new(self, rhs)
}
}