border_candle_agent/atari_cnn/
base.rs

1use super::AtariCnnConfig;
2use crate::model::SubModel1;
3use anyhow::Result;
4use candle_core::{DType::F32, Device, Tensor};
5use candle_nn::{
6    conv::Conv2dConfig,
7    conv2d, linear,
8    sequential::{seq, Sequential},
9    Module, VarBuilder,
10};
11
12#[allow(clippy::upper_case_acronyms)]
13#[allow(dead_code)]
14/// Convolutional neural network, which has the same architecture of the DQN paper.
15pub struct AtariCnn {
16    n_stack: i64,
17    out_dim: i64,
18    device: Device,
19    seq: Sequential,
20    skip_linear: bool,
21}
22
23impl AtariCnn {
24    fn stride(s: i64) -> Conv2dConfig {
25        Conv2dConfig {
26            stride: s as _,
27            ..Default::default()
28        }
29    }
30
31    fn create_net(vb: &VarBuilder, n_stack: i64, out_dim: i64) -> Result<Sequential> {
32        let seq = seq()
33            .add_fn(|xs| xs.squeeze(2)?.to_dtype(F32)? / 255.0)
34            .add(conv2d(n_stack as _, 32, 8, Self::stride(4), vb.pp("c1"))?)
35            .add_fn(|xs| xs.relu())
36            .add(conv2d(32, 64, 4, Self::stride(2), vb.pp("c2"))?)
37            .add_fn(|xs| xs.relu())
38            .add(conv2d(64, 64, 3, Self::stride(1), vb.pp("c3"))?)
39            .add_fn(|xs| xs.relu()?.flatten_from(1))
40            .add(linear(3136, 512, vb.pp("l1"))?)
41            .add_fn(|xs| xs.relu())
42            .add(linear(512, out_dim as _, vb.pp("l2"))?);
43
44        Ok(seq)
45    }
46
47    fn create_net_wo_linear(vb: &VarBuilder, n_stack: i64) -> Result<Sequential> {
48        let seq = seq()
49            .add_fn(|xs| xs.squeeze(2)?.to_dtype(F32)? / 255.0)
50            .add(conv2d(n_stack as _, 32, 8, Self::stride(4), vb.pp("c1"))?)
51            .add_fn(|xs| xs.relu())
52            .add(conv2d(32, 64, 4, Self::stride(2), vb.pp("c2"))?)
53            .add_fn(|xs| xs.relu())
54            .add(conv2d(64, 64, 3, Self::stride(1), vb.pp("c3"))?)
55            .add_fn(|xs| xs.relu()?.flatten_from(1));
56
57        Ok(seq)
58    }
59}
60
61impl SubModel1 for AtariCnn {
62    type Config = AtariCnnConfig;
63    type Input = Tensor;
64    type Output = Tensor;
65
66    fn forward(&self, x: &Self::Input) -> Tensor {
67        self.seq
68            .forward(&x.to_device(&self.device).unwrap())
69            .unwrap()
70    }
71
72    fn build(vb: VarBuilder, config: Self::Config) -> Self {
73        let n_stack = config.n_stack;
74        let out_dim = config.out_dim;
75        let device = vb.device().clone();
76        let skip_linear = config.skip_linear;
77        let seq = if config.skip_linear {
78            Self::create_net_wo_linear(&vb, n_stack)
79        } else {
80            Self::create_net(&vb, n_stack, out_dim)
81        }
82        .unwrap();
83
84        Self {
85            n_stack,
86            out_dim,
87            device,
88            seq,
89            skip_linear,
90        }
91    }
92}