burn_train/evaluator/
base.rs

1use crate::{
2    Interrupter,
3    evaluator::components::{EvaluatorComponentTypes, TestStep},
4    metric::processor::{EvaluatorEvent, EventProcessorEvaluation, LearnerItem},
5    renderer::{EvaluationName, MetricsRenderer},
6};
7use burn_core::{data::dataloader::DataLoader, module::Module};
8use std::sync::Arc;
9
10/// Evaluates a model on a specific dataset.
11pub struct Evaluator<EC: EvaluatorComponentTypes> {
12    pub(crate) model: EC::Model,
13    pub(crate) interrupter: Interrupter,
14    pub(crate) event_processor: EC::EventProcessor,
15}
16
17pub(crate) type TestBackend<EC> = <EC as EvaluatorComponentTypes>::Backend;
18pub(crate) type TestInput<EC> = <EC as EvaluatorComponentTypes>::TestInput;
19
20pub(crate) type TestLoader<EC> = Arc<dyn DataLoader<TestBackend<EC>, TestInput<EC>>>;
21
22impl<EC: EvaluatorComponentTypes> Evaluator<EC> {
23    /// Run the evaluation on the given dataset.
24    ///
25    /// The data will be stored and displayed under the provided name.
26    pub fn eval<S: core::fmt::Display>(
27        mut self,
28        name: S,
29        dataloader: TestLoader<EC>,
30    ) -> Box<dyn MetricsRenderer> {
31        // Move dataloader to the model device
32        let dataloader = dataloader.to_device(self.model.devices().first().unwrap());
33
34        let name = EvaluationName::new(name);
35        let mut iterator = dataloader.iter();
36        let mut iteration = 0;
37
38        while let Some(item) = iterator.next() {
39            let progress = iterator.progress();
40            iteration += 1;
41
42            let item = self.model.step(item);
43            let item = LearnerItem::new(item, progress, 0, 1, iteration, None);
44
45            self.event_processor
46                .process_test(EvaluatorEvent::ProcessedItem(name.clone(), item));
47
48            if self.interrupter.should_stop() {
49                log::info!("Testing interrupted.");
50                break;
51            }
52        }
53
54        self.event_processor.renderer()
55    }
56}