chokepoint/
sink.rs

1use crate::{
2    item::ChokeItem,
3    ChokeSettings,
4    ChokeSettingsOrder,
5    ChokeStream,
6};
7use futures::{
8    Sink,
9    SinkExt,
10    StreamExt,
11};
12use std::{
13    pin::Pin,
14    task::{
15        Context,
16        Poll,
17    },
18};
19use tokio::sync::mpsc;
20use tokio_stream::wrappers::UnboundedReceiverStream;
21
22const VERBOSE: bool = false;
23
24/// A [`futures::Sink`] that uses an underlaying [`ChokeStream`] to control how items are forwarded to the inner sink.
25#[allow(clippy::type_complexity)]
26#[pin_project]
27pub struct ChokeSink<Si, T>
28where
29    Si: Sink<T> + Unpin,
30{
31    /// The inner sink that gets written to.
32    sink: Si,
33    /// The choke stream that controls how items are forwarded to the inner sink.
34    choke_stream: ChokeStream<T>,
35    sender: mpsc::UnboundedSender<T>,
36    backpressure: bool,
37}
38
39impl<Si, T> ChokeSink<Si, T>
40where
41    Si: Sink<T> + Unpin,
42    T: ChokeItem,
43{
44    pub fn new(sink: Si, settings: ChokeSettings) -> Self {
45        let (tx, rx) = mpsc::unbounded_channel();
46        let stream = Box::new(UnboundedReceiverStream::new(rx));
47        Self {
48            sink,
49            sender: tx,
50            backpressure: settings.ordering.unwrap_or_default() == ChokeSettingsOrder::Backpressure,
51            choke_stream: ChokeStream::new(stream, settings),
52        }
53    }
54
55    pub fn into_inner(self) -> Si {
56        self.sink
57    }
58}
59
60impl<Si, T> Sink<T> for ChokeSink<Si, T>
61where
62    Si: Sink<T> + Unpin + 'static,
63    T: ChokeItem + Send + 'static,
64{
65    type Error = Si::Error;
66
67    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
68        if VERBOSE {
69            debug!(backpressure = %self.backpressure, pending = %self.choke_stream.pending(), "poll_ready");
70        }
71        if self.backpressure && self.choke_stream.pending() {
72            return Poll::Pending;
73        }
74        self.sink.poll_ready_unpin(cx)
75    }
76
77    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
78        if VERBOSE {
79            debug!(pending = %self.choke_stream.pending(), "start_send");
80        }
81        self.sender.send(item).expect("the stream owns the receiver");
82        Ok(())
83    }
84
85    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        if VERBOSE {
87            debug!(pending = %self.choke_stream.pending(), "poll_flush");
88        }
89
90        match self.choke_stream.poll_next_unpin(cx) {
91            Poll::Ready(Some(item)) => {
92                if VERBOSE {
93                    debug!(pending = %self.choke_stream.pending(), "poll_flush: got item");
94                }
95                if let Err(err) = self.sink.start_send_unpin(item) {
96                    return Poll::Ready(Err(err));
97                }
98            }
99            Poll::Pending => {
100                if self.choke_stream.has_dropped_item() {
101                    self.choke_stream.reset_dropped_item();
102                    return Poll::Ready(Ok(()));
103                }
104            }
105            _ => {}
106        }
107
108        self.sink.poll_flush_unpin(cx)
109    }
110
111    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112        if VERBOSE {
113            debug!(pending = %self.choke_stream.pending(), "poll_close");
114        }
115
116        if self.choke_stream.pending() {
117            if let Poll::Ready(Err(err)) = self.poll_flush(cx) {
118                return Poll::Ready(Err(err));
119            };
120            Poll::Pending
121        } else {
122            self.sink.poll_close_unpin(cx)
123        }
124    }
125}