use async_trait::async_trait;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Notify;
use tracing::{debug, error, info, warn};
#[async_trait]
pub trait WorkDistributor: Send + Sync {
async fn wait_for_work(&self);
fn shutdown(&self);
}
#[cfg(feature = "postgres")]
pub struct PostgresDistributor {
#[allow(dead_code)]
database_url: String,
notify: Arc<Notify>,
shutdown: Arc<std::sync::atomic::AtomicBool>,
listener_handle: Option<tokio::task::JoinHandle<()>>,
}
#[cfg(feature = "postgres")]
impl PostgresDistributor {
const POLL_FALLBACK: Duration = Duration::from_secs(30);
pub async fn new(database_url: &str) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let notify = Arc::new(Notify::new());
let shutdown = Arc::new(std::sync::atomic::AtomicBool::new(false));
let listener_handle =
Self::spawn_listener(database_url.to_string(), notify.clone(), shutdown.clone())
.await?;
Ok(Self {
database_url: database_url.to_string(),
notify,
shutdown,
listener_handle: Some(listener_handle),
})
}
async fn spawn_listener(
database_url: String,
notify: Arc<Notify>,
shutdown: Arc<std::sync::atomic::AtomicBool>,
) -> Result<tokio::task::JoinHandle<()>, Box<dyn std::error::Error + Send + Sync>> {
use futures::StreamExt;
use tokio::sync::mpsc;
let (client, mut connection) =
tokio_postgres::connect(&database_url, tokio_postgres::NoTls).await?;
let (tx, mut rx) = mpsc::unbounded_channel();
let conn_shutdown = shutdown.clone();
tokio::spawn(async move {
let stream = futures::stream::poll_fn(move |cx| connection.poll_message(cx));
futures::pin_mut!(stream);
while !conn_shutdown.load(std::sync::atomic::Ordering::SeqCst) {
match stream.next().await {
Some(Ok(msg)) => {
if tx.send(msg).is_err() {
break;
}
}
Some(Err(e)) => {
error!("PostgreSQL listener connection error: {}", e);
break;
}
None => {
break;
}
}
}
});
client.execute("LISTEN task_ready", &[]).await?;
info!("PostgreSQL LISTEN/NOTIFY listener started on channel 'task_ready'");
let handle = tokio::spawn(async move {
let _client = client;
loop {
if shutdown.load(std::sync::atomic::Ordering::SeqCst) {
debug!("PostgreSQL listener shutting down");
break;
}
match tokio::time::timeout(Self::POLL_FALLBACK, rx.recv()).await {
Ok(Some(tokio_postgres::AsyncMessage::Notification(notification))) => {
debug!(
"Received NOTIFY on channel '{}': {}",
notification.channel(),
notification.payload()
);
notify.notify_waiters();
}
Ok(Some(_)) => {
}
Ok(None) => {
warn!("PostgreSQL listener channel closed");
break;
}
Err(_) => {
debug!("LISTEN timeout, triggering fallback poll");
notify.notify_waiters();
}
}
}
});
Ok(handle)
}
}
#[cfg(feature = "postgres")]
#[async_trait]
impl WorkDistributor for PostgresDistributor {
async fn wait_for_work(&self) {
tokio::select! {
_ = self.notify.notified() => {
debug!("Woke from NOTIFY signal");
}
_ = tokio::time::sleep(Self::POLL_FALLBACK) => {
debug!("Woke from fallback poll timeout");
}
}
}
fn shutdown(&self) {
self.shutdown
.store(true, std::sync::atomic::Ordering::SeqCst);
self.notify.notify_waiters();
}
}
#[cfg(feature = "postgres")]
impl Drop for PostgresDistributor {
fn drop(&mut self) {
self.shutdown();
if let Some(handle) = self.listener_handle.take() {
handle.abort();
}
}
}
#[cfg(feature = "sqlite")]
pub struct SqliteDistributor {
poll_interval: Duration,
shutdown: Arc<std::sync::atomic::AtomicBool>,
notify: Arc<Notify>,
}
#[cfg(feature = "sqlite")]
impl SqliteDistributor {
const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(500);
pub fn new() -> Self {
Self::with_poll_interval(Self::DEFAULT_POLL_INTERVAL)
}
pub fn with_poll_interval(poll_interval: Duration) -> Self {
Self {
poll_interval,
shutdown: Arc::new(std::sync::atomic::AtomicBool::new(false)),
notify: Arc::new(Notify::new()),
}
}
}
#[cfg(feature = "sqlite")]
impl Default for SqliteDistributor {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "sqlite")]
#[async_trait]
impl WorkDistributor for SqliteDistributor {
async fn wait_for_work(&self) {
if self.shutdown.load(std::sync::atomic::Ordering::SeqCst) {
return;
}
tokio::select! {
_ = tokio::time::sleep(self.poll_interval) => {
debug!("SQLite poll interval elapsed");
}
_ = self.notify.notified() => {
debug!("SQLite distributor shutdown signal received");
}
}
}
fn shutdown(&self) {
self.shutdown
.store(true, std::sync::atomic::Ordering::SeqCst);
self.notify.notify_waiters();
}
}
pub async fn create_work_distributor(
database: &crate::Database,
) -> Result<Box<dyn WorkDistributor>, Box<dyn std::error::Error + Send + Sync>> {
match database.backend() {
#[cfg(feature = "postgres")]
crate::database::BackendType::Postgres => {
Err("PostgreSQL distributor requires database URL. Use PostgresDistributor::new() directly.".into())
}
#[cfg(feature = "sqlite")]
crate::database::BackendType::Sqlite => Ok(Box::new(SqliteDistributor::new())),
#[allow(unreachable_patterns)]
_ => Err("Unsupported database backend".into()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "sqlite")]
#[tokio::test]
async fn test_sqlite_distributor_poll_interval() {
let distributor = SqliteDistributor::with_poll_interval(Duration::from_millis(50));
let start = std::time::Instant::now();
distributor.wait_for_work().await;
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(40));
assert!(elapsed < Duration::from_millis(100));
}
#[cfg(feature = "sqlite")]
#[tokio::test]
async fn test_sqlite_distributor_shutdown() {
let distributor = SqliteDistributor::with_poll_interval(Duration::from_secs(60));
let start = std::time::Instant::now();
let shutdown_distributor = distributor.shutdown.clone();
let shutdown_notify = distributor.notify.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
shutdown_distributor.store(true, std::sync::atomic::Ordering::SeqCst);
shutdown_notify.notify_waiters();
});
distributor.wait_for_work().await;
let elapsed = start.elapsed();
assert!(elapsed < Duration::from_secs(1));
}
}