Skip to main content

ringkernel_core/
checkpoint.rs

1//! Kernel checkpointing for persistent state snapshot and restore.
2//!
3//! This module provides infrastructure for checkpointing persistent GPU kernels,
4//! enabling fault tolerance, migration, and debugging capabilities.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────────┐
10//! │                    CheckpointableKernel                         │
11//! │  ┌─────────────┐  ┌─────────────┐  ┌─────────────────────────┐  │
12//! │  │ Control     │  │ Queue       │  │ Device Memory           │  │
13//! │  │ Block       │  │ State       │  │ (pressure, halo, etc.)  │  │
14//! │  └─────────────┘  └─────────────┘  └─────────────────────────┘  │
15//! └─────────────────────────────────────────────────────────────────┘
16//!                              │
17//!                              ▼
18//! ┌─────────────────────────────────────────────────────────────────┐
19//! │                        Checkpoint                               │
20//! │  ┌─────────────┐  ┌─────────────┐  ┌─────────────────────────┐  │
21//! │  │ Header      │  │ Metadata    │  │ Compressed Data Chunks  │  │
22//! │  │ (magic,ver) │  │ (kernel_id) │  │ (control,queues,memory) │  │
23//! │  └─────────────┘  └─────────────┘  └─────────────────────────┘  │
24//! └─────────────────────────────────────────────────────────────────┘
25//!                              │
26//!                              ▼
27//! ┌─────────────────────────────────────────────────────────────────┐
28//! │                   CheckpointStorage                             │
29//! │  ┌─────────────┐  ┌─────────────┐  ┌─────────────────────────┐  │
30//! │  │ File        │  │ Memory      │  │ Cloud (S3/GCS)          │  │
31//! │  │ Backend     │  │ Backend     │  │ Backend                 │  │
32//! │  └─────────────┘  └─────────────┘  └─────────────────────────┘  │
33//! └─────────────────────────────────────────────────────────────────┘
34//! ```
35//!
36//! # Example
37//!
38//! ```ignore
39//! use ringkernel_core::checkpoint::{Checkpoint, FileStorage, CheckpointableKernel};
40//!
41//! // Create checkpoint from running kernel
42//! let checkpoint = kernel.create_checkpoint()?;
43//!
44//! // Save to file
45//! let storage = FileStorage::new("/checkpoints");
46//! storage.save(&checkpoint, "sim_step_1000")?;
47//!
48//! // Later: restore from checkpoint
49//! let checkpoint = storage.load("sim_step_1000")?;
50//! kernel.restore_from_checkpoint(&checkpoint)?;
51//! ```
52
53use std::collections::HashMap;
54use std::io::{Read, Write};
55use std::path::{Path, PathBuf};
56use std::time::{Duration, SystemTime, UNIX_EPOCH};
57
58use crate::error::{Result, RingKernelError};
59use crate::hlc::HlcTimestamp;
60
61// ============================================================================
62// Checkpoint Format Constants
63// ============================================================================
64
65/// Magic number for checkpoint files: "RKCKPT01" in ASCII.
66pub const CHECKPOINT_MAGIC: u64 = 0x524B434B50543031;
67
68/// Current checkpoint format version.
69pub const CHECKPOINT_VERSION: u32 = 1;
70
71/// Maximum supported checkpoint size (1 GB).
72pub const MAX_CHECKPOINT_SIZE: usize = 1024 * 1024 * 1024;
73
74/// Chunk types for checkpoint data sections.
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76#[repr(u32)]
77pub enum ChunkType {
78    /// Control block state (256 bytes typically).
79    ControlBlock = 1,
80    /// H2K queue header and pending messages.
81    H2KQueue = 2,
82    /// K2H queue header and pending messages.
83    K2HQueue = 3,
84    /// HLC timestamp state.
85    HlcState = 4,
86    /// Device memory region (e.g., pressure field).
87    DeviceMemory = 5,
88    /// K2K routing table.
89    K2KRouting = 6,
90    /// Halo exchange buffers.
91    HaloBuffers = 7,
92    /// Telemetry statistics.
93    Telemetry = 8,
94    /// Custom application data.
95    Custom = 100,
96}
97
98impl ChunkType {
99    /// Convert from raw u32 value.
100    pub fn from_u32(value: u32) -> Option<Self> {
101        match value {
102            1 => Some(Self::ControlBlock),
103            2 => Some(Self::H2KQueue),
104            3 => Some(Self::K2HQueue),
105            4 => Some(Self::HlcState),
106            5 => Some(Self::DeviceMemory),
107            6 => Some(Self::K2KRouting),
108            7 => Some(Self::HaloBuffers),
109            8 => Some(Self::Telemetry),
110            100 => Some(Self::Custom),
111            _ => None,
112        }
113    }
114}
115
116// ============================================================================
117// Checkpoint Header
118// ============================================================================
119
120/// Checkpoint file header (64 bytes, fixed size).
121#[derive(Debug, Clone, Copy)]
122#[repr(C)]
123pub struct CheckpointHeader {
124    /// Magic number for format identification.
125    pub magic: u64,
126    /// Format version number.
127    pub version: u32,
128    /// Header size in bytes.
129    pub header_size: u32,
130    /// Total checkpoint size in bytes (including header).
131    pub total_size: u64,
132    /// Number of data chunks.
133    pub chunk_count: u32,
134    /// Compression algorithm (0 = none, 1 = lz4, 2 = zstd).
135    pub compression: u32,
136    /// CRC32 checksum of all data after header.
137    pub checksum: u32,
138    /// Flags (reserved for future use).
139    pub flags: u32,
140    /// Timestamp when checkpoint was created (UNIX epoch microseconds).
141    pub created_at: u64,
142    /// Reserved for alignment.
143    pub _reserved: [u8; 8],
144}
145
146impl CheckpointHeader {
147    /// Create a new checkpoint header.
148    pub fn new(chunk_count: u32, total_size: u64) -> Self {
149        let now = SystemTime::now()
150            .duration_since(UNIX_EPOCH)
151            .unwrap_or(Duration::ZERO);
152
153        Self {
154            magic: CHECKPOINT_MAGIC,
155            version: CHECKPOINT_VERSION,
156            header_size: std::mem::size_of::<Self>() as u32,
157            total_size,
158            chunk_count,
159            compression: 0,
160            checksum: 0,
161            flags: 0,
162            created_at: now.as_micros() as u64,
163            _reserved: [0; 8],
164        }
165    }
166
167    /// Validate the header.
168    pub fn validate(&self) -> Result<()> {
169        if self.magic != CHECKPOINT_MAGIC {
170            return Err(RingKernelError::InvalidCheckpoint(
171                "Invalid magic number".to_string(),
172            ));
173        }
174        if self.version > CHECKPOINT_VERSION {
175            return Err(RingKernelError::InvalidCheckpoint(format!(
176                "Unsupported version: {} (max: {})",
177                self.version, CHECKPOINT_VERSION
178            )));
179        }
180        if self.total_size as usize > MAX_CHECKPOINT_SIZE {
181            return Err(RingKernelError::InvalidCheckpoint(format!(
182                "Checkpoint too large: {} bytes (max: {})",
183                self.total_size, MAX_CHECKPOINT_SIZE
184            )));
185        }
186        Ok(())
187    }
188
189    /// Serialize to bytes.
190    pub fn to_bytes(&self) -> [u8; 64] {
191        let mut bytes = [0u8; 64];
192        bytes[0..8].copy_from_slice(&self.magic.to_le_bytes());
193        bytes[8..12].copy_from_slice(&self.version.to_le_bytes());
194        bytes[12..16].copy_from_slice(&self.header_size.to_le_bytes());
195        bytes[16..24].copy_from_slice(&self.total_size.to_le_bytes());
196        bytes[24..28].copy_from_slice(&self.chunk_count.to_le_bytes());
197        bytes[28..32].copy_from_slice(&self.compression.to_le_bytes());
198        bytes[32..36].copy_from_slice(&self.checksum.to_le_bytes());
199        bytes[36..40].copy_from_slice(&self.flags.to_le_bytes());
200        bytes[40..48].copy_from_slice(&self.created_at.to_le_bytes());
201        bytes
202    }
203
204    /// Deserialize from bytes.
205    pub fn from_bytes(bytes: &[u8; 64]) -> Self {
206        // All slice-to-array conversions below are infallible because the input
207        // is a fixed-size [u8; 64] and all subslice ranges are compile-time constants.
208        Self {
209            magic: u64::from_le_bytes(bytes[0..8].try_into().expect("slice is exactly 8 bytes")),
210            version: u32::from_le_bytes(bytes[8..12].try_into().expect("slice is exactly 4 bytes")),
211            header_size: u32::from_le_bytes(
212                bytes[12..16].try_into().expect("slice is exactly 4 bytes"),
213            ),
214            total_size: u64::from_le_bytes(
215                bytes[16..24].try_into().expect("slice is exactly 8 bytes"),
216            ),
217            chunk_count: u32::from_le_bytes(
218                bytes[24..28].try_into().expect("slice is exactly 4 bytes"),
219            ),
220            compression: u32::from_le_bytes(
221                bytes[28..32].try_into().expect("slice is exactly 4 bytes"),
222            ),
223            checksum: u32::from_le_bytes(
224                bytes[32..36].try_into().expect("slice is exactly 4 bytes"),
225            ),
226            flags: u32::from_le_bytes(bytes[36..40].try_into().expect("slice is exactly 4 bytes")),
227            created_at: u64::from_le_bytes(
228                bytes[40..48].try_into().expect("slice is exactly 8 bytes"),
229            ),
230            _reserved: bytes[48..56].try_into().expect("slice is exactly 8 bytes"),
231        }
232    }
233}
234
235// ============================================================================
236// Chunk Header
237// ============================================================================
238
239/// Header for each data chunk (32 bytes).
240#[derive(Debug, Clone, Copy)]
241#[repr(C)]
242pub struct ChunkHeader {
243    /// Chunk type identifier.
244    pub chunk_type: u32,
245    /// Chunk flags (compression, etc.).
246    pub flags: u32,
247    /// Uncompressed data size.
248    pub uncompressed_size: u64,
249    /// Compressed data size (same as uncompressed if not compressed).
250    pub compressed_size: u64,
251    /// Chunk-specific identifier (e.g., memory region name hash).
252    pub chunk_id: u64,
253}
254
255impl ChunkHeader {
256    /// Create a new chunk header.
257    pub fn new(chunk_type: ChunkType, data_size: usize) -> Self {
258        Self {
259            chunk_type: chunk_type as u32,
260            flags: 0,
261            uncompressed_size: data_size as u64,
262            compressed_size: data_size as u64,
263            chunk_id: 0,
264        }
265    }
266
267    /// Set the chunk ID.
268    pub fn with_id(mut self, id: u64) -> Self {
269        self.chunk_id = id;
270        self
271    }
272
273    /// Serialize to bytes.
274    pub fn to_bytes(&self) -> [u8; 32] {
275        let mut bytes = [0u8; 32];
276        bytes[0..4].copy_from_slice(&self.chunk_type.to_le_bytes());
277        bytes[4..8].copy_from_slice(&self.flags.to_le_bytes());
278        bytes[8..16].copy_from_slice(&self.uncompressed_size.to_le_bytes());
279        bytes[16..24].copy_from_slice(&self.compressed_size.to_le_bytes());
280        bytes[24..32].copy_from_slice(&self.chunk_id.to_le_bytes());
281        bytes
282    }
283
284    /// Deserialize from bytes.
285    pub fn from_bytes(bytes: &[u8; 32]) -> Self {
286        // All slice-to-array conversions below are infallible because the input
287        // is a fixed-size [u8; 32] and all subslice ranges are compile-time constants.
288        Self {
289            chunk_type: u32::from_le_bytes(
290                bytes[0..4].try_into().expect("slice is exactly 4 bytes"),
291            ),
292            flags: u32::from_le_bytes(bytes[4..8].try_into().expect("slice is exactly 4 bytes")),
293            uncompressed_size: u64::from_le_bytes(
294                bytes[8..16].try_into().expect("slice is exactly 8 bytes"),
295            ),
296            compressed_size: u64::from_le_bytes(
297                bytes[16..24].try_into().expect("slice is exactly 8 bytes"),
298            ),
299            chunk_id: u64::from_le_bytes(
300                bytes[24..32].try_into().expect("slice is exactly 8 bytes"),
301            ),
302        }
303    }
304}
305
306// ============================================================================
307// Checkpoint Metadata
308// ============================================================================
309
310/// Kernel-specific metadata stored in checkpoint.
311#[derive(Debug, Clone, Default)]
312pub struct CheckpointMetadata {
313    /// Unique kernel identifier.
314    pub kernel_id: String,
315    /// Kernel type (e.g., "fdtd_3d", "wave_sim").
316    pub kernel_type: String,
317    /// Current simulation step.
318    pub current_step: u64,
319    /// Grid dimensions.
320    pub grid_size: (u32, u32, u32),
321    /// Tile/block dimensions.
322    pub tile_size: (u32, u32, u32),
323    /// HLC timestamp at checkpoint time.
324    pub hlc_timestamp: HlcTimestamp,
325    /// Custom key-value metadata.
326    pub custom: HashMap<String, String>,
327}
328
329impl CheckpointMetadata {
330    /// Create new metadata for a kernel.
331    pub fn new(kernel_id: impl Into<String>, kernel_type: impl Into<String>) -> Self {
332        Self {
333            kernel_id: kernel_id.into(),
334            kernel_type: kernel_type.into(),
335            ..Default::default()
336        }
337    }
338
339    /// Set current step.
340    pub fn with_step(mut self, step: u64) -> Self {
341        self.current_step = step;
342        self
343    }
344
345    /// Set grid size.
346    pub fn with_grid_size(mut self, width: u32, height: u32, depth: u32) -> Self {
347        self.grid_size = (width, height, depth);
348        self
349    }
350
351    /// Set tile size.
352    pub fn with_tile_size(mut self, x: u32, y: u32, z: u32) -> Self {
353        self.tile_size = (x, y, z);
354        self
355    }
356
357    /// Set HLC timestamp.
358    pub fn with_hlc(mut self, hlc: HlcTimestamp) -> Self {
359        self.hlc_timestamp = hlc;
360        self
361    }
362
363    /// Add custom metadata.
364    pub fn with_custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
365        self.custom.insert(key.into(), value.into());
366        self
367    }
368
369    /// Serialize metadata to bytes.
370    pub fn to_bytes(&self) -> Vec<u8> {
371        let mut bytes = Vec::new();
372
373        // Kernel ID (length-prefixed string)
374        let kernel_id_bytes = self.kernel_id.as_bytes();
375        bytes.extend_from_slice(&(kernel_id_bytes.len() as u32).to_le_bytes());
376        bytes.extend_from_slice(kernel_id_bytes);
377
378        // Kernel type
379        let kernel_type_bytes = self.kernel_type.as_bytes();
380        bytes.extend_from_slice(&(kernel_type_bytes.len() as u32).to_le_bytes());
381        bytes.extend_from_slice(kernel_type_bytes);
382
383        // Current step
384        bytes.extend_from_slice(&self.current_step.to_le_bytes());
385
386        // Grid size
387        bytes.extend_from_slice(&self.grid_size.0.to_le_bytes());
388        bytes.extend_from_slice(&self.grid_size.1.to_le_bytes());
389        bytes.extend_from_slice(&self.grid_size.2.to_le_bytes());
390
391        // Tile size
392        bytes.extend_from_slice(&self.tile_size.0.to_le_bytes());
393        bytes.extend_from_slice(&self.tile_size.1.to_le_bytes());
394        bytes.extend_from_slice(&self.tile_size.2.to_le_bytes());
395
396        // HLC timestamp
397        bytes.extend_from_slice(&self.hlc_timestamp.physical.to_le_bytes());
398        bytes.extend_from_slice(&self.hlc_timestamp.logical.to_le_bytes());
399        bytes.extend_from_slice(&self.hlc_timestamp.node_id.to_le_bytes());
400
401        // Custom metadata count
402        bytes.extend_from_slice(&(self.custom.len() as u32).to_le_bytes());
403
404        // Custom key-value pairs
405        for (key, value) in &self.custom {
406            let key_bytes = key.as_bytes();
407            bytes.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
408            bytes.extend_from_slice(key_bytes);
409
410            let value_bytes = value.as_bytes();
411            bytes.extend_from_slice(&(value_bytes.len() as u32).to_le_bytes());
412            bytes.extend_from_slice(value_bytes);
413        }
414
415        bytes
416    }
417
418    /// Deserialize metadata from bytes.
419    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
420        let mut offset = 0;
421
422        // Helper to read u32
423        let read_u32 = |off: &mut usize| -> Result<u32> {
424            if *off + 4 > bytes.len() {
425                return Err(RingKernelError::InvalidCheckpoint(
426                    "Unexpected end of metadata".to_string(),
427                ));
428            }
429            // Bounds checked above, so the 4-byte slice conversion is infallible
430            let val = u32::from_le_bytes(
431                bytes[*off..*off + 4]
432                    .try_into()
433                    .expect("bounds checked 4-byte slice"),
434            );
435            *off += 4;
436            Ok(val)
437        };
438
439        // Helper to read u64
440        let read_u64 = |off: &mut usize| -> Result<u64> {
441            if *off + 8 > bytes.len() {
442                return Err(RingKernelError::InvalidCheckpoint(
443                    "Unexpected end of metadata".to_string(),
444                ));
445            }
446            // Bounds checked above, so the 8-byte slice conversion is infallible
447            let val = u64::from_le_bytes(
448                bytes[*off..*off + 8]
449                    .try_into()
450                    .expect("bounds checked 8-byte slice"),
451            );
452            *off += 8;
453            Ok(val)
454        };
455
456        // Helper to read string
457        let read_string = |off: &mut usize| -> Result<String> {
458            let len = read_u32(off)? as usize;
459            if *off + len > bytes.len() {
460                return Err(RingKernelError::InvalidCheckpoint(
461                    "Unexpected end of metadata".to_string(),
462                ));
463            }
464            let s = String::from_utf8(bytes[*off..*off + len].to_vec())
465                .map_err(|e| RingKernelError::InvalidCheckpoint(e.to_string()))?;
466            *off += len;
467            Ok(s)
468        };
469
470        let kernel_id = read_string(&mut offset)?;
471        let kernel_type = read_string(&mut offset)?;
472        let current_step = read_u64(&mut offset)?;
473
474        let grid_size = (
475            read_u32(&mut offset)?,
476            read_u32(&mut offset)?,
477            read_u32(&mut offset)?,
478        );
479
480        let tile_size = (
481            read_u32(&mut offset)?,
482            read_u32(&mut offset)?,
483            read_u32(&mut offset)?,
484        );
485
486        let hlc_timestamp = HlcTimestamp {
487            physical: read_u64(&mut offset)?,
488            logical: read_u64(&mut offset)?,
489            node_id: read_u64(&mut offset)?,
490        };
491
492        let custom_count = read_u32(&mut offset)? as usize;
493        let mut custom = HashMap::new();
494
495        for _ in 0..custom_count {
496            let key = read_string(&mut offset)?;
497            let value = read_string(&mut offset)?;
498            custom.insert(key, value);
499        }
500
501        Ok(Self {
502            kernel_id,
503            kernel_type,
504            current_step,
505            grid_size,
506            tile_size,
507            hlc_timestamp,
508            custom,
509        })
510    }
511}
512
513// ============================================================================
514// Checkpoint Data Chunk
515// ============================================================================
516
517/// A single data chunk in a checkpoint.
518#[derive(Debug, Clone)]
519pub struct DataChunk {
520    /// Chunk header.
521    pub header: ChunkHeader,
522    /// Chunk data (may be compressed).
523    pub data: Vec<u8>,
524}
525
526impl DataChunk {
527    /// Create a new data chunk.
528    pub fn new(chunk_type: ChunkType, data: Vec<u8>) -> Self {
529        Self {
530            header: ChunkHeader::new(chunk_type, data.len()),
531            data,
532        }
533    }
534
535    /// Create a chunk with a custom ID.
536    pub fn with_id(chunk_type: ChunkType, data: Vec<u8>, id: u64) -> Self {
537        Self {
538            header: ChunkHeader::new(chunk_type, data.len()).with_id(id),
539            data,
540        }
541    }
542
543    /// Get the chunk type.
544    pub fn chunk_type(&self) -> Option<ChunkType> {
545        ChunkType::from_u32(self.header.chunk_type)
546    }
547}
548
549// ============================================================================
550// Checkpoint
551// ============================================================================
552
553/// Complete checkpoint containing all kernel state.
554#[derive(Debug, Clone)]
555pub struct Checkpoint {
556    /// Checkpoint header.
557    pub header: CheckpointHeader,
558    /// Kernel metadata.
559    pub metadata: CheckpointMetadata,
560    /// Data chunks.
561    pub chunks: Vec<DataChunk>,
562}
563
564impl Checkpoint {
565    /// Create a new checkpoint.
566    pub fn new(metadata: CheckpointMetadata) -> Self {
567        Self {
568            header: CheckpointHeader::new(0, 0),
569            metadata,
570            chunks: Vec::new(),
571        }
572    }
573
574    /// Add a data chunk.
575    pub fn add_chunk(&mut self, chunk: DataChunk) {
576        self.chunks.push(chunk);
577    }
578
579    /// Add control block data.
580    pub fn add_control_block(&mut self, data: Vec<u8>) {
581        self.add_chunk(DataChunk::new(ChunkType::ControlBlock, data));
582    }
583
584    /// Add H2K queue data.
585    pub fn add_h2k_queue(&mut self, data: Vec<u8>) {
586        self.add_chunk(DataChunk::new(ChunkType::H2KQueue, data));
587    }
588
589    /// Add K2H queue data.
590    pub fn add_k2h_queue(&mut self, data: Vec<u8>) {
591        self.add_chunk(DataChunk::new(ChunkType::K2HQueue, data));
592    }
593
594    /// Add HLC state.
595    pub fn add_hlc_state(&mut self, data: Vec<u8>) {
596        self.add_chunk(DataChunk::new(ChunkType::HlcState, data));
597    }
598
599    /// Add device memory region.
600    pub fn add_device_memory(&mut self, name: &str, data: Vec<u8>) {
601        // Use hash of name as chunk ID
602        use std::collections::hash_map::DefaultHasher;
603        use std::hash::{Hash, Hasher};
604        let mut hasher = DefaultHasher::new();
605        name.hash(&mut hasher);
606        let id = hasher.finish();
607
608        self.add_chunk(DataChunk::with_id(ChunkType::DeviceMemory, data, id));
609    }
610
611    /// Get a chunk by type.
612    pub fn get_chunk(&self, chunk_type: ChunkType) -> Option<&DataChunk> {
613        self.chunks
614            .iter()
615            .find(|c| c.chunk_type() == Some(chunk_type))
616    }
617
618    /// Get all chunks of a type.
619    pub fn get_chunks(&self, chunk_type: ChunkType) -> Vec<&DataChunk> {
620        self.chunks
621            .iter()
622            .filter(|c| c.chunk_type() == Some(chunk_type))
623            .collect()
624    }
625
626    /// Calculate total size in bytes.
627    pub fn total_size(&self) -> usize {
628        let header_size = std::mem::size_of::<CheckpointHeader>();
629        let metadata_bytes = self.metadata.to_bytes();
630        let metadata_size = 4 + metadata_bytes.len(); // length prefix + data
631
632        let chunks_size: usize = self
633            .chunks
634            .iter()
635            .map(|c| std::mem::size_of::<ChunkHeader>() + c.data.len())
636            .sum();
637
638        header_size + metadata_size + chunks_size
639    }
640
641    /// Serialize checkpoint to bytes.
642    pub fn to_bytes(&self) -> Vec<u8> {
643        let mut bytes = Vec::new();
644
645        // Metadata as bytes
646        let metadata_bytes = self.metadata.to_bytes();
647
648        // Calculate total size
649        let total_size = self.total_size();
650
651        // Create header with correct values
652        let header = CheckpointHeader::new(self.chunks.len() as u32, total_size as u64);
653
654        // Write header
655        bytes.extend_from_slice(&header.to_bytes());
656
657        // Write metadata (length-prefixed)
658        bytes.extend_from_slice(&(metadata_bytes.len() as u32).to_le_bytes());
659        bytes.extend_from_slice(&metadata_bytes);
660
661        // Write chunks
662        for chunk in &self.chunks {
663            bytes.extend_from_slice(&chunk.header.to_bytes());
664            bytes.extend_from_slice(&chunk.data);
665        }
666
667        // Calculate checksum (simple CRC32 of data after header) and update in place
668        let checksum = crc32_simple(&bytes[64..]);
669        bytes[32..36].copy_from_slice(&checksum.to_le_bytes());
670
671        bytes
672    }
673
674    /// Deserialize checkpoint from bytes.
675    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
676        if bytes.len() < 64 {
677            return Err(RingKernelError::InvalidCheckpoint(
678                "Checkpoint too small".to_string(),
679            ));
680        }
681
682        // Read header - slice is exactly 64 bytes (checked above)
683        let header = CheckpointHeader::from_bytes(
684            bytes[0..64]
685                .try_into()
686                .expect("input validated to be >= 64 bytes"),
687        );
688        header.validate()?;
689
690        // Verify checksum
691        let expected_checksum = crc32_simple(&bytes[64..]);
692        if header.checksum != expected_checksum {
693            return Err(RingKernelError::InvalidCheckpoint(format!(
694                "Checksum mismatch: expected {}, got {}",
695                expected_checksum, header.checksum
696            )));
697        }
698
699        let mut offset = 64;
700
701        // Read metadata length
702        if offset + 4 > bytes.len() {
703            return Err(RingKernelError::InvalidCheckpoint(
704                "Missing metadata length".to_string(),
705            ));
706        }
707        // Bounds checked above, so the 4-byte slice conversion is infallible
708        let metadata_len = u32::from_le_bytes(
709            bytes[offset..offset + 4]
710                .try_into()
711                .expect("bounds checked 4-byte slice"),
712        ) as usize;
713        offset += 4;
714
715        // Read metadata
716        if offset + metadata_len > bytes.len() {
717            return Err(RingKernelError::InvalidCheckpoint(
718                "Metadata truncated".to_string(),
719            ));
720        }
721        let metadata = CheckpointMetadata::from_bytes(&bytes[offset..offset + metadata_len])?;
722        offset += metadata_len;
723
724        // Read chunks
725        let mut chunks = Vec::new();
726        for _ in 0..header.chunk_count {
727            if offset + 32 > bytes.len() {
728                return Err(RingKernelError::InvalidCheckpoint(
729                    "Chunk header truncated".to_string(),
730                ));
731            }
732
733            // Bounds checked above, so the 32-byte slice conversion is infallible
734            let chunk_header = ChunkHeader::from_bytes(
735                bytes[offset..offset + 32]
736                    .try_into()
737                    .expect("bounds checked 32-byte slice"),
738            );
739            offset += 32;
740
741            let data_len = chunk_header.compressed_size as usize;
742            if offset + data_len > bytes.len() {
743                return Err(RingKernelError::InvalidCheckpoint(
744                    "Chunk data truncated".to_string(),
745                ));
746            }
747
748            let data = bytes[offset..offset + data_len].to_vec();
749            offset += data_len;
750
751            chunks.push(DataChunk {
752                header: chunk_header,
753                data,
754            });
755        }
756
757        Ok(Self {
758            header,
759            metadata,
760            chunks,
761        })
762    }
763}
764
765// ============================================================================
766// Incremental (Delta) Checkpoints
767// ============================================================================
768
769/// Metadata custom key that records a delta checkpoint's parent.
770///
771/// A checkpoint produced via [`Checkpoint::delta_from`] sets this key to
772/// the string returned by [`Checkpoint::content_digest`] of the base
773/// checkpoint it was diffed against. [`Checkpoint::applied_with_delta`]
774/// (and anyone reading the delta) can use it to validate the delta is
775/// being applied on top of the correct base.
776pub const DELTA_PARENT_DIGEST_KEY: &str = "ringkernel.delta.parent_digest";
777
778/// Key per chunk that records a stable identity across checkpoints.
779///
780/// [`DataChunk::chunk_identity`] builds a `(ChunkType, Option<chunk_id>)`
781/// pair which is stable across snapshots of the same kernel: a control
782/// block or HLC-state chunk has identity `(ControlBlock, None)`; a
783/// device-memory chunk has identity `(DeviceMemory, Some(name_hash))`.
784/// Two chunks with the same identity refer to the same region —
785/// potentially with different bytes — across a base + delta.
786impl DataChunk {
787    /// Stable (kind, id) identity used for delta diffing. `None` id is
788    /// used for chunk kinds that appear at most once in a checkpoint.
789    pub fn chunk_identity(&self) -> Option<(ChunkType, Option<u64>)> {
790        let kind = self.chunk_type()?;
791        let id = match kind {
792            ChunkType::DeviceMemory | ChunkType::Custom => Some(self.header.chunk_id),
793            _ => None,
794        };
795        Some((kind, id))
796    }
797}
798
799impl Checkpoint {
800    /// Stable content digest of the checkpoint's data chunks.
801    ///
802    /// This is a CRC32 over each chunk's `(identity, data)` in the order
803    /// they were added. Two checkpoints with the same content digest
804    /// have identical chunk contents at identical identities. Used as
805    /// the stable "parent id" for delta checkpoints.
806    pub fn content_digest(&self) -> String {
807        let mut acc: u32 = 0xFFFF_FFFF;
808        for chunk in &self.chunks {
809            if let Some((kind, id)) = chunk.chunk_identity() {
810                let mut header = [0u8; 16];
811                header[0..4].copy_from_slice(&(kind as u32).to_le_bytes());
812                header[4..12].copy_from_slice(&id.unwrap_or(0).to_le_bytes());
813                acc = crc32_update(acc, &header);
814                acc = crc32_update(acc, &chunk.data);
815            }
816        }
817        format!("{:08x}", !acc)
818    }
819
820    /// Produce a delta checkpoint: chunks present in `new` whose bytes
821    /// differ from the corresponding chunk in `base` (same identity),
822    /// plus chunks that are new in `new`. Unchanged chunks are omitted.
823    ///
824    /// The delta's `metadata.custom` records the base's content digest
825    /// under [`DELTA_PARENT_DIGEST_KEY`] so the reader can verify it
826    /// before applying.
827    ///
828    /// Restore via [`Checkpoint::applied_with_delta`].
829    pub fn delta_from(base: &Checkpoint, new: &Checkpoint) -> Checkpoint {
830        use std::collections::HashMap;
831        let mut base_index: HashMap<(ChunkType, Option<u64>), &DataChunk> = HashMap::new();
832        for chunk in &base.chunks {
833            if let Some(id) = chunk.chunk_identity() {
834                base_index.insert(id, chunk);
835            }
836        }
837
838        let mut delta = Checkpoint::new(new.metadata.clone());
839        delta.metadata = delta
840            .metadata
841            .with_custom(DELTA_PARENT_DIGEST_KEY, base.content_digest());
842        for chunk in &new.chunks {
843            let Some(identity) = chunk.chunk_identity() else {
844                continue;
845            };
846            match base_index.get(&identity) {
847                Some(old) if old.data == chunk.data => { /* unchanged, skip */ }
848                _ => delta.chunks.push(chunk.clone()),
849            }
850        }
851        delta
852    }
853
854    /// Apply a delta produced by [`Checkpoint::delta_from`] on top of
855    /// `base`, returning the resulting full checkpoint. Chunks in the
856    /// delta replace chunks with the same identity in `base`; chunks
857    /// only in `base` carry over unchanged.
858    ///
859    /// Errors if the delta's recorded parent digest does not match
860    /// `base.content_digest()` — this catches accidental application
861    /// on top of the wrong base.
862    pub fn applied_with_delta(base: &Checkpoint, delta: &Checkpoint) -> Result<Checkpoint> {
863        if let Some(recorded) = delta.metadata.custom.get(DELTA_PARENT_DIGEST_KEY) {
864            let actual = base.content_digest();
865            if recorded != &actual {
866                return Err(RingKernelError::InvalidCheckpoint(format!(
867                    "delta parent digest mismatch: expected {recorded}, got {actual}"
868                )));
869            }
870        }
871
872        use std::collections::HashMap;
873        let mut out = Checkpoint::new(delta.metadata.clone());
874        let mut delta_index: HashMap<(ChunkType, Option<u64>), &DataChunk> = HashMap::new();
875        for chunk in &delta.chunks {
876            if let Some(id) = chunk.chunk_identity() {
877                delta_index.insert(id, chunk);
878            }
879        }
880
881        // Base chunks first, replaced by delta if present.
882        let mut replaced: std::collections::HashSet<(ChunkType, Option<u64>)> =
883            std::collections::HashSet::new();
884        for chunk in &base.chunks {
885            match chunk.chunk_identity() {
886                Some(id) if delta_index.contains_key(&id) => {
887                    out.chunks.push(delta_index[&id].clone());
888                    replaced.insert(id);
889                }
890                _ => out.chunks.push(chunk.clone()),
891            }
892        }
893        // Chunks in delta that weren't in base.
894        for chunk in &delta.chunks {
895            if let Some(id) = chunk.chunk_identity() {
896                if !replaced.contains(&id) {
897                    out.chunks.push(chunk.clone());
898                }
899            }
900        }
901        Ok(out)
902    }
903}
904
905/// CRC32 rolling update used by [`Checkpoint::content_digest`].
906/// Takes the current reverse-XOR accumulator (so `!crc` gives the
907/// final digest) and updates it with the input bytes.
908fn crc32_update(mut crc: u32, data: &[u8]) -> u32 {
909    const POLY: u32 = 0xEDB88320;
910    for &b in data {
911        crc ^= b as u32;
912        for _ in 0..8 {
913            let mask = (crc & 1).wrapping_neg();
914            crc = (crc >> 1) ^ (POLY & mask);
915        }
916    }
917    crc
918}
919
920// ============================================================================
921// Simple CRC32 Implementation
922// ============================================================================
923
924/// Simple CRC32 checksum (IEEE polynomial).
925fn crc32_simple(data: &[u8]) -> u32 {
926    const CRC32_TABLE: [u32; 256] = crc32_table();
927
928    let mut crc = 0xFFFFFFFF;
929    for byte in data {
930        let index = ((crc ^ (*byte as u32)) & 0xFF) as usize;
931        crc = CRC32_TABLE[index] ^ (crc >> 8);
932    }
933    !crc
934}
935
936/// Generate CRC32 lookup table at compile time.
937const fn crc32_table() -> [u32; 256] {
938    let mut table = [0u32; 256];
939    let mut i = 0;
940    while i < 256 {
941        let mut crc = i as u32;
942        let mut j = 0;
943        while j < 8 {
944            if crc & 1 != 0 {
945                crc = (crc >> 1) ^ 0xEDB88320;
946            } else {
947                crc >>= 1;
948            }
949            j += 1;
950        }
951        table[i] = crc;
952        i += 1;
953    }
954    table
955}
956
957// ============================================================================
958// CheckpointableKernel Trait
959// ============================================================================
960
961/// Trait for kernels that support checkpointing.
962pub trait CheckpointableKernel {
963    /// Create a checkpoint of the current kernel state.
964    ///
965    /// This should pause the kernel, serialize all state, and return a checkpoint.
966    fn create_checkpoint(&self) -> Result<Checkpoint>;
967
968    /// Restore kernel state from a checkpoint.
969    ///
970    /// This should pause the kernel, deserialize state, and resume.
971    fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<()>;
972
973    /// Get the kernel ID for checkpointing.
974    fn checkpoint_kernel_id(&self) -> &str;
975
976    /// Get the kernel type for checkpointing.
977    fn checkpoint_kernel_type(&self) -> &str;
978
979    /// Check if the kernel supports incremental checkpoints.
980    fn supports_incremental(&self) -> bool {
981        false
982    }
983
984    /// Create an incremental checkpoint (only changed state since last checkpoint).
985    fn create_incremental_checkpoint(&self, _base: &Checkpoint) -> Result<Checkpoint> {
986        // Default: fall back to full checkpoint
987        self.create_checkpoint()
988    }
989}
990
991// ============================================================================
992// Checkpoint Storage Trait
993// ============================================================================
994
995/// Trait for checkpoint storage backends.
996pub trait CheckpointStorage: Send + Sync {
997    /// Save a checkpoint with the given name.
998    fn save(&self, checkpoint: &Checkpoint, name: &str) -> Result<()>;
999
1000    /// Load a checkpoint by name.
1001    fn load(&self, name: &str) -> Result<Checkpoint>;
1002
1003    /// List all available checkpoints.
1004    fn list(&self) -> Result<Vec<String>>;
1005
1006    /// Delete a checkpoint.
1007    fn delete(&self, name: &str) -> Result<()>;
1008
1009    /// Check if a checkpoint exists.
1010    fn exists(&self, name: &str) -> bool;
1011}
1012
1013// ============================================================================
1014// File Storage Backend
1015// ============================================================================
1016
1017/// File-based checkpoint storage.
1018pub struct FileStorage {
1019    /// Base directory for checkpoint files.
1020    base_path: PathBuf,
1021}
1022
1023impl FileStorage {
1024    /// Create a new file storage backend.
1025    pub fn new(base_path: impl AsRef<Path>) -> Self {
1026        Self {
1027            base_path: base_path.as_ref().to_path_buf(),
1028        }
1029    }
1030
1031    /// Get the full path for a checkpoint.
1032    fn checkpoint_path(&self, name: &str) -> PathBuf {
1033        self.base_path.join(format!("{}.rkcp", name))
1034    }
1035}
1036
1037impl CheckpointStorage for FileStorage {
1038    fn save(&self, checkpoint: &Checkpoint, name: &str) -> Result<()> {
1039        // Ensure directory exists
1040        std::fs::create_dir_all(&self.base_path).map_err(|e| {
1041            RingKernelError::IoError(format!("Failed to create checkpoint directory: {}", e))
1042        })?;
1043
1044        let path = self.checkpoint_path(name);
1045        let bytes = checkpoint.to_bytes();
1046
1047        let mut file = std::fs::File::create(&path).map_err(|e| {
1048            RingKernelError::IoError(format!("Failed to create checkpoint file: {}", e))
1049        })?;
1050
1051        file.write_all(&bytes)
1052            .map_err(|e| RingKernelError::IoError(format!("Failed to write checkpoint: {}", e)))?;
1053
1054        Ok(())
1055    }
1056
1057    fn load(&self, name: &str) -> Result<Checkpoint> {
1058        let path = self.checkpoint_path(name);
1059
1060        let mut file = std::fs::File::open(&path).map_err(|e| {
1061            RingKernelError::IoError(format!("Failed to open checkpoint file: {}", e))
1062        })?;
1063
1064        let mut bytes = Vec::new();
1065        file.read_to_end(&mut bytes)
1066            .map_err(|e| RingKernelError::IoError(format!("Failed to read checkpoint: {}", e)))?;
1067
1068        Checkpoint::from_bytes(&bytes)
1069    }
1070
1071    fn list(&self) -> Result<Vec<String>> {
1072        let entries = std::fs::read_dir(&self.base_path).map_err(|e| {
1073            RingKernelError::IoError(format!("Failed to read checkpoint directory: {}", e))
1074        })?;
1075
1076        let mut names = Vec::new();
1077        for entry in entries.flatten() {
1078            let path = entry.path();
1079            if path.extension().map(|e| e == "rkcp").unwrap_or(false) {
1080                if let Some(stem) = path.file_stem() {
1081                    names.push(stem.to_string_lossy().to_string());
1082                }
1083            }
1084        }
1085
1086        names.sort();
1087        Ok(names)
1088    }
1089
1090    fn delete(&self, name: &str) -> Result<()> {
1091        let path = self.checkpoint_path(name);
1092        std::fs::remove_file(&path)
1093            .map_err(|e| RingKernelError::IoError(format!("Failed to delete checkpoint: {}", e)))?;
1094        Ok(())
1095    }
1096
1097    fn exists(&self, name: &str) -> bool {
1098        self.checkpoint_path(name).exists()
1099    }
1100}
1101
1102// ============================================================================
1103// Memory Storage Backend
1104// ============================================================================
1105
1106/// In-memory checkpoint storage (for testing and fast operations).
1107pub struct MemoryStorage {
1108    /// Stored checkpoints.
1109    checkpoints: std::sync::RwLock<HashMap<String, Vec<u8>>>,
1110}
1111
1112impl MemoryStorage {
1113    /// Create a new memory storage backend.
1114    pub fn new() -> Self {
1115        Self {
1116            checkpoints: std::sync::RwLock::new(HashMap::new()),
1117        }
1118    }
1119}
1120
1121impl Default for MemoryStorage {
1122    fn default() -> Self {
1123        Self::new()
1124    }
1125}
1126
1127impl CheckpointStorage for MemoryStorage {
1128    fn save(&self, checkpoint: &Checkpoint, name: &str) -> Result<()> {
1129        let bytes = checkpoint.to_bytes();
1130        let mut checkpoints = self
1131            .checkpoints
1132            .write()
1133            .map_err(|_| RingKernelError::IoError("Failed to acquire write lock".to_string()))?;
1134        checkpoints.insert(name.to_string(), bytes);
1135        Ok(())
1136    }
1137
1138    fn load(&self, name: &str) -> Result<Checkpoint> {
1139        let checkpoints = self
1140            .checkpoints
1141            .read()
1142            .map_err(|_| RingKernelError::IoError("Failed to acquire read lock".to_string()))?;
1143
1144        let bytes = checkpoints
1145            .get(name)
1146            .ok_or_else(|| RingKernelError::IoError(format!("Checkpoint not found: {}", name)))?;
1147
1148        Checkpoint::from_bytes(bytes)
1149    }
1150
1151    fn list(&self) -> Result<Vec<String>> {
1152        let checkpoints = self
1153            .checkpoints
1154            .read()
1155            .map_err(|_| RingKernelError::IoError("Failed to acquire read lock".to_string()))?;
1156
1157        let mut names: Vec<_> = checkpoints.keys().cloned().collect();
1158        names.sort();
1159        Ok(names)
1160    }
1161
1162    fn delete(&self, name: &str) -> Result<()> {
1163        let mut checkpoints = self
1164            .checkpoints
1165            .write()
1166            .map_err(|_| RingKernelError::IoError("Failed to acquire write lock".to_string()))?;
1167
1168        checkpoints
1169            .remove(name)
1170            .ok_or_else(|| RingKernelError::IoError(format!("Checkpoint not found: {}", name)))?;
1171
1172        Ok(())
1173    }
1174
1175    fn exists(&self, name: &str) -> bool {
1176        self.checkpoints
1177            .read()
1178            .map(|c| c.contains_key(name))
1179            .unwrap_or(false)
1180    }
1181}
1182
1183// ============================================================================
1184// Checkpoint Builder
1185// ============================================================================
1186
1187/// Builder for creating checkpoints incrementally.
1188pub struct CheckpointBuilder {
1189    metadata: CheckpointMetadata,
1190    chunks: Vec<DataChunk>,
1191}
1192
1193impl CheckpointBuilder {
1194    /// Create a new checkpoint builder.
1195    pub fn new(kernel_id: impl Into<String>, kernel_type: impl Into<String>) -> Self {
1196        Self {
1197            metadata: CheckpointMetadata::new(kernel_id, kernel_type),
1198            chunks: Vec::new(),
1199        }
1200    }
1201
1202    /// Set the current step.
1203    pub fn step(mut self, step: u64) -> Self {
1204        self.metadata.current_step = step;
1205        self
1206    }
1207
1208    /// Set grid size.
1209    pub fn grid_size(mut self, width: u32, height: u32, depth: u32) -> Self {
1210        self.metadata.grid_size = (width, height, depth);
1211        self
1212    }
1213
1214    /// Set tile size.
1215    pub fn tile_size(mut self, x: u32, y: u32, z: u32) -> Self {
1216        self.metadata.tile_size = (x, y, z);
1217        self
1218    }
1219
1220    /// Set HLC timestamp.
1221    pub fn hlc(mut self, hlc: HlcTimestamp) -> Self {
1222        self.metadata.hlc_timestamp = hlc;
1223        self
1224    }
1225
1226    /// Add custom metadata.
1227    pub fn custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1228        self.metadata.custom.insert(key.into(), value.into());
1229        self
1230    }
1231
1232    /// Add control block data.
1233    pub fn control_block(mut self, data: Vec<u8>) -> Self {
1234        self.chunks
1235            .push(DataChunk::new(ChunkType::ControlBlock, data));
1236        self
1237    }
1238
1239    /// Add H2K queue data.
1240    pub fn h2k_queue(mut self, data: Vec<u8>) -> Self {
1241        self.chunks.push(DataChunk::new(ChunkType::H2KQueue, data));
1242        self
1243    }
1244
1245    /// Add K2H queue data.
1246    pub fn k2h_queue(mut self, data: Vec<u8>) -> Self {
1247        self.chunks.push(DataChunk::new(ChunkType::K2HQueue, data));
1248        self
1249    }
1250
1251    /// Add device memory region.
1252    pub fn device_memory(mut self, name: &str, data: Vec<u8>) -> Self {
1253        use std::collections::hash_map::DefaultHasher;
1254        use std::hash::{Hash, Hasher};
1255        let mut hasher = DefaultHasher::new();
1256        name.hash(&mut hasher);
1257        let id = hasher.finish();
1258
1259        self.chunks
1260            .push(DataChunk::with_id(ChunkType::DeviceMemory, data, id));
1261        self
1262    }
1263
1264    /// Add a custom chunk.
1265    pub fn chunk(mut self, chunk: DataChunk) -> Self {
1266        self.chunks.push(chunk);
1267        self
1268    }
1269
1270    /// Build the checkpoint.
1271    pub fn build(self) -> Checkpoint {
1272        let mut checkpoint = Checkpoint::new(self.metadata);
1273        checkpoint.chunks = self.chunks;
1274        checkpoint
1275    }
1276}
1277
1278// ============================================================================
1279// Checkpoint Configuration
1280// ============================================================================
1281
1282/// Configuration for periodic actor state checkpointing.
1283///
1284/// Controls how often snapshots are taken, how many are retained,
1285/// and where they are stored.
1286#[derive(Debug, Clone)]
1287pub struct CheckpointConfig {
1288    /// Interval between periodic snapshots.
1289    pub interval: Duration,
1290    /// Maximum number of checkpoints to retain per kernel.
1291    /// When exceeded, the oldest checkpoint is deleted.
1292    /// A value of 0 means unlimited retention.
1293    pub max_snapshots: usize,
1294    /// Storage path for file-based checkpoints.
1295    pub storage_path: PathBuf,
1296    /// Whether checkpointing is enabled.
1297    pub enabled: bool,
1298    /// Prefix for checkpoint names (e.g., "actor_0").
1299    pub name_prefix: String,
1300}
1301
1302impl Default for CheckpointConfig {
1303    fn default() -> Self {
1304        Self {
1305            interval: Duration::from_secs(30),
1306            max_snapshots: 5,
1307            storage_path: PathBuf::from("/tmp/ringkernel/checkpoints"),
1308            enabled: true,
1309            name_prefix: "checkpoint".to_string(),
1310        }
1311    }
1312}
1313
1314impl CheckpointConfig {
1315    /// Create a new checkpoint config with the given interval.
1316    pub fn new(interval: Duration) -> Self {
1317        Self {
1318            interval,
1319            ..Default::default()
1320        }
1321    }
1322
1323    /// Set the maximum number of retained snapshots.
1324    pub fn with_max_snapshots(mut self, max: usize) -> Self {
1325        self.max_snapshots = max;
1326        self
1327    }
1328
1329    /// Set the storage path.
1330    pub fn with_storage_path(mut self, path: impl AsRef<Path>) -> Self {
1331        self.storage_path = path.as_ref().to_path_buf();
1332        self
1333    }
1334
1335    /// Set the name prefix for checkpoint files.
1336    pub fn with_name_prefix(mut self, prefix: impl Into<String>) -> Self {
1337        self.name_prefix = prefix.into();
1338        self
1339    }
1340
1341    /// Enable or disable checkpointing.
1342    pub fn with_enabled(mut self, enabled: bool) -> Self {
1343        self.enabled = enabled;
1344        self
1345    }
1346}
1347
1348// ============================================================================
1349// Snapshot Request / Response (backend-agnostic protocol)
1350// ============================================================================
1351
1352/// A request to snapshot a specific actor's state.
1353///
1354/// This is the backend-agnostic representation of a snapshot command.
1355/// The CUDA backend maps this to an H2K `SnapshotActor` message.
1356#[derive(Debug, Clone)]
1357pub struct SnapshotRequest {
1358    /// Unique ID for this snapshot request (used for correlation).
1359    pub request_id: u64,
1360    /// Actor slot index to snapshot.
1361    pub actor_slot: u32,
1362    /// Offset in snapshot buffer (backend-specific).
1363    pub buffer_offset: u32,
1364    /// Timestamp when this request was issued.
1365    pub issued_at: SystemTime,
1366}
1367
1368/// A completed snapshot response from the device.
1369///
1370/// The CUDA backend maps this from a K2H `SnapshotComplete` message.
1371#[derive(Debug, Clone)]
1372pub struct SnapshotResponse {
1373    /// The request ID this responds to.
1374    pub request_id: u64,
1375    /// Actor slot that was snapshotted.
1376    pub actor_slot: u32,
1377    /// Whether the snapshot succeeded.
1378    pub success: bool,
1379    /// Snapshot data (copied from device/mapped memory by the backend).
1380    pub data: Vec<u8>,
1381    /// Simulation step at snapshot time.
1382    pub step: u64,
1383}
1384
1385// ============================================================================
1386// Checkpoint Manager
1387// ============================================================================
1388
1389/// Tracks the state of a pending snapshot request.
1390#[derive(Debug, Clone)]
1391struct PendingSnapshot {
1392    /// The original request.
1393    request: SnapshotRequest,
1394    /// Kernel ID this snapshot belongs to.
1395    kernel_id: String,
1396    /// Kernel type for metadata.
1397    kernel_type: String,
1398}
1399
1400/// Manages periodic checkpointing for persistent GPU actors.
1401///
1402/// The `CheckpointManager` orchestrates the checkpoint lifecycle:
1403///
1404/// 1. Periodically determines when a snapshot is due
1405/// 2. Issues `SnapshotRequest`s (caller sends as H2K commands)
1406/// 3. Processes `SnapshotResponse`s (caller feeds from K2H responses)
1407/// 4. Persists completed checkpoints to storage
1408/// 5. Enforces retention policy (deletes old checkpoints)
1409///
1410/// # Usage
1411///
1412/// ```ignore
1413/// use ringkernel_core::checkpoint::{CheckpointConfig, CheckpointManager};
1414/// use std::time::Duration;
1415///
1416/// let config = CheckpointConfig::new(Duration::from_secs(10))
1417///     .with_max_snapshots(3)
1418///     .with_storage_path("/tmp/checkpoints");
1419///
1420/// let mut manager = CheckpointManager::new(config);
1421/// manager.register_actor(0, "wave_sim_0", "fdtd_3d");
1422///
1423/// // In your poll loop:
1424/// for request in manager.poll_due_snapshots() {
1425///     // Send as H2K SnapshotActor command
1426///     h2k_queue.send(H2KMessage::snapshot_actor(
1427///         request.request_id,
1428///         request.actor_slot,
1429///         request.buffer_offset,
1430///     ));
1431/// }
1432///
1433/// // When K2H SnapshotComplete arrives:
1434/// manager.complete_snapshot(SnapshotResponse { ... })?;
1435/// ```
1436pub struct CheckpointManager {
1437    /// Configuration.
1438    config: CheckpointConfig,
1439    /// Storage backend.
1440    storage: Box<dyn CheckpointStorage>,
1441    /// Registered actors: slot -> (kernel_id, kernel_type).
1442    actors: HashMap<u32, (String, String)>,
1443    /// Last snapshot time per actor slot.
1444    last_snapshot: HashMap<u32, std::time::Instant>,
1445    /// Pending snapshot requests awaiting completion.
1446    pending: HashMap<u64, PendingSnapshot>,
1447    /// Next request ID counter.
1448    next_request_id: u64,
1449    /// Ordered list of checkpoint names per actor (oldest first) for retention.
1450    checkpoint_history: HashMap<u32, Vec<String>>,
1451    /// Total snapshots completed (lifetime counter).
1452    total_completed: u64,
1453    /// Total snapshots failed (lifetime counter).
1454    total_failed: u64,
1455}
1456
1457impl CheckpointManager {
1458    /// Create a new checkpoint manager with file storage at the configured path.
1459    pub fn new(config: CheckpointConfig) -> Self {
1460        let storage = Box::new(FileStorage::new(&config.storage_path));
1461        Self {
1462            config,
1463            storage,
1464            actors: HashMap::new(),
1465            last_snapshot: HashMap::new(),
1466            pending: HashMap::new(),
1467            next_request_id: 1,
1468            checkpoint_history: HashMap::new(),
1469            total_completed: 0,
1470            total_failed: 0,
1471        }
1472    }
1473
1474    /// Create a checkpoint manager with a custom storage backend.
1475    pub fn with_storage(config: CheckpointConfig, storage: Box<dyn CheckpointStorage>) -> Self {
1476        Self {
1477            config,
1478            storage,
1479            actors: HashMap::new(),
1480            last_snapshot: HashMap::new(),
1481            pending: HashMap::new(),
1482            next_request_id: 1,
1483            checkpoint_history: HashMap::new(),
1484            total_completed: 0,
1485            total_failed: 0,
1486        }
1487    }
1488
1489    /// Register an actor for periodic checkpointing.
1490    pub fn register_actor(
1491        &mut self,
1492        actor_slot: u32,
1493        kernel_id: impl Into<String>,
1494        kernel_type: impl Into<String>,
1495    ) {
1496        self.actors
1497            .insert(actor_slot, (kernel_id.into(), kernel_type.into()));
1498    }
1499
1500    /// Unregister an actor from checkpointing.
1501    pub fn unregister_actor(&mut self, actor_slot: u32) {
1502        self.actors.remove(&actor_slot);
1503        self.last_snapshot.remove(&actor_slot);
1504    }
1505
1506    /// Check if checkpointing is enabled.
1507    pub fn is_enabled(&self) -> bool {
1508        self.config.enabled
1509    }
1510
1511    /// Get the checkpoint configuration.
1512    pub fn config(&self) -> &CheckpointConfig {
1513        &self.config
1514    }
1515
1516    /// Get the number of pending snapshot requests.
1517    pub fn pending_count(&self) -> usize {
1518        self.pending.len()
1519    }
1520
1521    /// Get total completed snapshots.
1522    pub fn total_completed(&self) -> u64 {
1523        self.total_completed
1524    }
1525
1526    /// Get total failed snapshots.
1527    pub fn total_failed(&self) -> u64 {
1528        self.total_failed
1529    }
1530
1531    /// Poll for actors that are due for a snapshot.
1532    ///
1533    /// Returns a list of `SnapshotRequest`s that should be sent to the device
1534    /// as H2K `SnapshotActor` commands.
1535    ///
1536    /// Each actor is only requested once per interval, and only if no prior
1537    /// request for that actor is still pending.
1538    pub fn poll_due_snapshots(&mut self) -> Vec<SnapshotRequest> {
1539        if !self.config.enabled {
1540            return Vec::new();
1541        }
1542
1543        let now = std::time::Instant::now();
1544        let interval = self.config.interval;
1545        let mut requests = Vec::new();
1546
1547        // Collect actor slots that are due (to avoid borrow conflict)
1548        let due_slots: Vec<u32> = self
1549            .actors
1550            .keys()
1551            .filter(|slot| {
1552                // Skip if there's already a pending request for this actor
1553                let has_pending = self
1554                    .pending
1555                    .values()
1556                    .any(|p| p.request.actor_slot == **slot);
1557                if has_pending {
1558                    return false;
1559                }
1560
1561                // Check if interval has elapsed since last snapshot
1562                match self.last_snapshot.get(slot) {
1563                    Some(last) => now.duration_since(*last) >= interval,
1564                    None => true, // Never snapshotted, due immediately
1565                }
1566            })
1567            .copied()
1568            .collect();
1569
1570        for slot in due_slots {
1571            let request_id = self.next_request_id;
1572            self.next_request_id += 1;
1573
1574            let request = SnapshotRequest {
1575                request_id,
1576                actor_slot: slot,
1577                buffer_offset: 0, // Backend fills in actual offset
1578                issued_at: SystemTime::now(),
1579            };
1580
1581            if let Some((kernel_id, kernel_type)) = self.actors.get(&slot) {
1582                self.pending.insert(
1583                    request_id,
1584                    PendingSnapshot {
1585                        request: request.clone(),
1586                        kernel_id: kernel_id.clone(),
1587                        kernel_type: kernel_type.clone(),
1588                    },
1589                );
1590            }
1591
1592            requests.push(request);
1593        }
1594
1595        requests
1596    }
1597
1598    /// Process a completed snapshot response from the device.
1599    ///
1600    /// If the snapshot succeeded, the data is persisted to storage and
1601    /// the retention policy is enforced.
1602    ///
1603    /// Returns the checkpoint name on success.
1604    pub fn complete_snapshot(&mut self, response: SnapshotResponse) -> Result<Option<String>> {
1605        let pending = match self.pending.remove(&response.request_id) {
1606            Some(p) => p,
1607            None => {
1608                // Unknown request ID -- may have been cancelled or already completed
1609                return Ok(None);
1610            }
1611        };
1612
1613        // Record the snapshot time regardless of success
1614        self.last_snapshot
1615            .insert(pending.request.actor_slot, std::time::Instant::now());
1616
1617        if !response.success {
1618            self.total_failed += 1;
1619            return Err(RingKernelError::InvalidCheckpoint(format!(
1620                "Snapshot failed for actor slot {}",
1621                response.actor_slot
1622            )));
1623        }
1624
1625        // Build a checkpoint from the snapshot data
1626        let checkpoint = CheckpointBuilder::new(&pending.kernel_id, &pending.kernel_type)
1627            .step(response.step)
1628            .custom("actor_slot", pending.request.actor_slot.to_string())
1629            .custom(
1630                "snapshot_request_id",
1631                pending.request.request_id.to_string(),
1632            )
1633            .device_memory("actor_state", response.data)
1634            .build();
1635
1636        // Generate checkpoint name
1637        let name = format!(
1638            "{}_{}_step_{}",
1639            self.config.name_prefix, pending.request.actor_slot, response.step
1640        );
1641
1642        // Persist to storage
1643        self.storage.save(&checkpoint, &name)?;
1644        self.total_completed += 1;
1645
1646        // Track in history for retention
1647        let history = self
1648            .checkpoint_history
1649            .entry(pending.request.actor_slot)
1650            .or_default();
1651        history.push(name.clone());
1652
1653        // Enforce retention policy
1654        if self.config.max_snapshots > 0 {
1655            while history.len() > self.config.max_snapshots {
1656                let oldest = history.remove(0);
1657                if let Err(e) = self.storage.delete(&oldest) {
1658                    tracing::warn!(
1659                        checkpoint = oldest,
1660                        error = %e,
1661                        "Failed to delete old checkpoint during retention cleanup"
1662                    );
1663                }
1664            }
1665        }
1666
1667        Ok(Some(name))
1668    }
1669
1670    /// Manually request a snapshot for a specific actor, bypassing the interval timer.
1671    ///
1672    /// This is useful for on-demand snapshots (e.g., before a risky operation)
1673    /// or in tests. Returns `None` if the actor is not registered.
1674    pub fn request_snapshot(&mut self, actor_slot: u32) -> Option<SnapshotRequest> {
1675        let (kernel_id, kernel_type) = self.actors.get(&actor_slot)?.clone();
1676
1677        let request_id = self.next_request_id;
1678        self.next_request_id += 1;
1679
1680        let request = SnapshotRequest {
1681            request_id,
1682            actor_slot,
1683            buffer_offset: 0,
1684            issued_at: SystemTime::now(),
1685        };
1686
1687        self.pending.insert(
1688            request_id,
1689            PendingSnapshot {
1690                request: request.clone(),
1691                kernel_id,
1692                kernel_type,
1693            },
1694        );
1695
1696        Some(request)
1697    }
1698
1699    /// Cancel a pending snapshot request.
1700    ///
1701    /// Returns true if the request was found and cancelled.
1702    pub fn cancel_pending(&mut self, request_id: u64) -> bool {
1703        self.pending.remove(&request_id).is_some()
1704    }
1705
1706    /// Cancel all pending snapshot requests.
1707    pub fn cancel_all_pending(&mut self) {
1708        self.pending.clear();
1709    }
1710
1711    /// Load the most recent checkpoint for an actor.
1712    pub fn load_latest(&self, actor_slot: u32) -> Result<Option<Checkpoint>> {
1713        if let Some(history) = self.checkpoint_history.get(&actor_slot) {
1714            if let Some(latest_name) = history.last() {
1715                return self.storage.load(latest_name).map(Some);
1716            }
1717        }
1718
1719        // Fallback: scan storage for checkpoints matching this actor's prefix
1720        let prefix = format!("{}_{}_", self.config.name_prefix, actor_slot);
1721        let all = self.storage.list()?;
1722        let matching: Vec<_> = all.iter().filter(|n| n.starts_with(&prefix)).collect();
1723
1724        if let Some(latest) = matching.last() {
1725            return self.storage.load(latest).map(Some);
1726        }
1727
1728        Ok(None)
1729    }
1730
1731    /// List all checkpoint names for an actor.
1732    pub fn list_checkpoints(&self, actor_slot: u32) -> Result<Vec<String>> {
1733        let prefix = format!("{}_{}_", self.config.name_prefix, actor_slot);
1734        let all = self.storage.list()?;
1735        Ok(all.into_iter().filter(|n| n.starts_with(&prefix)).collect())
1736    }
1737
1738    /// Get a reference to the storage backend.
1739    pub fn storage(&self) -> &dyn CheckpointStorage {
1740        &*self.storage
1741    }
1742}
1743
1744// ============================================================================
1745// Tests
1746// ============================================================================
1747
1748#[cfg(test)]
1749mod tests {
1750    use super::*;
1751
1752    #[test]
1753    fn test_checkpoint_header_roundtrip() {
1754        let header = CheckpointHeader::new(5, 1024);
1755        let bytes = header.to_bytes();
1756        let restored = CheckpointHeader::from_bytes(&bytes);
1757
1758        assert_eq!(restored.magic, CHECKPOINT_MAGIC);
1759        assert_eq!(restored.version, CHECKPOINT_VERSION);
1760        assert_eq!(restored.chunk_count, 5);
1761        assert_eq!(restored.total_size, 1024);
1762    }
1763
1764    #[test]
1765    fn test_chunk_header_roundtrip() {
1766        let header = ChunkHeader::new(ChunkType::DeviceMemory, 4096).with_id(12345);
1767        let bytes = header.to_bytes();
1768        let restored = ChunkHeader::from_bytes(&bytes);
1769
1770        assert_eq!(restored.chunk_type, ChunkType::DeviceMemory as u32);
1771        assert_eq!(restored.uncompressed_size, 4096);
1772        assert_eq!(restored.chunk_id, 12345);
1773    }
1774
1775    #[test]
1776    fn test_metadata_roundtrip() {
1777        let metadata = CheckpointMetadata::new("kernel_1", "fdtd_3d")
1778            .with_step(1000)
1779            .with_grid_size(64, 64, 64)
1780            .with_tile_size(8, 8, 8)
1781            .with_custom("version", "1.0");
1782
1783        let bytes = metadata.to_bytes();
1784        let restored = CheckpointMetadata::from_bytes(&bytes).unwrap();
1785
1786        assert_eq!(restored.kernel_id, "kernel_1");
1787        assert_eq!(restored.kernel_type, "fdtd_3d");
1788        assert_eq!(restored.current_step, 1000);
1789        assert_eq!(restored.grid_size, (64, 64, 64));
1790        assert_eq!(restored.tile_size, (8, 8, 8));
1791        assert_eq!(restored.custom.get("version"), Some(&"1.0".to_string()));
1792    }
1793
1794    #[test]
1795    fn test_checkpoint_roundtrip() {
1796        let checkpoint = CheckpointBuilder::new("test_kernel", "test_type")
1797            .step(500)
1798            .grid_size(32, 32, 32)
1799            .control_block(vec![1, 2, 3, 4])
1800            .device_memory("pressure_a", vec![5, 6, 7, 8, 9, 10])
1801            .build();
1802
1803        let bytes = checkpoint.to_bytes();
1804        let restored = Checkpoint::from_bytes(&bytes).unwrap();
1805
1806        assert_eq!(restored.metadata.kernel_id, "test_kernel");
1807        assert_eq!(restored.metadata.current_step, 500);
1808        assert_eq!(restored.chunks.len(), 2);
1809
1810        let control = restored.get_chunk(ChunkType::ControlBlock).unwrap();
1811        assert_eq!(control.data, vec![1, 2, 3, 4]);
1812    }
1813
1814    #[test]
1815    fn test_memory_storage() {
1816        let storage = MemoryStorage::new();
1817
1818        let checkpoint = CheckpointBuilder::new("mem_test", "test").step(100).build();
1819
1820        storage.save(&checkpoint, "test_001").unwrap();
1821        assert!(storage.exists("test_001"));
1822
1823        let loaded = storage.load("test_001").unwrap();
1824        assert_eq!(loaded.metadata.kernel_id, "mem_test");
1825        assert_eq!(loaded.metadata.current_step, 100);
1826
1827        let list = storage.list().unwrap();
1828        assert_eq!(list, vec!["test_001"]);
1829
1830        storage.delete("test_001").unwrap();
1831        assert!(!storage.exists("test_001"));
1832    }
1833
1834    #[test]
1835    fn test_crc32() {
1836        // Known CRC32 values
1837        assert_eq!(crc32_simple(b""), 0);
1838        assert_eq!(crc32_simple(b"123456789"), 0xCBF43926);
1839    }
1840
1841    #[test]
1842    fn test_checkpoint_validation() {
1843        // Test invalid magic
1844        let mut bytes = [0u8; 64];
1845        bytes[0..8].copy_from_slice(&0u64.to_le_bytes()); // Wrong magic
1846
1847        let header = CheckpointHeader::from_bytes(&bytes);
1848        assert!(header.validate().is_err());
1849    }
1850
1851    #[test]
1852    fn test_large_checkpoint() {
1853        // Test with larger data
1854        let large_data: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
1855
1856        let checkpoint = CheckpointBuilder::new("large_kernel", "stress_test")
1857            .step(999)
1858            .device_memory("field_a", large_data.clone())
1859            .device_memory("field_b", large_data.clone())
1860            .build();
1861
1862        let bytes = checkpoint.to_bytes();
1863        let restored = Checkpoint::from_bytes(&bytes).unwrap();
1864
1865        assert_eq!(restored.chunks.len(), 2);
1866        let chunks = restored.get_chunks(ChunkType::DeviceMemory);
1867        assert_eq!(chunks.len(), 2);
1868        assert_eq!(chunks[0].data.len(), 100_000);
1869    }
1870
1871    // ========================================================================
1872    // CheckpointConfig tests
1873    // ========================================================================
1874
1875    #[test]
1876    fn test_checkpoint_config_defaults() {
1877        let config = CheckpointConfig::default();
1878        assert_eq!(config.interval, Duration::from_secs(30));
1879        assert_eq!(config.max_snapshots, 5);
1880        assert!(config.enabled);
1881        assert_eq!(config.name_prefix, "checkpoint");
1882    }
1883
1884    #[test]
1885    fn test_checkpoint_config_builder() {
1886        let config = CheckpointConfig::new(Duration::from_secs(10))
1887            .with_max_snapshots(3)
1888            .with_storage_path("/var/checkpoints")
1889            .with_name_prefix("actor")
1890            .with_enabled(false);
1891
1892        assert_eq!(config.interval, Duration::from_secs(10));
1893        assert_eq!(config.max_snapshots, 3);
1894        assert_eq!(config.storage_path, PathBuf::from("/var/checkpoints"));
1895        assert_eq!(config.name_prefix, "actor");
1896        assert!(!config.enabled);
1897    }
1898
1899    // ========================================================================
1900    // CheckpointManager tests
1901    // ========================================================================
1902
1903    #[test]
1904    fn test_manager_disabled() {
1905        let config = CheckpointConfig::new(Duration::from_millis(1)).with_enabled(false);
1906        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
1907        manager.register_actor(0, "kernel_0", "test");
1908
1909        // Should return no requests when disabled
1910        let requests = manager.poll_due_snapshots();
1911        assert!(requests.is_empty());
1912    }
1913
1914    #[test]
1915    fn test_manager_register_and_poll() {
1916        let config = CheckpointConfig::new(Duration::from_millis(1));
1917        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
1918        manager.register_actor(0, "sim_0", "fdtd_3d");
1919        manager.register_actor(1, "sim_1", "fdtd_3d");
1920
1921        // First poll: both actors are due immediately (never snapshotted)
1922        let requests = manager.poll_due_snapshots();
1923        assert_eq!(requests.len(), 2);
1924
1925        // Verify request fields
1926        let slots: Vec<u32> = requests.iter().map(|r| r.actor_slot).collect();
1927        assert!(slots.contains(&0));
1928        assert!(slots.contains(&1));
1929
1930        // Request IDs should be unique
1931        assert_ne!(requests[0].request_id, requests[1].request_id);
1932    }
1933
1934    #[test]
1935    fn test_manager_no_duplicate_pending() {
1936        let config = CheckpointConfig::new(Duration::from_millis(1));
1937        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
1938        manager.register_actor(0, "sim_0", "fdtd_3d");
1939
1940        // First poll: actor is due
1941        let requests = manager.poll_due_snapshots();
1942        assert_eq!(requests.len(), 1);
1943
1944        // Second poll: actor already has a pending request
1945        let requests2 = manager.poll_due_snapshots();
1946        assert!(requests2.is_empty());
1947    }
1948
1949    #[test]
1950    fn test_manager_complete_snapshot() {
1951        let config = CheckpointConfig::new(Duration::from_secs(3600)).with_name_prefix("test");
1952        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
1953        manager.register_actor(0, "sim_0", "fdtd_3d");
1954
1955        // Generate a snapshot request
1956        let requests = manager.poll_due_snapshots();
1957        assert_eq!(requests.len(), 1);
1958        let req = &requests[0];
1959
1960        // Complete the snapshot
1961        let response = SnapshotResponse {
1962            request_id: req.request_id,
1963            actor_slot: 0,
1964            success: true,
1965            data: vec![1, 2, 3, 4, 5],
1966            step: 1000,
1967        };
1968
1969        let name = manager.complete_snapshot(response).unwrap();
1970        assert!(name.is_some());
1971        let name = name.unwrap();
1972        assert_eq!(name, "test_0_step_1000");
1973
1974        // Verify checkpoint was persisted
1975        assert!(manager.storage().exists(&name));
1976
1977        // Load and verify
1978        let loaded = manager.storage().load(&name).unwrap();
1979        assert_eq!(loaded.metadata.kernel_id, "sim_0");
1980        assert_eq!(loaded.metadata.kernel_type, "fdtd_3d");
1981        assert_eq!(loaded.metadata.current_step, 1000);
1982
1983        assert_eq!(manager.total_completed(), 1);
1984        assert_eq!(manager.total_failed(), 0);
1985    }
1986
1987    #[test]
1988    fn test_manager_failed_snapshot() {
1989        let config = CheckpointConfig::new(Duration::from_secs(3600));
1990        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
1991        manager.register_actor(0, "sim_0", "fdtd_3d");
1992
1993        let requests = manager.poll_due_snapshots();
1994        let req = &requests[0];
1995
1996        let response = SnapshotResponse {
1997            request_id: req.request_id,
1998            actor_slot: 0,
1999            success: false,
2000            data: Vec::new(),
2001            step: 500,
2002        };
2003
2004        let result = manager.complete_snapshot(response);
2005        assert!(result.is_err());
2006        assert_eq!(manager.total_failed(), 1);
2007        assert_eq!(manager.total_completed(), 0);
2008    }
2009
2010    #[test]
2011    fn test_manager_retention_policy() {
2012        let config = CheckpointConfig::new(Duration::from_secs(3600))
2013            .with_max_snapshots(2)
2014            .with_name_prefix("ret");
2015        let storage = Box::new(MemoryStorage::new());
2016        let mut manager = CheckpointManager::with_storage(config, storage);
2017        manager.register_actor(0, "sim_0", "test");
2018
2019        // Create 3 snapshots using manual requests (bypasses interval timer)
2020        for step in [100u64, 200, 300] {
2021            let req = manager.request_snapshot(0).unwrap();
2022
2023            let response = SnapshotResponse {
2024                request_id: req.request_id,
2025                actor_slot: 0,
2026                success: true,
2027                data: vec![step as u8],
2028                step,
2029            };
2030            manager.complete_snapshot(response).unwrap();
2031        }
2032
2033        // Oldest checkpoint (step 100) should have been deleted
2034        assert!(!manager.storage().exists("ret_0_step_100"));
2035        // Steps 200 and 300 should remain
2036        assert!(manager.storage().exists("ret_0_step_200"));
2037        assert!(manager.storage().exists("ret_0_step_300"));
2038
2039        assert_eq!(manager.total_completed(), 3);
2040    }
2041
2042    #[test]
2043    fn test_manager_unknown_response() {
2044        let config = CheckpointConfig::new(Duration::from_secs(3600));
2045        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
2046
2047        // Response with unknown request_id
2048        let response = SnapshotResponse {
2049            request_id: 9999,
2050            actor_slot: 0,
2051            success: true,
2052            data: vec![1, 2, 3],
2053            step: 100,
2054        };
2055
2056        let result = manager.complete_snapshot(response).unwrap();
2057        assert!(result.is_none());
2058    }
2059
2060    #[test]
2061    fn test_manager_cancel_pending() {
2062        let config = CheckpointConfig::new(Duration::from_millis(1));
2063        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
2064        manager.register_actor(0, "sim_0", "test");
2065
2066        let requests = manager.poll_due_snapshots();
2067        assert_eq!(manager.pending_count(), 1);
2068
2069        let cancelled = manager.cancel_pending(requests[0].request_id);
2070        assert!(cancelled);
2071        assert_eq!(manager.pending_count(), 0);
2072    }
2073
2074    #[test]
2075    fn test_manager_cancel_all_pending() {
2076        let config = CheckpointConfig::new(Duration::from_millis(1));
2077        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
2078        manager.register_actor(0, "sim_0", "test");
2079        manager.register_actor(1, "sim_1", "test");
2080
2081        let _requests = manager.poll_due_snapshots();
2082        assert_eq!(manager.pending_count(), 2);
2083
2084        manager.cancel_all_pending();
2085        assert_eq!(manager.pending_count(), 0);
2086    }
2087
2088    #[test]
2089    fn test_manager_load_latest() {
2090        let config = CheckpointConfig::new(Duration::from_secs(3600)).with_name_prefix("lat");
2091        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
2092        manager.register_actor(0, "sim_0", "test");
2093
2094        // No checkpoints yet
2095        let latest = manager.load_latest(0).unwrap();
2096        assert!(latest.is_none());
2097
2098        // Create two checkpoints using manual requests
2099        for step in [100u64, 200] {
2100            let req = manager.request_snapshot(0).unwrap();
2101            let response = SnapshotResponse {
2102                request_id: req.request_id,
2103                actor_slot: 0,
2104                success: true,
2105                data: vec![step as u8],
2106                step,
2107            };
2108            manager.complete_snapshot(response).unwrap();
2109        }
2110
2111        // Latest should be step 200
2112        let latest = manager.load_latest(0).unwrap().unwrap();
2113        assert_eq!(latest.metadata.current_step, 200);
2114    }
2115
2116    #[test]
2117    fn test_manager_list_checkpoints() {
2118        let config = CheckpointConfig::new(Duration::from_secs(3600))
2119            .with_max_snapshots(10)
2120            .with_name_prefix("list");
2121        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
2122        manager.register_actor(0, "sim_0", "test");
2123        manager.register_actor(1, "sim_1", "test");
2124
2125        // Create checkpoints for both actors using manual requests
2126        for step in [100u64, 200] {
2127            for actor_slot in [0u32, 1] {
2128                let req = manager.request_snapshot(actor_slot).unwrap();
2129                let response = SnapshotResponse {
2130                    request_id: req.request_id,
2131                    actor_slot,
2132                    success: true,
2133                    data: vec![step as u8],
2134                    step,
2135                };
2136                manager.complete_snapshot(response).unwrap();
2137            }
2138        }
2139
2140        let actor0_checkpoints = manager.list_checkpoints(0).unwrap();
2141        let actor1_checkpoints = manager.list_checkpoints(1).unwrap();
2142
2143        assert_eq!(actor0_checkpoints.len(), 2);
2144        assert_eq!(actor1_checkpoints.len(), 2);
2145
2146        // Actor 0's checkpoints should only contain its own
2147        for name in &actor0_checkpoints {
2148            assert!(name.starts_with("list_0_"));
2149        }
2150        // Actor 1's checkpoints should only contain its own
2151        for name in &actor1_checkpoints {
2152            assert!(name.starts_with("list_1_"));
2153        }
2154    }
2155
2156    #[test]
2157    fn test_manager_unregister_actor() {
2158        let config = CheckpointConfig::new(Duration::from_millis(1));
2159        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
2160        manager.register_actor(0, "sim_0", "test");
2161
2162        let requests = manager.poll_due_snapshots();
2163        assert_eq!(requests.len(), 1);
2164
2165        // Unregister and poll again
2166        manager.unregister_actor(0);
2167        // Complete the pending request first
2168        manager.cancel_all_pending();
2169
2170        let requests = manager.poll_due_snapshots();
2171        assert!(requests.is_empty());
2172    }
2173
2174    #[test]
2175    fn test_snapshot_request_response_roundtrip() {
2176        // Test that the request/response types work correctly
2177        let request = SnapshotRequest {
2178            request_id: 42,
2179            actor_slot: 7,
2180            buffer_offset: 4096,
2181            issued_at: SystemTime::now(),
2182        };
2183
2184        assert_eq!(request.request_id, 42);
2185        assert_eq!(request.actor_slot, 7);
2186        assert_eq!(request.buffer_offset, 4096);
2187
2188        let response = SnapshotResponse {
2189            request_id: 42,
2190            actor_slot: 7,
2191            success: true,
2192            data: vec![0xDE, 0xAD, 0xBE, 0xEF],
2193            step: 5000,
2194        };
2195
2196        assert_eq!(response.request_id, request.request_id);
2197        assert_eq!(response.actor_slot, request.actor_slot);
2198        assert!(response.success);
2199        assert_eq!(response.step, 5000);
2200    }
2201
2202    #[test]
2203    fn test_manager_interval_respected() {
2204        // Use a long interval so second poll returns nothing
2205        let config = CheckpointConfig::new(Duration::from_secs(3600));
2206        let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
2207        manager.register_actor(0, "sim_0", "test");
2208
2209        // First poll: due immediately
2210        let requests = manager.poll_due_snapshots();
2211        assert_eq!(requests.len(), 1);
2212
2213        // Complete the snapshot
2214        let response = SnapshotResponse {
2215            request_id: requests[0].request_id,
2216            actor_slot: 0,
2217            success: true,
2218            data: vec![1],
2219            step: 100,
2220        };
2221        manager.complete_snapshot(response).unwrap();
2222
2223        // Second poll: not due yet (interval is 1 hour)
2224        let requests = manager.poll_due_snapshots();
2225        assert!(requests.is_empty());
2226    }
2227
2228    #[test]
2229    fn test_file_storage_roundtrip() {
2230        // Use a temp directory for file storage
2231        let tmp_dir = std::env::temp_dir().join("ringkernel_checkpoint_test");
2232
2233        let config = CheckpointConfig::new(Duration::from_millis(1))
2234            .with_storage_path(&tmp_dir)
2235            .with_name_prefix("file_test");
2236        let mut manager = CheckpointManager::new(config);
2237        manager.register_actor(0, "file_kernel", "test_type");
2238
2239        let requests = manager.poll_due_snapshots();
2240        assert_eq!(requests.len(), 1);
2241
2242        let response = SnapshotResponse {
2243            request_id: requests[0].request_id,
2244            actor_slot: 0,
2245            success: true,
2246            data: vec![10, 20, 30, 40, 50],
2247            step: 42,
2248        };
2249
2250        let name = manager.complete_snapshot(response).unwrap().unwrap();
2251
2252        // Verify the file exists on disk
2253        let file_path = tmp_dir.join(format!("{}.rkcp", name));
2254        assert!(file_path.exists());
2255
2256        // Load it back
2257        let loaded = manager.load_latest(0).unwrap().unwrap();
2258        assert_eq!(loaded.metadata.kernel_id, "file_kernel");
2259        assert_eq!(loaded.metadata.current_step, 42);
2260
2261        // Clean up
2262        let _ = std::fs::remove_dir_all(&tmp_dir);
2263    }
2264
2265    // ========== Delta / incremental checkpoints ==========
2266
2267    fn build_sample_checkpoint(control: &[u8], mem: &[u8]) -> Checkpoint {
2268        let meta = CheckpointMetadata::new("delta_test", "sim").with_step(0);
2269        let mut cp = Checkpoint::new(meta);
2270        cp.add_control_block(control.to_vec());
2271        cp.add_device_memory("pressure", mem.to_vec());
2272        cp
2273    }
2274
2275    #[test]
2276    fn delta_from_empty_when_new_matches_base() {
2277        let base = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
2278        let new = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
2279        let delta = Checkpoint::delta_from(&base, &new);
2280        assert!(
2281            delta.chunks.is_empty(),
2282            "unchanged chunks should be omitted"
2283        );
2284        assert_eq!(
2285            delta.metadata.custom.get(DELTA_PARENT_DIGEST_KEY).cloned(),
2286            Some(base.content_digest())
2287        );
2288    }
2289
2290    #[test]
2291    fn delta_captures_changed_and_new_chunks() {
2292        let base = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
2293        let mut new = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 9, 9]); // mem changed
2294        new.add_h2k_queue(vec![42, 42]); // new chunk
2295        let delta = Checkpoint::delta_from(&base, &new);
2296
2297        // Device memory should be in delta; control block should not be.
2298        assert!(delta
2299            .chunks
2300            .iter()
2301            .any(|c| c.chunk_type() == Some(ChunkType::DeviceMemory)));
2302        assert!(delta
2303            .chunks
2304            .iter()
2305            .any(|c| c.chunk_type() == Some(ChunkType::H2KQueue)));
2306        assert!(!delta
2307            .chunks
2308            .iter()
2309            .any(|c| c.chunk_type() == Some(ChunkType::ControlBlock)));
2310    }
2311
2312    #[test]
2313    fn delta_apply_recovers_new_checkpoint() {
2314        let base = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
2315        let mut new = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 9, 9]);
2316        new.add_h2k_queue(vec![42, 42]);
2317        let delta = Checkpoint::delta_from(&base, &new);
2318
2319        let restored = Checkpoint::applied_with_delta(&base, &delta).expect("apply");
2320        // Restored should have same chunk identities as `new`.
2321        assert_eq!(restored.chunks.len(), new.chunks.len());
2322        for chunk in &new.chunks {
2323            let id = chunk.chunk_identity().unwrap();
2324            let found = restored
2325                .chunks
2326                .iter()
2327                .find(|c| c.chunk_identity() == Some(id))
2328                .expect("identity present");
2329            assert_eq!(found.data, chunk.data, "chunk {id:?} bytes match");
2330        }
2331    }
2332
2333    #[test]
2334    fn delta_apply_rejects_wrong_base() {
2335        let base_a = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
2336        let base_b = build_sample_checkpoint(&[9, 9, 9], &[8, 8, 8, 8]);
2337        let new = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 9, 9]);
2338        let delta = Checkpoint::delta_from(&base_a, &new);
2339        let err = Checkpoint::applied_with_delta(&base_b, &delta)
2340            .expect_err("different base should fail");
2341        assert!(matches!(err, RingKernelError::InvalidCheckpoint(_)));
2342    }
2343
2344    #[test]
2345    fn content_digest_stable_across_identical_chunks() {
2346        let a = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
2347        let b = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
2348        assert_eq!(a.content_digest(), b.content_digest());
2349    }
2350}