async_context/
lib.rs

1//! A way to have some context within async functions. This can be used to implement React-like hooks.
2//!
3#![feature(box_into_inner)]
4
5use core::future::Future;
6use std::{any::Any, cell::RefCell, pin::Pin, sync::Mutex, task::Poll};
7
8use pin_project::pin_project;
9
10thread_local! {
11    static CTX: RefCell<Box<dyn Any>> = RefCell::new(Box::new(()));
12}
13
14/// Stores a future along with the async context provided for it.
15/// Create AsyncContext using [provide_async_context]
16/// Access the context using [with_async_context] or [with_async_context_mut]
17#[pin_project]
18pub struct AsyncContext<C, T, F>
19where
20    C: 'static,
21    F: Future<Output = T>,
22{
23    ctx: Mutex<Option<C>>,
24    #[pin]
25    future: F,
26}
27
28/// Wraps a future with some async context.
29/// Within the future, the provided context can be retrieved using [with_async_context] or [with_async_context_mut]
30pub fn provide_async_context<C, T, F>(ctx: C, future: F) -> AsyncContext<C, T, F>
31where
32    C: 'static,
33    F: Future<Output = T>,
34{
35    AsyncContext {
36        ctx: Mutex::new(Some(ctx)),
37        future,
38    }
39}
40
41impl<C, T, F> Future for AsyncContext<C, T, F>
42where
43    F: Future<Output = T>,
44{
45    type Output = (T, C);
46
47    fn poll(
48        self: std::pin::Pin<&mut Self>,
49        cx: &mut std::task::Context<'_>,
50    ) -> core::task::Poll<Self::Output> {
51        let ctx: C = self
52            .ctx
53            .lock()
54            .expect("Failed to lock context mutex")
55            .take()
56            .expect("No context found");
57        CTX.set(Box::new(ctx));
58        let projection = self.project();
59        let future: Pin<&mut F> = projection.future;
60        let poll = future.poll(cx);
61        let ctx: C = Box::into_inner(CTX.replace(Box::new(())).downcast().unwrap());
62        match poll {
63            Poll::Ready(value) => return Poll::Ready((value, ctx)),
64            Poll::Pending => {
65                projection
66                    .ctx
67                    .lock()
68                    .expect("Feiled to lock context mutex")
69                    .replace(ctx);
70                Poll::Pending
71            }
72        }
73    }
74}
75
76/// Retrieves immutable ref for async context in order to read values.
77pub fn with_async_context<C, F, R>(f: F) -> R
78where
79    F: FnOnce(Option<&C>) -> R,
80    C: 'static,
81{
82    return CTX.with(|value| f(value.borrow().downcast_ref::<C>()));
83}
84
85/// Retrieves mutable ref for async context in order to read values.
86pub fn with_async_context_mut<C, F, R>(f: F) -> R
87where
88    F: FnOnce(Option<&mut C>) -> R,
89    C: 'static,
90{
91    return CTX.with(|value| f(value.borrow_mut().downcast_mut::<C>()));
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[tokio::test]
99    async fn it_works() {
100        async fn runs_with_context() -> String {
101            let value = with_async_context(|value: Option<&String>| value.unwrap().clone());
102            value
103        }
104
105        let async_context = provide_async_context("foobar".to_string(), runs_with_context());
106
107        let (value, _) = async_context.await;
108
109        assert_eq!("foobar", value);
110    }
111}