Skip to main content

apalis_pgmq/
sink.rs

1use std::{
2    collections::VecDeque,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use apalis_core::backend::codec::Codec;
8use chrono::Utc;
9use futures::{FutureExt, Sink};
10use sqlx::{PgPool, Row, postgres::PgRow};
11
12use crate::{PGMQueue, PgMqTask, config::Config, errors::PgmqError, query::enqueue_batch};
13
14pin_project_lite::pin_project! {
15    pub(super) struct PgMqSink<T, C> {
16        conn: PgPool,
17        config: Config<C>,
18        items: VecDeque<PgMqTask<T>>,
19        pending_sends: VecDeque<PendingSend>,
20        _codec: std::marker::PhantomData<C>,
21    }
22}
23
24impl<T, C> Clone for PgMqSink<T, C> {
25    fn clone(&self) -> Self {
26        Self {
27            conn: self.conn.clone(),
28            config: self.config.clone(),
29            items: VecDeque::new(),
30            pending_sends: VecDeque::new(),
31            _codec: std::marker::PhantomData,
32        }
33    }
34}
35
36impl<T, C> PgMqSink<T, C> {
37    pub(crate) fn new(conn: PgPool, config: Config<C>) -> Self {
38        Self {
39            conn,
40            config,
41            items: VecDeque::new(),
42            pending_sends: VecDeque::new(),
43            _codec: std::marker::PhantomData,
44        }
45    }
46}
47
48struct PendingSend {
49    future:
50        Pin<Box<dyn std::future::Future<Output = Result<Vec<i64>, PgmqError>> + Send + 'static>>,
51}
52
53struct MessageWithDelay {
54    bytes: Vec<u8>,
55    delay: u64,
56    headers: Option<serde_json::Value>,
57}
58
59impl<T, C> Sink<PgMqTask<T>> for PGMQueue<T, C>
60where
61    T: Send + 'static + Unpin,
62    C: Codec<T, Compact = Vec<u8>> + Unpin,
63    C::Error: std::error::Error + Send + Sync + 'static,
64{
65    type Error = PgmqError;
66
67    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
68        let this = &mut self.get_mut().sink;
69
70        // Poll pending sends
71        while let Some(pending) = this.pending_sends.front_mut() {
72            match pending.future.as_mut().poll(cx) {
73                Poll::Ready(Ok(_msg_ids)) => {
74                    this.pending_sends.pop_front();
75                    println!("Completed pending send to PgMq");
76                }
77                Poll::Ready(Err(e)) => {
78                    this.pending_sends.pop_front();
79                    return Poll::Ready(Err(e));
80                }
81                Poll::Pending => {
82                    return Poll::Pending;
83                }
84            }
85        }
86
87        Poll::Ready(Ok(()))
88    }
89
90    fn start_send(self: Pin<&mut Self>, item: PgMqTask<T>) -> Result<(), Self::Error> {
91        let this = &mut self.get_mut().sink;
92
93        this.items.push_back(item);
94        Ok(())
95    }
96
97    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98        let this = &mut self.get_mut().sink;
99
100        let queue_name = this.config.queue();
101
102        // Collect all messages with their individual delays and headers
103        let mut messages: Vec<MessageWithDelay> = Vec::new();
104
105        while let Some(item) = this.items.pop_front() {
106            let delay = calculate_delay_seconds(item.parts.run_at as i64);
107            let bytes = C::encode(&item.args).map_err(|e| PgmqError::ParsingError(Box::new(e)))?;
108            let headers = Some(serde_json::Value::Object(item.parts.ctx.headers));
109
110            messages.push(MessageWithDelay {
111                bytes,
112                delay,
113                headers,
114            });
115        }
116
117        // Create a single pending send for all messages
118        if !messages.is_empty() {
119            let conn = this.conn.clone();
120            let queue_name = queue_name.to_string();
121
122            let future = async move { send_batch(&conn, &queue_name, &messages).await }.boxed();
123
124            this.pending_sends.push_back(PendingSend { future });
125        }
126
127        // Now poll all pending sends
128        while let Some(pending) = this.pending_sends.front_mut() {
129            match pending.future.as_mut().poll(cx) {
130                Poll::Ready(Ok(msg_ids)) => {
131                    this.pending_sends.pop_front();
132                    println!("Pushed {} jobs to PgMq", msg_ids.len());
133                }
134                Poll::Ready(Err(e)) => {
135                    this.pending_sends.pop_front();
136                    return Poll::Ready(Err(e));
137                }
138                Poll::Pending => {
139                    return Poll::Pending;
140                }
141            }
142        }
143
144        Poll::Ready(Ok(()))
145    }
146
147    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
148        self.poll_flush(cx)
149    }
150}
151
152fn calculate_delay_seconds(run_at: i64) -> u64 {
153    let now = Utc::now().timestamp();
154    if run_at > now {
155        (run_at - now).max(0) as u64
156    } else {
157        0
158    }
159}
160
161async fn send_batch(
162    conn: &PgPool,
163    queue_name: &str,
164    messages: &[MessageWithDelay],
165) -> Result<Vec<i64>, PgmqError> {
166    let mut msg_ids: Vec<i64> = Vec::new();
167    let query = enqueue_batch(queue_name, messages.len())?;
168
169    let mut q = sqlx::query(&query);
170
171    // Bind delays, messages, and headers in the correct order
172    for msg in messages.iter() {
173        q = q.bind(msg.delay as i64);
174        q = q.bind(&msg.bytes);
175        q = q.bind(&msg.headers);
176    }
177
178    let rows: Vec<PgRow> = q.fetch_all(conn).await?;
179    for row in rows.iter() {
180        msg_ids.push(row.get("msg_id"));
181    }
182    Ok(msg_ids)
183}