1use crate::common::*;
2use tokio::sync::oneshot;
3
4pub trait TryStreamExt
6where
7 Self: TryStream,
8{
9 fn try_enumerate(self) -> TryEnumerate<Self, Self::Ok, Self::Error>;
18
19 fn take_until_error(self) -> TakeUntilError<Self, Self::Ok, Self::Error>;
21
22 fn catch_error(self) -> (ErrorNotify<Self::Error>, CatchError<Self>);
28
29 fn try_stateful_then<B, U, F, Fut>(
31 self,
32 init: B,
33 f: F,
34 ) -> TryStatefulThen<Self, B, Self::Ok, U, Self::Error, F, Fut>
35 where
36 F: FnMut(B, Self::Ok) -> Fut,
37 Fut: Future<Output = Result<Option<(B, U)>, Self::Error>>;
38
39 fn try_stateful_map<B, U, F>(
41 self,
42 init: B,
43 f: F,
44 ) -> TryStatefulMap<Self, B, Self::Ok, U, Self::Error, F>
45 where
46 F: FnMut(B, Self::Ok) -> Result<Option<(B, U)>, Self::Error>;
47}
48
49impl<S, T, E> TryStreamExt for S
50where
51 S: Stream<Item = Result<T, E>>,
52{
53 fn try_enumerate(self) -> TryEnumerate<Self, T, E> {
54 TryEnumerate {
55 counter: 0,
56 fused: false,
57 _phantom: PhantomData,
58 stream: self,
59 }
60 }
61
62 fn take_until_error(self) -> TakeUntilError<Self, T, E> {
63 TakeUntilError {
64 _phantom: PhantomData,
65 is_terminated: false,
66 stream: self,
67 }
68 }
69
70 fn try_stateful_then<B, U, F, Fut>(
71 self,
72 init: B,
73 f: F,
74 ) -> TryStatefulThen<Self, B, T, U, E, F, Fut>
75 where
76 F: FnMut(B, T) -> Fut,
77 Fut: Future<Output = Result<Option<(B, U)>, E>>,
78 {
79 TryStatefulThen {
80 stream: self,
81 future: None,
82 state: Some(init),
83 f,
84 _phantom: PhantomData,
85 }
86 }
87
88 fn try_stateful_map<B, U, F>(self, init: B, f: F) -> TryStatefulMap<Self, B, T, U, E, F>
89 where
90 F: FnMut(B, T) -> Result<Option<(B, U)>, E>,
91 {
92 TryStatefulMap {
93 stream: self,
94 state: Some(init),
95 f,
96 _phantom: PhantomData,
97 }
98 }
99
100 fn catch_error(self) -> (ErrorNotify<E>, CatchError<S>) {
101 let (tx, rx) = oneshot::channel();
102 let stream = CatchError {
103 sender: Some(tx),
104 stream: self,
105 };
106 let notify = ErrorNotify { receiver: rx };
107
108 (notify, stream)
109 }
110}
111
112pub use take_until_error::*;
113mod take_until_error {
114 use super::*;
115
116 #[pin_project]
118 pub struct TakeUntilError<St, T, E>
119 where
120 St: ?Sized,
121 {
122 pub(super) _phantom: PhantomData<(T, E)>,
123 pub(super) is_terminated: bool,
124 #[pin]
125 pub(super) stream: St,
126 }
127
128 impl<St, T, E> Stream for TakeUntilError<St, T, E>
129 where
130 St: Stream<Item = Result<T, E>>,
131 {
132 type Item = Result<T, E>;
133
134 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
135 let this = self.project();
136
137 Ready({
138 if *this.is_terminated {
139 None
140 } else if let Some(result) = ready!(this.stream.poll_next(cx)) {
141 if result.is_err() {
142 *this.is_terminated = true;
143 }
144 Some(result)
145 } else {
146 *this.is_terminated = true;
147 None
148 }
149 })
150 }
151 }
152}
153
154pub use try_stateful_then::*;
155mod try_stateful_then {
156 use super::*;
157
158 #[pin_project]
160 pub struct TryStatefulThen<St, B, T, U, E, F, Fut>
161 where
162 St: ?Sized,
163 {
164 #[pin]
165 pub(super) future: Option<Fut>,
166 pub(super) state: Option<B>,
167 pub(super) f: F,
168 pub(super) _phantom: PhantomData<(T, U, E)>,
169 #[pin]
170 pub(super) stream: St,
171 }
172
173 impl<St, B, T, U, E, F, Fut> Stream for TryStatefulThen<St, B, T, U, E, F, Fut>
174 where
175 St: Stream<Item = Result<T, E>>,
176 F: FnMut(B, T) -> Fut,
177 Fut: Future<Output = Result<Option<(B, U)>, E>>,
178 {
179 type Item = Result<U, E>;
180
181 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182 let mut this = self.project();
183
184 Poll::Ready(loop {
185 if let Some(fut) = this.future.as_mut().as_pin_mut() {
186 let result = ready!(fut.poll(cx));
187 this.future.set(None);
188
189 match result {
190 Ok(Some((state, item))) => {
191 *this.state = Some(state);
192 break Some(Ok(item));
193 }
194 Ok(None) => {
195 break None;
196 }
197 Err(err) => break Some(Err(err)),
198 }
199 } else if let Some(state) = this.state.take() {
200 match this.stream.as_mut().poll_next(cx) {
201 Ready(Some(Ok(item))) => {
202 this.future.set(Some((this.f)(state, item)));
203 }
204 Ready(Some(Err(err))) => break Some(Err(err)),
205 Ready(None) => break None,
206 Pending => {
207 *this.state = Some(state);
208 return Pending;
209 }
210 }
211 } else {
212 break None;
213 }
214 })
215 }
216 }
217}
218
219pub use try_stateful_map::*;
220mod try_stateful_map {
221 use super::*;
222
223 #[pin_project]
225 pub struct TryStatefulMap<St, B, T, U, E, F>
226 where
227 St: ?Sized,
228 {
229 pub(super) state: Option<B>,
230 pub(super) f: F,
231 pub(super) _phantom: PhantomData<(T, U, E)>,
232 #[pin]
233 pub(super) stream: St,
234 }
235
236 impl<St, B, T, U, E, F> Stream for TryStatefulMap<St, B, T, U, E, F>
237 where
238 St: Stream<Item = Result<T, E>>,
239 F: FnMut(B, T) -> Result<Option<(B, U)>, E>,
240 {
241 type Item = Result<U, E>;
242
243 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
244 let mut this = self.project();
245
246 Poll::Ready({
247 if let Some(state) = this.state.take() {
248 match this.stream.as_mut().poll_next(cx) {
249 Ready(Some(Ok(in_item))) => {
250 let result = (this.f)(state, in_item);
251
252 match result {
253 Ok(Some((state, out_item))) => {
254 *this.state = Some(state);
255 Some(Ok(out_item))
256 }
257 Ok(None) => None,
258 Err(err) => Some(Err(err)),
259 }
260 }
261 Ready(Some(Err(err))) => Some(Err(err)),
262 Ready(None) => None,
263 Pending => {
264 *this.state = Some(state);
265 return Pending;
266 }
267 }
268 } else {
269 None
270 }
271 })
272 }
273 }
274}
275
276pub use try_enumerate::*;
277mod try_enumerate {
278 use super::*;
279
280 #[derive(Derivative)]
282 #[derivative(Debug)]
283 #[pin_project]
284 pub struct TryEnumerate<S, T, E>
285 where
286 S: ?Sized,
287 {
288 pub(super) counter: usize,
289 pub(super) fused: bool,
290 pub(super) _phantom: PhantomData<(T, E)>,
291 #[pin]
292 #[derivative(Debug = "ignore")]
293 pub(super) stream: S,
294 }
295
296 impl<S, T, E> Stream for TryEnumerate<S, T, E>
297 where
298 S: Stream<Item = Result<T, E>>,
299 {
300 type Item = Result<(usize, T), E>;
301
302 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
303 let mut this = self.project();
304
305 Ready({
306 if *this.fused {
307 None
308 } else {
309 match ready!(Pin::new(&mut this.stream).poll_next(cx)) {
310 Some(Ok(item)) => {
311 let index = *this.counter;
312 *this.counter += 1;
313 Some(Ok((index, item)))
314 }
315 Some(Err(err)) => {
316 *this.fused = true;
317 Some(Err(err))
318 }
319 None => None,
320 }
321 }
322 })
323 }
324 }
325
326 impl<S, T, E> FusedStream for TryEnumerate<S, T, E>
327 where
328 S: Stream<Item = Result<T, E>>,
329 {
330 fn is_terminated(&self) -> bool {
331 self.fused
332 }
333 }
334}
335
336pub use catch_error::*;
337mod catch_error {
338 use super::*;
339
340 #[pin_project]
342 pub struct CatchError<St>
343 where
344 St: ?Sized + TryStream,
345 {
346 pub(super) sender: Option<oneshot::Sender<St::Error>>,
347 #[pin]
348 pub(super) stream: St,
349 }
350
351 impl<St> Stream for CatchError<St>
352 where
353 St: TryStream,
354 {
355 type Item = St::Ok;
356
357 fn poll_next(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Option<Self::Item>> {
358 let this = self.project();
359
360 Ready({
361 if let Some(sender) = this.sender.take() {
362 match this.stream.try_poll_next(ctx) {
363 Ready(Some(Ok(item))) => {
364 *this.sender = Some(sender);
365 Some(item)
366 }
367 Ready(Some(Err(err))) => {
368 let _ = sender.send(err);
369 None
370 }
371 Ready(None) => {
372 drop(sender);
373 None
374 }
375 Pending => {
376 *this.sender = Some(sender);
377 return Pending;
378 }
379 }
380 } else {
381 None
382 }
383 })
384 }
385 }
386
387 #[pin_project]
389 pub struct ErrorNotify<E> {
390 #[pin]
391 pub(super) receiver: oneshot::Receiver<E>,
392 }
393
394 impl<E> ErrorNotify<E> {
395 pub fn try_catch(mut self) -> ControlFlow<Result<(), E>, Self> {
396 use oneshot::error::TryRecvError::*;
397
398 match self.receiver.try_recv() {
399 Ok(err) => Break(Err(err)),
400 Err(Empty) => Continue(self),
401 Err(Closed) => Break(Ok(())),
402 }
403 }
404 }
405
406 impl<E> Future for ErrorNotify<E> {
407 type Output = Result<(), E>;
408
409 fn poll(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Self::Output> {
410 let this = self.project();
411
412 Ready(match ready!(this.receiver.poll(ctx)) {
413 Ok(err) => Err(err),
414 Err(_) => Ok(()),
415 })
416 }
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use crate::utils::async_test;
424
425 async_test! {
426 async fn take_until_error_test() {
427 {
428 let vec: Vec<Result<(), ()>> = stream::empty().take_until_error().collect().await;
429 assert_eq!(vec, []);
430 }
431
432 {
433 let vec: Vec<Result<_, ()>> = stream::iter([Ok(0), Ok(1), Ok(2), Ok(3)])
434 .take_until_error()
435 .collect()
436 .await;
437 assert_eq!(vec, [Ok(0), Ok(1), Ok(2), Ok(3)]);
438 }
439
440 {
441 let vec: Vec<Result<_, _>> = stream::iter([Ok(0), Ok(1), Err(2), Ok(3)])
442 .take_until_error()
443 .collect()
444 .await;
445 assert_eq!(vec, [Ok(0), Ok(1), Err(2),]);
446 }
447 }
448
449
450 async fn try_stateful_then_test() {
451 {
452 let values: Result<Vec<_>, ()> = stream::iter([Ok(3), Ok(1), Ok(4), Ok(1)])
453 .try_stateful_then(0, |acc, val| async move {
454 let new_acc = acc + val;
455 Ok(Some((new_acc, new_acc)))
456 })
457 .try_collect()
458 .await;
459
460 assert_eq!(values, Ok(vec![3, 4, 8, 9]));
461 }
462
463 {
464 let mut stream = stream::iter([Ok(3), Ok(1), Err(()), Ok(1)])
465 .try_stateful_then(0, |acc, val| async move {
466 let new_acc = acc + val;
467 Ok(Some((new_acc, new_acc)))
468 })
469 .boxed();
470
471 assert_eq!(stream.next().await, Some(Ok(3)));
472 assert_eq!(stream.next().await, Some(Ok(4)));
473 assert_eq!(stream.next().await, Some(Err(())));
474 assert_eq!(stream.next().await, None);
475 }
476
477 {
478 let mut stream = stream::iter([Ok(3), Ok(1), Ok(4), Ok(1), Err(())])
479 .try_stateful_then(0, |acc, val| async move {
480 let new_acc = acc + val;
481 if new_acc != 8 {
482 Ok(Some((new_acc, new_acc)))
483 } else {
484 Err(())
485 }
486 })
487 .boxed();
488
489 assert_eq!(stream.next().await, Some(Ok(3)));
490 assert_eq!(stream.next().await, Some(Ok(4)));
491 assert_eq!(stream.next().await, Some(Err(())));
492 assert_eq!(stream.next().await, None);
493 }
494
495 {
496 let mut stream = stream::iter([Ok(3), Ok(1), Ok(4), Ok(1), Err(())])
497 .try_stateful_then(0, |acc, val| async move {
498 let new_acc = acc + val;
499 if new_acc != 8 {
500 Ok(Some((new_acc, new_acc)))
501 } else {
502 Ok(None)
503 }
504 })
505 .boxed();
506
507 assert_eq!(stream.next().await, Some(Ok(3)));
508 assert_eq!(stream.next().await, Some(Ok(4)));
509 assert_eq!(stream.next().await, None);
510 }
511 }
512
513
514 async fn catch_error_test() {
515 {
516 let (notify, stream) = stream::empty::<Result<(), ()>>().catch_error();
517
518 let vec: Vec<_> = stream.collect().await;
519 let result = notify.await;
520
521 assert_eq!(vec, []);
522 assert_eq!(result, Ok(()));
523 }
524
525 {
526 let (notify, stream) =
527 stream::iter([Result::<_, ()>::Ok(0), Ok(1), Ok(2), Ok(3)]).catch_error();
528
529 let vec: Vec<_> = stream.collect().await;
530 let result = notify.await;
531
532 assert_eq!(vec, [0, 1, 2, 3]);
533 assert_eq!(result, Ok(()));
534 }
535
536 {
537 let (notify, stream) = stream::iter([Ok(0), Ok(1), Err(2), Ok(3)]).catch_error();
538
539 let vec: Vec<_> = stream.collect().await;
540 let result = notify.await;
541
542 assert_eq!(vec, [0, 1]);
543 assert_eq!(result, Err(2));
544 }
545
546 {
547 let (notify, mut stream) = stream::empty::<Result<(), ()>>().catch_error();
548
549 let notify = match notify.try_catch() {
550 Continue(notify) => notify,
551 _ => unreachable!(),
552 };
553
554 assert_eq!(stream.next().await, None);
555 assert!(matches!(notify.try_catch(), Break(Ok(()))));
556 }
557
558 {
559 let (notify, mut stream) = stream::iter([Result::<_, ()>::Ok(0)]).catch_error();
560
561 let notify = match notify.try_catch() {
562 Continue(notify) => notify,
563 _ => unreachable!(),
564 };
565
566 assert_eq!(stream.next().await, Some(0));
567 let notify = match notify.try_catch() {
568 Continue(notify) => notify,
569 _ => unreachable!(),
570 };
571
572 assert_eq!(stream.next().await, None);
573 assert!(matches!(notify.try_catch(), Break(Ok(()))));
574 }
575
576 {
577 let (notify, mut stream) = stream::iter([Ok(0), Err(2)]).catch_error();
578
579 let notify = match notify.try_catch() {
580 Continue(notify) => notify,
581 _ => unreachable!(),
582 };
583
584 assert_eq!(stream.next().await, Some(0));
585 let notify = match notify.try_catch() {
586 Continue(notify) => notify,
587 _ => unreachable!(),
588 };
589
590 assert_eq!(stream.next().await, None);
591 assert!(matches!(notify.try_catch(), Break(Err(2))));
592 }
593 }
594 }
595}