1use std::{
2 collections::VecDeque,
3 marker::PhantomData,
4 pin::Pin,
5 sync::{
6 Arc,
7 atomic::{AtomicUsize, Ordering},
8 },
9 task::{Context, Poll},
10 time::Duration,
11};
12
13use apalis_core::{
14 backend::{
15 TaskStream,
16 codec::Codec,
17 poll_strategy::{PollContext, PollStrategyExt},
18 },
19 task::Task,
20 timer::Delay,
21 worker::context::WorkerContext,
22};
23use futures::{
24 FutureExt, Stream, StreamExt, TryFutureExt,
25 future::{BoxFuture, ready},
26 stream,
27};
28
29use crate::{CompactType, Config, Error, PgContext, PgPool, PgTask, queries};
30
31#[derive(Debug, Clone, Default)]
33pub struct PgNotify;
34
35pub(crate) fn register_then_stream<S>(
46 register: impl Future<Output = Result<Option<PgTask<CompactType>>, Error>> + Send + 'static,
47 body: S,
48) -> TaskStream<PgTask<CompactType>, Error>
49where
50 S: Stream<Item = Result<Option<PgTask<CompactType>>, Error>> + Send + 'static,
51{
52 let mut body_slot = Some(body);
53 stream::once(register)
54 .flat_map(move |res| match res {
55 Ok(none) => {
56 let b = body_slot
57 .take()
58 .expect("registration flat_map invoked twice");
59 stream::once(ready(Ok(none))).chain(b).left_stream()
60 }
61 Err(e) => stream::once(ready(Err(e))).right_stream(),
62 })
63 .boxed()
64}
65
66pub(crate) fn decode_task_stream<Args, Decode>(
70 compact: TaskStream<PgTask<CompactType>, Error>,
71) -> TaskStream<PgTask<Args>, Error>
72where
73 Args: Send + 'static,
74 Decode: Codec<Args, Compact = CompactType> + 'static,
75 Decode::Error: std::error::Error + Send + Sync + 'static,
76{
77 compact
78 .map(|row| match row {
79 Ok(Some(task)) => {
80 Ok(Some(task.try_map(|t| {
81 Decode::decode(&t).map_err(|e| Error::Decode(e.into()))
82 })?))
83 }
84 Ok(None) => Ok(None),
85 Err(error) => Err(error),
86 })
87 .boxed()
88}
89
90impl PgFetcherSource for PgNotify {
91 const STORAGE_NAME: &'static str = "PostgresStorageWithNotify";
92
93 fn into_compact_stream(
94 self,
95 pool: PgPool,
96 config: Config,
97 worker: WorkerContext,
98 lease_token: Arc<str>,
99 ) -> TaskStream<PgTask<CompactType>, Error> {
100 let register_worker = queries::initial_heartbeat(
101 pool.clone(),
102 config.clone(),
103 worker.clone(),
104 Self::STORAGE_NAME,
105 lease_token,
106 )
107 .map_ok(|_| None);
108
109 let lazy_fetcher = queries::batch_ids_into_tasks(
116 pool.clone(),
117 config.queue().to_string(),
118 worker.name().to_owned(),
119 config.buffer_size().max(1),
120 queries::notify_task_ids(
121 pool.clone(),
122 config.queue().to_string(),
123 config.buffer_size().max(1),
124 ),
125 )
126 .boxed();
127
128 let eager_fetcher = PgPollFetcher::<CompactType>::new(&pool, &config, &worker);
129 let combined = futures::stream::select(lazy_fetcher, eager_fetcher);
130 register_then_stream(register_worker, combined)
131 }
132}
133
134pub(crate) trait PgFetcherSource: Sized + Send + 'static {
141 const STORAGE_NAME: &'static str;
142
143 fn into_compact_stream(
144 self,
145 pool: PgPool,
146 config: Config,
147 worker: apalis_core::worker::context::WorkerContext,
148 lease_token: Arc<str>,
149 ) -> TaskStream<PgTask<CompactType>, Error>;
150}
151
152impl<Decode> PgFetcherSource for PgFetcher<CompactType, Decode>
153where
154 Decode: Send + 'static,
155{
156 const STORAGE_NAME: &'static str = crate::STORAGE_NAME;
157
158 fn into_compact_stream(
159 self,
160 pool: PgPool,
161 config: Config,
162 worker: apalis_core::worker::context::WorkerContext,
163 lease_token: Arc<str>,
164 ) -> TaskStream<PgTask<CompactType>, Error> {
165 let register_worker = queries::initial_heartbeat(
166 pool.clone(),
167 config.clone(),
168 worker.clone(),
169 Self::STORAGE_NAME,
170 lease_token,
171 )
172 .map_ok(|_| None);
173 let fetcher = PgPollFetcher::<CompactType>::new(&pool, &config, &worker);
174 register_then_stream(register_worker, fetcher)
175 }
176}
177
178type Poller = Pin<Box<dyn Stream<Item = ()> + Send>>;
179
180enum StreamState<Args> {
181 WaitForPoll(Poller),
182 StrategyEnded(Delay),
183 Fetch(BoxFuture<'static, Result<Vec<PgTask<Args>>, Error>>),
184 Buffered(VecDeque<PgTask<Args>>),
185}
186
187#[derive(Clone, Debug)]
189pub struct PgFetcher<Compact, Decode> {
190 pub _marker: PhantomData<(Compact, Decode)>,
191}
192
193pub(crate) struct PgPollFetcher<Compact> {
195 pool: PgPool,
196 config: Config,
197 worker: WorkerContext,
198 state: StreamState<Compact>,
199 previous_task_count: Arc<AtomicUsize>,
200}
201
202impl<Compact> Clone for PgPollFetcher<Compact> {
203 fn clone(&self) -> Self {
204 let previous_task_count = Arc::new(AtomicUsize::new(0));
205 Self {
206 pool: self.pool.clone(),
207 config: self.config.clone(),
208 worker: self.worker.clone(),
209 state: poll_state(&self.config, &self.worker, previous_task_count.clone()),
210 previous_task_count,
211 }
212 }
213}
214
215impl PgPollFetcher<CompactType> {
216 #[must_use]
218 pub fn new(pool: &PgPool, config: &Config, worker: &WorkerContext) -> Self {
219 let previous_task_count = Arc::new(AtomicUsize::new(0));
220 Self {
221 pool: pool.clone(),
222 config: config.clone(),
223 worker: worker.clone(),
224 state: poll_state(config, worker, previous_task_count.clone()),
225 previous_task_count,
226 }
227 }
228}
229
230const STRATEGY_EXHAUSTED_BACKOFF: Duration = Duration::from_millis(100);
235
236impl PgPollFetcher<CompactType> {
237 fn start_fetch(&self) -> StreamState<CompactType> {
238 StreamState::Fetch(
239 queries::fetch_next(self.pool.clone(), self.config.clone(), self.worker.clone())
240 .boxed(),
241 )
242 }
243}
244
245impl<Compact> PgPollFetcher<Compact> {
246 #[cfg(test)]
249 #[must_use]
250 pub(crate) fn take_pending(&mut self) -> VecDeque<PgTask<Compact>> {
251 match &mut self.state {
252 StreamState::Buffered(tasks) => std::mem::take(tasks),
253 _ => VecDeque::new(),
254 }
255 }
256}
257
258impl Stream for PgPollFetcher<CompactType> {
259 type Item = Result<Option<Task<CompactType, PgContext, ulid::Ulid>>, Error>;
260
261 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
262 let this = self.get_mut();
263
264 loop {
265 match &mut this.state {
266 StreamState::WaitForPoll(poller) => match poller.poll_next_unpin(cx) {
267 Poll::Pending => return Poll::Pending,
268 Poll::Ready(Some(())) => {
269 this.state = this.start_fetch();
270 }
271 Poll::Ready(None) => {
272 this.state =
273 StreamState::StrategyEnded(Delay::new(STRATEGY_EXHAUSTED_BACKOFF));
274 }
275 },
276 StreamState::StrategyEnded(delay) => match Pin::new(delay).poll(cx) {
277 Poll::Pending => return Poll::Pending,
278 Poll::Ready(()) => {
279 this.state = this.start_fetch();
280 }
281 },
282 StreamState::Fetch(fetch) => match fetch.poll_unpin(cx) {
283 Poll::Pending => return Poll::Pending,
284 Poll::Ready(Ok(tasks)) if tasks.is_empty() => {
285 this.previous_task_count.store(0, Ordering::Relaxed);
286 this.state = poll_state(
287 &this.config,
288 &this.worker,
289 this.previous_task_count.clone(),
290 );
291 }
292 Poll::Ready(Ok(tasks)) => {
293 this.previous_task_count
294 .store(tasks.len(), Ordering::Relaxed);
295 this.state = StreamState::Buffered(VecDeque::from(tasks));
296 }
297 Poll::Ready(Err(error)) => {
298 this.previous_task_count.store(0, Ordering::Relaxed);
299 this.state = poll_state(
300 &this.config,
301 &this.worker,
302 this.previous_task_count.clone(),
303 );
304 return Poll::Ready(Some(Err(error)));
305 }
306 },
307 StreamState::Buffered(buffer) => {
308 if let Some(task) = buffer.pop_front() {
309 if buffer.is_empty() {
310 this.state = poll_state(
311 &this.config,
312 &this.worker,
313 this.previous_task_count.clone(),
314 );
315 }
316 return Poll::Ready(Some(Ok(Some(task))));
317 }
318 this.state =
319 poll_state(&this.config, &this.worker, this.previous_task_count.clone());
320 }
321 }
322 }
323 }
324}
325
326fn poll_state<Compact>(
327 config: &Config,
328 worker: &WorkerContext,
329 previous_task_count: Arc<AtomicUsize>,
330) -> StreamState<Compact> {
331 let context = PollContext::new(worker.clone(), previous_task_count);
332 StreamState::WaitForPoll(config.poll_strategy().clone().build_stream(&context))
333}
334
335#[cfg(test)]
336mod tests {
337 use std::{
338 collections::VecDeque,
339 pin::Pin,
340 sync::{
341 Arc,
342 atomic::{AtomicUsize, Ordering},
343 },
344 task::{Context, Poll},
345 time::Duration,
346 };
347
348 use apalis_core::{task::builder::TaskBuilder, worker::context::WorkerContext};
349 use diesel::{
350 PgConnection,
351 r2d2::{ConnectionManager, Pool},
352 };
353 use futures::{FutureExt, future, stream, task::noop_waker_ref};
354 use lets_expect::{AssertionError, AssertionResult, *};
355
356 use super::*;
357
358 struct PollObservation {
359 poll: &'static str,
360 state: &'static str,
361 previous_task_count: usize,
362 }
363
364 fn unchecked_pool() -> PgPool {
365 let manager = ConnectionManager::<PgConnection>::new("postgres://127.0.0.1:1/not-used");
366 Pool::builder()
367 .max_size(1)
368 .connection_timeout(Duration::from_millis(10))
369 .build_unchecked(manager)
370 }
371
372 fn buffered_fetcher() -> PgPollFetcher<CompactType> {
373 PgPollFetcher {
374 pool: unchecked_pool(),
375 config: Config::new("fetcher-test"),
376 worker: WorkerContext::new::<()>("fetcher-worker"),
377 state: StreamState::Buffered(VecDeque::new()),
378 previous_task_count: Arc::new(AtomicUsize::new(12)),
379 }
380 }
381
382 fn state_name(fetcher: &PgPollFetcher<CompactType>) -> &'static str {
383 match &fetcher.state {
384 StreamState::WaitForPoll(_) => "wait_for_poll",
385 StreamState::StrategyEnded(_) => "strategy_ended",
386 StreamState::Fetch(_) => "fetch",
387 StreamState::Buffered(_) => "buffered",
388 }
389 }
390
391 fn poll_observation(fetcher: &mut PgPollFetcher<CompactType>) -> PollObservation {
392 let mut cx = Context::from_waker(noop_waker_ref());
393 let poll = match Pin::new(&mut *fetcher).poll_next(&mut cx) {
394 Poll::Ready(Some(Ok(Some(_)))) => "task",
395 Poll::Ready(Some(Ok(None))) => "empty",
396 Poll::Ready(Some(Err(_))) => "error",
397 Poll::Ready(None) => "closed",
398 Poll::Pending => "pending",
399 };
400 PollObservation {
401 poll,
402 state: state_name(fetcher),
403 previous_task_count: fetcher.previous_task_count.load(Ordering::Relaxed),
404 }
405 }
406
407 fn pending_poll_strategy_observation() -> PollObservation {
408 let mut fetcher = buffered_fetcher();
409 fetcher.state = StreamState::WaitForPoll(Box::pin(stream::pending()));
410 poll_observation(&mut fetcher)
411 }
412
413 fn exhausted_poll_strategy_observation() -> PollObservation {
414 let mut fetcher = buffered_fetcher();
418 fetcher.state = StreamState::WaitForPoll(Box::pin(stream::empty::<()>()));
419 poll_observation(&mut fetcher)
420 }
421
422 fn observed_strategy_exhaustion(result: &PollObservation) -> AssertionResult {
423 match (result.poll, result.state) {
424 ("pending", "strategy_ended") => Ok(()),
428 other => Err(AssertionError::new(vec![format!(
429 "expected exhausted strategy to transition into strategy_ended/pending, got {other:?}"
430 )])),
431 }
432 }
433
434 fn fetch_error_observation() -> PollObservation {
435 let mut fetcher = buffered_fetcher();
436 fetcher.state = StreamState::Fetch(future::ready(Err(Error::SinkBufferFull(1))).boxed());
437 poll_observation(&mut fetcher)
438 }
439
440 fn empty_fetch_observation() -> PollObservation {
441 let mut fetcher = buffered_fetcher();
442 fetcher.state = StreamState::Fetch(future::ready(Ok(Vec::new())).boxed());
443 poll_observation(&mut fetcher)
444 }
445
446 fn successful_fetch_observation() -> PollObservation {
447 let mut fetcher = buffered_fetcher();
448 let task = TaskBuilder::new(vec![1, 2, 3])
449 .with_ctx(PgContext::new())
450 .build();
451 fetcher.state = StreamState::Fetch(future::ready(Ok(vec![task])).boxed());
452 poll_observation(&mut fetcher)
453 }
454
455 fn cloned_state(fetcher: &PgPollFetcher<CompactType>) -> &'static str {
456 match &fetcher.clone().state {
457 StreamState::WaitForPoll(_) => "wait_for_poll",
458 StreamState::StrategyEnded(_) => "strategy_ended",
459 StreamState::Fetch(_) => "fetch",
460 StreamState::Buffered(_) => "buffered",
461 }
462 }
463
464 fn cloned_previous_task_count(fetcher: &PgPollFetcher<CompactType>) -> usize {
465 fetcher.clone().previous_task_count.load(Ordering::Relaxed)
466 }
467
468 fn observed_fetch_error(result: &PollObservation) -> AssertionResult {
469 match (result.poll, result.state, result.previous_task_count) {
470 ("error", "wait_for_poll", 0) => Ok(()),
471 other => Err(AssertionError::new(vec![format!(
472 "expected fetch error to reset the poll strategy, got {other:?}"
473 )])),
474 }
475 }
476
477 fn observed_empty_fetch(result: &PollObservation) -> AssertionResult {
478 match (result.poll, result.state, result.previous_task_count) {
479 ("pending", "wait_for_poll", 0) => Ok(()),
480 other => Err(AssertionError::new(vec![format!(
481 "expected empty fetch to wait for configured polling, got {other:?}"
482 )])),
483 }
484 }
485
486 fn observed_successful_fetch(result: &PollObservation) -> AssertionResult {
487 match (result.poll, result.state, result.previous_task_count) {
488 ("task", "wait_for_poll", 1) => Ok(()),
489 other => Err(AssertionError::new(vec![format!(
490 "expected successful fetch to yield one task and remember the count, got {other:?}"
491 )])),
492 }
493 }
494
495 fn observed_pending_strategy(result: &PollObservation) -> AssertionResult {
496 match (result.poll, result.state, result.previous_task_count) {
497 ("pending", "wait_for_poll", 12) => Ok(()),
498 other => Err(AssertionError::new(vec![format!(
499 "expected pending strategy to prevent a database fetch, got {other:?}"
500 )])),
501 }
502 }
503
504 fn buffered_with(tasks: Vec<PgTask<CompactType>>) -> PgPollFetcher<CompactType> {
505 let mut fetcher = buffered_fetcher();
506 fetcher.state = StreamState::Buffered(VecDeque::from(tasks));
507 fetcher
508 }
509
510 fn synthetic_task(payload: &[u8]) -> PgTask<CompactType> {
511 TaskBuilder::new(payload.to_vec())
512 .with_ctx(PgContext::new())
513 .build()
514 }
515
516 fn take_pending_count(state_kind: &'static str) -> usize {
517 let mut fetcher = match state_kind {
518 "buffered_two" => buffered_with(vec![synthetic_task(b"one"), synthetic_task(b"two")]),
519 "buffered_empty" => buffered_with(Vec::new()),
520 "wait_for_poll" => {
521 let mut fetcher = buffered_fetcher();
522 fetcher.state = StreamState::WaitForPoll(Box::pin(stream::pending()));
523 fetcher
524 }
525 "fetch" => {
526 let mut fetcher = buffered_fetcher();
527 fetcher.state = StreamState::Fetch(future::ready(Ok(Vec::new())).boxed());
528 fetcher
529 }
530 "strategy_ended" => {
531 let mut fetcher = buffered_fetcher();
532 fetcher.state = StreamState::StrategyEnded(Delay::new(Duration::from_secs(60)));
533 fetcher
534 }
535 other => panic!("unknown state kind: {other}"),
536 };
537 fetcher.take_pending().len()
538 }
539
540 fn take_pending_drains_then_reports_empty() -> (usize, &'static str) {
545 let mut fetcher = buffered_with(vec![synthetic_task(b"alpha"), synthetic_task(b"beta")]);
546 let drained = fetcher.take_pending().len();
547 let remaining = match &fetcher.state {
548 StreamState::Buffered(tasks) => tasks.len(),
549 _ => panic!("take_pending changed the state slot"),
550 };
551 let _ = remaining;
552 (drained, state_name(&fetcher))
553 }
554
555 fn buffered_pop_front_observation() -> PollObservation {
556 let mut fetcher = buffered_with(vec![synthetic_task(b"first"), synthetic_task(b"second")]);
557 poll_observation(&mut fetcher)
558 }
559
560 fn observed_buffered_pop_front(result: &PollObservation) -> AssertionResult {
561 match (result.poll, result.state, result.previous_task_count) {
566 ("task", "buffered", 12) => Ok(()),
567 other => Err(AssertionError::new(vec![format!(
568 "expected pop_front to yield a task while remaining buffered, got {other:?}"
569 )])),
570 }
571 }
572
573 fn buffered_drain_observation() -> &'static str {
577 let mut fetcher = buffered_with(vec![synthetic_task(b"only")]);
578 let mut cx = Context::from_waker(noop_waker_ref());
579 let _ = Pin::new(&mut fetcher).poll_next(&mut cx);
580 state_name(&fetcher)
581 }
582
583 lets_expect! {
584 expect(cloned_state(&fetcher)) {
585 let fetcher = buffered_fetcher();
586
587 when original_stream_has_buffered_state {
588 to resets_the_clone_to_poll_strategy { equal("wait_for_poll") }
589 }
590 }
591
592 expect(cloned_previous_task_count(&fetcher)) {
593 let fetcher = buffered_fetcher();
594
595 when original_stream_remembers_a_previous_batch {
596 to starts_the_clone_with_no_previous_count { equal(0) }
597 }
598 }
599
600 expect(pending_poll_strategy_observation()) {
601 when the_configured_poll_strategy_is_not_ready {
602 to does_not_start_a_fetch { observed_pending_strategy }
603 }
604 }
605
606 expect(exhausted_poll_strategy_observation()) {
607 when the_configured_poll_strategy_returns_ready_none {
608 to transitions_into_strategy_ended_and_waits_for_the_delay {
609 observed_strategy_exhaustion
610 }
611 }
612 }
613
614 expect(fetch_error_observation()) {
615 when fetch_query_fails {
616 to yields_the_error_and_waits_for_the_next_poll_signal { observed_fetch_error }
617 }
618 }
619
620 expect(empty_fetch_observation()) {
621 when fetch_returns_no_tasks {
622 to waits_for_the_next_configured_poll_signal { observed_empty_fetch }
623 }
624 }
625
626 expect(successful_fetch_observation()) {
627 when fetch_returns_tasks {
628 to yields_a_task_and_records_the_batch_size { observed_successful_fetch }
629 }
630 }
631
632 expect(take_pending_count(state_kind)) {
633 let state_kind = "buffered_two";
634
635 when fetcher_is_in_buffered_state_with_two_tasks {
636 to drains_every_buffered_task { equal(2) }
637 }
638
639 when fetcher_is_in_buffered_state_with_no_tasks {
640 let state_kind = "buffered_empty";
641 to returns_an_empty_drained_queue { equal(0) }
642 }
643
644 when fetcher_is_in_wait_for_poll_state {
645 let state_kind = "wait_for_poll";
646 to ignores_states_other_than_buffered { equal(0) }
647 }
648
649 when fetcher_is_in_fetch_state {
650 let state_kind = "fetch";
651 to ignores_states_other_than_buffered { equal(0) }
652 }
653
654 when fetcher_is_in_strategy_ended_state {
655 let state_kind = "strategy_ended";
656 to ignores_states_other_than_buffered { equal(0) }
657 }
658 }
659
660 expect(take_pending_drains_then_reports_empty()) {
661 when buffered_state_is_drained_via_take_pending {
662 to leaves_the_fetcher_in_the_buffered_state_with_zero_tasks {
663 equal((2, "buffered"))
664 }
665 }
666 }
667
668 expect(buffered_pop_front_observation()) {
669 when buffer_holds_multiple_tasks {
670 to pops_a_task_and_stays_in_buffered { observed_buffered_pop_front }
671 }
672 }
673
674 expect(buffered_drain_observation()) {
675 when buffer_holds_exactly_one_task {
676 to transitions_to_wait_for_poll_after_emitting_the_task {
677 equal("wait_for_poll")
678 }
679 }
680 }
681 }
682}