rama_http/layer/auth/
async_require_authorization.rs

1//! Authorize requests using the [`Authorization`] header asynchronously.
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use bytes::Bytes;
9//!
10//! use rama_http::layer::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
11//! use rama_http::{Body, Request, Response, StatusCode, header::AUTHORIZATION};
12//! use rama_core::service::service_fn;
13//! use rama_core::{Context, Service, Layer};
14//! use rama_core::error::BoxError;
15//!
16//! #[derive(Clone, Copy)]
17//! struct MyAuth;
18//!
19//! impl<S, B> AsyncAuthorizeRequest<S, B> for MyAuth
20//! where
21//!     S: Clone + Send + Sync + 'static,
22//!     B: Send + Sync + 'static,
23//! {
24//!     type RequestBody = B;
25//!     type ResponseBody = Body;
26//!
27//!     async fn authorize(&self, mut ctx: Context<S>, request: Request<B>) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
28//!         if let Some(user_id) = check_auth(&request).await {
29//!             // Set `user_id` as a request extension so it can be accessed by other
30//!             // services down the stack.
31//!             ctx.insert(user_id);
32//!
33//!             Ok((ctx, request))
34//!         } else {
35//!             let unauthorized_response = Response::builder()
36//!                 .status(StatusCode::UNAUTHORIZED)
37//!                 .body(Body::default())
38//!                 .unwrap();
39//!
40//!             Err(unauthorized_response)
41//!         }
42//!     }
43//! }
44//!
45//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
46//!     // ...
47//!     # None
48//! }
49//!
50//! #[derive(Clone, Debug)]
51//! struct UserId(String);
52//!
53//! async fn handle<S>(ctx: Context<S>, _request: Request) -> Result<Response, BoxError> {
54//!     // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the
55//!     // request was authorized and `UserId` will be present.
56//!     let user_id = ctx
57//!         .get::<UserId>()
58//!         .expect("UserId will be there if request was authorized");
59//!
60//!     println!("request from {:?}", user_id);
61//!
62//!     Ok(Response::new(Body::default()))
63//! }
64//!
65//! # #[tokio::main]
66//! # async fn main() -> Result<(), BoxError> {
67//! let service = (
68//!     // Authorize requests using `MyAuth`
69//!     AsyncRequireAuthorizationLayer::new(MyAuth),
70//! ).layer(service_fn(handle::<()>));
71//! # Ok(())
72//! # }
73//! ```
74//!
75//! Or using a closure:
76//!
77//! ```
78//! use bytes::Bytes;
79//!
80//! use rama_http::layer::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
81//! use rama_http::{Body, Request, Response, StatusCode};
82//! use rama_core::service::service_fn;
83//! use rama_core::{Service, Layer};
84//! use rama_core::error::BoxError;
85//!
86//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
87//!     // ...
88//!     # None
89//! }
90//!
91//! #[derive(Debug)]
92//! struct UserId(String);
93//!
94//! async fn handle(request: Request) -> Result<Response, BoxError> {
95//!     # todo!();
96//!     // ...
97//! }
98//!
99//! # #[tokio::main]
100//! # async fn main() -> Result<(), BoxError> {
101//! let service =
102//!     AsyncRequireAuthorizationLayer::new(async |request: Request| {
103//!         if let Some(user_id) = check_auth(&request).await {
104//!             Ok(request)
105//!         } else {
106//!             let unauthorized_response = Response::builder()
107//!                 .status(StatusCode::UNAUTHORIZED)
108//!                 .body(Body::default())
109//!                 .unwrap();
110//!
111//!             Err(unauthorized_response)
112//!         }
113//!     })
114//!     .layer(service_fn(handle));
115//! # Ok(())
116//! # }
117//! ```
118
119use crate::{Request, Response};
120use rama_core::{Context, Layer, Service};
121use rama_utils::macros::define_inner_service_accessors;
122
123/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the
124/// [`Authorization`] header.
125///
126/// See the [module docs](crate::layer::auth::async_require_authorization) for an example.
127///
128/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
129#[derive(Debug, Clone)]
130pub struct AsyncRequireAuthorizationLayer<T> {
131    auth: T,
132}
133
134impl<T> AsyncRequireAuthorizationLayer<T> {
135    /// Authorize requests using a custom scheme.
136    pub const fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
137        Self { auth }
138    }
139}
140
141impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
142where
143    T: Clone,
144{
145    type Service = AsyncRequireAuthorization<S, T>;
146
147    fn layer(&self, inner: S) -> Self::Service {
148        AsyncRequireAuthorization::new(inner, self.auth.clone())
149    }
150
151    fn into_layer(self, inner: S) -> Self::Service {
152        AsyncRequireAuthorization::new(inner, self.auth)
153    }
154}
155
156/// Middleware that authorizes all requests using the [`Authorization`] header.
157///
158/// See the [module docs](crate::layer::auth::async_require_authorization) for an example.
159///
160/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
161#[derive(Clone, Debug)]
162pub struct AsyncRequireAuthorization<S, T> {
163    inner: S,
164    auth: T,
165}
166
167impl<S, T> AsyncRequireAuthorization<S, T> {
168    /// Authorize requests using a custom scheme.
169    ///
170    /// The `Authorization` header is required to have the value provided.
171    pub const fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
172        Self { inner, auth }
173    }
174
175    define_inner_service_accessors!();
176}
177
178impl<ReqBody, ResBody, S, State, Auth> Service<State, Request<ReqBody>>
179    for AsyncRequireAuthorization<S, Auth>
180where
181    Auth: AsyncAuthorizeRequest<State, ReqBody, ResponseBody = ResBody> + Send + Sync + 'static,
182    S: Service<State, Request<Auth::RequestBody>, Response = Response<ResBody>>,
183    ReqBody: Send + 'static,
184    ResBody: Send + 'static,
185    State: Clone + Send + Sync + 'static,
186{
187    type Response = Response<ResBody>;
188    type Error = S::Error;
189
190    async fn serve(
191        &self,
192        ctx: Context<State>,
193        req: Request<ReqBody>,
194    ) -> Result<Self::Response, Self::Error> {
195        let (ctx, req) = match self.auth.authorize(ctx, req).await {
196            Ok(req) => req,
197            Err(res) => return Ok(res),
198        };
199        self.inner.serve(ctx, req).await
200    }
201}
202
203/// Trait for authorizing requests.
204pub trait AsyncAuthorizeRequest<S, B> {
205    /// The type of request body returned by `authorize`.
206    ///
207    /// Set this to `B` unless you need to change the request body type.
208    type RequestBody;
209
210    /// The body type used for responses to unauthorized requests.
211    type ResponseBody;
212
213    /// Authorize the request.
214    ///
215    /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
216    fn authorize(
217        &self,
218        ctx: Context<S>,
219        request: Request<B>,
220    ) -> impl Future<
221        Output = Result<(Context<S>, Request<Self::RequestBody>), Response<Self::ResponseBody>>,
222    > + Send
223    + '_;
224}
225
226impl<S, B, F, Fut, ReqBody, ResBody> AsyncAuthorizeRequest<S, B> for F
227where
228    F: Fn(Context<S>, Request<B>) -> Fut + Send + Sync + 'static,
229    Fut:
230        Future<Output = Result<(Context<S>, Request<ReqBody>), Response<ResBody>>> + Send + 'static,
231    B: Send + 'static,
232    S: Clone + Send + Sync + 'static,
233    ReqBody: Send + 'static,
234    ResBody: Send + 'static,
235{
236    type RequestBody = ReqBody;
237    type ResponseBody = ResBody;
238
239    async fn authorize(
240        &self,
241        ctx: Context<S>,
242        request: Request<B>,
243    ) -> Result<(Context<S>, Request<Self::RequestBody>), Response<Self::ResponseBody>> {
244        self(ctx, request).await
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    #[allow(unused_imports)]
251    use super::*;
252
253    use crate::{Body, StatusCode, header};
254    use rama_core::error::BoxError;
255    use rama_core::service::service_fn;
256
257    #[derive(Clone, Copy)]
258    struct MyAuth;
259
260    impl<S, B> AsyncAuthorizeRequest<S, B> for MyAuth
261    where
262        S: Clone + Send + Sync + 'static,
263        B: Send + 'static,
264    {
265        type RequestBody = B;
266        type ResponseBody = Body;
267
268        async fn authorize(
269            &self,
270            mut ctx: Context<S>,
271            request: Request<B>,
272        ) -> Result<(Context<S>, Request<Self::RequestBody>), Response<Self::ResponseBody>>
273        {
274            let authorized = request
275                .headers()
276                .get(header::AUTHORIZATION)
277                .and_then(|it: &rama_http_types::HeaderValue| it.to_str().ok())
278                .and_then(|it| it.strip_prefix("Bearer "))
279                .map(|it| it == "69420")
280                .unwrap_or(false);
281
282            if authorized {
283                let user_id = UserId("6969".to_owned());
284                ctx.insert(user_id);
285                Ok((ctx, request))
286            } else {
287                Err(Response::builder()
288                    .status(StatusCode::UNAUTHORIZED)
289                    .body(Body::empty())
290                    .unwrap())
291            }
292        }
293    }
294
295    #[derive(Debug, Clone)]
296    #[allow(dead_code)]
297    struct UserId(String);
298
299    #[tokio::test]
300    async fn require_async_auth_works() {
301        let service = AsyncRequireAuthorizationLayer::new(MyAuth).layer(service_fn(echo));
302
303        let request = Request::get("/")
304            .header(header::AUTHORIZATION, "Bearer 69420")
305            .body(Body::empty())
306            .unwrap();
307
308        let res = service.serve(Context::default(), request).await.unwrap();
309
310        assert_eq!(res.status(), StatusCode::OK);
311    }
312
313    #[tokio::test]
314    async fn require_async_auth_401() {
315        let service = AsyncRequireAuthorizationLayer::new(MyAuth).layer(service_fn(echo));
316
317        let request = Request::get("/")
318            .header(header::AUTHORIZATION, "Bearer deez")
319            .body(Body::empty())
320            .unwrap();
321
322        let res = service.serve(Context::default(), request).await.unwrap();
323
324        assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
325    }
326
327    async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
328        Ok(Response::new(req.into_body()))
329    }
330}