burn_train/checkpoint/
async_checkpoint.rs1use super::{Checkpointer, CheckpointerError};
2use crate::Interrupter;
3use burn_core::{record::Record, tensor::backend::Backend};
4use std::sync::mpsc;
5
6enum Message<R, B: Backend> {
7 Restore(
8 usize,
9 B::Device,
10 mpsc::SyncSender<Result<R, CheckpointerError>>,
11 Option<Interrupter>,
12 ),
13 Save(usize, R, Option<Interrupter>),
14 Delete(usize, Option<Interrupter>),
15 End,
16}
17
18#[derive(new)]
19struct CheckpointerThread<C, R, B: Backend> {
20 checkpointer: C,
21 receiver: mpsc::Receiver<Message<R, B>>,
22}
23
24impl<C, R, B> CheckpointerThread<C, R, B>
25where
26 C: Checkpointer<R, B>,
27 R: Record<B>,
28 B: Backend,
29{
30 fn run(self) {
31 for item in self.receiver.iter() {
32 match item {
33 Message::Restore(epoch, device, callback, interrupter) => {
34 let record = self.checkpointer.restore(epoch, &device);
35 callback.send(record).unwrap_or_else(|err| {
36 interrupter.map_or_else(
37 || {
38 panic!(
39 "Error when sending response through callback channel: {err}"
40 )
41 },
42 |int| int.stop(Some(&err.to_string())),
43 )
44 });
45 }
46 Message::Save(epoch, state, interrupter) => {
47 self.checkpointer.save(epoch, state).unwrap_or_else(|err| {
48 interrupter.map_or_else(
49 || panic!("Error when saving the state: {err}"),
50 |int| int.stop(Some(&err.to_string())),
51 )
52 });
53 }
54 Message::Delete(epoch, interrupter) => {
55 self.checkpointer.delete(epoch).unwrap_or_else(|err| {
56 interrupter.map_or_else(
57 || panic!("Error when deleting the state: {err}"),
58 |int| int.stop(Some(&err.to_string())),
59 )
60 });
61 }
62
63 Message::End => {
64 return;
65 }
66 };
67 }
68 }
69}
70
71pub struct AsyncCheckpointer<Record, B: Backend> {
73 sender: mpsc::SyncSender<Message<Record, B>>,
74 handler: Option<std::thread::JoinHandle<()>>,
75 interrupter: Option<Interrupter>,
76}
77
78impl<R, B> AsyncCheckpointer<R, B>
79where
80 R: Record<B> + 'static,
81 B: Backend,
82{
83 pub fn new<C>(checkpointer: C) -> Self
93 where
94 C: Checkpointer<R, B> + Send + 'static,
95 {
96 let (sender, receiver) = mpsc::sync_channel(0);
98 let thread = CheckpointerThread::new(checkpointer, receiver);
99 let handler = Some(std::thread::spawn(move || thread.run()));
100
101 Self {
102 sender,
103 handler,
104 interrupter: None,
105 }
106 }
107
108 pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
110 self.interrupter = Some(interrupter);
111 self
112 }
113}
114
115impl<R, B> Checkpointer<R, B> for AsyncCheckpointer<R, B>
116where
117 R: Record<B> + 'static,
118 B: Backend,
119{
120 fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
121 self.sender
122 .send(Message::Save(epoch, record, self.interrupter.clone()))
123 .expect("Can send message to checkpointer thread.");
124
125 Ok(())
126 }
127
128 fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
129 let (sender, receiver) = mpsc::sync_channel(1);
130 self.sender
131 .send(Message::Restore(
132 epoch,
133 device.clone(),
134 sender,
135 self.interrupter.clone(),
136 ))
137 .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
138
139 if let Ok(record) = receiver.recv() {
140 return record;
141 };
142
143 Err(CheckpointerError::Unknown("Channel error.".to_string()))
144 }
145
146 fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
147 self.sender
148 .send(Message::Delete(epoch, self.interrupter.clone()))
149 .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;
150
151 Ok(())
152 }
153}
154
155impl<E, B> Drop for AsyncCheckpointer<E, B>
156where
157 B: Backend,
158{
159 fn drop(&mut self) {
160 self.sender
161 .send(Message::End)
162 .expect("Can send the end message to the checkpointer thread.");
163 let handler = self.handler.take();
164
165 if let Some(handler) = handler {
166 handler
167 .join()
168 .expect("The checkpointer thread should stop.");
169 }
170 }
171}