border_core/trainer/
sampler.rs

1//! Samples transitions and pushes them into a replay buffer.
2use crate::{record::Record, Agent, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor};
3use anyhow::Result;
4
5/// Encapsulates sampling steps. Specifically it does the followint steps:
6///
7/// 1. Samples an action from the [`Agent`], apply to the [`Env`] and takes [`Step`].
8/// 2. Convert [`Step`] into a transition (typically a batch) with [`StepProcessor`].
9/// 3. Pushes the trainsition to [`ReplayBufferBase`].
10/// 4. Count episode length and pushes to [`Record`].
11///
12/// TODO: being able to set `interval_env_record`
13///
14/// [`Step`]: crate::Step
15/// [`StepProcessor`]: crate::StepProcessor
16pub struct Sampler<E, P>
17where
18    E: Env,
19    P: StepProcessor<E>,
20{
21    env: E,
22    prev_obs: Option<E::Obs>,
23    step_processor: P,
24    /// Number of environment steps for counting frames per second.
25    n_env_steps_for_fps: usize,
26
27    /// Total time of takes n_frames.
28    time: f32,
29
30    /// Number of environment steps in an episode.
31    n_env_steps_in_episode: usize,
32
33    /// Total number of environment steps.
34    n_env_steps_total: usize,
35
36    /// Interval of recording from the environment in environment steps.
37    ///
38    /// Default to None (record from environment discarded)
39    interval_env_record: Option<usize>,
40}
41
42impl<E, P> Sampler<E, P>
43where
44    E: Env,
45    P: StepProcessor<E>,
46{
47    /// Creates a sampler.
48    pub fn new(env: E, step_processor: P) -> Self {
49        Self {
50            env,
51            prev_obs: None,
52            step_processor,
53            n_env_steps_for_fps: 0,
54            time: 0f32,
55            n_env_steps_in_episode: 0,
56            n_env_steps_total: 0,
57            interval_env_record: None,
58        }
59    }
60
61    /// Samples transitions and pushes them into the replay buffer.
62    ///
63    /// The replay buffer `R_`, to which samples will be pushed, has to accept
64    /// `Item` that are the same with `Agent::R`.
65    pub fn sample_and_push<A, R, R_>(&mut self, agent: &mut A, buffer: &mut R_) -> Result<Record>
66    where
67        A: Agent<E, R>,
68        R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
69        R_: ExperienceBufferBase<Item = R::Item>,
70    {
71        let now = std::time::SystemTime::now();
72
73        // Reset environment(s) if required
74        if self.prev_obs.is_none() {
75            // For a vectorized environments, reset all environments in `env`
76            // by giving `None` to reset() method
77            self.prev_obs = Some(self.env.reset(None)?);
78            self.step_processor
79                .reset(self.prev_obs.as_ref().unwrap().clone());
80        }
81
82        // Sample an action and apply it to the environment
83        let (step, mut record, is_done) = {
84            let act = agent.sample(self.prev_obs.as_ref().unwrap());
85            let (step, mut record) = self.env.step_with_reset(&act);
86            self.n_env_steps_in_episode += 1;
87            self.n_env_steps_total += 1;
88            let is_done = step.is_done(); // not support vectorized env
89            if let Some(interval) = &self.interval_env_record {
90                if self.n_env_steps_total % interval != 0 {
91                    record = Record::empty();
92                }
93            } else {
94                record = Record::empty();
95            }
96            (step, record, is_done)
97        };
98
99        // Update previouos observation
100        self.prev_obs = match is_done {
101            true => Some(step.init_obs.clone()),
102            false => Some(step.obs.clone()),
103        };
104
105        // Produce transition
106        let transition = self.step_processor.process(step);
107
108        // Push transition
109        buffer.push(transition)?;
110
111        // Reset step processor
112        if is_done {
113            self.step_processor
114                .reset(self.prev_obs.as_ref().unwrap().clone());
115            record.insert(
116                "episode_length",
117                crate::record::RecordValue::Scalar(self.n_env_steps_in_episode as _),
118            );
119            self.n_env_steps_in_episode = 0;
120        }
121
122        // Count environment steps
123        if let Ok(time) = now.elapsed() {
124            self.n_env_steps_for_fps += 1;
125            self.time += time.as_millis() as f32;
126        }
127
128        Ok(record)
129    }
130
131    /// Returns frames (environment steps) per second, then resets the internal counter.
132    ///
133    /// A frame involves taking action, applying it to the environment,
134    /// producing transition, and pushing it into the replay buffer.
135    pub fn fps(&mut self) -> f32 {
136        if self.time == 0f32 {
137            0f32
138        } else {
139            let fps = self.n_env_steps_for_fps as f32 / self.time * 1000f32;
140            self.reset_fps_counter();
141            fps
142        }
143    }
144
145    /// Reset stats for computing FPS.
146    pub fn reset_fps_counter(&mut self) {
147        self.n_env_steps_for_fps = 0;
148        self.time = 0f32;
149    }
150}