1use super::header::{
6 ENCRYPTED_FLAG, HEADER_SIZE, MAX_WAL_PAYLOAD_SIZE, RecordHeader, WAL_FORMAT_VERSION, WAL_MAGIC,
7};
8use crate::error::{Result, WalError};
9use crate::preamble::PREAMBLE_SIZE;
10
11#[derive(Debug, Clone)]
13pub struct WalRecord {
14 pub header: RecordHeader,
15 pub payload: Vec<u8>,
16}
17
18impl WalRecord {
19 pub fn new(
30 record_type: u32,
31 lsn: u64,
32 tenant_id: u64,
33 vshard_id: u32,
34 payload: Vec<u8>,
35 encryption_key: Option<&crate::crypto::WalEncryptionKey>,
36 preamble_bytes: Option<&[u8; PREAMBLE_SIZE]>,
37 ) -> Result<Self> {
38 if payload.len() > MAX_WAL_PAYLOAD_SIZE {
39 return Err(WalError::PayloadTooLarge {
40 size: payload.len(),
41 max: MAX_WAL_PAYLOAD_SIZE,
42 });
43 }
44
45 let (final_payload, encrypted) = if let Some(key) = encryption_key {
46 let temp_header = RecordHeader {
47 magic: WAL_MAGIC,
48 format_version: WAL_FORMAT_VERSION,
49 record_type,
50 lsn,
51 tenant_id,
52 vshard_id,
53 payload_len: 0,
54 reserved: [0u8; 16],
55 crc32c: 0,
56 };
57 let header_bytes = temp_header.to_bytes();
58 let aad = build_aad(preamble_bytes, &header_bytes);
61 let ciphertext = key.encrypt_aad(lsn, &aad, &payload)?;
62 (ciphertext, true)
63 } else {
64 (payload, false)
65 };
66
67 let record_type = if encrypted {
68 record_type | ENCRYPTED_FLAG
69 } else {
70 record_type
71 };
72
73 let mut header = RecordHeader {
74 magic: WAL_MAGIC,
75 format_version: WAL_FORMAT_VERSION,
76 record_type,
77 lsn,
78 tenant_id,
79 vshard_id,
80 payload_len: final_payload.len() as u32,
81 reserved: [0u8; 16],
82 crc32c: 0,
83 };
84
85 header.crc32c = header.compute_checksum(&final_payload);
86
87 Ok(Self {
88 header,
89 payload: final_payload,
90 })
91 }
92
93 pub fn decrypt_payload(
99 &self,
100 epoch: &[u8; 4],
101 preamble_bytes: Option<&[u8; PREAMBLE_SIZE]>,
102 encryption_key: Option<&crate::crypto::WalEncryptionKey>,
103 ) -> Result<Vec<u8>> {
104 if !self.is_encrypted() {
105 return Ok(self.payload.clone());
106 }
107
108 let key = encryption_key.ok_or_else(|| WalError::EncryptionError {
109 detail: "record is encrypted but no decryption key provided".into(),
110 })?;
111
112 let mut aad_header = self.header;
113 aad_header.record_type &= !ENCRYPTED_FLAG;
114 aad_header.payload_len = 0;
115 aad_header.crc32c = 0;
116 let header_bytes = aad_header.to_bytes();
117 let aad = build_aad(preamble_bytes, &header_bytes);
118
119 key.decrypt_aad(epoch, self.header.lsn, &aad, &self.payload)
120 }
121
122 pub fn decrypt_payload_ring(
127 &self,
128 epoch: &[u8; 4],
129 preamble_bytes: Option<&[u8; PREAMBLE_SIZE]>,
130 ring: Option<&crate::crypto::KeyRing>,
131 ) -> Result<Vec<u8>> {
132 if !self.is_encrypted() {
133 return Ok(self.payload.clone());
134 }
135
136 let ring = ring.ok_or_else(|| WalError::EncryptionError {
137 detail: "record is encrypted but no decryption key ring provided".into(),
138 })?;
139
140 let mut aad_header = self.header;
141 aad_header.record_type &= !ENCRYPTED_FLAG;
142 aad_header.payload_len = 0;
143 aad_header.crc32c = 0;
144 let header_bytes = aad_header.to_bytes();
145 let aad = build_aad(preamble_bytes, &header_bytes);
146
147 ring.decrypt_aad(epoch, self.header.lsn, &aad, &self.payload)
148 }
149
150 pub fn is_encrypted(&self) -> bool {
152 self.header.record_type & ENCRYPTED_FLAG != 0
153 }
154
155 pub fn logical_record_type(&self) -> u32 {
157 self.header.record_type & !ENCRYPTED_FLAG
158 }
159
160 pub fn verify_checksum(&self) -> Result<()> {
162 let expected = self.header.crc32c;
163 let actual = self.header.compute_checksum(&self.payload);
164 if expected != actual {
165 return Err(WalError::ChecksumMismatch {
166 lsn: self.header.lsn,
167 expected,
168 actual,
169 });
170 }
171 Ok(())
172 }
173
174 pub fn wire_size(&self) -> usize {
176 HEADER_SIZE + self.payload.len()
177 }
178}
179
180pub(crate) fn build_aad(
185 preamble_bytes: Option<&[u8; PREAMBLE_SIZE]>,
186 header_bytes: &[u8; HEADER_SIZE],
187) -> Vec<u8> {
188 match preamble_bytes {
189 Some(p) => {
190 let mut aad = Vec::with_capacity(PREAMBLE_SIZE + HEADER_SIZE);
191 aad.extend_from_slice(p);
192 aad.extend_from_slice(header_bytes);
193 aad
194 }
195 None => header_bytes.to_vec(),
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::super::types::RecordType;
202 use super::*;
203
204 #[test]
205 fn checksum_roundtrip() {
206 let payload = b"hello nodedb";
207 let record = WalRecord::new(
208 RecordType::Put as u32,
209 1,
210 0,
211 0,
212 payload.to_vec(),
213 None,
214 None,
215 )
216 .unwrap();
217 record.verify_checksum().unwrap();
218 }
219
220 #[test]
221 fn checksum_detects_corruption() {
222 let payload = b"hello nodedb";
223 let mut record = WalRecord::new(
224 RecordType::Put as u32,
225 1,
226 0,
227 0,
228 payload.to_vec(),
229 None,
230 None,
231 )
232 .unwrap();
233 record.payload[0] ^= 0xFF;
234 assert!(matches!(
235 record.verify_checksum(),
236 Err(WalError::ChecksumMismatch { .. })
237 ));
238 }
239
240 #[test]
241 fn payload_too_large_rejected() {
242 let big_payload = vec![0u8; MAX_WAL_PAYLOAD_SIZE + 1];
243 assert!(matches!(
244 WalRecord::new(RecordType::Put as u32, 1, 0, 0, big_payload, None, None),
245 Err(WalError::PayloadTooLarge { .. })
246 ));
247 }
248
249 #[test]
250 fn anchor_payload_in_record() {
251 use super::super::anchor::LsnMsAnchorPayload;
252 let anchor = LsnMsAnchorPayload::new(42, 1_700_000_000_000);
253 let record = WalRecord::new(
254 RecordType::LsnMsAnchor as u32,
255 42,
256 0,
257 0,
258 anchor.to_bytes().to_vec(),
259 None,
260 None,
261 )
262 .unwrap();
263 record.verify_checksum().unwrap();
264 assert_eq!(record.logical_record_type(), RecordType::LsnMsAnchor as u32);
265 let decoded = LsnMsAnchorPayload::from_bytes(&record.payload).unwrap();
266 assert_eq!(decoded, anchor);
267 }
268}