rama_core/layer/
get_extension.rs

1//! Middleware that gets called with a clone of the value of to given type if it is available in the current [`Context`].
2//!
3//! [Context]: https://docs.rs/rama/latest/rama/context/struct.Context.html
4
5use crate::{Context, Layer, Service};
6use rama_utils::macros::define_inner_service_accessors;
7use std::{fmt, marker::PhantomData};
8
9/// [`Layer`] for adding some shareable value to incoming [Context].
10///
11/// [Context]: https://docs.rs/rama/latest/rama/context/struct.Context.html
12pub struct GetExtensionLayer<T, Fut, F> {
13    callback: F,
14    _phantom: PhantomData<fn(T) -> Fut>,
15}
16
17impl<T, Fut, F: fmt::Debug> std::fmt::Debug for GetExtensionLayer<T, Fut, F> {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("GetExtensionLayer")
20            .field("callback", &self.callback)
21            .field(
22                "_phantom",
23                &format_args!("{}", std::any::type_name::<fn(T) -> Fut>()),
24            )
25            .finish()
26    }
27}
28
29impl<T, Fut, F> Clone for GetExtensionLayer<T, Fut, F>
30where
31    F: Clone,
32{
33    fn clone(&self) -> Self {
34        Self {
35            callback: self.callback.clone(),
36            _phantom: PhantomData,
37        }
38    }
39}
40
41impl<T, Fut, F> GetExtensionLayer<T, Fut, F>
42where
43    F: FnOnce(T) -> Fut + Clone + Send + Sync + 'static,
44    Fut: Future<Output = ()> + Send + 'static,
45{
46    /// Create a new [`GetExtensionLayer`].
47    pub const fn new(callback: F) -> Self {
48        GetExtensionLayer {
49            callback,
50            _phantom: PhantomData,
51        }
52    }
53}
54
55impl<S, T, Fut, F> Layer<S> for GetExtensionLayer<T, Fut, F>
56where
57    F: Clone,
58{
59    type Service = GetExtension<S, T, Fut, F>;
60
61    fn layer(&self, inner: S) -> Self::Service {
62        GetExtension {
63            inner,
64            callback: self.callback.clone(),
65            _phantom: PhantomData,
66        }
67    }
68
69    fn into_layer(self, inner: S) -> Self::Service {
70        GetExtension {
71            inner,
72            callback: self.callback,
73            _phantom: PhantomData,
74        }
75    }
76}
77
78/// Middleware for adding some shareable value to incoming [Context].
79///
80/// [Context]: https://docs.rs/rama/latest/rama/context/struct.Context.html
81pub struct GetExtension<S, T, Fut, F> {
82    inner: S,
83    callback: F,
84    _phantom: PhantomData<fn(T) -> Fut>,
85}
86
87impl<S: fmt::Debug, T, Fut, F: fmt::Debug> std::fmt::Debug for GetExtension<S, T, Fut, F> {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        f.debug_struct("GetExtension")
90            .field("inner", &self.inner)
91            .field("callback", &self.callback)
92            .field(
93                "_phantom",
94                &format_args!("{}", std::any::type_name::<fn(T) -> Fut>()),
95            )
96            .finish()
97    }
98}
99
100impl<S, T, Fut, F> Clone for GetExtension<S, T, Fut, F>
101where
102    S: Clone,
103    F: Clone,
104{
105    fn clone(&self) -> Self {
106        Self {
107            inner: self.inner.clone(),
108            callback: self.callback.clone(),
109            _phantom: PhantomData,
110        }
111    }
112}
113
114impl<S, T, Fut, F> GetExtension<S, T, Fut, F> {
115    /// Create a new [`GetExtension`].
116    pub const fn new(inner: S, callback: F) -> Self
117    where
118        F: FnOnce(T) -> Fut + Clone + Send + Sync + 'static,
119        Fut: Future<Output = ()> + Send + 'static,
120    {
121        Self {
122            inner,
123            callback,
124            _phantom: PhantomData,
125        }
126    }
127
128    define_inner_service_accessors!();
129}
130
131impl<State, Request, S, T, Fut, F> Service<State, Request> for GetExtension<S, T, Fut, F>
132where
133    State: Clone + Send + Sync + 'static,
134    Request: Send + 'static,
135    S: Service<State, Request>,
136    T: Clone + Send + Sync + 'static,
137    F: FnOnce(T) -> Fut + Clone + Send + Sync + 'static,
138    Fut: Future<Output = ()> + Send + 'static,
139{
140    type Response = S::Response;
141    type Error = S::Error;
142
143    async fn serve(
144        &self,
145        ctx: Context<State>,
146        req: Request,
147    ) -> Result<Self::Response, Self::Error> {
148        if let Some(value) = ctx.get::<T>() {
149            let value = value.clone();
150            (self.callback.clone())(value).await;
151        }
152        self.inner.serve(ctx, req).await
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use crate::{Context, service::service_fn};
160    use std::{convert::Infallible, sync::Arc};
161
162    #[derive(Debug, Clone)]
163    struct State(i32);
164
165    #[tokio::test]
166    async fn get_extension_basic() {
167        let value = Arc::new(std::sync::atomic::AtomicI32::new(0));
168
169        let cloned_value = value.clone();
170        let svc = GetExtensionLayer::new(async move |state: State| {
171            cloned_value.store(state.0, std::sync::atomic::Ordering::Release);
172        })
173        .into_layer(service_fn(async |ctx: Context<()>, _req: ()| {
174            let state = ctx.get::<State>().unwrap();
175            Ok::<_, Infallible>(state.0)
176        }));
177
178        let mut ctx = Context::default();
179        ctx.insert(State(42));
180
181        let res = svc.serve(ctx, ()).await.unwrap();
182        assert_eq!(42, res);
183
184        let value = value.load(std::sync::atomic::Ordering::Acquire);
185        assert_eq!(42, value);
186    }
187}