use crate::error::{DbxError, DbxResult};
use crate::storage::encryption::EncryptionConfig;
use crate::wal::WalRecord;
use std::fs::{File, OpenOptions};
use std::io::{BufRead, BufReader, Write};
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
const WAL_AAD: &[u8] = b"dbx-wal-v1";
pub struct EncryptedWal {
log_file: Mutex<File>,
path: PathBuf,
sequence: AtomicU64,
encryption: EncryptionConfig,
}
impl EncryptedWal {
pub fn open(path: &Path, encryption: EncryptionConfig) -> DbxResult<Self> {
let file = OpenOptions::new()
.create(true)
.read(true)
.append(true)
.open(path)?;
let max_seq = Self::scan_max_sequence(path, &encryption)?;
Ok(Self {
log_file: Mutex::new(file),
path: path.to_path_buf(),
sequence: AtomicU64::new(max_seq),
encryption,
})
}
fn scan_max_sequence(path: &Path, encryption: &EncryptionConfig) -> DbxResult<u64> {
let file = match File::open(path) {
Ok(f) => f,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(0),
Err(e) => return Err(e.into()),
};
let reader = BufReader::new(file);
let mut max_seq = 0u64;
for line in reader.lines() {
let line = line?;
if line.is_empty() {
continue;
}
if let Ok(record) = Self::decrypt_line(&line, encryption)
&& let WalRecord::Checkpoint { sequence } = record
{
max_seq = max_seq.max(sequence);
}
max_seq += 1;
}
Ok(max_seq)
}
fn decrypt_line(line: &str, encryption: &EncryptionConfig) -> DbxResult<WalRecord> {
use base64::Engine;
use base64::engine::general_purpose::STANDARD;
let ciphertext = STANDARD
.decode(line.as_bytes())
.map_err(|e| DbxError::Encryption(format!("base64 decode failed: {}", e)))?;
let json_bytes = encryption.decrypt_with_aad(&ciphertext, WAL_AAD)?;
serde_json::from_slice(&json_bytes)
.map_err(|e| DbxError::Wal(format!("deserialization failed: {}", e)))
}
pub fn append(&self, record: &WalRecord) -> DbxResult<u64> {
use base64::Engine;
use base64::engine::general_purpose::STANDARD;
let seq = self.sequence.fetch_add(1, Ordering::SeqCst);
let json = serde_json::to_vec(record)
.map_err(|e| DbxError::Wal(format!("serialization failed: {}", e)))?;
let ciphertext = self.encryption.encrypt_with_aad(&json, WAL_AAD)?;
let encoded = STANDARD.encode(&ciphertext);
let mut file = self
.log_file
.lock()
.map_err(|e| DbxError::Wal(format!("lock failed: {}", e)))?;
file.write_all(encoded.as_bytes())?;
file.write_all(b"\n")?;
Ok(seq)
}
pub fn sync(&self) -> DbxResult<()> {
let file = self
.log_file
.lock()
.map_err(|e| DbxError::Wal(format!("lock failed: {}", e)))?;
file.sync_all()?;
Ok(())
}
pub fn replay(&self) -> DbxResult<Vec<WalRecord>> {
let file = File::open(&self.path)?;
let reader = BufReader::new(file);
let mut records = Vec::new();
for line in reader.lines() {
let line = line?;
if line.is_empty() {
continue;
}
let record = Self::decrypt_line(&line, &self.encryption)?;
records.push(record);
}
Ok(records)
}
pub fn current_sequence(&self) -> u64 {
self.sequence.load(Ordering::SeqCst)
}
pub fn encryption_config(&self) -> &EncryptionConfig {
&self.encryption
}
pub fn rekey(&mut self, new_encryption: EncryptionConfig) -> DbxResult<usize> {
use base64::Engine;
use base64::engine::general_purpose::STANDARD;
let records = self.replay()?;
let count = records.len();
let tmp_path = self.path.with_extension("rekey.tmp");
{
let mut tmp_file = File::create(&tmp_path)?;
for record in &records {
let json = serde_json::to_vec(record)
.map_err(|e| DbxError::Wal(format!("serialization failed: {}", e)))?;
let ciphertext = new_encryption.encrypt_with_aad(&json, WAL_AAD)?;
let encoded = STANDARD.encode(&ciphertext);
tmp_file.write_all(encoded.as_bytes())?;
tmp_file.write_all(b"\n")?;
}
tmp_file.sync_all()?;
}
std::fs::rename(&tmp_path, &self.path)?;
let file = OpenOptions::new()
.create(true)
.read(true)
.append(true)
.open(&self.path)?;
*self
.log_file
.lock()
.map_err(|e| DbxError::Wal(format!("lock failed: {}", e)))? = file;
self.encryption = new_encryption;
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
fn test_encryption() -> EncryptionConfig {
EncryptionConfig::from_password("test-wal-password")
}
#[test]
fn append_and_replay_round_trip() {
let temp = NamedTempFile::new().unwrap();
let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
let record1 = WalRecord::Insert {
table: "users".to_string(),
key: b"user:1".to_vec(),
value: b"Alice".to_vec(),
ts: 0,
};
let record2 = WalRecord::Delete {
table: "users".to_string(),
key: b"user:2".to_vec(),
ts: 1,
};
let seq1 = wal.append(&record1).unwrap();
let seq2 = wal.append(&record2).unwrap();
wal.sync().unwrap();
assert_eq!(seq1, 0);
assert_eq!(seq2, 1);
let records = wal.replay().unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0], record1);
assert_eq!(records[1], record2);
}
#[test]
fn sync_durability() {
let temp = NamedTempFile::new().unwrap();
let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
let record = WalRecord::Insert {
table: "test".to_string(),
key: b"key".to_vec(),
value: b"value".to_vec(),
ts: 0,
};
wal.append(&record).unwrap();
wal.sync().unwrap();
let wal2 = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
let records = wal2.replay().unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0], record);
}
#[test]
fn wrong_key_cannot_replay() {
let temp = NamedTempFile::new().unwrap();
let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
let record = WalRecord::Insert {
table: "secret".to_string(),
key: b"key".to_vec(),
value: b"value".to_vec(),
ts: 0,
};
wal.append(&record).unwrap();
wal.sync().unwrap();
let wrong_enc = EncryptionConfig::from_password("wrong-password");
let wal2 = EncryptedWal::open(temp.path(), wrong_enc).unwrap();
let result = wal2.replay();
assert!(result.is_err(), "Replay with wrong key should fail");
}
#[test]
fn empty_wal_replay() {
let temp = NamedTempFile::new().unwrap();
let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
let records = wal.replay().unwrap();
assert_eq!(records.len(), 0);
}
#[test]
fn checkpoint_record() {
let temp = NamedTempFile::new().unwrap();
let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
let checkpoint = WalRecord::Checkpoint { sequence: 42 };
wal.append(&checkpoint).unwrap();
wal.sync().unwrap();
let records = wal.replay().unwrap();
assert_eq!(records.len(), 1);
assert_eq!(records[0], checkpoint);
}
#[test]
fn multiple_record_types() {
let temp = NamedTempFile::new().unwrap();
let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
let records_to_write = vec![
WalRecord::Insert {
table: "t".to_string(),
key: b"k1".to_vec(),
value: b"v1".to_vec(),
ts: 0,
},
WalRecord::Delete {
table: "t".to_string(),
key: b"k2".to_vec(),
ts: 1,
},
WalRecord::Commit { tx_id: 1 },
WalRecord::Rollback { tx_id: 2 },
WalRecord::Checkpoint { sequence: 10 },
];
for r in &records_to_write {
wal.append(r).unwrap();
}
wal.sync().unwrap();
let replayed = wal.replay().unwrap();
assert_eq!(replayed, records_to_write);
}
#[test]
fn raw_file_is_not_readable() {
let temp = NamedTempFile::new().unwrap();
let wal = EncryptedWal::open(temp.path(), test_encryption()).unwrap();
let record = WalRecord::Insert {
table: "secret".to_string(),
key: b"key".to_vec(),
value: b"sensitive_data".to_vec(),
ts: 0,
};
wal.append(&record).unwrap();
wal.sync().unwrap();
let raw = std::fs::read_to_string(temp.path()).unwrap();
assert!(!raw.contains("secret"));
assert!(!raw.contains("sensitive_data"));
assert!(!raw.contains("key"));
}
#[test]
fn rekey_preserves_records() {
let temp = NamedTempFile::new().unwrap();
let enc_old = EncryptionConfig::from_password("old-key");
let mut wal = EncryptedWal::open(temp.path(), enc_old).unwrap();
let record1 = WalRecord::Insert {
table: "t".to_string(),
key: b"k1".to_vec(),
value: b"v1".to_vec(),
ts: 0,
};
let record2 = WalRecord::Delete {
table: "t".to_string(),
key: b"k2".to_vec(),
ts: 1,
};
wal.append(&record1).unwrap();
wal.append(&record2).unwrap();
wal.sync().unwrap();
let enc_new = EncryptionConfig::from_password("new-key");
let count = wal.rekey(enc_new.clone()).unwrap();
assert_eq!(count, 2);
let records = wal.replay().unwrap();
assert_eq!(records.len(), 2);
assert_eq!(records[0], record1);
assert_eq!(records[1], record2);
let wal2 = EncryptedWal::open(temp.path(), enc_new).unwrap();
let records2 = wal2.replay().unwrap();
assert_eq!(records2.len(), 2);
let enc_old2 = EncryptionConfig::from_password("old-key");
let wal3 = EncryptedWal::open(temp.path(), enc_old2).unwrap();
let result = wal3.replay();
assert!(result.is_err());
}
}