1use crate::{
2 item::ChokeItem,
3 ChokeSettings,
4 ChokeSettingsOrder,
5 ChokeStream,
6};
7use futures::{
8 Sink,
9 SinkExt,
10 StreamExt,
11};
12use std::{
13 pin::Pin,
14 task::{
15 Context,
16 Poll,
17 },
18};
19use tokio::sync::mpsc;
20use tokio_stream::wrappers::UnboundedReceiverStream;
21
22const VERBOSE: bool = false;
23
24#[allow(clippy::type_complexity)]
26#[pin_project]
27pub struct ChokeSink<Si, T>
28where
29 Si: Sink<T> + Unpin,
30{
31 sink: Si,
33 choke_stream: ChokeStream<T>,
35 sender: mpsc::UnboundedSender<T>,
36 backpressure: bool,
37}
38
39impl<Si, T> ChokeSink<Si, T>
40where
41 Si: Sink<T> + Unpin,
42 T: ChokeItem,
43{
44 pub fn new(sink: Si, settings: ChokeSettings) -> Self {
45 let (tx, rx) = mpsc::unbounded_channel();
46 let stream = Box::new(UnboundedReceiverStream::new(rx));
47 Self {
48 sink,
49 sender: tx,
50 backpressure: settings.ordering.unwrap_or_default() == ChokeSettingsOrder::Backpressure,
51 choke_stream: ChokeStream::new(stream, settings),
52 }
53 }
54
55 pub fn into_inner(self) -> Si {
56 self.sink
57 }
58}
59
60impl<Si, T> Sink<T> for ChokeSink<Si, T>
61where
62 Si: Sink<T> + Unpin + 'static,
63 T: ChokeItem + Send + 'static,
64{
65 type Error = Si::Error;
66
67 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
68 if VERBOSE {
69 debug!(backpressure = %self.backpressure, pending = %self.choke_stream.pending(), "poll_ready");
70 }
71 if self.backpressure && self.choke_stream.pending() {
72 return Poll::Pending;
73 }
74 self.sink.poll_ready_unpin(cx)
75 }
76
77 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
78 if VERBOSE {
79 debug!(pending = %self.choke_stream.pending(), "start_send");
80 }
81 self.sender.send(item).expect("the stream owns the receiver");
82 Ok(())
83 }
84
85 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86 if VERBOSE {
87 debug!(pending = %self.choke_stream.pending(), "poll_flush");
88 }
89
90 match self.choke_stream.poll_next_unpin(cx) {
91 Poll::Ready(Some(item)) => {
92 if VERBOSE {
93 debug!(pending = %self.choke_stream.pending(), "poll_flush: got item");
94 }
95 if let Err(err) = self.sink.start_send_unpin(item) {
96 return Poll::Ready(Err(err));
97 }
98 }
99 Poll::Pending => {
100 if self.choke_stream.has_dropped_item() {
101 self.choke_stream.reset_dropped_item();
102 return Poll::Ready(Ok(()));
103 }
104 }
105 _ => {}
106 }
107
108 self.sink.poll_flush_unpin(cx)
109 }
110
111 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
112 if VERBOSE {
113 debug!(pending = %self.choke_stream.pending(), "poll_close");
114 }
115
116 if self.choke_stream.pending() {
117 if let Poll::Ready(Err(err)) = self.poll_flush(cx) {
118 return Poll::Ready(Err(err));
119 };
120 Poll::Pending
121 } else {
122 self.sink.poll_close_unpin(cx)
123 }
124 }
125}