ferrite/network/module/
sequential.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use super::module::*;
5use crate::tensor::*;
6
7
8pub struct Sequential {
9  layers: Vec<Box<dyn Module>>,
10  training: bool,
11}
12
13impl Sequential {
14  pub fn new(layers: Vec<Box<dyn Module>>) -> Self {
15    Self {
16      layers,
17      training: false,
18    }
19  }
20
21  pub fn add(&mut self, layer: Box<dyn Module>) {
22    self.layers.push(layer);
23  }
24
25  fn visit_parameters(&self, f: &mut dyn FnMut(&str, &Tensor)) {
26    for (idx, layer) in self.layers.iter().enumerate() {
27      // Create a new closure that prefixes the parameter names
28      let mut prefixed_f = |name: &str, tensor: &Tensor| {
29        let full_name = format!("layer_{}.{}", idx, name);
30        f(&full_name, tensor);
31      };
32      layer.visit_parameters(&mut prefixed_f);
33    }
34  }
35}
36
37impl Module for Sequential {
38  fn forward(&mut self, input: &Tensor) -> Tensor {
39    let mut current = input.clone();
40    for layer in self.layers.iter_mut() {
41      current = layer.forward(&current);
42    }
43    current
44  }
45
46  fn parameters(&self) -> HashMap<String, Arc<RwLock<Tensor>>> {
47    let mut params = HashMap::new();
48    for (idx, layer) in self.layers.iter().enumerate() {
49      for (name, param) in layer.parameters() {
50        let full_name = format!("layer_{}.{}", idx, name);
51        params.insert(full_name, param);
52      }
53    }
54    params
55  }
56
57  fn train(&mut self) {
58    self.training = true;
59    for layer in &mut self.layers {
60      layer.train();
61    }
62  }
63
64  fn eval(&mut self) {
65    self.training = false;
66    for layer in &mut self.layers {
67      layer.eval();
68    }
69  }
70
71  fn zero_grad(&mut self) {
72    for layer in &mut self.layers {
73      layer.zero_grad();
74    }
75  }
76}