1use core::fmt;
2use core::future::Future;
3use core::marker::PhantomData;
4use core::pin::Pin;
5use core::task::{Context, Poll};
6use tokio_stream::Stream;
7
8use super::Sink;
9
10#[must_use = "sinks do nothing unless polled"]
12pub struct With<Si, Item, U, Fut, F, E> {
13 sink: Si,
14 f: F,
15 state: Option<Fut>,
16 _phantom_e: PhantomData<E>,
17 _phantom_item: PhantomData<fn(U) -> Item>,
18}
19
20impl<Si: Unpin, Item, U, Fut: Unpin, F, E> Unpin for With<Si, Item, U, Fut, F, E> {}
21
22impl<Si, Item, U, Fut, F, E> fmt::Debug for With<Si, Item, U, Fut, F, E>
23where
24 Si: fmt::Debug,
25 Fut: fmt::Debug,
26{
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("With")
29 .field("sink", &self.sink)
30 .field("state", &self.state)
31 .finish()
32 }
33}
34
35impl<Si, Item, U, Fut, F, E> With<Si, Item, U, Fut, F, E>
36where
37 Si: Sink<Item>,
38 F: FnMut(U) -> Fut,
39 E: From<Si::Error>,
40 Fut: Future<Output = Result<Item, E>>,
41{
42 pub(super) fn new(sink: Si, f: F) -> Self {
43 Self {
44 state: None,
45 sink,
46 f,
47 _phantom_item: PhantomData,
48 _phantom_e: PhantomData,
49 }
50 }
51}
52
53impl<Si, Item, U, Fut, F, E> Clone for With<Si, Item, U, Fut, F, E>
54where
55 Si: Clone,
56 F: Clone,
57 Fut: Clone,
58{
59 fn clone(&self) -> Self {
60 Self {
61 state: self.state.clone(),
62 sink: self.sink.clone(),
63 f: self.f.clone(),
64 _phantom_item: PhantomData,
65 _phantom_e: PhantomData,
66 }
67 }
68}
69
70impl<S, Item, U, Fut, F, E> Stream for With<S, Item, U, Fut, F, E>
72where
73 S: Stream + Sink<Item>,
74 F: FnMut(U) -> Fut,
75 Fut: Future<Output = Result<Item, E>>,
76 E: From<S::Error>,
77{
78 type Item = S::Item;
79
80 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81 let this = unsafe { self.get_unchecked_mut() };
82 let sink = unsafe { Pin::new_unchecked(&mut this.sink) };
83 sink.poll_next(cx)
84 }
85
86 fn size_hint(&self) -> (usize, Option<usize>) {
87 self.sink.size_hint()
88 }
89}
90
91impl<Si, Item, U, Fut, F, E> With<Si, Item, U, Fut, F, E>
92where
93 Si: Sink<Item>,
94 F: FnMut(U) -> Fut,
95 Fut: Future<Output = Result<Item, E>>,
96 E: From<Si::Error>,
97{
98 pub fn get_ref(&self) -> &Si {
100 &self.sink
101 }
102
103 pub fn get_mut(&mut self) -> &mut Si {
108 &mut self.sink
109 }
110
111 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Si> {
116 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }
117 }
118
119 pub fn into_inner(self) -> Si {
124 self.sink
125 }
126
127 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
129 let this = unsafe { self.get_unchecked_mut() };
130
131 if let Some(fut) = this.state.as_mut() {
132 match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
133 Poll::Ready(item_res) => {
134 this.state = None;
135 match item_res {
136 Ok(item) => {
137 let sink = unsafe { Pin::new_unchecked(&mut this.sink) };
138 match sink.start_send(item) {
139 Err(e) => Poll::Ready(Err(e.into())),
140 Ok(()) => Poll::Ready(Ok(())),
141 }
142 }
143 Err(e) => Poll::Ready(Err(e)),
144 }
145 }
146 Poll::Pending => Poll::Pending,
147 }
148 } else {
149 Poll::Ready(Ok(()))
150 }
151 }
152}
153
154impl<Si, Item, U, Fut, F, E> Sink<U> for With<Si, Item, U, Fut, F, E>
155where
156 Si: Sink<Item>,
157 F: FnMut(U) -> Fut,
158 Fut: Future<Output = Result<Item, E>>,
159 E: From<Si::Error>,
160{
161 type Error = E;
162
163 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
164 match self.as_mut().poll(cx) {
165 Poll::Ready(Ok(())) => {
166 match unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }
167 .poll_ready(cx)
168 {
169 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
170 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
171 Poll::Pending => Poll::Pending,
172 }
173 }
174 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
175 Poll::Pending => Poll::Pending,
176 }
177 }
178
179 fn start_send(self: Pin<&mut Self>, item: U) -> Result<(), Self::Error> {
180 let this = unsafe { self.get_unchecked_mut() };
181
182 assert!(this.state.is_none());
183 this.state = Some((this.f)(item));
184 Ok(())
185 }
186
187 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
188 match self.as_mut().poll(cx) {
189 Poll::Ready(Ok(())) => {
190 match unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }
191 .poll_flush(cx)
192 {
193 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
194 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
195 Poll::Pending => Poll::Pending,
196 }
197 }
198 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
199 Poll::Pending => Poll::Pending,
200 }
201 }
202
203 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
204 match self.as_mut().poll(cx) {
205 Poll::Ready(Ok(())) => {
206 match unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }
207 .poll_close(cx)
208 {
209 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
210 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
211 Poll::Pending => Poll::Pending,
212 }
213 }
214 Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
215 Poll::Pending => Poll::Pending,
216 }
217 }
218}