1use std::{fmt::Debug, marker::PhantomData};
197
198use apalis_core::{
199 backend::{
200 Backend, TaskStream,
201 codec::{Codec, json::JsonCodec},
202 },
203 layers::Stack,
204 task::{Task, task_id::TaskId},
205 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
206};
207use apalis_sql::from_row::TaskRow;
208use futures::{
209 FutureExt, StreamExt, TryStreamExt,
210 future::ready,
211 stream::{self, BoxStream, select},
212};
213use serde::Deserialize;
214use sqlx::{PgPool, postgres::PgListener};
215use ulid::Ulid;
216
217use crate::{
218 ack::{LockTaskLayer, PgAck},
219 config::Config,
220 context::PgContext,
221 fetcher::{PgFetcher, PgPollFetcher},
222 queries::{
223 keep_alive::{initial_heartbeat, keep_alive_stream},
224 reenqueue_orphaned::reenqueue_orphaned_stream,
225 },
226 sink::PgSink,
227};
228
229mod ack;
230mod config;
231mod fetcher;
232mod from_row {
233 use chrono::{DateTime, Utc};
234 #[derive(Debug)]
235 pub struct PgTaskRow {
236 pub job: Option<Vec<u8>>,
237 pub id: Option<String>,
238 pub job_type: Option<String>,
239 pub status: Option<String>,
240 pub attempts: Option<i32>,
241 pub max_attempts: Option<i32>,
242 pub run_at: Option<DateTime<Utc>>,
243 pub last_result: Option<serde_json::Value>,
244 pub lock_at: Option<DateTime<Utc>>,
245 pub lock_by: Option<String>,
246 pub done_at: Option<DateTime<Utc>>,
247 pub priority: Option<i32>,
248 pub metadata: Option<serde_json::Value>,
249 }
250 impl TryInto<apalis_sql::from_row::TaskRow> for PgTaskRow {
251 type Error = sqlx::Error;
252
253 fn try_into(self) -> Result<apalis_sql::from_row::TaskRow, Self::Error> {
254 Ok(apalis_sql::from_row::TaskRow {
255 job: self.job.unwrap_or_default(),
256 id: self
257 .id
258 .ok_or_else(|| sqlx::Error::Protocol("Missing id".into()))?,
259 job_type: self
260 .job_type
261 .ok_or_else(|| sqlx::Error::Protocol("Missing job_type".into()))?,
262 status: self
263 .status
264 .ok_or_else(|| sqlx::Error::Protocol("Missing status".into()))?,
265 attempts: self
266 .attempts
267 .ok_or_else(|| sqlx::Error::Protocol("Missing attempts".into()))?
268 as usize,
269 max_attempts: self.max_attempts.map(|v| v as usize),
270 run_at: self.run_at,
271 last_result: self.last_result,
272 lock_at: self.lock_at,
273 lock_by: self.lock_by,
274 done_at: self.done_at,
275 priority: self.priority.map(|v| v as usize),
276 metadata: self.metadata,
277 })
278 }
279 }
280}
281mod context {
282 pub type PgContext = apalis_sql::context::SqlContext;
283}
284mod queries;
285pub mod shared;
286pub mod sink;
287
288pub type PgTask<Args> = Task<Args, PgContext, Ulid>;
289
290pub type CompactType = Vec<u8>;
291
292#[pin_project::pin_project]
293pub struct PostgresStorage<
294 Args,
295 Compact = CompactType,
296 Codec = JsonCodec<CompactType>,
297 Fetcher = PgFetcher<Args, Compact, Codec>,
298> {
299 _marker: PhantomData<(Args, Compact, Codec)>,
300 pool: PgPool,
301 config: Config,
302 #[pin]
303 fetcher: Fetcher,
304 #[pin]
305 sink: PgSink<Args, Compact, Codec>,
306}
307
308impl<Args, Compact, Codec, Fetcher: Clone> Clone
309 for PostgresStorage<Args, Compact, Codec, Fetcher>
310{
311 fn clone(&self) -> Self {
312 Self {
313 _marker: PhantomData,
314 pool: self.pool.clone(),
315 config: self.config.clone(),
316 fetcher: self.fetcher.clone(),
317 sink: self.sink.clone(),
318 }
319 }
320}
321
322impl PostgresStorage<(), (), ()> {
323 #[cfg(feature = "migrate")]
325 pub async fn setup(pool: &PgPool) -> Result<(), sqlx::Error> {
326 Self::migrations().run(pool).await?;
327 Ok(())
328 }
329
330 #[cfg(feature = "migrate")]
332 pub fn migrations() -> sqlx::migrate::Migrator {
333 sqlx::migrate!("./migrations")
334 }
335}
336
337impl<Args> PostgresStorage<Args> {
338 pub fn new(pool: &PgPool) -> Self {
339 let config = Config::new(std::any::type_name::<Args>());
340 Self::new_with_config(pool, &config)
341 }
342
343 pub fn new_with_config(pool: &PgPool, config: &Config) -> Self {
345 let sink = PgSink::new(pool, config);
346 Self {
347 _marker: PhantomData,
348 pool: pool.clone(),
349 config: config.clone(),
350 fetcher: PgFetcher {
351 _marker: PhantomData,
352 },
353 sink,
354 }
355 }
356
357 pub async fn new_with_notify(
358 pool: &PgPool,
359 config: &Config,
360 ) -> PostgresStorage<Args, CompactType, JsonCodec<CompactType>, PgListener> {
361 let sink = PgSink::new(pool, config);
362 let mut fetcher = PgListener::connect_with(pool)
363 .await
364 .expect("Failed to create listener");
365 fetcher.listen("apalis::job::insert").await.unwrap();
366 PostgresStorage {
367 _marker: PhantomData,
368 pool: pool.clone(),
369 config: config.clone(),
370 fetcher,
371 sink,
372 }
373 }
374
375 pub fn pool(&self) -> &PgPool {
377 &self.pool
378 }
379
380 pub fn config(&self) -> &Config {
382 &self.config
383 }
384}
385
386impl<Args, Decode> Backend
387 for PostgresStorage<Args, CompactType, Decode, PgFetcher<Args, CompactType, Decode>>
388where
389 Args: Send + 'static + Unpin,
390 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
391 Decode::Error: std::error::Error + Send + Sync + 'static,
392{
393 type Args = Args;
394
395 type Compact = CompactType;
396
397 type IdType = Ulid;
398
399 type Context = PgContext;
400
401 type Codec = Decode;
402
403 type Error = sqlx::Error;
404
405 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
406
407 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
408
409 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
410
411 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
412 let pool = self.pool.clone();
413 let config = self.config.clone();
414 let worker = worker.clone();
415 let keep_alive = keep_alive_stream(pool, config, worker);
416 let reenqueue = reenqueue_orphaned_stream(
417 self.pool.clone(),
418 self.config.clone(),
419 *self.config.keep_alive(),
420 )
421 .map_ok(|_| ());
422 futures::stream::select(keep_alive, reenqueue).boxed()
423 }
424
425 fn middleware(&self) -> Self::Layer {
426 Stack::new(
427 LockTaskLayer::new(self.pool.clone()),
428 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
429 )
430 }
431
432 fn poll(self, worker: &WorkerContext) -> Self::Stream {
433 let register_worker = initial_heartbeat(
434 self.pool.clone(),
435 self.config.clone(),
436 worker.clone(),
437 "PostgresStorage",
438 )
439 .map(|_| Ok(None));
440 let register = stream::once(register_worker);
441 register
442 .chain(PgPollFetcher::<Args, CompactType, Decode>::new(
443 &self.pool,
444 &self.config,
445 worker,
446 ))
447 .boxed()
448 }
449}
450
451impl<Args, Decode> Backend for PostgresStorage<Args, CompactType, Decode, PgListener>
452where
453 Args: Send + 'static + Unpin,
454 Decode: Codec<Args, Compact = CompactType> + 'static + Send,
455 Decode::Error: std::error::Error + Send + Sync + 'static,
456{
457 type Args = Args;
458
459 type Compact = CompactType;
460
461 type IdType = Ulid;
462
463 type Context = PgContext;
464
465 type Codec = Decode;
466
467 type Error = sqlx::Error;
468
469 type Stream = TaskStream<PgTask<Args>, sqlx::Error>;
470
471 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
472
473 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<PgAck>>;
474
475 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
476 let pool = self.pool.clone();
477 let config = self.config.clone();
478 let worker = worker.clone();
479 let keep_alive = keep_alive_stream(pool, config, worker);
480 let reenqueue = reenqueue_orphaned_stream(
481 self.pool.clone(),
482 self.config.clone(),
483 *self.config.keep_alive(),
484 )
485 .map_ok(|_| ());
486 futures::stream::select(keep_alive, reenqueue).boxed()
487 }
488
489 fn middleware(&self) -> Self::Layer {
490 Stack::new(
491 LockTaskLayer::new(self.pool.clone()),
492 AcknowledgeLayer::new(PgAck::new(self.pool.clone())),
493 )
494 }
495
496 fn poll(self, worker: &WorkerContext) -> Self::Stream {
497 let pool = self.pool.clone();
498 let worker_id = worker.name().to_owned();
499 let namespace = self.config.queue().to_string();
500 let register_worker = initial_heartbeat(
501 self.pool.clone(),
502 self.config.clone(),
503 worker.clone(),
504 "PostgresStorageWithNotify",
505 )
506 .map(|_| Ok(None));
507 let register = stream::once(register_worker);
508 let lazy_fetcher = self
509 .fetcher
510 .into_stream()
511 .filter_map(move |notification| {
512 let namespace = namespace.clone();
513 async move {
514 let pg_notification = notification.ok()?;
515 let payload = pg_notification.payload();
516 let ev: InsertEvent = serde_json::from_str(payload).ok()?;
517
518 if ev.job_type == namespace {
519 return Some(ev.id);
520 }
521 None
522 }
523 })
524 .map(|t| t.to_string())
525 .ready_chunks(self.config.buffer_size())
526 .then(move |ids| {
527 let pool = pool.clone();
528 let worker_id = worker_id.clone();
529 async move {
530 let mut tx = pool.begin().await?;
531 use crate::from_row::PgTaskRow;
532 let res: Vec<_> = sqlx::query_file_as!(
533 PgTaskRow,
534 "queries/task/lock_by_id.sql",
535 &ids,
536 &worker_id
537 )
538 .fetch(&mut *tx)
539 .map(|r| {
540 let row: TaskRow = r?.try_into()?;
541 Ok(Some(
542 row.try_into_task::<Decode, Args, Ulid>()
543 .map_err(|e| sqlx::Error::Protocol(e.to_string()))?,
544 ))
545 })
546 .collect()
547 .await;
548 tx.commit().await?;
549 Ok::<_, sqlx::Error>(res)
550 }
551 })
552 .flat_map(|vec| match vec {
553 Ok(vec) => stream::iter(vec.into_iter().map(|res| match res {
554 Ok(t) => Ok(t),
555 Err(e) => Err(e),
556 }))
557 .boxed(),
558 Err(e) => stream::once(ready(Err(e))).boxed(),
559 })
560 .boxed();
561
562 let eager_fetcher = StreamExt::boxed(PgPollFetcher::<Args, CompactType, Decode>::new(
563 &self.pool,
564 &self.config,
565 worker,
566 ));
567 register.chain(select(lazy_fetcher, eager_fetcher)).boxed()
568 }
569}
570
571#[derive(Debug, Deserialize)]
572pub struct InsertEvent {
573 job_type: String,
574 id: TaskId,
575}
576
577#[cfg(test)]
578mod tests {
579 use std::{collections::HashMap, env, time::Duration};
580
581 use apalis_workflow::{WorkFlow, WorkflowError};
582
583 use apalis_core::{
584 error::BoxDynError,
585 task::data::Data,
586 worker::{builder::WorkerBuilder, event::Event, ext::event_listener::EventListenerExt},
587 };
588 use serde::{Deserialize, Serialize};
589
590 use super::*;
591
592 #[tokio::test]
593 async fn basic_worker() {
594 use apalis_core::backend::TaskSink;
595 let pool = PgPool::connect(
596 env::var("DATABASE_URL")
597 .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
598 .as_str(),
599 )
600 .await
601 .unwrap();
602 let mut backend = PostgresStorage::new(&pool);
603
604 let mut items = stream::repeat_with(HashMap::default).take(1);
605 backend.push_stream(&mut items).await.unwrap();
606
607 async fn send_reminder(
608 _: HashMap<String, String>,
609 wrk: WorkerContext,
610 ) -> Result<(), BoxDynError> {
611 tokio::time::sleep(Duration::from_secs(2)).await;
612 wrk.stop().unwrap();
613 Ok(())
614 }
615
616 let worker = WorkerBuilder::new("rango-tango-1")
617 .backend(backend)
618 .build(send_reminder);
619 worker.run().await.unwrap();
620 }
621
622 #[tokio::test]
623 async fn notify_worker() {
624 use apalis_core::backend::TaskSink;
625 let pool = PgPool::connect(
626 env::var("DATABASE_URL")
627 .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
628 .as_str(),
629 )
630 .await
631 .unwrap();
632 let config = Config::new("test");
633 let mut backend = PostgresStorage::new_with_notify(&pool, &config).await;
634
635 let mut items = stream::repeat_with(|| {
636 Task::builder(42u32)
637 .with_ctx(PgContext::new().with_priority(1))
638 .build()
639 })
640 .take(1);
641 backend.push_all(&mut items).await.unwrap();
642
643 async fn send_reminder(_: u32, wrk: WorkerContext) -> Result<(), BoxDynError> {
644 tokio::time::sleep(Duration::from_secs(2)).await;
645 wrk.stop().unwrap();
646 Ok(())
647 }
648
649 let worker = WorkerBuilder::new("rango-tango-1")
650 .backend(backend)
651 .build(send_reminder);
652 worker.run().await.unwrap();
653 }
654
655 #[tokio::test]
656 async fn test_workflow_complete() {
657 use apalis_core::backend::WeakTaskSink;
658 #[derive(Debug, Serialize, Deserialize, Clone)]
659 struct PipelineConfig {
660 min_confidence: f32,
661 enable_sentiment: bool,
662 }
663
664 #[derive(Debug, Serialize, Deserialize)]
665 struct UserInput {
666 text: String,
667 }
668
669 #[derive(Debug, Serialize, Deserialize)]
670 struct Classified {
671 text: String,
672 label: String,
673 confidence: f32,
674 }
675
676 #[derive(Debug, Serialize, Deserialize)]
677 struct Summary {
678 text: String,
679 sentiment: Option<String>,
680 }
681
682 let workflow = WorkFlow::new("text-pipeline")
683 .then(|input: UserInput, mut worker: WorkerContext| async move {
685 worker.emit(&Event::Custom(Box::new(format!(
686 "Preprocessing input: {}",
687 input.text
688 ))));
689 let processed = input.text.to_lowercase();
690 Ok::<_, WorkflowError>(processed)
691 })
692 .then(|text: String| async move {
694 let confidence = 0.85; let items = text.split_whitespace().collect::<Vec<_>>();
696 let results = items
697 .into_iter()
698 .map(|x| Classified {
699 text: x.to_string(),
700 label: if x.contains("rust") {
701 "Tech"
702 } else {
703 "General"
704 }
705 .to_string(),
706 confidence,
707 })
708 .collect::<Vec<_>>();
709 Ok::<_, WorkflowError>(results)
710 })
711 .filter_map(
713 |c: Classified| async move { if c.confidence >= 0.6 { Some(c) } else { None } },
714 )
715 .filter_map(move |c: Classified, config: Data<PipelineConfig>| {
716 let cfg = config.enable_sentiment;
717 async move {
718 if !cfg {
719 return Some(Summary {
720 text: c.text,
721 sentiment: None,
722 });
723 }
724
725 let sentiment = if c.text.contains("delightful") {
727 "positive"
728 } else {
729 "neutral"
730 };
731 Some(Summary {
732 text: c.text,
733 sentiment: Some(sentiment.to_string()),
734 })
735 }
736 })
737 .then(|a: Vec<Summary>, mut worker: WorkerContext| async move {
738 dbg!(&a);
739 worker.emit(&Event::Custom(Box::new(format!(
740 "Generated {} summaries",
741 a.len()
742 ))));
743 worker.stop()
744 });
745
746 let pool = PgPool::connect(
747 env::var("DATABASE_URL")
748 .unwrap_or("postgres://postgres:postgres@localhost/apalis_dev".to_owned())
749 .as_str(),
750 )
751 .await
752 .unwrap();
753 let config = Config::new("test");
754 let mut backend: PostgresStorage<Vec<u8>> =
755 PostgresStorage::new_with_config(&pool, &config);
756
757 let input = UserInput {
758 text: "Rust makes systems programming delightful!".to_string(),
759 };
760 backend.push(input).await.unwrap();
761
762 let worker = WorkerBuilder::new("rango-tango")
763 .backend(backend)
764 .data(PipelineConfig {
765 min_confidence: 0.8,
766 enable_sentiment: true,
767 })
768 .on_event(|ctx, ev| match ev {
769 Event::Custom(msg) => {
770 if let Some(m) = msg.downcast_ref::<String>() {
771 println!("Custom Message: {m}");
772 }
773 }
774 Event::Error(_) => {
775 println!("On Error = {ev:?}");
776 ctx.stop().unwrap();
777 }
778 _ => {
779 println!("On Event = {ev:?}");
780 }
781 })
782 .build(workflow);
783 worker.run().await.unwrap();
784 }
785}