1use base64::Engine as _;
56use std::{fmt, marker::PhantomData};
57
58use crate::layer::validate_request::{
59 ValidateRequest, ValidateRequestHeader, ValidateRequestHeaderLayer,
60};
61use crate::{
62 header::{self, HeaderValue},
63 Request, Response, StatusCode,
64};
65use rama_core::Context;
66
67use rama_net::user::UserId;
68
69const BASE64: base64::engine::GeneralPurpose = base64::engine::general_purpose::STANDARD;
70
71impl<C> ValidateRequestHeaderLayer<AuthorizeContext<C>> {
72 pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
74 self.validate.allow_anonymous = allow_anonymous;
75 self
76 }
77
78 pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
80 self.validate.allow_anonymous = allow_anonymous;
81 self
82 }
83}
84
85impl<S, C> ValidateRequestHeader<S, AuthorizeContext<C>> {
86 pub fn set_allow_anonymous(&mut self, allow_anonymous: bool) -> &mut Self {
88 self.validate.allow_anonymous = allow_anonymous;
89 self
90 }
91
92 pub fn with_allow_anonymous(mut self, allow_anonymous: bool) -> Self {
94 self.validate.allow_anonymous = allow_anonymous;
95 self
96 }
97}
98
99impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Basic<ResBody>>> {
100 pub fn basic(inner: S, username: &str, value: &str) -> Self
108 where
109 ResBody: Default,
110 {
111 Self::custom(inner, AuthorizeContext::new(Basic::new(username, value)))
112 }
113}
114
115impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Basic<ResBody>>> {
116 pub fn basic(username: &str, password: &str) -> Self
124 where
125 ResBody: Default,
126 {
127 Self::custom(AuthorizeContext::new(Basic::new(username, password)))
128 }
129}
130
131impl<S, ResBody> ValidateRequestHeader<S, AuthorizeContext<Bearer<ResBody>>> {
132 pub fn bearer(inner: S, token: &str) -> Self
140 where
141 ResBody: Default,
142 {
143 Self::custom(inner, AuthorizeContext::new(Bearer::new(token)))
144 }
145}
146
147impl<ResBody> ValidateRequestHeaderLayer<AuthorizeContext<Bearer<ResBody>>> {
148 pub fn bearer(token: &str) -> Self
156 where
157 ResBody: Default,
158 {
159 Self::custom(AuthorizeContext::new(Bearer::new(token)))
160 }
161}
162
163pub struct Bearer<ResBody> {
167 header_value: HeaderValue,
168 _ty: PhantomData<fn() -> ResBody>,
169}
170
171impl<ResBody> Bearer<ResBody> {
172 fn new(token: &str) -> Self
173 where
174 ResBody: Default,
175 {
176 Self {
177 header_value: format!("Bearer {}", token)
178 .parse()
179 .expect("token is not a valid header value"),
180 _ty: PhantomData,
181 }
182 }
183}
184
185impl<ResBody> Clone for Bearer<ResBody> {
186 fn clone(&self) -> Self {
187 Self {
188 header_value: self.header_value.clone(),
189 _ty: PhantomData,
190 }
191 }
192}
193
194impl<ResBody> fmt::Debug for Bearer<ResBody> {
195 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
196 f.debug_struct("Bearer")
197 .field("header_value", &self.header_value)
198 .finish()
199 }
200}
201
202impl<S, B, ResBody> ValidateRequest<S, B> for AuthorizeContext<Bearer<ResBody>>
203where
204 ResBody: Default + Send + 'static,
205 B: Send + 'static,
206 S: Clone + Send + Sync + 'static,
207{
208 type ResponseBody = ResBody;
209
210 async fn validate(
211 &self,
212 ctx: Context<S>,
213 request: Request<B>,
214 ) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
215 match request.headers().get(header::AUTHORIZATION) {
216 Some(actual) if actual == self.credential.header_value => Ok((ctx, request)),
217 None if self.allow_anonymous => {
218 let mut ctx = ctx;
219 ctx.insert(UserId::Anonymous);
220 Ok((ctx, request))
221 }
222 _ => {
223 let mut res = Response::new(ResBody::default());
224 *res.status_mut() = StatusCode::UNAUTHORIZED;
225 Err(res)
226 }
227 }
228 }
229}
230
231pub struct Basic<ResBody> {
235 header_value: HeaderValue,
236 _ty: PhantomData<fn() -> ResBody>,
237}
238
239impl<ResBody> Basic<ResBody> {
240 fn new(username: &str, password: &str) -> Self
241 where
242 ResBody: Default,
243 {
244 let encoded = BASE64.encode(format!("{}:{}", username, password));
245 let header_value = format!("Basic {}", encoded).parse().unwrap();
246 Self {
247 header_value,
248 _ty: PhantomData,
249 }
250 }
251}
252
253impl<ResBody> Clone for Basic<ResBody> {
254 fn clone(&self) -> Self {
255 Self {
256 header_value: self.header_value.clone(),
257 _ty: PhantomData,
258 }
259 }
260}
261
262impl<ResBody> fmt::Debug for Basic<ResBody> {
263 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
264 f.debug_struct("Basic")
265 .field("header_value", &self.header_value)
266 .finish()
267 }
268}
269
270impl<S, B, ResBody> ValidateRequest<S, B> for AuthorizeContext<Basic<ResBody>>
271where
272 ResBody: Default + Send + 'static,
273 B: Send + 'static,
274 S: Clone + Send + Sync + 'static,
275{
276 type ResponseBody = ResBody;
277
278 async fn validate(
279 &self,
280 ctx: Context<S>,
281 request: Request<B>,
282 ) -> Result<(Context<S>, Request<B>), Response<Self::ResponseBody>> {
283 match request.headers().get(header::AUTHORIZATION) {
284 Some(actual) if actual == self.credential.header_value => Ok((ctx, request)),
285 None if self.allow_anonymous => {
286 let mut ctx = ctx;
287 ctx.insert(UserId::Anonymous);
288 Ok((ctx, request))
289 }
290 _ => {
291 let mut res = Response::new(ResBody::default());
292 *res.status_mut() = StatusCode::UNAUTHORIZED;
293 res.headers_mut()
294 .insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
295 Err(res)
296 }
297 }
298 }
299}
300
301pub struct AuthorizeContext<C> {
302 credential: C,
303 allow_anonymous: bool,
304}
305
306impl<C> AuthorizeContext<C> {
307 pub(crate) fn new(credential: C) -> Self {
308 Self {
309 credential,
310 allow_anonymous: false,
311 }
312 }
313}
314
315impl<C: Clone> Clone for AuthorizeContext<C> {
316 fn clone(&self) -> Self {
317 Self {
318 credential: self.credential.clone(),
319 allow_anonymous: self.allow_anonymous,
320 }
321 }
322}
323
324impl<C: fmt::Debug> fmt::Debug for AuthorizeContext<C> {
325 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326 f.debug_struct("AuthorizeContext")
327 .field("credential", &self.credential)
328 .field("allow_anonymous", &self.allow_anonymous)
329 .finish()
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 #[allow(unused_imports)]
336 use super::*;
337
338 use crate::layer::validate_request::ValidateRequestHeaderLayer;
339 use crate::{header, Body};
340 use rama_core::error::BoxError;
341 use rama_core::service::service_fn;
342 use rama_core::{Context, Layer, Service};
343
344 #[tokio::test]
345 async fn valid_basic_token() {
346 let service = ValidateRequestHeaderLayer::basic("foo", "bar").layer(service_fn(echo));
347
348 let request = Request::get("/")
349 .header(
350 header::AUTHORIZATION,
351 format!("Basic {}", BASE64.encode("foo:bar")),
352 )
353 .body(Body::empty())
354 .unwrap();
355
356 let res = service.serve(Context::default(), request).await.unwrap();
357
358 assert_eq!(res.status(), StatusCode::OK);
359 }
360
361 #[tokio::test]
362 async fn invalid_basic_token() {
363 let service = ValidateRequestHeaderLayer::basic("foo", "bar").layer(service_fn(echo));
364
365 let request = Request::get("/")
366 .header(
367 header::AUTHORIZATION,
368 format!("Basic {}", BASE64.encode("wrong:credentials")),
369 )
370 .body(Body::empty())
371 .unwrap();
372
373 let res = service.serve(Context::default(), request).await.unwrap();
374
375 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
376
377 let www_authenticate = res.headers().get(header::WWW_AUTHENTICATE).unwrap();
378 assert_eq!(www_authenticate, "Basic");
379 }
380
381 #[tokio::test]
382 async fn valid_bearer_token() {
383 let service = ValidateRequestHeaderLayer::bearer("foobar").layer(service_fn(echo));
384
385 let request = Request::get("/")
386 .header(header::AUTHORIZATION, "Bearer foobar")
387 .body(Body::empty())
388 .unwrap();
389
390 let res = service.serve(Context::default(), request).await.unwrap();
391
392 assert_eq!(res.status(), StatusCode::OK);
393 }
394
395 #[tokio::test]
396 async fn basic_auth_is_case_sensitive_in_prefix() {
397 let service = ValidateRequestHeaderLayer::basic("foo", "bar").layer(service_fn(echo));
398
399 let request = Request::get("/")
400 .header(
401 header::AUTHORIZATION,
402 format!("basic {}", BASE64.encode("foo:bar")),
403 )
404 .body(Body::empty())
405 .unwrap();
406
407 let res = service.serve(Context::default(), request).await.unwrap();
408
409 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
410 }
411
412 #[tokio::test]
413 async fn basic_auth_is_case_sensitive_in_value() {
414 let service = ValidateRequestHeaderLayer::basic("foo", "bar").layer(service_fn(echo));
415
416 let request = Request::get("/")
417 .header(
418 header::AUTHORIZATION,
419 format!("Basic {}", BASE64.encode("Foo:bar")),
420 )
421 .body(Body::empty())
422 .unwrap();
423
424 let res = service.serve(Context::default(), request).await.unwrap();
425
426 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
427 }
428
429 #[tokio::test]
430 async fn invalid_bearer_token() {
431 let service = ValidateRequestHeaderLayer::bearer("foobar").layer(service_fn(echo));
432
433 let request = Request::get("/")
434 .header(header::AUTHORIZATION, "Bearer wat")
435 .body(Body::empty())
436 .unwrap();
437
438 let res = service.serve(Context::default(), request).await.unwrap();
439
440 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
441 }
442
443 #[tokio::test]
444 async fn bearer_token_is_case_sensitive_in_prefix() {
445 let service = ValidateRequestHeaderLayer::bearer("foobar").layer(service_fn(echo));
446
447 let request = Request::get("/")
448 .header(header::AUTHORIZATION, "bearer foobar")
449 .body(Body::empty())
450 .unwrap();
451
452 let res = service.serve(Context::default(), request).await.unwrap();
453
454 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
455 }
456
457 #[tokio::test]
458 async fn bearer_token_is_case_sensitive_in_token() {
459 let service = ValidateRequestHeaderLayer::bearer("foobar").layer(service_fn(echo));
460
461 let request = Request::get("/")
462 .header(header::AUTHORIZATION, "Bearer Foobar")
463 .body(Body::empty())
464 .unwrap();
465
466 let res = service.serve(Context::default(), request).await.unwrap();
467
468 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
469 }
470
471 async fn echo<Body>(req: Request<Body>) -> Result<Response<Body>, BoxError> {
472 Ok(Response::new(req.into_body()))
473 }
474
475 #[tokio::test]
476 async fn basic_allows_anonymous_if_header_is_missing() {
477 let service = ValidateRequestHeaderLayer::basic("foo", "bar")
478 .with_allow_anonymous(true)
479 .layer(service_fn(echo));
480
481 let request = Request::get("/").body(Body::empty()).unwrap();
482
483 let res = service.serve(Context::default(), request).await.unwrap();
484
485 assert_eq!(res.status(), StatusCode::OK);
486 }
487
488 #[tokio::test]
489 async fn basic_fails_if_allow_anonymous_and_credentials_are_invalid() {
490 let service = ValidateRequestHeaderLayer::basic("foo", "bar")
491 .with_allow_anonymous(true)
492 .layer(service_fn(echo));
493
494 let request = Request::get("/")
495 .header(
496 header::AUTHORIZATION,
497 format!("Basic {}", BASE64.encode("wrong:credentials")),
498 )
499 .body(Body::empty())
500 .unwrap();
501
502 let res = service.serve(Context::default(), request).await.unwrap();
503
504 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
505 }
506
507 #[tokio::test]
508 async fn bearer_allows_anonymous_if_header_is_missing() {
509 let service = ValidateRequestHeaderLayer::bearer("foobar")
510 .with_allow_anonymous(true)
511 .layer(service_fn(echo));
512
513 let request = Request::get("/").body(Body::empty()).unwrap();
514
515 let res = service.serve(Context::default(), request).await.unwrap();
516
517 assert_eq!(res.status(), StatusCode::OK);
518 }
519
520 #[tokio::test]
521 async fn bearer_fails_if_allow_anonymous_and_credentials_are_invalid() {
522 let service = ValidateRequestHeaderLayer::bearer("foobar")
523 .with_allow_anonymous(true)
524 .layer(service_fn(echo));
525
526 let request = Request::get("/")
527 .header(header::AUTHORIZATION, "Bearer wrong")
528 .body(Body::empty())
529 .unwrap();
530
531 let res = service.serve(Context::default(), request).await.unwrap();
532
533 assert_eq!(res.status(), StatusCode::UNAUTHORIZED);
534 }
535}