apalis_postgres/
sink.rs

1use std::{
2    pin::Pin,
3    sync::Arc,
4    task::{Context, Poll},
5};
6
7use apalis_core::backend::codec::json::JsonCodec;
8use chrono::DateTime;
9use futures::{
10    FutureExt, Sink, TryFutureExt,
11    future::{BoxFuture, Shared},
12};
13use sqlx::PgPool;
14use ulid::Ulid;
15
16use crate::{CompactType, PgTask, PostgresStorage, config::Config};
17
18type FlushFuture = BoxFuture<'static, Result<(), Arc<sqlx::Error>>>;
19
20#[pin_project::pin_project]
21pub struct PgSink<Args, Compact = CompactType, Codec = JsonCodec<CompactType>> {
22    pool: PgPool,
23    config: Config,
24    buffer: Vec<PgTask<Compact>>,
25    #[pin]
26    flush_future: Option<Shared<FlushFuture>>,
27    _marker: std::marker::PhantomData<(Args, Codec)>,
28}
29
30impl<Args, Compact, Codec> Clone for PgSink<Args, Compact, Codec> {
31    fn clone(&self) -> Self {
32        Self {
33            pool: self.pool.clone(),
34            config: self.config.clone(),
35            buffer: Vec::new(),
36            flush_future: None,
37            _marker: std::marker::PhantomData,
38        }
39    }
40}
41
42pub async fn push_tasks(
43    pool: PgPool,
44    cfg: Config,
45    buffer: Vec<PgTask<CompactType>>,
46) -> Result<(), sqlx::Error> {
47    let job_type = cfg.queue().to_string();
48    // Build the multi-row INSERT with UNNEST
49    let mut ids = Vec::new();
50    let mut job_data = Vec::new();
51    let mut run_ats = Vec::new();
52    let mut priorities = Vec::new();
53    let mut max_attempts_vec = Vec::new();
54    let mut metadata = Vec::new();
55
56    for task in buffer {
57        ids.push(
58            task.parts
59                .task_id
60                .map(|id| id.to_string())
61                .unwrap_or(Ulid::new().to_string()),
62        );
63        job_data.push(task.args);
64        run_ats.push(
65            DateTime::from_timestamp(task.parts.run_at as i64, 0)
66                .ok_or(sqlx::Error::ColumnNotFound("run_at".to_owned()))?,
67        );
68        priorities.push(task.parts.ctx.priority());
69        max_attempts_vec.push(task.parts.ctx.max_attempts());
70        metadata.push(serde_json::Value::Object(task.parts.ctx.meta().clone()));
71    }
72
73    sqlx::query_file!(
74        "queries/task/sink.sql",
75        &ids,
76        &job_type,
77        &job_data,
78        &max_attempts_vec,
79        &run_ats,
80        &priorities,
81        &metadata
82    )
83    .execute(&pool)
84    .await?;
85    Ok(())
86}
87
88impl<Args, Compact, Codec> PgSink<Args, Compact, Codec> {
89    pub fn new(pool: &PgPool, config: &Config) -> Self {
90        Self {
91            pool: pool.clone(),
92            config: config.clone(),
93            buffer: Vec::new(),
94            _marker: std::marker::PhantomData,
95            flush_future: None,
96        }
97    }
98}
99
100impl<Args, Encode, Fetcher> Sink<PgTask<CompactType>>
101    for PostgresStorage<Args, CompactType, Encode, Fetcher>
102where
103    Args: Unpin + Send + Sync + 'static,
104    Fetcher: Unpin,
105{
106    type Error = sqlx::Error;
107
108    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109        Poll::Ready(Ok(()))
110    }
111
112    fn start_send(self: Pin<&mut Self>, item: PgTask<CompactType>) -> Result<(), Self::Error> {
113        // Add the item to the buffer
114        self.get_mut().sink.buffer.push(item);
115        Ok(())
116    }
117
118    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
119        let this = self.get_mut();
120
121        // If there's no existing future and buffer is empty, we're done
122        if this.sink.flush_future.is_none() && this.sink.buffer.is_empty() {
123            return Poll::Ready(Ok(()));
124        }
125
126        // Create the future only if we don't have one and there's work to do
127        if this.sink.flush_future.is_none() && !this.sink.buffer.is_empty() {
128            let pool = this.pool.clone();
129            let config = this.config.clone();
130            let buffer = std::mem::take(&mut this.sink.buffer);
131            let sink_fut = push_tasks(pool, config, buffer).map_err(Arc::new);
132            this.sink.flush_future = Some(sink_fut.boxed().shared());
133        }
134
135        // Poll the existing future
136        if let Some(mut fut) = this.sink.flush_future.take() {
137            match fut.poll_unpin(cx) {
138                Poll::Ready(Ok(())) => {
139                    // Future completed successfully, don't put it back
140                    Poll::Ready(Ok(()))
141                }
142                Poll::Ready(Err(e)) => {
143                    // Future completed with error, don't put it back
144                    Poll::Ready(Err(Arc::<sqlx::Error>::into_inner(e).unwrap()))
145                }
146                Poll::Pending => {
147                    // Future is still pending, put it back and return Pending
148                    this.sink.flush_future = Some(fut);
149                    Poll::Pending
150                }
151            }
152        } else {
153            // No future and no work to do
154            Poll::Ready(Ok(()))
155        }
156    }
157
158    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
159        self.poll_flush(cx)
160    }
161}