mod config;
mod sampler;
use std::time::{Duration, SystemTime};
use crate::{
record::{Record, RecordValue::Scalar, Recorder},
Agent, Env, Evaluator, ExperienceBufferBase, ReplayBufferBase, StepProcessor,
};
use anyhow::Result;
pub use config::TrainerConfig;
use log::info;
pub use sampler::Sampler;
pub struct Trainer {
opt_interval: usize,
record_compute_cost_interval: usize,
record_agent_info_interval: usize,
flush_records_interval: usize,
eval_interval: usize,
save_interval: usize,
max_opts: usize,
warmup_period: usize,
samples_counter: usize,
timer_for_samples: Duration,
opt_steps_counter: usize,
timer_for_opt_steps: Duration,
max_eval_reward: f32,
env_steps: usize,
opt_steps: usize,
}
impl Trainer {
pub fn build(config: TrainerConfig) -> Self {
Self {
opt_interval: config.opt_interval,
record_compute_cost_interval: config.record_compute_cost_interval,
record_agent_info_interval: config.record_agent_info_interval,
flush_records_interval: config.flush_record_interval,
eval_interval: config.eval_interval,
save_interval: config.save_interval,
max_opts: config.max_opts,
warmup_period: config.warmup_period,
samples_counter: 0,
timer_for_samples: Duration::new(0, 0),
opt_steps_counter: 0,
timer_for_opt_steps: Duration::new(0, 0),
max_eval_reward: f32::MIN,
env_steps: 0,
opt_steps: 0,
}
}
fn reset_counters(&mut self) {
self.samples_counter = 0;
self.timer_for_samples = Duration::new(0, 0);
self.opt_steps_counter = 0;
self.timer_for_opt_steps = Duration::new(0, 0);
}
fn average_time(&mut self) -> (f32, f32) {
let avr_opt_time = match self.opt_steps_counter {
0 => -1f32,
n => self.timer_for_opt_steps.as_millis() as f32 / n as f32,
};
let avr_sample_time = match self.samples_counter {
0 => -1f32,
n => self.timer_for_samples.as_millis() as f32 / n as f32,
};
(avr_opt_time, avr_sample_time)
}
pub fn train_step<E, R>(
&mut self,
agent: &mut Box<dyn Agent<E, R>>,
buffer: &mut R,
) -> Result<(Record, bool)>
where
E: Env,
R: ReplayBufferBase,
{
if self.env_steps < self.warmup_period {
Ok((Record::empty(), false))
} else if self.env_steps % self.opt_interval != 0 {
Ok((Record::empty(), false))
} else if (self.opt_steps + 1) % self.record_agent_info_interval == 0 {
let timer = SystemTime::now();
let record_agent = agent.opt_with_record(buffer);
self.opt_steps += 1;
self.timer_for_opt_steps += timer.elapsed()?;
self.opt_steps_counter += 1;
Ok((record_agent, true))
} else {
let timer = SystemTime::now();
agent.opt(buffer);
self.opt_steps += 1;
self.timer_for_opt_steps += timer.elapsed()?;
self.opt_steps_counter += 1;
Ok((Record::empty(), true))
}
}
fn post_process<E, R, D>(
&mut self,
agent: &mut Box<dyn Agent<E, R>>,
evaluator: &mut D,
recorder: &mut Box<dyn Recorder<E, R>>,
record: &mut Record,
) -> Result<()>
where
E: Env,
R: ReplayBufferBase,
D: Evaluator<E>,
{
if self.opt_steps % self.eval_interval == 0 {
info!("Starts evaluation of the trained model");
agent.eval();
let (score, record_eval) = evaluator.evaluate(agent)?;
agent.train();
record.merge_inplace(record_eval);
if score > self.max_eval_reward {
self.max_eval_reward = score;
recorder.save_model("best".as_ref(), agent)?;
}
};
if (self.save_interval > 0) && (self.opt_steps % self.save_interval == 0) {
recorder.save_model(format!("{}", self.opt_steps).as_ref(), agent)?;
}
Ok(())
}
pub fn train<E, P, R, D>(
&mut self,
env: E,
step_proc: P,
agent: &mut Box<dyn Agent<E, R>>,
buffer: &mut R,
recorder: &mut Box<dyn Recorder<E, R>>,
evaluator: &mut D,
) -> Result<()>
where
E: Env,
P: StepProcessor<E>,
R: ExperienceBufferBase<Item = P::Output> + ReplayBufferBase,
D: Evaluator<E>,
{
let mut sampler = Sampler::new(env, step_proc);
agent.train();
loop {
let now = SystemTime::now();
let record = sampler.sample_and_push(agent, buffer)?;
self.timer_for_samples += now.elapsed()?;
self.samples_counter += 1;
self.env_steps += 1;
let (mut record, is_opt) = {
let (r, is_opt) = self.train_step(agent, buffer)?;
(record.merge(r), is_opt)
};
if is_opt {
self.post_process(agent, evaluator, recorder, &mut record)?;
}
if self.opt_steps % self.record_compute_cost_interval == 0 {
let (avr_opt_time, avr_sample_time) = self.average_time();
record.insert("average_opt_time", Scalar(avr_opt_time));
record.insert("average_sample_time", Scalar(avr_sample_time));
self.reset_counters();
}
if !record.is_empty() {
recorder.store(record);
}
if is_opt && ((self.opt_steps - 1) % self.flush_records_interval == 0) {
recorder.flush(self.opt_steps as _);
}
if self.opt_steps == self.max_opts {
return Ok(());
}
}
}
pub fn train_offline<E, R, D>(
&mut self,
agent: &mut Box<dyn Agent<E, R>>,
buffer: &mut R,
recorder: &mut Box<dyn Recorder<E, R>>,
evaluator: &mut D,
) -> Result<()>
where
E: Env,
R: ReplayBufferBase,
D: Evaluator<E>,
{
self.warmup_period = 0;
self.opt_interval = 1;
agent.train();
loop {
let record = Record::empty();
self.env_steps += 1;
let (mut record, is_opt) = {
let (r, is_opt) = self.train_step(agent, buffer)?;
(record.merge(r), is_opt)
};
if is_opt {
self.post_process(agent, evaluator, recorder, &mut record)?;
}
if self.opt_steps % self.record_compute_cost_interval == 0 {
let (avr_opt_time, _) = self.average_time();
record.insert("average_opt_time", Scalar(avr_opt_time));
self.reset_counters();
}
if !record.is_empty() {
recorder.store(record);
}
if is_opt && ((self.opt_steps - 1) % self.flush_records_interval == 0) {
recorder.flush(self.opt_steps as _);
}
if self.opt_steps == self.max_opts {
return Ok(());
}
}
}
}