use crate::error::CheckpointError;
use serde::Serialize;
use serde::de::DeserializeOwned;
#[cfg(feature = "encryption")]
use aes_gcm::{Aes256Gcm, Nonce, aead::Aead};
#[cfg(feature = "encryption")]
use aes_gcm::aead::{AeadCore, KeyInit, OsRng};
#[cfg(feature = "encryption")]
use aes_gcm::aead::generic_array::GenericArray;
#[cfg(feature = "encryption")]
use pbkdf2::pbkdf2_hmac;
#[cfg(feature = "encryption")]
use sha2::Sha256;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub enum SerializationFormat {
#[default]
MessagePack,
Json,
}
#[derive(Clone, Debug, Default)]
pub enum SerializerKind {
#[default]
MessagePack,
Json,
}
impl SerializerKind {
pub fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
match self {
Self::MessagePack => {
rmp_serde::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
}
Self::Json => {
serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
}
}
}
pub fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
match self {
Self::MessagePack => {
rmp_serde::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
}
Self::Json => {
serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
}
}
}
#[must_use]
pub const fn format(&self) -> SerializationFormat {
match self {
Self::MessagePack => SerializationFormat::MessagePack,
Self::Json => SerializationFormat::Json,
}
}
}
pub trait CheckpointSerializer: Send + Sync + 'static {
fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError>;
fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError>;
fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError>;
fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError>;
#[must_use]
fn format(&self) -> SerializationFormat;
}
#[derive(Clone, Debug, Default)]
pub struct MsgpackSerializer;
impl MsgpackSerializer {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl CheckpointSerializer for MsgpackSerializer {
fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
rmp_serde::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
}
fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
rmp_serde::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
}
fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
rmp_serde::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
}
fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
rmp_serde::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
}
fn format(&self) -> SerializationFormat {
SerializationFormat::MessagePack
}
}
#[derive(Clone, Debug, Default)]
pub struct JsonSerializer;
impl JsonSerializer {
#[must_use]
pub const fn new() -> Self {
Self
}
}
impl CheckpointSerializer for JsonSerializer {
fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
}
fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
}
fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
}
fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
}
fn format(&self) -> SerializationFormat {
SerializationFormat::Json
}
}
#[derive(Clone, Debug)]
pub struct JsonPlusSerializer {
pretty: bool,
}
impl JsonPlusSerializer {
#[must_use]
pub const fn new() -> Self {
Self { pretty: true }
}
#[must_use]
pub const fn with_pretty(pretty: bool) -> Self {
Self { pretty }
}
}
impl Default for JsonPlusSerializer {
fn default() -> Self {
Self::new()
}
}
impl CheckpointSerializer for JsonPlusSerializer {
fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
if self.pretty {
serde_json::to_vec_pretty(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
} else {
serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
}
}
fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
}
fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
if self.pretty {
serde_json::to_vec_pretty(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
} else {
serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
}
}
fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
}
fn format(&self) -> SerializationFormat {
SerializationFormat::Json
}
}
#[cfg(feature = "encryption")]
#[derive(Clone)]
pub struct EncryptedSerializer<S: CheckpointSerializer> {
inner: S,
cipher: Aes256Gcm,
}
#[cfg(feature = "encryption")]
impl<S: CheckpointSerializer> EncryptedSerializer<S> {
pub fn new(inner: S, key: &[u8; 32]) -> Self {
let cipher = Aes256Gcm::new(GenericArray::from_slice(key));
Self { inner, cipher }
}
pub fn from_passphrase(
inner: S,
passphrase: &str,
salt: &[u8; 32],
) -> Result<Self, CheckpointError> {
let mut key = [0u8; 32];
pbkdf2_hmac::<Sha256>(passphrase.as_bytes(), salt, 100_000, &mut key);
let cipher = Aes256Gcm::new(GenericArray::from_slice(&key));
Ok(Self { inner, cipher })
}
}
#[cfg(feature = "encryption")]
impl<S: CheckpointSerializer + std::fmt::Debug> std::fmt::Debug for EncryptedSerializer<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EncryptedSerializer")
.field("inner", &self.inner)
.field("cipher", &"<aes-256-gcm cipher>")
.finish()
}
}
#[cfg(feature = "encryption")]
impl<S: CheckpointSerializer> CheckpointSerializer for EncryptedSerializer<S> {
fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
let plaintext = self.inner.serialize_value(value)?;
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let ciphertext = self
.cipher
.encrypt(&nonce, plaintext.as_ref())
.map_err(|e| CheckpointError::serialize_msg(format!("Encryption failed: {e}")))?;
let mut result = Vec::with_capacity(12 + ciphertext.len());
result.extend_from_slice(&nonce);
result.extend_from_slice(&ciphertext);
Ok(result)
}
fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
if data.len() < 12 {
return Err(CheckpointError::deserialize_msg(
"Encrypted data too short".to_string(),
));
}
let (nonce_bytes, ciphertext) = data.split_at(12);
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = self
.cipher
.decrypt(nonce, ciphertext)
.map_err(|e| CheckpointError::deserialize_msg(format!("Decryption failed: {e}")))?;
self.inner.deserialize_value(&plaintext)
}
fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
let json_value =
serde_json::to_value(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))?;
self.serialize_value(&json_value)
}
fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
let json_value = self.deserialize_value(data)?;
serde_json::from_value(json_value).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
}
fn format(&self) -> SerializationFormat {
self.inner.format()
}
}
#[must_use]
pub fn detect_format(data: &[u8]) -> SerializationFormat {
if data.is_empty() {
return SerializationFormat::Json;
}
let first_byte = data[0];
if first_byte == b'{' || first_byte == b'[' || first_byte.is_ascii_whitespace() {
return SerializationFormat::Json;
}
if (0x80..=0x9f).contains(&first_byte)
|| first_byte == 0xde
|| first_byte == 0xdf
|| first_byte == 0xdc
|| first_byte == 0xdd
{
return SerializationFormat::MessagePack;
}
SerializationFormat::Json
}
pub fn deserialize_auto<T: DeserializeOwned>(data: &[u8]) -> Result<T, CheckpointError> {
let format = detect_format(data);
match format {
SerializationFormat::MessagePack => {
MsgpackSerializer::new()
.deserialize::<T>(data)
.or_else(|_| JsonSerializer::new().deserialize::<T>(data))
}
SerializationFormat::Json => JsonSerializer::new().deserialize::<T>(data),
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_msgpack_serializer_roundtrip() {
let ser = MsgpackSerializer::new();
let original = json!({"key": "value", "number": 42});
let serialized_data = ser.serialize_value(&original).unwrap();
let deserialized = ser.deserialize_value(&serialized_data).unwrap();
assert_eq!(original, deserialized);
}
#[test]
fn test_json_serializer_roundtrip() {
let ser = JsonSerializer::new();
let original = json!({"key": "value", "number": 42});
let serialized_data = ser.serialize_value(&original).unwrap();
let deserialized = ser.deserialize_value(&serialized_data).unwrap();
assert_eq!(original, deserialized);
}
#[test]
fn test_json_plus_serializer_pretty() {
let ser = JsonPlusSerializer::new();
let original = json!({"key": "value", "nested": {"a": 1}});
let serialized_data = ser.serialize_value(&original).unwrap();
let serialized_str = std::str::from_utf8(&serialized_data).unwrap();
assert!(serialized_str.contains('\n'));
let deserialized = ser.deserialize_value(&serialized_data).unwrap();
assert_eq!(original, deserialized);
}
#[test]
fn test_checkpoint_detect_format_json() {
let json_data = b"{\"key\":\"value\"}";
let format = detect_format(json_data);
assert_eq!(format, SerializationFormat::Json);
}
#[test]
fn test_checkpoint_detect_format_msgpack() {
let serializer = MsgpackSerializer::new();
let value = json!({"key": "value"});
let msgpack_data = serializer.serialize_value(&value).unwrap();
let format = detect_format(&msgpack_data);
assert_eq!(format, SerializationFormat::MessagePack);
}
#[test]
fn test_checkpoint_detect_format_empty() {
let format = detect_format(&[]);
assert_eq!(format, SerializationFormat::Json);
}
#[cfg(feature = "encryption")]
#[test]
fn test_encrypted_serializer() {
use aes_gcm::aead::rand_core::RngCore;
let inner = JsonSerializer::new();
let mut key = [0u8; 32];
OsRng.fill_bytes(&mut key);
let serializer = EncryptedSerializer::new(inner, &key);
let original = json!({"secret": "data"});
let encrypted = serializer.serialize_value(&original).unwrap();
assert!(encrypted.len() > original.to_string().len());
let decrypted = serializer.deserialize_value(&encrypted).unwrap();
assert_eq!(original, decrypted);
}
#[test]
fn test_serialization_format_eq() {
assert_eq!(
SerializationFormat::MessagePack,
SerializationFormat::MessagePack
);
assert_eq!(SerializationFormat::Json, SerializationFormat::Json);
assert_ne!(SerializationFormat::MessagePack, SerializationFormat::Json);
}
}