apalis_postgres/
fetcher.rs1use std::{
2 collections::VecDeque,
3 marker::PhantomData,
4 pin::Pin,
5 task::{Context, Poll},
6 time::{Duration, Instant},
7};
8
9use apalis_core::{task::Task, timer::Delay, worker::context::WorkerContext};
10use apalis_sql::from_row::TaskRow;
11use futures::{Future, FutureExt, future::BoxFuture, stream::Stream};
12use pin_project::pin_project;
13
14use sqlx::{PgPool, Pool, Postgres};
15use ulid::Ulid;
16
17use crate::{CompactType, PgContext, PgTask, config::Config, from_row::PgTaskRow};
18
19async fn fetch_next(
20 pool: PgPool,
21 config: Config,
22 worker: WorkerContext,
23) -> Result<Vec<Task<CompactType, PgContext, Ulid>>, sqlx::Error> {
24 let job_type = config.queue().to_string();
25 let buffer_size = config.buffer_size() as i32;
26
27 sqlx::query_file_as!(
28 PgTaskRow,
29 "queries/task/fetch_next.sql",
30 worker.name(),
31 job_type,
32 buffer_size
33 )
34 .fetch_all(&pool)
35 .await?
36 .into_iter()
37 .map(|r| {
38 let row: TaskRow = r.try_into()?;
39 row.try_into_task_compact()
40 .map_err(|e| sqlx::Error::Protocol(e.to_string()))
41 })
42 .collect()
43}
44
45enum StreamState<Args> {
46 Ready,
47 Delay(Delay),
48 Fetch(BoxFuture<'static, Result<Vec<PgTask<Args>>, sqlx::Error>>),
49 Buffered(VecDeque<PgTask<Args>>),
50}
51
52#[derive(Clone, Debug)]
54pub struct PgFetcher<Args, Compact, Decode> {
55 pub _marker: PhantomData<(Args, Compact, Decode)>,
56}
57
58#[pin_project]
59pub struct PgPollFetcher<Compact> {
60 pool: PgPool,
61 config: Config,
62 wrk: WorkerContext,
63 #[pin]
64 state: StreamState<Compact>,
65 current_backoff: Duration,
66 last_fetch_time: Option<Instant>,
67}
68
69impl<Compact> Clone for PgPollFetcher<Compact> {
70 fn clone(&self) -> Self {
71 Self {
72 pool: self.pool.clone(),
73 config: self.config.clone(),
74 wrk: self.wrk.clone(),
75 state: StreamState::Ready,
76 current_backoff: self.current_backoff,
77 last_fetch_time: self.last_fetch_time,
78 }
79 }
80}
81
82impl PgPollFetcher<CompactType> {
83 pub fn new(pool: &Pool<Postgres>, config: &Config, wrk: &WorkerContext) -> Self {
84 let initial_backoff = Duration::from_secs(1);
85 Self {
86 pool: pool.clone(),
87 config: config.clone(),
88 wrk: wrk.clone(),
89 state: StreamState::Ready,
90 current_backoff: initial_backoff,
91 last_fetch_time: None,
92 }
93 }
94}
95
96impl Stream for PgPollFetcher<CompactType> {
97 type Item = Result<Option<PgTask<CompactType>>, sqlx::Error>;
98
99 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
100 let this = self.get_mut();
101
102 loop {
103 match this.state {
104 StreamState::Ready => {
105 let stream =
106 fetch_next(this.pool.clone(), this.config.clone(), this.wrk.clone());
107 this.state = StreamState::Fetch(stream.boxed());
108 }
109 StreamState::Delay(ref mut delay) => match Pin::new(delay).poll(cx) {
110 Poll::Pending => return Poll::Pending,
111 Poll::Ready(_) => this.state = StreamState::Ready,
112 },
113
114 StreamState::Fetch(ref mut fut) => match fut.poll_unpin(cx) {
115 Poll::Pending => return Poll::Pending,
116 Poll::Ready(item) => match item {
117 Ok(requests) => {
118 if requests.is_empty() {
119 let next = this.next_backoff(this.current_backoff);
120 this.current_backoff = next;
121 let delay = Delay::new(this.current_backoff);
122 this.state = StreamState::Delay(delay);
123 } else {
124 let mut buffer = VecDeque::new();
125 for request in requests {
126 buffer.push_back(request);
127 }
128 this.current_backoff = Duration::from_secs(1);
129 this.state = StreamState::Buffered(buffer);
130 }
131 }
132 Err(e) => {
133 let next = this.next_backoff(this.current_backoff);
134 this.current_backoff = next;
135 this.state = StreamState::Delay(Delay::new(next));
136 return Poll::Ready(Some(Err(e)));
137 }
138 },
139 },
140
141 StreamState::Buffered(ref mut buffer) => {
142 if let Some(request) = buffer.pop_front() {
143 if buffer.is_empty() {
145 this.state = StreamState::Ready;
147 }
148 return Poll::Ready(Some(Ok(Some(request))));
149 } else {
150 this.state = StreamState::Ready;
152 }
153 }
154 }
155 }
156 }
157}
158
159impl<Compact> PgPollFetcher<Compact> {
160 fn next_backoff(&self, current: Duration) -> Duration {
161 let doubled = current * 2;
162 std::cmp::min(doubled, Duration::from_secs(60 * 5))
163 }
164
165 #[allow(unused)]
166 pub fn take_pending(&mut self) -> VecDeque<PgTask<Compact>> {
167 match &mut self.state {
168 StreamState::Buffered(tasks) => std::mem::take(tasks),
169 _ => VecDeque::new(),
170 }
171 }
172}