1use std::{
2 cmp::max,
3 collections::{HashMap, HashSet},
4 future::ready,
5 marker::PhantomData,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9};
10
11use crate::{
12 CompactType, Config, JOBS_TABLE, SqliteContext, SqliteStorage, SqliteTask,
13 ack::{LockTaskLayer, SqliteAck},
14 callback::{DbEvent, update_hook_callback},
15 fetcher::SqlitePollFetcher,
16 initial_heartbeat, keep_alive,
17};
18use crate::{from_row::SqliteTaskRow, sink::SqliteSink};
19use apalis_codec::json::JsonCodec;
20use apalis_core::{
21 backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue, shared::MakeShared},
22 layers::Stack,
23 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
24};
25use apalis_sql::TaskRow;
26use futures::{
27 FutureExt, SinkExt, Stream, StreamExt, TryStreamExt,
28 channel::mpsc::{self, Receiver, Sender},
29 future::{BoxFuture, Shared},
30 lock::Mutex,
31 stream::{self, BoxStream, select},
32};
33use sqlx::{Sqlite, SqlitePool, pool::PoolOptions, sqlite::SqliteOperation};
34use ulid::Ulid;
35
36#[derive(Clone, Debug)]
38pub struct SharedSqliteStorage<Decode> {
39 pool: SqlitePool,
40 registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>>,
41 drive: Shared<BoxFuture<'static, ()>>,
42 _marker: PhantomData<Decode>,
43}
44
45impl<Decode> SharedSqliteStorage<Decode> {
46 #[must_use]
48 pub fn pool(&self) -> &SqlitePool {
49 &self.pool
50 }
51}
52
53impl SharedSqliteStorage<JsonCodec<CompactType>> {
54 #[must_use]
56 pub fn new(url: &str) -> Self {
57 Self::new_with_codec(url)
58 }
59
60 #[must_use]
62 pub fn new_with_codec<Codec>(url: &str) -> SharedSqliteStorage<Codec> {
63 Self::new_with_pool_options(
64 url,
65 PoolOptions::new().max_lifetime(None).idle_timeout(None),
66 )
67 }
68
69 #[must_use]
71 pub fn new_with_pool_options<Codec>(
72 url: &str,
73 options: PoolOptions<Sqlite>,
74 ) -> SharedSqliteStorage<Codec> {
75 let (tx, rx) = mpsc::unbounded::<DbEvent>();
76 let pool = options
77 .after_connect(move |conn, _meta| {
78 let mut tx = tx.clone();
79 Box::pin(async move {
80 let mut lock_handle = conn.lock_handle().await?;
81 lock_handle.set_update_hook(move |ev| update_hook_callback(ev, &mut tx));
82 Ok(())
83 })
84 })
85 .connect_lazy(url)
86 .expect("Failed to create Sqlite pool");
87
88 let registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>> =
89 Arc::new(Mutex::new(HashMap::default()));
90 let p = pool.clone();
91 let instances = registry.clone();
92 SharedSqliteStorage {
93 pool,
94 drive: async move {
95 rx.filter(|a| {
96 ready(a.operation() == &SqliteOperation::Insert && a.table_name() == JOBS_TABLE)
97 })
98 .ready_chunks(instances.try_lock().map(|r| r.len()).unwrap_or(10))
99 .then(|events| {
100 let row_ids = events.iter().map(|e| e.rowid()).collect::<HashSet<i64>>();
101 let instances = instances.clone();
102 let pool = p.clone();
103 async move {
104 let instances = instances.lock().await;
105 let job_types = serde_json::to_string(
106 &instances.keys().cloned().collect::<Vec<String>>(),
107 )
108 .unwrap();
109 let row_ids = serde_json::to_string(&row_ids).unwrap();
110 let mut tx = pool.begin().await?;
111 let buffer_size = max(10, instances.len()) as i32;
112 let res: Vec<_> = sqlx::query_file_as!(
113 SqliteTaskRow,
114 "queries/backend/fetch_next_shared.sql",
115 job_types,
116 row_ids,
117 buffer_size,
118 )
119 .fetch(&mut *tx)
120 .map(|r| {
121 let row: TaskRow = r?.try_into()?;
122 row.try_into_task_compact()
123 .map_err(|e| sqlx::Error::Protocol(e.to_string()))
124 })
125 .try_collect()
126 .await?;
127 tx.commit().await?;
128 Ok::<_, sqlx::Error>(res)
129 }
130 })
131 .for_each(|r| async {
132 match r {
133 Ok(tasks) => {
134 let mut instances = instances.lock().await;
135 for task in tasks {
136 if let Some(tx) = instances.get_mut(
137 task.parts
138 .ctx
139 .queue()
140 .as_ref()
141 .expect("Namespace must be set"),
142 ) {
143 let _ = tx.send(task).await;
144 }
145 }
146 }
147 Err(e) => {
148 log::error!("Error fetching tasks: {e:?}");
149 }
150 }
151 })
152 .await;
153 }
154 .boxed()
155 .shared(),
156 registry,
157 _marker: PhantomData,
158 }
159 }
160}
161
162#[derive(Debug, thiserror::Error)]
164pub enum SharedSqliteError {
165 #[error("Namespace {0} already exists")]
167 NamespaceExists(String),
168 #[error("Could not acquire registry lock")]
170 RegistryLocked,
171}
172
173impl<Args, Decode: Codec<Args, Compact = CompactType>> MakeShared<Args>
174 for SharedSqliteStorage<Decode>
175{
176 type Backend = SqliteStorage<Args, Decode, SharedFetcher<CompactType>>;
177 type Config = Config;
178 type MakeError = SharedSqliteError;
179 fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
180 where
181 Self::Config: Default,
182 {
183 Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
184 }
185 fn make_shared_with_config(
186 &mut self,
187 config: Self::Config,
188 ) -> Result<Self::Backend, Self::MakeError> {
189 let (tx, rx) = mpsc::channel(config.buffer_size());
190 let mut r = self
191 .registry
192 .try_lock()
193 .ok_or(SharedSqliteError::RegistryLocked)?;
194 if r.insert(config.queue().to_string(), tx).is_some() {
195 return Err(SharedSqliteError::NamespaceExists(
196 config.queue().to_string(),
197 ));
198 }
199 let sink = SqliteSink::new(&self.pool, &config);
200 Ok(SqliteStorage {
201 config,
202 fetcher: SharedFetcher {
203 poller: self.drive.clone(),
204 receiver: Arc::new(std::sync::Mutex::new(rx)),
205 },
206 pool: self.pool.clone(),
207 sink,
208 job_type: PhantomData,
209 codec: PhantomData,
210 })
211 }
212}
213
214#[derive(Clone, Debug)]
215pub struct SharedFetcher<Compact> {
216 poller: Shared<BoxFuture<'static, ()>>,
217 receiver: Arc<std::sync::Mutex<Receiver<SqliteTask<Compact>>>>,
218}
219
220impl<Compact> Stream for SharedFetcher<Compact> {
221 type Item = SqliteTask<Compact>;
222
223 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
224 let this = self.get_mut();
225 let _ = this.poller.poll_unpin(cx);
227
228 this.receiver.lock().unwrap().poll_next_unpin(cx)
230 }
231}
232
233impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
234where
235 Args: Send + 'static + Unpin + Sync,
236 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send + Sync,
237 Decode::Error: std::error::Error + Send + Sync + 'static,
238{
239 type Args = Args;
240
241 type IdType = Ulid;
242
243 type Error = sqlx::Error;
244
245 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
246
247 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
248
249 type Context = SqliteContext;
250
251 type Layer = Stack<AcknowledgeLayer<SqliteAck>, LockTaskLayer>;
252
253 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
254 let keep_alive_interval = *self.config.keep_alive();
255 let pool = self.pool.clone();
256 let worker = worker.clone();
257 let config = self.config.clone();
258
259 stream::unfold((), move |()| async move {
260 apalis_core::timer::sleep(keep_alive_interval).await;
261 Some(((), ()))
262 })
263 .then(move |_| keep_alive(pool.clone(), config.clone(), worker.clone()))
264 .boxed()
265 }
266
267 fn middleware(&self) -> Self::Layer {
268 let lock = LockTaskLayer::new(self.pool.clone());
269 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
270 Stack::new(ack, lock)
271 }
272
273 fn poll(self, worker: &WorkerContext) -> Self::Stream {
274 self.poll_shared(worker)
275 .map(|a| match a {
276 Ok(Some(task)) => Ok(Some(
277 task.try_map(|t| Decode::decode(&t))
278 .map_err(|e| sqlx::Error::Decode(e.into()))?,
279 )),
280 Ok(None) => Ok(None),
281 Err(e) => Err(e),
282 })
283 .boxed()
284 }
285}
286
287impl<Args, Decode: Send + 'static> BackendExt
288 for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
289where
290 Self: Backend<Args = Args, IdType = Ulid, Context = SqliteContext, Error = sqlx::Error>,
291 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
292 Decode::Error: std::error::Error + Send + Sync + 'static,
293 Args: Send + 'static + Unpin,
294{
295 type Codec = Decode;
296 type Compact = CompactType;
297 type CompactStream = TaskStream<SqliteTask<Self::Compact>, sqlx::Error>;
298
299 fn get_queue(&self) -> Queue {
300 self.config.queue().clone()
301 }
302
303 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
304 self.poll_shared(worker).boxed()
305 }
306}
307
308impl<Args, Decode: Send + 'static> SqliteStorage<Args, Decode, SharedFetcher<CompactType>> {
309 fn poll_shared(
310 self,
311 worker: &WorkerContext,
312 ) -> impl Stream<Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>> + 'static {
313 let pool = self.pool.clone();
314 let worker = worker.clone();
315 let init = initial_heartbeat(
321 pool,
322 self.config.clone(),
323 worker.clone(),
324 "SharedSqliteStorage",
325 );
326 let starter = stream::once(init)
327 .map_ok(|_| None) .boxed();
329 let lazy_fetcher = self.fetcher.map(|s| Ok(Some(s))).boxed();
330
331 let eager_fetcher = StreamExt::boxed(SqlitePollFetcher::<CompactType, Decode>::new(
332 &self.pool,
333 &self.config,
334 &worker,
335 ));
336 starter.chain(select(lazy_fetcher, eager_fetcher)).boxed()
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use std::time::Duration;
343
344 use apalis_core::{
345 backend::TaskSink, error::BoxDynError, task::task_id::TaskId,
346 worker::builder::WorkerBuilder,
347 };
348
349 use super::*;
350
351 #[tokio::test]
352 async fn basic_worker() {
353 let mut store = SharedSqliteStorage::new(":memory:");
354 SqliteStorage::setup(store.pool()).await.unwrap();
355
356 let mut map_store = store.make_shared().unwrap();
357
358 let mut int_store = store.make_shared().unwrap();
359
360 map_store
361 .push(HashMap::<String, i32>::from([("value".to_string(), 42)]))
362 .await
363 .unwrap();
364 int_store.push(99).await.unwrap();
365
366 async fn send_reminder<T, I>(
367 _: T,
368 _task_id: TaskId<I>,
369 wrk: WorkerContext,
370 ) -> Result<(), BoxDynError> {
371 tokio::time::sleep(Duration::from_secs(2)).await;
372 wrk.stop().unwrap();
373 Ok(())
374 }
375
376 let int_worker = WorkerBuilder::new("rango-tango-2")
377 .backend(int_store)
378 .build(send_reminder);
379 let map_worker = WorkerBuilder::new("rango-tango-1")
380 .backend(map_store)
381 .build(send_reminder);
382 tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
383 }
384}