apalis_sqlite/
fetcher.rs

1use std::{
2    collections::VecDeque,
3    marker::PhantomData,
4    pin::Pin,
5    sync::{Arc, atomic::AtomicUsize},
6    task::{Context, Poll},
7};
8
9use apalis_core::{
10    backend::{
11        codec::Codec,
12        poll_strategy::{PollContext, PollStrategyExt},
13    },
14    task::Task,
15    worker::context::WorkerContext,
16};
17use apalis_sql::{context::SqlContext, from_row::TaskRow};
18use futures::{FutureExt, future::BoxFuture, stream::Stream};
19use pin_project::pin_project;
20use sqlx::{Pool, Sqlite, SqlitePool};
21use ulid::Ulid;
22
23use crate::{CompactType, SqliteTask, config::Config, from_row::SqliteTaskRow};
24
25pub async fn fetch_next<Args, D: Codec<Args, Compact = CompactType>>(
26    pool: SqlitePool,
27    config: Config,
28    worker: WorkerContext,
29) -> Result<Vec<Task<Args, SqlContext, Ulid>>, sqlx::Error>
30where
31    D::Error: std::error::Error + Send + Sync + 'static,
32    Args: 'static,
33{
34    let job_type = config.queue().to_string();
35    let buffer_size = config.buffer_size() as i32;
36    let worker = worker.name().to_string();
37    sqlx::query_file_as!(
38        SqliteTaskRow,
39        "queries/backend/fetch_next.sql",
40        worker,
41        job_type,
42        buffer_size
43    )
44    .fetch_all(&pool)
45    .await?
46    .into_iter()
47    .map(|r| {
48        let row: TaskRow = r.try_into()?;
49        row.try_into_task::<D, Args, Ulid>()
50            .map_err(|e| sqlx::Error::Protocol(e.to_string()))
51    })
52    .collect()
53}
54
55enum StreamState<Args> {
56    Ready,
57    Delay,
58    Fetch(BoxFuture<'static, Result<Vec<SqliteTask<Args>>, sqlx::Error>>),
59    Buffered(VecDeque<SqliteTask<Args>>),
60    Empty,
61}
62
63/// Dispatcher for fetching tasks from a SQLite backend via [SqlitePollFetcher]
64#[derive(Clone, Debug)]
65pub struct SqliteFetcher<Args, Compact, Decode> {
66    pub _marker: PhantomData<(Args, Compact, Decode)>,
67}
68
69/// Polling-based fetcher for retrieving tasks from a SQLite backend
70#[pin_project]
71pub struct SqlitePollFetcher<Args, Compact, Decode> {
72    pool: SqlitePool,
73    config: Config,
74    wrk: WorkerContext,
75    _marker: PhantomData<(Compact, Decode)>,
76    #[pin]
77    state: StreamState<Args>,
78
79    #[pin]
80    delay_stream: Option<Pin<Box<dyn Stream<Item = ()> + Send>>>,
81
82    prev_count: Arc<AtomicUsize>,
83}
84
85impl<Args, Compact, Decode> Clone for SqlitePollFetcher<Args, Compact, Decode> {
86    fn clone(&self) -> Self {
87        Self {
88            pool: self.pool.clone(),
89            config: self.config.clone(),
90            wrk: self.wrk.clone(),
91            _marker: PhantomData,
92            state: StreamState::Ready,
93            delay_stream: None,
94            prev_count: Arc::new(AtomicUsize::new(0)),
95        }
96    }
97}
98
99impl<Args: 'static, Decode> SqlitePollFetcher<Args, CompactType, Decode> {
100    pub fn new(pool: &Pool<Sqlite>, config: &Config, wrk: &WorkerContext) -> Self
101    where
102        Decode: Codec<Args, Compact = CompactType> + 'static,
103        Decode::Error: std::error::Error + Send + Sync + 'static,
104    {
105        Self {
106            pool: pool.clone(),
107            config: config.clone(),
108            wrk: wrk.clone(),
109            _marker: PhantomData,
110            state: StreamState::Ready,
111            delay_stream: None,
112            prev_count: Arc::new(AtomicUsize::new(0)),
113        }
114    }
115}
116
117impl<Args, Decode> Stream for SqlitePollFetcher<Args, CompactType, Decode>
118where
119    Decode::Error: std::error::Error + Send + Sync + 'static,
120    Args: Send + 'static + Unpin,
121    Decode: Codec<Args, Compact = CompactType> + 'static,
122{
123    type Item = Result<Option<SqliteTask<Args>>, sqlx::Error>;
124
125    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
126        let this = self.get_mut();
127        if this.delay_stream.is_none() {
128            let strategy = this
129                .config
130                .poll_strategy()
131                .clone()
132                .build_stream(&PollContext::new(this.wrk.clone(), this.prev_count.clone()));
133            this.delay_stream = Some(Box::pin(strategy));
134        }
135
136        loop {
137            match this.state {
138                StreamState::Ready => {
139                    let stream = fetch_next::<Args, Decode>(
140                        this.pool.clone(),
141                        this.config.clone(),
142                        this.wrk.clone(),
143                    );
144                    this.state = StreamState::Fetch(stream.boxed());
145                }
146                StreamState::Delay => {
147                    if let Some(delay_stream) = this.delay_stream.as_mut() {
148                        match delay_stream.as_mut().poll_next(cx) {
149                            Poll::Pending => return Poll::Pending,
150                            Poll::Ready(Some(_)) => {
151                                this.state = StreamState::Ready;
152                            }
153                            Poll::Ready(None) => {
154                                this.state = StreamState::Empty;
155                                return Poll::Ready(None);
156                            }
157                        }
158                    } else {
159                        this.state = StreamState::Empty;
160                        return Poll::Ready(None);
161                    }
162                }
163
164                StreamState::Fetch(ref mut fut) => match fut.poll_unpin(cx) {
165                    Poll::Pending => return Poll::Pending,
166                    Poll::Ready(item) => match item {
167                        Ok(requests) => {
168                            if requests.is_empty() {
169                                this.state = StreamState::Delay;
170                            } else {
171                                let mut buffer = VecDeque::new();
172                                for request in requests {
173                                    buffer.push_back(request);
174                                }
175
176                                this.state = StreamState::Buffered(buffer);
177                            }
178                        }
179                        Err(e) => {
180                            this.state = StreamState::Empty;
181                            return Poll::Ready(Some(Err(e)));
182                        }
183                    },
184                },
185
186                StreamState::Buffered(ref mut buffer) => {
187                    if let Some(request) = buffer.pop_front() {
188                        // Yield the next buffered item
189                        if buffer.is_empty() {
190                            // Buffer is now empty, transition to ready for next fetch
191                            this.state = StreamState::Ready;
192                        }
193                        return Poll::Ready(Some(Ok(Some(request))));
194                    } else {
195                        // Buffer is empty, transition to ready
196                        this.state = StreamState::Ready;
197                    }
198                }
199
200                StreamState::Empty => return Poll::Ready(None),
201            }
202        }
203    }
204}
205
206impl<Args, Compact, Decode> SqlitePollFetcher<Args, Compact, Decode> {
207    pub fn take_pending(&mut self) -> VecDeque<SqliteTask<Args>> {
208        match &mut self.state {
209            StreamState::Buffered(tasks) => std::mem::take(tasks),
210            _ => VecDeque::new(),
211        }
212    }
213}