border_tch_agent/cnn/
base.rs

1use super::CnnConfig;
2use crate::model::SubModel;
3use tch::{nn, nn::Module, Device, Tensor};
4
5#[allow(clippy::upper_case_acronyms)]
6/// Convolutional neural network, which has the same architecture of the DQN paper.
7pub struct Cnn {
8    n_stack: i64,
9    out_dim: i64,
10    device: Device,
11    seq: nn::Sequential,
12    skip_linear: bool,
13}
14
15impl Cnn {
16    fn stride(s: i64) -> nn::ConvConfig {
17        nn::ConvConfig {
18            stride: s,
19            ..Default::default()
20        }
21    }
22
23    fn create_net(var_store: &nn::VarStore, n_stack: i64, out_dim: i64) -> nn::Sequential {
24        let p = &var_store.root();
25        nn::seq()
26            .add_fn(|xs| xs.squeeze_dim(2).internal_cast_float(true) / 255)
27            .add(nn::conv2d(p / "c1", n_stack, 32, 8, Self::stride(4)))
28            .add_fn(|xs| xs.relu())
29            .add(nn::conv2d(p / "c2", 32, 64, 4, Self::stride(2)))
30            .add_fn(|xs| xs.relu())
31            .add(nn::conv2d(p / "c3", 64, 64, 3, Self::stride(1)))
32            .add_fn(|xs| xs.relu().flat_view())
33            .add(nn::linear(p / "l1", 3136, 512, Default::default()))
34            .add_fn(|xs| xs.relu())
35            .add(nn::linear(p / "l2", 512, out_dim as _, Default::default()))
36    }
37
38    fn create_net_wo_linear(var_store: &nn::VarStore, n_stack: i64) -> nn::Sequential {
39        let p = &var_store.root();
40        nn::seq()
41            .add_fn(|xs| xs.squeeze_dim(2).internal_cast_float(true) / 255)
42            .add(nn::conv2d(p / "c1", n_stack, 32, 8, Self::stride(4)))
43            .add_fn(|xs| xs.relu())
44            .add(nn::conv2d(p / "c2", 32, 64, 4, Self::stride(2)))
45            .add_fn(|xs| xs.relu())
46            .add(nn::conv2d(p / "c3", 64, 64, 3, Self::stride(1)))
47            .add_fn(|xs| xs.relu().flat_view())
48    }
49}
50
51impl SubModel for Cnn {
52    type Config = CnnConfig;
53    type Input = Tensor;
54    type Output = Tensor;
55
56    fn forward(&self, x: &Self::Input) -> Tensor {
57        self.seq.forward(&x.to(self.device))
58    }
59
60    fn build(var_store: &nn::VarStore, config: Self::Config) -> Self {
61        let n_stack = config.n_stack;
62        let out_dim = config.out_dim;
63        let device = var_store.device();
64        let skip_linear = config.skip_linear;
65        let seq = if config.skip_linear {
66            Self::create_net_wo_linear(var_store, n_stack)
67        } else {
68            Self::create_net(var_store, n_stack, out_dim)
69        };
70
71        // // Debug: check weight scale
72        // for (k, v) in var_store.variables() {
73        //     if k.starts_with("c") {
74        //         let m: f32 = v.mean(tch::Kind::Float).into();
75        //         let s: f32 = v.std(false).into();
76        //         println!("{}: mean={}, std={}", k, m, s);
77        //     }
78        // }
79        // panic!();
80
81        Self {
82            n_stack,
83            out_dim,
84            device,
85            seq,
86            skip_linear,
87        }
88    }
89
90    fn clone_with_var_store(&self, var_store: &nn::VarStore) -> Self {
91        let n_stack = self.n_stack;
92        let out_dim = self.out_dim;
93        let skip_linear = self.skip_linear;
94        let device = var_store.device();
95        let seq = if skip_linear {
96            Self::create_net_wo_linear(&var_store, n_stack)
97        } else {
98            Self::create_net(&var_store, n_stack, out_dim)
99        };
100
101        Self {
102            n_stack,
103            out_dim,
104            device,
105            seq,
106            skip_linear,
107        }
108    }
109}