border_tch_agent/mlp/
base.rs

1use super::{mlp, MlpConfig};
2use crate::model::{SubModel, SubModel2};
3use tch::{nn, nn::Module, Device, Tensor};
4
5/// Multilayer perceptron with ReLU activation function.
6pub struct Mlp {
7    config: MlpConfig,
8    device: Device,
9    seq: nn::Sequential,
10}
11
12impl Mlp {
13    fn create_net(var_store: &nn::VarStore, config: &MlpConfig) -> nn::Sequential {
14        let p = &(var_store.root() / "mlp");
15        let mut seq = nn::seq();
16        let mut in_dim = config.in_dim;
17
18        for (i, &out_dim) in config.units.iter().enumerate() {
19            seq = seq.add(nn::linear(
20                p / format!("{}{}", "ln", i),
21                in_dim,
22                out_dim,
23                Default::default(),
24            ));
25            seq = seq.add_fn(|x| x.relu());
26            in_dim = out_dim;
27        }
28
29        seq = seq.add(nn::linear(
30            p / format!("{}{}", "ln", config.units.len()),
31            in_dim,
32            config.out_dim,
33            Default::default(),
34        ));
35
36        if config.activation_out {
37            seq = seq.add_fn(|x| x.relu());
38        }
39
40        seq
41    }
42}
43
44impl SubModel for Mlp {
45    type Config = MlpConfig;
46    type Input = Tensor;
47    type Output = Tensor;
48
49    fn forward(&self, x: &Self::Input) -> Tensor {
50        self.seq.forward(&x.to(self.device))
51    }
52
53    fn build(var_store: &nn::VarStore, config: Self::Config) -> Self {
54        let device = var_store.device();
55        let seq = Self::create_net(var_store, &config);
56
57        Self {
58            config,
59            device,
60            seq,
61        }
62    }
63
64    fn clone_with_var_store(&self, var_store: &nn::VarStore) -> Self {
65        let config = self.config.clone();
66        let device = var_store.device();
67        let seq = Self::create_net(&var_store, &config);
68
69        Self {
70            config,
71            device,
72            seq,
73        }
74    }
75}
76
77impl SubModel2 for Mlp {
78    type Config = MlpConfig;
79    type Input1 = Tensor;
80    type Input2 = Tensor;
81    type Output = Tensor;
82
83    fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output {
84        let input1: Tensor = input1.to(self.device);
85        let input2: Tensor = input2.to(self.device);
86        let input = Tensor::cat(&[input1, input2], -1);
87        self.seq.forward(&input.to(self.device))
88    }
89
90    fn build(var_store: &nn::VarStore, config: Self::Config) -> Self {
91        let units = &config.units;
92        let in_dim = *units.last().unwrap_or(&config.in_dim);
93        let out_dim = config.out_dim;
94        let p = &(var_store.root() / "mlp");
95        let seq = mlp("ln", var_store, &config).add(nn::linear(
96            p / format!("ln{}", units.len()),
97            in_dim,
98            out_dim,
99            Default::default(),
100        ));
101
102        Self {
103            config,
104            device: var_store.device(),
105            seq,
106        }
107    }
108
109    fn clone_with_var_store(&self, var_store: &nn::VarStore) -> Self {
110        let config = self.config.clone();
111        let device = var_store.device();
112        let seq = Self::create_net(&var_store, &config);
113
114        Self {
115            config,
116            device,
117            seq,
118        }
119    }
120}