border_tch_agent/
mlp.rs

1//! Multilayer perceptron.
2mod base;
3mod config;
4mod mlp2;
5pub use base::Mlp;
6pub use config::MlpConfig;
7pub use mlp2::Mlp2;
8use tch::nn;
9
10fn mlp(prefix: &str, var_store: &nn::VarStore, config: &MlpConfig) -> nn::Sequential {
11    let mut seq = nn::seq();
12    let mut in_dim = config.in_dim;
13    let p = &(var_store.root() / "mlp");
14
15    for (i, &n) in config.units.iter().enumerate() {
16        seq = seq.add(nn::linear(
17            p / format!("{}{}", prefix, i),
18            in_dim,
19            n,
20            Default::default(),
21        ));
22        seq = seq.add_fn(|x| x.relu());
23        in_dim = n;
24    }
25
26    seq
27}