1#![no_std]
91extern crate no_std_compat as std;
92
93use std::prelude::v1::*;
94
95use core::mem;
96use std::collections::BTreeMap;
97
98pub use product_os_http::{
99 Request, Response,
100 request::Parts as RequestParts,
101 response::Parts as ResponseParts,
102 StatusCode,
103 header::{ HeaderMap, HeaderName, HeaderValue }
104};
105
106pub use product_os_http_body::{ BodyBytes, Bytes };
107
108#[cfg(feature = "middleware")]
109pub use product_os_http_body::*;
110
111pub use axum::{
112 routing::{get, post, put, patch, delete, head, trace, options, any, MethodRouter},
113 middleware::{from_fn, from_fn_with_state, Next},
114 handler::Handler,
115 response::{ IntoResponse, Redirect },
116 body::{ HttpBody, Body },
117 Router,
118 Json,
119 Form,
120 extract::{
121 State, Path, Query, FromRef, FromRequest, FromRequestParts,
122 ws::{ WebSocketUpgrade, WebSocket, Message }
123 },
124 BoxError,
125 http::Uri
126};
127
128use axum::routing::Route;
130
131
132pub use crate::extractors::RequestMethod;
133
134
135
136#[cfg(feature = "debug")]
137pub use axum_macros::debug_handler;
138
139pub use tower::{
140 Layer, Service, ServiceBuilder, ServiceExt, MakeService,
141 make::AsService, make::MakeConnection, make::IntoService,
142 make::Shared, make::future::SharedFuture,
143 util::service_fn, util::ServiceFn
144};
145
146pub use product_os_request::Method;
147use regex::Regex;
148
149
150#[cfg(feature = "middleware")]
151pub mod middleware;
152
153mod extractors;
154mod default_headers;
155mod service_wrapper;
156pub mod dual_protocol;
157
158pub use crate::default_headers::DefaultHeadersLayer;
159pub use crate::dual_protocol::{UpgradeHttpLayer, Protocol};
160
161#[cfg(feature = "middleware")]
162pub use crate::middleware::*;
163
164#[cfg(feature = "cors")]
165use tower_http::cors::{CorsLayer, Any};
166
167use crate::service_wrapper::{WrapperLayer, WrapperService};
168
169
170#[derive(Debug)]
185pub enum ProductOSRouterError {
186 Headers(String),
188 Query(String),
190 Body(String),
192 Authentication(String),
194 Authorization(String),
196 Process(String),
198 Unavailable(String)
200}
201
202impl ProductOSRouterError {
203 fn message(&self) -> &str {
205 match self {
206 ProductOSRouterError::Headers(m) => m,
207 ProductOSRouterError::Query(m) => m,
208 ProductOSRouterError::Body(m) => m,
209 ProductOSRouterError::Authentication(m) => m,
210 ProductOSRouterError::Authorization(m) => m,
211 ProductOSRouterError::Process(m) => m,
212 ProductOSRouterError::Unavailable(m) => m,
213 }
214 }
215
216 fn status_code(&self) -> StatusCode {
218 match self {
219 ProductOSRouterError::Headers(_) => StatusCode::BAD_REQUEST,
220 ProductOSRouterError::Query(_) => StatusCode::BAD_REQUEST,
221 ProductOSRouterError::Body(_) => StatusCode::BAD_REQUEST,
222 ProductOSRouterError::Authentication(_) => StatusCode::UNAUTHORIZED,
223 ProductOSRouterError::Authorization(_) => StatusCode::FORBIDDEN,
224 ProductOSRouterError::Process(_) => StatusCode::GATEWAY_TIMEOUT,
225 ProductOSRouterError::Unavailable(_) => StatusCode::SERVICE_UNAVAILABLE,
226 }
227 }
228
229 fn build_error_response(message: &str, status_code: &StatusCode) -> Response<Body> {
234 let safe_message = message.replace("\"", "'");
235 let status_code_u16 = status_code.as_u16();
236
237 Response::builder()
238 .status(status_code_u16)
239 .header("content-type", "application/json")
240 .body(Body::from(format!("{{\"error\": \"{safe_message}\"}}")))
241 .unwrap_or_else(|_| {
242 Response::builder()
243 .status(StatusCode::INTERNAL_SERVER_ERROR.as_u16())
244 .body(Body::from("{\"error\": \"internal error\"}"))
245 .expect("fallback response must be valid")
246 })
247 }
248
249 pub fn error_response(error: ProductOSRouterError, status_code: &StatusCode) -> Response<Body> {
272 Self::build_error_response(error.message(), status_code)
273 }
274
275 pub fn handle_error(error: &ProductOSRouterError) -> Response<Body> {
297 Self::build_error_response(error.message(), &error.status_code())
298 }
299}
300
301
302pub struct ProductOSRouter<S = ()> {
346 router: Router<S>,
347 state: S
348}
349
350impl<S: Clone + Send + Sync + 'static> ProductOSRouter<S> {
351 fn take_router(&mut self) -> Router<S> {
356 mem::replace(&mut self.router, Router::new())
357 }
358}
359
360impl ProductOSRouter<()> {
361 pub fn new() -> Self {
374 ProductOSRouter::new_with_state(())
375 }
376}
377
378impl Default for ProductOSRouter<()> {
379 fn default() -> Self {
380 Self::new()
381 }
382}
383
384impl ProductOSRouter<()> {
385 pub fn param_to_field(path: &str) -> String {
406 use std::sync::OnceLock;
407
408 static PARAM_MATCHER: OnceLock<Regex> = OnceLock::new();
409 static SYMBOL_MATCHER: OnceLock<Regex> = OnceLock::new();
410
411 let param_matcher = PARAM_MATCHER.get_or_init(|| {
412 Regex::new("([:*])([a-zA-Z][-_a-zA-Z0-9]*)").unwrap()
413 });
414 let symbol_matcher = SYMBOL_MATCHER.get_or_init(|| {
415 Regex::new("[*?&]").unwrap()
416 });
417
418 let mut path_new = path.to_owned();
419 for (value, [_prefix, variable_name]) in param_matcher.captures_iter(path).map(|c| c.extract()) {
420 let mut value_sanitized = value.to_owned();
421 for (prefix, []) in symbol_matcher.captures_iter(value).map(|c| c.extract()) {
422 let mut symbol = String::from("[");
423 symbol.push(prefix.chars().next().unwrap());
424 symbol.push(']');
425
426 value_sanitized = value_sanitized.replace(prefix, symbol.as_str());
427 }
428
429 let exact_matcher = Regex::new(value_sanitized.as_str()).unwrap();
430
431 let mut variable = String::from("{");
432 variable.push_str(variable_name);
433 variable.push('}');
434
435 path_new = exact_matcher.replace(path_new.as_str(), variable).to_string()
436 }
437
438 path_new
439 }
440
441 pub fn add_service_no_state<Z>(&mut self, path: &str, service: Z)
445 where
446 Z: Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Clone + Send + Sync + 'static,
447 Z::Future: Send + 'static
448 {
449 let wrapper = WrapperService::new(service);
450 let path = ProductOSRouter::param_to_field(path);
451 self.router = self.take_router().route_service(path.as_str(), wrapper);
452 }
453
454 pub fn add_middleware_no_state<L, ResBody>(&mut self, middleware: L)
458 where
459 L: Layer<Route> + Clone + Send + Sync + 'static,
460 L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
461 ResBody: HttpBody<Data = Bytes> + Send + 'static,
462 ResBody::Error: Into<BoxError>,
463 <L::Service as Service<Request<Body>>>::Future: Send + 'static,
464 <L::Service as Service<Request<Body>>>::Error: Into<BoxError> + 'static,
465 {
466 let wrapper: WrapperLayer<L, Route, ResBody> = WrapperLayer::new(middleware);
467 self.router = self.take_router().layer(wrapper)
468 }
469
470 pub fn add_default_header_no_state(&mut self, header_name: &str, header_value: &str) {
474 let mut default_headers = HeaderMap::new();
475 default_headers.insert(HeaderName::from_bytes(header_name.as_bytes()).unwrap(), HeaderValue::from_str(header_value).unwrap());
476
477 self.add_middleware_no_state(DefaultHeadersLayer::new(default_headers));
478 }
479
480 pub fn not_implemented() -> Response<Body> {
485 Response::builder()
486 .status(StatusCode::NOT_IMPLEMENTED.as_u16())
487 .header("content-type", "application/json")
488 .body(Body::from("{}"))
489 .unwrap()
490 }
491
492 #[deprecated(since = "0.0.41", note = "Use not_implemented() which correctly returns 501, or build a 404 response explicitly")]
498 pub fn not_implemented_legacy() -> Response<Body> {
499 Response::builder()
500 .status(StatusCode::NOT_FOUND.as_u16())
501 .header("content-type", "application/json")
502 .body(Body::from("{}"))
503 .unwrap()
504 }
505
506
507
508 pub fn add_route_no_state(&mut self, path: &str, service_handler: MethodRouter)
510 {
511 let service= ServiceBuilder::new().service(service_handler);
512 let path = ProductOSRouter::param_to_field(path);
513 self.router = self.take_router().route(path.as_str(), service);
514 }
515
516 pub fn set_fallback_no_state(&mut self, service_handler: MethodRouter)
518 {
519 let service= ServiceBuilder::new().service(service_handler);
520 self.router = self.take_router().fallback(service);
521 }
522
523 pub fn add_get_no_state<H, T>(&mut self, path: &str, handler: H)
543 where
544 H: Handler<T, ()>,
545 T: 'static
546 {
547 let method_router = ProductOSRouter::convert_handler_to_method_router(Method::GET, handler, ());
548 let path = ProductOSRouter::param_to_field(path);
549 self.router = self.take_router().route(path.as_str(), method_router);
550 }
551
552 pub fn add_post_no_state<H, T>(&mut self, path: &str, handler: H)
578 where
579 H: Handler<T, ()>,
580 T: 'static
581 {
582 let method_router = ProductOSRouter::convert_handler_to_method_router(Method::POST, handler, ());
583 let path = ProductOSRouter::param_to_field(path);
584 self.router = self.take_router().route(path.as_str(), method_router);
585 }
586
587 pub fn add_handler_no_state<H, T>(&mut self, path: &str, method: Method, handler: H)
589 where
590 H: Handler<T, ()>,
591 T: 'static
592 {
593 let method_router = ProductOSRouter::convert_handler_to_method_router(method, handler, ());
594 let path = ProductOSRouter::param_to_field(path);
595 self.router = self.take_router().route(path.as_str(), method_router);
596 }
597
598 pub fn set_fallback_handler_no_state<H, T>(&mut self, handler: H)
600 where
601 H: Handler<T, ()>,
602 T: 'static
603 {
604 self.router = self.take_router().fallback(handler).with_state(());
605 }
606
607 #[cfg(feature = "cors")]
609 pub fn add_cors_handler_no_state<H, T>(&mut self, path: &str, method: Method, handler: H)
610 where
611 H: Handler<T, ()>,
612 T: 'static
613 {
614 let method_router = ProductOSRouter::add_cors_method_router(method, handler, ());
615 let path = ProductOSRouter::param_to_field(path);
616 self.router = self.take_router().route(path.as_str(), method_router);
617 }
618
619 #[cfg(feature = "ws")]
621 pub fn add_ws_handler_no_state<H, T>(&mut self, path: &str, ws_handler: H)
622 where
623 H: Handler<T, ()>,
624 T: 'static
625 {
626 let service: MethodRouter = ProductOSRouter::convert_handler_to_method_router(Method::GET, ws_handler, ());
628 let path = ProductOSRouter::param_to_field(path);
629 self.router = self.take_router().route(path.as_str(), service);
630 }
631
632 #[cfg(feature = "sse")]
634 pub fn add_sse_handler_no_state<H, T>(&mut self, path: &str, sse_handler: H)
635 where
636 H: Handler<T, ()>,
637 T: 'static
638 {
639 let service: MethodRouter = ProductOSRouter::convert_handler_to_method_router(Method::GET, sse_handler, ());
641 let path = ProductOSRouter::param_to_field(path);
642 self.router = self.take_router().route(path.as_str(), service);
643 }
644
645 pub fn add_handlers_no_state<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
647 where
648 H: Handler<T, ()>,
649 T: 'static
650 {
651 let mut method_router: MethodRouter = MethodRouter::new();
652
653 for (method, handler) in handlers {
654 method_router = ProductOSRouter::add_handler_to_method_router(method_router, method, handler);
655 }
656
657 let path = ProductOSRouter::param_to_field(path);
658 self.router = self.take_router().route(path.as_str(), method_router);
659 }
660
661 #[cfg(feature = "cors")]
668 pub fn add_cors_middleware_no_state(&mut self) {
669 self.add_middleware_no_state(
670 CorsLayer::new()
671 .allow_origin(Any)
672 .allow_methods(Any)
673 .allow_headers(Any),
674 );
675 }
676
677 #[cfg(feature = "cors")]
679 pub fn add_cors_handlers_no_state<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
680 where
681 H: Handler<T, ()>,
682 T: 'static
683 {
684 let mut service: MethodRouter = MethodRouter::new();
685
686 for (method, handler) in handlers {
687 service = ProductOSRouter::add_cors_handler_to_method_router(service, method, handler);
688 }
689
690 let path = ProductOSRouter::param_to_field(path);
691 self.router = self.take_router().route(path.as_str(), service);
692 }
693
694}
695
696impl<S> ProductOSRouter<S>
697where
698 S: Clone + Send + Sync + 'static
699{
700 pub fn get_router(&self) -> Router {
704 self.router.clone().with_state(self.state.clone())
705 }
706
707 pub fn get_router_mut(&mut self) -> &mut Router<S> {
709 &mut self.router
710 }
711
712 pub fn new_with_state(state: S) -> Self
716 where
717 S: Clone + Send + 'static
718 {
719 Self {
720 router: Router::new().with_state(state.clone()),
721 state
722 }
723 }
724
725 pub fn add_service<Z>(&mut self, path: &str, service: Z)
727 where
728 Z: Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Clone + Send + Sync + 'static,
729 Z::Response: IntoResponse,
730 Z::Future: Send + 'static
731 {
732 let wrapper = WrapperService::new(service);
733 let path = ProductOSRouter::param_to_field(path);
734 self.router = self.take_router().route_service(path.as_str(), wrapper);
735 }
736
737 pub fn add_middleware<L, ResBody>(&mut self, middleware: L)
739 where
740 L: Layer<Route> + Clone + Send + Sync + 'static,
741 L::Service: Service<Request<Body>, Response = Response<ResBody>> + Clone + Send + Sync + 'static,
742 ResBody: HttpBody<Data = Bytes> + Send + 'static,
743 ResBody::Error: Into<BoxError>,
744 <L::Service as Service<Request<Body>>>::Future: Send + 'static,
745 <L::Service as Service<Request<Body>>>::Error: Into<BoxError> + 'static,
746 {
747 let wrapper: WrapperLayer<L, Route, ResBody> = WrapperLayer::new(middleware);
748 self.router = self.take_router().layer(wrapper)
749 }
750
751 pub fn add_default_header(&mut self, header_name: &str, header_value: &str) {
760 let mut default_headers = HeaderMap::new();
761 default_headers.insert(HeaderName::from_bytes(header_name.as_bytes()).unwrap(), HeaderValue::from_str(header_value).unwrap());
762
763 self.add_middleware(DefaultHeadersLayer::new(default_headers));
764 }
765
766 #[deprecated(since = "0.0.41", note = "Use add_default_header(&str, &str) instead")]
771 pub fn add_default_header_owned(&mut self, header_name: String, header_value: String) {
772 self.add_default_header(header_name.as_str(), header_value.as_str());
773 }
774
775
776
777 pub fn add_route(&mut self, path: &str, service_handler: MethodRouter<S>)
779 where
780 S: Clone + Send + 'static,
781 {
782 let service= ServiceBuilder::new().service(service_handler);
783 let path = ProductOSRouter::param_to_field(path);
784 self.router = self.take_router().route(path.as_str(), service);
785 }
786
787 pub fn set_fallback(&mut self, service_handler: MethodRouter<S>)
789 where
790 S: Clone + Send + 'static
791 {
792 let service= ServiceBuilder::new().service(service_handler);
793 self.router = self.take_router().fallback(service);
794 }
795
796 pub fn add_get<H, T>(&mut self, path: &str, handler: H)
798 where
799 H: Handler<T, S>,
800 T: 'static
801 {
802 let method_router = ProductOSRouter::convert_handler_to_method_router(Method::GET, handler, self.state.clone());
803 let path = ProductOSRouter::param_to_field(path);
804 self.router = self.take_router().route(path.as_str(), method_router);
805 }
806
807 pub fn add_post<H, T>(&mut self, path: &str, handler: H)
809 where
810 H: Handler<T, S>,
811 T: 'static
812 {
813 let method_router = ProductOSRouter::convert_handler_to_method_router(Method::POST, handler, self.state.clone());
814 let path = ProductOSRouter::param_to_field(path);
815 self.router = self.take_router().route(path.as_str(), method_router);
816 }
817
818 pub fn add_handler<H, T>(&mut self, path: &str, method: Method, handler: H)
820 where
821 H: Handler<T, S>,
822 T: 'static
823 {
824 let method_router = ProductOSRouter::convert_handler_to_method_router(method, handler, self.state.clone());
825 let path = ProductOSRouter::param_to_field(path);
826 self.router = self.take_router().route(path.as_str(), method_router);
827 }
828
829 pub fn add_state(&mut self, state: S) {
831 self.router = self.take_router().with_state(state);
832 }
833
834 pub fn set_fallback_handler<H, T>(&mut self, handler: H)
836 where
837 H: Handler<T, S>,
838 T: 'static
839 {
840 self.router = self.take_router().fallback(handler).with_state(self.state.clone());
841 }
842
843 #[cfg(feature = "cors")]
845 pub fn add_cors_handler<H, T>(&mut self, path: &str, method: Method, handler: H)
846 where
847 H: Handler<T, S>,
848 T: 'static
849 {
850 let method_router = ProductOSRouter::add_cors_method_router(method, handler, self.state.clone());
851 let path = ProductOSRouter::param_to_field(path);
852 self.router = self.take_router().route(path.as_str(), method_router);
853 }
854
855 #[cfg(feature = "ws")]
857 pub fn add_ws_handler<H, T>(&mut self, path: &str, ws_handler: H)
858 where
859 H: Handler<T, S>,
860 T: 'static
861 {
862 let service: MethodRouter<S> = ProductOSRouter::convert_handler_to_method_router(Method::GET, ws_handler, self.state.clone());
864 let path = ProductOSRouter::param_to_field(path);
865 self.router = self.take_router().route(path.as_str(), service);
866 }
867
868 #[cfg(feature = "sse")]
870 pub fn add_sse_handler<H, T>(&mut self, path: &str, sse_handler: H)
871 where
872 H: Handler<T, S>,
873 T: 'static
874 {
875 let service: MethodRouter<S> = ProductOSRouter::convert_handler_to_method_router(Method::GET, sse_handler, self.state.clone());
877 let path = ProductOSRouter::param_to_field(path);
878 self.router = self.take_router().route(path.as_str(), service);
879 }
880
881 pub fn add_handlers<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
883 where
884 H: Handler<T, S>,
885 T: 'static
886 {
887 let mut method_router: MethodRouter<S> = MethodRouter::new();
888
889 for (method, handler) in handlers {
890 method_router = ProductOSRouter::add_handler_to_method_router(method_router, method, handler);
891 }
892
893 let path = ProductOSRouter::param_to_field(path);
894 self.router = self.take_router().route(path.as_str(), method_router);
895 }
896
897 #[cfg(feature = "cors")]
899 pub fn add_cors_handlers<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
900 where
901 H: Handler<T, S>,
902 T: 'static
903 {
904 let mut service: MethodRouter<S> = MethodRouter::new();
905
906 for (method, handler) in handlers {
907 service = ProductOSRouter::add_cors_handler_to_method_router(service, method, handler);
908 }
909
910 let path = ProductOSRouter::param_to_field(path);
911 self.router = self.take_router().route(path.as_str(), service);
912 }
913
914 #[cfg(feature = "cors")]
921 pub fn add_cors_middleware(&mut self) {
922 self.add_middleware(
923 CorsLayer::new()
924 .allow_origin(Any)
925 .allow_methods(Any)
926 .allow_headers(Any),
927 );
928 }
929
930 fn add_handler_to_method_router<H, T>(method_router: MethodRouter<S>, method: Method, handler: H) -> MethodRouter<S>
931 where
932 H: Handler<T, S>,
933 T: 'static
934 {
935 match method {
936 Method::GET => method_router.get(handler),
937 Method::POST => method_router.post(handler),
938 Method::PUT => method_router.put(handler),
939 Method::PATCH => method_router.patch(handler),
940 Method::DELETE => method_router.delete(handler),
941 Method::TRACE => method_router.trace(handler),
942 Method::HEAD => method_router.head(handler),
943 Method::CONNECT => method_router.get(handler),
945 Method::OPTIONS => method_router.options(handler),
946 Method::ANY => method_router.fallback(handler),
947 }
948 }
949
950 fn convert_handler_to_method_router<H, T>(method: Method, handler: H, state: S) -> MethodRouter<S>
951 where
952 H: Handler<T, S>,
953 T: 'static,
954 {
955 match method {
956 Method::GET => get(handler).with_state(state),
957 Method::POST => post(handler).with_state(state),
958 Method::PUT => put(handler).with_state(state),
959 Method::PATCH => patch(handler).with_state(state),
960 Method::DELETE => delete(handler).with_state(state),
961 Method::TRACE => trace(handler).with_state(state),
962 Method::HEAD => head(handler).with_state(state),
963 Method::CONNECT => get(handler).with_state(state),
965 Method::OPTIONS => options(handler).with_state(state),
966 Method::ANY => any(handler).with_state(state)
967 }
968 }
969
970 #[cfg(feature = "cors")]
971 fn add_cors_handler_to_method_router<H, T>(method_router: MethodRouter<S>, method: Method, handler: H) -> MethodRouter<S>
972 where
973 H: Handler<T, S>,
974 T: 'static
975 {
976 ProductOSRouter::add_handler_to_method_router(method_router, method, handler).layer(CorsLayer::new()
977 .allow_origin(Any)
978 .allow_methods(Any))
979 }
980
981 #[cfg(feature = "cors")]
982 fn add_cors_method_router<H, T>(method: Method, handler: H, state: S) -> MethodRouter<S>
983 where
984 H: Handler<T, S>,
985 T: 'static
986 {
987 ProductOSRouter::convert_handler_to_method_router(method, handler, state).layer(CorsLayer::new()
988 .allow_origin(Any)
989 .allow_methods(Any))
990 }
991}
992
993
994use std::ops::{Deref, DerefMut};
998
999#[derive(Debug, Clone, Copy, Default)]
1020pub struct Extension<T>(pub T);
1021
1022impl<T> Deref for Extension<T> {
1023 type Target = T;
1024
1025 fn deref(&self) -> &Self::Target {
1026 &self.0
1027 }
1028}
1029
1030impl<T> DerefMut for Extension<T> {
1031 fn deref_mut(&mut self) -> &mut Self::Target {
1032 &mut self.0
1033 }
1034}
1035
1036#[async_trait::async_trait]
1037impl<T, S> FromRequestParts<S> for Extension<T>
1038where
1039 T: Clone + Send + Sync + 'static,
1040 S: Send + Sync,
1041{
1042 type Rejection = (product_os_http::StatusCode, &'static str);
1043
1044 fn from_request_parts(parts: &mut product_os_http::request::Parts, _state: &S) -> impl core::future::Future<Output = Result<Self, Self::Rejection>> + Send {
1045 let value = parts
1046 .extensions
1047 .get::<T>()
1048 .cloned()
1049 .map(Extension)
1050 .ok_or((
1051 product_os_http::StatusCode::INTERNAL_SERVER_ERROR,
1052 "Extension not found in request",
1053 ));
1054
1055 async move { value }
1056 }
1057}
1058
1059pub fn extension_layer<T>(value: T) -> tower::util::MapRequestLayer<impl Fn(Request<Body>) -> Request<Body> + Clone>
1062where
1063 T: Clone + Send + Sync + 'static,
1064{
1065 tower::util::MapRequestLayer::new(move |mut req: Request<Body>| {
1066 req.extensions_mut().insert(value.clone());
1067 req
1068 })
1069}
1070