rama_http/layer/auth/
add_authorization.rs

1//! Add authorization to requests using the [`Authorization`] header.
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use bytes::Bytes;
9//!
10//! use rama_http::layer::validate_request::{ValidateRequestHeader, ValidateRequestHeaderLayer};
11//! use rama_http::layer::auth::AddAuthorizationLayer;
12//! use rama_http::{Body, Request, Response, StatusCode, header::AUTHORIZATION};
13//! use rama_core::service::service_fn;
14//! use rama_core::{Context, Service, Layer};
15//! use rama_core::error::BoxError;
16//!
17//! # async fn handle(request: Request) -> Result<Response, BoxError> {
18//! #     Ok(Response::new(Body::default()))
19//! # }
20//!
21//! # #[tokio::main]
22//! # async fn main() -> Result<(), BoxError> {
23//! # let service_that_requires_auth = ValidateRequestHeader::basic(
24//! #     service_fn(handle),
25//! #     "username",
26//! #     "password",
27//! # );
28//! let mut client = (
29//!     // Use basic auth with the given username and password
30//!     AddAuthorizationLayer::basic("username", "password"),
31//! ).layer(service_that_requires_auth);
32//!
33//! // Make a request, we don't have to add the `Authorization` header manually
34//! let response = client
35//!     .serve(Context::default(), Request::new(Body::default()))
36//!     .await?;
37//!
38//! assert_eq!(StatusCode::OK, response.status());
39//! # Ok(())
40//! # }
41//! ```
42
43use crate::{HeaderValue, Request, Response};
44use base64::Engine as _;
45use rama_core::{Context, Layer, Service};
46use rama_utils::macros::define_inner_service_accessors;
47use std::convert::TryFrom;
48use std::fmt;
49
50const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
51
52/// Layer that applies [`AddAuthorization`] which adds authorization to all requests using the
53/// [`Authorization`] header.
54///
55/// See the [module docs](crate::layer::auth::add_authorization) for an example.
56///
57/// You can also use [`SetRequestHeader`] if you have a use case that isn't supported by this
58/// middleware.
59///
60/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
61/// [`SetRequestHeader`]: crate::layer::set_header::SetRequestHeader
62#[derive(Debug, Clone)]
63pub struct AddAuthorizationLayer {
64    value: Option<HeaderValue>,
65    if_not_present: bool,
66}
67
68impl AddAuthorizationLayer {
69    /// Create a new [`AddAuthorizationLayer`] that does not add any authorization.
70    ///
71    /// Can be useful if you only want to add authorization for some branches
72    /// of your service.
73    pub fn none() -> Self {
74        Self {
75            value: None,
76            if_not_present: false,
77        }
78    }
79
80    /// Authorize requests using a username and password pair.
81    ///
82    /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is
83    /// `base64_encode("{username}:{password}")`.
84    ///
85    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
86    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
87    pub fn basic(username: &str, password: &str) -> Self {
88        let encoded = BASE64.encode(format!("{}:{}", username, password));
89        let value = HeaderValue::try_from(format!("Basic {}", encoded)).unwrap();
90        Self {
91            value: Some(value),
92            if_not_present: false,
93        }
94    }
95
96    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
97    ///
98    /// The `Authorization` header will be set to `Bearer {token}`.
99    ///
100    /// # Panics
101    ///
102    /// Panics if the token is not a valid [`HeaderValue`].
103    pub fn bearer(token: &str) -> Self {
104        let value =
105            HeaderValue::try_from(format!("Bearer {}", token)).expect("token is not valid header");
106        Self {
107            value: Some(value),
108            if_not_present: false,
109        }
110    }
111
112    /// Mark the header as [sensitive].
113    ///
114    /// This can for example be used to hide the header value from logs.
115    ///
116    /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
117    pub fn as_sensitive(mut self, sensitive: bool) -> Self {
118        if let Some(value) = &mut self.value {
119            value.set_sensitive(sensitive);
120        }
121        self
122    }
123
124    /// Mark the header as [sensitive].
125    ///
126    /// This can for example be used to hide the header value from logs.
127    ///
128    /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
129    pub fn set_as_sensitive(&mut self, sensitive: bool) -> &mut Self {
130        if let Some(value) = &mut self.value {
131            value.set_sensitive(sensitive);
132        }
133        self
134    }
135
136    /// Preserve the existing `Authorization` header if it exists.
137    ///
138    /// This can be useful if you want to use different authorization headers for different requests.
139    pub fn if_not_present(mut self, value: bool) -> Self {
140        self.if_not_present = value;
141        self
142    }
143
144    /// Preserve the existing `Authorization` header if it exists.
145    ///
146    /// This can be useful if you want to use different authorization headers for different requests.
147    pub fn set_if_not_present(&mut self, value: bool) -> &mut Self {
148        self.if_not_present = value;
149        self
150    }
151}
152
153impl<S> Layer<S> for AddAuthorizationLayer {
154    type Service = AddAuthorization<S>;
155
156    fn layer(&self, inner: S) -> Self::Service {
157        AddAuthorization {
158            inner,
159            value: self.value.clone(),
160            if_not_present: self.if_not_present,
161        }
162    }
163}
164
165/// Middleware that adds authorization all requests using the [`Authorization`] header.
166///
167/// See the [module docs](crate::layer::auth::add_authorization) for an example.
168///
169/// You can also use [`SetRequestHeader`] if you have a use case that isn't supported by this
170/// middleware.
171///
172/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
173/// [`SetRequestHeader`]: crate::layer::set_header::SetRequestHeader
174pub struct AddAuthorization<S> {
175    inner: S,
176    value: Option<HeaderValue>,
177    if_not_present: bool,
178}
179
180impl<S> AddAuthorization<S> {
181    /// Create a new [`AddAuthorization`] that does not add any authorization.
182    ///
183    /// Can be useful if you only want to add authorization for some branches
184    /// of your service.
185    pub fn none(inner: S) -> Self {
186        AddAuthorizationLayer::none().layer(inner)
187    }
188
189    /// Authorize requests using a username and password pair.
190    ///
191    /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is
192    /// `base64_encode("{username}:{password}")`.
193    ///
194    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
195    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
196    pub fn basic(inner: S, username: &str, password: &str) -> Self {
197        AddAuthorizationLayer::basic(username, password).layer(inner)
198    }
199
200    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
201    ///
202    /// The `Authorization` header will be set to `Bearer {token}`.
203    ///
204    /// # Panics
205    ///
206    /// Panics if the token is not a valid [`HeaderValue`].
207    pub fn bearer(inner: S, token: &str) -> Self {
208        AddAuthorizationLayer::bearer(token).layer(inner)
209    }
210
211    define_inner_service_accessors!();
212
213    /// Mark the header as [sensitive].
214    ///
215    /// This can for example be used to hide the header value from logs.
216    ///
217    /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
218    pub fn as_sensitive(mut self, sensitive: bool) -> Self {
219        if let Some(value) = &mut self.value {
220            value.set_sensitive(sensitive);
221        }
222        self
223    }
224
225    /// Mark the header as [sensitive].
226    ///
227    /// This can for example be used to hide the header value from logs.
228    ///
229    /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
230    pub fn set_as_sensitive(&mut self, sensitive: bool) -> &mut Self {
231        if let Some(value) = &mut self.value {
232            value.set_sensitive(sensitive);
233        }
234        self
235    }
236
237    /// Preserve the existing `Authorization` header if it exists.
238    ///
239    /// This can be useful if you want to use different authorization headers for different requests.
240    pub fn if_not_present(mut self, value: bool) -> Self {
241        self.if_not_present = value;
242        self
243    }
244
245    /// Preserve the existing `Authorization` header if it exists.
246    ///
247    /// This can be useful if you want to use different authorization headers for different requests.
248    pub fn set_if_not_present(&mut self, value: bool) -> &mut Self {
249        self.if_not_present = value;
250        self
251    }
252}
253
254impl<S: fmt::Debug> fmt::Debug for AddAuthorization<S> {
255    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256        f.debug_struct("AddAuthorization")
257            .field("inner", &self.inner)
258            .field("value", &self.value)
259            .field("if_not_present", &self.if_not_present)
260            .finish()
261    }
262}
263
264impl<S: Clone> Clone for AddAuthorization<S> {
265    fn clone(&self) -> Self {
266        AddAuthorization {
267            inner: self.inner.clone(),
268            value: self.value.clone(),
269            if_not_present: self.if_not_present,
270        }
271    }
272}
273
274impl<S, State, ReqBody, ResBody> Service<State, Request<ReqBody>> for AddAuthorization<S>
275where
276    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
277    ReqBody: Send + 'static,
278    ResBody: Send + 'static,
279    State: Clone + Send + Sync + 'static,
280{
281    type Response = S::Response;
282    type Error = S::Error;
283
284    async fn serve(
285        &self,
286        ctx: Context<State>,
287        mut req: Request<ReqBody>,
288    ) -> Result<Self::Response, Self::Error> {
289        if let Some(value) = &self.value {
290            if !self.if_not_present || !req.headers().contains_key(http::header::AUTHORIZATION) {
291                req.headers_mut()
292                    .insert(http::header::AUTHORIZATION, value.clone());
293            }
294        }
295        self.inner.serve(ctx, req).await
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    #[allow(unused_imports)]
302    use super::*;
303
304    use crate::layer::validate_request::ValidateRequestHeaderLayer;
305    use crate::{Body, Request, Response, StatusCode};
306    use rama_core::error::BoxError;
307    use rama_core::service::service_fn;
308    use rama_core::{Context, Service};
309    use std::convert::Infallible;
310
311    #[tokio::test]
312    async fn basic() {
313        // service that requires auth for all requests
314        let svc = ValidateRequestHeaderLayer::basic("foo", "bar").layer(service_fn(echo));
315
316        // make a client that adds auth
317        let client = AddAuthorization::basic(svc, "foo", "bar");
318
319        let res = client
320            .serve(Context::default(), Request::new(Body::empty()))
321            .await
322            .unwrap();
323
324        assert_eq!(res.status(), StatusCode::OK);
325    }
326
327    #[tokio::test]
328    async fn token() {
329        // service that requires auth for all requests
330        let svc = ValidateRequestHeaderLayer::bearer("foo").layer(service_fn(echo));
331
332        // make a client that adds auth
333        let client = AddAuthorization::bearer(svc, "foo");
334
335        let res = client
336            .serve(Context::default(), Request::new(Body::empty()))
337            .await
338            .unwrap();
339
340        assert_eq!(res.status(), StatusCode::OK);
341    }
342
343    #[tokio::test]
344    async fn making_header_sensitive() {
345        let svc = ValidateRequestHeaderLayer::bearer("foo").layer(service_fn(
346            |request: Request<Body>| async move {
347                let auth = request.headers().get(http::header::AUTHORIZATION).unwrap();
348                assert!(auth.is_sensitive());
349
350                Ok::<_, Infallible>(Response::new(Body::empty()))
351            },
352        ));
353
354        let client = AddAuthorization::bearer(svc, "foo").as_sensitive(true);
355
356        let res = client
357            .serve(Context::default(), Request::new(Body::empty()))
358            .await
359            .unwrap();
360
361        assert_eq!(res.status(), StatusCode::OK);
362    }
363
364    async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
365        Ok(Response::new(req.into_body()))
366    }
367}