burn_train/checkpoint/
async_checkpoint.rs

1use super::{Checkpointer, CheckpointerError};
2use burn_core::{record::Record, tensor::backend::Backend};
3use std::sync::mpsc;
4
5enum Message<R, B: Backend> {
6    Restore(
7        usize,
8        B::Device,
9        mpsc::SyncSender<Result<R, CheckpointerError>>,
10    ),
11    Save(usize, R),
12    Delete(usize),
13    End,
14}
15
16#[derive(new)]
17struct CheckpointerThread<C, R, B: Backend> {
18    checkpointer: C,
19    receiver: mpsc::Receiver<Message<R, B>>,
20}
21
22impl<C, R, B> CheckpointerThread<C, R, B>
23where
24    C: Checkpointer<R, B>,
25    R: Record<B>,
26    B: Backend,
27{
28    fn run(self) {
29        for item in self.receiver.iter() {
30            match item {
31                Message::Restore(epoch, device, callback) => {
32                    let record = self.checkpointer.restore(epoch, &device);
33                    callback
34                        .send(record)
35                        .expect("Can send response through callback channel.");
36                }
37                Message::Save(epoch, state) => self
38                    .checkpointer
39                    .save(epoch, state)
40                    .expect("Can save the state."),
41                Message::Delete(epoch) => self
42                    .checkpointer
43                    .delete(epoch)
44                    .expect("Can delete the state."),
45                Message::End => {
46                    return;
47                }
48            };
49        }
50    }
51}
52
53/// Async checkpointer.
54pub struct AsyncCheckpointer<Record, B: Backend> {
55    sender: mpsc::SyncSender<Message<Record, B>>,
56    handler: Option<std::thread::JoinHandle<()>>,
57}
58
59impl<R, B> AsyncCheckpointer<R, B>
60where
61    R: Record<B> + 'static,
62    B: Backend,
63{
64    /// Create a new async checkpointer.
65    ///
66    /// # Arguments
67    ///
68    /// * `checkpointer` - The checkpointer.
69    ///
70    /// # Returns
71    ///
72    /// The async checkpointer.
73    pub fn new<C>(checkpointer: C) -> Self
74    where
75        C: Checkpointer<R, B> + Send + 'static,
76    {
77        // Only on checkpoint can be done in advance.
78        let (sender, receiver) = mpsc::sync_channel(0);
79        let thread = CheckpointerThread::new(checkpointer, receiver);
80        let handler = Some(std::thread::spawn(move || thread.run()));
81
82        Self { sender, handler }
83    }
84}
85
86impl<R, B> Checkpointer<R, B> for AsyncCheckpointer<R, B>
87where
88    R: Record<B> + 'static,
89    B: Backend,
90{
91    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
92        self.sender
93            .send(Message::Save(epoch, record))
94            .expect("Can send message to checkpointer thread.");
95
96        Ok(())
97    }
98
99    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
100        let (sender, receiver) = mpsc::sync_channel(1);
101        self.sender
102            .send(Message::Restore(epoch, device.clone(), sender))
103            .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
104
105        if let Ok(record) = receiver.recv() {
106            return record;
107        };
108
109        Err(CheckpointerError::Unknown("Channel error.".to_string()))
110    }
111
112    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
113        self.sender
114            .send(Message::Delete(epoch))
115            .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
116
117        Ok(())
118    }
119}
120
121impl<E, B> Drop for AsyncCheckpointer<E, B>
122where
123    B: Backend,
124{
125    fn drop(&mut self) {
126        self.sender
127            .send(Message::End)
128            .expect("Can send the end message to the checkpointer thread.");
129        let handler = self.handler.take();
130
131        if let Some(handler) = handler {
132            handler
133                .join()
134                .expect("The checkpointer thread should stop.");
135        }
136    }
137}