rama_core/layer/
add_extension.rs1use crate::{Context, Layer, Service};
55use rama_utils::macros::define_inner_service_accessors;
56use std::fmt;
57
58pub struct AddExtensionLayer<T> {
62 value: T,
63}
64
65impl<T: fmt::Debug> std::fmt::Debug for AddExtensionLayer<T> {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("AddExtensionLayer")
68 .field("value", &self.value)
69 .finish()
70 }
71}
72
73impl<T> Clone for AddExtensionLayer<T>
74where
75 T: Clone,
76{
77 fn clone(&self) -> Self {
78 Self {
79 value: self.value.clone(),
80 }
81 }
82}
83
84impl<T> AddExtensionLayer<T> {
85 pub const fn new(value: T) -> Self {
87 AddExtensionLayer { value }
88 }
89}
90
91impl<S, T> Layer<S> for AddExtensionLayer<T>
92where
93 T: Clone,
94{
95 type Service = AddExtension<S, T>;
96
97 fn layer(&self, inner: S) -> Self::Service {
98 AddExtension {
99 inner,
100 value: self.value.clone(),
101 }
102 }
103
104 fn into_layer(self, inner: S) -> Self::Service {
105 AddExtension {
106 inner,
107 value: self.value,
108 }
109 }
110}
111
112pub struct AddExtension<S, T> {
116 inner: S,
117 value: T,
118}
119
120impl<S: fmt::Debug, T: fmt::Debug> std::fmt::Debug for AddExtension<S, T> {
121 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
122 f.debug_struct("AddExtension")
123 .field("inner", &self.inner)
124 .field("value", &self.value)
125 .finish()
126 }
127}
128
129impl<S, T> Clone for AddExtension<S, T>
130where
131 S: Clone,
132 T: Clone,
133{
134 fn clone(&self) -> Self {
135 Self {
136 inner: self.inner.clone(),
137 value: self.value.clone(),
138 }
139 }
140}
141
142impl<S, T> AddExtension<S, T> {
143 pub const fn new(inner: S, value: T) -> Self {
145 Self { inner, value }
146 }
147
148 define_inner_service_accessors!();
149}
150
151impl<State, Request, S, T> Service<State, Request> for AddExtension<S, T>
152where
153 State: Clone + Send + Sync + 'static,
154 Request: Send + 'static,
155 S: Service<State, Request>,
156 T: Clone + Send + Sync + 'static,
157{
158 type Response = S::Response;
159 type Error = S::Error;
160
161 fn serve(
162 &self,
163 mut ctx: Context<State>,
164 req: Request,
165 ) -> impl Future<Output = Result<Self::Response, Self::Error>> + Send + '_ {
166 ctx.insert(self.value.clone());
167 self.inner.serve(ctx, req)
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174 use crate::{Context, service::service_fn};
175 use std::{convert::Infallible, sync::Arc};
176
177 struct State(i32);
178
179 #[tokio::test]
180 async fn basic() {
181 let state = Arc::new(State(1));
182
183 let svc = AddExtensionLayer::new(state).into_layer(service_fn(
184 async |ctx: Context<()>, _req: ()| {
185 let state = ctx.get::<Arc<State>>().unwrap();
186 Ok::<_, Infallible>(state.0)
187 },
188 ));
189
190 let res = svc.serve(Context::default(), ()).await.unwrap();
191
192 assert_eq!(1, res);
193 }
194}