use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use tracing::{debug, info};
pub trait CustomSerializer: Send + Sync {
fn serialize(&self, data: &[u8]) -> Result<Vec<u8>>;
fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>>;
fn format_name(&self) -> &str;
fn format_version(&self) -> &str {
"1.0.0"
}
fn magic_bytes(&self) -> Option<&[u8]> {
None
}
fn supports_zero_copy(&self) -> bool {
false
}
fn validate_schema(&self, _schema: &[u8], _data: &[u8]) -> Result<bool> {
Ok(true)
}
fn stats(&self) -> SerializerStats {
SerializerStats::default()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SerializerStats {
pub bytes_serialized: u64,
pub bytes_deserialized: u64,
pub serialization_count: u64,
pub deserialization_count: u64,
pub avg_serialization_time_ms: f64,
pub avg_deserialization_time_ms: f64,
pub error_count: u64,
}
pub struct SerializerRegistry {
serializers: Arc<RwLock<HashMap<String, Box<dyn CustomSerializer>>>>,
benchmarks: Arc<RwLock<HashMap<String, SerializerBenchmark>>>,
}
impl SerializerRegistry {
pub fn new() -> Self {
Self {
serializers: Arc::new(RwLock::new(HashMap::new())),
benchmarks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub async fn register(&self, name: &str, serializer: Box<dyn CustomSerializer>) -> Result<()> {
let mut serializers = self.serializers.write().await;
if serializers.contains_key(name) {
return Err(anyhow!("Serializer '{}' already registered", name));
}
serializers.insert(name.to_string(), serializer);
info!("Registered custom serializer: {}", name);
Ok(())
}
pub async fn unregister(&self, name: &str) -> Result<()> {
let mut serializers = self.serializers.write().await;
if serializers.remove(name).is_some() {
info!("Unregistered serializer: {}", name);
Ok(())
} else {
Err(anyhow!("Serializer '{}' not found", name))
}
}
pub async fn get(&self, name: &str) -> Result<String> {
let serializers = self.serializers.read().await;
if serializers.contains_key(name) {
Ok(name.to_string())
} else {
Err(anyhow!("Serializer '{}' not found", name))
}
}
pub async fn list(&self) -> Vec<String> {
let serializers = self.serializers.read().await;
serializers.keys().cloned().collect()
}
pub async fn serialize(&self, format: &str, data: &[u8]) -> Result<Vec<u8>> {
let serializers = self.serializers.read().await;
let serializer = serializers
.get(format)
.ok_or_else(|| anyhow!("Serializer '{}' not found", format))?;
let start = Instant::now();
let result = serializer.serialize(data)?;
let duration = start.elapsed();
drop(serializers);
self.update_benchmark(format, duration.as_secs_f64() * 1000.0, true)
.await;
Ok(result)
}
pub async fn deserialize(&self, format: &str, data: &[u8]) -> Result<Vec<u8>> {
let serializers = self.serializers.read().await;
let serializer = serializers
.get(format)
.ok_or_else(|| anyhow!("Serializer '{}' not found", format))?;
let start = Instant::now();
let result = serializer.deserialize(data)?;
let duration = start.elapsed();
drop(serializers);
self.update_benchmark(format, duration.as_secs_f64() * 1000.0, false)
.await;
Ok(result)
}
pub async fn detect_format(&self, data: &[u8]) -> Option<String> {
let serializers = self.serializers.read().await;
for (name, serializer) in serializers.iter() {
if let Some(magic) = serializer.magic_bytes() {
if data.len() >= magic.len() && &data[0..magic.len()] == magic {
return Some(name.clone());
}
}
}
None
}
pub async fn get_benchmark(&self, format: &str) -> Option<SerializerBenchmark> {
let benchmarks = self.benchmarks.read().await;
benchmarks.get(format).cloned()
}
pub async fn all_benchmarks(&self) -> HashMap<String, SerializerBenchmark> {
let benchmarks = self.benchmarks.read().await;
benchmarks.clone()
}
async fn update_benchmark(&self, format: &str, duration_ms: f64, is_serialization: bool) {
let mut benchmarks = self.benchmarks.write().await;
let benchmark = benchmarks
.entry(format.to_string())
.or_insert_with(SerializerBenchmark::default);
if is_serialization {
benchmark.serialization_times.push(duration_ms);
benchmark.serialization_count += 1;
} else {
benchmark.deserialization_times.push(duration_ms);
benchmark.deserialization_count += 1;
}
}
}
impl Default for SerializerRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SerializerBenchmark {
pub serialization_count: u64,
pub deserialization_count: u64,
pub serialization_times: Vec<f64>,
pub deserialization_times: Vec<f64>,
pub last_updated: Option<DateTime<Utc>>,
}
impl SerializerBenchmark {
pub fn avg_serialization_time(&self) -> f64 {
if self.serialization_times.is_empty() {
0.0
} else {
self.serialization_times.iter().sum::<f64>() / self.serialization_times.len() as f64
}
}
pub fn avg_deserialization_time(&self) -> f64 {
if self.deserialization_times.is_empty() {
0.0
} else {
self.deserialization_times.iter().sum::<f64>() / self.deserialization_times.len() as f64
}
}
pub fn p95_serialization_time(&self) -> f64 {
self.percentile(&self.serialization_times, 0.95)
}
pub fn p95_deserialization_time(&self) -> f64 {
self.percentile(&self.deserialization_times, 0.95)
}
fn percentile(&self, times: &[f64], p: f64) -> f64 {
if times.is_empty() {
return 0.0;
}
let mut sorted = times.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let index = ((sorted.len() as f64 - 1.0) * p) as usize;
sorted[index]
}
}
pub struct BsonSerializer;
impl CustomSerializer for BsonSerializer {
fn serialize(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut result = Vec::new();
result.extend_from_slice(b"BSON");
result.extend_from_slice(data);
Ok(result)
}
fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < 4 {
return Err(anyhow!("Invalid BSON data"));
}
Ok(data[4..].to_vec())
}
fn format_name(&self) -> &str {
"bson"
}
fn magic_bytes(&self) -> Option<&[u8]> {
Some(b"BSON")
}
}
pub struct ThriftSerializer;
impl CustomSerializer for ThriftSerializer {
fn serialize(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut result = Vec::new();
result.extend_from_slice(b"THFT");
result.extend_from_slice(data);
Ok(result)
}
fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < 4 {
return Err(anyhow!("Invalid Thrift data"));
}
Ok(data[4..].to_vec())
}
fn format_name(&self) -> &str {
"thrift"
}
fn magic_bytes(&self) -> Option<&[u8]> {
Some(b"THFT")
}
}
pub struct FlexBuffersSerializer;
impl CustomSerializer for FlexBuffersSerializer {
fn serialize(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut result = Vec::new();
result.extend_from_slice(b"FLEX");
result.extend_from_slice(data);
Ok(result)
}
fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < 4 {
return Err(anyhow!("Invalid FlexBuffers data"));
}
Ok(data[4..].to_vec())
}
fn format_name(&self) -> &str {
"flexbuffers"
}
fn magic_bytes(&self) -> Option<&[u8]> {
Some(b"FLEX")
}
fn supports_zero_copy(&self) -> bool {
true
}
}
pub struct RonSerializer;
impl CustomSerializer for RonSerializer {
fn serialize(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut result = Vec::new();
result.extend_from_slice(b"RON\0");
result.extend_from_slice(data);
Ok(result)
}
fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < 4 {
return Err(anyhow!("Invalid RON data"));
}
Ok(data[4..].to_vec())
}
fn format_name(&self) -> &str {
"ron"
}
fn magic_bytes(&self) -> Option<&[u8]> {
Some(b"RON\0")
}
}
pub struct IonSerializer;
impl CustomSerializer for IonSerializer {
fn serialize(&self, data: &[u8]) -> Result<Vec<u8>> {
let mut result = Vec::new();
result.extend_from_slice(b"ION\x01");
result.extend_from_slice(data);
Ok(result)
}
fn deserialize(&self, data: &[u8]) -> Result<Vec<u8>> {
if data.len() < 4 {
return Err(anyhow!("Invalid Ion data"));
}
Ok(data[4..].to_vec())
}
fn format_name(&self) -> &str {
"ion"
}
fn magic_bytes(&self) -> Option<&[u8]> {
Some(b"ION\x01")
}
}
pub struct SerializerBenchmarkSuite {
registry: Arc<SerializerRegistry>,
test_data: Vec<Vec<u8>>,
}
impl SerializerBenchmarkSuite {
pub fn new(registry: Arc<SerializerRegistry>) -> Self {
Self {
registry,
test_data: Self::generate_test_data(),
}
}
fn generate_test_data() -> Vec<Vec<u8>> {
use scirs2_core::random::{rng, RngExt};
let mut rand_gen = rng();
let sizes = [100, 1024, 10_240, 102_400];
sizes
.iter()
.map(|&size| (0..size).map(|_| rand_gen.random_range(0..=255)).collect())
.collect()
}
pub async fn benchmark(&self, format: &str, iterations: usize) -> Result<BenchmarkResults> {
let mut results = BenchmarkResults {
format: format.to_string(),
iterations,
serialization_times: Vec::new(),
deserialization_times: Vec::new(),
sizes: Vec::new(),
};
for test_data in &self.test_data {
let mut ser_times = Vec::new();
let mut deser_times = Vec::new();
for _ in 0..iterations {
let start = Instant::now();
let serialized = self.registry.serialize(format, test_data).await?;
ser_times.push(start.elapsed().as_secs_f64() * 1000.0);
let start = Instant::now();
self.registry.deserialize(format, &serialized).await?;
deser_times.push(start.elapsed().as_secs_f64() * 1000.0);
}
let avg_ser = ser_times.iter().sum::<f64>() / ser_times.len() as f64;
let avg_deser = deser_times.iter().sum::<f64>() / deser_times.len() as f64;
results.serialization_times.push(avg_ser);
results.deserialization_times.push(avg_deser);
results.sizes.push(test_data.len());
}
debug!("Benchmark completed for {}: {:?}", format, results);
Ok(results)
}
pub async fn compare(
&self,
formats: &[String],
iterations: usize,
) -> Result<Vec<BenchmarkResults>> {
let mut all_results = Vec::new();
for format in formats {
let results = self.benchmark(format, iterations).await?;
all_results.push(results);
}
Ok(all_results)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BenchmarkResults {
pub format: String,
pub iterations: usize,
pub serialization_times: Vec<f64>,
pub deserialization_times: Vec<f64>,
pub sizes: Vec<usize>,
}
impl BenchmarkResults {
pub fn avg_serialization_time(&self) -> f64 {
if self.serialization_times.is_empty() {
0.0
} else {
self.serialization_times.iter().sum::<f64>() / self.serialization_times.len() as f64
}
}
pub fn avg_deserialization_time(&self) -> f64 {
if self.deserialization_times.is_empty() {
0.0
} else {
self.deserialization_times.iter().sum::<f64>() / self.deserialization_times.len() as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_register_custom_serializer() {
let registry = SerializerRegistry::new();
registry
.register("bson", Box::new(BsonSerializer))
.await
.unwrap();
let formats = registry.list().await;
assert!(formats.contains(&"bson".to_string()));
}
#[tokio::test]
async fn test_serialize_deserialize() {
let registry = SerializerRegistry::new();
registry
.register("bson", Box::new(BsonSerializer))
.await
.unwrap();
let data = b"test data";
let serialized = registry.serialize("bson", data).await.unwrap();
let deserialized = registry.deserialize("bson", &serialized).await.unwrap();
assert_eq!(deserialized, data);
}
#[tokio::test]
async fn test_format_detection() {
let registry = SerializerRegistry::new();
registry
.register("bson", Box::new(BsonSerializer))
.await
.unwrap();
registry
.register("thrift", Box::new(ThriftSerializer))
.await
.unwrap();
let data = b"BSONtest data";
let format = registry.detect_format(data).await;
assert_eq!(format, Some("bson".to_string()));
}
#[tokio::test]
async fn test_benchmark() {
let registry = Arc::new(SerializerRegistry::new());
registry
.register("bson", Box::new(BsonSerializer))
.await
.unwrap();
let suite = SerializerBenchmarkSuite::new(registry.clone());
let results = suite.benchmark("bson", 10).await.unwrap();
assert_eq!(results.format, "bson");
assert_eq!(results.iterations, 10);
assert!(!results.serialization_times.is_empty());
}
#[tokio::test]
async fn test_multiple_formats() {
let registry = SerializerRegistry::new();
registry
.register("bson", Box::new(BsonSerializer))
.await
.unwrap();
registry
.register("thrift", Box::new(ThriftSerializer))
.await
.unwrap();
registry
.register("flexbuffers", Box::new(FlexBuffersSerializer))
.await
.unwrap();
registry
.register("ron", Box::new(RonSerializer))
.await
.unwrap();
registry
.register("ion", Box::new(IonSerializer))
.await
.unwrap();
let formats = registry.list().await;
assert_eq!(formats.len(), 5);
}
#[tokio::test]
async fn test_unregister() {
let registry = SerializerRegistry::new();
registry
.register("bson", Box::new(BsonSerializer))
.await
.unwrap();
assert!(registry.list().await.contains(&"bson".to_string()));
registry.unregister("bson").await.unwrap();
assert!(!registry.list().await.contains(&"bson".to_string()));
}
}