border_core/trainer/sampler.rs
1//! Experience sampling and replay buffer management.
2//!
3//! This module provides functionality for sampling experiences from the environment
4//! and storing them in a replay buffer. It handles the interaction between the agent,
5//! environment, and replay buffer, while also tracking performance metrics.
6//!
7//! # Sampling Process
8//!
9//! The sampling process involves:
10//!
11//! 1. Environment Interaction:
12//! * Agent observes environment state
13//! * Agent selects and executes action
14//! * Environment transitions to new state
15//!
16//! 2. Experience Processing:
17//! * Convert environment step into transition
18//! * Store transition in replay buffer
19//! * Track episode length and performance metrics
20//!
21//! 3. Performance Monitoring:
22//! * Monitor episode length
23//! * Record environment metrics
24use crate::{record::Record, Agent, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor};
25use anyhow::Result;
26
27/// Manages the sampling of experiences from the environment.
28///
29/// This struct handles the interaction between the agent and environment,
30/// processes the resulting experiences, and stores them in a replay buffer.
31/// It also tracks various metrics about the sampling process.
32///
33/// # Type Parameters
34///
35/// * `E` - The environment type
36/// * `P` - The step processor type
37pub struct Sampler<E, P>
38where
39 E: Env,
40 P: StepProcessor<E>,
41{
42 /// The environment being sampled from
43 env: E,
44
45 /// Previous observation from the environment
46 prev_obs: Option<E::Obs>,
47
48 /// Processor for converting steps into transitions
49 step_processor: P,
50}
51
52impl<E, P> Sampler<E, P>
53where
54 E: Env,
55 P: StepProcessor<E>,
56{
57 /// Creates a new sampler with the given environment and step processor.
58 ///
59 /// # Arguments
60 ///
61 /// * `env` - The environment to sample from
62 /// * `step_processor` - The processor for converting steps into transitions
63 ///
64 /// # Returns
65 ///
66 /// A new `Sampler` instance
67 pub fn new(env: E, step_processor: P) -> Self {
68 Self {
69 env,
70 prev_obs: None,
71 step_processor,
72 }
73 }
74
75 /// Samples an experience and pushes it to the replay buffer.
76 ///
77 /// This method:
78 /// 1. Resets the environment if needed
79 /// 2. Samples an action from the agent
80 /// 3. Applies the action to the environment
81 /// 4. Processes the resulting step
82 /// 5. Stores the experience in the replay buffer
83 ///
84 /// # Arguments
85 ///
86 /// * `agent` - The agent to sample actions from
87 /// * `buffer` - The replay buffer to store experiences in
88 ///
89 /// # Returns
90 ///
91 /// A `Record` containing metrics about the sampling process
92 ///
93 /// # Errors
94 ///
95 /// Returns an error if:
96 /// * The environment fails to reset
97 /// * The environment step fails
98 /// * The replay buffer operation fails
99 pub fn sample_and_push<R, R_>(
100 &mut self,
101 agent: &mut Box<dyn Agent<E, R>>,
102 buffer: &mut R_,
103 ) -> Result<Record>
104 where
105 R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
106 R_: ExperienceBufferBase<Item = R::Item>,
107 {
108 // Reset environment(s) if required
109 if self.prev_obs.is_none() {
110 // For a vectorized environments, reset all environments in `env`
111 // by giving `None` to reset() method
112 self.prev_obs = Some(self.env.reset(None)?);
113 self.step_processor
114 .reset(self.prev_obs.as_ref().unwrap().clone());
115 }
116
117 // Sample an action and apply it to the environment
118 let (step, record, is_done) = {
119 let act = agent.sample(self.prev_obs.as_ref().unwrap());
120 let (step, record) = self.env.step_with_reset(&act);
121 let is_done = step.is_done(); // not support vectorized env
122 (step, record, is_done)
123 };
124
125 // Update previouos observation
126 self.prev_obs = match is_done {
127 true => Some(step.init_obs.clone().expect("Failed to unwrap init_obs")),
128 false => Some(step.obs.clone()),
129 };
130
131 // Produce transition
132 let transition = self.step_processor.process(step);
133
134 // Push transition
135 buffer.push(transition)?;
136
137 // Reset step processor
138 if is_done {
139 self.step_processor
140 .reset(self.prev_obs.as_ref().unwrap().clone());
141 }
142
143 Ok(record)
144 }
145}