use std::{result::Result as StdResult, str::FromStr, sync::Arc, time::Duration as StdDuration};
use jiff::{tz::TimeZone, Zoned};
use jiff_cron::{Schedule, ScheduleIterator};
use sqlx::postgres::{PgAdvisoryLock, PgListener};
use tokio::time::Instant;
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use crate::{
queue::{shutdown_channel, try_acquire_advisory_lock, Error as QueueError},
Queue, Task,
};
pub(crate) type Result<T = ()> = std::result::Result<T, Error>;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Queue(#[from] QueueError),
#[error(transparent)]
Database(#[from] sqlx::Error),
#[error(transparent)]
Jiff(#[from] jiff::Error),
#[error(transparent)]
Cron(#[from] jiff_cron::error::Error),
}
pub struct Scheduler<T: Task> {
queue: Arc<Queue<T>>,
queue_lock: PgAdvisoryLock,
task: Arc<T>,
shutdown_token: CancellationToken,
}
impl<T: Task> Scheduler<T> {
pub fn new(queue: Arc<Queue<T>>, task: T) -> Self {
let task = Arc::new(task);
let queue_lock = queue_scheduler_lock(&queue.name);
Self {
queue,
queue_lock,
task,
shutdown_token: CancellationToken::new(),
}
}
pub fn set_shutdown_token(&mut self, shutdown_token: CancellationToken) {
self.shutdown_token = shutdown_token;
}
pub fn shutdown(&self) {
self.shutdown_token.cancel();
}
#[instrument(skip(self), fields(queue.name = self.queue.name), err)]
pub async fn run(&self) -> Result {
let conn = self.queue.pool.acquire().await?;
let Some(guard) = try_acquire_advisory_lock(conn, &self.queue_lock).await? else {
tracing::trace!("Scheduler could not acquire lock, exiting");
return Ok(());
};
let Some((zoned_schedule, input)) = self.queue.task_schedule(&self.queue.pool).await?
else {
return Ok(());
};
let mut shutdown_listener = PgListener::connect_with(&self.queue.pool).await?;
let chan = shutdown_channel();
shutdown_listener.listen(chan).await?;
for next in zoned_schedule.iter() {
tracing::debug!(?next, "Waiting until next scheduled task enqueue");
tokio::select! {
notify_shutdown = shutdown_listener.recv() => {
match notify_shutdown {
Ok(_) => {
self.shutdown_token.cancel();
},
Err(err) => {
tracing::error!(%err, "Postgres shutdown notification error");
}
}
}
_ = self.shutdown_token.cancelled() => {
guard.release_now().await?;
break
}
_ = wait_until(&next) => {
self.process_next_schedule(&input).await?
}
}
}
Ok(())
}
#[instrument(skip_all, fields(task.id = tracing::field::Empty), err)]
async fn process_next_schedule(&self, input: &T::Input) -> Result {
let task_id = self
.queue
.enqueue(&self.queue.pool, &self.task, input)
.await?;
tracing::Span::current().record("task.id", task_id.as_hyphenated().to_string());
Ok(())
}
}
fn queue_scheduler_lock(queue_name: &str) -> PgAdvisoryLock {
PgAdvisoryLock::new(format!("{queue_name}-scheduler"))
}
async fn wait_until(next: &Zoned) {
let tz = next.time_zone();
loop {
let now = Zoned::now().with_time_zone(tz.to_owned());
if now >= *next {
break;
}
let until_next = next.duration_until(&now).unsigned_abs();
if until_next == StdDuration::ZERO {
break;
}
tokio::time::sleep_until(Instant::now() + until_next).await;
}
}
#[derive(Debug, PartialEq)]
pub struct ZonedSchedule {
schedule: Schedule,
timezone: TimeZone,
}
impl ZonedSchedule {
pub fn new(cron_expr: &str, time_zone_name: &str) -> StdResult<Self, ZonedScheduleError> {
let schedule = cron_expr.parse()?;
let timezone = TimeZone::get(time_zone_name)?;
assert!(
timezone.iana_name().is_some(),
"Time zones must use IANA names for now"
);
Ok(Self { schedule, timezone })
}
pub(crate) fn cron_expr(&self) -> String {
self.schedule.to_string()
}
pub(crate) fn iana_name(&self) -> &str {
self.timezone
.iana_name()
.expect("iana_name should always be Some because new ensures valid time zone")
}
pub fn iter(&self) -> ZonedScheduleIterator {
ZonedScheduleIterator {
upcoming: self.schedule.upcoming(self.timezone.clone()),
}
}
}
pub struct ZonedScheduleIterator<'a> {
upcoming: ScheduleIterator<'a>,
}
impl Iterator for ZonedScheduleIterator<'_> {
type Item = Zoned;
fn next(&mut self) -> Option<Self::Item> {
self.upcoming.next()
}
}
#[derive(Debug, thiserror::Error)]
pub enum ZonedScheduleError {
#[error(transparent)]
Jiff(#[from] jiff::Error),
#[error(transparent)]
Cron(#[from] jiff_cron::error::Error),
#[error("Parsing error: {0}")]
Parse(String),
}
impl FromStr for ZonedSchedule {
type Err = ZonedScheduleError;
fn from_str(s: &str) -> StdResult<Self, Self::Err> {
if !s.ends_with(']') {
return Err(ZonedScheduleError::Parse("Missing closing ']'".to_string()));
}
let open_bracket_pos = s
.find('[')
.ok_or_else(|| ZonedScheduleError::Parse("Missing opening '['".to_string()))?;
let cron_expr = &s[..open_bracket_pos];
let time_zone_name = &s[open_bracket_pos + 1..s.len() - 1];
ZonedSchedule::new(cron_expr, time_zone_name)
}
}
#[cfg(test)]
mod tests {
use std::time::SystemTime;
use jiff::ToSpan;
use super::*;
#[test]
fn zoned_schedule_creation_valid() {
let cron_expr = "0 0 * * * * *"; let time_zone_name = "UTC";
let schedule = ZonedSchedule::new(cron_expr, time_zone_name);
assert!(
schedule.is_ok(),
"Expected ZonedSchedule to be created successfully"
);
}
#[test]
fn zoned_schedule_creation_invalid_cron() {
let cron_expr = "invalid cron";
let time_zone_name = "UTC";
let schedule = ZonedSchedule::new(cron_expr, time_zone_name);
assert!(
schedule.is_err(),
"Expected error due to invalid cron expression"
);
}
#[test]
fn zoned_schedule_creation_invalid_time_zone() {
let cron_expr = "0 0 * * * * *";
let time_zone_name = "Invalid/TimeZone";
let schedule = ZonedSchedule::new(cron_expr, time_zone_name);
assert!(schedule.is_err(), "Expected error due to invalid time zone");
}
#[test]
fn zoned_schedule_parses() {
"0 0 * * * *[America/Los_Angeles]"
.parse::<ZonedSchedule>()
.expect("A schedule should be parsed");
}
#[tokio::test]
async fn wait_until_past_time() {
let tz = TimeZone::UTC;
let next = Zoned::now()
.with_time_zone(tz.to_owned())
.saturating_sub(10.seconds());
let start = SystemTime::now();
wait_until(&next).await;
let elapsed = start.elapsed().unwrap();
assert!(
elapsed < StdDuration::from_millis(10),
"Expected immediate return"
);
}
#[tokio::test]
async fn wait_until_future_time() {
let tz = TimeZone::UTC;
let next = Zoned::now()
.with_time_zone(tz.to_owned())
.saturating_add(5.seconds());
tokio::time::pause();
let handle = tokio::spawn({
let next = next.clone();
async move { wait_until(&next).await }
});
tokio::time::advance(StdDuration::from_secs(5)).await;
handle.await.expect("Failed to run wait_until");
let elapsed: StdDuration = (&Zoned::now().with_time_zone(tz.to_owned()) - &next)
.try_into()
.unwrap();
assert!(
elapsed < StdDuration::from_millis(10),
"Expected precise completion"
);
}
#[tokio::test]
async fn wait_until_exact_time() {
let tz = TimeZone::UTC;
let next = Zoned::now().with_time_zone(tz.to_owned());
let start = SystemTime::now();
wait_until(&next).await;
let elapsed = start.elapsed().unwrap();
assert!(
elapsed < StdDuration::from_millis(10),
"Expected immediate return"
);
}
}