use super::{EventProcessorTraining, ItemLazy, LearnerEvent, MetricsTraining};
use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation, MetricsEvaluation};
use crate::metric::store::{EpochSummary, EventStoreClient, Split};
use crate::renderer::{
EvaluationProgress, MetricState, MetricsRenderer, ProgressType, TrainingProgress,
};
use std::sync::Arc;
pub struct FullEventProcessorTraining<T: ItemLazy, V: ItemLazy> {
metrics: MetricsTraining<T, V>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
}
pub struct FullEventProcessorEvaluation<T: ItemLazy> {
metrics: MetricsEvaluation<T>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
}
impl<T: ItemLazy, V: ItemLazy> FullEventProcessorTraining<T, V> {
pub(crate) fn new(
metrics: MetricsTraining<T, V>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
) -> Self {
Self {
metrics,
renderer,
store,
}
}
fn progress_indicators(&self, progress: &TrainingProgress) -> Vec<ProgressType> {
let mut indicators = vec![];
indicators.push(ProgressType::Detailed {
tag: String::from("Epoch"),
progress: progress.global_progress.clone(),
});
if let Some(iteration) = progress.iteration {
indicators.push(ProgressType::Value {
tag: String::from("Iteration"),
value: iteration,
});
};
if let Some(p) = &progress.progress {
indicators.push(ProgressType::Detailed {
tag: String::from("Items"),
progress: p.clone(),
});
};
indicators
}
}
impl<T: ItemLazy> FullEventProcessorEvaluation<T> {
pub(crate) fn new(
metrics: MetricsEvaluation<T>,
renderer: Box<dyn MetricsRenderer>,
store: Arc<EventStoreClient>,
) -> Self {
Self {
metrics,
renderer,
store,
}
}
fn progress_indicators(&self, progress: &EvaluationProgress) -> Vec<ProgressType> {
let mut indicators = vec![];
if let Some(iteration) = progress.iteration {
indicators.push(ProgressType::Value {
tag: String::from("Iteration"),
value: iteration,
});
};
indicators.push(ProgressType::Detailed {
tag: String::from("Items"),
progress: progress.progress.clone(),
});
indicators
}
}
impl<T: ItemLazy> EventProcessorEvaluation for FullEventProcessorEvaluation<T> {
type ItemTest = T;
fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {
match event {
EvaluatorEvent::Start => {
let definitions = self.metrics.metric_definitions();
self.store
.add_event_train(crate::metric::store::Event::MetricsInit(
definitions.clone(),
));
definitions
.iter()
.for_each(|definition| self.renderer.register_metric(definition.clone()));
}
EvaluatorEvent::ProcessedItem(name, item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_test(&item, &metadata);
self.store.add_event_test(
crate::metric::store::Event::MetricsUpdate(update.clone()),
name.name.clone(),
);
update.entries.into_iter().for_each(|entry| {
self.renderer
.update_test(name.clone(), MetricState::Generic(entry))
});
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_test(
name.clone(),
MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
),
)
});
let indicators = self.progress_indicators(&progress);
self.renderer.render_test(progress, indicators);
}
EvaluatorEvent::End(summary) => {
self.renderer.on_test_end(summary).ok();
}
}
}
fn renderer(self) -> Box<dyn MetricsRenderer> {
self.renderer
}
}
impl<T: ItemLazy, V: ItemLazy> EventProcessorTraining<LearnerEvent<T>, LearnerEvent<V>>
for FullEventProcessorTraining<T, V>
{
fn process_train(&mut self, event: LearnerEvent<T>) {
match event {
LearnerEvent::Start => {
let definitions = self.metrics.metric_definitions();
self.store
.add_event_train(crate::metric::store::Event::MetricsInit(
definitions.clone(),
));
definitions
.iter()
.for_each(|definition| self.renderer.register_metric(definition.clone()));
}
LearnerEvent::ProcessedItem(item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_train(&item, &metadata);
self.store
.add_event_train(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_train(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_train(MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
))
});
let indicators = self.progress_indicators(&progress);
self.renderer.render_train(progress, indicators);
}
LearnerEvent::EndEpoch(epoch) => {
self.store
.add_event_train(crate::metric::store::Event::EndEpoch(EpochSummary::new(
epoch,
Split::Train,
)));
self.metrics.end_epoch_train();
}
LearnerEvent::End(summary) => {
self.renderer.on_train_end(summary).ok();
}
}
}
fn process_valid(&mut self, event: LearnerEvent<V>) {
match event {
LearnerEvent::Start => {} LearnerEvent::ProcessedItem(item) => {
let item = item.sync();
let progress = (&item).into();
let metadata = (&item).into();
let update = self.metrics.update_valid(&item, &metadata);
self.store
.add_event_valid(crate::metric::store::Event::MetricsUpdate(update.clone()));
update
.entries
.into_iter()
.for_each(|entry| self.renderer.update_valid(MetricState::Generic(entry)));
update
.entries_numeric
.into_iter()
.for_each(|numeric_update| {
self.renderer.update_valid(MetricState::Numeric(
numeric_update.entry,
numeric_update.numeric_entry,
))
});
let indicators = self.progress_indicators(&progress);
self.renderer.render_valid(progress, indicators);
}
LearnerEvent::EndEpoch(epoch) => {
self.store
.add_event_valid(crate::metric::store::Event::EndEpoch(EpochSummary::new(
epoch,
Split::Valid,
)));
self.metrics.end_epoch_valid();
}
LearnerEvent::End(_) => {} }
}
fn renderer(self) -> Box<dyn MetricsRenderer> {
self.renderer
}
}