1use core::fmt;
2use core::future::Future;
3use core::marker::PhantomData;
4use core::pin::Pin;
5use core::task::{Context, Poll};
6use tokio_stream::Stream;
7use tokio_stream_util::FusedStream;
8
9use super::Sink;
10
11#[must_use = "sinks do nothing unless polled"]
13pub struct With<Si, Item, U, Fut, F> {
14 sink: Si,
15 f: F,
16 state: Option<Fut>,
17 _phantom: PhantomData<fn(U) -> Item>,
18}
19
20impl<Si: Unpin, Item, U, Fut: Unpin, F> Unpin for With<Si, Item, U, Fut, F> {}
21
22impl<Si, Item, U, Fut, F> fmt::Debug for With<Si, Item, U, Fut, F>
23where
24 Si: fmt::Debug,
25 Fut: fmt::Debug,
26{
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("With")
29 .field("sink", &self.sink)
30 .field("state", &self.state)
31 .finish()
32 }
33}
34
35impl<Si, Item, U, Fut, F> With<Si, Item, U, Fut, F>
36where
37 Si: Sink<Item>,
38 F: FnMut(U) -> Fut,
39 Fut: Future,
40{
41 pub(super) fn new<E>(sink: Si, f: F) -> Self
42 where
43 Fut: Future<Output = Result<Item, E>>,
44 E: From<Si::Error>,
45 {
46 Self {
47 state: None,
48 sink,
49 f,
50 _phantom: PhantomData,
51 }
52 }
53}
54
55impl<Si, Item, U, Fut, F> Clone for With<Si, Item, U, Fut, F>
56where
57 Si: Clone,
58 F: Clone,
59 Fut: Clone,
60{
61 fn clone(&self) -> Self {
62 Self {
63 state: self.state.clone(),
64 sink: self.sink.clone(),
65 f: self.f.clone(),
66 _phantom: PhantomData,
67 }
68 }
69}
70
71impl<S, Item, U, Fut, F> Stream for With<S, Item, U, Fut, F>
73where
74 S: Stream + Sink<Item>,
75 F: FnMut(U) -> Fut,
76 Fut: Future,
77{
78 type Item = S::Item;
79
80 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81 let this = unsafe { self.get_unchecked_mut() };
82 let sink = unsafe { Pin::new_unchecked(&mut this.sink) };
83 sink.poll_next(cx)
84 }
85
86 fn size_hint(&self) -> (usize, Option<usize>) {
87 self.sink.size_hint()
88 }
89}
90
91impl<S, Item, U, Fut, F> FusedStream for With<S, Item, U, Fut, F>
92where
93 S: FusedStream + Sink<Item>,
94 F: FnMut(U) -> Fut,
95 Fut: Future,
96{
97 fn is_terminated(&self) -> bool {
98 self.sink.is_terminated()
99 }
100}
101
102impl<Si, Item, U, Fut, F, E> With<Si, Item, U, Fut, F>
103where
104 Si: Sink<Item>,
105 F: FnMut(U) -> Fut,
106 Fut: Future<Output = Result<Item, E>>,
107 E: From<Si::Error>,
108{
109 pub fn get_ref(&self) -> &Si {
111 &self.sink
112 }
113
114 pub fn get_mut(&mut self) -> &mut Si {
119 &mut self.sink
120 }
121
122 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Si> {
127 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }
128 }
129
130 pub fn into_inner(self) -> Si {
135 self.sink
136 }
137
138 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
140 let this = unsafe { self.get_unchecked_mut() };
141
142 if let Some(fut) = this.state.as_mut() {
143 let item_res = match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
144 Poll::Ready(res) => res,
145 Poll::Pending => return Poll::Pending,
146 };
147 this.state = None;
148 let item = match item_res {
149 Ok(item) => item,
150 Err(e) => return Poll::Ready(Err(e)),
151 };
152 let sink = unsafe { Pin::new_unchecked(&mut this.sink) };
153 if let Err(e) = sink.start_send(item) {
154 return Poll::Ready(Err(e.into()));
155 }
156 }
157
158 Poll::Ready(Ok(()))
159 }
160}
161
162impl<Si, Item, U, Fut, F, E> Sink<U> for With<Si, Item, U, Fut, F>
163where
164 Si: Sink<Item>,
165 F: FnMut(U) -> Fut,
166 Fut: Future<Output = Result<Item, E>>,
167 E: From<Si::Error>,
168{
169 type Error = E;
170
171 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
172 match self.as_mut().poll(cx) {
173 Poll::Ready(Ok(())) => {}
174 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
175 Poll::Pending => return Poll::Pending,
176 }
177 let sink = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) };
178 match sink.poll_ready(cx) {
179 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
180 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
181 Poll::Pending => Poll::Pending,
182 }
183 }
184
185 fn start_send(self: Pin<&mut Self>, item: U) -> Result<(), Self::Error> {
186 let this = unsafe { self.get_unchecked_mut() };
187
188 assert!(this.state.is_none());
189 this.state = Some((this.f)(item));
190 Ok(())
191 }
192
193 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
194 match self.as_mut().poll(cx) {
195 Poll::Ready(Ok(())) => {}
196 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
197 Poll::Pending => return Poll::Pending,
198 };
199 let sink = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) };
200 match sink.poll_flush(cx) {
201 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
202 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
203 Poll::Pending => Poll::Pending,
204 }
205 }
206
207 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
208 match self.as_mut().poll(cx) {
209 Poll::Ready(Ok(())) => {}
210 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
211 Poll::Pending => return Poll::Pending,
212 };
213 let sink = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) };
214 match sink.poll_close(cx) {
215 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
216 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
217 Poll::Pending => Poll::Pending,
218 }
219 }
220}