1use std::{
2 collections::VecDeque,
3 marker::PhantomData,
4 pin::Pin,
5 sync::{Arc, atomic::AtomicUsize},
6 task::{Context, Poll},
7};
8
9use apalis_core::{
10 backend::{
11 codec::Codec,
12 poll_strategy::{PollContext, PollStrategyExt},
13 },
14 task::Task,
15 worker::context::WorkerContext,
16};
17use apalis_sql::{context::SqlContext, from_row::TaskRow};
18use futures::{FutureExt, future::BoxFuture, stream::Stream};
19use pin_project::pin_project;
20use sqlx::{Pool, Sqlite, SqlitePool};
21use ulid::Ulid;
22
23use crate::{CompactType, SqliteTask, config::Config, from_row::SqliteTaskRow};
24
25pub async fn fetch_next<Args, D: Codec<Args, Compact = CompactType>>(
26 pool: SqlitePool,
27 config: Config,
28 worker: WorkerContext,
29) -> Result<Vec<Task<Args, SqlContext, Ulid>>, sqlx::Error>
30where
31 D::Error: std::error::Error + Send + Sync + 'static,
32 Args: 'static,
33{
34 let job_type = config.queue().to_string();
35 let buffer_size = config.buffer_size() as i32;
36 let worker = worker.name().to_string();
37 sqlx::query_file_as!(
38 SqliteTaskRow,
39 "queries/backend/fetch_next.sql",
40 worker,
41 job_type,
42 buffer_size
43 )
44 .fetch_all(&pool)
45 .await?
46 .into_iter()
47 .map(|r| {
48 let row: TaskRow = r.try_into()?;
49 row.try_into_task::<D, Args, Ulid>()
50 .map_err(|e| sqlx::Error::Protocol(e.to_string()))
51 })
52 .collect()
53}
54
55enum StreamState<Args> {
56 Ready,
57 Delay,
58 Fetch(BoxFuture<'static, Result<Vec<SqliteTask<Args>>, sqlx::Error>>),
59 Buffered(VecDeque<SqliteTask<Args>>),
60 Empty,
61}
62
63#[derive(Clone, Debug)]
65pub struct SqliteFetcher<Args, Compact, Decode> {
66 pub _marker: PhantomData<(Args, Compact, Decode)>,
67}
68
69#[pin_project]
71pub struct SqlitePollFetcher<Args, Compact, Decode> {
72 pool: SqlitePool,
73 config: Config,
74 wrk: WorkerContext,
75 _marker: PhantomData<(Compact, Decode)>,
76 #[pin]
77 state: StreamState<Args>,
78
79 #[pin]
80 delay_stream: Option<Pin<Box<dyn Stream<Item = ()> + Send>>>,
81
82 prev_count: Arc<AtomicUsize>,
83}
84
85impl<Args, Compact, Decode> Clone for SqlitePollFetcher<Args, Compact, Decode> {
86 fn clone(&self) -> Self {
87 Self {
88 pool: self.pool.clone(),
89 config: self.config.clone(),
90 wrk: self.wrk.clone(),
91 _marker: PhantomData,
92 state: StreamState::Ready,
93 delay_stream: None,
94 prev_count: Arc::new(AtomicUsize::new(0)),
95 }
96 }
97}
98
99impl<Args: 'static, Decode> SqlitePollFetcher<Args, CompactType, Decode> {
100 pub fn new(pool: &Pool<Sqlite>, config: &Config, wrk: &WorkerContext) -> Self
101 where
102 Decode: Codec<Args, Compact = CompactType> + 'static,
103 Decode::Error: std::error::Error + Send + Sync + 'static,
104 {
105 Self {
106 pool: pool.clone(),
107 config: config.clone(),
108 wrk: wrk.clone(),
109 _marker: PhantomData,
110 state: StreamState::Ready,
111 delay_stream: None,
112 prev_count: Arc::new(AtomicUsize::new(0)),
113 }
114 }
115}
116
117impl<Args, Decode> Stream for SqlitePollFetcher<Args, CompactType, Decode>
118where
119 Decode::Error: std::error::Error + Send + Sync + 'static,
120 Args: Send + 'static + Unpin,
121 Decode: Codec<Args, Compact = CompactType> + 'static,
122{
123 type Item = Result<Option<SqliteTask<Args>>, sqlx::Error>;
124
125 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
126 let this = self.get_mut();
127 if this.delay_stream.is_none() {
128 let strategy = this
129 .config
130 .poll_strategy()
131 .clone()
132 .build_stream(&PollContext::new(this.wrk.clone(), this.prev_count.clone()));
133 this.delay_stream = Some(Box::pin(strategy));
134 }
135
136 loop {
137 match this.state {
138 StreamState::Ready => {
139 let stream = fetch_next::<Args, Decode>(
140 this.pool.clone(),
141 this.config.clone(),
142 this.wrk.clone(),
143 );
144 this.state = StreamState::Fetch(stream.boxed());
145 }
146 StreamState::Delay => {
147 if let Some(delay_stream) = this.delay_stream.as_mut() {
148 match delay_stream.as_mut().poll_next(cx) {
149 Poll::Pending => return Poll::Pending,
150 Poll::Ready(Some(_)) => {
151 this.state = StreamState::Ready;
152 }
153 Poll::Ready(None) => {
154 this.state = StreamState::Empty;
155 return Poll::Ready(None);
156 }
157 }
158 } else {
159 this.state = StreamState::Empty;
160 return Poll::Ready(None);
161 }
162 }
163
164 StreamState::Fetch(ref mut fut) => match fut.poll_unpin(cx) {
165 Poll::Pending => return Poll::Pending,
166 Poll::Ready(item) => match item {
167 Ok(requests) => {
168 if requests.is_empty() {
169 this.state = StreamState::Delay;
170 } else {
171 let mut buffer = VecDeque::new();
172 for request in requests {
173 buffer.push_back(request);
174 }
175
176 this.state = StreamState::Buffered(buffer);
177 }
178 }
179 Err(e) => {
180 this.state = StreamState::Empty;
181 return Poll::Ready(Some(Err(e)));
182 }
183 },
184 },
185
186 StreamState::Buffered(ref mut buffer) => {
187 if let Some(request) = buffer.pop_front() {
188 if buffer.is_empty() {
190 this.state = StreamState::Ready;
192 }
193 return Poll::Ready(Some(Ok(Some(request))));
194 } else {
195 this.state = StreamState::Ready;
197 }
198 }
199
200 StreamState::Empty => return Poll::Ready(None),
201 }
202 }
203 }
204}
205
206impl<Args, Compact, Decode> SqlitePollFetcher<Args, Compact, Decode> {
207 pub fn take_pending(&mut self) -> VecDeque<SqliteTask<Args>> {
208 match &mut self.state {
209 StreamState::Buffered(tasks) => std::mem::take(tasks),
210 _ => VecDeque::new(),
211 }
212 }
213}