1use crate::error::CheckpointError;
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9
10#[cfg(feature = "encryption")]
11use aes_gcm::{Aes256Gcm, Nonce, aead::Aead};
12
13#[cfg(feature = "encryption")]
14use aes_gcm::aead::{AeadCore, KeyInit, OsRng};
15
16#[cfg(feature = "encryption")]
17use aes_gcm::aead::generic_array::GenericArray;
18
19#[cfg(feature = "encryption")]
20use pbkdf2::pbkdf2_hmac;
21
22#[cfg(feature = "encryption")]
23use sha2::Sha256;
24
25#[derive(Clone, Debug, Default, PartialEq, Eq)]
29pub enum SerializationFormat {
30 #[default]
32 MessagePack,
33
34 Json,
36}
37
38#[derive(Clone, Debug, Default)]
43pub enum SerializerKind {
44 #[default]
46 MessagePack,
47 Json,
49}
50
51impl SerializerKind {
52 pub fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
58 match self {
59 Self::MessagePack => {
60 rmp_serde::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
61 }
62 Self::Json => {
63 serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
64 }
65 }
66 }
67
68 pub fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
74 match self {
75 Self::MessagePack => {
76 rmp_serde::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
77 }
78 Self::Json => {
79 serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
80 }
81 }
82 }
83
84 #[must_use]
86 pub const fn format(&self) -> SerializationFormat {
87 match self {
88 Self::MessagePack => SerializationFormat::MessagePack,
89 Self::Json => SerializationFormat::Json,
90 }
91 }
92}
93
94pub trait CheckpointSerializer: Send + Sync + 'static {
99 fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError>;
105
106 fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError>;
112
113 fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError>;
119
120 fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError>;
126
127 #[must_use]
129 fn format(&self) -> SerializationFormat;
130}
131
132#[derive(Clone, Debug, Default)]
137pub struct MsgpackSerializer;
138
139impl MsgpackSerializer {
140 #[must_use]
142 pub const fn new() -> Self {
143 Self
144 }
145}
146
147impl CheckpointSerializer for MsgpackSerializer {
148 fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
149 rmp_serde::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
150 }
151
152 fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
153 rmp_serde::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
154 }
155
156 fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
157 rmp_serde::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
158 }
159
160 fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
161 rmp_serde::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
162 }
163
164 fn format(&self) -> SerializationFormat {
165 SerializationFormat::MessagePack
166 }
167}
168
169#[derive(Clone, Debug, Default)]
174pub struct JsonSerializer;
175
176impl JsonSerializer {
177 #[must_use]
179 pub const fn new() -> Self {
180 Self
181 }
182}
183
184impl CheckpointSerializer for JsonSerializer {
185 fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
186 serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
187 }
188
189 fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
190 serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
191 }
192
193 fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
194 serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
195 }
196
197 fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
198 serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
199 }
200
201 fn format(&self) -> SerializationFormat {
202 SerializationFormat::Json
203 }
204}
205
206#[derive(Clone, Debug)]
210pub struct JsonPlusSerializer {
211 pretty: bool,
213}
214
215impl JsonPlusSerializer {
216 #[must_use]
218 pub const fn new() -> Self {
219 Self { pretty: true }
220 }
221
222 #[must_use]
224 pub const fn with_pretty(pretty: bool) -> Self {
225 Self { pretty }
226 }
227}
228
229impl Default for JsonPlusSerializer {
230 fn default() -> Self {
231 Self::new()
232 }
233}
234
235impl CheckpointSerializer for JsonPlusSerializer {
236 fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
237 if self.pretty {
238 serde_json::to_vec_pretty(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
239 } else {
240 serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
241 }
242 }
243
244 fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
245 serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
246 }
247
248 fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
249 if self.pretty {
250 serde_json::to_vec_pretty(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
251 } else {
252 serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
253 }
254 }
255
256 fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
257 serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
258 }
259
260 fn format(&self) -> SerializationFormat {
261 SerializationFormat::Json
262 }
263}
264
265#[cfg(feature = "encryption")]
273#[derive(Clone)]
274pub struct EncryptedSerializer<S: CheckpointSerializer> {
275 inner: S,
277 cipher: Aes256Gcm,
279}
280
281#[cfg(feature = "encryption")]
282impl<S: CheckpointSerializer> EncryptedSerializer<S> {
283 pub fn new(inner: S, key: &[u8; 32]) -> Self {
292 let cipher = Aes256Gcm::new(GenericArray::from_slice(key));
293 Self { inner, cipher }
294 }
295
296 pub fn from_passphrase(
305 inner: S,
306 passphrase: &str,
307 salt: &[u8; 32],
308 ) -> Result<Self, CheckpointError> {
309 let mut key = [0u8; 32];
310 pbkdf2_hmac::<Sha256>(passphrase.as_bytes(), salt, 100_000, &mut key);
311 let cipher = Aes256Gcm::new(GenericArray::from_slice(&key));
312 Ok(Self { inner, cipher })
313 }
314}
315
316#[cfg(feature = "encryption")]
317impl<S: CheckpointSerializer + std::fmt::Debug> std::fmt::Debug for EncryptedSerializer<S> {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 f.debug_struct("EncryptedSerializer")
320 .field("inner", &self.inner)
321 .field("cipher", &"<aes-256-gcm cipher>")
322 .finish()
323 }
324}
325
326#[cfg(feature = "encryption")]
327impl<S: CheckpointSerializer> CheckpointSerializer for EncryptedSerializer<S> {
328 fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
329 let plaintext = self.inner.serialize_value(value)?;
331
332 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
334
335 let ciphertext = self
337 .cipher
338 .encrypt(&nonce, plaintext.as_ref())
339 .map_err(|e| CheckpointError::serialize_msg(format!("Encryption failed: {e}")))?;
340
341 let mut result = Vec::with_capacity(12 + ciphertext.len());
343 result.extend_from_slice(&nonce);
344 result.extend_from_slice(&ciphertext);
345
346 Ok(result)
347 }
348
349 fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
350 if data.len() < 12 {
351 return Err(CheckpointError::deserialize_msg(
352 "Encrypted data too short".to_string(),
353 ));
354 }
355
356 let (nonce_bytes, ciphertext) = data.split_at(12);
358 let nonce = Nonce::from_slice(nonce_bytes);
359
360 let plaintext = self
362 .cipher
363 .decrypt(nonce, ciphertext)
364 .map_err(|e| CheckpointError::deserialize_msg(format!("Decryption failed: {e}")))?;
365
366 self.inner.deserialize_value(&plaintext)
368 }
369
370 fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
371 let json_value =
373 serde_json::to_value(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))?;
374 self.serialize_value(&json_value)
375 }
376
377 fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
378 let json_value = self.deserialize_value(data)?;
379 serde_json::from_value(json_value).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
380 }
381
382 fn format(&self) -> SerializationFormat {
383 self.inner.format()
384 }
385}
386
387#[must_use]
401pub fn detect_format(data: &[u8]) -> SerializationFormat {
402 if data.is_empty() {
406 return SerializationFormat::Json;
407 }
408
409 let first_byte = data[0];
410
411 if first_byte == b'{' || first_byte == b'[' || first_byte.is_ascii_whitespace() {
413 return SerializationFormat::Json;
414 }
415
416 if (0x80..=0x9f).contains(&first_byte)
420 || first_byte == 0xde
421 || first_byte == 0xdf
422 || first_byte == 0xdc
423 || first_byte == 0xdd
424 {
425 return SerializationFormat::MessagePack;
426 }
427
428 SerializationFormat::Json
430}
431
432pub fn deserialize_auto<T: DeserializeOwned>(data: &[u8]) -> Result<T, CheckpointError> {
447 let format = detect_format(data);
448 match format {
449 SerializationFormat::MessagePack => {
450 MsgpackSerializer::new()
452 .deserialize::<T>(data)
453 .or_else(|_| JsonSerializer::new().deserialize::<T>(data))
454 }
455 SerializationFormat::Json => JsonSerializer::new().deserialize::<T>(data),
456 }
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462 use serde_json::json;
463
464 #[test]
465 fn test_msgpack_serializer_roundtrip() {
466 let ser = MsgpackSerializer::new();
467 let original = json!({"key": "value", "number": 42});
468
469 let serialized_data = ser.serialize_value(&original).unwrap();
470 let deserialized = ser.deserialize_value(&serialized_data).unwrap();
471
472 assert_eq!(original, deserialized);
473 }
474
475 #[test]
476 fn test_json_serializer_roundtrip() {
477 let ser = JsonSerializer::new();
478 let original = json!({"key": "value", "number": 42});
479
480 let serialized_data = ser.serialize_value(&original).unwrap();
481 let deserialized = ser.deserialize_value(&serialized_data).unwrap();
482
483 assert_eq!(original, deserialized);
484 }
485
486 #[test]
487 fn test_json_plus_serializer_pretty() {
488 let ser = JsonPlusSerializer::new();
489 let original = json!({"key": "value", "nested": {"a": 1}});
490
491 let serialized_data = ser.serialize_value(&original).unwrap();
492 let serialized_str = std::str::from_utf8(&serialized_data).unwrap();
493
494 assert!(serialized_str.contains('\n'));
496
497 let deserialized = ser.deserialize_value(&serialized_data).unwrap();
498 assert_eq!(original, deserialized);
499 }
500
501 #[test]
502 fn test_checkpoint_detect_format_json() {
503 let json_data = b"{\"key\":\"value\"}";
504 let format = detect_format(json_data);
505 assert_eq!(format, SerializationFormat::Json);
506 }
507
508 #[test]
509 fn test_checkpoint_detect_format_msgpack() {
510 let serializer = MsgpackSerializer::new();
512 let value = json!({"key": "value"});
513 let msgpack_data = serializer.serialize_value(&value).unwrap();
514
515 let format = detect_format(&msgpack_data);
516 assert_eq!(format, SerializationFormat::MessagePack);
517 }
518
519 #[test]
520 fn test_checkpoint_detect_format_empty() {
521 let format = detect_format(&[]);
522 assert_eq!(format, SerializationFormat::Json);
523 }
524
525 #[cfg(feature = "encryption")]
526 #[test]
527 fn test_encrypted_serializer() {
528 use aes_gcm::aead::rand_core::RngCore;
529
530 let inner = JsonSerializer::new();
531 let mut key = [0u8; 32];
532 OsRng.fill_bytes(&mut key);
533
534 let serializer = EncryptedSerializer::new(inner, &key);
535 let original = json!({"secret": "data"});
536
537 let encrypted = serializer.serialize_value(&original).unwrap();
538
539 assert!(encrypted.len() > original.to_string().len());
541
542 let decrypted = serializer.deserialize_value(&encrypted).unwrap();
543 assert_eq!(original, decrypted);
544 }
545
546 #[test]
547 fn test_serialization_format_eq() {
548 assert_eq!(
549 SerializationFormat::MessagePack,
550 SerializationFormat::MessagePack
551 );
552 assert_eq!(SerializationFormat::Json, SerializationFormat::Json);
553 assert_ne!(SerializationFormat::MessagePack, SerializationFormat::Json);
554 }
555}
556
557