rama_http/layer/auth/
require_authorization.rs

1//! Authorize requests using [`ValidateRequest`].
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use bytes::Bytes;
9//!
10//! use rama_http::layer::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
11//! use rama_http::{Body, Request, Response, StatusCode, header::AUTHORIZATION};
12//! use rama_core::service::service_fn;
13//! use rama_core::{Context, Service, Layer};
14//! use rama_core::error::BoxError;
15//!
16//! async fn handle(request: Request) -> Result<Response, BoxError> {
17//!     Ok(Response::new(Body::default()))
18//! }
19//!
20//! # #[tokio::main]
21//! # async fn main() -> Result<(), BoxError> {
22//! let mut service = (
23//!     // Require the `Authorization` header to be `Bearer passwordlol`
24//!     ValidateRequestHeaderLayer::bearer("passwordlol"),
25//! ).layer(service_fn(handle));
26//!
27//! // Requests with the correct token are allowed through
28//! let request = Request::builder()
29//!     .header(AUTHORIZATION, "Bearer passwordlol")
30//!     .body(Body::default())
31//!     .unwrap();
32//!
33//! let response = service
34//!     .serve(Context::default(), request)
35//!     .await?;
36//!
37//! assert_eq!(StatusCode::OK, response.status());
38//!
39//! // Requests with an invalid token get a `401 Unauthorized` response
40//! let request = Request::builder()
41//!     .body(Body::default())
42//!     .unwrap();
43//!
44//! let response = service
45//!     .serve(Context::default(), request)
46//!     .await?;
47//!
48//! assert_eq!(StatusCode::UNAUTHORIZED, response.status());
49//! # Ok(())
50//! # }
51//! ```
52//!
53//! Custom validation can be made by implementing [`ValidateRequest`].
54
55use crate::layer::validate_request::{
56    ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer,
57};
58use crate::{
59    Request, Response, StatusCode,
60    header::{self, HeaderValue},
61};
62use base64::Engine as _;
63use rama_core::Context;
64use std::{fmt, marker::PhantomData, sync::Arc};
65
66use rama_net::user::UserId;
67
68const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
69
70impl<C> ValidateRequestHeaderLayer<AuthorizeContext<C>> {
71    /// Allow anonymous requests.
72    pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
73        self.validate.allow_anonymous = allow_anonymous;
74        self
75    }
76
77    /// Allow anonymous requests.
78    pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
79        self.validate.allow_anonymous = allow_anonymous;
80        self
81    }
82}
83
84impl<S, C> ValidateRequestHeader<S, AuthorizeContext<C>> {
85    /// Allow anonymous requests.
86    pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
87        self.validate.allow_anonymous = allow_anonymous;
88        self
89    }
90
91    /// Allow anonymous requests.
92    pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
93        self.validate.allow_anonymous = allow_anonymous;
94        self
95    }
96}
97
98impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Basic<ResBody>>> {
99    /// Authorize requests using a username and password pair.
100    ///
101    /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
102    /// `base64_encode("{username}:{password}")`.
103    ///
104    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
105    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
106    pub fn basic(inner: S, username: &str, value: &str) -> Self
107    where
108        ResBody: Default,
109    {
110        Self::custom(inner, AuthorizeContext::new(Basic::new(username, value)))
111    }
112}
113
114impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Basic<ResBody>>> {
115    /// Authorize requests using a username and password pair.
116    ///
117    /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
118    /// `base64_encode("{username}:{password}")`.
119    ///
120    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
121    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
122    pub fn basic(username: &str, password: &str) -> Self
123    where
124        ResBody: Default,
125    {
126        Self::custom(AuthorizeContext::new(Basic::new(username, password)))
127    }
128}
129
130impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Bearer<ResBody>>> {
131    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
132    ///
133    /// The `Authorization` header is required to be `Bearer {token}`.
134    ///
135    /// # Panics
136    ///
137    /// Panics if the token is not a valid [`HeaderValue`].
138    pub fn bearer(inner: S, token: &str) -> Self
139    where
140        ResBody: Default,
141    {
142        Self::custom(inner, AuthorizeContext::new(Bearer::new(token)))
143    }
144}
145
146impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Bearer<ResBody>>> {
147    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
148    ///
149    /// The `Authorization` header is required to be `Bearer {token}`.
150    ///
151    /// # Panics
152    ///
153    /// Panics if the token is not a valid [`HeaderValue`].
154    pub fn bearer(token: &str) -> Self
155    where
156        ResBody: Default,
157    {
158        Self::custom(AuthorizeContext::new(Bearer::new(token)))
159    }
160}
161
162/// Type that performs "bearer token" authorization.
163///
164/// See [`ValidateRequestHeader::bearer`] for more details.
165pub struct Bearer<ResBody> {
166    header_value: HeaderValue,
167    _ty: PhantomData<fn() -> ResBody>,
168}
169
170impl<ResBody> Bearer<ResBody> {
171    fn new(token: &str) -> Self
172    where
173        ResBody: Default,
174    {
175        Self {
176            header_value: format!("Bearer {}", token)
177                .parse()
178                .expect("token is not a valid header value"),
179            _ty: PhantomData,
180        }
181    }
182}
183
184impl<ResBody> Clone for Bearer<ResBody> {
185    fn clone(&self) -> Self {
186        Self {
187            header_value: self.header_value.clone(),
188            _ty: PhantomData,
189        }
190    }
191}
192
193impl<ResBody> fmt::Debug for Bearer<ResBody> {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        f.debug_struct("Bearer")
196            .field("header_value", &self.header_value)
197            .finish()
198    }
199}
200
201// TODO: revisit ValidateRequest and related types so we do not require
202// the associated Response types for all these traits. E.g. by forcing
203// downstream users that their response bodies can be turned into the standard `rama::http::Body`
204impl<S, B, C> ValidateRequest<S, B> for AuthorizeContext<C>
205where
206    C: Authorizer,
207    B: Send + 'static,
208    S: Clone + Send + Sync + 'static,
209{
210    type ResponseBody = C::ResBody;
211
212    async fn validate(
213        &self,
214        ctx: Context<S>,
215        request: Request<B>,
216    ) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
217        match request.headers().get(header::AUTHORIZATION) {
218            Some(header_value) if self.credential.is_valid(header_value) => Ok((ctx, request)),
219            None if self.allow_anonymous => {
220                let mut ctx = ctx;
221                ctx.insert(UserId::Anonymous);
222                Ok((ctx, request))
223            }
224            _ => {
225                let mut res = Response::new(Self::ResponseBody::default());
226                *res.status_mut() = StatusCode::UNAUTHORIZED;
227
228                if let Some(www_auth) = C::www_authenticate_header() {
229                    res.headers_mut().insert(header::WWW_AUTHENTICATE, www_auth);
230                } else {
231                    res.headers_mut()
232                        .insert(header::WWW_AUTHENTICATE, "Bearer".parse().unwrap());
233                }
234
235                Err(res)
236            }
237        }
238    }
239}
240
241/// Type that performs basic authorization.
242///
243/// See [`ValidateRequestHeader::basic`] for more details.
244pub struct Basic<ResBody> {
245    header_value: HeaderValue,
246    _ty: PhantomData<fn() -> ResBody>,
247}
248
249impl<ResBody> Basic<ResBody> {
250    fn new(username: &str, password: &str) -> Self
251    where
252        ResBody: Default,
253    {
254        let encoded = BASE64.encode(format!("{}:{}", username, password));
255        let header_value = format!("Basic {}", encoded).parse().unwrap();
256        Self {
257            header_value,
258            _ty: PhantomData,
259        }
260    }
261}
262
263impl<ResBody> Clone for Basic<ResBody> {
264    fn clone(&self) -> Self {
265        Self {
266            header_value: self.header_value.clone(),
267            _ty: PhantomData,
268        }
269    }
270}
271
272impl<ResBody> fmt::Debug for Basic<ResBody> {
273    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274        f.debug_struct("Basic")
275            .field("header_value", &self.header_value)
276            .finish()
277    }
278}
279
280// Private module with the actual implementation details
281mod sealed {
282    use super::*;
283
284    /// Private trait that contains the actual authorization logic
285    pub trait AuthorizerSeal: Send + Sync + 'static {
286        /// Check if the given header value is valid for this authorizer.
287        fn is_valid(&self, header_value: &HeaderValue) -> bool;
288
289        /// Return the WWW-Authenticate header value if applicable.
290        fn www_authenticate_header() -> Option<HeaderValue>;
291    }
292
293    impl<ResBody: Default + Send + 'static> AuthorizerSeal for Basic<ResBody> {
294        fn is_valid(&self, header_value: &HeaderValue) -> bool {
295            header_value == self.header_value
296        }
297
298        fn www_authenticate_header() -> Option<HeaderValue> {
299            Some(HeaderValue::from_static("Basic"))
300        }
301    }
302
303    impl<ResBody: Default + Send + 'static> AuthorizerSeal for Bearer<ResBody> {
304        fn is_valid(&self, header_value: &HeaderValue) -> bool {
305            header_value == self.header_value
306        }
307
308        fn www_authenticate_header() -> Option<HeaderValue> {
309            None
310        }
311    }
312
313    impl<T, const N: usize> AuthorizerSeal for [T; N]
314    where
315        T: AuthorizerSeal,
316    {
317        fn is_valid(&self, header_value: &HeaderValue) -> bool {
318            self.iter().any(|auth| auth.is_valid(header_value))
319        }
320
321        fn www_authenticate_header() -> Option<HeaderValue> {
322            T::www_authenticate_header()
323        }
324    }
325
326    impl<T> AuthorizerSeal for Vec<T>
327    where
328        T: AuthorizerSeal,
329    {
330        fn is_valid(&self, header_value: &HeaderValue) -> bool {
331            self.iter().any(|auth| auth.is_valid(header_value))
332        }
333
334        fn www_authenticate_header() -> Option<HeaderValue> {
335            T::www_authenticate_header()
336        }
337    }
338
339    impl<T> AuthorizerSeal for Arc<T>
340    where
341        T: AuthorizerSeal,
342    {
343        fn is_valid(&self, header_value: &HeaderValue) -> bool {
344            (**self).is_valid(header_value)
345        }
346
347        fn www_authenticate_header() -> Option<HeaderValue> {
348            T::www_authenticate_header()
349        }
350    }
351}
352
353/// Trait for authorizing requests.
354pub trait Authorizer: sealed::AuthorizerSeal {
355    type ResBody: Default + Send + 'static;
356}
357
358// Implement the public trait for our existing types
359impl<ResBody: Default + Send + 'static> Authorizer for Basic<ResBody> {
360    type ResBody = ResBody;
361}
362impl<ResBody: Default + Send + 'static> Authorizer for Bearer<ResBody> {
363    type ResBody = ResBody;
364}
365impl<T: Authorizer, const N: usize> Authorizer for [T; N] {
366    type ResBody = T::ResBody;
367}
368impl<T: Authorizer> Authorizer for Vec<T> {
369    type ResBody = T::ResBody;
370}
371impl<T: Authorizer> Authorizer for Arc<T> {
372    type ResBody = T::ResBody;
373}
374
375pub struct AuthorizeContext<C> {
376    credential: C,
377    allow_anonymous: bool,
378}
379
380impl<C> AuthorizeContext<C> {
381    /// Create a new [`AuthorizeContext`] with the given credential.
382    pub(crate) fn new(credential: C) -> Self {
383        Self {
384            credential,
385            allow_anonymous: false,
386        }
387    }
388}
389
390impl<C: Clone> Clone for AuthorizeContext<C> {
391    fn clone(&self) -> Self {
392        Self {
393            credential: self.credential.clone(),
394            allow_anonymous: self.allow_anonymous,
395        }
396    }
397}
398
399impl<C: fmt::Debug> fmt::Debug for AuthorizeContext<C> {
400    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401        f.debug_struct("AuthorizeContext")
402            .field("credential", &self.credential)
403            .field("allow_anonymous", &self.allow_anonymous)
404            .finish()
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    use crate::layer::validate_request::ValidateRequestHeaderLayer;
413    use crate::{Body, header};
414
415    use rama_core::error::BoxError;
416    use rama_core::service::service_fn;
417    use rama_core::{Context, Layer, Service};
418
419    #[tokio::test]
420    async fn valid_basic_token() {
421        let service = ValidateRequestHeaderLayer::basic("foo", "bar").into_layer(service_fn(echo));
422
423        let request = Request::get("/")
424            .header(
425                header::AUTHORIZATION,
426                format!("Basic {}", BASE64.encode("foo:bar")),
427            )
428            .body(Body::empty())
429            .unwrap();
430
431        let res = service.serve(Context::default(), request).await.unwrap();
432
433        assert_eq!(res.status(), StatusCode::OK);
434    }
435
436    #[tokio::test]
437    async fn invalid_basic_token() {
438        let service = ValidateRequestHeaderLayer::basic("foo", "bar").into_layer(service_fn(echo));
439
440        let request = Request::get("/")
441            .header(
442                header::AUTHORIZATION,
443                format!("Basic {}", BASE64.encode("wrong:credentials")),
444            )
445            .body(Body::empty())
446            .unwrap();
447
448        let res = service.serve(Context::default(), request).await.unwrap();
449
450        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
451
452        let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap();
453        assert_eq!(www_authenticate, "Basic");
454    }
455
456    #[tokio::test]
457    async fn valid_bearer_token() {
458        let service = ValidateRequestHeaderLayer::bearer("foobar").into_layer(service_fn(echo));
459
460        let request = Request::get("/")
461            .header(header::AUTHORIZATION, "Bearer foobar")
462            .body(Body::empty())
463            .unwrap();
464
465        let res = service.serve(Context::default(), request).await.unwrap();
466
467        assert_eq!(res.status(), StatusCode::OK);
468    }
469
470    #[tokio::test]
471    async fn basic_auth_is_case_sensitive_in_prefix() {
472        let service = ValidateRequestHeaderLayer::basic("foo", "bar").into_layer(service_fn(echo));
473
474        let request = Request::get("/")
475            .header(
476                header::AUTHORIZATION,
477                format!("basic {}", BASE64.encode("foo:bar")),
478            )
479            .body(Body::empty())
480            .unwrap();
481
482        let res = service.serve(Context::default(), request).await.unwrap();
483
484        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
485    }
486
487    #[tokio::test]
488    async fn basic_auth_is_case_sensitive_in_value() {
489        let service = ValidateRequestHeaderLayer::basic("foo", "bar").into_layer(service_fn(echo));
490
491        let request = Request::get("/")
492            .header(
493                header::AUTHORIZATION,
494                format!("Basic {}", BASE64.encode("Foo:bar")),
495            )
496            .body(Body::empty())
497            .unwrap();
498
499        let res = service.serve(Context::default(), request).await.unwrap();
500
501        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
502    }
503
504    #[tokio::test]
505    async fn invalid_bearer_token() {
506        let service = ValidateRequestHeaderLayer::bearer("foobar").into_layer(service_fn(echo));
507
508        let request = Request::get("/")
509            .header(header::AUTHORIZATION, "Bearer wat")
510            .body(Body::empty())
511            .unwrap();
512
513        let res = service.serve(Context::default(), request).await.unwrap();
514
515        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
516    }
517
518    #[tokio::test]
519    async fn bearer_token_is_case_sensitive_in_prefix() {
520        let service = ValidateRequestHeaderLayer::bearer("foobar").into_layer(service_fn(echo));
521
522        let request = Request::get("/")
523            .header(header::AUTHORIZATION, "bearer foobar")
524            .body(Body::empty())
525            .unwrap();
526
527        let res = service.serve(Context::default(), request).await.unwrap();
528
529        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
530    }
531
532    #[tokio::test]
533    async fn bearer_token_is_case_sensitive_in_token() {
534        let service = ValidateRequestHeaderLayer::bearer("foobar").into_layer(service_fn(echo));
535
536        let request = Request::get("/")
537            .header(header::AUTHORIZATION, "Bearer Foobar")
538            .body(Body::empty())
539            .unwrap();
540
541        let res = service.serve(Context::default(), request).await.unwrap();
542
543        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
544    }
545
546    #[tokio::test]
547    async fn multiple_basic_auth_vec() {
548        let auth1 = Basic::new("user1", "pass1");
549        let auth2 = Basic::new("user2", "pass2");
550        let auth_vec = vec![auth1, auth2];
551        let auth_context = AuthorizeContext::new(auth_vec);
552        let service = ValidateRequestHeaderLayer::custom(auth_context).into_layer(service_fn(echo));
553
554        // Test first credential
555        let request = Request::builder()
556            .header(
557                header::AUTHORIZATION,
558                format!("Basic {}", BASE64.encode("user1:pass1")),
559            )
560            .body(Body::default())
561            .unwrap();
562        let response = service.serve(Context::default(), request).await.unwrap();
563        assert_eq!(StatusCode::OK, response.status());
564
565        // Test second credential
566        let request = Request::builder()
567            .header(
568                header::AUTHORIZATION,
569                format!("Basic {}", BASE64.encode("user2:pass2")),
570            )
571            .body(Body::default())
572            .unwrap();
573        let response = service.serve(Context::default(), request).await.unwrap();
574        assert_eq!(StatusCode::OK, response.status());
575
576        // Test invalid credential
577        let request = Request::builder()
578            .header(
579                header::AUTHORIZATION,
580                format!("Basic {}", BASE64.encode("invalid:invalid")),
581            )
582            .body(Body::default())
583            .unwrap();
584        let response = service.serve(Context::default(), request).await.unwrap();
585        assert_eq!(StatusCode::UNAUTHORIZED, response.status());
586    }
587
588    #[tokio::test]
589    async fn multiple_basic_auth_array() {
590        let auth1 = Basic::new("user1", "pass1");
591        let auth_array = [auth1.clone(), auth1.clone()];
592        let auth_context = AuthorizeContext::new(auth_array);
593        let service = ValidateRequestHeaderLayer::custom(auth_context).into_layer(service_fn(echo));
594
595        // Test valid credential
596        let request = Request::builder()
597            .header(
598                header::AUTHORIZATION,
599                format!("Basic {}", BASE64.encode("user1:pass1")),
600            )
601            .body(Body::default())
602            .unwrap();
603        let response = service.serve(Context::default(), request).await.unwrap();
604        assert_eq!(StatusCode::OK, response.status());
605    }
606
607    #[tokio::test]
608    async fn arc_basic_auth() {
609        let auth = Basic::new("user", "pass");
610        let arc_auth = Arc::new(auth);
611        let auth_context = AuthorizeContext::new(arc_auth);
612        let service = ValidateRequestHeaderLayer::custom(auth_context).into_layer(service_fn(echo));
613
614        let request = Request::builder()
615            .header(
616                header::AUTHORIZATION,
617                format!("Basic {}", BASE64.encode("user:pass")),
618            )
619            .body(Body::default())
620            .unwrap();
621        let response = service.serve(Context::default(), request).await.unwrap();
622        assert_eq!(StatusCode::OK, response.status());
623    }
624
625    #[tokio::test]
626    async fn basic_allows_anonymous_if_header_is_missing() {
627        let service = ValidateRequestHeaderLayer::basic("foo", "bar")
628            .with_allow_anonymous(true)
629            .into_layer(service_fn(echo));
630
631        let request = Request::get("/").body(Body::empty()).unwrap();
632
633        let res = service.serve(Context::default(), request).await.unwrap();
634
635        assert_eq!(res.status(), StatusCode::OK);
636    }
637
638    #[tokio::test]
639    async fn basic_fails_if_allow_anonymous_and_credentials_are_invalid() {
640        let service = ValidateRequestHeaderLayer::basic("foo", "bar")
641            .with_allow_anonymous(true)
642            .into_layer(service_fn(echo));
643
644        let request = Request::get("/")
645            .header(
646                header::AUTHORIZATION,
647                format!("Basic {}", BASE64.encode("wrong:credentials")),
648            )
649            .body(Body::empty())
650            .unwrap();
651
652        let res = service.serve(Context::default(), request).await.unwrap();
653
654        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
655    }
656
657    #[tokio::test]
658    async fn bearer_allows_anonymous_if_header_is_missing() {
659        let service = ValidateRequestHeaderLayer::bearer("foobar")
660            .with_allow_anonymous(true)
661            .into_layer(service_fn(echo));
662
663        let request = Request::get("/").body(Body::empty()).unwrap();
664
665        let res = service.serve(Context::default(), request).await.unwrap();
666
667        assert_eq!(res.status(), StatusCode::OK);
668    }
669
670    #[tokio::test]
671    async fn bearer_fails_if_allow_anonymous_and_credentials_are_invalid() {
672        let service = ValidateRequestHeaderLayer::bearer("foobar")
673            .with_allow_anonymous(true)
674            .into_layer(service_fn(echo));
675
676        let request = Request::get("/")
677            .header(header::AUTHORIZATION, "Bearer wrong")
678            .body(Body::empty())
679            .unwrap();
680
681        let res = service.serve(Context::default(), request).await.unwrap();
682
683        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
684    }
685
686    async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
687        Ok(Response::new(req.into_body()))
688    }
689}