1#![doc = include_str!("../README.md")]
2
3use std::{fmt::Debug, marker::PhantomData};
4
5pub use apalis_codec::json::JsonCodec;
6use apalis_core::{
7 backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue},
8 task::{Task, task_id::TaskId},
9 worker::context::WorkerContext,
10};
11pub use apalis_sql::{config::Config, from_row::TaskRow};
12use diesel::{
13 PgConnection,
14 r2d2::{ConnectionManager, Pool},
15};
16use futures::{StreamExt, TryStreamExt};
17use ulid::Ulid;
18
19pub use crate::{
20 ack::{PgAck, PgMiddleware, lock_task, lock_task_in_queue},
21 error::Error,
22 fetcher::{PgFetcher, PgNotify},
23 lifecycle::{refresh_queue_stats_snapshot, setup, verify_schema},
24 pool::{build_pool, build_pool_with},
25 queries::migrations::MIGRATIONS,
26 shared::{SharedFetcher, SharedPostgresError, SharedPostgresStorage},
27 sink::PgSink,
28};
29
30mod ack;
31mod admin;
32mod error;
33mod fetcher;
34mod lifecycle;
35mod models;
36mod notify_event;
37mod pool;
38mod queries;
39mod runtime;
40mod shared;
41mod sink;
42
43pub(crate) use notify_event::InsertEvent;
44pub mod schema;
45
46pub type PgPool = Pool<ConnectionManager<PgConnection>>;
48pub type PgContext = apalis_sql::context::SqlContext<PgPool>;
50pub type PgTask<Args> = Task<Args, PgContext, Ulid>;
52pub type PgTaskId = TaskId<Ulid>;
54pub type CompactType = Vec<u8>;
56
57pub(crate) const STORAGE_NAME: &str = "PostgresStorage";
61
62#[must_use]
64pub const fn crate_name() -> &'static str {
65 "apalis-diesel-postgres"
66}
67
68const _: fn() = || {
70 const fn assert_send_sync<T: Send + Sync>() {}
71 assert_send_sync::<PostgresStorage<()>>();
72 assert_send_sync::<PostgresStorage<(), JsonCodec<CompactType>, PgNotify>>();
73 assert_send_sync::<PostgresStorage<(), JsonCodec<CompactType>, SharedFetcher>>();
74 assert_send_sync::<SharedPostgresStorage<()>>();
75};
76
77pub struct PostgresStorage<
79 Args,
80 Codec = JsonCodec<CompactType>,
81 Fetcher = PgFetcher<CompactType, Codec>,
82> {
83 _marker: PhantomData<(Args, Codec)>,
84 pub(crate) pool: PgPool,
85 pub(crate) config: Config,
86 pub(crate) fetcher: Fetcher,
87 pub(crate) sink: PgSink<Args, Codec>,
88 pub(crate) lease_token: std::sync::Arc<str>,
94}
95
96impl<Args, Codec, Fetcher: Unpin> Unpin for PostgresStorage<Args, Codec, Fetcher> {}
101
102impl<Args, Codec, Fetcher: Debug> Debug for PostgresStorage<Args, Codec, Fetcher> {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("PostgresStorage")
105 .field("config", &self.config)
106 .field("fetcher", &self.fetcher)
107 .finish_non_exhaustive()
108 }
109}
110
111impl<Args, Codec, Fetcher: Clone> Clone for PostgresStorage<Args, Codec, Fetcher> {
112 fn clone(&self) -> Self {
113 Self {
114 _marker: PhantomData,
115 pool: self.pool.clone(),
116 config: self.config.clone(),
117 fetcher: self.fetcher.clone(),
118 sink: self.sink.clone(),
119 lease_token: self.lease_token.clone(),
120 }
121 }
122}
123
124impl<Args> PostgresStorage<Args> {
125 #[must_use]
135 pub fn new(pool: &PgPool) -> Self {
136 let config = Config::new(std::any::type_name::<Args>());
137 Self::new_with_config(pool, &config)
138 }
139
140 #[must_use]
145 pub fn new_with_config(pool: &PgPool, config: &Config) -> Self {
146 Self {
147 _marker: PhantomData,
148 pool: pool.clone(),
149 config: config.clone(),
150 fetcher: PgFetcher {
151 _marker: PhantomData,
152 },
153 sink: PgSink::new(pool, config),
154 lease_token: queries::worker::mint_lease_token().into(),
155 }
156 }
157
158 #[must_use]
171 pub fn new_with_notify(
172 pool: &PgPool,
173 config: &Config,
174 ) -> PostgresStorage<Args, JsonCodec<CompactType>, PgNotify> {
175 PostgresStorage {
176 _marker: PhantomData,
177 pool: pool.clone(),
178 config: config.clone(),
179 fetcher: PgNotify,
180 sink: PgSink::new(pool, config),
181 lease_token: queries::worker::mint_lease_token().into(),
182 }
183 }
184
185 #[must_use]
187 pub fn pool(&self) -> &PgPool {
188 &self.pool
189 }
190
191 #[must_use]
193 pub fn config(&self) -> &Config {
194 &self.config
195 }
196}
197
198impl<Args, Codec, Fetcher> PostgresStorage<Args, Codec, Fetcher> {
199 #[must_use]
201 pub fn with_codec<NewCodec>(self) -> PostgresStorage<Args, NewCodec, Fetcher> {
202 PostgresStorage {
203 _marker: PhantomData,
204 sink: PgSink::new(&self.pool, &self.config),
205 pool: self.pool,
206 config: self.config,
207 fetcher: self.fetcher,
208 lease_token: self.lease_token,
209 }
210 }
211
212 pub(crate) fn heartbeat_stream(
216 &self,
217 worker: &WorkerContext,
218 ) -> futures::stream::BoxStream<'static, Result<(), Error>> {
219 let keep_alive = queries::keep_alive_stream(
220 self.pool.clone(),
221 self.config.clone(),
222 worker.clone(),
223 std::sync::Arc::clone(&self.lease_token),
224 );
225 let reenqueue = queries::reenqueue_orphaned_stream(self.pool.clone(), self.config.clone())
226 .map_ok(|_| ());
227 futures::stream::select(keep_alive, reenqueue).boxed()
228 }
229}
230
231impl<Args, EncodeCodec, Fetcher> PostgresStorage<Args, EncodeCodec, Fetcher>
235where
236 EncodeCodec: Codec<Args, Compact = CompactType>,
237 EncodeCodec::Error: std::error::Error + Send + Sync + 'static,
238{
239 pub fn push_with_conn(&self, conn: &mut PgConnection, args: Args) -> Result<PgTaskId, Error> {
266 let encoded = EncodeCodec::encode(&args).map_err(|err| Error::Decode(Box::new(err)))?;
267 let task_id = PgTaskId::new(Ulid::new());
268 let mut task = PgTask::<CompactType>::new(encoded);
269 task.parts.task_id = Some(task_id);
270 queries::push_tasks_on_conn(conn, &self.config, vec![task])?;
271 Ok(task_id)
272 }
273
274 pub fn push_task_with_conn(
295 &self,
296 conn: &mut PgConnection,
297 task: PgTask<Args>,
298 ) -> Result<PgTaskId, Error> {
299 let encoded =
300 EncodeCodec::encode(&task.args).map_err(|err| Error::Decode(Box::new(err)))?;
301 let task_id = task
302 .parts
303 .task_id
304 .unwrap_or_else(|| PgTaskId::new(Ulid::new()));
305 let mut compact = PgTask::<CompactType> {
306 args: encoded,
307 parts: task.parts,
308 };
309 compact.parts.task_id = Some(task_id);
310 queries::push_tasks_on_conn(conn, &self.config, vec![compact])?;
311 Ok(task_id)
312 }
313}
314
315impl<Args, Decode, Fetcher> Backend for PostgresStorage<Args, Decode, Fetcher>
319where
320 Args: Send + 'static + Unpin,
321 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
322 Decode::Error: std::error::Error + Send + Sync + 'static,
323 Fetcher: crate::fetcher::PgFetcherSource,
324{
325 type Args = Args;
326 type IdType = Ulid;
327 type Context = PgContext;
328 type Error = Error;
329 type Stream = TaskStream<PgTask<Args>, Error>;
330 type Beat = futures::stream::BoxStream<'static, Result<(), Error>>;
331 type Layer = PgMiddleware;
332
333 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
334 self.heartbeat_stream(worker)
335 }
336
337 fn middleware(&self) -> Self::Layer {
338 PgMiddleware::with_lease_token(
339 self.pool.clone(),
340 self.config.ack(),
341 std::sync::Arc::clone(&self.lease_token),
342 )
343 }
344
345 fn poll(self, worker: &WorkerContext) -> Self::Stream {
346 let compact = self.fetcher.into_compact_stream(
347 self.pool,
348 self.config,
349 worker.clone(),
350 self.lease_token,
351 );
352 crate::fetcher::decode_task_stream::<Args, Decode>(compact)
353 }
354}
355
356impl<Args, Decode, Fetcher> BackendExt for PostgresStorage<Args, Decode, Fetcher>
357where
358 Args: Send + 'static + Unpin,
359 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
360 Decode::Error: std::error::Error + Send + Sync + 'static,
361 Fetcher: crate::fetcher::PgFetcherSource,
362{
363 type Compact = CompactType;
364 type Codec = Decode;
365 type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
366
367 fn get_queue(&self) -> Queue {
368 self.config.queue().clone()
369 }
370
371 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
372 self.fetcher
373 .into_compact_stream(self.pool, self.config, worker.clone(), self.lease_token)
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use std::str::FromStr;
380
381 use apalis_core::{
382 backend::{Backend, BackendExt},
383 task::status::Status,
384 };
385 use apalis_sql::{DateTime, DateTimeExt, from_row::FromRowError};
386 use diesel::{
387 PgConnection,
388 r2d2::{ConnectionManager, Pool},
389 };
390 use lets_expect::{AssertionError, AssertionResult, *};
391
392 use super::*;
393
394 fn row(
395 id: &str,
396 status: &str,
397 run_at: Option<DateTime>,
398 idempotency_key: Option<&str>,
399 ) -> TaskRow {
400 TaskRow {
401 job: b"payload".to_vec(),
402 id: id.to_owned(),
403 job_type: "unit-queue".to_owned(),
404 status: status.to_owned(),
405 attempts: 2,
406 max_attempts: Some(3),
407 run_at,
408 last_result: None,
409 lock_at: None,
410 lock_by: Some("worker-a".to_owned()),
411 done_at: None,
412 priority: Some(7),
413 metadata: Some(serde_json::json!({"kind": "unit"})),
414 idempotency_key: idempotency_key.map(str::to_owned),
415 }
416 }
417
418 fn compact_task_has_expected_parts(
419 result: &Result<PgTask<CompactType>, FromRowError>,
420 ) -> AssertionResult {
421 match result {
422 Ok(task)
423 if task.args == b"payload"
424 && task.parts.attempt.current() == 2
425 && task.parts.status.load() == Status::Pending
426 && task.parts.ctx.priority() == 7
427 && task.parts.ctx.lock_by() == &Some("worker-a".to_owned())
428 && task.parts.idempotency_key == Some("same-key".to_owned()) =>
429 {
430 Ok(())
431 }
432 Ok(task) => Err(AssertionError::new(vec![format!(
433 "unexpected task parts: {task:?}"
434 )])),
435 Err(error) => Err(AssertionError::new(vec![format!(
436 "expected successful conversion, got {error:?}"
437 )])),
438 }
439 }
440
441 fn column_not_found(column: &'static str) -> impl Fn(&FromRowError) -> AssertionResult {
442 move |error| match error {
443 FromRowError::ColumnNotFound(found) if found == column => Ok(()),
444 other => Err(AssertionError::new(vec![format!(
445 "expected missing column {column}, got {other:?}"
446 )])),
447 }
448 }
449
450 fn decode_error(error: &FromRowError) -> AssertionResult {
451 match error {
452 FromRowError::DecodeError(_) => Ok(()),
453 other => Err(AssertionError::new(vec![format!(
454 "expected decode error, got {other:?}"
455 )])),
456 }
457 }
458
459 fn unchecked_pool() -> PgPool {
460 let manager = ConnectionManager::<PgConnection>::new("postgres://127.0.0.1:1/not-used");
461 Pool::builder()
462 .max_size(1)
463 .connection_timeout(std::time::Duration::from_millis(10))
464 .build_unchecked(manager)
465 }
466
467 fn storage_uses_queue_and_buffer<Args, Codec, Fetcher>(
468 queue: &'static str,
469 buffer_size: usize,
470 ) -> impl Fn(&PostgresStorage<Args, Codec, Fetcher>) -> AssertionResult {
471 move |storage| {
472 if storage.config.queue().to_string() == queue
473 && storage.config.buffer_size() == buffer_size
474 {
475 Ok(())
476 } else {
477 Err(AssertionError::new(vec![format!(
478 "expected queue {queue:?} and buffer {buffer_size}, got queue {:?} and buffer {}",
479 storage.config.queue(),
480 storage.config.buffer_size()
481 )]))
482 }
483 }
484 }
485
486 fn debug_mentions_public_type(result: &String) -> AssertionResult {
487 if result.contains("PostgresStorage") && result.contains("config") {
488 Ok(())
489 } else {
490 Err(AssertionError::new(vec![format!(
491 "debug output did not describe storage: {result}"
492 )]))
493 }
494 }
495
496 fn task_id_alias_accepts_ulid(id: PgTaskId) -> bool {
497 Ulid::from_str(&id.to_string()).is_ok()
498 }
499
500 fn storage_for_type_name() -> PostgresStorage<String> {
501 let pool = unchecked_pool();
502 PostgresStorage::<String>::new(&pool)
503 }
504
505 fn storage_for_config(queue: &'static str, buffer_size: usize) -> PostgresStorage<String> {
506 let pool = unchecked_pool();
507 let config = Config::new(queue).set_buffer_size(buffer_size);
508 PostgresStorage::<String>::new_with_config(&pool, &config)
509 }
510
511 fn notify_storage_for_config(
512 queue: &'static str,
513 buffer_size: usize,
514 ) -> PostgresStorage<String, JsonCodec<CompactType>, PgNotify> {
515 let pool = unchecked_pool();
516 let config = Config::new(queue).set_buffer_size(buffer_size);
517 PostgresStorage::<String>::new_with_notify(&pool, &config)
518 }
519
520 fn cloned_storage_for_config(
521 queue: &'static str,
522 buffer_size: usize,
523 ) -> PostgresStorage<String> {
524 storage_for_config(queue, buffer_size).clone()
525 }
526
527 fn debug_storage() -> String {
528 format!("{:?}", storage_for_config("debug-api", 10))
529 }
530
531 fn storage_with_changed_codec() -> PostgresStorage<String, JsonCodec<CompactType>> {
532 storage_for_config("codec-api", 6)
533 .with_codec::<()>()
534 .with_codec::<JsonCodec<CompactType>>()
535 }
536
537 fn type_name_after_with_codec() -> &'static str {
538 let pool = unchecked_pool();
539 let storage = PostgresStorage::<String>::new(&pool).with_codec::<()>();
540 std::any::type_name_of_val(&storage)
541 }
542
543 fn type_name_contains_unit_codec(result: &&'static str) -> AssertionResult {
544 if result.contains("PostgresStorage")
545 && result.contains("alloc::string::String")
546 && (result.contains(", (),") || result.contains(",()"))
547 {
548 Ok(())
549 } else {
550 Err(AssertionError::new(vec![format!(
551 "expected with_codec::<()> to switch the codec slot, got {result}"
552 )]))
553 }
554 }
555
556 fn storage_accessors() -> (String, usize, u32) {
557 let storage = storage_for_config("accessor-api", 8);
558 (
559 storage.config().queue().to_string(),
560 storage.config().buffer_size(),
561 storage.pool().state().connections,
562 )
563 }
564
565 fn basic_get_queue() -> String {
566 storage_for_config("basic-queue-api", 4)
567 .get_queue()
568 .to_string()
569 }
570
571 fn notify_get_queue() -> String {
572 notify_storage_for_config("notify-queue-api", 4)
573 .get_queue()
574 .to_string()
575 }
576
577 fn backend_trait_surfaces(notify: bool) -> (String, String, String) {
578 let worker = WorkerContext::new::<()>("backend-trait-worker");
579 if notify {
580 let storage = notify_storage_for_config("notify-trait-api", 2);
581 let middleware = std::any::type_name_of_val(&storage.middleware()).to_owned();
582 let heartbeat = std::any::type_name_of_val(&storage.heartbeat(&worker)).to_owned();
583 let stream = std::any::type_name_of_val(&storage.poll_compact(&worker)).to_owned();
584 (middleware, heartbeat, stream)
585 } else {
586 let storage = storage_for_config("basic-trait-api", 2);
587 let middleware = std::any::type_name_of_val(&storage.middleware()).to_owned();
588 let heartbeat = std::any::type_name_of_val(&storage.heartbeat(&worker)).to_owned();
589 let stream = std::any::type_name_of_val(&storage.poll_compact(&worker)).to_owned();
590 (middleware, heartbeat, stream)
591 }
592 }
593
594 fn exposes_accessors(result: &(String, usize, u32)) -> AssertionResult {
595 if result.0 == "accessor-api" && result.1 == 8 {
596 Ok(())
597 } else {
598 Err(AssertionError::new(vec![format!(
599 "unexpected storage accessors: {result:?}"
600 )]))
601 }
602 }
603
604 fn constructs_backend_traits(result: &(String, String, String)) -> AssertionResult {
605 if result.0.contains("PgMiddleware")
606 && result.1.contains("Stream")
607 && result.2.contains("Stream")
608 {
609 Ok(())
610 } else {
611 Err(AssertionError::new(vec![format!(
612 "unexpected backend trait surfaces: {result:?}"
613 )]))
614 }
615 }
616
617 lets_expect! {
618 expect(crate_name()) {
619 to returns_the_crate_name { equal("apalis-diesel-postgres") }
620 }
621
622 expect(row(id, status, run_at, idempotency_key).try_into_task_compact::<Ulid, PgPool>()) {
623 let id = &Ulid::new().to_string();
624 let status = "Pending";
625 let run_at = Some(DateTime::now());
626 let idempotency_key = Some("same-key");
627
628 when row_has_all_required_fields {
629 to preserves_task_payload_and_context { compact_task_has_expected_parts }
630 }
631
632 when run_time_is_missing {
633 let run_at = None;
634 to rejects_the_row { be_err_and column_not_found("run_at") }
635 }
636
637 when status_is_unknown {
638 let status = "Unknown";
639 to rejects_the_row { be_err_and decode_error }
640 }
641
642 when id_is_not_a_ulid {
643 let id = "not-a-ulid";
644 to rejects_the_row { be_err_and decode_error }
645 }
646 }
647
648 expect(storage) {
649 let storage = storage_for_type_name();
650
651 when storage_is_built_from_the_task_type {
652 to uses_the_type_name_as_queue {
653 storage_uses_queue_and_buffer(std::any::type_name::<String>(), 10)
654 }
655 }
656
657 when storage_is_built_with_an_explicit_config {
658 let storage = storage_for_config("public-api", 3);
659 to preserves_the_supplied_config { storage_uses_queue_and_buffer("public-api", 3) }
660 }
661
662 when storage_is_built_with_notify {
663 let storage = notify_storage_for_config("notify-api", 2);
664 to preserves_the_supplied_config { storage_uses_queue_and_buffer("notify-api", 2) }
665 }
666
667 when storage_is_cloned {
668 let storage = cloned_storage_for_config("clone-api", 4);
669 to keeps_the_queue_configuration { storage_uses_queue_and_buffer("clone-api", 4) }
670 }
671 }
672
673 expect(debug_storage()) {
674 to describes_the_storage_without_exposing_the_pool { debug_mentions_public_type }
675 }
676
677 expect(storage_with_changed_codec()) {
678 to keeps_the_existing_pool_config_and_fetcher { storage_uses_queue_and_buffer("codec-api", 6) }
679 }
680
681 expect(type_name_after_with_codec()) {
682 when with_codec_replaces_the_codec_type_parameter {
683 to swaps_the_codec_slot_in_the_generic_signature { type_name_contains_unit_codec }
684 }
685 }
686
687 expect(storage_accessors()) {
688 to exposes_the_pool_and_config { exposes_accessors }
689 }
690
691 expect(basic_get_queue()) {
692 to returns_the_basic_queue { equal("basic-queue-api".to_owned()) }
693 }
694
695 expect(notify_get_queue()) {
696 to returns_the_notify_queue { equal("notify-queue-api".to_owned()) }
697 }
698
699 expect(backend_trait_surfaces(notify)) {
700 let notify = false;
701
702 when basic_polling_storage {
703 to builds_heartbeat_middleware_and_compact_stream { constructs_backend_traits }
704 }
705
706 when notify_storage {
707 let notify = true;
708 to builds_heartbeat_middleware_and_compact_stream { constructs_backend_traits }
709 }
710 }
711
712 expect(task_id_alias_accepts_ulid(task_id)) {
713 let task_id = PgTaskId::from_str(&Ulid::new().to_string()).expect("generated ULID parses");
714
715 to accepts_the_public_task_id_alias { be_true }
716 }
717 }
718}