nomad_protocol/extensions/
checkpoint.rs

1//! Checkpoint extension (0x0006)
2//!
3//! Provides full state snapshots for recovery, initial sync, or periodic
4//! consistency verification. Unlike incremental sync, checkpoints contain
5//! the complete state at a specific point in time.
6//!
7//! Wire format for extension negotiation:
8//! ```text
9//! +0  Flags (1 byte)
10//!     - bit 0: Client can request checkpoints
11//!     - bit 1: Server sends periodic checkpoints
12//!     - bit 2: Incremental checkpoints supported (delta from previous)
13//!     - bit 3: Compressed checkpoints supported
14//! +1  Max checkpoint size (4 bytes LE32) - maximum uncompressed size
15//! +5  Checkpoint interval hint (2 bytes LE16) - suggested seconds between checkpoints
16//! ```
17//!
18//! Wire format for checkpoint frame:
19//! ```text
20//! +0   Checkpoint ID (8 bytes LE64)
21//! +8   Flags (1 byte)
22//!      - bit 0: Is compressed
23//!      - bit 1: Is incremental (delta from base_id)
24//!      - bit 2: Has signature
25//! +9   State number (8 bytes LE64) - sync state this checkpoint represents
26//! +17  Base checkpoint ID (8 bytes LE64) - for incremental, 0 otherwise
27//! +25  Uncompressed size (4 bytes LE32)
28//! +29  Payload length (4 bytes LE32)
29//! +33  Payload data
30//! +N   [Optional] Signature (32 bytes) if Has signature flag set
31//! ```
32
33use super::negotiation::{ext_type, Extension, NegotiationError};
34
35/// Checkpoint negotiation flags
36pub mod checkpoint_config_flags {
37    /// Client can request checkpoints on demand
38    pub const CLIENT_REQUEST: u8 = 0x01;
39    /// Server sends periodic checkpoints
40    pub const PERIODIC: u8 = 0x02;
41    /// Incremental checkpoints (delta from base) supported
42    pub const INCREMENTAL: u8 = 0x04;
43    /// Compressed checkpoints supported
44    pub const COMPRESSED: u8 = 0x08;
45}
46
47/// Checkpoint frame flags
48pub mod checkpoint_frame_flags {
49    /// Payload is compressed
50    pub const COMPRESSED: u8 = 0x01;
51    /// Checkpoint is incremental (delta from base)
52    pub const INCREMENTAL: u8 = 0x02;
53    /// Frame includes signature
54    pub const SIGNED: u8 = 0x04;
55}
56
57/// Checkpoint configuration
58#[derive(Debug, Clone, PartialEq, Eq)]
59pub struct CheckpointConfig {
60    /// Feature flags
61    pub flags: u8,
62    /// Maximum uncompressed checkpoint size
63    pub max_size: u32,
64    /// Suggested interval between checkpoints (seconds)
65    pub interval_secs: u16,
66}
67
68impl Default for CheckpointConfig {
69    fn default() -> Self {
70        Self {
71            flags: checkpoint_config_flags::CLIENT_REQUEST | checkpoint_config_flags::COMPRESSED,
72            max_size: 16 * 1024 * 1024, // 16 MB
73            interval_secs: 300,          // 5 minutes
74        }
75    }
76}
77
78impl CheckpointConfig {
79    /// Create config with all features
80    pub fn full() -> Self {
81        Self {
82            flags: checkpoint_config_flags::CLIENT_REQUEST
83                | checkpoint_config_flags::PERIODIC
84                | checkpoint_config_flags::INCREMENTAL
85                | checkpoint_config_flags::COMPRESSED,
86            max_size: 64 * 1024 * 1024, // 64 MB
87            interval_secs: 60,
88        }
89    }
90
91    /// Check if client can request checkpoints
92    pub fn supports_client_request(&self) -> bool {
93        (self.flags & checkpoint_config_flags::CLIENT_REQUEST) != 0
94    }
95
96    /// Check if server sends periodic checkpoints
97    pub fn supports_periodic(&self) -> bool {
98        (self.flags & checkpoint_config_flags::PERIODIC) != 0
99    }
100
101    /// Check if incremental checkpoints are supported
102    pub fn supports_incremental(&self) -> bool {
103        (self.flags & checkpoint_config_flags::INCREMENTAL) != 0
104    }
105
106    /// Check if compressed checkpoints are supported
107    pub fn supports_compressed(&self) -> bool {
108        (self.flags & checkpoint_config_flags::COMPRESSED) != 0
109    }
110
111    /// Wire size
112    pub const fn wire_size() -> usize {
113        7 // flags(1) + max_size(4) + interval(2)
114    }
115
116    /// Encode to extension
117    pub fn to_extension(&self) -> Extension {
118        let mut data = Vec::with_capacity(Self::wire_size());
119        data.push(self.flags);
120        data.extend_from_slice(&self.max_size.to_le_bytes());
121        data.extend_from_slice(&self.interval_secs.to_le_bytes());
122        Extension::new(ext_type::CHECKPOINT, data)
123    }
124
125    /// Decode from extension
126    pub fn from_extension(ext: &Extension) -> Option<Self> {
127        if ext.ext_type != ext_type::CHECKPOINT || ext.data.len() < Self::wire_size() {
128            return None;
129        }
130        Some(Self {
131            flags: ext.data[0],
132            max_size: u32::from_le_bytes([ext.data[1], ext.data[2], ext.data[3], ext.data[4]]),
133            interval_secs: u16::from_le_bytes([ext.data[5], ext.data[6]]),
134        })
135    }
136
137    /// Negotiate between client and server
138    pub fn negotiate(client: &Self, server: &Self) -> Self {
139        Self {
140            flags: client.flags & server.flags,
141            max_size: client.max_size.min(server.max_size),
142            interval_secs: client.interval_secs.max(server.interval_secs), // Use longer interval
143        }
144    }
145}
146
147/// Header for a checkpoint frame
148pub const CHECKPOINT_HEADER_SIZE: usize = 33;
149
150/// A checkpoint frame header
151#[derive(Debug, Clone, PartialEq, Eq)]
152pub struct CheckpointHeader {
153    /// Unique checkpoint identifier
154    pub checkpoint_id: u64,
155    /// Frame flags
156    pub flags: u8,
157    /// State number this checkpoint represents
158    pub state_num: u64,
159    /// Base checkpoint ID for incremental (0 if full)
160    pub base_id: u64,
161    /// Uncompressed payload size
162    pub uncompressed_size: u32,
163    /// Actual payload size in frame
164    pub payload_len: u32,
165}
166
167impl CheckpointHeader {
168    /// Create a full checkpoint header
169    pub fn full(checkpoint_id: u64, state_num: u64, size: u32) -> Self {
170        Self {
171            checkpoint_id,
172            flags: 0,
173            state_num,
174            base_id: 0,
175            uncompressed_size: size,
176            payload_len: size,
177        }
178    }
179
180    /// Create an incremental checkpoint header
181    pub fn incremental(checkpoint_id: u64, state_num: u64, base_id: u64, size: u32) -> Self {
182        Self {
183            checkpoint_id,
184            flags: checkpoint_frame_flags::INCREMENTAL,
185            state_num,
186            base_id,
187            uncompressed_size: size,
188            payload_len: size,
189        }
190    }
191
192    /// Check if checkpoint is compressed
193    pub fn is_compressed(&self) -> bool {
194        (self.flags & checkpoint_frame_flags::COMPRESSED) != 0
195    }
196
197    /// Check if checkpoint is incremental
198    pub fn is_incremental(&self) -> bool {
199        (self.flags & checkpoint_frame_flags::INCREMENTAL) != 0
200    }
201
202    /// Check if checkpoint has signature
203    pub fn is_signed(&self) -> bool {
204        (self.flags & checkpoint_frame_flags::SIGNED) != 0
205    }
206
207    /// Set compressed flag and actual payload size
208    pub fn set_compressed(&mut self, compressed_len: u32) {
209        self.flags |= checkpoint_frame_flags::COMPRESSED;
210        self.payload_len = compressed_len;
211    }
212
213    /// Set signed flag
214    pub fn set_signed(&mut self) {
215        self.flags |= checkpoint_frame_flags::SIGNED;
216    }
217
218    /// Encode header to bytes
219    pub fn encode(&self) -> [u8; CHECKPOINT_HEADER_SIZE] {
220        let mut buf = [0u8; CHECKPOINT_HEADER_SIZE];
221        buf[0..8].copy_from_slice(&self.checkpoint_id.to_le_bytes());
222        buf[8] = self.flags;
223        buf[9..17].copy_from_slice(&self.state_num.to_le_bytes());
224        buf[17..25].copy_from_slice(&self.base_id.to_le_bytes());
225        buf[25..29].copy_from_slice(&self.uncompressed_size.to_le_bytes());
226        buf[29..33].copy_from_slice(&self.payload_len.to_le_bytes());
227        buf
228    }
229
230    /// Decode header from bytes
231    pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
232        if data.len() < CHECKPOINT_HEADER_SIZE {
233            return Err(NegotiationError::TooShort {
234                expected: CHECKPOINT_HEADER_SIZE,
235                actual: data.len(),
236            });
237        }
238
239        Ok(Self {
240            checkpoint_id: u64::from_le_bytes(
241                data[0..8].try_into().expect("length checked"),
242            ),
243            flags: data[8],
244            state_num: u64::from_le_bytes(data[9..17].try_into().expect("length checked")),
245            base_id: u64::from_le_bytes(data[17..25].try_into().expect("length checked")),
246            uncompressed_size: u32::from_le_bytes(
247                data[25..29].try_into().expect("length checked"),
248            ),
249            payload_len: u32::from_le_bytes(data[29..33].try_into().expect("length checked")),
250        })
251    }
252}
253
254/// A complete checkpoint (header + payload)
255#[derive(Debug, Clone)]
256pub struct Checkpoint {
257    /// Checkpoint header
258    pub header: CheckpointHeader,
259    /// Checkpoint payload (may be compressed)
260    pub payload: Vec<u8>,
261    /// Optional signature
262    pub signature: Option<[u8; 32]>,
263}
264
265impl Checkpoint {
266    /// Create a full checkpoint
267    pub fn new(checkpoint_id: u64, state_num: u64, data: Vec<u8>) -> Self {
268        let header = CheckpointHeader::full(checkpoint_id, state_num, data.len() as u32);
269        Self {
270            header,
271            payload: data,
272            signature: None,
273        }
274    }
275
276    /// Total wire size
277    pub fn wire_size(&self) -> usize {
278        CHECKPOINT_HEADER_SIZE
279            + self.payload.len()
280            + if self.signature.is_some() { 32 } else { 0 }
281    }
282
283    /// Encode to bytes
284    pub fn encode(&self) -> Vec<u8> {
285        let mut buf = Vec::with_capacity(self.wire_size());
286        buf.extend_from_slice(&self.header.encode());
287        buf.extend_from_slice(&self.payload);
288        if let Some(sig) = &self.signature {
289            buf.extend_from_slice(sig);
290        }
291        buf
292    }
293
294    /// Decode from bytes
295    pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
296        let header = CheckpointHeader::decode(data)?;
297
298        let payload_start = CHECKPOINT_HEADER_SIZE;
299        let payload_end = payload_start + header.payload_len as usize;
300
301        if data.len() < payload_end {
302            return Err(NegotiationError::TooShort {
303                expected: payload_end,
304                actual: data.len(),
305            });
306        }
307
308        let payload = data[payload_start..payload_end].to_vec();
309
310        let signature = if header.is_signed() {
311            let sig_start = payload_end;
312            let sig_end = sig_start + 32;
313            if data.len() < sig_end {
314                return Err(NegotiationError::TooShort {
315                    expected: sig_end,
316                    actual: data.len(),
317                });
318            }
319            Some(data[sig_start..sig_end].try_into().expect("length checked"))
320        } else {
321            None
322        };
323
324        Ok(Self {
325            header,
326            payload,
327            signature,
328        })
329    }
330}
331
332/// Request for a checkpoint
333#[derive(Debug, Clone, PartialEq, Eq)]
334pub enum CheckpointRequest {
335    /// Request latest full checkpoint
336    Latest,
337    /// Request checkpoint at specific state
338    AtState(u64),
339    /// Request incremental from specified base
340    IncrementalFrom(u64),
341}
342
343impl CheckpointRequest {
344    /// Encode to bytes
345    pub fn encode(&self) -> Vec<u8> {
346        match self {
347            Self::Latest => vec![0x00],
348            Self::AtState(state) => {
349                let mut buf = vec![0x01];
350                buf.extend_from_slice(&state.to_le_bytes());
351                buf
352            }
353            Self::IncrementalFrom(base) => {
354                let mut buf = vec![0x02];
355                buf.extend_from_slice(&base.to_le_bytes());
356                buf
357            }
358        }
359    }
360
361    /// Decode from bytes
362    pub fn decode(data: &[u8]) -> Result<Self, NegotiationError> {
363        if data.is_empty() {
364            return Err(NegotiationError::TooShort {
365                expected: 1,
366                actual: 0,
367            });
368        }
369
370        match data[0] {
371            0x00 => Ok(Self::Latest),
372            0x01 => {
373                if data.len() < 9 {
374                    return Err(NegotiationError::TooShort {
375                        expected: 9,
376                        actual: data.len(),
377                    });
378                }
379                let state = u64::from_le_bytes(data[1..9].try_into().expect("length checked"));
380                Ok(Self::AtState(state))
381            }
382            0x02 => {
383                if data.len() < 9 {
384                    return Err(NegotiationError::TooShort {
385                        expected: 9,
386                        actual: data.len(),
387                    });
388                }
389                let base = u64::from_le_bytes(data[1..9].try_into().expect("length checked"));
390                Ok(Self::IncrementalFrom(base))
391            }
392            _ => Err(NegotiationError::InvalidData),
393        }
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_config_default() {
403        let config = CheckpointConfig::default();
404        assert!(config.supports_client_request());
405        assert!(!config.supports_periodic());
406        assert!(config.supports_compressed());
407    }
408
409    #[test]
410    fn test_config_extension_roundtrip() {
411        let config = CheckpointConfig {
412            flags: checkpoint_config_flags::CLIENT_REQUEST | checkpoint_config_flags::INCREMENTAL,
413            max_size: 8 * 1024 * 1024,
414            interval_secs: 120,
415        };
416
417        let ext = config.to_extension();
418        let decoded = CheckpointConfig::from_extension(&ext).unwrap();
419        assert_eq!(decoded, config);
420    }
421
422    #[test]
423    fn test_config_negotiate() {
424        let client = CheckpointConfig {
425            flags: checkpoint_config_flags::CLIENT_REQUEST | checkpoint_config_flags::COMPRESSED,
426            max_size: 32 * 1024 * 1024,
427            interval_secs: 60,
428        };
429        let server = CheckpointConfig {
430            flags: checkpoint_config_flags::CLIENT_REQUEST | checkpoint_config_flags::PERIODIC,
431            max_size: 16 * 1024 * 1024,
432            interval_secs: 300,
433        };
434
435        let result = CheckpointConfig::negotiate(&client, &server);
436        assert!(result.supports_client_request());
437        assert!(!result.supports_compressed()); // Only client
438        assert!(!result.supports_periodic()); // Only server
439        assert_eq!(result.max_size, 16 * 1024 * 1024);
440        assert_eq!(result.interval_secs, 300); // Use longer
441    }
442
443    #[test]
444    fn test_header_roundtrip() {
445        let header = CheckpointHeader::full(12345, 100, 4096);
446        let encoded = header.encode();
447        let decoded = CheckpointHeader::decode(&encoded).unwrap();
448        assert_eq!(decoded, header);
449    }
450
451    #[test]
452    fn test_incremental_header() {
453        let header = CheckpointHeader::incremental(200, 150, 100, 1024);
454        assert!(header.is_incremental());
455        assert_eq!(header.base_id, 100);
456
457        let encoded = header.encode();
458        let decoded = CheckpointHeader::decode(&encoded).unwrap();
459        assert!(decoded.is_incremental());
460        assert_eq!(decoded.base_id, 100);
461    }
462
463    #[test]
464    fn test_checkpoint_roundtrip() {
465        let data = vec![1, 2, 3, 4, 5, 6, 7, 8];
466        let checkpoint = Checkpoint::new(42, 10, data.clone());
467
468        let encoded = checkpoint.encode();
469        let decoded = Checkpoint::decode(&encoded).unwrap();
470
471        assert_eq!(decoded.header.checkpoint_id, 42);
472        assert_eq!(decoded.header.state_num, 10);
473        assert_eq!(decoded.payload, data);
474        assert!(decoded.signature.is_none());
475    }
476
477    #[test]
478    fn test_checkpoint_with_signature() {
479        let mut checkpoint = Checkpoint::new(1, 1, vec![0xAB; 100]);
480        checkpoint.header.set_signed();
481        checkpoint.signature = Some([0xCD; 32]);
482
483        let encoded = checkpoint.encode();
484        let decoded = Checkpoint::decode(&encoded).unwrap();
485
486        assert!(decoded.header.is_signed());
487        assert_eq!(decoded.signature, Some([0xCD; 32]));
488    }
489
490    #[test]
491    fn test_request_roundtrip() {
492        for request in [
493            CheckpointRequest::Latest,
494            CheckpointRequest::AtState(999),
495            CheckpointRequest::IncrementalFrom(500),
496        ] {
497            let encoded = request.encode();
498            let decoded = CheckpointRequest::decode(&encoded).unwrap();
499            assert_eq!(decoded, request);
500        }
501    }
502
503    #[test]
504    fn test_compressed_header() {
505        let mut header = CheckpointHeader::full(1, 1, 10000);
506        header.set_compressed(2500);
507
508        assert!(header.is_compressed());
509        assert_eq!(header.uncompressed_size, 10000);
510        assert_eq!(header.payload_len, 2500);
511    }
512}