1use super::header::{
6 ENCRYPTED_FLAG, HEADER_SIZE, MAX_WAL_PAYLOAD_SIZE, RecordHeader, WAL_FORMAT_VERSION, WAL_MAGIC,
7};
8use crate::error::{Result, WalError};
9use crate::preamble::PREAMBLE_SIZE;
10
11#[derive(Debug, Clone)]
13pub struct WalRecord {
14 pub header: RecordHeader,
15 pub payload: Vec<u8>,
16}
17
18impl WalRecord {
19 #[allow(clippy::too_many_arguments)]
34 pub fn new(
35 record_type: u32,
36 lsn: u64,
37 tenant_id: u64,
38 vshard_id: u32,
39 database_id: u64,
40 payload: Vec<u8>,
41 encryption_key: Option<&crate::crypto::WalEncryptionKey>,
42 preamble_bytes: Option<&[u8; PREAMBLE_SIZE]>,
43 ) -> Result<Self> {
44 if payload.len() > MAX_WAL_PAYLOAD_SIZE {
45 return Err(WalError::PayloadTooLarge {
46 size: payload.len(),
47 max: MAX_WAL_PAYLOAD_SIZE,
48 });
49 }
50
51 let (final_payload, encrypted) = if let Some(key) = encryption_key {
52 let temp_header = RecordHeader {
53 magic: WAL_MAGIC,
54 format_version: WAL_FORMAT_VERSION,
55 record_type,
56 lsn,
57 tenant_id,
58 vshard_id,
59 payload_len: 0,
60 database_id,
61 reserved: [0u8; 8],
62 crc32c: 0,
63 };
64 let header_bytes = temp_header.to_bytes();
65 let aad = build_aad(preamble_bytes, &header_bytes);
68 let ciphertext = key.encrypt_aad(lsn, &aad, &payload)?;
69 (ciphertext, true)
70 } else {
71 (payload, false)
72 };
73
74 let record_type = if encrypted {
75 record_type | ENCRYPTED_FLAG
76 } else {
77 record_type
78 };
79
80 let mut header = RecordHeader {
81 magic: WAL_MAGIC,
82 format_version: WAL_FORMAT_VERSION,
83 record_type,
84 lsn,
85 tenant_id,
86 vshard_id,
87 payload_len: final_payload.len() as u32,
88 database_id,
89 reserved: [0u8; 8],
90 crc32c: 0,
91 };
92
93 header.crc32c = header.compute_checksum(&final_payload);
94
95 Ok(Self {
96 header,
97 payload: final_payload,
98 })
99 }
100
101 pub fn decrypt_payload(
107 &self,
108 epoch: &[u8; 4],
109 preamble_bytes: Option<&[u8; PREAMBLE_SIZE]>,
110 encryption_key: Option<&crate::crypto::WalEncryptionKey>,
111 ) -> Result<Vec<u8>> {
112 if !self.is_encrypted() {
113 return Ok(self.payload.clone());
114 }
115
116 let key = encryption_key.ok_or_else(|| WalError::EncryptionError {
117 detail: "record is encrypted but no decryption key provided".into(),
118 })?;
119
120 let mut aad_header = self.header;
121 aad_header.record_type &= !ENCRYPTED_FLAG;
122 aad_header.payload_len = 0;
123 aad_header.crc32c = 0;
124 let header_bytes = aad_header.to_bytes();
125 let aad = build_aad(preamble_bytes, &header_bytes);
126
127 key.decrypt_aad(epoch, self.header.lsn, &aad, &self.payload)
128 }
129
130 pub fn decrypt_payload_ring(
135 &self,
136 epoch: &[u8; 4],
137 preamble_bytes: Option<&[u8; PREAMBLE_SIZE]>,
138 ring: Option<&crate::crypto::KeyRing>,
139 ) -> Result<Vec<u8>> {
140 if !self.is_encrypted() {
141 return Ok(self.payload.clone());
142 }
143
144 let ring = ring.ok_or_else(|| WalError::EncryptionError {
145 detail: "record is encrypted but no decryption key ring provided".into(),
146 })?;
147
148 let mut aad_header = self.header;
149 aad_header.record_type &= !ENCRYPTED_FLAG;
150 aad_header.payload_len = 0;
151 aad_header.crc32c = 0;
152 let header_bytes = aad_header.to_bytes();
153 let aad = build_aad(preamble_bytes, &header_bytes);
154
155 ring.decrypt_aad(epoch, self.header.lsn, &aad, &self.payload)
156 }
157
158 pub fn is_encrypted(&self) -> bool {
160 self.header.record_type & ENCRYPTED_FLAG != 0
161 }
162
163 pub fn logical_record_type(&self) -> u32 {
165 self.header.record_type & !ENCRYPTED_FLAG
166 }
167
168 pub fn verify_checksum(&self) -> Result<()> {
170 let expected = self.header.crc32c;
171 let actual = self.header.compute_checksum(&self.payload);
172 if expected != actual {
173 return Err(WalError::ChecksumMismatch {
174 lsn: self.header.lsn,
175 expected,
176 actual,
177 });
178 }
179 Ok(())
180 }
181
182 pub fn wire_size(&self) -> usize {
184 HEADER_SIZE + self.payload.len()
185 }
186}
187
188pub(crate) fn build_aad(
193 preamble_bytes: Option<&[u8; PREAMBLE_SIZE]>,
194 header_bytes: &[u8; HEADER_SIZE],
195) -> Vec<u8> {
196 match preamble_bytes {
197 Some(p) => {
198 let mut aad = Vec::with_capacity(PREAMBLE_SIZE + HEADER_SIZE);
199 aad.extend_from_slice(p);
200 aad.extend_from_slice(header_bytes);
201 aad
202 }
203 None => header_bytes.to_vec(),
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::super::types::RecordType;
210 use super::*;
211
212 #[test]
213 fn checksum_roundtrip() {
214 let payload = b"hello nodedb";
215 let record = WalRecord::new(
216 RecordType::Put as u32,
217 1,
218 0,
219 0,
220 0,
221 payload.to_vec(),
222 None,
223 None,
224 )
225 .unwrap();
226 record.verify_checksum().unwrap();
227 }
228
229 #[test]
230 fn checksum_detects_corruption() {
231 let payload = b"hello nodedb";
232 let mut record = WalRecord::new(
233 RecordType::Put as u32,
234 1,
235 0,
236 0,
237 0,
238 payload.to_vec(),
239 None,
240 None,
241 )
242 .unwrap();
243 record.payload[0] ^= 0xFF;
244 assert!(matches!(
245 record.verify_checksum(),
246 Err(WalError::ChecksumMismatch { .. })
247 ));
248 }
249
250 #[test]
251 fn payload_too_large_rejected() {
252 let big_payload = vec![0u8; MAX_WAL_PAYLOAD_SIZE + 1];
253 assert!(matches!(
254 WalRecord::new(RecordType::Put as u32, 1, 0, 0, 0, big_payload, None, None),
255 Err(WalError::PayloadTooLarge { .. })
256 ));
257 }
258
259 #[test]
260 fn anchor_payload_in_record() {
261 use super::super::anchor::LsnMsAnchorPayload;
262 let anchor = LsnMsAnchorPayload::new(42, 1_700_000_000_000);
263 let record = WalRecord::new(
264 RecordType::LsnMsAnchor as u32,
265 42,
266 0,
267 0,
268 0,
269 anchor.to_bytes().to_vec(),
270 None,
271 None,
272 )
273 .unwrap();
274 record.verify_checksum().unwrap();
275 assert_eq!(record.logical_record_type(), RecordType::LsnMsAnchor as u32);
276 let decoded = LsnMsAnchorPayload::from_bytes(&record.payload).unwrap();
277 assert_eq!(decoded, anchor);
278 }
279}