ember_rl/algorithms/dqn/
inference.rs1use std::marker::PhantomData;
2use std::path::Path;
3
4use burn::prelude::*;
5use burn::record::{CompactRecorder, RecorderError};
6use rl_traits::{Environment, Policy};
7
8use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
9use super::config::DqnConfig;
10use super::network::QNetwork;
11
12pub struct DqnPolicy<E, Enc, Act, B: Backend> {
33 net: QNetwork<B>,
34 encoder: Enc,
35 action_mapper: Act,
36 device: B::Device,
37 _env: PhantomData<E>,
38}
39
40impl<E, Enc, Act, B> DqnPolicy<E, Enc, Act, B>
41where
42 E: Environment,
43 Enc: ObservationEncoder<E::Observation, B>,
44 Act: DiscreteActionMapper<E::Action>,
45 B: Backend,
46{
47 pub fn new(encoder: Enc, action_mapper: Act, config: &DqnConfig, device: B::Device) -> Self {
51 let obs_size = encoder.obs_size();
52 let num_actions = action_mapper.num_actions();
53
54 let mut layer_sizes = vec![obs_size];
55 layer_sizes.extend_from_slice(&config.hidden_sizes);
56 layer_sizes.push(num_actions);
57
58 let net = QNetwork::new(&layer_sizes, &device);
59
60 Self {
61 net,
62 encoder,
63 action_mapper,
64 device,
65 _env: PhantomData,
66 }
67 }
68
69 pub fn from_network(net: QNetwork<B>, encoder: Enc, action_mapper: Act, device: B::Device) -> Self {
74 Self {
75 net,
76 encoder,
77 action_mapper,
78 device,
79 _env: PhantomData,
80 }
81 }
82
83 pub fn load(mut self, path: impl AsRef<Path>) -> Result<Self, RecorderError> {
88 self.net = self
89 .net
90 .load_file(path.as_ref().to_path_buf(), &CompactRecorder::new(), &self.device)?;
91 Ok(self)
92 }
93}
94
95impl<E, Enc, Act, B> Policy<E::Observation, E::Action> for DqnPolicy<E, Enc, Act, B>
96where
97 E: Environment,
98 Enc: ObservationEncoder<E::Observation, B>,
99 Act: DiscreteActionMapper<E::Action>,
100 B: Backend,
101{
102 fn act(&self, obs: &E::Observation) -> E::Action {
103 let obs_tensor = self.encoder.encode(obs, &self.device).unsqueeze_dim(0);
104 let q_values = self.net.forward(obs_tensor);
105 let best_idx = q_values
106 .argmax(1)
107 .into_data()
108 .to_vec::<i64>()
109 .unwrap()[0] as usize;
110 self.action_mapper.index_to_action(best_idx)
111 }
112}