apalis_postgres/
sink.rs

1use std::{
2    pin::Pin,
3    sync::Arc,
4    task::{Context, Poll},
5};
6
7use apalis_core::backend::codec::json::JsonCodec;
8use chrono::{DateTime, Utc};
9use futures::{
10    FutureExt, Sink, TryFutureExt,
11    future::{BoxFuture, Shared},
12};
13use sqlx::{Executor, PgPool};
14use ulid::Ulid;
15
16use crate::{CompactType, PgTask, PostgresStorage, config::Config};
17
18type FlushFuture = BoxFuture<'static, Result<(), Arc<sqlx::Error>>>;
19
20#[pin_project::pin_project]
21pub struct PgSink<Args, Compact = CompactType, Codec = JsonCodec<CompactType>> {
22    pool: PgPool,
23    config: Config,
24    buffer: Vec<PgTask<Compact>>,
25    #[pin]
26    flush_future: Option<Shared<FlushFuture>>,
27    _marker: std::marker::PhantomData<(Args, Codec)>,
28}
29
30impl<Args, Compact, Codec> Clone for PgSink<Args, Compact, Codec> {
31    fn clone(&self) -> Self {
32        Self {
33            pool: self.pool.clone(),
34            config: self.config.clone(),
35            buffer: Vec::new(),
36            flush_future: None,
37            _marker: std::marker::PhantomData,
38        }
39    }
40}
41
42pub fn push_tasks<'a, E>(
43    conn: E,
44    cfg: Config,
45    buffer: Vec<PgTask<CompactType>>,
46) -> impl futures::Future<Output = Result<(), sqlx::Error>> + Send + 'a
47where
48    E: Executor<'a, Database = sqlx::Postgres> + Send + 'a,
49{
50    let job_type = cfg.queue().to_string();
51    // Build the multi-row INSERT with UNNEST
52    let mut ids = Vec::new();
53    let mut job_data = Vec::new();
54    let mut run_ats = Vec::new();
55    let mut priorities = Vec::new();
56    let mut max_attempts_vec = Vec::new();
57    let mut metadata = Vec::new();
58
59    for task in buffer {
60        ids.push(
61            task.parts
62                .task_id
63                .map(|id| id.to_string())
64                .unwrap_or(Ulid::new().to_string()),
65        );
66        job_data.push(task.args);
67        run_ats.push(DateTime::from_timestamp(task.parts.run_at as i64, 0).unwrap_or(Utc::now()));
68        priorities.push(task.parts.ctx.priority());
69        max_attempts_vec.push(task.parts.ctx.max_attempts());
70        metadata.push(serde_json::Value::Object(task.parts.ctx.meta().clone()));
71    }
72
73    sqlx::query_file!(
74        "queries/task/sink.sql",
75        &ids,
76        &job_type,
77        &job_data,
78        &max_attempts_vec,
79        &run_ats,
80        &priorities,
81        &metadata
82    )
83    .execute(conn)
84    .map_ok(|_| ())
85    .boxed()
86}
87
88impl<Args, Compact, Codec> PgSink<Args, Compact, Codec> {
89    pub fn new(pool: &PgPool, config: &Config) -> Self {
90        Self {
91            pool: pool.clone(),
92            config: config.clone(),
93            buffer: Vec::new(),
94            _marker: std::marker::PhantomData,
95            flush_future: None,
96        }
97    }
98}
99
100impl<Args, Encode, Fetcher> Sink<PgTask<CompactType>>
101    for PostgresStorage<Args, CompactType, Encode, Fetcher>
102where
103    Args: Unpin + Send + Sync + 'static,
104    Fetcher: Unpin,
105{
106    type Error = sqlx::Error;
107
108    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109        Poll::Ready(Ok(()))
110    }
111
112    fn start_send(self: Pin<&mut Self>, item: PgTask<CompactType>) -> Result<(), Self::Error> {
113        // Add the item to the buffer
114        self.get_mut().sink.buffer.push(item);
115        Ok(())
116    }
117
118    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
119        let this = self.get_mut();
120
121        // If there's no existing future and buffer is empty, we're done
122        if this.sink.flush_future.is_none() && this.sink.buffer.is_empty() {
123            return Poll::Ready(Ok(()));
124        }
125
126        // Create the future only if we don't have one and there's work to do
127        if this.sink.flush_future.is_none() && !this.sink.buffer.is_empty() {
128            let config = this.config.clone();
129            let buffer = std::mem::take(&mut this.sink.buffer);
130            let pool = this.sink.pool.clone();
131            let fut = async move {
132                let mut conn = pool.begin().map_err(Arc::new).await?;
133                push_tasks(&mut *conn, config, buffer)
134                    .map_err(Arc::new)
135                    .await?;
136                conn.commit().map_err(Arc::new).await?;
137                Ok(())
138            };
139            this.sink.flush_future = Some(fut.boxed().shared());
140        }
141
142        // Poll the existing future
143        if let Some(mut fut) = this.sink.flush_future.take() {
144            match fut.poll_unpin(cx) {
145                Poll::Ready(Ok(())) => {
146                    // Future completed successfully, don't put it back
147                    Poll::Ready(Ok(()))
148                }
149                Poll::Ready(Err(e)) => {
150                    // Future completed with error, don't put it back
151                    Poll::Ready(Err(Arc::<sqlx::Error>::into_inner(e).unwrap()))
152                }
153                Poll::Pending => {
154                    // Future is still pending, put it back and return Pending
155                    this.sink.flush_future = Some(fut);
156                    Poll::Pending
157                }
158            }
159        } else {
160            // No future and no work to do
161            Poll::Ready(Ok(()))
162        }
163    }
164
165    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
166        self.poll_flush(cx)
167    }
168}