1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
//! A way to have some context within async functions. This can be used to implement React-like hooks.
//!
#![feature(box_into_inner)]

use core::future::Future;
use std::{any::Any, cell::RefCell, pin::Pin, sync::Mutex, task::Poll};

use pin_project::pin_project;

thread_local! {
    static CTX: RefCell<Box<dyn Any>> = RefCell::new(Box::new(()));
}

/// Stores a future along with the async context provided for it.
/// Create AsyncContext using [provide_async_context]
/// Access the context using [with_async_context] or [with_async_context_mut]
#[pin_project]
pub struct AsyncContext<C, T, F>
where
    C: 'static,
    F: Future<Output = T>,
{
    ctx: Mutex<Option<C>>,
    #[pin]
    future: F,
}

/// Wraps a future with some async context.
/// Within the future, the provided context can be retrieved using [with_async_context] or [with_async_context_mut]
pub fn provide_async_context<C, T, F>(ctx: C, future: F) -> AsyncContext<C, T, F>
where
    C: 'static,
    F: Future<Output = T>,
{
    AsyncContext {
        ctx: Mutex::new(Some(ctx)),
        future,
    }
}

impl<C, T, F> Future for AsyncContext<C, T, F>
where
    F: Future<Output = T>,
{
    type Output = (T, C);

    fn poll(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> core::task::Poll<Self::Output> {
        let ctx: C = self
            .ctx
            .lock()
            .expect("Failed to lock context mutex")
            .take()
            .expect("No context found");
        CTX.set(Box::new(ctx));
        let projection = self.project();
        let future: Pin<&mut F> = projection.future;
        let poll = future.poll(cx);
        let ctx: C = Box::into_inner(CTX.replace(Box::new(())).downcast().unwrap());
        match poll {
            Poll::Ready(value) => return Poll::Ready((value, ctx)),
            Poll::Pending => {
                projection
                    .ctx
                    .lock()
                    .expect("Feiled to lock context mutex")
                    .replace(ctx);
                Poll::Pending
            }
        }
    }
}

/// Retrieves immutable ref for async context in order to read values.
pub fn with_async_context<C, F, R>(f: F) -> R
where
    F: FnOnce(Option<&C>) -> R,
    C: 'static,
{
    return CTX.with(|value| f(value.borrow().downcast_ref::<C>()));
}

/// Retrieves mutable ref for async context in order to read values.
pub fn with_async_context_mut<C, F, R>(f: F) -> R
where
    F: FnOnce(Option<&mut C>) -> R,
    C: 'static,
{
    return CTX.with(|value| f(value.borrow_mut().downcast_mut::<C>()));
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn it_works() {
        async fn runs_with_context() -> String {
            let value = with_async_context(|value: Option<&String>| value.unwrap().clone());
            value
        }

        let async_context = provide_async_context("foobar".to_string(), runs_with_context());

        let (value, _) = async_context.await;

        assert_eq!("foobar", value);
    }
}