failure_ext/
context_futures.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under both the MIT license found in the
5 * LICENSE-MIT file in the root directory of this source tree and the Apache
6 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
7 * of this source tree.
8 */
9
10use std::fmt::Display;
11
12use futures::Future;
13use futures::Poll;
14
15/// "Context" support for futures.
16pub trait FutureErrorContext: Future + Sized {
17    /// Add context to the error returned by this future
18    fn context<D>(self, context: D) -> ContextFut<Self, D>
19    where
20        D: Display + Send + Sync + 'static;
21
22    /// Add context created by provided function to the error returned by this future
23    fn with_context<D, F>(self, f: F) -> WithContextFut<Self, F>
24    where
25        D: Display + Send + Sync + 'static,
26        F: FnOnce() -> D;
27}
28
29impl<F, E> FutureErrorContext for F
30where
31    F: Future<Error = E> + Sized,
32    E: Into<anyhow::Error>,
33{
34    fn context<D>(self, displayable: D) -> ContextFut<Self, D>
35    where
36        D: Display + Send + Sync + 'static,
37    {
38        ContextFut::new(self, displayable)
39    }
40
41    fn with_context<D, O>(self, f: O) -> WithContextFut<Self, O>
42    where
43        D: Display + Send + Sync + 'static,
44        O: FnOnce() -> D,
45    {
46        WithContextFut::new(self, f)
47    }
48}
49
50pub struct ContextFut<A, D> {
51    inner: A,
52    displayable: Option<D>,
53}
54
55impl<A, D> ContextFut<A, D> {
56    pub fn new(future: A, displayable: D) -> Self {
57        Self {
58            inner: future,
59            displayable: Some(displayable),
60        }
61    }
62}
63
64impl<A, E, D> Future for ContextFut<A, D>
65where
66    A: Future<Error = E>,
67    E: Into<anyhow::Error>,
68    D: Display + Send + Sync + 'static,
69{
70    type Item = A::Item;
71    type Error = anyhow::Error;
72
73    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
74        match self.inner.poll() {
75            Err(err) => Err(err.into().context(
76                self.displayable
77                    .take()
78                    .expect("poll called after future completion"),
79            )),
80            Ok(item) => Ok(item),
81        }
82    }
83}
84
85pub struct WithContextFut<A, F> {
86    inner: A,
87    displayable: Option<F>,
88}
89
90impl<A, F> WithContextFut<A, F> {
91    pub fn new(future: A, displayable: F) -> Self {
92        Self {
93            inner: future,
94            displayable: Some(displayable),
95        }
96    }
97}
98
99impl<A, E, F, D> Future for WithContextFut<A, F>
100where
101    A: Future<Error = E>,
102    E: Into<anyhow::Error>,
103    D: Display + Send + Sync + 'static,
104    F: FnOnce() -> D,
105{
106    type Item = A::Item;
107    type Error = anyhow::Error;
108
109    fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
110        match self.inner.poll() {
111            Err(err) => {
112                let f = self
113                    .displayable
114                    .take()
115                    .expect("poll called after future completion");
116
117                let context = f();
118                Err(err.into().context(context))
119            }
120            Ok(item) => Ok(item),
121        }
122    }
123}
124
125#[cfg(test)]
126mod test {
127    use anyhow::format_err;
128    use futures::future::err;
129
130    use super::*;
131
132    #[test]
133    #[should_panic]
134    fn poll_after_completion_fail() {
135        let err = err::<(), _>(format_err!("foo").context("bar"));
136        let mut err = err.context("baz");
137        let _ = err.poll();
138        let _ = err.poll();
139    }
140
141    #[test]
142    #[should_panic]
143    fn poll_after_completion_fail_with_context() {
144        let err = err::<(), _>(format_err!("foo").context("bar"));
145        let mut err = err.with_context(|| "baz");
146        let _ = err.poll();
147        let _ = err.poll();
148    }
149
150    #[test]
151    #[should_panic]
152    fn poll_after_completion_error() {
153        let err = err::<(), _>(format_err!("foo"));
154        let mut err = err.context("baz");
155        let _ = err.poll();
156        let _ = err.poll();
157    }
158
159    #[test]
160    #[should_panic]
161    fn poll_after_completion_error_with_context() {
162        let err = err::<(), _>(format_err!("foo"));
163        let mut err = err.with_context(|| "baz");
164        let _ = err.poll();
165        let _ = err.poll();
166    }
167}