border_core/trainer/
sampler.rs1use crate::{record::Record, Agent, Env, ExperienceBufferBase, ReplayBufferBase, StepProcessor};
3use anyhow::Result;
4
5pub struct Sampler<E, P>
17where
18 E: Env,
19 P: StepProcessor<E>,
20{
21 env: E,
22 prev_obs: Option<E::Obs>,
23 step_processor: P,
24 n_env_steps_for_fps: usize,
26
27 time: f32,
29
30 n_env_steps_in_episode: usize,
32
33 n_env_steps_total: usize,
35
36 interval_env_record: Option<usize>,
40}
41
42impl<E, P> Sampler<E, P>
43where
44 E: Env,
45 P: StepProcessor<E>,
46{
47 pub fn new(env: E, step_processor: P) -> Self {
49 Self {
50 env,
51 prev_obs: None,
52 step_processor,
53 n_env_steps_for_fps: 0,
54 time: 0f32,
55 n_env_steps_in_episode: 0,
56 n_env_steps_total: 0,
57 interval_env_record: None,
58 }
59 }
60
61 pub fn sample_and_push<A, R, R_>(&mut self, agent: &mut A, buffer: &mut R_) -> Result<Record>
66 where
67 A: Agent<E, R>,
68 R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
69 R_: ExperienceBufferBase<Item = R::Item>,
70 {
71 let now = std::time::SystemTime::now();
72
73 if self.prev_obs.is_none() {
75 self.prev_obs = Some(self.env.reset(None)?);
78 self.step_processor
79 .reset(self.prev_obs.as_ref().unwrap().clone());
80 }
81
82 let (step, mut record, is_done) = {
84 let act = agent.sample(self.prev_obs.as_ref().unwrap());
85 let (step, mut record) = self.env.step_with_reset(&act);
86 self.n_env_steps_in_episode += 1;
87 self.n_env_steps_total += 1;
88 let is_done = step.is_done(); if let Some(interval) = &self.interval_env_record {
90 if self.n_env_steps_total % interval != 0 {
91 record = Record::empty();
92 }
93 } else {
94 record = Record::empty();
95 }
96 (step, record, is_done)
97 };
98
99 self.prev_obs = match is_done {
101 true => Some(step.init_obs.clone()),
102 false => Some(step.obs.clone()),
103 };
104
105 let transition = self.step_processor.process(step);
107
108 buffer.push(transition)?;
110
111 if is_done {
113 self.step_processor
114 .reset(self.prev_obs.as_ref().unwrap().clone());
115 record.insert(
116 "episode_length",
117 crate::record::RecordValue::Scalar(self.n_env_steps_in_episode as _),
118 );
119 self.n_env_steps_in_episode = 0;
120 }
121
122 if let Ok(time) = now.elapsed() {
124 self.n_env_steps_for_fps += 1;
125 self.time += time.as_millis() as f32;
126 }
127
128 Ok(record)
129 }
130
131 pub fn fps(&mut self) -> f32 {
136 if self.time == 0f32 {
137 0f32
138 } else {
139 let fps = self.n_env_steps_for_fps as f32 / self.time * 1000f32;
140 self.reset_fps_counter();
141 fps
142 }
143 }
144
145 pub fn reset_fps_counter(&mut self) {
147 self.n_env_steps_for_fps = 0;
148 self.time = 0f32;
149 }
150}