burn_train/evaluator/
components.rs

1use crate::metric::{ItemLazy, processor::EventProcessorEvaluation};
2use burn_core::{module::Module, prelude::Backend};
3use std::marker::PhantomData;
4
5/// All components necessary to evaluate a model grouped in one trait.
6pub trait EvaluatorComponentTypes {
7    /// The backend in used for the evaluation.
8    type Backend: Backend;
9    /// The model to evaluate.
10    type Model: Module<Self::Backend>
11        + TestStep<Self::TestInput, Self::TestOutput>
12        + core::fmt::Display
13        + 'static;
14    type EventProcessor: EventProcessorEvaluation<ItemTest = Self::TestOutput> + 'static;
15    /// Type of input to the evaluation step
16    type TestInput: Send + 'static;
17    /// Type of output of the evaluation step
18    type TestOutput: ItemLazy + 'static;
19}
20
21/// Trait to be implemented for validating models.
22pub trait TestStep<TI, TO> {
23    /// Runs a test step.
24    ///
25    /// # Arguments
26    ///
27    /// * `item` - The item to validate on.
28    ///
29    /// # Returns
30    ///
31    /// The test output.
32    fn step(&self, item: TI) -> TO;
33}
34
35/// A marker type used to provide [evaluation components](EvaluatorComponentTypes).
36pub struct EvaluatorComponentTypesMarker<B, M, E, TI, TO> {
37    _p: PhantomData<(B, M, E, TI, TO)>,
38}
39
40impl<B, M, E, TI, TO> EvaluatorComponentTypes for EvaluatorComponentTypesMarker<B, M, E, TI, TO>
41where
42    B: Backend,
43    M: Module<B> + TestStep<TI, TO> + core::fmt::Display + 'static,
44    E: EventProcessorEvaluation<ItemTest = TO> + 'static,
45    TI: Send + 'static,
46    TO: ItemLazy + 'static,
47{
48    type Backend = B;
49    type Model = M;
50    type EventProcessor = E;
51    type TestInput = TI;
52    type TestOutput = TO;
53}