Skip to main content

context_logger/
future.rs

1//! Future types.
2
3use std::task::Poll;
4
5use pin_project::pin_project;
6
7use crate::{LogContext, scope::LogScope};
8
9/// Extension trait for futures to propagate contextual logging information.
10///
11/// This trait adds ability to attach a [`LogContext`] for any [`Future`],
12/// ensuring that logs emitted during the future's execution will include
13/// the contextual properties even if the future is polled across different threads.
14pub trait FutureExt: Sized + crate::private::Sealed {
15    /// Attaches a log context to this future.
16    ///
17    /// The attached [context](LogContext) will be activated every time the instrumented
18    /// future is polled.
19    ///
20    /// # Examples
21    ///
22    /// ```
23    /// use context_logger::{LogContext, FutureExt};
24    /// use log::info;
25    ///
26    /// async fn process_user_data(user_id: u64) {
27    ///     // Create a context with user information
28    ///     let context = LogContext::new()
29    ///         .with_local_record("user_id", user_id)
30    ///         .with_local_record("operation", "process_data");
31    ///
32    ///     async {
33    ///         info!("Starting user data processing"); // Will include context
34    ///
35    ///         // Do some async work...
36    ///
37    ///         info!("User data processing complete"); // Still includes context
38    ///     }
39    ///     .in_log_context(context)
40    ///     .await;
41    /// }
42    /// ```
43    fn in_log_context(self, context: LogContext) -> LogContextFuture<Self>;
44}
45
46impl<F> FutureExt for F
47where
48    F: Future,
49{
50    fn in_log_context(self, context: LogContext) -> LogContextFuture<Self> {
51        LogContextFuture {
52            inner: self,
53            log_context: Some(context),
54        }
55    }
56}
57
58/// A future with an attached logging context.
59///
60/// This type is created by the [`FutureExt::in_log_context`].
61///
62/// # Note
63///
64/// If the wrapped future will panic, the next `poll` invocation will panic unconditionally.
65#[pin_project]
66#[derive(Debug)]
67pub struct LogContextFuture<F> {
68    #[pin]
69    inner: F,
70    log_context: Option<LogContext>,
71}
72
73impl<F> Future for LogContextFuture<F>
74where
75    F: Future,
76{
77    type Output = F::Output;
78
79    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
80        let this = self.project();
81
82        let log_context = this
83            .log_context
84            .take()
85            .expect("An attempt to poll panicked future");
86
87        let guard = LogScope::enter(log_context);
88        let result = this.inner.poll(cx);
89        this.log_context.replace(guard.exit());
90
91        result
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use std::panic::AssertUnwindSafe;
98
99    use futures_util::FutureExt as _;
100    use pretty_assertions::assert_eq;
101
102    use super::FutureExt;
103    use crate::{LogContext, LogValue, scope::stack::SCOPE_STACK};
104
105    fn find_local_value(key: &str) -> Option<String> {
106        SCOPE_STACK.with(|stack| {
107            let frame = stack.top()?;
108            frame.0.local.find(key).map(ToString::to_string)
109        })
110    }
111
112    async fn check_nested_different_contexts(answer: u32) {
113        let context = LogContext::new().with_local_record("answer", answer);
114
115        async {
116            tokio::task::yield_now().await;
117
118            async {
119                tokio::task::yield_now().await;
120                assert_eq!(find_local_value("answer"), Some("None".to_string()));
121            }
122            .in_log_context(LogContext::new().with_local_record("answer", LogValue::null()))
123            .await;
124
125            tokio::task::yield_now().await;
126            assert_eq!(find_local_value("answer"), Some(answer.to_string()));
127        }
128        .in_log_context(context)
129        .await;
130
131        assert_eq!(find_local_value("answer"), None);
132    }
133
134    #[tokio::test]
135    async fn test_future_with_context() {
136        let context = LogContext::new().with_local_record("answer", 42);
137
138        async {
139            tokio::task::yield_now().await;
140            assert_eq!(find_local_value("answer"), Some("42".to_string()));
141        }
142        .in_log_context(context)
143        .await;
144
145        assert_eq!(find_local_value("answer"), None);
146    }
147
148    #[tokio::test]
149    async fn test_panicked_future() {
150        let context = LogContext::new().with_local_record("answer", 42);
151
152        AssertUnwindSafe(
153            async {
154                tokio::task::yield_now().await;
155                panic!("Goodbye cruel world");
156            }
157            .in_log_context(context),
158        )
159        .catch_unwind()
160        .await
161        .unwrap_err();
162
163        assert_eq!(find_local_value("answer"), None);
164    }
165
166    #[tokio::test]
167    async fn test_nested_future_with_common_context() {
168        let context = LogContext::new().with_local_record("answer", 42);
169
170        async {
171            tokio::task::yield_now().await;
172
173            async {
174                tokio::task::yield_now().await;
175                assert_eq!(find_local_value("answer"), Some("42".to_string()));
176            }
177            .await;
178
179            assert_eq!(find_local_value("answer"), Some("42".to_string()));
180        }
181        .in_log_context(context)
182        .await;
183
184        assert_eq!(find_local_value("answer"), None);
185    }
186
187    #[tokio::test]
188    async fn test_nested_future_with_different_contexts() {
189        check_nested_different_contexts(42).await;
190    }
191
192    #[tokio::test]
193    async fn test_join_multiple_tasks_single_thread() {
194        let tasks = (0..128).map(check_nested_different_contexts);
195        futures_util::future::join_all(tasks).await;
196    }
197
198    #[tokio::test]
199    async fn test_join_multiple_tasks_multi_thread() {
200        let handles = (0..64).map(|i| {
201            tokio::spawn(futures_util::future::join_all(
202                (0..128).map(|j| check_nested_different_contexts(j * i)),
203            ))
204        });
205
206        let results = futures_util::future::join_all(handles).await;
207        for result in results {
208            result.unwrap();
209        }
210    }
211}