1use std::{
2 cmp::max,
3 collections::{HashMap, HashSet},
4 future::ready,
5 marker::PhantomData,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9};
10
11use crate::{
12 CompactType, Config, JOBS_TABLE, SqliteContext, SqliteStorage, SqliteTask,
13 ack::{LockTaskLayer, SqliteAck},
14 callback::{DbEvent, update_hook_callback},
15 fetcher::SqlitePollFetcher,
16 initial_heartbeat, keep_alive,
17};
18use crate::{from_row::SqliteTaskRow, sink::SqliteSink};
19use apalis_codec::json::JsonCodec;
20use apalis_core::{
21 backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue, shared::MakeShared},
22 layers::Stack,
23 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
24};
25use apalis_sql::from_row::TaskRow;
26use futures::{
27 FutureExt, SinkExt, Stream, StreamExt, TryStreamExt,
28 channel::mpsc::{self, Receiver, Sender},
29 future::{BoxFuture, Shared},
30 lock::Mutex,
31 stream::{self, BoxStream, select},
32};
33use sqlx::{Sqlite, SqlitePool, pool::PoolOptions, sqlite::SqliteOperation};
34use ulid::Ulid;
35
36#[derive(Clone, Debug)]
38pub struct SharedSqliteStorage<Decode> {
39 pool: SqlitePool,
40 registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>>,
41 drive: Shared<BoxFuture<'static, ()>>,
42 _marker: PhantomData<Decode>,
43}
44
45impl<Decode> SharedSqliteStorage<Decode> {
46 #[must_use]
48 pub fn pool(&self) -> &SqlitePool {
49 &self.pool
50 }
51}
52
53impl SharedSqliteStorage<JsonCodec<CompactType>> {
54 #[must_use]
56 pub fn new(url: &str) -> Self {
57 Self::new_with_codec(url)
58 }
59 #[must_use]
61 pub fn new_with_codec<Codec>(url: &str) -> SharedSqliteStorage<Codec> {
62 let (tx, rx) = mpsc::unbounded::<DbEvent>();
63 let pool = PoolOptions::<Sqlite>::new()
64 .after_connect(move |conn, _meta| {
65 let mut tx = tx.clone();
66 Box::pin(async move {
67 let mut lock_handle = conn.lock_handle().await?;
68 lock_handle.set_update_hook(move |ev| update_hook_callback(ev, &mut tx));
69 Ok(())
70 })
71 })
72 .connect_lazy(url)
73 .expect("Failed to create Sqlite pool");
74
75 let registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>> =
76 Arc::new(Mutex::new(HashMap::default()));
77 let p = pool.clone();
78 let instances = registry.clone();
79 SharedSqliteStorage {
80 pool,
81 drive: async move {
82 rx.filter(|a| {
83 ready(a.operation() == &SqliteOperation::Insert && a.table_name() == JOBS_TABLE)
84 })
85 .ready_chunks(instances.try_lock().map(|r| r.len()).unwrap_or(10))
86 .then(|events| {
87 let row_ids = events.iter().map(|e| e.rowid()).collect::<HashSet<i64>>();
88 let instances = instances.clone();
89 let pool = p.clone();
90 async move {
91 let instances = instances.lock().await;
92 let job_types = serde_json::to_string(
93 &instances.keys().cloned().collect::<Vec<String>>(),
94 )
95 .unwrap();
96 let row_ids = serde_json::to_string(&row_ids).unwrap();
97 let mut tx = pool.begin().await?;
98 let buffer_size = max(10, instances.len()) as i32;
99 let res: Vec<_> = sqlx::query_file_as!(
100 SqliteTaskRow,
101 "queries/backend/fetch_next_shared.sql",
102 job_types,
103 row_ids,
104 buffer_size,
105 )
106 .fetch(&mut *tx)
107 .map(|r| {
108 let row: TaskRow = r?.try_into()?;
109 row.try_into_task_compact()
110 .map_err(|e| sqlx::Error::Protocol(e.to_string()))
111 })
112 .try_collect()
113 .await?;
114 tx.commit().await?;
115 Ok::<_, sqlx::Error>(res)
116 }
117 })
118 .for_each(|r| async {
119 match r {
120 Ok(tasks) => {
121 let mut instances = instances.lock().await;
122 for task in tasks {
123 if let Some(tx) = instances.get_mut(
124 task.parts
125 .ctx
126 .queue()
127 .as_ref()
128 .expect("Namespace must be set"),
129 ) {
130 let _ = tx.send(task).await;
131 }
132 }
133 }
134 Err(e) => {
135 log::error!("Error fetching tasks: {e:?}");
136 }
137 }
138 })
139 .await;
140 }
141 .boxed()
142 .shared(),
143 registry,
144 _marker: PhantomData,
145 }
146 }
147}
148
149#[derive(Debug, thiserror::Error)]
151pub enum SharedSqliteError {
152 #[error("Namespace {0} already exists")]
154 NamespaceExists(String),
155 #[error("Could not acquire registry lock")]
157 RegistryLocked,
158}
159
160impl<Args, Decode: Codec<Args, Compact = CompactType>> MakeShared<Args>
161 for SharedSqliteStorage<Decode>
162{
163 type Backend = SqliteStorage<Args, Decode, SharedFetcher<CompactType>>;
164 type Config = Config;
165 type MakeError = SharedSqliteError;
166 fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
167 where
168 Self::Config: Default,
169 {
170 Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
171 }
172 fn make_shared_with_config(
173 &mut self,
174 config: Self::Config,
175 ) -> Result<Self::Backend, Self::MakeError> {
176 let (tx, rx) = mpsc::channel(config.buffer_size());
177 let mut r = self
178 .registry
179 .try_lock()
180 .ok_or(SharedSqliteError::RegistryLocked)?;
181 if r.insert(config.queue().to_string(), tx).is_some() {
182 return Err(SharedSqliteError::NamespaceExists(
183 config.queue().to_string(),
184 ));
185 }
186 let sink = SqliteSink::new(&self.pool, &config);
187 Ok(SqliteStorage {
188 config,
189 fetcher: SharedFetcher {
190 poller: self.drive.clone(),
191 receiver: Arc::new(std::sync::Mutex::new(rx)),
192 },
193 pool: self.pool.clone(),
194 sink,
195 job_type: PhantomData,
196 codec: PhantomData,
197 })
198 }
199}
200
201#[derive(Clone, Debug)]
202pub struct SharedFetcher<Compact> {
203 poller: Shared<BoxFuture<'static, ()>>,
204 receiver: Arc<std::sync::Mutex<Receiver<SqliteTask<Compact>>>>,
205}
206
207impl<Compact> Stream for SharedFetcher<Compact> {
208 type Item = SqliteTask<Compact>;
209
210 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
211 let this = self.get_mut();
212 let _ = this.poller.poll_unpin(cx);
214
215 this.receiver.lock().unwrap().poll_next_unpin(cx)
217 }
218}
219
220impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
221where
222 Args: Send + 'static + Unpin + Sync,
223 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send + Sync,
224 Decode::Error: std::error::Error + Send + Sync + 'static,
225{
226 type Args = Args;
227
228 type IdType = Ulid;
229
230 type Error = sqlx::Error;
231
232 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
233
234 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
235
236 type Context = SqliteContext;
237
238 type Layer = Stack<AcknowledgeLayer<SqliteAck>, LockTaskLayer>;
239
240 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
241 let keep_alive_interval = *self.config.keep_alive();
242 let pool = self.pool.clone();
243 let worker = worker.clone();
244 let config = self.config.clone();
245
246 stream::unfold((), move |()| async move {
247 apalis_core::timer::sleep(keep_alive_interval).await;
248 Some(((), ()))
249 })
250 .then(move |_| keep_alive(pool.clone(), config.clone(), worker.clone()))
251 .boxed()
252 }
253
254 fn middleware(&self) -> Self::Layer {
255 let lock = LockTaskLayer::new(self.pool.clone());
256 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
257 Stack::new(ack, lock)
258 }
259
260 fn poll(self, worker: &WorkerContext) -> Self::Stream {
261 self.poll_shared(worker)
262 .map(|a| match a {
263 Ok(Some(task)) => Ok(Some(
264 task.try_map(|t| Decode::decode(&t))
265 .map_err(|e| sqlx::Error::Decode(e.into()))?,
266 )),
267 Ok(None) => Ok(None),
268 Err(e) => Err(e),
269 })
270 .boxed()
271 }
272}
273
274impl<Args, Decode: Send + 'static> BackendExt
275 for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
276where
277 Self: Backend<Args = Args, IdType = Ulid, Context = SqliteContext, Error = sqlx::Error>,
278 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
279 Decode::Error: std::error::Error + Send + Sync + 'static,
280 Args: Send + 'static + Unpin,
281{
282 type Codec = Decode;
283 type Compact = CompactType;
284 type CompactStream = TaskStream<SqliteTask<Self::Compact>, sqlx::Error>;
285
286 fn get_queue(&self) -> Queue {
287 self.config.queue().to_owned()
288 }
289
290 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
291 self.poll_shared(worker).boxed()
292 }
293}
294
295impl<Args, Decode: Send + 'static> SqliteStorage<Args, Decode, SharedFetcher<CompactType>> {
296 fn poll_shared(
297 self,
298 worker: &WorkerContext,
299 ) -> impl Stream<Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>> + 'static {
300 let pool = self.pool.clone();
301 let worker = worker.clone();
302 let init = initial_heartbeat(
308 pool,
309 self.config.clone(),
310 worker.clone(),
311 "SharedSqliteStorage",
312 );
313 let starter = stream::once(init)
314 .map_ok(|_| None) .boxed();
316 let lazy_fetcher = self.fetcher.map(|s| Ok(Some(s))).boxed();
317
318 let eager_fetcher = StreamExt::boxed(SqlitePollFetcher::<CompactType, Decode>::new(
319 &self.pool,
320 &self.config,
321 &worker,
322 ));
323 starter.chain(select(lazy_fetcher, eager_fetcher)).boxed()
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use std::time::Duration;
330
331 use apalis_core::{
332 backend::TaskSink, error::BoxDynError, task::task_id::TaskId,
333 worker::builder::WorkerBuilder,
334 };
335
336 use super::*;
337
338 #[tokio::test]
339 async fn basic_worker() {
340 let mut store = SharedSqliteStorage::new(":memory:");
341 SqliteStorage::setup(store.pool()).await.unwrap();
342
343 let mut map_store = store.make_shared().unwrap();
344
345 let mut int_store = store.make_shared().unwrap();
346
347 map_store
348 .push(HashMap::<String, i32>::from([("value".to_string(), 42)]))
349 .await
350 .unwrap();
351 int_store.push(99).await.unwrap();
352
353 async fn send_reminder<T, I>(
354 _: T,
355 _task_id: TaskId<I>,
356 wrk: WorkerContext,
357 ) -> Result<(), BoxDynError> {
358 tokio::time::sleep(Duration::from_secs(2)).await;
359 wrk.stop().unwrap();
360 Ok(())
361 }
362
363 let int_worker = WorkerBuilder::new("rango-tango-2")
364 .backend(int_store)
365 .build(send_reminder);
366 let map_worker = WorkerBuilder::new("rango-tango-1")
367 .backend(map_store)
368 .build(send_reminder);
369 tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
370 }
371}