rama_http/layer/auth/add_authorization.rs
1//! Add authorization to requests using the [`Authorization`] header.
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use bytes::Bytes;
9//!
10//! use rama_http::layer::validate_request::{ValidateRequestHeader, ValidateRequestHeaderLayer};
11//! use rama_http::layer::auth::AddAuthorizationLayer;
12//! use rama_http::{Body, Request, Response, StatusCode, header::AUTHORIZATION};
13//! use rama_core::service::service_fn;
14//! use rama_core::{Context, Service, Layer};
15//! use rama_core::error::BoxError;
16//!
17//! # async fn handle(request: Request) -> Result<Response, BoxError> {
18//! # Ok(Response::new(Body::default()))
19//! # }
20//!
21//! # #[tokio::main]
22//! # async fn main() -> Result<(), BoxError> {
23//! # let service_that_requires_auth = ValidateRequestHeader::basic(
24//! # service_fn(handle),
25//! # "username",
26//! # "password",
27//! # );
28//! let mut client = (
29//! // Use basic auth with the given username and password
30//! AddAuthorizationLayer::basic("username", "password"),
31//! ).layer(service_that_requires_auth);
32//!
33//! // Make a request, we don't have to add the `Authorization` header manually
34//! let response = client
35//! .serve(Context::default(), Request::new(Body::default()))
36//! .await?;
37//!
38//! assert_eq!(StatusCode::OK, response.status());
39//! # Ok(())
40//! # }
41//! ```
42
43use crate::{HeaderValue, Request, Response};
44use base64::Engine as _;
45use rama_core::{Context, Layer, Service};
46use rama_utils::macros::define_inner_service_accessors;
47use std::convert::TryFrom;
48use std::fmt;
49
50const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
51
52/// Layer that applies [`AddAuthorization`] which adds authorization to all requests using the
53/// [`Authorization`] header.
54///
55/// See the [module docs](crate::layer::auth::add_authorization) for an example.
56///
57/// You can also use [`SetRequestHeader`] if you have a use case that isn't supported by this
58/// middleware.
59///
60/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
61/// [`SetRequestHeader`]: crate::layer::set_header::SetRequestHeader
62#[derive(Debug, Clone)]
63pub struct AddAuthorizationLayer {
64 value: Option<HeaderValue>,
65 if_not_present: bool,
66}
67
68impl AddAuthorizationLayer {
69 /// Create a new [`AddAuthorizationLayer`] that does not add any authorization.
70 ///
71 /// Can be useful if you only want to add authorization for some branches
72 /// of your service.
73 pub fn none() -> Self {
74 Self {
75 value: None,
76 if_not_present: false,
77 }
78 }
79
80 /// Authorize requests using a username and password pair.
81 ///
82 /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is
83 /// `base64_encode("{username}:{password}")`.
84 ///
85 /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
86 /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
87 pub fn basic(username: &str, password: &str) -> Self {
88 let encoded = BASE64.encode(format!("{}:{}", username, password));
89 let value = HeaderValue::try_from(format!("Basic {}", encoded)).unwrap();
90 Self {
91 value: Some(value),
92 if_not_present: false,
93 }
94 }
95
96 /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
97 ///
98 /// The `Authorization` header will be set to `Bearer {token}`.
99 ///
100 /// # Panics
101 ///
102 /// Panics if the token is not a valid [`HeaderValue`].
103 pub fn bearer(token: &str) -> Self {
104 let value =
105 HeaderValue::try_from(format!("Bearer {}", token)).expect("token is not valid header");
106 Self {
107 value: Some(value),
108 if_not_present: false,
109 }
110 }
111
112 /// Mark the header as [sensitive].
113 ///
114 /// This can for example be used to hide the header value from logs.
115 ///
116 /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
117 pub fn as_sensitive(mut self, sensitive: bool) -> Self {
118 if let Some(value) = &mut self.value {
119 value.set_sensitive(sensitive);
120 }
121 self
122 }
123
124 /// Mark the header as [sensitive].
125 ///
126 /// This can for example be used to hide the header value from logs.
127 ///
128 /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
129 pub fn set_as_sensitive(&mut self, sensitive: bool) -> &mut Self {
130 if let Some(value) = &mut self.value {
131 value.set_sensitive(sensitive);
132 }
133 self
134 }
135
136 /// Preserve the existing `Authorization` header if it exists.
137 ///
138 /// This can be useful if you want to use different authorization headers for different requests.
139 pub fn if_not_present(mut self, value: bool) -> Self {
140 self.if_not_present = value;
141 self
142 }
143
144 /// Preserve the existing `Authorization` header if it exists.
145 ///
146 /// This can be useful if you want to use different authorization headers for different requests.
147 pub fn set_if_not_present(&mut self, value: bool) -> &mut Self {
148 self.if_not_present = value;
149 self
150 }
151}
152
153impl<S> Layer<S> for AddAuthorizationLayer {
154 type Service = AddAuthorization<S>;
155
156 fn layer(&self, inner: S) -> Self::Service {
157 AddAuthorization {
158 inner,
159 value: self.value.clone(),
160 if_not_present: self.if_not_present,
161 }
162 }
163
164 fn into_layer(self, inner: S) -> Self::Service {
165 AddAuthorization {
166 inner,
167 value: self.value,
168 if_not_present: self.if_not_present,
169 }
170 }
171}
172
173/// Middleware that adds authorization all requests using the [`Authorization`] header.
174///
175/// See the [module docs](crate::layer::auth::add_authorization) for an example.
176///
177/// You can also use [`SetRequestHeader`] if you have a use case that isn't supported by this
178/// middleware.
179///
180/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
181/// [`SetRequestHeader`]: crate::layer::set_header::SetRequestHeader
182pub struct AddAuthorization<S> {
183 inner: S,
184 value: Option<HeaderValue>,
185 if_not_present: bool,
186}
187
188impl<S> AddAuthorization<S> {
189 /// Create a new [`AddAuthorization`] that does not add any authorization.
190 ///
191 /// Can be useful if you only want to add authorization for some branches
192 /// of your service.
193 pub fn none(inner: S) -> Self {
194 AddAuthorizationLayer::none().layer(inner)
195 }
196
197 /// Authorize requests using a username and password pair.
198 ///
199 /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is
200 /// `base64_encode("{username}:{password}")`.
201 ///
202 /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
203 /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
204 pub fn basic(inner: S, username: &str, password: &str) -> Self {
205 AddAuthorizationLayer::basic(username, password).layer(inner)
206 }
207
208 /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
209 ///
210 /// The `Authorization` header will be set to `Bearer {token}`.
211 ///
212 /// # Panics
213 ///
214 /// Panics if the token is not a valid [`HeaderValue`].
215 pub fn bearer(inner: S, token: &str) -> Self {
216 AddAuthorizationLayer::bearer(token).layer(inner)
217 }
218
219 define_inner_service_accessors!();
220
221 /// Mark the header as [sensitive].
222 ///
223 /// This can for example be used to hide the header value from logs.
224 ///
225 /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
226 pub fn as_sensitive(mut self, sensitive: bool) -> Self {
227 if let Some(value) = &mut self.value {
228 value.set_sensitive(sensitive);
229 }
230 self
231 }
232
233 /// Mark the header as [sensitive].
234 ///
235 /// This can for example be used to hide the header value from logs.
236 ///
237 /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
238 pub fn set_as_sensitive(&mut self, sensitive: bool) -> &mut Self {
239 if let Some(value) = &mut self.value {
240 value.set_sensitive(sensitive);
241 }
242 self
243 }
244
245 /// Preserve the existing `Authorization` header if it exists.
246 ///
247 /// This can be useful if you want to use different authorization headers for different requests.
248 pub fn if_not_present(mut self, value: bool) -> Self {
249 self.if_not_present = value;
250 self
251 }
252
253 /// Preserve the existing `Authorization` header if it exists.
254 ///
255 /// This can be useful if you want to use different authorization headers for different requests.
256 pub fn set_if_not_present(&mut self, value: bool) -> &mut Self {
257 self.if_not_present = value;
258 self
259 }
260}
261
262impl<S: fmt::Debug> fmt::Debug for AddAuthorization<S> {
263 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264 f.debug_struct("AddAuthorization")
265 .field("inner", &self.inner)
266 .field("value", &self.value)
267 .field("if_not_present", &self.if_not_present)
268 .finish()
269 }
270}
271
272impl<S: Clone> Clone for AddAuthorization<S> {
273 fn clone(&self) -> Self {
274 AddAuthorization {
275 inner: self.inner.clone(),
276 value: self.value.clone(),
277 if_not_present: self.if_not_present,
278 }
279 }
280}
281
282impl<S, State, ReqBody, ResBody> Service<State, Request<ReqBody>> for AddAuthorization<S>
283where
284 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
285 ReqBody: Send + 'static,
286 ResBody: Send + 'static,
287 State: Clone + Send + Sync + 'static,
288{
289 type Response = S::Response;
290 type Error = S::Error;
291
292 async fn serve(
293 &self,
294 ctx: Context<State>,
295 mut req: Request<ReqBody>,
296 ) -> Result<Self::Response, Self::Error> {
297 if let Some(value) = &self.value {
298 if !self.if_not_present
299 || !req
300 .headers()
301 .contains_key(rama_http_types::header::AUTHORIZATION)
302 {
303 req.headers_mut()
304 .insert(rama_http_types::header::AUTHORIZATION, value.clone());
305 }
306 }
307 self.inner.serve(ctx, req).await
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 #[allow(unused_imports)]
314 use super::*;
315
316 use crate::layer::validate_request::ValidateRequestHeaderLayer;
317 use crate::{Body, Request, Response, StatusCode};
318 use rama_core::error::BoxError;
319 use rama_core::service::service_fn;
320 use rama_core::{Context, Service};
321 use std::convert::Infallible;
322
323 #[tokio::test]
324 async fn basic() {
325 // service that requires auth for all requests
326 let svc = ValidateRequestHeaderLayer::basic("foo", "bar").layer(service_fn(echo));
327
328 // make a client that adds auth
329 let client = AddAuthorization::basic(svc, "foo", "bar");
330
331 let res = client
332 .serve(Context::default(), Request::new(Body::empty()))
333 .await
334 .unwrap();
335
336 assert_eq!(res.status(), StatusCode::OK);
337 }
338
339 #[tokio::test]
340 async fn token() {
341 // service that requires auth for all requests
342 let svc = ValidateRequestHeaderLayer::bearer("foo").layer(service_fn(echo));
343
344 // make a client that adds auth
345 let client = AddAuthorization::bearer(svc, "foo");
346
347 let res = client
348 .serve(Context::default(), Request::new(Body::empty()))
349 .await
350 .unwrap();
351
352 assert_eq!(res.status(), StatusCode::OK);
353 }
354
355 #[tokio::test]
356 async fn making_header_sensitive() {
357 let svc = ValidateRequestHeaderLayer::bearer("foo").layer(service_fn(
358 async |request: Request<Body>| {
359 let auth = request
360 .headers()
361 .get(rama_http_types::header::AUTHORIZATION)
362 .unwrap();
363 assert!(auth.is_sensitive());
364
365 Ok::<_, Infallible>(Response::new(Body::empty()))
366 },
367 ));
368
369 let client = AddAuthorization::bearer(svc, "foo").as_sensitive(true);
370
371 let res = client
372 .serve(Context::default(), Request::new(Body::empty()))
373 .await
374 .unwrap();
375
376 assert_eq!(res.status(), StatusCode::OK);
377 }
378
379 async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
380 Ok(Response::new(req.into_body()))
381 }
382}