1use std::{
2 collections::VecDeque,
3 marker::PhantomData,
4 pin::Pin,
5 sync::{Arc, atomic::AtomicUsize},
6 task::{Context, Poll},
7};
8
9use apalis_core::{
10 backend::poll_strategy::{PollContext, PollStrategyExt},
11 task::Task,
12 worker::context::WorkerContext,
13};
14use apalis_sql::{context::SqlContext, from_row::TaskRow};
15use futures::{FutureExt, future::BoxFuture, stream::Stream};
16use pin_project::pin_project;
17use sqlx::{Pool, Sqlite, SqlitePool};
18use ulid::Ulid;
19
20use crate::{CompactType, SqliteTask, config::Config, from_row::SqliteTaskRow};
21
22pub async fn fetch_next(
24 pool: SqlitePool,
25 config: Config,
26 worker: WorkerContext,
27) -> Result<Vec<Task<CompactType, SqlContext, Ulid>>, sqlx::Error>
28where
29{
30 let job_type = config.queue().to_string();
31 let buffer_size = config.buffer_size() as i32;
32 let worker = worker.name().clone();
33 sqlx::query_file_as!(
34 SqliteTaskRow,
35 "queries/backend/fetch_next.sql",
36 worker,
37 job_type,
38 buffer_size
39 )
40 .fetch_all(&pool)
41 .await?
42 .into_iter()
43 .map(|r| {
44 let row: TaskRow = r.try_into()?;
45 row.try_into_task_compact::<Ulid>()
46 .map_err(|e| sqlx::Error::Protocol(e.to_string()))
47 })
48 .collect()
49}
50
51enum StreamState {
52 Ready,
53 Delay,
54 Fetch(BoxFuture<'static, Result<Vec<SqliteTask<CompactType>>, sqlx::Error>>),
55 Buffered(VecDeque<SqliteTask<CompactType>>),
56 Empty,
57}
58
59#[derive(Clone, Debug)]
61pub struct SqliteFetcher;
62#[pin_project]
64pub struct SqlitePollFetcher<Compact, Decode> {
65 pool: SqlitePool,
66 config: Config,
67 wrk: WorkerContext,
68 _marker: PhantomData<(Compact, Decode)>,
69 #[pin]
70 state: StreamState,
71
72 #[pin]
73 delay_stream: Option<Pin<Box<dyn Stream<Item = ()> + Send>>>,
74
75 prev_count: Arc<AtomicUsize>,
76}
77
78impl<Compact, Decode> std::fmt::Debug for SqlitePollFetcher<Compact, Decode> {
79 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80 f.debug_struct("SqlitePollFetcher")
81 .field("pool", &self.pool)
82 .field("config", &self.config)
83 .field("wrk", &self.wrk)
84 .field("_marker", &self._marker)
85 .field("prev_count", &self.prev_count)
86 .finish()
87 }
88}
89
90impl<Compact, Decode> Clone for SqlitePollFetcher<Compact, Decode> {
91 fn clone(&self) -> Self {
92 Self {
93 pool: self.pool.clone(),
94 config: self.config.clone(),
95 wrk: self.wrk.clone(),
96 _marker: PhantomData,
97 state: StreamState::Ready,
98 delay_stream: None,
99 prev_count: Arc::new(AtomicUsize::new(0)),
100 }
101 }
102}
103
104impl<Decode> SqlitePollFetcher<CompactType, Decode> {
105 #[must_use]
107 pub fn new(pool: &Pool<Sqlite>, config: &Config, wrk: &WorkerContext) -> Self {
108 Self {
109 pool: pool.clone(),
110 config: config.clone(),
111 wrk: wrk.clone(),
112 _marker: PhantomData,
113 state: StreamState::Ready,
114 delay_stream: None,
115 prev_count: Arc::new(AtomicUsize::new(0)),
116 }
117 }
118}
119
120impl<Decode> Stream for SqlitePollFetcher<CompactType, Decode> {
121 type Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>;
122
123 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
124 let this = self.get_mut();
125 if this.delay_stream.is_none() {
126 let strategy = this
127 .config
128 .poll_strategy()
129 .clone()
130 .build_stream(&PollContext::new(this.wrk.clone(), this.prev_count.clone()));
131 this.delay_stream = Some(Box::pin(strategy));
132 }
133
134 loop {
135 match this.state {
136 StreamState::Ready => {
137 let stream =
138 fetch_next(this.pool.clone(), this.config.clone(), this.wrk.clone());
139 this.state = StreamState::Fetch(stream.boxed());
140 }
141 StreamState::Delay => {
142 if let Some(delay_stream) = this.delay_stream.as_mut() {
143 match delay_stream.as_mut().poll_next(cx) {
144 Poll::Pending => return Poll::Pending,
145 Poll::Ready(Some(_)) => {
146 this.state = StreamState::Ready;
147 }
148 Poll::Ready(None) => {
149 this.state = StreamState::Empty;
150 return Poll::Ready(None);
151 }
152 }
153 } else {
154 this.state = StreamState::Empty;
155 return Poll::Ready(None);
156 }
157 }
158
159 StreamState::Fetch(ref mut fut) => match fut.poll_unpin(cx) {
160 Poll::Pending => return Poll::Pending,
161 Poll::Ready(item) => match item {
162 Ok(requests) => {
163 if requests.is_empty() {
164 this.state = StreamState::Delay;
165 } else {
166 let mut buffer = VecDeque::new();
167 for request in requests {
168 buffer.push_back(request);
169 }
170
171 this.state = StreamState::Buffered(buffer);
172 }
173 }
174 Err(e) => {
175 this.state = StreamState::Empty;
176 return Poll::Ready(Some(Err(e)));
177 }
178 },
179 },
180
181 StreamState::Buffered(ref mut buffer) => {
182 if let Some(request) = buffer.pop_front() {
183 if buffer.is_empty() {
185 this.state = StreamState::Ready;
187 }
188 return Poll::Ready(Some(Ok(Some(request))));
189 } else {
190 this.state = StreamState::Ready;
192 }
193 }
194
195 StreamState::Empty => return Poll::Ready(None),
196 }
197 }
198 }
199}
200
201impl<Compact, Decode> SqlitePollFetcher<Compact, Decode> {
202 pub fn take_pending(&mut self) -> VecDeque<SqliteTask<Vec<u8>>> {
204 match &mut self.state {
205 StreamState::Buffered(tasks) => std::mem::take(tasks),
206 _ => VecDeque::new(),
207 }
208 }
209}