border_tch_agent/mlp/
base.rs1use super::{mlp, MlpConfig};
2use crate::model::{SubModel, SubModel2};
3use tch::{nn, nn::Module, Device, Tensor};
4
5pub struct Mlp {
7 config: MlpConfig,
8 device: Device,
9 seq: nn::Sequential,
10}
11
12impl Mlp {
13 fn create_net(var_store: &nn::VarStore, config: &MlpConfig) -> nn::Sequential {
14 let p = &(var_store.root() / "mlp");
15 let mut seq = nn::seq();
16 let mut in_dim = config.in_dim;
17
18 for (i, &out_dim) in config.units.iter().enumerate() {
19 seq = seq.add(nn::linear(
20 p / format!("{}{}", "ln", i),
21 in_dim,
22 out_dim,
23 Default::default(),
24 ));
25 seq = seq.add_fn(|x| x.relu());
26 in_dim = out_dim;
27 }
28
29 seq = seq.add(nn::linear(
30 p / format!("{}{}", "ln", config.units.len()),
31 in_dim,
32 config.out_dim,
33 Default::default(),
34 ));
35
36 if config.activation_out {
37 seq = seq.add_fn(|x| x.relu());
38 }
39
40 seq
41 }
42}
43
44impl SubModel for Mlp {
45 type Config = MlpConfig;
46 type Input = Tensor;
47 type Output = Tensor;
48
49 fn forward(&self, x: &Self::Input) -> Tensor {
50 self.seq.forward(&x.to(self.device))
51 }
52
53 fn build(var_store: &nn::VarStore, config: Self::Config) -> Self {
54 let device = var_store.device();
55 let seq = Self::create_net(var_store, &config);
56
57 Self {
58 config,
59 device,
60 seq,
61 }
62 }
63
64 fn clone_with_var_store(&self, var_store: &nn::VarStore) -> Self {
65 let config = self.config.clone();
66 let device = var_store.device();
67 let seq = Self::create_net(&var_store, &config);
68
69 Self {
70 config,
71 device,
72 seq,
73 }
74 }
75}
76
77impl SubModel2 for Mlp {
78 type Config = MlpConfig;
79 type Input1 = Tensor;
80 type Input2 = Tensor;
81 type Output = Tensor;
82
83 fn forward(&self, input1: &Self::Input1, input2: &Self::Input2) -> Self::Output {
84 let input1: Tensor = input1.to(self.device);
85 let input2: Tensor = input2.to(self.device);
86 let input = Tensor::cat(&[input1, input2], -1);
87 self.seq.forward(&input.to(self.device))
88 }
89
90 fn build(var_store: &nn::VarStore, config: Self::Config) -> Self {
91 let units = &config.units;
92 let in_dim = *units.last().unwrap_or(&config.in_dim);
93 let out_dim = config.out_dim;
94 let p = &(var_store.root() / "mlp");
95 let seq = mlp("ln", var_store, &config).add(nn::linear(
96 p / format!("ln{}", units.len()),
97 in_dim,
98 out_dim,
99 Default::default(),
100 ));
101
102 Self {
103 config,
104 device: var_store.device(),
105 seq,
106 }
107 }
108
109 fn clone_with_var_store(&self, var_store: &nn::VarStore) -> Self {
110 let config = self.config.clone();
111 let device = var_store.device();
112 let seq = Self::create_net(&var_store, &config);
113
114 Self {
115 config,
116 device,
117 seq,
118 }
119 }
120}