1use std::{
2 pin::Pin,
3 sync::Arc,
4 task::{Context, Poll},
5};
6
7use apalis_core::backend::codec::json::JsonCodec;
8use chrono::DateTime;
9use futures::{
10 FutureExt, Sink, TryFutureExt,
11 future::{BoxFuture, Shared},
12};
13use sqlx::PgPool;
14use ulid::Ulid;
15
16use crate::{CompactType, PgTask, PostgresStorage, config::Config};
17
18type FlushFuture = BoxFuture<'static, Result<(), Arc<sqlx::Error>>>;
19
20#[pin_project::pin_project]
21pub struct PgSink<Args, Compact = CompactType, Codec = JsonCodec<CompactType>> {
22 pool: PgPool,
23 config: Config,
24 buffer: Vec<PgTask<Compact>>,
25 #[pin]
26 flush_future: Option<Shared<FlushFuture>>,
27 _marker: std::marker::PhantomData<(Args, Codec)>,
28}
29
30impl<Args, Compact, Codec> Clone for PgSink<Args, Compact, Codec> {
31 fn clone(&self) -> Self {
32 Self {
33 pool: self.pool.clone(),
34 config: self.config.clone(),
35 buffer: Vec::new(),
36 flush_future: None,
37 _marker: std::marker::PhantomData,
38 }
39 }
40}
41
42pub async fn push_tasks(
43 pool: PgPool,
44 cfg: Config,
45 buffer: Vec<PgTask<CompactType>>,
46) -> Result<(), sqlx::Error> {
47 let job_type = cfg.queue().to_string();
48 let mut ids = Vec::new();
50 let mut job_data = Vec::new();
51 let mut run_ats = Vec::new();
52 let mut priorities = Vec::new();
53 let mut max_attempts_vec = Vec::new();
54 let mut metadata = Vec::new();
55
56 for task in buffer {
57 ids.push(
58 task.parts
59 .task_id
60 .map(|id| id.to_string())
61 .unwrap_or(Ulid::new().to_string()),
62 );
63 job_data.push(task.args);
64 run_ats.push(
65 DateTime::from_timestamp(task.parts.run_at as i64, 0)
66 .ok_or(sqlx::Error::ColumnNotFound("run_at".to_owned()))?,
67 );
68 priorities.push(task.parts.ctx.priority());
69 max_attempts_vec.push(task.parts.ctx.max_attempts());
70 metadata.push(serde_json::Value::Object(task.parts.ctx.meta().clone()));
71 }
72
73 sqlx::query_file!(
74 "queries/task/sink.sql",
75 &ids,
76 &job_type,
77 &job_data,
78 &max_attempts_vec,
79 &run_ats,
80 &priorities,
81 &metadata
82 )
83 .execute(&pool)
84 .await?;
85 Ok(())
86}
87
88impl<Args, Compact, Codec> PgSink<Args, Compact, Codec> {
89 pub fn new(pool: &PgPool, config: &Config) -> Self {
90 Self {
91 pool: pool.clone(),
92 config: config.clone(),
93 buffer: Vec::new(),
94 _marker: std::marker::PhantomData,
95 flush_future: None,
96 }
97 }
98}
99
100impl<Args, Encode, Fetcher> Sink<PgTask<CompactType>>
101 for PostgresStorage<Args, CompactType, Encode, Fetcher>
102where
103 Args: Unpin + Send + Sync + 'static,
104 Fetcher: Unpin,
105{
106 type Error = sqlx::Error;
107
108 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109 Poll::Ready(Ok(()))
110 }
111
112 fn start_send(self: Pin<&mut Self>, item: PgTask<CompactType>) -> Result<(), Self::Error> {
113 self.get_mut().sink.buffer.push(item);
115 Ok(())
116 }
117
118 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
119 let this = self.get_mut();
120
121 if this.sink.flush_future.is_none() && this.sink.buffer.is_empty() {
123 return Poll::Ready(Ok(()));
124 }
125
126 if this.sink.flush_future.is_none() && !this.sink.buffer.is_empty() {
128 let pool = this.pool.clone();
129 let config = this.config.clone();
130 let buffer = std::mem::take(&mut this.sink.buffer);
131 let sink_fut = push_tasks(pool, config, buffer).map_err(Arc::new);
132 this.sink.flush_future = Some(sink_fut.boxed().shared());
133 }
134
135 if let Some(mut fut) = this.sink.flush_future.take() {
137 match fut.poll_unpin(cx) {
138 Poll::Ready(Ok(())) => {
139 Poll::Ready(Ok(()))
141 }
142 Poll::Ready(Err(e)) => {
143 Poll::Ready(Err(Arc::<sqlx::Error>::into_inner(e).unwrap()))
145 }
146 Poll::Pending => {
147 this.sink.flush_future = Some(fut);
149 Poll::Pending
150 }
151 }
152 } else {
153 Poll::Ready(Ok(()))
155 }
156 }
157
158 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
159 self.poll_flush(cx)
160 }
161}