burn_train/checkpoint/
async_checkpoint.rs

1use super::{Checkpointer, CheckpointerError};
2use crate::Interrupter;
3use burn_core::{record::Record, tensor::backend::Backend};
4use std::sync::mpsc;
5
6enum Message<R, B: Backend> {
7    Restore(
8        usize,
9        B::Device,
10        mpsc::SyncSender<Result<R, CheckpointerError>>,
11        Option<Interrupter>,
12    ),
13    Save(usize, R, Option<Interrupter>),
14    Delete(usize, Option<Interrupter>),
15    End,
16}
17
18#[derive(new)]
19struct CheckpointerThread<C, R, B: Backend> {
20    checkpointer: C,
21    receiver: mpsc::Receiver<Message<R, B>>,
22}
23
24impl<C, R, B> CheckpointerThread<C, R, B>
25where
26    C: Checkpointer<R, B>,
27    R: Record<B>,
28    B: Backend,
29{
30    fn run(self) {
31        for item in self.receiver.iter() {
32            match item {
33                Message::Restore(epoch, device, callback, interrupter) => {
34                    let record = self.checkpointer.restore(epoch, &device);
35                    callback.send(record).unwrap_or_else(|err| {
36                        interrupter.map_or_else(
37                            || {
38                                panic!(
39                                    "Error when sending response through callback channel: {err}"
40                                )
41                            },
42                            |int| int.stop(Some(&err.to_string())),
43                        )
44                    });
45                }
46                Message::Save(epoch, state, interrupter) => {
47                    self.checkpointer.save(epoch, state).unwrap_or_else(|err| {
48                        interrupter.map_or_else(
49                            || panic!("Error when saving the state: {err}"),
50                            |int| int.stop(Some(&err.to_string())),
51                        )
52                    });
53                }
54                Message::Delete(epoch, interrupter) => {
55                    self.checkpointer.delete(epoch).unwrap_or_else(|err| {
56                        interrupter.map_or_else(
57                            || panic!("Error when deleting the state: {err}"),
58                            |int| int.stop(Some(&err.to_string())),
59                        )
60                    });
61                }
62
63                Message::End => {
64                    return;
65                }
66            };
67        }
68    }
69}
70
71/// Async checkpointer.
72pub struct AsyncCheckpointer<Record, B: Backend> {
73    sender: mpsc::SyncSender<Message<Record, B>>,
74    handler: Option<std::thread::JoinHandle<()>>,
75    interrupter: Option<Interrupter>,
76}
77
78impl<R, B> AsyncCheckpointer<R, B>
79where
80    R: Record<B> + 'static,
81    B: Backend,
82{
83    /// Create a new async checkpointer.
84    ///
85    /// # Arguments
86    ///
87    /// * `checkpointer` - The checkpointer.
88    ///
89    /// # Returns
90    ///
91    /// The async checkpointer.
92    pub fn new<C>(checkpointer: C) -> Self
93    where
94        C: Checkpointer<R, B> + Send + 'static,
95    {
96        // Only on checkpoint can be done in advance.
97        let (sender, receiver) = mpsc::sync_channel(0);
98        let thread = CheckpointerThread::new(checkpointer, receiver);
99        let handler = Some(std::thread::spawn(move || thread.run()));
100
101        Self {
102            sender,
103            handler,
104            interrupter: None,
105        }
106    }
107
108    /// Assign a handle used to interrupt training in case of checkpointing error.
109    pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
110        self.interrupter = Some(interrupter);
111        self
112    }
113}
114
115impl<R, B> Checkpointer<R, B> for AsyncCheckpointer<R, B>
116where
117    R: Record<B> + 'static,
118    B: Backend,
119{
120    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
121        self.sender
122            .send(Message::Save(epoch, record, self.interrupter.clone()))
123            .expect("Can send message to checkpointer thread.");
124
125        Ok(())
126    }
127
128    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
129        let (sender, receiver) = mpsc::sync_channel(1);
130        self.sender
131            .send(Message::Restore(
132                epoch,
133                device.clone(),
134                sender,
135                self.interrupter.clone(),
136            ))
137            .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
138
139        if let Ok(record) = receiver.recv() {
140            return record;
141        };
142
143        Err(CheckpointerError::Unknown("Channel error.".to_string()))
144    }
145
146    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
147        self.sender
148            .send(Message::Delete(epoch, self.interrupter.clone()))
149            .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
150
151        Ok(())
152    }
153}
154
155impl<E, B> Drop for AsyncCheckpointer<E, B>
156where
157    B: Backend,
158{
159    fn drop(&mut self) {
160        self.sender
161            .send(Message::End)
162            .expect("Can send the end message to the checkpointer thread.");
163        let handler = self.handler.take();
164
165        if let Some(handler) = handler {
166            handler
167                .join()
168                .expect("The checkpointer thread should stop.");
169        }
170    }
171}