use super::scheduler_runtime::SharedState;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, Mutex};
use tokio::task::JoinHandle;
#[derive(Debug, Clone, Copy)]
pub enum StateSaverMessage {
SaveRequest,
Shutdown,
}
#[derive(Clone)]
pub struct StateSaverHandle {
tx: mpsc::UnboundedSender<StateSaverMessage>,
task_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
}
impl StateSaverHandle {
pub fn new(tx: mpsc::UnboundedSender<StateSaverMessage>) -> Self {
Self {
tx,
task_handle: Arc::new(Mutex::new(None)),
}
}
pub fn mark_dirty(&self) {
let _ = self.tx.send(StateSaverMessage::SaveRequest);
}
pub async fn shutdown_and_wait(&self) -> anyhow::Result<()> {
self.tx
.send(StateSaverMessage::Shutdown)
.map_err(|e| anyhow::anyhow!("Failed to send shutdown message: {}", e))?;
if let Some(handle) = self.task_handle.lock().await.take() {
handle
.await
.map_err(|e| anyhow::anyhow!("State saver task panicked: {}", e))?;
}
Ok(())
}
pub fn set_task_handle(&self, handle: JoinHandle<()>) {
if let Ok(mut guard) = self.task_handle.try_lock() {
*guard = Some(handle);
}
}
}
pub async fn run(
shared_state: SharedState,
mut rx: mpsc::UnboundedReceiver<StateSaverMessage>,
save_interval: Duration,
) {
tracing::info!(
"State saver started with save interval: {}s",
save_interval.as_secs()
);
let mut interval = tokio::time::interval(save_interval);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
let mut pending_save = false;
loop {
tokio::select! {
_ = interval.tick() => {
if pending_save {
tracing::debug!("Periodic save triggered (dirty flag was set)");
perform_save(&shared_state).await;
pending_save = false;
}
}
msg = rx.recv() => {
match msg {
Some(StateSaverMessage::SaveRequest) => {
pending_save = true;
tracing::trace!("Save request received, will save on next interval");
}
Some(StateSaverMessage::Shutdown) => {
tracing::info!("State saver received shutdown signal");
perform_save(&shared_state).await;
tracing::info!("Final state save completed, state saver exiting");
break;
}
None => {
tracing::warn!("State saver channel closed unexpectedly");
if pending_save {
tracing::info!("Performing final save before exit");
perform_save(&shared_state).await;
}
break;
}
}
}
}
}
tracing::info!("State saver task finished");
}
async fn perform_save(shared_state: &SharedState) {
let start = std::time::Instant::now();
let mut state = shared_state.write().await;
state.save_state_if_dirty().await;
let elapsed = start.elapsed();
if elapsed.as_millis() > 100 {
tracing::warn!("State save took {}ms (slow I/O)", elapsed.as_millis());
} else {
tracing::trace!("State save completed in {}ms", elapsed.as_millis());
}
}