use std::sync::Arc;
use crate::{
Interrupter, LearnerSummaryConfig, OffPolicyConfig, RLCheckpointer, RLComponentsTypes, RLEvent,
RLEventProcessorType, RLResult,
metric::{processor::EventProcessorTraining, store::EventStoreClient},
};
pub struct RLComponents<RLC: RLComponentsTypes> {
pub num_steps: usize,
pub checkpoint: Option<usize>,
pub checkpointer: Option<RLCheckpointer<RLC>>,
pub grad_accumulation: Option<usize>,
pub interrupter: Interrupter,
pub event_processor: RLEventProcessorType<RLC>,
pub event_store: Arc<EventStoreClient>,
pub summary: Option<LearnerSummaryConfig>,
}
#[derive(Clone)]
pub enum RLStrategies<RLC: RLComponentsTypes> {
OffPolicyStrategy(OffPolicyConfig),
Custom(CustomRLStrategy<RLC>),
}
pub type CustomRLStrategy<LC> = Arc<dyn RLStrategy<LC>>;
pub trait RLStrategy<RLC: RLComponentsTypes> {
fn train(
&self,
mut learner_agent: RLC::LearningAgent,
mut training_components: RLComponents<RLC>,
env_init: RLC::EnvInit,
) -> RLResult<RLC::Policy> {
let starting_epoch = match training_components.checkpoint {
Some(checkpoint) => {
if let Some(checkpointer) = &mut training_components.checkpointer {
learner_agent = checkpointer.load_checkpoint(
learner_agent,
&Default::default(),
checkpoint,
);
}
checkpoint + 1
}
None => 1,
};
let summary_config = training_components.summary.clone();
training_components
.event_processor
.process_train(RLEvent::Start);
let (policy, mut event_processor) = self.train_loop(
training_components,
&mut learner_agent,
starting_epoch,
env_init,
);
let summary = summary_config.and_then(|summary| summary.init().ok());
event_processor.process_train(RLEvent::End(summary));
let renderer = event_processor.renderer();
RLResult { policy, renderer }
}
fn train_loop(
&self,
training_components: RLComponents<RLC>,
learner_agent: &mut RLC::LearningAgent,
starting_epoch: usize,
env_init: RLC::EnvInit,
) -> (RLC::Policy, RLEventProcessorType<RLC>);
}