1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
use super::{Checkpointer, CheckpointerError};
use burn_core::{record::Record, tensor::backend::Backend};
use std::sync::mpsc;

enum Message<R, B: Backend> {
    Restore(
        usize,
        B::Device,
        mpsc::SyncSender<Result<R, CheckpointerError>>,
    ),
    Save(usize, R),
    Delete(usize),
    End,
}

#[derive(new)]
struct CheckpointerThread<C, R, B: Backend> {
    checkpointer: C,
    receiver: mpsc::Receiver<Message<R, B>>,
}

impl<C, R, B> CheckpointerThread<C, R, B>
where
    C: Checkpointer<R, B>,
    R: Record<B>,
    B: Backend,
{
    fn run(self) {
        for item in self.receiver.iter() {
            match item {
                Message::Restore(epoch, device, callback) => {
                    let record = self.checkpointer.restore(epoch, &device);
                    callback
                        .send(record)
                        .expect("Can send response through callback channel.");
                }
                Message::Save(epoch, state) => self
                    .checkpointer
                    .save(epoch, state)
                    .expect("Can save the state."),
                Message::Delete(epoch) => self
                    .checkpointer
                    .delete(epoch)
                    .expect("Can delete the state."),
                Message::End => {
                    return;
                }
            };
        }
    }
}

/// Async checkpointer.
pub struct AsyncCheckpointer<Record, B: Backend> {
    sender: mpsc::SyncSender<Message<Record, B>>,
    handler: Option<std::thread::JoinHandle<()>>,
}

impl<R, B> AsyncCheckpointer<R, B>
where
    R: Record<B> + 'static,
    B: Backend,
{
    /// Create a new async checkpointer.
    ///
    /// # Arguments
    ///
    /// * `checkpointer` - The checkpointer.
    ///
    /// # Returns
    ///
    /// The async checkpointer.
    pub fn new<C>(checkpointer: C) -> Self
    where
        C: Checkpointer<R, B> + Send + 'static,
    {
        // Only on checkpoint can be done in advance.
        let (sender, receiver) = mpsc::sync_channel(0);
        let thread = CheckpointerThread::new(checkpointer, receiver);
        let handler = Some(std::thread::spawn(move || thread.run()));

        Self { sender, handler }
    }
}

impl<R, B> Checkpointer<R, B> for AsyncCheckpointer<R, B>
where
    R: Record<B> + 'static,
    B: Backend,
{
    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
        self.sender
            .send(Message::Save(epoch, record))
            .expect("Can send message to checkpointer thread.");

        Ok(())
    }

    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
        let (sender, receiver) = mpsc::sync_channel(1);
        self.sender
            .send(Message::Restore(epoch, device.clone(), sender))
            .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;

        if let Ok(record) = receiver.recv() {
            return record;
        };

        Err(CheckpointerError::Unknown("Channel error.".to_string()))
    }

    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
        self.sender
            .send(Message::Delete(epoch))
            .map_err(|e| CheckpointerError::Unknown(e.to_string()))?;

        Ok(())
    }
}

impl<E, B> Drop for AsyncCheckpointer<E, B>
where
    B: Backend,
{
    fn drop(&mut self) {
        self.sender
            .send(Message::End)
            .expect("Can send the end message to the checkpointer thread.");
        let handler = self.handler.take();

        if let Some(handler) = handler {
            handler
                .join()
                .expect("The checkpointer thread should stop.");
        }
    }
}