use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use tokio::fs::{File, OpenOptions};
use tokio::io::{AsyncWrite, AsyncWriteExt};
#[derive(Debug, Clone)]
pub enum SerializationFormat {
Json,
MessagePack,
Cbor,
Bincode,
}
#[derive(Debug, Clone)]
pub enum CompressionAlgorithm {
None,
Lz4,
Zstd,
Deflate,
}
#[derive(Debug, Clone)]
pub struct SerializationConfig {
pub format: SerializationFormat,
pub compression: CompressionAlgorithm,
pub compression_level: i32,
pub enable_checksums: bool,
}
impl Default for SerializationConfig {
fn default() -> Self {
Self {
format: SerializationFormat::Bincode,
compression: CompressionAlgorithm::Lz4,
compression_level: 6,
enable_checksums: true,
}
}
}
pub struct BinarySerializer {
config: SerializationConfig,
}
impl BinarySerializer {
pub fn new(config: SerializationConfig) -> Self {
Self { config }
}
pub fn serialize<T: Serialize>(&self, data: &T) -> Result<Vec<u8>> {
let mut binary_data = match self.config.format {
SerializationFormat::Json => serde_json::to_vec(data)?,
SerializationFormat::MessagePack => rmp_serde::to_vec(data)?,
SerializationFormat::Cbor => {
let mut buf = Vec::new();
ciborium::ser::into_writer(data, &mut buf)
.map_err(|e| anyhow::anyhow!("CBOR serialization failed: {e}"))?;
buf
}
SerializationFormat::Bincode => {
oxicode::serde::encode_to_vec(data, oxicode::config::standard())?
}
};
binary_data = match self.config.compression {
CompressionAlgorithm::None => binary_data,
CompressionAlgorithm::Lz4 => self.compress_lz4(&binary_data)?,
CompressionAlgorithm::Zstd => self.compress_zstd(&binary_data)?,
CompressionAlgorithm::Deflate => self.compress_deflate(&binary_data)?,
};
if self.config.enable_checksums {
let checksum = crc32fast::hash(&binary_data);
let mut result = Vec::with_capacity(binary_data.len() + 8);
result.extend_from_slice(&checksum.to_le_bytes());
result.extend_from_slice(&(binary_data.len() as u32).to_le_bytes());
result.extend_from_slice(&binary_data);
Ok(result)
} else {
Ok(binary_data)
}
}
pub fn deserialize<T: for<'de> Deserialize<'de>>(&self, data: &[u8]) -> Result<T> {
let (binary_data, _expected_checksum) = if self.config.enable_checksums {
if data.len() < 8 {
return Err(anyhow::anyhow!("Data too short for checksum"));
}
let checksum = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
let length = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
if data.len() < 8 + length {
return Err(anyhow::anyhow!("Data length mismatch"));
}
let binary_data = &data[8..8 + length];
let computed_checksum = crc32fast::hash(binary_data);
if checksum != computed_checksum {
return Err(anyhow::anyhow!("Checksum validation failed"));
}
(binary_data, Some(checksum))
} else {
(data, None)
};
let decompressed_data = match self.config.compression {
CompressionAlgorithm::None => binary_data.to_vec(),
CompressionAlgorithm::Lz4 => self.decompress_lz4(binary_data)?,
CompressionAlgorithm::Zstd => self.decompress_zstd(binary_data)?,
CompressionAlgorithm::Deflate => self.decompress_deflate(binary_data)?,
};
let result = match self.config.format {
SerializationFormat::Json => serde_json::from_slice(&decompressed_data)?,
SerializationFormat::MessagePack => rmp_serde::from_slice(&decompressed_data)?,
SerializationFormat::Cbor => ciborium::de::from_reader(&decompressed_data[..])
.map_err(|e| anyhow::anyhow!("CBOR deserialization failed: {e}"))?,
SerializationFormat::Bincode => {
oxicode::serde::decode_from_slice(&decompressed_data, oxicode::config::standard())
.map(|(v, _)| v)?
}
};
Ok(result)
}
fn compress_lz4(&self, data: &[u8]) -> Result<Vec<u8>> {
oxiarc_lz4::compress(data).map_err(|e| anyhow::anyhow!("LZ4 compression failed: {}", e))
}
fn decompress_lz4(&self, data: &[u8]) -> Result<Vec<u8>> {
oxiarc_lz4::decompress(data, 100 * 1024 * 1024)
.map_err(|e| anyhow::anyhow!("LZ4 decompression failed: {}", e))
}
fn compress_zstd(&self, data: &[u8]) -> Result<Vec<u8>> {
oxiarc_zstd::encode_all(data, self.config.compression_level)
.map_err(|e| anyhow::anyhow!("Zstd compression failed: {}", e))
}
fn decompress_zstd(&self, data: &[u8]) -> Result<Vec<u8>> {
oxiarc_zstd::decode_all(data)
.map_err(|e| anyhow::anyhow!("Zstd decompression failed: {}", e))
}
fn compress_deflate(&self, data: &[u8]) -> Result<Vec<u8>> {
use flate2::write::ZlibEncoder;
use flate2::Compression;
use std::io::Write;
let mut encoder = ZlibEncoder::new(
Vec::new(),
Compression::new(self.config.compression_level as u32),
);
encoder.write_all(data)?;
encoder
.finish()
.map_err(|e| anyhow::anyhow!("Deflate compression failed: {}", e))
}
fn decompress_deflate(&self, data: &[u8]) -> Result<Vec<u8>> {
use flate2::read::ZlibDecoder;
use std::io::Read;
let mut decoder = ZlibDecoder::new(data);
let mut result = Vec::new();
decoder
.read_to_end(&mut result)
.map_err(|e| anyhow::anyhow!("Deflate decompression failed: {}", e))?;
Ok(result)
}
}
pub struct AtomicFileWriter {
temp_path: PathBuf,
final_path: PathBuf,
file: Option<File>,
}
impl AtomicFileWriter {
pub async fn new<P: AsRef<Path>>(path: P) -> Result<Self> {
let final_path = path.as_ref().to_path_buf();
let temp_path = final_path.with_extension("tmp");
let file = OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&temp_path)
.await?;
Ok(Self {
temp_path,
final_path,
file: Some(file),
})
}
pub async fn write(&mut self, data: &[u8]) -> Result<()> {
if let Some(ref mut file) = self.file {
file.write_all(data).await?;
file.sync_all().await?;
}
Ok(())
}
pub async fn commit(mut self) -> Result<()> {
if let Some(file) = self.file.take() {
file.sync_all().await?;
drop(file);
tokio::fs::rename(&self.temp_path, &self.final_path).await?;
}
Ok(())
}
pub async fn abort(self) -> Result<()> {
if self.temp_path.exists() {
tokio::fs::remove_file(&self.temp_path).await?;
}
Ok(())
}
}
impl Drop for AtomicFileWriter {
fn drop(&mut self) {
if self.temp_path.exists() {
let _ = std::fs::remove_file(&self.temp_path);
}
}
}
impl AsyncWrite for AtomicFileWriter {
fn poll_write(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
if let Some(ref mut file) = self.file {
let file_pin = std::pin::Pin::new(file);
file_pin.poll_write(cx, buf)
} else {
std::task::Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::Other,
"File not open",
)))
}
}
fn poll_flush(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
if let Some(ref mut file) = self.file {
let file_pin = std::pin::Pin::new(file);
file_pin.poll_flush(cx)
} else {
std::task::Poll::Ready(Ok(()))
}
}
fn poll_shutdown(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
if let Some(ref mut file) = self.file {
let file_pin = std::pin::Pin::new(file);
file_pin.poll_shutdown(cx)
} else {
std::task::Poll::Ready(Ok(()))
}
}
}
pub struct CorruptionDetector {
enable_deep_scan: bool,
}
impl CorruptionDetector {
pub fn new(enable_deep_scan: bool) -> Self {
Self { enable_deep_scan }
}
pub async fn verify_file_integrity<P: AsRef<Path>>(&self, path: P) -> Result<bool> {
let path = path.as_ref();
if !path.exists() {
return Ok(false);
}
let metadata = tokio::fs::metadata(path).await?;
if metadata.len() == 0 {
return Ok(true); }
let data = tokio::fs::read(path).await?;
if (data.len() as u64) != metadata.len() {
return Ok(false);
}
if self.enable_deep_scan {
self.verify_content_integrity(&data).await
} else {
Ok(true)
}
}
async fn verify_content_integrity(&self, data: &[u8]) -> Result<bool> {
if data.len() >= 8 {
let checksum = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
let length = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
if data.len() >= 8 + length {
let content = &data[8..8 + length];
let computed = crc32fast::hash(content);
return Ok(checksum == computed);
}
}
Ok(true) }
pub async fn attempt_repair<P: AsRef<Path>>(&self, path: P) -> Result<bool> {
let path = path.as_ref();
if !path.exists() {
return Ok(false);
}
let data = tokio::fs::read(path).await?;
if data.len() >= 8 {
let length = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
if length > 0 && data.len() >= 8 + length {
let valid_data = &data[0..8 + length];
let backup_path = path.with_extension("backup");
tokio::fs::rename(path, &backup_path).await?;
tokio::fs::write(path, valid_data).await?;
return Ok(true);
}
}
Ok(false)
}
pub async fn validate_file<P: AsRef<Path>>(&self, path: P) -> Result<bool> {
self.verify_file_integrity(path).await
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchemaVersion {
pub major: u32,
pub minor: u32,
pub patch: u32,
}
impl SchemaVersion {
pub const CURRENT: SchemaVersion = SchemaVersion {
major: 1,
minor: 0,
patch: 0,
};
pub fn is_compatible(&self, other: &SchemaVersion) -> bool {
self.major == other.major && self.minor >= other.minor
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VersionedData<T> {
pub version: SchemaVersion,
pub data: T,
}
impl<T> VersionedData<T> {
pub fn new(data: T) -> Self {
Self {
version: SchemaVersion::CURRENT,
data,
}
}
pub fn validate_compatibility(&self) -> Result<()> {
if !SchemaVersion::CURRENT.is_compatible(&self.version) {
return Err(anyhow::anyhow!(
"Incompatible schema version: {:?}, current: {:?}",
self.version,
SchemaVersion::CURRENT
));
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_binary_serialization() {
let config = SerializationConfig::default();
let serializer = BinarySerializer::new(config);
let data = vec!["hello".to_string(), "world".to_string()];
let serialized = serializer.serialize(&data).unwrap();
let deserialized: Vec<String> = serializer.deserialize(&serialized).unwrap();
assert_eq!(data, deserialized);
}
#[tokio::test]
async fn test_atomic_file_writer() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.dat");
let mut writer = AtomicFileWriter::new(&file_path).await.unwrap();
writer.write_all(b"test data").await.unwrap();
writer.commit().await.unwrap();
let content = tokio::fs::read(&file_path).await.unwrap();
assert_eq!(content, b"test data");
}
#[tokio::test]
async fn test_corruption_detection() {
let temp_dir = TempDir::new().unwrap();
let file_path = temp_dir.path().join("test.dat");
let config = SerializationConfig::default();
let serializer = BinarySerializer::new(config);
let detector = CorruptionDetector::new(true);
let data = vec!["test".to_string()];
let serialized = serializer.serialize(&data).unwrap();
tokio::fs::write(&file_path, &serialized).await.unwrap();
assert!(detector.validate_file(&file_path).await.unwrap());
let mut corrupted = serialized.clone();
let corrupt_idx = corrupted.len() / 2; corrupted[corrupt_idx] = !corrupted[corrupt_idx]; tokio::fs::write(&file_path, &corrupted).await.unwrap();
assert!(!detector.validate_file(&file_path).await.unwrap());
}
#[test]
fn test_schema_version_compatibility() {
let v1_0_0 = SchemaVersion {
major: 1,
minor: 0,
patch: 0,
};
let v1_1_0 = SchemaVersion {
major: 1,
minor: 1,
patch: 0,
};
let v2_0_0 = SchemaVersion {
major: 2,
minor: 0,
patch: 0,
};
assert!(v1_1_0.is_compatible(&v1_0_0));
assert!(!v1_0_0.is_compatible(&v1_1_0));
assert!(!v2_0_0.is_compatible(&v1_0_0));
}
}