1#![doc = include_str!("../README.md")]
2use std::{fmt::Debug, marker::PhantomData};
6
7use apalis_core::{
8 backend::{
9 Backend, BackendExt, TaskStream,
10 codec::{Codec, json::JsonCodec},
11 },
12 features_table,
13 layers::Stack,
14 task::{Task, task_id::TaskId},
15 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
16};
17use apalis_sql::from_row::TaskRow;
18use futures::{
19 StreamExt, TryFutureExt, TryStreamExt,
20 future::ready,
21 stream::{self, BoxStream, select},
22};
23use serde::Deserialize;
24pub use sqlx::{PgPool, postgres::PgConnectOptions, postgres::PgListener, postgres::Postgres};
25use ulid::Ulid;
26
27pub use crate::{
28 ack::{LockTaskLayer, PgAck},
29 fetcher::{PgFetcher, PgPollFetcher},
30 queries::{
31 keep_alive::{initial_heartbeat, keep_alive_stream},
32 reenqueue_orphaned::reenqueue_orphaned_stream,
33 },
34 sink::PgSink,
35};
36
37mod ack;
38mod config;
39mod fetcher;
40mod from_row;
41
42pub type PgContext = apalis_sql::context::SqlContext;
43pub use config::Config;
44
45mod queries;
46pub mod shared;
47pub mod sink;
48
49pub type PgTask<Args> = Task<Args, PgContext, Ulid>;
50
51pub type CompactType = Vec<u8>;
52
53#[derive(Debug, Clone, Default)]
54pub struct PgNotify {
55 _private: PhantomData<()>,
56}
57
58#[doc = features_table! {
59 setup = r#"
60 # {
61 # use apalis_postgres::PostgresStorage;
62 # use sqlx::PgPool;
63 # let pool = PgPool::connect(std::env::var("DATABASE_URL").unwrap().as_str()).await.unwrap();
64 # PostgresStorage::setup(&pool).await.unwrap();
65 # PostgresStorage::new(&pool)
66 # };
67 "#,
68
69 Backend => supported("Supports storage and retrieval of tasks", true),
70 TaskSink => supported("Ability to push new tasks", true),
71 Serialization => supported("Serialization support for arguments", true),
72 Workflow => supported("Flexible enough to support workflows", true),
73 WebUI => supported("Expose a web interface for monitoring tasks", true),
74 FetchById => supported("Allow fetching a task by its ID", false),
75 RegisterWorker => supported("Allow registering a worker with the backend", false),
76 MakeShared => supported("Share one connection across multiple workers via [`SharedPostgresStorage`]", false),
77 WaitForCompletion => supported("Wait for tasks to complete without blocking", true),
78 ResumeById => supported("Resume a task by its ID", false),
79 ResumeAbandoned => supported("Resume abandoned tasks", false),
80 ListWorkers => supported("List all workers registered with the backend", false),
81 ListTasks => supported("List all tasks in the backend", false),
82}]
83#[pin_project::pin_project]
84pub struct PostgresStorage<
85 Args,
86 Compact = CompactType,
87 Codec = JsonCodec<CompactType>,
88 Fetcher = PgFetcher<Args, Compact, Codec>,
89> {
90 _marker: PhantomData<(Args, Compact, Codec)>,
91 pool: PgPool,
92 config: Config,
93 #[pin]
94 fetcher: Fetcher,
95 #[pin]
96 sink: PgSink<Args, Compact, Codec>,
97}
98
99impl<Args, Compact, Codec, Fetcher: Clone> Clone
100 for PostgresStorage<Args, Compact, Codec, Fetcher>
101{
102 fn clone(&self) -> Self {
103 Self {
104 _marker: PhantomData,
105 pool: self.pool.clone(),
106 config: self.config.clone(),
107 fetcher: self.fetcher.clone(),
108 sink: self.sink.clone(),
109 }
110 }
111}
112
113impl PostgresStorage<(), (), ()> {
114 #[cfg(feature = "migrate")]
116 pub async fn setup(pool: &PgPool) -> Result<(), sqlx::Error> {
117 Self::migrations().run(pool).await?;
118 Ok(())
119 }
120
121 #[cfg(feature = "migrate")]
123 pub fn migrations() -> sqlx::migrate::Migrator {
124 sqlx::migrate!("./migrations")
125 }
126}
127
128impl<Args> PostgresStorage<Args> {
129 pub fn new(pool: &PgPool) -> Self {
130 let config = Config::new(std::any::type_name::<Args>());
131 Self::new_with_config(pool, &config)
132 }
133
134 pub fn new_with_config(pool: &PgPool, config: &Config) -> Self {
136 let sink = PgSink::new(pool, config);
137 Self {
138 _marker: PhantomData,
139 pool: pool.clone(),
140 config: config.clone(),
141 fetcher: PgFetcher {
142 _marker: PhantomData,
143 },
144 sink,
145 }
146 }
147
148 pub fn new_with_notify(
149 pool: &PgPool,
150 config: &Config,
151 ) -> PostgresStorage<Args, CompactType, JsonCodec<CompactType>, PgNotify> {
152 let sink = PgSink::new(pool, config);
153
154 PostgresStorage {
155 _marker: PhantomData,
156 pool: pool.clone(),
157 config: config.clone(),
158 fetcher: PgNotify::default(),
159 sink,
160 }
161 }
162
163 pub fn pool(&self) -> &PgPool {
165 &self.pool
166 }
167
168 pub fn config(&self) -> &Config {
170 &self.config
171 }
172}
173
174impl<Args, Compact, Codec, Fetcher> PostgresStorage<Args, Compact, Codec, Fetcher> {
175 pub fn with_codec<NewCodec>(self) -> PostgresStorage<Args, Compact, NewCodec, Fetcher> {
176 PostgresStorage {
177 _marker: PhantomData,
178 sink: PgSink::new(&self.pool, &self.config),
179 pool: self.pool,
180 config: self.config,
181 fetcher: self.fetcher,
182 }
183 }
184}
185
186impl<Args, Decode> Backend
187 for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
188where
189 Args: Send + 'static + Unpin,
190 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
191 Decode::Error: std::error::Error + Send + Sync + 'static,
192{
193 type Args = Args;
194
195 type IdType = Ulid;
196
197 type Context = PgContext;
198
199 type Error = sqlx::Error;
200
201 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
202
203 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
204
205 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
206
207 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
208 let pool = self.pool.clone();
209 let config = self.config.clone();
210 let worker = worker.clone();
211 let keep_alive = keep_alive_stream(pool, config, worker);
212 let reenqueue = reenqueue_orphaned_stream(
213 self.pool.clone(),
214 self.config.clone(),
215 *self.config.keep_alive(),
216 )
217 .map_ok(|_| ());
218 futures::stream::select(keep_alive, reenqueue).boxed()
219 }
220
221 fn middleware(&self) -> Self::Layer {
222 Stack::new(
223 LockTaskLayer::new(self.pool.clone()),
224 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
225 )
226 }
227
228 fn poll(self, worker: &WorkerContext) -> Self::Stream {
229 self.poll_basic(worker)
230 .map(|a| match a {
231 Ok(Some(task)) => Ok(Some(
232 task.try_map(|t| Decode::decode(&t))
233 .map_err(|e| sqlx::Error::Decode(e.into()))?,
234 )),
235 Ok(None) => Ok(None),
236 Err(e) => Err(e),
237 })
238 .boxed()
239 }
240}
241
242impl<Args, Decode> BackendExt
243 for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
244where
245 Args: Send + 'static + Unpin,
246 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
247 Decode::Error: std::error::Error + Send + Sync + 'static,
248{
249 type Compact = CompactType;
250
251 type Codec = Decode;
252 type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
253 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
254 self.poll_basic(worker).boxed()
255 }
256}
257
258impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
259where
260 Args: Send + 'static + Unpin,
261{
262 fn poll_basic(&self, worker: &WorkerContext) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
263 let register_worker = initial_heartbeat(
264 self.pool.clone(),
265 self.config.clone(),
266 worker.clone(),
267 "PostgresStorage",
268 )
269 .map_ok(|_| None);
270 let register = stream::once(register_worker);
271 register
272 .chain(PgPollFetcher::<CompactType>::new(
273 &self.pool,
274 &self.config,
275 worker,
276 ))
277 .boxed()
278 }
279}
280
281impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, PgNotify>
282where
283 Args: Send + 'static + Unpin,
284 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
285 Decode::Error: std::error::Error + Send + Sync + 'static,
286{
287 type Args = Args;
288
289 type IdType = Ulid;
290
291 type Context = PgContext;
292
293 type Error = sqlx::Error;
294
295 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
296
297 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
298
299 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
300
301 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
302 let pool = self.pool.clone();
303 let config = self.config.clone();
304 let worker = worker.clone();
305 let keep_alive = keep_alive_stream(pool, config, worker);
306 let reenqueue = reenqueue_orphaned_stream(
307 self.pool.clone(),
308 self.config.clone(),
309 *self.config.keep_alive(),
310 )
311 .map_ok(|_| ());
312 futures::stream::select(keep_alive, reenqueue).boxed()
313 }
314
315 fn middleware(&self) -> Self::Layer {
316 Stack::new(
317 LockTaskLayer::new(self.pool.clone()),
318 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
319 )
320 }
321
322 fn poll(self, worker: &WorkerContext) -> Self::Stream {
323 self.poll_with_notify(worker)
324 .map(|a| match a {
325 Ok(Some(task)) => Ok(Some(
326 task.try_map(|t| Decode::decode(&t))
327 .map_err(|e| sqlx::Error::Decode(e.into()))?,
328 )),
329 Ok(None) => Ok(None),
330 Err(e) => Err(e),
331 })
332 .boxed()
333 }
334}
335
336impl<Args, Decode> BackendExt for PostgresStorage<Args, CompactType, Decode, PgNotify>
337where
338 Args: Send + 'static + Unpin,
339 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send,
340 Decode::Error: std::error::Error + Send + Sync + 'static,
341{
342 type Compact = CompactType;
343
344 type Codec = Decode;
345 type CompactStream = TaskStream<PgTask<CompactType>, Self::Error>;
346 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
347 self.poll_with_notify(worker).boxed()
348 }
349}
350
351impl<Args, Decode> PostgresStorage<Args, CompactType, Decode, PgNotify> {
352 pub fn poll_with_notify(
353 &self,
354 worker: &WorkerContext,
355 ) -> TaskStream<PgTask<CompactType>, sqlx::Error> {
356 let pool = self.pool.clone();
357 let worker_id = worker.name().to_owned();
358 let namespace = self.config.queue().to_string();
359 let listener = async move {
360 let mut fetcher = PgListener::connect_with(&pool)
361 .await
362 .expect("Failed to create listener");
363 fetcher.listen("apalis::job::insert").await.unwrap();
364 fetcher
365 };
366 let fetcher = stream::once(listener).flat_map(|f| f.into_stream());
367 let pool = self.pool.clone();
368 let register_worker = initial_heartbeat(
369 self.pool.clone(),
370 self.config.clone(),
371 worker.clone(),
372 "PostgresStorageWithNotify",
373 )
374 .map_ok(|_| None);
375 let register = stream::once(register_worker);
376 let lazy_fetcher = fetcher
377 .into_stream()
378 .filter_map(move |notification| {
379 let namespace = namespace.clone();
380 async move {
381 let pg_notification = notification.ok()?;
382 let payload = pg_notification.payload();
383 let ev: InsertEvent = serde_json::from_str(payload).ok()?;
384
385 if ev.job_type == namespace {
386 return Some(ev.id);
387 }
388 None
389 }
390 })
391 .map(|t| t.to_string())
392 .ready_chunks(self.config.buffer_size())
393 .then(move |ids| {
394 let pool = pool.clone();
395 let worker_id = worker_id.clone();
396 async move {
397 let mut tx = pool.begin().await?;
398 use crate::from_row::PgTaskRow;
399 let res: Vec<_> = sqlx::query_file_as!(
400 PgTaskRow,
401 "queries/task/queue_by_id.sql",
402 &ids,
403 &worker_id
404 )
405 .fetch(&mut *tx)
406 .map(|r| {
407 let row: TaskRow = r?.try_into()?;
408 Ok(Some(
409 row.try_into_task_compact()
410 .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
411 ))
412 })
413 .collect()
414 .await;
415 tx.commit().await?;
416 Ok::<_, sqlx::Error>(res)
417 }
418 })
419 .flat_map(|vec| match vec {
420 Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
421 Ok(t) => Ok(t),
422 Err(e) => Err(e),
423 }))
424 .boxed(),
425 Err(e) => stream::once(ready(Err(e))).boxed(),
426 })
427 .boxed();
428
429 let eager_fetcher = StreamExt::boxed(PgPollFetcher::<CompactType>::new(
430 &self.pool,
431 &self.config,
432 worker,
433 ));
434 register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
435 }
436}
437
438#[derive(Debug, Deserialize)]
439pub struct InsertEvent {
440 job_type: String,
441 id: TaskId,
442}
443
444#[cfg(test)]
445mod tests {
446 use std::{
447 collections::HashMap,
448 env,
449 time::{Duration, Instant},
450 };
451
452 use apalis_workflow::Workflow;
453 use apalis_workflow::WorkflowSink;
454
455 use apalis_core::{
456 backend::poll_strategy::{IntervalStrategy, StrategyBuilder},
457 error::BoxDynError,
458 task::data::Data,
459 worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
460 };
461 use serde::{Deserialize, Serialize};
462
463 use super::*;
464
465 #[tokio::test]
466 async fn basic_worker() {
467 use apalis_core::backend::TaskSink;
468 let pool = PgPool::connect(
469 env::var("DATABASE_URL")
470 .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
471 .as_str(),
472 )
473 .await
474 .unwrap();
475 let mut backend = PostgresStorage::new(&pool);
476
477 let mut items = stream::repeat_with(HashMap::default).take(1);
478 backend.push_stream(&mut items).await.unwrap();
479
480 async fn send_reminder(
481 _: HashMap<String, String>,
482 wrk: WorkerContext,
483 ) -> Result<(), BoxDynError> {
484 tokio::time::sleep(Duration::from_secs(2)).await;
485 wrk.stop().unwrap();
486 Ok(())
487 }
488
489 let worker = WorkerBuilder::new("rango-tango-1")
490 .backend(backend)
491 .build(send_reminder);
492 worker.run().await.unwrap();
493 }
494
495 #[tokio::test]
496 async fn notify_worker() {
497 use apalis_core::backend::TaskSink;
498 let pool = PgPool::connect(
499 env::var("DATABASE_URL")
500 .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
501 .as_str(),
502 )
503 .await
504 .unwrap();
505 let config = Config::new("test").with_poll_interval(
506 StrategyBuilder::new()
507 .apply(IntervalStrategy::new(Duration::from_secs(6)))
508 .build(),
509 );
510 let backend = PostgresStorage::new_with_notify(&pool, &config);
511
512 let mut b = backend.clone();
513
514 tokio::spawn(async move {
515 tokio::time::sleep(Duration::from_secs(2)).await;
516 let mut items = stream::repeat_with(|| {
517 Task::builder(42u32)
518 .with_ctx(PgContext::new().with_priority(1))
519 .build()
520 })
521 .take(1);
522 b.push_all(&mut items).await.unwrap();
523 });
524
525 async fn send_reminder(_: u32, wrk: WorkerContext) -> Result<(), BoxDynError> {
526 wrk.stop().unwrap();
527 Ok(())
528 }
529
530 let instant = Instant::now();
531 let worker = WorkerBuilder::new("rango-tango-2")
532 .backend(backend)
533 .build(send_reminder);
534 worker.run().await.unwrap();
535 let run_for = instant.elapsed();
536 assert!(
537 run_for < Duration::from_secs(4),
538 "Worker did not use notify mechanism"
539 );
540 }
541
542 #[tokio::test]
543 async fn test_workflow_complete() {
544 #[derive(Debug, Serialize, Deserialize, Clone)]
545 struct PipelineConfig {
546 min_confidence: f32,
547 enable_sentiment: bool,
548 }
549
550 #[derive(Debug, Serialize, Deserialize)]
551 struct UserInput {
552 text: String,
553 }
554
555 #[derive(Debug, Serialize, Deserialize)]
556 struct Classified {
557 text: String,
558 label: String,
559 confidence: f32,
560 }
561
562 #[derive(Debug, Serialize, Deserialize)]
563 struct Summary {
564 text: String,
565 sentiment: Option<String>,
566 }
567
568 let workflow = Workflow::new("text-pipeline")
569 .and_then(|input: UserInput, mut worker: WorkerContext| async move {
571 worker.emit(&Event::Custom(Box::new(format!(
572 "Preprocessing input: {}",
573 input.text
574 ))));
575 let processed = input.text.to_lowercase();
576 Ok::<_, BoxDynError>(processed)
577 })
578 .and_then(|text: String| async move {
580 let confidence = 0.85; let items = text.split_whitespace().collect::<Vec<_>>();
582 let results = items
583 .into_iter()
584 .map(|x| Classified {
585 text: x.to_string(),
586 label: if x.contains("rust") {
587 "Tech"
588 } else {
589 "General"
590 }
591 .to_string(),
592 confidence,
593 })
594 .collect::<Vec<_>>();
595 Ok::<_, BoxDynError>(results)
596 })
597 .filter_map(
599 |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
600 )
601 .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
602 let cfg = config.enable_sentiment;
603 async move {
604 if !cfg {
605 return Some(Summary {
606 text: c.text,
607 sentiment: None,
608 });
609 }
610
611 let sentiment = if c.text.contains("delightful") {
613 "positive"
614 } else {
615 "neutral"
616 };
617 Some(Summary {
618 text: c.text,
619 sentiment: Some(sentiment.to_string()),
620 })
621 }
622 })
623 .and_then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
624 dbg!(&a);
625 worker.emit(&Event::Custom(Box::new(format!(
626 "Generated {} summaries",
627 a.len()
628 ))));
629 worker.stop()
630 });
631
632 let pool = PgPool::connect(
633 env::var("DATABASE_URL")
634 .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
635 .as_str(),
636 )
637 .await
638 .unwrap();
639 let config = Config::new("test").with_poll_interval(
640 StrategyBuilder::new()
641 .apply(IntervalStrategy::new(Duration::from_secs(1)))
642 .build(),
643 );
644 let mut backend = PostgresStorage::new_with_notify(&pool, &config);
645
646 let input = UserInput {
647 text: "Rust makes systems programming delightful!".to_string(),
648 };
649 backend.push_start(input).await.unwrap();
650
651 let worker = WorkerBuilder::new("rango-tango")
652 .backend(backend)
653 .data(PipelineConfig {
654 min_confidence: 0.8,
655 enable_sentiment: true,
656 })
657 .on_event(|ctx, ev| match ev {
658 Event::Custom(msg) => {
659 if let Some(m) = msg.downcast_ref::<String>() {
660 println!("Custom Message: {m}");
661 }
662 }
663 Event::Error(_) => {
664 println!("On Error = {ev:?}");
665 ctx.stop().unwrap();
666 }
667 _ => {
668 println!("On Event = {ev:?}");
669 }
670 })
671 .build(workflow);
672 worker.run().await.unwrap();
673 }
674}