Skip to main content

apalis_postgres/
sink.rs

1use apalis_codec::json::JsonCodec;
2use apalis_sql::{DateTime, DateTimeExt, config::Config};
3use futures::{
4    FutureExt, Sink, TryFutureExt,
5    future::{BoxFuture, Shared},
6};
7use sqlx::{Executor, PgPool};
8use std::{
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12};
13use ulid::Ulid;
14
15use crate::{CompactType, PgTask, PostgresStorage};
16
17type FlushFuture = BoxFuture<'static, Result<(), Arc<sqlx::Error>>>;
18
19#[pin_project::pin_project]
20pub struct PgSink<Args, Compact = CompactType, Codec = JsonCodec<CompactType>> {
21    pool: PgPool,
22    config: Config,
23    buffer: Vec<PgTask<Compact>>,
24    #[pin]
25    flush_future: Option<Shared<FlushFuture>>,
26    _marker: std::marker::PhantomData<(Args, Codec)>,
27}
28
29impl<Args, Compact, Codec> Clone for PgSink<Args, Compact, Codec> {
30    fn clone(&self) -> Self {
31        Self {
32            pool: self.pool.clone(),
33            config: self.config.clone(),
34            buffer: Vec::new(),
35            flush_future: None,
36            _marker: std::marker::PhantomData,
37        }
38    }
39}
40
41pub fn push_tasks<'a, E>(
42    conn: E,
43    cfg: Config,
44    buffer: Vec<PgTask<CompactType>>,
45) -> impl futures::Future<Output = Result<(), sqlx::Error>> + Send + 'a
46where
47    E: Executor<'a, Database = sqlx::Postgres> + Send + 'a,
48{
49    let job_type = cfg.queue().to_string();
50    // Build the multi-row INSERT with UNNEST
51    let mut ids = Vec::new();
52    let mut job_data = Vec::new();
53    let mut run_ats = Vec::new();
54    let mut priorities = Vec::new();
55    let mut max_attempts_vec = Vec::new();
56    let mut metadata = Vec::new();
57    let mut idempotency_key: Vec<Option<String>> = Vec::new();
58
59    for task in buffer {
60        ids.push(
61            task.parts
62                .task_id
63                .map(|id| id.to_string())
64                .unwrap_or(Ulid::new().to_string()),
65        );
66        job_data.push(task.args);
67        run_ats.push(<DateTime as DateTimeExt>::from_unix_timestamp(
68            task.parts.run_at as i64,
69        ));
70        priorities.push(task.parts.ctx.priority());
71        max_attempts_vec.push(task.parts.ctx.max_attempts());
72        metadata.push(serde_json::Value::Object(task.parts.ctx.meta().clone()));
73        idempotency_key.push(task.parts.idempotency_key);
74    }
75
76    sqlx::query_file!(
77        "queries/task/sink.sql",
78        &ids,
79        &job_type,
80        &job_data,
81        &max_attempts_vec,
82        &run_ats,
83        &priorities,
84        &metadata,
85        &idempotency_key as &[Option<String>]
86    )
87    .execute(conn)
88    .map_ok(|_| ())
89    .boxed()
90}
91
92impl<Args, Compact, Codec> PgSink<Args, Compact, Codec> {
93    pub fn new(pool: &PgPool, config: &Config) -> Self {
94        Self {
95            pool: pool.clone(),
96            config: config.clone(),
97            buffer: Vec::new(),
98            _marker: std::marker::PhantomData,
99            flush_future: None,
100        }
101    }
102}
103
104impl<Args, Encode, Fetcher> Sink<PgTask<CompactType>>
105    for PostgresStorage<Args, CompactType, Encode, Fetcher>
106where
107    Args: Unpin + Send + Sync + 'static,
108    Fetcher: Unpin,
109{
110    type Error = sqlx::Error;
111
112    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
113        Poll::Ready(Ok(()))
114    }
115
116    fn start_send(self: Pin<&mut Self>, item: PgTask<CompactType>) -> Result<(), Self::Error> {
117        // Add the item to the buffer
118        self.get_mut().sink.buffer.push(item);
119        Ok(())
120    }
121
122    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
123        let this = self.get_mut();
124
125        // If there's no existing future and buffer is empty, we're done
126        if this.sink.flush_future.is_none() && this.sink.buffer.is_empty() {
127            return Poll::Ready(Ok(()));
128        }
129
130        // Create the future only if we don't have one and there's work to do
131        if this.sink.flush_future.is_none() && !this.sink.buffer.is_empty() {
132            let config = this.config.clone();
133            let buffer = std::mem::take(&mut this.sink.buffer);
134            let pool = this.sink.pool.clone();
135            let fut = async move {
136                let mut conn = pool.begin().map_err(Arc::new).await?;
137                push_tasks(&mut *conn, config, buffer)
138                    .map_err(Arc::new)
139                    .await?;
140                conn.commit().map_err(Arc::new).await?;
141                Ok(())
142            };
143            this.sink.flush_future = Some(fut.boxed().shared());
144        }
145
146        // Poll the existing future
147        if let Some(mut fut) = this.sink.flush_future.take() {
148            match fut.poll_unpin(cx) {
149                Poll::Ready(Ok(())) => {
150                    // Future completed successfully, don't put it back
151                    Poll::Ready(Ok(()))
152                }
153                Poll::Ready(Err(e)) => {
154                    // Future completed with error, don't put it back
155                    Poll::Ready(Err(Arc::<sqlx::Error>::into_inner(e).unwrap()))
156                }
157                Poll::Pending => {
158                    // Future is still pending, put it back and return Pending
159                    this.sink.flush_future = Some(fut);
160                    Poll::Pending
161                }
162            }
163        } else {
164            // No future and no work to do
165            Poll::Ready(Ok(()))
166        }
167    }
168
169    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
170        self.poll_flush(cx)
171    }
172}