ember_rl/algorithms/dqn/network.rs
1use burn::prelude::*;
2use burn::nn::{Linear, LinearConfig, Relu};
3
4/// A feedforward Q-network.
5///
6/// Maps observations to Q-values: `[batch, obs_size] -> [batch, num_actions]`
7///
8/// Architecture: fully connected layers with ReLU activations between them.
9/// Layer sizes are determined at construction from `DqnConfig::hidden_sizes`.
10///
11/// This is a Burn `Module`, meaning it owns its parameters and can be
12/// serialised, cloned for the target network, and updated by an optimiser.
13///
14/// # Target network
15///
16/// DQN requires two copies of the network: the online network (updated every
17/// step) and the target network (periodically synced from online). Burn's
18/// `Module::clone()` gives us the target network for free — it performs a
19/// deep clone of all parameters.
20#[derive(Module, Debug)]
21pub struct QNetwork<B: Backend> {
22 layers: Vec<Linear<B>>,
23 activation: Relu,
24}
25
26impl<B: Backend> QNetwork<B> {
27 /// Build a Q-network with the given layer sizes.
28 ///
29 /// `layer_sizes` should be the full sequence from input to output:
30 /// `[obs_size, hidden_0, hidden_1, ..., num_actions]`
31 pub fn new(layer_sizes: &[usize], device: &B::Device) -> Self {
32 assert!(layer_sizes.len() >= 2, "need at least input and output sizes");
33
34 let layers = layer_sizes
35 .windows(2)
36 .map(|w| LinearConfig::new(w[0], w[1]).init(device))
37 .collect();
38
39 Self {
40 layers,
41 activation: Relu::new(),
42 }
43 }
44
45 /// Forward pass: observation batch -> Q-values.
46 ///
47 /// Input: `[batch_size, obs_size]`
48 /// Output: `[batch_size, num_actions]`
49 ///
50 /// ReLU is applied between all layers except the final output layer,
51 /// which is linear (Q-values are unbounded).
52 pub fn forward(&self, x: Tensor<B, 2>) -> Tensor<B, 2> {
53 let last_idx = self.layers.len() - 1;
54 let mut out = x;
55
56 for (i, layer) in self.layers.iter().enumerate() {
57 out = layer.forward(out);
58 if i < last_idx {
59 out = self.activation.forward(out);
60 }
61 }
62
63 out
64 }
65}