futures_util/stream/
catch_unwind.rs1use std::prelude::v1::*;
2use std::any::Any;
3use std::panic::{catch_unwind, UnwindSafe, AssertUnwindSafe};
4use std::mem;
5
6use futures_core::{Poll, Async, Stream};
7use futures_core::task;
8
9#[derive(Debug)]
13#[must_use = "streams do nothing unless polled"]
14pub struct CatchUnwind<S> where S: Stream {
15 state: CatchUnwindState<S>,
16}
17
18pub fn new<S>(stream: S) -> CatchUnwind<S>
19 where S: Stream + UnwindSafe,
20{
21 CatchUnwind {
22 state: CatchUnwindState::Stream(stream),
23 }
24}
25
26#[derive(Debug)]
27enum CatchUnwindState<S> {
28 Stream(S),
29 Eof,
30 Done,
31}
32
33impl<S> Stream for CatchUnwind<S>
34 where S: Stream + UnwindSafe,
35{
36 type Item = Result<S::Item, S::Error>;
37 type Error = Box<Any + Send>;
38
39 fn poll_next(&mut self, cx: &mut task::Context) -> Poll<Option<Self::Item>, Self::Error> {
40 let mut stream = match mem::replace(&mut self.state, CatchUnwindState::Eof) {
41 CatchUnwindState::Done => panic!("cannot poll after eof"),
42 CatchUnwindState::Eof => {
43 self.state = CatchUnwindState::Done;
44 return Ok(Async::Ready(None));
45 }
46 CatchUnwindState::Stream(stream) => stream,
47 };
48 let res = catch_unwind(AssertUnwindSafe(|| (stream.poll_next(cx), stream)));
49 match res {
50 Err(e) => Err(e), Ok((poll, stream)) => {
52 self.state = CatchUnwindState::Stream(stream);
53 match poll {
54 Err(e) => Ok(Async::Ready(Some(Err(e)))),
55 Ok(Async::Pending) => Ok(Async::Pending),
56 Ok(Async::Ready(Some(r))) => Ok(Async::Ready(Some(Ok(r)))),
57 Ok(Async::Ready(None)) => Ok(Async::Ready(None)),
58 }
59 }
60 }
61 }
62}