1use std::{
2 cmp::max,
3 collections::{HashMap, HashSet},
4 ffi::c_void,
5 future::ready,
6 marker::PhantomData,
7 pin::Pin,
8 sync::Arc,
9 task::{Context, Poll},
10};
11
12use crate::{
13 CompactType, Config, INSERT_OPERATION, JOBS_TABLE, SqliteStorage, SqliteTask,
14 ack::{LockTaskLayer, SqliteAck},
15 callback::{DbEvent, update_hook_callback},
16 fetcher::SqlitePollFetcher,
17 initial_heartbeat, keep_alive,
18};
19use crate::{from_row::SqliteTaskRow, sink::SqliteSink};
20use apalis_core::{
21 backend::{
22 Backend, TaskStream,
23 codec::{Codec, json::JsonCodec},
24 shared::MakeShared,
25 },
26 layers::Stack,
27 worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
28};
29use apalis_sql::{context::SqlContext, from_row::TaskRow};
30use futures::{
31 FutureExt, SinkExt, Stream, StreamExt, TryStreamExt,
32 channel::mpsc::{self, Receiver, Sender},
33 future::{BoxFuture, Shared},
34 lock::Mutex,
35 stream::{self, BoxStream, select},
36};
37use libsqlite3_sys::{sqlite3, sqlite3_update_hook};
38use sqlx::SqlitePool;
39use ulid::Ulid;
40
41pub struct SharedSqliteStorage<Decode> {
42 pool: SqlitePool,
43 registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>>,
44 drive: Shared<BoxFuture<'static, ()>>,
45 _marker: PhantomData<Decode>,
46}
47
48impl SharedSqliteStorage<JsonCodec<CompactType>> {
49 pub fn new(pool: SqlitePool) -> Self {
50 SharedSqliteStorage::new_with_codec(pool)
51 }
52 pub fn new_with_codec<Codec>(pool: SqlitePool) -> SharedSqliteStorage<Codec> {
53 let registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>> =
54 Arc::new(Mutex::new(HashMap::default()));
55 let p = pool.clone();
56 let instances = registry.clone();
57 SharedSqliteStorage {
58 pool,
59 drive: async move {
60 let (tx, rx) = mpsc::unbounded::<DbEvent>();
61
62 let mut conn = p.acquire().await.unwrap();
63 let handle: *mut sqlite3 =
65 conn.lock_handle().await.unwrap().as_raw_handle().as_ptr();
66
67 let tx_box = Box::new(tx);
69 let tx_ptr = Box::into_raw(tx_box) as *mut c_void;
70
71 unsafe {
72 sqlite3_update_hook(handle, Some(update_hook_callback), tx_ptr);
73 }
74
75 rx.filter(|a| {
76 ready(a.operation() == INSERT_OPERATION && a.table_name() == JOBS_TABLE)
77 })
78 .ready_chunks(instances.try_lock().map(|r| r.len()).unwrap_or(10))
79 .then(|events| {
80 let row_ids = events.iter().map(|e| e.rowid()).collect::<HashSet<i64>>();
81 let instances = instances.clone();
82 let pool = p.clone();
83 async move {
84 let instances = instances.lock().await;
85 let job_types = serde_json::to_string(
86 &instances.keys().cloned().collect::<Vec<String>>(),
87 )
88 .unwrap();
89 let row_ids = serde_json::to_string(&row_ids).unwrap();
90 let mut tx = pool.begin().await?;
91 let buffer_size = max(10, instances.len()) as i32;
92 let res: Vec<_> = sqlx::query_file_as!(
93 SqliteTaskRow,
94 "queries/backend/fetch_next_shared.sql",
95 job_types,
96 row_ids,
97 buffer_size,
98 )
99 .fetch(&mut *tx)
100 .map(|r| {
101 let row: TaskRow = r?.try_into()?;
102 row.try_into_task_compact()
103 .map_err(|e| sqlx::Error::Protocol(e.to_string()))
104 })
105 .try_collect()
106 .await?;
107 tx.commit().await?;
108 Ok::<_, sqlx::Error>(res)
109 }
110 })
111 .for_each(|r| async {
112 match r {
113 Ok(tasks) => {
114 let mut instances = instances.lock().await;
115 for task in tasks {
116 if let Some(tx) = instances.get_mut(
117 task.parts
118 .ctx
119 .queue()
120 .as_ref()
121 .expect("Namespace must be set"),
122 ) {
123 let _ = tx.send(task).await;
124 }
125 }
126 }
127 Err(e) => {
128 log::error!("Error fetching tasks: {e:?}");
129 }
130 }
131 })
132 .await;
133 }
134 .boxed()
135 .shared(),
136 registry,
137 _marker: PhantomData,
138 }
139 }
140}
141#[derive(Debug, thiserror::Error)]
142pub enum SharedPostgresError {
143 #[error("Namespace {0} already exists")]
144 NamespaceExists(String),
145 #[error("Could not acquire registry lock")]
146 RegistryLocked,
147}
148
149impl<Args, Decode: Codec<Args, Compact = CompactType>> MakeShared<Args>
150 for SharedSqliteStorage<Decode>
151{
152 type Backend = SqliteStorage<Args, Decode, SharedFetcher<CompactType>>;
153 type Config = Config;
154 type MakeError = SharedPostgresError;
155 fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
156 where
157 Self::Config: Default,
158 {
159 Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
160 }
161 fn make_shared_with_config(
162 &mut self,
163 config: Self::Config,
164 ) -> Result<Self::Backend, Self::MakeError> {
165 let (tx, rx) = mpsc::channel(config.buffer_size());
166 let mut r = self
167 .registry
168 .try_lock()
169 .ok_or(SharedPostgresError::RegistryLocked)?;
170 if r.insert(config.queue().to_string(), tx).is_some() {
171 return Err(SharedPostgresError::NamespaceExists(
172 config.queue().to_string(),
173 ));
174 }
175 let sink = SqliteSink::new(&self.pool, &config);
176 Ok(SqliteStorage {
177 config,
178 fetcher: SharedFetcher {
179 poller: self.drive.clone(),
180 receiver: rx,
181 },
182 pool: self.pool.clone(),
183 sink,
184 job_type: PhantomData,
185 codec: PhantomData,
186 })
187 }
188}
189
190pub struct SharedFetcher<Compact> {
191 poller: Shared<BoxFuture<'static, ()>>,
192 receiver: Receiver<SqliteTask<Compact>>,
193}
194
195impl<Compact> Stream for SharedFetcher<Compact> {
196 type Item = SqliteTask<Compact>;
197
198 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199 let this = self.get_mut();
200 let _ = this.poller.poll_unpin(cx);
202
203 this.receiver.poll_next_unpin(cx)
205 }
206}
207
208impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
209where
210 Args: Send + 'static + Unpin + Sync,
211 Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send + Sync,
212 Decode::Error: std::error::Error + Send + Sync + 'static,
213{
214 type Args = Args;
215
216 type IdType = Ulid;
217
218 type Error = sqlx::Error;
219
220 type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
221
222 type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
223
224 type Codec = Decode;
225
226 type Compact = CompactType;
227
228 type Context = SqlContext;
229
230 type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
231
232 fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
233 let keep_alive_interval = *self.config.keep_alive();
234 let pool = self.pool.clone();
235 let worker = worker.clone();
236 let config = self.config.clone();
237
238 stream::unfold((), move |()| async move {
239 apalis_core::timer::sleep(keep_alive_interval).await;
240 Some(((), ()))
241 })
242 .then(move |_| keep_alive(pool.clone(), config.clone(), worker.clone()))
243 .boxed()
244 }
245
246 fn middleware(&self) -> Self::Layer {
247 let lock = LockTaskLayer::new(self.pool.clone());
248 let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
249 Stack::new(lock, ack)
250 }
251
252 fn poll(self, worker: &WorkerContext) -> Self::Stream {
253 let pool = self.pool.clone();
254 let worker = worker.clone();
255 let init = initial_heartbeat(
261 pool.clone(),
262 self.config.clone(),
263 worker.clone(),
264 "SharedSqliteStorage",
265 );
266 let starter = stream::once(init)
267 .map_ok(|_| None) .boxed();
269 let lazy_fetcher = self
270 .fetcher
271 .map(|t| {
272 t.try_map(|args| Decode::decode(&args).map_err(|e| sqlx::Error::Decode(e.into())))
273 })
274 .flat_map(|vec| match vec {
275 Ok(task) => stream::iter(vec![Ok(Some(task))]).boxed(),
276 Err(e) => stream::once(ready(Err(e))).boxed(),
277 })
278 .boxed();
279
280 let eager_fetcher = StreamExt::boxed(SqlitePollFetcher::<Args, CompactType, Decode>::new(
281 &self.pool,
282 &self.config,
283 &worker,
284 ));
285 starter.chain(select(lazy_fetcher, eager_fetcher)).boxed()
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use std::time::Duration;
292
293 use apalis_core::{
294 backend::TaskSink, error::BoxDynError, task::task_id::TaskId,
295 worker::builder::WorkerBuilder,
296 };
297
298 use super::*;
299
300 #[tokio::test]
301 async fn basic_worker() {
302 let pool = SqlitePool::connect(":memory:").await.unwrap();
303
304 SqliteStorage::setup(&pool).await.unwrap();
305
306 let mut store = SharedSqliteStorage::new(pool);
307
308 let mut map_store = store.make_shared().unwrap();
309
310 let mut int_store = store.make_shared().unwrap();
311
312 map_store
313 .push(HashMap::<String, i32>::from([("value".to_string(), 42)]))
314 .await
315 .unwrap();
316 int_store.push(99).await.unwrap();
317
318 async fn send_reminder<T, I>(
319 _: T,
320 _task_id: TaskId<I>,
321 wrk: WorkerContext,
322 ) -> Result<(), BoxDynError> {
323 tokio::time::sleep(Duration::from_secs(2)).await;
324 wrk.stop().unwrap();
325 Ok(())
326 }
327
328 let int_worker = WorkerBuilder::new("rango-tango-2")
329 .backend(int_store)
330 .build(send_reminder);
331 let map_worker = WorkerBuilder::new("rango-tango-1")
332 .backend(map_store)
333 .build(send_reminder);
334 tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
335 }
336}