rama_core/layer/
add_extension.rs

1//! Middleware that clones a value into the incoming [Context].
2//!
3//! [Context]: https://docs.rs/rama/latest/rama/context/struct.Context.html
4//!
5//! # Example
6//!
7//! ```
8//! use std::{sync::Arc, convert::Infallible};
9//!
10//! use rama_core::{Context, Service, Layer, service::service_fn};
11//! use rama_core::layer::add_extension::AddExtensionLayer;
12//! use rama_core::error::BoxError;
13//!
14//! # struct DatabaseConnectionPool;
15//! # impl DatabaseConnectionPool {
16//! #     fn new() -> DatabaseConnectionPool { DatabaseConnectionPool }
17//! # }
18//! #
19//! // Shared state across all request handlers --- in this case, a pool of database connections.
20//! struct State {
21//!     pool: DatabaseConnectionPool,
22//! }
23//!
24//! async fn handle<S>(ctx: Context<S>, req: ()) -> Result<(), Infallible>
25//! where
26//!    S: Clone + Send + Sync + 'static,
27//! {
28//!     // Grab the state from the request extensions.
29//!     let state = ctx.get::<Arc<State>>().unwrap();
30//!
31//!     Ok(req)
32//! }
33//!
34//! # #[tokio::main]
35//! # async fn main() -> Result<(), BoxError> {
36//! // Construct the shared state.
37//! let state = State {
38//!     pool: DatabaseConnectionPool::new(),
39//! };
40//!
41//! let mut service = (
42//!     // Share an `Arc<State>` with all requests.
43//!     AddExtensionLayer::new(Arc::new(state)),
44//! ).into_layer(service_fn(handle));
45//!
46//! // Call the service.
47//! let response = service
48//!     .serve(Context::default(), ())
49//!     .await?;
50//! # Ok(())
51//! # }
52//! ```
53
54use crate::{Context, Layer, Service};
55use rama_utils::macros::define_inner_service_accessors;
56use std::fmt;
57
58/// [`Layer`] for adding some shareable value to incoming [Context].
59///
60/// [Context]: https://docs.rs/rama/latest/rama/context/struct.Context.html
61pub struct AddExtensionLayer<T> {
62    value: T,
63}
64
65impl<T: fmt::Debug> std::fmt::Debug for AddExtensionLayer<T> {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("AddExtensionLayer")
68            .field("value", &self.value)
69            .finish()
70    }
71}
72
73impl<T> Clone for AddExtensionLayer<T>
74where
75    T: Clone,
76{
77    fn clone(&self) -> Self {
78        Self {
79            value: self.value.clone(),
80        }
81    }
82}
83
84impl<T> AddExtensionLayer<T> {
85    /// Create a new [`AddExtensionLayer`].
86    pub const fn new(value: T) -> Self {
87        AddExtensionLayer { value }
88    }
89}
90
91impl<S, T> Layer<S> for AddExtensionLayer<T>
92where
93    T: Clone,
94{
95    type Service = AddExtension<S, T>;
96
97    fn layer(&self, inner: S) -> Self::Service {
98        AddExtension {
99            inner,
100            value: self.value.clone(),
101        }
102    }
103
104    fn into_layer(self, inner: S) -> Self::Service {
105        AddExtension {
106            inner,
107            value: self.value,
108        }
109    }
110}
111
112/// Middleware for adding some shareable value to incoming [Context].
113///
114/// [Context]: https://docs.rs/rama/latest/rama/context/struct.Context.html
115pub struct AddExtension<S, T> {
116    inner: S,
117    value: T,
118}
119
120impl<S: fmt::Debug, T: fmt::Debug> std::fmt::Debug for AddExtension<S, T> {
121    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122        f.debug_struct("AddExtension")
123            .field("inner", &self.inner)
124            .field("value", &self.value)
125            .finish()
126    }
127}
128
129impl<S, T> Clone for AddExtension<S, T>
130where
131    S: Clone,
132    T: Clone,
133{
134    fn clone(&self) -> Self {
135        Self {
136            inner: self.inner.clone(),
137            value: self.value.clone(),
138        }
139    }
140}
141
142impl<S, T> AddExtension<S, T> {
143    /// Create a new [`AddExtension`].
144    pub const fn new(inner: S, value: T) -> Self {
145        Self { inner, value }
146    }
147
148    define_inner_service_accessors!();
149}
150
151impl<State, Request, S, T> Service<State, Request> for AddExtension<S, T>
152where
153    State: Clone + Send + Sync + 'static,
154    Request: Send + 'static,
155    S: Service<State, Request>,
156    T: Clone + Send + Sync + 'static,
157{
158    type Response = S::Response;
159    type Error = S::Error;
160
161    fn serve(
162        &self,
163        mut ctx: Context<State>,
164        req: Request,
165    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
166        ctx.insert(self.value.clone());
167        self.inner.serve(ctx, req)
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use crate::{Context, service::service_fn};
175    use std::{convert::Infallible, sync::Arc};
176
177    struct State(i32);
178
179    #[tokio::test]
180    async fn basic() {
181        let state = Arc::new(State(1));
182
183        let svc = AddExtensionLayer::new(state).into_layer(service_fn(
184            async |ctx: Context<()>, _req: ()| {
185                let state = ctx.get::<Arc<State>>().unwrap();
186                Ok::<_, Infallible>(state.0)
187            },
188        ));
189
190        let res = svc.serve(Context::default(), ()).await.unwrap();
191
192        assert_eq!(1, res);
193    }
194}