border_tch_agent/cnn/
base.rs1use super::AtariCnnConfig;
2use crate::model::SubModel;
3use tch::{nn, nn::Module, Device, Tensor};
4
5#[allow(clippy::upper_case_acronyms)]
6pub struct AtariCnn {
8 n_stack: i64,
9 out_dim: i64,
10 device: Device,
11 seq: nn::Sequential,
12 skip_linear: bool,
13}
14
15impl AtariCnn {
16 fn stride(s: i64) -> nn::ConvConfig {
17 nn::ConvConfig {
18 stride: s,
19 ..Default::default()
20 }
21 }
22
23 fn create_net(var_store: &nn::VarStore, n_stack: i64, out_dim: i64) -> nn::Sequential {
24 let p = &var_store.root();
25 nn::seq()
26 .add_fn(|xs| xs.squeeze_dim(2).internal_cast_float(true) / 255)
27 .add(nn::conv2d(p / "c1", n_stack, 32, 8, Self::stride(4)))
28 .add_fn(|xs| xs.relu())
29 .add(nn::conv2d(p / "c2", 32, 64, 4, Self::stride(2)))
30 .add_fn(|xs| xs.relu())
31 .add(nn::conv2d(p / "c3", 64, 64, 3, Self::stride(1)))
32 .add_fn(|xs| xs.relu().flat_view())
33 .add(nn::linear(p / "l1", 3136, 512, Default::default()))
34 .add_fn(|xs| xs.relu())
35 .add(nn::linear(p / "l2", 512, out_dim as _, Default::default()))
36 }
37
38 fn create_net_wo_linear(var_store: &nn::VarStore, n_stack: i64) -> nn::Sequential {
39 let p = &var_store.root();
40 nn::seq()
41 .add_fn(|xs| xs.squeeze_dim(2).internal_cast_float(true) / 255)
42 .add(nn::conv2d(p / "c1", n_stack, 32, 8, Self::stride(4)))
43 .add_fn(|xs| xs.relu())
44 .add(nn::conv2d(p / "c2", 32, 64, 4, Self::stride(2)))
45 .add_fn(|xs| xs.relu())
46 .add(nn::conv2d(p / "c3", 64, 64, 3, Self::stride(1)))
47 .add_fn(|xs| xs.relu().flat_view())
48 }
49}
50
51impl SubModel for AtariCnn {
52 type Config = AtariCnnConfig;
53 type Input = Tensor;
54 type Output = Tensor;
55
56 fn forward(&self, x: &Self::Input) -> Tensor {
57 self.seq.forward(&x.to(self.device))
58 }
59
60 fn build(var_store: &nn::VarStore, config: Self::Config) -> Self {
61 let n_stack = config.n_stack;
62 let out_dim = config.out_dim;
63 let device = var_store.device();
64 let skip_linear = config.skip_linear;
65 let seq = if config.skip_linear {
66 Self::create_net_wo_linear(var_store, n_stack)
67 } else {
68 Self::create_net(var_store, n_stack, out_dim)
69 };
70
71 Self {
72 n_stack,
73 out_dim,
74 device,
75 seq,
76 skip_linear,
77 }
78 }
79
80 fn clone_with_var_store(&self, var_store: &nn::VarStore) -> Self {
81 let n_stack = self.n_stack;
82 let out_dim = self.out_dim;
83 let skip_linear = self.skip_linear;
84 let device = var_store.device();
85 let seq = if skip_linear {
86 Self::create_net_wo_linear(&var_store, n_stack)
87 } else {
88 Self::create_net(&var_store, n_stack, out_dim)
89 };
90
91 Self {
92 n_stack,
93 out_dim,
94 device,
95 seq,
96 skip_linear,
97 }
98 }
99}