burn_train/evaluator/
builder.rs

1use crate::{
2    ApplicationLoggerInstaller, Evaluator, FileApplicationLoggerInstaller, Interrupter, TestStep,
3    evaluator::components::EvaluatorComponentTypesMarker,
4    logger::FileMetricLogger,
5    metric::{
6        Adaptor, ItemLazy, Metric,
7        processor::{AsyncProcessorEvaluation, FullEventProcessorEvaluation, MetricsEvaluation},
8        store::{EventStoreClient, LogEventStore},
9    },
10    renderer::{MetricsRenderer, default_renderer},
11};
12use burn_core::{module::Module, prelude::Backend};
13use std::{
14    collections::BTreeSet,
15    marker::PhantomData,
16    path::{Path, PathBuf},
17    sync::Arc,
18};
19
20/// Struct to configure and create an [evaluator](Evaluator).
21///
22/// The generics components of the builder should probably not be set manually, as they are
23/// optimized for Rust type inference.
24pub struct EvaluatorBuilder<B: Backend, TI, TO: ItemLazy> {
25    tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
26    event_store: LogEventStore,
27    summary_metrics: BTreeSet<String>,
28    renderer: Option<Box<dyn MetricsRenderer + 'static>>,
29    interrupter: Interrupter,
30    metrics: MetricsEvaluation<TO>,
31    directory: PathBuf,
32    summary: bool,
33    _p: PhantomData<(B, TI, TO)>,
34}
35
36impl<B: Backend, TI, TO: ItemLazy + 'static> EvaluatorBuilder<B, TI, TO> {
37    /// Creates a new evaluator builder.
38    ///
39    /// # Arguments
40    ///
41    /// * `directory` - The directory to save the checkpoints.
42    pub fn new(directory: impl AsRef<Path>) -> Self {
43        let directory = directory.as_ref().to_path_buf();
44        let log_file = directory.join("evaluation.log");
45
46        Self {
47            tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(log_file))),
48            event_store: LogEventStore::default(),
49            summary_metrics: Default::default(),
50            renderer: None,
51            interrupter: Interrupter::new(),
52            summary: false,
53            metrics: MetricsEvaluation::default(),
54            directory,
55            _p: PhantomData,
56        }
57    }
58
59    /// Registers [numeric](crate::metric::Numeric) test [metrics](Metric).
60    pub fn metrics<M: EvalMetricRegistration<TI, TO>>(self, metrics: M) -> Self {
61        metrics.register(self)
62    }
63
64    /// Registers text [metrics](Metric).
65    pub fn metrics_text<M: EvalTextMetricRegistration<TI, TO>>(self, metrics: M) -> Self {
66        metrics.register(self)
67    }
68
69    /// By default, Rust logs are captured and written into
70    /// `evaluation.log`. If disabled, standard Rust log handling
71    /// will apply.
72    pub fn with_application_logger(
73        mut self,
74        logger: Option<Box<dyn ApplicationLoggerInstaller>>,
75    ) -> Self {
76        self.tracing_logger = logger;
77        self
78    }
79
80    /// Register a [numeric](crate::metric::Numeric) test [metric](Metric).
81    pub fn metric_numeric<Me: Metric + crate::metric::Numeric + 'static>(
82        mut self,
83        metric: Me,
84    ) -> Self
85    where
86        <TO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
87    {
88        self.summary_metrics.insert(metric.name().to_string());
89        self.metrics.register_test_metric_numeric(metric);
90        self
91    }
92
93    /// Register a text test [metric](Metric).
94    pub fn metric<Me: Metric + 'static>(mut self, metric: Me) -> Self
95    where
96        <TO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
97    {
98        self.summary_metrics.insert(metric.name().to_string());
99        self.metrics.register_test_metric(metric);
100        self
101    }
102
103    /// Replace the default CLI renderer with a custom one.
104    ///
105    /// # Arguments
106    ///
107    /// * `renderer` - The custom renderer.
108    pub fn renderer(mut self, renderer: Box<dyn MetricsRenderer + 'static>) -> Self {
109        self.renderer = Some(renderer);
110        self
111    }
112
113    /// Enable the evaluation summary report.
114    ///
115    /// The summary will be displayed at the end of `.eval()`.
116    pub fn summary(mut self) -> Self {
117        self.summary = true;
118        self
119    }
120
121    /// Builds the evaluator.
122    #[allow(clippy::type_complexity)]
123    pub fn build<M>(
124        mut self,
125        model: M,
126    ) -> Evaluator<
127        EvaluatorComponentTypesMarker<
128            B,
129            M,
130            AsyncProcessorEvaluation<FullEventProcessorEvaluation<TO>>,
131            TI,
132            TO,
133        >,
134    >
135    where
136        TI: Send + 'static,
137        M: Module<B> + TestStep<TI, TO> + core::fmt::Display + 'static,
138    {
139        let renderer = self
140            .renderer
141            .unwrap_or_else(|| default_renderer(self.interrupter.clone(), None));
142
143        self.event_store
144            .register_logger_test(FileMetricLogger::new_eval(self.directory.join("test")));
145        let event_store = Arc::new(EventStoreClient::new(self.event_store));
146
147        let event_processor = AsyncProcessorEvaluation::new(FullEventProcessorEvaluation::new(
148            self.metrics,
149            renderer,
150            event_store,
151        ));
152
153        Evaluator {
154            model,
155            interrupter: self.interrupter,
156            event_processor,
157        }
158    }
159}
160
161/// Trait to fake variadic generics.
162pub trait EvalMetricRegistration<TI, TO: ItemLazy>: Sized {
163    /// Register the metrics.
164    fn register<B: Backend>(
165        self,
166        builder: EvaluatorBuilder<B, TI, TO>,
167    ) -> EvaluatorBuilder<B, TI, TO>;
168}
169
170/// Trait to fake variadic generics.
171pub trait EvalTextMetricRegistration<TI, TO: ItemLazy>: Sized {
172    /// Register the metrics.
173    fn register<B: Backend>(
174        self,
175        builder: EvaluatorBuilder<B, TI, TO>,
176    ) -> EvaluatorBuilder<B, TI, TO>;
177}
178
179macro_rules! gen_tuple {
180    ($($M:ident),*) => {
181        impl<$($M,)* TI: 'static, TO: ItemLazy+'static> EvalTextMetricRegistration<TI, TO> for ($($M,)*)
182        where
183            $(TO::ItemSync: Adaptor<$M::Input>,)*
184            $($M: Metric + 'static,)*
185        {
186            #[allow(non_snake_case)]
187            fn register<B: Backend>(
188                self,
189                builder: EvaluatorBuilder<B, TI, TO>,
190            ) -> EvaluatorBuilder<B, TI, TO> {
191                let ($($M,)*) = self;
192                $(let builder = builder.metric($M);)*
193                builder
194            }
195        }
196
197        impl<$($M,)* TI: 'static, TO: ItemLazy+'static> EvalMetricRegistration<TI, TO> for ($($M,)*)
198        where
199            $(TO::ItemSync: Adaptor<$M::Input>,)*
200            $($M: Metric + $crate::metric::Numeric+ 'static,)*
201        {
202            #[allow(non_snake_case)]
203            fn register<B: Backend>(
204                self,
205                builder: EvaluatorBuilder<B, TI, TO>,
206            ) -> EvaluatorBuilder<B, TI, TO> {
207                let ($($M,)*) = self;
208                $(let builder = builder.metric_numeric($M);)*
209                builder
210            }
211        }
212    };
213}
214
215gen_tuple!(M1);
216gen_tuple!(M1, M2);
217gen_tuple!(M1, M2, M3);
218gen_tuple!(M1, M2, M3, M4);
219gen_tuple!(M1, M2, M3, M4, M5);
220gen_tuple!(M1, M2, M3, M4, M5, M6);