rama_http/layer/auth/
add_authorization.rs

1//! Add authorization to requests using the [`Authorization`] header.
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use bytes::Bytes;
9//!
10//! use rama_http::layer::validate_request::{ValidateRequestHeader, ValidateRequestHeaderLayer};
11//! use rama_http::layer::auth::AddAuthorizationLayer;
12//! use rama_http::{Body, Request, Response, StatusCode, header::AUTHORIZATION};
13//! use rama_core::service::service_fn;
14//! use rama_core::{Context, Service, Layer};
15//! use rama_core::error::BoxError;
16//!
17//! # async fn handle(request: Request) -> Result<Response, BoxError> {
18//! #     Ok(Response::new(Body::default()))
19//! # }
20//!
21//! # #[tokio::main]
22//! # async fn main() -> Result<(), BoxError> {
23//! # let service_that_requires_auth = ValidateRequestHeader::basic(
24//! #     service_fn(handle),
25//! #     "username",
26//! #     "password",
27//! # );
28//! let mut client = (
29//!     // Use basic auth with the given username and password
30//!     AddAuthorizationLayer::basic("username", "password"),
31//! ).layer(service_that_requires_auth);
32//!
33//! // Make a request, we don't have to add the `Authorization` header manually
34//! let response = client
35//!     .serve(Context::default(), Request::new(Body::default()))
36//!     .await?;
37//!
38//! assert_eq!(StatusCode::OK, response.status());
39//! # Ok(())
40//! # }
41//! ```
42
43use crate::{HeaderValue, Request, Response};
44use base64::Engine as _;
45use rama_core::{Context, Layer, Service};
46use rama_utils::macros::define_inner_service_accessors;
47use std::convert::TryFrom;
48use std::fmt;
49
50const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
51
52/// Layer that applies [`AddAuthorization`] which adds authorization to all requests using the
53/// [`Authorization`] header.
54///
55/// See the [module docs](crate::layer::auth::add_authorization) for an example.
56///
57/// You can also use [`SetRequestHeader`] if you have a use case that isn't supported by this
58/// middleware.
59///
60/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
61/// [`SetRequestHeader`]: crate::layer::set_header::SetRequestHeader
62#[derive(Debug, Clone)]
63pub struct AddAuthorizationLayer {
64    value: Option<HeaderValue>,
65    if_not_present: bool,
66}
67
68impl AddAuthorizationLayer {
69    /// Create a new [`AddAuthorizationLayer`] that does not add any authorization.
70    ///
71    /// Can be useful if you only want to add authorization for some branches
72    /// of your service.
73    pub fn none() -> Self {
74        Self {
75            value: None,
76            if_not_present: false,
77        }
78    }
79
80    /// Authorize requests using a username and password pair.
81    ///
82    /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is
83    /// `base64_encode("{username}:{password}")`.
84    ///
85    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
86    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
87    pub fn basic(username: &str, password: &str) -> Self {
88        let encoded = BASE64.encode(format!("{}:{}", username, password));
89        let value = HeaderValue::try_from(format!("Basic {}", encoded)).unwrap();
90        Self {
91            value: Some(value),
92            if_not_present: false,
93        }
94    }
95
96    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
97    ///
98    /// The `Authorization` header will be set to `Bearer {token}`.
99    ///
100    /// # Panics
101    ///
102    /// Panics if the token is not a valid [`HeaderValue`].
103    pub fn bearer(token: &str) -> Self {
104        let value =
105            HeaderValue::try_from(format!("Bearer {}", token)).expect("token is not valid header");
106        Self {
107            value: Some(value),
108            if_not_present: false,
109        }
110    }
111
112    /// Mark the header as [sensitive].
113    ///
114    /// This can for example be used to hide the header value from logs.
115    ///
116    /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
117    pub fn as_sensitive(mut self, sensitive: bool) -> Self {
118        if let Some(value) = &mut self.value {
119            value.set_sensitive(sensitive);
120        }
121        self
122    }
123
124    /// Mark the header as [sensitive].
125    ///
126    /// This can for example be used to hide the header value from logs.
127    ///
128    /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
129    pub fn set_as_sensitive(&mut self, sensitive: bool) -> &mut Self {
130        if let Some(value) = &mut self.value {
131            value.set_sensitive(sensitive);
132        }
133        self
134    }
135
136    /// Preserve the existing `Authorization` header if it exists.
137    ///
138    /// This can be useful if you want to use different authorization headers for different requests.
139    pub fn if_not_present(mut self, value: bool) -> Self {
140        self.if_not_present = value;
141        self
142    }
143
144    /// Preserve the existing `Authorization` header if it exists.
145    ///
146    /// This can be useful if you want to use different authorization headers for different requests.
147    pub fn set_if_not_present(&mut self, value: bool) -> &mut Self {
148        self.if_not_present = value;
149        self
150    }
151}
152
153impl<S> Layer<S> for AddAuthorizationLayer {
154    type Service = AddAuthorization<S>;
155
156    fn layer(&self, inner: S) -> Self::Service {
157        AddAuthorization {
158            inner,
159            value: self.value.clone(),
160            if_not_present: self.if_not_present,
161        }
162    }
163
164    fn into_layer(self, inner: S) -> Self::Service {
165        AddAuthorization {
166            inner,
167            value: self.value,
168            if_not_present: self.if_not_present,
169        }
170    }
171}
172
173/// Middleware that adds authorization all requests using the [`Authorization`] header.
174///
175/// See the [module docs](crate::layer::auth::add_authorization) for an example.
176///
177/// You can also use [`SetRequestHeader`] if you have a use case that isn't supported by this
178/// middleware.
179///
180/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
181/// [`SetRequestHeader`]: crate::layer::set_header::SetRequestHeader
182pub struct AddAuthorization<S> {
183    inner: S,
184    value: Option<HeaderValue>,
185    if_not_present: bool,
186}
187
188impl<S> AddAuthorization<S> {
189    /// Create a new [`AddAuthorization`] that does not add any authorization.
190    ///
191    /// Can be useful if you only want to add authorization for some branches
192    /// of your service.
193    pub fn none(inner: S) -> Self {
194        AddAuthorizationLayer::none().layer(inner)
195    }
196
197    /// Authorize requests using a username and password pair.
198    ///
199    /// The `Authorization` header will be set to `Basic {credentials}` where `credentials` is
200    /// `base64_encode("{username}:{password}")`.
201    ///
202    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
203    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
204    pub fn basic(inner: S, username: &str, password: &str) -> Self {
205        AddAuthorizationLayer::basic(username, password).layer(inner)
206    }
207
208    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
209    ///
210    /// The `Authorization` header will be set to `Bearer {token}`.
211    ///
212    /// # Panics
213    ///
214    /// Panics if the token is not a valid [`HeaderValue`].
215    pub fn bearer(inner: S, token: &str) -> Self {
216        AddAuthorizationLayer::bearer(token).layer(inner)
217    }
218
219    define_inner_service_accessors!();
220
221    /// Mark the header as [sensitive].
222    ///
223    /// This can for example be used to hide the header value from logs.
224    ///
225    /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
226    pub fn as_sensitive(mut self, sensitive: bool) -> Self {
227        if let Some(value) = &mut self.value {
228            value.set_sensitive(sensitive);
229        }
230        self
231    }
232
233    /// Mark the header as [sensitive].
234    ///
235    /// This can for example be used to hide the header value from logs.
236    ///
237    /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive
238    pub fn set_as_sensitive(&mut self, sensitive: bool) -> &mut Self {
239        if let Some(value) = &mut self.value {
240            value.set_sensitive(sensitive);
241        }
242        self
243    }
244
245    /// Preserve the existing `Authorization` header if it exists.
246    ///
247    /// This can be useful if you want to use different authorization headers for different requests.
248    pub fn if_not_present(mut self, value: bool) -> Self {
249        self.if_not_present = value;
250        self
251    }
252
253    /// Preserve the existing `Authorization` header if it exists.
254    ///
255    /// This can be useful if you want to use different authorization headers for different requests.
256    pub fn set_if_not_present(&mut self, value: bool) -> &mut Self {
257        self.if_not_present = value;
258        self
259    }
260}
261
262impl<S: fmt::Debug> fmt::Debug for AddAuthorization<S> {
263    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264        f.debug_struct("AddAuthorization")
265            .field("inner", &self.inner)
266            .field("value", &self.value)
267            .field("if_not_present", &self.if_not_present)
268            .finish()
269    }
270}
271
272impl<S: Clone> Clone for AddAuthorization<S> {
273    fn clone(&self) -> Self {
274        AddAuthorization {
275            inner: self.inner.clone(),
276            value: self.value.clone(),
277            if_not_present: self.if_not_present,
278        }
279    }
280}
281
282impl<S, State, ReqBody, ResBody> Service<State, Request<ReqBody>> for AddAuthorization<S>
283where
284    S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
285    ReqBody: Send + 'static,
286    ResBody: Send + 'static,
287    State: Clone + Send + Sync + 'static,
288{
289    type Response = S::Response;
290    type Error = S::Error;
291
292    async fn serve(
293        &self,
294        ctx: Context<State>,
295        mut req: Request<ReqBody>,
296    ) -> Result<Self::Response, Self::Error> {
297        if let Some(value) = &self.value {
298            if !self.if_not_present
299                || !req
300                    .headers()
301                    .contains_key(rama_http_types::header::AUTHORIZATION)
302            {
303                req.headers_mut()
304                    .insert(rama_http_types::header::AUTHORIZATION, value.clone());
305            }
306        }
307        self.inner.serve(ctx, req).await
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    #[allow(unused_imports)]
314    use super::*;
315
316    use crate::layer::validate_request::ValidateRequestHeaderLayer;
317    use crate::{Body, Request, Response, StatusCode};
318    use rama_core::error::BoxError;
319    use rama_core::service::service_fn;
320    use rama_core::{Context, Service};
321    use std::convert::Infallible;
322
323    #[tokio::test]
324    async fn basic() {
325        // service that requires auth for all requests
326        let svc = ValidateRequestHeaderLayer::basic("foo", "bar").layer(service_fn(echo));
327
328        // make a client that adds auth
329        let client = AddAuthorization::basic(svc, "foo", "bar");
330
331        let res = client
332            .serve(Context::default(), Request::new(Body::empty()))
333            .await
334            .unwrap();
335
336        assert_eq!(res.status(), StatusCode::OK);
337    }
338
339    #[tokio::test]
340    async fn token() {
341        // service that requires auth for all requests
342        let svc = ValidateRequestHeaderLayer::bearer("foo").layer(service_fn(echo));
343
344        // make a client that adds auth
345        let client = AddAuthorization::bearer(svc, "foo");
346
347        let res = client
348            .serve(Context::default(), Request::new(Body::empty()))
349            .await
350            .unwrap();
351
352        assert_eq!(res.status(), StatusCode::OK);
353    }
354
355    #[tokio::test]
356    async fn making_header_sensitive() {
357        let svc = ValidateRequestHeaderLayer::bearer("foo").layer(service_fn(
358            async |request: Request<Body>| {
359                let auth = request
360                    .headers()
361                    .get(rama_http_types::header::AUTHORIZATION)
362                    .unwrap();
363                assert!(auth.is_sensitive());
364
365                Ok::<_, Infallible>(Response::new(Body::empty()))
366            },
367        ));
368
369        let client = AddAuthorization::bearer(svc, "foo").as_sensitive(true);
370
371        let res = client
372            .serve(Context::default(), Request::new(Body::empty()))
373            .await
374            .unwrap();
375
376        assert_eq!(res.status(), StatusCode::OK);
377    }
378
379    async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
380        Ok(Response::new(req.into_body()))
381    }
382}