border_core/trainer/
sampler.rs

1//! Experience sampling and replay buffer management.
2//!
3//! This module provides functionality for sampling experiences from the environment
4//! and storing them in a replay buffer. It handles the interaction between the agent,
5//! environment, and replay buffer, while also tracking performance metrics.
6//!
7//! # Sampling Process
8//!
9//! The sampling process involves:
10//!
11//! 1. Environment Interaction:
12//!    * Agent observes environment state
13//!    * Agent selects and executes action
14//!    * Environment transitions to new state
15//!
16//! 2. Experience Processing:
17//!    * Convert environment step into transition
18//!    * Store transition in replay buffer
19//!    * Track episode length and performance metrics
20//!
21//! 3. Performance Monitoring:
22//!    * Monitor episode length
23//!    * Record environment metrics
24use crate::{record::Record, Agent, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor};
25use anyhow::Result;
26
27/// Manages the sampling of experiences from the environment.
28///
29/// This struct handles the interaction between the agent and environment,
30/// processes the resulting experiences, and stores them in a replay buffer.
31/// It also tracks various metrics about the sampling process.
32///
33/// # Type Parameters
34///
35/// * `E` - The environment type
36/// * `P` - The step processor type
37pub struct Sampler<E, P>
38where
39    E: Env,
40    P: StepProcessor<E>,
41{
42    /// The environment being sampled from
43    env: E,
44
45    /// Previous observation from the environment
46    prev_obs: Option<E::Obs>,
47
48    /// Processor for converting steps into transitions
49    step_processor: P,
50}
51
52impl<E, P> Sampler<E, P>
53where
54    E: Env,
55    P: StepProcessor<E>,
56{
57    /// Creates a new sampler with the given environment and step processor.
58    ///
59    /// # Arguments
60    ///
61    /// * `env` - The environment to sample from
62    /// * `step_processor` - The processor for converting steps into transitions
63    ///
64    /// # Returns
65    ///
66    /// A new `Sampler` instance
67    pub fn new(env: E, step_processor: P) -> Self {
68        Self {
69            env,
70            prev_obs: None,
71            step_processor,
72        }
73    }
74
75    /// Samples an experience and pushes it to the replay buffer.
76    ///
77    /// This method:
78    /// 1. Resets the environment if needed
79    /// 2. Samples an action from the agent
80    /// 3. Applies the action to the environment
81    /// 4. Processes the resulting step
82    /// 5. Stores the experience in the replay buffer
83    ///
84    /// # Arguments
85    ///
86    /// * `agent` - The agent to sample actions from
87    /// * `buffer` - The replay buffer to store experiences in
88    ///
89    /// # Returns
90    ///
91    /// A `Record` containing metrics about the sampling process
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if:
96    /// * The environment fails to reset
97    /// * The environment step fails
98    /// * The replay buffer operation fails
99    pub fn sample_and_push<R, R_>(
100        &mut self,
101        agent: &mut Box<dyn Agent<E, R>>,
102        buffer: &mut R_,
103    ) -> Result<Record>
104    where
105        R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
106        R_: ExperienceBufferBase<Item = R::Item>,
107    {
108        // Reset environment(s) if required
109        if self.prev_obs.is_none() {
110            // For a vectorized environments, reset all environments in `env`
111            // by giving `None` to reset() method
112            self.prev_obs = Some(self.env.reset(None)?);
113            self.step_processor
114                .reset(self.prev_obs.as_ref().unwrap().clone());
115        }
116
117        // Sample an action and apply it to the environment
118        let (step, record, is_done) = {
119            let act = agent.sample(self.prev_obs.as_ref().unwrap());
120            let (step, record) = self.env.step_with_reset(&act);
121            let is_done = step.is_done(); // not support vectorized env
122            (step, record, is_done)
123        };
124
125        // Update previouos observation
126        self.prev_obs = match is_done {
127            true => Some(step.init_obs.clone().expect("Failed to unwrap init_obs")),
128            false => Some(step.obs.clone()),
129        };
130
131        // Produce transition
132        let transition = self.step_processor.process(step);
133
134        // Push transition
135        buffer.push(transition)?;
136
137        // Reset step processor
138        if is_done {
139            self.step_processor
140                .reset(self.prev_obs.as_ref().unwrap().clone());
141        }
142
143        Ok(record)
144    }
145}