1use crate::{Context, Layer, Service};
6use rama_utils::macros::define_inner_service_accessors;
7use std::{fmt, marker::PhantomData};
8
9pub struct GetExtensionLayer<T, Fut, F> {
13 callback: F,
14 _phantom: PhantomData<fn(T) -> Fut>,
15}
16
17impl<T, Fut, F: fmt::Debug> std::fmt::Debug for GetExtensionLayer<T, Fut, F> {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 f.debug_struct("GetExtensionLayer")
20 .field("callback", &self.callback)
21 .field(
22 "_phantom",
23 &format_args!("{}", std::any::type_name::<fn(T) -> Fut>()),
24 )
25 .finish()
26 }
27}
28
29impl<T, Fut, F> Clone for GetExtensionLayer<T, Fut, F>
30where
31 F: Clone,
32{
33 fn clone(&self) -> Self {
34 Self {
35 callback: self.callback.clone(),
36 _phantom: PhantomData,
37 }
38 }
39}
40
41impl<T, Fut, F> GetExtensionLayer<T, Fut, F>
42where
43 F: FnOnce(T) -> Fut + Clone + Send + Sync + 'static,
44 Fut: Future<Output = ()> + Send + 'static,
45{
46 pub const fn new(callback: F) -> Self {
48 GetExtensionLayer {
49 callback,
50 _phantom: PhantomData,
51 }
52 }
53}
54
55impl<S, T, Fut, F> Layer<S> for GetExtensionLayer<T, Fut, F>
56where
57 F: Clone,
58{
59 type Service = GetExtension<S, T, Fut, F>;
60
61 fn layer(&self, inner: S) -> Self::Service {
62 GetExtension {
63 inner,
64 callback: self.callback.clone(),
65 _phantom: PhantomData,
66 }
67 }
68
69 fn into_layer(self, inner: S) -> Self::Service {
70 GetExtension {
71 inner,
72 callback: self.callback,
73 _phantom: PhantomData,
74 }
75 }
76}
77
78pub struct GetExtension<S, T, Fut, F> {
82 inner: S,
83 callback: F,
84 _phantom: PhantomData<fn(T) -> Fut>,
85}
86
87impl<S: fmt::Debug, T, Fut, F: fmt::Debug> std::fmt::Debug for GetExtension<S, T, Fut, F> {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 f.debug_struct("GetExtension")
90 .field("inner", &self.inner)
91 .field("callback", &self.callback)
92 .field(
93 "_phantom",
94 &format_args!("{}", std::any::type_name::<fn(T) -> Fut>()),
95 )
96 .finish()
97 }
98}
99
100impl<S, T, Fut, F> Clone for GetExtension<S, T, Fut, F>
101where
102 S: Clone,
103 F: Clone,
104{
105 fn clone(&self) -> Self {
106 Self {
107 inner: self.inner.clone(),
108 callback: self.callback.clone(),
109 _phantom: PhantomData,
110 }
111 }
112}
113
114impl<S, T, Fut, F> GetExtension<S, T, Fut, F> {
115 pub const fn new(inner: S, callback: F) -> Self
117 where
118 F: FnOnce(T) -> Fut + Clone + Send + Sync + 'static,
119 Fut: Future<Output = ()> + Send + 'static,
120 {
121 Self {
122 inner,
123 callback,
124 _phantom: PhantomData,
125 }
126 }
127
128 define_inner_service_accessors!();
129}
130
131impl<State, Request, S, T, Fut, F> Service<State, Request> for GetExtension<S, T, Fut, F>
132where
133 State: Clone + Send + Sync + 'static,
134 Request: Send + 'static,
135 S: Service<State, Request>,
136 T: Clone + Send + Sync + 'static,
137 F: FnOnce(T) -> Fut + Clone + Send + Sync + 'static,
138 Fut: Future<Output = ()> + Send + 'static,
139{
140 type Response = S::Response;
141 type Error = S::Error;
142
143 async fn serve(
144 &self,
145 ctx: Context<State>,
146 req: Request,
147 ) -> Result<Self::Response, Self::Error> {
148 if let Some(value) = ctx.get::<T>() {
149 let value = value.clone();
150 (self.callback.clone())(value).await;
151 }
152 self.inner.serve(ctx, req).await
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::{Context, service::service_fn};
160 use std::{convert::Infallible, sync::Arc};
161
162 #[derive(Debug, Clone)]
163 struct State(i32);
164
165 #[tokio::test]
166 async fn get_extension_basic() {
167 let value = Arc::new(std::sync::atomic::AtomicI32::new(0));
168
169 let cloned_value = value.clone();
170 let svc = GetExtensionLayer::new(async move |state: State| {
171 cloned_value.store(state.0, std::sync::atomic::Ordering::Release);
172 })
173 .into_layer(service_fn(async |ctx: Context<()>, _req: ()| {
174 let state = ctx.get::<State>().unwrap();
175 Ok::<_, Infallible>(state.0)
176 }));
177
178 let mut ctx = Context::default();
179 ctx.insert(State(42));
180
181 let res = svc.serve(ctx, ()).await.unwrap();
182 assert_eq!(42, res);
183
184 let value = value.load(std::sync::atomic::Ordering::Acquire);
185 assert_eq!(42, value);
186 }
187}