1use std::task::Poll;
4
5use pin_project::pin_project;
6
7use crate::{LogContext, scope::LogScope};
8
9pub trait FutureExt: Sized + crate::private::Sealed {
15 fn in_log_context(self, context: LogContext) -> LogContextFuture<Self>;
44}
45
46impl<F> FutureExt for F
47where
48 F: Future,
49{
50 fn in_log_context(self, context: LogContext) -> LogContextFuture<Self> {
51 LogContextFuture {
52 inner: self,
53 log_context: Some(context),
54 }
55 }
56}
57
58#[pin_project]
66#[derive(Debug)]
67pub struct LogContextFuture<F> {
68 #[pin]
69 inner: F,
70 log_context: Option<LogContext>,
71}
72
73impl<F> Future for LogContextFuture<F>
74where
75 F: Future,
76{
77 type Output = F::Output;
78
79 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
80 let this = self.project();
81
82 let log_context = this
83 .log_context
84 .take()
85 .expect("An attempt to poll panicked future");
86
87 let guard = LogScope::enter(log_context);
88 let result = this.inner.poll(cx);
89 this.log_context.replace(guard.exit());
90
91 result
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use std::panic::AssertUnwindSafe;
98
99 use futures_util::FutureExt as _;
100 use pretty_assertions::assert_eq;
101
102 use super::FutureExt;
103 use crate::{LogContext, LogValue, scope::stack::SCOPE_STACK};
104
105 fn find_local_value(key: &str) -> Option<String> {
106 SCOPE_STACK.with(|stack| {
107 let frame = stack.top()?;
108 frame.0.local.find(key).map(ToString::to_string)
109 })
110 }
111
112 async fn check_nested_different_contexts(answer: u32) {
113 let context = LogContext::new().with_local_record("answer", answer);
114
115 async {
116 tokio::task::yield_now().await;
117
118 async {
119 tokio::task::yield_now().await;
120 assert_eq!(find_local_value("answer"), Some("None".to_string()));
121 }
122 .in_log_context(LogContext::new().with_local_record("answer", LogValue::null()))
123 .await;
124
125 tokio::task::yield_now().await;
126 assert_eq!(find_local_value("answer"), Some(answer.to_string()));
127 }
128 .in_log_context(context)
129 .await;
130
131 assert_eq!(find_local_value("answer"), None);
132 }
133
134 #[tokio::test]
135 async fn test_future_with_context() {
136 let context = LogContext::new().with_local_record("answer", 42);
137
138 async {
139 tokio::task::yield_now().await;
140 assert_eq!(find_local_value("answer"), Some("42".to_string()));
141 }
142 .in_log_context(context)
143 .await;
144
145 assert_eq!(find_local_value("answer"), None);
146 }
147
148 #[tokio::test]
149 async fn test_panicked_future() {
150 let context = LogContext::new().with_local_record("answer", 42);
151
152 AssertUnwindSafe(
153 async {
154 tokio::task::yield_now().await;
155 panic!("Goodbye cruel world");
156 }
157 .in_log_context(context),
158 )
159 .catch_unwind()
160 .await
161 .unwrap_err();
162
163 assert_eq!(find_local_value("answer"), None);
164 }
165
166 #[tokio::test]
167 async fn test_nested_future_with_common_context() {
168 let context = LogContext::new().with_local_record("answer", 42);
169
170 async {
171 tokio::task::yield_now().await;
172
173 async {
174 tokio::task::yield_now().await;
175 assert_eq!(find_local_value("answer"), Some("42".to_string()));
176 }
177 .await;
178
179 assert_eq!(find_local_value("answer"), Some("42".to_string()));
180 }
181 .in_log_context(context)
182 .await;
183
184 assert_eq!(find_local_value("answer"), None);
185 }
186
187 #[tokio::test]
188 async fn test_nested_future_with_different_contexts() {
189 check_nested_different_contexts(42).await;
190 }
191
192 #[tokio::test]
193 async fn test_join_multiple_tasks_single_thread() {
194 let tasks = (0..128).map(check_nested_different_contexts);
195 futures_util::future::join_all(tasks).await;
196 }
197
198 #[tokio::test]
199 async fn test_join_multiple_tasks_multi_thread() {
200 let handles = (0..64).map(|i| {
201 tokio::spawn(futures_util::future::join_all(
202 (0..128).map(|j| check_nested_different_contexts(j * i)),
203 ))
204 });
205
206 let results = futures_util::future::join_all(handles).await;
207 for result in results {
208 result.unwrap();
209 }
210 }
211}