1#![doc = include_str!("../README.md")]
2use std::{fmt::Debug, marker::PhantomData};
6
7use apalis_codec::json::JsonCodec;
8use apalis_core::{
9 backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue},
10 features_table,
11 layers::Stack,
12 task::{Task, task_id::TaskId},
13 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
14};
15pub use apalis_sql::{config::Config, from_row::TaskRow};
16use futures::{
17 StreamExt, TryFutureExt, TryStreamExt,
18 future::ready,
19 stream::{self, BoxStream, select},
20};
21use serde::Deserialize;
22pub use sqlx::{PgPool, postgres::PgConnectOptions, postgres::PgListener, postgres::Postgres};
23use ulid::Ulid;
24
25pub use crate::{
26 ack::{LockTaskLayer, PgAck},
27 fetcher::{PgFetcher, PgPollFetcher},
28 queries::{
29 keep_alive::{initial_heartbeat, keep_alive_stream},
30 reenqueue_orphaned::reenqueue_orphaned_stream,
31 },
32 sink::PgSink,
33};
34
35mod ack;
36mod fetcher;
37mod from_row;
38
39pub type PgContext = apalis_sql::context::SqlContext<PgPool>;
40mod queries;
41pub mod shared;
42pub mod sink;
43
44pub type PgTask<Args> = Task<Args, PgContext, Ulid>;
45
46pub type PgTaskId = TaskId<Ulid>;
47
48pub type CompactType = Vec<u8>;
49
50#[doc = features_table! {
51 setup = r#"
52 # {
53 # use apalis_postgres::PostgresStorage;
54 # use sqlx::PgPool;
55 # let pool = PgPool::connect(std::env::var("DATABASE_URL").unwrap().as_str()).await.unwrap();
56 # PostgresStorage::setup(&pool).await.unwrap();
57 # PostgresStorage::new(&pool)
58 # };
59 "#,
60
61 Backend => supported("Supports storage and retrieval of tasks", true),
62 TaskSink => supported("Ability to push new tasks", true),
63 Serialization => supported("Serialization support for arguments", true),
64 Workflow => supported("Flexible enough to support workflows", true),
65 WebUI => supported("Expose a web interface for monitoring tasks", true),
66 FetchById => supported("Allow fetching a task by its ID", false),
67 RegisterWorker => supported("Allow registering a worker with the backend", false),
68 MakeShared => supported("Share one connection across multiple workers via [`SharedPostgresStorage`]", false),
69 WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
70 ResumeById => supported("Resume a task by its ID", false),
71 ResumeAbandoned => supported("Resume abandoned tasks", false),
72 ListWorkers => supported("List all workers registered with the backend", false),
73 ListTasks => supported("List all tasks in the backend", false),
74}]
75#[pin_project::pin_project]
78pub struct PostgresStorage<
79 Args,
80 Compact = CompactType,
81 Codec = JsonCodec<CompactType>,
82 Fetcher = PgFetcher<Args, Compact, Codec>,
83> {
84 _marker: PhantomData<(Args, Compact, Codec)>,
85 pool: PgPool,
86 config: Config,
87 #[pin]
88 fetcher: Fetcher,
89 #[pin]
90 sink: PgSink<Args, Compact, Codec>,
91}
92
93#[derive(Debug, Clone, Default)]
95pub struct PgNotify {
96 _private: PhantomData<()>,
97}
98
99impl<Args, Compact, Codec, Fetcher: Clone> Clone
100 for PostgresStorage<Args, Compact, Codec, Fetcher>
101{
102 fn clone(&self) -> Self {
103 Self {
104 _marker: PhantomData,
105 pool: self.pool.clone(),
106 config: self.config.clone(),
107 fetcher: self.fetcher.clone(),
108 sink: self.sink.clone(),
109 }
110 }
111}
112
113impl PostgresStorage<(), (), ()> {
114 #[cfg(feature = "migrate")]
116 pub async fn setup(pool: &PgPool) -> Result<(), sqlx::Error> {
117 Self::migrations().run(pool).await?;
118 Ok(())
119 }
120
121 #[cfg(feature = "migrate")]
123 pub fn migrations() -> sqlx::migrate::Migrator {
124 sqlx::migrate!("./migrations")
125 }
126}
127
128impl<Args> PostgresStorage<Args> {
129 pub fn new(pool: &PgPool) -> Self {
130 let config = Config::new(std::any::type_name::<Args>());
131 Self::new_with_config(pool, &config)
132 }
133
134 pub fn new_with_config(pool: &PgPool, config: &Config) -> Self {
136 let sink = PgSink::new(pool, config);
137 Self {
138 _marker: PhantomData,
139 pool: pool.clone(),
140 config: config.clone(),
141 fetcher: PgFetcher {
142 _marker: PhantomData,
143 },
144 sink,
145 }
146 }
147
148 pub fn new_with_notify(
149 pool: &PgPool,
150 config: &Config,
151 ) -> PostgresStorage<Args, CompactType, JsonCodec<CompactType>, PgNotify> {
152 let sink = PgSink::new(pool, config);
153
154 PostgresStorage {
155 _marker: PhantomData,
156 pool: pool.clone(),
157 config: config.clone(),
158 fetcher: PgNotify::default(),
159 sink,
160 }
161 }
162
163 pub fn pool(&self) -> &PgPool {
165 &self.pool
166 }
167
168 pub fn config(&self) -> &Config {
170 &self.config
171 }
172}
173
174impl<Args, Compact, Codec, Fetcher> PostgresStorage<Args, Compact, Codec, Fetcher> {
175 pub fn with_codec<NewCodec>(self) -> PostgresStorage<Args, Compact, NewCodec, Fetcher> {
176 PostgresStorage {
177 _marker: PhantomData,
178 sink: PgSink::new(&self.pool, &self.config),
179 pool: self.pool,
180 config: self.config,
181 fetcher: self.fetcher,
182 }
183 }
184}
185
186impl<Args, Decode> Backend
187 for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
188where
189 Args: Send + 'static + Unpin,
190 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
191 Decode::Error: std::error::Error + Send + Sync + 'static,
192{
193 type Args = Args;
194
195 type IdType = Ulid;
196
197 type Context = PgContext;
198
199 type Error = sqlx::Error;
200
201 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
202
203 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
204
205 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
206
207 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
208 let pool = self.pool.clone();
209 let config = self.config.clone();
210 let worker = worker.clone();
211 let keep_alive = keep_alive_stream(pool, config, worker);
212 let reenqueue = reenqueue_orphaned_stream(
213 self.pool.clone(),
214 self.config.clone(),
215 *self.config.keep_alive(),
216 )
217 .map_ok(|_| ());
218 futures::stream::select(keep_alive, reenqueue).boxed()
219 }
220
221 fn middleware(&self) -> Self::Layer {
222 Stack::new(
223 LockTaskLayer::new(self.pool.clone()),
224 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
225 )
226 }
227
228 fn poll(self, worker: &WorkerContext) -> Self::Stream {
229 self.poll_basic(worker)
230 .map(|a| match a {
231 Ok(Some(task)) => Ok(Some(
232 task.try_map(|t| Decode::decode(&t))
233 .map_err(|e| sqlx::Error::Decode(e.into()))?,
234 )),
235 Ok(None) => Ok(None),
236 Err(e) => Err(e),
237 })
238 .boxed()
239 }
240}
241
242impl<Args, Decode> BackendExt
243 for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
244where
245 Args: Send + 'static + Unpin,
246 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
247 Decode::Error: std::error::Error + Send + Sync + 'static,
248{
249 type Compact = CompactType;
250
251 type Codec = Decode;
252 type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
253
254 fn get_queue(&self) -> Queue {
255 self.config.queue().clone()
256 }
257
258 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
259 self.poll_basic(worker).boxed()
260 }
261}
262
263impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
264where
265 Args: Send + 'static + Unpin,
266{
267 fn poll_basic(&self, worker: &WorkerContext) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
268 let register_worker = initial_heartbeat(
269 self.pool.clone(),
270 self.config.clone(),
271 worker.clone(),
272 "PostgresStorage",
273 )
274 .map_ok(|_| None);
275 let register = stream::once(register_worker);
276 register
277 .chain(PgPollFetcher::<CompactType>::new(
278 &self.pool,
279 &self.config,
280 worker,
281 ))
282 .boxed()
283 }
284}
285
286impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, PgNotify>
287where
288 Args: Send + 'static + Unpin,
289 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
290 Decode::Error: std::error::Error + Send + Sync + 'static,
291{
292 type Args = Args;
293
294 type IdType = Ulid;
295
296 type Context = PgContext;
297
298 type Error = sqlx::Error;
299
300 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
301
302 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
303
304 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
305
306 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
307 let pool = self.pool.clone();
308 let config = self.config.clone();
309 let worker = worker.clone();
310 let keep_alive = keep_alive_stream(pool, config, worker);
311 let reenqueue = reenqueue_orphaned_stream(
312 self.pool.clone(),
313 self.config.clone(),
314 *self.config.keep_alive(),
315 )
316 .map_ok(|_| ());
317 futures::stream::select(keep_alive, reenqueue).boxed()
318 }
319
320 fn middleware(&self) -> Self::Layer {
321 Stack::new(
322 LockTaskLayer::new(self.pool.clone()),
323 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
324 )
325 }
326
327 fn poll(self, worker: &WorkerContext) -> Self::Stream {
328 self.poll_with_notify(worker)
329 .map(|a| match a {
330 Ok(Some(task)) => Ok(Some(
331 task.try_map(|t| Decode::decode(&t))
332 .map_err(|e| sqlx::Error::Decode(e.into()))?,
333 )),
334 Ok(None) => Ok(None),
335 Err(e) => Err(e),
336 })
337 .boxed()
338 }
339}
340
341impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, PgNotify>
342where
343 Args: Send + 'static + Unpin,
344 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
345 Decode::Error: std::error::Error + Send + Sync + 'static,
346{
347 type Compact = CompactType;
348
349 type Codec = Decode;
350 type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
351
352 fn get_queue(&self) -> Queue {
353 self.config.queue().clone()
354 }
355
356 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
357 self.poll_with_notify(worker).boxed()
358 }
359}
360
361impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgNotify> {
362 pub fn poll_with_notify(
363 &self,
364 worker: &WorkerContext,
365 ) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
366 let pool = self.pool.clone();
367 let worker_id = worker.name().to_owned();
368 let namespace = self.config.queue().to_string();
369 let listener = async move {
370 let mut fetcher = PgListener::connect_with(&pool)
371 .await
372 .expect("Failed to create listener");
373 fetcher.listen("apalis::job::insert").await.unwrap();
374 fetcher
375 };
376 let fetcher = stream::once(listener).flat_map(|f| f.into_stream());
377 let pool = self.pool.clone();
378 let register_worker = initial_heartbeat(
379 self.pool.clone(),
380 self.config.clone(),
381 worker.clone(),
382 "PostgresStorageWithNotify",
383 )
384 .map_ok(|_| None);
385 let register = stream::once(register_worker);
386 let lazy_fetcher = fetcher
387 .into_stream()
388 .filter_map(move |notification| {
389 let namespace = namespace.clone();
390 async move {
391 let pg_notification = notification.ok()?;
392 let payload = pg_notification.payload();
393 let ev: InsertEvent = serde_json::from_str(payload).ok()?;
394
395 if ev.job_type == namespace {
396 return Some(ev.id);
397 }
398 None
399 }
400 })
401 .map(|t| t.to_string())
402 .ready_chunks(self.config.buffer_size())
403 .then(move |ids| {
404 let pool = pool.clone();
405 let worker_id = worker_id.clone();
406 async move {
407 let mut tx = pool.begin().await?;
408 use crate::from_row::PgTaskRow;
409 let res: Vec<_> = sqlx::query_file_as!(
410 PgTaskRow,
411 "queries/task/queue_by_id.sql",
412 &ids,
413 &worker_id
414 )
415 .fetch(&mut *tx)
416 .map(|r| {
417 let row: TaskRow = r?.try_into()?;
418 Ok(Some(
419 row.try_into_task_compact()
420 .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
421 ))
422 })
423 .collect()
424 .await;
425 tx.commit().await?;
426 Ok::<_, sqlx::Error>(res)
427 }
428 })
429 .flat_map(|vec| match vec {
430 Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
431 Ok(t) => Ok(t),
432 Err(e) => Err(e),
433 }))
434 .boxed(),
435 Err(e) => stream::once(ready(Err(e))).boxed(),
436 })
437 .boxed();
438
439 let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
440 &self.pool,
441 &self.config,
442 worker,
443 ));
444 register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
445 }
446}
447
448#[derive(Debug, Deserialize)]
449pub struct InsertEvent {
450 job_type: String,
451 id: PgTaskId,
452}
453
454#[cfg(test)]
455mod tests {
456 use std::{
457 collections::HashMap,
458 env,
459 time::{Duration, Instant},
460 };
461
462 use apalis_workflow::Workflow;
463 use apalis_workflow::WorkflowSink;
464
465 use apalis_core::{
466 backend::poll_strategy::{IntervalStrategy, StrategyBuilder},
467 error::BoxDynError,
468 task::data::Data,
469 worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
470 };
471 use serde::{Deserialize, Serialize};
472
473 use super::*;
474
475 #[tokio::test]
476 async fn basic_worker() {
477 use apalis_core::backend::TaskSink;
478 let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
479 .await
480 .unwrap();
481 let mut backend = PostgresStorage::new(&pool);
482
483 let mut items = stream::repeat_with(HashMap::default).take(1);
484 backend.push_stream(&mut items).await.unwrap();
485
486 async fn send_reminder(
487 _: HashMap<String, String>,
488 wrk: WorkerContext,
489 ) -> Result<(), BoxDynError> {
490 tokio::time::sleep(Duration::from_secs(2)).await;
491 wrk.stop().unwrap();
492 Ok(())
493 }
494
495 let worker = WorkerBuilder::new("rango-tango-1")
496 .backend(backend)
497 .build(send_reminder);
498 worker.run().await.unwrap();
499 }
500
501 #[tokio::test]
502 async fn notify_worker() {
503 use apalis_core::backend::TaskSink;
504 let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
505 .await
506 .unwrap();
507 let config = Config::new("test").with_poll_interval(
508 StrategyBuilder::new()
509 .apply(IntervalStrategy::new(Duration::from_secs(6)))
510 .build(),
511 );
512 let backend = PostgresStorage::new_with_notify(&pool, &config);
513
514 let mut b = backend.clone();
515
516 tokio::spawn(async move {
517 tokio::time::sleep(Duration::from_secs(2)).await;
518 let mut items = stream::repeat_with(|| {
519 Task::builder(42u32)
520 .with_ctx(PgContext::new().with_priority(1))
521 .build()
522 })
523 .take(1);
524 b.push_all(&mut items).await.unwrap();
525 });
526
527 async fn send_reminder(_: u32, wrk: WorkerContext) -> Result<(), BoxDynError> {
528 wrk.stop().unwrap();
529 Ok(())
530 }
531
532 let instant = Instant::now();
533 let worker = WorkerBuilder::new("rango-tango-2")
534 .backend(backend)
535 .build(send_reminder);
536 worker.run().await.unwrap();
537 let run_for = instant.elapsed();
538 assert!(
539 run_for < Duration::from_secs(4),
540 "Worker did not use notify mechanism"
541 );
542 }
543
544 #[tokio::test]
545 async fn test_workflow_complete() {
546 #[derive(Debug, Serialize, Deserialize, Clone)]
547 struct PipelineConfig {
548 min_confidence: f32,
549 enable_sentiment: bool,
550 }
551
552 #[derive(Debug, Serialize, Deserialize)]
553 struct UserInput {
554 text: String,
555 }
556
557 #[derive(Debug, Serialize, Deserialize)]
558 struct Classified {
559 text: String,
560 label: String,
561 confidence: f32,
562 }
563
564 #[derive(Debug, Serialize, Deserialize)]
565 struct Summary {
566 text: String,
567 sentiment: Option<String>,
568 }
569
570 let workflow = Workflow::new("text-pipeline")
571 .and_then(|input: UserInput, mut worker: WorkerContext| async move {
573 worker.emit(&Event::Custom(Box::new(format!(
574 "Preprocessing input: {}",
575 input.text
576 ))));
577 let processed = input.text.to_lowercase();
578 Ok::<_, BoxDynError>(processed)
579 })
580 .and_then(|text: String| async move {
582 let confidence = 0.85; let items = text.split_whitespace().collect::<Vec<_>>();
584 let results = items
585 .into_iter()
586 .map(|x| Classified {
587 text: x.to_string(),
588 label: if x.contains("rust") {
589 "Tech"
590 } else {
591 "General"
592 }
593 .to_string(),
594 confidence,
595 })
596 .collect::<Vec<_>>();
597 Ok::<_, BoxDynError>(results)
598 })
599 .filter_map(
601 |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
602 )
603 .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
604 let cfg = config.enable_sentiment;
605 async move {
606 if !cfg {
607 return Some(Summary {
608 text: c.text,
609 sentiment: None,
610 });
611 }
612
613 let sentiment = if c.text.contains("delightful") {
615 "positive"
616 } else {
617 "neutral"
618 };
619 Some(Summary {
620 text: c.text,
621 sentiment: Some(sentiment.to_string()),
622 })
623 }
624 })
625 .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
626 dbg!(&a);
627 worker.emit(&Event::Custom(Box::new(format!(
628 "Generated {} summaries",
629 a.len()
630 ))));
631 worker.stop()
632 });
633
634 let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
635 .await
636 .unwrap();
637 let config = Config::new("test").with_poll_interval(
638 StrategyBuilder::new()
639 .apply(IntervalStrategy::new(Duration::from_secs(1)))
640 .build(),
641 );
642 let mut backend = PostgresStorage::new_with_notify(&pool, &config);
643
644 let input = UserInput {
645 text: "Rust makes systems programming delightful!".to_string(),
646 };
647 backend.push_start(input).await.unwrap();
648
649 let worker = WorkerBuilder::new("rango-tango")
650 .backend(backend)
651 .data(PipelineConfig {
652 min_confidence: 0.8,
653 enable_sentiment: true,
654 })
655 .on_event(|ctx, ev| match ev {
656 Event::Custom(msg) => {
657 if let Some(m) = msg.downcast_ref::<String>() {
658 println!("Custom Message: {m}");
659 }
660 }
661 Event::Error(_) => {
662 println!("On Error = {ev:?}");
663 ctx.stop().unwrap();
664 }
665 _ => {
666 println!("On Event = {ev:?}");
667 }
668 })
669 .build(workflow);
670 worker.run().await.unwrap();
671 }
672}