burn_train/checkpoint/
async_checkpoint.rs1use super::{Checkpointer, CheckpointerError};
2use burn_core::{record::Record, tensor::backend::Backend};
3use std::sync::mpsc;
4
5enum Message<R, B: Backend> {
6 Restore(
7 usize,
8 B::Device,
9 mpsc::SyncSender<Result<R, CheckpointerError>>,
10 ),
11 Save(usize, R),
12 Delete(usize),
13 End,
14}
15
16#[derive(new)]
17struct CheckpointerThread<C, R, B: Backend> {
18 checkpointer: C,
19 receiver: mpsc::Receiver<Message<R, B>>,
20}
21
22impl<C, R, B> CheckpointerThread<C, R, B>
23where
24 C: Checkpointer<R, B>,
25 R: Record<B>,
26 B: Backend,
27{
28 fn run(self) {
29 for item in self.receiver.iter() {
30 match item {
31 Message::Restore(epoch, device, callback) => {
32 let record = self.checkpointer.restore(epoch, &device);
33 callback
34 .send(record)
35 .expect("Can send response through callback channel.");
36 }
37 Message::Save(epoch, state) => self
38 .checkpointer
39 .save(epoch, state)
40 .expect("Can save the state."),
41 Message::Delete(epoch) => self
42 .checkpointer
43 .delete(epoch)
44 .expect("Can delete the state."),
45 Message::End => {
46 return;
47 }
48 };
49 }
50 }
51}
52
53pub struct AsyncCheckpointer<Record, B: Backend> {
55 sender: mpsc::SyncSender<Message<Record, B>>,
56 handler: Option<std::thread::JoinHandle<()>>,
57}
58
59impl<R, B> AsyncCheckpointer<R, B>
60where
61 R: Record<B> + 'static,
62 B: Backend,
63{
64 pub fn new<C>(checkpointer: C) -> Self
74 where
75 C: Checkpointer<R, B> + Send + 'static,
76 {
77 let (sender, receiver) = mpsc::sync_channel(0);
79 let thread = CheckpointerThread::new(checkpointer, receiver);
80 let handler = Some(std::thread::spawn(move || thread.run()));
81
82 Self { sender, handler }
83 }
84}
85
86impl<R, B> Checkpointer<R, B> for AsyncCheckpointer<R, B>
87where
88 R: Record<B> + 'static,
89 B: Backend,
90{
91 fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
92 self.sender
93 .send(Message::Save(epoch, record))
94 .expect("Can send message to checkpointer thread.");
95
96 Ok(())
97 }
98
99 fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
100 let (sender, receiver) = mpsc::sync_channel(1);
101 self.sender
102 .send(Message::Restore(epoch, device.clone(), sender))
103 .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
104
105 if let Ok(record) = receiver.recv() {
106 return record;
107 };
108
109 Err(CheckpointerError::Unknown("Channel error.".to_string()))
110 }
111
112 fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
113 self.sender
114 .send(Message::Delete(epoch))
115 .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
116
117 Ok(())
118 }
119}
120
121impl<E, B> Drop for AsyncCheckpointer<E, B>
122where
123 B: Backend,
124{
125 fn drop(&mut self) {
126 self.sender
127 .send(Message::End)
128 .expect("Can send the end message to the checkpointer thread.");
129 let handler = self.handler.take();
130
131 if let Some(handler) = handler {
132 handler
133 .join()
134 .expect("The checkpointer thread should stop.");
135 }
136 }
137}