apalis_sqlite/
fetcher.rs

1use std::{
2    collections::VecDeque,
3    marker::PhantomData,
4    pin::Pin,
5    sync::{Arc, atomic::AtomicUsize},
6    task::{Context, Poll},
7};
8
9use apalis_core::{
10    backend::poll_strategy::{PollContext, PollStrategyExt},
11    task::Task,
12    worker::context::WorkerContext,
13};
14use apalis_sql::{context::SqlContext, from_row::TaskRow};
15use futures::{FutureExt, future::BoxFuture, stream::Stream};
16use pin_project::pin_project;
17use sqlx::{Pool, Sqlite, SqlitePool};
18use ulid::Ulid;
19
20use crate::{CompactType, SqliteTask, config::Config, from_row::SqliteTaskRow};
21
22/// Fetch the next batch of tasks from the sqlite backend
23pub async fn fetch_next(
24    pool: SqlitePool,
25    config: Config,
26    worker: WorkerContext,
27) -> Result<Vec<Task<CompactType, SqlContext, Ulid>>, sqlx::Error>
28where
29{
30    let job_type = config.queue().to_string();
31    let buffer_size = config.buffer_size() as i32;
32    let worker = worker.name().clone();
33    sqlx::query_file_as!(
34        SqliteTaskRow,
35        "queries/backend/fetch_next.sql",
36        worker,
37        job_type,
38        buffer_size
39    )
40    .fetch_all(&pool)
41    .await?
42    .into_iter()
43    .map(|r| {
44        let row: TaskRow = r.try_into()?;
45        row.try_into_task_compact::<Ulid>()
46            .map_err(|e| sqlx::Error::Protocol(e.to_string()))
47    })
48    .collect()
49}
50
51enum StreamState {
52    Ready,
53    Delay,
54    Fetch(BoxFuture<'static, Result<Vec<SqliteTask<CompactType>>, sqlx::Error>>),
55    Buffered(VecDeque<SqliteTask<CompactType>>),
56    Empty,
57}
58
59/// Dispatcher for fetching tasks from a SQLite backend via [SqlitePollFetcher]
60#[derive(Clone, Debug)]
61pub struct SqliteFetcher;
62/// Polling-based fetcher for retrieving tasks from a SQLite backend
63#[pin_project]
64pub struct SqlitePollFetcher<Compact, Decode> {
65    pool: SqlitePool,
66    config: Config,
67    wrk: WorkerContext,
68    _marker: PhantomData<(Compact, Decode)>,
69    #[pin]
70    state: StreamState,
71
72    #[pin]
73    delay_stream: Option<Pin<Box<dyn Stream<Item = ()> + Send>>>,
74
75    prev_count: Arc<AtomicUsize>,
76}
77
78impl<Compact, Decode> std::fmt::Debug for SqlitePollFetcher<Compact, Decode> {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        f.debug_struct("SqlitePollFetcher")
81            .field("pool", &self.pool)
82            .field("config", &self.config)
83            .field("wrk", &self.wrk)
84            .field("_marker", &self._marker)
85            .field("prev_count", &self.prev_count)
86            .finish()
87    }
88}
89
90impl<Compact, Decode> Clone for SqlitePollFetcher<Compact, Decode> {
91    fn clone(&self) -> Self {
92        Self {
93            pool: self.pool.clone(),
94            config: self.config.clone(),
95            wrk: self.wrk.clone(),
96            _marker: PhantomData,
97            state: StreamState::Ready,
98            delay_stream: None,
99            prev_count: Arc::new(AtomicUsize::new(0)),
100        }
101    }
102}
103
104impl<Decode> SqlitePollFetcher<CompactType, Decode> {
105    /// Create a new SqlitePollFetcher
106    #[must_use]
107    pub fn new(pool: &Pool<Sqlite>, config: &Config, wrk: &WorkerContext) -> Self {
108        Self {
109            pool: pool.clone(),
110            config: config.clone(),
111            wrk: wrk.clone(),
112            _marker: PhantomData,
113            state: StreamState::Ready,
114            delay_stream: None,
115            prev_count: Arc::new(AtomicUsize::new(0)),
116        }
117    }
118}
119
120impl<Decode> Stream for SqlitePollFetcher<CompactType, Decode> {
121    type Item = Result<Option<SqliteTask<CompactType>>, sqlx::Error>;
122
123    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
124        let this = self.get_mut();
125        if this.delay_stream.is_none() {
126            let strategy = this
127                .config
128                .poll_strategy()
129                .clone()
130                .build_stream(&PollContext::new(this.wrk.clone(), this.prev_count.clone()));
131            this.delay_stream = Some(Box::pin(strategy));
132        }
133
134        loop {
135            match this.state {
136                StreamState::Ready => {
137                    let stream =
138                        fetch_next(this.pool.clone(), this.config.clone(), this.wrk.clone());
139                    this.state = StreamState::Fetch(stream.boxed());
140                }
141                StreamState::Delay => {
142                    if let Some(delay_stream) = this.delay_stream.as_mut() {
143                        match delay_stream.as_mut().poll_next(cx) {
144                            Poll::Pending => return Poll::Pending,
145                            Poll::Ready(Some(_)) => {
146                                this.state = StreamState::Ready;
147                            }
148                            Poll::Ready(None) => {
149                                this.state = StreamState::Empty;
150                                return Poll::Ready(None);
151                            }
152                        }
153                    } else {
154                        this.state = StreamState::Empty;
155                        return Poll::Ready(None);
156                    }
157                }
158
159                StreamState::Fetch(ref mut fut) => match fut.poll_unpin(cx) {
160                    Poll::Pending => return Poll::Pending,
161                    Poll::Ready(item) => match item {
162                        Ok(requests) => {
163                            if requests.is_empty() {
164                                this.state = StreamState::Delay;
165                            } else {
166                                let mut buffer = VecDeque::new();
167                                for request in requests {
168                                    buffer.push_back(request);
169                                }
170
171                                this.state = StreamState::Buffered(buffer);
172                            }
173                        }
174                        Err(e) => {
175                            this.state = StreamState::Empty;
176                            return Poll::Ready(Some(Err(e)));
177                        }
178                    },
179                },
180
181                StreamState::Buffered(ref mut buffer) => {
182                    if let Some(request) = buffer.pop_front() {
183                        // Yield the next buffered item
184                        if buffer.is_empty() {
185                            // Buffer is now empty, transition to ready for next fetch
186                            this.state = StreamState::Ready;
187                        }
188                        return Poll::Ready(Some(Ok(Some(request))));
189                    } else {
190                        // Buffer is empty, transition to ready
191                        this.state = StreamState::Ready;
192                    }
193                }
194
195                StreamState::Empty => return Poll::Ready(None),
196            }
197        }
198    }
199}
200
201impl<Compact, Decode> SqlitePollFetcher<Compact, Decode> {
202    /// Take pending tasks from the fetcher
203    pub fn take_pending(&mut self) -> VecDeque<SqliteTask<Vec<u8>>> {
204        match &mut self.state {
205            StreamState::Buffered(tasks) => std::mem::take(tasks),
206            _ => VecDeque::new(),
207        }
208    }
209}