1use crate::error::{DbxError, DbxResult};
30use crate::storage::encryption::EncryptionConfig;
31use crate::wal::WalRecord;
32
33use std::fs::{File, OpenOptions};
34use std::io::{BufRead, BufReader, Write};
35use std::path::{Path, PathBuf};
36use std::sync::Mutex;
37use std::sync::atomic::{AtomicU64, Ordering};
38
39const WAL_AAD: &[u8] = b"dbx-wal-v1";
41
42pub struct EncryptedWal {
74 log_file: Mutex<File>,
76 path: PathBuf,
78 sequence: AtomicU64,
80 encryption: EncryptionConfig,
82}
83
84impl EncryptedWal {
85 pub fn open(path: &Path, encryption: EncryptionConfig) -> DbxResult<Self> {
87 let file = OpenOptions::new()
88 .create(true)
89 .read(true)
90 .append(true)
91 .open(path)?;
92
93 let max_seq = Self::scan_max_sequence(path, &encryption)?;
95
96 Ok(Self {
97 log_file: Mutex::new(file),
98 path: path.to_path_buf(),
99 sequence: AtomicU64::new(max_seq),
100 encryption,
101 })
102 }
103
104 fn scan_max_sequence(path: &Path, encryption: &EncryptionConfig) -> DbxResult<u64> {
106 let file = match File::open(path) {
107 Ok(f) => f,
108 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(0),
109 Err(e) => return Err(e.into()),
110 };
111
112 let reader = BufReader::new(file);
113 let mut max_seq = 0u64;
114
115 for line in reader.lines() {
116 let line = line?;
117 if line.is_empty() {
118 continue;
119 }
120
121 if let Ok(record) = Self::decrypt_line(&line, encryption)
123 && let WalRecord::Checkpoint { sequence } = record
124 {
125 max_seq = max_seq.max(sequence);
126 }
127 max_seq += 1;
129 }
130
131 Ok(max_seq)
132 }
133
134 fn decrypt_line(line: &str, encryption: &EncryptionConfig) -> DbxResult<WalRecord> {
136 use base64::Engine;
137 use base64::engine::general_purpose::STANDARD;
138
139 let ciphertext = STANDARD
140 .decode(line.as_bytes())
141 .map_err(|e| DbxError::Encryption(format!("base64 decode failed: {}", e)))?;
142
143 let json_bytes = encryption.decrypt_with_aad(&ciphertext, WAL_AAD)?;
144
145 serde_json::from_slice(&json_bytes)
146 .map_err(|e| DbxError::Wal(format!("deserialization failed: {}", e)))
147 }
148
149 pub fn append(&self, record: &WalRecord) -> DbxResult<u64> {
154 use base64::Engine;
155 use base64::engine::general_purpose::STANDARD;
156
157 let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
158
159 let json = serde_json::to_vec(record)
161 .map_err(|e| DbxError::Wal(format!("serialization failed: {}", e)))?;
162
163 let ciphertext = self.encryption.encrypt_with_aad(&json, WAL_AAD)?;
164 let encoded = STANDARD.encode(&ciphertext);
165
166 let mut file = self
168 .log_file
169 .lock()
170 .map_err(|e| DbxError::Wal(format!("lock failed: {}", e)))?;
171
172 file.write_all(encoded.as_bytes())?;
173 file.write_all(b"\n")?;
174
175 Ok(seq)
176 }
177
178 pub fn sync(&self) -> DbxResult<()> {
180 let file = self
181 .log_file
182 .lock()
183 .map_err(|e| DbxError::Wal(format!("lock failed: {}", e)))?;
184
185 file.sync_all()?;
186 Ok(())
187 }
188
189 pub fn replay(&self) -> DbxResult<Vec<WalRecord>> {
191 let file = File::open(&self.path)?;
192 let reader = BufReader::new(file);
193 let mut records = Vec::new();
194
195 for line in reader.lines() {
196 let line = line?;
197 if line.is_empty() {
198 continue;
199 }
200
201 let record = Self::decrypt_line(&line, &self.encryption)?;
202 records.push(record);
203 }
204
205 Ok(records)
206 }
207
208 pub fn current_sequence(&self) -> u64 {
210 self.sequence.load(Ordering::SeqCst)
211 }
212
213 pub fn encryption_config(&self) -> &EncryptionConfig {
215 &self.encryption
216 }
217
218 pub fn rekey(&mut self, new_encryption: EncryptionConfig) -> DbxResult<usize> {
228 use base64::Engine;
229 use base64::engine::general_purpose::STANDARD;
230
231 let records = self.replay()?;
233 let count = records.len();
234
235 let tmp_path = self.path.with_extension("rekey.tmp");
237 {
238 let mut tmp_file = File::create(&tmp_path)?;
239 for record in &records {
240 let json = serde_json::to_vec(record)
241 .map_err(|e| DbxError::Wal(format!("serialization failed: {}", e)))?;
242 let ciphertext = new_encryption.encrypt_with_aad(&json, WAL_AAD)?;
243 let encoded = STANDARD.encode(&ciphertext);
244 tmp_file.write_all(encoded.as_bytes())?;
245 tmp_file.write_all(b"\n")?;
246 }
247 tmp_file.sync_all()?;
248 }
249
250 std::fs::rename(&tmp_path, &self.path)?;
252
253 let file = OpenOptions::new()
255 .create(true)
256 .read(true)
257 .append(true)
258 .open(&self.path)?;
259
260 *self
261 .log_file
262 .lock()
263 .map_err(|e| DbxError::Wal(format!("lock failed: {}", e)))? = file;
264 self.encryption = new_encryption;
265
266 Ok(count)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use tempfile::NamedTempFile;
274
275 fn test_encryption() -> EncryptionConfig {
276 EncryptionConfig::from_password("test-wal-password")
277 }
278
279 #[test]
280 fn append_and_replay_round_trip() {
281 let temp = NamedTempFile::new().unwrap();
282 let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
283
284 let record1 = WalRecord::Insert {
285 table: "users".to_string(),
286 key: b"user:1".to_vec(),
287 value: b"Alice".to_vec(),
288 ts: 0,
289 };
290 let record2 = WalRecord::Delete {
291 table: "users".to_string(),
292 key: b"user:2".to_vec(),
293 ts: 1,
294 };
295
296 let seq1 = wal.append(&record1).unwrap();
297 let seq2 = wal.append(&record2).unwrap();
298 wal.sync().unwrap();
299
300 assert_eq!(seq1, 0);
301 assert_eq!(seq2, 1);
302
303 let records = wal.replay().unwrap();
304 assert_eq!(records.len(), 2);
305 assert_eq!(records[0], record1);
306 assert_eq!(records[1], record2);
307 }
308
309 #[test]
310 fn sync_durability() {
311 let temp = NamedTempFile::new().unwrap();
312 let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
313
314 let record = WalRecord::Insert {
315 table: "test".to_string(),
316 key: b"key".to_vec(),
317 value: b"value".to_vec(),
318 ts: 0,
319 };
320
321 wal.append(&record).unwrap();
322 wal.sync().unwrap();
323
324 let wal2 = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
326 let records = wal2.replay().unwrap();
327 assert_eq!(records.len(), 1);
328 assert_eq!(records[0], record);
329 }
330
331 #[test]
332 fn wrong_key_cannot_replay() {
333 let temp = NamedTempFile::new().unwrap();
334 let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
335
336 let record = WalRecord::Insert {
337 table: "secret".to_string(),
338 key: b"key".to_vec(),
339 value: b"value".to_vec(),
340 ts: 0,
341 };
342
343 wal.append(&record).unwrap();
344 wal.sync().unwrap();
345
346 let wrong_enc = EncryptionConfig::from_password("wrong-password");
348 let wal2 = EncryptedWal::open(temp.path(), wrong_enc).unwrap();
349 let result = wal2.replay();
350 assert!(result.is_err(), "Replay with wrong key should fail");
351 }
352
353 #[test]
354 fn empty_wal_replay() {
355 let temp = NamedTempFile::new().unwrap();
356 let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
357 let records = wal.replay().unwrap();
358 assert_eq!(records.len(), 0);
359 }
360
361 #[test]
362 fn checkpoint_record() {
363 let temp = NamedTempFile::new().unwrap();
364 let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
365
366 let checkpoint = WalRecord::Checkpoint { sequence: 42 };
367 wal.append(&checkpoint).unwrap();
368 wal.sync().unwrap();
369
370 let records = wal.replay().unwrap();
371 assert_eq!(records.len(), 1);
372 assert_eq!(records[0], checkpoint);
373 }
374
375 #[test]
376 fn multiple_record_types() {
377 let temp = NamedTempFile::new().unwrap();
378 let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
379
380 let records_to_write = vec![
381 WalRecord::Insert {
382 table: "t".to_string(),
383 key: b"k1".to_vec(),
384 value: b"v1".to_vec(),
385 ts: 0,
386 },
387 WalRecord::Delete {
388 table: "t".to_string(),
389 key: b"k2".to_vec(),
390 ts: 1,
391 },
392 WalRecord::Commit { tx_id: 1 },
393 WalRecord::Rollback { tx_id: 2 },
394 WalRecord::Checkpoint { sequence: 10 },
395 ];
396
397 for r in &records_to_write {
398 wal.append(r).unwrap();
399 }
400 wal.sync().unwrap();
401
402 let replayed = wal.replay().unwrap();
403 assert_eq!(replayed, records_to_write);
404 }
405
406 #[test]
407 fn raw_file_is_not_readable() {
408 let temp = NamedTempFile::new().unwrap();
409 let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
410
411 let record = WalRecord::Insert {
412 table: "secret".to_string(),
413 key: b"key".to_vec(),
414 value: b"sensitive_data".to_vec(),
415 ts: 0,
416 };
417
418 wal.append(&record).unwrap();
419 wal.sync().unwrap();
420
421 let raw = std::fs::read_to_string(temp.path()).unwrap();
423 assert!(!raw.contains("secret"));
424 assert!(!raw.contains("sensitive_data"));
425 assert!(!raw.contains("key"));
426 }
427
428 #[test]
429 fn rekey_preserves_records() {
430 let temp = NamedTempFile::new().unwrap();
431 let enc_old = EncryptionConfig::from_password("old-key");
432 let mut wal = EncryptedWal::open(temp.path(), enc_old).unwrap();
433
434 let record1 = WalRecord::Insert {
435 table: "t".to_string(),
436 key: b"k1".to_vec(),
437 value: b"v1".to_vec(),
438 ts: 0,
439 };
440 let record2 = WalRecord::Delete {
441 table: "t".to_string(),
442 key: b"k2".to_vec(),
443 ts: 1,
444 };
445
446 wal.append(&record1).unwrap();
447 wal.append(&record2).unwrap();
448 wal.sync().unwrap();
449
450 let enc_new = EncryptionConfig::from_password("new-key");
452 let count = wal.rekey(enc_new.clone()).unwrap();
453 assert_eq!(count, 2);
454
455 let records = wal.replay().unwrap();
457 assert_eq!(records.len(), 2);
458 assert_eq!(records[0], record1);
459 assert_eq!(records[1], record2);
460
461 let wal2 = EncryptedWal::open(temp.path(), enc_new).unwrap();
463 let records2 = wal2.replay().unwrap();
464 assert_eq!(records2.len(), 2);
465
466 let enc_old2 = EncryptionConfig::from_password("old-key");
468 let wal3 = EncryptedWal::open(temp.path(), enc_old2).unwrap();
469 let result = wal3.replay();
470 assert!(result.is_err());
471 }
472}