use crate::core::{error::BellandeError, tensor::Tensor};
pub trait NeuralLayer: Send + Sync {
fn forward(&mut self, input: &Tensor) -> Result<Tensor, BellandeError>;
fn backward(&mut self, grad: &Tensor) -> Result<Tensor, BellandeError>;
fn parameters(&self) -> Vec<Tensor>;
fn named_parameters(&self) -> Vec<(String, Tensor)>;
fn set_parameter(&mut self, name: &str, value: Tensor) -> Result<(), BellandeError>;
fn train(&mut self);
fn eval(&mut self);
}
pub struct Sequential {
pub(crate) layers: Vec<Box<dyn NeuralLayer>>,
pub(crate) training: bool,
}
impl Sequential {
pub fn new() -> Self {
Sequential {
layers: Vec::new(),
training: true,
}
}
pub fn add(&mut self, layer: Box<dyn NeuralLayer>) -> &mut Self {
self.layers.push(layer);
self
}
pub fn forward(&mut self, input: &Tensor) -> Result<Tensor, BellandeError> {
let mut current = input.clone();
for layer in &mut self.layers {
current = layer.forward(¤t)?;
}
Ok(current)
}
pub fn backward(&mut self, grad: &Tensor) -> Result<Tensor, BellandeError> {
if !self.training {
return Err(BellandeError::InvalidBackward(
"Forward pass not called before backward".into(),
))?;
}
let mut current_grad = grad.clone();
for layer in self.layers.iter_mut().rev() {
current_grad = layer.backward(¤t_grad)?;
}
Ok(current_grad)
}
pub fn parameters(&self) -> Vec<Tensor> {
self.layers
.iter()
.flat_map(|layer| layer.parameters())
.collect()
}
pub fn len(&self) -> usize {
self.layers.len()
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
pub fn get_layer(&self, index: usize) -> Option<&Box<dyn NeuralLayer>> {
self.layers.get(index)
}
pub fn get_layer_mut(&mut self, index: usize) -> Option<&mut Box<dyn NeuralLayer>> {
self.layers.get_mut(index)
}
pub fn train(&mut self) {
self.training = true;
for layer in &mut self.layers {
layer.train();
}
}
pub fn eval(&mut self) {
self.training = false;
for layer in &mut self.layers {
layer.eval();
}
}
}
impl Default for Sequential {
fn default() -> Self {
Self::new()
}
}