1use core::fmt;
2use core::future::Future;
3use core::marker::PhantomData;
4use core::pin::Pin;
5use core::task::{Context, Poll};
6use tokio_stream::Stream;
7
8use super::Sink;
9
10#[must_use = "sinks do nothing unless polled"]
12pub struct With<Si, Item, U, Fut, F, E> {
13 sink: Si,
14 f: F,
15 state: Option<Fut>,
16 _phantom_e: PhantomData<E>,
17 _phantom_item: PhantomData<fn(U) -> Item>,
18}
19
20impl<Si: Unpin, Item, U, Fut: Unpin, F, E> Unpin for With<Si, Item, U, Fut, F, E> {}
21
22impl<Si, Item, U, Fut, F, E> fmt::Debug for With<Si, Item, U, Fut, F, E>
23where
24 Si: fmt::Debug,
25 Fut: fmt::Debug,
26{
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 f.debug_struct("With")
29 .field("sink", &self.sink)
30 .field("state", &self.state)
31 .finish()
32 }
33}
34
35impl<Si, Item, U, Fut, F, E> With<Si, Item, U, Fut, F, E>
36where
37 Si: Sink<Item>,
38 F: FnMut(U) -> Fut,
39 E: From<Si::Error>,
40 Fut: Future<Output = Result<Item, E>>,
41{
42 pub(super) fn new(sink: Si, f: F) -> Self {
43 Self {
44 state: None,
45 sink,
46 f,
47 _phantom_item: PhantomData,
48 _phantom_e: PhantomData,
49 }
50 }
51}
52
53impl<Si, Item, U, Fut, F, E> Clone for With<Si, Item, U, Fut, F, E>
54where
55 Si: Clone,
56 F: Clone,
57 Fut: Clone,
58{
59 fn clone(&self) -> Self {
60 Self {
61 state: self.state.clone(),
62 sink: self.sink.clone(),
63 f: self.f.clone(),
64 _phantom_item: PhantomData,
65 _phantom_e: PhantomData,
66 }
67 }
68}
69
70impl<S, Item, U, Fut, F, E> Stream for With<S, Item, U, Fut, F, E>
72where
73 S: Stream + Sink<Item>,
74 F: FnMut(U) -> Fut,
75 Fut: Future<Output = Result<Item, E>>,
76 E: From<S::Error>,
77{
78 type Item = S::Item;
79
80 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
81 let this = unsafe { self.get_unchecked_mut() };
82 let sink = unsafe { Pin::new_unchecked(&mut this.sink) };
83 sink.poll_next(cx)
84 }
85
86 fn size_hint(&self) -> (usize, Option<usize>) {
87 self.sink.size_hint()
88 }
89}
90
91impl<Si, Item, U, Fut, F, E> With<Si, Item, U, Fut, F, E>
92where
93 Si: Sink<Item>,
94 F: FnMut(U) -> Fut,
95 Fut: Future<Output = Result<Item, E>>,
96 E: From<Si::Error>,
97{
98 pub fn get_ref(&self) -> &Si {
100 &self.sink
101 }
102
103 pub fn get_mut(&mut self) -> &mut Si {
108 &mut self.sink
109 }
110
111 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut Si> {
116 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) }
117 }
118
119 pub fn into_inner(self) -> Si {
124 self.sink
125 }
126
127 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), E>> {
129 let this = unsafe { self.get_unchecked_mut() };
130
131 if let Some(fut) = this.state.as_mut() {
132 let item_res = match unsafe { Pin::new_unchecked(fut) }.poll(cx) {
133 Poll::Ready(res) => res,
134 Poll::Pending => return Poll::Pending,
135 };
136 this.state = None;
137 let item = match item_res {
138 Ok(item) => item,
139 Err(e) => return Poll::Ready(Err(e)),
140 };
141 let sink = unsafe { Pin::new_unchecked(&mut this.sink) };
142 if let Err(e) = sink.start_send(item) {
143 return Poll::Ready(Err(e.into()));
144 }
145 }
146
147 Poll::Ready(Ok(()))
148 }
149}
150
151impl<Si, Item, U, Fut, F, E> Sink<U> for With<Si, Item, U, Fut, F, E>
152where
153 Si: Sink<Item>,
154 F: FnMut(U) -> Fut,
155 Fut: Future<Output = Result<Item, E>>,
156 E: From<Si::Error> + core::error::Error,
157{
158 type Error = E;
159
160 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161 match self.as_mut().poll(cx) {
162 Poll::Ready(Ok(())) => {}
163 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
164 Poll::Pending => return Poll::Pending,
165 }
166 let sink = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) };
167 match sink.poll_ready(cx) {
168 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
169 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
170 Poll::Pending => Poll::Pending,
171 }
172 }
173
174 fn start_send(self: Pin<&mut Self>, item: U) -> Result<(), Self::Error> {
175 let this = unsafe { self.get_unchecked_mut() };
176
177 assert!(this.state.is_none());
178 this.state = Some((this.f)(item));
179 Ok(())
180 }
181
182 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
183 match self.as_mut().poll(cx) {
184 Poll::Ready(Ok(())) => {}
185 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
186 Poll::Pending => return Poll::Pending,
187 };
188 let sink = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) };
189 match sink.poll_flush(cx) {
190 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
191 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
192 Poll::Pending => Poll::Pending,
193 }
194 }
195
196 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
197 match self.as_mut().poll(cx) {
198 Poll::Ready(Ok(())) => {}
199 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
200 Poll::Pending => return Poll::Pending,
201 };
202 let sink = unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().sink) };
203 match sink.poll_close(cx) {
204 Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
205 Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
206 Poll::Pending => Poll::Pending,
207 }
208 }
209}