1use crate::{
2 ApplicationLoggerInstaller, Evaluator, FileApplicationLoggerInstaller, Interrupter, TestStep,
3 evaluator::components::EvaluatorComponentTypesMarker,
4 logger::FileMetricLogger,
5 metric::{
6 Adaptor, ItemLazy, Metric,
7 processor::{AsyncProcessorEvaluation, FullEventProcessorEvaluation, MetricsEvaluation},
8 store::{EventStoreClient, LogEventStore},
9 },
10 renderer::{MetricsRenderer, default_renderer},
11};
12use burn_core::{module::Module, prelude::Backend};
13use std::{
14 collections::BTreeSet,
15 marker::PhantomData,
16 path::{Path, PathBuf},
17 sync::Arc,
18};
19
20pub struct EvaluatorBuilder<B: Backend, TI, TO: ItemLazy> {
25 tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
26 event_store: LogEventStore,
27 summary_metrics: BTreeSet<String>,
28 renderer: Option<Box<dyn MetricsRenderer + 'static>>,
29 interrupter: Interrupter,
30 metrics: MetricsEvaluation<TO>,
31 directory: PathBuf,
32 summary: bool,
33 _p: PhantomData<(B, TI, TO)>,
34}
35
36impl<B: Backend, TI, TO: ItemLazy + 'static> EvaluatorBuilder<B, TI, TO> {
37 pub fn new(directory: impl AsRef<Path>) -> Self {
43 let directory = directory.as_ref().to_path_buf();
44 let log_file = directory.join("evaluation.log");
45
46 Self {
47 tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(log_file))),
48 event_store: LogEventStore::default(),
49 summary_metrics: Default::default(),
50 renderer: None,
51 interrupter: Interrupter::new(),
52 summary: false,
53 metrics: MetricsEvaluation::default(),
54 directory,
55 _p: PhantomData,
56 }
57 }
58
59 pub fn metrics<M: EvalMetricRegistration<TI, TO>>(self, metrics: M) -> Self {
61 metrics.register(self)
62 }
63
64 pub fn metrics_text<M: EvalTextMetricRegistration<TI, TO>>(self, metrics: M) -> Self {
66 metrics.register(self)
67 }
68
69 pub fn with_application_logger(
73 mut self,
74 logger: Option<Box<dyn ApplicationLoggerInstaller>>,
75 ) -> Self {
76 self.tracing_logger = logger;
77 self
78 }
79
80 pub fn metric_numeric<Me: Metric + crate::metric::Numeric + 'static>(
82 mut self,
83 metric: Me,
84 ) -> Self
85 where
86 <TO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
87 {
88 self.summary_metrics.insert(metric.name().to_string());
89 self.metrics.register_test_metric_numeric(metric);
90 self
91 }
92
93 pub fn metric<Me: Metric + 'static>(mut self, metric: Me) -> Self
95 where
96 <TO as ItemLazy>::ItemSync: Adaptor<Me::Input>,
97 {
98 self.summary_metrics.insert(metric.name().to_string());
99 self.metrics.register_test_metric(metric);
100 self
101 }
102
103 pub fn renderer(mut self, renderer: Box<dyn MetricsRenderer + 'static>) -> Self {
109 self.renderer = Some(renderer);
110 self
111 }
112
113 pub fn summary(mut self) -> Self {
117 self.summary = true;
118 self
119 }
120
121 #[allow(clippy::type_complexity)]
123 pub fn build<M>(
124 mut self,
125 model: M,
126 ) -> Evaluator<
127 EvaluatorComponentTypesMarker<
128 B,
129 M,
130 AsyncProcessorEvaluation<FullEventProcessorEvaluation<TO>>,
131 TI,
132 TO,
133 >,
134 >
135 where
136 TI: Send + 'static,
137 M: Module<B> + TestStep<TI, TO> + core::fmt::Display + 'static,
138 {
139 let renderer = self
140 .renderer
141 .unwrap_or_else(|| default_renderer(self.interrupter.clone(), None));
142
143 self.event_store
144 .register_logger_test(FileMetricLogger::new_eval(self.directory.join("test")));
145 let event_store = Arc::new(EventStoreClient::new(self.event_store));
146
147 let event_processor = AsyncProcessorEvaluation::new(FullEventProcessorEvaluation::new(
148 self.metrics,
149 renderer,
150 event_store,
151 ));
152
153 Evaluator {
154 model,
155 interrupter: self.interrupter,
156 event_processor,
157 }
158 }
159}
160
161pub trait EvalMetricRegistration<TI, TO: ItemLazy>: Sized {
163 fn register<B: Backend>(
165 self,
166 builder: EvaluatorBuilder<B, TI, TO>,
167 ) -> EvaluatorBuilder<B, TI, TO>;
168}
169
170pub trait EvalTextMetricRegistration<TI, TO: ItemLazy>: Sized {
172 fn register<B: Backend>(
174 self,
175 builder: EvaluatorBuilder<B, TI, TO>,
176 ) -> EvaluatorBuilder<B, TI, TO>;
177}
178
179macro_rules! gen_tuple {
180 ($($M:ident),*) => {
181 impl<$($M,)* TI: 'static, TO: ItemLazy+'static> EvalTextMetricRegistration<TI, TO> for ($($M,)*)
182 where
183 $(TO::ItemSync: Adaptor<$M::Input>,)*
184 $($M: Metric + 'static,)*
185 {
186 #[allow(non_snake_case)]
187 fn register<B: Backend>(
188 self,
189 builder: EvaluatorBuilder<B, TI, TO>,
190 ) -> EvaluatorBuilder<B, TI, TO> {
191 let ($($M,)*) = self;
192 $(let builder = builder.metric($M);)*
193 builder
194 }
195 }
196
197 impl<$($M,)* TI: 'static, TO: ItemLazy+'static> EvalMetricRegistration<TI, TO> for ($($M,)*)
198 where
199 $(TO::ItemSync: Adaptor<$M::Input>,)*
200 $($M: Metric + $crate::metric::Numeric+ 'static,)*
201 {
202 #[allow(non_snake_case)]
203 fn register<B: Backend>(
204 self,
205 builder: EvaluatorBuilder<B, TI, TO>,
206 ) -> EvaluatorBuilder<B, TI, TO> {
207 let ($($M,)*) = self;
208 $(let builder = builder.metric_numeric($M);)*
209 builder
210 }
211 }
212 };
213}
214
215gen_tuple!(M1);
216gen_tuple!(M1, M2);
217gen_tuple!(M1, M2, M3);
218gen_tuple!(M1, M2, M3, M4);
219gen_tuple!(M1, M2, M3, M4, M5);
220gen_tuple!(M1, M2, M3, M4, M5, M6);