use super::neural_network_trait::{Layer, LossFunction, Optimizer};
use crate::error::{IoError, ModelError};
use crate::neural_network::Tensor;
use crate::neural_network::layer::TrainingParameters;
use crate::neural_network::layer::layer_weight::LayerWeight;
use crate::neural_network::layer::serialize_weight::{
LayerInfo, SerializableLayer, SerializableLayerWeight, SerializableSequential,
apply_weights_to_layer,
};
use indicatif::{ProgressBar, ProgressStyle};
use ndarray::{Array, IxDyn, s};
use ndarray_rand::rand::{rng, seq::SliceRandom};
use serde_json::{from_reader, to_writer_pretty};
use std::fs::File;
use std::io::{BufWriter, Write};
pub struct Sequential {
layers: Vec<Box<dyn Layer>>,
optimizer: Option<Box<dyn Optimizer>>,
loss: Option<Box<dyn LossFunction>>,
}
impl Sequential {
pub fn new() -> Self {
Self {
layers: Vec::new(),
optimizer: None,
loss: None,
}
}
pub fn add<L: 'static + Layer>(&mut self, layer: L) -> &mut Self {
self.layers.push(Box::new(layer));
self
}
pub fn compile<O, LFunc>(&mut self, optimizer: O, loss: LFunc) -> &mut Self
where
O: 'static + Optimizer,
LFunc: 'static + LossFunction,
{
self.optimizer = Some(Box::new(optimizer));
self.loss = Some(Box::new(loss));
self
}
fn validate_training_inputs(&self, x: &Tensor, y: &Tensor) -> Result<(), ModelError> {
if self.optimizer.is_none() {
return Err(ModelError::InputValidationError(
"Optimizer not specified".to_string(),
));
}
if self.loss.is_none() {
return Err(ModelError::InputValidationError(
"Loss function not specified".to_string(),
));
}
if self.layers.is_empty() {
return Err(ModelError::InputValidationError(
"Layers not specified".to_string(),
));
}
if x.is_empty() || y.is_empty() {
return Err(ModelError::InputValidationError(
"Input tensors cannot be empty".to_string(),
));
}
if x.shape()[0] != y.shape()[0] {
return Err(ModelError::InputValidationError(format!(
"Batch size mismatch: input has {} samples, target has {} samples",
x.shape()[0],
y.shape()[0]
)));
}
Ok(())
}
fn train_batch(&mut self, x: &Tensor, y: &Tensor) -> Result<f32, ModelError> {
let mut layers_iter = self.layers.iter_mut();
let first_layer = layers_iter
.next()
.ok_or_else(|| ModelError::InputValidationError("No layers in model".to_string()))?;
first_layer.set_training_if_mode_dependent(true);
let mut output = first_layer.forward(x)?;
for layer in layers_iter {
layer.set_training_if_mode_dependent(true);
output = layer.forward(&output)?;
}
let loss_value = self.loss.as_ref().unwrap().compute_loss(y, &output);
let mut grad = self.loss.as_ref().unwrap().compute_grad(y, &output);
for layer in self.layers.iter_mut().rev() {
grad = match layer.backward(&grad) {
Ok(grad) => grad,
Err(e) => return Err(e),
};
if let Some(ref mut optimizer) = self.optimizer {
optimizer.update(&mut **layer);
}
}
Ok(loss_value)
}
pub fn fit(&mut self, x: &Tensor, y: &Tensor, epochs: u32) -> Result<&mut Self, ModelError> {
self.validate_training_inputs(x, y)?;
let n_samples = x.shape()[0];
let progress_bar = ProgressBar::new(epochs as u64);
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} | Loss: {msg}")
.expect("Failed to set progress bar template")
.progress_chars("█▓░"),
);
for _ in 0..epochs {
let loss_value = self.train_batch(x, y)?;
progress_bar.set_message(format!("{:.6}", loss_value));
progress_bar.inc(1);
}
progress_bar.finish_with_message("Training completed");
println!(
"\nNeural network training completed: {} samples, {} epochs",
n_samples, epochs
);
Ok(self)
}
pub fn fit_with_batches(
&mut self,
x: &Tensor,
y: &Tensor,
epochs: u32,
batch_size: usize,
) -> Result<&mut Self, ModelError> {
self.validate_training_inputs(x, y)?;
let n_samples = x.shape()[0];
if batch_size == 0 {
return Err(ModelError::InputValidationError(
"Batch size must be greater than 0".to_string(),
));
}
if batch_size > n_samples {
return Err(ModelError::InputValidationError(format!(
"Batch size ({}) cannot be larger than dataset size ({})",
batch_size, n_samples
)));
}
let create_batch_tensors =
|x: &Tensor, y: &Tensor, indices: &[usize]| -> Result<(Tensor, Tensor), ModelError> {
let batch_size = indices.len();
let mut x_batch_shape = x.shape().to_vec();
x_batch_shape[0] = batch_size;
let mut y_batch_shape = y.shape().to_vec();
y_batch_shape[0] = batch_size;
let mut x_batch_data = Vec::new();
let mut y_batch_data = Vec::new();
for &idx in indices {
let x_sample = x.slice(s![idx, ..]);
x_batch_data.extend(x_sample.iter().cloned());
let y_sample = y.slice(s![idx, ..]);
y_batch_data.extend(y_sample.iter().cloned());
}
let x_batch =
Array::from_shape_vec(IxDyn(&x_batch_shape), x_batch_data).map_err(|e| {
ModelError::ProcessingError(format!(
"Failed to create batch tensor for x: {}",
e
))
})?;
let y_batch =
Array::from_shape_vec(IxDyn(&y_batch_shape), y_batch_data).map_err(|e| {
ModelError::ProcessingError(format!(
"Failed to create batch tensor for y: {}",
e
))
})?;
Ok((x_batch, y_batch))
};
let mut indices: Vec<usize> = (0..n_samples).collect();
let total_batches = (n_samples + batch_size - 1) / batch_size;
let total_iterations = epochs as u64 * total_batches as u64;
let progress_bar = ProgressBar::new(total_iterations);
progress_bar.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} | Epoch {msg}")
.expect("Failed to set progress bar template")
.progress_chars("█▓░"),
);
for epoch in 0..epochs {
indices.shuffle(&mut rng());
let mut epoch_loss = 0.0;
let mut batch_count = 0;
for batch_indices in indices.chunks(batch_size) {
batch_count += 1;
let (batch_x, batch_y) = create_batch_tensors(x, y, batch_indices)?;
let batch_loss = self.train_batch(&batch_x, &batch_y)?;
epoch_loss += batch_loss;
let avg_loss = epoch_loss / batch_count as f32;
progress_bar.set_message(format!(
"{}/{} | Avg Loss: {:.6}",
epoch + 1,
epochs,
avg_loss
));
progress_bar.inc(1);
}
}
progress_bar.finish_with_message("Training completed");
println!(
"\nNeural network batch training completed: {} samples, {} batch size, {} epochs",
n_samples, batch_size, epochs
);
Ok(self)
}
pub fn predict(&mut self, x: &Tensor) -> Result<Tensor, ModelError> {
if x.is_empty() {
return Err(ModelError::InputValidationError(
"Input tensor cannot be empty".to_string(),
));
}
let mut layers_iter = self.layers.iter_mut();
let first_layer = layers_iter
.next()
.ok_or_else(|| ModelError::InputValidationError("Model has no layers".to_string()))?;
first_layer.set_training_if_mode_dependent(false);
let mut output = first_layer.forward(x)?;
for layer in layers_iter {
layer.set_training_if_mode_dependent(false);
output = layer.forward(&output)?;
}
Ok(output)
}
pub fn summary(&self) {
let col1_width = 33;
let col2_width = 24;
let col3_width = 15;
println!("Model: \"sequential\"");
println!(
"┏{}┳{}┳{}┓",
"━".repeat(col1_width),
"━".repeat(col2_width),
"━".repeat(col3_width)
);
println!(
"┃ {:<31} ┃ {:<22} ┃ {:>13} ┃",
"Layer (type)", "Output Shape", "Param #"
);
println!(
"┡{}╇{}╇{}┩",
"━".repeat(col1_width),
"━".repeat(col2_width),
"━".repeat(col3_width)
);
let mut total_params: usize = 0;
let mut trainable_param_count: usize = 0;
let mut non_trainable_param_count: usize = 0;
for (i, layer) in self.layers.iter().enumerate() {
let layer_name = if i == 0 {
"Layer".to_string()
} else {
format!("Layer_{}", i)
};
let out_shape = layer.output_shape();
let param_count = layer.param_count();
let param_count_num: usize;
match param_count {
TrainingParameters::Trainable(count) => {
trainable_param_count += count;
total_params += count;
param_count_num = count;
}
TrainingParameters::NonTrainable(count) => {
non_trainable_param_count += count;
total_params += count;
param_count_num = count;
}
_ => {
param_count_num = 0;
}
};
println!(
"│ {:<31} │ {:<22} │ {:>13} │",
format!("{} ({})", layer_name, layer.layer_type()),
out_shape,
param_count_num
);
}
println!(
"└{}┴{}┴{}┘",
"─".repeat(col1_width),
"─".repeat(col2_width),
"─".repeat(col3_width)
);
println!(" Total params: {} ({} B)", total_params, total_params * 4); println!(
" Trainable params: {} ({} B)",
trainable_param_count,
trainable_param_count * 4
);
println!(
" Non-trainable params: {} ({} B)",
non_trainable_param_count,
non_trainable_param_count * 4
);
}
pub fn get_weights(&self) -> Vec<LayerWeight<'_>> {
let mut weights = Vec::with_capacity(self.layers.len());
for layer in &self.layers {
weights.push(layer.get_weights());
}
weights
}
pub fn save_to_path(&self, path: &str) -> Result<(), IoError> {
let serializable_layers = self
.layers
.iter()
.map(|layer| {
let weights = layer.get_weights();
let layer_info = LayerInfo {
layer_type: layer.layer_type().to_string(),
output_shape: layer.output_shape(),
};
SerializableLayer {
info: layer_info,
weights: SerializableLayerWeight::from_layer_weight(&weights),
}
})
.collect();
let serializable_model = SerializableSequential {
layers: serializable_layers,
};
let file = File::create(path).map_err(IoError::StdIoError)?;
let mut writer = BufWriter::new(file);
to_writer_pretty(&mut writer, &serializable_model).map_err(IoError::JsonError)?;
writer.flush().map_err(IoError::StdIoError)?;
Ok(())
}
pub fn load_from_path(&mut self, path: &str) -> Result<(), IoError> {
let reader = IoError::load_in_buf_reader(path)?;
let serializable_model: SerializableSequential =
from_reader(reader).map_err(IoError::JsonError)?;
if serializable_model.layers.len() != self.layers.len() {
return Err(IoError::StdIoError(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Layer count mismatch: model has {} layers, file has {} layers",
self.layers.len(),
serializable_model.layers.len()
),
)));
}
for (i, serializable_layer) in serializable_model.layers.iter().enumerate() {
apply_weights_to_layer(
&mut *self.layers[i],
&serializable_layer.weights,
&serializable_layer.info.layer_type,
)?;
}
Ok(())
}
}