use crate::error::{DbxError, DbxResult};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
pub type SerializeFn = Arc<dyn Fn(&[u8]) -> DbxResult<Vec<u8>> + Send + Sync>;
pub type DeserializeFn = Arc<dyn Fn(&[u8]) -> DbxResult<Vec<u8>> + Send + Sync>;
pub struct SerializationRegistry {
serializers: Arc<RwLock<HashMap<String, SerializeFn>>>,
deserializers: Arc<RwLock<HashMap<String, DeserializeFn>>>,
}
impl SerializationRegistry {
pub fn new() -> Self {
Self {
serializers: Arc::new(RwLock::new(HashMap::new())),
deserializers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register_serializer(&self, type_name: String, serializer: SerializeFn) {
self.serializers
.write()
.unwrap()
.insert(type_name, serializer);
}
pub fn register_deserializer(&self, type_name: String, deserializer: DeserializeFn) {
self.deserializers
.write()
.unwrap()
.insert(type_name, deserializer);
}
pub fn serialize(&self, type_name: &str, data: &[u8]) -> DbxResult<Vec<u8>> {
let serializers = self.serializers.read().unwrap();
let serializer = serializers.get(type_name).ok_or_else(|| {
DbxError::Serialization(format!("No serializer registered for type '{}'", type_name))
})?;
serializer(data)
}
pub fn deserialize(&self, type_name: &str, data: &[u8]) -> DbxResult<Vec<u8>> {
let deserializers = self.deserializers.read().unwrap();
let deserializer = deserializers.get(type_name).ok_or_else(|| {
DbxError::Serialization(format!(
"No deserializer registered for type '{}'",
type_name
))
})?;
deserializer(data)
}
pub fn registered_types(&self) -> Vec<String> {
self.serializers.read().unwrap().keys().cloned().collect()
}
pub fn compress(&self, data: &[u8], level: i32) -> DbxResult<Vec<u8>> {
zstd::encode_all(data, level)
.map_err(|e| DbxError::Serialization(format!("Compression failed: {}", e)))
}
pub fn decompress(&self, data: &[u8]) -> DbxResult<Vec<u8>> {
zstd::decode_all(data)
.map_err(|e| DbxError::Serialization(format!("Decompression failed: {}", e)))
}
pub fn checksum(&self, data: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().to_vec()
}
pub fn verify_checksum(&self, data: &[u8], expected_checksum: &[u8]) -> bool {
let actual_checksum = self.checksum(data);
actual_checksum == expected_checksum
}
}
impl Default for SerializationRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct TwoLevelCache {
l1_cache: Arc<RwLock<HashMap<String, Vec<u8>>>>,
l1_max_size: usize,
l1_current_size: Arc<RwLock<usize>>,
l2_cache_dir: PathBuf,
}
impl TwoLevelCache {
pub fn new(l1_max_size: usize, l2_cache_dir: PathBuf) -> Self {
Self {
l1_cache: Arc::new(RwLock::new(HashMap::new())),
l1_max_size,
l1_current_size: Arc::new(RwLock::new(0)),
l2_cache_dir,
}
}
pub fn put(&self, key: String, value: Vec<u8>) -> DbxResult<()> {
let value_size = value.len();
let mut current_size = self.l1_current_size.write().unwrap();
if *current_size + value_size <= self.l1_max_size {
self.l1_cache
.write()
.unwrap()
.insert(key.clone(), value.clone());
*current_size += value_size;
} else {
drop(current_size);
self.put_l2(&key, &value)?;
}
Ok(())
}
pub fn get(&self, key: &str) -> DbxResult<Option<Vec<u8>>> {
if let Some(value) = self.l1_cache.read().unwrap().get(key) {
return Ok(Some(value.clone()));
}
self.get_l2(key)
}
fn put_l2(&self, key: &str, value: &[u8]) -> DbxResult<()> {
fs::create_dir_all(&self.l2_cache_dir)?;
let file_path = self.l2_cache_dir.join(format!("{}.bin", key));
fs::write(file_path, value)?;
Ok(())
}
fn get_l2(&self, key: &str) -> DbxResult<Option<Vec<u8>>> {
let file_path = self.l2_cache_dir.join(format!("{}.bin", key));
if !file_path.exists() {
return Ok(None);
}
let data = fs::read(file_path)?;
Ok(Some(data))
}
pub fn clear(&self) -> DbxResult<()> {
self.l1_cache.write().unwrap().clear();
*self.l1_current_size.write().unwrap() = 0;
if self.l2_cache_dir.exists() {
fs::remove_dir_all(&self.l2_cache_dir)?;
}
Ok(())
}
pub fn l1_size(&self) -> usize {
*self.l1_current_size.read().unwrap()
}
pub fn l1_count(&self) -> usize {
self.l1_cache.read().unwrap().len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialization_registry() {
let registry = SerializationRegistry::new();
let serializer: SerializeFn = Arc::new(|data| Ok(data.to_vec()));
registry.register_serializer("test_type".to_string(), serializer);
let deserializer: DeserializeFn = Arc::new(|data| Ok(data.to_vec()));
registry.register_deserializer("test_type".to_string(), deserializer);
let data = b"hello world";
let serialized = registry.serialize("test_type", data).unwrap();
assert_eq!(serialized, data);
let deserialized = registry.deserialize("test_type", &serialized).unwrap();
assert_eq!(deserialized, data);
let types = registry.registered_types();
assert!(types.contains(&"test_type".to_string()));
}
#[test]
fn test_two_level_cache_l1() {
let cache = TwoLevelCache::new(1024, PathBuf::from("target/test_cache_l1"));
let key = "test_key".to_string();
let value = b"test_value".to_vec();
cache.put(key.clone(), value.clone()).unwrap();
let retrieved = cache.get(&key).unwrap();
assert_eq!(retrieved, Some(value));
assert!(cache.l1_size() > 0);
assert_eq!(cache.l1_count(), 1);
let _ = cache.clear();
}
#[test]
fn test_two_level_cache_l2() {
let cache = TwoLevelCache::new(10, PathBuf::from("target/test_cache_l2"));
let key = "large_key".to_string();
let value = vec![0u8; 100]; cache.put(key.clone(), value.clone()).unwrap();
let retrieved = cache.get(&key).unwrap();
assert_eq!(retrieved, Some(value));
let _ = cache.clear();
}
#[test]
fn test_two_level_cache_clear() {
let cache = TwoLevelCache::new(1024, PathBuf::from("target/test_cache_clear"));
cache.put("key1".to_string(), b"value1".to_vec()).unwrap();
cache.put("key2".to_string(), b"value2".to_vec()).unwrap();
cache.clear().unwrap();
assert_eq!(cache.l1_size(), 0);
assert_eq!(cache.l1_count(), 0);
assert_eq!(cache.get("key1").unwrap(), None);
}
}