burn_train/metric/processor/
async_wrapper.rs1use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation};
2
3use super::EventProcessorTraining;
4use async_channel::{Receiver, Sender};
5
6pub struct AsyncProcessorTraining<ET, EV> {
8 sender: Sender<Message<ET, EV>>,
9}
10
11pub struct AsyncProcessorEvaluation<P: EventProcessorEvaluation> {
13 sender: Sender<EvalMessage<P>>,
14}
15
16struct WorkerTraining<ET, EV, P: EventProcessorTraining<ET, EV>> {
17 processor: P,
18 rec: Receiver<Message<ET, EV>>,
19}
20
21struct WorkerEvaluation<P: EventProcessorEvaluation> {
22 processor: P,
23 rec: Receiver<EvalMessage<P>>,
24}
25
26impl<ET: Send + 'static, EV: Send + 'static, P: EventProcessorTraining<ET, EV> + 'static>
27 WorkerTraining<ET, EV, P>
28{
29 pub fn start(processor: P, rec: Receiver<Message<ET, EV>>) {
30 let mut worker = Self { processor, rec };
31 std::thread::Builder::new()
32 .name("train-worker".into())
33 .spawn(move || {
34 while let Ok(msg) = worker.rec.recv_blocking() {
35 match msg {
36 Message::Train(event) => worker.processor.process_train(event),
37 Message::Valid(event) => worker.processor.process_valid(event),
38 Message::Renderer(callback) => {
39 callback.send_blocking(worker.processor.renderer()).unwrap();
40 return;
41 }
42 }
43 }
44 })
45 .unwrap();
46 }
47}
48impl<P: EventProcessorEvaluation + 'static> WorkerEvaluation<P> {
49 pub fn start(processor: P, rec: Receiver<EvalMessage<P>>) {
50 let mut worker = Self { processor, rec };
51
52 std::thread::Builder::new()
53 .name("evel-worker".into())
54 .spawn(move || {
55 while let Ok(event) = worker.rec.recv_blocking() {
56 match event {
57 EvalMessage::Test(event) => worker.processor.process_test(event),
58 EvalMessage::Renderer(sender) => {
59 sender.send_blocking(worker.processor.renderer()).unwrap();
60 return;
61 }
62 }
63 }
64 })
65 .unwrap();
66 }
67}
68
69impl<ET: Send + 'static, EV: Send + 'static> AsyncProcessorTraining<ET, EV> {
70 pub fn new<P: EventProcessorTraining<ET, EV> + 'static>(processor: P) -> Self {
72 let (sender, rec) = async_channel::bounded(1);
73
74 WorkerTraining::start(processor, rec);
75
76 Self { sender }
77 }
78}
79
80impl<P: EventProcessorEvaluation + 'static> AsyncProcessorEvaluation<P> {
81 pub fn new(processor: P) -> Self {
83 let (sender, rec) = async_channel::bounded(1);
84
85 WorkerEvaluation::start(processor, rec);
86
87 Self { sender }
88 }
89}
90
91enum Message<EventTrain, EventValid> {
92 Train(EventTrain),
93 Valid(EventValid),
94 Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
95}
96
97enum EvalMessage<P: EventProcessorEvaluation> {
98 Test(EvaluatorEvent<P::ItemTest>),
99 Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
100}
101
102impl<ET: Send, EV: Send> EventProcessorTraining<ET, EV> for AsyncProcessorTraining<ET, EV> {
103 fn process_train(&mut self, event: ET) {
104 self.sender.send_blocking(Message::Train(event)).unwrap();
105 }
106
107 fn process_valid(&mut self, event: EV) {
108 self.sender.send_blocking(Message::Valid(event)).unwrap();
109 }
110
111 fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
112 let (sender, rec) = async_channel::bounded(1);
113 self.sender
114 .send_blocking(Message::Renderer(sender))
115 .unwrap();
116
117 match rec.recv_blocking() {
118 Ok(value) => value,
119 Err(err) => panic!("{err:?}"),
120 }
121 }
122}
123
124impl<P: EventProcessorEvaluation> EventProcessorEvaluation for AsyncProcessorEvaluation<P> {
125 type ItemTest = P::ItemTest;
126
127 fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {
128 self.sender.send_blocking(EvalMessage::Test(event)).unwrap();
129 }
130
131 fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
132 let (sender, rec) = async_channel::bounded(1);
133 self.sender
134 .send_blocking(EvalMessage::Renderer(sender))
135 .unwrap();
136
137 match rec.recv_blocking() {
138 Ok(value) => value,
139 Err(err) => panic!("{err:?}"),
140 }
141 }
142}