Skip to main content

burn_train/evaluator/
base.rs

1use crate::{
2    AsyncProcessorEvaluation, FullEventProcessorEvaluation, InferenceStep, Interrupter,
3    evaluator::components::EvaluatorComponentTypes,
4    metric::processor::{EvaluatorEvent, EventProcessorEvaluation, LearnerItem},
5    renderer::{EvaluationName, MetricsRenderer},
6};
7use burn_core::{data::dataloader::DataLoader, module::Module};
8use std::sync::Arc;
9
10pub(crate) type TestBackend<EC> = <EC as EvaluatorComponentTypes>::Backend;
11pub(crate) type TestInput<EC> = <<EC as EvaluatorComponentTypes>::Model as InferenceStep>::Input;
12pub(crate) type TestOutput<EC> = <<EC as EvaluatorComponentTypes>::Model as InferenceStep>::Output;
13
14pub(crate) type TestLoader<EC> = Arc<dyn DataLoader<TestBackend<EC>, TestInput<EC>>>;
15
16/// Evaluates a model on a specific dataset.
17pub struct Evaluator<EC: EvaluatorComponentTypes> {
18    pub(crate) model: EC::Model,
19    pub(crate) interrupter: Interrupter,
20    pub(crate) event_processor:
21        AsyncProcessorEvaluation<FullEventProcessorEvaluation<TestOutput<EC>>>,
22}
23
24impl<EC: EvaluatorComponentTypes> Evaluator<EC> {
25    /// Run the evaluation on the given dataset.
26    ///
27    /// The data will be stored and displayed under the provided name.
28    pub fn eval<S: core::fmt::Display>(
29        mut self,
30        name: S,
31        dataloader: TestLoader<EC>,
32    ) -> Box<dyn MetricsRenderer> {
33        // Move dataloader to the model device
34        let dataloader = dataloader.to_device(self.model.devices().first().unwrap());
35
36        let name = EvaluationName::new(name);
37        let mut iterator = dataloader.iter();
38        let mut iteration = 0;
39
40        self.event_processor.process_test(EvaluatorEvent::Start);
41        while let Some(item) = iterator.next() {
42            let progress = iterator.progress();
43            iteration += 1;
44
45            let item = self.model.step(item);
46            let item = LearnerItem::new(item, progress, 0, 1, iteration, None);
47
48            self.event_processor
49                .process_test(EvaluatorEvent::ProcessedItem(name.clone(), item));
50
51            if self.interrupter.should_stop() {
52                log::info!("Testing interrupted.");
53                break;
54            }
55        }
56
57        self.event_processor.renderer()
58    }
59}