1mod embedded;
2mod future;
3mod schema;
4
5use future::{Metrics as _, Timeout as _};
6
7use std::{
8 collections::{BTreeSet, VecDeque},
9 error::Error,
10 future::Future,
11 ops::Deref,
12 sync::Arc,
13 time::Duration,
14};
15
16use background_jobs_core::{Backoff, JobInfo, JobResult, MaxRetries, NewJobInfo, ReturnJobInfo};
17use dashmap::DashMap;
18use diesel::{
19 data_types::PgInterval,
20 dsl::IntervalDsl,
21 prelude::*,
22 sql_types::{Interval, Timestamp},
23};
24use diesel_async::{
25 pooled_connection::{
26 deadpool::{BuildError, Hook, Pool, PoolError},
27 AsyncDieselConnectionManager, ManagerConfig,
28 },
29 AsyncPgConnection, RunQueryDsl,
30};
31use futures_core::future::BoxFuture;
32use serde_json::Value;
33use time::PrimitiveDateTime;
34use tokio::{sync::Notify, task::JoinHandle};
35use tokio_postgres::{tls::NoTlsStream, AsyncMessage, Connection, NoTls, Notification, Socket};
36use tracing::Instrument;
37use url::Url;
38use uuid::Uuid;
39
40type ConfigFn =
41 Box<dyn Fn(&str) -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> + Send + Sync + 'static>;
42
43#[derive(Clone)]
44pub struct Storage {
45 inner: Arc<Inner>,
46 #[allow(dead_code)]
47 drop_handle: Arc<DropHandle<()>>,
48}
49
50struct Inner {
51 pool: Pool<AsyncPgConnection>,
52 queue_notifications: DashMap<String, Arc<Notify>>,
53}
54
55struct DropHandle<T> {
56 handle: JoinHandle<T>,
57}
58
59fn spawn<F: Future + Send + 'static>(
60 name: &str,
61 future: F,
62) -> std::io::Result<DropHandle<F::Output>>
63where
64 F::Output: Send,
65{
66 Ok(DropHandle {
67 handle: spawn_detach(name, future)?,
68 })
69}
70
71#[cfg(tokio_unstable)]
72fn spawn_detach<F: Future + Send + 'static>(
73 name: &str,
74 future: F,
75) -> std::io::Result<JoinHandle<F::Output>>
76where
77 F::Output: Send,
78{
79 tokio::task::Builder::new().name(name).spawn(future)
80}
81
82#[cfg(not(tokio_unstable))]
83fn spawn_detach<F: Future + Send + 'static>(
84 name: &str,
85 future: F,
86) -> std::io::Result<JoinHandle<F::Output>>
87where
88 F::Output: Send,
89{
90 let _ = name;
91 Ok(tokio::spawn(future))
92}
93
94#[derive(Debug)]
95pub enum ConnectPostgresError {
96 ConnectForMigration(tokio_postgres::Error),
98
99 Migration(refinery::Error),
101
102 BuildPool(BuildError),
104
105 SpawnTask(std::io::Error),
107}
108
109#[derive(Debug)]
110pub enum PostgresError {
111 Pool(PoolError),
112
113 Diesel(diesel::result::Error),
114
115 DbTimeout,
116}
117
118struct JobNotifierState<'a> {
119 inner: &'a Inner,
120 capacity: usize,
121 jobs: BTreeSet<Uuid>,
122 jobs_ordered: VecDeque<Uuid>,
123}
124
125#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, diesel_derive_enum::DbEnum)]
126#[ExistingTypePath = "crate::schema::sql_types::JobStatus"]
127enum JobStatus {
128 New,
129 Running,
130}
131
132#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, diesel_derive_enum::DbEnum)]
133#[ExistingTypePath = "crate::schema::sql_types::BackoffStrategy"]
134enum BackoffStrategy {
135 Linear,
136 Exponential,
137}
138
139#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, diesel_derive_enum::DbEnum)]
140#[ExistingTypePath = "crate::schema::sql_types::RetryStrategy"]
141enum RetryStrategy {
142 Infinite,
143 Count,
144}
145
146#[derive(diesel::Insertable, diesel::Queryable, diesel::Selectable)]
147#[diesel(table_name = crate::schema::job_queue)]
148struct PostgresJob {
149 id: Uuid,
150 name: String,
151 queue: String,
152 args: Value,
153 retry_count: i32,
154 max_retries: i32,
155 retry: RetryStrategy,
156 backoff_multiplier: i32,
157 backoff: BackoffStrategy,
158 next_queue: PrimitiveDateTime,
159 heartbeat_interval: PgInterval,
160}
161
162impl From<JobInfo> for PostgresJob {
163 fn from(value: JobInfo) -> Self {
164 let JobInfo {
165 id,
166 name,
167 queue,
168 args,
169 retry_count,
170 max_retries,
171 backoff_strategy,
172 next_queue,
173 heartbeat_interval,
174 } = value;
175
176 let next_queue = next_queue.to_offset(time::UtcOffset::UTC);
177
178 PostgresJob {
179 id,
180 name,
181 queue,
182 args,
183 retry_count: retry_count as _,
184 max_retries: match max_retries {
185 MaxRetries::Count(count) => count as _,
186 MaxRetries::Infinite => 0,
187 },
188 retry: match max_retries {
189 MaxRetries::Infinite => RetryStrategy::Infinite,
190 MaxRetries::Count(_) => RetryStrategy::Count,
191 },
192 backoff_multiplier: match backoff_strategy {
193 Backoff::Linear(multiplier) => multiplier as _,
194 Backoff::Exponential(multiplier) => multiplier as _,
195 },
196 backoff: match backoff_strategy {
197 Backoff::Linear(_) => BackoffStrategy::Linear,
198 Backoff::Exponential(_) => BackoffStrategy::Exponential,
199 },
200 next_queue: PrimitiveDateTime::new(next_queue.date(), next_queue.time()),
201 heartbeat_interval: (heartbeat_interval as i32).milliseconds(),
202 }
203 }
204}
205
206impl From<PostgresJob> for JobInfo {
207 fn from(value: PostgresJob) -> Self {
208 let PostgresJob {
209 id,
210 name,
211 queue,
212 args,
213 retry_count,
214 max_retries,
215 retry,
216 backoff_multiplier,
217 backoff,
218 next_queue,
219 heartbeat_interval,
220 } = value;
221
222 JobInfo {
223 id,
224 name,
225 queue,
226 args,
227 retry_count: retry_count as _,
228 max_retries: match retry {
229 RetryStrategy::Count => MaxRetries::Count(max_retries as _),
230 RetryStrategy::Infinite => MaxRetries::Infinite,
231 },
232 backoff_strategy: match backoff {
233 BackoffStrategy::Linear => Backoff::Linear(backoff_multiplier as _),
234 BackoffStrategy::Exponential => Backoff::Exponential(backoff_multiplier as _),
235 },
236 next_queue: next_queue.assume_utc(),
237 heartbeat_interval: (heartbeat_interval.microseconds / 1_000) as _,
238 }
239 }
240}
241
242#[async_trait::async_trait]
243impl background_jobs_core::Storage for Storage {
244 type Error = PostgresError;
245
246 #[tracing::instrument]
247 async fn info(
248 &self,
249 job_id: Uuid,
250 ) -> Result<Option<background_jobs_core::JobInfo>, Self::Error> {
251 let mut conn = self.inner.pool.get().await.map_err(PostgresError::Pool)?;
252
253 let opt = {
254 use schema::job_queue::dsl::*;
255
256 job_queue
257 .select(PostgresJob::as_select())
258 .filter(id.eq(job_id))
259 .get_result(&mut conn)
260 .metrics("background-jobs.postgres.info")
261 .timeout(Duration::from_secs(5))
262 .await
263 .map_err(|_| PostgresError::DbTimeout)?
264 .optional()
265 .map_err(PostgresError::Diesel)?
266 };
267
268 if let Some(postgres_job) = opt {
269 Ok(Some(postgres_job.into()))
270 } else {
271 Ok(None)
272 }
273 }
274
275 #[tracing::instrument(skip_all)]
276 async fn push(&self, job: NewJobInfo) -> Result<Uuid, Self::Error> {
277 self.insert(job.build()).await
278 }
279
280 #[tracing::instrument(skip(self))]
281 async fn pop(&self, in_queue: &str, in_runner_id: Uuid) -> Result<JobInfo, Self::Error> {
282 loop {
283 tracing::trace!("pop: looping");
284
285 let mut conn = self.inner.pool.get().await.map_err(PostgresError::Pool)?;
286
287 let notifier: Arc<Notify> = self
288 .inner
289 .queue_notifications
290 .entry(String::from(in_queue))
291 .or_insert_with(|| Arc::new(Notify::const_new()))
292 .clone();
293
294 diesel::sql_query("LISTEN queue_status_channel;")
295 .execute(&mut conn)
296 .metrics("background-jobs.postgres.listen")
297 .timeout(Duration::from_secs(5))
298 .await
299 .map_err(|_| PostgresError::DbTimeout)?
300 .map_err(PostgresError::Diesel)?;
301
302 let count = {
303 use schema::job_queue::dsl::*;
304
305 diesel::update(job_queue)
306 .filter(heartbeat.is_not_null().and(heartbeat.assume_not_null().le(
307 diesel::dsl::sql::<Timestamp>("NOW() - heartbeat_interval * 5"),
309 )))
310 .set((
311 heartbeat.eq(Option::<PrimitiveDateTime>::None),
312 status.eq(JobStatus::New),
313 runner_id.eq(Option::<Uuid>::None),
314 ))
315 .execute(&mut conn)
316 .metrics("background-jobs.postgres.requeue")
317 .await
318 .map_err(PostgresError::Diesel)?
319 };
320
321 if count > 0 {
322 tracing::info!("Reset {count} jobs");
323 }
324
325 let id_query = {
326 use schema::job_queue::dsl::*;
327
328 let queue_alias = diesel::alias!(schema::job_queue as queue_alias);
329
330 queue_alias
331 .select(queue_alias.field(id))
332 .filter(
333 queue_alias
334 .field(status)
335 .eq(JobStatus::New)
336 .and(queue_alias.field(queue).eq(in_queue))
337 .and(queue_alias.field(next_queue).le(diesel::dsl::now)),
338 )
339 .order(queue_alias.field(next_queue))
340 .for_update()
341 .skip_locked()
342 .single_value()
343 };
344
345 let opt = {
346 use schema::job_queue::dsl::*;
347
348 diesel::update(job_queue)
349 .filter(id.nullable().eq(id_query))
350 .filter(status.eq(JobStatus::New))
351 .set((
352 heartbeat.eq(diesel::dsl::now),
353 status.eq(JobStatus::Running),
354 runner_id.eq(in_runner_id),
355 ))
356 .returning(PostgresJob::as_returning())
357 .get_result(&mut conn)
358 .metrics("background-jobs.postgres.claim")
359 .timeout(Duration::from_secs(5))
360 .await
361 .map_err(|_| PostgresError::DbTimeout)?
362 .optional()
363 .map_err(PostgresError::Diesel)?
364 };
365
366 if let Some(postgres_job) = opt {
367 return Ok(postgres_job.into());
368 }
369
370 let sleep_duration = {
371 use schema::job_queue::dsl::*;
372
373 job_queue
374 .filter(queue.eq(in_queue).and(status.eq(JobStatus::New)))
375 .select(diesel::dsl::sql::<Interval>("NOW() - next_queue"))
376 .get_result::<PgInterval>(&mut conn)
377 .metrics("background-jobs.postgres.next-queue")
378 .timeout(Duration::from_secs(5))
379 .await
380 .map_err(|_| PostgresError::DbTimeout)?
381 .optional()
382 .map_err(PostgresError::Diesel)?
383 .map(|interval| {
384 if interval.microseconds < 0 {
385 Duration::from_micros(interval.microseconds.abs_diff(0))
386 } else {
387 Duration::from_secs(0)
388 }
389 })
390 .unwrap_or(Duration::from_secs(5))
391 };
392
393 drop(conn);
394 if notifier.notified().timeout(sleep_duration).await.is_ok() {
395 tracing::debug!("Notified");
396 } else {
397 tracing::debug!("Timed out");
398 }
399 }
400 }
401
402 #[tracing::instrument(skip(self))]
403 async fn heartbeat(&self, job_id: Uuid, in_runner_id: Uuid) -> Result<(), Self::Error> {
404 let mut conn = self.inner.pool.get().await.map_err(PostgresError::Pool)?;
405
406 {
407 use schema::job_queue::dsl::*;
408
409 diesel::update(job_queue)
410 .filter(id.eq(job_id))
411 .set((heartbeat.eq(diesel::dsl::now), runner_id.eq(in_runner_id)))
412 .execute(&mut conn)
413 .metrics("background-jobs.postgres.heartbeat")
414 .timeout(Duration::from_secs(5))
415 .await
416 .map_err(|_| PostgresError::DbTimeout)?
417 .map_err(PostgresError::Diesel)?;
418 }
419
420 Ok(())
421 }
422
423 #[tracing::instrument(skip(self))]
424 async fn complete(&self, return_job_info: ReturnJobInfo) -> Result<bool, Self::Error> {
425 let mut conn = self.inner.pool.get().await.map_err(PostgresError::Pool)?;
426
427 let job = {
428 use schema::job_queue::dsl::*;
429
430 diesel::delete(job_queue)
431 .filter(id.eq(return_job_info.id))
432 .returning(PostgresJob::as_returning())
433 .get_result(&mut conn)
434 .metrics("background-jobs.postgres.complete")
435 .timeout(Duration::from_secs(5))
436 .await
437 .map_err(|_| PostgresError::DbTimeout)?
438 .optional()
439 .map_err(PostgresError::Diesel)?
440 };
441
442 let mut job: JobInfo = if let Some(job) = job {
443 job.into()
444 } else {
445 return Ok(true);
446 };
447
448 match return_job_info.result {
449 JobResult::Success => Ok(true),
451 JobResult::Unexecuted | JobResult::Unregistered => {
453 self.insert(job).await?;
454
455 Ok(false)
456 }
457 JobResult::Failure if job.prepare_retry() => {
459 self.insert(job).await?;
460
461 Ok(false)
462 }
463 JobResult::Failure => Ok(true),
465 }
466 }
467}
468
469impl Storage {
470 pub async fn connect(
471 postgres_url: Url,
472 migration_table: Option<&str>,
473 ) -> Result<Self, ConnectPostgresError> {
474 let (mut client, conn) = tokio_postgres::connect(postgres_url.as_str(), NoTls)
475 .await
476 .map_err(ConnectPostgresError::ConnectForMigration)?;
477
478 let handle = spawn("postgres-migrations", conn)?;
479
480 let mut runner = embedded::migrations::runner();
481
482 if let Some(table_name) = migration_table {
483 runner.set_migration_table_name(table_name);
484 }
485
486 runner
487 .run_async(&mut client)
488 .await
489 .map_err(ConnectPostgresError::Migration)?;
490
491 handle.abort();
492 let _ = handle.await;
493
494 let parallelism = std::thread::available_parallelism()
495 .map(|u| u.into())
496 .unwrap_or(1_usize);
497
498 let (tx, rx) = flume::bounded(10);
499
500 let mut config = ManagerConfig::default();
501 config.custom_setup = build_handler(tx);
502
503 let mgr = AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
504 postgres_url,
505 config,
506 );
507
508 let pool = Pool::builder(mgr)
509 .runtime(deadpool::Runtime::Tokio1)
510 .wait_timeout(Some(Duration::from_secs(10)))
511 .create_timeout(Some(Duration::from_secs(2)))
512 .recycle_timeout(Some(Duration::from_secs(2)))
513 .post_create(Hook::sync_fn(|_, _| {
514 metrics::counter!("background-jobs.postgres.pool.connection.create").increment(1);
515 Ok(())
516 }))
517 .post_recycle(Hook::sync_fn(|_, _| {
518 metrics::counter!("background-jobs.postgres.pool.connection.recycle").increment(1);
519 Ok(())
520 }))
521 .max_size(parallelism * 8)
522 .build()
523 .map_err(ConnectPostgresError::BuildPool)?;
524
525 let inner = Arc::new(Inner {
526 pool,
527 queue_notifications: DashMap::new(),
528 });
529
530 let handle = spawn(
531 "postgres-delegate-notifications",
532 delegate_notifications(rx, inner.clone(), parallelism * 8),
533 )?;
534
535 let drop_handle = Arc::new(handle);
536
537 Ok(Storage { inner, drop_handle })
538 }
539
540 async fn insert(&self, job_info: JobInfo) -> Result<Uuid, PostgresError> {
541 let postgres_job: PostgresJob = job_info.into();
542 let id = postgres_job.id;
543
544 let mut conn = self.inner.pool.get().await.map_err(PostgresError::Pool)?;
545
546 {
547 use schema::job_queue::dsl::*;
548
549 postgres_job
550 .insert_into(job_queue)
551 .execute(&mut conn)
552 .metrics("background-jobs.postgres.insert")
553 .timeout(Duration::from_secs(5))
554 .await
555 .map_err(|_| PostgresError::DbTimeout)?
556 .map_err(PostgresError::Diesel)?;
557 }
558
559 Ok(id)
560 }
561}
562
563impl<'a> JobNotifierState<'a> {
564 fn handle(&mut self, payload: &str) {
565 let Some((job_id, queue_name)) = payload.split_once(' ') else {
566 tracing::warn!("Invalid queue payload {payload}");
567 return;
568 };
569
570 let Ok(job_id) = job_id.parse::<Uuid>() else {
571 tracing::warn!("Invalid job ID {job_id}");
572 return;
573 };
574
575 if !self.jobs.insert(job_id) {
576 return;
578 }
579
580 self.jobs_ordered.push_back(job_id);
581
582 if self.jobs_ordered.len() > self.capacity {
583 if let Some(job_id) = self.jobs_ordered.pop_front() {
584 self.jobs.remove(&job_id);
585 }
586 }
587
588 self.inner
589 .queue_notifications
590 .entry(queue_name.to_string())
591 .or_insert_with(|| Arc::new(Notify::const_new()))
592 .notify_one();
593
594 metrics::counter!("background-jobs.postgres.job-notifier.notified", "queue" => queue_name.to_string()).increment(1);
595 }
596}
597
598async fn delegate_notifications(
599 receiver: flume::Receiver<Notification>,
600 inner: Arc<Inner>,
601 capacity: usize,
602) {
603 let mut job_notifier_state = JobNotifierState {
604 inner: &inner,
605 capacity,
606 jobs: BTreeSet::new(),
607 jobs_ordered: VecDeque::new(),
608 };
609
610 while let Ok(notification) = receiver.recv_async().await {
611 tracing::trace!("delegate_notifications: looping");
612 metrics::counter!("background-jobs.postgres.notification").increment(1);
613
614 match notification.channel() {
615 "queue_status_channel" => {
616 job_notifier_state.handle(notification.payload());
618 }
619 channel => {
620 tracing::info!(
621 "Unhandled postgres notification: {channel}: {}",
622 notification.payload()
623 );
624 }
625 }
626 }
627
628 tracing::warn!("Notification delegator shutting down");
629}
630
631fn build_handler(sender: flume::Sender<Notification>) -> ConfigFn {
632 Box::new(
633 move |config: &str| -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> {
634 let sender = sender.clone();
635
636 let connect_span = tracing::trace_span!(parent: None, "connect future");
637
638 Box::pin(
639 async move {
640 let (client, conn) =
641 tokio_postgres::connect(config, tokio_postgres::tls::NoTls)
642 .await
643 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
644
645 spawn_db_notification_task(sender, conn)
647 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
648
649 AsyncPgConnection::try_from(client).await
650 }
651 .instrument(connect_span),
652 )
653 },
654 )
655}
656
657fn spawn_db_notification_task(
658 sender: flume::Sender<Notification>,
659 mut conn: Connection<Socket, NoTlsStream>,
660) -> std::io::Result<()> {
661 spawn_detach("postgres-notifications", async move {
662 while let Some(res) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
663 tracing::trace!("db_notification_task: looping");
664
665 match res {
666 Err(e) => {
667 tracing::error!("Database Connection {e:?}");
668 return;
669 }
670 Ok(AsyncMessage::Notice(e)) => {
671 tracing::warn!("Database Notice {e:?}");
672 }
673 Ok(AsyncMessage::Notification(notification)) => {
674 if sender.send_async(notification).await.is_err() {
675 tracing::warn!("Missed notification. Are we shutting down?");
676 }
677 }
678 Ok(_) => {
679 tracing::warn!("Unhandled AsyncMessage!!! Please contact the developer of this application");
680 }
681 }
682 }
683 })?;
684
685 Ok(())
686}
687
688impl<T> Future for DropHandle<T> {
689 type Output = <JoinHandle<T> as Future>::Output;
690
691 fn poll(
692 self: std::pin::Pin<&mut Self>,
693 cx: &mut std::task::Context<'_>,
694 ) -> std::task::Poll<Self::Output> {
695 std::pin::Pin::new(&mut self.get_mut().handle).poll(cx)
696 }
697}
698
699impl<T> Drop for DropHandle<T> {
700 fn drop(&mut self) {
701 self.handle.abort();
702 }
703}
704
705impl<T> Deref for DropHandle<T> {
706 type Target = JoinHandle<T>;
707
708 fn deref(&self) -> &Self::Target {
709 &self.handle
710 }
711}
712
713impl std::fmt::Debug for Storage {
714 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
715 f.debug_struct("Storage").finish()
716 }
717}
718
719impl From<refinery::Error> for ConnectPostgresError {
720 fn from(value: refinery::Error) -> Self {
721 Self::Migration(value)
722 }
723}
724
725impl From<tokio_postgres::Error> for ConnectPostgresError {
726 fn from(value: tokio_postgres::Error) -> Self {
727 Self::ConnectForMigration(value)
728 }
729}
730
731impl From<BuildError> for ConnectPostgresError {
732 fn from(value: BuildError) -> Self {
733 Self::BuildPool(value)
734 }
735}
736
737impl From<std::io::Error> for ConnectPostgresError {
738 fn from(value: std::io::Error) -> Self {
739 Self::SpawnTask(value)
740 }
741}
742
743impl std::fmt::Display for ConnectPostgresError {
744 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
745 match self {
746 Self::BuildPool(_) => write!(f, "Failed to build postgres connection pool"),
747 Self::ConnectForMigration(_) => {
748 write!(f, "Failed to connect to postgres for migrations")
749 }
750 Self::Migration(_) => write!(f, "Failed to run migrations"),
751 Self::SpawnTask(_) => write!(f, "Failed to spawn task"),
752 }
753 }
754}
755
756impl std::error::Error for ConnectPostgresError {
757 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
758 match self {
759 Self::BuildPool(e) => Some(e),
760 Self::ConnectForMigration(e) => Some(e),
761 Self::Migration(e) => Some(e),
762 Self::SpawnTask(e) => Some(e),
763 }
764 }
765}
766
767impl std::fmt::Display for PostgresError {
768 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
769 match self {
770 Self::Pool(_) => write!(f, "Error in db pool"),
771 Self::Diesel(_) => write!(f, "Error in database"),
772 Self::DbTimeout => write!(f, "Timed out waiting for postgres"),
773 }
774 }
775}
776
777impl Error for PostgresError {
778 fn source(&self) -> Option<&(dyn Error + 'static)> {
779 match self {
780 Self::Pool(e) => Some(e),
781 Self::Diesel(e) => Some(e),
782 Self::DbTimeout => None,
783 }
784 }
785}