1#![doc = include_str!("../README.md")]
2use std::{fmt::Debug, marker::PhantomData};
6
7use apalis_codec::json::JsonCodec;
8use apalis_core::{
9 backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue},
10 features_table,
11 layers::Stack,
12 task::{Task, task_id::TaskId},
13 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
14};
15pub use apalis_sql::{config::Config, from_row::TaskRow};
16use futures::{
17 StreamExt, TryFutureExt, TryStreamExt,
18 future::ready,
19 stream::{self, BoxStream, select},
20};
21use serde::Deserialize;
22pub use sqlx::{PgPool, postgres::PgConnectOptions, postgres::PgListener, postgres::Postgres};
23use ulid::Ulid;
24
25pub use crate::{
26 ack::{LockTaskLayer, PgAck},
27 fetcher::{PgFetcher, PgPollFetcher},
28 queries::{
29 keep_alive::{initial_heartbeat, keep_alive_stream},
30 reenqueue_orphaned::reenqueue_orphaned_stream,
31 },
32 sink::PgSink,
33};
34
35mod ack;
36mod fetcher;
37mod from_row;
38
39pub type PgContext = apalis_sql::context::SqlContext<PgPool>;
40mod queries;
41pub mod shared;
42pub mod sink;
43
44pub type PgTask<Args> = Task<Args, PgContext, Ulid>;
45
46pub type PgTaskId = TaskId<Ulid>;
47
48pub type CompactType = Vec<u8>;
49
50#[doc = features_table! {
51 setup = r#"
52 # {
53 # use apalis_postgres::PostgresStorage;
54 # use sqlx::PgPool;
55 # let pool = PgPool::connect(std::env::var("DATABASE_URL").unwrap().as_str()).await.unwrap();
56 # PostgresStorage::setup(&pool).await.unwrap();
57 # PostgresStorage::new(&pool)
58 # };
59 "#,
60
61 Backend => supported("Supports storage and retrieval of tasks", true),
62 TaskSink => supported("Ability to push new tasks", true),
63 Serialization => supported("Serialization support for arguments", true),
64 Workflow => supported("Flexible enough to support workflows", true),
65 WebUI => supported("Expose a web interface for monitoring tasks", true),
66 FetchById => supported("Allow fetching a task by its ID", false),
67 RegisterWorker => supported("Allow registering a worker with the backend", false),
68 MakeShared => supported("Share one connection across multiple workers via [`SharedPostgresStorage`]", false),
69 WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
70 ResumeById => supported("Resume a task by its ID", false),
71 ResumeAbandoned => supported("Resume abandoned tasks", false),
72 ListWorkers => supported("List all workers registered with the backend", false),
73 ListTasks => supported("List all tasks in the backend", false),
74}]
75#[pin_project::pin_project]
78pub struct PostgresStorage<
79 Args,
80 Compact = CompactType,
81 Codec = JsonCodec<CompactType>,
82 Fetcher = PgFetcher<Args, Compact, Codec>,
83> {
84 _marker: PhantomData<(Args, Compact, Codec)>,
85 pool: PgPool,
86 config: Config,
87 #[pin]
88 fetcher: Fetcher,
89 #[pin]
90 sink: PgSink<Args, Compact, Codec>,
91}
92
93#[derive(Debug, Clone, Default)]
95pub struct PgNotify {
96 _private: PhantomData<()>,
97}
98
99impl<Args, Compact, Codec, Fetcher: Clone> Clone
100 for PostgresStorage<Args, Compact, Codec, Fetcher>
101{
102 fn clone(&self) -> Self {
103 Self {
104 _marker: PhantomData,
105 pool: self.pool.clone(),
106 config: self.config.clone(),
107 fetcher: self.fetcher.clone(),
108 sink: self.sink.clone(),
109 }
110 }
111}
112
113impl PostgresStorage<(), (), ()> {
114 #[cfg(feature = "migrate")]
116 pub async fn setup(pool: &PgPool) -> Result<(), sqlx::Error> {
117 Self::migrations().run(pool).await?;
118 Ok(())
119 }
120
121 #[cfg(feature = "migrate")]
123 pub fn migrations() -> sqlx::migrate::Migrator {
124 sqlx::migrate!("./migrations")
125 }
126}
127
128impl<Args> PostgresStorage<Args> {
129 pub fn new(pool: &PgPool) -> Self {
130 let config = Config::new(std::any::type_name::<Args>());
131 Self::new_with_config(pool, &config)
132 }
133
134 pub fn new_with_config(pool: &PgPool, config: &Config) -> Self {
136 let sink = PgSink::new(pool, config);
137 Self {
138 _marker: PhantomData,
139 pool: pool.clone(),
140 config: config.clone(),
141 fetcher: PgFetcher {
142 _marker: PhantomData,
143 },
144 sink,
145 }
146 }
147
148 pub fn new_with_notify(
149 pool: &PgPool,
150 config: &Config,
151 ) -> PostgresStorage<Args, CompactType, JsonCodec<CompactType>, PgNotify> {
152 let sink = PgSink::new(pool, config);
153
154 PostgresStorage {
155 _marker: PhantomData,
156 pool: pool.clone(),
157 config: config.clone(),
158 fetcher: PgNotify::default(),
159 sink,
160 }
161 }
162
163 pub fn pool(&self) -> &PgPool {
165 &self.pool
166 }
167
168 pub fn config(&self) -> &Config {
170 &self.config
171 }
172}
173
174impl<Args, Compact, Codec, Fetcher> PostgresStorage<Args, Compact, Codec, Fetcher> {
175 pub fn with_codec<NewCodec>(self) -> PostgresStorage<Args, Compact, NewCodec, Fetcher> {
176 PostgresStorage {
177 _marker: PhantomData,
178 sink: PgSink::new(&self.pool, &self.config),
179 pool: self.pool,
180 config: self.config,
181 fetcher: self.fetcher,
182 }
183 }
184}
185
186impl<Args, Decode> Backend
187 for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
188where
189 Args: Send + 'static + Unpin,
190 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
191 Decode::Error: std::error::Error + Send + Sync + 'static,
192{
193 type Args = Args;
194
195 type IdType = Ulid;
196
197 type Context = PgContext;
198
199 type Error = sqlx::Error;
200
201 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
202
203 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
204
205 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
206
207 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
208 let pool = self.pool.clone();
209 let config = self.config.clone();
210 let worker = worker.clone();
211 let keep_alive = keep_alive_stream(pool, config, worker);
212 let reenqueue = reenqueue_orphaned_stream(
213 self.pool.clone(),
214 self.config.clone(),
215 *self.config.keep_alive(),
216 )
217 .map_ok(|_| ());
218 futures::stream::select(keep_alive, reenqueue).boxed()
219 }
220
221 fn middleware(&self) -> Self::Layer {
222 Stack::new(
223 LockTaskLayer::new(self.pool.clone()),
224 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
225 )
226 }
227
228 fn poll(self, worker: &WorkerContext) -> Self::Stream {
229 self.poll_basic(worker)
230 .map(|a| match a {
231 Ok(Some(task)) => Ok(Some(
232 task.try_map(|t| Decode::decode(&t))
233 .map_err(|e| sqlx::Error::Decode(e.into()))?,
234 )),
235 Ok(None) => Ok(None),
236 Err(e) => Err(e),
237 })
238 .boxed()
239 }
240}
241
242impl<Args, Decode> BackendExt
243 for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
244where
245 Args: Send + 'static + Unpin,
246 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
247 Decode::Error: std::error::Error + Send + Sync + 'static,
248{
249 type Compact = CompactType;
250
251 type Codec = Decode;
252 type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
253
254 fn get_queue(&self) -> Queue {
255 self.config.queue().clone()
256 }
257 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
258 self.poll_basic(worker).boxed()
259 }
260}
261
262impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
263where
264 Args: Send + 'static + Unpin,
265{
266 fn poll_basic(&self, worker: &WorkerContext) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
267 let register_worker = initial_heartbeat(
268 self.pool.clone(),
269 self.config.clone(),
270 worker.clone(),
271 "PostgresStorage",
272 )
273 .map_ok(|_| None);
274 let register = stream::once(register_worker);
275 register
276 .chain(PgPollFetcher::<CompactType>::new(
277 &self.pool,
278 &self.config,
279 worker,
280 ))
281 .boxed()
282 }
283}
284
285impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, PgNotify>
286where
287 Args: Send + 'static + Unpin,
288 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
289 Decode::Error: std::error::Error + Send + Sync + 'static,
290{
291 type Args = Args;
292
293 type IdType = Ulid;
294
295 type Context = PgContext;
296
297 type Error = sqlx::Error;
298
299 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
300
301 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
302
303 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
304
305 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
306 let pool = self.pool.clone();
307 let config = self.config.clone();
308 let worker = worker.clone();
309 let keep_alive = keep_alive_stream(pool, config, worker);
310 let reenqueue = reenqueue_orphaned_stream(
311 self.pool.clone(),
312 self.config.clone(),
313 *self.config.keep_alive(),
314 )
315 .map_ok(|_| ());
316 futures::stream::select(keep_alive, reenqueue).boxed()
317 }
318
319 fn middleware(&self) -> Self::Layer {
320 Stack::new(
321 LockTaskLayer::new(self.pool.clone()),
322 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
323 )
324 }
325
326 fn poll(self, worker: &WorkerContext) -> Self::Stream {
327 self.poll_with_notify(worker)
328 .map(|a| match a {
329 Ok(Some(task)) => Ok(Some(
330 task.try_map(|t| Decode::decode(&t))
331 .map_err(|e| sqlx::Error::Decode(e.into()))?,
332 )),
333 Ok(None) => Ok(None),
334 Err(e) => Err(e),
335 })
336 .boxed()
337 }
338}
339
340impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, PgNotify>
341where
342 Args: Send + 'static + Unpin,
343 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
344 Decode::Error: std::error::Error + Send + Sync + 'static,
345{
346 type Compact = CompactType;
347
348 type Codec = Decode;
349 type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
350
351 fn get_queue(&self) -> Queue {
352 self.config.queue().clone()
353 }
354
355 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
356 self.poll_with_notify(worker).boxed()
357 }
358}
359
360impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgNotify> {
361 pub fn poll_with_notify(
362 &self,
363 worker: &WorkerContext,
364 ) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
365 let pool = self.pool.clone();
366 let worker_id = worker.name().to_owned();
367 let namespace = self.config.queue().to_string();
368 let listener = async move {
369 let mut fetcher = PgListener::connect_with(&pool)
370 .await
371 .expect("Failed to create listener");
372 fetcher.listen("apalis::job::insert").await.unwrap();
373 fetcher
374 };
375 let fetcher = stream::once(listener).flat_map(|f| f.into_stream());
376 let pool = self.pool.clone();
377 let register_worker = initial_heartbeat(
378 self.pool.clone(),
379 self.config.clone(),
380 worker.clone(),
381 "PostgresStorageWithNotify",
382 )
383 .map_ok(|_| None);
384 let register = stream::once(register_worker);
385 let lazy_fetcher = fetcher
386 .into_stream()
387 .filter_map(move |notification| {
388 let namespace = namespace.clone();
389 async move {
390 let pg_notification = notification.ok()?;
391 let payload = pg_notification.payload();
392 let ev: InsertEvent = serde_json::from_str(payload).ok()?;
393
394 if ev.job_type == namespace {
395 return Some(ev.id);
396 }
397 None
398 }
399 })
400 .map(|t| t.to_string())
401 .ready_chunks(self.config.buffer_size())
402 .then(move |ids| {
403 let pool = pool.clone();
404 let worker_id = worker_id.clone();
405 async move {
406 let mut tx = pool.begin().await?;
407 use crate::from_row::PgTaskRow;
408 let res: Vec<_> = sqlx::query_file_as!(
409 PgTaskRow,
410 "queries/task/queue_by_id.sql",
411 &ids,
412 &worker_id
413 )
414 .fetch(&mut *tx)
415 .map(|r| {
416 let row: TaskRow = r?.try_into()?;
417 Ok(Some(
418 row.try_into_task_compact()
419 .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
420 ))
421 })
422 .collect()
423 .await;
424 tx.commit().await?;
425 Ok::<_, sqlx::Error>(res)
426 }
427 })
428 .flat_map(|vec| match vec {
429 Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
430 Ok(t) => Ok(t),
431 Err(e) => Err(e),
432 }))
433 .boxed(),
434 Err(e) => stream::once(ready(Err(e))).boxed(),
435 })
436 .boxed();
437
438 let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
439 &self.pool,
440 &self.config,
441 worker,
442 ));
443 register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
444 }
445}
446
447#[derive(Debug, Deserialize)]
448pub struct InsertEvent {
449 job_type: String,
450 id: PgTaskId,
451}
452
453#[cfg(test)]
454mod tests {
455 use std::{
456 collections::HashMap,
457 env,
458 time::{Duration, Instant},
459 };
460
461 use apalis_workflow::Workflow;
462 use apalis_workflow::WorkflowSink;
463
464 use apalis_core::{
465 backend::poll_strategy::{IntervalStrategy, StrategyBuilder},
466 error::BoxDynError,
467 task::data::Data,
468 worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
469 };
470 use serde::{Deserialize, Serialize};
471
472 use super::*;
473
474 #[tokio::test]
475 async fn basic_worker() {
476 use apalis_core::backend::TaskSink;
477 let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
478 .await
479 .unwrap();
480 let mut backend = PostgresStorage::new(&pool);
481
482 let mut items = stream::repeat_with(HashMap::default).take(1);
483 backend.push_stream(&mut items).await.unwrap();
484
485 async fn send_reminder(
486 _: HashMap<String, String>,
487 wrk: WorkerContext,
488 ) -> Result<(), BoxDynError> {
489 tokio::time::sleep(Duration::from_secs(2)).await;
490 wrk.stop().unwrap();
491 Ok(())
492 }
493
494 let worker = WorkerBuilder::new("rango-tango-1")
495 .backend(backend)
496 .build(send_reminder);
497 worker.run().await.unwrap();
498 }
499
500 #[tokio::test]
501 async fn notify_worker() {
502 use apalis_core::backend::TaskSink;
503 let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
504 .await
505 .unwrap();
506 let config = Config::new("test").with_poll_interval(
507 StrategyBuilder::new()
508 .apply(IntervalStrategy::new(Duration::from_secs(6)))
509 .build(),
510 );
511 let backend = PostgresStorage::new_with_notify(&pool, &config);
512
513 let mut b = backend.clone();
514
515 tokio::spawn(async move {
516 tokio::time::sleep(Duration::from_secs(2)).await;
517 let mut items = stream::repeat_with(|| {
518 Task::builder(42u32)
519 .with_ctx(PgContext::new().with_priority(1))
520 .build()
521 })
522 .take(1);
523 b.push_all(&mut items).await.unwrap();
524 });
525
526 async fn send_reminder(_: u32, wrk: WorkerContext) -> Result<(), BoxDynError> {
527 wrk.stop().unwrap();
528 Ok(())
529 }
530
531 let instant = Instant::now();
532 let worker = WorkerBuilder::new("rango-tango-2")
533 .backend(backend)
534 .build(send_reminder);
535 worker.run().await.unwrap();
536 let run_for = instant.elapsed();
537 assert!(
538 run_for < Duration::from_secs(4),
539 "Worker did not use notify mechanism"
540 );
541 }
542
543 #[tokio::test]
544 async fn test_workflow_complete() {
545 #[derive(Debug, Serialize, Deserialize, Clone)]
546 struct PipelineConfig {
547 min_confidence: f32,
548 enable_sentiment: bool,
549 }
550
551 #[derive(Debug, Serialize, Deserialize)]
552 struct UserInput {
553 text: String,
554 }
555
556 #[derive(Debug, Serialize, Deserialize)]
557 struct Classified {
558 text: String,
559 label: String,
560 confidence: f32,
561 }
562
563 #[derive(Debug, Serialize, Deserialize)]
564 struct Summary {
565 text: String,
566 sentiment: Option<String>,
567 }
568
569 let workflow = Workflow::new("text-pipeline")
570 .and_then(|input: UserInput, mut worker: WorkerContext| async move {
572 worker.emit(&Event::Custom(Box::new(format!(
573 "Preprocessing input: {}",
574 input.text
575 ))));
576 let processed = input.text.to_lowercase();
577 Ok::<_, BoxDynError>(processed)
578 })
579 .and_then(|text: String| async move {
581 let confidence = 0.85; let items = text.split_whitespace().collect::<Vec<_>>();
583 let results = items
584 .into_iter()
585 .map(|x| Classified {
586 text: x.to_string(),
587 label: if x.contains("rust") {
588 "Tech"
589 } else {
590 "General"
591 }
592 .to_string(),
593 confidence,
594 })
595 .collect::<Vec<_>>();
596 Ok::<_, BoxDynError>(results)
597 })
598 .filter_map(
600 |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
601 )
602 .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
603 let cfg = config.enable_sentiment;
604 async move {
605 if !cfg {
606 return Some(Summary {
607 text: c.text,
608 sentiment: None,
609 });
610 }
611
612 let sentiment = if c.text.contains("delightful") {
614 "positive"
615 } else {
616 "neutral"
617 };
618 Some(Summary {
619 text: c.text,
620 sentiment: Some(sentiment.to_string()),
621 })
622 }
623 })
624 .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
625 dbg!(&a);
626 worker.emit(&Event::Custom(Box::new(format!(
627 "Generated {} summaries",
628 a.len()
629 ))));
630 worker.stop()
631 });
632
633 let pool = PgPool::connect(env::var("DATABASE_URL").unwrap().as_str())
634 .await
635 .unwrap();
636 let config = Config::new("test").with_poll_interval(
637 StrategyBuilder::new()
638 .apply(IntervalStrategy::new(Duration::from_secs(1)))
639 .build(),
640 );
641 let mut backend = PostgresStorage::new_with_notify(&pool, &config);
642
643 let input = UserInput {
644 text: "Rust makes systems programming delightful!".to_string(),
645 };
646 backend.push_start(input).await.unwrap();
647
648 let worker = WorkerBuilder::new("rango-tango")
649 .backend(backend)
650 .data(PipelineConfig {
651 min_confidence: 0.8,
652 enable_sentiment: true,
653 })
654 .on_event(|ctx, ev| match ev {
655 Event::Custom(msg) => {
656 if let Some(m) = msg.downcast_ref::<String>() {
657 println!("Custom Message: {m}");
658 }
659 }
660 Event::Error(_) => {
661 println!("On Error = {ev:?}");
662 ctx.stop().unwrap();
663 }
664 _ => {
665 println!("On Event = {ev:?}");
666 }
667 })
668 .build(workflow);
669 worker.run().await.unwrap();
670 }
671}