use std::marker::PhantomData;
use crate::{
AgentEnvAsyncLoop, AgentEnvLoop, AsyncAgentEnvLoopConfig, EvaluationItem,
EventProcessorTraining, MultiAgentEnvLoop, RLComponents, RLComponentsTypes, RLEvent,
RLEventProcessorType, RLStrategy,
};
use burn_core::{self as burn};
use burn_core::{config::Config, data::dataloader::Progress};
use burn_ndarray::NdArray;
use burn_rl::{AsyncPolicy, Policy, PolicyLearner, SliceAccess, TransitionBuffer};
#[derive(Config, Debug)]
pub struct OffPolicyConfig {
#[config(default = 1)]
pub num_envs: usize,
#[config(default = 1)]
pub autobatch_size: usize,
#[config(default = 1024)]
pub replay_buffer_size: usize,
#[config(default = 1)]
pub train_interval: usize,
#[config(default = 1)]
pub train_steps: usize,
#[config(default = 10_000)]
pub eval_interval: usize,
#[config(default = 1)]
pub eval_episodes: usize,
#[config(default = 32)]
pub train_batch_size: usize,
#[config(default = 0)]
pub warmup_steps: usize,
}
pub struct OffPolicyStrategy<RLC: RLComponentsTypes> {
config: OffPolicyConfig,
_components: PhantomData<RLC>,
}
impl<RLC: RLComponentsTypes> OffPolicyStrategy<RLC> {
pub fn new(config: OffPolicyConfig) -> Self {
Self {
config,
_components: PhantomData,
}
}
}
impl<RLC> RLStrategy<RLC> for OffPolicyStrategy<RLC>
where
RLC: RLComponentsTypes,
RLC::PolicyObs: SliceAccess<RLC::Backend>,
RLC::PolicyAction: SliceAccess<RLC::Backend>,
{
fn train_loop(
&self,
training_components: RLComponents<RLC>,
learner_agent: &mut RLC::LearningAgent,
starting_epoch: usize,
env_init: RLC::EnvInit,
) -> (RLC::Policy, RLEventProcessorType<RLC>) {
let mut event_processor = training_components.event_processor;
let mut checkpointer = training_components.checkpointer;
let num_steps_total = training_components.num_steps;
let mut env_runner = MultiAgentEnvLoop::<NdArray, RLC>::new(
self.config.num_envs,
env_init.clone(),
AsyncPolicy::new(
self.config.num_envs.min(self.config.autobatch_size),
learner_agent.policy(),
),
false,
false,
&Default::default(),
);
let runner_config = AsyncAgentEnvLoopConfig {
eval: true,
deterministic: true,
id: 0,
};
let mut env_runner_valid = AgentEnvAsyncLoop::<NdArray, RLC>::new(
env_init,
AsyncPolicy::new(1, learner_agent.policy()),
runner_config,
&Default::default(),
None,
None,
);
let device: <RLC::Backend as burn_core::prelude::Backend>::Device = Default::default();
let mut transition_buffer = TransitionBuffer::<
RLC::Backend,
RLC::PolicyObs,
RLC::PolicyAction,
>::new(self.config.replay_buffer_size, &device);
let mut valid_next = self.config.eval_interval + starting_epoch - 1;
let mut progress = Progress {
items_processed: starting_epoch,
items_total: num_steps_total,
};
let mut intermediary_update: Option<<RLC::Policy as Policy<RLC::Backend>>::PolicyState> =
None;
while progress.items_processed < num_steps_total {
if training_components.interrupter.should_stop() {
let reason = training_components
.interrupter
.get_message()
.unwrap_or(String::from("Reason unknown"));
log::info!("Training interrupted: {reason}");
break;
}
let previous_steps = progress.items_processed;
let items = env_runner.run_steps(
self.config.train_interval,
&mut event_processor,
&training_components.interrupter,
&mut progress,
);
for item in &items {
let t = &item.transition;
let state: RLC::PolicyObs = t.state.clone().into();
let next_state: RLC::PolicyObs = t.next_state.clone().into();
let action: RLC::PolicyAction = t.action.clone().into();
let reward = t.reward.to_data().to_vec::<f32>().unwrap()[0];
let done = t.done.to_data().to_vec::<f32>().unwrap()[0] > 0.5;
transition_buffer.push(state, next_state, action, reward, done);
}
if transition_buffer.len() >= self.config.train_batch_size
&& progress.items_processed >= self.config.warmup_steps
{
if let Some(ref u) = intermediary_update {
env_runner.update_policy(u.clone());
}
for _ in 0..self.config.train_steps {
let batch = transition_buffer.sample(self.config.train_batch_size);
let train_item = learner_agent.train(batch);
intermediary_update = Some(train_item.policy);
event_processor.process_train(RLEvent::TrainStep(EvaluationItem::new(
train_item.item,
progress.clone(),
None,
)));
}
}
if valid_next > previous_steps && valid_next <= progress.items_processed {
env_runner_valid.update_policy(learner_agent.policy().state());
env_runner_valid.run_episodes(
self.config.eval_episodes,
&mut event_processor,
&training_components.interrupter,
&mut progress,
);
if let Some(checkpointer) = &mut checkpointer {
checkpointer.checkpoint(
&env_runner.policy(),
learner_agent,
valid_next,
&training_components.event_store,
);
}
valid_next += self.config.eval_interval;
}
}
(learner_agent.policy(), event_processor)
}
}