Skip to main content

apalis_sqlite/
shared.rs

1use std::{
2    cmp::max,
3    collections::{HashMap, HashSet},
4    future::ready,
5    marker::PhantomData,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll},
9};
10
11use crate::{
12    CompactType, Config, JOBS_TABLE, SqliteContext, SqliteStorage, SqliteTask,
13    ack::{LockTaskLayer, SqliteAck},
14    callback::{DbEvent, update_hook_callback},
15    fetcher::SqlitePollFetcher,
16    initial_heartbeat, keep_alive,
17};
18use crate::{from_row::SqliteTaskRow, sink::SqliteSink};
19use apalis_codec::json::JsonCodec;
20use apalis_core::{
21    backend::{Backend, BackendExt, TaskStream, codec::Codec, queue::Queue, shared::MakeShared},
22    layers::Stack,
23    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
24};
25use apalis_sql::TaskRow;
26use futures::{
27    FutureExt, SinkExt, Stream, StreamExt, TryStreamExt,
28    channel::mpsc::{self, Receiver, Sender},
29    future::{BoxFuture, Shared},
30    lock::Mutex,
31    stream::{self, BoxStream, select},
32};
33use sqlx::{Sqlite, SqlitePool, pool::PoolOptions, sqlite::SqliteOperation};
34use ulid::Ulid;
35
36/// Shared Sqlite storage backend that can be used across multiple workers
37#[derive(Clone, Debug)]
38pub struct SharedSqliteStorage<Decode> {
39    pool: SqlitePool,
40    registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>>,
41    drive: Shared<BoxFuture<'static, ()>>,
42    _marker: PhantomData<Decode>,
43}
44
45impl<Decode> SharedSqliteStorage<Decode> {
46    /// Get a reference to the underlying Sqlite connection pool
47    #[must_use]
48    pub fn pool(&self) -> &SqlitePool {
49        &self.pool
50    }
51}
52
53impl SharedSqliteStorage<JsonCodec<CompactType>> {
54    /// Create a new shared Sqlite storage backend with the given database URL
55    #[must_use]
56    pub fn new(url: &str) -> Self {
57        Self::new_with_codec(url)
58    }
59
60    /// Create a new shared Sqlite storage backend with the given database URL and codec
61    #[must_use]
62    pub fn new_with_codec<Codec>(url: &str) -> SharedSqliteStorage<Codec> {
63        Self::new_with_pool_options(
64            url,
65            PoolOptions::new().max_lifetime(None).idle_timeout(None),
66        )
67    }
68
69    /// Create a new shared Sqlite storage backend with the given database URL and pool options
70    #[must_use]
71    pub fn new_with_pool_options<Codec>(
72        url: &str,
73        options: PoolOptions<Sqlite>,
74    ) -> SharedSqliteStorage<Codec> {
75        let (tx, rx) = mpsc::unbounded::<DbEvent>();
76        let pool = options
77            .after_connect(move |conn, _meta| {
78                let mut tx = tx.clone();
79                Box::pin(async move {
80                    let mut lock_handle = conn.lock_handle().await?;
81                    lock_handle.set_update_hook(move |ev| update_hook_callback(ev, &mut tx));
82                    Ok(())
83                })
84            })
85            .connect_lazy(url)
86            .expect("Failed to create Sqlite pool");
87
88        let registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>> =
89            Arc::new(Mutex::new(HashMap::default()));
90        let p = pool.clone();
91        let instances = registry.clone();
92        SharedSqliteStorage {
93            pool,
94            drive: async move {
95                rx.filter(|a| {
96                    ready(a.operation() == &SqliteOperation::Insert && a.table_name() == JOBS_TABLE)
97                })
98                .ready_chunks(instances.try_lock().map(|r| r.len()).unwrap_or(10))
99                .then(|events| {
100                    let row_ids = events.iter().map(|e| e.rowid()).collect::<HashSet<i64>>();
101                    let instances = instances.clone();
102                    let pool = p.clone();
103                    async move {
104                        let instances = instances.lock().await;
105                        let job_types = serde_json::to_string(
106                            &instances.keys().cloned().collect::<Vec<String>>(),
107                        )
108                        .unwrap();
109                        let row_ids = serde_json::to_string(&row_ids).unwrap();
110                        let mut tx = pool.begin().await?;
111                        let buffer_size = max(10, instances.len()) as i32;
112                        let res: Vec<_> = sqlx::query_file_as!(
113                            SqliteTaskRow,
114                            "queries/backend/fetch_next_shared.sql",
115                            job_types,
116                            row_ids,
117                            buffer_size,
118                        )
119                        .fetch(&mut *tx)
120                        .map(|r| {
121                            let row: TaskRow = r?.try_into()?;
122                            row.try_into_task_compact()
123                                .map_err(|e| sqlx::Error::Protocol(e.to_string()))
124                        })
125                        .try_collect()
126                        .await?;
127                        tx.commit().await?;
128                        Ok::<_, sqlx::Error>(res)
129                    }
130                })
131                .for_each(|r| async {
132                    match r {
133                        Ok(tasks) => {
134                            let mut instances = instances.lock().await;
135                            for task in tasks {
136                                if let Some(tx) = instances.get_mut(
137                                    task.parts
138                                        .ctx
139                                        .queue()
140                                        .as_ref()
141                                        .expect("Namespace must be set"),
142                                ) {
143                                    let _ = tx.send(task).await;
144                                }
145                            }
146                        }
147                        Err(e) => {
148                            log::error!("Error fetching tasks: {e:?}");
149                        }
150                    }
151                })
152                .await;
153            }
154            .boxed()
155            .shared(),
156            registry,
157            _marker: PhantomData,
158        }
159    }
160}
161
162/// Errors that can occur when creating a shared Sqlite storage backend
163#[derive(Debug, thiserror::Error)]
164pub enum SharedSqliteError {
165    /// Namespace already exists in the registry
166    #[error("Namespace {0} already exists")]
167    NamespaceExists(String),
168    /// Could not acquire registry loc
169    #[error("Could not acquire registry lock")]
170    RegistryLocked,
171}
172
173impl<Args, Decode: Codec<Args, Compact = CompactType>> MakeShared<Args>
174    for SharedSqliteStorage<Decode>
175{
176    type Backend = SqliteStorage<Args, Decode, SharedFetcher<CompactType>>;
177    type Config = Config;
178    type MakeError = SharedSqliteError;
179    fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
180    where
181        Self::Config: Default,
182    {
183        Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
184    }
185    fn make_shared_with_config(
186        &mut self,
187        config: Self::Config,
188    ) -> Result<Self::Backend, Self::MakeError> {
189        let (tx, rx) = mpsc::channel(config.buffer_size());
190        let mut r = self
191            .registry
192            .try_lock()
193            .ok_or(SharedSqliteError::RegistryLocked)?;
194        if r.insert(config.queue().to_string(), tx).is_some() {
195            return Err(SharedSqliteError::NamespaceExists(
196                config.queue().to_string(),
197            ));
198        }
199        let sink = SqliteSink::new(&self.pool, &config);
200        Ok(SqliteStorage {
201            config,
202            fetcher: SharedFetcher {
203                poller: self.drive.clone(),
204                receiver: Arc::new(std::sync::Mutex::new(rx)),
205            },
206            pool: self.pool.clone(),
207            sink,
208            job_type: PhantomData,
209            codec: PhantomData,
210        })
211    }
212}
213
214#[derive(Clone, Debug)]
215pub struct SharedFetcher<Compact> {
216    poller: Shared<BoxFuture<'static, ()>>,
217    receiver: Arc<std::sync::Mutex<Receiver<SqliteTask<Compact>>>>,
218}
219
220impl<Compact> Stream for SharedFetcher<Compact> {
221    type Item = SqliteTask<Compact>;
222
223    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
224        let this = self.get_mut();
225        // Keep the poller alive by polling it, but ignoring the output
226        let _ = this.poller.poll_unpin(cx);
227
228        // Delegate actual items to receiver
229        this.receiver.lock().unwrap().poll_next_unpin(cx)
230    }
231}
232
233impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
234where
235    Args: Send + 'static + Unpin + Sync,
236    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send + Sync,
237    Decode::Error: std::error::Error + Send + Sync + 'static,
238{
239    type Args = Args;
240
241    type IdType = Ulid;
242
243    type Error = sqlx::Error;
244
245    type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
246
247    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
248
249    type Context = SqliteContext;
250
251    type Layer = Stack<AcknowledgeLayer<SqliteAck>, LockTaskLayer>;
252
253    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
254        let keep_alive_interval = *self.config.keep_alive();
255        let pool = self.pool.clone();
256        let worker = worker.clone();
257        let config = self.config.clone();
258
259        stream::unfold((), move |()| async move {
260            apalis_core::timer::sleep(keep_alive_interval).await;
261            Some(((), ()))
262        })
263        .then(move |_| keep_alive(pool.clone(), config.clone(), worker.clone()))
264        .boxed()
265    }
266
267    fn middleware(&self) -> Self::Layer {
268        let lock = LockTaskLayer::new(self.pool.clone());
269        let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
270        Stack::new(ack, lock)
271    }
272
273    fn poll(self, worker: &WorkerContext) -> Self::Stream {
274        self.poll_shared(worker)
275            .map(|a| match a {
276                Ok(Some(task)) => Ok(Some(
277                    task.try_map(|t| Decode::decode(&t))
278                        .map_err(|e| sqlx::Error::Decode(e.into()))?,
279                )),
280                Ok(None) => Ok(None),
281                Err(e) => Err(e),
282            })
283            .boxed()
284    }
285}
286
287impl<Args, Decode: Send + 'static> BackendExt
288    for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
289where
290    Self: Backend<Args = Args, IdType = Ulid, Context = SqliteContext, Error = sqlx::Error>,
291    Decode: Codec<Args, Compact = CompactType> + Send + 'static,
292    Decode::Error: std::error::Error + Send + Sync + 'static,
293    Args: Send + 'static + Unpin,
294{
295    type Codec = Decode;
296    type Compact = CompactType;
297    type CompactStream = TaskStream<SqliteTask<Self::Compact>, sqlx::Error>;
298
299    fn get_queue(&self) -> Queue {
300        self.config.queue().clone()
301    }
302
303    fn poll_compact(self, worker: &WorkerContext) -> Self::CompactStream {
304        self.poll_shared(worker).boxed()
305    }
306}
307
308impl<Args, Decode: Send + 'static> SqliteStorage<Args, Decode, SharedFetcher<CompactType>> {
309    fn poll_shared(
310        self,
311        worker: &WorkerContext,
312    ) -> impl Stream<Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>> + 'static {
313        let pool = self.pool.clone();
314        let worker = worker.clone();
315        // Initial registration heartbeat
316        // This ensures that the worker is registered before fetching any tasks
317        // This also ensures that the worker is marked as alive in case it crashes
318        // before fetching any tasks
319        // Subsequent heartbeats are handled in the heartbeat stream
320        let init = initial_heartbeat(
321            pool,
322            self.config.clone(),
323            worker.clone(),
324            "SharedSqliteStorage",
325        );
326        let starter = stream::once(init)
327            .map_ok(|_| None) // Noop after initial heartbeat
328            .boxed();
329        let lazy_fetcher = self.fetcher.map(|s| Ok(Some(s))).boxed();
330
331        let eager_fetcher = StreamExt::boxed(SqlitePollFetcher::<CompactType, Decode>::new(
332            &self.pool,
333            &self.config,
334            &worker,
335        ));
336        starter.chain(select(lazy_fetcher, eager_fetcher)).boxed()
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use std::time::Duration;
343
344    use apalis_core::{
345        backend::TaskSink, error::BoxDynError, task::task_id::TaskId,
346        worker::builder::WorkerBuilder,
347    };
348
349    use super::*;
350
351    #[tokio::test]
352    async fn basic_worker() {
353        let mut store = SharedSqliteStorage::new(":memory:");
354        SqliteStorage::setup(store.pool()).await.unwrap();
355
356        let mut map_store = store.make_shared().unwrap();
357
358        let mut int_store = store.make_shared().unwrap();
359
360        map_store
361            .push(HashMap::<String, i32>::from([("value".to_string(), 42)]))
362            .await
363            .unwrap();
364        int_store.push(99).await.unwrap();
365
366        async fn send_reminder<T, I>(
367            _: T,
368            _task_id: TaskId<I>,
369            wrk: WorkerContext,
370        ) -> Result<(), BoxDynError> {
371            tokio::time::sleep(Duration::from_secs(2)).await;
372            wrk.stop().unwrap();
373            Ok(())
374        }
375
376        let int_worker = WorkerBuilder::new("rango-tango-2")
377            .backend(int_store)
378            .build(send_reminder);
379        let map_worker = WorkerBuilder::new("rango-tango-1")
380            .backend(map_store)
381            .build(send_reminder);
382        tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
383    }
384}