1use std::{
2 pin::Pin,
3 sync::Arc,
4 task::{Context, Poll},
5};
6
7use futures::{
8 FutureExt, Sink,
9 future::{BoxFuture, Shared},
10};
11use sqlx::SqlitePool;
12use ulid::Ulid;
13
14use crate::{CompactType, SqliteStorage, SqliteTask, config::Config};
15
16type FlushFuture = BoxFuture<'static, Result<(), Arc<sqlx::Error>>>;
17
18#[pin_project::pin_project]
20#[derive(Debug)]
21pub struct SqliteSink<Args, Compact, Codec> {
22 pool: SqlitePool,
23 config: Config,
24 buffer: Vec<SqliteTask<Compact>>,
25 #[pin]
26 flush_future: Option<Shared<FlushFuture>>,
27 _marker: std::marker::PhantomData<(Args, Codec)>,
28}
29
30impl<Args, Compact, Codec> Clone for SqliteSink<Args, Compact, Codec> {
31 fn clone(&self) -> Self {
32 Self {
33 pool: self.pool.clone(),
34 config: self.config.clone(),
35 buffer: Vec::new(),
36 flush_future: None,
37 _marker: std::marker::PhantomData,
38 }
39 }
40}
41
42pub async fn push_tasks(
44 pool: SqlitePool,
45 cfg: Config,
46 buffer: Vec<SqliteTask<CompactType>>,
47) -> Result<(), Arc<sqlx::Error>> {
48 let mut tx = pool.begin().await?;
49 for task in buffer {
50 let id = task
51 .parts
52 .task_id
53 .map(|id| id.to_string())
54 .unwrap_or(Ulid::new().to_string());
55 let run_at = task.parts.run_at as i64;
56 let max_attempts = task.parts.ctx.max_attempts();
57 let priority = task.parts.ctx.priority();
58 let args = task.args;
59 let job_type = cfg.queue().to_string();
61 let meta = serde_json::to_string(&task.parts.ctx.meta()).unwrap_or_default();
62 sqlx::query_file!(
63 "queries/task/sink.sql",
64 args,
65 id,
66 job_type,
67 max_attempts,
68 run_at,
69 priority,
70 meta
71 )
72 .execute(&mut *tx)
73 .await?;
74 }
75 tx.commit().await?;
76
77 Ok(())
78}
79
80impl<Args, Compact, Codec> SqliteSink<Args, Compact, Codec> {
81 #[must_use]
83 pub fn new(pool: &SqlitePool, config: &Config) -> Self {
84 Self {
85 pool: pool.clone(),
86 config: config.clone(),
87 buffer: Vec::new(),
88 _marker: std::marker::PhantomData,
89 flush_future: None,
90 }
91 }
92}
93
94impl<Args, Encode, Fetcher> Sink<SqliteTask<CompactType>> for SqliteStorage<Args, Encode, Fetcher>
95where
96 Args: Send + Sync + 'static,
97{
98 type Error = sqlx::Error;
99
100 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101 Poll::Ready(Ok(()))
102 }
103
104 fn start_send(self: Pin<&mut Self>, item: SqliteTask<CompactType>) -> Result<(), Self::Error> {
105 self.project().sink.buffer.push(item);
107 Ok(())
108 }
109
110 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
111 let mut this = self.project();
112
113 if this.sink.flush_future.is_none() && this.sink.buffer.is_empty() {
115 return Poll::Ready(Ok(()));
116 }
117
118 if this.sink.flush_future.is_none() && !this.sink.buffer.is_empty() {
120 let pool = this.pool.clone();
121 let config = this.config.clone();
122 let buffer = std::mem::take(&mut this.sink.buffer);
123 let sink_fut = push_tasks(pool, config, buffer);
124 this.sink.flush_future = Some((Box::pin(sink_fut) as FlushFuture).shared());
125 }
126
127 if let Some(mut fut) = this.sink.flush_future.take() {
129 match fut.poll_unpin(cx) {
130 Poll::Ready(Ok(())) => {
131 Poll::Ready(Ok(()))
133 }
134 Poll::Ready(Err(e)) => {
135 Poll::Ready(Err(Arc::into_inner(e).unwrap()))
137 }
138 Poll::Pending => {
139 this.sink.flush_future = Some(fut);
141 Poll::Pending
142 }
143 }
144 } else {
145 Poll::Ready(Ok(()))
147 }
148 }
149
150 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
151 self.poll_flush(cx)
152 }
153}