1use std::task::Poll;
4
5use pin_project::pin_project;
6
7use crate::LogContext;
8
9pub trait FutureExt: Sized + private::Sealed {
15 fn in_log_context(self, context: LogContext) -> LogContextFuture<Self>;
44}
45
46impl<F> FutureExt for F
47where
48 F: Future,
49{
50 fn in_log_context(self, context: LogContext) -> LogContextFuture<Self> {
51 LogContextFuture {
52 inner: self,
53 log_context: Some(context),
54 }
55 }
56}
57
58#[pin_project]
66#[derive(Debug)]
67pub struct LogContextFuture<F> {
68 #[pin]
69 inner: F,
70 log_context: Option<LogContext>,
71}
72
73impl<F> Future for LogContextFuture<F>
74where
75 F: Future,
76{
77 type Output = F::Output;
78
79 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
80 let this = self.project();
81
82 let log_context = this
83 .log_context
84 .take()
85 .expect("An attempt to poll panicked future");
86
87 let guard = log_context.enter();
88 let result = this.inner.poll(cx);
89 this.log_context.replace(guard.exit());
90
91 result
92 }
93}
94
95mod private {
96 pub trait Sealed {}
97
98 impl<F: Future> Sealed for F {}
99}
100
101#[cfg(test)]
102mod tests {
103 use std::panic::AssertUnwindSafe;
104
105 use futures_util::FutureExt as _;
106 use pretty_assertions::assert_eq;
107
108 use super::FutureExt;
109 use crate::{ContextValue, LogContext, stack::CONTEXT_STACK};
110
111 fn get_property(idx: usize) -> Option<String> {
112 CONTEXT_STACK.with(|stack| {
113 let top = stack.top();
114 top.map(|properties| properties[idx].1.to_string())
115 })
116 }
117
118 async fn check_nested_different_contexts(answer: u32) {
119 let context = LogContext::new().record("answer", answer);
120
121 async {
122 tokio::task::yield_now().await;
123
124 async {
125 tokio::task::yield_now().await;
126 assert_eq!(get_property(0), Some("None".to_string()));
127 }
128 .in_log_context(LogContext::new().record("answer", ContextValue::null()))
129 .await;
130
131 tokio::task::yield_now().await;
132 assert_eq!(get_property(0), Some(answer.to_string()));
133 }
134 .in_log_context(context)
135 .await;
136
137 assert_eq!(get_property(0), None);
138 }
139
140 #[tokio::test]
141 async fn test_future_with_context() {
142 let context = LogContext::new().record("answer", 42);
143
144 async {
145 tokio::task::yield_now().await;
146 assert_eq!(get_property(0), Some("42".to_string()));
147 }
148 .in_log_context(context)
149 .await;
150
151 assert_eq!(get_property(0), None);
152 }
153
154 #[tokio::test]
155 async fn test_panicked_future() {
156 let context = LogContext::new().record("answer", 42);
157
158 AssertUnwindSafe(
159 async {
160 tokio::task::yield_now().await;
161 panic!("Goodbye cruel world");
162 }
163 .in_log_context(context),
164 )
165 .catch_unwind()
166 .await
167 .unwrap_err();
168
169 assert_eq!(get_property(0), None);
170 }
171
172 #[tokio::test]
173 async fn test_nested_future_with_common_context() {
174 let context = LogContext::new().record("answer", 42);
175
176 async {
177 tokio::task::yield_now().await;
178
179 async {
180 tokio::task::yield_now().await;
181 assert_eq!(get_property(0), Some("42".to_string()));
182 }
183 .await;
184
185 assert_eq!(get_property(0), Some("42".to_string()));
186 }
187 .in_log_context(context)
188 .await;
189
190 assert_eq!(get_property(0), None);
191 }
192
193 #[tokio::test]
194 async fn test_nested_future_with_different_contexts() {
195 check_nested_different_contexts(42).await;
196 }
197
198 #[tokio::test]
199 async fn test_join_multiple_tasks_single_thread() {
200 let tasks = (0..128).map(check_nested_different_contexts);
201 futures_util::future::join_all(tasks).await;
202 }
203
204 #[tokio::test]
205 async fn test_join_multiple_tasks_multi_thread() {
206 let handles = (0..64).map(|i| {
207 tokio::spawn(futures_util::future::join_all(
208 (0..128).map(|j| check_nested_different_contexts(j * i)),
209 ))
210 });
211
212 let results = futures_util::future::join_all(handles).await;
213 for result in results {
214 result.unwrap();
215 }
216 }
217}