rama_http/layer/auth/
require_authorization.rs

1//! Authorize requests using [`ValidateRequest`].
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use bytes::Bytes;
9//!
10//! use rama_http::layer::validate_request::{ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer};
11//! use rama_http::{Body, Request, Response, StatusCode, header::AUTHORIZATION};
12//! use rama_core::service::service_fn;
13//! use rama_core::{Context, Service, Layer};
14//! use rama_core::error::BoxError;
15//!
16//! async fn handle(request: Request) -> Result<Response, BoxError> {
17//!     Ok(Response::new(Body::default()))
18//! }
19//!
20//! # #[tokio::main]
21//! # async fn main() -> Result<(), BoxError> {
22//! let mut service = (
23//!     // Require the `Authorization` header to be `Bearer passwordlol`
24//!     ValidateRequestHeaderLayer::bearer("passwordlol"),
25//! ).layer(service_fn(handle));
26//!
27//! // Requests with the correct token are allowed through
28//! let request = Request::builder()
29//!     .header(AUTHORIZATION, "Bearer passwordlol")
30//!     .body(Body::default())
31//!     .unwrap();
32//!
33//! let response = service
34//!     .serve(Context::default(), request)
35//!     .await?;
36//!
37//! assert_eq!(StatusCode::OK, response.status());
38//!
39//! // Requests with an invalid token get a `401 Unauthorized` response
40//! let request = Request::builder()
41//!     .body(Body::default())
42//!     .unwrap();
43//!
44//! let response = service
45//!     .serve(Context::default(), request)
46//!     .await?;
47//!
48//! assert_eq!(StatusCode::UNAUTHORIZED, response.status());
49//! # Ok(())
50//! # }
51//! ```
52//!
53//! Custom validation can be made by implementing [`ValidateRequest`].
54
55use base64::Engine as _;
56use std::{fmt, marker::PhantomData};
57
58use crate::layer::validate_request::{
59    ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer,
60};
61use crate::{
62    header::{self, HeaderValue},
63    Request, Response, StatusCode,
64};
65use rama_core::Context;
66
67use rama_net::user::UserId;
68
69const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
70
71impl<C> ValidateRequestHeaderLayer<AuthorizeContext<C>> {
72    /// Allow anonymous requests.
73    pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
74        self.validate.allow_anonymous = allow_anonymous;
75        self
76    }
77
78    /// Allow anonymous requests.
79    pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
80        self.validate.allow_anonymous = allow_anonymous;
81        self
82    }
83}
84
85impl<S, C> ValidateRequestHeader<S, AuthorizeContext<C>> {
86    /// Allow anonymous requests.
87    pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
88        self.validate.allow_anonymous = allow_anonymous;
89        self
90    }
91
92    /// Allow anonymous requests.
93    pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
94        self.validate.allow_anonymous = allow_anonymous;
95        self
96    }
97}
98
99impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Basic<ResBody>>> {
100    /// Authorize requests using a username and password pair.
101    ///
102    /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
103    /// `base64_encode("{username}:{password}")`.
104    ///
105    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
106    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
107    pub fn basic(inner: S, username: &str, value: &str) -> Self
108    where
109        ResBody: Default,
110    {
111        Self::custom(inner, AuthorizeContext::new(Basic::new(username, value)))
112    }
113}
114
115impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Basic<ResBody>>> {
116    /// Authorize requests using a username and password pair.
117    ///
118    /// The `Authorization` header is required to be `Basic {credentials}` where `credentials` is
119    /// `base64_encode("{username}:{password}")`.
120    ///
121    /// Since the username and password is sent in clear text it is recommended to use HTTPS/TLS
122    /// with this method. However use of HTTPS/TLS is not enforced by this middleware.
123    pub fn basic(username: &str, password: &str) -> Self
124    where
125        ResBody: Default,
126    {
127        Self::custom(AuthorizeContext::new(Basic::new(username, password)))
128    }
129}
130
131impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Bearer<ResBody>>> {
132    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
133    ///
134    /// The `Authorization` header is required to be `Bearer {token}`.
135    ///
136    /// # Panics
137    ///
138    /// Panics if the token is not a valid [`HeaderValue`].
139    pub fn bearer(inner: S, token: &str) -> Self
140    where
141        ResBody: Default,
142    {
143        Self::custom(inner, AuthorizeContext::new(Bearer::new(token)))
144    }
145}
146
147impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Bearer<ResBody>>> {
148    /// Authorize requests using a "bearer token". Commonly used for OAuth 2.
149    ///
150    /// The `Authorization` header is required to be `Bearer {token}`.
151    ///
152    /// # Panics
153    ///
154    /// Panics if the token is not a valid [`HeaderValue`].
155    pub fn bearer(token: &str) -> Self
156    where
157        ResBody: Default,
158    {
159        Self::custom(AuthorizeContext::new(Bearer::new(token)))
160    }
161}
162
163/// Type that performs "bearer token" authorization.
164///
165/// See [`ValidateRequestHeader::bearer`] for more details.
166pub struct Bearer<ResBody> {
167    header_value: HeaderValue,
168    _ty: PhantomData<fn() -> ResBody>,
169}
170
171impl<ResBody> Bearer<ResBody> {
172    fn new(token: &str) -> Self
173    where
174        ResBody: Default,
175    {
176        Self {
177            header_value: format!("Bearer {}", token)
178                .parse()
179                .expect("token is not a valid header value"),
180            _ty: PhantomData,
181        }
182    }
183}
184
185impl<ResBody> Clone for Bearer<ResBody> {
186    fn clone(&self) -> Self {
187        Self {
188            header_value: self.header_value.clone(),
189            _ty: PhantomData,
190        }
191    }
192}
193
194impl<ResBody> fmt::Debug for Bearer<ResBody> {
195    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196        f.debug_struct("Bearer")
197            .field("header_value", &self.header_value)
198            .finish()
199    }
200}
201
202impl<S, B, ResBody> ValidateRequest<S, B> for AuthorizeContext<Bearer<ResBody>>
203where
204    ResBody: Default + Send + 'static,
205    B: Send + 'static,
206    S: Clone + Send + Sync + 'static,
207{
208    type ResponseBody = ResBody;
209
210    async fn validate(
211        &self,
212        ctx: Context<S>,
213        request: Request<B>,
214    ) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
215        match request.headers().get(header::AUTHORIZATION) {
216            Some(actual) if actual == self.credential.header_value => Ok((ctx, request)),
217            None if self.allow_anonymous => {
218                let mut ctx = ctx;
219                ctx.insert(UserId::Anonymous);
220                Ok((ctx, request))
221            }
222            _ => {
223                let mut res = Response::new(ResBody::default());
224                *res.status_mut() = StatusCode::UNAUTHORIZED;
225                Err(res)
226            }
227        }
228    }
229}
230
231/// Type that performs basic authorization.
232///
233/// See [`ValidateRequestHeader::basic`] for more details.
234pub struct Basic<ResBody> {
235    header_value: HeaderValue,
236    _ty: PhantomData<fn() -> ResBody>,
237}
238
239impl<ResBody> Basic<ResBody> {
240    fn new(username: &str, password: &str) -> Self
241    where
242        ResBody: Default,
243    {
244        let encoded = BASE64.encode(format!("{}:{}", username, password));
245        let header_value = format!("Basic {}", encoded).parse().unwrap();
246        Self {
247            header_value,
248            _ty: PhantomData,
249        }
250    }
251}
252
253impl<ResBody> Clone for Basic<ResBody> {
254    fn clone(&self) -> Self {
255        Self {
256            header_value: self.header_value.clone(),
257            _ty: PhantomData,
258        }
259    }
260}
261
262impl<ResBody> fmt::Debug for Basic<ResBody> {
263    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264        f.debug_struct("Basic")
265            .field("header_value", &self.header_value)
266            .finish()
267    }
268}
269
270impl<S, B, ResBody> ValidateRequest<S, B> for AuthorizeContext<Basic<ResBody>>
271where
272    ResBody: Default + Send + 'static,
273    B: Send + 'static,
274    S: Clone + Send + Sync + 'static,
275{
276    type ResponseBody = ResBody;
277
278    async fn validate(
279        &self,
280        ctx: Context<S>,
281        request: Request<B>,
282    ) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
283        match request.headers().get(header::AUTHORIZATION) {
284            Some(actual) if actual == self.credential.header_value => Ok((ctx, request)),
285            None if self.allow_anonymous => {
286                let mut ctx = ctx;
287                ctx.insert(UserId::Anonymous);
288                Ok((ctx, request))
289            }
290            _ => {
291                let mut res = Response::new(ResBody::default());
292                *res.status_mut() = StatusCode::UNAUTHORIZED;
293                res.headers_mut()
294                    .insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
295                Err(res)
296            }
297        }
298    }
299}
300
301pub struct AuthorizeContext<C> {
302    credential: C,
303    allow_anonymous: bool,
304}
305
306impl<C> AuthorizeContext<C> {
307    pub(crate) fn new(credential: C) -> Self {
308        Self {
309            credential,
310            allow_anonymous: false,
311        }
312    }
313}
314
315impl<C: Clone> Clone for AuthorizeContext<C> {
316    fn clone(&self) -> Self {
317        Self {
318            credential: self.credential.clone(),
319            allow_anonymous: self.allow_anonymous,
320        }
321    }
322}
323
324impl<C: fmt::Debug> fmt::Debug for AuthorizeContext<C> {
325    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326        f.debug_struct("AuthorizeContext")
327            .field("credential", &self.credential)
328            .field("allow_anonymous", &self.allow_anonymous)
329            .finish()
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    #[allow(unused_imports)]
336    use super::*;
337
338    use crate::layer::validate_request::ValidateRequestHeaderLayer;
339    use crate::{header, Body};
340    use rama_core::error::BoxError;
341    use rama_core::service::service_fn;
342    use rama_core::{Context, Layer, Service};
343
344    #[tokio::test]
345    async fn valid_basic_token() {
346        let service = ValidateRequestHeaderLayer::basic("foo", "bar").layer(service_fn(echo));
347
348        let request = Request::get("/")
349            .header(
350                header::AUTHORIZATION,
351                format!("Basic {}", BASE64.encode("foo:bar")),
352            )
353            .body(Body::empty())
354            .unwrap();
355
356        let res = service.serve(Context::default(), request).await.unwrap();
357
358        assert_eq!(res.status(), StatusCode::OK);
359    }
360
361    #[tokio::test]
362    async fn invalid_basic_token() {
363        let service = ValidateRequestHeaderLayer::basic("foo", "bar").layer(service_fn(echo));
364
365        let request = Request::get("/")
366            .header(
367                header::AUTHORIZATION,
368                format!("Basic {}", BASE64.encode("wrong:credentials")),
369            )
370            .body(Body::empty())
371            .unwrap();
372
373        let res = service.serve(Context::default(), request).await.unwrap();
374
375        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
376
377        let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap();
378        assert_eq!(www_authenticate, "Basic");
379    }
380
381    #[tokio::test]
382    async fn valid_bearer_token() {
383        let service = ValidateRequestHeaderLayer::bearer("foobar").layer(service_fn(echo));
384
385        let request = Request::get("/")
386            .header(header::AUTHORIZATION, "Bearer foobar")
387            .body(Body::empty())
388            .unwrap();
389
390        let res = service.serve(Context::default(), request).await.unwrap();
391
392        assert_eq!(res.status(), StatusCode::OK);
393    }
394
395    #[tokio::test]
396    async fn basic_auth_is_case_sensitive_in_prefix() {
397        let service = ValidateRequestHeaderLayer::basic("foo", "bar").layer(service_fn(echo));
398
399        let request = Request::get("/")
400            .header(
401                header::AUTHORIZATION,
402                format!("basic {}", BASE64.encode("foo:bar")),
403            )
404            .body(Body::empty())
405            .unwrap();
406
407        let res = service.serve(Context::default(), request).await.unwrap();
408
409        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
410    }
411
412    #[tokio::test]
413    async fn basic_auth_is_case_sensitive_in_value() {
414        let service = ValidateRequestHeaderLayer::basic("foo", "bar").layer(service_fn(echo));
415
416        let request = Request::get("/")
417            .header(
418                header::AUTHORIZATION,
419                format!("Basic {}", BASE64.encode("Foo:bar")),
420            )
421            .body(Body::empty())
422            .unwrap();
423
424        let res = service.serve(Context::default(), request).await.unwrap();
425
426        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
427    }
428
429    #[tokio::test]
430    async fn invalid_bearer_token() {
431        let service = ValidateRequestHeaderLayer::bearer("foobar").layer(service_fn(echo));
432
433        let request = Request::get("/")
434            .header(header::AUTHORIZATION, "Bearer wat")
435            .body(Body::empty())
436            .unwrap();
437
438        let res = service.serve(Context::default(), request).await.unwrap();
439
440        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
441    }
442
443    #[tokio::test]
444    async fn bearer_token_is_case_sensitive_in_prefix() {
445        let service = ValidateRequestHeaderLayer::bearer("foobar").layer(service_fn(echo));
446
447        let request = Request::get("/")
448            .header(header::AUTHORIZATION, "bearer foobar")
449            .body(Body::empty())
450            .unwrap();
451
452        let res = service.serve(Context::default(), request).await.unwrap();
453
454        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
455    }
456
457    #[tokio::test]
458    async fn bearer_token_is_case_sensitive_in_token() {
459        let service = ValidateRequestHeaderLayer::bearer("foobar").layer(service_fn(echo));
460
461        let request = Request::get("/")
462            .header(header::AUTHORIZATION, "Bearer Foobar")
463            .body(Body::empty())
464            .unwrap();
465
466        let res = service.serve(Context::default(), request).await.unwrap();
467
468        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
469    }
470
471    async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
472        Ok(Response::new(req.into_body()))
473    }
474
475    #[tokio::test]
476    async fn basic_allows_anonymous_if_header_is_missing() {
477        let service = ValidateRequestHeaderLayer::basic("foo", "bar")
478            .with_allow_anonymous(true)
479            .layer(service_fn(echo));
480
481        let request = Request::get("/").body(Body::empty()).unwrap();
482
483        let res = service.serve(Context::default(), request).await.unwrap();
484
485        assert_eq!(res.status(), StatusCode::OK);
486    }
487
488    #[tokio::test]
489    async fn basic_fails_if_allow_anonymous_and_credentials_are_invalid() {
490        let service = ValidateRequestHeaderLayer::basic("foo", "bar")
491            .with_allow_anonymous(true)
492            .layer(service_fn(echo));
493
494        let request = Request::get("/")
495            .header(
496                header::AUTHORIZATION,
497                format!("Basic {}", BASE64.encode("wrong:credentials")),
498            )
499            .body(Body::empty())
500            .unwrap();
501
502        let res = service.serve(Context::default(), request).await.unwrap();
503
504        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
505    }
506
507    #[tokio::test]
508    async fn bearer_allows_anonymous_if_header_is_missing() {
509        let service = ValidateRequestHeaderLayer::bearer("foobar")
510            .with_allow_anonymous(true)
511            .layer(service_fn(echo));
512
513        let request = Request::get("/").body(Body::empty()).unwrap();
514
515        let res = service.serve(Context::default(), request).await.unwrap();
516
517        assert_eq!(res.status(), StatusCode::OK);
518    }
519
520    #[tokio::test]
521    async fn bearer_fails_if_allow_anonymous_and_credentials_are_invalid() {
522        let service = ValidateRequestHeaderLayer::bearer("foobar")
523            .with_allow_anonymous(true)
524            .layer(service_fn(echo));
525
526        let request = Request::get("/")
527            .header(header::AUTHORIZATION, "Bearer wrong")
528            .body(Body::empty())
529            .unwrap();
530
531        let res = service.serve(Context::default(), request).await.unwrap();
532
533        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
534    }
535}