1use std::{
2 pin::Pin,
3 sync::Arc,
4 task::{Context, Poll},
5};
6
7use futures::{
8 FutureExt, Sink,
9 future::{BoxFuture, Shared},
10};
11use sqlx::SqlitePool;
12use ulid::Ulid;
13
14use crate::{CompactType, SqliteStorage, SqliteTask, config::Config};
15
16type FlushFuture = BoxFuture<'static, Result<(), Arc<sqlx::Error>>>;
17
18#[pin_project::pin_project]
19pub struct SqliteSink<Args, Compact, Codec> {
20 pool: SqlitePool,
21 config: Config,
22 buffer: Vec<SqliteTask<Compact>>,
23 #[pin]
24 flush_future: Option<Shared<FlushFuture>>,
25 _marker: std::marker::PhantomData<(Args, Codec)>,
26}
27
28impl<Args, Compact, Codec> Clone for SqliteSink<Args, Compact, Codec> {
29 fn clone(&self) -> Self {
30 Self {
31 pool: self.pool.clone(),
32 config: self.config.clone(),
33 buffer: Vec::new(),
34 flush_future: None,
35 _marker: std::marker::PhantomData,
36 }
37 }
38}
39
40pub async fn push_tasks(
41 pool: SqlitePool,
42 cfg: Config,
43 buffer: Vec<SqliteTask<CompactType>>,
44) -> Result<(), Arc<sqlx::Error>> {
45 let mut tx = pool.begin().await?;
46 for task in buffer {
47 let id = task
48 .parts
49 .task_id
50 .map(|id| id.to_string())
51 .unwrap_or(Ulid::new().to_string());
52 let run_at = task.parts.run_at as i64;
53 let max_attempts = task.parts.ctx.max_attempts();
54 let priority = task.parts.ctx.priority();
55 let args = task.args;
56 let job_type = cfg.queue().to_string();
58 let meta = serde_json::to_string(&task.parts.ctx.meta()).unwrap_or_default();
59 sqlx::query_file!(
60 "queries/task/sink.sql",
61 args,
62 id,
63 job_type,
64 max_attempts,
65 run_at,
66 priority,
67 meta
68 )
69 .execute(&mut *tx)
70 .await?;
71 }
72 tx.commit().await?;
73
74 Ok(())
75}
76
77impl<Args, Compact, Codec> SqliteSink<Args, Compact, Codec> {
78 pub fn new(pool: &SqlitePool, config: &Config) -> Self {
79 Self {
80 pool: pool.clone(),
81 config: config.clone(),
82 buffer: Vec::new(),
83 _marker: std::marker::PhantomData,
84 flush_future: None,
85 }
86 }
87}
88
89impl<Args, Encode, Fetcher> Sink<SqliteTask<CompactType>> for SqliteStorage<Args, Encode, Fetcher>
90where
91 Args: Send + Sync + 'static,
92{
93 type Error = sqlx::Error;
94
95 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
96 Poll::Ready(Ok(()))
97 }
98
99 fn start_send(self: Pin<&mut Self>, item: SqliteTask<CompactType>) -> Result<(), Self::Error> {
100 self.project().sink.buffer.push(item);
102 Ok(())
103 }
104
105 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
106 let mut this = self.project();
107
108 if this.sink.flush_future.is_none() && this.sink.buffer.is_empty() {
110 return Poll::Ready(Ok(()));
111 }
112
113 if this.sink.flush_future.is_none() && !this.sink.buffer.is_empty() {
115 let pool = this.pool.clone();
116 let config = this.config.clone();
117 let buffer = std::mem::take(&mut this.sink.buffer);
118 let sink_fut = push_tasks(pool, config, buffer);
119 this.sink.flush_future = Some((Box::pin(sink_fut) as FlushFuture).shared());
120 }
121
122 if let Some(mut fut) = this.sink.flush_future.take() {
124 match fut.poll_unpin(cx) {
125 Poll::Ready(Ok(())) => {
126 Poll::Ready(Ok(()))
128 }
129 Poll::Ready(Err(e)) => {
130 Poll::Ready(Err(Arc::into_inner(e).unwrap()))
132 }
133 Poll::Pending => {
134 this.sink.flush_future = Some(fut);
136 Poll::Pending
137 }
138 }
139 } else {
140 Poll::Ready(Ok(()))
142 }
143 }
144
145 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146 self.poll_flush(cx)
147 }
148}