use std::sync::Arc;
use crate::{
EpisodeSummary, EvaluationItem, EventProcessorTraining, ItemLazy, LearnerSummary, RLMetrics,
metric::store::{Event, EventStoreClient, MetricsUpdate},
renderer::{MetricState, MetricsRenderer, ProgressType, TrainingProgress},
};
pub enum RLEvent<TS, ES> {
Start,
TrainStep(EvaluationItem<TS>),
TimeStep(EvaluationItem<ES>),
EpisodeEnd(EvaluationItem<EpisodeSummary>),
End(Option<LearnerSummary>),
}
pub enum AgentEvaluationEvent<T> {
Start,
TimeStep(EvaluationItem<T>),
EpisodeEnd(EvaluationItem<EpisodeSummary>),
End,
}
#[derive(new)]
pub struct RLEventProcessor<TS: ItemLazy, ES: ItemLazy> {
metrics: RLMetrics<TS, ES>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
}
impl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {
fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
let indicators = vec![ProgressType::Detailed {
tag: String::from("Step"),
progress: progress.global_progress.clone(),
}];
indicators
}
fn progress_indicators_eval(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
let indicators = vec![ProgressType::Detailed {
tag: String::from("Step"),
progress: progress.global_progress.clone(),
}];
indicators
}
}
impl<TS: ItemLazy, ES: ItemLazy> RLEventProcessor<TS, ES> {
fn process_update_train(&mut self, update: MetricsUpdate) {
self.store
.add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_train(MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
))
});
}
fn process_update_valid(&mut self, update: MetricsUpdate) {
self.store
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_valid(MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
))
});
}
}
impl<TS: ItemLazy, ES: ItemLazy> EventProcessorTraining<RLEvent<TS, ES>, AgentEvaluationEvent<ES>>
for RLEventProcessor<TS, ES>
{
fn process_train(&mut self, event: RLEvent<TS, ES>) {
match event {
RLEvent::Start => {
let definitions = self.metrics.metric_definitions();
self.store
.add_event_train(Event::MetricsInit(definitions.clone()));
definitions
.iter()
.for_each(|definition| self.renderer.register_metric(definition.clone()));
}
RLEvent::TrainStep(item) => {
let item = item.sync();
let metadata = (&item).into();
let update = self.metrics.update_train_step(&item, &metadata);
self.process_update_train(update);
}
RLEvent::TimeStep(item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_env_step(&item, &metadata);
self.process_update_train(update);
let status = self.progress_indicators(&progress);
self.renderer.render_train(progress, status);
}
RLEvent::EpisodeEnd(item) => {
let item = item.sync();
let metadata = (&item).into();
let update = self.metrics.update_episode_end(&item, &metadata);
self.process_update_train(update);
}
RLEvent::End(learner_summary) => {
self.renderer.on_train_end(learner_summary).ok();
}
}
}
fn process_valid(&mut self, event: AgentEvaluationEvent<ES>) {
match event {
AgentEvaluationEvent::Start => {} AgentEvaluationEvent::TimeStep(item) => {
let item = item.sync();
let metadata = (&item).into();
let update = self.metrics.update_env_step_valid(&item, &metadata);
self.process_update_valid(update);
}
AgentEvaluationEvent::EpisodeEnd(item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_episode_end_valid(&item, &metadata);
self.process_update_valid(update);
let status = self.progress_indicators_eval(&progress);
self.renderer.render_valid(progress, status);
}
AgentEvaluationEvent::End => {} }
}
fn renderer(self) -> Box<dyn MetricsRenderer> {
self.renderer
}
}