1use std::convert::Infallible;
2use std::task::{Poll, Context};
3use std::future::Future;
4use std::marker::{PhantomData, PhantomPinned};
5use std::pin::{Pin};
6use std::sync::{Arc, Mutex};
7use axum_core::body::BoxBody;
8use axum_core::extract::{FromRequestParts};
9use axum_core::response::{IntoResponse, Response};
10use futures_core::future::BoxFuture;
11use tower_service::Service;
12use http::{Request, StatusCode};
13use http::request::Parts;
14use axum_core::BoxError;
15use futures_core::ready;
16use tower_layer::Layer;
17
18pub trait Guard {
19 fn check_guard(&self, expected:&Self) -> bool;
20}
21
22pub trait GuardServiceExt : Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Clone + Send {
23 fn or<Other>(self,other:Other)
24 -> OrGuardService<Self,Other>
25 where
26 Other:Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Clone + Send,
27 <Other as Service<Parts>>::Future: Send,
28 <Self as Service<Parts>>::Future: Send {
29 OrGuardService{
30 left:self,
31 right: other,
32 }
33 }
34 fn and<Other>(self,other:Other)
35 -> AndGuardService<Self,Other>
36 where
37 Other:Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Clone + Send,
38 <Other as Service<Parts>>::Future: Send,
39 <Self as Service<Parts>>::Future: Send {
40 AndGuardService{
41 left:self,
42 right: other,
43 }
44 }
45 fn into_layer<ReqBody>(self) -> GuardLayer<Self,ReqBody> {
46 GuardLayer::with(self)
47 }
48}
49impl<T> GuardServiceExt for T
50 where
51 T: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Clone + Send {}
52
53pub struct GuardLayer<GuardService,ReqBody>
54 where
55 GuardService:Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)>
56 + Send + Clone + 'static {
57 guard_service:GuardService,
58 _marker:PhantomData<ReqBody>
59}
60
61impl<GuardService,ReqBody> Clone for GuardLayer<GuardService,ReqBody>
62 where
63 GuardService:Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)>
64 + Send + Clone + 'static {
65 fn clone(&self) -> Self {
66 Self{
67 guard_service: self.guard_service.clone(),
68 _marker: PhantomData,
69 }
70 }
71}
72impl<GuardService,ReqBody> GuardLayer<GuardService,ReqBody>
73 where
74 GuardService:Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)>
75 + Send + Clone + 'static {
76 pub fn with(guard:GuardService) -> Self {
77 Self{ guard_service:guard, _marker:PhantomData }
78 }
79}
80impl<S,GuardService,ReqBody> Layer<S> for GuardLayer<GuardService,ReqBody>
81 where
82 S:Service<Request<ReqBody>> + Clone,
83 GuardService:Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)>
84 + Send + Clone + 'static {
85 type Service = GuardServiceWrapper<S,GuardService,ReqBody>;
86
87 fn layer(&self, inner: S) -> Self::Service {
88 GuardServiceWrapper{
89 inner,
90 guard_service:self.guard_service.clone(),
91 _marker:PhantomData
92 }
93 }
94}
95pub struct GuardServiceWrapper<S,GuardService,ReqBody>
96 where
97 S:Service<Request<ReqBody>> + Clone,
98 GuardService:Service<Parts, Response=GuardServiceResponse>
99 + Send + Clone + 'static {
100 inner:S,
101 guard_service:GuardService,
102 _marker:PhantomData<ReqBody>
103}
104impl<S,GuardService,ReqBody> Clone for GuardServiceWrapper<S,GuardService,ReqBody>
105 where
106 S:Service<Request<ReqBody>> + Clone,
107 GuardService:Service<Parts, Response=GuardServiceResponse>
108 + Send + Clone + 'static {
109 fn clone(&self) -> Self {
110 Self{
111 inner: self.inner.clone(),
112 guard_service: self.guard_service.clone(),
113 _marker: PhantomData
114 }
115 }
116}
117impl<S,ReqBody,GuardService> Service<Request<ReqBody>> for
118GuardServiceWrapper<S,GuardService,ReqBody>
119 where
120 ReqBody:Send +'static,
121 S:Service<Request<ReqBody>, Response = Response, Error = Infallible> + Send + Clone + 'static,
122 <S as Service<Request<ReqBody>>>::Future: Send,
123 <GuardService as Service<Parts>>::Future: Send,
124 GuardService:Service<Parts, Response=GuardServiceResponse, Error = (StatusCode,String)>
125 + Send + Clone + 'static{
126 type Response = Response;
127 type Error = Infallible;
128 type Future = BoxFuture<'static,Result<Response,Infallible>>;
129
130 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
131 Poll::Ready(Ok(()))
132 }
133
134 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
135 let (parts,body) = req.into_parts();
136 let f = self.guard_service.call(parts);
137 let mut inner = self.inner.clone();
138 Box::pin(async move {
139 match f.await {
140 Ok(GuardServiceResponse(result,parts)) => {
141 if result.0 {
142 inner.call(Request::from_parts(parts,body))
143 .await
144 } else {
145 Ok((StatusCode::UNAUTHORIZED,result.1.unwrap_or_default()).into_response())
146 }
147 },
148 Err(status) => {
149 Ok(status.into_response())
150 }
151 }
152 })
153 }
154}
155impl<State,G> GuardService< State,G>
200 where
201 State:Clone,
202 G: Clone + FromRequestParts<State, Rejection = (StatusCode,String)> + Guard {
203 pub fn new(state:State,expected_guard:G,err_msg:&'static str) -> GuardService< State,G> {
204 Self{ state, expected_guard,err_msg}
205 }
206}
207#[derive(Clone)]
208pub struct GuardService<State,G>
209 where
210 State:Clone,
211 G:Clone{
212 state:State,
213 expected_guard:G,
214 err_msg:&'static str,
215}
216
217
218impl<State,G> Service<Parts> for GuardService<State,G>
219 where
220 State: Sync + Send + Clone + 'static,
221 G: Clone + FromRequestParts<State, Rejection = (StatusCode,String)> + Guard + Sync + Send + 'static, {
222 type Response = GuardServiceResponse;
223 type Error = (StatusCode,String);
224 type Future = BoxFuture<'static,Result<Self::Response,Self::Error>>;
225
226 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
227 Poll::Ready(Ok(()))
228 }
229
230 fn call(&mut self, mut req: Parts) -> Self::Future {
231 let expected = self.expected_guard.clone();
232 let state = self.state.clone();
233 let err_msg = self.err_msg;
234 Box::pin(async move {
235 let result = match G::from_request_parts(&mut req, &state).await {
236 Ok(guard) => {
237 guard.check_guard(&expected)
238 },
239 Err(status) => {
240 return Err(status);
241 }
242 };
243 let err_msg = if !result {
244 Some(String::from(err_msg))
245 } else {None};
246 Ok(GuardServiceResponse((result,err_msg),req))
247 })
248 }
249}
250
251#[derive(Clone)]
252pub struct AndGuardService<S1,S2>
253 where
254 S1: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
255 <S1 as Service<Parts>>::Future: Send,
256 S2: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
257 <S2 as Service<Parts>>::Future: Send,{
258 left:S1,
259 right:S2,
260}
261impl<S1,S2> AndGuardService<S1,S2>
262 where
263 S1: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
264 <S1 as Service<Parts>>::Future: Send,
265 S2: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
266 <S2 as Service<Parts>>::Future: Send, {
267 pub fn new(left:S1,right:S2) -> Self{
268 Self{ left, right }
269 }
270}
271impl<S1,S2> Service<Parts> for AndGuardService<S1,S2>
272 where
273 S1: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
274 <S1 as Service<Parts>>::Future: Send,
275 S2: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
276 <S2 as Service<Parts>>::Future: Send, {
277 type Response = GuardServiceResponse;
278 type Error = (StatusCode,String);
279 type Future = BoxFuture<'static,Result<Self::Response,Self::Error>>;
280
281 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
282 Poll::Ready(Ok(()))
283 }
284
285 fn call(&mut self, parts: Parts) -> Self::Future {
286 let mut left = self.left.clone();
287 let mut right = self.right.clone();
288 Box::pin(async move {
289 let GuardServiceResponse(result,parts) =
290 left.call(parts).await?;
291 if result.0{
292 right.call(parts).await
293 } else {
294 Ok(GuardServiceResponse((false,result.1), parts, ))
295 }
296 })
297
298 }
299}
300#[derive(Clone)]
301pub struct OrGuardService<S1,S2>
302 where
303 S1: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
304 <S1 as Service<Parts>>::Future: Send,
305 S2: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
306 <S2 as Service<Parts>>::Future: Send,{
307 left:S1,
308 right:S2,
309}
310impl<S1,S2> OrGuardService<S1,S2>
311 where
312 S1: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
313 <S1 as Service<Parts>>::Future: Send,
314 S2: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
315 <S2 as Service<Parts>>::Future: Send, {
316 pub fn new(left:S1,right:S2) -> Self{
317 Self{ left, right }
318 }
319}
320impl<S1,S2> Service<Parts> for OrGuardService<S1,S2>
321 where
322 S1: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
323 <S1 as Service<Parts>>::Future: Send,
324 S2: Service<Parts, Response=GuardServiceResponse,Error=(StatusCode,String)> + Send + Clone + 'static,
325 <S2 as Service<Parts>>::Future: Send, {
326 type Response = GuardServiceResponse;
327 type Error = (StatusCode,String);
328 type Future = BoxFuture<'static,Result<Self::Response,Self::Error>>;
329
330 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
331 Poll::Ready(Ok(()))
332 }
333
334 fn call(&mut self, parts: Parts) -> Self::Future {
335 let mut left = self.left.clone();
336 let mut right = self.right.clone();
337 Box::pin(async move {
338 let GuardServiceResponse(result,parts) =
339 left.call(parts).await?;
340 if result.0{
341 Ok(GuardServiceResponse((true,None), parts))
342 } else {
343 right.call(parts).await
344 }
345 })
346
347 }
348}
349
350
351pub struct GuardServiceResponse( (bool,Option<String>), Parts);
352
353
354
355#[cfg(test)]
356pub mod tests {
357 use tokio::time::{sleep, Duration};
358 use axum::body::Body;
359 use axum::error_handling::{HandleError, HandleErrorLayer};
360 use axum::handler::Handler;
361 use axum::Router;
362 use axum::routing::get;
363 use http::{HeaderValue, Request, StatusCode};
364 use tower::util::ServiceExt;
365 use axum::BoxError;
366 use axum::extract::State;
367 use axum::middleware::Next;
368 use axum_core::extract::FromRequestParts;
369 use tower::{service_fn, ServiceBuilder};
370
371 #[derive(Clone, Debug, PartialEq)]
372 pub struct ArbitraryData {
373 data: String,
374 }
375
376 impl Guard for ArbitraryData {
377 fn check_guard(&self, expected: &Self) -> bool {
378 *self == *expected
379 }
380 }
381
382 #[async_trait::async_trait]
383 impl FromRequestParts<ArbitraryData> for ArbitraryData {
384 type Rejection = (StatusCode,String);
385
386 async fn from_request_parts(parts: &mut Parts, state: &ArbitraryData) -> Result<Self, Self::Rejection> {
387 Ok(Self {
388 data: parts.headers.get(state.data.clone())
389 .ok_or((StatusCode::INTERNAL_SERVER_ERROR,"error".into()))?
390 .to_str()
391 .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR,"error".into()))?
392 .to_string()
393 })
394 }
395 }
396 async fn ok() -> StatusCode {
397 StatusCode::OK
398 }
399 use super::*;
400
401 #[tokio::test]
402 async fn test_guard_service_ok() {
403 let (mut parts, _) = Request::new(()).into_parts();
404 parts.headers.insert(
405 "data",
406 HeaderValue::from_static("other_data"));
407 assert_eq!(GuardService::new(
408 ArbitraryData { data: "data".into() },
409 ArbitraryData { data: "other_data".into() },"err")
410 .call(parts).await.unwrap().0,(true,None));
411 }
412
413 #[tokio::test]
414 async fn test_guard_service_not_from_request_parts_error() {
415 let (mut parts, _) = Request::new(()).into_parts();
416 parts.headers.insert(
417 "BIG-UH-OH",
418 HeaderValue::from_static("other_data"));
419 let result = GuardService::new(
420 ArbitraryData { data: "data".into() },
421 ArbitraryData { data: "other_data".into() },"err")
422 .call(parts).await;
423 assert_eq!(result.err().map(|err|err.0), Some(StatusCode::INTERNAL_SERVER_ERROR));
424 }
425
426 #[tokio::test]
427 async fn test_guard_service_expected_failed() {
428 let (mut parts, _) = Request::new(()).into_parts();
429 parts.headers.insert(
430 "data",
431 HeaderValue::from_static("other_data"));
432 assert_eq!(GuardService::new(
433 ArbitraryData { data: "data".into() },
434 ArbitraryData { data: "NOT OTHER DATA MY BAD".into() },"err")
435 .call(parts).await.unwrap().0,(false,Some("err".into())));
436 }
437
438 #[tokio::test]
439 async fn test_and() {
440 let (mut parts, _) = Request::new(()).into_parts();
441 parts.headers.insert(
442 "data",
443 HeaderValue::from_static("data"));
444 parts.headers.insert(
445 "other_data",
446 HeaderValue::from_static("other_data"));
447 let data = ArbitraryData { data: "data".into() };
448 let other_data = ArbitraryData { data: "data".into() };
449 assert!(
450 AndGuardService::new(
451 GuardService::new(
452 data.clone(), data.clone(),"err"
453 ),
454 GuardService::new(
455 other_data.clone(), other_data.clone(),"err"
456 )
457 ).call(parts).await.unwrap().0.0
458 )
459 }
460
461 #[tokio::test]
462 async fn test_or() {
463 let (mut parts, _) = Request::new(()).into_parts();
464 parts.headers.insert(
465 "data",
466 HeaderValue::from_static("data"));
467 parts.headers.insert(
468 "other_data",
469 HeaderValue::from_static("NUH UH BRUH NUHHHHH"));
470 let data = ArbitraryData { data: "data".into() };
471 let other_data = ArbitraryData { data: "data".into() };
472 assert!(
473 OrGuardService::new(
474 GuardService::new(
475 data.clone(), data.clone(),"err"
476 ),
477 GuardService::new(
478 other_data.clone(), other_data.clone(),"err"
479 )
480 ).call(parts).await.unwrap().0.0
481 )
482 }
483
484 #[tokio::test]
485 async fn test_and_deep() {
486 let (mut parts, _) = Request::new(()).into_parts();
487 parts.headers.insert(
488 "data",
489 HeaderValue::from_static("data"));
490 let data = ArbitraryData { data: "data".into() };
491 assert!(
492 AndGuardService::new(
493 AndGuardService::new(
494 AndGuardService::new(
495 GuardService::new(
496 data.clone(), data.clone(),"err"
497 ),
498 GuardService::new(
499 data.clone(), data.clone(),"err"
500 ),
501 ),
502 AndGuardService::new(
503 GuardService::new(
504 data.clone(), data.clone(),"err"
505 ),
506 GuardService::new(
507 data.clone(), data.clone(),"err"
508 )
509 )
510 ),
511 AndGuardService::new(
512 AndGuardService::new(
513 GuardService::new(
514 data.clone(), data.clone(),"err"
515 ),
516 GuardService::new(
517 data.clone(), data.clone(),"err"
518 ),
519 ),
520 AndGuardService::new(
521 GuardService::new(
522 data.clone(), data.clone(),"err"
523 ),
524 GuardService::new(
525 data.clone(), data.clone(),"err"
526 )
527 )
528 )
529 ).call(parts).await.unwrap().0.0
530 )
531 }
532 #[tokio::test]
533 async fn test_layer() {
534 let req = Request::builder()
535 .header("data","data")
536 .body(BoxBody::default())
537 .unwrap();
538
539 let data = ArbitraryData { data: "data".into() };
540 let app = Router::new()
541 .route("/", get(ok))
542 .layer(
543 GuardLayer::with(
544 GuardService::new(
545 data.clone(),
546 data.clone(),"err"))
547 );
548 let resp = app.oneshot(req).await.unwrap();
549 assert_eq!(resp.status(),StatusCode::OK)
550 }
551 #[tokio::test]
552 async fn test_layer_or_ok() {
553 let req = Request::builder()
554 .header("data","data")
555 .header("not_data","wazzup i'm a criminal")
556 .body(BoxBody::default())
557 .unwrap();
558
559 let data = ArbitraryData { data: "data".into() };
560 let app = Router::new()
561 .route("/", get(ok))
562 .layer(
563 GuardLayer::with(
564 GuardService::new(
565 data.clone(),
566 data.clone(),"err")
567 .or(GuardService::new(
568 ArbitraryData{data:"not_data".into()},
569 data,"err"
570 ))
571 )
572 );
573 let resp = app.oneshot(req).await.unwrap();
574 assert_eq!(resp.status(),StatusCode::OK)
575 }
576 #[tokio::test]
577 async fn test_layer_or_not_ok() {
578 let req = Request::builder()
579 .header("data","wazzup i'm a criminal")
580 .header("not_data","wazzup i'm a criminal")
581 .body(BoxBody::default())
582 .unwrap();
583
584 let data = ArbitraryData { data: "data".into() };
585 let app = Router::new()
586 .route("/", get(ok))
587 .layer(
588 GuardLayer::with(
589 GuardService::new(
590 data.clone(),
591 data.clone(),"err")
592 .or(GuardService::new(
593 ArbitraryData{data:"not_data".into()},
594 data,"err"
595 ))
596 )
597 );
598 let resp = app.oneshot(req).await.unwrap();
599 assert_eq!(resp.status(),StatusCode::UNAUTHORIZED)
600 }
601 #[tokio::test]
602 async fn test_layer_and_ok() {
603 let req = Request::builder()
604 .header("data","data")
605 .header("not_data","not_data")
606 .body(BoxBody::default())
607 .unwrap();
608
609 let data = ArbitraryData { data: "data".into() };
610 let definitely_still_data = ArbitraryData{data:"not_data".into()};
611 let app = Router::new()
612 .route("/", get(ok))
613 .layer(
614 GuardLayer::with(
615 GuardService::new(
616 data.clone(),
617 data.clone(),"err")
618 .and(GuardService::new(
619 definitely_still_data.clone(),
620 definitely_still_data,"err"
621 ))
622 )
623 );
624 let resp = app.oneshot(req).await.unwrap();
625 assert_eq!(resp.status(),StatusCode::OK)
626 }
627 #[tokio::test]
628 async fn test_layer_and_not_ok() {
629 let req = Request::builder()
630 .header("data","data")
631 .header("not_data","GRRR I SO HACK U")
632 .body(BoxBody::default())
633 .unwrap();
634
635 let data = ArbitraryData { data: "data".into() };
636 let definitely_still_data = ArbitraryData{data:"not_data".into()};
637 let app = Router::new()
638 .route("/", get(ok))
639 .layer(
640 GuardLayer::with(
641 GuardService::new(
642 data.clone(),
643 data.clone(),"err")
644 .and(GuardService::new(
645 definitely_still_data.clone(),
646 definitely_still_data,"err"
647 ))
648 )
649 );
650 let resp = app.oneshot(req).await.unwrap();
651 assert_eq!(resp.status(),StatusCode::UNAUTHORIZED)
652 }
653
654 #[tokio::test]
655 async fn test_deep_layer_ok() {
656 let req = Request::builder()
657 .header("1","1")
658 .header("2","2")
659 .header("3","3")
660 .header("4","4")
661 .header("5","5")
662 .header("6","6")
663 .header("7","8")
664 .body(BoxBody::default())
665 .unwrap();
666
667 let one = ArbitraryData { data: "1".into() };
668 let two = ArbitraryData { data: "2".into() };
669 let three = ArbitraryData { data: "3".into() };
670 let four = ArbitraryData { data: "4".into() };
671 let five = ArbitraryData { data: "5".into() };
672 let six = ArbitraryData { data: "6".into() };
673 let seven = ArbitraryData { data: "7".into() };
674 let bad = ArbitraryData { data: "bad".into() };
675
676 let app = Router::new()
677 .route("/", get(ok))
678 .layer(
679 GuardLayer::with(
680 GuardService::new(
681 one.clone(),
682 one.clone(),"err").and(
683 GuardService::new(
684 two.clone(),
685 bad.clone(),"err"
686 ).or(
687 GuardService::new(
688 three.clone(),
689 bad.clone(),"err"
690 )
691 .or(
692 GuardService::new(
693 four.clone(),
694 four.clone(),"err"
695 )
696 )
697 )
698 ).and(
699 GuardService::new(five.clone(), five.clone(),"err"
700 )
701 ).and(
702 GuardService::new(six.clone(),six.clone(),"err")
703 )
704 .or(GuardService::new(seven.clone(),bad.clone(),"err")))
705 );
706 let resp = app.oneshot(req).await.unwrap();
707 assert_eq!(resp.status(),StatusCode::OK)
708 }
709 async fn time_time<B>(req: Request<B>, next: Next<B>) -> Result<Response, StatusCode> {
710 sleep(Duration::from_millis(10)).await;
711 Ok(next.run(req).await)
712 }
713 #[tokio::test]
714 async fn test_into_layer() {
715 let req = Request::builder()
716 .header("data","data")
717 .body(BoxBody::default())
718 .unwrap();
719
720 let data = ArbitraryData { data: "data".into() };
721 let app = Router::new()
722 .route("/", get(ok))
723 .layer(
724 GuardService::new(data.clone(), data.clone(),"err")
725 .into_layer()
726 );
727 let resp = app.oneshot(req).await.unwrap();
728 assert_eq!(resp.status(),StatusCode::OK)
729 }
730 #[tokio::test]
731 async fn test_happy_with_layered_polls() {
732 let app = Router::new()
733 .route("/",get(ok))
734 .layer(axum::middleware::from_fn(time_time))
735 .layer(GuardLayer::with(
736 GuardService::new(
737 ArbitraryData{data:"x".into()},ArbitraryData{data:"x".into()},"err"
738 )))
739 .layer(axum::middleware::from_fn(time_time))
740 .layer(tower_http::timeout::TimeoutLayer::new(Duration::from_secs(1)));
741 let response = app
742 .oneshot(
743 Request::builder()
744 .header("x","x")
745 .uri("/")
746 .body(Body::empty())
747 .unwrap(),
748 )
749 .await
750 .unwrap();
751
752 assert_eq!(response.status(), StatusCode::OK);
753 }
754
755}