use crate::{record::Record, Agent, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor};
use anyhow::Result;
pub struct Sampler<E, P>
where
E: Env,
P: StepProcessor<E>,
{
env: E,
prev_obs: Option<E::Obs>,
step_processor: P,
}
impl<E, P> Sampler<E, P>
where
E: Env,
P: StepProcessor<E>,
{
pub fn new(env: E, step_processor: P) -> Self {
Self {
env,
prev_obs: None,
step_processor,
}
}
pub fn sample_and_push<R, R_>(
&mut self,
agent: &mut Box<dyn Agent<E, R>>,
buffer: &mut R_,
) -> Result<Record>
where
R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
R_: ExperienceBufferBase<Item = R::Item>,
{
if self.prev_obs.is_none() {
self.prev_obs = Some(self.env.reset(None)?);
self.step_processor
.reset(self.prev_obs.as_ref().unwrap().clone());
}
let (step, record, is_done) = {
let act = agent.sample(self.prev_obs.as_ref().unwrap());
let (step, record) = self.env.step_with_reset(&act);
let is_done = step.is_done(); (step, record, is_done)
};
self.prev_obs = match is_done {
true => Some(step.init_obs.clone().expect("Failed to unwrap init_obs")),
false => Some(step.obs.clone()),
};
let transition = self.step_processor.process(step);
buffer.push(transition)?;
if is_done {
self.step_processor
.reset(self.prev_obs.as_ref().unwrap().clone());
}
Ok(record)
}
}