Skip to main content

ember_rl/algorithms/dqn/
network.rs

1use burn::prelude::*;
2use burn::nn::{Linear, LinearConfig, Relu};
3
4/// A feedforward Q-network.
5///
6/// Maps observations to Q-values: `[batch, obs_size] -> [batch, num_actions]`
7///
8/// Architecture: fully connected layers with ReLU activations between them.
9/// Layer sizes are determined at construction from `DqnConfig::hidden_sizes`.
10///
11/// This is a Burn `Module`, meaning it owns its parameters and can be
12/// serialised, cloned for the target network, and updated by an optimiser.
13///
14/// # Target network
15///
16/// DQN requires two copies of the network: the online network (updated every
17/// step) and the target network (periodically synced from online). Burn's
18/// `Module::clone()` gives us the target network for free -- it performs a
19/// deep clone of all parameters.
20#[derive(Module, Debug)]
21pub struct QNetwork<B: Backend> {
22    layers: Vec<Linear<B>>,
23    activation: Relu,
24}
25
26impl<B: Backend> QNetwork<B> {
27    /// Build a Q-network with the given layer sizes.
28    ///
29    /// `layer_sizes` should be the full sequence from input to output:
30    /// `[obs_size, hidden_0, hidden_1, ..., num_actions]`
31    pub fn new(layer_sizes: &[usize], device: &B::Device) -> Self {
32        assert!(layer_sizes.len() >= 2, "need at least input and output sizes");
33
34        let layers = layer_sizes
35            .windows(2)
36            .map(|w| LinearConfig::new(w[0], w[1]).init(device))
37            .collect();
38
39        Self {
40            layers,
41            activation: Relu::new(),
42        }
43    }
44
45    /// Forward pass: observation batch -> Q-values.
46    ///
47    /// Input:  `[batch_size, obs_size]`
48    /// Output: `[batch_size, num_actions]`
49    ///
50    /// ReLU is applied between all layers except the final output layer,
51    /// which is linear (Q-values are unbounded).
52    pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
53        let last_idx = self.layers.len() - 1;
54        let mut out = x;
55
56        for (i, layer) in self.layers.iter().enumerate() {
57            out = layer.forward(out);
58            if i < last_idx {
59                out = self.activation.forward(out);
60            }
61        }
62
63        out
64    }
65}