use crate::checkpoint::{Checkpoint, CheckpointError};
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::watch;
use tokio::time::{interval, Duration};
#[derive(Debug, Clone)]
pub struct SchedulerConfig {
pub checkpoint_path: PathBuf,
pub interval_secs: u64,
pub backup_count: u32,
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
checkpoint_path: PathBuf::from("./data/checkpoint.json"),
interval_secs: 300, backup_count: 3,
}
}
}
pub trait CheckpointProvider: Send + Sync {
fn integrity(&self) -> nklave_core::state::integrity::StateIntegrity;
fn validator_states(&self) -> std::collections::HashMap<[u8; 48], nklave_core::state::validator::ValidatorState>;
}
pub struct CheckpointSchedulerHandle {
shutdown_tx: watch::Sender<bool>,
running: Arc<AtomicBool>,
last_checkpoint_time: Arc<AtomicU64>,
}
impl CheckpointSchedulerHandle {
pub fn is_running(&self) -> bool {
self.running.load(Ordering::Relaxed)
}
pub fn checkpoint_age_secs(&self) -> u64 {
let last = self.last_checkpoint_time.load(Ordering::Relaxed);
if last == 0 {
return 0; }
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
now.saturating_sub(last)
}
pub fn shutdown(&self) {
let _ = self.shutdown_tx.send(true);
}
pub fn trigger_checkpoint(&self) -> u64 {
self.last_checkpoint_time.load(Ordering::Relaxed)
}
}
pub struct CheckpointScheduler;
impl CheckpointScheduler {
pub fn start<P>(
provider: Arc<P>,
config: SchedulerConfig,
) -> CheckpointSchedulerHandle
where
P: CheckpointProvider + 'static,
{
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let running = Arc::new(AtomicBool::new(true));
let last_checkpoint_time = Arc::new(AtomicU64::new(0));
let handle = CheckpointSchedulerHandle {
shutdown_tx,
running: running.clone(),
last_checkpoint_time: last_checkpoint_time.clone(),
};
tokio::spawn(Self::run_scheduler(
provider,
config,
shutdown_rx,
running,
last_checkpoint_time,
));
handle
}
async fn run_scheduler<P>(
provider: Arc<P>,
config: SchedulerConfig,
mut shutdown_rx: watch::Receiver<bool>,
running: Arc<AtomicBool>,
last_checkpoint_time: Arc<AtomicU64>,
) where
P: CheckpointProvider + 'static,
{
let mut ticker = interval(Duration::from_secs(config.interval_secs));
ticker.tick().await;
tracing::info!(
interval_secs = config.interval_secs,
path = %config.checkpoint_path.display(),
"Checkpoint scheduler started"
);
loop {
tokio::select! {
_ = ticker.tick() => {
if let Err(e) = Self::save_checkpoint(&provider, &config, &last_checkpoint_time) {
tracing::error!(error = %e, "Failed to save scheduled checkpoint");
}
}
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
tracing::info!("Checkpoint scheduler shutting down");
if let Err(e) = Self::save_checkpoint(&provider, &config, &last_checkpoint_time) {
tracing::error!(error = %e, "Failed to save final checkpoint");
}
break;
}
}
}
}
running.store(false, Ordering::Relaxed);
tracing::info!("Checkpoint scheduler stopped");
}
fn save_checkpoint<P>(
provider: &Arc<P>,
config: &SchedulerConfig,
last_checkpoint_time: &Arc<AtomicU64>,
) -> Result<(), CheckpointError>
where
P: CheckpointProvider,
{
let integrity = provider.integrity();
let validators = provider.validator_states();
let checkpoint = Checkpoint::new(&integrity, validators);
tracing::debug!(
sequence = checkpoint.sequence,
validators = checkpoint.validators.len(),
"Saving scheduled checkpoint"
);
checkpoint.save_atomic(&config.checkpoint_path, config.backup_count)?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
last_checkpoint_time.store(now, Ordering::Relaxed);
nklave_core::metrics::set_checkpoint_age(0);
tracing::info!(
sequence = checkpoint.sequence,
path = %config.checkpoint_path.display(),
"Scheduled checkpoint saved"
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use nklave_core::state::integrity::StateIntegrity;
use nklave_core::state::validator::ValidatorState;
use std::collections::HashMap;
use tempfile::TempDir;
struct MockProvider {
integrity: StateIntegrity,
validators: HashMap<[u8; 48], ValidatorState>,
}
impl MockProvider {
fn new() -> Self {
Self {
integrity: StateIntegrity::new(),
validators: HashMap::new(),
}
}
}
impl CheckpointProvider for MockProvider {
fn integrity(&self) -> StateIntegrity {
self.integrity.clone()
}
fn validator_states(&self) -> HashMap<[u8; 48], ValidatorState> {
self.validators.clone()
}
}
#[tokio::test]
async fn test_scheduler_starts_and_stops() {
let dir = TempDir::new().unwrap();
let provider = Arc::new(MockProvider::new());
let config = SchedulerConfig {
checkpoint_path: dir.path().join("checkpoint.json"),
interval_secs: 1, backup_count: 2,
};
let handle = CheckpointScheduler::start(provider, config);
assert!(handle.is_running());
tokio::time::sleep(Duration::from_millis(100)).await;
handle.shutdown();
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(!handle.is_running());
}
#[tokio::test]
async fn test_scheduler_creates_checkpoint() {
let dir = TempDir::new().unwrap();
let checkpoint_path = dir.path().join("checkpoint.json");
let provider = Arc::new(MockProvider::new());
let config = SchedulerConfig {
checkpoint_path: checkpoint_path.clone(),
interval_secs: 1,
backup_count: 2,
};
let handle = CheckpointScheduler::start(provider, config);
tokio::time::sleep(Duration::from_secs(2)).await;
handle.shutdown();
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(checkpoint_path.exists());
}
}