use crate::Message;
use once_cell::sync::Lazy;
use std::any::Any;
use std::collections::HashMap;
use std::fmt;
use std::sync::{Arc, RwLock};
#[derive(Debug, Clone)]
pub enum SerializationError {
SerializeFailed(String),
DeserializeFailed(String),
UnknownMessageType(String),
InvalidFormat(String),
TypeMismatch { expected: String, found: String },
}
impl fmt::Display for SerializationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SerializationError::SerializeFailed(msg) => {
write!(f, "Serialization failed: {}", msg)
}
SerializationError::DeserializeFailed(msg) => {
write!(f, "Deserialization failed: {}", msg)
}
SerializationError::UnknownMessageType(type_id) => {
write!(f, "Unknown message type: {}", type_id)
}
SerializationError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
SerializationError::TypeMismatch { expected, found } => {
write!(f, "Type mismatch: expected {}, found {}", expected, found)
}
}
}
}
impl std::error::Error for SerializationError {}
pub trait SerializableMessage: Any + Send + Sync {
fn message_type_id(&self) -> &'static str;
fn as_any(&self) -> &dyn Any;
fn serialize(&self) -> Result<Vec<u8>, SerializationError>;
}
pub type Deserializer =
Box<dyn Fn(&[u8]) -> Result<Box<dyn SerializableMessage>, SerializationError> + Send + Sync>;
#[derive(Clone)]
pub struct MessageRegistry {
deserializers: Arc<RwLock<HashMap<String, Deserializer>>>,
}
impl MessageRegistry {
pub fn new() -> Self {
Self {
deserializers: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn register(&mut self, type_id: &str, deserializer: Deserializer) {
self.deserializers
.write()
.unwrap()
.insert(type_id.to_string(), deserializer);
}
pub fn deserialize(
&self,
type_id: &str,
data: &[u8],
) -> Result<Box<dyn SerializableMessage>, SerializationError> {
let deserializers = self.deserializers.read().unwrap();
let deserializer = deserializers
.get(type_id)
.ok_or_else(|| SerializationError::UnknownMessageType(type_id.to_string()))?;
deserializer(data)
}
pub fn has_type(&self, type_id: &str) -> bool {
self.deserializers.read().unwrap().contains_key(type_id)
}
pub fn len(&self) -> usize {
self.deserializers.read().unwrap().len()
}
pub fn is_empty(&self) -> bool {
self.deserializers.read().unwrap().is_empty()
}
pub fn clear(&mut self) {
self.deserializers.write().unwrap().clear();
}
}
impl Default for MessageRegistry {
fn default() -> Self {
Self::new()
}
}
static GLOBAL_REGISTRY: Lazy<Arc<RwLock<MessageRegistry>>> =
Lazy::new(|| Arc::new(RwLock::new(MessageRegistry::new())));
pub fn register_message_type(type_id: &str, deserializer: Deserializer) {
GLOBAL_REGISTRY
.write()
.unwrap()
.register(type_id, deserializer);
}
pub fn get_global_registry() -> Arc<RwLock<MessageRegistry>> {
Arc::clone(&GLOBAL_REGISTRY)
}
#[derive(Debug, Clone)]
pub struct SerializableEnvelope {
type_id: String,
data: Vec<u8>,
}
impl SerializableEnvelope {
pub fn wrap(msg: &dyn SerializableMessage) -> Result<Self, SerializationError> {
let type_id = msg.message_type_id().to_string();
let data = msg.serialize()?;
Ok(Self { type_id, data })
}
pub fn type_id(&self) -> &str {
&self.type_id
}
pub fn data(&self) -> &[u8] {
&self.data
}
pub fn unwrap(
&self,
registry: &MessageRegistry,
) -> Result<Box<dyn SerializableMessage>, SerializationError> {
registry.deserialize(&self.type_id, &self.data)
}
pub fn to_bytes(&self) -> Vec<u8> {
let type_id_bytes = self.type_id.as_bytes();
let type_id_len = type_id_bytes.len() as u32;
let data_len = self.data.len() as u32;
let mut result = Vec::with_capacity(8 + type_id_bytes.len() + self.data.len());
result.extend_from_slice(&type_id_len.to_le_bytes());
result.extend_from_slice(type_id_bytes);
result.extend_from_slice(&data_len.to_le_bytes());
result.extend_from_slice(&self.data);
result
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, SerializationError> {
if bytes.len() < 8 {
return Err(SerializationError::InvalidFormat(
"Envelope too short".to_string(),
));
}
let type_id_len = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
if bytes.len() < 8 + type_id_len {
return Err(SerializationError::InvalidFormat(
"Invalid type_id length".to_string(),
));
}
let type_id_bytes = &bytes[4..4 + type_id_len];
let type_id = String::from_utf8(type_id_bytes.to_vec())
.map_err(|e| SerializationError::InvalidFormat(format!("Invalid UTF-8: {}", e)))?;
let data_len_offset = 4 + type_id_len;
if bytes.len() < data_len_offset + 4 {
return Err(SerializationError::InvalidFormat(
"Missing data length".to_string(),
));
}
let data_len = u32::from_le_bytes([
bytes[data_len_offset],
bytes[data_len_offset + 1],
bytes[data_len_offset + 2],
bytes[data_len_offset + 3],
]) as usize;
let data_offset = data_len_offset + 4;
if bytes.len() < data_offset + data_len {
return Err(SerializationError::InvalidFormat(
"Invalid data length".to_string(),
));
}
let data = bytes[data_offset..data_offset + data_len].to_vec();
Ok(Self { type_id, data })
}
pub fn to_message(msg: Box<dyn SerializableMessage>) -> Message {
Box::new(msg)
}
pub fn from_message(msg: &Message) -> Option<&dyn SerializableMessage> {
msg.downcast_ref::<Box<dyn SerializableMessage>>()
.map(|b| b.as_ref())
}
}
#[macro_export]
macro_rules! impl_serializable {
($type:ty, $type_id:expr, $serialize_fn:expr) => {
impl $crate::serialization::SerializableMessage for $type {
fn message_type_id(&self) -> &'static str {
$type_id
}
fn as_any(&self) -> &dyn ::std::any::Any {
self
}
fn serialize(&self) -> Result<Vec<u8>, $crate::serialization::SerializationError> {
let serialize: fn(
&Self,
)
-> Result<Vec<u8>, $crate::serialization::SerializationError> = $serialize_fn;
serialize(self)
}
}
};
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, Clone, PartialEq)]
struct TestMessage {
value: u32,
}
impl SerializableMessage for TestMessage {
fn message_type_id(&self) -> &'static str {
"TestMessage"
}
fn as_any(&self) -> &dyn Any {
self
}
fn serialize(&self) -> Result<Vec<u8>, SerializationError> {
Ok(self.value.to_le_bytes().to_vec())
}
}
fn deserialize_test_message(
data: &[u8],
) -> Result<Box<dyn SerializableMessage>, SerializationError> {
if data.len() != 4 {
return Err(SerializationError::DeserializeFailed(
"Invalid length".to_string(),
));
}
let value = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
Ok(Box::new(TestMessage { value }))
}
#[test]
fn test_message_serialization() {
let msg = TestMessage { value: 42 };
let data = msg.serialize().unwrap();
assert_eq!(data, vec![42, 0, 0, 0]);
}
#[test]
fn test_registry() {
let mut registry = MessageRegistry::new();
assert!(registry.is_empty());
registry.register("TestMessage", Box::new(deserialize_test_message));
assert_eq!(registry.len(), 1);
assert!(registry.has_type("TestMessage"));
assert!(!registry.has_type("UnknownType"));
}
#[test]
fn test_registry_deserialize() {
let mut registry = MessageRegistry::new();
registry.register("TestMessage", Box::new(deserialize_test_message));
let data = vec![42, 0, 0, 0];
let msg = registry.deserialize("TestMessage", &data).unwrap();
let test_msg = msg.as_any().downcast_ref::<TestMessage>().unwrap();
assert_eq!(test_msg.value, 42);
}
#[test]
fn test_registry_unknown_type() {
let registry = MessageRegistry::new();
let result = registry.deserialize("Unknown", &[]);
assert!(matches!(
result,
Err(SerializationError::UnknownMessageType(_))
));
}
#[test]
fn test_envelope_wrap_unwrap() {
let msg = TestMessage { value: 123 };
let envelope = SerializableEnvelope::wrap(&msg).unwrap();
assert_eq!(envelope.type_id(), "TestMessage");
let mut registry = MessageRegistry::new();
registry.register("TestMessage", Box::new(deserialize_test_message));
let unwrapped = envelope.unwrap(®istry).unwrap();
let result = unwrapped.as_any().downcast_ref::<TestMessage>().unwrap();
assert_eq!(result.value, 123);
}
#[test]
fn test_envelope_to_from_bytes() {
let msg = TestMessage { value: 999 };
let envelope = SerializableEnvelope::wrap(&msg).unwrap();
let bytes = envelope.to_bytes();
let reconstructed = SerializableEnvelope::from_bytes(&bytes).unwrap();
assert_eq!(envelope.type_id(), reconstructed.type_id());
assert_eq!(envelope.data(), reconstructed.data());
}
#[test]
fn test_envelope_from_bytes_invalid() {
let result = SerializableEnvelope::from_bytes(&[1, 2, 3]);
assert!(matches!(result, Err(SerializationError::InvalidFormat(_))));
}
#[test]
fn test_registry_clone() {
let mut registry = MessageRegistry::new();
registry.register("TestMessage", Box::new(deserialize_test_message));
let cloned = registry.clone();
assert!(cloned.has_type("TestMessage"));
}
#[test]
fn test_registry_clear() {
let mut registry = MessageRegistry::new();
registry.register("TestMessage", Box::new(deserialize_test_message));
assert_eq!(registry.len(), 1);
registry.clear();
assert!(registry.is_empty());
}
}