apalis_postgres/
fetcher.rs

1use std::{
2    collections::VecDeque,
3    marker::PhantomData,
4    pin::Pin,
5    task::{Context, Poll},
6    time::{Duration, Instant},
7};
8
9use apalis_core::{task::Task, timer::Delay, worker::context::WorkerContext};
10use apalis_sql::from_row::TaskRow;
11use futures::{Future, FutureExt, future::BoxFuture, stream::Stream};
12use pin_project::pin_project;
13
14use sqlx::{PgPool, Pool, Postgres};
15use ulid::Ulid;
16
17use crate::{CompactType, PgContext, PgTask, config::Config, from_row::PgTaskRow};
18
19async fn fetch_next(
20    pool: PgPool,
21    config: Config,
22    worker: WorkerContext,
23) -> Result<Vec<Task<CompactType, PgContext, Ulid>>, sqlx::Error> {
24    let job_type = config.queue().to_string();
25    let buffer_size = config.buffer_size() as i32;
26
27    sqlx::query_file_as!(
28        PgTaskRow,
29        "queries/task/fetch_next.sql",
30        worker.name(),
31        job_type,
32        buffer_size
33    )
34    .fetch_all(&pool)
35    .await?
36    .into_iter()
37    .map(|r| {
38        let row: TaskRow = r.try_into()?;
39        row.try_into_task_compact()
40            .map_err(|e| sqlx::Error::Protocol(e.to_string()))
41    })
42    .collect()
43}
44
45enum StreamState<Args> {
46    Ready,
47    Delay(Delay),
48    Fetch(BoxFuture<'static, Result<Vec<PgTask<Args>>, sqlx::Error>>),
49    Buffered(VecDeque<PgTask<Args>>),
50}
51
52/// Dispatcher for fetching tasks from a PostgreSQL backend via [PgPollFetcher]
53#[derive(Clone, Debug)]
54pub struct PgFetcher<Args, Compact, Decode> {
55    pub _marker: PhantomData<(Args, Compact, Decode)>,
56}
57
58#[pin_project]
59pub struct PgPollFetcher<Compact> {
60    pool: PgPool,
61    config: Config,
62    wrk: WorkerContext,
63    #[pin]
64    state: StreamState<Compact>,
65    current_backoff: Duration,
66    last_fetch_time: Option<Instant>,
67}
68
69impl<Compact> Clone for PgPollFetcher<Compact> {
70    fn clone(&self) -> Self {
71        Self {
72            pool: self.pool.clone(),
73            config: self.config.clone(),
74            wrk: self.wrk.clone(),
75            state: StreamState::Ready,
76            current_backoff: self.current_backoff,
77            last_fetch_time: self.last_fetch_time,
78        }
79    }
80}
81
82impl PgPollFetcher<CompactType> {
83    pub fn new(pool: &Pool<Postgres>, config: &Config, wrk: &WorkerContext) -> Self {
84        let initial_backoff = Duration::from_secs(1);
85        Self {
86            pool: pool.clone(),
87            config: config.clone(),
88            wrk: wrk.clone(),
89            state: StreamState::Ready,
90            current_backoff: initial_backoff,
91            last_fetch_time: None,
92        }
93    }
94}
95
96impl Stream for PgPollFetcher<CompactType> {
97    type Item = Result<Option<PgTask<CompactType>>, sqlx::Error>;
98
99    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100        let this = self.get_mut();
101
102        loop {
103            match this.state {
104                StreamState::Ready => {
105                    let stream =
106                        fetch_next(this.pool.clone(), this.config.clone(), this.wrk.clone());
107                    this.state = StreamState::Fetch(stream.boxed());
108                }
109                StreamState::Delay(ref mut delay) => match Pin::new(delay).poll(cx) {
110                    Poll::Pending => return Poll::Pending,
111                    Poll::Ready(_) => this.state = StreamState::Ready,
112                },
113
114                StreamState::Fetch(ref mut fut) => match fut.poll_unpin(cx) {
115                    Poll::Pending => return Poll::Pending,
116                    Poll::Ready(item) => match item {
117                        Ok(requests) => {
118                            if requests.is_empty() {
119                                let next = this.next_backoff(this.current_backoff);
120                                this.current_backoff = next;
121                                let delay = Delay::new(this.current_backoff);
122                                this.state = StreamState::Delay(delay);
123                            } else {
124                                let mut buffer = VecDeque::new();
125                                for request in requests {
126                                    buffer.push_back(request);
127                                }
128                                this.current_backoff = Duration::from_secs(1);
129                                this.state = StreamState::Buffered(buffer);
130                            }
131                        }
132                        Err(e) => {
133                            let next = this.next_backoff(this.current_backoff);
134                            this.current_backoff = next;
135                            this.state = StreamState::Delay(Delay::new(next));
136                            return Poll::Ready(Some(Err(e)));
137                        }
138                    },
139                },
140
141                StreamState::Buffered(ref mut buffer) => {
142                    if let Some(request) = buffer.pop_front() {
143                        // Yield the next buffered item
144                        if buffer.is_empty() {
145                            // Buffer is now empty, transition to ready for next fetch
146                            this.state = StreamState::Ready;
147                        }
148                        return Poll::Ready(Some(Ok(Some(request))));
149                    } else {
150                        // Buffer is empty, transition to ready
151                        this.state = StreamState::Ready;
152                    }
153                }
154            }
155        }
156    }
157}
158
159impl<Compact> PgPollFetcher<Compact> {
160    fn next_backoff(&self, current: Duration) -> Duration {
161        let doubled = current * 2;
162        std::cmp::min(doubled, Duration::from_secs(60 * 5))
163    }
164
165    #[allow(unused)]
166    pub fn take_pending(&mut self) -> VecDeque<PgTask<Compact>> {
167        match &mut self.state {
168            StreamState::Buffered(tasks) => std::mem::take(tasks),
169            _ => VecDeque::new(),
170        }
171    }
172}