1#![doc = include_str!("../README.md")]
2use std::{fmt::Debug, marker::PhantomData};
6
7use apalis_core::{
8 backend::{
9 Backend, BackendExt, TaskStream,
10 codec::{Codec, json::JsonCodec},
11 },
12 features_table,
13 layers::Stack,
14 task::{Task, task_id::TaskId},
15 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
16};
17use apalis_sql::from_row::TaskRow;
18use futures::{
19 StreamExt, TryFutureExt, TryStreamExt,
20 future::ready,
21 stream::{self, BoxStream, select},
22};
23use serde::Deserialize;
24pub use sqlx::{PgPool, postgres::PgConnectOptions, postgres::PgListener, postgres::Postgres};
25use ulid::Ulid;
26
27pub use crate::{
28 ack::{LockTaskLayer, PgAck},
29 fetcher::{PgFetcher, PgPollFetcher},
30 queries::{
31 keep_alive::{initial_heartbeat, keep_alive_stream},
32 reenqueue_orphaned::reenqueue_orphaned_stream,
33 },
34 sink::PgSink,
35};
36
37mod ack;
38mod config;
39mod fetcher;
40mod from_row;
41
42pub type PgContext = apalis_sql::context::SqlContext;
43pub use config::Config;
44
45mod queries;
46pub mod shared;
47pub mod sink;
48
49pub type PgTask<Args> = Task<Args, PgContext, Ulid>;
50
51pub type CompactType = Vec<u8>;
52
53#[derive(Debug, Clone, Default)]
54pub struct PgNotify {
55 _private: PhantomData<()>,
56}
57
58#[doc = features_table! {
59 setup = r#"
60 # {
61 # use apalis_postgres::PostgresStorage;
62 # use sqlx::PgPool;
63 # let pool = PgPool::connect(std::env::var("DATABASE_URL").unwrap().as_str()).await.unwrap();
64 # PostgresStorage::setup(&pool).await.unwrap();
65 # PostgresStorage::new(&pool)
66 # };
67 "#,
68
69 Backend => supported("Supports storage and retrieval of tasks", true),
70 TaskSink => supported("Ability to push new tasks", true),
71 Serialization => supported("Serialization support for arguments", true),
72 Workflow => supported("Flexible enough to support workflows", true),
73 WebUI => supported("Expose a web interface for monitoring tasks", true),
74 FetchById => supported("Allow fetching a task by its ID", false),
75 RegisterWorker => supported("Allow registering a worker with the backend", false),
76 MakeShared => supported("Share one connection across multiple workers via [`SharedPostgresStorage`]", false),
77 WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
78 ResumeById => supported("Resume a task by its ID", false),
79 ResumeAbandoned => supported("Resume abandoned tasks", false),
80 ListWorkers => supported("List all workers registered with the backend", false),
81 ListTasks => supported("List all tasks in the backend", false),
82}]
83#[pin_project::pin_project]
86pub struct PostgresStorage<
87 Args,
88 Compact = CompactType,
89 Codec = JsonCodec<CompactType>,
90 Fetcher = PgFetcher<Args, Compact, Codec>,
91> {
92 _marker: PhantomData<(Args, Compact, Codec)>,
93 pool: PgPool,
94 config: Config,
95 #[pin]
96 fetcher: Fetcher,
97 #[pin]
98 sink: PgSink<Args, Compact, Codec>,
99}
100
101impl<Args, Compact, Codec, Fetcher: Clone> Clone
102 for PostgresStorage<Args, Compact, Codec, Fetcher>
103{
104 fn clone(&self) -> Self {
105 Self {
106 _marker: PhantomData,
107 pool: self.pool.clone(),
108 config: self.config.clone(),
109 fetcher: self.fetcher.clone(),
110 sink: self.sink.clone(),
111 }
112 }
113}
114
115impl PostgresStorage<(), (), ()> {
116 #[cfg(feature = "migrate")]
118 pub async fn setup(pool: &PgPool) -> Result<(), sqlx::Error> {
119 Self::migrations().run(pool).await?;
120 Ok(())
121 }
122
123 #[cfg(feature = "migrate")]
125 pub fn migrations() -> sqlx::migrate::Migrator {
126 sqlx::migrate!("./migrations")
127 }
128}
129
130impl<Args> PostgresStorage<Args> {
131 pub fn new(pool: &PgPool) -> Self {
132 let config = Config::new(std::any::type_name::<Args>());
133 Self::new_with_config(pool, &config)
134 }
135
136 pub fn new_with_config(pool: &PgPool, config: &Config) -> Self {
138 let sink = PgSink::new(pool, config);
139 Self {
140 _marker: PhantomData,
141 pool: pool.clone(),
142 config: config.clone(),
143 fetcher: PgFetcher {
144 _marker: PhantomData,
145 },
146 sink,
147 }
148 }
149
150 pub fn new_with_notify(
151 pool: &PgPool,
152 config: &Config,
153 ) -> PostgresStorage<Args, CompactType, JsonCodec<CompactType>, PgNotify> {
154 let sink = PgSink::new(pool, config);
155
156 PostgresStorage {
157 _marker: PhantomData,
158 pool: pool.clone(),
159 config: config.clone(),
160 fetcher: PgNotify::default(),
161 sink,
162 }
163 }
164
165 pub fn pool(&self) -> &PgPool {
167 &self.pool
168 }
169
170 pub fn config(&self) -> &Config {
172 &self.config
173 }
174}
175
176impl<Args, Compact, Codec, Fetcher> PostgresStorage<Args, Compact, Codec, Fetcher> {
177 pub fn with_codec<NewCodec>(self) -> PostgresStorage<Args, Compact, NewCodec, Fetcher> {
178 PostgresStorage {
179 _marker: PhantomData,
180 sink: PgSink::new(&self.pool, &self.config),
181 pool: self.pool,
182 config: self.config,
183 fetcher: self.fetcher,
184 }
185 }
186}
187
188impl<Args, Decode> Backend
189 for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
190where
191 Args: Send + 'static + Unpin,
192 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
193 Decode::Error: std::error::Error + Send + Sync + 'static,
194{
195 type Args = Args;
196
197 type IdType = Ulid;
198
199 type Context = PgContext;
200
201 type Error = sqlx::Error;
202
203 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
204
205 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
206
207 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
208
209 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
210 let pool = self.pool.clone();
211 let config = self.config.clone();
212 let worker = worker.clone();
213 let keep_alive = keep_alive_stream(pool, config, worker);
214 let reenqueue = reenqueue_orphaned_stream(
215 self.pool.clone(),
216 self.config.clone(),
217 *self.config.keep_alive(),
218 )
219 .map_ok(|_| ());
220 futures::stream::select(keep_alive, reenqueue).boxed()
221 }
222
223 fn middleware(&self) -> Self::Layer {
224 Stack::new(
225 LockTaskLayer::new(self.pool.clone()),
226 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
227 )
228 }
229
230 fn poll(self, worker: &WorkerContext) -> Self::Stream {
231 self.poll_basic(worker)
232 .map(|a| match a {
233 Ok(Some(task)) => Ok(Some(
234 task.try_map(|t| Decode::decode(&t))
235 .map_err(|e| sqlx::Error::Decode(e.into()))?,
236 )),
237 Ok(None) => Ok(None),
238 Err(e) => Err(e),
239 })
240 .boxed()
241 }
242}
243
244impl<Args, Decode> BackendExt
245 for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
246where
247 Args: Send + 'static + Unpin,
248 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
249 Decode::Error: std::error::Error + Send + Sync + 'static,
250{
251 type Compact = CompactType;
252
253 type Codec = Decode;
254 type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
255 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
256 self.poll_basic(worker).boxed()
257 }
258}
259
260impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
261where
262 Args: Send + 'static + Unpin,
263{
264 fn poll_basic(&self, worker: &WorkerContext) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
265 let register_worker = initial_heartbeat(
266 self.pool.clone(),
267 self.config.clone(),
268 worker.clone(),
269 "PostgresStorage",
270 )
271 .map_ok(|_| None);
272 let register = stream::once(register_worker);
273 register
274 .chain(PgPollFetcher::<CompactType>::new(
275 &self.pool,
276 &self.config,
277 worker,
278 ))
279 .boxed()
280 }
281}
282
283impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, PgNotify>
284where
285 Args: Send + 'static + Unpin,
286 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
287 Decode::Error: std::error::Error + Send + Sync + 'static,
288{
289 type Args = Args;
290
291 type IdType = Ulid;
292
293 type Context = PgContext;
294
295 type Error = sqlx::Error;
296
297 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
298
299 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
300
301 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
302
303 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
304 let pool = self.pool.clone();
305 let config = self.config.clone();
306 let worker = worker.clone();
307 let keep_alive = keep_alive_stream(pool, config, worker);
308 let reenqueue = reenqueue_orphaned_stream(
309 self.pool.clone(),
310 self.config.clone(),
311 *self.config.keep_alive(),
312 )
313 .map_ok(|_| ());
314 futures::stream::select(keep_alive, reenqueue).boxed()
315 }
316
317 fn middleware(&self) -> Self::Layer {
318 Stack::new(
319 LockTaskLayer::new(self.pool.clone()),
320 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
321 )
322 }
323
324 fn poll(self, worker: &WorkerContext) -> Self::Stream {
325 self.poll_with_notify(worker)
326 .map(|a| match a {
327 Ok(Some(task)) => Ok(Some(
328 task.try_map(|t| Decode::decode(&t))
329 .map_err(|e| sqlx::Error::Decode(e.into()))?,
330 )),
331 Ok(None) => Ok(None),
332 Err(e) => Err(e),
333 })
334 .boxed()
335 }
336}
337
338impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, PgNotify>
339where
340 Args: Send + 'static + Unpin,
341 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
342 Decode::Error: std::error::Error + Send + Sync + 'static,
343{
344 type Compact = CompactType;
345
346 type Codec = Decode;
347 type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
348 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
349 self.poll_with_notify(worker).boxed()
350 }
351}
352
353impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgNotify> {
354 pub fn poll_with_notify(
355 &self,
356 worker: &WorkerContext,
357 ) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
358 let pool = self.pool.clone();
359 let worker_id = worker.name().to_owned();
360 let namespace = self.config.queue().to_string();
361 let listener = async move {
362 let mut fetcher = PgListener::connect_with(&pool)
363 .await
364 .expect("Failed to create listener");
365 fetcher.listen("apalis::job::insert").await.unwrap();
366 fetcher
367 };
368 let fetcher = stream::once(listener).flat_map(|f| f.into_stream());
369 let pool = self.pool.clone();
370 let register_worker = initial_heartbeat(
371 self.pool.clone(),
372 self.config.clone(),
373 worker.clone(),
374 "PostgresStorageWithNotify",
375 )
376 .map_ok(|_| None);
377 let register = stream::once(register_worker);
378 let lazy_fetcher = fetcher
379 .into_stream()
380 .filter_map(move |notification| {
381 let namespace = namespace.clone();
382 async move {
383 let pg_notification = notification.ok()?;
384 let payload = pg_notification.payload();
385 let ev: InsertEvent = serde_json::from_str(payload).ok()?;
386
387 if ev.job_type == namespace {
388 return Some(ev.id);
389 }
390 None
391 }
392 })
393 .map(|t| t.to_string())
394 .ready_chunks(self.config.buffer_size())
395 .then(move |ids| {
396 let pool = pool.clone();
397 let worker_id = worker_id.clone();
398 async move {
399 let mut tx = pool.begin().await?;
400 use crate::from_row::PgTaskRow;
401 let res: Vec<_> = sqlx::query_file_as!(
402 PgTaskRow,
403 "queries/task/queue_by_id.sql",
404 &ids,
405 &worker_id
406 )
407 .fetch(&mut *tx)
408 .map(|r| {
409 let row: TaskRow = r?.try_into()?;
410 Ok(Some(
411 row.try_into_task_compact()
412 .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
413 ))
414 })
415 .collect()
416 .await;
417 tx.commit().await?;
418 Ok::<_, sqlx::Error>(res)
419 }
420 })
421 .flat_map(|vec| match vec {
422 Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
423 Ok(t) => Ok(t),
424 Err(e) => Err(e),
425 }))
426 .boxed(),
427 Err(e) => stream::once(ready(Err(e))).boxed(),
428 })
429 .boxed();
430
431 let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
432 &self.pool,
433 &self.config,
434 worker,
435 ));
436 register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
437 }
438}
439
440#[derive(Debug, Deserialize)]
441pub struct InsertEvent {
442 job_type: String,
443 id: TaskId,
444}
445
446#[cfg(test)]
447mod tests {
448 use std::{
449 collections::HashMap,
450 env,
451 time::{Duration, Instant},
452 };
453
454 use apalis_workflow::Workflow;
455 use apalis_workflow::WorkflowSink;
456
457 use apalis_core::{
458 backend::poll_strategy::{IntervalStrategy, StrategyBuilder},
459 error::BoxDynError,
460 task::data::Data,
461 worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
462 };
463 use serde::{Deserialize, Serialize};
464
465 use super::*;
466
467 #[tokio::test]
468 async fn basic_worker() {
469 use apalis_core::backend::TaskSink;
470 let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
471 .await
472 .unwrap();
473 let mut backend = PostgresStorage::new(&pool);
474
475 let mut items = stream::repeat_with(HashMap::default).take(1);
476 backend.push_stream(&mut items).await.unwrap();
477
478 async fn send_reminder(
479 _: HashMap<String, String>,
480 wrk: WorkerContext,
481 ) -> Result<(), BoxDynError> {
482 tokio::time::sleep(Duration::from_secs(2)).await;
483 wrk.stop().unwrap();
484 Ok(())
485 }
486
487 let worker = WorkerBuilder::new("rango-tango-1")
488 .backend(backend)
489 .build(send_reminder);
490 worker.run().await.unwrap();
491 }
492
493 #[tokio::test]
494 async fn notify_worker() {
495 use apalis_core::backend::TaskSink;
496 let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
497 .await
498 .unwrap();
499 let config = Config::new("test").with_poll_interval(
500 StrategyBuilder::new()
501 .apply(IntervalStrategy::new(Duration::from_secs(6)))
502 .build(),
503 );
504 let backend = PostgresStorage::new_with_notify(&pool, &config);
505
506 let mut b = backend.clone();
507
508 tokio::spawn(async move {
509 tokio::time::sleep(Duration::from_secs(2)).await;
510 let mut items = stream::repeat_with(|| {
511 Task::builder(42u32)
512 .with_ctx(PgContext::new().with_priority(1))
513 .build()
514 })
515 .take(1);
516 b.push_all(&mut items).await.unwrap();
517 });
518
519 async fn send_reminder(_: u32, wrk: WorkerContext) -> Result<(), BoxDynError> {
520 wrk.stop().unwrap();
521 Ok(())
522 }
523
524 let instant = Instant::now();
525 let worker = WorkerBuilder::new("rango-tango-2")
526 .backend(backend)
527 .build(send_reminder);
528 worker.run().await.unwrap();
529 let run_for = instant.elapsed();
530 assert!(
531 run_for < Duration::from_secs(4),
532 "Worker did not use notify mechanism"
533 );
534 }
535
536 #[tokio::test]
537 async fn test_workflow_complete() {
538 #[derive(Debug, Serialize, Deserialize, Clone)]
539 struct PipelineConfig {
540 min_confidence: f32,
541 enable_sentiment: bool,
542 }
543
544 #[derive(Debug, Serialize, Deserialize)]
545 struct UserInput {
546 text: String,
547 }
548
549 #[derive(Debug, Serialize, Deserialize)]
550 struct Classified {
551 text: String,
552 label: String,
553 confidence: f32,
554 }
555
556 #[derive(Debug, Serialize, Deserialize)]
557 struct Summary {
558 text: String,
559 sentiment: Option<String>,
560 }
561
562 let workflow = Workflow::new("text-pipeline")
563 .and_then(|input: UserInput, mut worker: WorkerContext| async move {
565 worker.emit(&Event::Custom(Box::new(format!(
566 "Preprocessing input: {}",
567 input.text
568 ))));
569 let processed = input.text.to_lowercase();
570 Ok::<_, BoxDynError>(processed)
571 })
572 .and_then(|text: String| async move {
574 let confidence = 0.85; let items = text.split_whitespace().collect::<Vec<_>>();
576 let results = items
577 .into_iter()
578 .map(|x| Classified {
579 text: x.to_string(),
580 label: if x.contains("rust") {
581 "Tech"
582 } else {
583 "General"
584 }
585 .to_string(),
586 confidence,
587 })
588 .collect::<Vec<_>>();
589 Ok::<_, BoxDynError>(results)
590 })
591 .filter_map(
593 |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
594 )
595 .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
596 let cfg = config.enable_sentiment;
597 async move {
598 if !cfg {
599 return Some(Summary {
600 text: c.text,
601 sentiment: None,
602 });
603 }
604
605 let sentiment = if c.text.contains("delightful") {
607 "positive"
608 } else {
609 "neutral"
610 };
611 Some(Summary {
612 text: c.text,
613 sentiment: Some(sentiment.to_string()),
614 })
615 }
616 })
617 .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
618 dbg!(&a);
619 worker.emit(&Event::Custom(Box::new(format!(
620 "Generated {} summaries",
621 a.len()
622 ))));
623 worker.stop()
624 });
625
626 let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
627 .await
628 .unwrap();
629 let config = Config::new("test").with_poll_interval(
630 StrategyBuilder::new()
631 .apply(IntervalStrategy::new(Duration::from_secs(1)))
632 .build(),
633 );
634 let mut backend = PostgresStorage::new_with_notify(&pool, &config);
635
636 let input = UserInput {
637 text: "Rust makes systems programming delightful!".to_string(),
638 };
639 backend.push_start(input).await.unwrap();
640
641 let worker = WorkerBuilder::new("rango-tango")
642 .backend(backend)
643 .data(PipelineConfig {
644 min_confidence: 0.8,
645 enable_sentiment: true,
646 })
647 .on_event(|ctx, ev| match ev {
648 Event::Custom(msg) => {
649 if let Some(m) = msg.downcast_ref::<String>() {
650 println!("Custom Message: {m}");
651 }
652 }
653 Event::Error(_) => {
654 println!("On Error = {ev:?}");
655 ctx.stop().unwrap();
656 }
657 _ => {
658 println!("On Event = {ev:?}");
659 }
660 })
661 .build(workflow);
662 worker.run().await.unwrap();
663 }
664}