context_logger/
future.rs

1//! Future types.
2
3use std::task::Poll;
4
5use pin_project::pin_project;
6
7use crate::LogContext;
8
9/// Extension trait for futures to propagate contextual logging information.
10///
11/// This trait adds ability to attach a [`LogContext`] for any [`Future`],
12/// ensuring that logs emitted during the future's execution will include
13/// the contextual properties even if the future is polled across different threads.
14pub trait FutureExt: Sized + private::Sealed {
15    /// Attaches a log context to this future.
16    ///
17    /// The attached [context](LogContext) will be activated every time the instrumented
18    /// future is polled.
19    ///
20    /// # Examples
21    ///
22    /// ```
23    /// use context_logger::{LogContext, FutureExt};
24    /// use log::info;
25    ///
26    /// async fn process_user_data(user_id: u64) {
27    ///     // Create a context with user information
28    ///     let context = LogContext::new()
29    ///         .record("user_id", user_id)
30    ///         .record("operation", "process_data");
31    ///
32    ///     async {
33    ///         info!("Starting user data processing"); // Will include context
34    ///
35    ///         // Do some async work...
36    ///
37    ///         info!("User data processing complete"); // Still includes context
38    ///     }
39    ///     .in_log_context(context)
40    ///     .await;
41    /// }
42    /// ```
43    fn in_log_context(self, context: LogContext) -> LogContextFuture<Self>;
44}
45
46impl<F> FutureExt for F
47where
48    F: Future,
49{
50    fn in_log_context(self, context: LogContext) -> LogContextFuture<Self> {
51        LogContextFuture {
52            inner: self,
53            log_context: Some(context),
54        }
55    }
56}
57
58/// A future with an attached logging context.
59///
60/// This type is created by the [`FutureExt::in_log_context`].
61///
62/// # Note
63///
64/// If the wrapped future will panic, the next `poll` invocation will panic unconditionally.
65#[pin_project]
66#[derive(Debug)]
67pub struct LogContextFuture<F> {
68    #[pin]
69    inner: F,
70    log_context: Option<LogContext>,
71}
72
73impl<F> Future for LogContextFuture<F>
74where
75    F: Future,
76{
77    type Output = F::Output;
78
79    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
80        let this = self.project();
81
82        let log_context = this
83            .log_context
84            .take()
85            .expect("An attempt to poll panicked future");
86
87        let guard = log_context.enter();
88        let result = this.inner.poll(cx);
89        this.log_context.replace(guard.exit());
90
91        result
92    }
93}
94
95mod private {
96    pub trait Sealed {}
97
98    impl<F: Future> Sealed for F {}
99}
100
101#[cfg(test)]
102mod tests {
103    use std::panic::AssertUnwindSafe;
104
105    use futures_util::FutureExt as _;
106    use pretty_assertions::assert_eq;
107
108    use super::FutureExt;
109    use crate::{ContextValue, LogContext, stack::CONTEXT_STACK};
110
111    fn get_property(idx: usize) -> Option<String> {
112        CONTEXT_STACK.with(|stack| {
113            let top = stack.top();
114            top.map(|properties| properties[idx].1.to_string())
115        })
116    }
117
118    async fn check_nested_different_contexts(answer: u32) {
119        let context = LogContext::new().record("answer", answer);
120
121        async {
122            tokio::task::yield_now().await;
123
124            async {
125                tokio::task::yield_now().await;
126                assert_eq!(get_property(0), Some("None".to_string()));
127            }
128            .in_log_context(LogContext::new().record("answer", ContextValue::null()))
129            .await;
130
131            tokio::task::yield_now().await;
132            assert_eq!(get_property(0), Some(answer.to_string()));
133        }
134        .in_log_context(context)
135        .await;
136
137        assert_eq!(get_property(0), None);
138    }
139
140    #[tokio::test]
141    async fn test_future_with_context() {
142        let context = LogContext::new().record("answer", 42);
143
144        async {
145            tokio::task::yield_now().await;
146            assert_eq!(get_property(0), Some("42".to_string()));
147        }
148        .in_log_context(context)
149        .await;
150
151        assert_eq!(get_property(0), None);
152    }
153
154    #[tokio::test]
155    async fn test_panicked_future() {
156        let context = LogContext::new().record("answer", 42);
157
158        AssertUnwindSafe(
159            async {
160                tokio::task::yield_now().await;
161                panic!("Goodbye cruel world");
162            }
163            .in_log_context(context),
164        )
165        .catch_unwind()
166        .await
167        .unwrap_err();
168
169        assert_eq!(get_property(0), None);
170    }
171
172    #[tokio::test]
173    async fn test_nested_future_with_common_context() {
174        let context = LogContext::new().record("answer", 42);
175
176        async {
177            tokio::task::yield_now().await;
178
179            async {
180                tokio::task::yield_now().await;
181                assert_eq!(get_property(0), Some("42".to_string()));
182            }
183            .await;
184
185            assert_eq!(get_property(0), Some("42".to_string()));
186        }
187        .in_log_context(context)
188        .await;
189
190        assert_eq!(get_property(0), None);
191    }
192
193    #[tokio::test]
194    async fn test_nested_future_with_different_contexts() {
195        check_nested_different_contexts(42).await;
196    }
197
198    #[tokio::test]
199    async fn test_join_multiple_tasks_single_thread() {
200        let tasks = (0..128).map(check_nested_different_contexts);
201        futures_util::future::join_all(tasks).await;
202    }
203
204    #[tokio::test]
205    async fn test_join_multiple_tasks_multi_thread() {
206        let handles = (0..64).map(|i| {
207            tokio::spawn(futures_util::future::join_all(
208                (0..128).map(|j| check_nested_different_contexts(j * i)),
209            ))
210        });
211
212        let results = futures_util::future::join_all(handles).await;
213        for result in results {
214            result.unwrap();
215        }
216    }
217}