product_os_router/
lib.rs

1#![no_std]
2extern crate no_std_compat as std;
3
4use std::prelude::v1::*;
5
6
7use std::collections::BTreeMap;
8
9pub use product_os_http::{
10    Request, Response,
11    request::Parts as RequestParts,
12    response::Parts as ResponseParts,
13    StatusCode,
14    header::{ HeaderMap, HeaderName, HeaderValue }
15};
16
17pub use product_os_http_body::{ BodyBytes, Bytes };
18
19/*
20pub use axum::http::{
21    Request,
22    request::Parts as RequestParts,
23    response::Parts as ResponseParts,
24    StatusCode,
25    header::{ HeaderMap, HeaderName, HeaderValue }
26};
27*/
28
29#[cfg(feature = "middleware")]
30pub use product_os_http_body::*;
31
32
33pub use axum::{
34    routing::*,
35    middleware::*,
36    handler::Handler,
37    response::{ IntoResponse, Redirect },
38    body::{ HttpBody, Body },
39    Router,
40    Json,
41    Form,
42    extract::{
43        *,
44        ws::{ WebSocketUpgrade, WebSocket, Message }
45    },
46    BoxError,
47    http::Uri
48};
49
50
51pub use crate::extractors::RequestMethod;
52
53
54
55// #[cfg(feature = "sessions")]
56// pub use axum_extra::extract::Query as QueryAdvanced;
57
58use axum::RequestExt;
59
60
61#[cfg(feature = "debug")]
62pub use axum_macros::debug_handler;
63
64/*
65pub use product_os_tower::{
66    Layer, Service, ServiceBuilder, ServiceExt,
67    make::Shared, util::service_fn, util::ServiceFn
68};
69*/
70
71pub use tower::{
72    Layer, Service, ServiceBuilder, ServiceExt, MakeService,
73    make::AsService, make::MakeConnection, make::IntoService,
74    make::Shared, make::future::SharedFuture,
75    util::service_fn, util::ServiceFn
76};
77
78pub use product_os_request::Method;
79use regex::Regex;
80
81
82#[cfg(feature = "middleware")]
83mod middleware;
84
85mod extractors;
86mod default_headers;
87mod service_wrapper;
88mod dual_protocol;
89
90pub use crate::default_headers::DefaultHeadersLayer;
91pub use crate::dual_protocol::UpgradeHttpLayer;
92
93
94#[cfg(feature = "middleware")]
95pub use crate::middleware::*;
96
97#[cfg(feature = "cors")]
98use tower_http::cors::CorsLayer;
99
100use crate::service_wrapper::{WrapperLayer, WrapperService};
101// use product_os_tower_http::cors::CorsLayer;
102
103
104
105#[derive(Debug)]
106pub enum ProductOSRouterError {
107    Headers(String),
108    Query(String),
109    Body(String),
110    Authentication(String),
111    Authorization(String),
112    Process(String),
113    Unavailable(String)
114}
115
116impl ProductOSRouterError {
117    pub fn error_response(error: ProductOSRouterError, status_code: &StatusCode) -> Response<Body> {
118        let error_message = match error {
119            ProductOSRouterError::Headers(m) => { m }
120            ProductOSRouterError::Query(m) => { m }
121            ProductOSRouterError::Body(m) => { m }
122            ProductOSRouterError::Authentication(m) => { m }
123            ProductOSRouterError::Authorization(m) => { m }
124            ProductOSRouterError::Process(m) => { m }
125            ProductOSRouterError::Unavailable(m) => { m }
126        }.replace("\"", "'");
127
128        let status_code_u16 = status_code.as_u16();
129
130        Response::builder()
131            .status(status_code_u16)
132            .header("content-type", "application/json")
133            .body(
134                Body::from(
135                    format!("{{\
136                            \"error\": \"{error_message}\"\
137                        }}")))
138            .unwrap()
139    }
140
141    pub fn handle_error(error: &ProductOSRouterError) -> Response<Body> {
142        match error {
143            ProductOSRouterError::Headers(msg) => Response::builder()
144                .status(StatusCode::BAD_REQUEST.as_u16())
145                .body(Body::from(format!("{{\
146                    \"error\": \"{msg}\"\
147                    }}")))
148                .unwrap(),
149            ProductOSRouterError::Query(msg) => Response::builder()
150                .status(StatusCode::BAD_REQUEST.as_u16())
151                .body(Body::from(format!("{{\
152                    \"error\": \"{msg}\"\
153                    }}")))
154                .unwrap(),
155            ProductOSRouterError::Body(msg) => Response::builder()
156                .status(StatusCode::BAD_REQUEST.as_u16())
157                .body(Body::from(format!("{{\
158                    \"error\": \"{msg}\"\
159                    }}")))
160                .unwrap(),
161            ProductOSRouterError::Authentication(msg) => Response::builder()
162                .status(StatusCode::UNAUTHORIZED.as_u16())
163                .body(Body::from(format!("{{\
164                    \"error\": \"{msg}\"\
165                    }}")))
166                .unwrap(),
167            ProductOSRouterError::Authorization(msg) => Response::builder()
168                .status(StatusCode::FORBIDDEN.as_u16())
169                .body(Body::from(format!("{{\
170                    \"error\": \"{msg}\"\
171                    }}")))
172                .unwrap(),
173            ProductOSRouterError::Process(msg) => Response::builder()
174                .status(StatusCode::GATEWAY_TIMEOUT.as_u16())
175                .body(Body::from(format!("{{\
176                    \"error\": \"{msg}\"\
177                    }}")))
178                .unwrap(),
179            ProductOSRouterError::Unavailable(msg) => Response::builder()
180                .status(StatusCode::SERVICE_UNAVAILABLE.as_u16())
181                .body(Body::from(format!("{{\
182                    \"error\": \"{msg}\"\
183                    }}")))
184                .unwrap()
185        }
186    }
187}
188
189
190pub struct ProductOSRouter<S = ()> {
191    router: Router<S>,
192    state: S
193}
194
195impl ProductOSRouter<()> {
196    pub fn new() -> Self {
197        ProductOSRouter::new_with_state(())
198    }
199
200    pub fn param_to_field(path: &str) -> String {
201        let param_matcher = Regex::new("([:*])([a-zA-Z][-_a-zA-Z0-9]*)").unwrap();
202        let symbol_matcher = Regex::new("[*?&]").unwrap();
203
204        let mut path_new = path.to_owned();
205        for (value, [prefix, variable_name]) in param_matcher.captures_iter(path).map(|c| c.extract()) {
206            let mut value_sanitized = value.to_owned();
207            for (prefix, []) in symbol_matcher.captures_iter(value).map(|c| c.extract()) {
208                let mut symbol = String::from("[");
209                symbol.push_str(prefix);
210                symbol.push_str("]");
211
212                value_sanitized = value_sanitized.replace(prefix, symbol.as_str());
213            }
214
215            let exact_matcher = Regex::new(value_sanitized.as_str()).unwrap();
216
217            let mut variable = String::from("{");
218            variable.push_str(variable_name);
219            variable.push_str("}");
220
221            path_new = exact_matcher.replace(path_new.as_str(), variable).to_string()
222        }
223
224        path_new
225    }
226
227    pub fn add_service_no_state<Z>(&mut self, path: &str, service: Z)
228        where
229            Z: Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Clone + Send + Sync + 'static,
230            Z::Future: Send + 'static
231    {
232        let wrapper = WrapperService::new(service);
233        let path = ProductOSRouter::param_to_field(path);
234        self.router = self.router.clone().route_service(path.as_str(), wrapper);
235    }
236
237    pub fn add_middleware_no_state<L>(&mut self, middleware: L)
238        where
239            L: Layer<Route> + Clone + Send + Sync + 'static,
240            L::Service: Service<Request<Body>> + Clone + Send + Sync + 'static,
241            <L::Service as Service<Request<Body>>>::Future: Send + 'static,
242            <L::Service as Service<Request<Body>>>::Error: Into<BoxError> + 'static,
243            <L::Service as Service<Request<Body>>>::Response: IntoResponse + 'static,
244    {
245        let wrapper = WrapperLayer::new(middleware);
246        self.router = self.router.clone().layer(wrapper)
247    }
248
249    pub fn add_default_header_no_state(&mut self, header_name: &str, header_value: &str) {
250        let mut default_headers = HeaderMap::new();
251        default_headers.insert(HeaderName::from_bytes(header_name.as_bytes()).unwrap(), HeaderValue::from_str(header_value).unwrap());
252
253        self.add_middleware_no_state(DefaultHeadersLayer::new(default_headers));
254    }
255
256    pub fn not_implemented() -> Response<Body> {
257        Response::builder()
258            .status(StatusCode::NOT_FOUND.as_u16())
259            .header("content-type", "application/json")
260            .body(Body::from("{}"))
261            .unwrap()
262    }
263
264
265
266    pub fn add_route_no_state(&mut self, path: &str, service_handler: MethodRouter)
267    {
268        let service= ServiceBuilder::new().service(service_handler);
269        let path = ProductOSRouter::param_to_field(path);
270        self.router = self.router.clone().route(path.as_str(), service);
271    }
272
273    pub fn set_fallback_no_state(&mut self, service_handler: MethodRouter)
274    {
275        let service= ServiceBuilder::new().service(service_handler);
276        self.router = self.router.clone().fallback(service);
277    }
278
279    pub fn add_get_no_state<H, T>(&mut self, path: &str, handler: H)
280        where
281            H: Handler<T, ()>,
282            T: 'static
283    {
284        let method_router = ProductOSRouter::convert_handler_to_method_router(Method::GET, handler, self.state.clone());
285        let path = ProductOSRouter::param_to_field(path);
286        self.router = self.router.clone().route(path.as_str(), method_router);
287    }
288
289    pub fn add_post_no_state<H, T>(&mut self, path: &str, handler: H)
290        where
291            H: Handler<T, ()>,
292            T: 'static
293    {
294        let method_router = ProductOSRouter::convert_handler_to_method_router(Method::POST, handler, self.state.clone());
295        let path = ProductOSRouter::param_to_field(path);
296        self.router = self.router.clone().route(path.as_str(), method_router);
297    }
298
299    pub fn add_handler_no_state<H, T>(&mut self, path: &str, method: Method, handler: H)
300        where
301            H: Handler<T, ()>,
302            T: 'static
303    {
304        let method_router = ProductOSRouter::convert_handler_to_method_router(method, handler, self.state.clone());
305        let path = ProductOSRouter::param_to_field(path);
306        self.router = self.router.clone().route(path.as_str(), method_router);
307    }
308
309    pub fn set_fallback_handler_no_state<H, T>(&mut self, handler: H)
310        where
311            H: Handler<T, ()>,
312            T: 'static
313    {
314        self.router = self.router.clone().fallback(handler).with_state(self.state.clone());
315    }
316
317    #[cfg(feature = "cors")]
318    pub fn add_cors_handler_no_state<H, T>(&mut self, path: &str, method: Method, handler: H)
319        where
320            H: Handler<T, ()>,
321            T: 'static
322    {
323        let method_router = ProductOSRouter::add_cors_method_router(method, handler, self.state.clone());
324        let path = ProductOSRouter::param_to_field(path);
325        self.router = self.router.clone().route(path.as_str(), method_router);
326    }
327
328    #[cfg(feature = "ws")]
329    pub fn add_ws_handler_no_state<H, T>(&mut self, path: &str, ws_handler: H)
330        where
331            H: Handler<T, ()>,
332            T: 'static
333    {
334        // https://docs.rs/axum/latest/axum/extract/ws/index.html
335        let service: MethodRouter = ProductOSRouter::convert_handler_to_method_router(Method::GET, ws_handler, self.state.clone());
336        let path = ProductOSRouter::param_to_field(path);
337        self.router = self.router.clone().route(path.as_str(), service);
338    }
339
340    #[cfg(feature = "sse")]
341    pub fn add_sse_handler_no_state<H, T>(&mut self, path: &str, sse_handler: H)
342        where
343            H: Handler<T, ()>,
344            T: 'static
345    {
346        // https://docs.rs/axum/latest/axum/response/sse/index.html
347        let service: MethodRouter = ProductOSRouter::convert_handler_to_method_router(Method::GET, sse_handler, self.state.clone());
348        let path = ProductOSRouter::param_to_field(path);
349        self.router = self.router.clone().route(path.as_str(), service);
350    }
351
352    pub fn add_handlers_no_state<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
353        where
354            H: Handler<T, ()>,
355            T: 'static
356    {
357        let mut method_router: MethodRouter = MethodRouter::new();
358
359        for (method, handler) in handlers {
360            method_router = ProductOSRouter::add_handler_to_method_router(method_router, method, handler);
361        }
362
363        let path = ProductOSRouter::param_to_field(path);
364        self.router = self.router.clone().route(path.as_str(), method_router);
365    }
366
367    #[cfg(feature = "cors")]
368    pub fn add_cors_handlers_no_state<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
369        where
370            H: Handler<T, ()>,
371            T: 'static
372    {
373        let mut service: MethodRouter = MethodRouter::new();
374
375        for (method, handler) in handlers {
376            service = ProductOSRouter::add_cors_handler_to_method_router(service, method, handler);
377        }
378
379        let path = ProductOSRouter::param_to_field(path);
380        self.router = self.router.clone().route(path.as_str(), service);
381    }
382
383    fn add_handler_to_method_router_no_state<H, T>(method_router: MethodRouter, method: Method, handler: H) -> MethodRouter
384        where
385            H: Handler<T, ()>,
386            T: 'static
387    {
388        match method {
389            Method::GET => method_router.get(handler),
390            Method::POST => method_router.post(handler),
391            Method::PUT => method_router.put(handler),
392            Method::PATCH => method_router.patch(handler),
393            Method::DELETE => method_router.delete(handler),
394            Method::TRACE => method_router.trace(handler),
395            Method::HEAD => method_router.head(handler),
396            Method::CONNECT => method_router.get(handler),
397            Method::OPTIONS => method_router.get(handler),
398            Method::ANY => method_router.get(handler),
399        }
400    }
401
402    fn convert_handler_to_method_router_no_state<H, T>(method: Method, handler: H) -> MethodRouter
403        where
404            H: Handler<T, ()>,
405            T: 'static,
406    {
407        match method {
408            Method::GET => get(handler),
409            Method::POST => post(handler),
410            Method::PUT => put(handler),
411            Method::PATCH => patch(handler),
412            Method::DELETE => delete(handler),
413            Method::TRACE => trace(handler),
414            Method::HEAD => head(handler),
415            Method::CONNECT => get(handler),
416            Method::OPTIONS => get(handler),
417            Method::ANY => any(handler)
418        }
419    }
420
421    #[cfg(feature = "cors")]
422    fn add_cors_handler_to_method_router_no_state<H, T>(method_router: MethodRouter, method: Method, handler: H) -> MethodRouter
423        where
424            H: Handler<T, ()>,
425            T: 'static
426    {
427        ProductOSRouter::add_handler_to_method_router_no_state(method_router, method, handler).layer(CorsLayer::new()
428            .allow_origin(tower_http::cors::any())
429            .allow_methods(tower_http::cors::any()))
430    }
431
432    #[cfg(feature = "cors")]
433    fn add_cors_method_router_no_state<H, T>(method: Method, handler: H) -> MethodRouter
434        where
435            H: Handler<T, ()>,
436            T: 'static
437    {
438        ProductOSRouter::convert_handler_to_method_router_no_state(method, handler).layer(CorsLayer::new()
439            .allow_origin(tower_http::cors::any())
440            .allow_methods(tower_http::cors::any()))
441    }
442}
443
444impl<S> ProductOSRouter<S>
445where
446    S: Clone + Send + Sync + 'static
447{
448    pub fn get_router(&self) -> Router {
449        self.router.clone().with_state(self.state.clone())
450    }
451
452    pub fn get_router_mut(&mut self) -> &mut Router<S> {
453        &mut self.router
454    }
455
456    pub fn new_with_state(state: S) -> Self
457    where
458        S: Clone + Send + 'static
459    {
460        Self {
461            router: Router::new().with_state(state.clone()),
462            state
463        }
464    }
465
466    pub fn add_service<Z>(&mut self, path: &str, service: Z)
467        where
468            Z: Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Clone + Send + Sync + 'static,
469            Z::Response: IntoResponse,
470            Z::Future: Send + 'static
471    {
472        let wrapper = WrapperService::new(service);
473        let path = ProductOSRouter::param_to_field(path);
474        self.router = self.router.clone().route_service(path.as_str(), wrapper);
475    }
476
477    pub fn add_middleware<L>(&mut self, middleware: L)
478        where
479            L: Layer<Route> + Clone + Send + Sync + 'static,
480            L::Service: Service<Request<Body>> + Clone + Send + Sync + 'static,
481            <L::Service as Service<Request<Body>>>::Future: Send + 'static,
482            <L::Service as Service<Request<Body>>>::Error: Into<BoxError> + 'static,
483            <L::Service as Service<Request<Body>>>::Response: IntoResponse + 'static,
484    {
485        let wrapper = WrapperLayer::new(middleware);
486        self.router = self.router.clone().layer(wrapper)
487    }
488
489    pub fn add_default_header(&mut self, header_name: String, header_value: String) {
490        let mut default_headers = HeaderMap::new();
491        default_headers.insert(HeaderName::from_bytes(header_name.as_bytes()).unwrap(), HeaderValue::from_str(header_value.as_str()).unwrap());
492
493        self.add_middleware(DefaultHeadersLayer::new(default_headers));
494    }
495
496
497
498    pub fn add_route(&mut self, path: &str, service_handler: MethodRouter<S>)
499    where
500        S: Clone + Send + 'static,
501    {
502        let service= ServiceBuilder::new().service(service_handler);
503        let path = ProductOSRouter::param_to_field(path);
504        self.router = self.router.clone().route(path.as_str(), service);
505    }
506
507    pub fn set_fallback(&mut self, service_handler: MethodRouter<S>)
508    where
509        S: Clone + Send + 'static
510    {
511        let service= ServiceBuilder::new().service(service_handler);
512        self.router = self.router.clone().fallback(service);
513    }
514
515    pub fn add_get<H, T>(&mut self, path: &str, handler: H)
516    where
517        H: Handler<T, S>,
518        T: 'static
519    {
520        let method_router = ProductOSRouter::convert_handler_to_method_router(Method::GET, handler, self.state.clone());
521        let path = ProductOSRouter::param_to_field(path);
522        self.router = self.router.clone().route(path.as_str(), method_router);
523    }
524
525    pub fn add_post<H, T>(&mut self, path: &str, handler: H)
526    where
527        H: Handler<T, S>,
528        T: 'static
529    {
530        let method_router = ProductOSRouter::convert_handler_to_method_router(Method::POST, handler, self.state.clone());
531        let path = ProductOSRouter::param_to_field(path);
532        self.router = self.router.clone().route(path.as_str(), method_router);
533    }
534
535    pub fn add_handler<H, T>(&mut self, path: &str, method: Method, handler: H)
536    where
537        H: Handler<T, S>,
538        T: 'static
539    {
540        let method_router = ProductOSRouter::convert_handler_to_method_router(method, handler, self.state.clone());
541        let path = ProductOSRouter::param_to_field(path);
542        self.router = self.router.clone().route(path.as_str(), method_router);
543    }
544
545    pub fn add_state(&mut self, state: S) {
546        self.router = self.router.clone().with_state(state);
547    }
548
549    pub fn set_fallback_handler<H, T>(&mut self, handler: H)
550    where
551        H: Handler<T, S>,
552        T: 'static
553    {
554        self.router = self.router.clone().fallback(handler).with_state(self.state.clone());
555    }
556
557    #[cfg(feature = "cors")]
558    pub fn add_cors_handler<H, T>(&mut self, path: &str, method: Method, handler: H)
559    where
560        H: Handler<T, S>,
561        T: 'static
562    {
563        let method_router = ProductOSRouter::add_cors_method_router(method, handler, self.state.clone());
564        let path = ProductOSRouter::param_to_field(path);
565        self.router = self.router.clone().route(path.as_str(), method_router);
566    }
567
568    #[cfg(feature = "ws")]
569    pub fn add_ws_handler<H, T>(&mut self, path: &str, ws_handler: H)
570    where
571        H: Handler<T, S>,
572        T: 'static
573    {
574        // https://docs.rs/axum/latest/axum/extract/ws/index.html
575        let service: MethodRouter<S> = ProductOSRouter::convert_handler_to_method_router(Method::GET, ws_handler, self.state.clone());
576        let path = ProductOSRouter::param_to_field(path);
577        self.router = self.router.clone().route(path.as_str(), service);
578    }
579
580    #[cfg(feature = "sse")]
581    pub fn add_sse_handler<H, T>(&mut self, path: &str, sse_handler: H)
582        where
583            H: Handler<T, S>,
584            T: 'static
585    {
586        // https://docs.rs/axum/latest/axum/response/sse/index.html
587        let service: MethodRouter<S> = ProductOSRouter::convert_handler_to_method_router(Method::GET, sse_handler, self.state.clone());
588        let path = ProductOSRouter::param_to_field(path);
589        self.router = self.router.clone().route(path.as_str(), service);
590    }
591
592    pub fn add_handlers<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
593    where
594        H: Handler<T, S>,
595        T: 'static
596    {
597        let mut method_router: MethodRouter<S> = MethodRouter::new();
598
599        for (method, handler) in handlers {
600            method_router = ProductOSRouter::add_handler_to_method_router(method_router, method, handler);
601        }
602
603        let path = ProductOSRouter::param_to_field(path);
604        self.router = self.router.clone().route(path.as_str(), method_router);
605    }
606
607    #[cfg(feature = "cors")]
608    pub fn add_cors_handlers<H, T>(&mut self, path: &str, handlers: BTreeMap<Method, H>)
609    where
610        H: Handler<T, S>,
611        T: 'static
612    {
613        let mut service: MethodRouter<S> = MethodRouter::new();
614
615        for (method, handler) in handlers {
616            service = ProductOSRouter::add_cors_handler_to_method_router(service, method, handler);
617        }
618
619        let path = ProductOSRouter::param_to_field(path);
620        self.router = self.router.clone().route(path.as_str(), service);
621    }
622
623    fn add_handler_to_method_router<H, T>(method_router: MethodRouter<S>, method: Method, handler: H) -> MethodRouter<S>
624    where
625        H: Handler<T, S>,
626        T: 'static
627    {
628        match method {
629            Method::GET => method_router.get(handler),
630            Method::POST => method_router.post(handler),
631            Method::PUT => method_router.put(handler),
632            Method::PATCH => method_router.patch(handler),
633            Method::DELETE => method_router.delete(handler),
634            Method::TRACE => method_router.trace(handler),
635            Method::HEAD => method_router.head(handler),
636            Method::CONNECT => method_router.get(handler),
637            Method::OPTIONS => method_router.get(handler),
638            Method::ANY => method_router.get(handler),
639        }
640    }
641
642    fn convert_handler_to_method_router<H, T>(method: Method, handler: H, state: S) -> MethodRouter<S>
643    where
644        H: Handler<T, S>,
645        T: 'static,
646    {
647        match method {
648            Method::GET => get(handler).with_state(state),
649            Method::POST => post(handler).with_state(state),
650            Method::PUT => put(handler).with_state(state),
651            Method::PATCH => patch(handler).with_state(state),
652            Method::DELETE => delete(handler).with_state(state),
653            Method::TRACE => trace(handler).with_state(state),
654            Method::HEAD => head(handler).with_state(state),
655            Method::CONNECT => get(handler).with_state(state),
656            Method::OPTIONS => get(handler).with_state(state),
657            Method::ANY => any(handler).with_state(state)
658        }
659    }
660
661    #[cfg(feature = "cors")]
662    fn add_cors_handler_to_method_router<H, T>(method_router: MethodRouter<S>, method: Method, handler: H) -> MethodRouter<S>
663    where
664        H: Handler<T, S>,
665        T: 'static
666    {
667        ProductOSRouter::add_handler_to_method_router(method_router, method, handler).layer(CorsLayer::new()
668            .allow_origin(tower_http::cors::any())
669            .allow_methods(tower_http::cors::any()))
670    }
671
672    #[cfg(feature = "cors")]
673    fn add_cors_method_router<H, T>(method: Method, handler: H, state: S) -> MethodRouter<S>
674    where
675        H: Handler<T, S>,
676        T: 'static
677    {
678        ProductOSRouter::convert_handler_to_method_router(method, handler, state).layer(CorsLayer::new()
679            .allow_origin(tower_http::cors::any())
680            .allow_methods(tower_http::cors::any()))
681    }
682}
683
684
685
686
687
688/*
689fn add_middleware<L>(router: Router, middleware: L) -> Router
690    where
691        L: Layer<Route<Body>>,
692        L::Service: Service<Request<Body>, Response = Response<Body>, Error = BoxError> + Clone + Send + 'static,
693        <L::Service as Service<Request<Body>>>::Future: Send + 'static
694{
695    router.layer(middleware)
696}
697*/
698
699
700
701
702
703/*
704pub fn add_get<H, T, B>(handler: H) -> MethodRouter
705    where
706        H: Handler<T, B>,
707        B: Send + 'static,
708        T: 'static
709{
710    ServiceBuilder::new().service(get(handler))
711}
712
713
714pub fn add_route(router: &mut Router, path: &str, service_handler: MethodRouter) {
715    let service= ServiceBuilder::new().service(service_handler);
716    *router = router.clone().route(path, service);
717}
718 */
719
720/*
721
722WS ----
723
724use axum::{
725    extract::ws::{WebSocketUpgrade, WebSocket},
726    routing::get,
727    response::IntoResponse,
728    Router,
729};
730
731let app = Router::new().route("/ws", get(handler));
732
733async fn handler(ws: WebSocketUpgrade) -> impl IntoResponse {
734    ws.on_upgrade(handle_socket)
735}
736
737async fn handle_socket(mut socket: WebSocket) {
738    while let Some(msg) = socket.recv().await {
739        let msg = if let Ok(msg) = msg {
740            msg
741        } else {
742            // client disconnected
743            return;
744        };
745
746        if socket.send(msg).await.is_err() {
747            // client disconnected
748            return;
749        }
750    }
751}
752
753-- include state for WS
754
755use axum::{
756    extract::{ws::{WebSocketUpgrade, WebSocket}, State},
757    response::Response,
758    routing::get,
759    Router,
760};
761
762#[derive(Clone)]
763struct AppState {
764    // ...
765}
766
767async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> Response {
768    ws.on_upgrade(|socket| handle_socket(socket, state))
769}
770
771async fn handle_socket(socket: WebSocket, state: AppState) {
772    // ...
773}
774
775let app = Router::new()
776    .route("/ws", get(handler))
777    .with_state(AppState { /* ... */ });
778
779-- concurrent writes with WS
780
781use axum::{Error, extract::ws::{WebSocket, Message}};
782use futures::{sink::SinkExt, stream::{StreamExt, SplitSink, SplitStream}};
783
784async fn handle_socket(mut socket: WebSocket) {
785    let (mut sender, mut receiver) = socket.split();
786
787    tokio::spawn(write(sender));
788    tokio::spawn(read(receiver));
789}
790
791async fn read(receiver: SplitStream<WebSocket>) {
792    // ...
793}
794
795async fn write(sender: SplitSink<WebSocket, Message>) {
796    // ...
797}
798
799
800SSE ----
801
802use axum::{
803    Router,
804    routing::get,
805    response::sse::{Event, KeepAlive, Sse},
806};
807use std::{time::Duration, convert::BoxError};
808use tokio_stream::StreamExt as _ ;
809use futures::stream::{self, Stream};
810
811let app = Router::new().route("/sse", get(sse_handler));
812
813async fn sse_handler() -> Sse<impl Stream<Item = Result<Event, BoxError>>> {
814    // A `Stream` that repeats an event every second
815    let stream = stream::repeat_with(|| Event::default().data("hi!"))
816        .map(Ok)
817        .throttle(Duration::from_secs(1));
818
819    Sse::new(stream).keep_alive(KeepAlive::default())
820}
821
822 */