1#![feature(box_into_inner)]
4
5use core::future::Future;
6use std::{any::Any, cell::RefCell, pin::Pin, sync::Mutex, task::Poll};
7
8use pin_project::pin_project;
9
10thread_local! {
11 static CTX: RefCell<Box<dyn Any>> = RefCell::new(Box::new(()));
12}
13
14#[pin_project]
18pub struct AsyncContext<C, T, F>
19where
20 C: 'static,
21 F: Future<Output = T>,
22{
23 ctx: Mutex<Option<C>>,
24 #[pin]
25 future: F,
26}
27
28pub fn provide_async_context<C, T, F>(ctx: C, future: F) -> AsyncContext<C, T, F>
31where
32 C: 'static,
33 F: Future<Output = T>,
34{
35 AsyncContext {
36 ctx: Mutex::new(Some(ctx)),
37 future,
38 }
39}
40
41impl<C, T, F> Future for AsyncContext<C, T, F>
42where
43 F: Future<Output = T>,
44{
45 type Output = (T, C);
46
47 fn poll(
48 self: std::pin::Pin<&mut Self>,
49 cx: &mut std::task::Context<'_>,
50 ) -> core::task::Poll<Self::Output> {
51 let ctx: C = self
52 .ctx
53 .lock()
54 .expect("Failed to lock context mutex")
55 .take()
56 .expect("No context found");
57 CTX.set(Box::new(ctx));
58 let projection = self.project();
59 let future: Pin<&mut F> = projection.future;
60 let poll = future.poll(cx);
61 let ctx: C = Box::into_inner(CTX.replace(Box::new(())).downcast().unwrap());
62 match poll {
63 Poll::Ready(value) => return Poll::Ready((value, ctx)),
64 Poll::Pending => {
65 projection
66 .ctx
67 .lock()
68 .expect("Feiled to lock context mutex")
69 .replace(ctx);
70 Poll::Pending
71 }
72 }
73 }
74}
75
76pub fn with_async_context<C, F, R>(f: F) -> R
78where
79 F: FnOnce(Option<&C>) -> R,
80 C: 'static,
81{
82 return CTX.with(|value| f(value.borrow().downcast_ref::<C>()));
83}
84
85pub fn with_async_context_mut<C, F, R>(f: F) -> R
87where
88 F: FnOnce(Option<&mut C>) -> R,
89 C: 'static,
90{
91 return CTX.with(|value| f(value.borrow_mut().downcast_mut::<C>()));
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[tokio::test]
99 async fn it_works() {
100 async fn runs_with_context() -> String {
101 let value = with_async_context(|value: Option<&String>| value.unwrap().clone());
102 value
103 }
104
105 let async_context = provide_async_context("foobar".to_string(), runs_with_context());
106
107 let (value, _) = async_context.await;
108
109 assert_eq!("foobar", value);
110 }
111}