Skip to main content

chunked_wal/wal/
wal_record.rs

1use std::fmt;
2use std::fmt::Formatter;
3use std::io;
4use std::io::Cursor;
5use std::io::Read;
6
7use byteorder::BigEndian;
8use byteorder::ReadBytesExt;
9use byteorder::WriteBytesExt;
10use codeq::Encode;
11use codeq::config::CodeqConfig;
12
13use crate::WalTypes;
14use crate::types::Checksum;
15
16/// For historical reasons and compatibility, the WAL reserves record types
17/// `0..=4` for user actions, and `5` for checkpoints.
18pub const CHECKPOINT_RECORD_TYPE: u32 = 5;
19
20/// Generic record stored in the Write-Ahead Log (WAL).
21///
22/// The WAL only distinguishes user actions from state-machine checkpoints.
23/// The concrete action and checkpoint payloads are defined by the user of the
24/// WAL.
25#[derive(Clone, PartialEq, Eq)]
26pub enum WALRecord<W>
27where W: WalTypes
28{
29    /// A user-defined command.
30    Action(W::Action),
31
32    /// A state-machine checkpoint persisted by the WAL.
33    Checkpoint(W::Checkpoint),
34}
35
36impl<W> fmt::Debug for WALRecord<W>
37where W: WalTypes
38{
39    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
40        match self {
41            WALRecord::Action(action) => fmt::Debug::fmt(action, f),
42            WALRecord::Checkpoint(checkpoint) => {
43                f.debug_tuple("State").field(checkpoint).finish()
44            }
45        }
46    }
47}
48
49impl<W> fmt::Display for WALRecord<W>
50where
51    W: WalTypes,
52    W::Action: fmt::Display,
53    W::Checkpoint: fmt::Display,
54{
55    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
56        match self {
57            WALRecord::Action(action) => fmt::Display::fmt(action, f),
58            WALRecord::Checkpoint(checkpoint) => {
59                write!(f, "Checkpoint({})", checkpoint)
60            }
61        }
62    }
63}
64
65impl<W> codeq::Encode for WALRecord<W>
66where W: WalTypes
67{
68    fn encode<Wt: io::Write>(&self, mut w: Wt) -> Result<usize, io::Error> {
69        match self {
70            WALRecord::Action(action) => {
71                let type_id = action.type_id().ok_or_else(|| {
72                    io::Error::new(
73                        io::ErrorKind::InvalidInput,
74                        "action encoding does not provide a leading type id",
75                    )
76                })?;
77
78                if type_id == CHECKPOINT_RECORD_TYPE {
79                    return Err(io::Error::new(
80                        io::ErrorKind::InvalidInput,
81                        format!(
82                            "action type id {} conflicts with checkpoint",
83                            CHECKPOINT_RECORD_TYPE
84                        ),
85                    ));
86                }
87
88                action.encode(&mut w)
89            }
90            WALRecord::Checkpoint(checkpoint) => {
91                let mut n = 0;
92                let mut cw = Checksum::new_writer(&mut w);
93
94                cw.write_u32::<BigEndian>(CHECKPOINT_RECORD_TYPE)?;
95                n += 4;
96
97                n += checkpoint.encode(&mut cw)?;
98                n += cw.write_checksum()?;
99
100                Ok(n)
101            }
102        }
103    }
104}
105
106/// Implements decoding for WALRecord.
107///
108/// The wrapper inspects the record type and replays it for the decoder.
109/// Checkpoint records reread the reserved checkpoint type so v1 checksum
110/// verification still covers the type and payload.
111impl<W> codeq::Decode for WALRecord<W>
112where W: WalTypes
113{
114    fn decode<R: io::Read>(mut r: R) -> Result<Self, io::Error> {
115        let mut type_bytes = [0; 4];
116        r.read_exact(&mut type_bytes)?;
117
118        let type_id = u32::from_be_bytes(type_bytes);
119
120        if type_id != CHECKPOINT_RECORD_TYPE {
121            let mut r = Cursor::new(type_bytes).chain(r);
122            let action = W::Action::decode(&mut r)?;
123            let decoded_type_id = action.type_id().ok_or_else(|| {
124                io::Error::new(
125                    io::ErrorKind::InvalidData,
126                    "decoded action does not provide a leading type id",
127                )
128            })?;
129
130            if decoded_type_id != type_id {
131                return Err(io::Error::new(
132                    io::ErrorKind::InvalidData,
133                    format!(
134                        "action type id mismatch: encoded {}, decoded {}",
135                        type_id, decoded_type_id
136                    ),
137                ));
138            }
139
140            return Ok(Self::Action(action));
141        }
142
143        let mut cr = Checksum::new_reader(Cursor::new(type_bytes).chain(r));
144        cr.read_u32::<BigEndian>()?;
145        let rec = Self::Checkpoint(W::Checkpoint::decode(&mut cr)?);
146        cr.verify_checksum(|| "Record::decode()")?;
147
148        Ok(rec)
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use std::fmt;
155    use std::io;
156    use std::sync::mpsc::SyncSender;
157
158    use codeq::Decode;
159    use codeq::Encode;
160
161    use crate::WalTypes;
162    use crate::wal::wal_record::CHECKPOINT_RECORD_TYPE;
163    use crate::wal::wal_record::WALRecord;
164
165    const TEST_ACTION_TYPE: u32 = 1;
166
167    #[derive(Clone, PartialEq, Eq)]
168    struct TestAction(String);
169
170    impl fmt::Debug for TestAction {
171        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172            fmt::Debug::fmt(&self.0, f)
173        }
174    }
175
176    impl fmt::Display for TestAction {
177        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178            fmt::Display::fmt(&self.0, f)
179        }
180    }
181
182    impl Encode for TestAction {
183        fn encode<Wt: io::Write>(&self, mut w: Wt) -> Result<usize, io::Error> {
184            let mut n = TEST_ACTION_TYPE.encode(&mut w)?;
185            n += self.0.encode(&mut w)?;
186            Ok(n)
187        }
188
189        fn type_id(&self) -> Option<u32> {
190            Some(TEST_ACTION_TYPE)
191        }
192    }
193
194    impl Decode for TestAction {
195        fn decode<R: io::Read>(mut r: R) -> Result<Self, io::Error> {
196            let type_id = u32::decode(&mut r)?;
197            if type_id != TEST_ACTION_TYPE {
198                return Err(io::Error::new(
199                    io::ErrorKind::InvalidData,
200                    format!("unexpected action type id {}", type_id),
201                ));
202            }
203
204            Ok(Self(String::decode(&mut r)?))
205        }
206    }
207
208    #[derive(Debug, Default, Clone, PartialEq, Eq)]
209    struct TestWal;
210
211    impl WalTypes for TestWal {
212        type Action = TestAction;
213        type Checkpoint = String;
214        type Callback = SyncSender<Result<(), io::Error>>;
215    }
216
217    #[derive(Debug, Default, Clone, PartialEq, Eq)]
218    struct NoTypeWal;
219
220    impl WalTypes for NoTypeWal {
221        type Action = String;
222        type Checkpoint = String;
223        type Callback = SyncSender<Result<(), io::Error>>;
224    }
225
226    fn action(v: &str) -> WALRecord<TestWal> {
227        WALRecord::Action(TestAction(v.to_string()))
228    }
229
230    fn checkpoint(v: &str) -> WALRecord<TestWal> {
231        WALRecord::Checkpoint(v.to_string())
232    }
233
234    fn checkpoint_state_bytes() -> Vec<u8> {
235        vec![
236            0, 0, 0, 5, // checkpoint record type
237            0, 0, 0, 5, // checkpoint string len
238            115, 116, 97, 116, 101, // checkpoint string: state
239            0, 0, 0, 0, 220, 33, 57, 147, // checksum
240        ]
241    }
242
243    #[test]
244    fn test_action_debug_display_and_clone() {
245        let rec = action("vote");
246
247        assert_eq!("\"vote\"", format!("{:?}", rec));
248        assert_eq!("vote", format!("{}", rec));
249        assert_eq!(rec, rec.clone());
250    }
251
252    #[test]
253    fn test_checkpoint_debug_display_and_clone() {
254        let rec = checkpoint("state");
255
256        assert_eq!("State(\"state\")", format!("{:?}", rec));
257        assert_eq!("Checkpoint(state)", format!("{}", rec));
258        assert_eq!(rec, rec.clone());
259    }
260
261    #[test]
262    fn test_encode_action_delegates_to_action_codec() -> Result<(), io::Error> {
263        let mut got = Vec::new();
264
265        let n = action("vote").encode(&mut got)?;
266
267        assert_eq!(got.len(), n);
268        assert_eq!(vec![0, 0, 0, 1, 0, 0, 0, 4, 118, 111, 116, 101], got);
269        Ok(())
270    }
271
272    #[test]
273    fn test_encode_action_requires_type_id() {
274        let mut got = Vec::new();
275        let rec = WALRecord::<NoTypeWal>::Action("vote".to_string());
276
277        let err = rec.encode(&mut got).unwrap_err();
278
279        assert_eq!(io::ErrorKind::InvalidInput, err.kind());
280        assert!(err.to_string().contains("does not provide"));
281    }
282
283    #[test]
284    fn test_encode_checkpoint_adds_type_and_checksum() -> Result<(), io::Error>
285    {
286        let mut got = Vec::new();
287
288        let n = checkpoint("state").encode(&mut got)?;
289
290        assert_eq!(CHECKPOINT_RECORD_TYPE, 5);
291        assert_eq!(got.len(), n);
292        assert_eq!(checkpoint_state_bytes(), got);
293        Ok(())
294    }
295
296    #[test]
297    fn test_decode_action_replays_record_type_bytes() -> Result<(), io::Error> {
298        let mut bytes = Vec::new();
299        action("vote").encode(&mut bytes)?;
300        action("log").encode(&mut bytes)?;
301
302        let mut r = &bytes[..];
303
304        assert_eq!(action("vote"), WALRecord::<TestWal>::decode(&mut r)?);
305        assert_eq!(action("log"), WALRecord::<TestWal>::decode(&mut r)?);
306        assert_eq!(&[] as &[u8], r);
307        Ok(())
308    }
309
310    #[test]
311    fn test_decode_checkpoint_verifies_checksum() -> Result<(), io::Error> {
312        let bytes = checkpoint_state_bytes();
313
314        let got = WALRecord::<TestWal>::decode(&mut bytes.as_slice())?;
315
316        assert_eq!(checkpoint("state"), got);
317        Ok(())
318    }
319
320    #[test]
321    fn test_decode_checkpoint_rejects_bad_checksum() {
322        let mut bytes = checkpoint_state_bytes();
323        *bytes.last_mut().unwrap() ^= 1;
324
325        let err = WALRecord::<TestWal>::decode(&mut bytes.as_slice())
326            .expect_err("corrupted checkpoint checksum must fail");
327
328        assert_eq!(io::ErrorKind::InvalidData, err.kind());
329        assert!(err.to_string().contains("Record::decode()"));
330    }
331
332    #[test]
333    fn test_decode_rejects_short_record_type() {
334        let err = WALRecord::<TestWal>::decode(&mut [0, 0, 0].as_slice())
335            .expect_err("short record type must fail");
336
337        assert_eq!(io::ErrorKind::UnexpectedEof, err.kind());
338    }
339}