Skip to main content

juncture_checkpoint/
serde.rs

1//! Checkpoint serialization
2//!
3//! Provides serialization abstractions and implementations for storing checkpoint data
4//! in multiple formats (`MessagePack`, JSON, and optionally encrypted).
5
6use crate::error::CheckpointError;
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9
10#[cfg(feature = "encryption")]
11use aes_gcm::{Aes256Gcm, Nonce, aead::Aead};
12
13#[cfg(feature = "encryption")]
14use aes_gcm::aead::{AeadCore, KeyInit, OsRng};
15
16#[cfg(feature = "encryption")]
17use aes_gcm::aead::generic_array::GenericArray;
18
19#[cfg(feature = "encryption")]
20use pbkdf2::pbkdf2_hmac;
21
22#[cfg(feature = "encryption")]
23use sha2::Sha256;
24
25/// Serialization format
26///
27/// Defines the supported serialization formats for checkpoint data.
28#[derive(Clone, Debug, Default, PartialEq, Eq)]
29pub enum SerializationFormat {
30    /// `MessagePack` binary format (default, high performance)
31    #[default]
32    MessagePack,
33
34    /// JSON text format (human readable, debug friendly)
35    Json,
36}
37
38/// Serializer kind for checkpoint data
39///
40/// An enum-dispatched serializer that can be stored in checkpoint savers without
41/// requiring dynamic dispatch. Defaults to `MessagePack`.
42#[derive(Clone, Debug, Default)]
43pub enum SerializerKind {
44    /// `MessagePack` binary format (default, high performance)
45    #[default]
46    MessagePack,
47    /// JSON text format (human readable, debug friendly)
48    Json,
49}
50
51impl SerializerKind {
52    /// Serialize a serializable value to bytes using this serializer
53    ///
54    /// # Errors
55    ///
56    /// Returns [`CheckpointError::Serialize`] if serialization fails.
57    pub fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
58        match self {
59            Self::MessagePack => {
60                rmp_serde::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
61            }
62            Self::Json => {
63                serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
64            }
65        }
66    }
67
68    /// Deserialize bytes to a deserializable type using this serializer
69    ///
70    /// # Errors
71    ///
72    /// Returns [`CheckpointError::Deserialize`] if deserialization fails.
73    pub fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
74        match self {
75            Self::MessagePack => {
76                rmp_serde::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
77            }
78            Self::Json => {
79                serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
80            }
81        }
82    }
83
84    /// Get the format identifier
85    #[must_use]
86    pub const fn format(&self) -> SerializationFormat {
87        match self {
88            Self::MessagePack => SerializationFormat::MessagePack,
89            Self::Json => SerializationFormat::Json,
90        }
91    }
92}
93
94/// Checkpoint serializer trait
95///
96/// Abstraction over different serialization formats, allowing checkpoint storage
97/// to use JSON, `MessagePack`, or custom serialization strategies.
98pub trait CheckpointSerializer: Send + Sync + 'static {
99    /// Serialize a JSON value to bytes
100    ///
101    /// # Errors
102    ///
103    /// Returns [`CheckpointError::Serialize`] if serialization fails.
104    fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError>;
105
106    /// Deserialize bytes back to a JSON value
107    ///
108    /// # Errors
109    ///
110    /// Returns [`CheckpointError::Deserialize`] if deserialization fails.
111    fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError>;
112
113    /// Serialize any serializable type to bytes
114    ///
115    /// # Errors
116    ///
117    /// Returns [`CheckpointError::Serialize`] if serialization fails.
118    fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError>;
119
120    /// Deserialize bytes to any deserializable type
121    ///
122    /// # Errors
123    ///
124    /// Returns [`CheckpointError::Deserialize`] if deserialization fails.
125    fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError>;
126
127    /// Get the format identifier
128    #[must_use]
129    fn format(&self) -> SerializationFormat;
130}
131
132/// `MessagePack` serializer
133///
134/// High-performance binary serialization using `MessagePack` format.
135/// This is the default serializer for production use.
136#[derive(Clone, Debug, Default)]
137pub struct MsgpackSerializer;
138
139impl MsgpackSerializer {
140    /// Create a new `MessagePack` serializer
141    #[must_use]
142    pub const fn new() -> Self {
143        Self
144    }
145}
146
147impl CheckpointSerializer for MsgpackSerializer {
148    fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
149        rmp_serde::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
150    }
151
152    fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
153        rmp_serde::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
154    }
155
156    fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
157        rmp_serde::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
158    }
159
160    fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
161        rmp_serde::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
162    }
163
164    fn format(&self) -> SerializationFormat {
165        SerializationFormat::MessagePack
166    }
167}
168
169/// JSON serializer
170///
171/// Human-readable text serialization using JSON format.
172/// Useful for debugging and development environments.
173#[derive(Clone, Debug, Default)]
174pub struct JsonSerializer;
175
176impl JsonSerializer {
177    /// Create a new JSON serializer
178    #[must_use]
179    pub const fn new() -> Self {
180        Self
181    }
182}
183
184impl CheckpointSerializer for JsonSerializer {
185    fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
186        serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
187    }
188
189    fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
190        serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
191    }
192
193    fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
194        serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
195    }
196
197    fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
198        serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
199    }
200
201    fn format(&self) -> SerializationFormat {
202        SerializationFormat::Json
203    }
204}
205
206/// JSON+ serializer (pretty-printed)
207///
208/// Like `JsonSerializer` but with pretty-printing for better human readability.
209#[derive(Clone, Debug)]
210pub struct JsonPlusSerializer {
211    /// Pretty-print output
212    pretty: bool,
213}
214
215impl JsonPlusSerializer {
216    /// Create a new JSON+ serializer with pretty-printing
217    #[must_use]
218    pub const fn new() -> Self {
219        Self { pretty: true }
220    }
221
222    /// Create a new JSON+ serializer with configurable pretty-printing
223    #[must_use]
224    pub const fn with_pretty(pretty: bool) -> Self {
225        Self { pretty }
226    }
227}
228
229impl Default for JsonPlusSerializer {
230    fn default() -> Self {
231        Self::new()
232    }
233}
234
235impl CheckpointSerializer for JsonPlusSerializer {
236    fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
237        if self.pretty {
238            serde_json::to_vec_pretty(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
239        } else {
240            serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
241        }
242    }
243
244    fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
245        serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
246    }
247
248    fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
249        if self.pretty {
250            serde_json::to_vec_pretty(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
251        } else {
252            serde_json::to_vec(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))
253        }
254    }
255
256    fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
257        serde_json::from_slice(data).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
258    }
259
260    fn format(&self) -> SerializationFormat {
261        SerializationFormat::Json
262    }
263}
264
265/// Encrypted serializer wrapper
266///
267/// Wraps any inner serializer with AES-256-GCM encryption for secure storage.
268///
269/// # Feature
270///
271/// Only available when the `encryption` feature is enabled.
272#[cfg(feature = "encryption")]
273#[derive(Clone)]
274pub struct EncryptedSerializer<S: CheckpointSerializer> {
275    /// Inner serializer to use after encryption
276    inner: S,
277    /// AES-256-GCM cipher (initialized once at construction)
278    cipher: Aes256Gcm,
279}
280
281#[cfg(feature = "encryption")]
282impl<S: CheckpointSerializer> EncryptedSerializer<S> {
283    /// Create a new encrypted serializer
284    ///
285    /// Initializes the AES-256-GCM cipher from the provided 32-byte key.
286    /// The cipher is stored and reused for all encryption/decryption operations.
287    ///
288    /// # Panics
289    ///
290    /// Panics if key length is not 32 bytes (should never happen with proper validation).
291    pub fn new(inner: S, key: &[u8; 32]) -> Self {
292        let cipher = Aes256Gcm::new(GenericArray::from_slice(key));
293        Self { inner, cipher }
294    }
295
296    /// Create from a passphrase using PBKDF2
297    ///
298    /// Derives a 32-byte key from the provided passphrase using PBKDF2-HMAC-SHA256
299    /// with 100,000 iterations (OWASP recommendation), then initializes the cipher.
300    ///
301    /// # Errors
302    ///
303    /// Returns [`CheckpointError::Serialize`] if key derivation fails.
304    pub fn from_passphrase(
305        inner: S,
306        passphrase: &str,
307        salt: &[u8; 32],
308    ) -> Result<Self, CheckpointError> {
309        let mut key = [0u8; 32];
310        pbkdf2_hmac::<Sha256>(passphrase.as_bytes(), salt, 100_000, &mut key);
311        let cipher = Aes256Gcm::new(GenericArray::from_slice(&key));
312        Ok(Self { inner, cipher })
313    }
314}
315
316#[cfg(feature = "encryption")]
317impl<S: CheckpointSerializer + std::fmt::Debug> std::fmt::Debug for EncryptedSerializer<S> {
318    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319        f.debug_struct("EncryptedSerializer")
320            .field("inner", &self.inner)
321            .field("cipher", &"<aes-256-gcm cipher>")
322            .finish()
323    }
324}
325
326#[cfg(feature = "encryption")]
327impl<S: CheckpointSerializer> CheckpointSerializer for EncryptedSerializer<S> {
328    fn serialize_value(&self, value: &serde_json::Value) -> Result<Vec<u8>, CheckpointError> {
329        // Serialize the value using inner serializer
330        let plaintext = self.inner.serialize_value(value)?;
331
332        // Generate random nonce
333        let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
334
335        // Encrypt using the pre-initialized cipher
336        let ciphertext = self
337            .cipher
338            .encrypt(&nonce, plaintext.as_ref())
339            .map_err(|e| CheckpointError::serialize_msg(format!("Encryption failed: {e}")))?;
340
341        // Format: nonce (12 bytes) + ciphertext
342        let mut result = Vec::with_capacity(12 + ciphertext.len());
343        result.extend_from_slice(&nonce);
344        result.extend_from_slice(&ciphertext);
345
346        Ok(result)
347    }
348
349    fn deserialize_value(&self, data: &[u8]) -> Result<serde_json::Value, CheckpointError> {
350        if data.len() < 12 {
351            return Err(CheckpointError::deserialize_msg(
352                "Encrypted data too short".to_string(),
353            ));
354        }
355
356        // Extract nonce and ciphertext
357        let (nonce_bytes, ciphertext) = data.split_at(12);
358        let nonce = Nonce::from_slice(nonce_bytes);
359
360        // Decrypt using the pre-initialized cipher
361        let plaintext = self
362            .cipher
363            .decrypt(nonce, ciphertext)
364            .map_err(|e| CheckpointError::deserialize_msg(format!("Decryption failed: {e}")))?;
365
366        // Deserialize using inner serializer
367        self.inner.deserialize_value(&plaintext)
368    }
369
370    fn serialize<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CheckpointError> {
371        // Convert to JSON value first
372        let json_value =
373            serde_json::to_value(value).map_err(|e| CheckpointError::Serialize(Box::new(e)))?;
374        self.serialize_value(&json_value)
375    }
376
377    fn deserialize<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CheckpointError> {
378        let json_value = self.deserialize_value(data)?;
379        serde_json::from_value(json_value).map_err(|e| CheckpointError::Deserialize(Box::new(e)))
380    }
381
382    fn format(&self) -> SerializationFormat {
383        self.inner.format()
384    }
385}
386
387/// Detect serialization format from raw bytes
388///
389/// Examines the byte sequence to determine if it's `MessagePack` or JSON format.
390///
391/// # Examples
392///
393/// ```
394/// use juncture_checkpoint::serde::{detect_format, SerializationFormat};
395///
396/// let json_data = b"{\"key\":\"value\"}";
397/// let format = detect_format(json_data);
398/// assert_eq!(format, SerializationFormat::Json);
399/// ```
400#[must_use]
401pub fn detect_format(data: &[u8]) -> SerializationFormat {
402    // MessagePack format detection
403    // Common MessagePack markers: 0x82 (fixmap), 0x83 (fixmap), 0xde (map16)
404    // JSON format: starts with '{' (0x7b) or '[' (0x5b) or whitespace
405    if data.is_empty() {
406        return SerializationFormat::Json;
407    }
408
409    let first_byte = data[0];
410
411    // JSON format
412    if first_byte == b'{' || first_byte == b'[' || first_byte.is_ascii_whitespace() {
413        return SerializationFormat::Json;
414    }
415
416    // MessagePack format detection (heuristic)
417    // fixmap: 0x80-0x8f, fixarray: 0x90-0x9f, map16: 0xde, map32: 0xdf
418    // array16: 0xdc, array32: 0xdd
419    if (0x80..=0x9f).contains(&first_byte)
420        || first_byte == 0xde
421        || first_byte == 0xdf
422        || first_byte == 0xdc
423        || first_byte == 0xdd
424    {
425        return SerializationFormat::MessagePack;
426    }
427
428    // Default to JSON for unknown formats
429    SerializationFormat::Json
430}
431
432/// Deserialize bytes using format auto-detection
433///
434/// Detects whether the data is `MessagePack` or JSON, then deserializes
435/// using the appropriate serializer. Falls back to JSON deserialization
436/// if detection is ambiguous.
437///
438/// This function provides backwards compatibility when reading checkpoints
439/// that were written with a different serializer (e.g., old JSON data
440/// read by a saver now defaulting to `MessagePack`).
441///
442/// # Errors
443///
444/// Returns [`CheckpointError::Deserialize`] if neither `MessagePack` nor JSON
445/// deserialization succeeds.
446pub fn deserialize_auto<T: DeserializeOwned>(data: &[u8]) -> Result<T, CheckpointError> {
447    let format = detect_format(data);
448    match format {
449        SerializationFormat::MessagePack => {
450            // Try msgpack first, fall back to JSON if detection was wrong
451            MsgpackSerializer::new()
452                .deserialize::<T>(data)
453                .or_else(|_| JsonSerializer::new().deserialize::<T>(data))
454        }
455        SerializationFormat::Json => JsonSerializer::new().deserialize::<T>(data),
456    }
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462    use serde_json::json;
463
464    #[test]
465    fn test_msgpack_serializer_roundtrip() {
466        let ser = MsgpackSerializer::new();
467        let original = json!({"key": "value", "number": 42});
468
469        let serialized_data = ser.serialize_value(&original).unwrap();
470        let deserialized = ser.deserialize_value(&serialized_data).unwrap();
471
472        assert_eq!(original, deserialized);
473    }
474
475    #[test]
476    fn test_json_serializer_roundtrip() {
477        let ser = JsonSerializer::new();
478        let original = json!({"key": "value", "number": 42});
479
480        let serialized_data = ser.serialize_value(&original).unwrap();
481        let deserialized = ser.deserialize_value(&serialized_data).unwrap();
482
483        assert_eq!(original, deserialized);
484    }
485
486    #[test]
487    fn test_json_plus_serializer_pretty() {
488        let ser = JsonPlusSerializer::new();
489        let original = json!({"key": "value", "nested": {"a": 1}});
490
491        let serialized_data = ser.serialize_value(&original).unwrap();
492        let serialized_str = std::str::from_utf8(&serialized_data).unwrap();
493
494        // Pretty-printed should contain newlines/indentation
495        assert!(serialized_str.contains('\n'));
496
497        let deserialized = ser.deserialize_value(&serialized_data).unwrap();
498        assert_eq!(original, deserialized);
499    }
500
501    #[test]
502    fn test_checkpoint_detect_format_json() {
503        let json_data = b"{\"key\":\"value\"}";
504        let format = detect_format(json_data);
505        assert_eq!(format, SerializationFormat::Json);
506    }
507
508    #[test]
509    fn test_checkpoint_detect_format_msgpack() {
510        // Create actual MessagePack data
511        let serializer = MsgpackSerializer::new();
512        let value = json!({"key": "value"});
513        let msgpack_data = serializer.serialize_value(&value).unwrap();
514
515        let format = detect_format(&msgpack_data);
516        assert_eq!(format, SerializationFormat::MessagePack);
517    }
518
519    #[test]
520    fn test_checkpoint_detect_format_empty() {
521        let format = detect_format(&[]);
522        assert_eq!(format, SerializationFormat::Json);
523    }
524
525    #[cfg(feature = "encryption")]
526    #[test]
527    fn test_encrypted_serializer() {
528        use aes_gcm::aead::rand_core::RngCore;
529
530        let inner = JsonSerializer::new();
531        let mut key = [0u8; 32];
532        OsRng.fill_bytes(&mut key);
533
534        let serializer = EncryptedSerializer::new(inner, &key);
535        let original = json!({"secret": "data"});
536
537        let encrypted = serializer.serialize_value(&original).unwrap();
538
539        // Encrypted data should be larger (nonce + ciphertext)
540        assert!(encrypted.len() > original.to_string().len());
541
542        let decrypted = serializer.deserialize_value(&encrypted).unwrap();
543        assert_eq!(original, decrypted);
544    }
545
546    #[test]
547    fn test_serialization_format_eq() {
548        assert_eq!(
549            SerializationFormat::MessagePack,
550            SerializationFormat::MessagePack
551        );
552        assert_eq!(SerializationFormat::Json, SerializationFormat::Json);
553        assert_ne!(SerializationFormat::MessagePack, SerializationFormat::Json);
554    }
555}
556
557// Rust guideline compliant 2026-05-20