burn_train/metric/processor/
async_wrapper.rs

1use crate::metric::processor::{EvaluatorEvent, EventProcessorEvaluation};
2
3use super::{EventProcessorTraining, LearnerEvent};
4use async_channel::{Receiver, Sender};
5
6/// Event processor for the training process.
7pub struct AsyncProcessorTraining<P: EventProcessorTraining> {
8    sender: Sender<Message<P>>,
9}
10
11/// Event processor for the model evaluation.
12pub struct AsyncProcessorEvaluation<P: EventProcessorEvaluation> {
13    sender: Sender<EvalMessage<P>>,
14}
15
16struct WorkerTraining<P: EventProcessorTraining> {
17    processor: P,
18    rec: Receiver<Message<P>>,
19}
20
21struct WorkerEvaluation<P: EventProcessorEvaluation> {
22    processor: P,
23    rec: Receiver<EvalMessage<P>>,
24}
25
26impl<P: EventProcessorTraining + 'static> WorkerTraining<P> {
27    pub fn start(processor: P, rec: Receiver<Message<P>>) {
28        let mut worker = Self { processor, rec };
29
30        std::thread::spawn(move || {
31            while let Ok(msg) = worker.rec.recv_blocking() {
32                match msg {
33                    Message::Train(event) => worker.processor.process_train(event),
34                    Message::Valid(event) => worker.processor.process_valid(event),
35                    Message::Renderer(callback) => {
36                        callback.send_blocking(worker.processor.renderer()).unwrap();
37                        return;
38                    }
39                }
40            }
41        });
42    }
43}
44impl<P: EventProcessorEvaluation + 'static> WorkerEvaluation<P> {
45    pub fn start(processor: P, rec: Receiver<EvalMessage<P>>) {
46        let mut worker = Self { processor, rec };
47
48        std::thread::spawn(move || {
49            while let Ok(event) = worker.rec.recv_blocking() {
50                match event {
51                    EvalMessage::Test(event) => worker.processor.process_test(event),
52                    EvalMessage::Renderer(sender) => {
53                        sender.send_blocking(worker.processor.renderer()).unwrap();
54                        return;
55                    }
56                }
57            }
58        });
59    }
60}
61
62impl<P: EventProcessorTraining + 'static> AsyncProcessorTraining<P> {
63    /// Create an event processor for training.
64    pub fn new(processor: P) -> Self {
65        let (sender, rec) = async_channel::bounded(1);
66
67        WorkerTraining::start(processor, rec);
68
69        Self { sender }
70    }
71}
72
73impl<P: EventProcessorEvaluation + 'static> AsyncProcessorEvaluation<P> {
74    /// Create an event processor for model evaluation.
75    pub fn new(processor: P) -> Self {
76        let (sender, rec) = async_channel::bounded(1);
77
78        WorkerEvaluation::start(processor, rec);
79
80        Self { sender }
81    }
82}
83
84enum Message<P: EventProcessorTraining> {
85    Train(LearnerEvent<P::ItemTrain>),
86    Valid(LearnerEvent<P::ItemValid>),
87    Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
88}
89
90enum EvalMessage<P: EventProcessorEvaluation> {
91    Test(EvaluatorEvent<P::ItemTest>),
92    Renderer(Sender<Box<dyn crate::renderer::MetricsRenderer>>),
93}
94
95impl<P: EventProcessorTraining> EventProcessorTraining for AsyncProcessorTraining<P> {
96    type ItemTrain = P::ItemTrain;
97    type ItemValid = P::ItemValid;
98
99    fn process_train(&mut self, event: LearnerEvent<Self::ItemTrain>) {
100        self.sender.send_blocking(Message::Train(event)).unwrap();
101    }
102
103    fn process_valid(&mut self, event: LearnerEvent<Self::ItemValid>) {
104        self.sender.send_blocking(Message::Valid(event)).unwrap();
105    }
106
107    fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
108        let (sender, rec) = async_channel::bounded(1);
109        self.sender
110            .send_blocking(Message::Renderer(sender))
111            .unwrap();
112
113        match rec.recv_blocking() {
114            Ok(value) => value,
115            Err(err) => panic!("{err:?}"),
116        }
117    }
118}
119
120impl<P: EventProcessorEvaluation> EventProcessorEvaluation for AsyncProcessorEvaluation<P> {
121    type ItemTest = P::ItemTest;
122
123    fn process_test(&mut self, event: EvaluatorEvent<Self::ItemTest>) {
124        self.sender.send_blocking(EvalMessage::Test(event)).unwrap();
125    }
126
127    fn renderer(self) -> Box<dyn crate::renderer::MetricsRenderer> {
128        let (sender, rec) = async_channel::bounded(1);
129        self.sender
130            .send_blocking(EvalMessage::Renderer(sender))
131            .unwrap();
132
133        match rec.recv_blocking() {
134            Ok(value) => value,
135            Err(err) => panic!("{err:?}"),
136        }
137    }
138}