use crate::error::{NeuralError, Result};
use crate::layers::{Layer, ParamLayer};
use crate::losses::Loss;
use crate::models::{History, Model, TrainingConfig};
use crate::optimizers::Optimizer;
use scirs2_core::ndarray::{Array, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
use std::fmt::{Debug, Display};
pub struct Sequential<F: Float + Debug + ScalarOperand + NumAssign + 'static> {
layers: Vec<Box<dyn Layer<F> + Send + Sync>>,
layer_outputs: Vec<Array<F, scirs2_core::ndarray::IxDyn>>,
input: Option<Array<F, scirs2_core::ndarray::IxDyn>>,
history: History<F>,
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + NumAssign + 'static> Default
for Sequential<F>
{
fn default() -> Self {
Self::new()
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + NumAssign + 'static> Clone
for Sequential<F>
{
fn clone(&self) -> Self {
Sequential {
layers: Vec::new(), layer_outputs: Vec::new(),
input: None,
history: History::default(),
}
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + NumAssign + 'static>
Sequential<F>
{
pub fn new() -> Self {
Self {
layers: Vec::new(),
layer_outputs: Vec::new(),
input: None,
history: History::default(),
}
}
pub fn from_layers(layers: Vec<Box<dyn Layer<F> + Send + Sync>>) -> Self {
Self {
layers,
layer_outputs: Vec::new(),
input: None,
history: History::default(),
}
}
pub fn add_layer<L: Layer<F> + 'static + Send + Sync>(&mut self, layer: L) -> &mut Self {
self.layers.push(Box::new(layer));
self
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
pub fn layers(&self) -> &[Box<dyn Layer<F> + Send + Sync>] {
&self.layers
}
pub fn layers_mut(&mut self) -> &mut Vec<Box<dyn Layer<F> + Send + Sync>> {
&mut self.layers
}
pub fn training_history(&self) -> &History<F> {
&self.history
}
pub fn training_history_mut(&mut self) -> &mut History<F> {
&mut self.history
}
pub fn predict_batched(
&self,
inputs: &Array<F, scirs2_core::ndarray::IxDyn>,
batch_size: usize,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
let inputshape = inputs.shape();
let num_samples = inputshape[0];
let mut outputs = Vec::new();
for i in (0..num_samples).step_by(batch_size) {
let end_idx = std::cmp::min(i + batch_size, num_samples);
let batch = inputs
.slice(scirs2_core::ndarray::s![i..end_idx, ..])
.to_owned()
.into_dyn();
let batch_output = self.forward(&batch)?;
outputs.push(batch_output);
}
if outputs.len() == 1 {
Ok(outputs.into_iter().next().expect("Operation failed"))
} else {
let mut concatenated = outputs[0].clone();
for output in outputs.into_iter().skip(1) {
concatenated = scirs2_core::ndarray::concatenate![
scirs2_core::ndarray::Axis(0),
concatenated,
output
];
}
Ok(concatenated)
}
}
pub fn fit(
&mut self,
x_train: &Array<F, scirs2_core::ndarray::IxDyn>,
y_train: &Array<F, scirs2_core::ndarray::IxDyn>,
config: &TrainingConfig,
loss_fn: &dyn Loss<F>,
optimizer: &mut dyn Optimizer<F>,
) -> Result<()> {
let num_samples = x_train.shape()[0];
let val_split_idx = if config.validation_split > 0.0 {
((1.0 - config.validation_split) * num_samples as f64) as usize
} else {
num_samples
};
let (x_train_split, x_val) = if config.validation_split > 0.0 {
let x_train_split = x_train
.slice(scirs2_core::ndarray::s![0..val_split_idx, ..])
.to_owned()
.into_dyn();
let x_val = x_train
.slice(scirs2_core::ndarray::s![val_split_idx.., ..])
.to_owned()
.into_dyn();
(x_train_split, Some(x_val))
} else {
(x_train.clone(), None)
};
let (y_train_split, y_val) = if config.validation_split > 0.0 {
let y_train_split = y_train
.slice(scirs2_core::ndarray::s![0..val_split_idx, ..])
.to_owned()
.into_dyn();
let y_val = y_train
.slice(scirs2_core::ndarray::s![val_split_idx.., ..])
.to_owned()
.into_dyn();
(y_train_split, Some(y_val))
} else {
(y_train.clone(), None)
};
for epoch in 0..config.epochs {
let mut epoch_loss = F::zero();
let num_batches = x_train_split.shape()[0].div_ceil(config.batch_size);
for i in 0..num_batches {
let start_idx = i * config.batch_size;
let end_idx =
std::cmp::min(start_idx + config.batch_size, x_train_split.shape()[0]);
let batch_x = x_train_split
.slice(scirs2_core::ndarray::s![start_idx..end_idx, ..])
.to_owned()
.into_dyn();
let batch_y = y_train_split
.slice(scirs2_core::ndarray::s![start_idx..end_idx, ..])
.to_owned()
.into_dyn();
let batch_loss = self.train_batch(&batch_x, &batch_y, loss_fn, optimizer)?;
epoch_loss += batch_loss;
}
let avg_train_loss =
epoch_loss / F::from_usize(num_batches).unwrap_or_else(|| F::one());
self.history.train_loss.push(avg_train_loss);
if let (Some(x_val), Some(y_val)) = (&x_val, &y_val) {
let val_loss = self.evaluate(x_val, y_val, loss_fn)?;
self.history.val_loss.push(val_loss);
} else {
self.history.val_loss.push(avg_train_loss);
}
if config.verbose > 0 {
println!(
"Epoch {}/{} - loss: {:.4}",
epoch + 1,
config.epochs,
avg_train_loss
);
}
}
Ok(())
}
}
impl<F: Float + Debug + ScalarOperand + FromPrimitive + Display + NumAssign + 'static> Model<F>
for Sequential<F>
{
fn forward(
&self,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
let mut current_output = input.clone();
for layer in &self.layers {
current_output = layer.forward(¤t_output)?;
}
Ok(current_output)
}
fn backward(
&self,
input: &Array<F, scirs2_core::ndarray::IxDyn>,
grad_output: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
if self.layer_outputs.is_empty() {
return Err(NeuralError::InferenceError(
"No forward pass performed before backward pass".to_string(),
));
}
let mut grad_input = grad_output.clone();
for (i, layer) in self.layers.iter().enumerate().rev() {
let layer_input = if i > 0 {
&self.layer_outputs[i - 1]
} else if let Some(saved_input) = &self.input {
saved_input
} else {
input
};
grad_input = layer.backward(layer_input, &grad_input)?;
}
Ok(grad_input)
}
fn update(&mut self, learningrate: F) -> Result<()> {
for layer in &mut self.layers {
layer.update(learningrate)?;
}
Ok(())
}
fn train_batch(
&mut self,
inputs: &Array<F, scirs2_core::ndarray::IxDyn>,
targets: &Array<F, scirs2_core::ndarray::IxDyn>,
loss_fn: &dyn Loss<F>,
optimizer: &mut dyn Optimizer<F>,
) -> Result<F> {
let mut layer_outputs = Vec::with_capacity(self.layers.len());
let mut current_output = inputs.clone();
for layer in &self.layers {
current_output = layer.forward(¤t_output)?;
layer_outputs.push(current_output.clone());
}
self.input = Some(inputs.clone());
self.layer_outputs = layer_outputs;
let predictions = self
.layer_outputs
.last()
.ok_or_else(|| NeuralError::InferenceError("No layers in model".to_string()))?;
let loss = loss_fn.forward(predictions, targets)?;
let loss_grad = loss_fn.backward(predictions, targets)?;
let mut grad_input = loss_grad;
for (i, layer) in self.layers.iter_mut().enumerate().rev() {
let layer_input = if i > 0 {
&self.layer_outputs[i - 1]
} else {
inputs
};
grad_input = layer.backward(layer_input, &grad_input)?;
}
let mut all_params = Vec::new();
let mut all_grads = Vec::new();
let mut param_layers = Vec::new();
for (i, layer) in self.layers.iter().enumerate() {
if let Some(param_layer) = layer
.as_any()
.downcast_ref::<Box<dyn ParamLayer<F> + Send + Sync>>()
{
param_layers.push(i);
for param in param_layer.get_parameters() {
all_params.push(param.clone());
}
for grad in param_layer.get_gradients() {
all_grads.push(grad.clone());
}
}
}
optimizer.update(&mut all_params, &all_grads)?;
let mut param_idx = 0;
for i in param_layers {
if let Some(param_layer) = self.layers[i]
.as_any_mut()
.downcast_mut::<Box<dyn ParamLayer<F> + Send + Sync>>()
{
let num_params = param_layer.get_parameters().len();
if param_idx + num_params <= all_params.len() {
let layer_params = all_params[param_idx..param_idx + num_params].to_vec();
param_layer.set_parameters(layer_params)?;
param_idx += num_params;
}
}
}
Ok(loss)
}
fn predict(
&self,
inputs: &Array<F, scirs2_core::ndarray::IxDyn>,
) -> Result<Array<F, scirs2_core::ndarray::IxDyn>> {
self.forward(inputs)
}
fn evaluate(
&self,
inputs: &Array<F, scirs2_core::ndarray::IxDyn>,
targets: &Array<F, scirs2_core::ndarray::IxDyn>,
loss_fn: &dyn Loss<F>,
) -> Result<F> {
let predictions = self.forward(inputs)?;
loss_fn.forward(&predictions, targets)
}
}