ferrite/network/module/
sequential.rs1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3
4use super::module::*;
5use crate::tensor::*;
6
7
8pub struct Sequential {
9 layers: Vec<Box<dyn Module>>,
10 training: bool,
11}
12
13impl Sequential {
14 pub fn new(layers: Vec<Box<dyn Module>>) -> Self {
15 Self {
16 layers,
17 training: false,
18 }
19 }
20
21 pub fn add(&mut self, layer: Box<dyn Module>) {
22 self.layers.push(layer);
23 }
24
25 fn visit_parameters(&self, f: &mut dyn FnMut(&str, &Tensor)) {
26 for (idx, layer) in self.layers.iter().enumerate() {
27 let mut prefixed_f = |name: &str, tensor: &Tensor| {
29 let full_name = format!("layer_{}.{}", idx, name);
30 f(&full_name, tensor);
31 };
32 layer.visit_parameters(&mut prefixed_f);
33 }
34 }
35}
36
37impl Module for Sequential {
38 fn forward(&mut self, input: &Tensor) -> Tensor {
39 let mut current = input.clone();
40 for layer in self.layers.iter_mut() {
41 current = layer.forward(¤t);
42 }
43 current
44 }
45
46 fn parameters(&self) -> HashMap<String, Arc<RwLock<Tensor>>> {
47 let mut params = HashMap::new();
48 for (idx, layer) in self.layers.iter().enumerate() {
49 for (name, param) in layer.parameters() {
50 let full_name = format!("layer_{}.{}", idx, name);
51 params.insert(full_name, param);
52 }
53 }
54 params
55 }
56
57 fn train(&mut self) {
58 self.training = true;
59 for layer in &mut self.layers {
60 layer.train();
61 }
62 }
63
64 fn eval(&mut self) {
65 self.training = false;
66 for layer in &mut self.layers {
67 layer.eval();
68 }
69 }
70
71 fn zero_grad(&mut self) {
72 for layer in &mut self.layers {
73 layer.zero_grad();
74 }
75 }
76}