rama_http/layer/required_header/
response.rs1use crate::{
6 HeaderValue, Request, Response,
7 header::{self, DATE, RAMA_ID_HEADER_VALUE, SERVER},
8 headers::{Date, HeaderMapExt},
9};
10use rama_core::{Context, Layer, Service};
11use rama_utils::macros::define_inner_service_accessors;
12use std::{fmt, time::SystemTime};
13
14#[derive(Debug, Clone, Default)]
18pub struct AddRequiredResponseHeadersLayer {
19 overwrite: bool,
20 server_header_value: Option<HeaderValue>,
21}
22
23impl AddRequiredResponseHeadersLayer {
24 pub const fn new() -> Self {
26 Self {
27 overwrite: false,
28 server_header_value: None,
29 }
30 }
31
32 pub const fn overwrite(mut self, overwrite: bool) -> Self {
37 self.overwrite = overwrite;
38 self
39 }
40
41 pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
46 self.overwrite = overwrite;
47 self
48 }
49
50 pub fn server_header_value(mut self, value: HeaderValue) -> Self {
54 self.server_header_value = Some(value);
55 self
56 }
57
58 pub fn maybe_server_header_value(mut self, value: Option<HeaderValue>) -> Self {
62 self.server_header_value = value;
63 self
64 }
65
66 pub fn set_server_header_value(&mut self, value: HeaderValue) -> &mut Self {
70 self.server_header_value = Some(value);
71 self
72 }
73}
74
75impl<S> Layer<S> for AddRequiredResponseHeadersLayer {
76 type Service = AddRequiredResponseHeaders<S>;
77
78 fn layer(&self, inner: S) -> Self::Service {
79 AddRequiredResponseHeaders {
80 inner,
81 overwrite: self.overwrite,
82 server_header_value: self.server_header_value.clone(),
83 }
84 }
85
86 fn into_layer(self, inner: S) -> Self::Service {
87 AddRequiredResponseHeaders {
88 inner,
89 overwrite: self.overwrite,
90 server_header_value: self.server_header_value,
91 }
92 }
93}
94
95#[derive(Clone)]
97pub struct AddRequiredResponseHeaders<S> {
98 inner: S,
99 overwrite: bool,
100 server_header_value: Option<HeaderValue>,
101}
102
103impl<S> AddRequiredResponseHeaders<S> {
104 pub const fn new(inner: S) -> Self {
106 Self {
107 inner,
108 overwrite: false,
109 server_header_value: None,
110 }
111 }
112
113 pub const fn overwrite(mut self, overwrite: bool) -> Self {
118 self.overwrite = overwrite;
119 self
120 }
121
122 pub fn set_overwrite(&mut self, overwrite: bool) -> &mut Self {
127 self.overwrite = overwrite;
128 self
129 }
130
131 pub fn server_header_value(mut self, value: HeaderValue) -> Self {
135 self.server_header_value = Some(value);
136 self
137 }
138
139 pub fn maybe_server_header_value(mut self, value: Option<HeaderValue>) -> Self {
143 self.server_header_value = value;
144 self
145 }
146
147 pub fn set_server_header_value(&mut self, value: HeaderValue) -> &mut Self {
151 self.server_header_value = Some(value);
152 self
153 }
154
155 define_inner_service_accessors!();
156}
157
158impl<S> fmt::Debug for AddRequiredResponseHeaders<S>
159where
160 S: fmt::Debug,
161{
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 f.debug_struct("AddRequiredResponseHeaders")
164 .field("inner", &self.inner)
165 .field("server_header_value", &self.server_header_value)
166 .finish()
167 }
168}
169
170impl<ReqBody, ResBody, State, S> Service<State, Request<ReqBody>> for AddRequiredResponseHeaders<S>
171where
172 ReqBody: Send + 'static,
173 ResBody: Send + 'static,
174 State: Clone + Send + Sync + 'static,
175 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
176{
177 type Response = S::Response;
178 type Error = S::Error;
179
180 async fn serve(
181 &self,
182 ctx: Context<State>,
183 req: Request<ReqBody>,
184 ) -> Result<Self::Response, Self::Error> {
185 let mut resp = self.inner.serve(ctx, req).await?;
186
187 if self.overwrite {
188 resp.headers_mut().insert(
189 SERVER,
190 self.server_header_value
191 .as_ref()
192 .unwrap_or(&RAMA_ID_HEADER_VALUE)
193 .clone(),
194 );
195 } else if let header::Entry::Vacant(header) = resp.headers_mut().entry(SERVER) {
196 header.insert(
197 self.server_header_value
198 .as_ref()
199 .unwrap_or(&RAMA_ID_HEADER_VALUE)
200 .clone(),
201 );
202 }
203
204 if self.overwrite || !resp.headers().contains_key(DATE) {
205 resp.headers_mut()
206 .typed_insert(Date::from(SystemTime::now()));
207 }
208
209 Ok(resp)
210 }
211}
212
213#[cfg(test)]
214mod tests {
215 use super::*;
216 use crate::Body;
217 use rama_core::{Layer, service::service_fn};
218 use std::convert::Infallible;
219
220 #[tokio::test]
221 async fn add_required_response_headers() {
222 let svc = AddRequiredResponseHeadersLayer::default().into_layer(service_fn(
223 async |_ctx: Context<()>, req: Request| {
224 assert!(!req.headers().contains_key(SERVER));
225 assert!(!req.headers().contains_key(DATE));
226 Ok::<_, Infallible>(Response::new(Body::empty()))
227 },
228 ));
229
230 let req = Request::new(Body::empty());
231 let resp = svc.serve(Context::default(), req).await.unwrap();
232
233 assert_eq!(
234 resp.headers().get(SERVER).unwrap(),
235 RAMA_ID_HEADER_VALUE.as_ref()
236 );
237 assert!(resp.headers().contains_key(DATE));
238 }
239
240 #[tokio::test]
241 async fn add_required_response_headers_custom_server() {
242 let svc = AddRequiredResponseHeadersLayer::default()
243 .server_header_value(HeaderValue::from_static("foo"))
244 .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
245 assert!(!req.headers().contains_key(SERVER));
246 assert!(!req.headers().contains_key(DATE));
247 Ok::<_, Infallible>(Response::new(Body::empty()))
248 }));
249
250 let req = Request::new(Body::empty());
251 let resp = svc.serve(Context::default(), req).await.unwrap();
252
253 assert_eq!(
254 resp.headers().get(SERVER).and_then(|v| v.to_str().ok()),
255 Some("foo")
256 );
257 assert!(resp.headers().contains_key(DATE));
258 }
259
260 #[tokio::test]
261 async fn add_required_response_headers_overwrite() {
262 let svc = AddRequiredResponseHeadersLayer::new()
263 .overwrite(true)
264 .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
265 assert!(!req.headers().contains_key(SERVER));
266 assert!(!req.headers().contains_key(DATE));
267 Ok::<_, Infallible>(
268 Response::builder()
269 .header(SERVER, "foo")
270 .header(DATE, "bar")
271 .body(Body::empty())
272 .unwrap(),
273 )
274 }));
275
276 let req = Request::new(Body::empty());
277 let resp = svc.serve(Context::default(), req).await.unwrap();
278
279 assert_eq!(
280 resp.headers().get(SERVER).unwrap(),
281 RAMA_ID_HEADER_VALUE.to_str().unwrap()
282 );
283 assert_ne!(resp.headers().get(DATE).unwrap(), "bar");
284 }
285
286 #[tokio::test]
287 async fn add_required_response_headers_overwrite_custom_ua() {
288 let svc = AddRequiredResponseHeadersLayer::new()
289 .overwrite(true)
290 .server_header_value(HeaderValue::from_static("foo"))
291 .into_layer(service_fn(async |_ctx: Context<()>, req: Request| {
292 assert!(!req.headers().contains_key(SERVER));
293 assert!(!req.headers().contains_key(DATE));
294 Ok::<_, Infallible>(
295 Response::builder()
296 .header(SERVER, "foo")
297 .header(DATE, "bar")
298 .body(Body::empty())
299 .unwrap(),
300 )
301 }));
302
303 let req = Request::new(Body::empty());
304 let resp = svc.serve(Context::default(), req).await.unwrap();
305
306 assert_eq!(
307 resp.headers().get(SERVER).and_then(|v| v.to_str().ok()),
308 Some("foo")
309 );
310 assert_ne!(resp.headers().get(DATE).unwrap(), "bar");
311 }
312}