1use std::{fmt::Debug, marker::PhantomData};
197
198use apalis_core::{
199 backend::{
200 Backend, TaskStream,
201 codec::{Codec, json::JsonCodec},
202 },
203 layers::Stack,
204 task::{Task, task_id::TaskId},
205 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
206};
207use apalis_sql::from_row::TaskRow;
208use futures::{
209 FutureExt, StreamExt, TryStreamExt,
210 future::ready,
211 stream::{self, BoxStream, select},
212};
213use serde::Deserialize;
214pub use sqlx::{PgPool, postgres::PgConnectOptions, postgres::PgListener, postgres::Postgres};
215use ulid::Ulid;
216
217use crate::{
218 ack::{LockTaskLayer, PgAck},
219 config::Config,
220 context::PgContext,
221 fetcher::{PgFetcher, PgPollFetcher},
222 queries::{
223 keep_alive::{initial_heartbeat, keep_alive_stream},
224 reenqueue_orphaned::reenqueue_orphaned_stream,
225 },
226 sink::PgSink,
227};
228
229mod ack;
230pub mod config;
231mod fetcher;
232mod from_row {
233 use chrono::{DateTime, Utc};
234 #[derive(Debug)]
235 pub struct PgTaskRow {
236 pub job: Option<Vec<u8>>,
237 pub id: Option<String>,
238 pub job_type: Option<String>,
239 pub status: Option<String>,
240 pub attempts: Option<i32>,
241 pub max_attempts: Option<i32>,
242 pub run_at: Option<DateTime<Utc>>,
243 pub last_result: Option<serde_json::Value>,
244 pub lock_at: Option<DateTime<Utc>>,
245 pub lock_by: Option<String>,
246 pub done_at: Option<DateTime<Utc>>,
247 pub priority: Option<i32>,
248 pub metadata: Option<serde_json::Value>,
249 }
250 impl TryInto<apalis_sql::from_row::TaskRow> for PgTaskRow {
251 type Error = sqlx::Error;
252
253 fn try_into(self) -> Result<apalis_sql::from_row::TaskRow, Self::Error> {
254 Ok(apalis_sql::from_row::TaskRow {
255 job: self.job.unwrap_or_default(),
256 id: self
257 .id
258 .ok_or_else(|| sqlx::Error::Protocol("Missing id".into()))?,
259 job_type: self
260 .job_type
261 .ok_or_else(|| sqlx::Error::Protocol("Missing job_type".into()))?,
262 status: self
263 .status
264 .ok_or_else(|| sqlx::Error::Protocol("Missing status".into()))?,
265 attempts: self
266 .attempts
267 .ok_or_else(|| sqlx::Error::Protocol("Missing attempts".into()))?
268 as usize,
269 max_attempts: self.max_attempts.map(|v| v as usize),
270 run_at: self.run_at,
271 last_result: self.last_result,
272 lock_at: self.lock_at,
273 lock_by: self.lock_by,
274 done_at: self.done_at,
275 priority: self.priority.map(|v| v as usize),
276 metadata: self.metadata,
277 })
278 }
279 }
280}
281pub mod context {
282 pub type PgContext = apalis_sql::context::SqlContext;
283}
284mod queries;
285pub mod shared;
286pub mod sink;
287
288pub type PgTask<Args> = Task<Args, PgContext, Ulid>;
289
290pub type CompactType = Vec<u8>;
291
292#[derive(Debug, Clone, Default)]
293pub struct PgNotify {
294 _private: PhantomData<()>,
295}
296
297#[pin_project::pin_project]
298pub struct PostgresStorage<
299 Args,
300 Compact = CompactType,
301 Codec = JsonCodec<CompactType>,
302 Fetcher = PgFetcher<Args, Compact, Codec>,
303> {
304 _marker: PhantomData<(Args, Compact, Codec)>,
305 pool: PgPool,
306 config: Config,
307 #[pin]
308 fetcher: Fetcher,
309 #[pin]
310 sink: PgSink<Args, Compact, Codec>,
311}
312
313impl<Args, Compact, Codec, Fetcher: Clone> Clone
314 for PostgresStorage<Args, Compact, Codec, Fetcher>
315{
316 fn clone(&self) -> Self {
317 Self {
318 _marker: PhantomData,
319 pool: self.pool.clone(),
320 config: self.config.clone(),
321 fetcher: self.fetcher.clone(),
322 sink: self.sink.clone(),
323 }
324 }
325}
326
327impl PostgresStorage<(), (), ()> {
328 #[cfg(feature = "migrate")]
330 pub async fn setup(pool: &PgPool) -> Result<(), sqlx::Error> {
331 Self::migrations().run(pool).await?;
332 Ok(())
333 }
334
335 #[cfg(feature = "migrate")]
337 pub fn migrations() -> sqlx::migrate::Migrator {
338 sqlx::migrate!("./migrations")
339 }
340}
341
342impl<Args> PostgresStorage<Args> {
343 pub fn new(pool: &PgPool) -> Self {
344 let config = Config::new(std::any::type_name::<Args>());
345 Self::new_with_config(pool, &config)
346 }
347
348 pub fn new_with_config(pool: &PgPool, config: &Config) -> Self {
350 let sink = PgSink::new(pool, config);
351 Self {
352 _marker: PhantomData,
353 pool: pool.clone(),
354 config: config.clone(),
355 fetcher: PgFetcher {
356 _marker: PhantomData,
357 },
358 sink,
359 }
360 }
361
362 pub fn new_with_notify(
363 pool: &PgPool,
364 config: &Config,
365 ) -> PostgresStorage<Args, CompactType, JsonCodec<CompactType>, PgNotify> {
366 let sink = PgSink::new(pool, config);
367
368 PostgresStorage {
369 _marker: PhantomData,
370 pool: pool.clone(),
371 config: config.clone(),
372 fetcher: PgNotify::default(),
373 sink,
374 }
375 }
376
377 pub fn pool(&self) -> &PgPool {
379 &self.pool
380 }
381
382 pub fn config(&self) -> &Config {
384 &self.config
385 }
386}
387
388impl<Args, Compact, Codec, Fetcher> PostgresStorage<Args, Compact, Codec, Fetcher> {
389 pub fn with_codec<NewCodec>(self) -> PostgresStorage<Args, Compact, NewCodec, Fetcher> {
390 PostgresStorage {
391 _marker: PhantomData,
392 sink: PgSink::new(&self.pool, &self.config),
393 pool: self.pool,
394 config: self.config,
395 fetcher: self.fetcher,
396 }
397 }
398}
399
400impl<Args, Decode> Backend
401 for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
402where
403 Args: Send + 'static + Unpin,
404 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
405 Decode::Error: std::error::Error + Send + Sync + 'static,
406{
407 type Args = Args;
408
409 type Compact = CompactType;
410
411 type IdType = Ulid;
412
413 type Context = PgContext;
414
415 type Codec = Decode;
416
417 type Error = sqlx::Error;
418
419 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
420
421 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
422
423 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
424
425 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
426 let pool = self.pool.clone();
427 let config = self.config.clone();
428 let worker = worker.clone();
429 let keep_alive = keep_alive_stream(pool, config, worker);
430 let reenqueue = reenqueue_orphaned_stream(
431 self.pool.clone(),
432 self.config.clone(),
433 *self.config.keep_alive(),
434 )
435 .map_ok(|_| ());
436 futures::stream::select(keep_alive, reenqueue).boxed()
437 }
438
439 fn middleware(&self) -> Self::Layer {
440 Stack::new(
441 LockTaskLayer::new(self.pool.clone()),
442 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
443 )
444 }
445
446 fn poll(self, worker: &WorkerContext) -> Self::Stream {
447 let register_worker = initial_heartbeat(
448 self.pool.clone(),
449 self.config.clone(),
450 worker.clone(),
451 "PostgresStorage",
452 )
453 .map(|_| Ok(None));
454 let register = stream::once(register_worker);
455 register
456 .chain(PgPollFetcher::<Args, CompactType, Decode>::new(
457 &self.pool,
458 &self.config,
459 worker,
460 ))
461 .boxed()
462 }
463}
464
465impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, PgNotify>
466where
467 Args: Send + 'static + Unpin,
468 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
469 Decode::Error: std::error::Error + Send + Sync + 'static,
470{
471 type Args = Args;
472
473 type Compact = CompactType;
474
475 type IdType = Ulid;
476
477 type Context = PgContext;
478
479 type Codec = Decode;
480
481 type Error = sqlx::Error;
482
483 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
484
485 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
486
487 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
488
489 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
490 let pool = self.pool.clone();
491 let config = self.config.clone();
492 let worker = worker.clone();
493 let keep_alive = keep_alive_stream(pool, config, worker);
494 let reenqueue = reenqueue_orphaned_stream(
495 self.pool.clone(),
496 self.config.clone(),
497 *self.config.keep_alive(),
498 )
499 .map_ok(|_| ());
500 futures::stream::select(keep_alive, reenqueue).boxed()
501 }
502
503 fn middleware(&self) -> Self::Layer {
504 Stack::new(
505 LockTaskLayer::new(self.pool.clone()),
506 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
507 )
508 }
509
510 fn poll(self, worker: &WorkerContext) -> Self::Stream {
511 let pool = self.pool.clone();
512 let worker_id = worker.name().to_owned();
513 let namespace = self.config.queue().to_string();
514 let listener = async move {
515 let mut fetcher = PgListener::connect_with(&pool)
516 .await
517 .expect("Failed to create listener");
518 fetcher.listen("apalis::job::insert").await.unwrap();
519 fetcher
520 };
521 let fetcher = stream::once(listener).flat_map(|f| f.into_stream());
522 let pool = self.pool.clone();
523 let register_worker = initial_heartbeat(
524 self.pool.clone(),
525 self.config.clone(),
526 worker.clone(),
527 "PostgresStorageWithNotify",
528 )
529 .map(|_| Ok(None));
530 let register = stream::once(register_worker);
531 let lazy_fetcher = fetcher
532 .into_stream()
533 .filter_map(move |notification| {
534 let namespace = namespace.clone();
535 async move {
536 let pg_notification = notification.ok()?;
537 let payload = pg_notification.payload();
538 let ev: InsertEvent = serde_json::from_str(payload).ok()?;
539
540 if ev.job_type == namespace {
541 return Some(ev.id);
542 }
543 None
544 }
545 })
546 .map(|t| t.to_string())
547 .ready_chunks(self.config.buffer_size())
548 .then(move |ids| {
549 let pool = pool.clone();
550 let worker_id = worker_id.clone();
551 async move {
552 let mut tx = pool.begin().await?;
553 use crate::from_row::PgTaskRow;
554 let res: Vec<_> = sqlx::query_file_as!(
555 PgTaskRow,
556 "queries/task/lock_by_id.sql",
557 &ids,
558 &worker_id
559 )
560 .fetch(&mut *tx)
561 .map(|r| {
562 let row: TaskRow = r?.try_into()?;
563 Ok(Some(
564 row.try_into_task::<Decode, Args, Ulid>()
565 .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
566 ))
567 })
568 .collect()
569 .await;
570 tx.commit().await?;
571 Ok::<_, sqlx::Error>(res)
572 }
573 })
574 .flat_map(|vec| match vec {
575 Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
576 Ok(t) => Ok(t),
577 Err(e) => Err(e),
578 }))
579 .boxed(),
580 Err(e) => stream::once(ready(Err(e))).boxed(),
581 })
582 .boxed();
583
584 let eager_fetcher = StreamExt::boxed(PgPollFetcher::<Args, CompactType, Decode>::new(
585 &self.pool,
586 &self.config,
587 worker,
588 ));
589 register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
590 }
591}
592
593#[derive(Debug, Deserialize)]
594pub struct InsertEvent {
595 job_type: String,
596 id: TaskId,
597}
598
599#[cfg(test)]
600mod tests {
601 use std::{collections::HashMap, env, time::Duration};
602
603 use apalis_workflow::{WorkFlow, WorkflowError};
604
605 use apalis_core::{
606 error::BoxDynError,
607 task::data::Data,
608 worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
609 };
610 use serde::{Deserialize, Serialize};
611
612 use super::*;
613
614 #[tokio::test]
615 async fn basic_worker() {
616 use apalis_core::backend::TaskSink;
617 let pool = PgPool::connect(
618 env::var("DATABASE_URL")
619 .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
620 .as_str(),
621 )
622 .await
623 .unwrap();
624 let mut backend = PostgresStorage::new(&pool);
625
626 let mut items = stream::repeat_with(HashMap::default).take(1);
627 backend.push_stream(&mut items).await.unwrap();
628
629 async fn send_reminder(
630 _: HashMap<String, String>,
631 wrk: WorkerContext,
632 ) -> Result<(), BoxDynError> {
633 tokio::time::sleep(Duration::from_secs(2)).await;
634 wrk.stop().unwrap();
635 Ok(())
636 }
637
638 let worker = WorkerBuilder::new("rango-tango-1")
639 .backend(backend)
640 .build(send_reminder);
641 worker.run().await.unwrap();
642 }
643
644 #[tokio::test]
645 async fn notify_worker() {
646 use apalis_core::backend::TaskSink;
647 let pool = PgPool::connect(
648 env::var("DATABASE_URL")
649 .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
650 .as_str(),
651 )
652 .await
653 .unwrap();
654 let config = Config::new("test");
655 let mut backend = PostgresStorage::new_with_notify(&pool, &config);
656
657 let mut items = stream::repeat_with(|| {
658 Task::builder(42u32)
659 .with_ctx(PgContext::new().with_priority(1))
660 .build()
661 })
662 .take(1);
663 backend.push_all(&mut items).await.unwrap();
664
665 async fn send_reminder(_: u32, wrk: WorkerContext) -> Result<(), BoxDynError> {
666 tokio::time::sleep(Duration::from_secs(2)).await;
667 wrk.stop().unwrap();
668 Ok(())
669 }
670
671 let worker = WorkerBuilder::new("rango-tango-1")
672 .backend(backend)
673 .build(send_reminder);
674 worker.run().await.unwrap();
675 }
676
677 #[tokio::test]
678 async fn test_workflow_complete() {
679 use apalis_core::backend::WeakTaskSink;
680 #[derive(Debug, Serialize, Deserialize, Clone)]
681 struct PipelineConfig {
682 min_confidence: f32,
683 enable_sentiment: bool,
684 }
685
686 #[derive(Debug, Serialize, Deserialize)]
687 struct UserInput {
688 text: String,
689 }
690
691 #[derive(Debug, Serialize, Deserialize)]
692 struct Classified {
693 text: String,
694 label: String,
695 confidence: f32,
696 }
697
698 #[derive(Debug, Serialize, Deserialize)]
699 struct Summary {
700 text: String,
701 sentiment: Option<String>,
702 }
703
704 let workflow = WorkFlow::new("text-pipeline")
705 .then(|input: UserInput, mut worker: WorkerContext| async move {
707 worker.emit(&Event::Custom(Box::new(format!(
708 "Preprocessing input: {}",
709 input.text
710 ))));
711 let processed = input.text.to_lowercase();
712 Ok::<_, WorkflowError>(processed)
713 })
714 .then(|text: String| async move {
716 let confidence = 0.85; let items = text.split_whitespace().collect::<Vec<_>>();
718 let results = items
719 .into_iter()
720 .map(|x| Classified {
721 text: x.to_string(),
722 label: if x.contains("rust") {
723 "Tech"
724 } else {
725 "General"
726 }
727 .to_string(),
728 confidence,
729 })
730 .collect::<Vec<_>>();
731 Ok::<_, WorkflowError>(results)
732 })
733 .filter_map(
735 |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
736 )
737 .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
738 let cfg = config.enable_sentiment;
739 async move {
740 if !cfg {
741 return Some(Summary {
742 text: c.text,
743 sentiment: None,
744 });
745 }
746
747 let sentiment = if c.text.contains("delightful") {
749 "positive"
750 } else {
751 "neutral"
752 };
753 Some(Summary {
754 text: c.text,
755 sentiment: Some(sentiment.to_string()),
756 })
757 }
758 })
759 .then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
760 dbg!(&a);
761 worker.emit(&Event::Custom(Box::new(format!(
762 "Generated {} summaries",
763 a.len()
764 ))));
765 worker.stop()
766 });
767
768 let pool = PgPool::connect(
769 env::var("DATABASE_URL")
770 .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
771 .as_str(),
772 )
773 .await
774 .unwrap();
775 let config = Config::new("test");
776 let mut backend: PostgresStorage<Vec<u8>> =
777 PostgresStorage::new_with_config(&pool, &config);
778
779 let input = UserInput {
780 text: "Rust makes systems programming delightful!".to_string(),
781 };
782 backend.push(input).await.unwrap();
783
784 let worker = WorkerBuilder::new("rango-tango")
785 .backend(backend)
786 .data(PipelineConfig {
787 min_confidence: 0.8,
788 enable_sentiment: true,
789 })
790 .on_event(|ctx, ev| match ev {
791 Event::Custom(msg) => {
792 if let Some(m) = msg.downcast_ref::<String>() {
793 println!("Custom Message: {m}");
794 }
795 }
796 Event::Error(_) => {
797 println!("On Error = {ev:?}");
798 ctx.stop().unwrap();
799 }
800 _ => {
801 println!("On Event = {ev:?}");
802 }
803 })
804 .build(workflow);
805 worker.run().await.unwrap();
806 }
807}