Skip to main content

apalis_diesel_postgres/
sink.rs

1use std::{
2    future::Future,
3    marker::PhantomData,
4    pin::Pin,
5    sync::Mutex,
6    task::{Context, Poll},
7};
8
9use apalis_codec::json::JsonCodec;
10use futures::{FutureExt, Sink};
11
12use crate::{CompactType, Config, Error, PgPool, PgTask, PostgresStorage, queries};
13
14// Wrapped in `Mutex` upstream so `PgSink: Sync` even when the inner future
15// isn't (ntex's `BlockingResult` is `Send`-only). `Mutex::get_mut` keeps the
16// hot path lock-free.
17type FlushFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'static>>;
18
19/// Buffered task sink used by [`PostgresStorage`].
20pub struct PgSink<Args, Codec = JsonCodec<CompactType>> {
21    pool: PgPool,
22    config: Config,
23    buffer: Vec<PgTask<CompactType>>,
24    flush_future: Mutex<Option<FlushFuture>>,
25    _marker: PhantomData<(Args, Codec)>,
26}
27
28impl<Args, Codec> std::fmt::Debug for PgSink<Args, Codec> {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("PgSink")
31            .field("config", &self.config)
32            .field("buffer_len", &self.buffer.len())
33            .finish_non_exhaustive()
34    }
35}
36
37impl<Args, Codec> Clone for PgSink<Args, Codec> {
38    /// Returns a fresh sink sharing the same pool/config; the buffer and any
39    /// in-flight flush are intentionally **not** cloned. Each `PgSink` owns its
40    /// pipeline state: cloning a sink that holds buffered tasks would either
41    /// silently duplicate (double-insert) or silently drop them on flush. The
42    /// clone starts empty, so callers responsible for pending work should
43    /// flush before cloning.
44    fn clone(&self) -> Self {
45        Self {
46            pool: self.pool.clone(),
47            config: self.config.clone(),
48            buffer: Vec::new(),
49            flush_future: Mutex::new(None),
50            _marker: PhantomData,
51        }
52    }
53}
54
55impl<Args, Codec> PgSink<Args, Codec> {
56    /// Create a sink for the given pool and config.
57    #[must_use]
58    pub fn new(pool: &PgPool, config: &Config) -> Self {
59        Self {
60            pool: pool.clone(),
61            config: config.clone(),
62            buffer: Vec::new(),
63            flush_future: Mutex::new(None),
64            _marker: PhantomData,
65        }
66    }
67}
68
69impl<Args, Codec> PgSink<Args, Codec> {
70    /// Buffer capacity from the underlying config (clamped to ≥1 so a
71    /// misconfigured `buffer_size(0)` does not deadlock the sink).
72    fn capacity(&self) -> usize {
73        self.config.buffer_size().max(1)
74    }
75
76    /// Whether `poll_ready` must drive a flush before accepting more work —
77    /// either a flush is already in flight, or the buffer is at capacity.
78    fn needs_flush_before_ready(&mut self) -> bool {
79        self.flush_future
80            .get_mut()
81            .expect("flush_future mutex poisoned")
82            .is_some()
83            || self.buffer.len() >= self.capacity()
84    }
85
86    /// Try to enqueue a single task into the buffer, returning
87    /// `Error::SinkBufferFull` when capacity has been reached.
88    fn try_push(&mut self, item: PgTask<CompactType>) -> Result<(), Error> {
89        let cap = self.capacity();
90        if self.buffer.len() >= cap {
91            return Err(Error::SinkBufferFull(cap));
92        }
93        self.buffer.push(item);
94        Ok(())
95    }
96
97    /// Drive the buffered batch toward completion. Starts a new flush future
98    /// when none is in flight and the buffer is non-empty; otherwise polls the
99    /// existing future and clears it once it resolves.
100    fn poll_flush_inner(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
101        // `&mut self` makes `Mutex::get_mut` infallible-by-borrow — no lock
102        // acquisition, just unique-borrow projection. The mutex exists purely
103        // to satisfy `PgSink: Sync` when the inner future is not `Sync` (ntex).
104        let flush_future = self
105            .flush_future
106            .get_mut()
107            .expect("flush_future mutex poisoned");
108
109        if flush_future.is_none() && self.buffer.is_empty() {
110            return Poll::Ready(Ok(()));
111        }
112
113        if flush_future.is_none() {
114            let pool = self.pool.clone();
115            let config = self.config.clone();
116            let buffer = std::mem::take(&mut self.buffer);
117            *flush_future = Some(Box::pin(queries::push_tasks(pool, config, buffer)));
118        }
119
120        let Some(future) = flush_future.as_mut() else {
121            return Poll::Ready(Ok(()));
122        };
123
124        match future.poll_unpin(cx) {
125            Poll::Pending => Poll::Pending,
126            Poll::Ready(result) => {
127                *flush_future = None;
128                Poll::Ready(result)
129            }
130        }
131    }
132}
133
134impl<Args, Encode, Fetcher> Sink<PgTask<CompactType>> for PostgresStorage<Args, Encode, Fetcher>
135where
136    Args: Send + Sync + 'static,
137    Fetcher: Unpin,
138{
139    type Error = Error;
140
141    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
142        let this = self.get_mut();
143        if this.sink.needs_flush_before_ready() {
144            this.sink.poll_flush_inner(cx)
145        } else {
146            Poll::Ready(Ok(()))
147        }
148    }
149
150    fn start_send(self: Pin<&mut Self>, item: PgTask<CompactType>) -> Result<(), Self::Error> {
151        self.get_mut().sink.try_push(item)
152    }
153
154    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
155        self.get_mut().sink.poll_flush_inner(cx)
156    }
157
158    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
159        self.poll_flush(cx)
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use std::{
166        pin::Pin,
167        task::{Context, Poll},
168    };
169
170    use diesel::{
171        PgConnection,
172        r2d2::{ConnectionManager, Pool},
173    };
174    use futures::{Sink, future, task::noop_waker_ref};
175    use lets_expect::{AssertionError, AssertionResult, *};
176
177    use super::*;
178
179    fn unchecked_pool() -> PgPool {
180        let manager = ConnectionManager::<PgConnection>::new("postgres://127.0.0.1:1/not-used");
181        Pool::builder()
182            .max_size(1)
183            .connection_timeout(std::time::Duration::from_millis(10))
184            .build_unchecked(manager)
185    }
186
187    fn task() -> PgTask<CompactType> {
188        PgTask::new(b"payload".to_vec())
189    }
190
191    fn sink(buffer_size: usize) -> PgSink<Vec<u8>> {
192        PgSink::new(
193            &unchecked_pool(),
194            &Config::new("sink-unit").set_buffer_size(buffer_size),
195        )
196    }
197
198    fn storage(buffer_size: usize) -> PostgresStorage<Vec<u8>> {
199        let pool = unchecked_pool();
200        let config = Config::new("sink-unit").set_buffer_size(buffer_size);
201        PostgresStorage::<Vec<u8>>::new_with_config(&pool, &config)
202    }
203
204    /// `start_send_via_storage` exercises the public `Sink` impl. The returned
205    /// `len` is the buffer length after the final send (only set on success).
206    fn start_send_via_storage(buffer_size: usize, existing_items: usize) -> Result<usize, Error> {
207        let mut storage = storage(buffer_size);
208        for _ in 0..existing_items {
209            storage.sink.buffer.push(task());
210        }
211        Pin::new(&mut storage).start_send(task())?;
212        Ok(storage.sink.buffer.len())
213    }
214
215    fn poll_ready_via_storage(
216        buffer_size: usize,
217        existing_items: usize,
218    ) -> Poll<Result<(), Error>> {
219        let mut storage = storage(buffer_size);
220        for _ in 0..existing_items {
221            storage.sink.buffer.push(task());
222        }
223        let mut cx = Context::from_waker(noop_waker_ref());
224        Pin::new(&mut storage).poll_ready(&mut cx)
225    }
226
227    struct ReadyObservation {
228        poll: Poll<Result<(), Error>>,
229        buffer_len: usize,
230        has_flush_future: bool,
231    }
232
233    fn poll_ready_in_flight() -> ReadyObservation {
234        let mut storage = storage(2);
235        storage.sink.flush_future =
236            Mutex::new(Some(Box::pin(future::pending::<Result<(), Error>>())));
237        let mut cx = Context::from_waker(noop_waker_ref());
238        let poll = Pin::new(&mut storage).poll_ready(&mut cx);
239        let has_flush_future = storage
240            .sink
241            .flush_future
242            .get_mut()
243            .expect("flush_future mutex poisoned")
244            .is_some();
245        ReadyObservation {
246            poll,
247            buffer_len: storage.sink.buffer.len(),
248            has_flush_future,
249        }
250    }
251
252    /// `poll_flush_observation` captures the state of `poll_flush_sink` after a
253    /// single poll: the poll result and whether the in-flight future was cleared.
254    struct FlushObservation {
255        poll: Poll<Result<(), Error>>,
256        future_cleared: bool,
257        buffer_len: usize,
258    }
259
260    fn poll_flush_sink_with_state(
261        buffer_size: usize,
262        buffered: usize,
263        future: Option<FlushFuture>,
264    ) -> FlushObservation {
265        let mut sink = sink(buffer_size);
266        for _ in 0..buffered {
267            sink.buffer.push(task());
268        }
269        sink.flush_future = Mutex::new(future);
270        let mut cx = Context::from_waker(noop_waker_ref());
271        let poll = sink.poll_flush_inner(&mut cx);
272        let future_cleared = sink
273            .flush_future
274            .get_mut()
275            .expect("flush_future mutex poisoned")
276            .is_none();
277        FlushObservation {
278            poll,
279            future_cleared,
280            buffer_len: sink.buffer.len(),
281        }
282    }
283
284    fn poll_flush_idle() -> FlushObservation {
285        poll_flush_sink_with_state(1, 0, None)
286    }
287
288    fn poll_flush_in_flight_ready(result: Result<(), Error>) -> FlushObservation {
289        poll_flush_sink_with_state(1, 0, Some(Box::pin(future::ready(result))))
290    }
291
292    fn poll_flush_in_flight_pending() -> FlushObservation {
293        poll_flush_sink_with_state(1, 0, Some(Box::pin(future::pending())))
294    }
295
296    /// `poll_flush_creates_future` exercises the `flush_future.is_none() &&
297    /// !buffer.is_empty()` branch: the function builds a new flush future from
298    /// the buffer and immediately polls it. Against an unreachable pool the
299    /// inner `push_tasks` future resolves to Err on first poll, so this returns
300    /// a `Ready(Err(...))` observation with the buffer drained.
301    #[cfg_attr(not(feature = "tokio"), allow(dead_code))]
302    fn poll_flush_creates_future() -> FlushObservation {
303        poll_flush_sink_with_state(2, 1, None)
304    }
305
306    fn poll_close_via_storage(buffered: usize) -> Poll<Result<(), Error>> {
307        let mut storage = storage(2);
308        for _ in 0..buffered {
309            storage.sink.buffer.push(task());
310        }
311        let mut cx = Context::from_waker(noop_waker_ref());
312        Pin::new(&mut storage).poll_close(&mut cx)
313    }
314
315    fn cloned_sink_buffer_len(buffered_items: usize) -> usize {
316        let mut sink = sink(3);
317        for _ in 0..buffered_items {
318            sink.buffer.push(task());
319        }
320        sink.clone().buffer.len()
321    }
322
323    fn cloned_sink_state_drops_flush_future() -> bool {
324        let mut sink = sink(3);
325        sink.buffer.push(task());
326        sink.flush_future = Mutex::new(Some(Box::pin(future::pending::<Result<(), Error>>())));
327        sink.clone()
328            .flush_future
329            .get_mut()
330            .expect("flush_future mutex poisoned")
331            .is_none()
332    }
333
334    fn cloned_sink_buffer_size(buffer_size: usize) -> usize {
335        sink(buffer_size).clone().config.buffer_size()
336    }
337
338    fn sink_debug(buffered_items: usize) -> String {
339        let mut sink = sink(3);
340        for _ in 0..buffered_items {
341            sink.buffer.push(task());
342        }
343        format!("{sink:?}")
344    }
345
346    fn sink_buffer_full(result: &Result<usize, Error>) -> AssertionResult {
347        match result {
348            Err(Error::SinkBufferFull(1)) => Ok(()),
349            other => Err(AssertionError::new(vec![format!(
350                "expected sink buffer full at capacity 1, got {other:?}"
351            )])),
352        }
353    }
354
355    fn poll_ready_ok(result: &Poll<Result<(), Error>>) -> AssertionResult {
356        match result {
357            Poll::Ready(Ok(())) => Ok(()),
358            other => Err(AssertionError::new(vec![format!(
359                "expected ready ok, got {other:?}"
360            )])),
361        }
362    }
363
364    #[cfg_attr(not(feature = "tokio"), allow(dead_code))]
365    fn poll_started_flush(result: &Poll<Result<(), Error>>) -> AssertionResult {
366        match result {
367            Poll::Pending | Poll::Ready(Err(_)) => Ok(()),
368            other => Err(AssertionError::new(vec![format!(
369                "expected backpressure to start flushing, got {other:?}"
370            )])),
371        }
372    }
373
374    fn observation_is_idle_ok(obs: &FlushObservation) -> AssertionResult {
375        match (&obs.poll, obs.future_cleared, obs.buffer_len) {
376            (Poll::Ready(Ok(())), true, 0) => Ok(()),
377            other => Err(AssertionError::new(vec![format!(
378                "expected idle Ready(Ok), got {other:?}"
379            )])),
380        }
381    }
382
383    fn observation_is_ready_ok_and_cleared(obs: &FlushObservation) -> AssertionResult {
384        match (&obs.poll, obs.future_cleared) {
385            (Poll::Ready(Ok(())), true) => Ok(()),
386            other => Err(AssertionError::new(vec![format!(
387                "expected Ready(Ok) with cleared future, got {other:?}"
388            )])),
389        }
390    }
391
392    fn observation_is_ready_err_and_cleared(obs: &FlushObservation) -> AssertionResult {
393        match (&obs.poll, obs.future_cleared) {
394            (Poll::Ready(Err(_)), true) => Ok(()),
395            other => Err(AssertionError::new(vec![format!(
396                "expected Ready(Err) with cleared future, got {other:?}"
397            )])),
398        }
399    }
400
401    fn observation_stays_pending(obs: &FlushObservation) -> AssertionResult {
402        match (&obs.poll, obs.future_cleared) {
403            (Poll::Pending, false) => Ok(()),
404            other => Err(AssertionError::new(vec![format!(
405                "expected Pending with future retained, got {other:?}"
406            )])),
407        }
408    }
409
410    #[cfg_attr(not(feature = "tokio"), allow(dead_code))]
411    fn observation_drained_buffer_into_future(obs: &FlushObservation) -> AssertionResult {
412        if obs.buffer_len != 0 {
413            return Err(AssertionError::new(vec![format!(
414                "expected buffer to be drained into the flush future, got {} items",
415                obs.buffer_len
416            )]));
417        }
418        // The flush future is created from the buffer. Either it is still
419        // running (Pending + future retained) or it has resolved (Ready +
420        // future cleared). Both observations confirm the drain happened; we
421        // reject the inconsistent combinations explicitly so the test cannot
422        // pass with stale state.
423        match (&obs.poll, obs.future_cleared) {
424            (Poll::Pending, false) => Ok(()),
425            (Poll::Ready(_), true) => Ok(()),
426            (Poll::Pending, true) => Err(AssertionError::new(vec![
427                "flush returned Pending but the future was cleared".to_owned(),
428            ])),
429            (Poll::Ready(_), false) => Err(AssertionError::new(vec![
430                "flush returned Ready but the future was retained".to_owned(),
431            ])),
432        }
433    }
434
435    fn keeps_in_flight_flush(observation: &ReadyObservation) -> AssertionResult {
436        match (
437            &observation.poll,
438            observation.buffer_len,
439            observation.has_flush_future,
440        ) {
441            (Poll::Pending, 0, true) => Ok(()),
442            other => Err(AssertionError::new(vec![format!(
443                "expected pending in-flight flush, got {other:?}"
444            )])),
445        }
446    }
447
448    fn debug_mentions_public_fields(result: &String) -> AssertionResult {
449        if result.contains("PgSink") && result.contains("config") && result.contains("buffer_len") {
450            Ok(())
451        } else {
452            Err(AssertionError::new(vec![format!(
453                "expected sink debug output with public fields, got {result}"
454            )]))
455        }
456    }
457
458    lets_expect! {
459        expect(start_send_via_storage(buffer_size, existing_items)) {
460            let buffer_size = 2;
461            let existing_items = 0;
462
463            when buffer_has_room_below_capacity {
464                to buffers_the_task { be_ok_and equal(1) }
465            }
466
467            when buffer_is_at_capacity_already {
468                let buffer_size = 1;
469                let existing_items = 1;
470                to rejects_the_send { sink_buffer_full }
471            }
472
473            when configured_capacity_is_zero_and_minimum_one_is_full {
474                let buffer_size = 0;
475                let existing_items = 1;
476                to rejects_the_send_via_the_minimum_capacity { sink_buffer_full }
477            }
478        }
479
480        expect(poll_ready_via_storage(buffer_size, existing_items)) {
481            let buffer_size = 2;
482            let existing_items = 0;
483
484            when buffer_is_below_capacity_and_no_flush_is_in_flight {
485                to returns_ready_without_flushing { poll_ready_ok }
486            }
487        }
488
489        expect(poll_ready_in_flight()) {
490            when an_earlier_flush_is_still_in_flight {
491                to waits_for_the_flush_to_complete { keeps_in_flight_flush }
492            }
493        }
494
495        expect(poll_flush_idle()) {
496            when there_is_neither_a_pending_flush_nor_buffered_work {
497                to completes_immediately_without_touching_the_database {
498                    observation_is_idle_ok
499                }
500            }
501        }
502
503        expect(poll_flush_in_flight_ready(result)) {
504            let result = Ok(());
505
506            when the_in_flight_flush_resolves_successfully {
507                to returns_ready_ok_and_clears_the_future {
508                    observation_is_ready_ok_and_cleared
509                }
510            }
511
512            when the_in_flight_flush_resolves_with_an_error {
513                let result = Err(Error::SinkBufferFull(1));
514                to surfaces_the_error_and_clears_the_future {
515                    observation_is_ready_err_and_cleared
516                }
517            }
518        }
519
520        expect(poll_flush_in_flight_pending()) {
521            when the_in_flight_flush_is_still_pending {
522                to stays_pending_and_keeps_the_future {
523                    observation_stays_pending
524                }
525            }
526        }
527
528        expect(poll_close_via_storage(buffered)) {
529            let buffered = 0;
530
531            when the_sink_is_already_drained {
532                to delegates_to_flush_and_completes { poll_ready_ok }
533            }
534        }
535
536        expect(cloned_sink_buffer_len(buffered_items)) {
537            let buffered_items = 2;
538
539            when the_original_sink_has_buffered_tasks {
540                to starts_the_clone_with_an_empty_buffer { equal(0) }
541            }
542        }
543
544        expect(cloned_sink_state_drops_flush_future()) {
545            when the_original_sink_has_an_in_flight_flush {
546                to does_not_share_the_in_flight_flush_future { equal(true) }
547            }
548        }
549
550        expect(cloned_sink_buffer_size(buffer_size)) {
551            let buffer_size = 4;
552
553            when the_original_sink_has_custom_capacity {
554                to keeps_the_capacity_configuration { equal(4) }
555            }
556        }
557
558        expect(sink_debug(buffered_items)) {
559            let buffered_items = 2;
560
561            when the_sink_has_buffered_items {
562                to describes_the_sink_without_exposing_the_pool {
563                    debug_mentions_public_fields
564                }
565            }
566        }
567    }
568
569    #[cfg(feature = "tokio")]
570    mod tokio_tests {
571        use super::*;
572
573        lets_expect! { #tokio_test
574            expect(poll_ready_via_storage(buffer_size, existing_items)) {
575                let buffer_size = 1;
576                let existing_items = 1;
577
578                when buffer_is_at_capacity_without_a_flush_in_flight {
579                    to starts_flushing_before_accepting_more_work { poll_started_flush }
580                }
581            }
582        }
583
584        lets_expect! { #tokio_test
585            expect(poll_flush_sink_with_state(buffer_size, buffered, None).poll) {
586                let buffer_size = 2;
587                let buffered = 1;
588
589                when poll_flush_runs_on_a_real_runtime_with_buffered_work {
590                    to resolves_to_an_error_against_an_unreachable_pool { poll_started_flush }
591                }
592            }
593        }
594
595        lets_expect! { #tokio_test
596            expect(poll_flush_creates_future()) {
597                when there_is_no_in_flight_flush_but_the_buffer_has_work {
598                    to drains_the_buffer_into_a_new_flush_future {
599                        observation_drained_buffer_into_future
600                    }
601                }
602            }
603
604            expect(poll_close_via_storage(1)) {
605                when there_is_buffered_work_to_flush_before_closing {
606                    to starts_flushing_the_buffered_work_before_completing { poll_started_flush }
607                }
608            }
609        }
610    }
611}