Skip to main content

ember_rl/algorithms/dqn/
inference.rs

1use std::marker::PhantomData;
2use std::path::Path;
3
4use burn::prelude::*;
5use burn::record::{CompactRecorder, RecorderError};
6use rl_traits::{Environment, Policy};
7
8use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
9use super::config::DqnConfig;
10use super::network::QNetwork;
11
12/// A DQN agent in inference-only mode.
13///
14/// Holds just the Q-network, encoder, and action mapper -- no optimizer,
15/// no replay buffer, no exploration. Requires only `B: Backend` (not
16/// `AutodiffBackend`), so it can run on plain `NdArray` without any
17/// autodiff overhead.
18///
19/// Load from a checkpoint saved by `DqnAgent::save`:
20///
21/// ```rust,ignore
22/// use burn::backend::NdArray;
23///
24/// let policy = DqnPolicy::<CartPoleEnv, _, _, NdArray>::new(
25///     VecEncoder::new(4),
26///     UsizeActionMapper::new(2),
27///     &config,
28///     device,
29/// )
30/// .load("cartpole_dqn")?;
31/// ```
32pub struct DqnPolicy<E, Enc, Act, B: Backend> {
33    net: QNetwork<B>,
34    encoder: Enc,
35    action_mapper: Act,
36    device: B::Device,
37    _env: PhantomData<E>,
38}
39
40impl<E, Enc, Act, B> DqnPolicy<E, Enc, Act, B>
41where
42    E: Environment,
43    Enc: ObservationEncoder<E::Observation, B>,
44    Act: DiscreteActionMapper<E::Action>,
45    B: Backend,
46{
47    /// Build an uninitialised policy with the given architecture.
48    ///
49    /// The network weights are random until `load` is called.
50    pub fn new(encoder: Enc, action_mapper: Act, config: &DqnConfig, device: B::Device) -> Self {
51        let obs_size = encoder.obs_size();
52        let num_actions = action_mapper.num_actions();
53
54        let mut layer_sizes = vec![obs_size];
55        layer_sizes.extend_from_slice(&config.hidden_sizes);
56        layer_sizes.push(num_actions);
57
58        let net = QNetwork::new(&layer_sizes, &device);
59
60        Self {
61            net,
62            encoder,
63            action_mapper,
64            device,
65            _env: PhantomData,
66        }
67    }
68
69    /// Create a policy directly from a pre-built network.
70    ///
71    /// Used by `DqnAgent::into_policy()` to convert a trained agent into
72    /// an inference-only policy without touching the disk.
73    pub fn from_network(net: QNetwork<B>, encoder: Enc, action_mapper: Act, device: B::Device) -> Self {
74        Self {
75            net,
76            encoder,
77            action_mapper,
78            device,
79            _env: PhantomData,
80        }
81    }
82
83    /// Load network weights from a checkpoint file.
84    ///
85    /// The checkpoint must have been saved with `DqnAgent::save` (`.mpk` format). The
86    /// architecture (hidden sizes) must match exactly.
87    pub fn load(mut self, path: impl AsRef<Path>) -> Result<Self, RecorderError> {
88        self.net = self
89            .net
90            .load_file(path.as_ref().to_path_buf(), &CompactRecorder::new(), &self.device)?;
91        Ok(self)
92    }
93}
94
95impl<E, Enc, Act, B> Policy<E::Observation, E::Action> for DqnPolicy<E, Enc, Act, B>
96where
97    E: Environment,
98    Enc: ObservationEncoder<E::Observation, B>,
99    Act: DiscreteActionMapper<E::Action>,
100    B: Backend,
101{
102    fn act(&self, obs: &E::Observation) -> E::Action {
103        let obs_tensor = self.encoder.encode(obs, &self.device).unsqueeze_dim(0);
104        let q_values = self.net.forward(obs_tensor);
105        let best_idx = q_values
106            .argmax(1)
107            .into_data()
108            .to_vec::<i64>()
109            .unwrap()[0] as usize;
110        self.action_mapper.index_to_action(best_idx)
111    }
112}