rama_tower/
layer.rs

1use super::service_ready::Ready;
2use crate::core::Layer as TowerLayer;
3use crate::core::Service as TowerService;
4use rama_core::error::{BoxError, ErrorContext};
5use std::{
6    fmt,
7    marker::PhantomData,
8    ops::{Deref, DerefMut},
9    pin::Pin,
10    sync::Arc,
11};
12
13#[derive(Clone)]
14/// Wrapper type that can be used to smuggle a ctx into a request's extensions.
15pub struct ContextWrap<S>(pub rama_core::Context<S>);
16
17impl<S: fmt::Debug> fmt::Debug for ContextWrap<S> {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        f.debug_tuple("ContextWrap").field(&self.0).finish()
20    }
21}
22
23/// Trait to be implemented for any request that can "smuggle" [`Context`]s.
24///
25/// - if the `http` feature is enabled it will already be implemented for
26///   [`rama_http_types::Request`];
27/// - for types that do have this capability and you work with tower services
28///   which do not care about the specific type of the request that passes through it,
29///   you can make use of [`RequestStatePair`] using the tower map-request capabilities,
30///   to easily swap between the pair and direct request format.
31///
32/// [`Context`]: rama_core::Context
33pub trait ContextSmuggler<S> {
34    /// inject the context into the smuggler.
35    fn inject_ctx(&mut self, ctx: rama_core::Context<S>);
36
37    /// try to extract the smuggled context out of the smuggle,
38    /// which is only possible once.
39    fn try_extract_ctx(&mut self) -> Option<rama_core::Context<S>>;
40}
41
42#[cfg(feature = "http")]
43mod http {
44    use super::*;
45    use rama_http_types::Request;
46
47    impl<B, S: Clone + Send + Sync + 'static> ContextSmuggler<S> for Request<B> {
48        fn inject_ctx(&mut self, ctx: rama_core::Context<S>) {
49            let wrap = ContextWrap(ctx);
50            self.extensions_mut().insert(wrap);
51        }
52
53        fn try_extract_ctx(&mut self) -> Option<rama_core::Context<S>> {
54            let wrap: ContextWrap<_> = self.extensions_mut().remove()?;
55            Some(wrap.0)
56        }
57    }
58}
59
60/// Simple implementation of a [`ContextSmuggler`].
61pub struct RequestStatePair<R, S> {
62    /// the inner reuqest
63    pub request: R,
64    /// the storage to "smuggle" the ctx"
65    pub ctx: Option<rama_core::Context<S>>,
66}
67
68impl<R: fmt::Debug, S: fmt::Debug> fmt::Debug for RequestStatePair<R, S> {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        f.debug_struct("RequestStatePair")
71            .field("request", &self.request)
72            .field("ctx", &self.ctx)
73            .finish()
74    }
75}
76
77impl<R: Clone, S: Clone> Clone for RequestStatePair<R, S> {
78    fn clone(&self) -> Self {
79        Self {
80            request: self.request.clone(),
81            ctx: self.ctx.clone(),
82        }
83    }
84}
85
86impl<R, S> RequestStatePair<R, S> {
87    pub const fn new(req: R) -> Self {
88        Self {
89            request: req,
90            ctx: None,
91        }
92    }
93}
94
95impl<R, S> Deref for RequestStatePair<R, S> {
96    type Target = R;
97
98    fn deref(&self) -> &Self::Target {
99        &self.request
100    }
101}
102
103impl<R, S> DerefMut for RequestStatePair<R, S> {
104    fn deref_mut(&mut self) -> &mut Self::Target {
105        &mut self.request
106    }
107}
108
109impl<R, S> ContextSmuggler<S> for RequestStatePair<R, S> {
110    fn inject_ctx(&mut self, ctx: rama_core::Context<S>) {
111        self.ctx = Some(ctx);
112    }
113
114    fn try_extract_ctx(&mut self) -> Option<rama_core::Context<S>> {
115        self.ctx.take()
116    }
117}
118
119/// Adapter to use a [`tower::Layer`]-[`tower::Service`] as a [`rama::Layer`]-[`rama::Service`].
120///
121/// The produced [`tower::Service`] will be wrapped by a [`LayerServiceAdapter`] making it
122/// a fully compatible [`rama::Service`] ready to be plugged into a rama stack.
123///
124/// Note that you should use [`ServiceAdapter`] or [`SharedServiceAdapter`] for non-layer services.
125///
126/// [`tower::Service`]: tower_service::Service
127/// [`tower::Layer`]: tower_layer::Layer
128/// [`rama::Layer`]: crate::Layer
129/// [`rama::Service`]: crate::Service
130/// [`ServiceAdapter`]: super::ServiceAdapter.
131pub struct LayerAdapter<L, State> {
132    inner: L,
133    _state: PhantomData<fn() -> State>,
134}
135
136impl<L: Send + Sync + 'static, State> LayerAdapter<L, State> {
137    /// Adapt a [`tower::Layer`] into a [`rama::Layer`].
138    ///
139    /// See [`LayerAdapter`] for more information.
140    ///
141    /// [`tower::Layer`]: tower_layer::Layer
142    /// [`rama::Layer`]: crate::Layer
143    pub fn new(layer: L) -> Self {
144        Self {
145            inner: layer,
146            _state: PhantomData,
147        }
148    }
149
150    /// Consume itself to return the inner [`tower::Layer`] back.
151    ///
152    /// [`tower::Layer`]: tower_layer::Layer
153    pub fn into_inner(self) -> L {
154        self.inner
155    }
156}
157
158impl<L: fmt::Debug, State> fmt::Debug for LayerAdapter<L, State> {
159    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160        f.debug_struct("LayerAdapter")
161            .field("inner", &self.inner)
162            .finish()
163    }
164}
165
166/// Adapter to use a [`rama::Service`] as a [`tower::Service`]
167/// in functio nof [`tower::Layer`].
168///
169/// [`tower::Service`]: tower_service::Service
170/// [`tower::Layer`]: tower_layer::Layer
171/// [`rama::Service`]: rama_core::Service
172pub struct TowerAdapterService<S, State> {
173    inner: Arc<S>,
174    _state: PhantomData<fn() -> State>,
175}
176
177impl<S, State> TowerAdapterService<S, State> {
178    /// Reference to the inner [`rama::Service`].
179    ///
180    /// [`rama::Service`]: rama_core::Service
181    pub fn inner(&self) -> &S {
182        self.inner.as_ref()
183    }
184}
185
186impl<S: fmt::Debug, State> fmt::Debug for TowerAdapterService<S, State> {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        f.debug_struct("TowerAdapterService")
189            .field("inner", &self.inner)
190            .finish()
191    }
192}
193
194impl<S, State> Clone for TowerAdapterService<S, State> {
195    fn clone(&self) -> Self {
196        Self {
197            inner: self.inner.clone(),
198            _state: PhantomData,
199        }
200    }
201}
202
203/// Adapter to use a [`tower::Service`] as a [`rama::Service`]
204/// in function of [`tower::Layer`].
205///
206/// [`tower::Service`]: tower_service::Service
207/// [`tower::Layer`]: tower_layer::Layer
208/// [`rama::Service`]: rama_core::Service
209#[derive(Clone)]
210pub struct LayerAdapterService<T>(T);
211
212impl<T: fmt::Debug> fmt::Debug for LayerAdapterService<T> {
213    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214        f.debug_tuple("LayerAdapterService").field(&self.0).finish()
215    }
216}
217
218impl<L, S, State> rama_core::Layer<S> for LayerAdapter<L, State>
219where
220    L: TowerLayer<TowerAdapterService<S, State>, Service: Clone + Send + Sync + 'static>,
221{
222    type Service = LayerAdapterService<L::Service>;
223
224    fn layer(&self, inner: S) -> Self::Service {
225        let tower_svc = TowerAdapterService {
226            inner: Arc::new(inner),
227            _state: PhantomData,
228        };
229        let layered_tower_svc = self.inner.layer(tower_svc);
230        LayerAdapterService(layered_tower_svc)
231    }
232}
233
234impl<T, State, Request> TowerService<Request> for TowerAdapterService<T, State>
235where
236    T: rama_core::Service<State, Request, Error: Into<BoxError>>,
237    State: Clone + Send + Sync + 'static,
238    Request: ContextSmuggler<State> + Send + 'static,
239{
240    type Response = T::Response;
241    type Error = BoxError;
242    type Future =
243        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
244
245    fn poll_ready(
246        &mut self,
247        _cx: &mut std::task::Context<'_>,
248    ) -> std::task::Poll<Result<(), Self::Error>> {
249        std::task::Poll::Ready(Ok(()))
250    }
251
252    fn call(&mut self, mut req: Request) -> Self::Future {
253        let svc = self.inner.clone();
254        Box::pin(async move {
255            let ctx: rama_core::Context<State> = req
256                .try_extract_ctx()
257                .context("extract context from req smuggler")?;
258            svc.serve(ctx, req).await.map_err(Into::into)
259        })
260    }
261}
262
263impl<T, State, Request> rama_core::Service<State, Request> for LayerAdapterService<T>
264where
265    T: TowerService<Request, Response: Send + 'static, Error: Send + 'static, Future: Send>
266        + Clone
267        + Send
268        + Sync
269        + 'static,
270    State: Clone + Send + Sync + 'static,
271    Request: ContextSmuggler<State> + Send + 'static,
272{
273    type Response = T::Response;
274    type Error = T::Error;
275
276    fn serve(
277        &self,
278        ctx: rama_core::Context<State>,
279        mut req: Request,
280    ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
281        req.inject_ctx(ctx);
282        let svc = self.0.clone();
283        async move {
284            let mut svc = svc;
285            let ready = Ready::new(&mut svc);
286            let ready_svc = ready.await?;
287            ready_svc.call(req).await
288        }
289    }
290}