use crate::autograd::Variable;
use crate::nn::Module;
use crate::training::trainer::TrainableModel;
use num_traits::Float;
use std::fmt::Debug;
#[derive(Debug)]
pub struct BasicSequential<
T: Float
+ Send
+ Sync
+ Debug
+ Clone
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
> {
layers: Vec<Box<dyn Module<T> + Send + Sync>>,
training: bool,
name: Option<String>,
}
impl<
T: Float
+ Send
+ Sync
+ Debug
+ Clone
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
> BasicSequential<T>
{
pub fn new() -> Self {
Self {
layers: Vec::new(),
training: false,
name: None,
}
}
pub fn with_name(name: impl Into<String>) -> Self {
Self {
layers: Vec::new(),
training: false,
name: Some(name.into()),
}
}
pub fn add_layer<M>(&mut self, layer: M)
where
M: Module<T> + Send + Sync + 'static,
{
self.layers.push(Box::new(layer));
}
pub fn len(&self) -> usize {
self.layers.len()
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
pub fn clear(&mut self) {
self.layers.clear();
}
pub fn total_parameters(&self) -> usize {
self.layers
.iter()
.map(|layer| layer.parameters().len())
.sum()
}
pub fn summary(&self) -> String {
let mut summary = String::new();
if let Some(ref name) = self.name {
summary.push_str(&format!("Model: \"{}\"\n", name));
} else {
summary.push_str("Basic Sequential Model\n");
}
summary.push_str(&format!("Layers: {}\n", self.layers.len()));
summary.push_str(&format!("Total parameters: {}\n", self.total_parameters()));
summary.push_str(&format!("Training mode: {}\n", self.training));
summary
}
}
impl<
T: Float
+ Send
+ Sync
+ Debug
+ Clone
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
> Default for BasicSequential<T>
{
fn default() -> Self {
Self::new()
}
}
impl<
T: Float
+ Send
+ Sync
+ Debug
+ Clone
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
> Module<T> for BasicSequential<T>
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
let mut output = input.clone();
for layer in &self.layers {
output = layer.forward(&output);
}
output
}
fn train(&mut self) {
self.training = true;
for layer in &mut self.layers {
layer.train();
}
}
fn eval(&mut self) {
self.training = false;
for layer in &mut self.layers {
layer.eval();
}
}
fn parameters(&self) -> Vec<Variable<T>> {
let mut params = Vec::new();
for layer in &self.layers {
let mut layer_params = layer.parameters();
params.append(&mut layer_params);
}
params
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl<
T: Float
+ Send
+ Sync
+ Debug
+ Clone
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
> TrainableModel<T> for BasicSequential<T>
{
fn forward(&self, input: &Variable<T>) -> Variable<T> {
Module::forward(self, input)
}
fn train(&mut self) {
Module::train(self);
}
fn eval(&mut self) {
Module::eval(self);
}
fn parameters(&self) -> Vec<&Variable<T>> {
Vec::new()
}
fn parameters_mut(&mut self) -> Vec<&mut Variable<T>> {
Vec::new()
}
}
pub struct BasicSequentialBuilder<
T: Float
+ Send
+ Sync
+ Debug
+ Clone
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
> {
model: BasicSequential<T>,
}
impl<
T: Float
+ Send
+ Sync
+ Debug
+ Clone
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
> BasicSequentialBuilder<T>
{
pub fn new() -> Self {
Self {
model: BasicSequential::new(),
}
}
pub fn with_name(name: impl Into<String>) -> Self {
Self {
model: BasicSequential::with_name(name),
}
}
pub fn add<M>(mut self, layer: M) -> Self
where
M: Module<T> + Send + Sync + 'static,
{
self.model.add_layer(layer);
self
}
pub fn build(self) -> BasicSequential<T> {
self.model
}
}
impl<
T: Float
+ Send
+ Sync
+ Debug
+ Clone
+ 'static
+ ndarray::ScalarOperand
+ num_traits::FromPrimitive,
> Default for BasicSequentialBuilder<T>
{
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_basic_sequential_creation() {
let model: BasicSequential<f32> = BasicSequential::new();
assert_eq!(model.len(), 0);
assert!(model.is_empty());
}
#[test]
fn test_basic_sequential_with_name() {
let model: BasicSequential<f32> = BasicSequential::with_name("test_model");
let summary = model.summary();
assert!(summary.contains("test_model"));
}
#[test]
fn test_basic_sequential_builder() {
let model: BasicSequential<f32> = BasicSequentialBuilder::new().build();
assert_eq!(model.len(), 0);
}
#[test]
fn test_basic_sequential_summary() {
let model: BasicSequential<f32> = BasicSequential::new();
let summary = model.summary();
assert!(summary.contains("Basic Sequential Model"));
assert!(summary.contains("Layers: 0"));
}
#[test]
fn test_basic_sequential_clear() {
let mut model: BasicSequential<f32> = BasicSequential::new();
model.clear();
assert_eq!(model.len(), 0);
assert!(model.is_empty());
}
#[test]
fn test_basic_sequential_training_mode() {
let mut model: BasicSequential<f32> = BasicSequential::new();
Module::train(&mut model);
Module::eval(&mut model);
assert_eq!(model.len(), 0);
}
#[test]
fn test_basic_sequential_forward() {
let model: BasicSequential<f32> = BasicSequential::new();
let input_data = vec![1.0, 2.0, 3.0];
let input_tensor = Tensor::from_vec(input_data, vec![1, 3]);
let input_var = Variable::new(input_tensor, false);
let output = Module::forward(&model, &input_var);
assert!(output.data().read().unwrap().shape().len() > 0);
}
}