use std::collections::HashMap;
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::error::{Result, RingKernelError};
use crate::hlc::HlcTimestamp;
pub const CHECKPOINT_MAGIC: u64 = 0x524B434B50543031;
pub const CHECKPOINT_VERSION: u32 = 1;
pub const MAX_CHECKPOINT_SIZE: usize = 1024 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(u32)]
pub enum ChunkType {
ControlBlock = 1,
H2KQueue = 2,
K2HQueue = 3,
HlcState = 4,
DeviceMemory = 5,
K2KRouting = 6,
HaloBuffers = 7,
Telemetry = 8,
Custom = 100,
}
impl ChunkType {
pub fn from_u32(value: u32) -> Option<Self> {
match value {
1 => Some(Self::ControlBlock),
2 => Some(Self::H2KQueue),
3 => Some(Self::K2HQueue),
4 => Some(Self::HlcState),
5 => Some(Self::DeviceMemory),
6 => Some(Self::K2KRouting),
7 => Some(Self::HaloBuffers),
8 => Some(Self::Telemetry),
100 => Some(Self::Custom),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct CheckpointHeader {
pub magic: u64,
pub version: u32,
pub header_size: u32,
pub total_size: u64,
pub chunk_count: u32,
pub compression: u32,
pub checksum: u32,
pub flags: u32,
pub created_at: u64,
pub _reserved: [u8; 8],
}
impl CheckpointHeader {
pub fn new(chunk_count: u32, total_size: u64) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO);
Self {
magic: CHECKPOINT_MAGIC,
version: CHECKPOINT_VERSION,
header_size: std::mem::size_of::<Self>() as u32,
total_size,
chunk_count,
compression: 0,
checksum: 0,
flags: 0,
created_at: now.as_micros() as u64,
_reserved: [0; 8],
}
}
pub fn validate(&self) -> Result<()> {
if self.magic != CHECKPOINT_MAGIC {
return Err(RingKernelError::InvalidCheckpoint(
"Invalid magic number".to_string(),
));
}
if self.version > CHECKPOINT_VERSION {
return Err(RingKernelError::InvalidCheckpoint(format!(
"Unsupported version: {} (max: {})",
self.version, CHECKPOINT_VERSION
)));
}
if self.total_size as usize > MAX_CHECKPOINT_SIZE {
return Err(RingKernelError::InvalidCheckpoint(format!(
"Checkpoint too large: {} bytes (max: {})",
self.total_size, MAX_CHECKPOINT_SIZE
)));
}
Ok(())
}
pub fn to_bytes(&self) -> [u8; 64] {
let mut bytes = [0u8; 64];
bytes[0..8].copy_from_slice(&self.magic.to_le_bytes());
bytes[8..12].copy_from_slice(&self.version.to_le_bytes());
bytes[12..16].copy_from_slice(&self.header_size.to_le_bytes());
bytes[16..24].copy_from_slice(&self.total_size.to_le_bytes());
bytes[24..28].copy_from_slice(&self.chunk_count.to_le_bytes());
bytes[28..32].copy_from_slice(&self.compression.to_le_bytes());
bytes[32..36].copy_from_slice(&self.checksum.to_le_bytes());
bytes[36..40].copy_from_slice(&self.flags.to_le_bytes());
bytes[40..48].copy_from_slice(&self.created_at.to_le_bytes());
bytes
}
pub fn from_bytes(bytes: &[u8; 64]) -> Self {
Self {
magic: u64::from_le_bytes(bytes[0..8].try_into().expect("slice is exactly 8 bytes")),
version: u32::from_le_bytes(bytes[8..12].try_into().expect("slice is exactly 4 bytes")),
header_size: u32::from_le_bytes(
bytes[12..16].try_into().expect("slice is exactly 4 bytes"),
),
total_size: u64::from_le_bytes(
bytes[16..24].try_into().expect("slice is exactly 8 bytes"),
),
chunk_count: u32::from_le_bytes(
bytes[24..28].try_into().expect("slice is exactly 4 bytes"),
),
compression: u32::from_le_bytes(
bytes[28..32].try_into().expect("slice is exactly 4 bytes"),
),
checksum: u32::from_le_bytes(
bytes[32..36].try_into().expect("slice is exactly 4 bytes"),
),
flags: u32::from_le_bytes(bytes[36..40].try_into().expect("slice is exactly 4 bytes")),
created_at: u64::from_le_bytes(
bytes[40..48].try_into().expect("slice is exactly 8 bytes"),
),
_reserved: bytes[48..56].try_into().expect("slice is exactly 8 bytes"),
}
}
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct ChunkHeader {
pub chunk_type: u32,
pub flags: u32,
pub uncompressed_size: u64,
pub compressed_size: u64,
pub chunk_id: u64,
}
impl ChunkHeader {
pub fn new(chunk_type: ChunkType, data_size: usize) -> Self {
Self {
chunk_type: chunk_type as u32,
flags: 0,
uncompressed_size: data_size as u64,
compressed_size: data_size as u64,
chunk_id: 0,
}
}
pub fn with_id(mut self, id: u64) -> Self {
self.chunk_id = id;
self
}
pub fn to_bytes(&self) -> [u8; 32] {
let mut bytes = [0u8; 32];
bytes[0..4].copy_from_slice(&self.chunk_type.to_le_bytes());
bytes[4..8].copy_from_slice(&self.flags.to_le_bytes());
bytes[8..16].copy_from_slice(&self.uncompressed_size.to_le_bytes());
bytes[16..24].copy_from_slice(&self.compressed_size.to_le_bytes());
bytes[24..32].copy_from_slice(&self.chunk_id.to_le_bytes());
bytes
}
pub fn from_bytes(bytes: &[u8; 32]) -> Self {
Self {
chunk_type: u32::from_le_bytes(
bytes[0..4].try_into().expect("slice is exactly 4 bytes"),
),
flags: u32::from_le_bytes(bytes[4..8].try_into().expect("slice is exactly 4 bytes")),
uncompressed_size: u64::from_le_bytes(
bytes[8..16].try_into().expect("slice is exactly 8 bytes"),
),
compressed_size: u64::from_le_bytes(
bytes[16..24].try_into().expect("slice is exactly 8 bytes"),
),
chunk_id: u64::from_le_bytes(
bytes[24..32].try_into().expect("slice is exactly 8 bytes"),
),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct CheckpointMetadata {
pub kernel_id: String,
pub kernel_type: String,
pub current_step: u64,
pub grid_size: (u32, u32, u32),
pub tile_size: (u32, u32, u32),
pub hlc_timestamp: HlcTimestamp,
pub custom: HashMap<String, String>,
}
impl CheckpointMetadata {
pub fn new(kernel_id: impl Into<String>, kernel_type: impl Into<String>) -> Self {
Self {
kernel_id: kernel_id.into(),
kernel_type: kernel_type.into(),
..Default::default()
}
}
pub fn with_step(mut self, step: u64) -> Self {
self.current_step = step;
self
}
pub fn with_grid_size(mut self, width: u32, height: u32, depth: u32) -> Self {
self.grid_size = (width, height, depth);
self
}
pub fn with_tile_size(mut self, x: u32, y: u32, z: u32) -> Self {
self.tile_size = (x, y, z);
self
}
pub fn with_hlc(mut self, hlc: HlcTimestamp) -> Self {
self.hlc_timestamp = hlc;
self
}
pub fn with_custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.custom.insert(key.into(), value.into());
self
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
let kernel_id_bytes = self.kernel_id.as_bytes();
bytes.extend_from_slice(&(kernel_id_bytes.len() as u32).to_le_bytes());
bytes.extend_from_slice(kernel_id_bytes);
let kernel_type_bytes = self.kernel_type.as_bytes();
bytes.extend_from_slice(&(kernel_type_bytes.len() as u32).to_le_bytes());
bytes.extend_from_slice(kernel_type_bytes);
bytes.extend_from_slice(&self.current_step.to_le_bytes());
bytes.extend_from_slice(&self.grid_size.0.to_le_bytes());
bytes.extend_from_slice(&self.grid_size.1.to_le_bytes());
bytes.extend_from_slice(&self.grid_size.2.to_le_bytes());
bytes.extend_from_slice(&self.tile_size.0.to_le_bytes());
bytes.extend_from_slice(&self.tile_size.1.to_le_bytes());
bytes.extend_from_slice(&self.tile_size.2.to_le_bytes());
bytes.extend_from_slice(&self.hlc_timestamp.physical.to_le_bytes());
bytes.extend_from_slice(&self.hlc_timestamp.logical.to_le_bytes());
bytes.extend_from_slice(&self.hlc_timestamp.node_id.to_le_bytes());
bytes.extend_from_slice(&(self.custom.len() as u32).to_le_bytes());
for (key, value) in &self.custom {
let key_bytes = key.as_bytes();
bytes.extend_from_slice(&(key_bytes.len() as u32).to_le_bytes());
bytes.extend_from_slice(key_bytes);
let value_bytes = value.as_bytes();
bytes.extend_from_slice(&(value_bytes.len() as u32).to_le_bytes());
bytes.extend_from_slice(value_bytes);
}
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
let mut offset = 0;
let read_u32 = |off: &mut usize| -> Result<u32> {
if *off + 4 > bytes.len() {
return Err(RingKernelError::InvalidCheckpoint(
"Unexpected end of metadata".to_string(),
));
}
let val = u32::from_le_bytes(
bytes[*off..*off + 4]
.try_into()
.expect("bounds checked 4-byte slice"),
);
*off += 4;
Ok(val)
};
let read_u64 = |off: &mut usize| -> Result<u64> {
if *off + 8 > bytes.len() {
return Err(RingKernelError::InvalidCheckpoint(
"Unexpected end of metadata".to_string(),
));
}
let val = u64::from_le_bytes(
bytes[*off..*off + 8]
.try_into()
.expect("bounds checked 8-byte slice"),
);
*off += 8;
Ok(val)
};
let read_string = |off: &mut usize| -> Result<String> {
let len = read_u32(off)? as usize;
if *off + len > bytes.len() {
return Err(RingKernelError::InvalidCheckpoint(
"Unexpected end of metadata".to_string(),
));
}
let s = String::from_utf8(bytes[*off..*off + len].to_vec())
.map_err(|e| RingKernelError::InvalidCheckpoint(e.to_string()))?;
*off += len;
Ok(s)
};
let kernel_id = read_string(&mut offset)?;
let kernel_type = read_string(&mut offset)?;
let current_step = read_u64(&mut offset)?;
let grid_size = (
read_u32(&mut offset)?,
read_u32(&mut offset)?,
read_u32(&mut offset)?,
);
let tile_size = (
read_u32(&mut offset)?,
read_u32(&mut offset)?,
read_u32(&mut offset)?,
);
let hlc_timestamp = HlcTimestamp {
physical: read_u64(&mut offset)?,
logical: read_u64(&mut offset)?,
node_id: read_u64(&mut offset)?,
};
let custom_count = read_u32(&mut offset)? as usize;
let mut custom = HashMap::new();
for _ in 0..custom_count {
let key = read_string(&mut offset)?;
let value = read_string(&mut offset)?;
custom.insert(key, value);
}
Ok(Self {
kernel_id,
kernel_type,
current_step,
grid_size,
tile_size,
hlc_timestamp,
custom,
})
}
}
#[derive(Debug, Clone)]
pub struct DataChunk {
pub header: ChunkHeader,
pub data: Vec<u8>,
}
impl DataChunk {
pub fn new(chunk_type: ChunkType, data: Vec<u8>) -> Self {
Self {
header: ChunkHeader::new(chunk_type, data.len()),
data,
}
}
pub fn with_id(chunk_type: ChunkType, data: Vec<u8>, id: u64) -> Self {
Self {
header: ChunkHeader::new(chunk_type, data.len()).with_id(id),
data,
}
}
pub fn chunk_type(&self) -> Option<ChunkType> {
ChunkType::from_u32(self.header.chunk_type)
}
}
#[derive(Debug, Clone)]
pub struct Checkpoint {
pub header: CheckpointHeader,
pub metadata: CheckpointMetadata,
pub chunks: Vec<DataChunk>,
}
impl Checkpoint {
pub fn new(metadata: CheckpointMetadata) -> Self {
Self {
header: CheckpointHeader::new(0, 0),
metadata,
chunks: Vec::new(),
}
}
pub fn add_chunk(&mut self, chunk: DataChunk) {
self.chunks.push(chunk);
}
pub fn add_control_block(&mut self, data: Vec<u8>) {
self.add_chunk(DataChunk::new(ChunkType::ControlBlock, data));
}
pub fn add_h2k_queue(&mut self, data: Vec<u8>) {
self.add_chunk(DataChunk::new(ChunkType::H2KQueue, data));
}
pub fn add_k2h_queue(&mut self, data: Vec<u8>) {
self.add_chunk(DataChunk::new(ChunkType::K2HQueue, data));
}
pub fn add_hlc_state(&mut self, data: Vec<u8>) {
self.add_chunk(DataChunk::new(ChunkType::HlcState, data));
}
pub fn add_device_memory(&mut self, name: &str, data: Vec<u8>) {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
name.hash(&mut hasher);
let id = hasher.finish();
self.add_chunk(DataChunk::with_id(ChunkType::DeviceMemory, data, id));
}
pub fn get_chunk(&self, chunk_type: ChunkType) -> Option<&DataChunk> {
self.chunks
.iter()
.find(|c| c.chunk_type() == Some(chunk_type))
}
pub fn get_chunks(&self, chunk_type: ChunkType) -> Vec<&DataChunk> {
self.chunks
.iter()
.filter(|c| c.chunk_type() == Some(chunk_type))
.collect()
}
pub fn total_size(&self) -> usize {
let header_size = std::mem::size_of::<CheckpointHeader>();
let metadata_bytes = self.metadata.to_bytes();
let metadata_size = 4 + metadata_bytes.len();
let chunks_size: usize = self
.chunks
.iter()
.map(|c| std::mem::size_of::<ChunkHeader>() + c.data.len())
.sum();
header_size + metadata_size + chunks_size
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::new();
let metadata_bytes = self.metadata.to_bytes();
let total_size = self.total_size();
let header = CheckpointHeader::new(self.chunks.len() as u32, total_size as u64);
bytes.extend_from_slice(&header.to_bytes());
bytes.extend_from_slice(&(metadata_bytes.len() as u32).to_le_bytes());
bytes.extend_from_slice(&metadata_bytes);
for chunk in &self.chunks {
bytes.extend_from_slice(&chunk.header.to_bytes());
bytes.extend_from_slice(&chunk.data);
}
let checksum = crc32_simple(&bytes[64..]);
bytes[32..36].copy_from_slice(&checksum.to_le_bytes());
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self> {
if bytes.len() < 64 {
return Err(RingKernelError::InvalidCheckpoint(
"Checkpoint too small".to_string(),
));
}
let header = CheckpointHeader::from_bytes(
bytes[0..64]
.try_into()
.expect("input validated to be >= 64 bytes"),
);
header.validate()?;
let expected_checksum = crc32_simple(&bytes[64..]);
if header.checksum != expected_checksum {
return Err(RingKernelError::InvalidCheckpoint(format!(
"Checksum mismatch: expected {}, got {}",
expected_checksum, header.checksum
)));
}
let mut offset = 64;
if offset + 4 > bytes.len() {
return Err(RingKernelError::InvalidCheckpoint(
"Missing metadata length".to_string(),
));
}
let metadata_len = u32::from_le_bytes(
bytes[offset..offset + 4]
.try_into()
.expect("bounds checked 4-byte slice"),
) as usize;
offset += 4;
if offset + metadata_len > bytes.len() {
return Err(RingKernelError::InvalidCheckpoint(
"Metadata truncated".to_string(),
));
}
let metadata = CheckpointMetadata::from_bytes(&bytes[offset..offset + metadata_len])?;
offset += metadata_len;
let mut chunks = Vec::new();
for _ in 0..header.chunk_count {
if offset + 32 > bytes.len() {
return Err(RingKernelError::InvalidCheckpoint(
"Chunk header truncated".to_string(),
));
}
let chunk_header = ChunkHeader::from_bytes(
bytes[offset..offset + 32]
.try_into()
.expect("bounds checked 32-byte slice"),
);
offset += 32;
let data_len = chunk_header.compressed_size as usize;
if offset + data_len > bytes.len() {
return Err(RingKernelError::InvalidCheckpoint(
"Chunk data truncated".to_string(),
));
}
let data = bytes[offset..offset + data_len].to_vec();
offset += data_len;
chunks.push(DataChunk {
header: chunk_header,
data,
});
}
Ok(Self {
header,
metadata,
chunks,
})
}
}
pub const DELTA_PARENT_DIGEST_KEY: &str = "ringkernel.delta.parent_digest";
impl DataChunk {
pub fn chunk_identity(&self) -> Option<(ChunkType, Option<u64>)> {
let kind = self.chunk_type()?;
let id = match kind {
ChunkType::DeviceMemory | ChunkType::Custom => Some(self.header.chunk_id),
_ => None,
};
Some((kind, id))
}
}
impl Checkpoint {
pub fn content_digest(&self) -> String {
let mut acc: u32 = 0xFFFF_FFFF;
for chunk in &self.chunks {
if let Some((kind, id)) = chunk.chunk_identity() {
let mut header = [0u8; 16];
header[0..4].copy_from_slice(&(kind as u32).to_le_bytes());
header[4..12].copy_from_slice(&id.unwrap_or(0).to_le_bytes());
acc = crc32_update(acc, &header);
acc = crc32_update(acc, &chunk.data);
}
}
format!("{:08x}", !acc)
}
pub fn delta_from(base: &Checkpoint, new: &Checkpoint) -> Checkpoint {
use std::collections::HashMap;
let mut base_index: HashMap<(ChunkType, Option<u64>), &DataChunk> = HashMap::new();
for chunk in &base.chunks {
if let Some(id) = chunk.chunk_identity() {
base_index.insert(id, chunk);
}
}
let mut delta = Checkpoint::new(new.metadata.clone());
delta.metadata = delta
.metadata
.with_custom(DELTA_PARENT_DIGEST_KEY, base.content_digest());
for chunk in &new.chunks {
let Some(identity) = chunk.chunk_identity() else {
continue;
};
match base_index.get(&identity) {
Some(old) if old.data == chunk.data => { }
_ => delta.chunks.push(chunk.clone()),
}
}
delta
}
pub fn applied_with_delta(base: &Checkpoint, delta: &Checkpoint) -> Result<Checkpoint> {
if let Some(recorded) = delta.metadata.custom.get(DELTA_PARENT_DIGEST_KEY) {
let actual = base.content_digest();
if recorded != &actual {
return Err(RingKernelError::InvalidCheckpoint(format!(
"delta parent digest mismatch: expected {recorded}, got {actual}"
)));
}
}
use std::collections::HashMap;
let mut out = Checkpoint::new(delta.metadata.clone());
let mut delta_index: HashMap<(ChunkType, Option<u64>), &DataChunk> = HashMap::new();
for chunk in &delta.chunks {
if let Some(id) = chunk.chunk_identity() {
delta_index.insert(id, chunk);
}
}
let mut replaced: std::collections::HashSet<(ChunkType, Option<u64>)> =
std::collections::HashSet::new();
for chunk in &base.chunks {
match chunk.chunk_identity() {
Some(id) if delta_index.contains_key(&id) => {
out.chunks.push(delta_index[&id].clone());
replaced.insert(id);
}
_ => out.chunks.push(chunk.clone()),
}
}
for chunk in &delta.chunks {
if let Some(id) = chunk.chunk_identity() {
if !replaced.contains(&id) {
out.chunks.push(chunk.clone());
}
}
}
Ok(out)
}
}
fn crc32_update(mut crc: u32, data: &[u8]) -> u32 {
const POLY: u32 = 0xEDB88320;
for &b in data {
crc ^= b as u32;
for _ in 0..8 {
let mask = (crc & 1).wrapping_neg();
crc = (crc >> 1) ^ (POLY & mask);
}
}
crc
}
fn crc32_simple(data: &[u8]) -> u32 {
const CRC32_TABLE: [u32; 256] = crc32_table();
let mut crc = 0xFFFFFFFF;
for byte in data {
let index = ((crc ^ (*byte as u32)) & 0xFF) as usize;
crc = CRC32_TABLE[index] ^ (crc >> 8);
}
!crc
}
const fn crc32_table() -> [u32; 256] {
let mut table = [0u32; 256];
let mut i = 0;
while i < 256 {
let mut crc = i as u32;
let mut j = 0;
while j < 8 {
if crc & 1 != 0 {
crc = (crc >> 1) ^ 0xEDB88320;
} else {
crc >>= 1;
}
j += 1;
}
table[i] = crc;
i += 1;
}
table
}
pub trait CheckpointableKernel {
fn create_checkpoint(&self) -> Result<Checkpoint>;
fn restore_from_checkpoint(&mut self, checkpoint: &Checkpoint) -> Result<()>;
fn checkpoint_kernel_id(&self) -> &str;
fn checkpoint_kernel_type(&self) -> &str;
fn supports_incremental(&self) -> bool {
false
}
fn create_incremental_checkpoint(&self, _base: &Checkpoint) -> Result<Checkpoint> {
self.create_checkpoint()
}
}
pub trait CheckpointStorage: Send + Sync {
fn save(&self, checkpoint: &Checkpoint, name: &str) -> Result<()>;
fn load(&self, name: &str) -> Result<Checkpoint>;
fn list(&self) -> Result<Vec<String>>;
fn delete(&self, name: &str) -> Result<()>;
fn exists(&self, name: &str) -> bool;
}
pub struct FileStorage {
base_path: PathBuf,
}
impl FileStorage {
pub fn new(base_path: impl AsRef<Path>) -> Self {
Self {
base_path: base_path.as_ref().to_path_buf(),
}
}
fn checkpoint_path(&self, name: &str) -> PathBuf {
self.base_path.join(format!("{}.rkcp", name))
}
}
impl CheckpointStorage for FileStorage {
fn save(&self, checkpoint: &Checkpoint, name: &str) -> Result<()> {
std::fs::create_dir_all(&self.base_path).map_err(|e| {
RingKernelError::IoError(format!("Failed to create checkpoint directory: {}", e))
})?;
let path = self.checkpoint_path(name);
let bytes = checkpoint.to_bytes();
let mut file = std::fs::File::create(&path).map_err(|e| {
RingKernelError::IoError(format!("Failed to create checkpoint file: {}", e))
})?;
file.write_all(&bytes)
.map_err(|e| RingKernelError::IoError(format!("Failed to write checkpoint: {}", e)))?;
Ok(())
}
fn load(&self, name: &str) -> Result<Checkpoint> {
let path = self.checkpoint_path(name);
let mut file = std::fs::File::open(&path).map_err(|e| {
RingKernelError::IoError(format!("Failed to open checkpoint file: {}", e))
})?;
let mut bytes = Vec::new();
file.read_to_end(&mut bytes)
.map_err(|e| RingKernelError::IoError(format!("Failed to read checkpoint: {}", e)))?;
Checkpoint::from_bytes(&bytes)
}
fn list(&self) -> Result<Vec<String>> {
let entries = std::fs::read_dir(&self.base_path).map_err(|e| {
RingKernelError::IoError(format!("Failed to read checkpoint directory: {}", e))
})?;
let mut names = Vec::new();
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map(|e| e == "rkcp").unwrap_or(false) {
if let Some(stem) = path.file_stem() {
names.push(stem.to_string_lossy().to_string());
}
}
}
names.sort();
Ok(names)
}
fn delete(&self, name: &str) -> Result<()> {
let path = self.checkpoint_path(name);
std::fs::remove_file(&path)
.map_err(|e| RingKernelError::IoError(format!("Failed to delete checkpoint: {}", e)))?;
Ok(())
}
fn exists(&self, name: &str) -> bool {
self.checkpoint_path(name).exists()
}
}
pub struct MemoryStorage {
checkpoints: std::sync::RwLock<HashMap<String, Vec<u8>>>,
}
impl MemoryStorage {
pub fn new() -> Self {
Self {
checkpoints: std::sync::RwLock::new(HashMap::new()),
}
}
}
impl Default for MemoryStorage {
fn default() -> Self {
Self::new()
}
}
impl CheckpointStorage for MemoryStorage {
fn save(&self, checkpoint: &Checkpoint, name: &str) -> Result<()> {
let bytes = checkpoint.to_bytes();
let mut checkpoints = self
.checkpoints
.write()
.map_err(|_| RingKernelError::IoError("Failed to acquire write lock".to_string()))?;
checkpoints.insert(name.to_string(), bytes);
Ok(())
}
fn load(&self, name: &str) -> Result<Checkpoint> {
let checkpoints = self
.checkpoints
.read()
.map_err(|_| RingKernelError::IoError("Failed to acquire read lock".to_string()))?;
let bytes = checkpoints
.get(name)
.ok_or_else(|| RingKernelError::IoError(format!("Checkpoint not found: {}", name)))?;
Checkpoint::from_bytes(bytes)
}
fn list(&self) -> Result<Vec<String>> {
let checkpoints = self
.checkpoints
.read()
.map_err(|_| RingKernelError::IoError("Failed to acquire read lock".to_string()))?;
let mut names: Vec<_> = checkpoints.keys().cloned().collect();
names.sort();
Ok(names)
}
fn delete(&self, name: &str) -> Result<()> {
let mut checkpoints = self
.checkpoints
.write()
.map_err(|_| RingKernelError::IoError("Failed to acquire write lock".to_string()))?;
checkpoints
.remove(name)
.ok_or_else(|| RingKernelError::IoError(format!("Checkpoint not found: {}", name)))?;
Ok(())
}
fn exists(&self, name: &str) -> bool {
self.checkpoints
.read()
.map(|c| c.contains_key(name))
.unwrap_or(false)
}
}
pub struct CheckpointBuilder {
metadata: CheckpointMetadata,
chunks: Vec<DataChunk>,
}
impl CheckpointBuilder {
pub fn new(kernel_id: impl Into<String>, kernel_type: impl Into<String>) -> Self {
Self {
metadata: CheckpointMetadata::new(kernel_id, kernel_type),
chunks: Vec::new(),
}
}
pub fn step(mut self, step: u64) -> Self {
self.metadata.current_step = step;
self
}
pub fn grid_size(mut self, width: u32, height: u32, depth: u32) -> Self {
self.metadata.grid_size = (width, height, depth);
self
}
pub fn tile_size(mut self, x: u32, y: u32, z: u32) -> Self {
self.metadata.tile_size = (x, y, z);
self
}
pub fn hlc(mut self, hlc: HlcTimestamp) -> Self {
self.metadata.hlc_timestamp = hlc;
self
}
pub fn custom(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.custom.insert(key.into(), value.into());
self
}
pub fn control_block(mut self, data: Vec<u8>) -> Self {
self.chunks
.push(DataChunk::new(ChunkType::ControlBlock, data));
self
}
pub fn h2k_queue(mut self, data: Vec<u8>) -> Self {
self.chunks.push(DataChunk::new(ChunkType::H2KQueue, data));
self
}
pub fn k2h_queue(mut self, data: Vec<u8>) -> Self {
self.chunks.push(DataChunk::new(ChunkType::K2HQueue, data));
self
}
pub fn device_memory(mut self, name: &str, data: Vec<u8>) -> Self {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
name.hash(&mut hasher);
let id = hasher.finish();
self.chunks
.push(DataChunk::with_id(ChunkType::DeviceMemory, data, id));
self
}
pub fn chunk(mut self, chunk: DataChunk) -> Self {
self.chunks.push(chunk);
self
}
pub fn build(self) -> Checkpoint {
let mut checkpoint = Checkpoint::new(self.metadata);
checkpoint.chunks = self.chunks;
checkpoint
}
}
#[derive(Debug, Clone)]
pub struct CheckpointConfig {
pub interval: Duration,
pub max_snapshots: usize,
pub storage_path: PathBuf,
pub enabled: bool,
pub name_prefix: String,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
interval: Duration::from_secs(30),
max_snapshots: 5,
storage_path: PathBuf::from("/tmp/ringkernel/checkpoints"),
enabled: true,
name_prefix: "checkpoint".to_string(),
}
}
}
impl CheckpointConfig {
pub fn new(interval: Duration) -> Self {
Self {
interval,
..Default::default()
}
}
pub fn with_max_snapshots(mut self, max: usize) -> Self {
self.max_snapshots = max;
self
}
pub fn with_storage_path(mut self, path: impl AsRef<Path>) -> Self {
self.storage_path = path.as_ref().to_path_buf();
self
}
pub fn with_name_prefix(mut self, prefix: impl Into<String>) -> Self {
self.name_prefix = prefix.into();
self
}
pub fn with_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
}
#[derive(Debug, Clone)]
pub struct SnapshotRequest {
pub request_id: u64,
pub actor_slot: u32,
pub buffer_offset: u32,
pub issued_at: SystemTime,
}
#[derive(Debug, Clone)]
pub struct SnapshotResponse {
pub request_id: u64,
pub actor_slot: u32,
pub success: bool,
pub data: Vec<u8>,
pub step: u64,
}
#[derive(Debug, Clone)]
struct PendingSnapshot {
request: SnapshotRequest,
kernel_id: String,
kernel_type: String,
}
pub struct CheckpointManager {
config: CheckpointConfig,
storage: Box<dyn CheckpointStorage>,
actors: HashMap<u32, (String, String)>,
last_snapshot: HashMap<u32, std::time::Instant>,
pending: HashMap<u64, PendingSnapshot>,
next_request_id: u64,
checkpoint_history: HashMap<u32, Vec<String>>,
total_completed: u64,
total_failed: u64,
}
impl CheckpointManager {
pub fn new(config: CheckpointConfig) -> Self {
let storage = Box::new(FileStorage::new(&config.storage_path));
Self {
config,
storage,
actors: HashMap::new(),
last_snapshot: HashMap::new(),
pending: HashMap::new(),
next_request_id: 1,
checkpoint_history: HashMap::new(),
total_completed: 0,
total_failed: 0,
}
}
pub fn with_storage(config: CheckpointConfig, storage: Box<dyn CheckpointStorage>) -> Self {
Self {
config,
storage,
actors: HashMap::new(),
last_snapshot: HashMap::new(),
pending: HashMap::new(),
next_request_id: 1,
checkpoint_history: HashMap::new(),
total_completed: 0,
total_failed: 0,
}
}
pub fn register_actor(
&mut self,
actor_slot: u32,
kernel_id: impl Into<String>,
kernel_type: impl Into<String>,
) {
self.actors
.insert(actor_slot, (kernel_id.into(), kernel_type.into()));
}
pub fn unregister_actor(&mut self, actor_slot: u32) {
self.actors.remove(&actor_slot);
self.last_snapshot.remove(&actor_slot);
}
pub fn is_enabled(&self) -> bool {
self.config.enabled
}
pub fn config(&self) -> &CheckpointConfig {
&self.config
}
pub fn pending_count(&self) -> usize {
self.pending.len()
}
pub fn total_completed(&self) -> u64 {
self.total_completed
}
pub fn total_failed(&self) -> u64 {
self.total_failed
}
pub fn poll_due_snapshots(&mut self) -> Vec<SnapshotRequest> {
if !self.config.enabled {
return Vec::new();
}
let now = std::time::Instant::now();
let interval = self.config.interval;
let mut requests = Vec::new();
let due_slots: Vec<u32> = self
.actors
.keys()
.filter(|slot| {
let has_pending = self
.pending
.values()
.any(|p| p.request.actor_slot == **slot);
if has_pending {
return false;
}
match self.last_snapshot.get(slot) {
Some(last) => now.duration_since(*last) >= interval,
None => true, }
})
.copied()
.collect();
for slot in due_slots {
let request_id = self.next_request_id;
self.next_request_id += 1;
let request = SnapshotRequest {
request_id,
actor_slot: slot,
buffer_offset: 0, issued_at: SystemTime::now(),
};
if let Some((kernel_id, kernel_type)) = self.actors.get(&slot) {
self.pending.insert(
request_id,
PendingSnapshot {
request: request.clone(),
kernel_id: kernel_id.clone(),
kernel_type: kernel_type.clone(),
},
);
}
requests.push(request);
}
requests
}
pub fn complete_snapshot(&mut self, response: SnapshotResponse) -> Result<Option<String>> {
let pending = match self.pending.remove(&response.request_id) {
Some(p) => p,
None => {
return Ok(None);
}
};
self.last_snapshot
.insert(pending.request.actor_slot, std::time::Instant::now());
if !response.success {
self.total_failed += 1;
return Err(RingKernelError::InvalidCheckpoint(format!(
"Snapshot failed for actor slot {}",
response.actor_slot
)));
}
let checkpoint = CheckpointBuilder::new(&pending.kernel_id, &pending.kernel_type)
.step(response.step)
.custom("actor_slot", pending.request.actor_slot.to_string())
.custom(
"snapshot_request_id",
pending.request.request_id.to_string(),
)
.device_memory("actor_state", response.data)
.build();
let name = format!(
"{}_{}_step_{}",
self.config.name_prefix, pending.request.actor_slot, response.step
);
self.storage.save(&checkpoint, &name)?;
self.total_completed += 1;
let history = self
.checkpoint_history
.entry(pending.request.actor_slot)
.or_default();
history.push(name.clone());
if self.config.max_snapshots > 0 {
while history.len() > self.config.max_snapshots {
let oldest = history.remove(0);
if let Err(e) = self.storage.delete(&oldest) {
tracing::warn!(
checkpoint = oldest,
error = %e,
"Failed to delete old checkpoint during retention cleanup"
);
}
}
}
Ok(Some(name))
}
pub fn request_snapshot(&mut self, actor_slot: u32) -> Option<SnapshotRequest> {
let (kernel_id, kernel_type) = self.actors.get(&actor_slot)?.clone();
let request_id = self.next_request_id;
self.next_request_id += 1;
let request = SnapshotRequest {
request_id,
actor_slot,
buffer_offset: 0,
issued_at: SystemTime::now(),
};
self.pending.insert(
request_id,
PendingSnapshot {
request: request.clone(),
kernel_id,
kernel_type,
},
);
Some(request)
}
pub fn cancel_pending(&mut self, request_id: u64) -> bool {
self.pending.remove(&request_id).is_some()
}
pub fn cancel_all_pending(&mut self) {
self.pending.clear();
}
pub fn load_latest(&self, actor_slot: u32) -> Result<Option<Checkpoint>> {
if let Some(history) = self.checkpoint_history.get(&actor_slot) {
if let Some(latest_name) = history.last() {
return self.storage.load(latest_name).map(Some);
}
}
let prefix = format!("{}_{}_", self.config.name_prefix, actor_slot);
let all = self.storage.list()?;
let matching: Vec<_> = all.iter().filter(|n| n.starts_with(&prefix)).collect();
if let Some(latest) = matching.last() {
return self.storage.load(latest).map(Some);
}
Ok(None)
}
pub fn list_checkpoints(&self, actor_slot: u32) -> Result<Vec<String>> {
let prefix = format!("{}_{}_", self.config.name_prefix, actor_slot);
let all = self.storage.list()?;
Ok(all.into_iter().filter(|n| n.starts_with(&prefix)).collect())
}
pub fn storage(&self) -> &dyn CheckpointStorage {
&*self.storage
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_header_roundtrip() {
let header = CheckpointHeader::new(5, 1024);
let bytes = header.to_bytes();
let restored = CheckpointHeader::from_bytes(&bytes);
assert_eq!(restored.magic, CHECKPOINT_MAGIC);
assert_eq!(restored.version, CHECKPOINT_VERSION);
assert_eq!(restored.chunk_count, 5);
assert_eq!(restored.total_size, 1024);
}
#[test]
fn test_chunk_header_roundtrip() {
let header = ChunkHeader::new(ChunkType::DeviceMemory, 4096).with_id(12345);
let bytes = header.to_bytes();
let restored = ChunkHeader::from_bytes(&bytes);
assert_eq!(restored.chunk_type, ChunkType::DeviceMemory as u32);
assert_eq!(restored.uncompressed_size, 4096);
assert_eq!(restored.chunk_id, 12345);
}
#[test]
fn test_metadata_roundtrip() {
let metadata = CheckpointMetadata::new("kernel_1", "fdtd_3d")
.with_step(1000)
.with_grid_size(64, 64, 64)
.with_tile_size(8, 8, 8)
.with_custom("version", "1.0");
let bytes = metadata.to_bytes();
let restored = CheckpointMetadata::from_bytes(&bytes).unwrap();
assert_eq!(restored.kernel_id, "kernel_1");
assert_eq!(restored.kernel_type, "fdtd_3d");
assert_eq!(restored.current_step, 1000);
assert_eq!(restored.grid_size, (64, 64, 64));
assert_eq!(restored.tile_size, (8, 8, 8));
assert_eq!(restored.custom.get("version"), Some(&"1.0".to_string()));
}
#[test]
fn test_checkpoint_roundtrip() {
let checkpoint = CheckpointBuilder::new("test_kernel", "test_type")
.step(500)
.grid_size(32, 32, 32)
.control_block(vec![1, 2, 3, 4])
.device_memory("pressure_a", vec![5, 6, 7, 8, 9, 10])
.build();
let bytes = checkpoint.to_bytes();
let restored = Checkpoint::from_bytes(&bytes).unwrap();
assert_eq!(restored.metadata.kernel_id, "test_kernel");
assert_eq!(restored.metadata.current_step, 500);
assert_eq!(restored.chunks.len(), 2);
let control = restored.get_chunk(ChunkType::ControlBlock).unwrap();
assert_eq!(control.data, vec![1, 2, 3, 4]);
}
#[test]
fn test_memory_storage() {
let storage = MemoryStorage::new();
let checkpoint = CheckpointBuilder::new("mem_test", "test").step(100).build();
storage.save(&checkpoint, "test_001").unwrap();
assert!(storage.exists("test_001"));
let loaded = storage.load("test_001").unwrap();
assert_eq!(loaded.metadata.kernel_id, "mem_test");
assert_eq!(loaded.metadata.current_step, 100);
let list = storage.list().unwrap();
assert_eq!(list, vec!["test_001"]);
storage.delete("test_001").unwrap();
assert!(!storage.exists("test_001"));
}
#[test]
fn test_crc32() {
assert_eq!(crc32_simple(b""), 0);
assert_eq!(crc32_simple(b"123456789"), 0xCBF43926);
}
#[test]
fn test_checkpoint_validation() {
let mut bytes = [0u8; 64];
bytes[0..8].copy_from_slice(&0u64.to_le_bytes());
let header = CheckpointHeader::from_bytes(&bytes);
assert!(header.validate().is_err());
}
#[test]
fn test_large_checkpoint() {
let large_data: Vec<u8> = (0..100_000).map(|i| (i % 256) as u8).collect();
let checkpoint = CheckpointBuilder::new("large_kernel", "stress_test")
.step(999)
.device_memory("field_a", large_data.clone())
.device_memory("field_b", large_data.clone())
.build();
let bytes = checkpoint.to_bytes();
let restored = Checkpoint::from_bytes(&bytes).unwrap();
assert_eq!(restored.chunks.len(), 2);
let chunks = restored.get_chunks(ChunkType::DeviceMemory);
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].data.len(), 100_000);
}
#[test]
fn test_checkpoint_config_defaults() {
let config = CheckpointConfig::default();
assert_eq!(config.interval, Duration::from_secs(30));
assert_eq!(config.max_snapshots, 5);
assert!(config.enabled);
assert_eq!(config.name_prefix, "checkpoint");
}
#[test]
fn test_checkpoint_config_builder() {
let config = CheckpointConfig::new(Duration::from_secs(10))
.with_max_snapshots(3)
.with_storage_path("/var/checkpoints")
.with_name_prefix("actor")
.with_enabled(false);
assert_eq!(config.interval, Duration::from_secs(10));
assert_eq!(config.max_snapshots, 3);
assert_eq!(config.storage_path, PathBuf::from("/var/checkpoints"));
assert_eq!(config.name_prefix, "actor");
assert!(!config.enabled);
}
#[test]
fn test_manager_disabled() {
let config = CheckpointConfig::new(Duration::from_millis(1)).with_enabled(false);
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
manager.register_actor(0, "kernel_0", "test");
let requests = manager.poll_due_snapshots();
assert!(requests.is_empty());
}
#[test]
fn test_manager_register_and_poll() {
let config = CheckpointConfig::new(Duration::from_millis(1));
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
manager.register_actor(0, "sim_0", "fdtd_3d");
manager.register_actor(1, "sim_1", "fdtd_3d");
let requests = manager.poll_due_snapshots();
assert_eq!(requests.len(), 2);
let slots: Vec<u32> = requests.iter().map(|r| r.actor_slot).collect();
assert!(slots.contains(&0));
assert!(slots.contains(&1));
assert_ne!(requests[0].request_id, requests[1].request_id);
}
#[test]
fn test_manager_no_duplicate_pending() {
let config = CheckpointConfig::new(Duration::from_millis(1));
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
manager.register_actor(0, "sim_0", "fdtd_3d");
let requests = manager.poll_due_snapshots();
assert_eq!(requests.len(), 1);
let requests2 = manager.poll_due_snapshots();
assert!(requests2.is_empty());
}
#[test]
fn test_manager_complete_snapshot() {
let config = CheckpointConfig::new(Duration::from_secs(3600)).with_name_prefix("test");
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
manager.register_actor(0, "sim_0", "fdtd_3d");
let requests = manager.poll_due_snapshots();
assert_eq!(requests.len(), 1);
let req = &requests[0];
let response = SnapshotResponse {
request_id: req.request_id,
actor_slot: 0,
success: true,
data: vec![1, 2, 3, 4, 5],
step: 1000,
};
let name = manager.complete_snapshot(response).unwrap();
assert!(name.is_some());
let name = name.unwrap();
assert_eq!(name, "test_0_step_1000");
assert!(manager.storage().exists(&name));
let loaded = manager.storage().load(&name).unwrap();
assert_eq!(loaded.metadata.kernel_id, "sim_0");
assert_eq!(loaded.metadata.kernel_type, "fdtd_3d");
assert_eq!(loaded.metadata.current_step, 1000);
assert_eq!(manager.total_completed(), 1);
assert_eq!(manager.total_failed(), 0);
}
#[test]
fn test_manager_failed_snapshot() {
let config = CheckpointConfig::new(Duration::from_secs(3600));
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
manager.register_actor(0, "sim_0", "fdtd_3d");
let requests = manager.poll_due_snapshots();
let req = &requests[0];
let response = SnapshotResponse {
request_id: req.request_id,
actor_slot: 0,
success: false,
data: Vec::new(),
step: 500,
};
let result = manager.complete_snapshot(response);
assert!(result.is_err());
assert_eq!(manager.total_failed(), 1);
assert_eq!(manager.total_completed(), 0);
}
#[test]
fn test_manager_retention_policy() {
let config = CheckpointConfig::new(Duration::from_secs(3600))
.with_max_snapshots(2)
.with_name_prefix("ret");
let storage = Box::new(MemoryStorage::new());
let mut manager = CheckpointManager::with_storage(config, storage);
manager.register_actor(0, "sim_0", "test");
for step in [100u64, 200, 300] {
let req = manager.request_snapshot(0).unwrap();
let response = SnapshotResponse {
request_id: req.request_id,
actor_slot: 0,
success: true,
data: vec![step as u8],
step,
};
manager.complete_snapshot(response).unwrap();
}
assert!(!manager.storage().exists("ret_0_step_100"));
assert!(manager.storage().exists("ret_0_step_200"));
assert!(manager.storage().exists("ret_0_step_300"));
assert_eq!(manager.total_completed(), 3);
}
#[test]
fn test_manager_unknown_response() {
let config = CheckpointConfig::new(Duration::from_secs(3600));
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
let response = SnapshotResponse {
request_id: 9999,
actor_slot: 0,
success: true,
data: vec![1, 2, 3],
step: 100,
};
let result = manager.complete_snapshot(response).unwrap();
assert!(result.is_none());
}
#[test]
fn test_manager_cancel_pending() {
let config = CheckpointConfig::new(Duration::from_millis(1));
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
manager.register_actor(0, "sim_0", "test");
let requests = manager.poll_due_snapshots();
assert_eq!(manager.pending_count(), 1);
let cancelled = manager.cancel_pending(requests[0].request_id);
assert!(cancelled);
assert_eq!(manager.pending_count(), 0);
}
#[test]
fn test_manager_cancel_all_pending() {
let config = CheckpointConfig::new(Duration::from_millis(1));
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
manager.register_actor(0, "sim_0", "test");
manager.register_actor(1, "sim_1", "test");
let _requests = manager.poll_due_snapshots();
assert_eq!(manager.pending_count(), 2);
manager.cancel_all_pending();
assert_eq!(manager.pending_count(), 0);
}
#[test]
fn test_manager_load_latest() {
let config = CheckpointConfig::new(Duration::from_secs(3600)).with_name_prefix("lat");
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
manager.register_actor(0, "sim_0", "test");
let latest = manager.load_latest(0).unwrap();
assert!(latest.is_none());
for step in [100u64, 200] {
let req = manager.request_snapshot(0).unwrap();
let response = SnapshotResponse {
request_id: req.request_id,
actor_slot: 0,
success: true,
data: vec![step as u8],
step,
};
manager.complete_snapshot(response).unwrap();
}
let latest = manager.load_latest(0).unwrap().unwrap();
assert_eq!(latest.metadata.current_step, 200);
}
#[test]
fn test_manager_list_checkpoints() {
let config = CheckpointConfig::new(Duration::from_secs(3600))
.with_max_snapshots(10)
.with_name_prefix("list");
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
manager.register_actor(0, "sim_0", "test");
manager.register_actor(1, "sim_1", "test");
for step in [100u64, 200] {
for actor_slot in [0u32, 1] {
let req = manager.request_snapshot(actor_slot).unwrap();
let response = SnapshotResponse {
request_id: req.request_id,
actor_slot,
success: true,
data: vec![step as u8],
step,
};
manager.complete_snapshot(response).unwrap();
}
}
let actor0_checkpoints = manager.list_checkpoints(0).unwrap();
let actor1_checkpoints = manager.list_checkpoints(1).unwrap();
assert_eq!(actor0_checkpoints.len(), 2);
assert_eq!(actor1_checkpoints.len(), 2);
for name in &actor0_checkpoints {
assert!(name.starts_with("list_0_"));
}
for name in &actor1_checkpoints {
assert!(name.starts_with("list_1_"));
}
}
#[test]
fn test_manager_unregister_actor() {
let config = CheckpointConfig::new(Duration::from_millis(1));
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
manager.register_actor(0, "sim_0", "test");
let requests = manager.poll_due_snapshots();
assert_eq!(requests.len(), 1);
manager.unregister_actor(0);
manager.cancel_all_pending();
let requests = manager.poll_due_snapshots();
assert!(requests.is_empty());
}
#[test]
fn test_snapshot_request_response_roundtrip() {
let request = SnapshotRequest {
request_id: 42,
actor_slot: 7,
buffer_offset: 4096,
issued_at: SystemTime::now(),
};
assert_eq!(request.request_id, 42);
assert_eq!(request.actor_slot, 7);
assert_eq!(request.buffer_offset, 4096);
let response = SnapshotResponse {
request_id: 42,
actor_slot: 7,
success: true,
data: vec![0xDE, 0xAD, 0xBE, 0xEF],
step: 5000,
};
assert_eq!(response.request_id, request.request_id);
assert_eq!(response.actor_slot, request.actor_slot);
assert!(response.success);
assert_eq!(response.step, 5000);
}
#[test]
fn test_manager_interval_respected() {
let config = CheckpointConfig::new(Duration::from_secs(3600));
let mut manager = CheckpointManager::with_storage(config, Box::new(MemoryStorage::new()));
manager.register_actor(0, "sim_0", "test");
let requests = manager.poll_due_snapshots();
assert_eq!(requests.len(), 1);
let response = SnapshotResponse {
request_id: requests[0].request_id,
actor_slot: 0,
success: true,
data: vec![1],
step: 100,
};
manager.complete_snapshot(response).unwrap();
let requests = manager.poll_due_snapshots();
assert!(requests.is_empty());
}
#[test]
fn test_file_storage_roundtrip() {
let tmp_dir = std::env::temp_dir().join("ringkernel_checkpoint_test");
let config = CheckpointConfig::new(Duration::from_millis(1))
.with_storage_path(&tmp_dir)
.with_name_prefix("file_test");
let mut manager = CheckpointManager::new(config);
manager.register_actor(0, "file_kernel", "test_type");
let requests = manager.poll_due_snapshots();
assert_eq!(requests.len(), 1);
let response = SnapshotResponse {
request_id: requests[0].request_id,
actor_slot: 0,
success: true,
data: vec![10, 20, 30, 40, 50],
step: 42,
};
let name = manager.complete_snapshot(response).unwrap().unwrap();
let file_path = tmp_dir.join(format!("{}.rkcp", name));
assert!(file_path.exists());
let loaded = manager.load_latest(0).unwrap().unwrap();
assert_eq!(loaded.metadata.kernel_id, "file_kernel");
assert_eq!(loaded.metadata.current_step, 42);
let _ = std::fs::remove_dir_all(&tmp_dir);
}
fn build_sample_checkpoint(control: &[u8], mem: &[u8]) -> Checkpoint {
let meta = CheckpointMetadata::new("delta_test", "sim").with_step(0);
let mut cp = Checkpoint::new(meta);
cp.add_control_block(control.to_vec());
cp.add_device_memory("pressure", mem.to_vec());
cp
}
#[test]
fn delta_from_empty_when_new_matches_base() {
let base = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
let new = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
let delta = Checkpoint::delta_from(&base, &new);
assert!(
delta.chunks.is_empty(),
"unchanged chunks should be omitted"
);
assert_eq!(
delta.metadata.custom.get(DELTA_PARENT_DIGEST_KEY).cloned(),
Some(base.content_digest())
);
}
#[test]
fn delta_captures_changed_and_new_chunks() {
let base = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
let mut new = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 9, 9]); new.add_h2k_queue(vec![42, 42]); let delta = Checkpoint::delta_from(&base, &new);
assert!(delta
.chunks
.iter()
.any(|c| c.chunk_type() == Some(ChunkType::DeviceMemory)));
assert!(delta
.chunks
.iter()
.any(|c| c.chunk_type() == Some(ChunkType::H2KQueue)));
assert!(!delta
.chunks
.iter()
.any(|c| c.chunk_type() == Some(ChunkType::ControlBlock)));
}
#[test]
fn delta_apply_recovers_new_checkpoint() {
let base = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
let mut new = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 9, 9]);
new.add_h2k_queue(vec![42, 42]);
let delta = Checkpoint::delta_from(&base, &new);
let restored = Checkpoint::applied_with_delta(&base, &delta).expect("apply");
assert_eq!(restored.chunks.len(), new.chunks.len());
for chunk in &new.chunks {
let id = chunk.chunk_identity().unwrap();
let found = restored
.chunks
.iter()
.find(|c| c.chunk_identity() == Some(id))
.expect("identity present");
assert_eq!(found.data, chunk.data, "chunk {id:?} bytes match");
}
}
#[test]
fn delta_apply_rejects_wrong_base() {
let base_a = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
let base_b = build_sample_checkpoint(&[9, 9, 9], &[8, 8, 8, 8]);
let new = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 9, 9]);
let delta = Checkpoint::delta_from(&base_a, &new);
let err = Checkpoint::applied_with_delta(&base_b, &delta)
.expect_err("different base should fail");
assert!(matches!(err, RingKernelError::InvalidCheckpoint(_)));
}
#[test]
fn content_digest_stable_across_identical_chunks() {
let a = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
let b = build_sample_checkpoint(&[1, 2, 3], &[4, 5, 6, 7]);
assert_eq!(a.content_digest(), b.content_digest());
}
}