1mod base;
3mod config;
4mod mlp2;
5pub use base::Mlp;
6pub use config::MlpConfig;
7pub use mlp2::Mlp2;
8use tch::nn;
9
10fn mlp(prefix: &str, var_store: &nn::VarStore, config: &MlpConfig) -> nn::Sequential {
11 let mut seq = nn::seq();
12 let mut in_dim = config.in_dim;
13 let p = &(var_store.root() / "mlp");
14
15 for (i, &n) in config.units.iter().enumerate() {
16 seq = seq.add(nn::linear(
17 p / format!("{}{}", prefix, i),
18 in_dim,
19 n,
20 Default::default(),
21 ));
22 seq = seq.add_fn(|x| x.relu());
23 in_dim = n;
24 }
25
26 seq
27}