apalis_postgres/
sink.rs

1use apalis_codec::json::JsonCodec;
2use apalis_sql::{DateTime, DateTimeExt, config::Config};
3use futures::{
4    FutureExt, Sink, TryFutureExt,
5    future::{BoxFuture, Shared},
6};
7use sqlx::{Executor, PgPool};
8use std::{
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12};
13use ulid::Ulid;
14
15use crate::{CompactType, PgTask, PostgresStorage};
16
17type FlushFuture = BoxFuture<'static, Result<(), Arc<sqlx::Error>>>;
18
19#[pin_project::pin_project]
20pub struct PgSink<Args, Compact = CompactType, Codec = JsonCodec<CompactType>> {
21    pool: PgPool,
22    config: Config,
23    buffer: Vec<PgTask<Compact>>,
24    #[pin]
25    flush_future: Option<Shared<FlushFuture>>,
26    _marker: std::marker::PhantomData<(Args, Codec)>,
27}
28
29impl<Args, Compact, Codec> Clone for PgSink<Args, Compact, Codec> {
30    fn clone(&self) -> Self {
31        Self {
32            pool: self.pool.clone(),
33            config: self.config.clone(),
34            buffer: Vec::new(),
35            flush_future: None,
36            _marker: std::marker::PhantomData,
37        }
38    }
39}
40
41pub fn push_tasks<'a, E>(
42    conn: E,
43    cfg: Config,
44    buffer: Vec<PgTask<CompactType>>,
45) -> impl futures::Future<Output = Result<(), sqlx::Error>> + Send + 'a
46where
47    E: Executor<'a, Database = sqlx::Postgres> + Send + 'a,
48{
49    let job_type = cfg.queue().to_string();
50    // Build the multi-row INSERT with UNNEST
51    let mut ids = Vec::new();
52    let mut job_data = Vec::new();
53    let mut run_ats = Vec::new();
54    let mut priorities = Vec::new();
55    let mut max_attempts_vec = Vec::new();
56    let mut metadata = Vec::new();
57
58    for task in buffer {
59        ids.push(
60            task.parts
61                .task_id
62                .map(|id| id.to_string())
63                .unwrap_or(Ulid::new().to_string()),
64        );
65        job_data.push(task.args);
66        run_ats.push(<DateTime as DateTimeExt>::from_unix_timestamp(
67            task.parts.run_at as i64,
68        ));
69        priorities.push(task.parts.ctx.priority());
70        max_attempts_vec.push(task.parts.ctx.max_attempts());
71        metadata.push(serde_json::Value::Object(task.parts.ctx.meta().clone()));
72    }
73
74    sqlx::query_file!(
75        "queries/task/sink.sql",
76        &ids,
77        &job_type,
78        &job_data,
79        &max_attempts_vec,
80        &run_ats,
81        &priorities,
82        &metadata
83    )
84    .execute(conn)
85    .map_ok(|_| ())
86    .boxed()
87}
88
89impl<Args, Compact, Codec> PgSink<Args, Compact, Codec> {
90    pub fn new(pool: &PgPool, config: &Config) -> Self {
91        Self {
92            pool: pool.clone(),
93            config: config.clone(),
94            buffer: Vec::new(),
95            _marker: std::marker::PhantomData,
96            flush_future: None,
97        }
98    }
99}
100
101impl<Args, Encode, Fetcher> Sink<PgTask<CompactType>>
102    for PostgresStorage<Args, CompactType, Encode, Fetcher>
103where
104    Args: Unpin + Send + Sync + 'static,
105    Fetcher: Unpin,
106{
107    type Error = sqlx::Error;
108
109    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110        Poll::Ready(Ok(()))
111    }
112
113    fn start_send(self: Pin<&mut Self>, item: PgTask<CompactType>) -> Result<(), Self::Error> {
114        // Add the item to the buffer
115        self.get_mut().sink.buffer.push(item);
116        Ok(())
117    }
118
119    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120        let this = self.get_mut();
121
122        // If there's no existing future and buffer is empty, we're done
123        if this.sink.flush_future.is_none() && this.sink.buffer.is_empty() {
124            return Poll::Ready(Ok(()));
125        }
126
127        // Create the future only if we don't have one and there's work to do
128        if this.sink.flush_future.is_none() && !this.sink.buffer.is_empty() {
129            let config = this.config.clone();
130            let buffer = std::mem::take(&mut this.sink.buffer);
131            let pool = this.sink.pool.clone();
132            let fut = async move {
133                let mut conn = pool.begin().map_err(Arc::new).await?;
134                push_tasks(&mut *conn, config, buffer)
135                    .map_err(Arc::new)
136                    .await?;
137                conn.commit().map_err(Arc::new).await?;
138                Ok(())
139            };
140            this.sink.flush_future = Some(fut.boxed().shared());
141        }
142
143        // Poll the existing future
144        if let Some(mut fut) = this.sink.flush_future.take() {
145            match fut.poll_unpin(cx) {
146                Poll::Ready(Ok(())) => {
147                    // Future completed successfully, don't put it back
148                    Poll::Ready(Ok(()))
149                }
150                Poll::Ready(Err(e)) => {
151                    // Future completed with error, don't put it back
152                    Poll::Ready(Err(Arc::<sqlx::Error>::into_inner(e).unwrap()))
153                }
154                Poll::Pending => {
155                    // Future is still pending, put it back and return Pending
156                    this.sink.flush_future = Some(fut);
157                    Poll::Pending
158                }
159            }
160        } else {
161            // No future and no work to do
162            Poll::Ready(Ok(()))
163        }
164    }
165
166    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
167        self.poll_flush(cx)
168    }
169}