1use aegis_common::{AegisError, BlockId, BlockType, CompressionType, EncryptionType, Result};
17use aes_gcm::{
18 aead::{Aead, KeyInit},
19 Aes256Gcm, Nonce,
20};
21use bytes::{BufMut, Bytes, BytesMut};
22use parking_lot::Mutex;
23use serde::{Deserialize, Serialize};
24use std::sync::OnceLock;
25
26pub const BLOCK_HEADER_SIZE: usize = 32;
31pub const DEFAULT_BLOCK_SIZE: usize = 8192;
32pub const MAX_BLOCK_SIZE: usize = 1024 * 1024; pub const AES_GCM_NONCE_SIZE: usize = 12;
34pub const AES_256_KEY_SIZE: usize = 32;
35
36static ENCRYPTION_KEY: OnceLock<[u8; AES_256_KEY_SIZE]> = OnceLock::new();
44static ENCRYPTION_KEY_INIT: Mutex<bool> = Mutex::new(false);
45
46fn get_encryption_key() -> Result<&'static [u8; AES_256_KEY_SIZE]> {
49 if let Some(key) = ENCRYPTION_KEY.get() {
51 return Ok(key);
52 }
53
54 let _guard = ENCRYPTION_KEY_INIT.lock();
56
57 if let Some(key) = ENCRYPTION_KEY.get() {
59 return Ok(key);
60 }
61
62 let hex_key = std::env::var("AEGIS_ENCRYPTION_KEY").map_err(|_| {
63 AegisError::Encryption("AEGIS_ENCRYPTION_KEY environment variable not set".to_string())
64 })?;
65
66 let key_bytes = hex::decode(&hex_key).map_err(|e| {
67 AegisError::Encryption(format!(
68 "Invalid hex encoding in AEGIS_ENCRYPTION_KEY: {}",
69 e
70 ))
71 })?;
72
73 if key_bytes.len() != AES_256_KEY_SIZE {
74 return Err(AegisError::Encryption(format!(
75 "AEGIS_ENCRYPTION_KEY must be {} bytes ({} hex chars), got {} bytes",
76 AES_256_KEY_SIZE,
77 AES_256_KEY_SIZE * 2,
78 key_bytes.len()
79 )));
80 }
81
82 let mut key = [0u8; AES_256_KEY_SIZE];
83 key.copy_from_slice(&key_bytes);
84
85 let _ = ENCRYPTION_KEY.set(key);
87
88 ENCRYPTION_KEY.get().ok_or_else(|| {
89 AegisError::Encryption("Failed to retrieve encryption key after initialization".to_string())
90 })
91}
92
93fn encrypt_aes256gcm(plaintext: &[u8]) -> Result<Vec<u8>> {
96 let key = get_encryption_key()?;
97 let cipher = Aes256Gcm::new_from_slice(key)
98 .map_err(|e| AegisError::Encryption(format!("Failed to create cipher: {}", e)))?;
99
100 let mut nonce_bytes = [0u8; AES_GCM_NONCE_SIZE];
102 getrandom::getrandom(&mut nonce_bytes)
103 .map_err(|e| AegisError::Encryption(format!("Failed to generate nonce: {}", e)))?;
104 let nonce = Nonce::from_slice(&nonce_bytes);
105
106 let ciphertext = cipher
107 .encrypt(nonce, plaintext)
108 .map_err(|e| AegisError::Encryption(format!("Encryption failed: {}", e)))?;
109
110 let mut result = Vec::with_capacity(AES_GCM_NONCE_SIZE + ciphertext.len());
112 result.extend_from_slice(&nonce_bytes);
113 result.extend(ciphertext);
114
115 Ok(result)
116}
117
118fn decrypt_aes256gcm(encrypted_data: &[u8]) -> Result<Vec<u8>> {
121 if encrypted_data.len() < AES_GCM_NONCE_SIZE {
122 return Err(AegisError::Encryption(
123 "Encrypted data too short: missing nonce".to_string(),
124 ));
125 }
126
127 let key = get_encryption_key()?;
128 let cipher = Aes256Gcm::new_from_slice(key)
129 .map_err(|e| AegisError::Encryption(format!("Failed to create cipher: {}", e)))?;
130
131 let nonce = Nonce::from_slice(&encrypted_data[..AES_GCM_NONCE_SIZE]);
132 let ciphertext = &encrypted_data[AES_GCM_NONCE_SIZE..];
133
134 let plaintext = cipher
135 .decrypt(nonce, ciphertext)
136 .map_err(|e| AegisError::Encryption(format!("Decryption failed: {}", e)))?;
137
138 Ok(plaintext)
139}
140
141#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
147pub struct BlockHeader {
148 pub block_id: BlockId,
149 pub block_type: BlockType,
150 pub compression: CompressionType,
151 pub encryption: EncryptionType,
152 pub data_size: u32,
153 pub uncompressed_size: u32,
154 pub checksum: u32,
155 pub version: u16,
156 pub flags: u16,
157}
158
159impl BlockHeader {
160 pub fn new(block_id: BlockId, block_type: BlockType) -> Self {
161 Self {
162 block_id,
163 block_type,
164 compression: CompressionType::None,
165 encryption: EncryptionType::None,
166 data_size: 0,
167 uncompressed_size: 0,
168 checksum: 0,
169 version: 1,
170 flags: 0,
171 }
172 }
173
174 pub fn to_bytes(&self) -> Result<Bytes> {
176 bincode::serialize(self)
177 .map(Bytes::from)
178 .map_err(|e| AegisError::Serialization(e.to_string()))
179 }
180
181 pub fn from_bytes(data: &[u8]) -> Result<Self> {
183 bincode::deserialize(data).map_err(|e| AegisError::Serialization(e.to_string()))
184 }
185}
186
187#[derive(Debug, Clone)]
193pub struct Block {
194 pub header: BlockHeader,
195 pub data: Bytes,
196}
197
198impl Block {
199 pub fn new(block_id: BlockId, block_type: BlockType, data: Bytes) -> Self {
201 let checksum = crc32fast::hash(&data);
202 let data_size = data.len() as u32;
203
204 let header = BlockHeader {
205 block_id,
206 block_type,
207 compression: CompressionType::None,
208 encryption: EncryptionType::None,
209 data_size,
210 uncompressed_size: data_size,
211 checksum,
212 version: 1,
213 flags: 0,
214 };
215
216 Self { header, data }
217 }
218
219 pub fn empty(block_id: BlockId, block_type: BlockType) -> Self {
221 Self::new(block_id, block_type, Bytes::new())
222 }
223
224 pub fn verify_checksum(&self) -> bool {
226 let computed = crc32fast::hash(&self.data);
227 computed == self.header.checksum
228 }
229
230 pub fn update_checksum(&mut self) {
232 self.header.checksum = crc32fast::hash(&self.data);
233 self.header.data_size = self.data.len() as u32;
234 }
235
236 pub fn compress(&mut self, compression: CompressionType) -> Result<()> {
238 if self.header.compression != CompressionType::None {
239 return Ok(());
240 }
241
242 let compressed = match compression {
243 CompressionType::None => return Ok(()),
244 CompressionType::Lz4 => lz4_flex::compress_prepend_size(&self.data),
245 CompressionType::Zstd => zstd::encode_all(self.data.as_ref(), 3)
246 .map_err(|e| AegisError::Storage(e.to_string()))?,
247 CompressionType::Snappy => {
248 let mut encoder = snap::raw::Encoder::new();
249 encoder
250 .compress_vec(&self.data)
251 .map_err(|e| AegisError::Storage(e.to_string()))?
252 }
253 };
254
255 self.header.uncompressed_size = self.header.data_size;
256 self.data = Bytes::from(compressed);
257 self.header.data_size = self.data.len() as u32;
258 self.header.compression = compression;
259 self.update_checksum();
260
261 Ok(())
262 }
263
264 pub fn decompress(&mut self) -> Result<()> {
266 if self.header.compression == CompressionType::None {
267 return Ok(());
268 }
269
270 let decompressed = match self.header.compression {
271 CompressionType::None => return Ok(()),
272 CompressionType::Lz4 => lz4_flex::decompress_size_prepended(&self.data)
273 .map_err(|e| AegisError::Storage(e.to_string()))?,
274 CompressionType::Zstd => zstd::decode_all(self.data.as_ref())
275 .map_err(|e| AegisError::Storage(e.to_string()))?,
276 CompressionType::Snappy => {
277 let mut decoder = snap::raw::Decoder::new();
278 decoder
279 .decompress_vec(&self.data)
280 .map_err(|e| AegisError::Storage(e.to_string()))?
281 }
282 };
283
284 self.data = Bytes::from(decompressed);
285 self.header.data_size = self.data.len() as u32;
286 self.header.compression = CompressionType::None;
287 self.update_checksum();
288
289 Ok(())
290 }
291
292 pub fn encrypt(&mut self, encryption: EncryptionType) -> Result<()> {
295 if self.header.encryption != EncryptionType::None {
296 return Ok(()); }
298
299 let encrypted = match encryption {
300 EncryptionType::None => return Ok(()),
301 EncryptionType::Aes256Gcm => encrypt_aes256gcm(&self.data)?,
302 };
303
304 self.data = Bytes::from(encrypted);
305 self.header.data_size = self.data.len() as u32;
306 self.header.encryption = encryption;
307 self.update_checksum();
308
309 Ok(())
310 }
311
312 pub fn decrypt(&mut self) -> Result<()> {
315 if self.header.encryption == EncryptionType::None {
316 return Ok(()); }
318
319 let decrypted = match self.header.encryption {
320 EncryptionType::None => return Ok(()),
321 EncryptionType::Aes256Gcm => decrypt_aes256gcm(&self.data)?,
322 };
323
324 self.data = Bytes::from(decrypted);
325 self.header.data_size = self.data.len() as u32;
326 self.header.encryption = EncryptionType::None;
327 self.update_checksum();
328
329 Ok(())
330 }
331
332 pub fn to_bytes(&self) -> Result<Bytes> {
334 let header_bytes = self.header.to_bytes()?;
335 let mut buf = BytesMut::with_capacity(header_bytes.len() + self.data.len() + 8);
336
337 buf.put_u32_le(header_bytes.len() as u32);
338 buf.put(header_bytes);
339 buf.put_u32_le(self.data.len() as u32);
340 buf.put(self.data.clone());
341
342 Ok(buf.freeze())
343 }
344
345 pub fn from_bytes(data: &[u8]) -> Result<Self> {
347 if data.len() < 8 {
348 return Err(AegisError::Corruption("Block too small".to_string()));
349 }
350
351 let header_len = u32::from_le_bytes([data[0], data[1], data[2], data[3]]) as usize;
352 if data.len() < 4 + header_len + 4 {
353 return Err(AegisError::Corruption("Block header truncated".to_string()));
354 }
355
356 let header = BlockHeader::from_bytes(&data[4..4 + header_len])?;
357
358 let data_offset = 4 + header_len;
359 let data_len = u32::from_le_bytes([
360 data[data_offset],
361 data[data_offset + 1],
362 data[data_offset + 2],
363 data[data_offset + 3],
364 ]) as usize;
365
366 if data.len() < data_offset + 4 + data_len {
367 return Err(AegisError::Corruption("Block data truncated".to_string()));
368 }
369
370 let block_data = Bytes::copy_from_slice(&data[data_offset + 4..data_offset + 4 + data_len]);
371
372 Ok(Self {
373 header,
374 data: block_data,
375 })
376 }
377
378 pub fn size(&self) -> usize {
380 BLOCK_HEADER_SIZE + self.data.len()
381 }
382}
383
384#[cfg(test)]
389mod tests {
390 use super::*;
391
392 #[test]
393 fn test_block_roundtrip() {
394 let data = Bytes::from("Hello, Aegis!");
395 let block = Block::new(BlockId(1), BlockType::TableData, data.clone());
396
397 assert!(block.verify_checksum());
398
399 let serialized = block.to_bytes().expect("to_bytes should succeed");
400 let deserialized = Block::from_bytes(&serialized).expect("from_bytes should succeed");
401
402 assert_eq!(deserialized.header.block_id, BlockId(1));
403 assert_eq!(deserialized.data, data);
404 assert!(deserialized.verify_checksum());
405 }
406
407 #[test]
408 fn test_block_compression_lz4() {
409 let data = Bytes::from("Hello, Aegis! ".repeat(100));
410 let mut block = Block::new(BlockId(1), BlockType::TableData, data.clone());
411
412 block
413 .compress(CompressionType::Lz4)
414 .expect("LZ4 compression should succeed");
415 assert!(block.header.data_size < block.header.uncompressed_size);
416
417 block
418 .decompress()
419 .expect("LZ4 decompression should succeed");
420 assert_eq!(block.data, data);
421 }
422
423 #[test]
424 fn test_block_compression_zstd() {
425 let data = Bytes::from("Hello, Aegis! ".repeat(100));
426 let mut block = Block::new(BlockId(1), BlockType::TableData, data.clone());
427
428 block
429 .compress(CompressionType::Zstd)
430 .expect("Zstd compression should succeed");
431 assert!(block.header.data_size < block.header.uncompressed_size);
432
433 block
434 .decompress()
435 .expect("Zstd decompression should succeed");
436 assert_eq!(block.data, data);
437 }
438
439 #[test]
440 fn test_block_encryption_aes256gcm() {
441 std::env::set_var(
443 "AEGIS_ENCRYPTION_KEY",
444 "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
445 );
446
447 let data = Bytes::from("Secret data to encrypt!");
448 let mut block = Block::new(BlockId(1), BlockType::TableData, data.clone());
449
450 block
452 .encrypt(EncryptionType::Aes256Gcm)
453 .expect("AES-256-GCM encryption should succeed");
454 assert_eq!(block.header.encryption, EncryptionType::Aes256Gcm);
455 assert_ne!(block.data, data); assert!(block.data.len() > data.len());
458
459 block
461 .decrypt()
462 .expect("AES-256-GCM decryption should succeed");
463 assert_eq!(block.header.encryption, EncryptionType::None);
464 assert_eq!(block.data, data);
465 }
466
467 #[test]
468 fn test_block_encryption_roundtrip() {
469 std::env::set_var(
471 "AEGIS_ENCRYPTION_KEY",
472 "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
473 );
474
475 let data = Bytes::from("Hello, encrypted Aegis!");
476 let mut block = Block::new(BlockId(1), BlockType::TableData, data.clone());
477
478 block
480 .encrypt(EncryptionType::Aes256Gcm)
481 .expect("Encryption should succeed");
482 assert!(block.verify_checksum());
483
484 let serialized = block.to_bytes().expect("to_bytes should succeed");
486 let mut deserialized = Block::from_bytes(&serialized).expect("from_bytes should succeed");
487
488 assert_eq!(deserialized.header.encryption, EncryptionType::Aes256Gcm);
489 assert!(deserialized.verify_checksum());
490
491 deserialized.decrypt().expect("Decryption should succeed");
493 assert_eq!(deserialized.data, data);
494 }
495
496 #[test]
497 fn test_block_compression_then_encryption() {
498 std::env::set_var(
500 "AEGIS_ENCRYPTION_KEY",
501 "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
502 );
503
504 let data = Bytes::from("Hello, Aegis! ".repeat(100));
505 let mut block = Block::new(BlockId(1), BlockType::TableData, data.clone());
506
507 block
509 .compress(CompressionType::Lz4)
510 .expect("Compression should succeed");
511 let compressed_size = block.header.data_size;
512
513 block
515 .encrypt(EncryptionType::Aes256Gcm)
516 .expect("Encryption should succeed");
517 assert_eq!(block.header.compression, CompressionType::Lz4);
518 assert_eq!(block.header.encryption, EncryptionType::Aes256Gcm);
519
520 block.decrypt().expect("Decryption should succeed");
522 assert_eq!(block.header.data_size, compressed_size);
523
524 block.decompress().expect("Decompression should succeed");
526 assert_eq!(block.data, data);
527 }
528
529 #[test]
530 fn test_encryption_key_validation() {
531 std::env::set_var("AEGIS_ENCRYPTION_KEY", "too_short");
533
534 }
538
539 #[test]
540 fn test_double_encryption_is_noop() {
541 std::env::set_var(
542 "AEGIS_ENCRYPTION_KEY",
543 "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
544 );
545
546 let data = Bytes::from("Test data");
547 let mut block = Block::new(BlockId(1), BlockType::TableData, data);
548
549 block
551 .encrypt(EncryptionType::Aes256Gcm)
552 .expect("First encryption should succeed");
553 let first_encrypted = block.data.clone();
554
555 block
557 .encrypt(EncryptionType::Aes256Gcm)
558 .expect("Second encryption should succeed (no-op)");
559 assert_eq!(block.data, first_encrypted);
560 }
561
562 #[test]
563 fn test_double_decryption_is_noop() {
564 std::env::set_var(
565 "AEGIS_ENCRYPTION_KEY",
566 "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
567 );
568
569 let data = Bytes::from("Test data");
570 let mut block = Block::new(BlockId(1), BlockType::TableData, data.clone());
571
572 block
573 .encrypt(EncryptionType::Aes256Gcm)
574 .expect("Encryption should succeed");
575 block.decrypt().expect("First decryption should succeed");
576 let decrypted = block.data.clone();
577
578 block
580 .decrypt()
581 .expect("Second decryption should succeed (no-op)");
582 assert_eq!(block.data, decrypted);
583 assert_eq!(block.data, data);
584 }
585}