apalis_sqlite/
shared.rs

1use std::{
2    cmp::max,
3    collections::{HashMap, HashSet},
4    future::ready,
5    marker::PhantomData,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9};
10
11use crate::{
12    CompactType, Config, JOBS_TABLE, SqliteStorage, SqliteTask,
13    ack::{LockTaskLayer, SqliteAck},
14    callback::{DbEvent, update_hook_callback},
15    fetcher::SqlitePollFetcher,
16    initial_heartbeat, keep_alive,
17};
18use crate::{from_row::SqliteTaskRow, sink::SqliteSink};
19use apalis_core::{
20    backend::{
21        Backend, BackendExt, TaskStream,
22        codec::{Codec, json::JsonCodec},
23        shared::MakeShared,
24    },
25    layers::Stack,
26    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
27};
28use apalis_sql::{context::SqlContext, from_row::TaskRow};
29use futures::{
30    FutureExt, SinkExt, Stream, StreamExt, TryStreamExt,
31    channel::mpsc::{self, Receiver, Sender},
32    future::{BoxFuture, Shared},
33    lock::Mutex,
34    stream::{self, BoxStream, select},
35};
36use sqlx::{Sqlite, SqlitePool, pool::PoolOptions, sqlite::SqliteOperation};
37use ulid::Ulid;
38
39/// Shared Sqlite storage backend that can be used across multiple workers
40#[derive(Clone, Debug)]
41pub struct SharedSqliteStorage<Decode> {
42    pool: SqlitePool,
43    registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>>,
44    drive: Shared<BoxFuture<'static, ()>>,
45    _marker: PhantomData<Decode>,
46}
47
48impl<Decode> SharedSqliteStorage<Decode> {
49    /// Get a reference to the underlying Sqlite connection pool
50    #[must_use]
51    pub fn pool(&self) -> &SqlitePool {
52        &self.pool
53    }
54}
55
56impl SharedSqliteStorage<JsonCodec<CompactType>> {
57    /// Create a new shared Sqlite storage backend with the given database URL
58    #[must_use]
59    pub fn new(url: &str) -> Self {
60        Self::new_with_codec(url)
61    }
62    /// Create a new shared Sqlite storage backend with the given database URL and codec
63    #[must_use]
64    pub fn new_with_codec<Codec>(url: &str) -> SharedSqliteStorage<Codec> {
65        let (tx, rx) = mpsc::unbounded::<DbEvent>();
66        let pool = PoolOptions::<Sqlite>::new()
67            .after_connect(move |conn, _meta| {
68                let mut tx = tx.clone();
69                Box::pin(async move {
70                    let mut lock_handle = conn.lock_handle().await?;
71                    lock_handle.set_update_hook(move |ev| update_hook_callback(ev, &mut tx));
72                    Ok(())
73                })
74            })
75            .connect_lazy(url)
76            .expect("Failed to create Sqlite pool");
77
78        let registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>> =
79            Arc::new(Mutex::new(HashMap::default()));
80        let p = pool.clone();
81        let instances = registry.clone();
82        SharedSqliteStorage {
83            pool,
84            drive: async move {
85                rx.filter(|a| {
86                    ready(a.operation() == &SqliteOperation::Insert && a.table_name() == JOBS_TABLE)
87                })
88                .ready_chunks(instances.try_lock().map(|r| r.len()).unwrap_or(10))
89                .then(|events| {
90                    let row_ids = events.iter().map(|e| e.rowid()).collect::<HashSet<i64>>();
91                    let instances = instances.clone();
92                    let pool = p.clone();
93                    async move {
94                        let instances = instances.lock().await;
95                        let job_types = serde_json::to_string(
96                            &instances.keys().cloned().collect::<Vec<String>>(),
97                        )
98                        .unwrap();
99                        let row_ids = serde_json::to_string(&row_ids).unwrap();
100                        let mut tx = pool.begin().await?;
101                        let buffer_size = max(10, instances.len()) as i32;
102                        let res: Vec<_> = sqlx::query_file_as!(
103                            SqliteTaskRow,
104                            "queries/backend/fetch_next_shared.sql",
105                            job_types,
106                            row_ids,
107                            buffer_size,
108                        )
109                        .fetch(&mut *tx)
110                        .map(|r| {
111                            let row: TaskRow = r?.try_into()?;
112                            row.try_into_task_compact()
113                                .map_err(|e| sqlx::Error::Protocol(e.to_string()))
114                        })
115                        .try_collect()
116                        .await?;
117                        tx.commit().await?;
118                        Ok::<_, sqlx::Error>(res)
119                    }
120                })
121                .for_each(|r| async {
122                    match r {
123                        Ok(tasks) => {
124                            let mut instances = instances.lock().await;
125                            for task in tasks {
126                                if let Some(tx) = instances.get_mut(
127                                    task.parts
128                                        .ctx
129                                        .queue()
130                                        .as_ref()
131                                        .expect("Namespace must be set"),
132                                ) {
133                                    let _ = tx.send(task).await;
134                                }
135                            }
136                        }
137                        Err(e) => {
138                            log::error!("Error fetching tasks: {e:?}");
139                        }
140                    }
141                })
142                .await;
143            }
144            .boxed()
145            .shared(),
146            registry,
147            _marker: PhantomData,
148        }
149    }
150}
151
152/// Errors that can occur when creating a shared Sqlite storage backend
153#[derive(Debug, thiserror::Error)]
154pub enum SharedSqliteError {
155    /// Namespace already exists in the registry
156    #[error("Namespace {0} already exists")]
157    NamespaceExists(String),
158    /// Could not acquire registry loc
159    #[error("Could not acquire registry lock")]
160    RegistryLocked,
161}
162
163impl<Args, Decode: Codec<Args, Compact = CompactType>> MakeShared<Args>
164    for SharedSqliteStorage<Decode>
165{
166    type Backend = SqliteStorage<Args, Decode, SharedFetcher<CompactType>>;
167    type Config = Config;
168    type MakeError = SharedSqliteError;
169    fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
170    where
171        Self::Config: Default,
172    {
173        Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
174    }
175    fn make_shared_with_config(
176        &mut self,
177        config: Self::Config,
178    ) -> Result<Self::Backend, Self::MakeError> {
179        let (tx, rx) = mpsc::channel(config.buffer_size());
180        let mut r = self
181            .registry
182            .try_lock()
183            .ok_or(SharedSqliteError::RegistryLocked)?;
184        if r.insert(config.queue().to_string(), tx).is_some() {
185            return Err(SharedSqliteError::NamespaceExists(
186                config.queue().to_string(),
187            ));
188        }
189        let sink = SqliteSink::new(&self.pool, &config);
190        Ok(SqliteStorage {
191            config,
192            fetcher: SharedFetcher {
193                poller: self.drive.clone(),
194                receiver: Arc::new(std::sync::Mutex::new(rx)),
195            },
196            pool: self.pool.clone(),
197            sink,
198            job_type: PhantomData,
199            codec: PhantomData,
200        })
201    }
202}
203
204#[derive(Clone, Debug)]
205pub struct SharedFetcher<Compact> {
206    poller: Shared<BoxFuture<'static, ()>>,
207    receiver: Arc<std::sync::Mutex<Receiver<SqliteTask<Compact>>>>,
208}
209
210impl<Compact> Stream for SharedFetcher<Compact> {
211    type Item = SqliteTask<Compact>;
212
213    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
214        let this = self.get_mut();
215        // Keep the poller alive by polling it, but ignoring the output
216        let _ = this.poller.poll_unpin(cx);
217
218        // Delegate actual items to receiver
219        this.receiver.lock().unwrap().poll_next_unpin(cx)
220    }
221}
222
223impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
224where
225    Args: Send + 'static + Unpin + Sync,
226    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send + Sync,
227    Decode::Error: std::error::Error + Send + Sync + 'static,
228{
229    type Args = Args;
230
231    type IdType = Ulid;
232
233    type Error = sqlx::Error;
234
235    type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
236
237    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
238
239    type Context = SqlContext;
240
241    type Layer = Stack<AcknowledgeLayer<SqliteAck>, LockTaskLayer>;
242
243    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
244        let keep_alive_interval = *self.config.keep_alive();
245        let pool = self.pool.clone();
246        let worker = worker.clone();
247        let config = self.config.clone();
248
249        stream::unfold((), move |()| async move {
250            apalis_core::timer::sleep(keep_alive_interval).await;
251            Some(((), ()))
252        })
253        .then(move |_| keep_alive(pool.clone(), config.clone(), worker.clone()))
254        .boxed()
255    }
256
257    fn middleware(&self) -> Self::Layer {
258        let lock = LockTaskLayer::new(self.pool.clone());
259        let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
260        Stack::new(ack, lock)
261    }
262
263    fn poll(self, worker: &WorkerContext) -> Self::Stream {
264        self.poll_shared(worker)
265            .map(|a| match a {
266                Ok(Some(task)) => Ok(Some(
267                    task.try_map(|t| Decode::decode(&t))
268                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
269                )),
270                Ok(None) => Ok(None),
271                Err(e) => Err(e),
272            })
273            .boxed()
274    }
275}
276
277impl<Args, Decode: Send + 'static> BackendExt
278    for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
279where
280    Self: Backend<Args = Args, IdType = Ulid, Context = SqlContext, Error = sqlx::Error>,
281    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
282    Decode::Error: std::error::Error + Send + Sync + 'static,
283    Args: Send + 'static + Unpin,
284{
285    type Codec = Decode;
286    type Compact = CompactType;
287    type CompactStream = TaskStream<SqliteTask<Self::Compact>, sqlx::Error>;
288
289    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
290        self.poll_shared(worker).boxed()
291    }
292}
293
294impl<Args, Decode: Send + 'static> SqliteStorage<Args, Decode, SharedFetcher<CompactType>> {
295    fn poll_shared(
296        self,
297        worker: &WorkerContext,
298    ) -> impl Stream<Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>> + 'static {
299        let pool = self.pool.clone();
300        let worker = worker.clone();
301        // Initial registration heartbeat
302        // This ensures that the worker is registered before fetching any tasks
303        // This also ensures that the worker is marked as alive in case it crashes
304        // before fetching any tasks
305        // Subsequent heartbeats are handled in the heartbeat stream
306        let init = initial_heartbeat(
307            pool,
308            self.config.clone(),
309            worker.clone(),
310            "SharedSqliteStorage",
311        );
312        let starter = stream::once(init)
313            .map_ok(|_| None) // Noop after initial heartbeat
314            .boxed();
315        let lazy_fetcher = self.fetcher.map(|s| Ok(Some(s))).boxed();
316
317        let eager_fetcher = StreamExt::boxed(SqlitePollFetcher::<CompactType, Decode>::new(
318            &self.pool,
319            &self.config,
320            &worker,
321        ));
322        starter.chain(select(lazy_fetcher, eager_fetcher)).boxed()
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use std::time::Duration;
329
330    use apalis_core::{
331        backend::TaskSink, error::BoxDynError, task::task_id::TaskId,
332        worker::builder::WorkerBuilder,
333    };
334
335    use super::*;
336
337    #[tokio::test]
338    async fn basic_worker() {
339        let mut store = SharedSqliteStorage::new(":memory:");
340        SqliteStorage::setup(store.pool()).await.unwrap();
341
342        let mut map_store = store.make_shared().unwrap();
343
344        let mut int_store = store.make_shared().unwrap();
345
346        map_store
347            .push(HashMap::<String, i32>::from([("value".to_string(), 42)]))
348            .await
349            .unwrap();
350        int_store.push(99).await.unwrap();
351
352        async fn send_reminder<T, I>(
353            _: T,
354            _task_id: TaskId<I>,
355            wrk: WorkerContext,
356        ) -> Result<(), BoxDynError> {
357            tokio::time::sleep(Duration::from_secs(2)).await;
358            wrk.stop().unwrap();
359            Ok(())
360        }
361
362        let int_worker = WorkerBuilder::new("rango-tango-2")
363            .backend(int_store)
364            .build(send_reminder);
365        let map_worker = WorkerBuilder::new("rango-tango-1")
366            .backend(map_store)
367            .build(send_reminder);
368        tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
369    }
370}