1use apalis_codec::json::JsonCodec;
2use apalis_sql::{DateTime, DateTimeExt, config::Config};
3use futures::{
4 FutureExt, Sink, TryFutureExt,
5 future::{BoxFuture, Shared},
6};
7use sqlx::{Executor, PgPool};
8use std::{
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll},
12};
13use ulid::Ulid;
14
15use crate::{CompactType, PgTask, PostgresStorage};
16
17type FlushFuture = BoxFuture<'static, Result<(), Arc<sqlx::Error>>>;
18
19#[pin_project::pin_project]
20pub struct PgSink<Args, Compact = CompactType, Codec = JsonCodec<CompactType>> {
21 pool: PgPool,
22 config: Config,
23 buffer: Vec<PgTask<Compact>>,
24 #[pin]
25 flush_future: Option<Shared<FlushFuture>>,
26 _marker: std::marker::PhantomData<(Args, Codec)>,
27}
28
29impl<Args, Compact, Codec> Clone for PgSink<Args, Compact, Codec> {
30 fn clone(&self) -> Self {
31 Self {
32 pool: self.pool.clone(),
33 config: self.config.clone(),
34 buffer: Vec::new(),
35 flush_future: None,
36 _marker: std::marker::PhantomData,
37 }
38 }
39}
40
41pub fn push_tasks<'a, E>(
42 conn: E,
43 cfg: Config,
44 buffer: Vec<PgTask<CompactType>>,
45) -> impl futures::Future<Output = Result<(), sqlx::Error>> + Send + 'a
46where
47 E: Executor<'a, Database = sqlx::Postgres> + Send + 'a,
48{
49 let job_type = cfg.queue().to_string();
50 let mut ids = Vec::new();
52 let mut job_data = Vec::new();
53 let mut run_ats = Vec::new();
54 let mut priorities = Vec::new();
55 let mut max_attempts_vec = Vec::new();
56 let mut metadata = Vec::new();
57
58 for task in buffer {
59 ids.push(
60 task.parts
61 .task_id
62 .map(|id| id.to_string())
63 .unwrap_or(Ulid::new().to_string()),
64 );
65 job_data.push(task.args);
66 run_ats.push(<DateTime as DateTimeExt>::from_unix_timestamp(
67 task.parts.run_at as i64,
68 ));
69 priorities.push(task.parts.ctx.priority());
70 max_attempts_vec.push(task.parts.ctx.max_attempts());
71 metadata.push(serde_json::Value::Object(task.parts.ctx.meta().clone()));
72 }
73
74 sqlx::query_file!(
75 "queries/task/sink.sql",
76 &ids,
77 &job_type,
78 &job_data,
79 &max_attempts_vec,
80 &run_ats,
81 &priorities,
82 &metadata
83 )
84 .execute(conn)
85 .map_ok(|_| ())
86 .boxed()
87}
88
89impl<Args, Compact, Codec> PgSink<Args, Compact, Codec> {
90 pub fn new(pool: &PgPool, config: &Config) -> Self {
91 Self {
92 pool: pool.clone(),
93 config: config.clone(),
94 buffer: Vec::new(),
95 _marker: std::marker::PhantomData,
96 flush_future: None,
97 }
98 }
99}
100
101impl<Args, Encode, Fetcher> Sink<PgTask<CompactType>>
102 for PostgresStorage<Args, CompactType, Encode, Fetcher>
103where
104 Args: Unpin + Send + Sync + 'static,
105 Fetcher: Unpin,
106{
107 type Error = sqlx::Error;
108
109 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110 Poll::Ready(Ok(()))
111 }
112
113 fn start_send(self: Pin<&mut Self>, item: PgTask<CompactType>) -> Result<(), Self::Error> {
114 self.get_mut().sink.buffer.push(item);
116 Ok(())
117 }
118
119 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120 let this = self.get_mut();
121
122 if this.sink.flush_future.is_none() && this.sink.buffer.is_empty() {
124 return Poll::Ready(Ok(()));
125 }
126
127 if this.sink.flush_future.is_none() && !this.sink.buffer.is_empty() {
129 let config = this.config.clone();
130 let buffer = std::mem::take(&mut this.sink.buffer);
131 let pool = this.sink.pool.clone();
132 let fut = async move {
133 let mut conn = pool.begin().map_err(Arc::new).await?;
134 push_tasks(&mut *conn, config, buffer)
135 .map_err(Arc::new)
136 .await?;
137 conn.commit().map_err(Arc::new).await?;
138 Ok(())
139 };
140 this.sink.flush_future = Some(fut.boxed().shared());
141 }
142
143 if let Some(mut fut) = this.sink.flush_future.take() {
145 match fut.poll_unpin(cx) {
146 Poll::Ready(Ok(())) => {
147 Poll::Ready(Ok(()))
149 }
150 Poll::Ready(Err(e)) => {
151 Poll::Ready(Err(Arc::<sqlx::Error>::into_inner(e).unwrap()))
153 }
154 Poll::Pending => {
155 this.sink.flush_future = Some(fut);
157 Poll::Pending
158 }
159 }
160 } else {
161 Poll::Ready(Ok(()))
163 }
164 }
165
166 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
167 self.poll_flush(cx)
168 }
169}