1use apalis_codec::json::JsonCodec;
2use apalis_sql::{DateTime, DateTimeExt, config::Config};
3use futures::{
4 FutureExt, Sink, TryFutureExt,
5 future::{BoxFuture, Shared},
6};
7use sqlx::{Executor, PgPool};
8use std::{
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll},
12};
13use ulid::Ulid;
14
15use crate::{CompactType, PgTask, PostgresStorage};
16
17type FlushFuture = BoxFuture<'static, Result<(), Arc<sqlx::Error>>>;
18
19#[pin_project::pin_project]
20pub struct PgSink<Args, Compact = CompactType, Codec = JsonCodec<CompactType>> {
21 pool: PgPool,
22 config: Config,
23 buffer: Vec<PgTask<Compact>>,
24 #[pin]
25 flush_future: Option<Shared<FlushFuture>>,
26 _marker: std::marker::PhantomData<(Args, Codec)>,
27}
28
29impl<Args, Compact, Codec> Clone for PgSink<Args, Compact, Codec> {
30 fn clone(&self) -> Self {
31 Self {
32 pool: self.pool.clone(),
33 config: self.config.clone(),
34 buffer: Vec::new(),
35 flush_future: None,
36 _marker: std::marker::PhantomData,
37 }
38 }
39}
40
41pub fn push_tasks<'a, E>(
42 conn: E,
43 cfg: Config,
44 buffer: Vec<PgTask<CompactType>>,
45) -> impl futures::Future<Output = Result<(), sqlx::Error>> + Send + 'a
46where
47 E: Executor<'a, Database = sqlx::Postgres> + Send + 'a,
48{
49 let job_type = cfg.queue().to_string();
50 let mut ids = Vec::new();
52 let mut job_data = Vec::new();
53 let mut run_ats = Vec::new();
54 let mut priorities = Vec::new();
55 let mut max_attempts_vec = Vec::new();
56 let mut metadata = Vec::new();
57 let mut idempotency_key: Vec<Option<String>> = Vec::new();
58
59 for task in buffer {
60 ids.push(
61 task.parts
62 .task_id
63 .map(|id| id.to_string())
64 .unwrap_or(Ulid::new().to_string()),
65 );
66 job_data.push(task.args);
67 run_ats.push(<DateTime as DateTimeExt>::from_unix_timestamp(
68 task.parts.run_at as i64,
69 ));
70 priorities.push(task.parts.ctx.priority());
71 max_attempts_vec.push(task.parts.ctx.max_attempts());
72 metadata.push(serde_json::Value::Object(task.parts.ctx.meta().clone()));
73 idempotency_key.push(task.parts.idempotency_key);
74 }
75
76 sqlx::query_file!(
77 "queries/task/sink.sql",
78 &ids,
79 &job_type,
80 &job_data,
81 &max_attempts_vec,
82 &run_ats,
83 &priorities,
84 &metadata,
85 &idempotency_key as &[Option<String>]
86 )
87 .execute(conn)
88 .map_ok(|_| ())
89 .boxed()
90}
91
92impl<Args, Compact, Codec> PgSink<Args, Compact, Codec> {
93 pub fn new(pool: &PgPool, config: &Config) -> Self {
94 Self {
95 pool: pool.clone(),
96 config: config.clone(),
97 buffer: Vec::new(),
98 _marker: std::marker::PhantomData,
99 flush_future: None,
100 }
101 }
102}
103
104impl<Args, Encode, Fetcher> Sink<PgTask<CompactType>>
105 for PostgresStorage<Args, CompactType, Encode, Fetcher>
106where
107 Args: Unpin + Send + Sync + 'static,
108 Fetcher: Unpin,
109{
110 type Error = sqlx::Error;
111
112 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
113 Poll::Ready(Ok(()))
114 }
115
116 fn start_send(self: Pin<&mut Self>, item: PgTask<CompactType>) -> Result<(), Self::Error> {
117 self.get_mut().sink.buffer.push(item);
119 Ok(())
120 }
121
122 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
123 let this = self.get_mut();
124
125 if this.sink.flush_future.is_none() && this.sink.buffer.is_empty() {
127 return Poll::Ready(Ok(()));
128 }
129
130 if this.sink.flush_future.is_none() && !this.sink.buffer.is_empty() {
132 let config = this.config.clone();
133 let buffer = std::mem::take(&mut this.sink.buffer);
134 let pool = this.sink.pool.clone();
135 let fut = async move {
136 let mut conn = pool.begin().map_err(Arc::new).await?;
137 push_tasks(&mut *conn, config, buffer)
138 .map_err(Arc::new)
139 .await?;
140 conn.commit().map_err(Arc::new).await?;
141 Ok(())
142 };
143 this.sink.flush_future = Some(fut.boxed().shared());
144 }
145
146 if let Some(mut fut) = this.sink.flush_future.take() {
148 match fut.poll_unpin(cx) {
149 Poll::Ready(Ok(())) => {
150 Poll::Ready(Ok(()))
152 }
153 Poll::Ready(Err(e)) => {
154 Poll::Ready(Err(Arc::<sqlx::Error>::into_inner(e).unwrap()))
156 }
157 Poll::Pending => {
158 this.sink.flush_future = Some(fut);
160 Poll::Pending
161 }
162 }
163 } else {
164 Poll::Ready(Ok(()))
166 }
167 }
168
169 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
170 self.poll_flush(cx)
171 }
172}