1use std::{
2 future::Future,
3 marker::PhantomData,
4 pin::Pin,
5 sync::Mutex,
6 task::{Context, Poll},
7};
8
9use apalis_codec::json::JsonCodec;
10use futures::{FutureExt, Sink};
11
12use crate::{CompactType, Config, Error, PgPool, PgTask, PostgresStorage, queries};
13
14type FlushFuture = Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'static>>;
18
19pub struct PgSink<Args, Codec = JsonCodec<CompactType>> {
21 pool: PgPool,
22 config: Config,
23 buffer: Vec<PgTask<CompactType>>,
24 flush_future: Mutex<Option<FlushFuture>>,
25 _marker: PhantomData<(Args, Codec)>,
26}
27
28impl<Args, Codec> std::fmt::Debug for PgSink<Args, Codec> {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("PgSink")
31 .field("config", &self.config)
32 .field("buffer_len", &self.buffer.len())
33 .finish_non_exhaustive()
34 }
35}
36
37impl<Args, Codec> Clone for PgSink<Args, Codec> {
38 fn clone(&self) -> Self {
45 Self {
46 pool: self.pool.clone(),
47 config: self.config.clone(),
48 buffer: Vec::new(),
49 flush_future: Mutex::new(None),
50 _marker: PhantomData,
51 }
52 }
53}
54
55impl<Args, Codec> PgSink<Args, Codec> {
56 #[must_use]
58 pub fn new(pool: &PgPool, config: &Config) -> Self {
59 Self {
60 pool: pool.clone(),
61 config: config.clone(),
62 buffer: Vec::new(),
63 flush_future: Mutex::new(None),
64 _marker: PhantomData,
65 }
66 }
67}
68
69impl<Args, Codec> PgSink<Args, Codec> {
70 fn capacity(&self) -> usize {
73 self.config.buffer_size().max(1)
74 }
75
76 fn needs_flush_before_ready(&mut self) -> bool {
79 self.flush_future
80 .get_mut()
81 .expect("flush_future mutex poisoned")
82 .is_some()
83 || self.buffer.len() >= self.capacity()
84 }
85
86 fn try_push(&mut self, item: PgTask<CompactType>) -> Result<(), Error> {
89 let cap = self.capacity();
90 if self.buffer.len() >= cap {
91 return Err(Error::SinkBufferFull(cap));
92 }
93 self.buffer.push(item);
94 Ok(())
95 }
96
97 fn poll_flush_inner(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
101 let flush_future = self
105 .flush_future
106 .get_mut()
107 .expect("flush_future mutex poisoned");
108
109 if flush_future.is_none() && self.buffer.is_empty() {
110 return Poll::Ready(Ok(()));
111 }
112
113 if flush_future.is_none() {
114 let pool = self.pool.clone();
115 let config = self.config.clone();
116 let buffer = std::mem::take(&mut self.buffer);
117 *flush_future = Some(Box::pin(queries::push_tasks(pool, config, buffer)));
118 }
119
120 let Some(future) = flush_future.as_mut() else {
121 return Poll::Ready(Ok(()));
122 };
123
124 match future.poll_unpin(cx) {
125 Poll::Pending => Poll::Pending,
126 Poll::Ready(result) => {
127 *flush_future = None;
128 Poll::Ready(result)
129 }
130 }
131 }
132}
133
134impl<Args, Encode, Fetcher> Sink<PgTask<CompactType>> for PostgresStorage<Args, Encode, Fetcher>
135where
136 Args: Send + Sync + 'static,
137 Fetcher: Unpin,
138{
139 type Error = Error;
140
141 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
142 let this = self.get_mut();
143 if this.sink.needs_flush_before_ready() {
144 this.sink.poll_flush_inner(cx)
145 } else {
146 Poll::Ready(Ok(()))
147 }
148 }
149
150 fn start_send(self: Pin<&mut Self>, item: PgTask<CompactType>) -> Result<(), Self::Error> {
151 self.get_mut().sink.try_push(item)
152 }
153
154 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
155 self.get_mut().sink.poll_flush_inner(cx)
156 }
157
158 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
159 self.poll_flush(cx)
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use std::{
166 pin::Pin,
167 task::{Context, Poll},
168 };
169
170 use diesel::{
171 PgConnection,
172 r2d2::{ConnectionManager, Pool},
173 };
174 use futures::{Sink, future, task::noop_waker_ref};
175 use lets_expect::{AssertionError, AssertionResult, *};
176
177 use super::*;
178
179 fn unchecked_pool() -> PgPool {
180 let manager = ConnectionManager::<PgConnection>::new("postgres://127.0.0.1:1/not-used");
181 Pool::builder()
182 .max_size(1)
183 .connection_timeout(std::time::Duration::from_millis(10))
184 .build_unchecked(manager)
185 }
186
187 fn task() -> PgTask<CompactType> {
188 PgTask::new(b"payload".to_vec())
189 }
190
191 fn sink(buffer_size: usize) -> PgSink<Vec<u8>> {
192 PgSink::new(
193 &unchecked_pool(),
194 &Config::new("sink-unit").set_buffer_size(buffer_size),
195 )
196 }
197
198 fn storage(buffer_size: usize) -> PostgresStorage<Vec<u8>> {
199 let pool = unchecked_pool();
200 let config = Config::new("sink-unit").set_buffer_size(buffer_size);
201 PostgresStorage::<Vec<u8>>::new_with_config(&pool, &config)
202 }
203
204 fn start_send_via_storage(buffer_size: usize, existing_items: usize) -> Result<usize, Error> {
207 let mut storage = storage(buffer_size);
208 for _ in 0..existing_items {
209 storage.sink.buffer.push(task());
210 }
211 Pin::new(&mut storage).start_send(task())?;
212 Ok(storage.sink.buffer.len())
213 }
214
215 fn poll_ready_via_storage(
216 buffer_size: usize,
217 existing_items: usize,
218 ) -> Poll<Result<(), Error>> {
219 let mut storage = storage(buffer_size);
220 for _ in 0..existing_items {
221 storage.sink.buffer.push(task());
222 }
223 let mut cx = Context::from_waker(noop_waker_ref());
224 Pin::new(&mut storage).poll_ready(&mut cx)
225 }
226
227 struct ReadyObservation {
228 poll: Poll<Result<(), Error>>,
229 buffer_len: usize,
230 has_flush_future: bool,
231 }
232
233 fn poll_ready_in_flight() -> ReadyObservation {
234 let mut storage = storage(2);
235 storage.sink.flush_future =
236 Mutex::new(Some(Box::pin(future::pending::<Result<(), Error>>())));
237 let mut cx = Context::from_waker(noop_waker_ref());
238 let poll = Pin::new(&mut storage).poll_ready(&mut cx);
239 let has_flush_future = storage
240 .sink
241 .flush_future
242 .get_mut()
243 .expect("flush_future mutex poisoned")
244 .is_some();
245 ReadyObservation {
246 poll,
247 buffer_len: storage.sink.buffer.len(),
248 has_flush_future,
249 }
250 }
251
252 struct FlushObservation {
255 poll: Poll<Result<(), Error>>,
256 future_cleared: bool,
257 buffer_len: usize,
258 }
259
260 fn poll_flush_sink_with_state(
261 buffer_size: usize,
262 buffered: usize,
263 future: Option<FlushFuture>,
264 ) -> FlushObservation {
265 let mut sink = sink(buffer_size);
266 for _ in 0..buffered {
267 sink.buffer.push(task());
268 }
269 sink.flush_future = Mutex::new(future);
270 let mut cx = Context::from_waker(noop_waker_ref());
271 let poll = sink.poll_flush_inner(&mut cx);
272 let future_cleared = sink
273 .flush_future
274 .get_mut()
275 .expect("flush_future mutex poisoned")
276 .is_none();
277 FlushObservation {
278 poll,
279 future_cleared,
280 buffer_len: sink.buffer.len(),
281 }
282 }
283
284 fn poll_flush_idle() -> FlushObservation {
285 poll_flush_sink_with_state(1, 0, None)
286 }
287
288 fn poll_flush_in_flight_ready(result: Result<(), Error>) -> FlushObservation {
289 poll_flush_sink_with_state(1, 0, Some(Box::pin(future::ready(result))))
290 }
291
292 fn poll_flush_in_flight_pending() -> FlushObservation {
293 poll_flush_sink_with_state(1, 0, Some(Box::pin(future::pending())))
294 }
295
296 #[cfg_attr(not(feature = "tokio"), allow(dead_code))]
302 fn poll_flush_creates_future() -> FlushObservation {
303 poll_flush_sink_with_state(2, 1, None)
304 }
305
306 fn poll_close_via_storage(buffered: usize) -> Poll<Result<(), Error>> {
307 let mut storage = storage(2);
308 for _ in 0..buffered {
309 storage.sink.buffer.push(task());
310 }
311 let mut cx = Context::from_waker(noop_waker_ref());
312 Pin::new(&mut storage).poll_close(&mut cx)
313 }
314
315 fn cloned_sink_buffer_len(buffered_items: usize) -> usize {
316 let mut sink = sink(3);
317 for _ in 0..buffered_items {
318 sink.buffer.push(task());
319 }
320 sink.clone().buffer.len()
321 }
322
323 fn cloned_sink_state_drops_flush_future() -> bool {
324 let mut sink = sink(3);
325 sink.buffer.push(task());
326 sink.flush_future = Mutex::new(Some(Box::pin(future::pending::<Result<(), Error>>())));
327 sink.clone()
328 .flush_future
329 .get_mut()
330 .expect("flush_future mutex poisoned")
331 .is_none()
332 }
333
334 fn cloned_sink_buffer_size(buffer_size: usize) -> usize {
335 sink(buffer_size).clone().config.buffer_size()
336 }
337
338 fn sink_debug(buffered_items: usize) -> String {
339 let mut sink = sink(3);
340 for _ in 0..buffered_items {
341 sink.buffer.push(task());
342 }
343 format!("{sink:?}")
344 }
345
346 fn sink_buffer_full(result: &Result<usize, Error>) -> AssertionResult {
347 match result {
348 Err(Error::SinkBufferFull(1)) => Ok(()),
349 other => Err(AssertionError::new(vec![format!(
350 "expected sink buffer full at capacity 1, got {other:?}"
351 )])),
352 }
353 }
354
355 fn poll_ready_ok(result: &Poll<Result<(), Error>>) -> AssertionResult {
356 match result {
357 Poll::Ready(Ok(())) => Ok(()),
358 other => Err(AssertionError::new(vec![format!(
359 "expected ready ok, got {other:?}"
360 )])),
361 }
362 }
363
364 #[cfg_attr(not(feature = "tokio"), allow(dead_code))]
365 fn poll_started_flush(result: &Poll<Result<(), Error>>) -> AssertionResult {
366 match result {
367 Poll::Pending | Poll::Ready(Err(_)) => Ok(()),
368 other => Err(AssertionError::new(vec![format!(
369 "expected backpressure to start flushing, got {other:?}"
370 )])),
371 }
372 }
373
374 fn observation_is_idle_ok(obs: &FlushObservation) -> AssertionResult {
375 match (&obs.poll, obs.future_cleared, obs.buffer_len) {
376 (Poll::Ready(Ok(())), true, 0) => Ok(()),
377 other => Err(AssertionError::new(vec![format!(
378 "expected idle Ready(Ok), got {other:?}"
379 )])),
380 }
381 }
382
383 fn observation_is_ready_ok_and_cleared(obs: &FlushObservation) -> AssertionResult {
384 match (&obs.poll, obs.future_cleared) {
385 (Poll::Ready(Ok(())), true) => Ok(()),
386 other => Err(AssertionError::new(vec![format!(
387 "expected Ready(Ok) with cleared future, got {other:?}"
388 )])),
389 }
390 }
391
392 fn observation_is_ready_err_and_cleared(obs: &FlushObservation) -> AssertionResult {
393 match (&obs.poll, obs.future_cleared) {
394 (Poll::Ready(Err(_)), true) => Ok(()),
395 other => Err(AssertionError::new(vec![format!(
396 "expected Ready(Err) with cleared future, got {other:?}"
397 )])),
398 }
399 }
400
401 fn observation_stays_pending(obs: &FlushObservation) -> AssertionResult {
402 match (&obs.poll, obs.future_cleared) {
403 (Poll::Pending, false) => Ok(()),
404 other => Err(AssertionError::new(vec![format!(
405 "expected Pending with future retained, got {other:?}"
406 )])),
407 }
408 }
409
410 #[cfg_attr(not(feature = "tokio"), allow(dead_code))]
411 fn observation_drained_buffer_into_future(obs: &FlushObservation) -> AssertionResult {
412 if obs.buffer_len != 0 {
413 return Err(AssertionError::new(vec![format!(
414 "expected buffer to be drained into the flush future, got {} items",
415 obs.buffer_len
416 )]));
417 }
418 match (&obs.poll, obs.future_cleared) {
424 (Poll::Pending, false) => Ok(()),
425 (Poll::Ready(_), true) => Ok(()),
426 (Poll::Pending, true) => Err(AssertionError::new(vec![
427 "flush returned Pending but the future was cleared".to_owned(),
428 ])),
429 (Poll::Ready(_), false) => Err(AssertionError::new(vec![
430 "flush returned Ready but the future was retained".to_owned(),
431 ])),
432 }
433 }
434
435 fn keeps_in_flight_flush(observation: &ReadyObservation) -> AssertionResult {
436 match (
437 &observation.poll,
438 observation.buffer_len,
439 observation.has_flush_future,
440 ) {
441 (Poll::Pending, 0, true) => Ok(()),
442 other => Err(AssertionError::new(vec![format!(
443 "expected pending in-flight flush, got {other:?}"
444 )])),
445 }
446 }
447
448 fn debug_mentions_public_fields(result: &String) -> AssertionResult {
449 if result.contains("PgSink") && result.contains("config") && result.contains("buffer_len") {
450 Ok(())
451 } else {
452 Err(AssertionError::new(vec![format!(
453 "expected sink debug output with public fields, got {result}"
454 )]))
455 }
456 }
457
458 lets_expect! {
459 expect(start_send_via_storage(buffer_size, existing_items)) {
460 let buffer_size = 2;
461 let existing_items = 0;
462
463 when buffer_has_room_below_capacity {
464 to buffers_the_task { be_ok_and equal(1) }
465 }
466
467 when buffer_is_at_capacity_already {
468 let buffer_size = 1;
469 let existing_items = 1;
470 to rejects_the_send { sink_buffer_full }
471 }
472
473 when configured_capacity_is_zero_and_minimum_one_is_full {
474 let buffer_size = 0;
475 let existing_items = 1;
476 to rejects_the_send_via_the_minimum_capacity { sink_buffer_full }
477 }
478 }
479
480 expect(poll_ready_via_storage(buffer_size, existing_items)) {
481 let buffer_size = 2;
482 let existing_items = 0;
483
484 when buffer_is_below_capacity_and_no_flush_is_in_flight {
485 to returns_ready_without_flushing { poll_ready_ok }
486 }
487 }
488
489 expect(poll_ready_in_flight()) {
490 when an_earlier_flush_is_still_in_flight {
491 to waits_for_the_flush_to_complete { keeps_in_flight_flush }
492 }
493 }
494
495 expect(poll_flush_idle()) {
496 when there_is_neither_a_pending_flush_nor_buffered_work {
497 to completes_immediately_without_touching_the_database {
498 observation_is_idle_ok
499 }
500 }
501 }
502
503 expect(poll_flush_in_flight_ready(result)) {
504 let result = Ok(());
505
506 when the_in_flight_flush_resolves_successfully {
507 to returns_ready_ok_and_clears_the_future {
508 observation_is_ready_ok_and_cleared
509 }
510 }
511
512 when the_in_flight_flush_resolves_with_an_error {
513 let result = Err(Error::SinkBufferFull(1));
514 to surfaces_the_error_and_clears_the_future {
515 observation_is_ready_err_and_cleared
516 }
517 }
518 }
519
520 expect(poll_flush_in_flight_pending()) {
521 when the_in_flight_flush_is_still_pending {
522 to stays_pending_and_keeps_the_future {
523 observation_stays_pending
524 }
525 }
526 }
527
528 expect(poll_close_via_storage(buffered)) {
529 let buffered = 0;
530
531 when the_sink_is_already_drained {
532 to delegates_to_flush_and_completes { poll_ready_ok }
533 }
534 }
535
536 expect(cloned_sink_buffer_len(buffered_items)) {
537 let buffered_items = 2;
538
539 when the_original_sink_has_buffered_tasks {
540 to starts_the_clone_with_an_empty_buffer { equal(0) }
541 }
542 }
543
544 expect(cloned_sink_state_drops_flush_future()) {
545 when the_original_sink_has_an_in_flight_flush {
546 to does_not_share_the_in_flight_flush_future { equal(true) }
547 }
548 }
549
550 expect(cloned_sink_buffer_size(buffer_size)) {
551 let buffer_size = 4;
552
553 when the_original_sink_has_custom_capacity {
554 to keeps_the_capacity_configuration { equal(4) }
555 }
556 }
557
558 expect(sink_debug(buffered_items)) {
559 let buffered_items = 2;
560
561 when the_sink_has_buffered_items {
562 to describes_the_sink_without_exposing_the_pool {
563 debug_mentions_public_fields
564 }
565 }
566 }
567 }
568
569 #[cfg(feature = "tokio")]
570 mod tokio_tests {
571 use super::*;
572
573 lets_expect! { #tokio_test
574 expect(poll_ready_via_storage(buffer_size, existing_items)) {
575 let buffer_size = 1;
576 let existing_items = 1;
577
578 when buffer_is_at_capacity_without_a_flush_in_flight {
579 to starts_flushing_before_accepting_more_work { poll_started_flush }
580 }
581 }
582 }
583
584 lets_expect! { #tokio_test
585 expect(poll_flush_sink_with_state(buffer_size, buffered, None).poll) {
586 let buffer_size = 2;
587 let buffered = 1;
588
589 when poll_flush_runs_on_a_real_runtime_with_buffered_work {
590 to resolves_to_an_error_against_an_unreachable_pool { poll_started_flush }
591 }
592 }
593 }
594
595 lets_expect! { #tokio_test
596 expect(poll_flush_creates_future()) {
597 when there_is_no_in_flight_flush_but_the_buffer_has_work {
598 to drains_the_buffer_into_a_new_flush_future {
599 observation_drained_buffer_into_future
600 }
601 }
602 }
603
604 expect(poll_close_via_storage(1)) {
605 when there_is_buffered_work_to_flush_before_closing {
606 to starts_flushing_the_buffered_work_before_completing { poll_started_flush }
607 }
608 }
609 }
610 }
611}