use std::{num::NonZeroUsize, time::Duration};
use chrono::Utc;
use futures::{stream, Stream};
use graphile_worker_shutdown_signal::ShutdownSignal;
use sqlx::{postgres::PgListener, PgPool};
use tokio::sync::mpsc;
use tracing::error;
use crate::{
errors::Result,
sql::{get_job::get_job, task_identifiers::SharedTaskDetails},
Job,
};
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum StreamSource {
Polling,
PgListener,
RunOnce,
Internal,
}
pub type JobSignalSender = mpsc::Sender<()>;
pub type JobSignalReceiver = mpsc::Receiver<()>;
struct JobSignalStreamData {
interval: tokio::time::Interval,
pg_listener: PgListener,
shutdown_signal: ShutdownSignal,
concurrency: usize,
yield_n: Option<(NonZeroUsize, StreamSource)>,
internal_rx: Option<mpsc::Receiver<()>>,
}
impl JobSignalStreamData {
fn new(
interval: tokio::time::Interval,
pg_listener: PgListener,
shutdown_signal: ShutdownSignal,
concurrency: usize,
internal_rx: Option<mpsc::Receiver<()>>,
) -> Self {
JobSignalStreamData {
interval,
pg_listener,
shutdown_signal,
concurrency,
yield_n: None,
internal_rx,
}
}
}
pub async fn job_signal_stream(
pg_pool: PgPool,
poll_interval: Duration,
shutdown_signal: ShutdownSignal,
concurrency: usize,
) -> Result<impl Stream<Item = StreamSource>> {
job_signal_stream_internal(pg_pool, poll_interval, shutdown_signal, concurrency, None).await
}
pub async fn job_signal_stream_with_receiver(
pg_pool: PgPool,
poll_interval: Duration,
shutdown_signal: ShutdownSignal,
concurrency: usize,
internal_rx: JobSignalReceiver,
) -> Result<impl Stream<Item = StreamSource>> {
job_signal_stream_internal(
pg_pool,
poll_interval,
shutdown_signal,
concurrency,
Some(internal_rx),
)
.await
}
async fn job_signal_stream_internal(
pg_pool: PgPool,
poll_interval: Duration,
shutdown_signal: ShutdownSignal,
concurrency: usize,
internal_rx: Option<mpsc::Receiver<()>>,
) -> Result<impl Stream<Item = StreamSource>> {
let interval = tokio::time::interval(poll_interval);
let mut pg_listener = PgListener::connect_with(&pg_pool).await?;
pg_listener.listen("jobs:insert").await?;
let stream_data = JobSignalStreamData::new(
interval,
pg_listener,
shutdown_signal,
concurrency,
internal_rx,
);
let stream = stream::unfold(stream_data, |mut f| async {
if let Some((n, source)) = f.yield_n.take() {
if n.get() > 1 {
let remaining_yields = n.get() - 1;
f.yield_n = Some((NonZeroUsize::new(remaining_yields).unwrap(), source));
}
return Some((source, f));
}
if let Some(ref mut rx) = f.internal_rx {
tokio::select! {
biased;
_ = f.interval.tick() => {
f.yield_n = Some((NonZeroUsize::new(f.concurrency).unwrap(), StreamSource::Polling));
Some((StreamSource::Polling, f))
},
_ = f.pg_listener.recv() => {
f.yield_n = Some((NonZeroUsize::new(f.concurrency).unwrap(), StreamSource::PgListener));
Some((StreamSource::PgListener, f))
},
res = rx.recv() => {
if res.is_some() {
f.yield_n = Some((NonZeroUsize::new(f.concurrency).unwrap(), StreamSource::Internal));
Some((StreamSource::Internal, f))
} else {
None
}
},
_ = &mut f.shutdown_signal => None,
}
} else {
tokio::select! {
biased;
_ = f.interval.tick() => {
f.yield_n = Some((NonZeroUsize::new(f.concurrency).unwrap(), StreamSource::Polling));
Some((StreamSource::Polling, f))
},
_ = f.pg_listener.recv() => {
f.yield_n = Some((NonZeroUsize::new(f.concurrency).unwrap(), StreamSource::PgListener));
Some((StreamSource::PgListener, f))
},
_ = &mut f.shutdown_signal => None,
}
}
});
Ok(stream)
}
pub fn job_stream(
pg_pool: PgPool,
shutdown_signal: ShutdownSignal,
task_details: SharedTaskDetails,
escaped_schema: String,
worker_id: String,
forbidden_flags: Vec<String>,
use_local_time: bool,
) -> impl Stream<Item = Job> {
futures::stream::unfold((), move |()| {
let pg_pool = pg_pool.clone();
let task_details = task_details.clone();
let escaped_schema = escaped_schema.clone();
let worker_id = worker_id.clone();
let forbidden_flags = forbidden_flags.clone();
let job_fut = async move {
let now = use_local_time.then(Utc::now);
let task_details_guard = task_details.read().await;
let job = get_job(
&pg_pool,
&task_details_guard,
&escaped_schema,
&worker_id,
&forbidden_flags,
now,
)
.await
.map_err(|e| {
error!("Could not get job : {:?}", e);
e
});
match job {
Ok(Some(job)) => Some((job, ())),
Ok(None) => None,
Err(_) => {
error!("Error occured while trying to get job : {:?}", job);
None
}
}
};
let shutdown_fut = shutdown_signal.clone();
async move {
tokio::select! {
res = job_fut => res,
_ = shutdown_fut => None
}
}
})
}