use std::collections::HashMap;
use crate::{
EpisodeSummary, EvaluationItem, ItemLazy, MetricUpdater, MetricWrapper, NumericMetricUpdater,
metric::{
Adaptor, Metric, MetricDefinition, MetricId, MetricMetadata, Numeric, store::MetricsUpdate,
},
};
pub(crate) struct RLMetrics<TS: ItemLazy, ES: ItemLazy> {
train_step: Vec<Box<dyn MetricUpdater<TS::ItemSync>>>,
env_step: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,
env_step_valid: Vec<Box<dyn MetricUpdater<ES::ItemSync>>>,
episode_end: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,
episode_end_valid: Vec<Box<dyn MetricUpdater<EpisodeSummary>>>,
train_step_numeric: Vec<Box<dyn NumericMetricUpdater<TS::ItemSync>>>,
env_step_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,
env_step_valid_numeric: Vec<Box<dyn NumericMetricUpdater<ES::ItemSync>>>,
episode_end_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,
episode_end_valid_numeric: Vec<Box<dyn NumericMetricUpdater<EpisodeSummary>>>,
metric_definitions: HashMap<MetricId, MetricDefinition>,
}
impl<TS: ItemLazy, ES: ItemLazy> Default for RLMetrics<TS, ES> {
fn default() -> Self {
Self {
train_step: Vec::default(),
env_step: Vec::default(),
env_step_valid: Vec::default(),
episode_end: Vec::default(),
episode_end_valid: Vec::default(),
train_step_numeric: Vec::default(),
env_step_numeric: Vec::default(),
env_step_valid_numeric: Vec::default(),
episode_end_numeric: Vec::default(),
episode_end_valid_numeric: Vec::default(),
metric_definitions: HashMap::default(),
}
}
}
impl<TS: ItemLazy, ES: ItemLazy> RLMetrics<TS, ES> {
pub(crate) fn register_text_metric_agent<Me: Metric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step.push(Box::new(metric))
}
pub(crate) fn register_agent_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step_numeric.push(Box::new(metric))
}
pub(crate) fn register_text_metric_train<Me: Metric + 'static>(&mut self, metric: Me)
where
TS::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.train_step.push(Box::new(metric))
}
pub(crate) fn register_metric_train<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
TS::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.train_step_numeric.push(Box::new(metric))
}
pub(crate) fn register_text_metric_agent_valid<Me: Metric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step_valid.push(Box::new(metric))
}
pub(crate) fn register_agent_metric_valid<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
ES::ItemSync: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.env_step_valid_numeric.push(Box::new(metric))
}
pub(crate) fn register_text_metric_episode<Me: Metric + 'static>(&mut self, metric: Me)
where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end.push(Box::new(metric))
}
pub(crate) fn register_episode_metric<Me: Metric + Numeric + 'static>(&mut self, metric: Me)
where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end_numeric.push(Box::new(metric))
}
pub(crate) fn register_text_metric_episode_valid<Me: Metric + 'static>(&mut self, metric: Me)
where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end_valid.push(Box::new(metric))
}
pub(crate) fn register_episode_metric_valid<Me: Metric + Numeric + 'static>(
&mut self,
metric: Me,
) where
EpisodeSummary: Adaptor<Me::Input> + 'static,
{
let metric = MetricWrapper::new(metric);
self.register_definition(&metric);
self.episode_end_valid_numeric.push(Box::new(metric))
}
fn register_definition<Me: Metric>(&mut self, metric: &MetricWrapper<Me>) {
self.metric_definitions.insert(
metric.id.clone(),
MetricDefinition::new(metric.id.clone(), &metric.metric),
);
}
pub(crate) fn metric_definitions(&mut self) -> Vec<MetricDefinition> {
self.metric_definitions.values().cloned().collect()
}
pub(crate) fn update_train_step(
&mut self,
item: &EvaluationItem<TS::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.train_step.len());
let mut entries_numeric = Vec::with_capacity(self.train_step_numeric.len());
for metric in self.train_step.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.train_step_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
pub(crate) fn update_env_step(
&mut self,
item: &EvaluationItem<ES::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.env_step.len());
let mut entries_numeric = Vec::with_capacity(self.env_step_numeric.len());
for metric in self.env_step.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.env_step_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
pub(crate) fn update_env_step_valid(
&mut self,
item: &EvaluationItem<ES::ItemSync>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.env_step_valid.len());
let mut entries_numeric = Vec::with_capacity(self.env_step_valid_numeric.len());
for metric in self.env_step_valid.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.env_step_valid_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
pub(crate) fn update_episode_end(
&mut self,
item: &EvaluationItem<EpisodeSummary>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.episode_end.len());
let mut entries_numeric = Vec::with_capacity(self.episode_end_numeric.len());
for metric in self.episode_end.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.episode_end_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
pub(crate) fn update_episode_end_valid(
&mut self,
item: &EvaluationItem<EpisodeSummary>,
metadata: &MetricMetadata,
) -> MetricsUpdate {
let mut entries = Vec::with_capacity(self.episode_end_valid.len());
let mut entries_numeric = Vec::with_capacity(self.episode_end_valid_numeric.len());
for metric in self.episode_end_valid.iter_mut() {
let state = metric.update(&item.item, metadata);
entries.push(state);
}
for metric in self.episode_end_valid_numeric.iter_mut() {
let numeric_update = metric.update(&item.item, metadata);
entries_numeric.push(numeric_update);
}
MetricsUpdate::new(entries, entries_numeric)
}
}