1use super::service_ready::Ready;
2use crate::core::Layer as TowerLayer;
3use crate::core::Service as TowerService;
4use rama_core::error::{BoxError, ErrorContext};
5use std::{
6 fmt,
7 marker::PhantomData,
8 ops::{Deref, DerefMut},
9 pin::Pin,
10 sync::Arc,
11};
12
13#[derive(Clone)]
14pub struct ContextWrap<S>(pub rama_core::Context<S>);
16
17impl<S: fmt::Debug> fmt::Debug for ContextWrap<S> {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 f.debug_tuple("ContextWrap").field(&self.0).finish()
20 }
21}
22
23pub trait ContextSmuggler<S> {
34 fn inject_ctx(&mut self, ctx: rama_core::Context<S>);
36
37 fn try_extract_ctx(&mut self) -> Option<rama_core::Context<S>>;
40}
41
42#[cfg(feature = "http")]
43mod http {
44 use super::*;
45 use rama_http_types::Request;
46
47 impl<B, S: Clone + Send + Sync + 'static> ContextSmuggler<S> for Request<B> {
48 fn inject_ctx(&mut self, ctx: rama_core::Context<S>) {
49 let wrap = ContextWrap(ctx);
50 self.extensions_mut().insert(wrap);
51 }
52
53 fn try_extract_ctx(&mut self) -> Option<rama_core::Context<S>> {
54 let wrap: ContextWrap<_> = self.extensions_mut().remove()?;
55 Some(wrap.0)
56 }
57 }
58}
59
60pub struct RequestStatePair<R, S> {
62 pub request: R,
64 pub ctx: Option<rama_core::Context<S>>,
66}
67
68impl<R: fmt::Debug, S: fmt::Debug> fmt::Debug for RequestStatePair<R, S> {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 f.debug_struct("RequestStatePair")
71 .field("request", &self.request)
72 .field("ctx", &self.ctx)
73 .finish()
74 }
75}
76
77impl<R: Clone, S: Clone> Clone for RequestStatePair<R, S> {
78 fn clone(&self) -> Self {
79 Self {
80 request: self.request.clone(),
81 ctx: self.ctx.clone(),
82 }
83 }
84}
85
86impl<R, S> RequestStatePair<R, S> {
87 pub const fn new(req: R) -> Self {
88 Self {
89 request: req,
90 ctx: None,
91 }
92 }
93}
94
95impl<R, S> Deref for RequestStatePair<R, S> {
96 type Target = R;
97
98 fn deref(&self) -> &Self::Target {
99 &self.request
100 }
101}
102
103impl<R, S> DerefMut for RequestStatePair<R, S> {
104 fn deref_mut(&mut self) -> &mut Self::Target {
105 &mut self.request
106 }
107}
108
109impl<R, S> ContextSmuggler<S> for RequestStatePair<R, S> {
110 fn inject_ctx(&mut self, ctx: rama_core::Context<S>) {
111 self.ctx = Some(ctx);
112 }
113
114 fn try_extract_ctx(&mut self) -> Option<rama_core::Context<S>> {
115 self.ctx.take()
116 }
117}
118
119pub struct LayerAdapter<L, State> {
132 inner: L,
133 _state: PhantomData<fn() -> State>,
134}
135
136impl<L: Send + Sync + 'static, State> LayerAdapter<L, State> {
137 pub fn new(layer: L) -> Self {
144 Self {
145 inner: layer,
146 _state: PhantomData,
147 }
148 }
149
150 pub fn into_inner(self) -> L {
154 self.inner
155 }
156}
157
158impl<L: fmt::Debug, State> fmt::Debug for LayerAdapter<L, State> {
159 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160 f.debug_struct("LayerAdapter")
161 .field("inner", &self.inner)
162 .finish()
163 }
164}
165
166pub struct TowerAdapterService<S, State> {
173 inner: Arc<S>,
174 _state: PhantomData<fn() -> State>,
175}
176
177impl<S, State> TowerAdapterService<S, State> {
178 pub fn inner(&self) -> &S {
182 self.inner.as_ref()
183 }
184}
185
186impl<S: fmt::Debug, State> fmt::Debug for TowerAdapterService<S, State> {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 f.debug_struct("TowerAdapterService")
189 .field("inner", &self.inner)
190 .finish()
191 }
192}
193
194impl<S, State> Clone for TowerAdapterService<S, State> {
195 fn clone(&self) -> Self {
196 Self {
197 inner: self.inner.clone(),
198 _state: PhantomData,
199 }
200 }
201}
202
203#[derive(Clone)]
210pub struct LayerAdapterService<T>(T);
211
212impl<T: fmt::Debug> fmt::Debug for LayerAdapterService<T> {
213 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214 f.debug_tuple("LayerAdapterService").field(&self.0).finish()
215 }
216}
217
218impl<L, S, State> rama_core::Layer<S> for LayerAdapter<L, State>
219where
220 L: TowerLayer<TowerAdapterService<S, State>, Service: Clone + Send + Sync + 'static>,
221{
222 type Service = LayerAdapterService<L::Service>;
223
224 fn layer(&self, inner: S) -> Self::Service {
225 let tower_svc = TowerAdapterService {
226 inner: Arc::new(inner),
227 _state: PhantomData,
228 };
229 let layered_tower_svc = self.inner.layer(tower_svc);
230 LayerAdapterService(layered_tower_svc)
231 }
232}
233
234impl<T, State, Request> TowerService<Request> for TowerAdapterService<T, State>
235where
236 T: rama_core::Service<State, Request, Error: Into<BoxError>>,
237 State: Clone + Send + Sync + 'static,
238 Request: ContextSmuggler<State> + Send + 'static,
239{
240 type Response = T::Response;
241 type Error = BoxError;
242 type Future =
243 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
244
245 fn poll_ready(
246 &mut self,
247 _cx: &mut std::task::Context<'_>,
248 ) -> std::task::Poll<Result<(), Self::Error>> {
249 std::task::Poll::Ready(Ok(()))
250 }
251
252 fn call(&mut self, mut req: Request) -> Self::Future {
253 let svc = self.inner.clone();
254 Box::pin(async move {
255 let ctx: rama_core::Context<State> = req
256 .try_extract_ctx()
257 .context("extract context from req smuggler")?;
258 svc.serve(ctx, req).await.map_err(Into::into)
259 })
260 }
261}
262
263impl<T, State, Request> rama_core::Service<State, Request> for LayerAdapterService<T>
264where
265 T: TowerService<Request, Response: Send + 'static, Error: Send + 'static, Future: Send>
266 + Clone
267 + Send
268 + Sync
269 + 'static,
270 State: Clone + Send + Sync + 'static,
271 Request: ContextSmuggler<State> + Send + 'static,
272{
273 type Response = T::Response;
274 type Error = T::Error;
275
276 fn serve(
277 &self,
278 ctx: rama_core::Context<State>,
279 mut req: Request,
280 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
281 req.inject_ctx(ctx);
282 let svc = self.0.clone();
283 async move {
284 let mut svc = svc;
285 let ready = Ready::new(&mut svc);
286 let ready_svc = ready.await?;
287 ready_svc.call(req).await
288 }
289 }
290}