ringkernel_core/
checkpoint.rs

1//! Kernel checkpointing for persistent state snapshot and restore.
2//!
3//! This module provides infrastructure for checkpointing persistent GPU kernels,
4//! enabling fault tolerance, migration, and debugging capabilities.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────────┐
10//! │                    CheckpointableKernel                         │
11//! │  ┌─────────────┐  ┌─────────────┐  ┌─────────────────────────┐  │
12//! │  │ Control     │  │ Queue       │  │ Device Memory           │  │
13//! │  │ Block       │  │ State       │  │ (pressure, halo, etc.)  │  │
14//! │  └─────────────┘  └─────────────┘  └─────────────────────────┘  │
15//! └─────────────────────────────────────────────────────────────────┘
16//!                              │
17//!                              ▼
18//! ┌─────────────────────────────────────────────────────────────────┐
19//! │                        Checkpoint                               │
20//! │  ┌─────────────┐  ┌─────────────┐  ┌─────────────────────────┐  │
21//! │  │ Header      │  │ Metadata    │  │ Compressed Data Chunks  │  │
22//! │  │ (magic,ver) │  │ (kernel_id) │  │ (control,queues,memory) │  │
23//! │  └─────────────┘  └─────────────┘  └─────────────────────────┘  │
24//! └─────────────────────────────────────────────────────────────────┘
25//!                              │
26//!                              ▼
27//! ┌─────────────────────────────────────────────────────────────────┐
28//! │                   CheckpointStorage                             │
29//! │  ┌─────────────┐  ┌─────────────┐  ┌─────────────────────────┐  │
30//! │  │ File        │  │ Memory      │  │ Cloud (S3/GCS)          │  │
31//! │  │ Backend     │  │ Backend     │  │ Backend                 │  │
32//! │  └─────────────┘  └─────────────┘  └─────────────────────────┘  │
33//! └─────────────────────────────────────────────────────────────────┘
34//! ```
35//!
36//! # Example
37//!
38//! ```ignore
39//! use ringkernel_core::checkpoint::{Checkpoint, FileStorage, CheckpointableKernel};
40//!
41//! // Create checkpoint from running kernel
42//! let checkpoint = kernel.create_checkpoint()?;
43//!
44//! // Save to file
45//! let storage = FileStorage::new("/checkpoints");
46//! storage.save(&checkpoint, "sim_step_1000")?;
47//!
48//! // Later: restore from checkpoint
49//! let checkpoint = storage.load("sim_step_1000")?;
50//! kernel.restore_from_checkpoint(&checkpoint)?;
51//! ```
52
53use std::collections::HashMap;
54use std::io::{Read, Write};
55use std::path::{Path, PathBuf};
56use std::time::{Duration, SystemTime, UNIX_EPOCH};
57
58use crate::error::{Result, RingKernelError};
59use crate::hlc::HlcTimestamp;
60
61// ============================================================================
62// Checkpoint Format Constants
63// ============================================================================
64
65/// Magic number for checkpoint files: "RKCKPT01" in ASCII.
66pub const CHECKPOINT_MAGIC: u64 = 0x524B434B50543031;
67
68/// Current checkpoint format version.
69pub const CHECKPOINT_VERSION: u32 = 1;
70
71/// Maximum supported checkpoint size (1 GB).
72pub const MAX_CHECKPOINT_SIZE: usize = 1024 * 1024 * 1024;
73
74/// Chunk types for checkpoint data sections.
75#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
76#[repr(u32)]
77pub enum ChunkType {
78    /// Control block state (256 bytes typically).
79    ControlBlock = 1,
80    /// H2K queue header and pending messages.
81    H2KQueue = 2,
82    /// K2H queue header and pending messages.
83    K2HQueue = 3,
84    /// HLC timestamp state.
85    HlcState = 4,
86    /// Device memory region (e.g., pressure field).
87    DeviceMemory = 5,
88    /// K2K routing table.
89    K2KRouting = 6,
90    /// Halo exchange buffers.
91    HaloBuffers = 7,
92    /// Telemetry statistics.
93    Telemetry = 8,
94    /// Custom application data.
95    Custom = 100,
96}
97
98impl ChunkType {
99    /// Convert from raw u32 value.
100    pub fn from_u32(value: u32) -> Option<Self> {
101        match value {
102            1 => Some(Self::ControlBlock),
103            2 => Some(Self::H2KQueue),
104            3 => Some(Self::K2HQueue),
105            4 => Some(Self::HlcState),
106            5 => Some(Self::DeviceMemory),
107            6 => Some(Self::K2KRouting),
108            7 => Some(Self::HaloBuffers),
109            8 => Some(Self::Telemetry),
110            100 => Some(Self::Custom),
111            _ => None,
112        }
113    }
114}
115
116// ============================================================================
117// Checkpoint Header
118// ============================================================================
119
120/// Checkpoint file header (64 bytes, fixed size).
121#[derive(Debug, Clone, Copy)]
122#[repr(C)]
123pub struct CheckpointHeader {
124    /// Magic number for format identification.
125    pub magic: u64,
126    /// Format version number.
127    pub version: u32,
128    /// Header size in bytes.
129    pub header_size: u32,
130    /// Total checkpoint size in bytes (including header).
131    pub total_size: u64,
132    /// Number of data chunks.
133    pub chunk_count: u32,
134    /// Compression algorithm (0 = none, 1 = lz4, 2 = zstd).
135    pub compression: u32,
136    /// CRC32 checksum of all data after header.
137    pub checksum: u32,
138    /// Flags (reserved for future use).
139    pub flags: u32,
140    /// Timestamp when checkpoint was created (UNIX epoch microseconds).
141    pub created_at: u64,
142    /// Reserved for alignment.
143    pub _reserved: [u8; 8],
144}
145
146impl CheckpointHeader {
147    /// Create a new checkpoint header.
148    pub fn new(chunk_count: u32, total_size: u64) -> Self {
149        let now = SystemTime::now()
150            .duration_since(UNIX_EPOCH)
151            .unwrap_or(Duration::ZERO);
152
153        Self {
154            magic: CHECKPOINT_MAGIC,
155            version: CHECKPOINT_VERSION,
156            header_size: std::mem::size_of::<Self>() as u32,
157            total_size,
158            chunk_count,
159            compression: 0,
160            checksum: 0,
161            flags: 0,
162            created_at: now.as_micros() as u64,
163            _reserved: [0; 8],
164        }
165    }
166
167    /// Validate the header.
168    pub fn validate(&self) -> Result<()> {
169        if self.magic != CHECKPOINT_MAGIC {
170            return Err(RingKernelError::InvalidCheckpoint(
171                "Invalid magic number".to_string(),
172            ));
173        }
174        if self.version > CHECKPOINT_VERSION {
175            return Err(RingKernelError::InvalidCheckpoint(format!(
176                "Unsupported version: {} (max: {})",
177                self.version, CHECKPOINT_VERSION
178            )));
179        }
180        if self.total_size as usize > MAX_CHECKPOINT_SIZE {
181            return Err(RingKernelError::InvalidCheckpoint(format!(
182                "Checkpoint too large: {} bytes (max: {})",
183                self.total_size, MAX_CHECKPOINT_SIZE
184            )));
185        }
186        Ok(())
187    }
188
189    /// Serialize to bytes.
190    pub fn to_bytes(&self) -> [u8; 64] {
191        let mut bytes = [0u8; 64];
192        bytes[0..8].copy_from_slice(&self.magic.to_le_bytes());
193        bytes[8..12].copy_from_slice(&self.version.to_le_bytes());
194        bytes[12..16].copy_from_slice(&self.header_size.to_le_bytes());
195        bytes[16..24].copy_from_slice(&self.total_size.to_le_bytes());
196        bytes[24..28].copy_from_slice(&self.chunk_count.to_le_bytes());
197        bytes[28..32].copy_from_slice(&self.compression.to_le_bytes());
198        bytes[32..36].copy_from_slice(&self.checksum.to_le_bytes());
199        bytes[36..40].copy_from_slice(&self.flags.to_le_bytes());
200        bytes[40..48].copy_from_slice(&self.created_at.to_le_bytes());
201        bytes
202    }
203
204    /// Deserialize from bytes.
205    pub fn from_bytes(bytes: &[u8; 64]) -> Self {
206        Self {
207            magic: u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
208            version: u32::from_le_bytes(bytes[8..12].try_into().unwrap()),
209            header_size: u32::from_le_bytes(bytes[12..16].try_into().unwrap()),
210            total_size: u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
211            chunk_count: u32::from_le_bytes(bytes[24..28].try_into().unwrap()),
212            compression: u32::from_le_bytes(bytes[28..32].try_into().unwrap()),
213            checksum: u32::from_le_bytes(bytes[32..36].try_into().unwrap()),
214            flags: u32::from_le_bytes(bytes[36..40].try_into().unwrap()),
215            created_at: u64::from_le_bytes(bytes[40..48].try_into().unwrap()),
216            _reserved: bytes[48..56].try_into().unwrap(),
217        }
218    }
219}
220
221// ============================================================================
222// Chunk Header
223// ============================================================================
224
225/// Header for each data chunk (32 bytes).
226#[derive(Debug, Clone, Copy)]
227#[repr(C)]
228pub struct ChunkHeader {
229    /// Chunk type identifier.
230    pub chunk_type: u32,
231    /// Chunk flags (compression, etc.).
232    pub flags: u32,
233    /// Uncompressed data size.
234    pub uncompressed_size: u64,
235    /// Compressed data size (same as uncompressed if not compressed).
236    pub compressed_size: u64,
237    /// Chunk-specific identifier (e.g., memory region name hash).
238    pub chunk_id: u64,
239}
240
241impl ChunkHeader {
242    /// Create a new chunk header.
243    pub fn new(chunk_type: ChunkType, data_size: usize) -> Self {
244        Self {
245            chunk_type: chunk_type as u32,
246            flags: 0,
247            uncompressed_size: data_size as u64,
248            compressed_size: data_size as u64,
249            chunk_id: 0,
250        }
251    }
252
253    /// Set the chunk ID.
254    pub fn with_id(mut self, id: u64) -> Self {
255        self.chunk_id = id;
256        self
257    }
258
259    /// Serialize to bytes.
260    pub fn to_bytes(&self) -> [u8; 32] {
261        let mut bytes = [0u8; 32];
262        bytes[0..4].copy_from_slice(&self.chunk_type.to_le_bytes());
263        bytes[4..8].copy_from_slice(&self.flags.to_le_bytes());
264        bytes[8..16].copy_from_slice(&self.uncompressed_size.to_le_bytes());
265        bytes[16..24].copy_from_slice(&self.compressed_size.to_le_bytes());
266        bytes[24..32].copy_from_slice(&self.chunk_id.to_le_bytes());
267        bytes
268    }
269
270    /// Deserialize from bytes.
271    pub fn from_bytes(bytes: &[u8; 32]) -> Self {
272        Self {
273            chunk_type: u32::from_le_bytes(bytes[0..4].try_into().unwrap()),
274            flags: u32::from_le_bytes(bytes[4..8].try_into().unwrap()),
275            uncompressed_size: u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
276            compressed_size: u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
277            chunk_id: u64::from_le_bytes(bytes[24..32].try_into().unwrap()),
278        }
279    }
280}
281
282// ============================================================================
283// Checkpoint Metadata
284// ============================================================================
285
286/// Kernel-specific metadata stored in checkpoint.
287#[derive(Debug, Clone, Default)]
288pub struct CheckpointMetadata {
289    /// Unique kernel identifier.
290    pub kernel_id: String,
291    /// Kernel type (e.g., "fdtd_3d", "wave_sim").
292    pub kernel_type: String,
293    /// Current simulation step.
294    pub current_step: u64,
295    /// Grid dimensions.
296    pub grid_size: (u32, u32, u32),
297    /// Tile/block dimensions.
298    pub tile_size: (u32, u32, u32),
299    /// HLC timestamp at checkpoint time.
300    pub hlc_timestamp: HlcTimestamp,
301    /// Custom key-value metadata.
302    pub custom: HashMap<String, String>,
303}
304
305impl CheckpointMetadata {
306    /// Create new metadata for a kernel.
307    pub fn new(kernel_id: impl Into<String>, kernel_type: impl Into<String>) -> Self {
308        Self {
309            kernel_id: kernel_id.into(),
310            kernel_type: kernel_type.into(),
311            ..Default::default()
312        }
313    }
314
315    /// Set current step.
316    pub fn with_step(mut self, step: u64) -> Self {
317        self.current_step = step;
318        self
319    }
320
321    /// Set grid size.
322    pub fn with_grid_size(mut self, width: u32, height: u32, depth: u32) -> Self {
323        self.grid_size = (width, height, depth);
324        self
325    }
326
327    /// Set tile size.
328    pub fn with_tile_size(mut self, x: u32, y: u32, z: u32) -> Self {
329        self.tile_size = (x, y, z);
330        self
331    }
332
333    /// Set HLC timestamp.
334    pub fn with_hlc(mut self, hlc: HlcTimestamp) -> Self {
335        self.hlc_timestamp = hlc;
336        self
337    }
338
339    /// Add custom metadata.
340    pub fn with_custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
341        self.custom.insert(key.into(), value.into());
342        self
343    }
344
345    /// Serialize metadata to bytes.
346    pub fn to_bytes(&self) -> Vec<u8> {
347        let mut bytes = Vec::new();
348
349        // Kernel ID (length-prefixed string)
350        let kernel_id_bytes = self.kernel_id.as_bytes();
351        bytes.extend_from_slice(&(kernel_id_bytes.len() as u32).to_le_bytes());
352        bytes.extend_from_slice(kernel_id_bytes);
353
354        // Kernel type
355        let kernel_type_bytes = self.kernel_type.as_bytes();
356        bytes.extend_from_slice(&(kernel_type_bytes.len() as u32).to_le_bytes());
357        bytes.extend_from_slice(kernel_type_bytes);
358
359        // Current step
360        bytes.extend_from_slice(&self.current_step.to_le_bytes());
361
362        // Grid size
363        bytes.extend_from_slice(&self.grid_size.0.to_le_bytes());
364        bytes.extend_from_slice(&self.grid_size.1.to_le_bytes());
365        bytes.extend_from_slice(&self.grid_size.2.to_le_bytes());
366
367        // Tile size
368        bytes.extend_from_slice(&self.tile_size.0.to_le_bytes());
369        bytes.extend_from_slice(&self.tile_size.1.to_le_bytes());
370        bytes.extend_from_slice(&self.tile_size.2.to_le_bytes());
371
372        // HLC timestamp
373        bytes.extend_from_slice(&self.hlc_timestamp.physical.to_le_bytes());
374        bytes.extend_from_slice(&self.hlc_timestamp.logical.to_le_bytes());
375        bytes.extend_from_slice(&self.hlc_timestamp.node_id.to_le_bytes());
376
377        // Custom metadata count
378        bytes.extend_from_slice(&(self.custom.len() as u32).to_le_bytes());
379
380        // Custom key-value pairs
381        for (key, value) in &self.custom {
382            let key_bytes = key.as_bytes();
383            bytes.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
384            bytes.extend_from_slice(key_bytes);
385
386            let value_bytes = value.as_bytes();
387            bytes.extend_from_slice(&(value_bytes.len() as u32).to_le_bytes());
388            bytes.extend_from_slice(value_bytes);
389        }
390
391        bytes
392    }
393
394    /// Deserialize metadata from bytes.
395    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
396        let mut offset = 0;
397
398        // Helper to read u32
399        let read_u32 = |off: &mut usize| -> Result<u32> {
400            if *off + 4 > bytes.len() {
401                return Err(RingKernelError::InvalidCheckpoint(
402                    "Unexpected end of metadata".to_string(),
403                ));
404            }
405            let val = u32::from_le_bytes(bytes[*off..*off + 4].try_into().unwrap());
406            *off += 4;
407            Ok(val)
408        };
409
410        // Helper to read u64
411        let read_u64 = |off: &mut usize| -> Result<u64> {
412            if *off + 8 > bytes.len() {
413                return Err(RingKernelError::InvalidCheckpoint(
414                    "Unexpected end of metadata".to_string(),
415                ));
416            }
417            let val = u64::from_le_bytes(bytes[*off..*off + 8].try_into().unwrap());
418            *off += 8;
419            Ok(val)
420        };
421
422        // Helper to read string
423        let read_string = |off: &mut usize| -> Result<String> {
424            let len = read_u32(off)? as usize;
425            if *off + len > bytes.len() {
426                return Err(RingKernelError::InvalidCheckpoint(
427                    "Unexpected end of metadata".to_string(),
428                ));
429            }
430            let s = String::from_utf8(bytes[*off..*off + len].to_vec())
431                .map_err(|e| RingKernelError::InvalidCheckpoint(e.to_string()))?;
432            *off += len;
433            Ok(s)
434        };
435
436        let kernel_id = read_string(&mut offset)?;
437        let kernel_type = read_string(&mut offset)?;
438        let current_step = read_u64(&mut offset)?;
439
440        let grid_size = (
441            read_u32(&mut offset)?,
442            read_u32(&mut offset)?,
443            read_u32(&mut offset)?,
444        );
445
446        let tile_size = (
447            read_u32(&mut offset)?,
448            read_u32(&mut offset)?,
449            read_u32(&mut offset)?,
450        );
451
452        let hlc_timestamp = HlcTimestamp {
453            physical: read_u64(&mut offset)?,
454            logical: read_u64(&mut offset)?,
455            node_id: read_u64(&mut offset)?,
456        };
457
458        let custom_count = read_u32(&mut offset)? as usize;
459        let mut custom = HashMap::new();
460
461        for _ in 0..custom_count {
462            let key = read_string(&mut offset)?;
463            let value = read_string(&mut offset)?;
464            custom.insert(key, value);
465        }
466
467        Ok(Self {
468            kernel_id,
469            kernel_type,
470            current_step,
471            grid_size,
472            tile_size,
473            hlc_timestamp,
474            custom,
475        })
476    }
477}
478
479// ============================================================================
480// Checkpoint Data Chunk
481// ============================================================================
482
483/// A single data chunk in a checkpoint.
484#[derive(Debug, Clone)]
485pub struct DataChunk {
486    /// Chunk header.
487    pub header: ChunkHeader,
488    /// Chunk data (may be compressed).
489    pub data: Vec<u8>,
490}
491
492impl DataChunk {
493    /// Create a new data chunk.
494    pub fn new(chunk_type: ChunkType, data: Vec<u8>) -> Self {
495        Self {
496            header: ChunkHeader::new(chunk_type, data.len()),
497            data,
498        }
499    }
500
501    /// Create a chunk with a custom ID.
502    pub fn with_id(chunk_type: ChunkType, data: Vec<u8>, id: u64) -> Self {
503        Self {
504            header: ChunkHeader::new(chunk_type, data.len()).with_id(id),
505            data,
506        }
507    }
508
509    /// Get the chunk type.
510    pub fn chunk_type(&self) -> Option<ChunkType> {
511        ChunkType::from_u32(self.header.chunk_type)
512    }
513}
514
515// ============================================================================
516// Checkpoint
517// ============================================================================
518
519/// Complete checkpoint containing all kernel state.
520#[derive(Debug, Clone)]
521pub struct Checkpoint {
522    /// Checkpoint header.
523    pub header: CheckpointHeader,
524    /// Kernel metadata.
525    pub metadata: CheckpointMetadata,
526    /// Data chunks.
527    pub chunks: Vec<DataChunk>,
528}
529
530impl Checkpoint {
531    /// Create a new checkpoint.
532    pub fn new(metadata: CheckpointMetadata) -> Self {
533        Self {
534            header: CheckpointHeader::new(0, 0),
535            metadata,
536            chunks: Vec::new(),
537        }
538    }
539
540    /// Add a data chunk.
541    pub fn add_chunk(&mut self, chunk: DataChunk) {
542        self.chunks.push(chunk);
543    }
544
545    /// Add control block data.
546    pub fn add_control_block(&mut self, data: Vec<u8>) {
547        self.add_chunk(DataChunk::new(ChunkType::ControlBlock, data));
548    }
549
550    /// Add H2K queue data.
551    pub fn add_h2k_queue(&mut self, data: Vec<u8>) {
552        self.add_chunk(DataChunk::new(ChunkType::H2KQueue, data));
553    }
554
555    /// Add K2H queue data.
556    pub fn add_k2h_queue(&mut self, data: Vec<u8>) {
557        self.add_chunk(DataChunk::new(ChunkType::K2HQueue, data));
558    }
559
560    /// Add HLC state.
561    pub fn add_hlc_state(&mut self, data: Vec<u8>) {
562        self.add_chunk(DataChunk::new(ChunkType::HlcState, data));
563    }
564
565    /// Add device memory region.
566    pub fn add_device_memory(&mut self, name: &str, data: Vec<u8>) {
567        // Use hash of name as chunk ID
568        use std::collections::hash_map::DefaultHasher;
569        use std::hash::{Hash, Hasher};
570        let mut hasher = DefaultHasher::new();
571        name.hash(&mut hasher);
572        let id = hasher.finish();
573
574        self.add_chunk(DataChunk::with_id(ChunkType::DeviceMemory, data, id));
575    }
576
577    /// Get a chunk by type.
578    pub fn get_chunk(&self, chunk_type: ChunkType) -> Option<&DataChunk> {
579        self.chunks
580            .iter()
581            .find(|c| c.chunk_type() == Some(chunk_type))
582    }
583
584    /// Get all chunks of a type.
585    pub fn get_chunks(&self, chunk_type: ChunkType) -> Vec<&DataChunk> {
586        self.chunks
587            .iter()
588            .filter(|c| c.chunk_type() == Some(chunk_type))
589            .collect()
590    }
591
592    /// Calculate total size in bytes.
593    pub fn total_size(&self) -> usize {
594        let header_size = std::mem::size_of::<CheckpointHeader>();
595        let metadata_bytes = self.metadata.to_bytes();
596        let metadata_size = 4 + metadata_bytes.len(); // length prefix + data
597
598        let chunks_size: usize = self
599            .chunks
600            .iter()
601            .map(|c| std::mem::size_of::<ChunkHeader>() + c.data.len())
602            .sum();
603
604        header_size + metadata_size + chunks_size
605    }
606
607    /// Serialize checkpoint to bytes.
608    pub fn to_bytes(&self) -> Vec<u8> {
609        let mut bytes = Vec::new();
610
611        // Metadata as bytes
612        let metadata_bytes = self.metadata.to_bytes();
613
614        // Calculate total size
615        let total_size = self.total_size();
616
617        // Create header with correct values
618        let header = CheckpointHeader::new(self.chunks.len() as u32, total_size as u64);
619
620        // Write header
621        bytes.extend_from_slice(&header.to_bytes());
622
623        // Write metadata (length-prefixed)
624        bytes.extend_from_slice(&(metadata_bytes.len() as u32).to_le_bytes());
625        bytes.extend_from_slice(&metadata_bytes);
626
627        // Write chunks
628        for chunk in &self.chunks {
629            bytes.extend_from_slice(&chunk.header.to_bytes());
630            bytes.extend_from_slice(&chunk.data);
631        }
632
633        // Calculate checksum (simple CRC32 of data after header) and update in place
634        let checksum = crc32_simple(&bytes[64..]);
635        bytes[32..36].copy_from_slice(&checksum.to_le_bytes());
636
637        bytes
638    }
639
640    /// Deserialize checkpoint from bytes.
641    pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
642        if bytes.len() < 64 {
643            return Err(RingKernelError::InvalidCheckpoint(
644                "Checkpoint too small".to_string(),
645            ));
646        }
647
648        // Read header
649        let header = CheckpointHeader::from_bytes(bytes[0..64].try_into().unwrap());
650        header.validate()?;
651
652        // Verify checksum
653        let expected_checksum = crc32_simple(&bytes[64..]);
654        if header.checksum != expected_checksum {
655            return Err(RingKernelError::InvalidCheckpoint(format!(
656                "Checksum mismatch: expected {}, got {}",
657                expected_checksum, header.checksum
658            )));
659        }
660
661        let mut offset = 64;
662
663        // Read metadata length
664        if offset + 4 > bytes.len() {
665            return Err(RingKernelError::InvalidCheckpoint(
666                "Missing metadata length".to_string(),
667            ));
668        }
669        let metadata_len =
670            u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()) as usize;
671        offset += 4;
672
673        // Read metadata
674        if offset + metadata_len > bytes.len() {
675            return Err(RingKernelError::InvalidCheckpoint(
676                "Metadata truncated".to_string(),
677            ));
678        }
679        let metadata = CheckpointMetadata::from_bytes(&bytes[offset..offset + metadata_len])?;
680        offset += metadata_len;
681
682        // Read chunks
683        let mut chunks = Vec::new();
684        for _ in 0..header.chunk_count {
685            if offset + 32 > bytes.len() {
686                return Err(RingKernelError::InvalidCheckpoint(
687                    "Chunk header truncated".to_string(),
688                ));
689            }
690
691            let chunk_header =
692                ChunkHeader::from_bytes(bytes[offset..offset + 32].try_into().unwrap());
693            offset += 32;
694
695            let data_len = chunk_header.compressed_size as usize;
696            if offset + data_len > bytes.len() {
697                return Err(RingKernelError::InvalidCheckpoint(
698                    "Chunk data truncated".to_string(),
699                ));
700            }
701
702            let data = bytes[offset..offset + data_len].to_vec();
703            offset += data_len;
704
705            chunks.push(DataChunk {
706                header: chunk_header,
707                data,
708            });
709        }
710
711        Ok(Self {
712            header,
713            metadata,
714            chunks,
715        })
716    }
717}
718
719// ============================================================================
720// Simple CRC32 Implementation
721// ============================================================================
722
723/// Simple CRC32 checksum (IEEE polynomial).
724fn crc32_simple(data: &[u8]) -> u32 {
725    const CRC32_TABLE: [u32; 256] = crc32_table();
726
727    let mut crc = 0xFFFFFFFF;
728    for byte in data {
729        let index = ((crc ^ (*byte as u32)) & 0xFF) as usize;
730        crc = CRC32_TABLE[index] ^ (crc >> 8);
731    }
732    !crc
733}
734
735/// Generate CRC32 lookup table at compile time.
736const fn crc32_table() -> [u32; 256] {
737    let mut table = [0u32; 256];
738    let mut i = 0;
739    while i < 256 {
740        let mut crc = i as u32;
741        let mut j = 0;
742        while j < 8 {
743            if crc & 1 != 0 {
744                crc = (crc >> 1) ^ 0xEDB88320;
745            } else {
746                crc >>= 1;
747            }
748            j += 1;
749        }
750        table[i] = crc;
751        i += 1;
752    }
753    table
754}
755
756// ============================================================================
757// CheckpointableKernel Trait
758// ============================================================================
759
760/// Trait for kernels that support checkpointing.
761pub trait CheckpointableKernel {
762    /// Create a checkpoint of the current kernel state.
763    ///
764    /// This should pause the kernel, serialize all state, and return a checkpoint.
765    fn create_checkpoint(&self) -> Result<Checkpoint>;
766
767    /// Restore kernel state from a checkpoint.
768    ///
769    /// This should pause the kernel, deserialize state, and resume.
770    fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<()>;
771
772    /// Get the kernel ID for checkpointing.
773    fn checkpoint_kernel_id(&self) -> &str;
774
775    /// Get the kernel type for checkpointing.
776    fn checkpoint_kernel_type(&self) -> &str;
777
778    /// Check if the kernel supports incremental checkpoints.
779    fn supports_incremental(&self) -> bool {
780        false
781    }
782
783    /// Create an incremental checkpoint (only changed state since last checkpoint).
784    fn create_incremental_checkpoint(&self, _base: &Checkpoint) -> Result<Checkpoint> {
785        // Default: fall back to full checkpoint
786        self.create_checkpoint()
787    }
788}
789
790// ============================================================================
791// Checkpoint Storage Trait
792// ============================================================================
793
794/// Trait for checkpoint storage backends.
795pub trait CheckpointStorage: Send + Sync {
796    /// Save a checkpoint with the given name.
797    fn save(&self, checkpoint: &Checkpoint, name: &str) -> Result<()>;
798
799    /// Load a checkpoint by name.
800    fn load(&self, name: &str) -> Result<Checkpoint>;
801
802    /// List all available checkpoints.
803    fn list(&self) -> Result<Vec<String>>;
804
805    /// Delete a checkpoint.
806    fn delete(&self, name: &str) -> Result<()>;
807
808    /// Check if a checkpoint exists.
809    fn exists(&self, name: &str) -> bool;
810}
811
812// ============================================================================
813// File Storage Backend
814// ============================================================================
815
816/// File-based checkpoint storage.
817pub struct FileStorage {
818    /// Base directory for checkpoint files.
819    base_path: PathBuf,
820}
821
822impl FileStorage {
823    /// Create a new file storage backend.
824    pub fn new(base_path: impl AsRef<Path>) -> Self {
825        Self {
826            base_path: base_path.as_ref().to_path_buf(),
827        }
828    }
829
830    /// Get the full path for a checkpoint.
831    fn checkpoint_path(&self, name: &str) -> PathBuf {
832        self.base_path.join(format!("{}.rkcp", name))
833    }
834}
835
836impl CheckpointStorage for FileStorage {
837    fn save(&self, checkpoint: &Checkpoint, name: &str) -> Result<()> {
838        // Ensure directory exists
839        std::fs::create_dir_all(&self.base_path).map_err(|e| {
840            RingKernelError::IoError(format!("Failed to create checkpoint directory: {}", e))
841        })?;
842
843        let path = self.checkpoint_path(name);
844        let bytes = checkpoint.to_bytes();
845
846        let mut file = std::fs::File::create(&path).map_err(|e| {
847            RingKernelError::IoError(format!("Failed to create checkpoint file: {}", e))
848        })?;
849
850        file.write_all(&bytes)
851            .map_err(|e| RingKernelError::IoError(format!("Failed to write checkpoint: {}", e)))?;
852
853        Ok(())
854    }
855
856    fn load(&self, name: &str) -> Result<Checkpoint> {
857        let path = self.checkpoint_path(name);
858
859        let mut file = std::fs::File::open(&path).map_err(|e| {
860            RingKernelError::IoError(format!("Failed to open checkpoint file: {}", e))
861        })?;
862
863        let mut bytes = Vec::new();
864        file.read_to_end(&mut bytes)
865            .map_err(|e| RingKernelError::IoError(format!("Failed to read checkpoint: {}", e)))?;
866
867        Checkpoint::from_bytes(&bytes)
868    }
869
870    fn list(&self) -> Result<Vec<String>> {
871        let entries = std::fs::read_dir(&self.base_path).map_err(|e| {
872            RingKernelError::IoError(format!("Failed to read checkpoint directory: {}", e))
873        })?;
874
875        let mut names = Vec::new();
876        for entry in entries.flatten() {
877            let path = entry.path();
878            if path.extension().map(|e| e == "rkcp").unwrap_or(false) {
879                if let Some(stem) = path.file_stem() {
880                    names.push(stem.to_string_lossy().to_string());
881                }
882            }
883        }
884
885        names.sort();
886        Ok(names)
887    }
888
889    fn delete(&self, name: &str) -> Result<()> {
890        let path = self.checkpoint_path(name);
891        std::fs::remove_file(&path)
892            .map_err(|e| RingKernelError::IoError(format!("Failed to delete checkpoint: {}", e)))?;
893        Ok(())
894    }
895
896    fn exists(&self, name: &str) -> bool {
897        self.checkpoint_path(name).exists()
898    }
899}
900
901// ============================================================================
902// Memory Storage Backend
903// ============================================================================
904
905/// In-memory checkpoint storage (for testing and fast operations).
906pub struct MemoryStorage {
907    /// Stored checkpoints.
908    checkpoints: std::sync::RwLock<HashMap<String, Vec<u8>>>,
909}
910
911impl MemoryStorage {
912    /// Create a new memory storage backend.
913    pub fn new() -> Self {
914        Self {
915            checkpoints: std::sync::RwLock::new(HashMap::new()),
916        }
917    }
918}
919
920impl Default for MemoryStorage {
921    fn default() -> Self {
922        Self::new()
923    }
924}
925
926impl CheckpointStorage for MemoryStorage {
927    fn save(&self, checkpoint: &Checkpoint, name: &str) -> Result<()> {
928        let bytes = checkpoint.to_bytes();
929        let mut checkpoints = self
930            .checkpoints
931            .write()
932            .map_err(|_| RingKernelError::IoError("Failed to acquire write lock".to_string()))?;
933        checkpoints.insert(name.to_string(), bytes);
934        Ok(())
935    }
936
937    fn load(&self, name: &str) -> Result<Checkpoint> {
938        let checkpoints = self
939            .checkpoints
940            .read()
941            .map_err(|_| RingKernelError::IoError("Failed to acquire read lock".to_string()))?;
942
943        let bytes = checkpoints
944            .get(name)
945            .ok_or_else(|| RingKernelError::IoError(format!("Checkpoint not found: {}", name)))?;
946
947        Checkpoint::from_bytes(bytes)
948    }
949
950    fn list(&self) -> Result<Vec<String>> {
951        let checkpoints = self
952            .checkpoints
953            .read()
954            .map_err(|_| RingKernelError::IoError("Failed to acquire read lock".to_string()))?;
955
956        let mut names: Vec<_> = checkpoints.keys().cloned().collect();
957        names.sort();
958        Ok(names)
959    }
960
961    fn delete(&self, name: &str) -> Result<()> {
962        let mut checkpoints = self
963            .checkpoints
964            .write()
965            .map_err(|_| RingKernelError::IoError("Failed to acquire write lock".to_string()))?;
966
967        checkpoints
968            .remove(name)
969            .ok_or_else(|| RingKernelError::IoError(format!("Checkpoint not found: {}", name)))?;
970
971        Ok(())
972    }
973
974    fn exists(&self, name: &str) -> bool {
975        self.checkpoints
976            .read()
977            .map(|c| c.contains_key(name))
978            .unwrap_or(false)
979    }
980}
981
982// ============================================================================
983// Checkpoint Builder
984// ============================================================================
985
986/// Builder for creating checkpoints incrementally.
987pub struct CheckpointBuilder {
988    metadata: CheckpointMetadata,
989    chunks: Vec<DataChunk>,
990}
991
992impl CheckpointBuilder {
993    /// Create a new checkpoint builder.
994    pub fn new(kernel_id: impl Into<String>, kernel_type: impl Into<String>) -> Self {
995        Self {
996            metadata: CheckpointMetadata::new(kernel_id, kernel_type),
997            chunks: Vec::new(),
998        }
999    }
1000
1001    /// Set the current step.
1002    pub fn step(mut self, step: u64) -> Self {
1003        self.metadata.current_step = step;
1004        self
1005    }
1006
1007    /// Set grid size.
1008    pub fn grid_size(mut self, width: u32, height: u32, depth: u32) -> Self {
1009        self.metadata.grid_size = (width, height, depth);
1010        self
1011    }
1012
1013    /// Set tile size.
1014    pub fn tile_size(mut self, x: u32, y: u32, z: u32) -> Self {
1015        self.metadata.tile_size = (x, y, z);
1016        self
1017    }
1018
1019    /// Set HLC timestamp.
1020    pub fn hlc(mut self, hlc: HlcTimestamp) -> Self {
1021        self.metadata.hlc_timestamp = hlc;
1022        self
1023    }
1024
1025    /// Add custom metadata.
1026    pub fn custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1027        self.metadata.custom.insert(key.into(), value.into());
1028        self
1029    }
1030
1031    /// Add control block data.
1032    pub fn control_block(mut self, data: Vec<u8>) -> Self {
1033        self.chunks
1034            .push(DataChunk::new(ChunkType::ControlBlock, data));
1035        self
1036    }
1037
1038    /// Add H2K queue data.
1039    pub fn h2k_queue(mut self, data: Vec<u8>) -> Self {
1040        self.chunks.push(DataChunk::new(ChunkType::H2KQueue, data));
1041        self
1042    }
1043
1044    /// Add K2H queue data.
1045    pub fn k2h_queue(mut self, data: Vec<u8>) -> Self {
1046        self.chunks.push(DataChunk::new(ChunkType::K2HQueue, data));
1047        self
1048    }
1049
1050    /// Add device memory region.
1051    pub fn device_memory(mut self, name: &str, data: Vec<u8>) -> Self {
1052        use std::collections::hash_map::DefaultHasher;
1053        use std::hash::{Hash, Hasher};
1054        let mut hasher = DefaultHasher::new();
1055        name.hash(&mut hasher);
1056        let id = hasher.finish();
1057
1058        self.chunks
1059            .push(DataChunk::with_id(ChunkType::DeviceMemory, data, id));
1060        self
1061    }
1062
1063    /// Add a custom chunk.
1064    pub fn chunk(mut self, chunk: DataChunk) -> Self {
1065        self.chunks.push(chunk);
1066        self
1067    }
1068
1069    /// Build the checkpoint.
1070    pub fn build(self) -> Checkpoint {
1071        let mut checkpoint = Checkpoint::new(self.metadata);
1072        checkpoint.chunks = self.chunks;
1073        checkpoint
1074    }
1075}
1076
1077// ============================================================================
1078// Tests
1079// ============================================================================
1080
1081#[cfg(test)]
1082mod tests {
1083    use super::*;
1084
1085    #[test]
1086    fn test_checkpoint_header_roundtrip() {
1087        let header = CheckpointHeader::new(5, 1024);
1088        let bytes = header.to_bytes();
1089        let restored = CheckpointHeader::from_bytes(&bytes);
1090
1091        assert_eq!(restored.magic, CHECKPOINT_MAGIC);
1092        assert_eq!(restored.version, CHECKPOINT_VERSION);
1093        assert_eq!(restored.chunk_count, 5);
1094        assert_eq!(restored.total_size, 1024);
1095    }
1096
1097    #[test]
1098    fn test_chunk_header_roundtrip() {
1099        let header = ChunkHeader::new(ChunkType::DeviceMemory, 4096).with_id(12345);
1100        let bytes = header.to_bytes();
1101        let restored = ChunkHeader::from_bytes(&bytes);
1102
1103        assert_eq!(restored.chunk_type, ChunkType::DeviceMemory as u32);
1104        assert_eq!(restored.uncompressed_size, 4096);
1105        assert_eq!(restored.chunk_id, 12345);
1106    }
1107
1108    #[test]
1109    fn test_metadata_roundtrip() {
1110        let metadata = CheckpointMetadata::new("kernel_1", "fdtd_3d")
1111            .with_step(1000)
1112            .with_grid_size(64, 64, 64)
1113            .with_tile_size(8, 8, 8)
1114            .with_custom("version", "1.0");
1115
1116        let bytes = metadata.to_bytes();
1117        let restored = CheckpointMetadata::from_bytes(&bytes).unwrap();
1118
1119        assert_eq!(restored.kernel_id, "kernel_1");
1120        assert_eq!(restored.kernel_type, "fdtd_3d");
1121        assert_eq!(restored.current_step, 1000);
1122        assert_eq!(restored.grid_size, (64, 64, 64));
1123        assert_eq!(restored.tile_size, (8, 8, 8));
1124        assert_eq!(restored.custom.get("version"), Some(&"1.0".to_string()));
1125    }
1126
1127    #[test]
1128    fn test_checkpoint_roundtrip() {
1129        let checkpoint = CheckpointBuilder::new("test_kernel", "test_type")
1130            .step(500)
1131            .grid_size(32, 32, 32)
1132            .control_block(vec![1, 2, 3, 4])
1133            .device_memory("pressure_a", vec![5, 6, 7, 8, 9, 10])
1134            .build();
1135
1136        let bytes = checkpoint.to_bytes();
1137        let restored = Checkpoint::from_bytes(&bytes).unwrap();
1138
1139        assert_eq!(restored.metadata.kernel_id, "test_kernel");
1140        assert_eq!(restored.metadata.current_step, 500);
1141        assert_eq!(restored.chunks.len(), 2);
1142
1143        let control = restored.get_chunk(ChunkType::ControlBlock).unwrap();
1144        assert_eq!(control.data, vec![1, 2, 3, 4]);
1145    }
1146
1147    #[test]
1148    fn test_memory_storage() {
1149        let storage = MemoryStorage::new();
1150
1151        let checkpoint = CheckpointBuilder::new("mem_test", "test").step(100).build();
1152
1153        storage.save(&checkpoint, "test_001").unwrap();
1154        assert!(storage.exists("test_001"));
1155
1156        let loaded = storage.load("test_001").unwrap();
1157        assert_eq!(loaded.metadata.kernel_id, "mem_test");
1158        assert_eq!(loaded.metadata.current_step, 100);
1159
1160        let list = storage.list().unwrap();
1161        assert_eq!(list, vec!["test_001"]);
1162
1163        storage.delete("test_001").unwrap();
1164        assert!(!storage.exists("test_001"));
1165    }
1166
1167    #[test]
1168    fn test_crc32() {
1169        // Known CRC32 values
1170        assert_eq!(crc32_simple(b""), 0);
1171        assert_eq!(crc32_simple(b"123456789"), 0xCBF43926);
1172    }
1173
1174    #[test]
1175    fn test_checkpoint_validation() {
1176        // Test invalid magic
1177        let mut bytes = [0u8; 64];
1178        bytes[0..8].copy_from_slice(&0u64.to_le_bytes()); // Wrong magic
1179
1180        let header = CheckpointHeader::from_bytes(&bytes);
1181        assert!(header.validate().is_err());
1182    }
1183
1184    #[test]
1185    fn test_large_checkpoint() {
1186        // Test with larger data
1187        let large_data: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
1188
1189        let checkpoint = CheckpointBuilder::new("large_kernel", "stress_test")
1190            .step(999)
1191            .device_memory("field_a", large_data.clone())
1192            .device_memory("field_b", large_data.clone())
1193            .build();
1194
1195        let bytes = checkpoint.to_bytes();
1196        let restored = Checkpoint::from_bytes(&bytes).unwrap();
1197
1198        assert_eq!(restored.chunks.len(), 2);
1199        let chunks = restored.get_chunks(ChunkType::DeviceMemory);
1200        assert_eq!(chunks.len(), 2);
1201        assert_eq!(chunks[0].data.len(), 100_000);
1202    }
1203}