Skip to main content

burn_train/metric/processor/
async_wrapper.rs

1use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation};
2
3use super::EventProcessorTraining;
4use async_channel::{Receiver, Sender};
5
6/// Event processor for the training process.
7pub struct AsyncProcessorTraining<ET, EV> {
8    sender: Sender<Message<ET, EV>>,
9}
10
11/// Event processor for the model evaluation.
12pub struct AsyncProcessorEvaluation<P: EventProcessorEvaluation> {
13    sender: Sender<EvalMessage<P>>,
14}
15
16struct WorkerTraining<ET, EV, P: EventProcessorTraining<ET, EV>> {
17    processor: P,
18    rec: Receiver<Message<ET, EV>>,
19}
20
21struct WorkerEvaluation<P: EventProcessorEvaluation> {
22    processor: P,
23    rec: Receiver<EvalMessage<P>>,
24}
25
26impl<ET: Send + 'static, EV: Send + 'static, P: EventProcessorTraining<ET, EV> + 'static>
27    WorkerTraining<ET, EV, P>
28{
29    pub fn start(processor: P, rec: Receiver<Message<ET, EV>>) {
30        let mut worker = Self { processor, rec };
31        std::thread::Builder::new()
32            .name("train-worker".into())
33            .spawn(move || {
34                while let Ok(msg) = worker.rec.recv_blocking() {
35                    match msg {
36                        Message::Train(event) => worker.processor.process_train(event),
37                        Message::Valid(event) => worker.processor.process_valid(event),
38                        Message::Renderer(callback) => {
39                            callback.send_blocking(worker.processor.renderer()).unwrap();
40                            return;
41                        }
42                    }
43                }
44            })
45            .unwrap();
46    }
47}
48impl<P: EventProcessorEvaluation + 'static> WorkerEvaluation<P> {
49    pub fn start(processor: P, rec: Receiver<EvalMessage<P>>) {
50        let mut worker = Self { processor, rec };
51
52        std::thread::Builder::new()
53            .name("evel-worker".into())
54            .spawn(move || {
55                while let Ok(event) = worker.rec.recv_blocking() {
56                    match event {
57                        EvalMessage::Test(event) => worker.processor.process_test(event),
58                        EvalMessage::Renderer(sender) => {
59                            sender.send_blocking(worker.processor.renderer()).unwrap();
60                            return;
61                        }
62                    }
63                }
64            })
65            .unwrap();
66    }
67}
68
69impl<ET: Send + 'static, EV: Send + 'static> AsyncProcessorTraining<ET, EV> {
70    /// Create an event processor for training.
71    pub fn new<P: EventProcessorTraining<ET, EV> + 'static>(processor: P) -> Self {
72        let (sender, rec) = async_channel::bounded(1);
73
74        WorkerTraining::start(processor, rec);
75
76        Self { sender }
77    }
78}
79
80impl<P: EventProcessorEvaluation + 'static> AsyncProcessorEvaluation<P> {
81    /// Create an event processor for model evaluation.
82    pub fn new(processor: P) -> Self {
83        let (sender, rec) = async_channel::bounded(1);
84
85        WorkerEvaluation::start(processor, rec);
86
87        Self { sender }
88    }
89}
90
91enum Message<EventTrain, EventValid> {
92    Train(EventTrain),
93    Valid(EventValid),
94    Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
95}
96
97enum EvalMessage<P: EventProcessorEvaluation> {
98    Test(EvaluatorEvent<P::ItemTest>),
99    Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
100}
101
102impl<ET: Send, EV: Send> EventProcessorTraining<ET, EV> for AsyncProcessorTraining<ET, EV> {
103    fn process_train(&mut self, event: ET) {
104        self.sender.send_blocking(Message::Train(event)).unwrap();
105    }
106
107    fn process_valid(&mut self, event: EV) {
108        self.sender.send_blocking(Message::Valid(event)).unwrap();
109    }
110
111    fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
112        let (sender, rec) = async_channel::bounded(1);
113        self.sender
114            .send_blocking(Message::Renderer(sender))
115            .unwrap();
116
117        match rec.recv_blocking() {
118            Ok(value) => value,
119            Err(err) => panic!("{err:?}"),
120        }
121    }
122}
123
124impl<P: EventProcessorEvaluation> EventProcessorEvaluation for AsyncProcessorEvaluation<P> {
125    type ItemTest = P::ItemTest;
126
127    fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {
128        self.sender.send_blocking(EvalMessage::Test(event)).unwrap();
129    }
130
131    fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
132        let (sender, rec) = async_channel::bounded(1);
133        self.sender
134            .send_blocking(EvalMessage::Renderer(sender))
135            .unwrap();
136
137        match rec.recv_blocking() {
138            Ok(value) => value,
139            Err(err) => panic!("{err:?}"),
140        }
141    }
142}