1use std::fmt;
2use std::fmt::Formatter;
3use std::io;
4use std::io::Cursor;
5use std::io::Read;
6
7use byteorder::BigEndian;
8use byteorder::ReadBytesExt;
9use byteorder::WriteBytesExt;
10use codeq::Encode;
11use codeq::config::CodeqConfig;
12
13use crate::WalTypes;
14use crate::types::Checksum;
15
16pub const CHECKPOINT_RECORD_TYPE: u32 = 5;
19
20#[derive(Clone, PartialEq, Eq)]
26pub enum WALRecord<W>
27where W: WalTypes
28{
29 Action(W::Action),
31
32 Checkpoint(W::Checkpoint),
34}
35
36impl<W> fmt::Debug for WALRecord<W>
37where W: WalTypes
38{
39 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
40 match self {
41 WALRecord::Action(action) => fmt::Debug::fmt(action, f),
42 WALRecord::Checkpoint(checkpoint) => {
43 f.debug_tuple("State").field(checkpoint).finish()
44 }
45 }
46 }
47}
48
49impl<W> fmt::Display for WALRecord<W>
50where
51 W: WalTypes,
52 W::Action: fmt::Display,
53 W::Checkpoint: fmt::Display,
54{
55 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
56 match self {
57 WALRecord::Action(action) => fmt::Display::fmt(action, f),
58 WALRecord::Checkpoint(checkpoint) => {
59 write!(f, "Checkpoint({})", checkpoint)
60 }
61 }
62 }
63}
64
65impl<W> codeq::Encode for WALRecord<W>
66where W: WalTypes
67{
68 fn encode<Wt: io::Write>(&self, mut w: Wt) -> Result<usize, io::Error> {
69 match self {
70 WALRecord::Action(action) => {
71 let type_id = action.type_id().ok_or_else(|| {
72 io::Error::new(
73 io::ErrorKind::InvalidInput,
74 "action encoding does not provide a leading type id",
75 )
76 })?;
77
78 if type_id == CHECKPOINT_RECORD_TYPE {
79 return Err(io::Error::new(
80 io::ErrorKind::InvalidInput,
81 format!(
82 "action type id {} conflicts with checkpoint",
83 CHECKPOINT_RECORD_TYPE
84 ),
85 ));
86 }
87
88 action.encode(&mut w)
89 }
90 WALRecord::Checkpoint(checkpoint) => {
91 let mut n = 0;
92 let mut cw = Checksum::new_writer(&mut w);
93
94 cw.write_u32::<BigEndian>(CHECKPOINT_RECORD_TYPE)?;
95 n += 4;
96
97 n += checkpoint.encode(&mut cw)?;
98 n += cw.write_checksum()?;
99
100 Ok(n)
101 }
102 }
103 }
104}
105
106impl<W> codeq::Decode for WALRecord<W>
112where W: WalTypes
113{
114 fn decode<R: io::Read>(mut r: R) -> Result<Self, io::Error> {
115 let mut type_bytes = [0; 4];
116 r.read_exact(&mut type_bytes)?;
117
118 let type_id = u32::from_be_bytes(type_bytes);
119
120 if type_id != CHECKPOINT_RECORD_TYPE {
121 let mut r = Cursor::new(type_bytes).chain(r);
122 let action = W::Action::decode(&mut r)?;
123 let decoded_type_id = action.type_id().ok_or_else(|| {
124 io::Error::new(
125 io::ErrorKind::InvalidData,
126 "decoded action does not provide a leading type id",
127 )
128 })?;
129
130 if decoded_type_id != type_id {
131 return Err(io::Error::new(
132 io::ErrorKind::InvalidData,
133 format!(
134 "action type id mismatch: encoded {}, decoded {}",
135 type_id, decoded_type_id
136 ),
137 ));
138 }
139
140 return Ok(Self::Action(action));
141 }
142
143 let mut cr = Checksum::new_reader(Cursor::new(type_bytes).chain(r));
144 cr.read_u32::<BigEndian>()?;
145 let rec = Self::Checkpoint(W::Checkpoint::decode(&mut cr)?);
146 cr.verify_checksum(|| "Record::decode()")?;
147
148 Ok(rec)
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use std::fmt;
155 use std::io;
156 use std::sync::mpsc::SyncSender;
157
158 use codeq::Decode;
159 use codeq::Encode;
160
161 use crate::WalTypes;
162 use crate::wal::wal_record::CHECKPOINT_RECORD_TYPE;
163 use crate::wal::wal_record::WALRecord;
164
165 const TEST_ACTION_TYPE: u32 = 1;
166
167 #[derive(Clone, PartialEq, Eq)]
168 struct TestAction(String);
169
170 impl fmt::Debug for TestAction {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 fmt::Debug::fmt(&self.0, f)
173 }
174 }
175
176 impl fmt::Display for TestAction {
177 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178 fmt::Display::fmt(&self.0, f)
179 }
180 }
181
182 impl Encode for TestAction {
183 fn encode<Wt: io::Write>(&self, mut w: Wt) -> Result<usize, io::Error> {
184 let mut n = TEST_ACTION_TYPE.encode(&mut w)?;
185 n += self.0.encode(&mut w)?;
186 Ok(n)
187 }
188
189 fn type_id(&self) -> Option<u32> {
190 Some(TEST_ACTION_TYPE)
191 }
192 }
193
194 impl Decode for TestAction {
195 fn decode<R: io::Read>(mut r: R) -> Result<Self, io::Error> {
196 let type_id = u32::decode(&mut r)?;
197 if type_id != TEST_ACTION_TYPE {
198 return Err(io::Error::new(
199 io::ErrorKind::InvalidData,
200 format!("unexpected action type id {}", type_id),
201 ));
202 }
203
204 Ok(Self(String::decode(&mut r)?))
205 }
206 }
207
208 #[derive(Debug, Default, Clone, PartialEq, Eq)]
209 struct TestWal;
210
211 impl WalTypes for TestWal {
212 type Action = TestAction;
213 type Checkpoint = String;
214 type Callback = SyncSender<Result<(), io::Error>>;
215 }
216
217 #[derive(Debug, Default, Clone, PartialEq, Eq)]
218 struct NoTypeWal;
219
220 impl WalTypes for NoTypeWal {
221 type Action = String;
222 type Checkpoint = String;
223 type Callback = SyncSender<Result<(), io::Error>>;
224 }
225
226 fn action(v: &str) -> WALRecord<TestWal> {
227 WALRecord::Action(TestAction(v.to_string()))
228 }
229
230 fn checkpoint(v: &str) -> WALRecord<TestWal> {
231 WALRecord::Checkpoint(v.to_string())
232 }
233
234 fn checkpoint_state_bytes() -> Vec<u8> {
235 vec![
236 0, 0, 0, 5, 0, 0, 0, 5, 115, 116, 97, 116, 101, 0, 0, 0, 0, 220, 33, 57, 147, ]
241 }
242
243 #[test]
244 fn test_action_debug_display_and_clone() {
245 let rec = action("vote");
246
247 assert_eq!("\"vote\"", format!("{:?}", rec));
248 assert_eq!("vote", format!("{}", rec));
249 assert_eq!(rec, rec.clone());
250 }
251
252 #[test]
253 fn test_checkpoint_debug_display_and_clone() {
254 let rec = checkpoint("state");
255
256 assert_eq!("State(\"state\")", format!("{:?}", rec));
257 assert_eq!("Checkpoint(state)", format!("{}", rec));
258 assert_eq!(rec, rec.clone());
259 }
260
261 #[test]
262 fn test_encode_action_delegates_to_action_codec() -> Result<(), io::Error> {
263 let mut got = Vec::new();
264
265 let n = action("vote").encode(&mut got)?;
266
267 assert_eq!(got.len(), n);
268 assert_eq!(vec![0, 0, 0, 1, 0, 0, 0, 4, 118, 111, 116, 101], got);
269 Ok(())
270 }
271
272 #[test]
273 fn test_encode_action_requires_type_id() {
274 let mut got = Vec::new();
275 let rec = WALRecord::<NoTypeWal>::Action("vote".to_string());
276
277 let err = rec.encode(&mut got).unwrap_err();
278
279 assert_eq!(io::ErrorKind::InvalidInput, err.kind());
280 assert!(err.to_string().contains("does not provide"));
281 }
282
283 #[test]
284 fn test_encode_checkpoint_adds_type_and_checksum() -> Result<(), io::Error>
285 {
286 let mut got = Vec::new();
287
288 let n = checkpoint("state").encode(&mut got)?;
289
290 assert_eq!(CHECKPOINT_RECORD_TYPE, 5);
291 assert_eq!(got.len(), n);
292 assert_eq!(checkpoint_state_bytes(), got);
293 Ok(())
294 }
295
296 #[test]
297 fn test_decode_action_replays_record_type_bytes() -> Result<(), io::Error> {
298 let mut bytes = Vec::new();
299 action("vote").encode(&mut bytes)?;
300 action("log").encode(&mut bytes)?;
301
302 let mut r = &bytes[..];
303
304 assert_eq!(action("vote"), WALRecord::<TestWal>::decode(&mut r)?);
305 assert_eq!(action("log"), WALRecord::<TestWal>::decode(&mut r)?);
306 assert_eq!(&[] as &[u8], r);
307 Ok(())
308 }
309
310 #[test]
311 fn test_decode_checkpoint_verifies_checksum() -> Result<(), io::Error> {
312 let bytes = checkpoint_state_bytes();
313
314 let got = WALRecord::<TestWal>::decode(&mut bytes.as_slice())?;
315
316 assert_eq!(checkpoint("state"), got);
317 Ok(())
318 }
319
320 #[test]
321 fn test_decode_checkpoint_rejects_bad_checksum() {
322 let mut bytes = checkpoint_state_bytes();
323 *bytes.last_mut().unwrap() ^= 1;
324
325 let err = WALRecord::<TestWal>::decode(&mut bytes.as_slice())
326 .expect_err("corrupted checkpoint checksum must fail");
327
328 assert_eq!(io::ErrorKind::InvalidData, err.kind());
329 assert!(err.to_string().contains("Record::decode()"));
330 }
331
332 #[test]
333 fn test_decode_rejects_short_record_type() {
334 let err = WALRecord::<TestWal>::decode(&mut [0, 0, 0].as_slice())
335 .expect_err("short record type must fail");
336
337 assert_eq!(io::ErrorKind::UnexpectedEof, err.kind());
338 }
339}