1use std::{
2 pin::Pin,
3 sync::Arc,
4 task::{Context, Poll},
5};
6
7use apalis_core::backend::codec::json::JsonCodec;
8use chrono::{DateTime, Utc};
9use futures::{
10 FutureExt, Sink, TryFutureExt,
11 future::{BoxFuture, Shared},
12};
13use sqlx::{Executor, PgPool};
14use ulid::Ulid;
15
16use crate::{CompactType, PgTask, PostgresStorage, config::Config};
17
18type FlushFuture = BoxFuture<'static, Result<(), Arc<sqlx::Error>>>;
19
20#[pin_project::pin_project]
21pub struct PgSink<Args, Compact = CompactType, Codec = JsonCodec<CompactType>> {
22 pool: PgPool,
23 config: Config,
24 buffer: Vec<PgTask<Compact>>,
25 #[pin]
26 flush_future: Option<Shared<FlushFuture>>,
27 _marker: std::marker::PhantomData<(Args, Codec)>,
28}
29
30impl<Args, Compact, Codec> Clone for PgSink<Args, Compact, Codec> {
31 fn clone(&self) -> Self {
32 Self {
33 pool: self.pool.clone(),
34 config: self.config.clone(),
35 buffer: Vec::new(),
36 flush_future: None,
37 _marker: std::marker::PhantomData,
38 }
39 }
40}
41
42pub fn push_tasks<'a, E>(
43 conn: E,
44 cfg: Config,
45 buffer: Vec<PgTask<CompactType>>,
46) -> impl futures::Future<Output = Result<(), sqlx::Error>> + Send + 'a
47where
48 E: Executor<'a, Database = sqlx::Postgres> + Send + 'a,
49{
50 let job_type = cfg.queue().to_string();
51 let mut ids = Vec::new();
53 let mut job_data = Vec::new();
54 let mut run_ats = Vec::new();
55 let mut priorities = Vec::new();
56 let mut max_attempts_vec = Vec::new();
57 let mut metadata = Vec::new();
58
59 for task in buffer {
60 ids.push(
61 task.parts
62 .task_id
63 .map(|id| id.to_string())
64 .unwrap_or(Ulid::new().to_string()),
65 );
66 job_data.push(task.args);
67 run_ats.push(DateTime::from_timestamp(task.parts.run_at as i64, 0).unwrap_or(Utc::now()));
68 priorities.push(task.parts.ctx.priority());
69 max_attempts_vec.push(task.parts.ctx.max_attempts());
70 metadata.push(serde_json::Value::Object(task.parts.ctx.meta().clone()));
71 }
72
73 sqlx::query_file!(
74 "queries/task/sink.sql",
75 &ids,
76 &job_type,
77 &job_data,
78 &max_attempts_vec,
79 &run_ats,
80 &priorities,
81 &metadata
82 )
83 .execute(conn)
84 .map_ok(|_| ())
85 .boxed()
86}
87
88impl<Args, Compact, Codec> PgSink<Args, Compact, Codec> {
89 pub fn new(pool: &PgPool, config: &Config) -> Self {
90 Self {
91 pool: pool.clone(),
92 config: config.clone(),
93 buffer: Vec::new(),
94 _marker: std::marker::PhantomData,
95 flush_future: None,
96 }
97 }
98}
99
100impl<Args, Encode, Fetcher> Sink<PgTask<CompactType>>
101 for PostgresStorage<Args, CompactType, Encode, Fetcher>
102where
103 Args: Unpin + Send + Sync + 'static,
104 Fetcher: Unpin,
105{
106 type Error = sqlx::Error;
107
108 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109 Poll::Ready(Ok(()))
110 }
111
112 fn start_send(self: Pin<&mut Self>, item: PgTask<CompactType>) -> Result<(), Self::Error> {
113 self.get_mut().sink.buffer.push(item);
115 Ok(())
116 }
117
118 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
119 let this = self.get_mut();
120
121 if this.sink.flush_future.is_none() && this.sink.buffer.is_empty() {
123 return Poll::Ready(Ok(()));
124 }
125
126 if this.sink.flush_future.is_none() && !this.sink.buffer.is_empty() {
128 let config = this.config.clone();
129 let buffer = std::mem::take(&mut this.sink.buffer);
130 let pool = this.sink.pool.clone();
131 let fut = async move {
132 let mut conn = pool.begin().map_err(Arc::new).await?;
133 push_tasks(&mut *conn, config, buffer)
134 .map_err(Arc::new)
135 .await?;
136 conn.commit().map_err(Arc::new).await?;
137 Ok(())
138 };
139 this.sink.flush_future = Some(fut.boxed().shared());
140 }
141
142 if let Some(mut fut) = this.sink.flush_future.take() {
144 match fut.poll_unpin(cx) {
145 Poll::Ready(Ok(())) => {
146 Poll::Ready(Ok(()))
148 }
149 Poll::Ready(Err(e)) => {
150 Poll::Ready(Err(Arc::<sqlx::Error>::into_inner(e).unwrap()))
152 }
153 Poll::Pending => {
154 this.sink.flush_future = Some(fut);
156 Poll::Pending
157 }
158 }
159 } else {
160 Poll::Ready(Ok(()))
162 }
163 }
164
165 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
166 self.poll_flush(cx)
167 }
168}