rama_http/layer/auth/async_require_authorization.rs
1//! Authorize requests using the [`Authorization`] header asynchronously.
2//!
3//! [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
4//!
5//! # Example
6//!
7//! ```
8//! use bytes::Bytes;
9//!
10//! use rama_http::layer::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
11//! use rama_http::{Body, Request, Response, StatusCode, header::AUTHORIZATION};
12//! use rama_core::service::service_fn;
13//! use rama_core::{Context, Service, Layer};
14//! use rama_core::error::BoxError;
15//!
16//! #[derive(Clone, Copy)]
17//! struct MyAuth;
18//!
19//! impl<S, B> AsyncAuthorizeRequest<S, B> for MyAuth
20//! where
21//! S: Clone + Send + Sync + 'static,
22//! B: Send + Sync + 'static,
23//! {
24//! type RequestBody = B;
25//! type ResponseBody = Body;
26//!
27//! async fn authorize(&self, mut ctx: Context<S>, request: Request<B>) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
28//! if let Some(user_id) = check_auth(&request).await {
29//! // Set `user_id` as a request extension so it can be accessed by other
30//! // services down the stack.
31//! ctx.insert(user_id);
32//!
33//! Ok((ctx, request))
34//! } else {
35//! let unauthorized_response = Response::builder()
36//! .status(StatusCode::UNAUTHORIZED)
37//! .body(Body::default())
38//! .unwrap();
39//!
40//! Err(unauthorized_response)
41//! }
42//! }
43//! }
44//!
45//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
46//! // ...
47//! # None
48//! }
49//!
50//! #[derive(Clone, Debug)]
51//! struct UserId(String);
52//!
53//! async fn handle<S>(ctx: Context<S>, _request: Request) -> Result<Response, BoxError> {
54//! // Access the `UserId` that was set in `on_authorized`. If `handle` gets called the
55//! // request was authorized and `UserId` will be present.
56//! let user_id = ctx
57//! .get::<UserId>()
58//! .expect("UserId will be there if request was authorized");
59//!
60//! println!("request from {:?}", user_id);
61//!
62//! Ok(Response::new(Body::default()))
63//! }
64//!
65//! # #[tokio::main]
66//! # async fn main() -> Result<(), BoxError> {
67//! let service = (
68//! // Authorize requests using `MyAuth`
69//! AsyncRequireAuthorizationLayer::new(MyAuth),
70//! ).layer(service_fn(handle::<()>));
71//! # Ok(())
72//! # }
73//! ```
74//!
75//! Or using a closure:
76//!
77//! ```
78//! use bytes::Bytes;
79//!
80//! use rama_http::layer::auth::{AsyncRequireAuthorizationLayer, AsyncAuthorizeRequest};
81//! use rama_http::{Body, Request, Response, StatusCode};
82//! use rama_core::service::service_fn;
83//! use rama_core::{Service, Layer};
84//! use rama_core::error::BoxError;
85//!
86//! async fn check_auth<B>(request: &Request<B>) -> Option<UserId> {
87//! // ...
88//! # None
89//! }
90//!
91//! #[derive(Debug)]
92//! struct UserId(String);
93//!
94//! async fn handle(request: Request) -> Result<Response, BoxError> {
95//! # todo!();
96//! // ...
97//! }
98//!
99//! # #[tokio::main]
100//! # async fn main() -> Result<(), BoxError> {
101//! let service =
102//! AsyncRequireAuthorizationLayer::new(async |request: Request| {
103//! if let Some(user_id) = check_auth(&request).await {
104//! Ok(request)
105//! } else {
106//! let unauthorized_response = Response::builder()
107//! .status(StatusCode::UNAUTHORIZED)
108//! .body(Body::default())
109//! .unwrap();
110//!
111//! Err(unauthorized_response)
112//! }
113//! })
114//! .layer(service_fn(handle));
115//! # Ok(())
116//! # }
117//! ```
118
119use crate::{Request, Response};
120use rama_core::{Context, Layer, Service};
121use rama_utils::macros::define_inner_service_accessors;
122
123/// Layer that applies [`AsyncRequireAuthorization`] which authorizes all requests using the
124/// [`Authorization`] header.
125///
126/// See the [module docs](crate::layer::auth::async_require_authorization) for an example.
127///
128/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
129#[derive(Debug, Clone)]
130pub struct AsyncRequireAuthorizationLayer<T> {
131 auth: T,
132}
133
134impl<T> AsyncRequireAuthorizationLayer<T> {
135 /// Authorize requests using a custom scheme.
136 pub const fn new(auth: T) -> AsyncRequireAuthorizationLayer<T> {
137 Self { auth }
138 }
139}
140
141impl<S, T> Layer<S> for AsyncRequireAuthorizationLayer<T>
142where
143 T: Clone,
144{
145 type Service = AsyncRequireAuthorization<S, T>;
146
147 fn layer(&self, inner: S) -> Self::Service {
148 AsyncRequireAuthorization::new(inner, self.auth.clone())
149 }
150
151 fn into_layer(self, inner: S) -> Self::Service {
152 AsyncRequireAuthorization::new(inner, self.auth)
153 }
154}
155
156/// Middleware that authorizes all requests using the [`Authorization`] header.
157///
158/// See the [module docs](crate::layer::auth::async_require_authorization) for an example.
159///
160/// [`Authorization`]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Authorization
161#[derive(Clone, Debug)]
162pub struct AsyncRequireAuthorization<S, T> {
163 inner: S,
164 auth: T,
165}
166
167impl<S, T> AsyncRequireAuthorization<S, T> {
168 /// Authorize requests using a custom scheme.
169 ///
170 /// The `Authorization` header is required to have the value provided.
171 pub const fn new(inner: S, auth: T) -> AsyncRequireAuthorization<S, T> {
172 Self { inner, auth }
173 }
174
175 define_inner_service_accessors!();
176}
177
178impl<ReqBody, ResBody, S, State, Auth> Service<State, Request<ReqBody>>
179 for AsyncRequireAuthorization<S, Auth>
180where
181 Auth: AsyncAuthorizeRequest<State, ReqBody, ResponseBody = ResBody> + Send + Sync + 'static,
182 S: Service<State, Request<Auth::RequestBody>, Response = Response<ResBody>>,
183 ReqBody: Send + 'static,
184 ResBody: Send + 'static,
185 State: Clone + Send + Sync + 'static,
186{
187 type Response = Response<ResBody>;
188 type Error = S::Error;
189
190 async fn serve(
191 &self,
192 ctx: Context<State>,
193 req: Request<ReqBody>,
194 ) -> Result<Self::Response, Self::Error> {
195 let (ctx, req) = match self.auth.authorize(ctx, req).await {
196 Ok(req) => req,
197 Err(res) => return Ok(res),
198 };
199 self.inner.serve(ctx, req).await
200 }
201}
202
203/// Trait for authorizing requests.
204pub trait AsyncAuthorizeRequest<S, B> {
205 /// The type of request body returned by `authorize`.
206 ///
207 /// Set this to `B` unless you need to change the request body type.
208 type RequestBody;
209
210 /// The body type used for responses to unauthorized requests.
211 type ResponseBody;
212
213 /// Authorize the request.
214 ///
215 /// If the future resolves to `Ok(request)` then the request is allowed through, otherwise not.
216 fn authorize(
217 &self,
218 ctx: Context<S>,
219 request: Request<B>,
220 ) -> impl Future<
221 Output = Result<(Context<S>, Request<Self::RequestBody>), Response<Self::ResponseBody>>,
222 > + Send
223 + '_;
224}
225
226impl<S, B, F, Fut, ReqBody, ResBody> AsyncAuthorizeRequest<S, B> for F
227where
228 F: Fn(Context<S>, Request<B>) -> Fut + Send + Sync + 'static,
229 Fut:
230 Future<Output = Result<(Context<S>, Request<ReqBody>), Response<ResBody>>> + Send + 'static,
231 B: Send + 'static,
232 S: Clone + Send + Sync + 'static,
233 ReqBody: Send + 'static,
234 ResBody: Send + 'static,
235{
236 type RequestBody = ReqBody;
237 type ResponseBody = ResBody;
238
239 async fn authorize(
240 &self,
241 ctx: Context<S>,
242 request: Request<B>,
243 ) -> Result<(Context<S>, Request<Self::RequestBody>), Response<Self::ResponseBody>> {
244 self(ctx, request).await
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 #[allow(unused_imports)]
251 use super::*;
252
253 use crate::{Body, StatusCode, header};
254 use rama_core::error::BoxError;
255 use rama_core::service::service_fn;
256
257 #[derive(Clone, Copy)]
258 struct MyAuth;
259
260 impl<S, B> AsyncAuthorizeRequest<S, B> for MyAuth
261 where
262 S: Clone + Send + Sync + 'static,
263 B: Send + 'static,
264 {
265 type RequestBody = B;
266 type ResponseBody = Body;
267
268 async fn authorize(
269 &self,
270 mut ctx: Context<S>,
271 request: Request<B>,
272 ) -> Result<(Context<S>, Request<Self::RequestBody>), Response<Self::ResponseBody>>
273 {
274 let authorized = request
275 .headers()
276 .get(header::AUTHORIZATION)
277 .and_then(|it: &rama_http_types::HeaderValue| it.to_str().ok())
278 .and_then(|it| it.strip_prefix("Bearer "))
279 .map(|it| it == "69420")
280 .unwrap_or(false);
281
282 if authorized {
283 let user_id = UserId("6969".to_owned());
284 ctx.insert(user_id);
285 Ok((ctx, request))
286 } else {
287 Err(Response::builder()
288 .status(StatusCode::UNAUTHORIZED)
289 .body(Body::empty())
290 .unwrap())
291 }
292 }
293 }
294
295 #[derive(Debug, Clone)]
296 #[allow(dead_code)]
297 struct UserId(String);
298
299 #[tokio::test]
300 async fn require_async_auth_works() {
301 let service = AsyncRequireAuthorizationLayer::new(MyAuth).layer(service_fn(echo));
302
303 let request = Request::get("/")
304 .header(header::AUTHORIZATION, "Bearer 69420")
305 .body(Body::empty())
306 .unwrap();
307
308 let res = service.serve(Context::default(), request).await.unwrap();
309
310 assert_eq!(res.status(), StatusCode::OK);
311 }
312
313 #[tokio::test]
314 async fn require_async_auth_401() {
315 let service = AsyncRequireAuthorizationLayer::new(MyAuth).layer(service_fn(echo));
316
317 let request = Request::get("/")
318 .header(header::AUTHORIZATION, "Bearer deez")
319 .body(Body::empty())
320 .unwrap();
321
322 let res = service.serve(Context::default(), request).await.unwrap();
323
324 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
325 }
326
327 async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
328 Ok(Response::new(req.into_body()))
329 }
330}