burn_train/metric/processor/
async_wrapper.rs1use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation};
2
3use super::{EventProcessorTraining, LearnerEvent};
4use async_channel::{Receiver, Sender};
5
6pub struct AsyncProcessorTraining<P: EventProcessorTraining> {
8 sender: Sender<Message<P>>,
9}
10
11pub struct AsyncProcessorEvaluation<P: EventProcessorEvaluation> {
13 sender: Sender<EvalMessage<P>>,
14}
15
16struct WorkerTraining<P: EventProcessorTraining> {
17 processor: P,
18 rec: Receiver<Message<P>>,
19}
20
21struct WorkerEvaluation<P: EventProcessorEvaluation> {
22 processor: P,
23 rec: Receiver<EvalMessage<P>>,
24}
25
26impl<P: EventProcessorTraining + 'static> WorkerTraining<P> {
27 pub fn start(processor: P, rec: Receiver<Message<P>>) {
28 let mut worker = Self { processor, rec };
29
30 std::thread::spawn(move || {
31 while let Ok(msg) = worker.rec.recv_blocking() {
32 match msg {
33 Message::Train(event) => worker.processor.process_train(event),
34 Message::Valid(event) => worker.processor.process_valid(event),
35 Message::Renderer(callback) => {
36 callback.send_blocking(worker.processor.renderer()).unwrap();
37 return;
38 }
39 }
40 }
41 });
42 }
43}
44impl<P: EventProcessorEvaluation + 'static> WorkerEvaluation<P> {
45 pub fn start(processor: P, rec: Receiver<EvalMessage<P>>) {
46 let mut worker = Self { processor, rec };
47
48 std::thread::spawn(move || {
49 while let Ok(event) = worker.rec.recv_blocking() {
50 match event {
51 EvalMessage::Test(event) => worker.processor.process_test(event),
52 EvalMessage::Renderer(sender) => {
53 sender.send_blocking(worker.processor.renderer()).unwrap();
54 return;
55 }
56 }
57 }
58 });
59 }
60}
61
62impl<P: EventProcessorTraining + 'static> AsyncProcessorTraining<P> {
63 pub fn new(processor: P) -> Self {
65 let (sender, rec) = async_channel::bounded(1);
66
67 WorkerTraining::start(processor, rec);
68
69 Self { sender }
70 }
71}
72
73impl<P: EventProcessorEvaluation + 'static> AsyncProcessorEvaluation<P> {
74 pub fn new(processor: P) -> Self {
76 let (sender, rec) = async_channel::bounded(1);
77
78 WorkerEvaluation::start(processor, rec);
79
80 Self { sender }
81 }
82}
83
84enum Message<P: EventProcessorTraining> {
85 Train(LearnerEvent<P::ItemTrain>),
86 Valid(LearnerEvent<P::ItemValid>),
87 Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
88}
89
90enum EvalMessage<P: EventProcessorEvaluation> {
91 Test(EvaluatorEvent<P::ItemTest>),
92 Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
93}
94
95impl<P: EventProcessorTraining> EventProcessorTraining for AsyncProcessorTraining<P> {
96 type ItemTrain = P::ItemTrain;
97 type ItemValid = P::ItemValid;
98
99 fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>) {
100 self.sender.send_blocking(Message::Train(event)).unwrap();
101 }
102
103 fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>) {
104 self.sender.send_blocking(Message::Valid(event)).unwrap();
105 }
106
107 fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
108 let (sender, rec) = async_channel::bounded(1);
109 self.sender
110 .send_blocking(Message::Renderer(sender))
111 .unwrap();
112
113 match rec.recv_blocking() {
114 Ok(value) => value,
115 Err(err) => panic!("{err:?}"),
116 }
117 }
118}
119
120impl<P: EventProcessorEvaluation> EventProcessorEvaluation for AsyncProcessorEvaluation<P> {
121 type ItemTest = P::ItemTest;
122
123 fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {
124 self.sender.send_blocking(EvalMessage::Test(event)).unwrap();
125 }
126
127 fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
128 let (sender, rec) = async_channel::bounded(1);
129 self.sender
130 .send_blocking(EvalMessage::Renderer(sender))
131 .unwrap();
132
133 match rec.recv_blocking() {
134 Ok(value) => value,
135 Err(err) => panic!("{err:?}"),
136 }
137 }
138}