use std::ops::Range;
use crate::core::{data::{input::Input, matrix::Matrix}, layer::methods::activations::Activations};
use rand::RngCore;
use serde::{Serialize, Deserialize};
use super::{dense::Dense, methods::pair::GradientPair, conv::Convolutional};
#[typetag::serde]
pub trait Layer: Send + Sync{
fn forward(&self, _inputs: &Box<dyn Input>) -> Box<dyn Input> {
Box::new(Matrix::new_empty(0,0))
}
fn backward(&mut self, gradients: Box<dyn Input>, errors: Box<dyn Input>, data: Box<dyn Input>) -> Box<dyn Input>;
fn avg_gradient(&self, gradients: Vec<&Box<dyn Input>>) -> Box<dyn Input>;
fn update_gradients(&mut self, gradient_pair: (&Box<dyn Input>, &Box<dyn Input>), clip: Option<Range<f32>>); fn get_data(&self) -> Box<dyn Input>;
fn set_data(&mut self, data: &Box<dyn Input>);
fn update_errors(&self, errors: Box<dyn Input>) -> Box<dyn Input>;
fn get_gradients(&self, data: &Box<dyn Input>, data_at: &Box<dyn Input>, errors: &Box<dyn Input>) -> GradientPair;
fn get_activation(&self) -> Option<Activations> {
None
}
fn shape(&self) -> (usize,usize,usize);
fn get_loss(&self) -> f32;
fn update_gradient(&self) -> Box<dyn Input>;
fn get_weights(&self) -> Box<dyn Input>;
fn get_biases(&self) -> Box<dyn Input>;
}
#[derive(Serialize, Deserialize, Clone)]
pub enum LayerTypes{
DENSE(usize, Activations, f32),
CONV((usize, usize, usize), (usize, usize), usize, usize, Activations, f32),
}
#[derive(Serialize, Deserialize, Clone)]
pub enum InputTypes{
DENSE(usize),
CONV((usize, usize, usize), (usize, usize), usize, usize),
}
impl LayerTypes{
pub fn to_layer(&self, prev_rows: usize, rand: &mut Box<dyn RngCore>) -> Box<dyn Layer> {
return match self {
LayerTypes::DENSE(rows, activation, learning) => Box::new(Dense::new(prev_rows, *rows, activation.clone(), learning.clone(), rand)),
LayerTypes::CONV(shape, kernels, stride, filters, activation, learning) => Box::new(Convolutional::new(*filters, *kernels, *shape, *stride, *activation, *learning, rand))
};
}
pub fn get_size(&self) -> usize{
return match self{
LayerTypes::DENSE(rows, _, _) => *rows,
LayerTypes::CONV(shape, _, _, _, _, _) => shape.0 * shape.1,
}
}
}
impl InputTypes {
pub fn to_layer(&self) -> LayerTypes {
return match self {
InputTypes::DENSE(size) => LayerTypes::DENSE(*size, Activations::SIGMOID, 1.0),
InputTypes::CONV(shape, kernel_shape, stride, filters) => LayerTypes::CONV(*shape, *kernel_shape, *stride, *filters, Activations::SIGMOID, 1.0)
}
}
pub fn get_size(&self) -> usize {
return match self {
InputTypes::DENSE(size) => *size,
InputTypes::CONV(shape, _, _, _) => shape.0 * shape.1,
}
}
}