1use std::{
2 collections::VecDeque,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use apalis_core::backend::codec::Codec;
8use chrono::Utc;
9use futures::{FutureExt, Sink};
10use sqlx::{PgPool, Row, postgres::PgRow};
11
12use crate::{PGMQueue, PgMqTask, config::Config, errors::PgmqError, query::enqueue_batch};
13
14pin_project_lite::pin_project! {
15 pub(super) struct PgMqSink<T, C> {
16 conn: PgPool,
17 config: Config<C>,
18 items: VecDeque<PgMqTask<T>>,
19 pending_sends: VecDeque<PendingSend>,
20 _codec: std::marker::PhantomData<C>,
21 }
22}
23
24impl<T, C> Clone for PgMqSink<T, C> {
25 fn clone(&self) -> Self {
26 Self {
27 conn: self.conn.clone(),
28 config: self.config.clone(),
29 items: VecDeque::new(),
30 pending_sends: VecDeque::new(),
31 _codec: std::marker::PhantomData,
32 }
33 }
34}
35
36impl<T, C> PgMqSink<T, C> {
37 pub(crate) fn new(conn: PgPool, config: Config<C>) -> Self {
38 Self {
39 conn,
40 config,
41 items: VecDeque::new(),
42 pending_sends: VecDeque::new(),
43 _codec: std::marker::PhantomData,
44 }
45 }
46}
47
48struct PendingSend {
49 future:
50 Pin<Box<dyn std::future::Future<Output = Result<Vec<i64>, PgmqError>> + Send + 'static>>,
51}
52
53struct MessageWithDelay {
54 bytes: Vec<u8>,
55 delay: u64,
56 headers: Option<serde_json::Value>,
57}
58
59impl<T, C> Sink<PgMqTask<T>> for PGMQueue<T, C>
60where
61 T: Send + 'static + Unpin,
62 C: Codec<T, Compact = Vec<u8>> + Unpin,
63 C::Error: std::error::Error + Send + Sync + 'static,
64{
65 type Error = PgmqError;
66
67 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
68 let this = &mut self.get_mut().sink;
69
70 while let Some(pending) = this.pending_sends.front_mut() {
72 match pending.future.as_mut().poll(cx) {
73 Poll::Ready(Ok(_msg_ids)) => {
74 this.pending_sends.pop_front();
75 println!("Completed pending send to PgMq");
76 }
77 Poll::Ready(Err(e)) => {
78 this.pending_sends.pop_front();
79 return Poll::Ready(Err(e));
80 }
81 Poll::Pending => {
82 return Poll::Pending;
83 }
84 }
85 }
86
87 Poll::Ready(Ok(()))
88 }
89
90 fn start_send(self: Pin<&mut Self>, item: PgMqTask<T>) -> Result<(), Self::Error> {
91 let this = &mut self.get_mut().sink;
92
93 this.items.push_back(item);
94 Ok(())
95 }
96
97 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
98 let this = &mut self.get_mut().sink;
99
100 let queue_name = this.config.queue();
101
102 let mut messages: Vec<MessageWithDelay> = Vec::new();
104
105 while let Some(item) = this.items.pop_front() {
106 let delay = calculate_delay_seconds(item.parts.run_at as i64);
107 let bytes = C::encode(&item.args).map_err(|e| PgmqError::ParsingError(Box::new(e)))?;
108 let headers = Some(serde_json::Value::Object(item.parts.ctx.headers));
109
110 messages.push(MessageWithDelay {
111 bytes,
112 delay,
113 headers,
114 });
115 }
116
117 if !messages.is_empty() {
119 let conn = this.conn.clone();
120 let queue_name = queue_name.to_string();
121
122 let future = async move { send_batch(&conn, &queue_name, &messages).await }.boxed();
123
124 this.pending_sends.push_back(PendingSend { future });
125 }
126
127 while let Some(pending) = this.pending_sends.front_mut() {
129 match pending.future.as_mut().poll(cx) {
130 Poll::Ready(Ok(msg_ids)) => {
131 this.pending_sends.pop_front();
132 println!("Pushed {} jobs to PgMq", msg_ids.len());
133 }
134 Poll::Ready(Err(e)) => {
135 this.pending_sends.pop_front();
136 return Poll::Ready(Err(e));
137 }
138 Poll::Pending => {
139 return Poll::Pending;
140 }
141 }
142 }
143
144 Poll::Ready(Ok(()))
145 }
146
147 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
148 self.poll_flush(cx)
149 }
150}
151
152fn calculate_delay_seconds(run_at: i64) -> u64 {
153 let now = Utc::now().timestamp();
154 if run_at > now {
155 (run_at - now).max(0) as u64
156 } else {
157 0
158 }
159}
160
161async fn send_batch(
162 conn: &PgPool,
163 queue_name: &str,
164 messages: &[MessageWithDelay],
165) -> Result<Vec<i64>, PgmqError> {
166 let mut msg_ids: Vec<i64> = Vec::new();
167 let query = enqueue_batch(queue_name, messages.len())?;
168
169 let mut q = sqlx::query(&query);
170
171 for msg in messages.iter() {
173 q = q.bind(msg.delay as i64);
174 q = q.bind(&msg.bytes);
175 q = q.bind(&msg.headers);
176 }
177
178 let rows: Vec<PgRow> = q.fetch_all(conn).await?;
179 for row in rows.iter() {
180 msg_ids.push(row.get("msg_id"));
181 }
182 Ok(msg_ids)
183}