Skip to main content

scirs2_integrate/distributed/
checkpointing.rs

1//! Checkpointing and fault tolerance for distributed integration
2//!
3//! This module provides checkpointing capabilities for fault tolerance in
4//! distributed ODE solving, allowing recovery from node failures.
5
6use crate::common::IntegrateFloat;
7use crate::distributed::types::{
8    ChunkId, ChunkResult, ChunkResultStatus, DistributedError, DistributedResult,
9    FaultToleranceMode, JobId, NodeId,
10};
11use scirs2_core::ndarray::Array1;
12use std::collections::{HashMap, HashSet, VecDeque};
13use std::fs::{self, File};
14use std::io::{Read, Write};
15use std::path::{Path, PathBuf};
16use std::sync::atomic::{AtomicU64, Ordering};
17use std::sync::{Arc, Mutex, RwLock};
18use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
19
20/// Checkpoint manager for distributed computation
21pub struct CheckpointManager<F: IntegrateFloat> {
22    /// Checkpoint storage directory
23    storage_path: PathBuf,
24    /// Active checkpoints by job
25    checkpoints: RwLock<HashMap<JobId, Vec<Checkpoint<F>>>>,
26    /// Next checkpoint ID
27    next_checkpoint_id: AtomicU64,
28    /// Configuration
29    config: CheckpointConfig,
30    /// Checkpoint creation times
31    checkpoint_times: Mutex<VecDeque<Instant>>,
32}
33
34/// Configuration for checkpointing
35#[derive(Debug, Clone)]
36pub struct CheckpointConfig {
37    /// Maximum checkpoints to keep per job
38    pub max_checkpoints_per_job: usize,
39    /// Checkpoint interval (number of chunks)
40    pub interval_chunks: usize,
41    /// Checkpoint interval (duration)
42    pub interval_duration: Duration,
43    /// Enable disk persistence
44    pub persist_to_disk: bool,
45    /// Compress checkpoints
46    pub compress: bool,
47    /// Verify checkpoints after writing
48    pub verify_writes: bool,
49}
50
51impl Default for CheckpointConfig {
52    fn default() -> Self {
53        Self {
54            max_checkpoints_per_job: 5,
55            interval_chunks: 10,
56            interval_duration: Duration::from_secs(60),
57            persist_to_disk: true,
58            compress: false,
59            verify_writes: true,
60        }
61    }
62}
63
64/// A checkpoint containing computation state
65#[derive(Debug, Clone)]
66pub struct Checkpoint<F: IntegrateFloat> {
67    /// Unique checkpoint ID
68    pub id: u64,
69    /// Job this checkpoint belongs to
70    pub job_id: JobId,
71    /// Timestamp when created
72    pub timestamp: SystemTime,
73    /// Completed chunks
74    pub completed_chunks: Vec<ChunkCheckpoint<F>>,
75    /// In-progress chunks (for recovery)
76    pub in_progress_chunks: Vec<ChunkId>,
77    /// Global state (e.g., iteration count)
78    pub global_state: CheckpointGlobalState<F>,
79    /// Validation hash
80    pub validation_hash: u64,
81}
82
83/// Checkpoint data for a single chunk
84#[derive(Debug, Clone)]
85pub struct ChunkCheckpoint<F: IntegrateFloat> {
86    /// Chunk ID
87    pub chunk_id: ChunkId,
88    /// Final time
89    pub final_time: F,
90    /// Final state
91    pub final_state: Array1<F>,
92    /// Final derivative (if available)
93    pub final_derivative: Option<Array1<F>>,
94    /// Node that processed this chunk
95    pub node_id: NodeId,
96    /// Processing time
97    pub processing_time: Duration,
98}
99
100/// Global state for checkpoint
101#[derive(Debug, Clone, Default)]
102pub struct CheckpointGlobalState<F: IntegrateFloat> {
103    /// Current iteration
104    pub iteration: usize,
105    /// Total chunks completed
106    pub chunks_completed: usize,
107    /// Total chunks remaining
108    pub chunks_remaining: usize,
109    /// Current time progress
110    pub current_time: F,
111    /// Accumulated error estimate
112    pub error_estimate: F,
113}
114
115impl<F: IntegrateFloat> CheckpointManager<F> {
116    /// Create a new checkpoint manager
117    pub fn new(storage_path: PathBuf, config: CheckpointConfig) -> DistributedResult<Self> {
118        // Create storage directory if it doesn't exist
119        if config.persist_to_disk {
120            fs::create_dir_all(&storage_path).map_err(|e| {
121                DistributedError::CheckpointError(format!(
122                    "Failed to create checkpoint directory: {}",
123                    e
124                ))
125            })?;
126        }
127
128        Ok(Self {
129            storage_path,
130            checkpoints: RwLock::new(HashMap::new()),
131            next_checkpoint_id: AtomicU64::new(1),
132            config,
133            checkpoint_times: Mutex::new(VecDeque::new()),
134        })
135    }
136
137    /// Create a new checkpoint
138    pub fn create_checkpoint(
139        &self,
140        job_id: JobId,
141        completed_chunks: Vec<ChunkResult<F>>,
142        in_progress_chunks: Vec<ChunkId>,
143        global_state: CheckpointGlobalState<F>,
144    ) -> DistributedResult<u64> {
145        let checkpoint_id = self.next_checkpoint_id.fetch_add(1, Ordering::SeqCst);
146
147        // Convert chunk results to checkpoint format
148        let chunk_checkpoints: Vec<ChunkCheckpoint<F>> = completed_chunks
149            .into_iter()
150            .filter(|r| r.status == ChunkResultStatus::Success)
151            .map(|r| ChunkCheckpoint {
152                chunk_id: r.chunk_id,
153                final_time: r.time_points.last().copied().unwrap_or(F::zero()),
154                final_state: r.final_state.clone(),
155                final_derivative: r.final_derivative.clone(),
156                node_id: r.node_id,
157                processing_time: r.processing_time,
158            })
159            .collect();
160
161        // Calculate validation hash
162        let validation_hash = self.calculate_hash(&chunk_checkpoints, &global_state);
163
164        let checkpoint = Checkpoint {
165            id: checkpoint_id,
166            job_id,
167            timestamp: SystemTime::now(),
168            completed_chunks: chunk_checkpoints,
169            in_progress_chunks,
170            global_state,
171            validation_hash,
172        };
173
174        // Store in memory
175        {
176            let mut checkpoints = self.checkpoints.write().map_err(|_| {
177                DistributedError::CheckpointError("Failed to acquire checkpoint lock".to_string())
178            })?;
179
180            let job_checkpoints = checkpoints.entry(job_id).or_insert_with(Vec::new);
181            job_checkpoints.push(checkpoint.clone());
182
183            // Trim old checkpoints
184            while job_checkpoints.len() > self.config.max_checkpoints_per_job {
185                let removed = job_checkpoints.remove(0);
186                if self.config.persist_to_disk {
187                    let _ = self.delete_from_disk(job_id, removed.id);
188                }
189            }
190        }
191
192        // Persist to disk
193        if self.config.persist_to_disk {
194            self.save_to_disk(&checkpoint)?;
195        }
196
197        // Record checkpoint time
198        if let Ok(mut times) = self.checkpoint_times.lock() {
199            times.push_back(Instant::now());
200            while times.len() > 100 {
201                times.pop_front();
202            }
203        }
204
205        Ok(checkpoint_id)
206    }
207
208    /// Get the latest checkpoint for a job
209    pub fn get_latest_checkpoint(&self, job_id: JobId) -> Option<Checkpoint<F>> {
210        match self.checkpoints.read() {
211            Ok(checkpoints) => checkpoints.get(&job_id).and_then(|cps| cps.last().cloned()),
212            Err(_) => None,
213        }
214    }
215
216    /// Get a specific checkpoint
217    pub fn get_checkpoint(&self, job_id: JobId, checkpoint_id: u64) -> Option<Checkpoint<F>> {
218        match self.checkpoints.read() {
219            Ok(checkpoints) => checkpoints
220                .get(&job_id)
221                .and_then(|cps| cps.iter().find(|cp| cp.id == checkpoint_id).cloned()),
222            Err(_) => None,
223        }
224    }
225
226    /// Restore from a checkpoint
227    pub fn restore(
228        &self,
229        job_id: JobId,
230        checkpoint_id: Option<u64>,
231    ) -> DistributedResult<Checkpoint<F>> {
232        let checkpoint = if let Some(id) = checkpoint_id {
233            self.get_checkpoint(job_id, id)
234        } else {
235            self.get_latest_checkpoint(job_id)
236        };
237
238        let checkpoint = checkpoint.ok_or_else(|| {
239            DistributedError::CheckpointError(format!("No checkpoint found for job {:?}", job_id))
240        })?;
241
242        // Validate checkpoint
243        let expected_hash =
244            self.calculate_hash(&checkpoint.completed_chunks, &checkpoint.global_state);
245        if expected_hash != checkpoint.validation_hash {
246            return Err(DistributedError::CheckpointError(
247                "Checkpoint validation failed".to_string(),
248            ));
249        }
250
251        Ok(checkpoint)
252    }
253
254    /// Delete all checkpoints for a job
255    pub fn cleanup_job(&self, job_id: JobId) -> DistributedResult<()> {
256        if let Ok(mut checkpoints) = self.checkpoints.write() {
257            if let Some(job_cps) = checkpoints.remove(&job_id) {
258                if self.config.persist_to_disk {
259                    for cp in job_cps {
260                        let _ = self.delete_from_disk(job_id, cp.id);
261                    }
262                }
263            }
264        }
265        Ok(())
266    }
267
268    /// Calculate a hash for validation
269    fn calculate_hash(
270        &self,
271        chunks: &[ChunkCheckpoint<F>],
272        global_state: &CheckpointGlobalState<F>,
273    ) -> u64 {
274        use std::collections::hash_map::DefaultHasher;
275        use std::hash::{Hash, Hasher};
276
277        let mut hasher = DefaultHasher::new();
278
279        // Hash chunk data
280        for chunk in chunks {
281            chunk.chunk_id.0.hash(&mut hasher);
282            chunk.node_id.0.hash(&mut hasher);
283
284            // Hash state values
285            for val in chunk.final_state.iter() {
286                let bits = val.to_f64().unwrap_or(0.0).to_bits();
287                bits.hash(&mut hasher);
288            }
289        }
290
291        // Hash global state
292        global_state.iteration.hash(&mut hasher);
293        global_state.chunks_completed.hash(&mut hasher);
294        global_state.chunks_remaining.hash(&mut hasher);
295
296        hasher.finish()
297    }
298
299    /// Save checkpoint to disk
300    fn save_to_disk(&self, checkpoint: &Checkpoint<F>) -> DistributedResult<()> {
301        let filename = format!(
302            "checkpoint_{}_{}.bin",
303            checkpoint.job_id.value(),
304            checkpoint.id
305        );
306        let path = self.storage_path.join(&filename);
307
308        // Serialize checkpoint (simplified - use proper serialization in production)
309        let data = self.serialize_checkpoint(checkpoint)?;
310
311        let mut file = File::create(&path).map_err(|e| {
312            DistributedError::CheckpointError(format!("Failed to create checkpoint file: {}", e))
313        })?;
314
315        file.write_all(&data).map_err(|e| {
316            DistributedError::CheckpointError(format!("Failed to write checkpoint: {}", e))
317        })?;
318
319        // Verify if configured
320        if self.config.verify_writes {
321            let mut verify_file = File::open(&path).map_err(|e| {
322                DistributedError::CheckpointError(format!(
323                    "Failed to verify checkpoint file: {}",
324                    e
325                ))
326            })?;
327
328            let mut verify_data = Vec::new();
329            verify_file.read_to_end(&mut verify_data).map_err(|e| {
330                DistributedError::CheckpointError(format!("Failed to read back checkpoint: {}", e))
331            })?;
332
333            if verify_data != data {
334                return Err(DistributedError::CheckpointError(
335                    "Checkpoint verification failed".to_string(),
336                ));
337            }
338        }
339
340        Ok(())
341    }
342
343    /// Delete checkpoint from disk
344    fn delete_from_disk(&self, job_id: JobId, checkpoint_id: u64) -> DistributedResult<()> {
345        let filename = format!("checkpoint_{}_{}.bin", job_id.value(), checkpoint_id);
346        let path = self.storage_path.join(&filename);
347
348        if path.exists() {
349            fs::remove_file(&path).map_err(|e| {
350                DistributedError::CheckpointError(format!(
351                    "Failed to delete checkpoint file: {}",
352                    e
353                ))
354            })?;
355        }
356
357        Ok(())
358    }
359
360    /// Serialize checkpoint to bytes
361    fn serialize_checkpoint(&self, checkpoint: &Checkpoint<F>) -> DistributedResult<Vec<u8>> {
362        let mut data = Vec::new();
363
364        // Write header
365        data.extend_from_slice(&checkpoint.id.to_le_bytes());
366        data.extend_from_slice(&checkpoint.job_id.value().to_le_bytes());
367        data.extend_from_slice(&checkpoint.validation_hash.to_le_bytes());
368
369        // Write timestamp
370        let timestamp_secs = checkpoint
371            .timestamp
372            .duration_since(UNIX_EPOCH)
373            .unwrap_or(Duration::ZERO)
374            .as_secs();
375        data.extend_from_slice(&timestamp_secs.to_le_bytes());
376
377        // Write global state
378        data.extend_from_slice(&checkpoint.global_state.iteration.to_le_bytes());
379        data.extend_from_slice(&checkpoint.global_state.chunks_completed.to_le_bytes());
380        data.extend_from_slice(&checkpoint.global_state.chunks_remaining.to_le_bytes());
381
382        // Write chunk count
383        data.extend_from_slice(&(checkpoint.completed_chunks.len() as u64).to_le_bytes());
384
385        // Write each chunk
386        for chunk in &checkpoint.completed_chunks {
387            data.extend_from_slice(&chunk.chunk_id.0.to_le_bytes());
388            data.extend_from_slice(&chunk.node_id.0.to_le_bytes());
389
390            let time_f64 = chunk.final_time.to_f64().unwrap_or(0.0);
391            data.extend_from_slice(&time_f64.to_le_bytes());
392
393            // Write state
394            data.extend_from_slice(&(chunk.final_state.len() as u64).to_le_bytes());
395            for val in chunk.final_state.iter() {
396                let val_f64 = val.to_f64().unwrap_or(0.0);
397                data.extend_from_slice(&val_f64.to_le_bytes());
398            }
399        }
400
401        Ok(data)
402    }
403
404    /// Check if checkpoint is due
405    pub fn should_checkpoint(&self, chunks_since_last: usize) -> bool {
406        // Check chunk interval
407        if chunks_since_last >= self.config.interval_chunks {
408            return true;
409        }
410
411        // Check time interval
412        if let Ok(times) = self.checkpoint_times.lock() {
413            if let Some(last_time) = times.back() {
414                if last_time.elapsed() >= self.config.interval_duration {
415                    return true;
416                }
417            } else {
418                // No checkpoints yet, time to create one
419                return chunks_since_last > 0;
420            }
421        }
422
423        false
424    }
425
426    /// Get checkpoint statistics
427    pub fn get_statistics(&self) -> CheckpointStatistics {
428        let mut total_checkpoints = 0;
429        let mut total_chunks_saved = 0;
430
431        if let Ok(checkpoints) = self.checkpoints.read() {
432            for (_, job_cps) in checkpoints.iter() {
433                total_checkpoints += job_cps.len();
434                for cp in job_cps {
435                    total_chunks_saved += cp.completed_chunks.len();
436                }
437            }
438        }
439
440        CheckpointStatistics {
441            total_checkpoints,
442            total_chunks_saved,
443            storage_path: self.storage_path.clone(),
444        }
445    }
446}
447
448/// Statistics about checkpointing
449#[derive(Debug, Clone)]
450pub struct CheckpointStatistics {
451    /// Total number of checkpoints
452    pub total_checkpoints: usize,
453    /// Total chunks saved across all checkpoints
454    pub total_chunks_saved: usize,
455    /// Storage path
456    pub storage_path: PathBuf,
457}
458
459/// Fault tolerance coordinator
460pub struct FaultToleranceCoordinator<F: IntegrateFloat> {
461    /// Checkpoint manager
462    checkpoint_manager: Arc<CheckpointManager<F>>,
463    /// Fault tolerance mode
464    mode: FaultToleranceMode,
465    /// Failed nodes
466    failed_nodes: RwLock<HashSet<NodeId>>,
467    /// Chunks pending retry
468    pending_retry: Mutex<Vec<ChunkId>>,
469    /// Recovery callbacks
470    recovery_callbacks: RwLock<Vec<Arc<dyn Fn(JobId) + Send + Sync>>>,
471}
472
473impl<F: IntegrateFloat> FaultToleranceCoordinator<F> {
474    /// Create a new fault tolerance coordinator
475    pub fn new(checkpoint_manager: Arc<CheckpointManager<F>>, mode: FaultToleranceMode) -> Self {
476        Self {
477            checkpoint_manager,
478            mode,
479            failed_nodes: RwLock::new(HashSet::new()),
480            pending_retry: Mutex::new(Vec::new()),
481            recovery_callbacks: RwLock::new(Vec::new()),
482        }
483    }
484
485    /// Handle node failure
486    pub fn handle_node_failure(
487        &self,
488        node_id: NodeId,
489        affected_chunks: Vec<ChunkId>,
490    ) -> DistributedResult<RecoveryAction> {
491        // Record failed node
492        if let Ok(mut failed) = self.failed_nodes.write() {
493            failed.insert(node_id);
494        }
495
496        match self.mode {
497            FaultToleranceMode::None => {
498                // No recovery, just report failure
499                Err(DistributedError::NodeFailure(
500                    node_id,
501                    "Node failed, no fault tolerance enabled".to_string(),
502                ))
503            }
504            FaultToleranceMode::Standard => {
505                // Queue chunks for retry
506                if let Ok(mut pending) = self.pending_retry.lock() {
507                    pending.extend(affected_chunks.iter().cloned());
508                }
509                Ok(RecoveryAction::RetryChunks(affected_chunks))
510            }
511            FaultToleranceMode::HighAvailability => {
512                // Immediate failover with replicas (if available)
513                if let Ok(mut pending) = self.pending_retry.lock() {
514                    pending.extend(affected_chunks.iter().cloned());
515                }
516                Ok(RecoveryAction::FailoverAndRetry(affected_chunks))
517            }
518            FaultToleranceMode::CheckpointRecovery => {
519                // Full recovery from checkpoint
520                Ok(RecoveryAction::RestoreFromCheckpoint)
521            }
522        }
523    }
524
525    /// Handle chunk failure
526    pub fn handle_chunk_failure(
527        &self,
528        chunk_id: ChunkId,
529        node_id: NodeId,
530        error: &str,
531        can_retry: bool,
532    ) -> DistributedResult<RecoveryAction> {
533        if can_retry && self.mode != FaultToleranceMode::None {
534            if let Ok(mut pending) = self.pending_retry.lock() {
535                pending.push(chunk_id);
536            }
537            Ok(RecoveryAction::RetryChunks(vec![chunk_id]))
538        } else if self.mode == FaultToleranceMode::CheckpointRecovery {
539            Ok(RecoveryAction::RestoreFromCheckpoint)
540        } else {
541            Err(DistributedError::ChunkError(
542                chunk_id,
543                format!("Unrecoverable error on node {}: {}", node_id, error),
544            ))
545        }
546    }
547
548    /// Get chunks pending retry
549    pub fn get_pending_retries(&self) -> Vec<ChunkId> {
550        match self.pending_retry.lock() {
551            Ok(pending) => pending.clone(),
552            Err(_) => Vec::new(),
553        }
554    }
555
556    /// Clear pending retries
557    pub fn clear_pending_retries(&self) -> Vec<ChunkId> {
558        match self.pending_retry.lock() {
559            Ok(mut pending) => std::mem::take(&mut *pending),
560            Err(_) => Vec::new(),
561        }
562    }
563
564    /// Check if a node has failed
565    pub fn is_node_failed(&self, node_id: NodeId) -> bool {
566        match self.failed_nodes.read() {
567            Ok(failed) => failed.contains(&node_id),
568            Err(_) => false,
569        }
570    }
571
572    /// Mark node as recovered
573    pub fn mark_node_recovered(&self, node_id: NodeId) {
574        if let Ok(mut failed) = self.failed_nodes.write() {
575            failed.remove(&node_id);
576        }
577    }
578
579    /// Recover a job from its latest checkpoint
580    pub fn recover_job(&self, job_id: JobId) -> DistributedResult<Checkpoint<F>> {
581        let checkpoint = self.checkpoint_manager.restore(job_id, None)?;
582
583        // Invoke recovery callbacks
584        if let Ok(callbacks) = self.recovery_callbacks.read() {
585            for cb in callbacks.iter() {
586                cb(job_id);
587            }
588        }
589
590        Ok(checkpoint)
591    }
592
593    /// Register a recovery callback
594    pub fn on_recovery<F2>(&self, callback: F2)
595    where
596        F2: Fn(JobId) + Send + Sync + 'static,
597    {
598        if let Ok(mut callbacks) = self.recovery_callbacks.write() {
599            callbacks.push(Arc::new(callback));
600        }
601    }
602
603    /// Get failed node count
604    pub fn failed_node_count(&self) -> usize {
605        match self.failed_nodes.read() {
606            Ok(failed) => failed.len(),
607            Err(_) => 0,
608        }
609    }
610}
611
612/// Action to take for recovery
613#[derive(Debug, Clone)]
614pub enum RecoveryAction {
615    /// Retry specific chunks
616    RetryChunks(Vec<ChunkId>),
617    /// Failover to backup and retry
618    FailoverAndRetry(Vec<ChunkId>),
619    /// Restore from checkpoint
620    RestoreFromCheckpoint,
621    /// No action needed
622    None,
623}
624
625#[cfg(test)]
626mod tests {
627    use super::*;
628
629    fn temp_storage_path() -> PathBuf {
630        std::env::temp_dir().join(format!("scirs_checkpoint_test_{}", std::process::id()))
631    }
632
633    #[test]
634    fn test_checkpoint_creation() {
635        let path = temp_storage_path();
636        let manager: CheckpointManager<f64> =
637            CheckpointManager::new(path.clone(), CheckpointConfig::default())
638                .expect("Failed to create manager");
639
640        let job_id = JobId::new(1);
641        let global_state = CheckpointGlobalState::default();
642
643        let checkpoint_id = manager
644            .create_checkpoint(job_id, Vec::new(), Vec::new(), global_state)
645            .expect("Failed to create checkpoint");
646
647        assert!(checkpoint_id > 0);
648
649        let checkpoint = manager.get_latest_checkpoint(job_id);
650        assert!(checkpoint.is_some());
651
652        // Cleanup
653        let _ = fs::remove_dir_all(&path);
654    }
655
656    #[test]
657    fn test_checkpoint_restore() {
658        let path = temp_storage_path();
659        let mut config = CheckpointConfig::default();
660        config.persist_to_disk = false;
661
662        let manager: CheckpointManager<f64> =
663            CheckpointManager::new(path.clone(), config).expect("Failed to create manager");
664
665        let job_id = JobId::new(1);
666        let global_state = CheckpointGlobalState {
667            iteration: 5,
668            chunks_completed: 10,
669            ..Default::default()
670        };
671
672        let _ = manager.create_checkpoint(job_id, Vec::new(), Vec::new(), global_state.clone());
673
674        let restored = manager.restore(job_id, None).expect("Failed to restore");
675        assert_eq!(restored.global_state.iteration, 5);
676        assert_eq!(restored.global_state.chunks_completed, 10);
677
678        // Cleanup
679        let _ = fs::remove_dir_all(&path);
680    }
681
682    #[test]
683    fn test_fault_tolerance_coordinator() {
684        let path = temp_storage_path();
685        let mut config = CheckpointConfig::default();
686        config.persist_to_disk = false;
687
688        let manager = Arc::new(
689            CheckpointManager::<f64>::new(path.clone(), config).expect("Failed to create manager"),
690        );
691
692        let coordinator = FaultToleranceCoordinator::new(manager, FaultToleranceMode::Standard);
693
694        let action = coordinator
695            .handle_node_failure(NodeId::new(1), vec![ChunkId::new(1), ChunkId::new(2)])
696            .expect("Failed to handle failure");
697
698        match action {
699            RecoveryAction::RetryChunks(chunks) => {
700                assert_eq!(chunks.len(), 2);
701            }
702            _ => panic!("Expected RetryChunks action"),
703        }
704
705        assert!(coordinator.is_node_failed(NodeId::new(1)));
706
707        // Cleanup
708        let _ = fs::remove_dir_all(&path);
709    }
710}