apalis_sqlite/
shared.rs

1use std::{
2    cmp::max,
3    collections::{HashMap, HashSet},
4    ffi::c_void,
5    future::ready,
6    marker::PhantomData,
7    pin::Pin,
8    sync::Arc,
9    task::{Context, Poll},
10};
11
12use crate::{
13    CompactType, Config, INSERT_OPERATION, JOBS_TABLE, SqliteStorage, SqliteTask,
14    ack::{LockTaskLayer, SqliteAck},
15    callback::{DbEvent, update_hook_callback},
16    fetcher::SqlitePollFetcher,
17    initial_heartbeat, keep_alive,
18};
19use crate::{from_row::SqliteTaskRow, sink::SqliteSink};
20use apalis_core::{
21    backend::{
22        Backend, TaskStream,
23        codec::{Codec, json::JsonCodec},
24        shared::MakeShared,
25    },
26    layers::Stack,
27    worker::{context::WorkerContext, ext::ack::AcknowledgeLayer},
28};
29use apalis_sql::{context::SqlContext, from_row::TaskRow};
30use futures::{
31    FutureExt, SinkExt, Stream, StreamExt, TryStreamExt,
32    channel::mpsc::{self, Receiver, Sender},
33    future::{BoxFuture, Shared},
34    lock::Mutex,
35    stream::{self, BoxStream, select},
36};
37use libsqlite3_sys::{sqlite3, sqlite3_update_hook};
38use sqlx::SqlitePool;
39use ulid::Ulid;
40
41pub struct SharedSqliteStorage<Decode> {
42    pool: SqlitePool,
43    registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>>,
44    drive: Shared<BoxFuture<'static, ()>>,
45    _marker: PhantomData<Decode>,
46}
47
48impl SharedSqliteStorage<JsonCodec<CompactType>> {
49    pub fn new(pool: SqlitePool) -> Self {
50        SharedSqliteStorage::new_with_codec(pool)
51    }
52    pub fn new_with_codec<Codec>(pool: SqlitePool) -> SharedSqliteStorage<Codec> {
53        let registry: Arc<Mutex<HashMap<String, Sender<SqliteTask<CompactType>>>>> =
54            Arc::new(Mutex::new(HashMap::default()));
55        let p = pool.clone();
56        let instances = registry.clone();
57        SharedSqliteStorage {
58            pool,
59            drive: async move {
60                let (tx, rx) = mpsc::unbounded::<DbEvent>();
61
62                let mut conn = p.acquire().await.unwrap();
63                // Get raw sqlite3* handle
64                let handle: *mut sqlite3 =
65                    conn.lock_handle().await.unwrap().as_raw_handle().as_ptr();
66
67                // Put sender in a Box so it has a stable memory address
68                let tx_box = Box::new(tx);
69                let tx_ptr = Box::into_raw(tx_box) as *mut c_void;
70
71                unsafe {
72                    sqlite3_update_hook(handle, Some(update_hook_callback), tx_ptr);
73                }
74
75                rx.filter(|a| {
76                    ready(a.operation() == INSERT_OPERATION && a.table_name() == JOBS_TABLE)
77                })
78                .ready_chunks(instances.try_lock().map(|r| r.len()).unwrap_or(10))
79                .then(|events| {
80                    let row_ids = events.iter().map(|e| e.rowid()).collect::<HashSet<i64>>();
81                    let instances = instances.clone();
82                    let pool = p.clone();
83                    async move {
84                        let instances = instances.lock().await;
85                        let job_types = serde_json::to_string(
86                            &instances.keys().cloned().collect::<Vec<String>>(),
87                        )
88                        .unwrap();
89                        let row_ids = serde_json::to_string(&row_ids).unwrap();
90                        let mut tx = pool.begin().await?;
91                        let buffer_size = max(10, instances.len()) as i32;
92                        let res: Vec<_> = sqlx::query_file_as!(
93                            SqliteTaskRow,
94                            "queries/backend/fetch_next_shared.sql",
95                            job_types,
96                            row_ids,
97                            buffer_size,
98                        )
99                        .fetch(&mut *tx)
100                        .map(|r| {
101                            let row: TaskRow = r?.try_into()?;
102                            row.try_into_task_compact()
103                                .map_err(|e| sqlx::Error::Protocol(e.to_string()))
104                        })
105                        .try_collect()
106                        .await?;
107                        tx.commit().await?;
108                        Ok::<_, sqlx::Error>(res)
109                    }
110                })
111                .for_each(|r| async {
112                    match r {
113                        Ok(tasks) => {
114                            let mut instances = instances.lock().await;
115                            for task in tasks {
116                                if let Some(tx) = instances.get_mut(
117                                    task.parts
118                                        .ctx
119                                        .queue()
120                                        .as_ref()
121                                        .expect("Namespace must be set"),
122                                ) {
123                                    let _ = tx.send(task).await;
124                                }
125                            }
126                        }
127                        Err(e) => {
128                            log::error!("Error fetching tasks: {e:?}");
129                        }
130                    }
131                })
132                .await;
133            }
134            .boxed()
135            .shared(),
136            registry,
137            _marker: PhantomData,
138        }
139    }
140}
141#[derive(Debug, thiserror::Error)]
142pub enum SharedPostgresError {
143    #[error("Namespace {0} already exists")]
144    NamespaceExists(String),
145    #[error("Could not acquire registry lock")]
146    RegistryLocked,
147}
148
149impl<Args, Decode: Codec<Args, Compact = CompactType>> MakeShared<Args>
150    for SharedSqliteStorage<Decode>
151{
152    type Backend = SqliteStorage<Args, Decode, SharedFetcher<CompactType>>;
153    type Config = Config;
154    type MakeError = SharedPostgresError;
155    fn make_shared(&mut self) -> Result<Self::Backend, Self::MakeError>
156    where
157        Self::Config: Default,
158    {
159        Self::make_shared_with_config(self, Config::new(std::any::type_name::<Args>()))
160    }
161    fn make_shared_with_config(
162        &mut self,
163        config: Self::Config,
164    ) -> Result<Self::Backend, Self::MakeError> {
165        let (tx, rx) = mpsc::channel(config.buffer_size());
166        let mut r = self
167            .registry
168            .try_lock()
169            .ok_or(SharedPostgresError::RegistryLocked)?;
170        if r.insert(config.queue().to_string(), tx).is_some() {
171            return Err(SharedPostgresError::NamespaceExists(
172                config.queue().to_string(),
173            ));
174        }
175        let sink = SqliteSink::new(&self.pool, &config);
176        Ok(SqliteStorage {
177            config,
178            fetcher: SharedFetcher {
179                poller: self.drive.clone(),
180                receiver: rx,
181            },
182            pool: self.pool.clone(),
183            sink,
184            job_type: PhantomData,
185            codec: PhantomData,
186        })
187    }
188}
189
190pub struct SharedFetcher<Compact> {
191    poller: Shared<BoxFuture<'static, ()>>,
192    receiver: Receiver<SqliteTask<Compact>>,
193}
194
195impl<Compact> Stream for SharedFetcher<Compact> {
196    type Item = SqliteTask<Compact>;
197
198    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
199        let this = self.get_mut();
200        // Keep the poller alive by polling it, but ignoring the output
201        let _ = this.poller.poll_unpin(cx);
202
203        // Delegate actual items to receiver
204        this.receiver.poll_next_unpin(cx)
205    }
206}
207
208impl<Args, Decode> Backend for SqliteStorage<Args, Decode, SharedFetcher<CompactType>>
209where
210    Args: Send + 'static + Unpin + Sync,
211    Decode: Codec<Args, Compact = CompactType> + 'static + Unpin + Send + Sync,
212    Decode::Error: std::error::Error + Send + Sync + 'static,
213{
214    type Args = Args;
215
216    type IdType = Ulid;
217
218    type Error = sqlx::Error;
219
220    type Stream = TaskStream<SqliteTask<Args>, sqlx::Error>;
221
222    type Beat = BoxStream<'static, Result<(), sqlx::Error>>;
223
224    type Codec = Decode;
225
226    type Compact = CompactType;
227
228    type Context = SqlContext;
229
230    type Layer = Stack<LockTaskLayer, AcknowledgeLayer<SqliteAck>>;
231
232    fn heartbeat(&self, worker: &WorkerContext) -> Self::Beat {
233        let keep_alive_interval = *self.config.keep_alive();
234        let pool = self.pool.clone();
235        let worker = worker.clone();
236        let config = self.config.clone();
237
238        stream::unfold((), move |()| async move {
239            apalis_core::timer::sleep(keep_alive_interval).await;
240            Some(((), ()))
241        })
242        .then(move |_| keep_alive(pool.clone(), config.clone(), worker.clone()))
243        .boxed()
244    }
245
246    fn middleware(&self) -> Self::Layer {
247        let lock = LockTaskLayer::new(self.pool.clone());
248        let ack = AcknowledgeLayer::new(SqliteAck::new(self.pool.clone()));
249        Stack::new(lock, ack)
250    }
251
252    fn poll(self, worker: &WorkerContext) -> Self::Stream {
253        let pool = self.pool.clone();
254        let worker = worker.clone();
255        // Initial registration heartbeat
256        // This ensures that the worker is registered before fetching any tasks
257        // This also ensures that the worker is marked as alive in case it crashes
258        // before fetching any tasks
259        // Subsequent heartbeats are handled in the heartbeat stream
260        let init = initial_heartbeat(
261            pool.clone(),
262            self.config.clone(),
263            worker.clone(),
264            "SharedSqliteStorage",
265        );
266        let starter = stream::once(init)
267            .map_ok(|_| None) // Noop after initial heartbeat
268            .boxed();
269        let lazy_fetcher = self
270            .fetcher
271            .map(|t| {
272                t.try_map(|args| Decode::decode(&args).map_err(|e| sqlx::Error::Decode(e.into())))
273            })
274            .flat_map(|vec| match vec {
275                Ok(task) => stream::iter(vec![Ok(Some(task))]).boxed(),
276                Err(e) => stream::once(ready(Err(e))).boxed(),
277            })
278            .boxed();
279
280        let eager_fetcher = StreamExt::boxed(SqlitePollFetcher::<Args, CompactType, Decode>::new(
281            &self.pool,
282            &self.config,
283            &worker,
284        ));
285        starter.chain(select(lazy_fetcher, eager_fetcher)).boxed()
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use std::time::Duration;
292
293    use apalis_core::{
294        backend::TaskSink, error::BoxDynError, task::task_id::TaskId,
295        worker::builder::WorkerBuilder,
296    };
297
298    use super::*;
299
300    #[tokio::test]
301    async fn basic_worker() {
302        let pool = SqlitePool::connect(":memory:").await.unwrap();
303
304        SqliteStorage::setup(&pool).await.unwrap();
305
306        let mut store = SharedSqliteStorage::new(pool);
307
308        let mut map_store = store.make_shared().unwrap();
309
310        let mut int_store = store.make_shared().unwrap();
311
312        map_store
313            .push(HashMap::<String, i32>::from([("value".to_string(), 42)]))
314            .await
315            .unwrap();
316        int_store.push(99).await.unwrap();
317
318        async fn send_reminder<T, I>(
319            _: T,
320            _task_id: TaskId<I>,
321            wrk: WorkerContext,
322        ) -> Result<(), BoxDynError> {
323            tokio::time::sleep(Duration::from_secs(2)).await;
324            wrk.stop().unwrap();
325            Ok(())
326        }
327
328        let int_worker = WorkerBuilder::new("rango-tango-2")
329            .backend(int_store)
330            .build(send_reminder);
331        let map_worker = WorkerBuilder::new("rango-tango-1")
332            .backend(map_store)
333            .build(send_reminder);
334        tokio::try_join!(int_worker.run(), map_worker.run()).unwrap();
335    }
336}