use std::cmp::{Ord, Ordering, PartialOrd};
use std::collections::BinaryHeap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering as AtomicOrdering};
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crossbeam_channel::{unbounded, Receiver, Sender, TryRecvError};
pub type UnixTimestampMs = u64;
#[inline]
pub fn now_ms() -> UnixTimestampMs {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("System time went backwards")
.as_millis() as u64
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CronState {
CheckEvents,
DrainChannel,
ExecutingTask,
Sleeping,
Terminated,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CronEvent {
TaskReceived,
TimerExpired,
TaskDue,
TaskCompleted {
success: bool,
should_requeue: bool,
},
TerminationRequested,
ChannelDisconnected,
NoEvents,
}
#[derive(Clone, Debug)]
pub enum TaskMetadata {
OneShot,
Recurring {
interval_ms: u64,
},
Named {
name: String,
recurring_interval_ms: Option<u64>,
},
}
impl TaskMetadata {
#[inline]
pub fn recurrence_interval(&self) -> Option<u64> {
match self {
TaskMetadata::OneShot => None,
TaskMetadata::Recurring { interval_ms } => Some(*interval_ms),
TaskMetadata::Named {
recurring_interval_ms,
..
} => *recurring_interval_ms,
}
}
pub fn name(&self) -> &str {
match self {
TaskMetadata::OneShot => "one-shot",
TaskMetadata::Recurring { .. } => "recurring",
TaskMetadata::Named { name, .. } => name,
}
}
}
pub struct ScheduledTask {
pub scheduled_time_ms: UnixTimestampMs,
pub metadata: TaskMetadata,
pub task: Box<dyn FnMut() -> bool + Send>,
}
impl std::fmt::Debug for ScheduledTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScheduledTask")
.field("scheduled_time_ms", &self.scheduled_time_ms)
.field("metadata", &self.metadata)
.field("task", &"<fn>")
.finish()
}
}
impl Ord for ScheduledTask {
fn cmp(&self, other: &Self) -> Ordering {
other.scheduled_time_ms.cmp(&self.scheduled_time_ms)
}
}
impl PartialOrd for ScheduledTask {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for ScheduledTask {
fn eq(&self, other: &Self) -> bool {
self.scheduled_time_ms == other.scheduled_time_ms
}
}
impl Eq for ScheduledTask {}
#[derive(Default)]
pub struct CronStats {
pub tasks_executed: AtomicU64,
pub tasks_failed: AtomicU64,
pub tasks_panicked: AtomicU64,
pub transitions: AtomicU64,
}
impl CronStats {
#[inline]
fn record_success(&self) {
self.tasks_executed.fetch_add(1, AtomicOrdering::Relaxed);
}
#[inline]
fn record_failure(&self) {
self.tasks_executed.fetch_add(1, AtomicOrdering::Relaxed);
self.tasks_failed.fetch_add(1, AtomicOrdering::Relaxed);
}
#[inline]
fn record_panic(&self) {
self.tasks_executed.fetch_add(1, AtomicOrdering::Relaxed);
self.tasks_panicked.fetch_add(1, AtomicOrdering::Relaxed);
}
#[inline]
fn record_transition(&self) {
self.transitions.fetch_add(1, AtomicOrdering::Relaxed);
}
pub fn snapshot(&self) -> CronStatsSnapshot {
CronStatsSnapshot {
tasks_executed: self.tasks_executed.load(AtomicOrdering::Relaxed),
tasks_failed: self.tasks_failed.load(AtomicOrdering::Relaxed),
tasks_panicked: self.tasks_panicked.load(AtomicOrdering::Relaxed),
transitions: self.transitions.load(AtomicOrdering::Relaxed),
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct CronStatsSnapshot {
pub tasks_executed: u64,
pub tasks_failed: u64,
pub tasks_panicked: u64,
pub transitions: u64,
}
pub struct CronStateMachine {
state: CronState,
queue: BinaryHeap<ScheduledTask>,
task_rx: Receiver<ScheduledTask>,
poll_interval_ms: u64,
terminating: Arc<AtomicBool>,
channel_disconnected: bool,
stats: Arc<CronStats>,
ready_tx: Option<Sender<()>>,
}
impl CronStateMachine {
pub const DEFAULT_POLL_INTERVAL_MS: u64 = 100;
pub fn new(
task_rx: Receiver<ScheduledTask>,
terminating: Arc<AtomicBool>,
stats: Arc<CronStats>,
poll_interval_ms: u64,
ready_tx: Option<Sender<()>>,
) -> Self {
Self {
state: CronState::CheckEvents,
queue: BinaryHeap::new(),
task_rx,
poll_interval_ms,
terminating,
channel_disconnected: false,
stats,
ready_tx,
}
}
pub fn run(&mut self) {
log::info!(
"CronStateMachine started (poll={}ms, lock-free reactive design)",
self.poll_interval_ms
);
if let Some(tx) = self.ready_tx.take() {
let _ = tx.send(());
}
while self.state != CronState::Terminated {
let event = self.poll_event();
self.transition(event);
}
log::info!(
"CronStateMachine terminated (executed={}, failed={}, panicked={}, transitions={})",
self.stats.tasks_executed.load(AtomicOrdering::Relaxed),
self.stats.tasks_failed.load(AtomicOrdering::Relaxed),
self.stats.tasks_panicked.load(AtomicOrdering::Relaxed),
self.stats.transitions.load(AtomicOrdering::Relaxed),
);
}
fn poll_event(&mut self) -> CronEvent {
if let Some(task) = self.queue.peek() {
if task.scheduled_time_ms <= now_ms() {
return CronEvent::TaskDue;
}
}
if self.terminating.load(AtomicOrdering::Acquire) {
return CronEvent::TerminationRequested;
}
match self.state {
CronState::CheckEvents => self.poll_check_events(),
CronState::DrainChannel => self.poll_drain_channel(),
CronState::ExecutingTask => unreachable!("ExecutingTask polls internally"),
CronState::Sleeping => CronEvent::TimerExpired, CronState::Terminated => unreachable!("Cannot poll from Terminated"),
}
}
fn poll_check_events(&mut self) -> CronEvent {
if self.terminating.load(AtomicOrdering::Acquire) {
return CronEvent::TerminationRequested;
}
if self.channel_disconnected && self.queue.is_empty() {
return CronEvent::ChannelDisconnected;
}
match self.task_rx.try_recv() {
Ok(task) => {
self.queue.push(task);
return CronEvent::TaskReceived;
}
Err(TryRecvError::Disconnected) if !self.channel_disconnected => {
self.channel_disconnected = true;
return CronEvent::ChannelDisconnected;
}
_ => {}
}
if let Some(task) = self.queue.peek() {
if task.scheduled_time_ms <= now_ms() {
return CronEvent::TaskDue;
}
}
CronEvent::NoEvents
}
fn poll_drain_channel(&mut self) -> CronEvent {
match self.task_rx.try_recv() {
Ok(task) => {
self.queue.push(task);
CronEvent::TaskReceived
}
Err(TryRecvError::Empty) => {
if let Some(task) = self.queue.peek() {
if task.scheduled_time_ms <= now_ms() {
return CronEvent::TaskDue;
}
}
CronEvent::NoEvents
}
Err(TryRecvError::Disconnected) => {
self.channel_disconnected = true;
CronEvent::ChannelDisconnected
}
}
}
fn transition(&mut self, event: CronEvent) {
self.stats.record_transition();
let old_state = self.state;
self.state = match (self.state, event) {
(_, CronEvent::TerminationRequested) => {
log::debug!(
"Transition: {:?} + TerminationRequested → Terminated",
old_state
);
CronState::Terminated
}
(CronState::CheckEvents, CronEvent::TaskReceived) => {
log::trace!("Transition: CheckEvents + TaskReceived → DrainChannel");
CronState::DrainChannel
}
(CronState::CheckEvents, CronEvent::TaskDue) => {
self.execute_one_task();
CronState::CheckEvents }
(CronState::CheckEvents, CronEvent::NoEvents) => CronState::Sleeping,
(CronState::CheckEvents, CronEvent::ChannelDisconnected) => {
log::debug!("Channel disconnected, continuing with existing tasks");
if self.queue.is_empty() {
CronState::Terminated
} else {
CronState::CheckEvents
}
}
(CronState::DrainChannel, CronEvent::TaskReceived) => {
CronState::DrainChannel
}
(CronState::DrainChannel, CronEvent::TaskDue) => {
self.execute_one_task();
CronState::CheckEvents
}
(CronState::DrainChannel, CronEvent::NoEvents) => CronState::Sleeping,
(CronState::DrainChannel, CronEvent::ChannelDisconnected) => {
log::debug!("Channel disconnected during drain");
CronState::CheckEvents
}
(CronState::Sleeping, CronEvent::TimerExpired) => CronState::CheckEvents,
(CronState::Sleeping, CronEvent::TaskDue) => {
self.execute_one_task();
CronState::CheckEvents
}
(CronState::ExecutingTask, CronEvent::TaskCompleted { success, .. }) => {
if success {
self.stats.record_success();
} else {
self.stats.record_failure();
}
CronState::CheckEvents
}
(CronState::Terminated, _) => {
unreachable!("Cannot transition from Terminated")
}
(state, event) => {
log::warn!("Unexpected transition: {:?} + {:?}", state, event);
CronState::CheckEvents
}
};
if self.state == CronState::Sleeping {
self.do_sleep();
}
}
fn execute_one_task(&mut self) {
let Some(mut task) = self.queue.pop() else {
return;
};
log::trace!("Executing task: {}", task.metadata.name());
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (task.task)()));
match result {
Ok(true) => {
self.stats.record_success();
if let Some(interval) = task.metadata.recurrence_interval() {
task.scheduled_time_ms = now_ms() + interval;
self.queue.push(task);
log::trace!("Task requeued with interval {}ms", interval);
}
}
Ok(false) => {
self.stats.record_failure();
log::warn!(
"Task '{}' returned false, not re-queuing",
task.metadata.name()
);
}
Err(e) => {
self.stats.record_panic();
log::error!("Task '{}' panicked: {:?}", task.metadata.name(), e);
}
}
}
fn do_sleep(&self) {
let sleep_ms = if let Some(task) = self.queue.peek() {
let now = now_ms();
if task.scheduled_time_ms <= now {
0 } else {
(task.scheduled_time_ms - now).min(self.poll_interval_ms)
}
} else {
self.poll_interval_ms
};
if sleep_ms > 0 {
std::thread::sleep(Duration::from_millis(sleep_ms));
}
}
pub fn pending_count(&self) -> usize {
self.queue.len()
}
pub fn current_state(&self) -> CronState {
self.state
}
}
#[derive(Clone)]
pub struct CronHandle {
task_tx: Sender<ScheduledTask>,
terminating: Arc<AtomicBool>,
}
impl CronHandle {
pub fn schedule_at<F>(&self, time_ms: UnixTimestampMs, metadata: TaskMetadata, task: F) -> bool
where
F: FnMut() -> bool + Send + 'static,
{
let scheduled_task = ScheduledTask {
scheduled_time_ms: time_ms,
metadata,
task: Box::new(task),
};
self.task_tx.send(scheduled_task).is_ok()
}
pub fn schedule_after<F>(&self, delay_ms: u64, metadata: TaskMetadata, task: F) -> bool
where
F: FnMut() -> bool + Send + 'static,
{
self.schedule_at(now_ms() + delay_ms, metadata, task)
}
pub fn schedule_recurring<F>(
&self,
initial_delay_ms: u64,
interval_ms: u64,
name: &str,
task: F,
) -> bool
where
F: FnMut() -> bool + Send + 'static,
{
let metadata = TaskMetadata::Named {
name: name.to_string(),
recurring_interval_ms: Some(interval_ms),
};
self.schedule_after(initial_delay_ms, metadata, task)
}
pub fn schedule_once<F>(&self, delay_ms: u64, name: &str, task: F) -> bool
where
F: FnMut() -> bool + Send + 'static,
{
let metadata = TaskMetadata::Named {
name: name.to_string(),
recurring_interval_ms: None,
};
self.schedule_after(delay_ms, metadata, task)
}
pub fn request_shutdown(&self) {
self.terminating.store(true, AtomicOrdering::Release);
}
pub fn is_shutting_down(&self) -> bool {
self.terminating.load(AtomicOrdering::Acquire)
}
}
pub fn spawn_cron(
terminating: Arc<AtomicBool>,
) -> (
CronHandle,
std::thread::JoinHandle<()>,
Arc<CronStats>,
Receiver<()>,
) {
spawn_cron_with_interval(terminating, CronStateMachine::DEFAULT_POLL_INTERVAL_MS)
}
pub fn spawn_cron_with_interval(
terminating: Arc<AtomicBool>,
poll_interval_ms: u64,
) -> (
CronHandle,
std::thread::JoinHandle<()>,
Arc<CronStats>,
Receiver<()>,
) {
let (task_tx, task_rx) = unbounded::<ScheduledTask>();
let (ready_tx, ready_rx) = unbounded::<()>();
let stats = Arc::new(CronStats::default());
let stats_clone = Arc::clone(&stats);
let terminating_clone = Arc::clone(&terminating);
let thread_handle = std::thread::Builder::new()
.name("cron-state-machine".to_string())
.spawn(move || {
let mut sm = CronStateMachine::new(
task_rx,
terminating_clone,
stats_clone,
poll_interval_ms,
Some(ready_tx),
);
sm.run();
})
.expect("Failed to spawn cron state machine thread");
let handle = CronHandle {
task_tx,
terminating,
};
(handle, thread_handle, stats, ready_rx)
}
#[cfg(test)]
mod tests;