1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
//! Multilayer perceptron.
mod base;
mod config;
mod mlp2;
pub use base::MLP;
pub use config::MLPConfig;
pub use mlp2::MLP2;
use tch::{nn};

fn mlp(prefix: &str, var_store: &nn::VarStore, config: &MLPConfig) -> nn::Sequential {
    let mut seq = nn::seq();
    let mut in_dim = config.in_dim;
    let p = &var_store.root();

    for (i, &n) in config.units.iter().enumerate() {
        seq = seq.add(nn::linear(
            p / format!("{}{}", prefix, i + 1),
            in_dim,
            n,
            Default::default(),
        ));
        seq = seq.add_fn(|x| x.relu());
        in_dim = n;
    }

    seq
}