use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use sha2::{Digest as _, Sha256};
use crate::config::{SchedulerBackend, SchedulerConfig};
use crate::state::AppState;
use crate::task::TaskCoordination;
use crate::{AutumnError, AutumnResult};
pub type SchedulerFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub trait SchedulerCoordinator: Send + Sync {
fn backend(&self) -> &'static str;
fn replica_id(&self) -> &str;
fn try_acquire<'a>(
&'a self,
task_name: &'a str,
tick_key: &'a str,
coordination: TaskCoordination,
) -> SchedulerFuture<'a, AutumnResult<Option<SchedulerLease>>>;
}
pub struct SchedulerLease {
backend: String,
leader_id: String,
#[cfg(test)]
release_count: Option<Arc<std::sync::atomic::AtomicUsize>>,
#[cfg(feature = "db")]
postgres: Option<PostgresAdvisoryLease>,
}
impl SchedulerLease {
pub(crate) fn local(backend: impl Into<String>, leader_id: impl Into<String>) -> Self {
Self {
backend: backend.into(),
leader_id: leader_id.into(),
#[cfg(test)]
release_count: None,
#[cfg(feature = "db")]
postgres: None,
}
}
#[cfg(test)]
pub(crate) fn tracked(
backend: impl Into<String>,
leader_id: impl Into<String>,
release_count: Arc<std::sync::atomic::AtomicUsize>,
) -> Self {
Self {
backend: backend.into(),
leader_id: leader_id.into(),
release_count: Some(release_count),
#[cfg(feature = "db")]
postgres: None,
}
}
#[cfg(feature = "db")]
fn postgres(leader_id: impl Into<String>, lease: PostgresAdvisoryLease) -> Self {
Self {
backend: "postgres".to_owned(),
leader_id: leader_id.into(),
#[cfg(test)]
release_count: None,
postgres: Some(lease),
}
}
#[must_use]
pub fn backend(&self) -> &str {
&self.backend
}
#[must_use]
pub fn leader_id(&self) -> &str {
&self.leader_id
}
pub async fn release(self) -> AutumnResult<()> {
#[cfg(test)]
if let Some(release_count) = self.release_count {
release_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
#[cfg(feature = "db")]
if let Some(lease) = self.postgres {
return lease.release().await;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct InProcessSchedulerCoordinator {
replica_id: String,
}
impl InProcessSchedulerCoordinator {
#[must_use]
pub fn new(replica_id: impl Into<String>) -> Self {
Self {
replica_id: replica_id.into(),
}
}
}
impl SchedulerCoordinator for InProcessSchedulerCoordinator {
fn backend(&self) -> &'static str {
"in_process"
}
fn replica_id(&self) -> &str {
&self.replica_id
}
fn try_acquire<'a>(
&'a self,
_task_name: &'a str,
_tick_key: &'a str,
coordination: TaskCoordination,
) -> SchedulerFuture<'a, AutumnResult<Option<SchedulerLease>>> {
Box::pin(async move {
let backend = match coordination {
TaskCoordination::Fleet => "in_process",
TaskCoordination::PerReplica => "per_replica",
};
Ok(Some(SchedulerLease::local(
backend,
self.replica_id.clone(),
)))
})
}
}
#[cfg(feature = "db")]
#[derive(Clone)]
pub struct PostgresAdvisorySchedulerCoordinator {
pool: diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>,
replica_id: String,
key_prefix: String,
}
#[cfg(feature = "db")]
impl PostgresAdvisorySchedulerCoordinator {
#[must_use]
pub fn new(
pool: diesel_async::pooled_connection::deadpool::Pool<diesel_async::AsyncPgConnection>,
replica_id: impl Into<String>,
key_prefix: impl Into<String>,
) -> Self {
Self {
pool,
replica_id: replica_id.into(),
key_prefix: key_prefix.into(),
}
}
}
#[cfg(feature = "db")]
impl SchedulerCoordinator for PostgresAdvisorySchedulerCoordinator {
fn backend(&self) -> &'static str {
"postgres"
}
fn replica_id(&self) -> &str {
&self.replica_id
}
fn try_acquire<'a>(
&'a self,
task_name: &'a str,
tick_key: &'a str,
coordination: TaskCoordination,
) -> SchedulerFuture<'a, AutumnResult<Option<SchedulerLease>>> {
Box::pin(async move {
if coordination == TaskCoordination::PerReplica {
return Ok(Some(SchedulerLease::local(
"per_replica",
self.replica_id.clone(),
)));
}
let key = advisory_lock_key(&self.key_prefix, task_name, tick_key);
let mut conn = self.pool.get().await.map_err(|error| {
AutumnError::service_unavailable_msg(format!(
"scheduler postgres lock connection unavailable: {error}"
))
})?;
let acquired = try_pg_advisory_lock(&mut conn, key).await?;
if acquired {
Ok(Some(SchedulerLease::postgres(
self.replica_id.clone(),
PostgresAdvisoryLease {
conn: Some(conn),
key,
},
)))
} else {
Ok(None)
}
})
}
}
#[cfg(feature = "db")]
struct PostgresAdvisoryLease {
conn:
Option<diesel_async::pooled_connection::deadpool::Object<diesel_async::AsyncPgConnection>>,
key: i64,
}
#[cfg(feature = "db")]
impl PostgresAdvisoryLease {
async fn release(mut self) -> AutumnResult<()> {
let Some(mut conn) = self.conn.take() else {
return Ok(());
};
let released = unlock_pg_advisory_lock(&mut conn, self.key).await?;
if !released {
tracing::warn!(
lock_key = self.key,
"Postgres advisory scheduler lock was already released"
);
}
Ok(())
}
}
pub fn coordinator_from_config(
config: &SchedulerConfig,
state: &AppState,
) -> AutumnResult<Arc<dyn SchedulerCoordinator>> {
let replica_id = config.resolved_replica_id();
match config.backend {
SchedulerBackend::InProcess => Ok(Arc::new(InProcessSchedulerCoordinator::new(replica_id))),
SchedulerBackend::Postgres => {
#[cfg(feature = "db")]
{
let pool = state.pool().cloned().ok_or_else(|| {
AutumnError::service_unavailable_msg(
"scheduler.backend = \"postgres\" requires a configured database pool",
)
})?;
Ok(Arc::new(PostgresAdvisorySchedulerCoordinator::new(
pool,
replica_id,
config.key_prefix.clone(),
)))
}
#[cfg(not(feature = "db"))]
{
let _ = state;
Err(AutumnError::service_unavailable_msg(
"scheduler.backend = \"postgres\" requires the autumn-web db feature",
))
}
}
}
}
#[must_use]
pub fn fixed_delay_tick_key(task_name: &str, delay: Duration, unix_elapsed: Duration) -> String {
let interval = delay.as_nanos().max(1);
let bucket = unix_elapsed.as_nanos() / interval;
format!("{task_name}:{bucket}")
}
#[must_use]
pub fn cron_tick_key(task_name: &str, unix_secs: u64) -> String {
format!("{task_name}:{unix_secs}")
}
#[must_use]
pub fn advisory_lock_key(key_prefix: &str, task_name: &str, tick_key: &str) -> i64 {
let mut hasher = Sha256::new();
hasher.update(key_prefix.as_bytes());
hasher.update(b"\0");
hasher.update(task_name.as_bytes());
hasher.update(b"\0");
hasher.update(tick_key.as_bytes());
let digest = hasher.finalize();
let mut bytes = [0_u8; 8];
bytes.copy_from_slice(&digest[..8]);
i64::from_be_bytes(bytes)
}
#[must_use]
pub fn now_unix_secs() -> u64 {
now_unix_duration().as_secs()
}
#[must_use]
pub fn now_unix_duration() -> Duration {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
}
#[cfg(feature = "db")]
#[derive(diesel::QueryableByName)]
struct AdvisoryLockRow {
#[diesel(sql_type = diesel::sql_types::Bool)]
acquired: bool,
}
#[cfg(feature = "db")]
async fn try_pg_advisory_lock(
conn: &mut diesel_async::pooled_connection::deadpool::Object<diesel_async::AsyncPgConnection>,
key: i64,
) -> AutumnResult<bool> {
use diesel_async::RunQueryDsl as _;
let row = diesel::sql_query("SELECT pg_try_advisory_lock($1) AS acquired")
.bind::<diesel::sql_types::BigInt, _>(key)
.get_result::<AdvisoryLockRow>(&mut **conn)
.await
.map_err(|error| AutumnError::internal_server_error_msg(error.to_string()))?;
Ok(row.acquired)
}
#[cfg(feature = "db")]
#[derive(diesel::QueryableByName)]
struct AdvisoryUnlockRow {
#[diesel(sql_type = diesel::sql_types::Bool)]
released: bool,
}
#[cfg(feature = "db")]
async fn unlock_pg_advisory_lock(
conn: &mut diesel_async::pooled_connection::deadpool::Object<diesel_async::AsyncPgConnection>,
key: i64,
) -> AutumnResult<bool> {
use diesel_async::RunQueryDsl as _;
let row = diesel::sql_query("SELECT pg_advisory_unlock($1) AS released")
.bind::<diesel::sql_types::BigInt, _>(key)
.get_result::<AdvisoryUnlockRow>(&mut **conn)
.await
.map_err(|error| AutumnError::internal_server_error_msg(error.to_string()))?;
Ok(row.released)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cron_tick_key_uses_task_name_and_second() {
assert_eq!(cron_tick_key("digest", 1_700_000_000), "digest:1700000000");
}
}