use bitcode::{Decode, Encode};
use crc32fast::Hasher;
use std::fs;
use std::path::Path;
use std::sync::Arc;
use super::atomic_write;
use super::error::{IndexPersistenceError, Result};
use crate::encryption::cipher::Cipher;
pub fn save_encoded_with_crc<T: Encode>(data: &T, path: &Path) -> Result<()> {
let encoded = bitcode::encode(data);
let mut hasher = Hasher::new();
hasher.update(&encoded);
let checksum = hasher.finalize();
let mut data_with_checksum = encoded;
data_with_checksum.extend_from_slice(&checksum.to_le_bytes());
atomic_write(path, &data_with_checksum)
}
pub fn load_encoded_with_crc<T: for<'a> Decode<'a>>(
path: &Path,
max_size: u64,
context: &str,
) -> Result<T> {
let metadata = fs::metadata(path)?;
if metadata.len() > max_size {
return Err(IndexPersistenceError::SizeLimitExceeded {
message: format!(
"{} file size {} exceeds limit {}",
context,
metadata.len(),
max_size
),
});
}
let bytes = fs::read(path)?;
if bytes.len() < 4 {
return Err(IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: "File too small to contain CRC32 checksum".into(),
});
}
let (data, checksum_bytes) = bytes.split_at(bytes.len() - 4);
let stored_checksum = u32::from_le_bytes(checksum_bytes.try_into().map_err(|_| {
IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: "Invalid CRC32 checksum format".into(),
}
})?);
let mut hasher = Hasher::new();
hasher.update(data);
let computed_checksum = hasher.finalize();
if computed_checksum != stored_checksum {
return Err(IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: format!(
"CRC32 checksum mismatch: expected {}, got {}",
stored_checksum, computed_checksum
)
.into(),
});
}
let decoded: T = bitcode::decode(data)?;
Ok(decoded)
}
#[allow(dead_code)] pub fn save_encoded_encrypted<T: Encode>(
data: &T,
path: &Path,
cipher: &Arc<dyn Cipher>,
) -> Result<()> {
let encoded = bitcode::encode(data);
let mut hasher = Hasher::new();
hasher.update(&encoded);
let checksum = hasher.finalize();
let mut plaintext = encoded;
plaintext.extend_from_slice(&checksum.to_le_bytes());
let encrypted = cipher
.encrypt(&plaintext, &[])
.map_err(|e| IndexPersistenceError::Serialization(format!("Encryption failed: {e}")))?;
atomic_write(path, &encrypted)
}
#[allow(dead_code)] pub fn load_encoded_encrypted<T: for<'a> Decode<'a>>(
path: &Path,
max_size: u64,
context: &str,
cipher: &Arc<dyn Cipher>,
) -> Result<T> {
let metadata = fs::metadata(path)?;
if metadata.len() > max_size {
return Err(IndexPersistenceError::SizeLimitExceeded {
message: format!(
"{} file size {} exceeds limit {}",
context,
metadata.len(),
max_size
),
});
}
let encrypted = fs::read(path)?;
let plaintext =
cipher
.decrypt(&encrypted, &[])
.map_err(|e| IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: format!("Decryption failed: {e}").into(),
})?;
if plaintext.len() < 4 {
return Err(IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: "Decrypted data too small to contain CRC32 checksum".into(),
});
}
let (data, checksum_bytes) = plaintext.split_at(plaintext.len() - 4);
let stored_checksum = u32::from_le_bytes(checksum_bytes.try_into().map_err(|_| {
IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: "Invalid CRC32 checksum format".into(),
}
})?);
let mut hasher = Hasher::new();
hasher.update(data);
let computed_checksum = hasher.finalize();
if computed_checksum != stored_checksum {
return Err(IndexPersistenceError::Corrupted {
path: path.to_path_buf(),
source: format!(
"CRC32 checksum mismatch: expected {}, got {}",
stored_checksum, computed_checksum
)
.into(),
});
}
let decoded: T = bitcode::decode(data)?;
Ok(decoded)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_save_load_round_trip() {
let file = NamedTempFile::new().unwrap();
let path = file.path();
let data = 42u64;
save_encoded_with_crc(&data, path).unwrap();
let loaded: u64 = load_encoded_with_crc(path, 1024, "Test").unwrap();
assert_eq!(loaded, data);
}
#[test]
fn test_checksum_mismatch() {
let file = NamedTempFile::new().unwrap();
let path = file.path();
let data = 42u64;
save_encoded_with_crc(&data, path).unwrap();
let mut bytes = fs::read(path).unwrap();
bytes[0] ^= 0xFF; let mut file_rw = fs::File::create(path).unwrap();
file_rw.write_all(&bytes).unwrap();
let result: Result<u64> = load_encoded_with_crc(path, 1024, "Test");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
IndexPersistenceError::Corrupted { .. }
));
}
#[test]
fn test_size_limit_exceeded() {
let file = NamedTempFile::new().unwrap();
let path = file.path();
let data = vec![0u8; 100];
save_encoded_with_crc(&data, path).unwrap();
let result: Result<Vec<u8>> = load_encoded_with_crc(path, 10, "Test");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
IndexPersistenceError::SizeLimitExceeded { .. }
));
}
#[test]
fn test_file_too_small() {
let file = NamedTempFile::new().unwrap();
let path = file.path();
let mut file_rw = fs::File::create(path).unwrap();
file_rw.write_all(&[1, 2, 3]).unwrap();
let result: Result<u64> = load_encoded_with_crc(path, 1024, "Test");
assert!(result.is_err());
match result.unwrap_err() {
IndexPersistenceError::Corrupted { source, .. } => {
assert!(source.to_string().contains("File too small"));
}
_ => panic!("Expected corrupted error for small file"),
}
}
fn test_cipher() -> Arc<dyn Cipher> {
use crate::encryption::Aes256GcmCipher;
use zeroize::Zeroizing;
let mut key = Zeroizing::new([0u8; 32]);
key[0] = 0xAB;
key[1] = 0xCD;
Arc::new(Aes256GcmCipher::new(&key))
}
fn different_cipher() -> Arc<dyn Cipher> {
use crate::encryption::Aes256GcmCipher;
use zeroize::Zeroizing;
let mut key = Zeroizing::new([0u8; 32]);
key[0] = 0x12;
key[1] = 0x34;
Arc::new(Aes256GcmCipher::new(&key))
}
#[test]
fn test_encrypted_save_load_round_trip() {
let file = NamedTempFile::new().unwrap();
let path = file.path();
let cipher = test_cipher();
let data = 42u64;
save_encoded_encrypted(&data, path, &cipher).unwrap();
let loaded: u64 = load_encoded_encrypted(path, 4096, "Test", &cipher).unwrap();
assert_eq!(loaded, data);
}
#[test]
fn test_encrypted_complex_data_round_trip() {
let file = NamedTempFile::new().unwrap();
let path = file.path();
let cipher = test_cipher();
let data = vec![1u8, 2, 3, 4, 5, 100, 200, 255];
save_encoded_encrypted(&data, path, &cipher).unwrap();
let loaded: Vec<u8> = load_encoded_encrypted(path, 4096, "Test", &cipher).unwrap();
assert_eq!(loaded, data);
}
#[test]
fn test_encrypted_tampered_file_fails() {
let file = NamedTempFile::new().unwrap();
let path = file.path();
let cipher = test_cipher();
let data = 42u64;
save_encoded_encrypted(&data, path, &cipher).unwrap();
let mut bytes = fs::read(path).unwrap();
let mid = bytes.len() / 2;
bytes[mid] ^= 0xFF;
let mut file_rw = fs::File::create(path).unwrap();
file_rw.write_all(&bytes).unwrap();
let result: Result<u64> = load_encoded_encrypted(path, 4096, "Test", &cipher);
assert!(result.is_err());
assert!(
matches!(result.unwrap_err(), IndexPersistenceError::Corrupted { .. }),
"Expected Corrupted error for tampered encrypted file"
);
}
#[test]
fn test_encrypted_wrong_key_fails() {
let file = NamedTempFile::new().unwrap();
let path = file.path();
let cipher1 = test_cipher();
let cipher2 = different_cipher();
let data = 42u64;
save_encoded_encrypted(&data, path, &cipher1).unwrap();
let result: Result<u64> = load_encoded_encrypted(path, 4096, "Test", &cipher2);
assert!(result.is_err());
assert!(
matches!(result.unwrap_err(), IndexPersistenceError::Corrupted { .. }),
"Expected Corrupted error when using wrong key"
);
}
#[test]
fn test_encrypted_size_limit_exceeded() {
let file = NamedTempFile::new().unwrap();
let path = file.path();
let cipher = test_cipher();
let data = vec![0u8; 100];
save_encoded_encrypted(&data, path, &cipher).unwrap();
let result: Result<Vec<u8>> = load_encoded_encrypted(path, 10, "Test", &cipher);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
IndexPersistenceError::SizeLimitExceeded { .. }
));
}
#[test]
fn test_encrypted_file_not_readable_as_unencrypted() {
let file = NamedTempFile::new().unwrap();
let path = file.path();
let cipher = test_cipher();
let data = 42u64;
save_encoded_encrypted(&data, path, &cipher).unwrap();
let result: Result<u64> = load_encoded_with_crc(path, 4096, "Test");
assert!(result.is_err());
}
}