1use crate::layer::validate_request::{
56 ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer,
57};
58use crate::{
59 Request, Response, StatusCode,
60 header::{self, HeaderValue},
61};
62use base64::Engine as _;
63use rama_core::Context;
64use std::{fmt, marker::PhantomData, sync::Arc};
65
66use rama_net::user::UserId;
67
68const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
69
70impl<C> ValidateRequestHeaderLayer<AuthorizeContext<C>> {
71 pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
73 self.validate.allow_anonymous = allow_anonymous;
74 self
75 }
76
77 pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
79 self.validate.allow_anonymous = allow_anonymous;
80 self
81 }
82}
83
84impl<S, C> ValidateRequestHeader<S, AuthorizeContext<C>> {
85 pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
87 self.validate.allow_anonymous = allow_anonymous;
88 self
89 }
90
91 pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
93 self.validate.allow_anonymous = allow_anonymous;
94 self
95 }
96}
97
98impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Basic<ResBody>>> {
99 pub fn basic(inner: S, username: &str, value: &str) -> Self
107 where
108 ResBody: Default,
109 {
110 Self::custom(inner, AuthorizeContext::new(Basic::new(username, value)))
111 }
112}
113
114impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Basic<ResBody>>> {
115 pub fn basic(username: &str, password: &str) -> Self
123 where
124 ResBody: Default,
125 {
126 Self::custom(AuthorizeContext::new(Basic::new(username, password)))
127 }
128}
129
130impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Bearer<ResBody>>> {
131 pub fn bearer(inner: S, token: &str) -> Self
139 where
140 ResBody: Default,
141 {
142 Self::custom(inner, AuthorizeContext::new(Bearer::new(token)))
143 }
144}
145
146impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Bearer<ResBody>>> {
147 pub fn bearer(token: &str) -> Self
155 where
156 ResBody: Default,
157 {
158 Self::custom(AuthorizeContext::new(Bearer::new(token)))
159 }
160}
161
162pub struct Bearer<ResBody> {
166 header_value: HeaderValue,
167 _ty: PhantomData<fn() -> ResBody>,
168}
169
170impl<ResBody> Bearer<ResBody> {
171 fn new(token: &str) -> Self
172 where
173 ResBody: Default,
174 {
175 Self {
176 header_value: format!("Bearer {}", token)
177 .parse()
178 .expect("token is not a valid header value"),
179 _ty: PhantomData,
180 }
181 }
182}
183
184impl<ResBody> Clone for Bearer<ResBody> {
185 fn clone(&self) -> Self {
186 Self {
187 header_value: self.header_value.clone(),
188 _ty: PhantomData,
189 }
190 }
191}
192
193impl<ResBody> fmt::Debug for Bearer<ResBody> {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.debug_struct("Bearer")
196 .field("header_value", &self.header_value)
197 .finish()
198 }
199}
200
201impl<S, B, C> ValidateRequest<S, B> for AuthorizeContext<C>
205where
206 C: Authorizer,
207 B: Send + 'static,
208 S: Clone + Send + Sync + 'static,
209{
210 type ResponseBody = C::ResBody;
211
212 async fn validate(
213 &self,
214 ctx: Context<S>,
215 request: Request<B>,
216 ) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
217 match request.headers().get(header::AUTHORIZATION) {
218 Some(header_value) if self.credential.is_valid(header_value) => Ok((ctx, request)),
219 None if self.allow_anonymous => {
220 let mut ctx = ctx;
221 ctx.insert(UserId::Anonymous);
222 Ok((ctx, request))
223 }
224 _ => {
225 let mut res = Response::new(Self::ResponseBody::default());
226 *res.status_mut() = StatusCode::UNAUTHORIZED;
227
228 if let Some(www_auth) = C::www_authenticate_header() {
229 res.headers_mut().insert(header::WWW_AUTHENTICATE, www_auth);
230 } else {
231 res.headers_mut()
232 .insert(header::WWW_AUTHENTICATE, "Bearer".parse().unwrap());
233 }
234
235 Err(res)
236 }
237 }
238 }
239}
240
241pub struct Basic<ResBody> {
245 header_value: HeaderValue,
246 _ty: PhantomData<fn() -> ResBody>,
247}
248
249impl<ResBody> Basic<ResBody> {
250 fn new(username: &str, password: &str) -> Self
251 where
252 ResBody: Default,
253 {
254 let encoded = BASE64.encode(format!("{}:{}", username, password));
255 let header_value = format!("Basic {}", encoded).parse().unwrap();
256 Self {
257 header_value,
258 _ty: PhantomData,
259 }
260 }
261}
262
263impl<ResBody> Clone for Basic<ResBody> {
264 fn clone(&self) -> Self {
265 Self {
266 header_value: self.header_value.clone(),
267 _ty: PhantomData,
268 }
269 }
270}
271
272impl<ResBody> fmt::Debug for Basic<ResBody> {
273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
274 f.debug_struct("Basic")
275 .field("header_value", &self.header_value)
276 .finish()
277 }
278}
279
280mod sealed {
282 use super::*;
283
284 pub trait AuthorizerSeal: Send + Sync + 'static {
286 fn is_valid(&self, header_value: &HeaderValue) -> bool;
288
289 fn www_authenticate_header() -> Option<HeaderValue>;
291 }
292
293 impl<ResBody: Default + Send + 'static> AuthorizerSeal for Basic<ResBody> {
294 fn is_valid(&self, header_value: &HeaderValue) -> bool {
295 header_value == self.header_value
296 }
297
298 fn www_authenticate_header() -> Option<HeaderValue> {
299 Some(HeaderValue::from_static("Basic"))
300 }
301 }
302
303 impl<ResBody: Default + Send + 'static> AuthorizerSeal for Bearer<ResBody> {
304 fn is_valid(&self, header_value: &HeaderValue) -> bool {
305 header_value == self.header_value
306 }
307
308 fn www_authenticate_header() -> Option<HeaderValue> {
309 None
310 }
311 }
312
313 impl<T, const N: usize> AuthorizerSeal for [T; N]
314 where
315 T: AuthorizerSeal,
316 {
317 fn is_valid(&self, header_value: &HeaderValue) -> bool {
318 self.iter().any(|auth| auth.is_valid(header_value))
319 }
320
321 fn www_authenticate_header() -> Option<HeaderValue> {
322 T::www_authenticate_header()
323 }
324 }
325
326 impl<T> AuthorizerSeal for Vec<T>
327 where
328 T: AuthorizerSeal,
329 {
330 fn is_valid(&self, header_value: &HeaderValue) -> bool {
331 self.iter().any(|auth| auth.is_valid(header_value))
332 }
333
334 fn www_authenticate_header() -> Option<HeaderValue> {
335 T::www_authenticate_header()
336 }
337 }
338
339 impl<T> AuthorizerSeal for Arc<T>
340 where
341 T: AuthorizerSeal,
342 {
343 fn is_valid(&self, header_value: &HeaderValue) -> bool {
344 (**self).is_valid(header_value)
345 }
346
347 fn www_authenticate_header() -> Option<HeaderValue> {
348 T::www_authenticate_header()
349 }
350 }
351}
352
353pub trait Authorizer: sealed::AuthorizerSeal {
355 type ResBody: Default + Send + 'static;
356}
357
358impl<ResBody: Default + Send + 'static> Authorizer for Basic<ResBody> {
360 type ResBody = ResBody;
361}
362impl<ResBody: Default + Send + 'static> Authorizer for Bearer<ResBody> {
363 type ResBody = ResBody;
364}
365impl<T: Authorizer, const N: usize> Authorizer for [T; N] {
366 type ResBody = T::ResBody;
367}
368impl<T: Authorizer> Authorizer for Vec<T> {
369 type ResBody = T::ResBody;
370}
371impl<T: Authorizer> Authorizer for Arc<T> {
372 type ResBody = T::ResBody;
373}
374
375pub struct AuthorizeContext<C> {
376 credential: C,
377 allow_anonymous: bool,
378}
379
380impl<C> AuthorizeContext<C> {
381 pub(crate) fn new(credential: C) -> Self {
383 Self {
384 credential,
385 allow_anonymous: false,
386 }
387 }
388}
389
390impl<C: Clone> Clone for AuthorizeContext<C> {
391 fn clone(&self) -> Self {
392 Self {
393 credential: self.credential.clone(),
394 allow_anonymous: self.allow_anonymous,
395 }
396 }
397}
398
399impl<C: fmt::Debug> fmt::Debug for AuthorizeContext<C> {
400 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
401 f.debug_struct("AuthorizeContext")
402 .field("credential", &self.credential)
403 .field("allow_anonymous", &self.allow_anonymous)
404 .finish()
405 }
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 use crate::layer::validate_request::ValidateRequestHeaderLayer;
413 use crate::{Body, header};
414
415 use rama_core::error::BoxError;
416 use rama_core::service::service_fn;
417 use rama_core::{Context, Layer, Service};
418
419 #[tokio::test]
420 async fn valid_basic_token() {
421 let service = ValidateRequestHeaderLayer::basic("foo", "bar").into_layer(service_fn(echo));
422
423 let request = Request::get("/")
424 .header(
425 header::AUTHORIZATION,
426 format!("Basic {}", BASE64.encode("foo:bar")),
427 )
428 .body(Body::empty())
429 .unwrap();
430
431 let res = service.serve(Context::default(), request).await.unwrap();
432
433 assert_eq!(res.status(), StatusCode::OK);
434 }
435
436 #[tokio::test]
437 async fn invalid_basic_token() {
438 let service = ValidateRequestHeaderLayer::basic("foo", "bar").into_layer(service_fn(echo));
439
440 let request = Request::get("/")
441 .header(
442 header::AUTHORIZATION,
443 format!("Basic {}", BASE64.encode("wrong:credentials")),
444 )
445 .body(Body::empty())
446 .unwrap();
447
448 let res = service.serve(Context::default(), request).await.unwrap();
449
450 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
451
452 let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap();
453 assert_eq!(www_authenticate, "Basic");
454 }
455
456 #[tokio::test]
457 async fn valid_bearer_token() {
458 let service = ValidateRequestHeaderLayer::bearer("foobar").into_layer(service_fn(echo));
459
460 let request = Request::get("/")
461 .header(header::AUTHORIZATION, "Bearer foobar")
462 .body(Body::empty())
463 .unwrap();
464
465 let res = service.serve(Context::default(), request).await.unwrap();
466
467 assert_eq!(res.status(), StatusCode::OK);
468 }
469
470 #[tokio::test]
471 async fn basic_auth_is_case_sensitive_in_prefix() {
472 let service = ValidateRequestHeaderLayer::basic("foo", "bar").into_layer(service_fn(echo));
473
474 let request = Request::get("/")
475 .header(
476 header::AUTHORIZATION,
477 format!("basic {}", BASE64.encode("foo:bar")),
478 )
479 .body(Body::empty())
480 .unwrap();
481
482 let res = service.serve(Context::default(), request).await.unwrap();
483
484 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
485 }
486
487 #[tokio::test]
488 async fn basic_auth_is_case_sensitive_in_value() {
489 let service = ValidateRequestHeaderLayer::basic("foo", "bar").into_layer(service_fn(echo));
490
491 let request = Request::get("/")
492 .header(
493 header::AUTHORIZATION,
494 format!("Basic {}", BASE64.encode("Foo:bar")),
495 )
496 .body(Body::empty())
497 .unwrap();
498
499 let res = service.serve(Context::default(), request).await.unwrap();
500
501 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
502 }
503
504 #[tokio::test]
505 async fn invalid_bearer_token() {
506 let service = ValidateRequestHeaderLayer::bearer("foobar").into_layer(service_fn(echo));
507
508 let request = Request::get("/")
509 .header(header::AUTHORIZATION, "Bearer wat")
510 .body(Body::empty())
511 .unwrap();
512
513 let res = service.serve(Context::default(), request).await.unwrap();
514
515 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
516 }
517
518 #[tokio::test]
519 async fn bearer_token_is_case_sensitive_in_prefix() {
520 let service = ValidateRequestHeaderLayer::bearer("foobar").into_layer(service_fn(echo));
521
522 let request = Request::get("/")
523 .header(header::AUTHORIZATION, "bearer foobar")
524 .body(Body::empty())
525 .unwrap();
526
527 let res = service.serve(Context::default(), request).await.unwrap();
528
529 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
530 }
531
532 #[tokio::test]
533 async fn bearer_token_is_case_sensitive_in_token() {
534 let service = ValidateRequestHeaderLayer::bearer("foobar").into_layer(service_fn(echo));
535
536 let request = Request::get("/")
537 .header(header::AUTHORIZATION, "Bearer Foobar")
538 .body(Body::empty())
539 .unwrap();
540
541 let res = service.serve(Context::default(), request).await.unwrap();
542
543 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
544 }
545
546 #[tokio::test]
547 async fn multiple_basic_auth_vec() {
548 let auth1 = Basic::new("user1", "pass1");
549 let auth2 = Basic::new("user2", "pass2");
550 let auth_vec = vec![auth1, auth2];
551 let auth_context = AuthorizeContext::new(auth_vec);
552 let service = ValidateRequestHeaderLayer::custom(auth_context).into_layer(service_fn(echo));
553
554 let request = Request::builder()
556 .header(
557 header::AUTHORIZATION,
558 format!("Basic {}", BASE64.encode("user1:pass1")),
559 )
560 .body(Body::default())
561 .unwrap();
562 let response = service.serve(Context::default(), request).await.unwrap();
563 assert_eq!(StatusCode::OK, response.status());
564
565 let request = Request::builder()
567 .header(
568 header::AUTHORIZATION,
569 format!("Basic {}", BASE64.encode("user2:pass2")),
570 )
571 .body(Body::default())
572 .unwrap();
573 let response = service.serve(Context::default(), request).await.unwrap();
574 assert_eq!(StatusCode::OK, response.status());
575
576 let request = Request::builder()
578 .header(
579 header::AUTHORIZATION,
580 format!("Basic {}", BASE64.encode("invalid:invalid")),
581 )
582 .body(Body::default())
583 .unwrap();
584 let response = service.serve(Context::default(), request).await.unwrap();
585 assert_eq!(StatusCode::UNAUTHORIZED, response.status());
586 }
587
588 #[tokio::test]
589 async fn multiple_basic_auth_array() {
590 let auth1 = Basic::new("user1", "pass1");
591 let auth_array = [auth1.clone(), auth1.clone()];
592 let auth_context = AuthorizeContext::new(auth_array);
593 let service = ValidateRequestHeaderLayer::custom(auth_context).into_layer(service_fn(echo));
594
595 let request = Request::builder()
597 .header(
598 header::AUTHORIZATION,
599 format!("Basic {}", BASE64.encode("user1:pass1")),
600 )
601 .body(Body::default())
602 .unwrap();
603 let response = service.serve(Context::default(), request).await.unwrap();
604 assert_eq!(StatusCode::OK, response.status());
605 }
606
607 #[tokio::test]
608 async fn arc_basic_auth() {
609 let auth = Basic::new("user", "pass");
610 let arc_auth = Arc::new(auth);
611 let auth_context = AuthorizeContext::new(arc_auth);
612 let service = ValidateRequestHeaderLayer::custom(auth_context).into_layer(service_fn(echo));
613
614 let request = Request::builder()
615 .header(
616 header::AUTHORIZATION,
617 format!("Basic {}", BASE64.encode("user:pass")),
618 )
619 .body(Body::default())
620 .unwrap();
621 let response = service.serve(Context::default(), request).await.unwrap();
622 assert_eq!(StatusCode::OK, response.status());
623 }
624
625 #[tokio::test]
626 async fn basic_allows_anonymous_if_header_is_missing() {
627 let service = ValidateRequestHeaderLayer::basic("foo", "bar")
628 .with_allow_anonymous(true)
629 .into_layer(service_fn(echo));
630
631 let request = Request::get("/").body(Body::empty()).unwrap();
632
633 let res = service.serve(Context::default(), request).await.unwrap();
634
635 assert_eq!(res.status(), StatusCode::OK);
636 }
637
638 #[tokio::test]
639 async fn basic_fails_if_allow_anonymous_and_credentials_are_invalid() {
640 let service = ValidateRequestHeaderLayer::basic("foo", "bar")
641 .with_allow_anonymous(true)
642 .into_layer(service_fn(echo));
643
644 let request = Request::get("/")
645 .header(
646 header::AUTHORIZATION,
647 format!("Basic {}", BASE64.encode("wrong:credentials")),
648 )
649 .body(Body::empty())
650 .unwrap();
651
652 let res = service.serve(Context::default(), request).await.unwrap();
653
654 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
655 }
656
657 #[tokio::test]
658 async fn bearer_allows_anonymous_if_header_is_missing() {
659 let service = ValidateRequestHeaderLayer::bearer("foobar")
660 .with_allow_anonymous(true)
661 .into_layer(service_fn(echo));
662
663 let request = Request::get("/").body(Body::empty()).unwrap();
664
665 let res = service.serve(Context::default(), request).await.unwrap();
666
667 assert_eq!(res.status(), StatusCode::OK);
668 }
669
670 #[tokio::test]
671 async fn bearer_fails_if_allow_anonymous_and_credentials_are_invalid() {
672 let service = ValidateRequestHeaderLayer::bearer("foobar")
673 .with_allow_anonymous(true)
674 .into_layer(service_fn(echo));
675
676 let request = Request::get("/")
677 .header(header::AUTHORIZATION, "Bearer wrong")
678 .body(Body::empty())
679 .unwrap();
680
681 let res = service.serve(Context::default(), request).await.unwrap();
682
683 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
684 }
685
686 async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
687 Ok(Response::new(req.into_body()))
688 }
689}