use crate::autograd::Variable;
use crate::nn::loss::Loss;
use crate::nn::Module;
use crate::optim::Optimizer;
use crate::training::TrainableModel;
use anyhow::Result;
use num_traits::Float;
use std::fmt::Debug;
pub struct Sequential<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
layers: Vec<Box<dyn Module<T> + Send + Sync>>,
compiled: bool,
optimizer: Option<Box<dyn Optimizer + Send + Sync>>,
loss_fn: Option<Box<dyn Loss<T> + Send + Sync>>,
metrics: Vec<String>,
training: bool,
name: Option<String>,
}
impl<T> std::fmt::Debug for Sequential<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Sequential")
.field("layers", &self.layers.len())
.field("compiled", &self.compiled)
.field("metrics", &self.metrics)
.field("training", &self.training)
.field("name", &self.name)
.finish()
}
}
impl<T> Sequential<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
pub fn new() -> Self {
Self {
layers: Vec::new(),
compiled: false,
optimizer: None,
loss_fn: None,
metrics: Vec::new(),
training: false,
name: None,
}
}
pub fn with_name(name: impl Into<String>) -> Self {
Self {
layers: Vec::new(),
compiled: false,
optimizer: None,
loss_fn: None,
metrics: Vec::new(),
training: false,
name: Some(name.into()),
}
}
pub fn add<M>(mut self, layer: M) -> Self
where
M: Module<T> + Send + Sync + 'static,
{
self.layers.push(Box::new(layer));
self
}
pub fn add_layer<M>(&mut self, layer: M)
where
M: Module<T> + Send + Sync + 'static,
{
self.layers.push(Box::new(layer));
}
pub fn insert<M>(&mut self, index: usize, layer: M) -> Result<()>
where
M: Module<T> + Send + Sync + 'static,
{
if index > self.layers.len() {
return Err(anyhow::anyhow!(
"Index {} out of bounds for {} layers",
index,
self.layers.len()
));
}
self.layers.insert(index, Box::new(layer));
Ok(())
}
pub fn remove(&mut self, index: usize) -> Result<()> {
if index >= self.layers.len() {
return Err(anyhow::anyhow!(
"Index {} out of bounds for {} layers",
index,
self.layers.len()
));
}
self.layers.remove(index);
Ok(())
}
pub fn len(&self) -> usize {
self.layers.len()
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
pub fn compile<O, L>(&mut self, optimizer: O, loss_fn: L, metrics: Vec<String>) -> Result<()>
where
O: Optimizer + Send + Sync + 'static,
L: Loss<T> + Send + Sync + 'static,
{
if self.layers.is_empty() {
return Err(anyhow::anyhow!(
"Cannot compile empty model. Add layers first."
));
}
self.optimizer = Some(Box::new(optimizer));
self.loss_fn = Some(Box::new(loss_fn));
self.metrics = metrics;
self.compiled = true;
Ok(())
}
pub fn is_compiled(&self) -> bool {
self.compiled
}
pub fn summary(&self) -> String {
let mut summary = String::new();
if let Some(ref name) = self.name {
summary.push_str(&format!("Model: \"{}\"\n", name));
} else {
summary.push_str("Sequential Model\n");
}
summary.push_str("_________________________________________________________________\n");
summary.push_str("Layer (type) Output Shape Param # \n");
summary.push_str("=================================================================\n");
let mut total_params = 0;
for (i, layer) in self.layers.iter().enumerate() {
let layer_info = format!("Layer_{} (type)", i);
summary.push_str(&format!("{:<2} {}\n", i, layer_info));
total_params += layer.parameters().len();
}
summary.push_str("=================================================================\n");
summary.push_str(&format!("Total params: {}\n", total_params));
summary.push_str(&format!("Trainable params: {}\n", total_params)); summary.push_str(&format!("Non-trainable params: 0\n"));
summary.push_str("_________________________________________________________________\n");
if self.compiled {
summary.push_str("\nModel compiled with:\n");
summary.push_str(&format!(" - Optimizer: {}\n", "Configured")); summary.push_str(&format!(" - Loss: {}\n", "Configured")); summary.push_str(&format!(" - Metrics: {:?}\n", self.metrics));
} else {
summary.push_str("\nModel not compiled yet. Call compile() before training.\n");
}
summary
}
pub fn clear(&mut self) {
self.layers.clear();
self.compiled = false;
self.optimizer = None;
self.loss_fn = None;
self.metrics.clear();
}
pub fn get_layer(&self, index: usize) -> Option<&Box<dyn Module<T> + Send + Sync>> {
self.layers.get(index)
}
pub fn layers(&self) -> &[Box<dyn Module<T> + Send + Sync>] {
&self.layers
}
pub fn validate(&self) -> Result<()> {
if self.layers.is_empty() {
return Err(anyhow::anyhow!("Model has no layers"));
}
if !self.compiled {
return Err(anyhow::anyhow!("Model is not compiled"));
}
Ok(())
}
pub fn total_parameters(&self) -> usize {
self.layers
.iter()
.map(|layer| layer.parameters().len())
.sum()
}
pub fn trainable_parameters(&self) -> usize {
self.total_parameters()
}
}
impl<T> Default for Sequential<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
fn default() -> Self {
Self::new()
}
}
impl<T> Module<T> for Sequential<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
let mut output = input.clone();
for layer in &self.layers {
output = layer.forward(&output);
}
output
}
fn train(&mut self) {
self.training = true;
for layer in &mut self.layers {
layer.train();
}
}
fn eval(&mut self) {
self.training = false;
for layer in &mut self.layers {
layer.eval();
}
}
fn parameters(&self) -> Vec<Variable<T>> {
let mut params = Vec::new();
for layer in &self.layers {
let mut layer_params = layer.parameters();
params.append(&mut layer_params);
}
params
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl<T> TrainableModel<T> for Sequential<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
Module::forward(self, input)
}
fn train(&mut self) {
Module::train(self);
}
fn eval(&mut self) {
Module::eval(self);
}
fn parameters(&self) -> Vec<&Variable<T>> {
let params = Vec::new();
for _layer in &self.layers {
}
params
}
fn parameters_mut(&mut self) -> Vec<&mut Variable<T>> {
let params = Vec::new();
params
}
}
pub struct SequentialBuilder<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
model: Sequential<T>,
}
impl<T> SequentialBuilder<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
pub fn new() -> Self {
Self {
model: Sequential::new(),
}
}
pub fn with_name(name: impl Into<String>) -> Self {
Self {
model: Sequential::with_name(name),
}
}
pub fn add<M>(mut self, layer: M) -> Self
where
M: Module<T> + Send + Sync + 'static,
{
self.model = self.model.add(layer);
self
}
pub fn build(self) -> Sequential<T> {
self.model
}
}
impl<T> Default for SequentialBuilder<T>
where
T: Float
+ Send
+ Sync
+ 'static
+ Debug
+ Clone
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sequential_creation() {
let model: Sequential<f32> = Sequential::new();
assert_eq!(model.len(), 0);
assert!(model.is_empty());
assert!(!model.is_compiled());
}
#[test]
fn test_sequential_with_name() {
let model: Sequential<f32> = Sequential::with_name("test_model");
assert_eq!(model.name, Some("test_model".to_string()));
}
#[test]
fn test_sequential_summary() {
let model: Sequential<f32> = Sequential::new();
let summary = model.summary();
assert!(summary.contains("Sequential Model"));
assert!(summary.contains("Total params: 0"));
}
#[test]
fn test_sequential_builder() {
let model: Sequential<f32> = SequentialBuilder::new().build();
assert_eq!(model.len(), 0);
assert!(!model.is_compiled());
}
#[test]
fn test_sequential_validation() {
let model: Sequential<f32> = Sequential::new();
assert!(model.validate().is_err());
}
#[test]
fn test_sequential_clear() {
let mut model: Sequential<f32> = Sequential::new();
model.clear();
assert_eq!(model.len(), 0);
assert!(!model.is_compiled());
}
}