1use std::{
2 cmp::max,
3 collections::{HashMap, HashSet},
4 future::ready,
5 marker::PhantomData,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9};
10
11use crate::{
12 CompactType, Config, JOBS_TABLE, SqliteStorage, SqliteTask,
13 ack::{LockTaskLayer, SqliteAck},
14 callback::{DbEvent, update_hook_callback},
15 fetcher::SqlitePollFetcher,
16 initial_heartbeat, keep_alive,
17};
18use crate::{from_row::SqliteTaskRow, sink::SqliteSink};
19use apalis_core::{
20 backend::{
21 Backend, BackendExt, TaskStream,
22 codec::{Codec, json::JsonCodec},
23 shared::MakeShared,
24 },
25 layers::Stack,
26 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
27};
28use apalis_sql::{context::SqlContext, from_row::TaskRow};
29use futures::{
30 FutureExt, SinkExt, Stream, StreamExt, TryStreamExt,
31 channel::mpsc::{self, Receiver, Sender},
32 future::{BoxFuture, Shared},
33 lock::Mutex,
34 stream::{self, BoxStream, select},
35};
36use sqlx::{Sqlite, SqlitePool, pool::PoolOptions, sqlite::SqliteOperation};
37use ulid::Ulid;
38
39#[derive(Clone, Debug)]
41pub struct SharedSqliteStorage<Decode> {
42 pool: SqlitePool,
43 registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>>,
44 drive: Shared<BoxFuture<'static, ()>>,
45 _marker: PhantomData<Decode>,
46}
47
48impl<Decode> SharedSqliteStorage<Decode> {
49 #[must_use]
51 pub fn pool(&self) -> &SqlitePool {
52 &self.pool
53 }
54}
55
56impl SharedSqliteStorage<JsonCodec<CompactType>> {
57 #[must_use]
59 pub fn new(url: &str) -> Self {
60 Self::new_with_codec(url)
61 }
62 #[must_use]
64 pub fn new_with_codec<Codec>(url: &str) -> SharedSqliteStorage<Codec> {
65 let (tx, rx) = mpsc::unbounded::<DbEvent>();
66 let pool = PoolOptions::<Sqlite>::new()
67 .after_connect(move |conn, _meta| {
68 let mut tx = tx.clone();
69 Box::pin(async move {
70 let mut lock_handle = conn.lock_handle().await?;
71 lock_handle.set_update_hook(move |ev| update_hook_callback(ev, &mut tx));
72 Ok(())
73 })
74 })
75 .connect_lazy(url)
76 .expect("Failed to create Sqlite pool");
77
78 let registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>> =
79 Arc::new(Mutex::new(HashMap::default()));
80 let p = pool.clone();
81 let instances = registry.clone();
82 SharedSqliteStorage {
83 pool,
84 drive: async move {
85 rx.filter(|a| {
86 ready(a.operation() == &SqliteOperation::Insert && a.table_name() == JOBS_TABLE)
87 })
88 .ready_chunks(instances.try_lock().map(|r| r.len()).unwrap_or(10))
89 .then(|events| {
90 let row_ids = events.iter().map(|e| e.rowid()).collect::<HashSet<i64>>();
91 let instances = instances.clone();
92 let pool = p.clone();
93 async move {
94 let instances = instances.lock().await;
95 let job_types = serde_json::to_string(
96 &instances.keys().cloned().collect::<Vec<String>>(),
97 )
98 .unwrap();
99 let row_ids = serde_json::to_string(&row_ids).unwrap();
100 let mut tx = pool.begin().await?;
101 let buffer_size = max(10, instances.len()) as i32;
102 let res: Vec<_> = sqlx::query_file_as!(
103 SqliteTaskRow,
104 "queries/backend/fetch_next_shared.sql",
105 job_types,
106 row_ids,
107 buffer_size,
108 )
109 .fetch(&mut *tx)
110 .map(|r| {
111 let row: TaskRow = r?.try_into()?;
112 row.try_into_task_compact()
113 .map_err(|e| sqlx::Error::Protocol(e.to_string()))
114 })
115 .try_collect()
116 .await?;
117 tx.commit().await?;
118 Ok::<_, sqlx::Error>(res)
119 }
120 })
121 .for_each(|r| async {
122 match r {
123 Ok(tasks) => {
124 let mut instances = instances.lock().await;
125 for task in tasks {
126 if let Some(tx) = instances.get_mut(
127 task.parts
128 .ctx
129 .queue()
130 .as_ref()
131 .expect("Namespace must be set"),
132 ) {
133 let _ = tx.send(task).await;
134 }
135 }
136 }
137 Err(e) => {
138 log::error!("Error fetching tasks: {e:?}");
139 }
140 }
141 })
142 .await;
143 }
144 .boxed()
145 .shared(),
146 registry,
147 _marker: PhantomData,
148 }
149 }
150}
151
152#[derive(Debug, thiserror::Error)]
154pub enum SharedSqliteError {
155 #[error("Namespace {0} already exists")]
157 NamespaceExists(String),
158 #[error("Could not acquire registry lock")]
160 RegistryLocked,
161}
162
163impl<Args, Decode: Codec<Args, Compact = CompactType>> MakeShared<Args>
164 for SharedSqliteStorage<Decode>
165{
166 type Backend = SqliteStorage<Args, Decode, SharedFetcher<CompactType>>;
167 type Config = Config;
168 type MakeError = SharedSqliteError;
169 fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
170 where
171 Self::Config: Default,
172 {
173 Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
174 }
175 fn make_shared_with_config(
176 &mut self,
177 config: Self::Config,
178 ) -> Result<Self::Backend, Self::MakeError> {
179 let (tx, rx) = mpsc::channel(config.buffer_size());
180 let mut r = self
181 .registry
182 .try_lock()
183 .ok_or(SharedSqliteError::RegistryLocked)?;
184 if r.insert(config.queue().to_string(), tx).is_some() {
185 return Err(SharedSqliteError::NamespaceExists(
186 config.queue().to_string(),
187 ));
188 }
189 let sink = SqliteSink::new(&self.pool, &config);
190 Ok(SqliteStorage {
191 config,
192 fetcher: SharedFetcher {
193 poller: self.drive.clone(),
194 receiver: Arc::new(std::sync::Mutex::new(rx)),
195 },
196 pool: self.pool.clone(),
197 sink,
198 job_type: PhantomData,
199 codec: PhantomData,
200 })
201 }
202}
203
204#[derive(Clone, Debug)]
205pub struct SharedFetcher<Compact> {
206 poller: Shared<BoxFuture<'static, ()>>,
207 receiver: Arc<std::sync::Mutex<Receiver<SqliteTask<Compact>>>>,
208}
209
210impl<Compact> Stream for SharedFetcher<Compact> {
211 type Item = SqliteTask<Compact>;
212
213 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
214 let this = self.get_mut();
215 let _ = this.poller.poll_unpin(cx);
217
218 this.receiver.lock().unwrap().poll_next_unpin(cx)
220 }
221}
222
223impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
224where
225 Args: Send + 'static + Unpin + Sync,
226 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send + Sync,
227 Decode::Error: std::error::Error + Send + Sync + 'static,
228{
229 type Args = Args;
230
231 type IdType = Ulid;
232
233 type Error = sqlx::Error;
234
235 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
236
237 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
238
239 type Context = SqlContext;
240
241 type Layer = Stack<AcknowledgeLayer<SqliteAck>, LockTaskLayer>;
242
243 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
244 let keep_alive_interval = *self.config.keep_alive();
245 let pool = self.pool.clone();
246 let worker = worker.clone();
247 let config = self.config.clone();
248
249 stream::unfold((), move |()| async move {
250 apalis_core::timer::sleep(keep_alive_interval).await;
251 Some(((), ()))
252 })
253 .then(move |_| keep_alive(pool.clone(), config.clone(), worker.clone()))
254 .boxed()
255 }
256
257 fn middleware(&self) -> Self::Layer {
258 let lock = LockTaskLayer::new(self.pool.clone());
259 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
260 Stack::new(ack, lock)
261 }
262
263 fn poll(self, worker: &WorkerContext) -> Self::Stream {
264 self.poll_shared(worker)
265 .map(|a| match a {
266 Ok(Some(task)) => Ok(Some(
267 task.try_map(|t| Decode::decode(&t))
268 .map_err(|e| sqlx::Error::Decode(e.into()))?,
269 )),
270 Ok(None) => Ok(None),
271 Err(e) => Err(e),
272 })
273 .boxed()
274 }
275}
276
277impl<Args, Decode: Send + 'static> BackendExt
278 for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
279where
280 Self: Backend<Args = Args, IdType = Ulid, Context = SqlContext, Error = sqlx::Error>,
281 Decode: Codec<Args, Compact = CompactType> + Send + 'static,
282 Decode::Error: std::error::Error + Send + Sync + 'static,
283 Args: Send + 'static + Unpin,
284{
285 type Codec = Decode;
286 type Compact = CompactType;
287 type CompactStream = TaskStream<SqliteTask<Self::Compact>, sqlx::Error>;
288
289 fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
290 self.poll_shared(worker).boxed()
291 }
292}
293
294impl<Args, Decode: Send + 'static> SqliteStorage<Args, Decode, SharedFetcher<CompactType>> {
295 fn poll_shared(
296 self,
297 worker: &WorkerContext,
298 ) -> impl Stream<Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>> + 'static {
299 let pool = self.pool.clone();
300 let worker = worker.clone();
301 let init = initial_heartbeat(
307 pool,
308 self.config.clone(),
309 worker.clone(),
310 "SharedSqliteStorage",
311 );
312 let starter = stream::once(init)
313 .map_ok(|_| None) .boxed();
315 let lazy_fetcher = self.fetcher.map(|s| Ok(Some(s))).boxed();
316
317 let eager_fetcher = StreamExt::boxed(SqlitePollFetcher::<CompactType, Decode>::new(
318 &self.pool,
319 &self.config,
320 &worker,
321 ));
322 starter.chain(select(lazy_fetcher, eager_fetcher)).boxed()
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use std::time::Duration;
329
330 use apalis_core::{
331 backend::TaskSink, error::BoxDynError, task::task_id::TaskId,
332 worker::builder::WorkerBuilder,
333 };
334
335 use super::*;
336
337 #[tokio::test]
338 async fn basic_worker() {
339 let mut store = SharedSqliteStorage::new(":memory:");
340 SqliteStorage::setup(store.pool()).await.unwrap();
341
342 let mut map_store = store.make_shared().unwrap();
343
344 let mut int_store = store.make_shared().unwrap();
345
346 map_store
347 .push(HashMap::<String, i32>::from([("value".to_string(), 42)]))
348 .await
349 .unwrap();
350 int_store.push(99).await.unwrap();
351
352 async fn send_reminder<T, I>(
353 _: T,
354 _task_id: TaskId<I>,
355 wrk: WorkerContext,
356 ) -> Result<(), BoxDynError> {
357 tokio::time::sleep(Duration::from_secs(2)).await;
358 wrk.stop().unwrap();
359 Ok(())
360 }
361
362 let int_worker = WorkerBuilder::new("rango-tango-2")
363 .backend(int_store)
364 .build(send_reminder);
365 let map_worker = WorkerBuilder::new("rango-tango-1")
366 .backend(map_store)
367 .build(send_reminder);
368 tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
369 }
370}