burn_train/logger/
async_logger.rs

1use super::Logger;
2use std::sync::mpsc;
3
4enum Message<T> {
5    Log(T),
6    End,
7    Sync(mpsc::Sender<()>),
8}
9/// Async logger.
10pub struct AsyncLogger<T> {
11    sender: mpsc::Sender<Message<T>>,
12    handler: Option<std::thread::JoinHandle<()>>,
13}
14
15#[derive(new)]
16struct LoggerThread<T, L: Logger<T>> {
17    logger: L,
18    receiver: mpsc::Receiver<Message<T>>,
19}
20
21impl<T, L> LoggerThread<T, L>
22where
23    L: Logger<T>,
24{
25    fn run(mut self) {
26        for item in self.receiver.iter() {
27            match item {
28                Message::Log(item) => {
29                    self.logger.log(item);
30                }
31                Message::End => {
32                    return;
33                }
34                Message::Sync(callback) => {
35                    callback
36                        .send(())
37                        .expect("Can return result with the callback channel.");
38                }
39            }
40        }
41    }
42}
43
44impl<T: Send + Sync + 'static> AsyncLogger<T> {
45    /// Create a new async logger.
46    pub fn new<L>(logger: L) -> Self
47    where
48        L: Logger<T> + 'static,
49    {
50        let (sender, receiver) = mpsc::channel();
51        let thread = LoggerThread::new(logger, receiver);
52
53        let handler = Some(std::thread::spawn(move || thread.run()));
54
55        Self { sender, handler }
56    }
57
58    /// Sync the async logger.
59    pub(crate) fn sync(&self) {
60        let (sender, receiver) = mpsc::channel();
61
62        self.sender
63            .send(Message::Sync(sender))
64            .expect("Can send message to logger thread.");
65
66        receiver
67            .recv()
68            .expect("Should sync, otherwise the thread is dead.");
69    }
70}
71
72impl<T: Send> Logger<T> for AsyncLogger<T> {
73    fn log(&mut self, item: T) {
74        self.sender
75            .send(Message::Log(item))
76            .expect("Can log using the logger thread.");
77    }
78}
79
80impl<T> Drop for AsyncLogger<T> {
81    fn drop(&mut self) {
82        self.sender
83            .send(Message::End)
84            .expect("Can send the end message to the logger thread.");
85        let handler = self.handler.take();
86
87        if let Some(handler) = handler {
88            handler.join().expect("The logger thread should stop.");
89        }
90    }
91}