use crate::backend::option::{OptionType, TaskOptions};
use crate::client::Client;
use crate::proto::{SchedulerEnqueueEvent, SchedulerEntry};
use crate::task::Task;
use chrono::{DateTime, Utc};
use cron::Schedule;
use prost::Message;
use std::collections::HashMap;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct PeriodicTask {
pub name: String,
pub cron: String,
pub payload: Vec<u8>,
pub queue: String,
pub options: TaskOptions,
pub schedule: Schedule,
pub next_tick: Option<DateTime<Utc>>,
}
impl PeriodicTask {
pub fn new(name: String, cron: String, payload: Vec<u8>, queue: String) -> anyhow::Result<Self> {
let schedule = Schedule::from_str(&cron)?;
let next_tick = schedule.upcoming(Utc).next();
let options = TaskOptions {
queue: queue.clone(),
..Default::default()
};
Ok(Self {
name,
cron,
payload,
queue,
options,
schedule,
next_tick,
})
}
pub fn new_with_options(
name: String,
cron: String,
payload: Vec<u8>,
options: TaskOptions,
) -> anyhow::Result<Self> {
let schedule = Schedule::from_str(&cron)?;
let next_tick = schedule.upcoming(Utc).next();
let queue = options.queue.clone();
Ok(Self {
name,
cron,
payload,
queue,
options,
schedule,
next_tick,
})
}
}
type SchedulerHandles = Arc<tokio::sync::Mutex<Option<(JoinHandle<()>, JoinHandle<()>)>>>;
pub struct Scheduler {
client: Arc<Client>,
id: String,
tasks: Arc<RwLock<HashMap<String, PeriodicTask>>>,
running: Arc<AtomicBool>,
notify: Arc<Notify>,
handles: SchedulerHandles,
heartbeat_interval: Duration,
acl_tenant: Option<String>,
}
impl Scheduler {
pub async fn new(
client: Arc<Client>,
heartbeat_interval: Option<Duration>,
) -> anyhow::Result<Self> {
Self::new_with_tenant(client, heartbeat_interval, None).await
}
pub async fn new_with_tenant(
client: Arc<Client>,
heartbeat_interval: Option<Duration>,
acl_tenant: Option<String>,
) -> anyhow::Result<Self> {
let base_id = format!(
"{}:{}:{}",
hostname::get().unwrap_or_default().to_string_lossy(),
std::process::id(),
Uuid::new_v4()
);
Ok(Self {
client,
id: base_id,
tasks: Arc::new(RwLock::new(HashMap::new())),
running: Arc::new(AtomicBool::new(false)),
notify: Arc::new(Notify::new()),
handles: Arc::new(tokio::sync::Mutex::new(None)),
heartbeat_interval: heartbeat_interval.unwrap_or(Duration::from_secs(10)),
acl_tenant,
})
}
pub async fn register(&self, mut task: PeriodicTask, queue: &str) -> anyhow::Result<String> {
let entry_id = Uuid::new_v4().to_string();
task.queue = queue.to_string();
task.options.queue = queue.to_string();
let schedule = Schedule::from_str(&task.cron)?;
task.schedule = schedule;
task.next_tick = task.schedule.upcoming(Utc).next();
let mut guard = self
.tasks
.write()
.map_err(|e| anyhow::anyhow!("lock poisoned: {e}"))?;
guard.insert(entry_id.clone(), task);
drop(guard);
self.notify.notify_one();
Ok(entry_id)
}
pub async fn unregister(&self, entry_id: &str) -> anyhow::Result<()> {
let mut guard = self
.tasks
.write()
.map_err(|e| anyhow::anyhow!("lock poisoned: {e}"))?;
guard.remove(entry_id);
drop(guard);
self.notify.notify_one();
Ok(())
}
pub fn list_tasks(&self) -> Vec<String> {
match self.tasks.read() {
Ok(tasks) => tasks.keys().cloned().collect(),
Err(_) => vec![],
}
}
#[cfg_attr(not(test), doc(hidden))]
pub async fn start(&self) {
if self.running.swap(true, Ordering::SeqCst) {
return;
}
let tasks = self.tasks.clone();
let running = self.running.clone();
let notify = self.notify.clone();
let client = self.client.clone();
let heartbeat_interval = self.heartbeat_interval;
let main_handle = tokio::spawn(async move {
loop {
if !running.load(Ordering::Relaxed) {
break;
}
let now = Utc::now();
let mut min_next: Option<DateTime<Utc>> = None;
let mut due_entries = Vec::new();
{
if let Ok(mut tasks) = tasks.write() {
for (entry_id, task) in tasks.iter_mut() {
let next_tick = task.next_tick;
if let Some(next) = next_tick {
if next <= now {
due_entries.push((
entry_id.clone(),
task.name.clone(),
task.payload.clone(),
task.options.clone(),
));
task.next_tick = task.schedule.upcoming(Utc).next();
}
if min_next.is_none() || min_next.map(|m| next < m).unwrap_or(false) {
min_next = Some(next);
}
}
}
}
}
for (entry_id, name, payload, options) in due_entries {
if let Ok(mut t) = Task::new(&name, &payload) {
t.options = options;
let _ = client.enqueue(t).await;
let event = SchedulerEnqueueEvent {
task_id: name.clone(),
enqueue_time: Some(prost_types::Timestamp {
seconds: now.timestamp(),
nanos: now.timestamp_subsec_nanos() as i32,
}),
};
let broker = client.get_scheduler_broker();
let entry_id = entry_id.clone();
let event_clone = event.clone();
tokio::spawn(async move {
let _ = broker
.record_scheduler_enqueue_event(&event_clone, &entry_id)
.await;
});
}
}
let sleep_dur = min_next
.map(|t| (t - now).to_std().unwrap_or(Duration::from_secs(1)))
.unwrap_or(heartbeat_interval);
tokio::select! {
_ = tokio::time::sleep(sleep_dur) => {},
_ = notify.notified() => {},
}
}
});
let heartbeat_handle = self.spawn_heartbeat();
let mut handles_guard = self.handles.lock().await;
*handles_guard = Some((main_handle, heartbeat_handle));
}
fn spawn_heartbeat(&self) -> JoinHandle<()> {
let tasks = self.tasks.clone();
let running = self.running.clone();
let client = self.client.clone();
let scheduler_id = self.id.clone();
let heartbeat_interval = self.heartbeat_interval;
let acl_tenant = self.acl_tenant.clone();
tokio::spawn(async move {
let scheduler_broker = client.get_scheduler_broker();
let mut ticker = tokio::time::interval(heartbeat_interval);
loop {
tokio::select! {
_ = ticker.tick() => {
if !running.load(Ordering::Relaxed) {
break;
}
let mut all_entries = Vec::new();
{
if let Ok(tasks) = tasks.read() {
for (entry_id, task) in tasks.iter() {
let next_tick = task.next_tick;
let entry = SchedulerEntry {
id: entry_id.clone(),
spec: task.cron.clone(),
task_type: task.name.clone(),
task_payload: task.payload.clone(),
enqueue_options: Self::stringify_options(&task.options),
next_enqueue_time: next_tick.map(|t| prost_types::Timestamp {
seconds: t.timestamp(),
nanos: t.timestamp_subsec_nanos() as i32,
}),
prev_enqueue_time: None,
};
all_entries.push(entry);
}
}
}
let _ = scheduler_broker.write_scheduler_entries(&all_entries, &scheduler_id, (heartbeat_interval * 2).as_secs(), acl_tenant.as_deref()).await;
}
_ = async {
while running.load(Ordering::Relaxed) {
tokio::time::sleep(heartbeat_interval).await;
}
} => {
let _ = scheduler_broker.clear_scheduler_entries(&scheduler_id, acl_tenant.as_deref()).await;
break;
}
}
}
})
}
pub async fn list_entries(&self, scheduler_id: &str) -> Vec<SchedulerEntry> {
let scheduler_broker = self.client.get_scheduler_broker();
let raw_map = scheduler_broker
.scheduler_entries_script(scheduler_id)
.await
.unwrap_or_default();
let mut entries = Vec::new();
for (_id, bytes) in raw_map {
if let Ok(entry) = SchedulerEntry::decode(&*bytes) {
entries.push(entry);
}
}
entries
}
pub async fn list_events(&self, count: usize) -> Vec<SchedulerEnqueueEvent> {
let scheduler_broker = self.client.get_scheduler_broker();
let raw_list = scheduler_broker
.scheduler_events_script(count)
.await
.unwrap_or_default();
let mut events = Vec::new();
for bytes in raw_list {
if let Ok(event) = SchedulerEnqueueEvent::decode(&*bytes) {
events.push(event);
}
}
events
}
#[cfg_attr(not(test), doc(hidden))]
pub async fn stop(&self) {
self.running.store(false, Ordering::SeqCst);
self.notify.notify_one();
let handles = {
let mut handles_guard = self.handles.lock().await;
handles_guard.take()
};
if let Some((main_handle, heartbeat_handle)) = handles {
let _ = main_handle.await;
let _ = heartbeat_handle.await;
}
let scheduler_broker = self.client.get_scheduler_broker();
let _ = scheduler_broker
.clear_scheduler_entries(&self.id, self.acl_tenant.as_deref())
.await;
}
pub fn stringify_options(opts: &TaskOptions) -> Vec<String> {
let option_types: Vec<OptionType> = opts.into();
option_types.iter().map(|opt| opt.to_string()).collect()
}
pub fn parse_options(opts: &[String]) -> TaskOptions {
let mut option_types = Vec::new();
for opt_str in opts {
if let Ok(opt) = OptionType::parse(opt_str) {
option_types.push(opt);
}
}
option_types.into()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_stringify_options() {
let mut opts = TaskOptions {
queue: "critical".to_string(),
..Default::default()
};
opts.max_retry = 5;
opts.timeout = Some(Duration::from_secs(60));
opts.retention = Some(Duration::from_secs(3600));
let strings = Scheduler::stringify_options(&opts);
assert!(strings.iter().any(|s| s.contains("Queue(\"critical\")")));
assert!(strings.iter().any(|s| s.contains("MaxRetry(5)")));
assert!(strings.iter().any(|s| s.contains("Timeout(60)")));
assert!(strings.iter().any(|s| s.contains("Retention(3600)")));
}
#[test]
fn test_parse_options() {
let option_strings = vec![
"Queue(\"critical\")".to_string(),
"MaxRetry(5)".to_string(),
"Timeout(60)".to_string(),
"Retention(3600)".to_string(),
];
let opts = Scheduler::parse_options(&option_strings);
assert_eq!(opts.queue, "critical");
assert_eq!(opts.max_retry, 5);
assert_eq!(opts.timeout, Some(Duration::from_secs(60)));
assert_eq!(opts.retention, Some(Duration::from_secs(3600)));
}
#[test]
fn test_options_roundtrip() {
let mut original_opts = TaskOptions {
queue: "high_priority".to_string(),
..Default::default()
};
original_opts.queue = "high_priority".to_string();
original_opts.max_retry = 10;
original_opts.timeout = Some(Duration::from_secs(120));
original_opts.retention = Some(Duration::from_secs(7200));
original_opts.task_id = Some("task-abc-123".to_string());
let strings = Scheduler::stringify_options(&original_opts);
let parsed_opts = Scheduler::parse_options(&strings);
assert_eq!(parsed_opts.queue, original_opts.queue);
assert_eq!(parsed_opts.max_retry, original_opts.max_retry);
assert_eq!(parsed_opts.timeout, original_opts.timeout);
assert_eq!(parsed_opts.retention, original_opts.retention);
assert_eq!(parsed_opts.task_id, original_opts.task_id);
}
#[test]
fn test_periodic_task_creation_with_options() {
let mut opts = TaskOptions {
queue: "scheduled".to_string(),
..Default::default()
};
opts.max_retry = 3;
opts.timeout = Some(Duration::from_secs(30));
let task = PeriodicTask::new_with_options(
"email:daily".to_string(),
"0 0 0 * * *".to_string(), b"daily email payload".to_vec(),
opts.clone(),
)
.unwrap();
assert_eq!(task.name, "email:daily");
assert_eq!(task.queue, "scheduled");
assert_eq!(task.options.queue, "scheduled");
assert_eq!(task.options.max_retry, 3);
assert_eq!(task.options.timeout, Some(Duration::from_secs(30)));
}
#[test]
fn test_periodic_task_default_creation() {
let task = PeriodicTask::new(
"test:task".to_string(),
"0 */5 * * * *".to_string(), b"test payload".to_vec(),
"default".to_string(),
)
.unwrap();
assert_eq!(task.name, "test:task");
assert_eq!(task.queue, "default");
assert_eq!(task.options.queue, "default");
}
#[tokio::test]
#[ignore] async fn test_scheduler_with_tenant() {
use crate::backend::RedisConnectionType;
use crate::client::Client;
let redis_connection_config = RedisConnectionType::single("redis://localhost:6379").unwrap();
let client = Arc::new(Client::new(redis_connection_config).await.unwrap());
let scheduler_no_tenant = Scheduler::new(client.clone(), None).await.unwrap();
assert_eq!(scheduler_no_tenant.acl_tenant, None);
let scheduler_with_tenant =
Scheduler::new_with_tenant(client.clone(), None, Some("tenant1".to_string()))
.await
.unwrap();
assert_eq!(
scheduler_with_tenant.acl_tenant,
Some("tenant1".to_string())
);
assert!(!scheduler_with_tenant.id.starts_with("tenant1:"));
}
#[tokio::test]
#[ignore] async fn test_scheduler_tenant_isolation() {
use crate::backend::RedisConnectionType;
use crate::client::Client;
let redis_connection_config = RedisConnectionType::single("redis://localhost:6379").unwrap();
let client = Arc::new(Client::new(redis_connection_config).await.unwrap());
let scheduler_tenant1 =
Scheduler::new_with_tenant(client.clone(), None, Some("tenant1".to_string()))
.await
.unwrap();
let scheduler_tenant2 =
Scheduler::new_with_tenant(client.clone(), None, Some("tenant2".to_string()))
.await
.unwrap();
assert_ne!(scheduler_tenant1.id, scheduler_tenant2.id);
assert_eq!(scheduler_tenant1.acl_tenant, Some("tenant1".to_string()));
assert_eq!(scheduler_tenant2.acl_tenant, Some("tenant2".to_string()));
}
}