burn_train/logger/
async_logger.rs1use super::Logger;
2use std::sync::mpsc;
3
4enum Message<T> {
5 Log(T),
6 End,
7 Sync(mpsc::Sender<()>),
8}
9pub struct AsyncLogger<T> {
11 sender: mpsc::Sender<Message<T>>,
12 handler: Option<std::thread::JoinHandle<()>>,
13}
14
15#[derive(new)]
16struct LoggerThread<T, L: Logger<T>> {
17 logger: L,
18 receiver: mpsc::Receiver<Message<T>>,
19}
20
21impl<T, L> LoggerThread<T, L>
22where
23 L: Logger<T>,
24{
25 fn run(mut self) {
26 for item in self.receiver.iter() {
27 match item {
28 Message::Log(item) => {
29 self.logger.log(item);
30 }
31 Message::End => {
32 return;
33 }
34 Message::Sync(callback) => {
35 callback
36 .send(())
37 .expect("Can return result with the callback channel.");
38 }
39 }
40 }
41 }
42}
43
44impl<T: Send + Sync + 'static> AsyncLogger<T> {
45 pub fn new<L>(logger: L) -> Self
47 where
48 L: Logger<T> + 'static,
49 {
50 let (sender, receiver) = mpsc::channel();
51 let thread = LoggerThread::new(logger, receiver);
52
53 let handler = Some(std::thread::spawn(move || thread.run()));
54
55 Self { sender, handler }
56 }
57
58 pub(crate) fn sync(&self) {
60 let (sender, receiver) = mpsc::channel();
61
62 self.sender
63 .send(Message::Sync(sender))
64 .expect("Can send message to logger thread.");
65
66 receiver
67 .recv()
68 .expect("Should sync, otherwise the thread is dead.");
69 }
70}
71
72impl<T: Send> Logger<T> for AsyncLogger<T> {
73 fn log(&mut self, item: T) {
74 self.sender
75 .send(Message::Log(item))
76 .expect("Can log using the logger thread.");
77 }
78}
79
80impl<T> Drop for AsyncLogger<T> {
81 fn drop(&mut self) {
82 self.sender
83 .send(Message::End)
84 .expect("Can send the end message to the logger thread.");
85 let handler = self.handler.take();
86
87 if let Some(handler) = handler {
88 handler.join().expect("The logger thread should stop.");
89 }
90 }
91}