1mod embedded;
2mod future;
3mod schema;
4
5use future::{Metrics as _, Timeout as _};
6
7use std::{
8 collections::{BTreeSet, VecDeque},
9 error::Error,
10 future::Future,
11 ops::Deref,
12 sync::Arc,
13 time::Duration,
14};
15
16use background_jobs_core::{Backoff, JobInfo, JobResult, MaxRetries, NewJobInfo, ReturnJobInfo};
17use dashmap::DashMap;
18use diesel::{
19 data_types::PgInterval,
20 dsl::IntervalDsl,
21 prelude::*,
22 sql_types::{Interval, Timestamp},
23};
24use diesel_async::{
25 pooled_connection::{
26 deadpool::{BuildError, Hook, Pool, PoolError},
27 AsyncDieselConnectionManager, ManagerConfig,
28 },
29 AsyncPgConnection, RunQueryDsl,
30};
31use futures_core::future::BoxFuture;
32use serde_json::Value;
33use time::PrimitiveDateTime;
34use tokio::{sync::Notify, task::JoinHandle};
35use tokio_postgres::{tls::NoTlsStream, AsyncMessage, Connection, NoTls, Notification, Socket};
36use tracing::Instrument;
37use url::Url;
38use uuid::Uuid;
39
40type ConfigFn =
41 Box<dyn Fn(&str) -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> + Send + Sync + 'static>;
42
43#[derive(Clone)]
44pub struct Storage {
45 inner: Arc<Inner>,
46 #[allow(dead_code)]
47 drop_handle: Arc<DropHandle<()>>,
48}
49
50struct Inner {
51 pool: Pool<AsyncPgConnection>,
52 queue_notifications: DashMap<String, Arc<Notify>>,
53}
54
55struct DropHandle<T> {
56 handle: JoinHandle<T>,
57}
58
59fn spawn<F: Future + Send + 'static>(
60 name: &str,
61 future: F,
62) -> std::io::Result<DropHandle<F::Output>>
63where
64 F::Output: Send,
65{
66 Ok(DropHandle {
67 handle: spawn_detach(name, future)?,
68 })
69}
70
71#[cfg(tokio_unstable)]
72fn spawn_detach<F: Future + Send + 'static>(
73 name: &str,
74 future: F,
75) -> std::io::Result<JoinHandle<F::Output>>
76where
77 F::Output: Send,
78{
79 tokio::task::Builder::new().name(name).spawn(future)
80}
81
82#[cfg(not(tokio_unstable))]
83fn spawn_detach<F: Future + Send + 'static>(
84 name: &str,
85 future: F,
86) -> std::io::Result<JoinHandle<F::Output>>
87where
88 F::Output: Send,
89{
90 let _ = name;
91 Ok(tokio::spawn(future))
92}
93
94#[derive(Debug)]
95pub enum ConnectPostgresError {
96 ConnectForMigration(tokio_postgres::Error),
98
99 Migration(refinery::Error),
101
102 BuildPool(BuildError),
104
105 SpawnTask(std::io::Error),
107}
108
109#[derive(Debug)]
110pub enum PostgresError {
111 Pool(PoolError),
112
113 Diesel(diesel::result::Error),
114
115 DbTimeout,
116}
117
118struct JobNotifierState<'a> {
119 inner: &'a Inner,
120 capacity: usize,
121 jobs: BTreeSet<Uuid>,
122 jobs_ordered: VecDeque<Uuid>,
123}
124
125#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, diesel_derive_enum::DbEnum)]
126#[ExistingTypePath = "crate::schema::sql_types::JobStatus"]
127enum JobStatus {
128 New,
129 Running,
130}
131
132#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, diesel_derive_enum::DbEnum)]
133#[ExistingTypePath = "crate::schema::sql_types::BackoffStrategy"]
134enum BackoffStrategy {
135 Linear,
136 Exponential,
137}
138
139#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, diesel_derive_enum::DbEnum)]
140#[ExistingTypePath = "crate::schema::sql_types::RetryStrategy"]
141enum RetryStrategy {
142 Infinite,
143 Count,
144}
145
146#[derive(diesel::Insertable, diesel::Queryable, diesel::Selectable)]
147#[diesel(table_name = crate::schema::job_queue)]
148struct PostgresJob {
149 id: Uuid,
150 name: String,
151 queue: String,
152 args: Value,
153 retry_count: i32,
154 max_retries: i32,
155 retry: RetryStrategy,
156 backoff_multiplier: i32,
157 backoff: BackoffStrategy,
158 next_queue: PrimitiveDateTime,
159 heartbeat_interval: PgInterval,
160}
161
162impl From<JobInfo> for PostgresJob {
163 fn from(value: JobInfo) -> Self {
164 let JobInfo {
165 id,
166 name,
167 queue,
168 args,
169 retry_count,
170 max_retries,
171 backoff_strategy,
172 next_queue,
173 heartbeat_interval,
174 } = value;
175
176 let next_queue = next_queue.to_offset(time::UtcOffset::UTC);
177
178 PostgresJob {
179 id,
180 name,
181 queue,
182 args,
183 retry_count: retry_count as _,
184 max_retries: match max_retries {
185 MaxRetries::Count(count) => count as _,
186 MaxRetries::Infinite => 0,
187 },
188 retry: match max_retries {
189 MaxRetries::Infinite => RetryStrategy::Infinite,
190 MaxRetries::Count(_) => RetryStrategy::Count,
191 },
192 backoff_multiplier: match backoff_strategy {
193 Backoff::Linear(multiplier) => multiplier as _,
194 Backoff::Exponential(multiplier) => multiplier as _,
195 },
196 backoff: match backoff_strategy {
197 Backoff::Linear(_) => BackoffStrategy::Linear,
198 Backoff::Exponential(_) => BackoffStrategy::Exponential,
199 },
200 next_queue: PrimitiveDateTime::new(next_queue.date(), next_queue.time()),
201 heartbeat_interval: (heartbeat_interval as i32).milliseconds(),
202 }
203 }
204}
205
206impl From<PostgresJob> for JobInfo {
207 fn from(value: PostgresJob) -> Self {
208 let PostgresJob {
209 id,
210 name,
211 queue,
212 args,
213 retry_count,
214 max_retries,
215 retry,
216 backoff_multiplier,
217 backoff,
218 next_queue,
219 heartbeat_interval,
220 } = value;
221
222 JobInfo {
223 id,
224 name,
225 queue,
226 args,
227 retry_count: retry_count as _,
228 max_retries: match retry {
229 RetryStrategy::Count => MaxRetries::Count(max_retries as _),
230 RetryStrategy::Infinite => MaxRetries::Infinite,
231 },
232 backoff_strategy: match backoff {
233 BackoffStrategy::Linear => Backoff::Linear(backoff_multiplier as _),
234 BackoffStrategy::Exponential => Backoff::Exponential(backoff_multiplier as _),
235 },
236 next_queue: next_queue.assume_utc(),
237 heartbeat_interval: (heartbeat_interval.microseconds / 1_000) as _,
238 }
239 }
240}
241
242impl background_jobs_core::Storage for Storage {
243 type Error = PostgresError;
244
245 #[tracing::instrument]
246 async fn info(
247 &self,
248 job_id: Uuid,
249 ) -> Result<Option<background_jobs_core::JobInfo>, Self::Error> {
250 let mut conn = self.inner.pool.get().await.map_err(PostgresError::Pool)?;
251
252 let opt = {
253 use schema::job_queue::dsl::*;
254
255 job_queue
256 .select(PostgresJob::as_select())
257 .filter(id.eq(job_id))
258 .get_result(&mut conn)
259 .metrics("background-jobs.postgres.info")
260 .timeout(Duration::from_secs(5))
261 .await
262 .map_err(|_| PostgresError::DbTimeout)?
263 .optional()
264 .map_err(PostgresError::Diesel)?
265 };
266
267 if let Some(postgres_job) = opt {
268 Ok(Some(postgres_job.into()))
269 } else {
270 Ok(None)
271 }
272 }
273
274 #[tracing::instrument(skip_all)]
275 async fn push(&self, job: NewJobInfo) -> Result<Uuid, Self::Error> {
276 self.insert(job.build()).await
277 }
278
279 #[tracing::instrument(skip(self))]
280 async fn pop(&self, in_queue: &str, in_runner_id: Uuid) -> Result<JobInfo, Self::Error> {
281 loop {
282 tracing::trace!("pop: looping");
283
284 let mut conn = self.inner.pool.get().await.map_err(PostgresError::Pool)?;
285
286 let notifier: Arc<Notify> = self
287 .inner
288 .queue_notifications
289 .entry(String::from(in_queue))
290 .or_insert_with(|| Arc::new(Notify::const_new()))
291 .clone();
292
293 diesel::sql_query("LISTEN queue_status_channel;")
294 .execute(&mut conn)
295 .metrics("background-jobs.postgres.listen")
296 .timeout(Duration::from_secs(5))
297 .await
298 .map_err(|_| PostgresError::DbTimeout)?
299 .map_err(PostgresError::Diesel)?;
300
301 let count = {
302 use schema::job_queue::dsl::*;
303
304 diesel::update(job_queue)
305 .filter(heartbeat.is_not_null().and(heartbeat.assume_not_null().le(
306 diesel::dsl::sql::<Timestamp>("NOW() - heartbeat_interval * 5"),
308 )))
309 .set((
310 heartbeat.eq(Option::<PrimitiveDateTime>::None),
311 status.eq(JobStatus::New),
312 runner_id.eq(Option::<Uuid>::None),
313 ))
314 .execute(&mut conn)
315 .metrics("background-jobs.postgres.requeue")
316 .await
317 .map_err(PostgresError::Diesel)?
318 };
319
320 if count > 0 {
321 tracing::info!("Reset {count} jobs");
322 }
323
324 let id_query = {
325 use schema::job_queue::dsl::*;
326
327 let queue_alias = diesel::alias!(schema::job_queue as queue_alias);
328
329 queue_alias
330 .select(queue_alias.field(id))
331 .filter(
332 queue_alias
333 .field(status)
334 .eq(JobStatus::New)
335 .and(queue_alias.field(queue).eq(in_queue))
336 .and(queue_alias.field(next_queue).le(diesel::dsl::now)),
337 )
338 .order(queue_alias.field(next_queue))
339 .for_update()
340 .skip_locked()
341 .single_value()
342 };
343
344 let opt = {
345 use schema::job_queue::dsl::*;
346
347 diesel::update(job_queue)
348 .filter(id.nullable().eq(id_query))
349 .filter(status.eq(JobStatus::New))
350 .set((
351 heartbeat.eq(diesel::dsl::now),
352 status.eq(JobStatus::Running),
353 runner_id.eq(in_runner_id),
354 ))
355 .returning(PostgresJob::as_returning())
356 .get_result(&mut conn)
357 .metrics("background-jobs.postgres.claim")
358 .timeout(Duration::from_secs(5))
359 .await
360 .map_err(|_| PostgresError::DbTimeout)?
361 .optional()
362 .map_err(PostgresError::Diesel)?
363 };
364
365 if let Some(postgres_job) = opt {
366 return Ok(postgres_job.into());
367 }
368
369 let sleep_duration = {
370 use schema::job_queue::dsl::*;
371
372 job_queue
373 .filter(queue.eq(in_queue).and(status.eq(JobStatus::New)))
374 .select(diesel::dsl::sql::<Interval>("NOW() - next_queue"))
375 .get_result::<PgInterval>(&mut conn)
376 .metrics("background-jobs.postgres.next-queue")
377 .timeout(Duration::from_secs(5))
378 .await
379 .map_err(|_| PostgresError::DbTimeout)?
380 .optional()
381 .map_err(PostgresError::Diesel)?
382 .map(|interval| {
383 if interval.microseconds < 0 {
384 Duration::from_micros(interval.microseconds.abs_diff(0))
385 } else {
386 Duration::from_secs(0)
387 }
388 })
389 .unwrap_or(Duration::from_secs(5))
390 };
391
392 drop(conn);
393 if notifier.notified().timeout(sleep_duration).await.is_ok() {
394 tracing::debug!("Notified");
395 } else {
396 tracing::debug!("Timed out");
397 }
398 }
399 }
400
401 #[tracing::instrument(skip(self))]
402 async fn heartbeat(&self, job_id: Uuid, in_runner_id: Uuid) -> Result<(), Self::Error> {
403 let mut conn = self.inner.pool.get().await.map_err(PostgresError::Pool)?;
404
405 {
406 use schema::job_queue::dsl::*;
407
408 diesel::update(job_queue)
409 .filter(id.eq(job_id))
410 .set((heartbeat.eq(diesel::dsl::now), runner_id.eq(in_runner_id)))
411 .execute(&mut conn)
412 .metrics("background-jobs.postgres.heartbeat")
413 .timeout(Duration::from_secs(5))
414 .await
415 .map_err(|_| PostgresError::DbTimeout)?
416 .map_err(PostgresError::Diesel)?;
417 }
418
419 Ok(())
420 }
421
422 #[tracing::instrument(skip(self))]
423 async fn complete(&self, return_job_info: ReturnJobInfo) -> Result<bool, Self::Error> {
424 let mut conn = self.inner.pool.get().await.map_err(PostgresError::Pool)?;
425
426 let job = {
427 use schema::job_queue::dsl::*;
428
429 diesel::delete(job_queue)
430 .filter(id.eq(return_job_info.id))
431 .returning(PostgresJob::as_returning())
432 .get_result(&mut conn)
433 .metrics("background-jobs.postgres.complete")
434 .timeout(Duration::from_secs(5))
435 .await
436 .map_err(|_| PostgresError::DbTimeout)?
437 .optional()
438 .map_err(PostgresError::Diesel)?
439 };
440
441 let mut job: JobInfo = if let Some(job) = job {
442 job.into()
443 } else {
444 return Ok(true);
445 };
446
447 match return_job_info.result {
448 JobResult::Success => Ok(true),
450 JobResult::Unexecuted | JobResult::Unregistered => {
452 self.insert(job).await?;
453
454 Ok(false)
455 }
456 JobResult::Failure if job.prepare_retry() => {
458 self.insert(job).await?;
459
460 Ok(false)
461 }
462 JobResult::Failure => Ok(true),
464 }
465 }
466}
467
468impl Storage {
469 pub async fn connect(
470 postgres_url: Url,
471 migration_table: Option<&str>,
472 ) -> Result<Self, ConnectPostgresError> {
473 let (mut client, conn) = tokio_postgres::connect(postgres_url.as_str(), NoTls)
474 .await
475 .map_err(ConnectPostgresError::ConnectForMigration)?;
476
477 let handle = spawn("postgres-migrations", conn)?;
478
479 let mut runner = embedded::migrations::runner();
480
481 if let Some(table_name) = migration_table {
482 runner.set_migration_table_name(table_name);
483 }
484
485 runner
486 .run_async(&mut client)
487 .await
488 .map_err(ConnectPostgresError::Migration)?;
489
490 handle.abort();
491 let _ = handle.await;
492
493 let parallelism = std::thread::available_parallelism()
494 .map(|u| u.into())
495 .unwrap_or(1_usize);
496
497 let (tx, rx) = flume::bounded(10);
498
499 let mut config = ManagerConfig::default();
500 config.custom_setup = build_handler(tx);
501
502 let mgr = AsyncDieselConnectionManager::<AsyncPgConnection>::new_with_config(
503 postgres_url,
504 config,
505 );
506
507 let pool = Pool::builder(mgr)
508 .runtime(deadpool::Runtime::Tokio1)
509 .wait_timeout(Some(Duration::from_secs(10)))
510 .create_timeout(Some(Duration::from_secs(2)))
511 .recycle_timeout(Some(Duration::from_secs(2)))
512 .post_create(Hook::sync_fn(|_, _| {
513 metrics::counter!("background-jobs.postgres.pool.connection.create").increment(1);
514 Ok(())
515 }))
516 .post_recycle(Hook::sync_fn(|_, _| {
517 metrics::counter!("background-jobs.postgres.pool.connection.recycle").increment(1);
518 Ok(())
519 }))
520 .max_size(parallelism * 8)
521 .build()
522 .map_err(ConnectPostgresError::BuildPool)?;
523
524 let inner = Arc::new(Inner {
525 pool,
526 queue_notifications: DashMap::new(),
527 });
528
529 let handle = spawn(
530 "postgres-delegate-notifications",
531 delegate_notifications(rx, inner.clone(), parallelism * 8),
532 )?;
533
534 let drop_handle = Arc::new(handle);
535
536 Ok(Storage { inner, drop_handle })
537 }
538
539 async fn insert(&self, job_info: JobInfo) -> Result<Uuid, PostgresError> {
540 let postgres_job: PostgresJob = job_info.into();
541 let id = postgres_job.id;
542
543 let mut conn = self.inner.pool.get().await.map_err(PostgresError::Pool)?;
544
545 {
546 use schema::job_queue::dsl::*;
547
548 postgres_job
549 .insert_into(job_queue)
550 .execute(&mut conn)
551 .metrics("background-jobs.postgres.insert")
552 .timeout(Duration::from_secs(5))
553 .await
554 .map_err(|_| PostgresError::DbTimeout)?
555 .map_err(PostgresError::Diesel)?;
556 }
557
558 Ok(id)
559 }
560}
561
562impl<'a> JobNotifierState<'a> {
563 fn handle(&mut self, payload: &str) {
564 let Some((job_id, queue_name)) = payload.split_once(' ') else {
565 tracing::warn!("Invalid queue payload {payload}");
566 return;
567 };
568
569 let Ok(job_id) = job_id.parse::<Uuid>() else {
570 tracing::warn!("Invalid job ID {job_id}");
571 return;
572 };
573
574 if !self.jobs.insert(job_id) {
575 return;
577 }
578
579 self.jobs_ordered.push_back(job_id);
580
581 if self.jobs_ordered.len() > self.capacity {
582 if let Some(job_id) = self.jobs_ordered.pop_front() {
583 self.jobs.remove(&job_id);
584 }
585 }
586
587 self.inner
588 .queue_notifications
589 .entry(queue_name.to_string())
590 .or_insert_with(|| Arc::new(Notify::const_new()))
591 .notify_one();
592
593 metrics::counter!("background-jobs.postgres.job-notifier.notified", "queue" => queue_name.to_string()).increment(1);
594 }
595}
596
597async fn delegate_notifications(
598 receiver: flume::Receiver<Notification>,
599 inner: Arc<Inner>,
600 capacity: usize,
601) {
602 let mut job_notifier_state = JobNotifierState {
603 inner: &inner,
604 capacity,
605 jobs: BTreeSet::new(),
606 jobs_ordered: VecDeque::new(),
607 };
608
609 while let Ok(notification) = receiver.recv_async().await {
610 tracing::trace!("delegate_notifications: looping");
611 metrics::counter!("background-jobs.postgres.notification").increment(1);
612
613 match notification.channel() {
614 "queue_status_channel" => {
615 job_notifier_state.handle(notification.payload());
617 }
618 channel => {
619 tracing::info!(
620 "Unhandled postgres notification: {channel}: {}",
621 notification.payload()
622 );
623 }
624 }
625 }
626
627 tracing::warn!("Notification delegator shutting down");
628}
629
630fn build_handler(sender: flume::Sender<Notification>) -> ConfigFn {
631 Box::new(
632 move |config: &str| -> BoxFuture<'_, ConnectionResult<AsyncPgConnection>> {
633 let sender = sender.clone();
634
635 let connect_span = tracing::trace_span!(parent: None, "connect future");
636
637 Box::pin(
638 async move {
639 let (client, conn) =
640 tokio_postgres::connect(config, tokio_postgres::tls::NoTls)
641 .await
642 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
643
644 spawn_db_notification_task(sender, conn)
646 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
647
648 AsyncPgConnection::try_from(client).await
649 }
650 .instrument(connect_span),
651 )
652 },
653 )
654}
655
656fn spawn_db_notification_task(
657 sender: flume::Sender<Notification>,
658 mut conn: Connection<Socket, NoTlsStream>,
659) -> std::io::Result<()> {
660 spawn_detach("postgres-notifications", async move {
661 while let Some(res) = std::future::poll_fn(|cx| conn.poll_message(cx)).await {
662 tracing::trace!("db_notification_task: looping");
663
664 match res {
665 Err(e) => {
666 tracing::error!("Database Connection {e:?}");
667 return;
668 }
669 Ok(AsyncMessage::Notice(e)) => {
670 tracing::warn!("Database Notice {e:?}");
671 }
672 Ok(AsyncMessage::Notification(notification)) => {
673 if sender.send_async(notification).await.is_err() {
674 tracing::warn!("Missed notification. Are we shutting down?");
675 }
676 }
677 Ok(_) => {
678 tracing::warn!("Unhandled AsyncMessage!!! Please contact the developer of this application");
679 }
680 }
681 }
682 })?;
683
684 Ok(())
685}
686
687impl<T> Future for DropHandle<T> {
688 type Output = <JoinHandle<T> as Future>::Output;
689
690 fn poll(
691 self: std::pin::Pin<&mut Self>,
692 cx: &mut std::task::Context<'_>,
693 ) -> std::task::Poll<Self::Output> {
694 std::pin::Pin::new(&mut self.get_mut().handle).poll(cx)
695 }
696}
697
698impl<T> Drop for DropHandle<T> {
699 fn drop(&mut self) {
700 self.handle.abort();
701 }
702}
703
704impl<T> Deref for DropHandle<T> {
705 type Target = JoinHandle<T>;
706
707 fn deref(&self) -> &Self::Target {
708 &self.handle
709 }
710}
711
712impl std::fmt::Debug for Storage {
713 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
714 f.debug_struct("Storage").finish()
715 }
716}
717
718impl From<refinery::Error> for ConnectPostgresError {
719 fn from(value: refinery::Error) -> Self {
720 Self::Migration(value)
721 }
722}
723
724impl From<tokio_postgres::Error> for ConnectPostgresError {
725 fn from(value: tokio_postgres::Error) -> Self {
726 Self::ConnectForMigration(value)
727 }
728}
729
730impl From<BuildError> for ConnectPostgresError {
731 fn from(value: BuildError) -> Self {
732 Self::BuildPool(value)
733 }
734}
735
736impl From<std::io::Error> for ConnectPostgresError {
737 fn from(value: std::io::Error) -> Self {
738 Self::SpawnTask(value)
739 }
740}
741
742impl std::fmt::Display for ConnectPostgresError {
743 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
744 match self {
745 Self::BuildPool(_) => write!(f, "Failed to build postgres connection pool"),
746 Self::ConnectForMigration(_) => {
747 write!(f, "Failed to connect to postgres for migrations")
748 }
749 Self::Migration(_) => write!(f, "Failed to run migrations"),
750 Self::SpawnTask(_) => write!(f, "Failed to spawn task"),
751 }
752 }
753}
754
755impl std::error::Error for ConnectPostgresError {
756 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
757 match self {
758 Self::BuildPool(e) => Some(e),
759 Self::ConnectForMigration(e) => Some(e),
760 Self::Migration(e) => Some(e),
761 Self::SpawnTask(e) => Some(e),
762 }
763 }
764}
765
766impl std::fmt::Display for PostgresError {
767 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
768 match self {
769 Self::Pool(_) => write!(f, "Error in db pool"),
770 Self::Diesel(_) => write!(f, "Error in database"),
771 Self::DbTimeout => write!(f, "Timed out waiting for postgres"),
772 }
773 }
774}
775
776impl Error for PostgresError {
777 fn source(&self) -> Option<&(dyn Error + 'static)> {
778 match self {
779 Self::Pool(e) => Some(e),
780 Self::Diesel(e) => Some(e),
781 Self::DbTimeout => None,
782 }
783 }
784}